mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 16:48:06 -05:00
Compare commits
211 Commits
toggle-cor
...
fix/sql-in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
224411abd3 | ||
|
|
6b241af79e | ||
|
|
320fb7d83a | ||
|
|
54552248f7 | ||
|
|
d8a5780ea2 | ||
|
|
377657f8a1 | ||
|
|
ff71c940c9 | ||
|
|
9967b3a7ce | ||
|
|
9db443960a | ||
|
|
9316100864 | ||
|
|
cbe0cee0fc | ||
|
|
7cbb1ed859 | ||
|
|
e06e7ff33f | ||
|
|
acb946801b | ||
|
|
48ff225837 | ||
|
|
e2a9923f30 | ||
|
|
39792d517e | ||
|
|
a6a2f71458 | ||
|
|
788b861bb7 | ||
|
|
e203e65dc4 | ||
|
|
bd03697ff2 | ||
|
|
efd37b7a36 | ||
|
|
bb0b45d7f7 | ||
|
|
04df981115 | ||
|
|
d25997b4f2 | ||
|
|
11d55f6055 | ||
|
|
063dc5cf65 | ||
|
|
b7646f3e58 | ||
|
|
0befaf0a47 | ||
|
|
93f58dec5e | ||
|
|
3da595f599 | ||
|
|
e5e60921a3 | ||
|
|
90af8f8e1a | ||
|
|
eba67e0a4b | ||
|
|
47bb89caeb | ||
|
|
271a520afa | ||
|
|
3988057032 | ||
|
|
a6c6e48f00 | ||
|
|
e72ce2f9e7 | ||
|
|
bd7a79a920 | ||
|
|
3f546ae845 | ||
|
|
097a19141d | ||
|
|
c958c95d6b | ||
|
|
3e50cbd2cb | ||
|
|
1b69f1644d | ||
|
|
d9035a233c | ||
|
|
972cbfc3de | ||
|
|
8f861b1bb2 | ||
|
|
fa2731bb8b | ||
|
|
2dc0c97a52 | ||
|
|
0bb2b87c32 | ||
|
|
a1d9b45238 | ||
|
|
29895c290f | ||
|
|
73c0b6899a | ||
|
|
4c853a54d7 | ||
|
|
dfdd632161 | ||
|
|
1ed224d481 | ||
|
|
3b5d919399 | ||
|
|
3c16de22ef | ||
|
|
e4bc728d40 | ||
|
|
2c6d85d15e | ||
|
|
8258338caf | ||
|
|
374f35874c | ||
|
|
e62a56e8ba | ||
|
|
f3f9a60157 | ||
|
|
3ed1c93ec0 | ||
|
|
773f545cfd | ||
|
|
84ad4a9f95 | ||
|
|
8610118ddc | ||
|
|
9469b9e2eb | ||
|
|
ebb4ebb025 | ||
|
|
cb532e1c4d | ||
|
|
b7ae2c2fd2 | ||
|
|
794aee25ab | ||
|
|
8b995c2394 | ||
|
|
12b1067017 | ||
|
|
ba53cb78dc | ||
|
|
f9778cc87e | ||
|
|
b230b1b5cf | ||
|
|
1925e77733 | ||
|
|
9bc9b53b99 | ||
|
|
adfa75eca8 | ||
|
|
0f19d01483 | ||
|
|
112c39f6a6 | ||
|
|
22946f4617 | ||
|
|
938834ac83 | ||
|
|
934cb3a9c7 | ||
|
|
7b8499ec69 | ||
|
|
63076a67e1 | ||
|
|
41260a7b4a | ||
|
|
5f2d4643f8 | ||
|
|
9c8652b273 | ||
|
|
58ef687a54 | ||
|
|
c7dcbc64ec | ||
|
|
99ac206272 | ||
|
|
f67d78df3e | ||
|
|
e32c509ccc | ||
|
|
20acd8b51d | ||
|
|
a49c957467 | ||
|
|
cf6e724e99 | ||
|
|
b67555391d | ||
|
|
05a72f4185 | ||
|
|
36f634c417 | ||
|
|
18e169aa51 | ||
|
|
c5b90f7b09 | ||
|
|
a446c1acc9 | ||
|
|
59d242f69c | ||
|
|
a2cd5d9c1f | ||
|
|
df5b348676 | ||
|
|
4856bd1f3a | ||
|
|
2e1d3dd185 | ||
|
|
ff72343035 | ||
|
|
7982c34450 | ||
|
|
59c27fe248 | ||
|
|
c7575dc579 | ||
|
|
73603a8ce5 | ||
|
|
e562ca37aa | ||
|
|
f906fd9298 | ||
|
|
9e79add436 | ||
|
|
de6f4fca23 | ||
|
|
fb4b8ed9fc | ||
|
|
f3900127d7 | ||
|
|
7c47f54e25 | ||
|
|
927042d93e | ||
|
|
4244979a45 | ||
|
|
aa27365e7f | ||
|
|
b86aa8b14e | ||
|
|
e7ab2626f5 | ||
|
|
ff58ce174b | ||
|
|
2d8ab6b7c0 | ||
|
|
a7306970b8 | ||
|
|
c42f94ce2a | ||
|
|
4e1557e498 | ||
|
|
7f8cf36ceb | ||
|
|
0978566089 | ||
|
|
8b4eb6f87c | ||
|
|
4b7d17b9d2 | ||
|
|
0fc6a44389 | ||
|
|
f5ee579ab2 | ||
|
|
57a06f7088 | ||
|
|
258bf0b1a5 | ||
|
|
4a1cb6d64b | ||
|
|
7c9db7419a | ||
|
|
18bbd8e572 | ||
|
|
047f011520 | ||
|
|
d11917eb10 | ||
|
|
4663066e65 | ||
|
|
48a0faa611 | ||
|
|
70d00b4104 | ||
|
|
aad0434cb2 | ||
|
|
f33ec1f2ec | ||
|
|
e68b873bcf | ||
|
|
4530e97e59 | ||
|
|
477c261488 | ||
|
|
8ac2228e1e | ||
|
|
91dd9364bb | ||
|
|
f314fbf14f | ||
|
|
a97ff641c3 | ||
|
|
114f604d7b | ||
|
|
3abea1ed96 | ||
|
|
da6e1ad26d | ||
|
|
634fffb967 | ||
|
|
f3ec426c82 | ||
|
|
0b267f573e | ||
|
|
7bd571d9ce | ||
|
|
7a331651ba | ||
|
|
5bc69adc33 | ||
|
|
f4bcc8494f | ||
|
|
4c000086e6 | ||
|
|
9c6cc5b29d | ||
|
|
b34973ca47 | ||
|
|
2bc6a56877 | ||
|
|
87c773d03a | ||
|
|
ebeefc96e8 | ||
|
|
83fe8d5b94 | ||
|
|
50689218ed | ||
|
|
ddff09a8e4 | ||
|
|
0c363a1cea | ||
|
|
e5d870a348 | ||
|
|
3f19cba28f | ||
|
|
a978e91271 | ||
|
|
f283e6c514 | ||
|
|
9fc2101e7e | ||
|
|
634f826d82 | ||
|
|
6d6bf308fc | ||
|
|
dd84fb5c66 | ||
|
|
33679f3ffe | ||
|
|
fc8c5ccbb6 | ||
|
|
7d2ab61546 | ||
|
|
c2f11dbcfa | ||
|
|
f82adeb959 | ||
|
|
6f08a1cca7 | ||
|
|
1ddf92eed4 | ||
|
|
4c0dd27157 | ||
|
|
17fcf68f2e | ||
|
|
381558342a | ||
|
|
1fdc02467b | ||
|
|
f262bb9307 | ||
|
|
5a6978b07d | ||
|
|
339ec733cb | ||
|
|
6575b655f0 | ||
|
|
7c2df24d7c | ||
|
|
23eafa178c | ||
|
|
27fccdbf31 | ||
|
|
fb8fbc9d1f | ||
|
|
6a86e70fd6 | ||
|
|
6a2d7e0fb0 | ||
|
|
3d6ea3088e | ||
|
|
64b4480b1e | ||
|
|
f490b01abb | ||
|
|
e56a4a135d |
94
.github/copilot-instructions.md
vendored
94
.github/copilot-instructions.md
vendored
@@ -12,6 +12,7 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
|
||||
- **Infrastructure** - Docker configurations, CI/CD, and development tools
|
||||
|
||||
**Primary Languages & Frameworks:**
|
||||
|
||||
- **Backend**: Python 3.10-3.13, FastAPI, Prisma ORM, PostgreSQL, RabbitMQ
|
||||
- **Frontend**: TypeScript, Next.js 15, React, Tailwind CSS, Radix UI
|
||||
- **Development**: Docker, Poetry, pnpm, Playwright, Storybook
|
||||
@@ -23,15 +24,17 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
|
||||
**Always run these commands in the correct directory and in this order:**
|
||||
|
||||
1. **Initial Setup** (required once):
|
||||
|
||||
```bash
|
||||
# Clone and enter repository
|
||||
git clone <repo> && cd AutoGPT
|
||||
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
|
||||
2. **Backend Setup** (always run before backend development):
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry install # Install dependencies
|
||||
@@ -48,6 +51,7 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
|
||||
### Runtime Requirements
|
||||
|
||||
**Critical:** Always ensure Docker services are running before starting development:
|
||||
|
||||
```bash
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
@@ -58,6 +62,7 @@ cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
### Development Commands
|
||||
|
||||
**Backend Development:**
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry run serve # Start development server (port 8000)
|
||||
@@ -68,6 +73,7 @@ poetry run lint # Lint code (ruff) - run after format
|
||||
```
|
||||
|
||||
**Frontend Development:**
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm dev # Start development server (port 3000) - use for active development
|
||||
@@ -81,23 +87,27 @@ pnpm storybook # Start component development server
|
||||
### Testing Strategy
|
||||
|
||||
**Backend Tests:**
|
||||
|
||||
- **Block Tests**: `poetry run pytest backend/blocks/test/test_block.py -xvs` (validates all blocks)
|
||||
- **Specific Block**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[BlockName]' -xvs`
|
||||
- **Snapshot Tests**: Use `--snapshot-update` when output changes, always review with `git diff`
|
||||
|
||||
**Frontend Tests:**
|
||||
|
||||
- **E2E Tests**: Always run `pnpm dev` before `pnpm test` (Playwright requires running instance)
|
||||
- **Component Tests**: Use Storybook for isolated component development
|
||||
|
||||
### Critical Validation Steps
|
||||
|
||||
**Before committing changes:**
|
||||
|
||||
1. Run `poetry run format` (backend) and `pnpm format` (frontend)
|
||||
2. Ensure all tests pass in modified areas
|
||||
3. Verify Docker services are still running
|
||||
4. Check that database migrations apply cleanly
|
||||
|
||||
**Common Issues & Workarounds:**
|
||||
|
||||
- **Prisma issues**: Run `poetry run prisma generate` after schema changes
|
||||
- **Permission errors**: Ensure Docker has proper permissions
|
||||
- **Port conflicts**: Check the `docker-compose.yml` file for the current list of exposed ports. You can list all mapped ports with:
|
||||
@@ -108,6 +118,7 @@ pnpm storybook # Start component development server
|
||||
### Core Architecture
|
||||
|
||||
**AutoGPT Platform** (`autogpt_platform/`):
|
||||
|
||||
- `backend/` - FastAPI server with async support
|
||||
- `backend/backend/` - Core API logic
|
||||
- `backend/blocks/` - Agent execution blocks
|
||||
@@ -121,6 +132,7 @@ pnpm storybook # Start component development server
|
||||
- `docker-compose.yml` - Development stack orchestration
|
||||
|
||||
**Key Configuration Files:**
|
||||
|
||||
- `pyproject.toml` - Python dependencies and tooling
|
||||
- `package.json` - Node.js dependencies and scripts
|
||||
- `schema.prisma` - Database schema and migrations
|
||||
@@ -136,6 +148,7 @@ pnpm storybook # Start component development server
|
||||
### Development Workflow
|
||||
|
||||
**GitHub Actions**: Multiple CI/CD workflows in `.github/workflows/`
|
||||
|
||||
- `platform-backend-ci.yml` - Backend testing and validation
|
||||
- `platform-frontend-ci.yml` - Frontend testing and validation
|
||||
- `platform-fullstack-ci.yml` - End-to-end integration tests
|
||||
@@ -146,11 +159,13 @@ pnpm storybook # Start component development server
|
||||
### Key Source Files
|
||||
|
||||
**Backend Entry Points:**
|
||||
|
||||
- `backend/backend/server/server.py` - FastAPI application setup
|
||||
- `backend/backend/data/` - Database models and user management
|
||||
- `backend/blocks/` - Agent execution blocks and logic
|
||||
|
||||
**Frontend Entry Points:**
|
||||
|
||||
- `frontend/src/app/layout.tsx` - Root application layout
|
||||
- `frontend/src/app/page.tsx` - Home page
|
||||
- `frontend/src/lib/supabase/` - Authentication and database client
|
||||
@@ -160,6 +175,7 @@ pnpm storybook # Start component development server
|
||||
### Agent Block System
|
||||
|
||||
Agents are built using a visual block-based system where each block performs a single action. Blocks are defined in `backend/blocks/` and must include:
|
||||
|
||||
- Block definition with input/output schemas
|
||||
- Execution logic with proper error handling
|
||||
- Tests validating functionality
|
||||
@@ -167,6 +183,7 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
### Database & ORM
|
||||
|
||||
**Prisma ORM** with PostgreSQL backend including pgvector for embeddings:
|
||||
|
||||
- Schema in `schema.prisma`
|
||||
- Migrations in `backend/migrations/`
|
||||
- Always run `prisma migrate dev` and `prisma generate` after schema changes
|
||||
@@ -174,13 +191,15 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
## Environment Configuration
|
||||
|
||||
### Configuration Files Priority Order
|
||||
|
||||
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
|
||||
4. Docker Compose `environment:` sections override file-based config
|
||||
5. Shell environment variables have highest precedence
|
||||
|
||||
### Docker Environment Setup
|
||||
|
||||
- All services use hardcoded defaults (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
@@ -189,6 +208,7 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
## Advanced Development Patterns
|
||||
|
||||
### Adding New Blocks
|
||||
|
||||
1. Create file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class with input/output schemas
|
||||
3. Implement `run` method with proper error handling
|
||||
@@ -198,6 +218,7 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
7. Consider how inputs/outputs connect with other blocks in graph editor
|
||||
|
||||
### API Development
|
||||
|
||||
1. Update routes in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside route files
|
||||
@@ -205,21 +226,76 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
5. Run `poetry run test` to verify changes
|
||||
|
||||
### Frontend Development
|
||||
1. Components in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for component development
|
||||
4. Test user-facing features with Playwright E2E tests
|
||||
5. Update protected routes in middleware when needed
|
||||
|
||||
**📖 Complete Frontend Guide**: See `autogpt_platform/frontend/CONTRIBUTING.md` and `autogpt_platform/frontend/.cursorrules` for comprehensive patterns and conventions.
|
||||
|
||||
**Quick Reference:**
|
||||
|
||||
**Component Structure:**
|
||||
|
||||
- Separate render logic from data/behavior
|
||||
- Structure: `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Exception: Small components (3-4 lines of logic) can be inline
|
||||
- Render-only components can be direct files without folders
|
||||
|
||||
**Data Fetching:**
|
||||
|
||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Generated via Orval from backend OpenAPI spec
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
- Example: `useGetV2ListLibraryAgents`
|
||||
- Regenerate with: `pnpm generate:api`
|
||||
- **Never** use deprecated `BackendAPI` or `src/lib/autogpt-server-api/*`
|
||||
|
||||
**Code Conventions:**
|
||||
|
||||
- Use function declarations for components and handlers (not arrow functions)
|
||||
- Only arrow functions for small inline lambdas (map, filter, etc.)
|
||||
- Components: `PascalCase`, Hooks: `camelCase` with `use` prefix
|
||||
- No barrel files or `index.ts` re-exports
|
||||
- Minimal comments (code should be self-documenting)
|
||||
|
||||
**Styling:**
|
||||
|
||||
- Use Tailwind CSS utilities only
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
- Only use Phosphor Icons (`@phosphor-icons/react`)
|
||||
- Prefer design tokens over hardcoded values
|
||||
|
||||
**Error Handling:**
|
||||
|
||||
- Render errors: Use `<ErrorCard />` component
|
||||
- Mutation errors: Display with toast notifications
|
||||
- Manual exceptions: Use `Sentry.captureException()`
|
||||
- Global error boundaries already configured
|
||||
|
||||
**Testing:**
|
||||
|
||||
- Add/update Storybook stories for UI components (`pnpm storybook`)
|
||||
- Run Playwright E2E tests with `pnpm test`
|
||||
- Verify in Chromatic after PR
|
||||
|
||||
**Architecture:**
|
||||
|
||||
- Default to client components ("use client")
|
||||
- Server components only for SEO or extreme TTFB needs
|
||||
- Use React Query for server state (via generated hooks)
|
||||
- Co-locate UI state in components/hooks
|
||||
|
||||
### Security Guidelines
|
||||
|
||||
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
||||
|
||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
||||
- Prevents sensitive data caching in browsers/proxies
|
||||
- Add new cacheable endpoints to `CACHEABLE_PATHS`
|
||||
|
||||
### CI/CD Alignment
|
||||
|
||||
The repository has comprehensive CI workflows that test:
|
||||
|
||||
- **Backend**: Python 3.11-3.13, services (Redis/RabbitMQ/ClamAV), Prisma migrations, Poetry lock validation
|
||||
- **Frontend**: Node.js 21, pnpm, Playwright with Docker Compose stack, API schema validation
|
||||
- **Integration**: Full-stack type checking and E2E testing
|
||||
@@ -229,6 +305,7 @@ Match these patterns when developing locally - the copilot setup environment mir
|
||||
## Collaboration with Other AI Assistants
|
||||
|
||||
This repository is actively developed with assistance from Claude (via CLAUDE.md files). When working on this codebase:
|
||||
|
||||
- Check for existing CLAUDE.md files that provide additional context
|
||||
- Follow established patterns and conventions already in the codebase
|
||||
- Maintain consistency with existing code style and architecture
|
||||
@@ -237,8 +314,9 @@ This repository is actively developed with assistance from Claude (via CLAUDE.md
|
||||
## Trust These Instructions
|
||||
|
||||
These instructions are comprehensive and tested. Only perform additional searches if:
|
||||
|
||||
1. Information here is incomplete for your specific task
|
||||
2. You encounter errors not covered by the workarounds
|
||||
3. You need to understand implementation details not covered above
|
||||
|
||||
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
|
||||
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
|
||||
|
||||
@@ -5,6 +5,13 @@ on:
|
||||
branches: [ dev ]
|
||||
paths:
|
||||
- 'autogpt_platform/**'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_ref:
|
||||
description: 'Git ref (branch/tag) of AutoGPT to deploy'
|
||||
required: true
|
||||
default: 'master'
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
@@ -19,6 +26,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -48,4 +57,4 @@ jobs:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: build_deploy_dev
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'
|
||||
client-payload: '{"ref": "${{ github.event.inputs.git_ref || github.ref }}", "repository": "${{ github.repository }}"}'
|
||||
|
||||
@@ -3,6 +3,7 @@ name: AutoGPT Platform - Deploy Prod Environment
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
@@ -17,6 +18,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name || 'master' }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -36,7 +39,7 @@ jobs:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
|
||||
|
||||
trigger:
|
||||
needs: migrate
|
||||
runs-on: ubuntu-latest
|
||||
@@ -47,4 +50,5 @@ jobs:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: build_deploy_prod
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'
|
||||
client-payload: |
|
||||
{"ref": "${{ github.ref_name || 'master' }}", "repository": "${{ github.repository }}"}
|
||||
5
.github/workflows/platform-backend-ci.yml
vendored
5
.github/workflows/platform-backend-ci.yml
vendored
@@ -37,9 +37,7 @@ jobs:
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: bitnami/redis:6.2
|
||||
env:
|
||||
REDIS_PASSWORD: testpassword
|
||||
image: redis:latest
|
||||
ports:
|
||||
- 6379:6379
|
||||
rabbitmq:
|
||||
@@ -204,7 +202,6 @@ jobs:
|
||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
env:
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
[pr_reviewer]
|
||||
num_code_suggestions=0
|
||||
|
||||
[pr_code_suggestions]
|
||||
commitable_code_suggestions=false
|
||||
num_code_suggestions=0
|
||||
|
||||
@@ -63,6 +63,9 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
# Install dependencies
|
||||
cd frontend && pnpm i
|
||||
|
||||
# Generate API client from OpenAPI spec
|
||||
pnpm generate:api
|
||||
|
||||
# Start development server
|
||||
pnpm dev
|
||||
|
||||
@@ -75,12 +78,23 @@ pnpm storybook
|
||||
# Build production
|
||||
pnpm build
|
||||
|
||||
# Format and lint
|
||||
pnpm format
|
||||
|
||||
# Type checking
|
||||
pnpm types
|
||||
```
|
||||
|
||||
We have a components library in autogpt_platform/frontend/src/components/atoms that should be used when adding new pages and components.
|
||||
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
|
||||
|
||||
**Key Frontend Conventions:**
|
||||
|
||||
- Separate render logic from data/behavior in components
|
||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Use function declarations (not arrow functions) for components/handlers
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Only use Phosphor Icons
|
||||
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
@@ -95,11 +109,16 @@ We have a components library in autogpt_platform/frontend/src/components/atoms t
|
||||
|
||||
### Frontend Architecture
|
||||
|
||||
- **Framework**: Next.js App Router with React Server Components
|
||||
- **State Management**: React hooks + Supabase client for real-time updates
|
||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
- **UI Components**: Radix UI primitives with Tailwind CSS styling
|
||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
||||
- **Icons**: Phosphor Icons only
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
||||
- **Testing**: Playwright for E2E, Storybook for component development
|
||||
|
||||
### Key Concepts
|
||||
|
||||
@@ -153,6 +172,7 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
**Adding a new block:**
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
@@ -160,6 +180,7 @@ Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
@@ -180,10 +201,20 @@ ex: do the inputs and outputs tie well together?
|
||||
|
||||
**Frontend feature development:**
|
||||
|
||||
1. Components go in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for new components
|
||||
4. Test with Playwright if user-facing
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Add `usePageName.ts` hook for logic
|
||||
- Put sub-components in local `components/` folder
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
|
||||
### Security Implementation
|
||||
|
||||
|
||||
57
autogpt_platform/Makefile
Normal file
57
autogpt_platform/Makefile
Normal file
@@ -0,0 +1,57 @@
|
||||
.PHONY: start-core stop-core logs-core format lint migrate run-backend run-frontend
|
||||
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
start-core:
|
||||
docker compose up -d deps
|
||||
|
||||
# Stop core services
|
||||
stop-core:
|
||||
docker compose stop deps
|
||||
|
||||
reset-db:
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
docker compose logs -f deps
|
||||
|
||||
# Run formatting and linting for backend and frontend
|
||||
format:
|
||||
cd backend && poetry run format
|
||||
cd frontend && pnpm format
|
||||
cd frontend && pnpm lint
|
||||
|
||||
init-env:
|
||||
cp -n .env.default .env || true
|
||||
cd backend && cp -n .env.default .env || true
|
||||
cd frontend && cp -n .env.default .env || true
|
||||
|
||||
|
||||
# Run migrations for backend
|
||||
migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
|
||||
run-frontend:
|
||||
cd frontend && pnpm dev
|
||||
|
||||
test-data:
|
||||
cd backend && poetry run python test/test_data_creator.py
|
||||
|
||||
help:
|
||||
@echo "Usage: make <target>"
|
||||
@echo "Targets:"
|
||||
@echo " start-core - Start just the core services (Supabase, Redis, RabbitMQ) in background"
|
||||
@echo " stop-core - Stop the core services"
|
||||
@echo " reset-db - Reset the database by deleting the volume"
|
||||
@echo " logs-core - Tail the logs for core services"
|
||||
@echo " format - Format & lint backend (Python) and frontend (TypeScript) code"
|
||||
@echo " migrate - Run backend database migrations"
|
||||
@echo " run-backend - Run the backend FastAPI server"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@echo " test-data - Run the test data creator"
|
||||
@@ -38,6 +38,37 @@ To run the AutoGPT Platform, follow these steps:
|
||||
|
||||
4. After all the services are in ready state, open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
|
||||
### Running Just Core services
|
||||
|
||||
You can now run the following to enable just the core services.
|
||||
|
||||
```
|
||||
# For help
|
||||
make help
|
||||
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
make start-core
|
||||
|
||||
# Stop core services
|
||||
make stop-core
|
||||
|
||||
# View logs from core services
|
||||
make logs-core
|
||||
|
||||
# Run formatting and linting for backend and frontend
|
||||
make format
|
||||
|
||||
# Run migrations for backend database
|
||||
make migrate
|
||||
|
||||
# Run backend server
|
||||
make run-backend
|
||||
|
||||
# Run frontend development server
|
||||
make run-frontend
|
||||
|
||||
```
|
||||
|
||||
### Docker Compose Commands
|
||||
|
||||
Here are some useful Docker Compose commands for managing your AutoGPT Platform:
|
||||
|
||||
@@ -10,7 +10,7 @@ from .jwt_utils import get_jwt_payload, verify_user
|
||||
from .models import User
|
||||
|
||||
|
||||
def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
async def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid authenticated user.
|
||||
|
||||
@@ -20,7 +20,9 @@ def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User
|
||||
return verify_user(jwt_payload, admin_only=False)
|
||||
|
||||
|
||||
def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
async def requires_admin_user(
|
||||
jwt_payload: dict = fastapi.Security(get_jwt_payload),
|
||||
) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid admin user.
|
||||
|
||||
@@ -30,7 +32,7 @@ def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -
|
||||
return verify_user(jwt_payload, admin_only=True)
|
||||
|
||||
|
||||
def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
async def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
"""
|
||||
FastAPI dependency that returns the ID of the authenticated user.
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestAuthDependencies:
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
async def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user with valid JWT payload."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
@@ -53,12 +53,12 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_user(jwt_payload)
|
||||
user = await requires_user(jwt_payload)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.role == "user"
|
||||
|
||||
def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
async def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user accepts admin users."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
@@ -69,28 +69,28 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_user(jwt_payload)
|
||||
user = await requires_user(jwt_payload)
|
||||
assert user.user_id == "admin-456"
|
||||
assert user.role == "admin"
|
||||
|
||||
def test_requires_user_missing_sub(self):
|
||||
async def test_requires_user_missing_sub(self):
|
||||
"""Test requires_user with missing user ID."""
|
||||
jwt_payload = {"role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_user(jwt_payload)
|
||||
await requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
def test_requires_user_empty_sub(self):
|
||||
async def test_requires_user_empty_sub(self):
|
||||
"""Test requires_user with empty user ID."""
|
||||
jwt_payload = {"sub": "", "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_user(jwt_payload)
|
||||
await requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
async def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
"""Test requires_admin_user with admin role."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-789",
|
||||
@@ -101,51 +101,51 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_admin_user(jwt_payload)
|
||||
user = await requires_admin_user(jwt_payload)
|
||||
assert user.user_id == "admin-789"
|
||||
assert user.role == "admin"
|
||||
|
||||
def test_requires_admin_user_with_regular_user(self):
|
||||
async def test_requires_admin_user_with_regular_user(self):
|
||||
"""Test requires_admin_user rejects regular users."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_admin_user(jwt_payload)
|
||||
await requires_admin_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
def test_requires_admin_user_missing_role(self):
|
||||
async def test_requires_admin_user_missing_role(self):
|
||||
"""Test requires_admin_user with missing role."""
|
||||
jwt_payload = {"sub": "user-123", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
requires_admin_user(jwt_payload)
|
||||
await requires_admin_user(jwt_payload)
|
||||
|
||||
def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
async def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
"""Test get_user_id extracts user ID correctly."""
|
||||
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user_id = get_user_id(jwt_payload)
|
||||
user_id = await get_user_id(jwt_payload)
|
||||
assert user_id == "user-id-xyz"
|
||||
|
||||
def test_get_user_id_missing_sub(self):
|
||||
async def test_get_user_id_missing_sub(self):
|
||||
"""Test get_user_id with missing user ID."""
|
||||
jwt_payload = {"role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_id(jwt_payload)
|
||||
await get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
def test_get_user_id_none_sub(self):
|
||||
async def test_get_user_id_none_sub(self):
|
||||
"""Test get_user_id with None user ID."""
|
||||
jwt_payload = {"sub": None, "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_id(jwt_payload)
|
||||
await get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@@ -170,7 +170,7 @@ class TestAuthDependenciesIntegration:
|
||||
|
||||
return _create_token
|
||||
|
||||
def test_endpoint_auth_enabled_no_token(self):
|
||||
async def test_endpoint_auth_enabled_no_token(self):
|
||||
"""Test endpoints require token when auth is enabled."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -184,7 +184,7 @@ class TestAuthDependenciesIntegration:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_endpoint_with_valid_token(self, create_token):
|
||||
async def test_endpoint_with_valid_token(self, create_token):
|
||||
"""Test endpoint with valid JWT token."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -203,7 +203,7 @@ class TestAuthDependenciesIntegration:
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "test-user"
|
||||
|
||||
def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
async def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
"""Test admin endpoint rejects non-admin users."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -240,7 +240,7 @@ class TestAuthDependenciesIntegration:
|
||||
class TestAuthDependenciesEdgeCases:
|
||||
"""Edge case tests for authentication dependencies."""
|
||||
|
||||
def test_dependency_with_complex_payload(self):
|
||||
async def test_dependency_with_complex_payload(self):
|
||||
"""Test dependencies handle complex JWT payloads."""
|
||||
complex_payload = {
|
||||
"sub": "user-123",
|
||||
@@ -256,14 +256,14 @@ class TestAuthDependenciesEdgeCases:
|
||||
"exp": 9999999999,
|
||||
}
|
||||
|
||||
user = requires_user(complex_payload)
|
||||
user = await requires_user(complex_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
admin = requires_admin_user(complex_payload)
|
||||
admin = await requires_admin_user(complex_payload)
|
||||
assert admin.role == "admin"
|
||||
|
||||
def test_dependency_with_unicode_in_payload(self):
|
||||
async def test_dependency_with_unicode_in_payload(self):
|
||||
"""Test dependencies handle unicode in JWT payloads."""
|
||||
unicode_payload = {
|
||||
"sub": "user-😀-123",
|
||||
@@ -272,11 +272,11 @@ class TestAuthDependenciesEdgeCases:
|
||||
"name": "日本語",
|
||||
}
|
||||
|
||||
user = requires_user(unicode_payload)
|
||||
user = await requires_user(unicode_payload)
|
||||
assert "😀" in user.user_id
|
||||
assert user.email == "测试@example.com"
|
||||
|
||||
def test_dependency_with_null_values(self):
|
||||
async def test_dependency_with_null_values(self):
|
||||
"""Test dependencies handle null values in payload."""
|
||||
null_payload = {
|
||||
"sub": "user-123",
|
||||
@@ -286,18 +286,18 @@ class TestAuthDependenciesEdgeCases:
|
||||
"metadata": None,
|
||||
}
|
||||
|
||||
user = requires_user(null_payload)
|
||||
user = await requires_user(null_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email is None
|
||||
|
||||
def test_concurrent_requests_isolation(self):
|
||||
async def test_concurrent_requests_isolation(self):
|
||||
"""Test that concurrent requests don't interfere with each other."""
|
||||
payload1 = {"sub": "user-1", "role": "user"}
|
||||
payload2 = {"sub": "user-2", "role": "admin"}
|
||||
|
||||
# Simulate concurrent processing
|
||||
user1 = requires_user(payload1)
|
||||
user2 = requires_admin_user(payload2)
|
||||
user1 = await requires_user(payload1)
|
||||
user2 = await requires_admin_user(payload2)
|
||||
|
||||
assert user1.user_id == "user-1"
|
||||
assert user2.user_id == "user-2"
|
||||
@@ -314,7 +314,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
({"sub": "user", "role": "user"}, "Admin access required", True),
|
||||
],
|
||||
)
|
||||
def test_dependency_error_cases(
|
||||
async def test_dependency_error_cases(
|
||||
self, payload, expected_error: str, admin_only: bool
|
||||
):
|
||||
"""Test that errors propagate correctly through dependencies."""
|
||||
@@ -325,7 +325,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
verify_user(payload, admin_only=admin_only)
|
||||
assert expected_error in exc_info.value.detail
|
||||
|
||||
def test_dependency_valid_user(self):
|
||||
async def test_dependency_valid_user(self):
|
||||
"""Test valid user case for dependency."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
from autogpt_libs.auth.jwt_utils import verify_user
|
||||
|
||||
@@ -16,7 +16,7 @@ bearer_jwt_auth = HTTPBearer(
|
||||
)
|
||||
|
||||
|
||||
def get_jwt_payload(
|
||||
async def get_jwt_payload(
|
||||
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -116,32 +116,32 @@ def test_parse_jwt_token_missing_audience():
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_get_jwt_payload_with_valid_token():
|
||||
async def test_get_jwt_payload_with_valid_token():
|
||||
"""Test extracting JWT payload with valid bearer token."""
|
||||
token = create_token(TEST_USER_PAYLOAD)
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
|
||||
result = jwt_utils.get_jwt_payload(credentials)
|
||||
result = await jwt_utils.get_jwt_payload(credentials)
|
||||
assert result["sub"] == "test-user-id"
|
||||
assert result["role"] == "user"
|
||||
|
||||
|
||||
def test_get_jwt_payload_no_credentials():
|
||||
async def test_get_jwt_payload_no_credentials():
|
||||
"""Test JWT payload when no credentials provided."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.get_jwt_payload(None)
|
||||
await jwt_utils.get_jwt_payload(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Authorization header is missing" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_get_jwt_payload_invalid_token():
|
||||
async def test_get_jwt_payload_invalid_token():
|
||||
"""Test JWT payload extraction with invalid token."""
|
||||
credentials = HTTPAuthorizationCredentials(
|
||||
scheme="Bearer", credentials="invalid.token.here"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.get_jwt_payload(credentials)
|
||||
await jwt_utils.get_jwt_payload(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid token" in exc_info.value.detail
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
@@ -93,42 +94,36 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
config = LoggingConfig()
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
structured_logging = config.enable_cloud_logging or force_cloud_logging
|
||||
|
||||
# Console output handlers
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
if not structured_logging:
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
log_handlers += [stdout, stderr]
|
||||
log_handlers += [stdout, stderr]
|
||||
|
||||
# Cloud logging setup
|
||||
if config.enable_cloud_logging or force_cloud_logging:
|
||||
import google.cloud.logging
|
||||
from google.cloud.logging.handlers import CloudLoggingHandler
|
||||
from google.cloud.logging_v2.handlers.transports import (
|
||||
BackgroundThreadTransport,
|
||||
)
|
||||
else:
|
||||
# Use Google Cloud Structured Log Handler. Log entries are printed to stdout
|
||||
# in a JSON format which is automatically picked up by Google Cloud Logging.
|
||||
from google.cloud.logging.handlers import StructuredLogHandler
|
||||
|
||||
client = google.cloud.logging.Client()
|
||||
# Use BackgroundThreadTransport to prevent blocking the main thread
|
||||
# and deadlocks when gRPC calls to Google Cloud Logging hang
|
||||
cloud_handler = CloudLoggingHandler(
|
||||
client,
|
||||
name="autogpt_logs",
|
||||
transport=BackgroundThreadTransport,
|
||||
)
|
||||
cloud_handler.setLevel(config.level)
|
||||
log_handlers.append(cloud_handler)
|
||||
structured_log_handler = StructuredLogHandler(stream=sys.stdout)
|
||||
structured_log_handler.setLevel(config.level)
|
||||
log_handlers.append(structured_log_handler)
|
||||
|
||||
# File logging setup
|
||||
if config.enable_file_logging:
|
||||
@@ -139,8 +134,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
print(f"Log directory: {config.log_dir}")
|
||||
|
||||
# Activity log handler (INFO and above)
|
||||
activity_log_handler = logging.FileHandler(
|
||||
config.log_dir / LOG_FILE, "a", "utf-8"
|
||||
# Security fix: Use RotatingFileHandler with size limits to prevent disk exhaustion
|
||||
activity_log_handler = RotatingFileHandler(
|
||||
config.log_dir / LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
)
|
||||
activity_log_handler.setLevel(config.level)
|
||||
activity_log_handler.setFormatter(
|
||||
@@ -150,8 +150,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
|
||||
if config.level == logging.DEBUG:
|
||||
# Debug log handler (all levels)
|
||||
debug_log_handler = logging.FileHandler(
|
||||
config.log_dir / DEBUG_LOG_FILE, "a", "utf-8"
|
||||
# Security fix: Use RotatingFileHandler with size limits
|
||||
debug_log_handler = RotatingFileHandler(
|
||||
config.log_dir / DEBUG_LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
)
|
||||
debug_log_handler.setLevel(logging.DEBUG)
|
||||
debug_log_handler.setFormatter(
|
||||
@@ -160,8 +165,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
log_handlers.append(debug_log_handler)
|
||||
|
||||
# Error log handler (ERROR and above)
|
||||
error_log_handler = logging.FileHandler(
|
||||
config.log_dir / ERROR_LOG_FILE, "a", "utf-8"
|
||||
# Security fix: Use RotatingFileHandler with size limits
|
||||
error_log_handler = RotatingFileHandler(
|
||||
config.log_dir / ERROR_LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
)
|
||||
error_log_handler.setLevel(logging.ERROR)
|
||||
error_log_handler.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT, no_color=True))
|
||||
@@ -169,7 +179,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
|
||||
# Configure the root logger
|
||||
logging.basicConfig(
|
||||
format=DEBUG_LOG_FORMAT if config.level == logging.DEBUG else SIMPLE_LOG_FORMAT,
|
||||
format=(
|
||||
"%(levelname)s %(message)s"
|
||||
if structured_logging
|
||||
else (
|
||||
DEBUG_LOG_FORMAT if config.level == logging.DEBUG else SIMPLE_LOG_FORMAT
|
||||
)
|
||||
),
|
||||
level=config.level,
|
||||
handlers=log_handlers,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
@@ -13,8 +15,8 @@ class RateLimitSettings(BaseSettings):
|
||||
default="6379", description="Redis port", validation_alias="REDIS_PORT"
|
||||
)
|
||||
|
||||
redis_password: str = Field(
|
||||
default="password",
|
||||
redis_password: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Redis password",
|
||||
validation_alias="REDIS_PASSWORD",
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ class RateLimiter:
|
||||
self,
|
||||
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
|
||||
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
|
||||
redis_password: str = RATE_LIMIT_SETTINGS.redis_password,
|
||||
redis_password: str | None = RATE_LIMIT_SETTINGS.redis_password,
|
||||
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
|
||||
):
|
||||
self.redis = Redis(
|
||||
|
||||
@@ -1,266 +0,0 @@
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
pass
|
||||
|
||||
|
||||
def thread_cached(
|
||||
func: Callable[P, R] | Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
|
||||
thread_local = threading.local()
|
||||
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
|
||||
|
||||
FuncT = TypeVar("FuncT")
|
||||
|
||||
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncCachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for async functions with cache management methods."""
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
return None
|
||||
|
||||
def cache_info(self) -> dict[str, int | None]:
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
"""Call the cached function."""
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
def async_ttl_cache(
|
||||
maxsize: int = 128, ttl_seconds: int | None = None
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
"""
|
||||
TTL (Time To Live) cache decorator for async functions.
|
||||
|
||||
Similar to functools.lru_cache but works with async functions and includes optional TTL.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of cached entries
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
# With TTL
|
||||
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
|
||||
async def api_call(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# Without TTL (permanent cache like lru_cache)
|
||||
@async_ttl_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
async_func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
# Cache storage - use union type to handle both cases
|
||||
cache_storage: dict[tuple, R | Tuple[R, float]] = {}
|
||||
|
||||
@wraps(async_func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# Create cache key from arguments
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we have a valid cached entry
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
# No TTL - return cached result directly
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, cache_storage[key])
|
||||
else:
|
||||
# With TTL - check expiration
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, result)
|
||||
else:
|
||||
# Expired entry
|
||||
del cache_storage[key]
|
||||
logger.debug(
|
||||
f"Cache entry expired for {async_func.__name__}"
|
||||
)
|
||||
|
||||
# Cache miss or expired - fetch fresh data
|
||||
logger.debug(
|
||||
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
result = await async_func(*args, **kwargs)
|
||||
|
||||
# Store in cache
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Simple cleanup when cache gets too large
|
||||
if len(cache_storage) > maxsize:
|
||||
# Remove oldest entries (simple FIFO cleanup)
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
logger.debug(
|
||||
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# Add cache management methods (similar to functools.lru_cache)
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
# Attach methods to wrapper
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
|
||||
return cast(AsyncCachedFunction[P, R], wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
pass
|
||||
|
||||
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]] | None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> (
|
||||
AsyncCachedFunction[P, R]
|
||||
| Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]
|
||||
):
|
||||
"""
|
||||
Process-level cache decorator for async functions (no TTL).
|
||||
|
||||
Similar to functools.lru_cache but works with async functions.
|
||||
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
|
||||
|
||||
Args:
|
||||
func: The async function to cache (when used without parentheses)
|
||||
maxsize: Maximum number of cached entries
|
||||
|
||||
Returns:
|
||||
Decorated function or decorator
|
||||
|
||||
Example:
|
||||
# Without parentheses (uses default maxsize=128)
|
||||
@async_cache
|
||||
async def get_data(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# With parentheses and custom maxsize
|
||||
@async_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
# Expensive computation here
|
||||
return {"result": param}
|
||||
"""
|
||||
if func is None:
|
||||
# Called with parentheses @async_cache() or @async_cache(maxsize=...)
|
||||
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
else:
|
||||
# Called without parentheses @async_cache
|
||||
decorator = async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
return decorator(func)
|
||||
@@ -1,705 +0,0 @@
|
||||
"""Tests for the @thread_cached decorator.
|
||||
|
||||
This module tests the thread-local caching functionality including:
|
||||
- Basic caching for sync and async functions
|
||||
- Thread isolation (each thread has its own cache)
|
||||
- Cache clearing functionality
|
||||
- Exception handling (exceptions are not cached)
|
||||
- Argument handling (positional vs keyword arguments)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import (
|
||||
async_cache,
|
||||
async_ttl_cache,
|
||||
clear_thread_cache,
|
||||
thread_cached,
|
||||
)
|
||||
|
||||
|
||||
class TestThreadCached:
|
||||
def test_sync_function_caching(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def expensive_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x + y
|
||||
|
||||
assert expensive_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert expensive_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert expensive_function(1, y=2) == 3
|
||||
assert call_count == 2
|
||||
|
||||
assert expensive_function(2, 3) == 5
|
||||
assert call_count == 3
|
||||
|
||||
assert expensive_function(1) == 1
|
||||
assert call_count == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function_caching(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def expensive_async_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x + y
|
||||
|
||||
assert await expensive_async_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert await expensive_async_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert await expensive_async_function(1, y=2) == 3
|
||||
assert call_count == 2
|
||||
|
||||
assert await expensive_async_function(2, 3) == 5
|
||||
assert call_count == 3
|
||||
|
||||
def test_thread_isolation(self):
|
||||
call_count = 0
|
||||
results = {}
|
||||
|
||||
@thread_cached
|
||||
def thread_specific_function(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{threading.current_thread().name}-{x}"
|
||||
|
||||
def worker(thread_id: int):
|
||||
result1 = thread_specific_function(1)
|
||||
result2 = thread_specific_function(1)
|
||||
result3 = thread_specific_function(2)
|
||||
results[thread_id] = (result1, result2, result3)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
futures = [executor.submit(worker, i) for i in range(3)]
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
assert call_count >= 2
|
||||
|
||||
for thread_id, (r1, r2, r3) in results.items():
|
||||
assert r1 == r2
|
||||
assert r1 != r3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_thread_isolation(self):
|
||||
call_count = 0
|
||||
results = {}
|
||||
|
||||
@thread_cached
|
||||
async def async_thread_specific_function(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return f"{threading.current_thread().name}-{x}"
|
||||
|
||||
async def async_worker(worker_id: int):
|
||||
result1 = await async_thread_specific_function(1)
|
||||
result2 = await async_thread_specific_function(1)
|
||||
result3 = await async_thread_specific_function(2)
|
||||
results[worker_id] = (result1, result2, result3)
|
||||
|
||||
tasks = [async_worker(i) for i in range(3)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
for worker_id, (r1, r2, r3) in results.items():
|
||||
assert r1 == r2
|
||||
assert r1 != r3
|
||||
|
||||
def test_clear_cache_sync(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
clear_thread_cache(clearable_function)
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_cache_async(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def clearable_async_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x * 2
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
clear_thread_cache(clearable_async_function)
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 2
|
||||
|
||||
def test_simple_arguments(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def simple_function(a: str, b: int, c: str = "default") -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# First call with all positional args
|
||||
result1 = simple_function("test", 42, "custom")
|
||||
assert call_count == 1
|
||||
|
||||
# Same args, all positional - should hit cache
|
||||
result2 = simple_function("test", 42, "custom")
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Same values but last arg as keyword - creates different cache key
|
||||
result3 = simple_function("test", 42, c="custom")
|
||||
assert call_count == 2
|
||||
assert result1 == result3 # Same result, different cache entry
|
||||
|
||||
# Different value - new cache entry
|
||||
result4 = simple_function("test", 43, "custom")
|
||||
assert call_count == 3
|
||||
assert result1 != result4
|
||||
|
||||
def test_positional_vs_keyword_args(self):
|
||||
"""Test that positional and keyword arguments create different cache entries."""
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def func(a: int, b: int = 10) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"result-{a}-{b}"
|
||||
|
||||
# All positional
|
||||
result1 = func(1, 2)
|
||||
assert call_count == 1
|
||||
assert result1 == "result-1-2"
|
||||
|
||||
# Same values, but second arg as keyword
|
||||
result2 = func(1, b=2)
|
||||
assert call_count == 2 # Different cache key!
|
||||
assert result2 == "result-1-2" # Same result
|
||||
|
||||
# Verify both are cached separately
|
||||
func(1, 2) # Uses first cache entry
|
||||
assert call_count == 2
|
||||
|
||||
func(1, b=2) # Uses second cache entry
|
||||
assert call_count == 2
|
||||
|
||||
def test_exception_handling(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def failing_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value")
|
||||
return x * 2
|
||||
|
||||
assert failing_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
failing_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
failing_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
assert failing_function(5) == 10
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_exception_handling(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def async_failing_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
if x < 0:
|
||||
raise ValueError("Negative value")
|
||||
return x * 2
|
||||
|
||||
assert await async_failing_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_failing_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_failing_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
def test_sync_caching_performance(self):
|
||||
@thread_cached
|
||||
def slow_function(x: int) -> int:
|
||||
print(f"slow_function called with x={x}")
|
||||
time.sleep(0.1)
|
||||
return x * 2
|
||||
|
||||
start = time.time()
|
||||
result1 = slow_function(5)
|
||||
first_call_time = time.time() - start
|
||||
print(f"First call took {first_call_time:.4f} seconds")
|
||||
|
||||
start = time.time()
|
||||
result2 = slow_function(5)
|
||||
second_call_time = time.time() - start
|
||||
print(f"Second call took {second_call_time:.4f} seconds")
|
||||
|
||||
assert result1 == result2 == 10
|
||||
assert first_call_time > 0.09
|
||||
assert second_call_time < 0.01
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_caching_performance(self):
|
||||
@thread_cached
|
||||
async def slow_async_function(x: int) -> int:
|
||||
print(f"slow_async_function called with x={x}")
|
||||
await asyncio.sleep(0.1)
|
||||
return x * 2
|
||||
|
||||
start = time.time()
|
||||
result1 = await slow_async_function(5)
|
||||
first_call_time = time.time() - start
|
||||
print(f"First async call took {first_call_time:.4f} seconds")
|
||||
|
||||
start = time.time()
|
||||
result2 = await slow_async_function(5)
|
||||
second_call_time = time.time() - start
|
||||
print(f"Second async call took {second_call_time:.4f} seconds")
|
||||
|
||||
assert result1 == result2 == 10
|
||||
assert first_call_time > 0.09
|
||||
assert second_call_time < 0.01
|
||||
|
||||
def test_with_mock_objects(self):
|
||||
mock = Mock(return_value=42)
|
||||
|
||||
@thread_cached
|
||||
def function_using_mock(x: int) -> int:
|
||||
return mock(x)
|
||||
|
||||
assert function_using_mock(1) == 42
|
||||
assert mock.call_count == 1
|
||||
|
||||
assert function_using_mock(1) == 42
|
||||
assert mock.call_count == 1
|
||||
|
||||
assert function_using_mock(2) == 42
|
||||
assert mock.call_count == 2
|
||||
|
||||
|
||||
class TestAsyncTTLCache:
|
||||
"""Tests for the @async_ttl_cache decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching(self):
|
||||
"""Test basic caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = await cached_function(2, 3)
|
||||
assert result3 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_expiration(self):
|
||||
"""Test that cache entries expire after TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def short_lived_cache(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First call
|
||||
result1 = await short_lived_cache(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Second call immediately - should use cache
|
||||
result2 = await short_lived_cache(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Third call after expiration - should call function again
|
||||
result3 = await short_lived_cache(5)
|
||||
assert result3 == 10
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_info(self):
|
||||
"""Test cache info functionality."""
|
||||
|
||||
@async_ttl_cache(maxsize=5, ttl_seconds=300)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] == 300
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_clear(self):
|
||||
"""Test cache clearing functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 4
|
||||
|
||||
# First call
|
||||
result1 = await clearable_function(2)
|
||||
assert result1 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = await clearable_function(2)
|
||||
assert result2 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Clear cache
|
||||
clearable_function.cache_clear()
|
||||
|
||||
# Third call after clear - should call function again
|
||||
result3 = await clearable_function(2)
|
||||
assert result3 == 8
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maxsize_cleanup(self):
|
||||
"""Test that cache cleans up when maxsize is exceeded."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=3, ttl_seconds=60)
|
||||
async def size_limited_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# Fill cache to maxsize
|
||||
await size_limited_function(1) # call_count: 1
|
||||
await size_limited_function(2) # call_count: 2
|
||||
await size_limited_function(3) # call_count: 3
|
||||
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
# Add one more entry - should trigger cleanup
|
||||
await size_limited_function(4) # call_count: 4
|
||||
|
||||
# Cache size should be reduced (cleanup removes oldest entries)
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] is not None and info["size"] <= 3 # Should be cleaned up
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_argument_variations(self):
|
||||
"""Test caching with different argument patterns."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# Different ways to call with same logical arguments
|
||||
result1 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
|
||||
# Same arguments, same order - should use cache
|
||||
result2 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Different arguments - should call function
|
||||
result3 = await arg_test_function(2, "test", c=200)
|
||||
assert call_count == 2
|
||||
assert result1 != result3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling(self):
|
||||
"""Test that exceptions are not cached."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def exception_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value not allowed")
|
||||
return x * 2
|
||||
|
||||
# Successful call - should be cached
|
||||
result1 = await exception_function(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Same successful call - should use cache
|
||||
result2 = await exception_function(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Exception call - should not be cached
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
# Same exception call - should call again (not cached)
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_calls(self):
|
||||
"""Test caching behavior with concurrent calls."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def concurrent_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.05) # Simulate work
|
||||
return x * x
|
||||
|
||||
# Launch concurrent calls with same arguments
|
||||
tasks = [concurrent_function(3) for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 9 for result in results)
|
||||
|
||||
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
|
||||
# This tests that the cache doesn't break under concurrent access
|
||||
assert 1 <= call_count <= 5
|
||||
|
||||
|
||||
class TestAsyncCache:
|
||||
"""Tests for the @async_cache decorator (no TTL)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching_no_ttl(self):
|
||||
"""Test basic caching functionality without TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_cache(maxsize=10)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Third call after some time - should still use cache (no TTL)
|
||||
await asyncio.sleep(0.05)
|
||||
result3 = await cached_function(1, 2)
|
||||
assert result3 == 3
|
||||
assert call_count == 1 # Still no additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result4 = await cached_function(2, 3)
|
||||
assert result4 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_ttl_vs_ttl_behavior(self):
|
||||
"""Test the difference between TTL and no-TTL caching."""
|
||||
ttl_call_count = 0
|
||||
no_ttl_call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_call_count
|
||||
ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
@async_cache(maxsize=10) # No TTL
|
||||
async def no_ttl_function(x: int) -> int:
|
||||
nonlocal no_ttl_call_count
|
||||
no_ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First calls
|
||||
await ttl_function(5)
|
||||
await no_ttl_function(5)
|
||||
assert ttl_call_count == 1
|
||||
assert no_ttl_call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Second calls after TTL expiry
|
||||
await ttl_function(5) # Should call function again (TTL expired)
|
||||
await no_ttl_function(5) # Should use cache (no TTL)
|
||||
assert ttl_call_count == 2 # TTL function called again
|
||||
assert no_ttl_call_count == 1 # No-TTL function still cached
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_info(self):
|
||||
"""Test cache info for no-TTL cache."""
|
||||
|
||||
@async_cache(maxsize=5)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] is None # No TTL
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
|
||||
class TestTTLOptional:
|
||||
"""Tests for optional TTL functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_none_behavior(self):
|
||||
"""Test that ttl_seconds=None works like no TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=None)
|
||||
async def no_ttl_via_none(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# First call
|
||||
result1 = await no_ttl_via_none(3)
|
||||
assert result1 == 9
|
||||
assert call_count == 1
|
||||
|
||||
# Wait (would expire if there was TTL)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Second call - should still use cache
|
||||
result2 = await no_ttl_via_none(3)
|
||||
assert result2 == 9
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Check cache info
|
||||
info = no_ttl_via_none.cache_info()
|
||||
assert info["ttl_seconds"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_options_comparison(self):
|
||||
"""Test different cache options work as expected."""
|
||||
ttl_calls = 0
|
||||
no_ttl_calls = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_calls
|
||||
ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
@async_cache(maxsize=10) # Process-level cache (no TTL)
|
||||
async def process_function(x: int) -> int:
|
||||
nonlocal no_ttl_calls
|
||||
no_ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
# Both should cache initially
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Immediate second calls - both should use cache
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# After TTL expiry
|
||||
await ttl_function(3) # Should call function again
|
||||
await process_function(3) # Should still use cache
|
||||
assert ttl_calls == 2 # TTL cache expired, called again
|
||||
assert no_ttl_calls == 1 # Process cache never expires
|
||||
@@ -21,7 +21,7 @@ PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
# Redis Configuration
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=password
|
||||
# REDIS_PASSWORD=
|
||||
|
||||
# RabbitMQ Credentials
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
@@ -66,6 +66,11 @@ NVIDIA_API_KEY=
|
||||
GITHUB_CLIENT_ID=
|
||||
GITHUB_CLIENT_SECRET=
|
||||
|
||||
# Notion OAuth App server credentials - https://developers.notion.com/docs/authorization
|
||||
# Configure a public integration
|
||||
NOTION_CLIENT_ID=
|
||||
NOTION_CLIENT_SECRET=
|
||||
|
||||
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
|
||||
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
|
||||
# You'll need to add/enable the following scopes (minimum):
|
||||
|
||||
10
autogpt_platform/backend/.gitignore
vendored
10
autogpt_platform/backend/.gitignore
vendored
@@ -9,4 +9,12 @@ secrets/*
|
||||
!secrets/.gitkeep
|
||||
|
||||
*.ignore.*
|
||||
*.ign.*
|
||||
*.ign.*
|
||||
|
||||
# Load test results and reports
|
||||
load-tests/*_RESULTS.md
|
||||
load-tests/*_REPORT.md
|
||||
load-tests/results/
|
||||
load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
|
||||
@@ -9,8 +9,15 @@ WORKDIR /app
|
||||
|
||||
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
||||
|
||||
# Update package list and install Python and build dependencies
|
||||
# Install Node.js repository key and setup
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
&& apt-get install -y curl ca-certificates gnupg \
|
||||
&& mkdir -p /etc/apt/keyrings \
|
||||
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg \
|
||||
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" | tee /etc/apt/sources.list.d/nodesource.list
|
||||
|
||||
# Update package list and install Python, Node.js, and build dependencies
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
python3.13 \
|
||||
python3.13-dev \
|
||||
@@ -20,7 +27,9 @@ RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
libpq5 \
|
||||
libz-dev \
|
||||
libssl-dev \
|
||||
postgresql-client
|
||||
postgresql-client \
|
||||
nodejs \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
@@ -38,6 +47,7 @@ RUN poetry install --no-ansi --no-root
|
||||
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
RUN poetry run prisma generate
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
@@ -54,13 +64,18 @@ ENV PATH=/opt/poetry/bin:$PATH
|
||||
# Install Python without upgrading system-managed packages
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.13 \
|
||||
python3-pip
|
||||
python3-pip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy only necessary files from builder
|
||||
COPY --from=builder /app /app
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Prisma binaries
|
||||
# Copy Node.js installation for Prisma
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
@@ -78,6 +93,7 @@ FROM server_dependencies AS migrate
|
||||
|
||||
# Migration stage only needs schema and migrations - much lighter than full backend
|
||||
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
||||
|
||||
FROM server_dependencies AS server
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
@@ -6,6 +5,8 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -15,7 +16,7 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@functools.cache
|
||||
@cached(ttl_seconds=3600)
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
from backend.util.settings import Config
|
||||
|
||||
214
autogpt_platform/backend/backend/blocks/ai_condition.py
Normal file
214
autogpt_platform/backend/backend/blocks/ai_condition.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.llm import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
AICredentials,
|
||||
AICredentialsField,
|
||||
LlmModel,
|
||||
LLMResponse,
|
||||
llm_call,
|
||||
)
|
||||
from backend.data.block import BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class AIConditionBlock(AIBlockBase):
|
||||
"""
|
||||
An AI-powered condition block that uses natural language to evaluate conditions.
|
||||
|
||||
This block allows users to define conditions in plain English (e.g., "the input is an email address",
|
||||
"the input is a city in the USA") and uses AI to determine if the input satisfies the condition.
|
||||
It provides the same yes/no data pass-through functionality as the standard ConditionBlock.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
input_value: Any = SchemaField(
|
||||
description="The input value to evaluate with the AI condition",
|
||||
placeholder="Enter the value to be evaluated (text, number, or any data)",
|
||||
)
|
||||
condition: str = SchemaField(
|
||||
description="A plaintext English description of the condition to evaluate",
|
||||
placeholder="E.g., 'the input is the body of an email', 'the input is a City in the USA', 'the input is an error or a refusal'",
|
||||
)
|
||||
yes_value: Any = SchemaField(
|
||||
description="(Optional) Value to output if the condition is true. If not provided, input_value will be used.",
|
||||
placeholder="Leave empty to use input_value, or enter a specific value",
|
||||
default=None,
|
||||
)
|
||||
no_value: Any = SchemaField(
|
||||
description="(Optional) Value to output if the condition is false. If not provided, input_value will be used.",
|
||||
placeholder="Leave empty to use input_value, or enter a specific value",
|
||||
default=None,
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: bool = SchemaField(
|
||||
description="The result of the AI condition evaluation (True or False)"
|
||||
)
|
||||
yes_output: Any = SchemaField(
|
||||
description="The output value if the condition is true"
|
||||
)
|
||||
no_output: Any = SchemaField(
|
||||
description="The output value if the condition is false"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the AI evaluation is uncertain or fails"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="553ec5b8-6c45-4299-8d75-b394d05f72ff",
|
||||
input_schema=AIConditionBlock.Input,
|
||||
output_schema=AIConditionBlock.Output,
|
||||
description="Uses AI to evaluate natural language conditions and provide conditional outputs",
|
||||
categories={BlockCategory.AI, BlockCategory.LOGIC},
|
||||
test_input={
|
||||
"input_value": "john@example.com",
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("result", True),
|
||||
("yes_output", "Valid email"),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response="true",
|
||||
tool_calls=None,
|
||||
prompt_tokens=50,
|
||||
completion_tokens=10,
|
||||
reasoning=None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def llm_call(
|
||||
self,
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list,
|
||||
max_tokens: int,
|
||||
) -> LLMResponse:
|
||||
"""Wrapper method for llm_call to enable mocking in tests."""
|
||||
return await llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
force_json_output=False,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Evaluate the AI condition and return appropriate outputs.
|
||||
"""
|
||||
# Prepare the yes and no values, using input_value as default
|
||||
yes_value = (
|
||||
input_data.yes_value
|
||||
if input_data.yes_value is not None
|
||||
else input_data.input_value
|
||||
)
|
||||
no_value = (
|
||||
input_data.no_value
|
||||
if input_data.no_value is not None
|
||||
else input_data.input_value
|
||||
)
|
||||
|
||||
# Convert input_value to string for AI evaluation
|
||||
input_str = str(input_data.input_value)
|
||||
|
||||
# Create the prompt for AI evaluation
|
||||
prompt = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an AI assistant that evaluates conditions based on input data. "
|
||||
"You must respond with only 'true' or 'false' (lowercase) to indicate whether "
|
||||
"the given condition is met by the input value. Be accurate and consider the "
|
||||
"context and meaning of both the input and the condition."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Input value: {input_str}\n"
|
||||
f"Condition to evaluate: {input_data.condition}\n\n"
|
||||
f"Does the input value satisfy the condition? Respond with only 'true' or 'false'."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
# Call the LLM
|
||||
try:
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=10, # We only expect a true/false response
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
response_text = response.response.strip().lower()
|
||||
if response_text == "true":
|
||||
result = True
|
||||
elif response_text == "false":
|
||||
result = False
|
||||
else:
|
||||
# If the response is not clear, try to interpret it using word boundaries
|
||||
import re
|
||||
|
||||
# Use word boundaries to avoid false positives like 'untrue' or '10'
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
|
||||
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
result = True
|
||||
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
result = False
|
||||
else:
|
||||
# Unclear or conflicting response - default to False and yield error
|
||||
result = False
|
||||
yield "error", f"Unclear AI response: '{response.response}'"
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
except Exception as e:
|
||||
# In case of any error, default to False to be safe
|
||||
result = False
|
||||
# Log the error but don't fail the block execution
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"AI condition evaluation failed: {str(e)}")
|
||||
yield "error", f"AI evaluation failed: {str(e)}"
|
||||
|
||||
# Yield results
|
||||
yield "result", result
|
||||
|
||||
if result:
|
||||
yield "yes_output", yes_value
|
||||
else:
|
||||
yield "no_output", no_value
|
||||
@@ -241,6 +241,7 @@ class AirtableCreateRecordsBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of created record objects")
|
||||
details: dict = SchemaField(description="Details of the created records")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -279,6 +280,9 @@ class AirtableCreateRecordsBlock(Block):
|
||||
result_records = normalized_data["records"]
|
||||
|
||||
yield "records", result_records
|
||||
details = data.get("details", None)
|
||||
if details:
|
||||
yield "details", details
|
||||
|
||||
|
||||
class AirtableUpdateRecordsBlock(Block):
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
from pydantic import SecretStr
|
||||
from e2b_code_interpreter import Result as E2BExecutionResult
|
||||
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
||||
from pydantic import BaseModel, JsonValue, SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
@@ -36,14 +38,135 @@ class ProgrammingLanguage(Enum):
|
||||
JAVA = "java"
|
||||
|
||||
|
||||
class CodeExecutionBlock(Block):
|
||||
class MainCodeExecutionResult(BaseModel):
|
||||
"""
|
||||
*Pydantic model mirroring `e2b_code_interpreter.Result`*
|
||||
|
||||
Represents the data to be displayed as a result of executing a cell in a Jupyter notebook.
|
||||
The result is similar to the structure returned by ipython kernel: https://ipython.readthedocs.io/en/stable/development/execution.html#execution-semantics
|
||||
|
||||
The result can contain multiple types of data, such as text, images, plots, etc. Each type of data is represented
|
||||
as a string, and the result can contain multiple types of data. The display calls don't have to have text representation,
|
||||
for the actual result the representation is always present for the result, the other representations are always optional.
|
||||
""" # noqa
|
||||
|
||||
class Chart(BaseModel, E2BExecutionResultChart):
|
||||
pass
|
||||
|
||||
text: Optional[str] = None
|
||||
html: Optional[str] = None
|
||||
markdown: Optional[str] = None
|
||||
svg: Optional[str] = None
|
||||
png: Optional[str] = None
|
||||
jpeg: Optional[str] = None
|
||||
pdf: Optional[str] = None
|
||||
latex: Optional[str] = None
|
||||
json: Optional[JsonValue] = None # type: ignore (reportIncompatibleMethodOverride)
|
||||
javascript: Optional[str] = None
|
||||
data: Optional[dict] = None
|
||||
chart: Optional[Chart] = None
|
||||
extra: Optional[dict] = None
|
||||
"""Extra data that can be included. Not part of the standard types."""
|
||||
|
||||
|
||||
class CodeExecutionResult(MainCodeExecutionResult):
|
||||
__doc__ = MainCodeExecutionResult.__doc__
|
||||
|
||||
is_main_result: bool = False
|
||||
"""Whether this data is the main result of the cell. Data can be produced by display calls of which can be multiple in a cell.""" # noqa
|
||||
|
||||
|
||||
class BaseE2BExecutorMixin:
|
||||
"""Shared implementation methods for E2B executor blocks."""
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
api_key: str,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
template_id: str = "",
|
||||
setup_commands: Optional[list[str]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
sandbox_id: Optional[str] = None,
|
||||
dispose_sandbox: bool = False,
|
||||
):
|
||||
"""
|
||||
Unified code execution method that handles all three use cases:
|
||||
1. Create new sandbox and execute (ExecuteCodeBlock)
|
||||
2. Create new sandbox, execute, and return sandbox_id (InstantiateCodeSandboxBlock)
|
||||
3. Connect to existing sandbox and execute (ExecuteCodeStepBlock)
|
||||
""" # noqa
|
||||
sandbox = None
|
||||
try:
|
||||
if sandbox_id:
|
||||
# Connect to existing sandbox (ExecuteCodeStepBlock case)
|
||||
sandbox = await AsyncSandbox.connect(
|
||||
sandbox_id=sandbox_id, api_key=api_key
|
||||
)
|
||||
else:
|
||||
# Create new sandbox (ExecuteCodeBlock/InstantiateCodeSandboxBlock case)
|
||||
sandbox = await AsyncSandbox.create(
|
||||
api_key=api_key, template=template_id, timeout=timeout
|
||||
)
|
||||
if setup_commands:
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Execute the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
results = execution.results
|
||||
text_output = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return results, text_output, stdout_logs, stderr_logs, sandbox.sandbox_id
|
||||
finally:
|
||||
# Dispose of sandbox if requested to reduce usage costs
|
||||
if dispose_sandbox and sandbox:
|
||||
await sandbox.kill()
|
||||
|
||||
def process_execution_results(
|
||||
self, results: list[E2BExecutionResult]
|
||||
) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]:
|
||||
"""Process and filter execution results."""
|
||||
# Filter out empty formats and convert to dicts
|
||||
processed_results = [
|
||||
{
|
||||
f: value
|
||||
for f in [*r.formats(), "extra", "is_main_result"]
|
||||
if (value := getattr(r, f, None)) is not None
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
if main_result := next(
|
||||
(r for r in processed_results if r.get("is_main_result")), None
|
||||
):
|
||||
# Make main_result a copy we can modify & remove is_main_result
|
||||
(main_result := {**main_result}).pop("is_main_result")
|
||||
|
||||
return main_result, processed_results
|
||||
|
||||
|
||||
class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
# TODO : Add support to upload and download files
|
||||
# Currently, You can customized the CPU and Memory, only by creating a pre customized sandbox template
|
||||
# NOTE: Currently, you can only customize the CPU and Memory
|
||||
# by creating a pre customized sandbox template
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
),
|
||||
)
|
||||
|
||||
# Todo : Option to run commond in background
|
||||
@@ -76,6 +199,14 @@ class CodeExecutionBlock(Block):
|
||||
description="Execution timeout in seconds", default=300
|
||||
)
|
||||
|
||||
dispose_sandbox: bool = SchemaField(
|
||||
description=(
|
||||
"Whether to dispose of the sandbox immediately after execution. "
|
||||
"If disabled, the sandbox will run until its timeout expires."
|
||||
),
|
||||
default=True,
|
||||
)
|
||||
|
||||
template_id: str = SchemaField(
|
||||
description=(
|
||||
"You can use an E2B sandbox template by entering its ID here. "
|
||||
@@ -87,7 +218,16 @@ class CodeExecutionBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
main_result: MainCodeExecutionResult = SchemaField(
|
||||
title="Main Result", description="The main result from the code execution"
|
||||
)
|
||||
results: list[CodeExecutionResult] = SchemaField(
|
||||
description="List of results from the code execution"
|
||||
)
|
||||
response: str = SchemaField(
|
||||
title="Main Text Output",
|
||||
description="Text output (if any) of the main execution result",
|
||||
)
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
@@ -97,10 +237,10 @@ class CodeExecutionBlock(Block):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0b02b072-abe7-11ef-8372-fb5d162dd712",
|
||||
description="Executes code in an isolated sandbox environment with internet access.",
|
||||
description="Executes code in a sandbox environment with internet access.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=CodeExecutionBlock.Input,
|
||||
output_schema=CodeExecutionBlock.Output,
|
||||
input_schema=ExecuteCodeBlock.Input,
|
||||
output_schema=ExecuteCodeBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -111,91 +251,59 @@ class CodeExecutionBlock(Block):
|
||||
"template_id": "",
|
||||
},
|
||||
test_output=[
|
||||
("results", []),
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
"sandbox_id", # sandbox_id
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
setup_commands: list[str],
|
||||
timeout: int,
|
||||
api_key: str,
|
||||
template_id: str,
|
||||
):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = await AsyncSandbox.create(api_key=api_key, timeout=timeout)
|
||||
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
response, stdout_logs, stderr_logs = await self.execute_code(
|
||||
input_data.code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
input_data.timeout,
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.template_id,
|
||||
results, text_output, stdout, stderr, _ = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.code,
|
||||
language=input_data.language,
|
||||
template_id=input_data.template_id,
|
||||
setup_commands=input_data.setup_commands,
|
||||
timeout=input_data.timeout,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
# Determine result object shape & filter out empty formats
|
||||
main_result, results = self.process_execution_results(results)
|
||||
if main_result:
|
||||
yield "main_result", main_result
|
||||
yield "results", results
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class InstantiationBlock(Block):
|
||||
class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
)
|
||||
)
|
||||
|
||||
# Todo : Option to run commond in background
|
||||
@@ -240,7 +348,10 @@ class InstantiationBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
sandbox_id: str = SchemaField(description="ID of the sandbox instance")
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
response: str = SchemaField(
|
||||
title="Text Result",
|
||||
description="Text result (if any) of the setup code execution",
|
||||
)
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
@@ -250,10 +361,13 @@ class InstantiationBlock(Block):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ff0861c9-1726-4aec-9e5b-bf53f3622112",
|
||||
description="Instantiate an isolated sandbox environment with internet access where to execute code in.",
|
||||
description=(
|
||||
"Instantiate a sandbox environment with internet access "
|
||||
"in which you can execute code with the Execute Code Step block."
|
||||
),
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=InstantiationBlock.Input,
|
||||
output_schema=InstantiationBlock.Output,
|
||||
input_schema=InstantiateCodeSandboxBlock.Input,
|
||||
output_schema=InstantiateCodeSandboxBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -269,11 +383,12 @@ class InstantiationBlock(Block):
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda setup_code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"sandbox_id",
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
"sandbox_id", # sandbox_id
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -282,78 +397,38 @@ class InstantiationBlock(Block):
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
sandbox_id, response, stdout_logs, stderr_logs = await self.execute_code(
|
||||
input_data.setup_code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
input_data.timeout,
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.template_id,
|
||||
_, text_output, stdout, stderr, sandbox_id = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.setup_code,
|
||||
language=input_data.language,
|
||||
template_id=input_data.template_id,
|
||||
setup_commands=input_data.setup_commands,
|
||||
timeout=input_data.timeout,
|
||||
)
|
||||
if sandbox_id:
|
||||
yield "sandbox_id", sandbox_id
|
||||
else:
|
||||
yield "error", "Sandbox ID not found"
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
setup_commands: list[str],
|
||||
timeout: int,
|
||||
api_key: str,
|
||||
template_id: str,
|
||||
):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = await AsyncSandbox.create(api_key=api_key, timeout=timeout)
|
||||
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return sandbox.sandbox_id, response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class StepExecutionBlock(Block):
|
||||
class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
),
|
||||
)
|
||||
|
||||
sandbox_id: str = SchemaField(
|
||||
@@ -374,8 +449,22 @@ class StepExecutionBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
dispose_sandbox: bool = SchemaField(
|
||||
description="Whether to dispose of the sandbox after executing this code.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
main_result: MainCodeExecutionResult = SchemaField(
|
||||
title="Main Result", description="The main result from the code execution"
|
||||
)
|
||||
results: list[CodeExecutionResult] = SchemaField(
|
||||
description="List of results from the code execution"
|
||||
)
|
||||
response: str = SchemaField(
|
||||
title="Main Text Output",
|
||||
description="Text output (if any) of the main execution result",
|
||||
)
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
@@ -385,10 +474,10 @@ class StepExecutionBlock(Block):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="82b59b8e-ea10-4d57-9161-8b169b0adba6",
|
||||
description="Execute code in a previously instantiated sandbox environment.",
|
||||
description="Execute code in a previously instantiated sandbox.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=StepExecutionBlock.Input,
|
||||
output_schema=StepExecutionBlock.Output,
|
||||
input_schema=ExecuteCodeStepBlock.Input,
|
||||
output_schema=ExecuteCodeStepBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -397,61 +486,43 @@ class StepExecutionBlock(Block):
|
||||
"language": ProgrammingLanguage.PYTHON.value,
|
||||
},
|
||||
test_output=[
|
||||
("results", []),
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_step_code": lambda sandbox_id, step_code, language, api_key: (
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
"execute_code": lambda api_key, code, language, sandbox_id, dispose_sandbox: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
sandbox_id, # sandbox_id
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_step_code(
|
||||
self,
|
||||
sandbox_id: str,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
api_key: str,
|
||||
):
|
||||
try:
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id=sandbox_id, api_key=api_key)
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not found")
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(code, language=language.value)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
response, stdout_logs, stderr_logs = await self.execute_step_code(
|
||||
input_data.sandbox_id,
|
||||
input_data.step_code,
|
||||
input_data.language,
|
||||
credentials.api_key.get_secret_value(),
|
||||
results, text_output, stdout, stderr, _ = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.step_code,
|
||||
language=input_data.language,
|
||||
sandbox_id=input_data.sandbox_id,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
# Determine result object shape & filter out empty formats
|
||||
main_result, results = self.process_execution_results(results)
|
||||
if main_result:
|
||||
yield "main_result", main_result
|
||||
yield "results", results
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -90,7 +90,7 @@ class CodeExtractionBlock(Block):
|
||||
for aliases in language_aliases.values()
|
||||
for alias in aliases
|
||||
)
|
||||
+ r")\s+[\s\S]*?```"
|
||||
+ r")[ \t]*\n[\s\S]*?```"
|
||||
)
|
||||
|
||||
remaining_text = re.sub(pattern, "", input_data.text).strip()
|
||||
@@ -103,7 +103,9 @@ class CodeExtractionBlock(Block):
|
||||
# Escape special regex characters in the language string
|
||||
language = re.escape(language)
|
||||
# Extract all code blocks enclosed in ```language``` blocks
|
||||
pattern = re.compile(rf"```{language}\s+(.*?)```", re.DOTALL | re.IGNORECASE)
|
||||
pattern = re.compile(
|
||||
rf"```{language}[ \t]*\n(.*?)\n```", re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
matches = pattern.finditer(text)
|
||||
# Combine all code blocks for this language with newlines between them
|
||||
code_blocks = [match.group(1).strip() for match in matches]
|
||||
|
||||
@@ -66,6 +66,7 @@ class AddToDictionaryBlock(Block):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
|
||||
advanced=False,
|
||||
)
|
||||
key: str = SchemaField(
|
||||
default="",
|
||||
|
||||
@@ -113,6 +113,7 @@ class DataForSeoClient:
|
||||
include_serp_info: bool = False,
|
||||
include_clickstream_data: bool = False,
|
||||
limit: int = 100,
|
||||
depth: Optional[int] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get related keywords from DataForSEO Labs.
|
||||
@@ -125,6 +126,7 @@ class DataForSeoClient:
|
||||
include_serp_info: Include SERP data
|
||||
include_clickstream_data: Include clickstream metrics
|
||||
limit: Maximum number of results (up to 3000)
|
||||
depth: Keyword search depth (0-4), controls number of returned keywords
|
||||
|
||||
Returns:
|
||||
API response with related keywords
|
||||
@@ -148,6 +150,8 @@ class DataForSeoClient:
|
||||
task_data["include_clickstream_data"] = include_clickstream_data
|
||||
if limit is not None:
|
||||
task_data["limit"] = limit
|
||||
if depth is not None:
|
||||
task_data["depth"] = depth
|
||||
|
||||
payload = [task_data]
|
||||
|
||||
|
||||
@@ -90,6 +90,7 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -161,43 +162,52 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the keyword suggestions query."""
|
||||
client = DataForSeoClient(credentials)
|
||||
try:
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info") if input_data.include_serp_info else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to fetch keyword suggestions: {str(e)}"
|
||||
|
||||
|
||||
class KeywordSuggestionExtractorBlock(Block):
|
||||
|
||||
@@ -78,6 +78,12 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
ge=1,
|
||||
le=3000,
|
||||
)
|
||||
depth: int = SchemaField(
|
||||
description="Keyword search depth (0-4). Controls the number of returned keywords: 0=1 keyword, 1=~8 keywords, 2=~72 keywords, 3=~584 keywords, 4=~4680 keywords",
|
||||
default=1,
|
||||
ge=0,
|
||||
le=4,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
related_keywords: List[RelatedKeyword] = SchemaField(
|
||||
@@ -92,6 +98,7 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -154,6 +161,7 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
include_serp_info=input_data.include_serp_info,
|
||||
include_clickstream_data=input_data.include_clickstream_data,
|
||||
limit=input_data.limit,
|
||||
depth=input_data.depth,
|
||||
)
|
||||
|
||||
async def run(
|
||||
@@ -164,50 +172,60 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the related keywords query."""
|
||||
client = DataForSeoClient(credentials)
|
||||
try:
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get("competition"),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
)
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
# Ensure items is never None
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get(
|
||||
"competition"
|
||||
),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get(
|
||||
"keyword_properties", {}
|
||||
).get("keyword_difficulty"),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to fetch related keywords: {str(e)}"
|
||||
|
||||
|
||||
class RelatedKeywordExtractorBlock(Block):
|
||||
|
||||
@@ -4,13 +4,13 @@ import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.request import Requests
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._auth import (
|
||||
@@ -114,10 +114,9 @@ class ReadDiscordMessagesBlock(Block):
|
||||
if message.attachments:
|
||||
attachment = message.attachments[0] # Process the first attachment
|
||||
if attachment.filename.endswith((".txt", ".py")):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(attachment.url) as response:
|
||||
file_content = response.text()
|
||||
self.output_data += f"\n\nFile from user: {attachment.filename}\nContent: {file_content}"
|
||||
response = await Requests().get(attachment.url)
|
||||
file_content = response.text()
|
||||
self.output_data += f"\n\nFile from user: {attachment.filename}\nContent: {file_content}"
|
||||
|
||||
await client.close()
|
||||
|
||||
@@ -171,11 +170,11 @@ class SendDiscordMessageBlock(Block):
|
||||
description="The content of the message to send"
|
||||
)
|
||||
channel_name: str = SchemaField(
|
||||
description="The name of the channel the message will be sent to"
|
||||
description="Channel ID or channel name to send the message to"
|
||||
)
|
||||
server_name: str = SchemaField(
|
||||
description="The name of the server where the channel is located",
|
||||
advanced=True, # Optional field for server name
|
||||
description="Server name (only needed if using channel name)",
|
||||
advanced=True,
|
||||
default="",
|
||||
)
|
||||
|
||||
@@ -231,25 +230,49 @@ class SendDiscordMessageBlock(Block):
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print(f"Logged in as {client.user}")
|
||||
for guild in client.guilds:
|
||||
if server_name and guild.name != server_name:
|
||||
continue
|
||||
for channel in guild.text_channels:
|
||||
if channel.name == channel_name:
|
||||
# Split message into chunks if it exceeds 2000 characters
|
||||
chunks = self.chunk_message(message_content)
|
||||
last_message = None
|
||||
for chunk in chunks:
|
||||
last_message = await channel.send(chunk)
|
||||
result["status"] = "Message sent"
|
||||
result["message_id"] = (
|
||||
str(last_message.id) if last_message else ""
|
||||
)
|
||||
result["channel_id"] = str(channel.id)
|
||||
await client.close()
|
||||
return
|
||||
channel = None
|
||||
|
||||
result["status"] = "Channel not found"
|
||||
# Try to parse as channel ID first
|
||||
try:
|
||||
channel_id = int(channel_name)
|
||||
channel = client.get_channel(channel_id)
|
||||
except ValueError:
|
||||
# Not a valid ID, will try name lookup
|
||||
pass
|
||||
|
||||
# If not found by ID (or not an ID), try name lookup
|
||||
if not channel:
|
||||
for guild in client.guilds:
|
||||
if server_name and guild.name != server_name:
|
||||
continue
|
||||
for ch in guild.text_channels:
|
||||
if ch.name == channel_name:
|
||||
channel = ch
|
||||
break
|
||||
if channel:
|
||||
break
|
||||
|
||||
if not channel:
|
||||
result["status"] = f"Channel not found: {channel_name}"
|
||||
await client.close()
|
||||
return
|
||||
|
||||
# Type check - ensure it's a text channel that can send messages
|
||||
if not hasattr(channel, "send"):
|
||||
result["status"] = (
|
||||
f"Channel {channel_name} cannot receive messages (not a text channel)"
|
||||
)
|
||||
await client.close()
|
||||
return
|
||||
|
||||
# Split message into chunks if it exceeds 2000 characters
|
||||
chunks = self.chunk_message(message_content)
|
||||
last_message = None
|
||||
for chunk in chunks:
|
||||
last_message = await channel.send(chunk) # type: ignore
|
||||
result["status"] = "Message sent"
|
||||
result["message_id"] = str(last_message.id) if last_message else ""
|
||||
result["channel_id"] = str(channel.id)
|
||||
await client.close()
|
||||
|
||||
await client.start(token)
|
||||
@@ -675,16 +698,15 @@ class SendDiscordFileBlock(Block):
|
||||
|
||||
elif file.startswith(("http://", "https://")):
|
||||
# URL - download the file
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(file) as response:
|
||||
file_bytes = await response.read()
|
||||
response = await Requests().get(file)
|
||||
file_bytes = response.content
|
||||
|
||||
# Try to get filename from URL if not provided
|
||||
if not filename:
|
||||
from urllib.parse import urlparse
|
||||
# Try to get filename from URL if not provided
|
||||
if not filename:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
path = urlparse(file).path
|
||||
detected_filename = Path(path).name or "download"
|
||||
path = urlparse(file).path
|
||||
detected_filename = Path(path).name or "download"
|
||||
else:
|
||||
# Local file path - read from stored media file
|
||||
# This would be a path from a previous block's output
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
"""Utility functions for converting between our ScrapeFormat enum and firecrawl FormatOption types."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from firecrawl.v2.types import FormatOption, ScreenshotFormat
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
|
||||
|
||||
def convert_to_format_options(
|
||||
formats: List[ScrapeFormat],
|
||||
) -> List[FormatOption]:
|
||||
"""Convert our ScrapeFormat enum values to firecrawl FormatOption types.
|
||||
|
||||
Handles special cases like screenshot@fullPage which needs to be converted
|
||||
to a ScreenshotFormat object.
|
||||
"""
|
||||
result: List[FormatOption] = []
|
||||
|
||||
for format_enum in formats:
|
||||
if format_enum.value == "screenshot@fullPage":
|
||||
# Special case: convert to ScreenshotFormat with full_page=True
|
||||
result.append(ScreenshotFormat(type="screenshot", full_page=True))
|
||||
else:
|
||||
# Regular string literals
|
||||
result.append(format_enum.value)
|
||||
|
||||
return result
|
||||
@@ -1,8 +1,9 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp, ScrapeOptions
|
||||
from firecrawl import FirecrawlApp
|
||||
from firecrawl.v2.types import ScrapeOptions
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -14,21 +15,10 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class FirecrawlCrawlBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
url: str = SchemaField(description="The URL to crawl")
|
||||
@@ -78,18 +68,17 @@ class FirecrawlCrawlBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
crawl_result = app.crawl_url(
|
||||
crawl_result = app.crawl(
|
||||
input_data.url,
|
||||
limit=input_data.limit,
|
||||
scrape_options=ScrapeOptions(
|
||||
formats=[format.value for format in input_data.formats],
|
||||
onlyMainContent=input_data.only_main_content,
|
||||
maxAge=input_data.max_age,
|
||||
waitFor=input_data.wait_for,
|
||||
formats=convert_to_format_options(input_data.formats),
|
||||
only_main_content=input_data.only_main_content,
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
),
|
||||
)
|
||||
yield "data", crawl_result.data
|
||||
@@ -101,7 +90,7 @@ class FirecrawlCrawlBlock(Block):
|
||||
elif f == ScrapeFormat.HTML:
|
||||
yield "html", data.html
|
||||
elif f == ScrapeFormat.RAW_HTML:
|
||||
yield "raw_html", data.rawHtml
|
||||
yield "raw_html", data.raw_html
|
||||
elif f == ScrapeFormat.LINKS:
|
||||
yield "links", data.links
|
||||
elif f == ScrapeFormat.SCREENSHOT:
|
||||
@@ -109,6 +98,6 @@ class FirecrawlCrawlBlock(Block):
|
||||
elif f == ScrapeFormat.SCREENSHOT_FULL_PAGE:
|
||||
yield "screenshot_full_page", data.screenshot
|
||||
elif f == ScrapeFormat.CHANGE_TRACKING:
|
||||
yield "change_tracking", data.changeTracking
|
||||
yield "change_tracking", data.change_tracking
|
||||
elif f == ScrapeFormat.JSON:
|
||||
yield "json", data.json
|
||||
|
||||
@@ -20,7 +20,6 @@ from ._config import firecrawl
|
||||
|
||||
@cost(BlockCost(2, BlockCostType.RUN))
|
||||
class FirecrawlExtractBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
urls: list[str] = SchemaField(
|
||||
@@ -53,7 +52,6 @@ class FirecrawlExtractBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
extract_result = app.extract(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from backend.sdk import (
|
||||
@@ -14,14 +16,16 @@ from ._config import firecrawl
|
||||
|
||||
|
||||
class FirecrawlMapWebsiteBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
|
||||
url: str = SchemaField(description="The website url to map")
|
||||
|
||||
class Output(BlockSchema):
|
||||
links: list[str] = SchemaField(description="The links of the website")
|
||||
links: list[str] = SchemaField(description="List of URLs found on the website")
|
||||
results: list[dict[str, Any]] = SchemaField(
|
||||
description="List of search results with url, title, and description"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -35,12 +39,22 @@ class FirecrawlMapWebsiteBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
map_result = app.map_url(
|
||||
map_result = app.map(
|
||||
url=input_data.url,
|
||||
)
|
||||
|
||||
yield "links", map_result.links
|
||||
# Convert SearchResult objects to dicts
|
||||
results_data = [
|
||||
{
|
||||
"url": link.url,
|
||||
"title": link.title,
|
||||
"description": link.description,
|
||||
}
|
||||
for link in map_result.links
|
||||
]
|
||||
|
||||
yield "links", [link.url for link in map_result.links]
|
||||
yield "results", results_data
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -14,21 +14,10 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class FirecrawlScrapeBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
url: str = SchemaField(description="The URL to crawl")
|
||||
@@ -78,12 +67,11 @@ class FirecrawlScrapeBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
scrape_result = app.scrape_url(
|
||||
scrape_result = app.scrape(
|
||||
input_data.url,
|
||||
formats=[format.value for format in input_data.formats],
|
||||
formats=convert_to_format_options(input_data.formats),
|
||||
only_main_content=input_data.only_main_content,
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
@@ -96,7 +84,7 @@ class FirecrawlScrapeBlock(Block):
|
||||
elif f == ScrapeFormat.HTML:
|
||||
yield "html", scrape_result.html
|
||||
elif f == ScrapeFormat.RAW_HTML:
|
||||
yield "raw_html", scrape_result.rawHtml
|
||||
yield "raw_html", scrape_result.raw_html
|
||||
elif f == ScrapeFormat.LINKS:
|
||||
yield "links", scrape_result.links
|
||||
elif f == ScrapeFormat.SCREENSHOT:
|
||||
@@ -104,6 +92,6 @@ class FirecrawlScrapeBlock(Block):
|
||||
elif f == ScrapeFormat.SCREENSHOT_FULL_PAGE:
|
||||
yield "screenshot_full_page", scrape_result.screenshot
|
||||
elif f == ScrapeFormat.CHANGE_TRACKING:
|
||||
yield "change_tracking", scrape_result.changeTracking
|
||||
yield "change_tracking", scrape_result.change_tracking
|
||||
elif f == ScrapeFormat.JSON:
|
||||
yield "json", scrape_result.json
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp, ScrapeOptions
|
||||
from firecrawl import FirecrawlApp
|
||||
from firecrawl.v2.types import ScrapeOptions
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -14,21 +15,10 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class FirecrawlSearchBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
query: str = SchemaField(description="The query to search for")
|
||||
@@ -61,7 +51,6 @@ class FirecrawlSearchBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
@@ -69,11 +58,12 @@ class FirecrawlSearchBlock(Block):
|
||||
input_data.query,
|
||||
limit=input_data.limit,
|
||||
scrape_options=ScrapeOptions(
|
||||
formats=[format.value for format in input_data.formats],
|
||||
maxAge=input_data.max_age,
|
||||
waitFor=input_data.wait_for,
|
||||
formats=convert_to_format_options(input_data.formats) or None,
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
),
|
||||
)
|
||||
yield "data", scrape_result
|
||||
for site in scrape_result.data:
|
||||
yield "site", site
|
||||
if hasattr(scrape_result, "web") and scrape_result.web:
|
||||
for site in scrape_result.web:
|
||||
yield "site", site
|
||||
|
||||
@@ -10,7 +10,6 @@ from backend.util.settings import Config
|
||||
from backend.util.text import TextFormatter
|
||||
from backend.util.type import LongTextType, MediaFileType, ShortTextType
|
||||
|
||||
formatter = TextFormatter()
|
||||
config = Config()
|
||||
|
||||
|
||||
@@ -132,6 +131,11 @@ class AgentOutputBlock(Block):
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
escape_html: bool = SchemaField(
|
||||
default=False,
|
||||
advanced=True,
|
||||
description="Whether to escape special characters in the inserted values to be HTML-safe. Enable for HTML output, disable for plain text.",
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to treat the output as advanced.",
|
||||
default=False,
|
||||
@@ -193,6 +197,7 @@ class AgentOutputBlock(Block):
|
||||
"""
|
||||
if input_data.format:
|
||||
try:
|
||||
formatter = TextFormatter(autoescape=input_data.escape_html)
|
||||
yield "output", formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
@@ -549,6 +554,89 @@ class AgentToggleInputBlock(AgentInputBlock):
|
||||
)
|
||||
|
||||
|
||||
class AgentTableInputBlock(AgentInputBlock):
|
||||
"""
|
||||
This block allows users to input data in a table format.
|
||||
|
||||
Configure the table columns at build time, then users can input
|
||||
rows of data at runtime. Each row is output as a dictionary
|
||||
with column names as keys.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[list[dict[str, Any]]] = SchemaField(
|
||||
description="The table data as a list of dictionaries.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
column_headers: list[str] = SchemaField(
|
||||
description="Column headers for the table.",
|
||||
default_factory=lambda: ["Column 1", "Column 2", "Column 3"],
|
||||
advanced=False,
|
||||
title="Column Headers",
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
"""Generate schema for the value field with table format."""
|
||||
schema = super().generate_schema()
|
||||
schema["type"] = "array"
|
||||
schema["format"] = "table"
|
||||
schema["items"] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
header: {"type": "string"}
|
||||
for header in (
|
||||
self.column_headers or ["Column 1", "Column 2", "Column 3"]
|
||||
)
|
||||
},
|
||||
}
|
||||
if self.value is not None:
|
||||
schema["default"] = self.value
|
||||
return schema
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: list[dict[str, Any]] = SchemaField(
|
||||
description="The table data as a list of dictionaries with headers as keys."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5603b273-f41e-4020-af7d-fbc9c6a8d928",
|
||||
description="Block for table data input with customizable headers.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentTableInputBlock.Input,
|
||||
output_schema=AgentTableInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"name": "test_table",
|
||||
"column_headers": ["Name", "Age", "City"],
|
||||
"value": [
|
||||
{"Name": "John", "Age": "30", "City": "New York"},
|
||||
{"Name": "Jane", "Age": "25", "City": "London"},
|
||||
],
|
||||
"description": "Example table input",
|
||||
}
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"result",
|
||||
[
|
||||
{"Name": "John", "Age": "30", "City": "New York"},
|
||||
{"Name": "Jane", "Age": "25", "City": "London"},
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Yields the table data as a list of dictionaries.
|
||||
"""
|
||||
# Pass through the value, defaulting to empty list if None
|
||||
yield "result", input_data.value if input_data.value is not None else []
|
||||
|
||||
|
||||
IO_BLOCK_IDs = [
|
||||
AgentInputBlock().id,
|
||||
AgentOutputBlock().id,
|
||||
@@ -560,4 +648,5 @@ IO_BLOCK_IDs = [
|
||||
AgentFileInputBlock().id,
|
||||
AgentDropdownInputBlock().id,
|
||||
AgentToggleInputBlock().id,
|
||||
AgentTableInputBlock().id,
|
||||
]
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Any
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.json import json
|
||||
from backend.util.json import loads
|
||||
|
||||
|
||||
class StepThroughItemsBlock(Block):
|
||||
@@ -54,20 +54,43 @@ class StepThroughItemsBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Security fix: Add limits to prevent DoS from large iterations
|
||||
MAX_ITEMS = 10000 # Maximum items to iterate
|
||||
MAX_ITEM_SIZE = 1024 * 1024 # 1MB per item
|
||||
|
||||
for data in [input_data.items, input_data.items_object, input_data.items_str]:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# Limit string size before parsing
|
||||
if isinstance(data, str):
|
||||
items = json.loads(data)
|
||||
if len(data) > MAX_ITEM_SIZE:
|
||||
raise ValueError(
|
||||
f"Input too large: {len(data)} bytes > {MAX_ITEM_SIZE} bytes"
|
||||
)
|
||||
items = loads(data)
|
||||
else:
|
||||
items = data
|
||||
|
||||
# Check total item count
|
||||
if isinstance(items, (list, dict)):
|
||||
if len(items) > MAX_ITEMS:
|
||||
raise ValueError(f"Too many items: {len(items)} > {MAX_ITEMS}")
|
||||
|
||||
iteration_count = 0
|
||||
if isinstance(items, dict):
|
||||
# If items is a dictionary, iterate over its values
|
||||
for item in items.values():
|
||||
yield "item", item
|
||||
yield "key", item
|
||||
for key, value in items.items():
|
||||
if iteration_count >= MAX_ITEMS:
|
||||
break
|
||||
yield "item", value
|
||||
yield "key", key # Fixed: should yield key, not item
|
||||
iteration_count += 1
|
||||
else:
|
||||
# If items is a list, iterate over the list
|
||||
for index, item in enumerate(items):
|
||||
if iteration_count >= MAX_ITEMS:
|
||||
break
|
||||
yield "item", item
|
||||
yield "key", index
|
||||
iteration_count += 1
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from typing import List
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks.jina._auth import (
|
||||
JinaCredentials,
|
||||
JinaCredentialsField,
|
||||
@@ -10,6 +13,12 @@ from backend.data.model import SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
class Reference(TypedDict):
|
||||
url: str
|
||||
keyQuote: str
|
||||
isSupportive: bool
|
||||
|
||||
|
||||
class FactCheckerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
statement: str = SchemaField(
|
||||
@@ -23,6 +32,10 @@ class FactCheckerBlock(Block):
|
||||
)
|
||||
result: bool = SchemaField(description="The result of the factuality check")
|
||||
reason: str = SchemaField(description="The reason for the factuality result")
|
||||
references: List[Reference] = SchemaField(
|
||||
description="List of references supporting or contradicting the statement",
|
||||
default=[],
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the check fails")
|
||||
|
||||
def __init__(self):
|
||||
@@ -53,5 +66,11 @@ class FactCheckerBlock(Block):
|
||||
yield "factuality", data["factuality"]
|
||||
yield "result", data["result"]
|
||||
yield "reason", data["reason"]
|
||||
|
||||
# Yield references if present in the response
|
||||
if "references" in data:
|
||||
yield "references", data["references"]
|
||||
else:
|
||||
yield "references", []
|
||||
else:
|
||||
raise RuntimeError(f"Expected 'data' key not found in response: {data}")
|
||||
|
||||
@@ -62,10 +62,10 @@ TEST_CREDENTIALS_OAUTH = OAuth2Credentials(
|
||||
title="Mock Linear API key",
|
||||
username="mock-linear-username",
|
||||
access_token=SecretStr("mock-linear-access-token"),
|
||||
access_token_expires_at=None,
|
||||
access_token_expires_at=1672531200, # Mock expiration time for short-lived token
|
||||
refresh_token=SecretStr("mock-linear-refresh-token"),
|
||||
refresh_token_expires_at=None,
|
||||
scopes=["mock-linear-scopes"],
|
||||
scopes=["read", "write"],
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_API_KEY = APIKeyCredentials(
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
Linear OAuth handler implementation.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
@@ -38,8 +40,9 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.auth_base_url = "https://linear.app/oauth/authorize"
|
||||
self.token_url = "https://api.linear.app/oauth/token" # Correct token URL
|
||||
self.token_url = "https://api.linear.app/oauth/token"
|
||||
self.revoke_url = "https://api.linear.app/oauth/revoke"
|
||||
self.migrate_url = "https://api.linear.app/oauth/migrate_old_token"
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
@@ -82,19 +85,84 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
|
||||
return True # Linear doesn't return JSON on successful revoke
|
||||
|
||||
async def migrate_old_token(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
"""
|
||||
Migrate an old long-lived token to a new short-lived token with refresh token.
|
||||
|
||||
This uses Linear's /oauth/migrate_old_token endpoint to exchange current
|
||||
long-lived tokens for short-lived tokens with refresh tokens without
|
||||
requiring users to re-authorize.
|
||||
"""
|
||||
if not credentials.access_token:
|
||||
raise ValueError("No access token to migrate")
|
||||
|
||||
request_body = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {credentials.access_token.get_secret_value()}",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
self.migrate_url, data=request_body, headers=headers
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_message = error_data.get("error", "Unknown error")
|
||||
error_description = error_data.get("error_description", "")
|
||||
if error_description:
|
||||
error_message = f"{error_message}: {error_description}"
|
||||
except json.JSONDecodeError:
|
||||
error_message = response.text
|
||||
raise LinearAPIException(
|
||||
f"Failed to migrate Linear token ({response.status}): {error_message}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
# Extract token expiration
|
||||
now = int(time.time())
|
||||
expires_in = token_data.get("expires_in")
|
||||
access_token_expires_at = None
|
||||
if expires_in:
|
||||
access_token_expires_at = now + expires_in
|
||||
|
||||
new_credentials = OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=credentials.title,
|
||||
username=credentials.username,
|
||||
access_token=token_data["access_token"],
|
||||
scopes=credentials.scopes, # Preserve original scopes
|
||||
refresh_token=token_data.get("refresh_token"),
|
||||
access_token_expires_at=access_token_expires_at,
|
||||
refresh_token_expires_at=None,
|
||||
)
|
||||
|
||||
new_credentials.id = credentials.id
|
||||
return new_credentials
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
if not credentials.refresh_token:
|
||||
raise ValueError(
|
||||
"No refresh token available."
|
||||
) # Linear uses non-expiring tokens
|
||||
"No refresh token available. Token may need to be migrated to the new refresh token system."
|
||||
)
|
||||
|
||||
return await self._request_tokens(
|
||||
{
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
"grant_type": "refresh_token",
|
||||
}
|
||||
},
|
||||
current_credentials=credentials,
|
||||
)
|
||||
|
||||
async def _request_tokens(
|
||||
@@ -102,16 +170,33 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
params: dict[str, str],
|
||||
current_credentials: Optional[OAuth2Credentials] = None,
|
||||
) -> OAuth2Credentials:
|
||||
# Determine if this is a refresh token request
|
||||
is_refresh = params.get("grant_type") == "refresh_token"
|
||||
|
||||
# Build request body with appropriate grant_type
|
||||
request_body = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"grant_type": "authorization_code", # Ensure grant_type is correct
|
||||
**params,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded"
|
||||
} # Correct header for token request
|
||||
# Set default grant_type if not provided
|
||||
if "grant_type" not in request_body:
|
||||
request_body["grant_type"] = "authorization_code"
|
||||
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
# For refresh token requests, support HTTP Basic Authentication as recommended
|
||||
if is_refresh:
|
||||
# Option 1: Use HTTP Basic Auth (preferred by Linear)
|
||||
client_credentials = f"{self.client_id}:{self.client_secret}"
|
||||
encoded_credentials = base64.b64encode(client_credentials.encode()).decode()
|
||||
headers["Authorization"] = f"Basic {encoded_credentials}"
|
||||
|
||||
# Remove client credentials from body when using Basic Auth
|
||||
request_body.pop("client_id", None)
|
||||
request_body.pop("client_secret", None)
|
||||
|
||||
response = await Requests().post(
|
||||
self.token_url, data=request_body, headers=headers
|
||||
)
|
||||
@@ -120,6 +205,9 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_message = error_data.get("error", "Unknown error")
|
||||
error_description = error_data.get("error_description", "")
|
||||
if error_description:
|
||||
error_message = f"{error_message}: {error_description}"
|
||||
except json.JSONDecodeError:
|
||||
error_message = response.text
|
||||
raise LinearAPIException(
|
||||
@@ -129,27 +217,84 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
# Note: Linear access tokens do not expire, so we set expires_at to None
|
||||
# Extract token expiration if provided (for new refresh token implementation)
|
||||
now = int(time.time())
|
||||
expires_in = token_data.get("expires_in")
|
||||
access_token_expires_at = None
|
||||
if expires_in:
|
||||
access_token_expires_at = now + expires_in
|
||||
|
||||
# Get username - preserve from current credentials if refreshing
|
||||
username = None
|
||||
if current_credentials and is_refresh:
|
||||
username = current_credentials.username
|
||||
elif "user" in token_data:
|
||||
username = token_data["user"].get("name", "Unknown User")
|
||||
else:
|
||||
# Fetch username using the access token
|
||||
username = await self._request_username(token_data["access_token"])
|
||||
|
||||
new_credentials = OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=current_credentials.title if current_credentials else None,
|
||||
username=token_data.get("user", {}).get(
|
||||
"name", "Unknown User"
|
||||
), # extract name or set appropriate
|
||||
username=username or "Unknown User",
|
||||
access_token=token_data["access_token"],
|
||||
scopes=token_data["scope"].split(
|
||||
","
|
||||
), # Linear returns comma-separated scopes
|
||||
refresh_token=token_data.get(
|
||||
"refresh_token"
|
||||
), # Linear uses non-expiring tokens so this might be null
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=(
|
||||
token_data["scope"].split(",")
|
||||
if "scope" in token_data
|
||||
else (current_credentials.scopes if current_credentials else [])
|
||||
),
|
||||
refresh_token=token_data.get("refresh_token"),
|
||||
access_token_expires_at=access_token_expires_at,
|
||||
refresh_token_expires_at=None, # Linear doesn't provide refresh token expiration
|
||||
)
|
||||
|
||||
if current_credentials:
|
||||
new_credentials.id = current_credentials.id
|
||||
|
||||
return new_credentials
|
||||
|
||||
async def get_access_token(self, credentials: OAuth2Credentials) -> str:
|
||||
"""
|
||||
Returns a valid access token, handling migration and refresh as needed.
|
||||
|
||||
This overrides the base implementation to handle Linear's token migration
|
||||
from old long-lived tokens to new short-lived tokens with refresh tokens.
|
||||
"""
|
||||
# If token has no expiration and no refresh token, it might be an old token
|
||||
# that needs migration
|
||||
if (
|
||||
credentials.access_token_expires_at is None
|
||||
and credentials.refresh_token is None
|
||||
):
|
||||
try:
|
||||
# Attempt to migrate the old token
|
||||
migrated_credentials = await self.migrate_old_token(credentials)
|
||||
# Update the credentials store would need to be handled by the caller
|
||||
# For now, use the migrated credentials for this request
|
||||
credentials = migrated_credentials
|
||||
except LinearAPIException:
|
||||
# Migration failed, try to use the old token as-is
|
||||
# This maintains backward compatibility
|
||||
pass
|
||||
|
||||
# Use the standard refresh logic from the base class
|
||||
if self.needs_refresh(credentials):
|
||||
credentials = await self.refresh_tokens(credentials)
|
||||
|
||||
return credentials.access_token.get_secret_value()
|
||||
|
||||
def needs_migration(self, credentials: OAuth2Credentials) -> bool:
|
||||
"""
|
||||
Check if credentials represent an old long-lived token that needs migration.
|
||||
|
||||
Old tokens have no expiration time and no refresh token.
|
||||
"""
|
||||
return (
|
||||
credentials.access_token_expires_at is None
|
||||
and credentials.refresh_token is None
|
||||
)
|
||||
|
||||
async def _request_username(self, access_token: str) -> Optional[str]:
|
||||
# Use the LinearClient to fetch user details using GraphQL
|
||||
from ._api import LinearClient
|
||||
|
||||
@@ -37,5 +37,5 @@ class Project(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
priority: int
|
||||
progress: int
|
||||
content: str
|
||||
progress: float
|
||||
content: str | None
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
# This file contains a lot of prompt block strings that would trigger "line too long"
|
||||
# flake8: noqa: E501
|
||||
import ast
|
||||
import logging
|
||||
import re
|
||||
import secrets
|
||||
from abc import ABC
|
||||
from enum import Enum, EnumMeta
|
||||
from json import JSONDecodeError
|
||||
@@ -27,7 +31,7 @@ from backend.util.prompt import compress_prompt, estimate_token_count
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
fmt = TextFormatter()
|
||||
fmt = TextFormatter(autoescape=False)
|
||||
|
||||
LLMProviderName = Literal[
|
||||
ProviderName.AIML_API,
|
||||
@@ -97,9 +101,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
|
||||
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# AI/ML API models
|
||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||
@@ -204,20 +208,20 @@ MODEL_METADATA = {
|
||||
"anthropic", 200000, 32000
|
||||
), # claude-opus-4-1-20250805
|
||||
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
"anthropic", 200000, 32000
|
||||
), # claude-4-opus-20250514
|
||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-4-sonnet-20250514
|
||||
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-sonnet-4-5-20250929
|
||||
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-haiku-4-5-20251001
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-3-7-sonnet-20250219
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-3-5-sonnet-20241022
|
||||
LlmModel.CLAUDE_3_5_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-3-5-haiku-20241022
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 4096
|
||||
), # claude-3-haiku-20240307
|
||||
@@ -382,7 +386,9 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
||||
return None
|
||||
|
||||
|
||||
def get_parallel_tool_calls_param(llm_model: LlmModel, parallel_tool_calls):
|
||||
def get_parallel_tool_calls_param(
|
||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||
):
|
||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||
return openai.NOT_GIVEN
|
||||
@@ -393,8 +399,8 @@ async def llm_call(
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
max_tokens: int | None,
|
||||
force_json_output: bool = False,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
parallel_tool_calls=None,
|
||||
@@ -407,7 +413,7 @@ async def llm_call(
|
||||
credentials: The API key credentials to use.
|
||||
llm_model: The LLM model to use.
|
||||
prompt: The prompt to send to the LLM.
|
||||
json_format: Whether the response should be in JSON format.
|
||||
force_json_output: Whether the response should be in JSON format.
|
||||
max_tokens: The maximum number of tokens to generate in the chat completion.
|
||||
tools: The tools to use in the chat completion.
|
||||
ollama_host: The host for ollama to use.
|
||||
@@ -446,7 +452,7 @@ async def llm_call(
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
if json_format:
|
||||
if force_json_output:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = await oai_client.chat.completions.create(
|
||||
@@ -559,7 +565,7 @@ async def llm_call(
|
||||
raise ValueError("Groq does not support tools.")
|
||||
|
||||
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = {"type": "json_object"} if json_format else None
|
||||
response_format = {"type": "json_object"} if force_json_output else None
|
||||
response = await client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
@@ -717,7 +723,7 @@ async def llm_call(
|
||||
)
|
||||
|
||||
response_format = None
|
||||
if json_format:
|
||||
if force_json_output:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
@@ -780,6 +786,17 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
force_json_output: bool = SchemaField(
|
||||
title="Restrict LLM to pure JSON output",
|
||||
default=False,
|
||||
description=(
|
||||
"Whether to force the LLM to produce a JSON-only response. "
|
||||
"This can increase the block's reliability, "
|
||||
"but may also reduce the quality of the response "
|
||||
"because it prohibits the LLM from reasoning "
|
||||
"before providing its JSON response."
|
||||
),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
sys_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
@@ -848,17 +865,18 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
"llm_call": lambda *args, **kwargs: LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[""],
|
||||
response=json.dumps(
|
||||
{
|
||||
"key1": "key1Value",
|
||||
"key2": "key2Value",
|
||||
}
|
||||
response=(
|
||||
'<json_output id="test123456">{\n'
|
||||
' "key1": "key1Value",\n'
|
||||
' "key2": "key2Value"\n'
|
||||
"}</json_output>"
|
||||
),
|
||||
tool_calls=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
reasoning=None,
|
||||
)
|
||||
),
|
||||
"get_collision_proof_output_tag_id": lambda *args: "test123456",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -867,9 +885,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
compress_prompt_to_fit: bool,
|
||||
max_tokens: int | None,
|
||||
force_json_output: bool = False,
|
||||
compress_prompt_to_fit: bool = True,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
) -> LLMResponse:
|
||||
@@ -882,8 +900,8 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
json_format=json_format,
|
||||
max_tokens=max_tokens,
|
||||
force_json_output=force_json_output,
|
||||
tools=tools,
|
||||
ollama_host=ollama_host,
|
||||
compress_prompt_to_fit=compress_prompt_to_fit,
|
||||
@@ -895,10 +913,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
logger.debug(f"Calling LLM with input data: {input_data}")
|
||||
prompt = [json.to_dict(p) for p in input_data.conversation_history]
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
lines = s.strip().split("\n")
|
||||
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = fmt.format_string(input_data.prompt, values)
|
||||
@@ -907,27 +921,15 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
if input_data.sys_prompt:
|
||||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||
|
||||
# Use a one-time unique tag to prevent collisions with user/LLM content
|
||||
output_tag_id = self.get_collision_proof_output_tag_id()
|
||||
output_tag_start = f'<json_output id="{output_tag_id}">'
|
||||
if input_data.expected_format:
|
||||
expected_format = [
|
||||
f'"{k}": "{v}"' for k, v in input_data.expected_format.items()
|
||||
]
|
||||
if input_data.list_result:
|
||||
format_prompt = (
|
||||
f'"results": [\n {{\n {", ".join(expected_format)}\n }}\n]'
|
||||
)
|
||||
else:
|
||||
format_prompt = "\n ".join(expected_format)
|
||||
|
||||
sys_prompt = trim_prompt(
|
||||
f"""
|
||||
|Reply strictly only in the following JSON format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
|
|
||||
|Ensure the response is valid JSON. Do not include any additional text outside of the JSON.
|
||||
|If you cannot provide all the keys, provide an empty string for the values you cannot answer.
|
||||
"""
|
||||
sys_prompt = self.response_format_instructions(
|
||||
input_data.expected_format,
|
||||
list_mode=input_data.list_result,
|
||||
pure_json_mode=input_data.force_json_output,
|
||||
output_tag_start=output_tag_start,
|
||||
)
|
||||
prompt.append({"role": "system", "content": sys_prompt})
|
||||
|
||||
@@ -945,18 +947,21 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
except JSONDecodeError as e:
|
||||
return f"JSON decode error: {e}"
|
||||
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
retry_prompt = ""
|
||||
error_feedback_message = ""
|
||||
llm_model = input_data.model
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
try:
|
||||
llm_response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
compress_prompt_to_fit=input_data.compress_prompt_to_fit,
|
||||
json_format=bool(input_data.expected_format),
|
||||
force_json_output=(
|
||||
input_data.force_json_output
|
||||
and bool(input_data.expected_format)
|
||||
),
|
||||
ollama_host=input_data.ollama_host,
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
@@ -970,16 +975,55 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
try:
|
||||
response_obj = self.get_json_from_response(
|
||||
response_text,
|
||||
pure_json_mode=input_data.force_json_output,
|
||||
output_tag_start=output_tag_start,
|
||||
)
|
||||
except (ValueError, JSONDecodeError) as parse_error:
|
||||
censored_response = re.sub(r"[A-Za-z0-9]", "*", response_text)
|
||||
response_snippet = (
|
||||
f"{censored_response[:50]}...{censored_response[-30:]}"
|
||||
)
|
||||
logger.warning(
|
||||
f"Error getting JSON from LLM response: {parse_error}\n\n"
|
||||
f"Response start+end: `{response_snippet}`"
|
||||
)
|
||||
prompt.append({"role": "assistant", "content": response_text})
|
||||
|
||||
response_obj = json.loads(response_text)
|
||||
error_feedback_message = self.invalid_response_feedback(
|
||||
parse_error,
|
||||
was_parseable=False,
|
||||
list_mode=input_data.list_result,
|
||||
pure_json_mode=input_data.force_json_output,
|
||||
output_tag_start=output_tag_start,
|
||||
)
|
||||
prompt.append(
|
||||
{"role": "user", "content": error_feedback_message}
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle object response for `force_json_output`+`list_result`
|
||||
if input_data.list_result and isinstance(response_obj, dict):
|
||||
if "results" in response_obj:
|
||||
response_obj = response_obj.get("results", [])
|
||||
elif len(response_obj) == 1:
|
||||
response_obj = list(response_obj.values())
|
||||
if "results" in response_obj and isinstance(
|
||||
response_obj["results"], list
|
||||
):
|
||||
response_obj = response_obj["results"]
|
||||
else:
|
||||
error_feedback_message = (
|
||||
"Expected an array of objects in the 'results' key, "
|
||||
f"but got: {response_obj}"
|
||||
)
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": response_text}
|
||||
)
|
||||
prompt.append(
|
||||
{"role": "user", "content": error_feedback_message}
|
||||
)
|
||||
continue
|
||||
|
||||
response_error = "\n".join(
|
||||
validation_errors = "\n".join(
|
||||
[
|
||||
validation_error
|
||||
for response_item in (
|
||||
@@ -991,7 +1035,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
]
|
||||
)
|
||||
|
||||
if not response_error:
|
||||
if not validation_errors:
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
@@ -1001,6 +1045,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
yield "response", response_obj
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
|
||||
prompt.append({"role": "assistant", "content": response_text})
|
||||
error_feedback_message = self.invalid_response_feedback(
|
||||
validation_errors,
|
||||
was_parseable=True,
|
||||
list_mode=input_data.list_result,
|
||||
pure_json_mode=input_data.force_json_output,
|
||||
output_tag_start=output_tag_start,
|
||||
)
|
||||
prompt.append({"role": "user", "content": error_feedback_message})
|
||||
else:
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
@@ -1011,21 +1065,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
yield "response", {"response": response_text}
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
|
||||
retry_prompt = trim_prompt(
|
||||
f"""
|
||||
|This is your previous error response:
|
||||
|--
|
||||
|{response_text}
|
||||
|--
|
||||
|
|
||||
|And this is the error:
|
||||
|--
|
||||
|{response_error}
|
||||
|--
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "user", "content": retry_prompt})
|
||||
except Exception as e:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
if (
|
||||
@@ -1038,9 +1077,133 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
logger.debug(
|
||||
f"Reducing max_tokens to {input_data.max_tokens} for next attempt"
|
||||
)
|
||||
retry_prompt = f"Error calling LLM: {e}"
|
||||
# Don't add retry prompt for token limit errors,
|
||||
# just retry with lower maximum output tokens
|
||||
|
||||
raise RuntimeError(retry_prompt)
|
||||
error_feedback_message = f"Error calling LLM: {e}"
|
||||
|
||||
raise RuntimeError(error_feedback_message)
|
||||
|
||||
def response_format_instructions(
|
||||
self,
|
||||
expected_object_format: dict[str, str],
|
||||
*,
|
||||
list_mode: bool,
|
||||
pure_json_mode: bool,
|
||||
output_tag_start: str,
|
||||
) -> str:
|
||||
expected_output_format = json.dumps(expected_object_format, indent=2)
|
||||
output_type = "object" if not list_mode else "array"
|
||||
outer_output_type = "object" if pure_json_mode else output_type
|
||||
|
||||
if output_type == "array":
|
||||
indented_obj_format = expected_output_format.replace("\n", "\n ")
|
||||
expected_output_format = f"[\n {indented_obj_format},\n ...\n]"
|
||||
if pure_json_mode:
|
||||
indented_list_format = expected_output_format.replace("\n", "\n ")
|
||||
expected_output_format = (
|
||||
"{\n"
|
||||
' "reasoning": "... (optional)",\n' # for better performance
|
||||
f' "results": {indented_list_format}\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
# Preserve indentation in prompt
|
||||
expected_output_format = expected_output_format.replace("\n", "\n|")
|
||||
|
||||
# Prepare prompt
|
||||
if not pure_json_mode:
|
||||
expected_output_format = (
|
||||
f"{output_tag_start}\n{expected_output_format}\n</json_output>"
|
||||
)
|
||||
|
||||
instructions = f"""
|
||||
|In your response you MUST include a valid JSON {outer_output_type} strictly following this format:
|
||||
|{expected_output_format}
|
||||
|
|
||||
|If you cannot provide all the keys, you MUST provide an empty string for the values you cannot answer.
|
||||
""".strip()
|
||||
|
||||
if not pure_json_mode:
|
||||
instructions += f"""
|
||||
|
|
||||
|You MUST enclose your final JSON answer in {output_tag_start}...</json_output> tags, even if the user specifies a different tag.
|
||||
|There MUST be exactly ONE {output_tag_start}...</json_output> block in your response, which MUST ONLY contain the JSON {outer_output_type} and nothing else. Other text outside this block is allowed.
|
||||
""".strip()
|
||||
|
||||
return trim_prompt(instructions)
|
||||
|
||||
def invalid_response_feedback(
|
||||
self,
|
||||
error,
|
||||
*,
|
||||
was_parseable: bool,
|
||||
list_mode: bool,
|
||||
pure_json_mode: bool,
|
||||
output_tag_start: str,
|
||||
) -> str:
|
||||
outer_output_type = "object" if not list_mode or pure_json_mode else "array"
|
||||
|
||||
if was_parseable:
|
||||
complaint = f"Your previous response did not match the expected {outer_output_type} format."
|
||||
else:
|
||||
complaint = f"Your previous response did not contain a parseable JSON {outer_output_type}."
|
||||
|
||||
indented_parse_error = str(error).replace("\n", "\n|")
|
||||
|
||||
instruction = (
|
||||
f"Please provide a {output_tag_start}...</json_output> block containing a"
|
||||
if not pure_json_mode
|
||||
else "Please provide a"
|
||||
) + f" valid JSON {outer_output_type} that matches the expected format."
|
||||
|
||||
return trim_prompt(
|
||||
f"""
|
||||
|{complaint}
|
||||
|
|
||||
|{indented_parse_error}
|
||||
|
|
||||
|{instruction}
|
||||
"""
|
||||
)
|
||||
|
||||
def get_json_from_response(
|
||||
self, response_text: str, *, pure_json_mode: bool, output_tag_start: str
|
||||
) -> dict[str, Any] | list[dict[str, Any]]:
|
||||
if pure_json_mode:
|
||||
# Handle pure JSON responses
|
||||
try:
|
||||
return json.loads(response_text)
|
||||
except JSONDecodeError as first_parse_error:
|
||||
# If that didn't work, try finding the { and } to deal with possible ```json fences etc.
|
||||
json_start = response_text.find("{")
|
||||
json_end = response_text.rfind("}")
|
||||
try:
|
||||
return json.loads(response_text[json_start : json_end + 1])
|
||||
except JSONDecodeError:
|
||||
# Raise the original error, as it's more likely to be relevant
|
||||
raise first_parse_error from None
|
||||
|
||||
if output_tag_start not in response_text:
|
||||
raise ValueError(
|
||||
"Response does not contain the expected "
|
||||
f"{output_tag_start}...</json_output> block."
|
||||
)
|
||||
json_output = (
|
||||
response_text.split(output_tag_start, 1)[1]
|
||||
.rsplit("</json_output>", 1)[0]
|
||||
.strip()
|
||||
)
|
||||
return json.loads(json_output)
|
||||
|
||||
def get_collision_proof_output_tag_id(self) -> str:
|
||||
return secrets.token_hex(8)
|
||||
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
"""Removes indentation up to and including `|` from a multi-line prompt."""
|
||||
lines = s.strip().split("\n")
|
||||
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||
|
||||
|
||||
class AITextGeneratorBlock(AIBlockBase):
|
||||
@@ -1237,11 +1400,27 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
|
||||
@staticmethod
|
||||
def _split_text(text: str, max_tokens: int, overlap: int) -> list[str]:
|
||||
# Security fix: Add validation to prevent DoS attacks
|
||||
# Limit text size to prevent memory exhaustion
|
||||
MAX_TEXT_LENGTH = 1_000_000 # 1MB character limit
|
||||
MAX_CHUNKS = 100 # Maximum number of chunks to prevent excessive memory use
|
||||
|
||||
if len(text) > MAX_TEXT_LENGTH:
|
||||
text = text[:MAX_TEXT_LENGTH]
|
||||
|
||||
# Ensure chunk_size is at least 1 to prevent infinite loops
|
||||
chunk_size = max(1, max_tokens - overlap)
|
||||
|
||||
# Ensure overlap is less than max_tokens to prevent invalid configurations
|
||||
if overlap >= max_tokens:
|
||||
overlap = max(0, max_tokens - 1)
|
||||
|
||||
words = text.split()
|
||||
chunks = []
|
||||
chunk_size = max_tokens - overlap
|
||||
|
||||
for i in range(0, len(words), chunk_size):
|
||||
if len(chunks) >= MAX_CHUNKS:
|
||||
break # Limit the number of chunks to prevent memory exhaustion
|
||||
chunk = " ".join(words[i : i + max_tokens])
|
||||
chunks.append(chunk)
|
||||
|
||||
@@ -1375,7 +1554,9 @@ class AIConversationBlock(AIBlockBase):
|
||||
("prompt", list),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: "The 2020 World Series was played at Globe Life Field in Arlington, Texas."
|
||||
"llm_call": lambda *args, **kwargs: dict(
|
||||
response="The 2020 World Series was played at Globe Life Field in Arlington, Texas."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1404,7 +1585,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
),
|
||||
credentials=credentials,
|
||||
)
|
||||
yield "response", response
|
||||
yield "response", response["response"]
|
||||
yield "prompt", self.prompt
|
||||
|
||||
|
||||
|
||||
536
autogpt_platform/backend/backend/blocks/notion/_api.py
Normal file
536
autogpt_platform/backend/backend/blocks/notion/_api.py
Normal file
@@ -0,0 +1,536 @@
|
||||
"""
|
||||
Notion API helper functions and client for making authenticated requests.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.util.request import Requests
|
||||
|
||||
NOTION_VERSION = "2022-06-28"
|
||||
|
||||
|
||||
class NotionAPIException(Exception):
|
||||
"""Exception raised for Notion API errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class NotionClient:
|
||||
"""Client for interacting with the Notion API."""
|
||||
|
||||
def __init__(self, credentials: OAuth2Credentials):
|
||||
self.credentials = credentials
|
||||
self.headers = {
|
||||
"Authorization": credentials.auth_header(),
|
||||
"Notion-Version": NOTION_VERSION,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self.requests = Requests()
|
||||
|
||||
async def get_page(self, page_id: str) -> dict:
|
||||
"""
|
||||
Fetch a page by ID.
|
||||
|
||||
Args:
|
||||
page_id: The ID of the page to fetch.
|
||||
|
||||
Returns:
|
||||
The page object from Notion API.
|
||||
"""
|
||||
url = f"https://api.notion.com/v1/pages/{page_id}"
|
||||
response = await self.requests.get(url, headers=self.headers)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to fetch page: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def get_blocks(self, block_id: str, recursive: bool = True) -> List[dict]:
|
||||
"""
|
||||
Fetch all blocks from a page or block.
|
||||
|
||||
Args:
|
||||
block_id: The ID of the page or block to fetch children from.
|
||||
recursive: Whether to fetch nested blocks recursively.
|
||||
|
||||
Returns:
|
||||
List of block objects.
|
||||
"""
|
||||
blocks = []
|
||||
cursor = None
|
||||
|
||||
while True:
|
||||
url = f"https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
params = {"page_size": 100}
|
||||
if cursor:
|
||||
params["start_cursor"] = cursor
|
||||
|
||||
response = await self.requests.get(url, headers=self.headers, params=params)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to fetch blocks: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
current_blocks = data.get("results", [])
|
||||
|
||||
# If recursive, fetch children for blocks that have them
|
||||
if recursive:
|
||||
for block in current_blocks:
|
||||
if block.get("has_children"):
|
||||
block["children"] = await self.get_blocks(
|
||||
block["id"], recursive=True
|
||||
)
|
||||
|
||||
blocks.extend(current_blocks)
|
||||
|
||||
if not data.get("has_more"):
|
||||
break
|
||||
cursor = data.get("next_cursor")
|
||||
|
||||
return blocks
|
||||
|
||||
async def query_database(
|
||||
self,
|
||||
database_id: str,
|
||||
filter_obj: Optional[dict] = None,
|
||||
sorts: Optional[List[dict]] = None,
|
||||
page_size: int = 100,
|
||||
) -> dict:
|
||||
"""
|
||||
Query a database with optional filters and sorts.
|
||||
|
||||
Args:
|
||||
database_id: The ID of the database to query.
|
||||
filter_obj: Optional filter object for the query.
|
||||
sorts: Optional list of sort objects.
|
||||
page_size: Number of results per page.
|
||||
|
||||
Returns:
|
||||
Query results including pages and pagination info.
|
||||
"""
|
||||
url = f"https://api.notion.com/v1/databases/{database_id}/query"
|
||||
|
||||
payload: Dict[str, Any] = {"page_size": page_size}
|
||||
if filter_obj:
|
||||
payload["filter"] = filter_obj
|
||||
if sorts:
|
||||
payload["sorts"] = sorts
|
||||
|
||||
response = await self.requests.post(url, headers=self.headers, json=payload)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to query database: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def create_page(
|
||||
self,
|
||||
parent: dict,
|
||||
properties: dict,
|
||||
children: Optional[List[dict]] = None,
|
||||
icon: Optional[dict] = None,
|
||||
cover: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a new page.
|
||||
|
||||
Args:
|
||||
parent: Parent object (page_id or database_id).
|
||||
properties: Page properties.
|
||||
children: Optional list of block children.
|
||||
icon: Optional icon object.
|
||||
cover: Optional cover object.
|
||||
|
||||
Returns:
|
||||
The created page object.
|
||||
"""
|
||||
url = "https://api.notion.com/v1/pages"
|
||||
|
||||
payload: Dict[str, Any] = {"parent": parent, "properties": properties}
|
||||
|
||||
if children:
|
||||
payload["children"] = children
|
||||
if icon:
|
||||
payload["icon"] = icon
|
||||
if cover:
|
||||
payload["cover"] = cover
|
||||
|
||||
response = await self.requests.post(url, headers=self.headers, json=payload)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to create page: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def update_page(self, page_id: str, properties: dict) -> dict:
|
||||
"""
|
||||
Update a page's properties.
|
||||
|
||||
Args:
|
||||
page_id: The ID of the page to update.
|
||||
properties: Properties to update.
|
||||
|
||||
Returns:
|
||||
The updated page object.
|
||||
"""
|
||||
url = f"https://api.notion.com/v1/pages/{page_id}"
|
||||
|
||||
response = await self.requests.patch(
|
||||
url, headers=self.headers, json={"properties": properties}
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to update page: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def append_blocks(self, block_id: str, children: List[dict]) -> dict:
|
||||
"""
|
||||
Append blocks to a page or block.
|
||||
|
||||
Args:
|
||||
block_id: The ID of the page or block to append to.
|
||||
children: List of block objects to append.
|
||||
|
||||
Returns:
|
||||
Response with the created blocks.
|
||||
"""
|
||||
url = f"https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
|
||||
response = await self.requests.patch(
|
||||
url, headers=self.headers, json={"children": children}
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to append blocks: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str = "",
|
||||
filter_obj: Optional[dict] = None,
|
||||
sort: Optional[dict] = None,
|
||||
page_size: int = 100,
|
||||
) -> dict:
|
||||
"""
|
||||
Search for pages and databases.
|
||||
|
||||
Args:
|
||||
query: Search query text.
|
||||
filter_obj: Optional filter object.
|
||||
sort: Optional sort object.
|
||||
page_size: Number of results per page.
|
||||
|
||||
Returns:
|
||||
Search results.
|
||||
"""
|
||||
url = "https://api.notion.com/v1/search"
|
||||
|
||||
payload: Dict[str, Any] = {"page_size": page_size}
|
||||
if query:
|
||||
payload["query"] = query
|
||||
if filter_obj:
|
||||
payload["filter"] = filter_obj
|
||||
if sort:
|
||||
payload["sort"] = sort
|
||||
|
||||
response = await self.requests.post(url, headers=self.headers, json=payload)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Search failed: {response.status} - {response.text()}", response.status
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
# Conversion helper functions
|
||||
|
||||
|
||||
def parse_rich_text(rich_text_array: List[dict]) -> str:
|
||||
"""
|
||||
Extract plain text from a Notion rich text array.
|
||||
|
||||
Args:
|
||||
rich_text_array: Array of rich text objects from Notion.
|
||||
|
||||
Returns:
|
||||
Plain text string.
|
||||
"""
|
||||
if not rich_text_array:
|
||||
return ""
|
||||
|
||||
text_parts = []
|
||||
for text_obj in rich_text_array:
|
||||
if "plain_text" in text_obj:
|
||||
text_parts.append(text_obj["plain_text"])
|
||||
|
||||
return "".join(text_parts)
|
||||
|
||||
|
||||
def rich_text_to_markdown(rich_text_array: List[dict]) -> str:
|
||||
"""
|
||||
Convert Notion rich text array to markdown with formatting.
|
||||
|
||||
Args:
|
||||
rich_text_array: Array of rich text objects from Notion.
|
||||
|
||||
Returns:
|
||||
Markdown formatted string.
|
||||
"""
|
||||
if not rich_text_array:
|
||||
return ""
|
||||
|
||||
markdown_parts = []
|
||||
|
||||
for text_obj in rich_text_array:
|
||||
text = text_obj.get("plain_text", "")
|
||||
annotations = text_obj.get("annotations", {})
|
||||
|
||||
# Apply formatting based on annotations
|
||||
if annotations.get("code"):
|
||||
text = f"`{text}`"
|
||||
else:
|
||||
if annotations.get("bold"):
|
||||
text = f"**{text}**"
|
||||
if annotations.get("italic"):
|
||||
text = f"*{text}*"
|
||||
if annotations.get("strikethrough"):
|
||||
text = f"~~{text}~~"
|
||||
if annotations.get("underline"):
|
||||
text = f"<u>{text}</u>"
|
||||
|
||||
# Handle links
|
||||
if text_obj.get("href"):
|
||||
text = f"[{text}]({text_obj['href']})"
|
||||
|
||||
markdown_parts.append(text)
|
||||
|
||||
return "".join(markdown_parts)
|
||||
|
||||
|
||||
def block_to_markdown(block: dict, indent_level: int = 0) -> str:
|
||||
"""
|
||||
Convert a single Notion block to markdown.
|
||||
|
||||
Args:
|
||||
block: Block object from Notion API.
|
||||
indent_level: Current indentation level for nested blocks.
|
||||
|
||||
Returns:
|
||||
Markdown string representation of the block.
|
||||
"""
|
||||
block_type = block.get("type")
|
||||
indent = " " * indent_level
|
||||
markdown_lines = []
|
||||
|
||||
# Handle different block types
|
||||
if block_type == "paragraph":
|
||||
text = rich_text_to_markdown(block["paragraph"].get("rich_text", []))
|
||||
if text:
|
||||
markdown_lines.append(f"{indent}{text}")
|
||||
|
||||
elif block_type == "heading_1":
|
||||
text = parse_rich_text(block["heading_1"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}# {text}")
|
||||
|
||||
elif block_type == "heading_2":
|
||||
text = parse_rich_text(block["heading_2"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}## {text}")
|
||||
|
||||
elif block_type == "heading_3":
|
||||
text = parse_rich_text(block["heading_3"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}### {text}")
|
||||
|
||||
elif block_type == "bulleted_list_item":
|
||||
text = rich_text_to_markdown(block["bulleted_list_item"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}- {text}")
|
||||
|
||||
elif block_type == "numbered_list_item":
|
||||
text = rich_text_to_markdown(block["numbered_list_item"].get("rich_text", []))
|
||||
# Note: This is simplified - proper numbering would need context
|
||||
markdown_lines.append(f"{indent}1. {text}")
|
||||
|
||||
elif block_type == "to_do":
|
||||
text = rich_text_to_markdown(block["to_do"].get("rich_text", []))
|
||||
checked = "x" if block["to_do"].get("checked") else " "
|
||||
markdown_lines.append(f"{indent}- [{checked}] {text}")
|
||||
|
||||
elif block_type == "toggle":
|
||||
text = rich_text_to_markdown(block["toggle"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}<details>")
|
||||
markdown_lines.append(f"{indent}<summary>{text}</summary>")
|
||||
markdown_lines.append(f"{indent}")
|
||||
# Process children if they exist
|
||||
if block.get("children"):
|
||||
for child in block["children"]:
|
||||
child_markdown = block_to_markdown(child, indent_level + 1)
|
||||
if child_markdown:
|
||||
markdown_lines.append(child_markdown)
|
||||
markdown_lines.append(f"{indent}</details>")
|
||||
|
||||
elif block_type == "code":
|
||||
code = parse_rich_text(block["code"].get("rich_text", []))
|
||||
language = block["code"].get("language", "")
|
||||
markdown_lines.append(f"{indent}```{language}")
|
||||
markdown_lines.append(f"{indent}{code}")
|
||||
markdown_lines.append(f"{indent}```")
|
||||
|
||||
elif block_type == "quote":
|
||||
text = rich_text_to_markdown(block["quote"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}> {text}")
|
||||
|
||||
elif block_type == "divider":
|
||||
markdown_lines.append(f"{indent}---")
|
||||
|
||||
elif block_type == "image":
|
||||
image = block["image"]
|
||||
url = image.get("external", {}).get("url") or image.get("file", {}).get(
|
||||
"url", ""
|
||||
)
|
||||
caption = parse_rich_text(image.get("caption", []))
|
||||
alt_text = caption if caption else "Image"
|
||||
markdown_lines.append(f"{indent}")
|
||||
if caption:
|
||||
markdown_lines.append(f"{indent}*{caption}*")
|
||||
|
||||
elif block_type == "video":
|
||||
video = block["video"]
|
||||
url = video.get("external", {}).get("url") or video.get("file", {}).get(
|
||||
"url", ""
|
||||
)
|
||||
caption = parse_rich_text(video.get("caption", []))
|
||||
markdown_lines.append(f"{indent}[Video]({url})")
|
||||
if caption:
|
||||
markdown_lines.append(f"{indent}*{caption}*")
|
||||
|
||||
elif block_type == "file":
|
||||
file = block["file"]
|
||||
url = file.get("external", {}).get("url") or file.get("file", {}).get("url", "")
|
||||
caption = parse_rich_text(file.get("caption", []))
|
||||
name = caption if caption else "File"
|
||||
markdown_lines.append(f"{indent}[{name}]({url})")
|
||||
|
||||
elif block_type == "bookmark":
|
||||
url = block["bookmark"].get("url", "")
|
||||
caption = parse_rich_text(block["bookmark"].get("caption", []))
|
||||
markdown_lines.append(f"{indent}[{caption if caption else url}]({url})")
|
||||
|
||||
elif block_type == "equation":
|
||||
expression = block["equation"].get("expression", "")
|
||||
markdown_lines.append(f"{indent}$${expression}$$")
|
||||
|
||||
elif block_type == "callout":
|
||||
text = rich_text_to_markdown(block["callout"].get("rich_text", []))
|
||||
icon = block["callout"].get("icon", {})
|
||||
if icon.get("emoji"):
|
||||
markdown_lines.append(f"{indent}> {icon['emoji']} {text}")
|
||||
else:
|
||||
markdown_lines.append(f"{indent}> ℹ️ {text}")
|
||||
|
||||
elif block_type == "child_page":
|
||||
title = block["child_page"].get("title", "Untitled")
|
||||
markdown_lines.append(f"{indent}📄 [{title}](notion://page/{block['id']})")
|
||||
|
||||
elif block_type == "child_database":
|
||||
title = block["child_database"].get("title", "Untitled Database")
|
||||
markdown_lines.append(f"{indent}🗂️ [{title}](notion://database/{block['id']})")
|
||||
|
||||
elif block_type == "table":
|
||||
# Tables are complex - for now just indicate there's a table
|
||||
markdown_lines.append(
|
||||
f"{indent}[Table with {block['table'].get('table_width', 0)} columns]"
|
||||
)
|
||||
|
||||
elif block_type == "column_list":
|
||||
# Process columns
|
||||
if block.get("children"):
|
||||
markdown_lines.append(f"{indent}<div style='display: flex'>")
|
||||
for column in block["children"]:
|
||||
markdown_lines.append(f"{indent}<div style='flex: 1'>")
|
||||
if column.get("children"):
|
||||
for child in column["children"]:
|
||||
child_markdown = block_to_markdown(child, indent_level + 1)
|
||||
if child_markdown:
|
||||
markdown_lines.append(child_markdown)
|
||||
markdown_lines.append(f"{indent}</div>")
|
||||
markdown_lines.append(f"{indent}</div>")
|
||||
|
||||
# Handle children for blocks that haven't been processed yet
|
||||
elif block.get("children") and block_type not in ["toggle", "column_list"]:
|
||||
for child in block["children"]:
|
||||
child_markdown = block_to_markdown(child, indent_level)
|
||||
if child_markdown:
|
||||
markdown_lines.append(child_markdown)
|
||||
|
||||
return "\n".join(markdown_lines) if markdown_lines else ""
|
||||
|
||||
|
||||
def blocks_to_markdown(blocks: List[dict]) -> str:
|
||||
"""
|
||||
Convert a list of Notion blocks to a markdown document.
|
||||
|
||||
Args:
|
||||
blocks: List of block objects from Notion API.
|
||||
|
||||
Returns:
|
||||
Complete markdown document as a string.
|
||||
"""
|
||||
markdown_parts = []
|
||||
|
||||
for i, block in enumerate(blocks):
|
||||
markdown = block_to_markdown(block)
|
||||
if markdown:
|
||||
markdown_parts.append(markdown)
|
||||
# Add spacing between top-level blocks (except lists)
|
||||
if i < len(blocks) - 1:
|
||||
next_type = blocks[i + 1].get("type", "")
|
||||
current_type = block.get("type", "")
|
||||
# Don't add extra spacing between list items
|
||||
list_types = {"bulleted_list_item", "numbered_list_item", "to_do"}
|
||||
if not (current_type in list_types and next_type in list_types):
|
||||
markdown_parts.append("")
|
||||
|
||||
return "\n".join(markdown_parts)
|
||||
|
||||
|
||||
def extract_page_title(page: dict) -> str:
|
||||
"""
|
||||
Extract the title from a Notion page object.
|
||||
|
||||
Args:
|
||||
page: Page object from Notion API.
|
||||
|
||||
Returns:
|
||||
Page title as a string.
|
||||
"""
|
||||
properties = page.get("properties", {})
|
||||
|
||||
# Find the title property (it has type "title")
|
||||
for prop_name, prop_value in properties.items():
|
||||
if prop_value.get("type") == "title":
|
||||
return parse_rich_text(prop_value.get("title", []))
|
||||
|
||||
return "Untitled"
|
||||
42
autogpt_platform/backend/backend/blocks/notion/_auth.py
Normal file
42
autogpt_platform/backend/backend/blocks/notion/_auth.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput, OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
secrets = Secrets()
|
||||
NOTION_OAUTH_IS_CONFIGURED = bool(
|
||||
secrets.notion_client_id and secrets.notion_client_secret
|
||||
)
|
||||
|
||||
NotionCredentials = OAuth2Credentials
|
||||
NotionCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.NOTION], Literal["oauth2"]
|
||||
]
|
||||
|
||||
|
||||
def NotionCredentialsField() -> NotionCredentialsInput:
|
||||
"""Creates a Notion OAuth2 credentials field."""
|
||||
return CredentialsField(
|
||||
description="Connect your Notion account. Ensure the pages/databases are shared with the integration."
|
||||
)
|
||||
|
||||
|
||||
# Test credentials for Notion OAuth2
|
||||
TEST_CREDENTIALS = OAuth2Credentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="notion",
|
||||
access_token=SecretStr("test_access_token"),
|
||||
title="Mock Notion OAuth",
|
||||
scopes=["read_content", "insert_content", "update_content"],
|
||||
username="testuser",
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
360
autogpt_platform/backend/backend/blocks/notion/create_page.py
Normal file
360
autogpt_platform/backend/backend/blocks/notion/create_page.py
Normal file
@@ -0,0 +1,360 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import NotionClient
|
||||
from ._auth import (
|
||||
NOTION_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
NotionCredentialsField,
|
||||
NotionCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class NotionCreatePageBlock(Block):
|
||||
"""Create a new page in Notion with content."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: NotionCredentialsInput = NotionCredentialsField()
|
||||
parent_page_id: Optional[str] = SchemaField(
|
||||
description="Parent page ID to create the page under. Either this OR parent_database_id is required.",
|
||||
default=None,
|
||||
)
|
||||
parent_database_id: Optional[str] = SchemaField(
|
||||
description="Parent database ID to create the page in. Either this OR parent_page_id is required.",
|
||||
default=None,
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Title of the new page",
|
||||
)
|
||||
content: Optional[str] = SchemaField(
|
||||
description="Content for the page. Can be plain text or markdown - will be converted to Notion blocks.",
|
||||
default=None,
|
||||
)
|
||||
properties: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Additional properties for database pages (e.g., {'Status': 'In Progress', 'Priority': 'High'})",
|
||||
default=None,
|
||||
)
|
||||
icon_emoji: Optional[str] = SchemaField(
|
||||
description="Emoji to use as the page icon (e.g., '📄', '🚀')", default=None
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_parent(self):
|
||||
"""Ensure either parent_page_id or parent_database_id is provided."""
|
||||
if not self.parent_page_id and not self.parent_database_id:
|
||||
raise ValueError(
|
||||
"Either parent_page_id or parent_database_id must be provided"
|
||||
)
|
||||
if self.parent_page_id and self.parent_database_id:
|
||||
raise ValueError(
|
||||
"Only one of parent_page_id or parent_database_id should be provided, not both"
|
||||
)
|
||||
return self
|
||||
|
||||
class Output(BlockSchema):
|
||||
page_id: str = SchemaField(description="ID of the created page.")
|
||||
page_url: str = SchemaField(description="URL of the created page.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c15febe0-66ce-4c6f-aebd-5ab351653804",
|
||||
description="Create a new page in Notion. Requires EITHER a parent_page_id OR parent_database_id. Supports markdown content.",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=NotionCreatePageBlock.Input,
|
||||
output_schema=NotionCreatePageBlock.Output,
|
||||
disabled=not NOTION_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"parent_page_id": "00000000-0000-0000-0000-000000000000",
|
||||
"title": "Test Page",
|
||||
"content": "This is test content.",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("page_id", "12345678-1234-1234-1234-123456789012"),
|
||||
(
|
||||
"page_url",
|
||||
"https://notion.so/Test-Page-12345678123412341234123456789012",
|
||||
),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"create_page": lambda *args, **kwargs: (
|
||||
"12345678-1234-1234-1234-123456789012",
|
||||
"https://notion.so/Test-Page-12345678123412341234123456789012",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _markdown_to_blocks(content: str) -> List[dict]:
|
||||
"""Convert markdown content to Notion block objects."""
|
||||
if not content:
|
||||
return []
|
||||
|
||||
blocks = []
|
||||
lines = content.split("\n")
|
||||
i = 0
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
|
||||
# Skip empty lines
|
||||
if not line.strip():
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Headings
|
||||
if line.startswith("### "):
|
||||
blocks.append(
|
||||
{
|
||||
"type": "heading_3",
|
||||
"heading_3": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[4:].strip()}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
elif line.startswith("## "):
|
||||
blocks.append(
|
||||
{
|
||||
"type": "heading_2",
|
||||
"heading_2": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[3:].strip()}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
elif line.startswith("# "):
|
||||
blocks.append(
|
||||
{
|
||||
"type": "heading_1",
|
||||
"heading_1": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[2:].strip()}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Bullet points
|
||||
elif line.strip().startswith("- "):
|
||||
blocks.append(
|
||||
{
|
||||
"type": "bulleted_list_item",
|
||||
"bulleted_list_item": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": line.strip()[2:].strip()},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Numbered list
|
||||
elif line.strip() and line.strip()[0].isdigit() and ". " in line:
|
||||
content_start = line.find(". ") + 2
|
||||
blocks.append(
|
||||
{
|
||||
"type": "numbered_list_item",
|
||||
"numbered_list_item": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": line[content_start:].strip()},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Code block
|
||||
elif line.strip().startswith("```"):
|
||||
code_lines = []
|
||||
language = line[3:].strip() or "plain text"
|
||||
i += 1
|
||||
while i < len(lines) and not lines[i].strip().startswith("```"):
|
||||
code_lines.append(lines[i])
|
||||
i += 1
|
||||
blocks.append(
|
||||
{
|
||||
"type": "code",
|
||||
"code": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": "\n".join(code_lines)},
|
||||
}
|
||||
],
|
||||
"language": language,
|
||||
},
|
||||
}
|
||||
)
|
||||
# Quote
|
||||
elif line.strip().startswith("> "):
|
||||
blocks.append(
|
||||
{
|
||||
"type": "quote",
|
||||
"quote": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": line.strip()[2:].strip()},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Horizontal rule
|
||||
elif line.strip() in ["---", "***", "___"]:
|
||||
blocks.append({"type": "divider", "divider": {}})
|
||||
# Regular paragraph
|
||||
else:
|
||||
# Parse for basic markdown formatting
|
||||
text_content = line.strip()
|
||||
rich_text = []
|
||||
|
||||
# Simple bold/italic parsing (this is simplified)
|
||||
if "**" in text_content or "*" in text_content:
|
||||
# For now, just pass as plain text
|
||||
# A full implementation would parse and create proper annotations
|
||||
rich_text = [{"type": "text", "text": {"content": text_content}}]
|
||||
else:
|
||||
rich_text = [{"type": "text", "text": {"content": text_content}}]
|
||||
|
||||
blocks.append(
|
||||
{"type": "paragraph", "paragraph": {"rich_text": rich_text}}
|
||||
)
|
||||
|
||||
i += 1
|
||||
|
||||
return blocks
|
||||
|
||||
@staticmethod
|
||||
def _build_properties(
|
||||
title: str, additional_properties: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Build properties object for page creation."""
|
||||
properties: Dict[str, Any] = {
|
||||
"title": {"title": [{"type": "text", "text": {"content": title}}]}
|
||||
}
|
||||
|
||||
if additional_properties:
|
||||
for key, value in additional_properties.items():
|
||||
if key.lower() == "title":
|
||||
continue # Skip title as we already have it
|
||||
|
||||
# Try to intelligently map property types
|
||||
if isinstance(value, bool):
|
||||
properties[key] = {"checkbox": value}
|
||||
elif isinstance(value, (int, float)):
|
||||
properties[key] = {"number": value}
|
||||
elif isinstance(value, list):
|
||||
# Assume multi-select
|
||||
properties[key] = {
|
||||
"multi_select": [{"name": str(item)} for item in value]
|
||||
}
|
||||
elif isinstance(value, str):
|
||||
# Could be select, rich_text, or other types
|
||||
# For simplicity, try common patterns
|
||||
if key.lower() in ["status", "priority", "type", "category"]:
|
||||
properties[key] = {"select": {"name": value}}
|
||||
elif key.lower() in ["url", "link"]:
|
||||
properties[key] = {"url": value}
|
||||
elif key.lower() in ["email"]:
|
||||
properties[key] = {"email": value}
|
||||
else:
|
||||
properties[key] = {
|
||||
"rich_text": [{"type": "text", "text": {"content": value}}]
|
||||
}
|
||||
|
||||
return properties
|
||||
|
||||
@staticmethod
|
||||
async def create_page(
|
||||
credentials: OAuth2Credentials,
|
||||
title: str,
|
||||
parent_page_id: Optional[str] = None,
|
||||
parent_database_id: Optional[str] = None,
|
||||
content: Optional[str] = None,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
icon_emoji: Optional[str] = None,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Create a new Notion page.
|
||||
|
||||
Returns:
|
||||
Tuple of (page_id, page_url)
|
||||
"""
|
||||
if not parent_page_id and not parent_database_id:
|
||||
raise ValueError(
|
||||
"Either parent_page_id or parent_database_id must be provided"
|
||||
)
|
||||
if parent_page_id and parent_database_id:
|
||||
raise ValueError(
|
||||
"Only one of parent_page_id or parent_database_id should be provided, not both"
|
||||
)
|
||||
|
||||
client = NotionClient(credentials)
|
||||
|
||||
# Build parent object
|
||||
if parent_page_id:
|
||||
parent = {"type": "page_id", "page_id": parent_page_id}
|
||||
else:
|
||||
parent = {"type": "database_id", "database_id": parent_database_id}
|
||||
|
||||
# Build properties
|
||||
page_properties = NotionCreatePageBlock._build_properties(title, properties)
|
||||
|
||||
# Convert content to blocks if provided
|
||||
children = None
|
||||
if content:
|
||||
children = NotionCreatePageBlock._markdown_to_blocks(content)
|
||||
|
||||
# Build icon if provided
|
||||
icon = None
|
||||
if icon_emoji:
|
||||
icon = {"type": "emoji", "emoji": icon_emoji}
|
||||
|
||||
# Create the page
|
||||
result = await client.create_page(
|
||||
parent=parent, properties=page_properties, children=children, icon=icon
|
||||
)
|
||||
|
||||
page_id = result.get("id", "")
|
||||
page_url = result.get("url", "")
|
||||
|
||||
if not page_id or not page_url:
|
||||
raise ValueError("Failed to get page ID or URL from Notion response")
|
||||
|
||||
return page_id, page_url
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
page_id, page_url = await self.create_page(
|
||||
credentials,
|
||||
input_data.title,
|
||||
input_data.parent_page_id,
|
||||
input_data.parent_database_id,
|
||||
input_data.content,
|
||||
input_data.properties,
|
||||
input_data.icon_emoji,
|
||||
)
|
||||
yield "page_id", page_id
|
||||
yield "page_url", page_url
|
||||
except Exception as e:
|
||||
yield "error", str(e) if str(e) else "Unknown error"
|
||||
285
autogpt_platform/backend/backend/blocks/notion/read_database.py
Normal file
285
autogpt_platform/backend/backend/blocks/notion/read_database.py
Normal file
@@ -0,0 +1,285 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import NotionClient, parse_rich_text
|
||||
from ._auth import (
|
||||
NOTION_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
NotionCredentialsField,
|
||||
NotionCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class NotionReadDatabaseBlock(Block):
|
||||
"""Query a Notion database and retrieve entries with their properties."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: NotionCredentialsInput = NotionCredentialsField()
|
||||
database_id: str = SchemaField(
|
||||
description="Notion database ID. Must be accessible by the connected integration.",
|
||||
)
|
||||
filter_property: Optional[str] = SchemaField(
|
||||
description="Property name to filter by (e.g., 'Status', 'Priority')",
|
||||
default=None,
|
||||
)
|
||||
filter_value: Optional[str] = SchemaField(
|
||||
description="Value to filter for in the specified property", default=None
|
||||
)
|
||||
sort_property: Optional[str] = SchemaField(
|
||||
description="Property name to sort by", default=None
|
||||
)
|
||||
sort_direction: Optional[str] = SchemaField(
|
||||
description="Sort direction: 'ascending' or 'descending'",
|
||||
default="ascending",
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of entries to retrieve",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
entries: List[Dict[str, Any]] = SchemaField(
|
||||
description="List of database entries with their properties."
|
||||
)
|
||||
entry: Dict[str, Any] = SchemaField(
|
||||
description="Individual database entry (yields one per entry found)."
|
||||
)
|
||||
entry_ids: List[str] = SchemaField(
|
||||
description="List of entry IDs for batch operations."
|
||||
)
|
||||
entry_id: str = SchemaField(
|
||||
description="Individual entry ID (yields one per entry found)."
|
||||
)
|
||||
count: int = SchemaField(description="Number of entries retrieved.")
|
||||
database_title: str = SchemaField(description="Title of the database.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fcd53135-88c9-4ba3-be50-cc6936286e6c",
|
||||
description="Query a Notion database with optional filtering and sorting, returning structured entries.",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=NotionReadDatabaseBlock.Input,
|
||||
output_schema=NotionReadDatabaseBlock.Output,
|
||||
disabled=not NOTION_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"database_id": "00000000-0000-0000-0000-000000000000",
|
||||
"limit": 10,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"entries",
|
||||
[{"Name": "Test Entry", "Status": "Active", "_id": "test-123"}],
|
||||
),
|
||||
("entry_ids", ["test-123"]),
|
||||
(
|
||||
"entry",
|
||||
{"Name": "Test Entry", "Status": "Active", "_id": "test-123"},
|
||||
),
|
||||
("entry_id", "test-123"),
|
||||
("count", 1),
|
||||
("database_title", "Test Database"),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"query_database": lambda *args, **kwargs: (
|
||||
[{"Name": "Test Entry", "Status": "Active", "_id": "test-123"}],
|
||||
1,
|
||||
"Test Database",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_property_value(prop: dict) -> Any:
|
||||
"""Parse a Notion property value into a simple Python type."""
|
||||
prop_type = prop.get("type")
|
||||
|
||||
if prop_type == "title":
|
||||
return parse_rich_text(prop.get("title", []))
|
||||
elif prop_type == "rich_text":
|
||||
return parse_rich_text(prop.get("rich_text", []))
|
||||
elif prop_type == "number":
|
||||
return prop.get("number")
|
||||
elif prop_type == "select":
|
||||
select = prop.get("select")
|
||||
return select.get("name") if select else None
|
||||
elif prop_type == "multi_select":
|
||||
return [item.get("name") for item in prop.get("multi_select", [])]
|
||||
elif prop_type == "date":
|
||||
date = prop.get("date")
|
||||
if date:
|
||||
return date.get("start")
|
||||
return None
|
||||
elif prop_type == "checkbox":
|
||||
return prop.get("checkbox", False)
|
||||
elif prop_type == "url":
|
||||
return prop.get("url")
|
||||
elif prop_type == "email":
|
||||
return prop.get("email")
|
||||
elif prop_type == "phone_number":
|
||||
return prop.get("phone_number")
|
||||
elif prop_type == "people":
|
||||
return [
|
||||
person.get("name", person.get("id"))
|
||||
for person in prop.get("people", [])
|
||||
]
|
||||
elif prop_type == "files":
|
||||
files = prop.get("files", [])
|
||||
return [
|
||||
f.get(
|
||||
"name",
|
||||
f.get("external", {}).get("url", f.get("file", {}).get("url")),
|
||||
)
|
||||
for f in files
|
||||
]
|
||||
elif prop_type == "relation":
|
||||
return [rel.get("id") for rel in prop.get("relation", [])]
|
||||
elif prop_type == "formula":
|
||||
formula = prop.get("formula", {})
|
||||
return formula.get(formula.get("type"))
|
||||
elif prop_type == "rollup":
|
||||
rollup = prop.get("rollup", {})
|
||||
return rollup.get(rollup.get("type"))
|
||||
elif prop_type == "created_time":
|
||||
return prop.get("created_time")
|
||||
elif prop_type == "created_by":
|
||||
return prop.get("created_by", {}).get(
|
||||
"name", prop.get("created_by", {}).get("id")
|
||||
)
|
||||
elif prop_type == "last_edited_time":
|
||||
return prop.get("last_edited_time")
|
||||
elif prop_type == "last_edited_by":
|
||||
return prop.get("last_edited_by", {}).get(
|
||||
"name", prop.get("last_edited_by", {}).get("id")
|
||||
)
|
||||
else:
|
||||
# Return the raw value for unknown types
|
||||
return prop
|
||||
|
||||
@staticmethod
|
||||
def _build_filter(property_name: str, value: str) -> dict:
|
||||
"""Build a simple filter object for a property."""
|
||||
# This is a simplified filter - in reality, you'd need to know the property type
|
||||
# For now, we'll try common filter types
|
||||
return {
|
||||
"or": [
|
||||
{"property": property_name, "rich_text": {"contains": value}},
|
||||
{"property": property_name, "title": {"contains": value}},
|
||||
{"property": property_name, "select": {"equals": value}},
|
||||
{"property": property_name, "multi_select": {"contains": value}},
|
||||
]
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def query_database(
|
||||
credentials: OAuth2Credentials,
|
||||
database_id: str,
|
||||
filter_property: Optional[str] = None,
|
||||
filter_value: Optional[str] = None,
|
||||
sort_property: Optional[str] = None,
|
||||
sort_direction: str = "ascending",
|
||||
limit: int = 100,
|
||||
) -> tuple[List[Dict[str, Any]], int, str]:
|
||||
"""
|
||||
Query a Notion database and parse the results.
|
||||
|
||||
Returns:
|
||||
Tuple of (entries_list, count, database_title)
|
||||
"""
|
||||
client = NotionClient(credentials)
|
||||
|
||||
# Build filter if specified
|
||||
filter_obj = None
|
||||
if filter_property and filter_value:
|
||||
filter_obj = NotionReadDatabaseBlock._build_filter(
|
||||
filter_property, filter_value
|
||||
)
|
||||
|
||||
# Build sorts if specified
|
||||
sorts = None
|
||||
if sort_property:
|
||||
sorts = [{"property": sort_property, "direction": sort_direction}]
|
||||
|
||||
# Query the database
|
||||
result = await client.query_database(
|
||||
database_id, filter_obj=filter_obj, sorts=sorts, page_size=limit
|
||||
)
|
||||
|
||||
# Parse the entries
|
||||
entries = []
|
||||
for page in result.get("results", []):
|
||||
entry = {}
|
||||
properties = page.get("properties", {})
|
||||
|
||||
for prop_name, prop_value in properties.items():
|
||||
entry[prop_name] = NotionReadDatabaseBlock._parse_property_value(
|
||||
prop_value
|
||||
)
|
||||
|
||||
# Add metadata
|
||||
entry["_id"] = page.get("id")
|
||||
entry["_url"] = page.get("url")
|
||||
entry["_created_time"] = page.get("created_time")
|
||||
entry["_last_edited_time"] = page.get("last_edited_time")
|
||||
|
||||
entries.append(entry)
|
||||
|
||||
# Get database title (we need to make a separate call for this)
|
||||
try:
|
||||
database_url = f"https://api.notion.com/v1/databases/{database_id}"
|
||||
db_response = await client.requests.get(
|
||||
database_url, headers=client.headers
|
||||
)
|
||||
if db_response.ok:
|
||||
db_data = db_response.json()
|
||||
db_title = parse_rich_text(db_data.get("title", []))
|
||||
else:
|
||||
db_title = "Unknown Database"
|
||||
except Exception:
|
||||
db_title = "Unknown Database"
|
||||
|
||||
return entries, len(entries), db_title
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
entries, count, db_title = await self.query_database(
|
||||
credentials,
|
||||
input_data.database_id,
|
||||
input_data.filter_property,
|
||||
input_data.filter_value,
|
||||
input_data.sort_property,
|
||||
input_data.sort_direction or "ascending",
|
||||
input_data.limit,
|
||||
)
|
||||
# Yield the complete list for batch operations
|
||||
yield "entries", entries
|
||||
|
||||
# Extract and yield IDs as a list for batch operations
|
||||
entry_ids = [entry["_id"] for entry in entries if "_id" in entry]
|
||||
yield "entry_ids", entry_ids
|
||||
|
||||
# Yield each individual entry and its ID for single connections
|
||||
for entry in entries:
|
||||
yield "entry", entry
|
||||
if "_id" in entry:
|
||||
yield "entry_id", entry["_id"]
|
||||
|
||||
yield "count", count
|
||||
yield "database_title", db_title
|
||||
except Exception as e:
|
||||
yield "error", str(e) if str(e) else "Unknown error"
|
||||
64
autogpt_platform/backend/backend/blocks/notion/read_page.py
Normal file
64
autogpt_platform/backend/backend/blocks/notion/read_page.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import NotionClient
|
||||
from ._auth import (
|
||||
NOTION_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
NotionCredentialsField,
|
||||
NotionCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class NotionReadPageBlock(Block):
|
||||
"""Read a Notion page by ID and return its raw JSON."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: NotionCredentialsInput = NotionCredentialsField()
|
||||
page_id: str = SchemaField(
|
||||
description="Notion page ID. Must be accessible by the connected integration. You can get this from the page URL notion.so/A-Page-586edd711467478da59fe3ce29a1ffab would be 586edd711467478da59fe35e29a1ffab",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
page: dict = SchemaField(description="Raw Notion page JSON.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5246cc1d-34b7-452b-8fc5-3fb25fd8f542",
|
||||
description="Read a Notion page by its ID and return its raw JSON.",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=NotionReadPageBlock.Input,
|
||||
output_schema=NotionReadPageBlock.Output,
|
||||
disabled=not NOTION_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"page_id": "00000000-0000-0000-0000-000000000000",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[("page", dict)],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"get_page": lambda *args, **kwargs: {"object": "page", "id": "mocked"}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_page(credentials: OAuth2Credentials, page_id: str) -> dict:
|
||||
client = NotionClient(credentials)
|
||||
return await client.get_page(page_id)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
page = await self.get_page(credentials, input_data.page_id)
|
||||
yield "page", page
|
||||
except Exception as e:
|
||||
yield "error", str(e) if str(e) else "Unknown error"
|
||||
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import NotionClient, blocks_to_markdown, extract_page_title
|
||||
from ._auth import (
|
||||
NOTION_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
NotionCredentialsField,
|
||||
NotionCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class NotionReadPageMarkdownBlock(Block):
|
||||
"""Read a Notion page and convert it to clean Markdown format."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: NotionCredentialsInput = NotionCredentialsField()
|
||||
page_id: str = SchemaField(
|
||||
description="Notion page ID. Must be accessible by the connected integration. You can get this from the page URL notion.so/A-Page-586edd711467478da59fe35e29a1ffab would be 586edd711467478da59fe35e29a1ffab",
|
||||
)
|
||||
include_title: bool = SchemaField(
|
||||
description="Whether to include the page title as a header in the markdown",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
markdown: str = SchemaField(description="Page content in Markdown format.")
|
||||
title: str = SchemaField(description="Page title.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d1312c4d-fae2-4e70-893d-f4d07cce1d4e",
|
||||
description="Read a Notion page and convert it to Markdown format with proper formatting for headings, lists, links, and rich text.",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=NotionReadPageMarkdownBlock.Input,
|
||||
output_schema=NotionReadPageMarkdownBlock.Output,
|
||||
disabled=not NOTION_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"page_id": "00000000-0000-0000-0000-000000000000",
|
||||
"include_title": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("markdown", "# Test Page\n\nThis is test content."),
|
||||
("title", "Test Page"),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"get_page_markdown": lambda *args, **kwargs: (
|
||||
"# Test Page\n\nThis is test content.",
|
||||
"Test Page",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_page_markdown(
|
||||
credentials: OAuth2Credentials, page_id: str, include_title: bool = True
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Get a Notion page and convert it to markdown.
|
||||
|
||||
Args:
|
||||
credentials: OAuth2 credentials for Notion.
|
||||
page_id: The ID of the page to fetch.
|
||||
include_title: Whether to include the page title in the markdown.
|
||||
|
||||
Returns:
|
||||
Tuple of (markdown_content, title)
|
||||
"""
|
||||
client = NotionClient(credentials)
|
||||
|
||||
# Get page metadata
|
||||
page = await client.get_page(page_id)
|
||||
title = extract_page_title(page)
|
||||
|
||||
# Get all blocks from the page
|
||||
blocks = await client.get_blocks(page_id, recursive=True)
|
||||
|
||||
# Convert blocks to markdown
|
||||
content_markdown = blocks_to_markdown(blocks)
|
||||
|
||||
# Combine title and content if requested
|
||||
if include_title and title:
|
||||
full_markdown = f"# {title}\n\n{content_markdown}"
|
||||
else:
|
||||
full_markdown = content_markdown
|
||||
|
||||
return full_markdown, title
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
markdown, title = await self.get_page_markdown(
|
||||
credentials, input_data.page_id, input_data.include_title
|
||||
)
|
||||
yield "markdown", markdown
|
||||
yield "title", title
|
||||
except Exception as e:
|
||||
yield "error", str(e) if str(e) else "Unknown error"
|
||||
225
autogpt_platform/backend/backend/blocks/notion/search.py
Normal file
225
autogpt_platform/backend/backend/blocks/notion/search.py
Normal file
@@ -0,0 +1,225 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import NotionClient, extract_page_title, parse_rich_text
|
||||
from ._auth import (
|
||||
NOTION_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
NotionCredentialsField,
|
||||
NotionCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class NotionSearchResult(BaseModel):
|
||||
"""Typed model for Notion search results."""
|
||||
|
||||
id: str
|
||||
type: str # 'page' or 'database'
|
||||
title: str
|
||||
url: str
|
||||
created_time: Optional[str] = None
|
||||
last_edited_time: Optional[str] = None
|
||||
parent_type: Optional[str] = None # 'page', 'database', or 'workspace'
|
||||
parent_id: Optional[str] = None
|
||||
icon: Optional[str] = None # emoji icon if present
|
||||
is_inline: Optional[bool] = None # for databases only
|
||||
|
||||
|
||||
class NotionSearchBlock(Block):
|
||||
"""Search across your Notion workspace for pages and databases."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: NotionCredentialsInput = NotionCredentialsField()
|
||||
query: str = SchemaField(
|
||||
description="Search query text. Leave empty to get all accessible pages/databases.",
|
||||
default="",
|
||||
)
|
||||
filter_type: Optional[str] = SchemaField(
|
||||
description="Filter results by type: 'page' or 'database'. Leave empty for both.",
|
||||
default=None,
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results to return", default=20, ge=1, le=100
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: List[NotionSearchResult] = SchemaField(
|
||||
description="List of search results with title, type, URL, and metadata."
|
||||
)
|
||||
result: NotionSearchResult = SchemaField(
|
||||
description="Individual search result (yields one per result found)."
|
||||
)
|
||||
result_ids: List[str] = SchemaField(
|
||||
description="List of IDs from search results for batch operations."
|
||||
)
|
||||
count: int = SchemaField(description="Number of results found.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="313515dd-9848-46ea-9cd6-3c627c892c56",
|
||||
description="Search your Notion workspace for pages and databases by text query.",
|
||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.SEARCH},
|
||||
input_schema=NotionSearchBlock.Input,
|
||||
output_schema=NotionSearchBlock.Output,
|
||||
disabled=not NOTION_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"query": "project",
|
||||
"limit": 5,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"results",
|
||||
[
|
||||
NotionSearchResult(
|
||||
id="123",
|
||||
type="page",
|
||||
title="Project Plan",
|
||||
url="https://notion.so/Project-Plan-123",
|
||||
)
|
||||
],
|
||||
),
|
||||
("result_ids", ["123"]),
|
||||
(
|
||||
"result",
|
||||
NotionSearchResult(
|
||||
id="123",
|
||||
type="page",
|
||||
title="Project Plan",
|
||||
url="https://notion.so/Project-Plan-123",
|
||||
),
|
||||
),
|
||||
("count", 1),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"search_workspace": lambda *args, **kwargs: (
|
||||
[
|
||||
NotionSearchResult(
|
||||
id="123",
|
||||
type="page",
|
||||
title="Project Plan",
|
||||
url="https://notion.so/Project-Plan-123",
|
||||
)
|
||||
],
|
||||
1,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def search_workspace(
|
||||
credentials: OAuth2Credentials,
|
||||
query: str = "",
|
||||
filter_type: Optional[str] = None,
|
||||
limit: int = 20,
|
||||
) -> tuple[List[NotionSearchResult], int]:
|
||||
"""
|
||||
Search the Notion workspace.
|
||||
|
||||
Returns:
|
||||
Tuple of (results_list, count)
|
||||
"""
|
||||
client = NotionClient(credentials)
|
||||
|
||||
# Build filter if type is specified
|
||||
filter_obj = None
|
||||
if filter_type:
|
||||
filter_obj = {"property": "object", "value": filter_type}
|
||||
|
||||
# Execute search
|
||||
response = await client.search(
|
||||
query=query, filter_obj=filter_obj, page_size=limit
|
||||
)
|
||||
|
||||
# Parse results
|
||||
results = []
|
||||
for item in response.get("results", []):
|
||||
result_data = {
|
||||
"id": item.get("id", ""),
|
||||
"type": item.get("object", ""),
|
||||
"url": item.get("url", ""),
|
||||
"created_time": item.get("created_time"),
|
||||
"last_edited_time": item.get("last_edited_time"),
|
||||
"title": "", # Will be set below
|
||||
}
|
||||
|
||||
# Extract title based on type
|
||||
if item.get("object") == "page":
|
||||
# For pages, get the title from properties
|
||||
result_data["title"] = extract_page_title(item)
|
||||
|
||||
# Add parent info
|
||||
parent = item.get("parent", {})
|
||||
if parent.get("type") == "page_id":
|
||||
result_data["parent_type"] = "page"
|
||||
result_data["parent_id"] = parent.get("page_id")
|
||||
elif parent.get("type") == "database_id":
|
||||
result_data["parent_type"] = "database"
|
||||
result_data["parent_id"] = parent.get("database_id")
|
||||
elif parent.get("type") == "workspace":
|
||||
result_data["parent_type"] = "workspace"
|
||||
|
||||
# Add icon if present
|
||||
icon = item.get("icon")
|
||||
if icon and icon.get("type") == "emoji":
|
||||
result_data["icon"] = icon.get("emoji")
|
||||
|
||||
elif item.get("object") == "database":
|
||||
# For databases, get title from the title array
|
||||
result_data["title"] = parse_rich_text(item.get("title", []))
|
||||
|
||||
# Add database-specific metadata
|
||||
result_data["is_inline"] = item.get("is_inline", False)
|
||||
|
||||
# Add parent info
|
||||
parent = item.get("parent", {})
|
||||
if parent.get("type") == "page_id":
|
||||
result_data["parent_type"] = "page"
|
||||
result_data["parent_id"] = parent.get("page_id")
|
||||
elif parent.get("type") == "workspace":
|
||||
result_data["parent_type"] = "workspace"
|
||||
|
||||
# Add icon if present
|
||||
icon = item.get("icon")
|
||||
if icon and icon.get("type") == "emoji":
|
||||
result_data["icon"] = icon.get("emoji")
|
||||
|
||||
results.append(NotionSearchResult(**result_data))
|
||||
|
||||
return results, len(results)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
results, count = await self.search_workspace(
|
||||
credentials, input_data.query, input_data.filter_type, input_data.limit
|
||||
)
|
||||
|
||||
# Yield the complete list for batch operations
|
||||
yield "results", results
|
||||
|
||||
# Extract and yield IDs as a list for batch operations
|
||||
result_ids = [r.id for r in results]
|
||||
yield "result_ids", result_ids
|
||||
|
||||
# Yield each individual result for single connections
|
||||
for result in results:
|
||||
yield "result", result
|
||||
|
||||
yield "count", count
|
||||
except Exception as e:
|
||||
yield "error", str(e) if str(e) else "Unknown error"
|
||||
226
autogpt_platform/backend/backend/blocks/perplexity.py
Normal file
226
autogpt_platform/backend/backend/blocks/perplexity.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# flake8: noqa: E501
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
import openai
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.logging import TruncatedLogger
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
|
||||
|
||||
|
||||
class PerplexityModel(str, Enum):
|
||||
"""Perplexity sonar models available via OpenRouter"""
|
||||
|
||||
SONAR = "perplexity/sonar"
|
||||
SONAR_PRO = "perplexity/sonar-pro"
|
||||
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
|
||||
|
||||
PerplexityCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="test-perplexity-creds",
|
||||
provider="open_router",
|
||||
api_key=SecretStr("mock-openrouter-api-key"),
|
||||
title="Mock OpenRouter API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def PerplexityCredentialsField() -> PerplexityCredentials:
|
||||
return CredentialsField(
|
||||
description="OpenRouter API key for accessing Perplexity models.",
|
||||
)
|
||||
|
||||
|
||||
class PerplexityBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
prompt: str = SchemaField(
|
||||
description="The query to send to the Perplexity model.",
|
||||
placeholder="Enter your query here...",
|
||||
)
|
||||
model: PerplexityModel = SchemaField(
|
||||
title="Perplexity Model",
|
||||
default=PerplexityModel.SONAR,
|
||||
description="The Perplexity sonar model to use.",
|
||||
advanced=False,
|
||||
)
|
||||
credentials: PerplexityCredentials = PerplexityCredentialsField()
|
||||
system_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="",
|
||||
description="Optional system prompt to provide context to the model.",
|
||||
advanced=True,
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(
|
||||
description="The response from the Perplexity model."
|
||||
)
|
||||
annotations: list[dict[str, Any]] = SchemaField(
|
||||
description="List of URL citations and annotations from the response."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c8a5f2e9-8b3d-4a7e-9f6c-1d5e3c9b7a4f",
|
||||
description="Query Perplexity's sonar models with real-time web search capabilities and receive annotated responses with source citations.",
|
||||
categories={BlockCategory.AI, BlockCategory.SEARCH},
|
||||
input_schema=PerplexityBlock.Input,
|
||||
output_schema=PerplexityBlock.Output,
|
||||
test_input={
|
||||
"prompt": "What is the weather today?",
|
||||
"model": PerplexityModel.SONAR,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("response", "The weather varies by location..."),
|
||||
("annotations", list),
|
||||
],
|
||||
test_mock={
|
||||
"call_perplexity": lambda *args, **kwargs: {
|
||||
"response": "The weather varies by location...",
|
||||
"annotations": [
|
||||
{
|
||||
"type": "url_citation",
|
||||
"url_citation": {
|
||||
"title": "weather.com",
|
||||
"url": "https://weather.com",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
self.execution_stats = NodeExecutionStats()
|
||||
|
||||
async def call_perplexity(
|
||||
self,
|
||||
credentials: APIKeyCredentials,
|
||||
model: PerplexityModel,
|
||||
prompt: str,
|
||||
system_prompt: str = "",
|
||||
max_tokens: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Call Perplexity via OpenRouter and extract annotations."""
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
extra_headers={
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=model.value,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError("No response from Perplexity via OpenRouter.")
|
||||
|
||||
# Extract the response content
|
||||
response_content = response.choices[0].message.content or ""
|
||||
|
||||
# Extract annotations if present in the message
|
||||
annotations = []
|
||||
if hasattr(response.choices[0].message, "annotations"):
|
||||
# If annotations are directly available
|
||||
annotations = response.choices[0].message.annotations
|
||||
else:
|
||||
# Check if there's a raw response with annotations
|
||||
raw = getattr(response.choices[0].message, "_raw_response", None)
|
||||
if isinstance(raw, dict) and "annotations" in raw:
|
||||
annotations = raw["annotations"]
|
||||
|
||||
if not annotations and hasattr(response, "model_extra"):
|
||||
# Check model_extra for annotations
|
||||
model_extra = response.model_extra
|
||||
if isinstance(model_extra, dict):
|
||||
# Check in choices
|
||||
if "choices" in model_extra and len(model_extra["choices"]) > 0:
|
||||
choice = model_extra["choices"][0]
|
||||
if "message" in choice and "annotations" in choice["message"]:
|
||||
annotations = choice["message"]["annotations"]
|
||||
|
||||
# Also check the raw response object for annotations
|
||||
if not annotations:
|
||||
raw = getattr(response, "_raw_response", None)
|
||||
if isinstance(raw, dict):
|
||||
# Check various possible locations for annotations
|
||||
if "annotations" in raw:
|
||||
annotations = raw["annotations"]
|
||||
elif "choices" in raw and len(raw["choices"]) > 0:
|
||||
choice = raw["choices"][0]
|
||||
if "message" in choice and "annotations" in choice["message"]:
|
||||
annotations = choice["message"]["annotations"]
|
||||
|
||||
# Update execution stats
|
||||
if response.usage:
|
||||
self.execution_stats.input_token_count = response.usage.prompt_tokens
|
||||
self.execution_stats.output_token_count = (
|
||||
response.usage.completion_tokens
|
||||
)
|
||||
|
||||
return {"response": response_content, "annotations": annotations or []}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Perplexity: {e}")
|
||||
raise
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
logger.debug(f"Running Perplexity block with model: {input_data.model}")
|
||||
|
||||
try:
|
||||
result = await self.call_perplexity(
|
||||
credentials=credentials,
|
||||
model=input_data.model,
|
||||
prompt=input_data.prompt,
|
||||
system_prompt=input_data.system_prompt,
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
|
||||
yield "response", result["response"]
|
||||
yield "annotations", result["annotations"]
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling Perplexity: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
yield "error", error_msg
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
@@ -7,6 +8,7 @@ import pydantic
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
class RSSEntry(pydantic.BaseModel):
|
||||
@@ -100,8 +102,33 @@ class ReadRSSFeedBlock(Block):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_feed(url: str) -> dict[str, Any]:
|
||||
return feedparser.parse(url) # type: ignore
|
||||
async def parse_feed(url: str) -> dict[str, Any]:
|
||||
# Security fix: Add protection against memory exhaustion attacks
|
||||
MAX_FEED_SIZE = 10 * 1024 * 1024 # 10MB limit for RSS feeds
|
||||
|
||||
# Download feed content with size limit
|
||||
try:
|
||||
response = await Requests(raise_for_status=True).get(url)
|
||||
|
||||
# Check content length if available
|
||||
content_length = response.headers.get("Content-Length")
|
||||
if content_length and int(content_length) > MAX_FEED_SIZE:
|
||||
raise ValueError(
|
||||
f"Feed too large: {content_length} bytes exceeds {MAX_FEED_SIZE} limit"
|
||||
)
|
||||
|
||||
# Get content with size limit
|
||||
content = response.content
|
||||
if len(content) > MAX_FEED_SIZE:
|
||||
raise ValueError(f"Feed too large: exceeds {MAX_FEED_SIZE} byte limit")
|
||||
|
||||
# Parse with feedparser using the validated content
|
||||
# feedparser has built-in protection against XML attacks
|
||||
return feedparser.parse(content) # type: ignore
|
||||
except Exception as e:
|
||||
# Log error and return empty feed
|
||||
logging.warning(f"Failed to parse RSS feed from {url}: {e}")
|
||||
return {"entries": []}
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
keep_going = True
|
||||
@@ -111,7 +138,7 @@ class ReadRSSFeedBlock(Block):
|
||||
while keep_going:
|
||||
keep_going = input_data.run_continuously
|
||||
|
||||
feed = self.parse_feed(input_data.rss_url)
|
||||
feed = await self.parse_feed(input_data.rss_url)
|
||||
all_entries = []
|
||||
|
||||
for entry in feed["entries"]:
|
||||
|
||||
@@ -13,6 +13,11 @@ from backend.data.block import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.dynamic_fields import (
|
||||
extract_base_field_name,
|
||||
get_dynamic_field_description,
|
||||
is_dynamic_field,
|
||||
)
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
@@ -98,6 +103,22 @@ def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
|
||||
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
|
||||
|
||||
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Safely convert raw_response to dictionary format for conversation history.
|
||||
Handles different response types from different LLM providers.
|
||||
"""
|
||||
if isinstance(raw_response, str):
|
||||
# Ollama returns a string, convert to dict format
|
||||
return {"role": "assistant", "content": raw_response}
|
||||
elif isinstance(raw_response, dict):
|
||||
# Already a dict (from tests or some providers)
|
||||
return raw_response
|
||||
else:
|
||||
# OpenAI/Anthropic return objects, convert with json.to_dict
|
||||
return json.to_dict(raw_response)
|
||||
|
||||
|
||||
def get_pending_tool_calls(conversation_history: list[Any]) -> dict[str, int]:
|
||||
"""
|
||||
All the tool calls entry in the conversation history requires a response.
|
||||
@@ -261,6 +282,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
def cleanup(s: str):
|
||||
"""Clean up block names for use as tool function names."""
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
|
||||
|
||||
@staticmethod
|
||||
@@ -288,41 +310,66 @@ class SmartDecisionMakerBlock(Block):
|
||||
}
|
||||
sink_block_input_schema = block.input_schema
|
||||
properties = {}
|
||||
field_mapping = {} # clean_name -> original_name
|
||||
|
||||
for link in links:
|
||||
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
|
||||
field_name = link.sink_name
|
||||
is_dynamic = is_dynamic_field(field_name)
|
||||
# Clean property key to ensure Anthropic API compatibility for ALL fields
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
field_mapping[clean_field_name] = field_name
|
||||
|
||||
# Handle dynamic fields (e.g., values_#_*, items_$_*, etc.)
|
||||
# These are fields that get merged by the executor into their base field
|
||||
if (
|
||||
"_#_" in link.sink_name
|
||||
or "_$_" in link.sink_name
|
||||
or "_@_" in link.sink_name
|
||||
):
|
||||
# For dynamic fields, provide a generic string schema
|
||||
# The executor will handle merging these into the appropriate structure
|
||||
properties[sink_name] = {
|
||||
if is_dynamic:
|
||||
# For dynamic fields, use cleaned name but preserve original in description
|
||||
properties[clean_field_name] = {
|
||||
"type": "string",
|
||||
"description": f"Dynamic value for {link.sink_name}",
|
||||
"description": get_dynamic_field_description(field_name),
|
||||
}
|
||||
else:
|
||||
# For regular fields, use the block's schema
|
||||
# For regular fields, use the block's schema directly
|
||||
try:
|
||||
properties[sink_name] = sink_block_input_schema.get_field_schema(
|
||||
link.sink_name
|
||||
properties[clean_field_name] = (
|
||||
sink_block_input_schema.get_field_schema(field_name)
|
||||
)
|
||||
except (KeyError, AttributeError):
|
||||
# If the field doesn't exist in the schema, provide a generic schema
|
||||
properties[sink_name] = {
|
||||
# If field doesn't exist in schema, provide a generic one
|
||||
properties[clean_field_name] = {
|
||||
"type": "string",
|
||||
"description": f"Value for {link.sink_name}",
|
||||
"description": f"Value for {field_name}",
|
||||
}
|
||||
|
||||
# Build the parameters schema using a single unified path
|
||||
base_schema = block.input_schema.jsonschema()
|
||||
base_required = set(base_schema.get("required", []))
|
||||
|
||||
# Compute required fields at the leaf level:
|
||||
# - If a linked field is dynamic and its base is required in the block schema, require the leaf
|
||||
# - If a linked field is regular and is required in the block schema, require the leaf
|
||||
required_fields: set[str] = set()
|
||||
for link in links:
|
||||
field_name = link.sink_name
|
||||
is_dynamic = is_dynamic_field(field_name)
|
||||
# Always use cleaned field name for property key (Anthropic API compliance)
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
|
||||
if is_dynamic:
|
||||
base_name = extract_base_field_name(field_name)
|
||||
if base_name in base_required:
|
||||
required_fields.add(clean_field_name)
|
||||
else:
|
||||
if field_name in base_required:
|
||||
required_fields.add(clean_field_name)
|
||||
|
||||
tool_function["parameters"] = {
|
||||
**block.input_schema.jsonschema(),
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"additionalProperties": False,
|
||||
"required": sorted(required_fields),
|
||||
}
|
||||
|
||||
# Store field mapping for later use in output processing
|
||||
tool_function["_field_mapping"] = field_mapping
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
@@ -366,13 +413,12 @@ class SmartDecisionMakerBlock(Block):
|
||||
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
|
||||
link.sink_name, {}
|
||||
)
|
||||
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
|
||||
description = (
|
||||
sink_block_properties["description"]
|
||||
if "description" in sink_block_properties
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[sink_name] = {
|
||||
properties[link.sink_name] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
"default": json.dumps(sink_block_properties.get("default", None)),
|
||||
@@ -388,24 +434,17 @@ class SmartDecisionMakerBlock(Block):
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
async def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
|
||||
async def _create_function_signature(
|
||||
node_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Creates function signatures for tools linked to a specified node within a graph.
|
||||
|
||||
This method filters the graph links to identify those that are tools and are
|
||||
connected to the given node_id. It then constructs function signatures for each
|
||||
tool based on the metadata and input schema of the linked nodes.
|
||||
Creates function signatures for connected tools.
|
||||
|
||||
Args:
|
||||
node_id: The node_id for which to create function signatures.
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: A list of dictionaries, each representing a function signature
|
||||
for a tool, including its name, description, and parameters.
|
||||
|
||||
Raises:
|
||||
ValueError: If no tool links are found for the specified node_id, or if a sink node
|
||||
or its metadata cannot be found.
|
||||
List of function signatures for tools
|
||||
"""
|
||||
db_client = get_database_manager_async_client()
|
||||
tools = [
|
||||
@@ -430,20 +469,116 @@ class SmartDecisionMakerBlock(Block):
|
||||
raise ValueError(f"Sink node not found: {links[0].sink_id}")
|
||||
|
||||
if sink_node.block_id == AgentExecutorBlock().id:
|
||||
return_tool_functions.append(
|
||||
tool_func = (
|
||||
await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
return_tool_functions.append(tool_func)
|
||||
else:
|
||||
return_tool_functions.append(
|
||||
tool_func = (
|
||||
await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
return_tool_functions.append(tool_func)
|
||||
|
||||
return return_tool_functions
|
||||
|
||||
async def _attempt_llm_call_with_validation(
|
||||
self,
|
||||
credentials: llm.APIKeyCredentials,
|
||||
input_data: Input,
|
||||
current_prompt: list[dict],
|
||||
tool_functions: list[dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Attempt a single LLM call with tool validation.
|
||||
|
||||
Returns the response if successful, raises ValueError if validation fails.
|
||||
"""
|
||||
resp = await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=current_prompt,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=input_data.multiple_tool_calls,
|
||||
)
|
||||
|
||||
# Track LLM usage stats per call
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=resp.prompt_tokens,
|
||||
output_token_count=resp.completion_tokens,
|
||||
llm_call_count=1,
|
||||
)
|
||||
)
|
||||
|
||||
if not resp.tool_calls:
|
||||
return resp
|
||||
validation_errors_list: list[str] = []
|
||||
for tool_call in resp.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
try:
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
except Exception as e:
|
||||
validation_errors_list.append(
|
||||
f"Tool call '{tool_name}' has invalid JSON arguments: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Find the tool definition to get the expected arguments
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if tool_def is None and len(tool_functions) == 1:
|
||||
tool_def = tool_functions[0]
|
||||
|
||||
# Get parameters schema from tool definition
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
and "parameters" in tool_def["function"]
|
||||
):
|
||||
parameters = tool_def["function"]["parameters"]
|
||||
expected_args = parameters.get("properties", {})
|
||||
required_params = set(parameters.get("required", []))
|
||||
else:
|
||||
expected_args = {arg: {} for arg in tool_args.keys()}
|
||||
required_params = set()
|
||||
|
||||
# Validate tool call arguments
|
||||
provided_args = set(tool_args.keys())
|
||||
expected_args_set = set(expected_args.keys())
|
||||
|
||||
# Check for unexpected arguments (typos)
|
||||
unexpected_args = provided_args - expected_args_set
|
||||
# Only check for missing REQUIRED parameters
|
||||
missing_required_args = required_params - provided_args
|
||||
|
||||
if unexpected_args or missing_required_args:
|
||||
error_msg = f"Tool call '{tool_name}' has parameter errors:"
|
||||
if unexpected_args:
|
||||
error_msg += f" Unknown parameters: {sorted(unexpected_args)}."
|
||||
if missing_required_args:
|
||||
error_msg += f" Missing required parameters: {sorted(missing_required_args)}."
|
||||
error_msg += f" Expected parameters: {sorted(expected_args_set)}."
|
||||
if required_params:
|
||||
error_msg += f" Required parameters: {sorted(required_params)}."
|
||||
validation_errors_list.append(error_msg)
|
||||
|
||||
if validation_errors_list:
|
||||
raise ValueError("; ".join(validation_errors_list))
|
||||
|
||||
return resp
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
@@ -466,27 +601,19 @@ class SmartDecisionMakerBlock(Block):
|
||||
if pending_tool_calls and input_data.last_tool_output is None:
|
||||
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
|
||||
|
||||
# Only assign the last tool output to the first pending tool call
|
||||
tool_output = []
|
||||
if pending_tool_calls and input_data.last_tool_output is not None:
|
||||
# Get the first pending tool call ID
|
||||
first_call_id = next(iter(pending_tool_calls.keys()))
|
||||
tool_output.append(
|
||||
_create_tool_response(first_call_id, input_data.last_tool_output)
|
||||
)
|
||||
|
||||
# Add tool output to prompt right away
|
||||
prompt.extend(tool_output)
|
||||
|
||||
# Check if there are still pending tool calls after handling the first one
|
||||
remaining_pending_calls = get_pending_tool_calls(prompt)
|
||||
|
||||
# If there are still pending tool calls, yield the conversation and return early
|
||||
if remaining_pending_calls:
|
||||
yield "conversations", prompt
|
||||
return
|
||||
|
||||
# Fallback on adding tool output in the conversation history as user prompt.
|
||||
elif input_data.last_tool_output:
|
||||
logger.error(
|
||||
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
|
||||
@@ -519,25 +646,33 @@ class SmartDecisionMakerBlock(Block):
|
||||
):
|
||||
prompt.append({"role": "user", "content": prefix + input_data.prompt})
|
||||
|
||||
response = await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
json_format=False,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=input_data.multiple_tool_calls,
|
||||
)
|
||||
current_prompt = list(prompt)
|
||||
max_attempts = max(1, int(input_data.retry))
|
||||
response = None
|
||||
|
||||
# Track LLM usage stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
llm_call_count=1,
|
||||
last_error = None
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
response = await self._attempt_llm_call_with_validation(
|
||||
credentials, input_data, current_prompt, tool_functions
|
||||
)
|
||||
break
|
||||
|
||||
except ValueError as e:
|
||||
last_error = e
|
||||
error_feedback = (
|
||||
"Your tool call had parameter errors. Please fix the following issues and try again:\n"
|
||||
+ f"- {str(e)}\n"
|
||||
+ "\nPlease make sure to use the exact parameter names as specified in the function schema."
|
||||
)
|
||||
current_prompt = list(current_prompt) + [
|
||||
{"role": "user", "content": error_feedback}
|
||||
]
|
||||
|
||||
if response is None:
|
||||
raise last_error or ValueError(
|
||||
"Failed to get valid response after all retry attempts"
|
||||
)
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
yield "finished", response.response
|
||||
@@ -547,7 +682,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
# Find the tool definition to get the expected arguments
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
@@ -556,7 +690,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
@@ -564,20 +697,38 @@ class SmartDecisionMakerBlock(Block):
|
||||
):
|
||||
expected_args = tool_def["function"]["parameters"].get("properties", {})
|
||||
else:
|
||||
expected_args = tool_args.keys()
|
||||
expected_args = {arg: {} for arg in tool_args.keys()}
|
||||
|
||||
# Yield provided arguments and None for missing ones
|
||||
for arg_name in expected_args:
|
||||
if arg_name in tool_args:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args[arg_name]
|
||||
else:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", None
|
||||
# Get field mapping from tool definition
|
||||
field_mapping = (
|
||||
tool_def.get("function", {}).get("_field_mapping", {})
|
||||
if tool_def
|
||||
else {}
|
||||
)
|
||||
|
||||
for clean_arg_name in expected_args:
|
||||
# arg_name is now always the cleaned field name (for Anthropic API compliance)
|
||||
# Get the original field name from field mapping for proper emit key generation
|
||||
original_field_name = field_mapping.get(clean_arg_name, clean_arg_name)
|
||||
arg_value = tool_args.get(clean_arg_name)
|
||||
|
||||
sanitized_tool_name = self.cleanup(tool_name)
|
||||
sanitized_arg_name = self.cleanup(original_field_name)
|
||||
emit_key = f"tools_^_{sanitized_tool_name}_~_{sanitized_arg_name}"
|
||||
|
||||
logger.debug(
|
||||
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",
|
||||
graph_exec_id,
|
||||
node_exec_id,
|
||||
emit_key,
|
||||
)
|
||||
yield emit_key, arg_value
|
||||
|
||||
# Add reasoning to conversation history if available
|
||||
if response.reasoning:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
)
|
||||
|
||||
prompt.append(response.raw_response)
|
||||
prompt.append(_convert_raw_response_to_dict(response.raw_response))
|
||||
|
||||
yield "conversations", prompt
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
@@ -26,6 +27,13 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
# Suppress false positive cleanup warning of litellm (a dependency of stagehand)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="coroutine 'close_litellm_async_clients' was never awaited",
|
||||
category=RuntimeWarning,
|
||||
)
|
||||
|
||||
# Store the original method
|
||||
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ async def test_block_ids_valid(block: Type[Block]):
|
||||
# Skip list for blocks with known invalid UUIDs
|
||||
skip_blocks = {
|
||||
"GetWeatherInformationBlock",
|
||||
"CodeExecutionBlock",
|
||||
"ExecuteCodeBlock",
|
||||
"CountdownTimerBlock",
|
||||
"TwitterGetListTweetsBlock",
|
||||
"TwitterRemoveListMemberBlock",
|
||||
|
||||
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Test security fixes for various DoS vulnerabilities.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.code_extraction_block import CodeExtractionBlock
|
||||
from backend.blocks.iteration import StepThroughItemsBlock
|
||||
from backend.blocks.llm import AITextSummarizerBlock
|
||||
from backend.blocks.text import ExtractTextInformationBlock
|
||||
from backend.blocks.xml_parser import XMLParserBlock
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
class TestCodeExtractionBlockSecurity:
|
||||
"""Test ReDoS fixes in CodeExtractionBlock."""
|
||||
|
||||
async def test_redos_protection(self):
|
||||
"""Test that the regex patterns don't cause ReDoS."""
|
||||
block = CodeExtractionBlock()
|
||||
|
||||
# Test with input that would previously cause ReDoS
|
||||
malicious_input = "```python" + " " * 10000 # Large spaces
|
||||
|
||||
result = []
|
||||
async for output_name, output_data in block.run(
|
||||
CodeExtractionBlock.Input(text=malicious_input)
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Should complete without hanging
|
||||
assert len(result) >= 1
|
||||
assert any(name == "remaining_text" for name, _ in result)
|
||||
|
||||
|
||||
class TestAITextSummarizerBlockSecurity:
|
||||
"""Test memory exhaustion fixes in AITextSummarizerBlock."""
|
||||
|
||||
def test_split_text_limits(self):
|
||||
"""Test that _split_text has proper limits."""
|
||||
# Test text size limit
|
||||
large_text = "a" * 2_000_000 # 2MB text
|
||||
result = AITextSummarizerBlock._split_text(large_text, 1000, 100)
|
||||
|
||||
# Should be truncated to 1MB
|
||||
total_chars = sum(len(chunk) for chunk in result)
|
||||
assert total_chars <= 1_000_000 + 1000 # Allow for chunk boundary
|
||||
|
||||
# Test chunk count limit
|
||||
result = AITextSummarizerBlock._split_text("word " * 10000, 10, 9)
|
||||
assert len(result) <= 100 # MAX_CHUNKS limit
|
||||
|
||||
# Test parameter validation
|
||||
result = AITextSummarizerBlock._split_text(
|
||||
"test", 10, 15
|
||||
) # overlap > max_tokens
|
||||
assert len(result) >= 1 # Should still work
|
||||
|
||||
|
||||
class TestExtractTextInformationBlockSecurity:
|
||||
"""Test ReDoS and memory exhaustion fixes in ExtractTextInformationBlock."""
|
||||
|
||||
async def test_text_size_limits(self):
|
||||
"""Test text size limits."""
|
||||
block = ExtractTextInformationBlock()
|
||||
|
||||
# Test with large input
|
||||
large_text = "a" * 2_000_000 # 2MB
|
||||
|
||||
results = []
|
||||
async for output_name, output_data in block.run(
|
||||
ExtractTextInformationBlock.Input(
|
||||
text=large_text, pattern=r"a+", find_all=True, group=0
|
||||
)
|
||||
):
|
||||
results.append((output_name, output_data))
|
||||
|
||||
# Should complete and have limits applied
|
||||
matched_results = [r for name, r in results if name == "matched_results"]
|
||||
if matched_results:
|
||||
assert len(matched_results[0]) <= 1000 # MAX_MATCHES limit
|
||||
|
||||
async def test_dangerous_pattern_timeout(self):
|
||||
"""Test timeout protection for dangerous patterns."""
|
||||
block = ExtractTextInformationBlock()
|
||||
|
||||
# Test with potentially dangerous lookahead pattern
|
||||
test_input = "a" * 1000
|
||||
|
||||
# This should complete quickly due to timeout protection
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
results = []
|
||||
async for output_name, output_data in block.run(
|
||||
ExtractTextInformationBlock.Input(
|
||||
text=test_input, pattern=r"(?=.+)", find_all=True, group=0
|
||||
)
|
||||
):
|
||||
results.append((output_name, output_data))
|
||||
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
# Should complete within reasonable time (much less than 5s timeout)
|
||||
assert (end_time - start_time) < 10
|
||||
|
||||
async def test_redos_catastrophic_backtracking(self):
|
||||
"""Test that ReDoS patterns with catastrophic backtracking are handled."""
|
||||
block = ExtractTextInformationBlock()
|
||||
|
||||
# Pattern that causes catastrophic backtracking: (a+)+b
|
||||
# With input "aaaaaaaaaaaaaaaaaaaaaaaaaaaa" (no 'b'), this causes exponential time
|
||||
dangerous_pattern = r"(a+)+b"
|
||||
test_input = "a" * 30 # 30 'a's without a 'b' at the end
|
||||
|
||||
# This should be handled by timeout protection or pattern detection
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
results = []
|
||||
|
||||
async for output_name, output_data in block.run(
|
||||
ExtractTextInformationBlock.Input(
|
||||
text=test_input, pattern=dangerous_pattern, find_all=True, group=0
|
||||
)
|
||||
):
|
||||
results.append((output_name, output_data))
|
||||
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
elapsed = end_time - start_time
|
||||
|
||||
# Should complete within timeout (6 seconds to be safe)
|
||||
# The current threading.Timer approach doesn't work, so this will likely fail
|
||||
# demonstrating the need for a fix
|
||||
assert elapsed < 6, f"Regex took {elapsed}s, timeout mechanism failed"
|
||||
|
||||
# Should return empty results on timeout or no match
|
||||
matched_results = [r for name, r in results if name == "matched_results"]
|
||||
assert matched_results[0] == [] # No matches expected
|
||||
|
||||
|
||||
class TestStepThroughItemsBlockSecurity:
|
||||
"""Test iteration limits in StepThroughItemsBlock."""
|
||||
|
||||
async def test_item_count_limits(self):
|
||||
"""Test maximum item count limits."""
|
||||
block = StepThroughItemsBlock()
|
||||
|
||||
# Test with too many items
|
||||
large_list = list(range(20000)) # Exceeds MAX_ITEMS (10000)
|
||||
|
||||
with pytest.raises(ValueError, match="Too many items"):
|
||||
async for _ in block.run(StepThroughItemsBlock.Input(items=large_list)):
|
||||
pass
|
||||
|
||||
async def test_string_size_limits(self):
|
||||
"""Test string input size limits."""
|
||||
block = StepThroughItemsBlock()
|
||||
|
||||
# Test with large JSON string
|
||||
large_string = '["item"]' * 200000 # Large JSON string
|
||||
|
||||
with pytest.raises(ValueError, match="Input too large"):
|
||||
async for _ in block.run(
|
||||
StepThroughItemsBlock.Input(items_str=large_string)
|
||||
):
|
||||
pass
|
||||
|
||||
async def test_normal_iteration_works(self):
|
||||
"""Test that normal iteration still works."""
|
||||
block = StepThroughItemsBlock()
|
||||
|
||||
results = []
|
||||
async for output_name, output_data in block.run(
|
||||
StepThroughItemsBlock.Input(items=[1, 2, 3])
|
||||
):
|
||||
results.append((output_name, output_data))
|
||||
|
||||
# Should have 6 outputs (item, key for each of 3 items)
|
||||
assert len(results) == 6
|
||||
items = [data for name, data in results if name == "item"]
|
||||
assert items == [1, 2, 3]
|
||||
|
||||
|
||||
class TestXMLParserBlockSecurity:
|
||||
"""Test XML size limits in XMLParserBlock."""
|
||||
|
||||
async def test_xml_size_limits(self):
|
||||
"""Test XML input size limits."""
|
||||
block = XMLParserBlock()
|
||||
|
||||
# Test with large XML - need to exceed 10MB limit
|
||||
# Each "<item>data</item>" is 17 chars, need ~620K items for >10MB
|
||||
large_xml = "<root>" + "<item>data</item>" * 620000 + "</root>"
|
||||
|
||||
with pytest.raises(ValueError, match="XML too large"):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=large_xml)):
|
||||
pass
|
||||
|
||||
|
||||
class TestStoreMediaFileSecurity:
|
||||
"""Test file storage security limits."""
|
||||
|
||||
@patch("backend.util.file.scan_content_safe")
|
||||
@patch("backend.util.file.get_cloud_storage_handler")
|
||||
async def test_file_size_limits(self, mock_cloud_storage, mock_scan):
|
||||
"""Test file size limits."""
|
||||
# Mock cloud storage handler - get_cloud_storage_handler is async
|
||||
# but is_cloud_path and parse_cloud_path are sync methods
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.is_cloud_path.return_value = False
|
||||
|
||||
# Make get_cloud_storage_handler an async function that returns the mock handler
|
||||
async def async_get_handler():
|
||||
return mock_handler
|
||||
|
||||
mock_cloud_storage.side_effect = async_get_handler
|
||||
mock_scan.return_value = None
|
||||
|
||||
# Test with large base64 content
|
||||
large_content = "a" * (200 * 1024 * 1024) # 200MB
|
||||
large_data_uri = f"data:text/plain;base64,{large_content}"
|
||||
|
||||
with pytest.raises(ValueError, match="File too large"):
|
||||
await store_media_file(
|
||||
graph_exec_id="test",
|
||||
file=MediaFileType(large_data_uri),
|
||||
user_id="test_user",
|
||||
)
|
||||
|
||||
@patch("backend.util.file.Path")
|
||||
@patch("backend.util.file.scan_content_safe")
|
||||
@patch("backend.util.file.get_cloud_storage_handler")
|
||||
async def test_directory_size_limits(self, mock_cloud_storage, mock_scan, MockPath):
|
||||
"""Test directory size limits."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.is_cloud_path.return_value = False
|
||||
|
||||
async def async_get_handler():
|
||||
return mock_handler
|
||||
|
||||
mock_cloud_storage.side_effect = async_get_handler
|
||||
mock_scan.return_value = None
|
||||
|
||||
# Create mock path instance for the execution directory
|
||||
mock_path_instance = MagicMock()
|
||||
mock_path_instance.exists.return_value = True
|
||||
|
||||
# Mock glob to return files that total > 1GB
|
||||
mock_file = MagicMock()
|
||||
mock_file.is_file.return_value = True
|
||||
mock_file.stat.return_value.st_size = 2 * 1024 * 1024 * 1024 # 2GB
|
||||
mock_path_instance.glob.return_value = [mock_file]
|
||||
|
||||
# Make Path() return our mock
|
||||
MockPath.return_value = mock_path_instance
|
||||
|
||||
# Should raise an error when directory size exceeds limit
|
||||
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
|
||||
await store_media_file(
|
||||
graph_exec_id="test",
|
||||
file=MediaFileType(
|
||||
"data:text/plain;base64,dGVzdA=="
|
||||
), # Small test file
|
||||
user_id="test_user",
|
||||
)
|
||||
@@ -30,7 +30,6 @@ class TestLLMStatsTracking:
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
llm_model=llm.LlmModel.GPT4O,
|
||||
prompt=[{"role": "user", "content": "Hello"}],
|
||||
json_format=False,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
@@ -42,6 +41,8 @@ class TestLLMStatsTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_structured_response_block_tracks_stats(self):
|
||||
"""Test that AIStructuredResponseGeneratorBlock correctly tracks stats."""
|
||||
from unittest.mock import patch
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
@@ -51,7 +52,7 @@ class TestLLMStatsTracking:
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='{"key1": "value1", "key2": "value2"}',
|
||||
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=15,
|
||||
completion_tokens=25,
|
||||
@@ -69,10 +70,12 @@ class TestLLMStatsTracking:
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
# Mock secrets.token_hex to return consistent ID
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Check stats
|
||||
assert block.execution_stats.input_token_count == 15
|
||||
@@ -143,7 +146,7 @@ class TestLLMStatsTracking:
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='{"wrong": "format"}',
|
||||
response='<json_output id="test123456">{"wrong": "format"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=15,
|
||||
@@ -154,7 +157,7 @@ class TestLLMStatsTracking:
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='{"key1": "value1", "key2": "value2"}',
|
||||
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=20,
|
||||
completion_tokens=25,
|
||||
@@ -173,10 +176,12 @@ class TestLLMStatsTracking:
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
# Mock secrets.token_hex to return consistent ID
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Check stats - should accumulate both calls
|
||||
# For 2 attempts: attempt 1 (failed) + attempt 2 (success) = 2 total
|
||||
@@ -269,7 +274,8 @@ class TestLLMStatsTracking:
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
content='{"summary": "Test chunk summary"}', tool_calls=None
|
||||
content='<json_output id="test123456">{"summary": "Test chunk summary"}</json_output>',
|
||||
tool_calls=None,
|
||||
)
|
||||
)
|
||||
]
|
||||
@@ -277,7 +283,7 @@ class TestLLMStatsTracking:
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
content='{"final_summary": "Test final summary"}',
|
||||
content='<json_output id="test123456">{"final_summary": "Test final summary"}</json_output>',
|
||||
tool_calls=None,
|
||||
)
|
||||
)
|
||||
@@ -298,11 +304,13 @@ class TestLLMStatsTracking:
|
||||
max_tokens=1000, # Large enough to avoid chunking
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
# Mock secrets.token_hex to return consistent ID
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
print(f"Actual calls made: {call_count}")
|
||||
print(f"Block stats: {block.execution_stats}")
|
||||
@@ -354,7 +362,7 @@ class TestLLMStatsTracking:
|
||||
assert block.execution_stats.llm_call_count == 1
|
||||
|
||||
# Check output
|
||||
assert outputs["response"] == {"response": "AI response to conversation"}
|
||||
assert outputs["response"] == "AI response to conversation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_list_generator_with_retries(self):
|
||||
@@ -457,7 +465,7 @@ class TestLLMStatsTracking:
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='{"result": "test"}',
|
||||
response='<json_output id="test123456">{"result": "test"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=20,
|
||||
@@ -476,10 +484,12 @@ class TestLLMStatsTracking:
|
||||
|
||||
# Run the block
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
# Mock secrets.token_hex to return consistent ID
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Block finished - now grab and assert stats
|
||||
assert block.execution_stats is not None
|
||||
|
||||
@@ -216,8 +216,17 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
}
|
||||
|
||||
# Mock the _create_function_signature method to avoid database calls
|
||||
with patch("backend.blocks.llm.llm_call", return_value=mock_response), patch.object(
|
||||
SmartDecisionMakerBlock, "_create_function_signature", return_value=[]
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
|
||||
# Create test input
|
||||
@@ -249,3 +258,471 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
# Verify outputs
|
||||
assert "finished" in outputs # Should have finished since no tool calls
|
||||
assert outputs["finished"] == "I need to think about this."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_parameter_validation():
|
||||
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock tool functions with specific parameter schema
|
||||
mock_tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_keywords",
|
||||
"description": "Search for keywords with difficulty filtering",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"max_keyword_difficulty": {
|
||||
"type": "integer",
|
||||
"description": "Maximum keyword difficulty (required)",
|
||||
},
|
||||
"optional_param": {
|
||||
"type": "string",
|
||||
"description": "Optional parameter with default",
|
||||
"default": "default_value",
|
||||
},
|
||||
},
|
||||
"required": ["query", "max_keyword_difficulty"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Test case 1: Tool call with TYPO in parameter name (should retry and eventually fail)
|
||||
mock_tool_call_with_typo = MagicMock()
|
||||
mock_tool_call_with_typo.function.name = "search_keywords"
|
||||
mock_tool_call_with_typo.function.arguments = '{"query": "test", "maximum_keyword_difficulty": 50}' # TYPO: maximum instead of max
|
||||
|
||||
mock_response_with_typo = MagicMock()
|
||||
mock_response_with_typo.response = None
|
||||
mock_response_with_typo.tool_calls = [mock_tool_call_with_typo]
|
||||
mock_response_with_typo.prompt_tokens = 50
|
||||
mock_response_with_typo.completion_tokens = 25
|
||||
mock_response_with_typo.reasoning = None
|
||||
mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_with_typo,
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2, # Set retry to 2 for testing
|
||||
)
|
||||
|
||||
# Should raise ValueError after retries due to typo'd parameter name
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify error message contains details about the typo
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Tool call 'search_keywords' has parameter errors" in error_msg
|
||||
assert "Unknown parameters: ['maximum_keyword_difficulty']" in error_msg
|
||||
|
||||
# Verify that LLM was called the expected number of times (retries)
|
||||
assert mock_llm_call.call_count == 2 # Should retry based on input_data.retry
|
||||
|
||||
# Test case 2: Tool call missing REQUIRED parameter (should raise ValueError)
|
||||
mock_tool_call_missing_required = MagicMock()
|
||||
mock_tool_call_missing_required.function.name = "search_keywords"
|
||||
mock_tool_call_missing_required.function.arguments = (
|
||||
'{"query": "test"}' # Missing required max_keyword_difficulty
|
||||
)
|
||||
|
||||
mock_response_missing_required = MagicMock()
|
||||
mock_response_missing_required.response = None
|
||||
mock_response_missing_required.tool_calls = [mock_tool_call_missing_required]
|
||||
mock_response_missing_required.prompt_tokens = 50
|
||||
mock_response_missing_required.completion_tokens = 25
|
||||
mock_response_missing_required.reasoning = None
|
||||
mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_missing_required,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
# Should raise ValueError due to missing required parameter
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Tool call 'search_keywords' has parameter errors" in error_msg
|
||||
assert "Missing required parameters: ['max_keyword_difficulty']" in error_msg
|
||||
|
||||
# Test case 3: Valid tool call with OPTIONAL parameter missing (should succeed)
|
||||
mock_tool_call_valid = MagicMock()
|
||||
mock_tool_call_valid.function.name = "search_keywords"
|
||||
mock_tool_call_valid.function.arguments = '{"query": "test", "max_keyword_difficulty": 50}' # optional_param missing, but that's OK
|
||||
|
||||
mock_response_valid = MagicMock()
|
||||
mock_response_valid.response = None
|
||||
mock_response_valid.tool_calls = [mock_tool_call_valid]
|
||||
mock_response_valid.prompt_tokens = 50
|
||||
mock_response_valid.completion_tokens = 25
|
||||
mock_response_valid.reasoning = None
|
||||
mock_response_valid.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_valid,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
# Should succeed - optional parameter missing is OK
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify tool outputs were generated correctly
|
||||
assert "tools_^_search_keywords_~_query" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_query"] == "test"
|
||||
assert "tools_^_search_keywords_~_max_keyword_difficulty" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
|
||||
# Optional parameter should be None when not provided
|
||||
assert "tools_^_search_keywords_~_optional_param" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_optional_param"] is None
|
||||
|
||||
# Test case 4: Valid tool call with ALL parameters (should succeed)
|
||||
mock_tool_call_all_params = MagicMock()
|
||||
mock_tool_call_all_params.function.name = "search_keywords"
|
||||
mock_tool_call_all_params.function.arguments = '{"query": "test", "max_keyword_difficulty": 50, "optional_param": "custom_value"}'
|
||||
|
||||
mock_response_all_params = MagicMock()
|
||||
mock_response_all_params.response = None
|
||||
mock_response_all_params.tool_calls = [mock_tool_call_all_params]
|
||||
mock_response_all_params.prompt_tokens = 50
|
||||
mock_response_all_params.completion_tokens = 25
|
||||
mock_response_all_params.reasoning = None
|
||||
mock_response_all_params.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_all_params,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
# Should succeed with all parameters
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify all tool outputs were generated correctly
|
||||
assert outputs["tools_^_search_keywords_~_query"] == "test"
|
||||
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
|
||||
assert outputs["tools_^_search_keywords_~_optional_param"] == "custom_value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_raw_response_conversion():
|
||||
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock tool functions
|
||||
mock_tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"param": {"type": "string"}},
|
||||
"required": ["param"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Test case 1: Simulate ChatCompletionMessage raw_response that caused the original error
|
||||
class MockChatCompletionMessage:
|
||||
"""Simulate OpenAI's ChatCompletionMessage object that lacks .get() method"""
|
||||
|
||||
def __init__(self, role, content, tool_calls=None):
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.tool_calls = tool_calls or []
|
||||
|
||||
# This is what caused the error - no .get() method
|
||||
# def get(self, key, default=None): # Intentionally missing
|
||||
|
||||
# First response: has invalid parameter name (triggers retry)
|
||||
mock_tool_call_invalid = MagicMock()
|
||||
mock_tool_call_invalid.function.name = "test_tool"
|
||||
mock_tool_call_invalid.function.arguments = (
|
||||
'{"wrong_param": "test_value"}' # Invalid parameter name
|
||||
)
|
||||
|
||||
mock_response_retry = MagicMock()
|
||||
mock_response_retry.response = None
|
||||
mock_response_retry.tool_calls = [mock_tool_call_invalid]
|
||||
mock_response_retry.prompt_tokens = 50
|
||||
mock_response_retry.completion_tokens = 25
|
||||
mock_response_retry.reasoning = None
|
||||
# This would cause the original error without our fix
|
||||
mock_response_retry.raw_response = MockChatCompletionMessage(
|
||||
role="assistant", content=None, tool_calls=[mock_tool_call_invalid]
|
||||
)
|
||||
|
||||
# Second response: successful (correct parameter name)
|
||||
mock_tool_call_valid = MagicMock()
|
||||
mock_tool_call_valid.function.name = "test_tool"
|
||||
mock_tool_call_valid.function.arguments = (
|
||||
'{"param": "test_value"}' # Correct parameter name
|
||||
)
|
||||
|
||||
mock_response_success = MagicMock()
|
||||
mock_response_success.response = None
|
||||
mock_response_success.tool_calls = [mock_tool_call_valid]
|
||||
mock_response_success.prompt_tokens = 50
|
||||
mock_response_success.completion_tokens = 25
|
||||
mock_response_success.reasoning = None
|
||||
mock_response_success.raw_response = MockChatCompletionMessage(
|
||||
role="assistant", content=None, tool_calls=[mock_tool_call_valid]
|
||||
)
|
||||
|
||||
# Mock llm_call to return different responses on different calls
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
# First call returns response that will trigger retry due to validation error
|
||||
# Second call returns successful response
|
||||
mock_llm_call.side_effect = [mock_response_retry, mock_response_success]
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
|
||||
# Should succeed after retry, demonstrating our helper function works
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify the tool output was generated successfully
|
||||
assert "tools_^_test_tool_~_param" in outputs
|
||||
assert outputs["tools_^_test_tool_~_param"] == "test_value"
|
||||
|
||||
# Verify conversation history was properly maintained
|
||||
assert "conversations" in outputs
|
||||
conversations = outputs["conversations"]
|
||||
assert len(conversations) > 0
|
||||
|
||||
# The conversations should contain properly converted raw_response objects as dicts
|
||||
# This would have failed with the original bug due to ChatCompletionMessage.get() error
|
||||
for msg in conversations:
|
||||
assert isinstance(msg, dict), f"Expected dict, got {type(msg)}"
|
||||
if msg.get("role") == "assistant":
|
||||
# Should have been converted from ChatCompletionMessage to dict
|
||||
assert "role" in msg
|
||||
|
||||
# Verify LLM was called twice (initial + 1 retry)
|
||||
assert mock_llm_call.call_count == 2
|
||||
|
||||
# Test case 2: Test with different raw_response types (Ollama string, dict)
|
||||
# Test Ollama string response
|
||||
mock_response_ollama = MagicMock()
|
||||
mock_response_ollama.response = "I'll help you with that."
|
||||
mock_response_ollama.tool_calls = None
|
||||
mock_response_ollama.prompt_tokens = 30
|
||||
mock_response_ollama.completion_tokens = 15
|
||||
mock_response_ollama.reasoning = None
|
||||
mock_response_ollama.raw_response = (
|
||||
"I'll help you with that." # Ollama returns string
|
||||
)
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_ollama,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[], # No tools for this test
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Simple prompt",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Should finish since no tool calls
|
||||
assert "finished" in outputs
|
||||
assert outputs["finished"] == "I'll help you with that."
|
||||
|
||||
# Test case 3: Test with dict raw_response (some providers/tests)
|
||||
mock_response_dict = MagicMock()
|
||||
mock_response_dict.response = "Test response"
|
||||
mock_response_dict.tool_calls = None
|
||||
mock_response_dict.prompt_tokens = 25
|
||||
mock_response_dict.completion_tokens = 10
|
||||
mock_response_dict.reasoning = None
|
||||
mock_response_dict.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": "Test response",
|
||||
} # Dict format
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_dict,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Another test",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
assert "finished" in outputs
|
||||
assert outputs["finished"] == "Test response"
|
||||
|
||||
@@ -48,16 +48,24 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
assert "parameters" in signature["function"]
|
||||
assert "properties" in signature["function"]["parameters"]
|
||||
|
||||
# Check that dynamic fields are handled
|
||||
# Check that dynamic fields are handled with original names
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 3 # Should have all three fields
|
||||
|
||||
# Each dynamic field should have proper schema
|
||||
for prop_value in properties.values():
|
||||
# Check that field names are cleaned (for Anthropic API compatibility)
|
||||
assert "values___name" in properties
|
||||
assert "values___age" in properties
|
||||
assert "values___city" in properties
|
||||
|
||||
# Each dynamic field should have proper schema with descriptive text
|
||||
for field_name, prop_value in properties.items():
|
||||
assert "type" in prop_value
|
||||
assert prop_value["type"] == "string" # Dynamic fields get string type
|
||||
assert "description" in prop_value
|
||||
assert "Dynamic value for" in prop_value["description"]
|
||||
# Check that descriptions properly explain the dynamic field
|
||||
if field_name == "values___name":
|
||||
assert "Dictionary field 'name'" in prop_value["description"]
|
||||
assert "values['name']" in prop_value["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -96,10 +104,18 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 2 # Should have both list items
|
||||
|
||||
# Each dynamic field should have proper schema
|
||||
for prop_value in properties.values():
|
||||
# Check that field names are cleaned (for Anthropic API compatibility)
|
||||
assert "entries___0" in properties
|
||||
assert "entries___1" in properties
|
||||
|
||||
# Each dynamic field should have proper schema with descriptive text
|
||||
for field_name, prop_value in properties.items():
|
||||
assert prop_value["type"] == "string"
|
||||
assert "Dynamic value for" in prop_value["description"]
|
||||
assert "description" in prop_value
|
||||
# Check that descriptions properly explain the list field
|
||||
if field_name == "entries___0":
|
||||
assert "List item 0" in prop_value["description"]
|
||||
assert "entries[0]" in prop_value["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -0,0 +1,553 @@
|
||||
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.text import MatchTextPatternBlock
|
||||
from backend.data.dynamic_fields import get_dynamic_field_description
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_field_description_generation():
|
||||
"""Test that dynamic field descriptions are generated correctly."""
|
||||
# Test dictionary field description
|
||||
desc = get_dynamic_field_description("values_#_name")
|
||||
assert "Dictionary field 'name' for base field 'values'" in desc
|
||||
assert "values['name']" in desc
|
||||
|
||||
# Test list field description
|
||||
desc = get_dynamic_field_description("items_$_0")
|
||||
assert "List item 0 for base field 'items'" in desc
|
||||
assert "items[0]" in desc
|
||||
|
||||
# Test object field description
|
||||
desc = get_dynamic_field_description("user_@_email")
|
||||
assert "Object attribute 'email' for base field 'user'" in desc
|
||||
assert "user.email" in desc
|
||||
|
||||
# Test regular field fallback
|
||||
desc = get_dynamic_field_description("regular_field")
|
||||
assert desc == "Value for regular_field"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_dict_fields():
|
||||
"""Test that function signatures are created correctly for dictionary dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node for CreateDictionaryBlock
|
||||
mock_node = Mock()
|
||||
mock_node.block = CreateDictionaryBlock()
|
||||
mock_node.block_id = CreateDictionaryBlock().id
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create mock links with dynamic dictionary fields (source sanitized, sink original)
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___name", # Sanitized source
|
||||
sink_name="values_#_name", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___age", # Sanitized source
|
||||
sink_name="values_#_age", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___email", # Sanitized source
|
||||
sink_name="values_#_email", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
|
||||
|
||||
# Verify the signature structure
|
||||
assert signature["type"] == "function"
|
||||
assert "function" in signature
|
||||
assert "parameters" in signature["function"]
|
||||
assert "properties" in signature["function"]["parameters"]
|
||||
|
||||
# Check that dynamic fields are handled with original names
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 3
|
||||
|
||||
# Check cleaned field names (for Anthropic API compatibility)
|
||||
assert "values___name" in properties
|
||||
assert "values___age" in properties
|
||||
assert "values___email" in properties
|
||||
|
||||
# Check descriptions mention they are dictionary fields
|
||||
assert "Dictionary field" in properties["values___name"]["description"]
|
||||
assert "values['name']" in properties["values___name"]["description"]
|
||||
|
||||
assert "Dictionary field" in properties["values___age"]["description"]
|
||||
assert "values['age']" in properties["values___age"]["description"]
|
||||
|
||||
assert "Dictionary field" in properties["values___email"]["description"]
|
||||
assert "values['email']" in properties["values___email"]["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_list_fields():
|
||||
"""Test that function signatures are created correctly for list dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node for AddToListBlock
|
||||
mock_node = Mock()
|
||||
mock_node.block = AddToListBlock()
|
||||
mock_node.block_id = AddToListBlock().id
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create mock links with dynamic list fields
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_0",
|
||||
sink_name="entries_$_0", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_1",
|
||||
sink_name="entries_$_1", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_2",
|
||||
sink_name="entries_$_2", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
|
||||
|
||||
# Verify the signature structure
|
||||
assert signature["type"] == "function"
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
|
||||
# Check cleaned field names (for Anthropic API compatibility)
|
||||
assert "entries___0" in properties
|
||||
assert "entries___1" in properties
|
||||
assert "entries___2" in properties
|
||||
|
||||
# Check descriptions mention they are list items
|
||||
assert "List item 0" in properties["entries___0"]["description"]
|
||||
assert "entries[0]" in properties["entries___0"]["description"]
|
||||
|
||||
assert "List item 1" in properties["entries___1"]["description"]
|
||||
assert "entries[1]" in properties["entries___1"]["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_object_fields():
|
||||
"""Test that function signatures are created correctly for object dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node for MatchTextPatternBlock (simulating object fields)
|
||||
mock_node = Mock()
|
||||
mock_node.block = MatchTextPatternBlock()
|
||||
mock_node.block_id = MatchTextPatternBlock().id
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create mock links with dynamic object fields
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_extract_~_user_name",
|
||||
sink_name="data_@_user_name", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_extract_~_user_email",
|
||||
sink_name="data_@_user_email", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
|
||||
|
||||
# Verify the signature structure
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
|
||||
# Check cleaned field names (for Anthropic API compatibility)
|
||||
assert "data___user_name" in properties
|
||||
assert "data___user_email" in properties
|
||||
|
||||
# Check descriptions mention they are object attributes
|
||||
assert "Object attribute" in properties["data___user_name"]["description"]
|
||||
assert "data.user_name" in properties["data___user_name"]["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_function_signature():
|
||||
"""Test that the mapping between sanitized and original field names is built correctly."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock the database client and connected nodes
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
) as mock_db:
|
||||
mock_client = AsyncMock()
|
||||
mock_db.return_value = mock_client
|
||||
|
||||
# Create mock nodes and links
|
||||
mock_dict_node = Mock()
|
||||
mock_dict_node.block = CreateDictionaryBlock()
|
||||
mock_dict_node.block_id = CreateDictionaryBlock().id
|
||||
mock_dict_node.input_default = {}
|
||||
|
||||
mock_list_node = Mock()
|
||||
mock_list_node.block = AddToListBlock()
|
||||
mock_list_node.block_id = AddToListBlock().id
|
||||
mock_list_node.input_default = {}
|
||||
|
||||
# Mock links with dynamic fields
|
||||
dict_link1 = Mock(
|
||||
source_name="tools_^_create_dictionary_~_name",
|
||||
sink_name="values_#_name",
|
||||
sink_id="dict_node_id",
|
||||
source_id="test_node_id",
|
||||
)
|
||||
dict_link2 = Mock(
|
||||
source_name="tools_^_create_dictionary_~_age",
|
||||
sink_name="values_#_age",
|
||||
sink_id="dict_node_id",
|
||||
source_id="test_node_id",
|
||||
)
|
||||
list_link = Mock(
|
||||
source_name="tools_^_add_to_list_~_0",
|
||||
sink_name="entries_$_0",
|
||||
sink_id="list_node_id",
|
||||
source_id="test_node_id",
|
||||
)
|
||||
|
||||
mock_client.get_connected_output_nodes.return_value = [
|
||||
(dict_link1, mock_dict_node),
|
||||
(dict_link2, mock_dict_node),
|
||||
(list_link, mock_list_node),
|
||||
]
|
||||
|
||||
# Call the method that builds signatures
|
||||
tool_functions = await block._create_function_signature("test_node_id")
|
||||
|
||||
# Verify we got 2 tool functions (one for dict, one for list)
|
||||
assert len(tool_functions) == 2
|
||||
|
||||
# Verify the tool functions contain the dynamic field names
|
||||
dict_tool = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == "createdictionaryblock"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert dict_tool is not None
|
||||
dict_properties = dict_tool["function"]["parameters"]["properties"]
|
||||
assert "values___name" in dict_properties
|
||||
assert "values___age" in dict_properties
|
||||
|
||||
list_tool = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == "addtolistblock"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert list_tool is not None
|
||||
list_properties = list_tool["function"]["parameters"]["properties"]
|
||||
assert "entries___0" in list_properties
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_yielding_with_dynamic_fields():
|
||||
"""Test that outputs are yielded correctly with dynamic field names mapped back."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# No more sanitized mapping needed since we removed sanitization
|
||||
|
||||
# Mock LLM response with tool calls
|
||||
mock_response = Mock()
|
||||
mock_response.tool_calls = [
|
||||
Mock(
|
||||
function=Mock(
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"values___name": "Alice",
|
||||
"values___age": 30,
|
||||
"values___email": "alice@example.com",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
]
|
||||
# Ensure function name is a real string, not a Mock name
|
||||
mock_response.tool_calls[0].function.name = "createdictionaryblock"
|
||||
mock_response.reasoning = "Creating a dictionary with user information"
|
||||
mock_response.raw_response = {"role": "assistant", "content": "test"}
|
||||
mock_response.prompt_tokens = 100
|
||||
mock_response.completion_tokens = 50
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
# Mock the function signature creation
|
||||
with patch.object(
|
||||
block, "_create_function_signature", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
mock_sig.return_value = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "createdictionaryblock",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"values___name": {"type": "string"},
|
||||
"values___age": {"type": "number"},
|
||||
"values___email": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Create input data
|
||||
from backend.blocks import llm
|
||||
|
||||
input_data = block.input_schema(
|
||||
prompt="Create a user dictionary",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
)
|
||||
|
||||
# Run the block
|
||||
outputs = {}
|
||||
async for output_name, output_value in block.run(
|
||||
input_data,
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
graph_id="test_graph",
|
||||
node_id="test_node",
|
||||
graph_exec_id="test_exec",
|
||||
node_exec_id="test_node_exec",
|
||||
user_id="test_user",
|
||||
):
|
||||
outputs[output_name] = output_value
|
||||
|
||||
# Verify the outputs use sanitized field names (matching frontend normalizeToolName)
|
||||
assert "tools_^_createdictionaryblock_~_values___name" in outputs
|
||||
assert outputs["tools_^_createdictionaryblock_~_values___name"] == "Alice"
|
||||
|
||||
assert "tools_^_createdictionaryblock_~_values___age" in outputs
|
||||
assert outputs["tools_^_createdictionaryblock_~_values___age"] == 30
|
||||
|
||||
assert "tools_^_createdictionaryblock_~_values___email" in outputs
|
||||
assert (
|
||||
outputs["tools_^_createdictionaryblock_~_values___email"]
|
||||
== "alice@example.com"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_regular_and_dynamic_fields():
|
||||
"""Test handling of blocks with both regular and dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create a mock node
|
||||
mock_node = Mock()
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "A test block"
|
||||
mock_node.block.input_schema = Mock()
|
||||
|
||||
# Mock the get_field_schema to return a proper schema for regular fields
|
||||
def get_field_schema(field_name):
|
||||
if field_name == "regular_field":
|
||||
return {"type": "string", "description": "A regular field"}
|
||||
elif field_name == "values":
|
||||
return {"type": "object", "description": "A dictionary field"}
|
||||
else:
|
||||
raise KeyError(f"Field {field_name} not found")
|
||||
|
||||
mock_node.block.input_schema.get_field_schema = get_field_schema
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": []}
|
||||
)
|
||||
|
||||
# Create links with both regular and dynamic fields
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_test_~_regular",
|
||||
sink_name="regular_field", # Regular field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key",
|
||||
sink_name="values_#_key1", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key2",
|
||||
sink_name="values_#_key2", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
|
||||
|
||||
# Check properties
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 3
|
||||
|
||||
# Regular field should have its original schema
|
||||
assert "regular_field" in properties
|
||||
assert properties["regular_field"]["description"] == "A regular field"
|
||||
|
||||
# Dynamic fields should have generated descriptions
|
||||
assert "values___key1" in properties
|
||||
assert "Dictionary field" in properties["values___key1"]["description"]
|
||||
|
||||
assert "values___key2" in properties
|
||||
assert "Dictionary field" in properties["values___key2"]["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_errors_dont_pollute_conversation():
|
||||
"""Test that validation errors are only used during retries and don't pollute the conversation."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Track conversation history changes
|
||||
conversation_snapshots = []
|
||||
|
||||
# Mock response with invalid tool call (missing required parameter)
|
||||
invalid_response = Mock()
|
||||
invalid_response.tool_calls = [
|
||||
Mock(
|
||||
function=Mock(
|
||||
arguments=json.dumps({"wrong_param": "value"}), # Wrong parameter name
|
||||
)
|
||||
)
|
||||
]
|
||||
# Ensure function name is a real string, not a Mock name
|
||||
invalid_response.tool_calls[0].function.name = "test_tool"
|
||||
invalid_response.reasoning = None
|
||||
invalid_response.raw_response = {"role": "assistant", "content": "invalid"}
|
||||
invalid_response.prompt_tokens = 100
|
||||
invalid_response.completion_tokens = 50
|
||||
|
||||
# Mock valid response after retry
|
||||
valid_response = Mock()
|
||||
valid_response.tool_calls = [
|
||||
Mock(function=Mock(arguments=json.dumps({"correct_param": "value"})))
|
||||
]
|
||||
# Ensure function name is a real string, not a Mock name
|
||||
valid_response.tool_calls[0].function.name = "test_tool"
|
||||
valid_response.reasoning = None
|
||||
valid_response.raw_response = {"role": "assistant", "content": "valid"}
|
||||
valid_response.prompt_tokens = 100
|
||||
valid_response.completion_tokens = 50
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal call_count
|
||||
# Capture conversation state
|
||||
conversation_snapshots.append(kwargs.get("prompt", []).copy())
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return invalid_response
|
||||
else:
|
||||
return valid_response
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.side_effect = mock_llm_call
|
||||
|
||||
# Mock the function signature creation
|
||||
with patch.object(
|
||||
block, "_create_function_signature", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
mock_sig.return_value = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"correct_param": {
|
||||
"type": "string",
|
||||
"description": "The correct parameter",
|
||||
}
|
||||
},
|
||||
"required": ["correct_param"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Create input data
|
||||
from backend.blocks import llm
|
||||
|
||||
input_data = block.input_schema(
|
||||
prompt="Test prompt",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
retry=3, # Allow retries
|
||||
)
|
||||
|
||||
# Run the block
|
||||
outputs = {}
|
||||
async for output_name, output_value in block.run(
|
||||
input_data,
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
graph_id="test_graph",
|
||||
node_id="test_node",
|
||||
graph_exec_id="test_exec",
|
||||
node_exec_id="test_node_exec",
|
||||
user_id="test_user",
|
||||
):
|
||||
outputs[output_name] = output_value
|
||||
|
||||
# Verify we had 2 LLM calls (initial + retry)
|
||||
assert call_count == 2
|
||||
|
||||
# Check the final conversation output
|
||||
final_conversation = outputs.get("conversations", [])
|
||||
|
||||
# The final conversation should NOT contain the validation error message
|
||||
error_messages = [
|
||||
msg
|
||||
for msg in final_conversation
|
||||
if msg.get("role") == "user"
|
||||
and "parameter errors" in msg.get("content", "")
|
||||
]
|
||||
assert (
|
||||
len(error_messages) == 0
|
||||
), "Validation error leaked into final conversation"
|
||||
|
||||
# The final conversation should only have the successful response
|
||||
assert final_conversation[-1]["content"] == "valid"
|
||||
131
autogpt_platform/backend/backend/blocks/test/test_table_input.py
Normal file
131
autogpt_platform/backend/backend/blocks/test/test_table_input.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import pytest
|
||||
|
||||
from backend.blocks.io import AgentTableInputBlock
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_block():
|
||||
"""Test the AgentTableInputBlock with basic input/output."""
|
||||
block = AgentTableInputBlock()
|
||||
await execute_block_test(block)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_with_data():
|
||||
"""Test AgentTableInputBlock with actual table data."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
input_data = block.Input(
|
||||
name="test_table",
|
||||
column_headers=["Name", "Age", "City"],
|
||||
value=[
|
||||
{"Name": "John", "Age": "30", "City": "New York"},
|
||||
{"Name": "Jane", "Age": "25", "City": "London"},
|
||||
{"Name": "Bob", "Age": "35", "City": "Paris"},
|
||||
],
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0][0] == "result"
|
||||
|
||||
result = output_data[0][1]
|
||||
assert len(result) == 3
|
||||
assert result[0]["Name"] == "John"
|
||||
assert result[1]["Age"] == "25"
|
||||
assert result[2]["City"] == "Paris"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_empty_data():
|
||||
"""Test AgentTableInputBlock with empty data."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
input_data = block.Input(
|
||||
name="empty_table", column_headers=["Col1", "Col2"], value=[]
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0][0] == "result"
|
||||
assert output_data[0][1] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_with_missing_columns():
|
||||
"""Test AgentTableInputBlock passes through data with missing columns as-is."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
input_data = block.Input(
|
||||
name="partial_table",
|
||||
column_headers=["Name", "Age", "City"],
|
||||
value=[
|
||||
{"Name": "John", "Age": "30"}, # Missing City
|
||||
{"Name": "Jane", "City": "London"}, # Missing Age
|
||||
{"Age": "35", "City": "Paris"}, # Missing Name
|
||||
],
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
result = output_data[0][1]
|
||||
assert len(result) == 3
|
||||
|
||||
# Check data is passed through as-is
|
||||
assert result[0] == {"Name": "John", "Age": "30"}
|
||||
assert result[1] == {"Name": "Jane", "City": "London"}
|
||||
assert result[2] == {"Age": "35", "City": "Paris"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_none_value():
|
||||
"""Test AgentTableInputBlock with None value returns empty list."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
input_data = block.Input(
|
||||
name="none_table", column_headers=["Name", "Age"], value=None
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0][0] == "result"
|
||||
assert output_data[0][1] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_with_default_headers():
|
||||
"""Test AgentTableInputBlock with default column headers."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
# Don't specify column_headers, should use defaults
|
||||
input_data = block.Input(
|
||||
name="default_headers_table",
|
||||
value=[
|
||||
{"Column 1": "A", "Column 2": "B", "Column 3": "C"},
|
||||
{"Column 1": "D", "Column 2": "E", "Column 3": "F"},
|
||||
],
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0][0] == "result"
|
||||
|
||||
result = output_data[0][1]
|
||||
assert len(result) == 2
|
||||
assert result[0]["Column 1"] == "A"
|
||||
assert result[1]["Column 3"] == "F"
|
||||
@@ -2,6 +2,8 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import regex # Has built-in timeout support
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json, text
|
||||
@@ -137,6 +139,11 @@ class ExtractTextInformationBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Security fix: Add limits to prevent ReDoS and memory exhaustion
|
||||
MAX_TEXT_LENGTH = 1_000_000 # 1MB character limit
|
||||
MAX_MATCHES = 1000 # Maximum number of matches to prevent memory exhaustion
|
||||
MAX_MATCH_LENGTH = 10_000 # Maximum length per match
|
||||
|
||||
flags = 0
|
||||
if not input_data.case_sensitive:
|
||||
flags = flags | re.IGNORECASE
|
||||
@@ -148,20 +155,85 @@ class ExtractTextInformationBlock(Block):
|
||||
else:
|
||||
txt = json.dumps(input_data.text)
|
||||
|
||||
matches = [
|
||||
match.group(input_data.group)
|
||||
for match in re.finditer(input_data.pattern, txt, flags)
|
||||
if input_data.group <= len(match.groups())
|
||||
]
|
||||
if not input_data.find_all:
|
||||
matches = matches[:1]
|
||||
for match in matches:
|
||||
yield "positive", match
|
||||
if not matches:
|
||||
yield "negative", input_data.text
|
||||
# Limit text size to prevent DoS
|
||||
if len(txt) > MAX_TEXT_LENGTH:
|
||||
txt = txt[:MAX_TEXT_LENGTH]
|
||||
|
||||
yield "matched_results", matches
|
||||
yield "matched_count", len(matches)
|
||||
# Validate regex pattern to prevent dangerous patterns
|
||||
dangerous_patterns = [
|
||||
r".*\+.*\+", # Nested quantifiers
|
||||
r".*\*.*\*", # Nested quantifiers
|
||||
r"(?=.*\+)", # Lookahead with quantifier
|
||||
r"(?=.*\*)", # Lookahead with quantifier
|
||||
r"\(.+\)\+", # Group with nested quantifier
|
||||
r"\(.+\)\*", # Group with nested quantifier
|
||||
r"\([^)]+\+\)\+", # Nested quantifiers like (a+)+
|
||||
r"\([^)]+\*\)\*", # Nested quantifiers like (a*)*
|
||||
]
|
||||
|
||||
# Check if pattern is potentially dangerous
|
||||
is_dangerous = any(
|
||||
re.search(dangerous, input_data.pattern) for dangerous in dangerous_patterns
|
||||
)
|
||||
|
||||
# Use regex module with timeout for dangerous patterns
|
||||
# For safe patterns, use standard re module for compatibility
|
||||
try:
|
||||
matches = []
|
||||
match_count = 0
|
||||
|
||||
if is_dangerous:
|
||||
# Use regex module with timeout (5 seconds) for dangerous patterns
|
||||
# The regex module supports timeout parameter in finditer
|
||||
try:
|
||||
for match in regex.finditer(
|
||||
input_data.pattern, txt, flags=flags, timeout=5.0
|
||||
):
|
||||
if match_count >= MAX_MATCHES:
|
||||
break
|
||||
if input_data.group <= len(match.groups()):
|
||||
match_text = match.group(input_data.group)
|
||||
# Limit match length to prevent memory exhaustion
|
||||
if len(match_text) > MAX_MATCH_LENGTH:
|
||||
match_text = match_text[:MAX_MATCH_LENGTH]
|
||||
matches.append(match_text)
|
||||
match_count += 1
|
||||
except regex.error as e:
|
||||
# Timeout occurred or regex error
|
||||
if "timeout" in str(e).lower():
|
||||
# Timeout - return empty results
|
||||
pass
|
||||
else:
|
||||
# Other regex error
|
||||
raise
|
||||
else:
|
||||
# Use standard re module for non-dangerous patterns
|
||||
for match in re.finditer(input_data.pattern, txt, flags):
|
||||
if match_count >= MAX_MATCHES:
|
||||
break
|
||||
if input_data.group <= len(match.groups()):
|
||||
match_text = match.group(input_data.group)
|
||||
# Limit match length to prevent memory exhaustion
|
||||
if len(match_text) > MAX_MATCH_LENGTH:
|
||||
match_text = match_text[:MAX_MATCH_LENGTH]
|
||||
matches.append(match_text)
|
||||
match_count += 1
|
||||
|
||||
if not input_data.find_all:
|
||||
matches = matches[:1]
|
||||
|
||||
for match in matches:
|
||||
yield "positive", match
|
||||
if not matches:
|
||||
yield "negative", input_data.text
|
||||
|
||||
yield "matched_results", matches
|
||||
yield "matched_count", len(matches)
|
||||
except Exception:
|
||||
# Return empty results on any regex error
|
||||
yield "negative", input_data.text
|
||||
yield "matched_results", []
|
||||
yield "matched_count", 0
|
||||
|
||||
|
||||
class FillTextTemplateBlock(Block):
|
||||
@@ -172,6 +244,11 @@ class FillTextTemplateBlock(Block):
|
||||
format: str = SchemaField(
|
||||
description="Template to format the text using `values`. Use Jinja2 syntax."
|
||||
)
|
||||
escape_html: bool = SchemaField(
|
||||
default=False,
|
||||
advanced=True,
|
||||
description="Whether to escape special characters in the inserted values to be HTML-safe. Enable for HTML output, disable for plain text.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: str = SchemaField(description="Formatted text")
|
||||
@@ -205,6 +282,7 @@ class FillTextTemplateBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
formatter = text.TextFormatter(autoescape=input_data.escape_html)
|
||||
yield "output", formatter.format_string(input_data.format, input_data.values)
|
||||
|
||||
|
||||
|
||||
@@ -270,13 +270,17 @@ class GetCurrentDateBlock(Block):
|
||||
test_output=[
|
||||
(
|
||||
"date",
|
||||
lambda t: abs(datetime.now() - datetime.strptime(t, "%Y-%m-%d"))
|
||||
< timedelta(days=8), # 7 days difference + 1 day error margin.
|
||||
lambda t: abs(
|
||||
datetime.now().date() - datetime.strptime(t, "%Y-%m-%d").date()
|
||||
)
|
||||
<= timedelta(days=8), # 7 days difference + 1 day error margin.
|
||||
),
|
||||
(
|
||||
"date",
|
||||
lambda t: abs(datetime.now() - datetime.strptime(t, "%m/%d/%Y"))
|
||||
< timedelta(days=8),
|
||||
lambda t: abs(
|
||||
datetime.now().date() - datetime.strptime(t, "%m/%d/%Y").date()
|
||||
)
|
||||
<= timedelta(days=8),
|
||||
# 7 days difference + 1 day error margin.
|
||||
),
|
||||
(
|
||||
@@ -382,7 +386,7 @@ class GetCurrentDateAndTimeBlock(Block):
|
||||
lambda t: abs(
|
||||
datetime.now().date() - datetime.strptime(t, "%Y/%m/%d").date()
|
||||
)
|
||||
< timedelta(days=1), # Date format only, no time component
|
||||
<= timedelta(days=1), # Date format only, no time component
|
||||
),
|
||||
(
|
||||
"date_time",
|
||||
|
||||
@@ -26,6 +26,14 @@ class XMLParserBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Security fix: Add size limits to prevent XML bomb attacks
|
||||
MAX_XML_SIZE = 10 * 1024 * 1024 # 10MB limit for XML input
|
||||
|
||||
if len(input_data.input_xml) > MAX_XML_SIZE:
|
||||
raise ValueError(
|
||||
f"XML too large: {len(input_data.input_xml)} bytes > {MAX_XML_SIZE} bytes"
|
||||
)
|
||||
|
||||
try:
|
||||
tokens = tokenize(input_data.input_xml)
|
||||
parser = Parser(tokens)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from youtube_transcript_api._api import YouTubeTranscriptApi
|
||||
from youtube_transcript_api._errors import NoTranscriptFound
|
||||
from youtube_transcript_api._transcripts import FetchedTranscript
|
||||
from youtube_transcript_api.formatters import TextFormatter
|
||||
|
||||
@@ -64,7 +65,29 @@ class TranscribeYoutubeVideoBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
def get_transcript(video_id: str) -> FetchedTranscript:
|
||||
return YouTubeTranscriptApi().fetch(video_id=video_id)
|
||||
"""
|
||||
Get transcript for a video, preferring English but falling back to any available language.
|
||||
|
||||
:param video_id: The YouTube video ID
|
||||
:return: The fetched transcript
|
||||
:raises: Any exception except NoTranscriptFound for requested languages
|
||||
"""
|
||||
api = YouTubeTranscriptApi()
|
||||
try:
|
||||
# Try to get English transcript first (default behavior)
|
||||
return api.fetch(video_id=video_id)
|
||||
except NoTranscriptFound:
|
||||
# If English is not available, get the first available transcript
|
||||
transcript_list = api.list(video_id)
|
||||
# Try manually created transcripts first, then generated ones
|
||||
available_transcripts = list(
|
||||
transcript_list._manually_created_transcripts.values()
|
||||
) + list(transcript_list._generated_transcripts.values())
|
||||
if available_transcripts:
|
||||
# Fetch the first available transcript
|
||||
return available_transcripts[0].fetch()
|
||||
# If no transcripts at all, re-raise the original error
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def format_transcript(transcript: FetchedTranscript) -> str:
|
||||
|
||||
@@ -45,9 +45,6 @@ class MainApp(AppProcess):
|
||||
|
||||
app.main(silent=True)
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
|
||||
@click.group()
|
||||
def main():
|
||||
|
||||
@@ -9,6 +9,7 @@ from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.includes import MAX_USER_API_KEYS_FETCH
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -178,9 +179,13 @@ async def revoke_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
|
||||
async def list_user_api_keys(user_id: str) -> list[APIKeyInfo]:
|
||||
async def list_user_api_keys(
|
||||
user_id: str, limit: int = MAX_USER_API_KEYS_FETCH
|
||||
) -> list[APIKeyInfo]:
|
||||
api_keys = await PrismaAPIKey.prisma().find_many(
|
||||
where={"userId": user_id}, order={"createdAt": "desc"}
|
||||
where={"userId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
|
||||
return [APIKeyInfo.from_db(key) for key in api_keys]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
@@ -28,6 +27,7 @@ from pydantic import BaseModel
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.cache import cached
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .model import (
|
||||
@@ -722,7 +722,7 @@ def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
|
||||
return cls() if cls else None
|
||||
|
||||
|
||||
@functools.cache
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_webhook_block_ids() -> Sequence[str]:
|
||||
return [
|
||||
id
|
||||
@@ -731,7 +731,7 @@ def get_webhook_block_ids() -> Sequence[str]:
|
||||
]
|
||||
|
||||
|
||||
@functools.cache
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_io_block_ids() -> Sequence[str]:
|
||||
return [
|
||||
id
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from typing import Type
|
||||
|
||||
from backend.blocks.ai_music_generator import AIMusicGeneratorBlock
|
||||
from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
|
||||
from backend.blocks.ai_shortform_video_block import (
|
||||
AIAdMakerVideoCreatorBlock,
|
||||
AIScreenshotToVideoAdBlock,
|
||||
AIShortformVideoCreatorBlock,
|
||||
)
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
from backend.blocks.apollo.person import GetPersonDetailBlock
|
||||
@@ -69,9 +73,9 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_5_SONNET: 4,
|
||||
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
|
||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
||||
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
||||
@@ -321,7 +325,31 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
],
|
||||
AIShortformVideoCreatorBlock: [
|
||||
BlockCost(
|
||||
cost_amount=50,
|
||||
cost_amount=307,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": revid_credentials.id,
|
||||
"provider": revid_credentials.provider,
|
||||
"type": revid_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
AIAdMakerVideoCreatorBlock: [
|
||||
BlockCost(
|
||||
cost_amount=714,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": revid_credentials.id,
|
||||
"provider": revid_credentials.provider,
|
||||
"type": revid_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
AIScreenshotToVideoAdBlock: [
|
||||
BlockCost(
|
||||
cost_amount=612,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": revid_credentials.id,
|
||||
|
||||
@@ -5,7 +5,6 @@ from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import stripe
|
||||
from prisma import Json
|
||||
from prisma.enums import (
|
||||
CreditRefundRequestStatus,
|
||||
CreditTransactionType,
|
||||
@@ -13,16 +12,13 @@ from prisma.enums import (
|
||||
OnboardingStep,
|
||||
)
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User
|
||||
from prisma.types import (
|
||||
CreditRefundRequestCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
CreditTransactionWhereInput,
|
||||
)
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
|
||||
from prisma.types import CreditRefundRequestCreateInput, CreditTransactionWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -35,7 +31,8 @@ from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.server.v2.admin.model import UserHistoryResponse
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.json import SafeJson, dumps
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.settings import Settings
|
||||
@@ -48,6 +45,10 @@ stripe.api_key = settings.secrets.stripe_api_key
|
||||
logger = logging.getLogger(__name__)
|
||||
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
|
||||
# Constants for test compatibility
|
||||
POSTGRES_INT_MAX = 2147483647
|
||||
POSTGRES_INT_MIN = -2147483648
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
graph_exec_id: str | None = None
|
||||
@@ -138,14 +139,20 @@ class UserCreditBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
|
||||
async def onboarding_reward(
|
||||
self, user_id: str, credits: int, step: OnboardingStep
|
||||
) -> bool:
|
||||
"""
|
||||
Reward the user with credits for completing an onboarding step.
|
||||
Won't reward if the user has already received credits for the step.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
credits (int): The amount to reward.
|
||||
step (OnboardingStep): The onboarding step.
|
||||
|
||||
Returns:
|
||||
bool: True if rewarded, False if already rewarded.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -235,6 +242,12 @@ class UserCreditBase(ABC):
|
||||
"""
|
||||
Returns the current balance of the user & the latest balance snapshot time.
|
||||
"""
|
||||
# Check UserBalance first for efficiency and consistency
|
||||
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
|
||||
if user_balance:
|
||||
return user_balance.balance, user_balance.updatedAt
|
||||
|
||||
# Fallback to transaction history computation if UserBalance doesn't exist
|
||||
top_time = self.time_now()
|
||||
snapshot = await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
@@ -249,72 +262,86 @@ class UserCreditBase(ABC):
|
||||
snapshot_balance = snapshot.runningBalance or 0 if snapshot else 0
|
||||
snapshot_time = snapshot.createdAt if snapshot else datetime_min
|
||||
|
||||
# Get transactions after the snapshot, this should not exist, but just in case.
|
||||
transactions = await CreditTransaction.prisma().group_by(
|
||||
by=["userId"],
|
||||
sum={"amount": True},
|
||||
max={"createdAt": True},
|
||||
where={
|
||||
"userId": user_id,
|
||||
"createdAt": {
|
||||
"gt": snapshot_time,
|
||||
"lte": top_time,
|
||||
},
|
||||
"isActive": True,
|
||||
},
|
||||
)
|
||||
transaction_balance = (
|
||||
int(transactions[0].get("_sum", {}).get("amount", 0) + snapshot_balance)
|
||||
if transactions
|
||||
else snapshot_balance
|
||||
)
|
||||
transaction_time = (
|
||||
datetime.fromisoformat(
|
||||
str(transactions[0].get("_max", {}).get("createdAt", datetime_min))
|
||||
)
|
||||
if transactions
|
||||
else snapshot_time
|
||||
)
|
||||
return transaction_balance, transaction_time
|
||||
return snapshot_balance, snapshot_time
|
||||
|
||||
@func_retry
|
||||
async def _enable_transaction(
|
||||
self,
|
||||
transaction_key: str,
|
||||
user_id: str,
|
||||
metadata: Json,
|
||||
metadata: SafeJson,
|
||||
new_transaction_key: str | None = None,
|
||||
):
|
||||
transaction = await CreditTransaction.prisma().find_first_or_raise(
|
||||
where={"transactionKey": transaction_key, "userId": user_id}
|
||||
# First check if transaction exists and is inactive (safety check)
|
||||
transaction = await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
"transactionKey": transaction_key,
|
||||
"userId": user_id,
|
||||
"isActive": False,
|
||||
}
|
||||
)
|
||||
if transaction.isActive:
|
||||
return
|
||||
if not transaction:
|
||||
# Transaction doesn't exist or is already active, return early
|
||||
return None
|
||||
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}"):
|
||||
|
||||
transaction = await CreditTransaction.prisma().find_first_or_raise(
|
||||
where={"transactionKey": transaction_key, "userId": user_id}
|
||||
# Atomic operation to enable transaction and update user balance using UserBalance
|
||||
result = await query_raw_with_schema(
|
||||
"""
|
||||
WITH user_balance_lock AS (
|
||||
SELECT
|
||||
$2::text as userId,
|
||||
COALESCE(
|
||||
(SELECT balance FROM {schema_prefix}"UserBalance" WHERE "userId" = $2 FOR UPDATE),
|
||||
-- Fallback: compute balance from transaction history if UserBalance doesn't exist
|
||||
(SELECT COALESCE(ct."runningBalance", 0)
|
||||
FROM {schema_prefix}"CreditTransaction" ct
|
||||
WHERE ct."userId" = $2
|
||||
AND ct."isActive" = true
|
||||
AND ct."runningBalance" IS NOT NULL
|
||||
ORDER BY ct."createdAt" DESC
|
||||
LIMIT 1),
|
||||
0
|
||||
) as balance
|
||||
),
|
||||
transaction_check AS (
|
||||
SELECT * FROM {schema_prefix}"CreditTransaction"
|
||||
WHERE "transactionKey" = $1 AND "userId" = $2 AND "isActive" = false
|
||||
),
|
||||
balance_update AS (
|
||||
INSERT INTO {schema_prefix}"UserBalance" ("userId", "balance", "updatedAt")
|
||||
SELECT
|
||||
$2::text,
|
||||
user_balance_lock.balance + transaction_check.amount,
|
||||
CURRENT_TIMESTAMP
|
||||
FROM user_balance_lock, transaction_check
|
||||
ON CONFLICT ("userId") DO UPDATE SET
|
||||
"balance" = EXCLUDED."balance",
|
||||
"updatedAt" = EXCLUDED."updatedAt"
|
||||
RETURNING "balance", "updatedAt"
|
||||
),
|
||||
transaction_update AS (
|
||||
UPDATE {schema_prefix}"CreditTransaction"
|
||||
SET "transactionKey" = COALESCE($4, $1),
|
||||
"isActive" = true,
|
||||
"runningBalance" = balance_update.balance,
|
||||
"createdAt" = balance_update."updatedAt",
|
||||
"metadata" = $3::jsonb
|
||||
FROM balance_update, transaction_check
|
||||
WHERE {schema_prefix}"CreditTransaction"."transactionKey" = transaction_check."transactionKey"
|
||||
AND {schema_prefix}"CreditTransaction"."userId" = transaction_check."userId"
|
||||
RETURNING {schema_prefix}"CreditTransaction"."runningBalance"
|
||||
)
|
||||
if transaction.isActive:
|
||||
return
|
||||
SELECT "runningBalance" as balance FROM transaction_update;
|
||||
""",
|
||||
transaction_key, # $1
|
||||
user_id, # $2
|
||||
dumps(metadata.data), # $3 - use pre-serialized JSON string for JSONB
|
||||
new_transaction_key, # $4
|
||||
)
|
||||
|
||||
user_balance, _ = await self._get_credits(user_id)
|
||||
await CreditTransaction.prisma().update(
|
||||
where={
|
||||
"creditTransactionIdentifier": {
|
||||
"transactionKey": transaction_key,
|
||||
"userId": user_id,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"transactionKey": new_transaction_key or transaction_key,
|
||||
"isActive": True,
|
||||
"runningBalance": user_balance + transaction.amount,
|
||||
"createdAt": self.time_now(),
|
||||
"metadata": metadata,
|
||||
},
|
||||
)
|
||||
if result:
|
||||
# UserBalance is already updated by the CTE
|
||||
return result[0]["balance"]
|
||||
|
||||
async def _add_transaction(
|
||||
self,
|
||||
@@ -325,12 +352,54 @@ class UserCreditBase(ABC):
|
||||
transaction_key: str | None = None,
|
||||
ceiling_balance: int | None = None,
|
||||
fail_insufficient_credits: bool = True,
|
||||
metadata: Json = SafeJson({}),
|
||||
metadata: SafeJson = SafeJson({}),
|
||||
) -> tuple[int, str]:
|
||||
"""
|
||||
Add a new transaction for the user.
|
||||
This is the only method that should be used to add a new transaction.
|
||||
|
||||
ATOMIC OPERATION DESIGN DECISION:
|
||||
================================
|
||||
This method uses PostgreSQL row-level locking (FOR UPDATE) for atomic credit operations.
|
||||
After extensive analysis of concurrency patterns and correctness requirements, we determined
|
||||
that the FOR UPDATE approach is necessary despite the latency overhead.
|
||||
|
||||
WHY FOR UPDATE LOCKING IS REQUIRED:
|
||||
----------------------------------
|
||||
1. **Data Consistency**: Credit operations must be ACID-compliant. The balance check,
|
||||
calculation, and update must be atomic to prevent race conditions where:
|
||||
- Multiple spend operations could exceed available balance
|
||||
- Lost update problems could occur with concurrent top-ups
|
||||
- Refunds could create negative balances incorrectly
|
||||
|
||||
2. **Serializability**: FOR UPDATE ensures operations are serialized at the database level,
|
||||
guaranteeing that each transaction sees a consistent view of the balance before applying changes.
|
||||
|
||||
3. **Correctness Over Performance**: Financial operations require absolute correctness.
|
||||
The ~10-50ms latency increase from row locking is acceptable for the guarantee that
|
||||
no user will ever have an incorrect balance due to race conditions.
|
||||
|
||||
4. **PostgreSQL Optimization**: Modern PostgreSQL versions optimize row locks efficiently.
|
||||
The performance cost is minimal compared to the complexity and risk of lock-free approaches.
|
||||
|
||||
ALTERNATIVES CONSIDERED AND REJECTED:
|
||||
------------------------------------
|
||||
- **Optimistic Concurrency**: Using version numbers or timestamps would require complex
|
||||
retry logic and could still fail under high contention scenarios.
|
||||
- **Application-Level Locking**: Redis locks or similar would add network overhead and
|
||||
single points of failure while being less reliable than database locks.
|
||||
- **Event Sourcing**: Would require complete architectural changes and eventual consistency
|
||||
models that don't fit our real-time balance requirements.
|
||||
|
||||
PERFORMANCE CHARACTERISTICS:
|
||||
---------------------------
|
||||
- Single user operations: 10-50ms latency (acceptable for financial operations)
|
||||
- Concurrent operations on same user: Serialized (prevents data corruption)
|
||||
- Concurrent operations on different users: Fully parallel (no blocking)
|
||||
|
||||
This design prioritizes correctness and data integrity over raw performance,
|
||||
which is the appropriate choice for a credit/payment system.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
amount (int): The amount of credits to add.
|
||||
@@ -344,40 +413,142 @@ class UserCreditBase(ABC):
|
||||
Returns:
|
||||
tuple[int, str]: The new balance & the transaction key.
|
||||
"""
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}"):
|
||||
# Get latest balance snapshot
|
||||
user_balance, _ = await self._get_credits(user_id)
|
||||
|
||||
if ceiling_balance and amount > 0 and user_balance >= ceiling_balance:
|
||||
# Quick validation for ceiling balance to avoid unnecessary database operations
|
||||
if ceiling_balance and amount > 0:
|
||||
current_balance, _ = await self._get_credits(user_id)
|
||||
if current_balance >= ceiling_balance:
|
||||
raise ValueError(
|
||||
f"You already have enough balance of ${user_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
|
||||
f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
|
||||
)
|
||||
|
||||
if amount < 0 and user_balance + amount < 0:
|
||||
if fail_insufficient_credits:
|
||||
raise InsufficientBalanceError(
|
||||
message=f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}",
|
||||
user_id=user_id,
|
||||
balance=user_balance,
|
||||
amount=amount,
|
||||
# Single unified atomic operation for all transaction types using UserBalance
|
||||
try:
|
||||
result = await query_raw_with_schema(
|
||||
"""
|
||||
WITH user_balance_lock AS (
|
||||
SELECT
|
||||
$1::text as userId,
|
||||
-- CRITICAL: FOR UPDATE lock prevents concurrent modifications to the same user's balance
|
||||
-- This ensures atomic read-modify-write operations and prevents race conditions
|
||||
COALESCE(
|
||||
(SELECT balance FROM {schema_prefix}"UserBalance" WHERE "userId" = $1 FOR UPDATE),
|
||||
-- Fallback: compute balance from transaction history if UserBalance doesn't exist
|
||||
(SELECT COALESCE(ct."runningBalance", 0)
|
||||
FROM {schema_prefix}"CreditTransaction" ct
|
||||
WHERE ct."userId" = $1
|
||||
AND ct."isActive" = true
|
||||
AND ct."runningBalance" IS NOT NULL
|
||||
ORDER BY ct."createdAt" DESC
|
||||
LIMIT 1),
|
||||
0
|
||||
) as balance
|
||||
),
|
||||
balance_update AS (
|
||||
INSERT INTO {schema_prefix}"UserBalance" ("userId", "balance", "updatedAt")
|
||||
SELECT
|
||||
$1::text,
|
||||
CASE
|
||||
-- For inactive transactions: Don't update balance
|
||||
WHEN $5::boolean = false THEN user_balance_lock.balance
|
||||
-- For ceiling balance (amount > 0): Apply ceiling
|
||||
WHEN $2 > 0 AND $7::int IS NOT NULL AND user_balance_lock.balance > $7::int - $2 THEN $7::int
|
||||
-- For regular operations: Apply with overflow/underflow protection
|
||||
WHEN user_balance_lock.balance + $2 > $6::int THEN $6::int
|
||||
WHEN user_balance_lock.balance + $2 < $10::int THEN $10::int
|
||||
ELSE user_balance_lock.balance + $2
|
||||
END,
|
||||
CURRENT_TIMESTAMP
|
||||
FROM user_balance_lock
|
||||
WHERE (
|
||||
$5::boolean = false OR -- Allow inactive transactions
|
||||
$2 >= 0 OR -- Allow positive amounts (top-ups, grants)
|
||||
$8::boolean = false OR -- Allow when insufficient balance check is disabled
|
||||
user_balance_lock.balance + $2 >= 0 -- Allow spending only when sufficient balance
|
||||
)
|
||||
ON CONFLICT ("userId") DO UPDATE SET
|
||||
"balance" = EXCLUDED."balance",
|
||||
"updatedAt" = EXCLUDED."updatedAt"
|
||||
RETURNING "balance", "updatedAt"
|
||||
),
|
||||
transaction_insert AS (
|
||||
INSERT INTO {schema_prefix}"CreditTransaction" (
|
||||
"userId", "amount", "type", "runningBalance",
|
||||
"metadata", "isActive", "createdAt", "transactionKey"
|
||||
)
|
||||
SELECT
|
||||
$1::text,
|
||||
$2::int,
|
||||
$3::text::{schema_prefix}"CreditTransactionType",
|
||||
CASE
|
||||
-- For inactive transactions: Set runningBalance to original balance (don't apply the change yet)
|
||||
WHEN $5::boolean = false THEN user_balance_lock.balance
|
||||
ELSE COALESCE(balance_update.balance, user_balance_lock.balance)
|
||||
END,
|
||||
$4::jsonb,
|
||||
$5::boolean,
|
||||
COALESCE(balance_update."updatedAt", CURRENT_TIMESTAMP),
|
||||
COALESCE($9, gen_random_uuid()::text)
|
||||
FROM user_balance_lock
|
||||
LEFT JOIN balance_update ON true
|
||||
WHERE (
|
||||
$5::boolean = false OR -- Allow inactive transactions
|
||||
$2 >= 0 OR -- Allow positive amounts (top-ups, grants)
|
||||
$8::boolean = false OR -- Allow when insufficient balance check is disabled
|
||||
user_balance_lock.balance + $2 >= 0 -- Allow spending only when sufficient balance
|
||||
)
|
||||
RETURNING "runningBalance", "transactionKey"
|
||||
)
|
||||
SELECT "runningBalance" as balance, "transactionKey" FROM transaction_insert;
|
||||
""",
|
||||
user_id, # $1
|
||||
amount, # $2
|
||||
transaction_type.value, # $3
|
||||
dumps(metadata.data), # $4 - use pre-serialized JSON string for JSONB
|
||||
is_active, # $5
|
||||
POSTGRES_INT_MAX, # $6 - overflow protection
|
||||
ceiling_balance, # $7 - ceiling balance (nullable)
|
||||
fail_insufficient_credits, # $8 - check balance for spending
|
||||
transaction_key, # $9 - transaction key (nullable)
|
||||
POSTGRES_INT_MIN, # $10 - underflow protection
|
||||
)
|
||||
except Exception as e:
|
||||
# Convert raw SQL unique constraint violations to UniqueViolationError
|
||||
# for consistent exception handling throughout the codebase
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"already exists" in error_str
|
||||
or "duplicate key" in error_str
|
||||
or "unique constraint" in error_str
|
||||
):
|
||||
# Extract table and constraint info for better error messages
|
||||
# Re-raise as a UniqueViolationError but with proper format
|
||||
# Create a minimal data structure that the error constructor expects
|
||||
raise UniqueViolationError({"error": str(e), "user_facing_error": {}})
|
||||
# For any other error, re-raise as-is
|
||||
raise
|
||||
|
||||
amount = min(-user_balance, 0)
|
||||
if result:
|
||||
new_balance, tx_key = result[0]["balance"], result[0]["transactionKey"]
|
||||
# UserBalance is already updated by the CTE
|
||||
return new_balance, tx_key
|
||||
|
||||
# Create the transaction
|
||||
transaction_data: CreditTransactionCreateInput = {
|
||||
"userId": user_id,
|
||||
"amount": amount,
|
||||
"runningBalance": user_balance + amount,
|
||||
"type": transaction_type,
|
||||
"metadata": metadata,
|
||||
"isActive": is_active,
|
||||
"createdAt": self.time_now(),
|
||||
}
|
||||
if transaction_key:
|
||||
transaction_data["transactionKey"] = transaction_key
|
||||
tx = await CreditTransaction.prisma().create(data=transaction_data)
|
||||
return user_balance + amount, tx.transactionKey
|
||||
# If no result, either user doesn't exist or insufficient balance
|
||||
user = await User.prisma().find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
raise ValueError(f"User {user_id} not found")
|
||||
|
||||
# Must be insufficient balance for spending operation
|
||||
if amount < 0 and fail_insufficient_credits:
|
||||
current_balance, _ = await self._get_credits(user_id)
|
||||
raise InsufficientBalanceError(
|
||||
message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}",
|
||||
user_id=user_id,
|
||||
balance=current_balance,
|
||||
amount=amount,
|
||||
)
|
||||
|
||||
# Unexpected case
|
||||
raise ValueError(f"Transaction failed for user {user_id}, amount {amount}")
|
||||
|
||||
|
||||
class UserCredit(UserCreditBase):
|
||||
@@ -452,9 +623,10 @@ class UserCredit(UserCreditBase):
|
||||
{"reason": f"Reward for completing {step.value} onboarding step."}
|
||||
),
|
||||
)
|
||||
return True
|
||||
except UniqueViolationError:
|
||||
# Already rewarded for this step
|
||||
pass
|
||||
# User already received this reward
|
||||
return False
|
||||
|
||||
async def top_up_refund(
|
||||
self, user_id: str, transaction_key: str, metadata: dict[str, str]
|
||||
@@ -643,7 +815,7 @@ class UserCredit(UserCreditBase):
|
||||
):
|
||||
# init metadata, without sharing it with the world
|
||||
metadata = metadata or {}
|
||||
if not metadata["reason"]:
|
||||
if not metadata.get("reason"):
|
||||
match top_up_type:
|
||||
case TopUpType.MANUAL:
|
||||
metadata["reason"] = {"reason": f"Top up credits for {user_id}"}
|
||||
@@ -905,7 +1077,9 @@ class UserCredit(UserCreditBase):
|
||||
),
|
||||
)
|
||||
|
||||
async def get_refund_requests(self, user_id: str) -> list[RefundRequest]:
|
||||
async def get_refund_requests(
|
||||
self, user_id: str, limit: int = MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
) -> list[RefundRequest]:
|
||||
return [
|
||||
RefundRequest(
|
||||
id=r.id,
|
||||
@@ -921,6 +1095,7 @@ class UserCredit(UserCreditBase):
|
||||
for r in await CreditRefundRequest.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -970,8 +1145,8 @@ class DisabledUserCredit(UserCreditBase):
|
||||
async def top_up_credits(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def onboarding_reward(self, *args, **kwargs):
|
||||
pass
|
||||
async def onboarding_reward(self, *args, **kwargs) -> bool:
|
||||
return True
|
||||
|
||||
async def top_up_intent(self, *args, **kwargs) -> str:
|
||||
return ""
|
||||
@@ -989,14 +1164,31 @@ class DisabledUserCredit(UserCreditBase):
|
||||
pass
|
||||
|
||||
|
||||
def get_user_credit_model() -> UserCreditBase:
|
||||
async def get_user_credit_model(user_id: str) -> UserCreditBase:
|
||||
"""
|
||||
Get the credit model for a user, considering LaunchDarkly flags.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID to check flags for.
|
||||
|
||||
Returns:
|
||||
UserCreditBase: The appropriate credit model for the user
|
||||
"""
|
||||
if not settings.config.enable_credit:
|
||||
return DisabledUserCredit()
|
||||
|
||||
if settings.config.enable_beta_monthly_credit:
|
||||
return BetaUserCredit(settings.config.num_user_credits_refill)
|
||||
# Check LaunchDarkly flag for payment pilot users
|
||||
# Default to False (beta monthly credit behavior) to maintain current behavior
|
||||
is_payment_enabled = await is_feature_enabled(
|
||||
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
|
||||
)
|
||||
|
||||
return UserCredit()
|
||||
if is_payment_enabled:
|
||||
# Payment enabled users get UserCredit (no monthly refills, enable payments)
|
||||
return UserCredit()
|
||||
else:
|
||||
# Default behavior: users get beta monthly credits
|
||||
return BetaUserCredit(settings.config.num_user_credits_refill)
|
||||
|
||||
|
||||
def get_block_costs() -> dict[str, list["BlockCost"]]:
|
||||
@@ -1086,7 +1278,8 @@ async def admin_get_user_history(
|
||||
)
|
||||
reason = metadata.get("reason", "No reason provided")
|
||||
|
||||
balance, last_update = await get_user_credit_model()._get_credits(tx.userId)
|
||||
user_credit_model = await get_user_credit_model(tx.userId)
|
||||
balance, _ = await user_credit_model._get_credits(tx.userId)
|
||||
|
||||
history.append(
|
||||
UserTransaction(
|
||||
|
||||
172
autogpt_platform/backend/backend/data/credit_ceiling_test.py
Normal file
172
autogpt_platform/backend/backend/data/credit_ceiling_test.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Test ceiling balance functionality to ensure auto top-up limits work correctly.
|
||||
|
||||
This test was added to cover a previously untested code path that could lead to
|
||||
incorrect balance capping behavior.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for ceiling tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
pass
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_test_user(user_id: str) -> None:
|
||||
"""Clean up test user and their transactions."""
|
||||
try:
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await User.prisma().delete_many(where={"id": user_id})
|
||||
except Exception as e:
|
||||
# Log cleanup failures but don't fail the test
|
||||
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_ceiling_balance_rejects_when_above_threshold(server: SpinTestServer):
|
||||
"""Test that ceiling balance correctly rejects top-ups when balance is above threshold."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"ceiling-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user balance of 1000 ($10) using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "initial_balance"}),
|
||||
)
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
assert current_balance == 1000
|
||||
|
||||
# Try to add 200 more with ceiling of 800 (should reject since 1000 > 800)
|
||||
with pytest.raises(ValueError, match="You already have enough balance"):
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=200,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
ceiling_balance=800, # Ceiling lower than current balance
|
||||
)
|
||||
|
||||
# Balance should remain unchanged
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 1000, f"Balance should remain 1000, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_ceiling_balance_clamps_when_would_exceed(server: SpinTestServer):
|
||||
"""Test that ceiling balance correctly clamps amounts that would exceed the ceiling."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"ceiling-clamp-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user balance of 500 ($5) using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "initial_balance"}),
|
||||
)
|
||||
|
||||
# Add 800 more with ceiling of 1000 (should clamp to 1000, not reach 1300)
|
||||
final_balance, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=800,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
ceiling_balance=1000, # Ceiling should clamp 500 + 800 = 1300 to 1000
|
||||
)
|
||||
|
||||
# Balance should be clamped to ceiling
|
||||
assert (
|
||||
final_balance == 1000
|
||||
), f"Balance should be clamped to 1000, got {final_balance}"
|
||||
|
||||
# Verify with get_credits too
|
||||
stored_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
stored_balance == 1000
|
||||
), f"Stored balance should be 1000, got {stored_balance}"
|
||||
|
||||
# Verify transaction shows the clamped amount
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id, "type": CreditTransactionType.TOP_UP},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
|
||||
# Should have 2 transactions: 500 + (500 to reach ceiling of 1000)
|
||||
assert len(transactions) == 2
|
||||
|
||||
# The second transaction should show it only added 500, not 800
|
||||
second_tx = transactions[0] # Most recent
|
||||
assert second_tx.runningBalance == 1000
|
||||
# The actual amount recorded could be 800 (what was requested) but balance was clamped
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_ceiling_balance_allows_when_under_threshold(server: SpinTestServer):
|
||||
"""Test that ceiling balance allows top-ups when balance is under threshold."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"ceiling-under-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user balance of 300 ($3) using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=300,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "initial_balance"}),
|
||||
)
|
||||
|
||||
# Add 200 more with ceiling of 1000 (should succeed: 300 + 200 = 500 < 1000)
|
||||
final_balance, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=200,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
ceiling_balance=1000,
|
||||
)
|
||||
|
||||
# Balance should be exactly 500
|
||||
assert final_balance == 500, f"Balance should be 500, got {final_balance}"
|
||||
|
||||
# Verify with get_credits too
|
||||
stored_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
stored_balance == 500
|
||||
), f"Stored balance should be 500, got {stored_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
737
autogpt_platform/backend/backend/data/credit_concurrency_test.py
Normal file
737
autogpt_platform/backend/backend/data/credit_concurrency_test.py
Normal file
@@ -0,0 +1,737 @@
|
||||
"""
|
||||
Concurrency and atomicity tests for the credit system.
|
||||
|
||||
These tests ensure the credit system handles high-concurrency scenarios correctly
|
||||
without race conditions, deadlocks, or inconsistent state.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from uuid import uuid4
|
||||
|
||||
import prisma.enums
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
# Test with both UserCredit and BetaUserCredit if needed
|
||||
credit_system = UserCredit()
|
||||
|
||||
|
||||
async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user with initial balance."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
pass
|
||||
|
||||
# Ensure UserBalance record exists
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_test_user(user_id: str) -> None:
|
||||
"""Clean up test user and their transactions."""
|
||||
try:
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
await User.prisma().delete_many(where={"id": user_id})
|
||||
except Exception as e:
|
||||
# Log cleanup failures but don't fail the test
|
||||
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_spends_same_user(server: SpinTestServer):
|
||||
"""Test multiple concurrent spends from the same user don't cause race conditions."""
|
||||
user_id = f"concurrent-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user initial balance using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "initial_balance"}),
|
||||
)
|
||||
|
||||
# Try to spend 10 x $1 concurrently
|
||||
async def spend_one_dollar(idx: int):
|
||||
try:
|
||||
return await credit_system.spend_credits(
|
||||
user_id,
|
||||
100, # $1
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=f"concurrent-{idx}",
|
||||
reason=f"Concurrent spend {idx}",
|
||||
),
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
return None
|
||||
|
||||
# Run 10 concurrent spends
|
||||
results = await asyncio.gather(
|
||||
*[spend_one_dollar(i) for i in range(10)], return_exceptions=True
|
||||
)
|
||||
|
||||
# Count successful spends
|
||||
successful = [
|
||||
r for r in results if r is not None and not isinstance(r, Exception)
|
||||
]
|
||||
failed = [r for r in results if isinstance(r, InsufficientBalanceError)]
|
||||
|
||||
# All 10 should succeed since we have exactly $10
|
||||
assert len(successful) == 10, f"Expected 10 successful, got {len(successful)}"
|
||||
assert len(failed) == 0, f"Expected 0 failures, got {len(failed)}"
|
||||
|
||||
# Final balance should be exactly 0
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
|
||||
|
||||
# Verify transaction history is consistent
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id, "type": prisma.enums.CreditTransactionType.USAGE}
|
||||
)
|
||||
assert (
|
||||
len(transactions) == 10
|
||||
), f"Expected 10 transactions, got {len(transactions)}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_spends_insufficient_balance(server: SpinTestServer):
|
||||
"""Test that concurrent spends correctly enforce balance limits."""
|
||||
user_id = f"insufficient-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user limited balance using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "limited_balance"}),
|
||||
)
|
||||
|
||||
# Try to spend 10 x $1 concurrently (but only have $5)
|
||||
async def spend_one_dollar(idx: int):
|
||||
try:
|
||||
return await credit_system.spend_credits(
|
||||
user_id,
|
||||
100, # $1
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=f"insufficient-{idx}",
|
||||
reason=f"Insufficient spend {idx}",
|
||||
),
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
return "FAILED"
|
||||
|
||||
# Run 10 concurrent spends
|
||||
results = await asyncio.gather(
|
||||
*[spend_one_dollar(i) for i in range(10)], return_exceptions=True
|
||||
)
|
||||
|
||||
# Count successful vs failed
|
||||
successful = [
|
||||
r
|
||||
for r in results
|
||||
if r not in ["FAILED", None] and not isinstance(r, Exception)
|
||||
]
|
||||
failed = [r for r in results if r == "FAILED"]
|
||||
|
||||
# Exactly 5 should succeed, 5 should fail
|
||||
assert len(successful) == 5, f"Expected 5 successful, got {len(successful)}"
|
||||
assert len(failed) == 5, f"Expected 5 failures, got {len(failed)}"
|
||||
|
||||
# Final balance should be exactly 0
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_mixed_operations(server: SpinTestServer):
|
||||
"""Test concurrent mix of spends, top-ups, and balance checks."""
|
||||
user_id = f"mixed-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Initial balance using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "initial_balance"}),
|
||||
)
|
||||
|
||||
# Mix of operations
|
||||
async def mixed_operations():
|
||||
operations = []
|
||||
|
||||
# 5 spends of $1 each
|
||||
for i in range(5):
|
||||
operations.append(
|
||||
credit_system.spend_credits(
|
||||
user_id,
|
||||
100,
|
||||
UsageTransactionMetadata(reason=f"Mixed spend {i}"),
|
||||
)
|
||||
)
|
||||
|
||||
# 3 top-ups of $2 each using internal method
|
||||
for i in range(3):
|
||||
operations.append(
|
||||
credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=200,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": f"concurrent_topup_{i}"}),
|
||||
)
|
||||
)
|
||||
|
||||
# 10 balance checks
|
||||
for i in range(10):
|
||||
operations.append(credit_system.get_credits(user_id))
|
||||
|
||||
return await asyncio.gather(*operations, return_exceptions=True)
|
||||
|
||||
results = await mixed_operations()
|
||||
|
||||
# Check no exceptions occurred
|
||||
exceptions = [
|
||||
r
|
||||
for r in results
|
||||
if isinstance(r, Exception) and not isinstance(r, InsufficientBalanceError)
|
||||
]
|
||||
assert len(exceptions) == 0, f"Unexpected exceptions: {exceptions}"
|
||||
|
||||
# Final balance should be: 1000 - 500 + 600 = 1100
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 1100, f"Expected balance 1100, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_race_condition_exact_balance(server: SpinTestServer):
|
||||
"""Test spending exact balance amount concurrently doesn't go negative."""
|
||||
user_id = f"exact-balance-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give exact amount using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=100,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "exact_amount"}),
|
||||
)
|
||||
|
||||
# Try to spend $1 twice concurrently
|
||||
async def spend_exact():
|
||||
try:
|
||||
return await credit_system.spend_credits(
|
||||
user_id, 100, UsageTransactionMetadata(reason="Exact spend")
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
return "FAILED"
|
||||
|
||||
# Both try to spend the full balance
|
||||
result1, result2 = await asyncio.gather(spend_exact(), spend_exact())
|
||||
|
||||
# Exactly one should succeed
|
||||
results = [result1, result2]
|
||||
successful = [
|
||||
r for r in results if r != "FAILED" and not isinstance(r, Exception)
|
||||
]
|
||||
failed = [r for r in results if r == "FAILED"]
|
||||
|
||||
assert len(successful) == 1, f"Expected 1 success, got {len(successful)}"
|
||||
assert len(failed) == 1, f"Expected 1 failure, got {len(failed)}"
|
||||
|
||||
# Balance should be exactly 0, never negative
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_onboarding_reward_idempotency(server: SpinTestServer):
|
||||
"""Test that onboarding rewards are idempotent (can't be claimed twice)."""
|
||||
user_id = f"onboarding-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Use WELCOME step which is defined in the OnboardingStep enum
|
||||
# Try to claim same reward multiple times concurrently
|
||||
async def claim_reward():
|
||||
try:
|
||||
result = await credit_system.onboarding_reward(
|
||||
user_id, 500, prisma.enums.OnboardingStep.WELCOME
|
||||
)
|
||||
return "SUCCESS" if result else "DUPLICATE"
|
||||
except Exception as e:
|
||||
print(f"Claim reward failed: {e}")
|
||||
return "FAILED"
|
||||
|
||||
# Try 5 concurrent claims of the same reward
|
||||
results = await asyncio.gather(*[claim_reward() for _ in range(5)])
|
||||
|
||||
# Count results
|
||||
success_count = results.count("SUCCESS")
|
||||
failed_count = results.count("FAILED")
|
||||
|
||||
# At least one should succeed, others should be duplicates
|
||||
assert success_count >= 1, f"At least one claim should succeed, got {results}"
|
||||
assert failed_count == 0, f"No claims should fail, got {results}"
|
||||
|
||||
# Check balance - should only have 500, not 2500
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 500, f"Expected balance 500, got {final_balance}"
|
||||
|
||||
# Check only one transaction exists
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"type": prisma.enums.CreditTransactionType.GRANT,
|
||||
"transactionKey": f"REWARD-{user_id}-WELCOME",
|
||||
}
|
||||
)
|
||||
assert (
|
||||
len(transactions) == 1
|
||||
), f"Expected 1 reward transaction, got {len(transactions)}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_integer_overflow_protection(server: SpinTestServer):
|
||||
"""Test that integer overflow is prevented by clamping to POSTGRES_INT_MAX."""
|
||||
user_id = f"overflow-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Try to add amount that would overflow
|
||||
max_int = POSTGRES_INT_MAX
|
||||
|
||||
# First, set balance near max
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": max_int - 100},
|
||||
"update": {"balance": max_int - 100},
|
||||
},
|
||||
)
|
||||
|
||||
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=200,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "overflow_protection"}),
|
||||
)
|
||||
|
||||
# Balance should be clamped to max_int, not overflowed
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance == max_int
|
||||
), f"Balance should be clamped to {max_int}, got {final_balance}"
|
||||
|
||||
# Verify transaction was created with clamped amount
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"type": prisma.enums.CreditTransactionType.TOP_UP,
|
||||
},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
assert len(transactions) > 0, "Transaction should be created"
|
||||
assert (
|
||||
transactions[0].runningBalance == max_int
|
||||
), "Transaction should show clamped balance"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_high_concurrency_stress(server: SpinTestServer):
|
||||
"""Stress test with many concurrent operations."""
|
||||
user_id = f"stress-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Initial balance using internal method (bypasses Stripe)
|
||||
initial_balance = 10000 # $100
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=initial_balance,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "stress_test_balance"}),
|
||||
)
|
||||
|
||||
# Run many concurrent operations
|
||||
async def random_operation(idx: int):
|
||||
operation = random.choice(["spend", "check"])
|
||||
|
||||
if operation == "spend":
|
||||
amount = random.randint(1, 50) # $0.01 to $0.50
|
||||
try:
|
||||
return (
|
||||
"spend",
|
||||
amount,
|
||||
await credit_system.spend_credits(
|
||||
user_id,
|
||||
amount,
|
||||
UsageTransactionMetadata(reason=f"Stress {idx}"),
|
||||
),
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
return ("spend_failed", amount, None)
|
||||
else:
|
||||
balance = await credit_system.get_credits(user_id)
|
||||
return ("check", 0, balance)
|
||||
|
||||
# Run 100 concurrent operations
|
||||
results = await asyncio.gather(
|
||||
*[random_operation(i) for i in range(100)], return_exceptions=True
|
||||
)
|
||||
|
||||
# Calculate expected final balance
|
||||
total_spent = sum(
|
||||
r[1]
|
||||
for r in results
|
||||
if not isinstance(r, Exception) and isinstance(r, tuple) and r[0] == "spend"
|
||||
)
|
||||
expected_balance = initial_balance - total_spent
|
||||
|
||||
# Verify final balance
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance == expected_balance
|
||||
), f"Expected {expected_balance}, got {final_balance}"
|
||||
assert final_balance >= 0, "Balance went negative!"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestServer):
|
||||
"""Test multiple concurrent spends when there's sufficient balance for all."""
|
||||
user_id = f"multi-spend-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user 150 balance ($1.50) using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=150,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "sufficient_balance"}),
|
||||
)
|
||||
|
||||
# Track individual timing to see serialization
|
||||
timings = {}
|
||||
|
||||
async def spend_with_detailed_timing(amount: int, label: str):
|
||||
start = asyncio.get_event_loop().time()
|
||||
try:
|
||||
await credit_system.spend_credits(
|
||||
user_id,
|
||||
amount,
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=f"concurrent-{label}",
|
||||
reason=f"Concurrent spend {label}",
|
||||
),
|
||||
)
|
||||
end = asyncio.get_event_loop().time()
|
||||
timings[label] = {"start": start, "end": end, "duration": end - start}
|
||||
return f"{label}-SUCCESS"
|
||||
except Exception as e:
|
||||
end = asyncio.get_event_loop().time()
|
||||
timings[label] = {
|
||||
"start": start,
|
||||
"end": end,
|
||||
"duration": end - start,
|
||||
"error": str(e),
|
||||
}
|
||||
return f"{label}-FAILED: {e}"
|
||||
|
||||
# Run concurrent spends: 10, 20, 30 (total 60, well under 150)
|
||||
overall_start = asyncio.get_event_loop().time()
|
||||
results = await asyncio.gather(
|
||||
spend_with_detailed_timing(10, "spend-10"),
|
||||
spend_with_detailed_timing(20, "spend-20"),
|
||||
spend_with_detailed_timing(30, "spend-30"),
|
||||
return_exceptions=True,
|
||||
)
|
||||
overall_end = asyncio.get_event_loop().time()
|
||||
|
||||
print(f"Results: {results}")
|
||||
print(f"Overall duration: {overall_end - overall_start:.4f}s")
|
||||
|
||||
# Analyze timing to detect serialization vs true concurrency
|
||||
print("\nTiming analysis:")
|
||||
for label, timing in timings.items():
|
||||
print(
|
||||
f" {label}: started at {timing['start']:.4f}, ended at {timing['end']:.4f}, duration {timing['duration']:.4f}s"
|
||||
)
|
||||
|
||||
# Check if operations overlapped (true concurrency) or were serialized
|
||||
sorted_timings = sorted(timings.items(), key=lambda x: x[1]["start"])
|
||||
print("\nExecution order by start time:")
|
||||
for i, (label, timing) in enumerate(sorted_timings):
|
||||
print(f" {i+1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
|
||||
|
||||
# Check for overlap (true concurrency) vs serialization
|
||||
overlaps = []
|
||||
for i in range(len(sorted_timings) - 1):
|
||||
current = sorted_timings[i]
|
||||
next_op = sorted_timings[i + 1]
|
||||
if current[1]["end"] > next_op[1]["start"]:
|
||||
overlaps.append(f"{current[0]} overlaps with {next_op[0]}")
|
||||
|
||||
if overlaps:
|
||||
print(f"✅ TRUE CONCURRENCY detected: {overlaps}")
|
||||
else:
|
||||
print("🔒 SERIALIZATION detected: No overlapping execution times")
|
||||
|
||||
# Check final balance
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
print(f"Final balance: {final_balance}")
|
||||
|
||||
# Count successes/failures
|
||||
successful = [r for r in results if "SUCCESS" in str(r)]
|
||||
failed = [r for r in results if "FAILED" in str(r)]
|
||||
|
||||
print(f"Successful: {len(successful)}, Failed: {len(failed)}")
|
||||
|
||||
# All should succeed since 150 - (10 + 20 + 30) = 90 > 0
|
||||
assert (
|
||||
len(successful) == 3
|
||||
), f"Expected all 3 to succeed, got {len(successful)} successes: {results}"
|
||||
assert final_balance == 90, f"Expected balance 90, got {final_balance}"
|
||||
|
||||
# Check transaction timestamps to confirm database-level serialization
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id, "type": prisma.enums.CreditTransactionType.USAGE},
|
||||
order={"createdAt": "asc"},
|
||||
)
|
||||
print("\nDatabase transaction order (by createdAt):")
|
||||
for i, tx in enumerate(transactions):
|
||||
print(
|
||||
f" {i+1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
|
||||
)
|
||||
|
||||
# Verify running balances are chronologically consistent (ordered by createdAt)
|
||||
actual_balances = [
|
||||
tx.runningBalance for tx in transactions if tx.runningBalance is not None
|
||||
]
|
||||
print(f"Running balances: {actual_balances}")
|
||||
|
||||
# The balances should be valid intermediate states regardless of execution order
|
||||
# Starting balance: 150, spending 10+20+30=60, so final should be 90
|
||||
# The intermediate balances depend on execution order but should all be valid
|
||||
expected_possible_balances = {
|
||||
# If order is 10, 20, 30: [140, 120, 90]
|
||||
# If order is 10, 30, 20: [140, 110, 90]
|
||||
# If order is 20, 10, 30: [130, 120, 90]
|
||||
# If order is 20, 30, 10: [130, 100, 90]
|
||||
# If order is 30, 10, 20: [120, 110, 90]
|
||||
# If order is 30, 20, 10: [120, 100, 90]
|
||||
90,
|
||||
100,
|
||||
110,
|
||||
120,
|
||||
130,
|
||||
140, # All possible intermediate balances
|
||||
}
|
||||
|
||||
# Verify all balances are valid intermediate states
|
||||
for balance in actual_balances:
|
||||
assert (
|
||||
balance in expected_possible_balances
|
||||
), f"Invalid balance {balance}, expected one of {expected_possible_balances}"
|
||||
|
||||
# Final balance should always be 90 (150 - 60)
|
||||
assert (
|
||||
min(actual_balances) == 90
|
||||
), f"Final balance should be 90, got {min(actual_balances)}"
|
||||
|
||||
# The final transaction should always have balance 90
|
||||
# The other transactions should have valid intermediate balances
|
||||
assert (
|
||||
90 in actual_balances
|
||||
), f"Final balance 90 should be in actual_balances: {actual_balances}"
|
||||
|
||||
# All balances should be >= 90 (the final state)
|
||||
assert all(
|
||||
balance >= 90 for balance in actual_balances
|
||||
), f"All balances should be >= 90, got {actual_balances}"
|
||||
|
||||
# CRITICAL: Transactions are atomic but can complete in any order
|
||||
# What matters is that all running balances are valid intermediate states
|
||||
# Each balance should be between 90 (final) and 140 (after first transaction)
|
||||
for balance in actual_balances:
|
||||
assert (
|
||||
90 <= balance <= 140
|
||||
), f"Balance {balance} is outside valid range [90, 140]"
|
||||
|
||||
# Final balance (minimum) should always be 90
|
||||
assert (
|
||||
min(actual_balances) == 90
|
||||
), f"Final balance should be 90, got {min(actual_balances)}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_prove_database_locking_behavior(server: SpinTestServer):
|
||||
"""Definitively prove whether database locking causes waiting vs failures."""
|
||||
user_id = f"locking-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Set balance to exact amount that can handle all spends using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=60, # Exactly 10+20+30
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "exact_amount_test"}),
|
||||
)
|
||||
|
||||
async def spend_with_precise_timing(amount: int, label: str):
|
||||
request_start = asyncio.get_event_loop().time()
|
||||
db_operation_start = asyncio.get_event_loop().time()
|
||||
try:
|
||||
# Add a small delay to increase chance of true concurrency
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
db_operation_start = asyncio.get_event_loop().time()
|
||||
await credit_system.spend_credits(
|
||||
user_id,
|
||||
amount,
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=f"locking-{label}",
|
||||
reason=f"Locking test {label}",
|
||||
),
|
||||
)
|
||||
db_operation_end = asyncio.get_event_loop().time()
|
||||
|
||||
return {
|
||||
"label": label,
|
||||
"status": "SUCCESS",
|
||||
"request_start": request_start,
|
||||
"db_start": db_operation_start,
|
||||
"db_end": db_operation_end,
|
||||
"db_duration": db_operation_end - db_operation_start,
|
||||
}
|
||||
except Exception as e:
|
||||
db_operation_end = asyncio.get_event_loop().time()
|
||||
return {
|
||||
"label": label,
|
||||
"status": "FAILED",
|
||||
"error": str(e),
|
||||
"request_start": request_start,
|
||||
"db_start": db_operation_start,
|
||||
"db_end": db_operation_end,
|
||||
"db_duration": db_operation_end - db_operation_start,
|
||||
}
|
||||
|
||||
# Launch all requests simultaneously
|
||||
results = await asyncio.gather(
|
||||
spend_with_precise_timing(10, "A"),
|
||||
spend_with_precise_timing(20, "B"),
|
||||
spend_with_precise_timing(30, "C"),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
print("\n🔍 LOCKING BEHAVIOR ANALYSIS:")
|
||||
print("=" * 50)
|
||||
|
||||
successful = [
|
||||
r for r in results if isinstance(r, dict) and r.get("status") == "SUCCESS"
|
||||
]
|
||||
failed = [
|
||||
r for r in results if isinstance(r, dict) and r.get("status") == "FAILED"
|
||||
]
|
||||
|
||||
print(f"✅ Successful operations: {len(successful)}")
|
||||
print(f"❌ Failed operations: {len(failed)}")
|
||||
|
||||
if len(failed) > 0:
|
||||
print(
|
||||
"\n🚫 CONCURRENT FAILURES - Some requests failed due to insufficient balance:"
|
||||
)
|
||||
for result in failed:
|
||||
if isinstance(result, dict):
|
||||
print(
|
||||
f" {result['label']}: {result.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
if len(successful) == 3:
|
||||
print(
|
||||
"\n🔒 SERIALIZATION CONFIRMED - All requests succeeded, indicating they were queued:"
|
||||
)
|
||||
|
||||
# Sort by actual execution time to see order
|
||||
dict_results = [r for r in results if isinstance(r, dict)]
|
||||
sorted_results = sorted(dict_results, key=lambda x: x["db_start"])
|
||||
|
||||
for i, result in enumerate(sorted_results):
|
||||
print(
|
||||
f" {i+1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
|
||||
)
|
||||
|
||||
# Check if any operations overlapped at the database level
|
||||
print("\n⏱️ Database operation timeline:")
|
||||
for result in sorted_results:
|
||||
print(
|
||||
f" {result['label']}: {result['db_start']:.4f} -> {result['db_end']:.4f}"
|
||||
)
|
||||
|
||||
# Verify final state
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
print(f"\n💰 Final balance: {final_balance}")
|
||||
|
||||
if len(successful) == 3:
|
||||
assert (
|
||||
final_balance == 0
|
||||
), f"If all succeeded, balance should be 0, got {final_balance}"
|
||||
print(
|
||||
"✅ CONCLUSION: Database row locking causes requests to WAIT and execute serially"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"❌ CONCLUSION: Some requests failed, indicating different concurrency behavior"
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
277
autogpt_platform/backend/backend/data/credit_integration_test.py
Normal file
277
autogpt_platform/backend/backend/data/credit_integration_test.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
Integration tests for credit system to catch SQL enum casting issues.
|
||||
|
||||
These tests run actual database operations to ensure SQL queries work correctly,
|
||||
which would have caught the CreditTransactionType enum casting bug.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
BetaUserCredit,
|
||||
UsageTransactionMetadata,
|
||||
get_auto_top_up,
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def cleanup_test_user():
|
||||
"""Clean up test user data before and after tests."""
|
||||
import uuid
|
||||
|
||||
user_id = str(uuid.uuid4()) # Use unique user ID for each test
|
||||
|
||||
# Create the user first
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"topUpConfig": SafeJson({}),
|
||||
"timezone": "UTC",
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# User might already exist, that's fine
|
||||
pass
|
||||
|
||||
yield user_id
|
||||
|
||||
# Cleanup after test
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
# Clear auto-top-up config before deleting user
|
||||
await User.prisma().update(
|
||||
where={"id": user_id}, data={"topUpConfig": SafeJson({})}
|
||||
)
|
||||
await User.prisma().delete(where={"id": user_id})
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_credit_transaction_enum_casting_integration(cleanup_test_user):
|
||||
"""
|
||||
Integration test to verify CreditTransactionType enum casting works in SQL queries.
|
||||
|
||||
This test would have caught the enum casting bug where PostgreSQL expected
|
||||
platform."CreditTransactionType" but got "CreditTransactionType".
|
||||
"""
|
||||
user_id = cleanup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# Test each transaction type to ensure enum casting works
|
||||
test_cases = [
|
||||
(CreditTransactionType.TOP_UP, 100, "Test top-up"),
|
||||
(CreditTransactionType.USAGE, -50, "Test usage"),
|
||||
(CreditTransactionType.GRANT, 200, "Test grant"),
|
||||
(CreditTransactionType.REFUND, -25, "Test refund"),
|
||||
(CreditTransactionType.CARD_CHECK, 0, "Test card check"),
|
||||
]
|
||||
|
||||
for transaction_type, amount, reason in test_cases:
|
||||
metadata = SafeJson({"reason": reason, "test": "enum_casting"})
|
||||
|
||||
# This call would fail with enum casting error before the fix
|
||||
balance, tx_key = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=amount,
|
||||
transaction_type=transaction_type,
|
||||
metadata=metadata,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify transaction was created with correct type
|
||||
transaction = await CreditTransaction.prisma().find_first(
|
||||
where={"userId": user_id, "transactionKey": tx_key}
|
||||
)
|
||||
|
||||
assert transaction is not None
|
||||
assert transaction.type == transaction_type
|
||||
assert transaction.amount == amount
|
||||
assert transaction.metadata is not None
|
||||
|
||||
# Verify metadata content
|
||||
assert transaction.metadata["reason"] == reason
|
||||
assert transaction.metadata["test"] == "enum_casting"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_auto_top_up_integration(cleanup_test_user, monkeypatch):
|
||||
"""
|
||||
Integration test for auto-top-up functionality that triggers enum casting.
|
||||
|
||||
This tests the complete auto-top-up flow which involves SQL queries with
|
||||
CreditTransactionType enums, ensuring enum casting works end-to-end.
|
||||
"""
|
||||
# Enable credits for this test
|
||||
from backend.data.credit import settings
|
||||
|
||||
monkeypatch.setattr(settings.config, "enable_credit", True)
|
||||
monkeypatch.setattr(settings.config, "enable_beta_monthly_credit", True)
|
||||
monkeypatch.setattr(settings.config, "num_user_credits_refill", 1000)
|
||||
|
||||
user_id = cleanup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# First add some initial credits so we can test the configuration and subsequent behavior
|
||||
balance, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=50, # Below threshold that we'll set
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
metadata=SafeJson({"reason": "Initial credits before auto top-up config"}),
|
||||
)
|
||||
assert balance == 50
|
||||
|
||||
# Configure auto top-up with threshold above current balance
|
||||
config = AutoTopUpConfig(threshold=100, amount=500)
|
||||
await set_auto_top_up(user_id, config)
|
||||
|
||||
# Verify configuration was saved but no immediate top-up occurred
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
assert current_balance == 50 # Balance should be unchanged
|
||||
|
||||
# Simulate spending credits that would trigger auto top-up
|
||||
# This involves multiple SQL operations with enum casting
|
||||
try:
|
||||
metadata = UsageTransactionMetadata(reason="Test spend to trigger auto top-up")
|
||||
await credit_system.spend_credits(user_id=user_id, cost=10, metadata=metadata)
|
||||
|
||||
# The auto top-up mechanism should have been triggered
|
||||
# Verify the transaction types were handled correctly
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id}, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
# Should have at least: GRANT (initial), USAGE (spend), and TOP_UP (auto top-up)
|
||||
assert len(transactions) >= 3
|
||||
|
||||
# Verify different transaction types exist and enum casting worked
|
||||
transaction_types = {t.type for t in transactions}
|
||||
assert CreditTransactionType.GRANT in transaction_types
|
||||
assert CreditTransactionType.USAGE in transaction_types
|
||||
assert (
|
||||
CreditTransactionType.TOP_UP in transaction_types
|
||||
) # Auto top-up should have triggered
|
||||
|
||||
except Exception as e:
|
||||
# If this fails with enum casting error, the test successfully caught the bug
|
||||
if "CreditTransactionType" in str(e) and (
|
||||
"cast" in str(e).lower() or "type" in str(e).lower()
|
||||
):
|
||||
pytest.fail(f"Enum casting error detected: {e}")
|
||||
else:
|
||||
# Re-raise other unexpected errors
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_enable_transaction_enum_casting_integration(cleanup_test_user):
|
||||
"""
|
||||
Integration test for _enable_transaction with enum casting.
|
||||
|
||||
Tests the scenario where inactive transactions are enabled, which also
|
||||
involves SQL queries with CreditTransactionType enum casting.
|
||||
"""
|
||||
user_id = cleanup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# Create an inactive transaction
|
||||
balance, tx_key = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=100,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"reason": "Inactive transaction test"}),
|
||||
is_active=False, # Create as inactive
|
||||
)
|
||||
|
||||
# Balance should be 0 since transaction is inactive
|
||||
assert balance == 0
|
||||
|
||||
# Enable the transaction with new metadata
|
||||
enable_metadata = SafeJson(
|
||||
{
|
||||
"payment_method": "test_payment",
|
||||
"activation_reason": "Integration test activation",
|
||||
}
|
||||
)
|
||||
|
||||
# This would fail with enum casting error before the fix
|
||||
final_balance = await credit_system._enable_transaction(
|
||||
transaction_key=tx_key,
|
||||
user_id=user_id,
|
||||
metadata=enable_metadata,
|
||||
)
|
||||
|
||||
# Now balance should reflect the activated transaction
|
||||
assert final_balance == 100
|
||||
|
||||
# Verify transaction was properly enabled with correct enum type
|
||||
transaction = await CreditTransaction.prisma().find_first(
|
||||
where={"userId": user_id, "transactionKey": tx_key}
|
||||
)
|
||||
|
||||
assert transaction is not None
|
||||
assert transaction.isActive is True
|
||||
assert transaction.type == CreditTransactionType.TOP_UP
|
||||
assert transaction.runningBalance == 100
|
||||
|
||||
# Verify metadata was updated
|
||||
assert transaction.metadata is not None
|
||||
assert transaction.metadata["payment_method"] == "test_payment"
|
||||
assert transaction.metadata["activation_reason"] == "Integration test activation"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_auto_top_up_configuration_storage(cleanup_test_user, monkeypatch):
|
||||
"""
|
||||
Test that auto-top-up configuration is properly stored and retrieved.
|
||||
|
||||
The immediate top-up logic is handled by the API routes, not the core
|
||||
set_auto_top_up function. This test verifies the configuration is correctly
|
||||
saved and can be retrieved.
|
||||
"""
|
||||
# Enable credits for this test
|
||||
from backend.data.credit import settings
|
||||
|
||||
monkeypatch.setattr(settings.config, "enable_credit", True)
|
||||
monkeypatch.setattr(settings.config, "enable_beta_monthly_credit", True)
|
||||
monkeypatch.setattr(settings.config, "num_user_credits_refill", 1000)
|
||||
|
||||
user_id = cleanup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# Set initial balance
|
||||
balance, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=50,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
metadata=SafeJson({"reason": "Initial balance for config test"}),
|
||||
)
|
||||
|
||||
assert balance == 50
|
||||
|
||||
# Configure auto top-up
|
||||
config = AutoTopUpConfig(threshold=100, amount=200)
|
||||
await set_auto_top_up(user_id, config)
|
||||
|
||||
# Verify the configuration was saved
|
||||
retrieved_config = await get_auto_top_up(user_id)
|
||||
assert retrieved_config.threshold == config.threshold
|
||||
assert retrieved_config.amount == config.amount
|
||||
|
||||
# Verify balance is unchanged (no immediate top-up from set_auto_top_up)
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 50 # Should be unchanged
|
||||
|
||||
# Verify no immediate auto-top-up transaction was created by set_auto_top_up
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id}, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
# Should only have the initial GRANT transaction
|
||||
assert len(transactions) == 1
|
||||
assert transactions[0].type == CreditTransactionType.GRANT
|
||||
141
autogpt_platform/backend/backend/data/credit_metadata_test.py
Normal file
141
autogpt_platform/backend/backend/data/credit_metadata_test.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Tests for credit system metadata handling to ensure JSON casting works correctly.
|
||||
|
||||
This test verifies that metadata parameters are properly serialized when passed
|
||||
to raw SQL queries with JSONB columns.
|
||||
"""
|
||||
|
||||
# type: ignore
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, UserBalance
|
||||
|
||||
from backend.data.credit import BetaUserCredit
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user():
|
||||
"""Setup test user and cleanup after test."""
|
||||
user_id = DEFAULT_USER_ID
|
||||
|
||||
# Cleanup before test
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
|
||||
yield user_id
|
||||
|
||||
# Cleanup after test
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_metadata_json_serialization(setup_test_user):
|
||||
"""Test that metadata is properly serialized for JSONB column in raw SQL."""
|
||||
user_id = setup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# Test with complex metadata that would fail if not properly serialized
|
||||
complex_metadata = SafeJson(
|
||||
{
|
||||
"graph_exec_id": "test-12345",
|
||||
"reason": "Testing metadata serialization",
|
||||
"nested_data": {
|
||||
"key1": "value1",
|
||||
"key2": ["array", "of", "values"],
|
||||
"key3": {"deeply": {"nested": "object"}},
|
||||
},
|
||||
"special_chars": "Testing 'quotes' and \"double quotes\" and unicode: 🚀",
|
||||
}
|
||||
)
|
||||
|
||||
# This should work without throwing a JSONB casting error
|
||||
balance, tx_key = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500, # $5 top-up
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=complex_metadata,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify the transaction was created successfully
|
||||
assert balance == 500
|
||||
|
||||
# Verify the metadata was stored correctly in the database
|
||||
transaction = await CreditTransaction.prisma().find_first(
|
||||
where={"userId": user_id, "transactionKey": tx_key}
|
||||
)
|
||||
|
||||
assert transaction is not None
|
||||
assert transaction.metadata is not None
|
||||
|
||||
# Verify the metadata contains our complex data
|
||||
metadata_dict: dict[str, Any] = dict(transaction.metadata) # type: ignore
|
||||
assert metadata_dict["graph_exec_id"] == "test-12345"
|
||||
assert metadata_dict["reason"] == "Testing metadata serialization"
|
||||
assert metadata_dict["nested_data"]["key1"] == "value1"
|
||||
assert metadata_dict["nested_data"]["key3"]["deeply"]["nested"] == "object"
|
||||
assert (
|
||||
metadata_dict["special_chars"]
|
||||
== "Testing 'quotes' and \"double quotes\" and unicode: 🚀"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_enable_transaction_metadata_serialization(setup_test_user):
|
||||
"""Test that _enable_transaction also handles metadata JSON serialization correctly."""
|
||||
user_id = setup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# First create an inactive transaction
|
||||
balance, tx_key = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=300,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"initial": "inactive_transaction"}),
|
||||
is_active=False, # Create as inactive
|
||||
)
|
||||
|
||||
# Initial balance should be 0 because transaction is inactive
|
||||
assert balance == 0
|
||||
|
||||
# Now enable the transaction with new metadata
|
||||
enable_metadata = SafeJson(
|
||||
{
|
||||
"payment_method": "stripe",
|
||||
"payment_intent": "pi_test_12345",
|
||||
"activation_reason": "Payment confirmed",
|
||||
"complex_data": {"array": [1, 2, 3], "boolean": True, "null_value": None},
|
||||
}
|
||||
)
|
||||
|
||||
# This should work without JSONB casting errors
|
||||
final_balance = await credit_system._enable_transaction(
|
||||
transaction_key=tx_key,
|
||||
user_id=user_id,
|
||||
metadata=enable_metadata,
|
||||
)
|
||||
|
||||
# Now balance should reflect the activated transaction
|
||||
assert final_balance == 300
|
||||
|
||||
# Verify the metadata was updated correctly
|
||||
transaction = await CreditTransaction.prisma().find_first(
|
||||
where={"userId": user_id, "transactionKey": tx_key}
|
||||
)
|
||||
|
||||
assert transaction is not None
|
||||
assert transaction.isActive is True
|
||||
|
||||
# Verify the metadata was updated with enable_metadata
|
||||
metadata_dict: dict[str, Any] = dict(transaction.metadata) # type: ignore
|
||||
assert metadata_dict["payment_method"] == "stripe"
|
||||
assert metadata_dict["payment_intent"] == "pi_test_12345"
|
||||
assert metadata_dict["complex_data"]["array"] == [1, 2, 3]
|
||||
assert metadata_dict["complex_data"]["boolean"] is True
|
||||
assert metadata_dict["complex_data"]["null_value"] is None
|
||||
372
autogpt_platform/backend/backend/data/credit_refund_test.py
Normal file
372
autogpt_platform/backend/backend/data/credit_refund_test.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""
|
||||
Tests for credit system refund and dispute operations.
|
||||
|
||||
These tests ensure that refund operations (deduct_credits, handle_dispute)
|
||||
are atomic and maintain data consistency.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import stripe
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
credit_system = UserCredit()
|
||||
|
||||
# Test user ID for refund tests
|
||||
REFUND_TEST_USER_ID = "refund-test-user"
|
||||
|
||||
|
||||
async def setup_test_user_with_topup():
|
||||
"""Create a test user with initial balance and a top-up transaction."""
|
||||
# Clean up any existing data
|
||||
await CreditRefundRequest.prisma().delete_many(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
|
||||
await UserBalance.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
|
||||
await User.prisma().delete_many(where={"id": REFUND_TEST_USER_ID})
|
||||
|
||||
# Create user
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": REFUND_TEST_USER_ID,
|
||||
"email": f"{REFUND_TEST_USER_ID}@example.com",
|
||||
"name": "Refund Test User",
|
||||
}
|
||||
)
|
||||
|
||||
# Create user balance
|
||||
await UserBalance.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"balance": 1000, # $10
|
||||
}
|
||||
)
|
||||
|
||||
# Create a top-up transaction that can be refunded
|
||||
topup_tx = await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 1000,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"transactionKey": "pi_test_12345",
|
||||
"runningBalance": 1000,
|
||||
"isActive": True,
|
||||
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
||||
}
|
||||
)
|
||||
|
||||
return topup_tx
|
||||
|
||||
|
||||
async def cleanup_test_user():
|
||||
"""Clean up test data."""
|
||||
await CreditRefundRequest.prisma().delete_many(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
|
||||
await UserBalance.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
|
||||
await User.prisma().delete_many(where={"id": REFUND_TEST_USER_ID})
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_deduct_credits_atomic(server: SpinTestServer):
|
||||
"""Test that deduct_credits is atomic and creates transaction correctly."""
|
||||
topup_tx = await setup_test_user_with_topup()
|
||||
|
||||
try:
|
||||
# Create a mock refund object
|
||||
refund = MagicMock(spec=stripe.Refund)
|
||||
refund.id = "re_test_refund_123"
|
||||
refund.payment_intent = topup_tx.transactionKey
|
||||
refund.amount = 500 # Refund $5 of the $10 top-up
|
||||
refund.status = "succeeded"
|
||||
refund.reason = "requested_by_customer"
|
||||
refund.created = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Create refund request record (simulating webhook flow)
|
||||
await CreditRefundRequest.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 500,
|
||||
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
|
||||
"reason": "Test refund",
|
||||
}
|
||||
)
|
||||
|
||||
# Call deduct_credits
|
||||
await credit_system.deduct_credits(refund)
|
||||
|
||||
# Verify the user's balance was deducted
|
||||
user_balance = await UserBalance.prisma().find_unique(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 500
|
||||
), f"Expected balance 500, got {user_balance.balance}"
|
||||
|
||||
# Verify refund transaction was created
|
||||
refund_tx = await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"type": CreditTransactionType.REFUND,
|
||||
"transactionKey": refund.id,
|
||||
}
|
||||
)
|
||||
assert refund_tx is not None
|
||||
assert refund_tx.amount == -500
|
||||
assert refund_tx.runningBalance == 500
|
||||
assert refund_tx.isActive
|
||||
|
||||
# Verify refund request was updated
|
||||
refund_request = await CreditRefundRequest.prisma().find_first(
|
||||
where={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"transactionKey": topup_tx.transactionKey,
|
||||
}
|
||||
)
|
||||
assert refund_request is not None
|
||||
assert (
|
||||
refund_request.result
|
||||
== "The refund request has been approved, the amount will be credited back to your account."
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_deduct_credits_user_not_found(server: SpinTestServer):
|
||||
"""Test that deduct_credits raises error if transaction not found (which means user doesn't exist)."""
|
||||
# Create a mock refund object that references a non-existent payment intent
|
||||
refund = MagicMock(spec=stripe.Refund)
|
||||
refund.id = "re_test_refund_nonexistent"
|
||||
refund.payment_intent = "pi_test_nonexistent" # This payment intent doesn't exist
|
||||
refund.amount = 500
|
||||
refund.status = "succeeded"
|
||||
refund.reason = "requested_by_customer"
|
||||
refund.created = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Should raise error for missing transaction
|
||||
with pytest.raises(Exception): # Should raise NotFoundError for missing transaction
|
||||
await credit_system.deduct_credits(refund)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.data.credit.settings")
|
||||
@patch("stripe.Dispute.modify")
|
||||
@patch("backend.data.credit.get_user_by_id")
|
||||
async def test_handle_dispute_with_sufficient_balance(
|
||||
mock_get_user, mock_stripe_modify, mock_settings, server: SpinTestServer
|
||||
):
|
||||
"""Test handling dispute when user has sufficient balance (dispute gets closed)."""
|
||||
topup_tx = await setup_test_user_with_topup()
|
||||
|
||||
try:
|
||||
# Mock settings to have a low tolerance threshold
|
||||
mock_settings.config.refund_credit_tolerance_threshold = 0
|
||||
|
||||
# Mock the user lookup
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = f"{REFUND_TEST_USER_ID}@example.com"
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Create a mock dispute object for small amount (user has 1000, disputing 100)
|
||||
dispute = MagicMock(spec=stripe.Dispute)
|
||||
dispute.id = "dp_test_dispute_123"
|
||||
dispute.payment_intent = topup_tx.transactionKey
|
||||
dispute.amount = 100 # Small dispute amount
|
||||
dispute.status = "pending"
|
||||
dispute.reason = "fraudulent"
|
||||
dispute.created = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Mock the close method to prevent real API calls
|
||||
dispute.close = MagicMock()
|
||||
|
||||
# Handle the dispute
|
||||
await credit_system.handle_dispute(dispute)
|
||||
|
||||
# Verify dispute.close() was called (since user has sufficient balance)
|
||||
dispute.close.assert_called_once()
|
||||
|
||||
# Verify no stripe evidence was added since dispute was closed
|
||||
mock_stripe_modify.assert_not_called()
|
||||
|
||||
# Verify the user's balance was NOT deducted (dispute was closed)
|
||||
user_balance = await UserBalance.prisma().find_unique(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 1000
|
||||
), f"Balance should remain 1000, got {user_balance.balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.data.credit.settings")
|
||||
@patch("stripe.Dispute.modify")
|
||||
@patch("backend.data.credit.get_user_by_id")
|
||||
async def test_handle_dispute_with_insufficient_balance(
|
||||
mock_get_user, mock_stripe_modify, mock_settings, server: SpinTestServer
|
||||
):
|
||||
"""Test handling dispute when user has insufficient balance (evidence gets added)."""
|
||||
topup_tx = await setup_test_user_with_topup()
|
||||
|
||||
# Save original method for restoration before any try blocks
|
||||
original_get_history = credit_system.get_transaction_history
|
||||
|
||||
try:
|
||||
# Mock settings to have a high tolerance threshold so dispute isn't closed
|
||||
mock_settings.config.refund_credit_tolerance_threshold = 2000
|
||||
|
||||
# Mock the user lookup
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = f"{REFUND_TEST_USER_ID}@example.com"
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Mock the transaction history method to return an async result
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_history = MagicMock()
|
||||
mock_history.transactions = []
|
||||
credit_system.get_transaction_history = AsyncMock(return_value=mock_history)
|
||||
|
||||
# Create a mock dispute object for full amount (user has 1000, disputing 1000)
|
||||
dispute = MagicMock(spec=stripe.Dispute)
|
||||
dispute.id = "dp_test_dispute_pending"
|
||||
dispute.payment_intent = topup_tx.transactionKey
|
||||
dispute.amount = 1000
|
||||
dispute.status = "warning_needs_response"
|
||||
dispute.created = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Mock the close method to prevent real API calls
|
||||
dispute.close = MagicMock()
|
||||
|
||||
# Handle the dispute (evidence should be added)
|
||||
await credit_system.handle_dispute(dispute)
|
||||
|
||||
# Verify dispute.close() was NOT called (insufficient balance after tolerance)
|
||||
dispute.close.assert_not_called()
|
||||
|
||||
# Verify stripe evidence was added since dispute wasn't closed
|
||||
mock_stripe_modify.assert_called_once()
|
||||
|
||||
# Verify the user's balance was NOT deducted (handle_dispute doesn't deduct credits)
|
||||
user_balance = await UserBalance.prisma().find_unique(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
assert user_balance is not None
|
||||
assert user_balance.balance == 1000, "Balance should remain unchanged"
|
||||
|
||||
finally:
|
||||
credit_system.get_transaction_history = original_get_history
|
||||
await cleanup_test_user()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_refunds(server: SpinTestServer):
|
||||
"""Test that concurrent refunds are handled atomically."""
|
||||
import asyncio
|
||||
|
||||
topup_tx = await setup_test_user_with_topup()
|
||||
|
||||
try:
|
||||
# Create multiple refund requests
|
||||
refund_requests = []
|
||||
for i in range(5):
|
||||
req = await CreditRefundRequest.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 100, # $1 each
|
||||
"transactionKey": topup_tx.transactionKey,
|
||||
"reason": f"Test refund {i}",
|
||||
}
|
||||
)
|
||||
refund_requests.append(req)
|
||||
|
||||
# Create refund tasks to run concurrently
|
||||
async def process_refund(index: int):
|
||||
refund = MagicMock(spec=stripe.Refund)
|
||||
refund.id = f"re_test_concurrent_{index}"
|
||||
refund.payment_intent = topup_tx.transactionKey
|
||||
refund.amount = 100 # $1 refund
|
||||
refund.status = "succeeded"
|
||||
refund.reason = "requested_by_customer"
|
||||
refund.created = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
try:
|
||||
await credit_system.deduct_credits(refund)
|
||||
return "success"
|
||||
except Exception as e:
|
||||
return f"error: {e}"
|
||||
|
||||
# Run refunds concurrently
|
||||
results = await asyncio.gather(
|
||||
*[process_refund(i) for i in range(5)], return_exceptions=True
|
||||
)
|
||||
|
||||
# All should succeed
|
||||
assert all(r == "success" for r in results), f"Some refunds failed: {results}"
|
||||
|
||||
# Verify final balance - with non-atomic implementation, this will demonstrate race condition
|
||||
# EXPECTED BEHAVIOR: Due to race conditions, not all refunds will be properly processed
|
||||
# The balance will be incorrect (higher than expected) showing lost updates
|
||||
user_balance = await UserBalance.prisma().find_unique(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
assert user_balance is not None
|
||||
|
||||
# With atomic implementation, this should be 500 (1000 - 5*100)
|
||||
# With current non-atomic implementation, this will likely be wrong due to race conditions
|
||||
print(f"DEBUG: Final balance = {user_balance.balance}, expected = 500")
|
||||
|
||||
# With atomic implementation, all 5 refunds should process correctly
|
||||
assert (
|
||||
user_balance.balance == 500
|
||||
), f"Expected balance 500 after 5 refunds of 100 each, got {user_balance.balance}"
|
||||
|
||||
# Verify all refund transactions exist
|
||||
refund_txs = await CreditTransaction.prisma().find_many(
|
||||
where={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"type": CreditTransactionType.REFUND,
|
||||
}
|
||||
)
|
||||
assert (
|
||||
len(refund_txs) == 5
|
||||
), f"Expected 5 refund transactions, got {len(refund_txs)}"
|
||||
|
||||
running_balances: set[int] = {
|
||||
tx.runningBalance for tx in refund_txs if tx.runningBalance is not None
|
||||
}
|
||||
|
||||
# Verify all balances are valid intermediate states
|
||||
for balance in running_balances:
|
||||
assert (
|
||||
500 <= balance <= 1000
|
||||
), f"Invalid balance {balance}, should be between 500 and 1000"
|
||||
|
||||
# Final balance should be present
|
||||
assert (
|
||||
500 in running_balances
|
||||
), f"Final balance 500 should be in {running_balances}"
|
||||
|
||||
# All balances should be unique and form a valid sequence
|
||||
sorted_balances = sorted(running_balances, reverse=True)
|
||||
assert (
|
||||
len(sorted_balances) == 5
|
||||
), f"Expected 5 unique balances, got {len(sorted_balances)}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user()
|
||||
@@ -1,8 +1,8 @@
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction
|
||||
from prisma.models import CreditTransaction, UserBalance
|
||||
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.block import get_block
|
||||
@@ -19,14 +19,24 @@ user_credit = BetaUserCredit(REFILL_VALUE)
|
||||
|
||||
async def disable_test_user_transactions():
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": DEFAULT_USER_ID})
|
||||
# Also reset the balance to 0 and set updatedAt to old date to trigger monthly refill
|
||||
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
|
||||
"update": {"balance": 0, "updatedAt": old_date},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def top_up(amount: int):
|
||||
await user_credit._add_transaction(
|
||||
balance, _ = await user_credit._add_transaction(
|
||||
DEFAULT_USER_ID,
|
||||
amount,
|
||||
CreditTransactionType.TOP_UP,
|
||||
)
|
||||
return balance
|
||||
|
||||
|
||||
async def spend_credits(entry: NodeExecutionEntry) -> int:
|
||||
@@ -111,29 +121,90 @@ async def test_block_credit_top_up(server: SpinTestServer):
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_credit_reset(server: SpinTestServer):
|
||||
"""Test that BetaUserCredit provides monthly refills correctly."""
|
||||
await disable_test_user_transactions()
|
||||
month1 = 1
|
||||
month2 = 2
|
||||
|
||||
# set the calendar to month 2 but use current time from now
|
||||
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
|
||||
month=month2, day=1
|
||||
)
|
||||
month2credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
# Save original time_now function for restoration
|
||||
original_time_now = user_credit.time_now
|
||||
|
||||
# Month 1 result should only affect month 1
|
||||
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
|
||||
month=month1, day=1
|
||||
)
|
||||
month1credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
await top_up(100)
|
||||
assert await user_credit.get_credits(DEFAULT_USER_ID) == month1credit + 100
|
||||
try:
|
||||
# Test month 1 behavior
|
||||
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
||||
user_credit.time_now = lambda: month1
|
||||
|
||||
# Month 2 balance is unaffected
|
||||
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
|
||||
month=month2, day=1
|
||||
)
|
||||
assert await user_credit.get_credits(DEFAULT_USER_ID) == month2credit
|
||||
# First call in month 1 should trigger refill
|
||||
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
assert balance == REFILL_VALUE # Should get 1000 credits
|
||||
|
||||
# Manually create a transaction with month 1 timestamp to establish history
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": 100,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"runningBalance": 1100,
|
||||
"isActive": True,
|
||||
"createdAt": month1, # Set specific timestamp
|
||||
}
|
||||
)
|
||||
|
||||
# Update user balance to match
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
|
||||
"update": {"balance": 1100},
|
||||
},
|
||||
)
|
||||
|
||||
# Now test month 2 behavior
|
||||
month2 = datetime.now(timezone.utc).replace(month=2, day=1)
|
||||
user_credit.time_now = lambda: month2
|
||||
|
||||
# In month 2, since balance (1100) > refill (1000), no refill should happen
|
||||
month2_balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
assert month2_balance == 1100 # Balance persists, no reset
|
||||
|
||||
# Now test the refill behavior when balance is low
|
||||
# Set balance below refill threshold
|
||||
await UserBalance.prisma().update(
|
||||
where={"userId": DEFAULT_USER_ID}, data={"balance": 400}
|
||||
)
|
||||
|
||||
# Create a month 2 transaction to update the last transaction time
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": -700, # Spent 700 to get to 400
|
||||
"type": CreditTransactionType.USAGE,
|
||||
"runningBalance": 400,
|
||||
"isActive": True,
|
||||
"createdAt": month2,
|
||||
}
|
||||
)
|
||||
|
||||
# Move to month 3
|
||||
month3 = datetime.now(timezone.utc).replace(month=3, day=1)
|
||||
user_credit.time_now = lambda: month3
|
||||
|
||||
# Should get refilled since balance (400) < refill value (1000)
|
||||
month3_balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
assert month3_balance == REFILL_VALUE # Should be refilled to 1000
|
||||
|
||||
# Verify the refill transaction was created
|
||||
refill_tx = await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"type": CreditTransactionType.GRANT,
|
||||
"transactionKey": {"contains": "MONTHLY-CREDIT-TOP-UP"},
|
||||
},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
assert refill_tx is not None, "Monthly refill transaction should be created"
|
||||
assert refill_tx.amount == 600, "Refill should be 600 (1000 - 400)"
|
||||
finally:
|
||||
# Restore original time_now function
|
||||
user_credit.time_now = original_time_now
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
361
autogpt_platform/backend/backend/data/credit_underflow_test.py
Normal file
361
autogpt_platform/backend/backend/data/credit_underflow_test.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Test underflow protection for cumulative refunds and negative transactions.
|
||||
|
||||
This test ensures that when multiple large refunds are processed, the user balance
|
||||
doesn't underflow below POSTGRES_INT_MIN, which could cause integer wraparound issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for underflow tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
pass
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_test_user(user_id: str) -> None:
|
||||
"""Clean up test user and their transactions."""
|
||||
try:
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
await User.prisma().delete_many(where={"id": user_id})
|
||||
except Exception as e:
|
||||
# Log cleanup failures but don't fail the test
|
||||
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
||||
"""Debug underflow behavior step by step."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"debug-underflow-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
print(f"POSTGRES_INT_MIN: {POSTGRES_INT_MIN}")
|
||||
|
||||
# Test 1: Set up balance close to underflow threshold
|
||||
print("\n=== Test 1: Setting up balance close to underflow threshold ===")
|
||||
# First, manually set balance to a value very close to POSTGRES_INT_MIN
|
||||
# We'll set it to POSTGRES_INT_MIN + 100, then try to subtract 200
|
||||
# This should trigger underflow protection: (POSTGRES_INT_MIN + 100) + (-200) = POSTGRES_INT_MIN - 100
|
||||
initial_balance_target = POSTGRES_INT_MIN + 100
|
||||
|
||||
# Use direct database update to set the balance close to underflow
|
||||
from prisma.models import UserBalance
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance_target},
|
||||
"update": {"balance": initial_balance_target},
|
||||
},
|
||||
)
|
||||
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
print(f"Set balance to: {current_balance}")
|
||||
assert current_balance == initial_balance_target
|
||||
|
||||
# Test 2: Apply amount that should cause underflow
|
||||
print("\n=== Test 2: Testing underflow protection ===")
|
||||
test_amount = (
|
||||
-200
|
||||
) # This should cause underflow: (POSTGRES_INT_MIN + 100) + (-200) = POSTGRES_INT_MIN - 100
|
||||
expected_without_protection = current_balance + test_amount
|
||||
print(f"Current balance: {current_balance}")
|
||||
print(f"Test amount: {test_amount}")
|
||||
print(f"Without protection would be: {expected_without_protection}")
|
||||
print(f"Should be clamped to POSTGRES_INT_MIN: {POSTGRES_INT_MIN}")
|
||||
|
||||
# Apply the amount that should trigger underflow protection
|
||||
balance_result, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=test_amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
print(f"Actual result: {balance_result}")
|
||||
|
||||
# Check if underflow protection worked
|
||||
assert (
|
||||
balance_result == POSTGRES_INT_MIN
|
||||
), f"Expected underflow protection to clamp balance to {POSTGRES_INT_MIN}, got {balance_result}"
|
||||
|
||||
# Test 3: Edge case - exactly at POSTGRES_INT_MIN
|
||||
print("\n=== Test 3: Testing exact POSTGRES_INT_MIN boundary ===")
|
||||
# Set balance to exactly POSTGRES_INT_MIN
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
|
||||
"update": {"balance": POSTGRES_INT_MIN},
|
||||
},
|
||||
)
|
||||
|
||||
edge_balance = await credit_system.get_credits(user_id)
|
||||
print(f"Balance set to exactly POSTGRES_INT_MIN: {edge_balance}")
|
||||
|
||||
# Try to subtract 1 - should stay at POSTGRES_INT_MIN
|
||||
edge_result, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=-1,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
print(f"After subtracting 1: {edge_result}")
|
||||
|
||||
assert (
|
||||
edge_result == POSTGRES_INT_MIN
|
||||
), f"Expected balance to remain clamped at {POSTGRES_INT_MIN}, got {edge_result}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_underflow_protection_large_refunds(server: SpinTestServer):
|
||||
"""Test that large cumulative refunds don't cause integer underflow."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"underflow-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Set up balance close to underflow threshold to test the protection
|
||||
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
|
||||
# This should trigger underflow protection
|
||||
from prisma.models import UserBalance
|
||||
|
||||
test_balance = POSTGRES_INT_MIN + 1000
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": test_balance},
|
||||
"update": {"balance": test_balance},
|
||||
},
|
||||
)
|
||||
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
assert current_balance == test_balance
|
||||
|
||||
# Try to deduct amount that would cause underflow: test_balance + (-2000) = POSTGRES_INT_MIN - 1000
|
||||
underflow_amount = -2000
|
||||
expected_without_protection = (
|
||||
current_balance + underflow_amount
|
||||
) # Should be POSTGRES_INT_MIN - 1000
|
||||
|
||||
# Use _add_transaction directly with amount that would cause underflow
|
||||
final_balance, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=underflow_amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False, # Allow going negative for refunds
|
||||
)
|
||||
|
||||
# Balance should be clamped to POSTGRES_INT_MIN, not the calculated underflow value
|
||||
assert (
|
||||
final_balance == POSTGRES_INT_MIN
|
||||
), f"Balance should be clamped to {POSTGRES_INT_MIN}, got {final_balance}"
|
||||
assert (
|
||||
final_balance > expected_without_protection
|
||||
), f"Balance should be greater than underflow result {expected_without_protection}, got {final_balance}"
|
||||
|
||||
# Verify with get_credits too
|
||||
stored_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
stored_balance == POSTGRES_INT_MIN
|
||||
), f"Stored balance should be {POSTGRES_INT_MIN}, got {stored_balance}"
|
||||
|
||||
# Verify transaction was created with the underflow-protected balance
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id, "type": CreditTransactionType.REFUND},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
assert len(transactions) > 0, "Refund transaction should be created"
|
||||
assert (
|
||||
transactions[0].runningBalance == POSTGRES_INT_MIN
|
||||
), f"Transaction should show clamped balance {POSTGRES_INT_MIN}, got {transactions[0].runningBalance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServer):
|
||||
"""Test that multiple large refunds applied sequentially don't cause underflow."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"cumulative-underflow-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Set up balance close to underflow threshold
|
||||
from prisma.models import UserBalance
|
||||
|
||||
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
)
|
||||
|
||||
# Apply multiple refunds that would cumulatively underflow
|
||||
refund_amount = -300 # Each refund that would cause underflow when cumulative
|
||||
|
||||
# First refund: (POSTGRES_INT_MIN + 500) + (-300) = POSTGRES_INT_MIN + 200 (still above minimum)
|
||||
balance_1, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=refund_amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
|
||||
# Should be above minimum for first refund
|
||||
expected_balance_1 = (
|
||||
initial_balance + refund_amount
|
||||
) # Should be POSTGRES_INT_MIN + 200
|
||||
assert (
|
||||
balance_1 == expected_balance_1
|
||||
), f"First refund should result in {expected_balance_1}, got {balance_1}"
|
||||
assert (
|
||||
balance_1 >= POSTGRES_INT_MIN
|
||||
), f"First refund should not go below {POSTGRES_INT_MIN}, got {balance_1}"
|
||||
|
||||
# Second refund: (POSTGRES_INT_MIN + 200) + (-300) = POSTGRES_INT_MIN - 100 (would underflow)
|
||||
balance_2, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=refund_amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
|
||||
# Should be clamped to minimum due to underflow protection
|
||||
assert (
|
||||
balance_2 == POSTGRES_INT_MIN
|
||||
), f"Second refund should be clamped to {POSTGRES_INT_MIN}, got {balance_2}"
|
||||
|
||||
# Third refund: Should stay at minimum
|
||||
balance_3, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=refund_amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
|
||||
# Should still be at minimum
|
||||
assert (
|
||||
balance_3 == POSTGRES_INT_MIN
|
||||
), f"Third refund should stay at {POSTGRES_INT_MIN}, got {balance_3}"
|
||||
|
||||
# Final balance check
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance == POSTGRES_INT_MIN
|
||||
), f"Final balance should be {POSTGRES_INT_MIN}, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
|
||||
"""Test that concurrent large refunds don't cause race condition underflow."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"concurrent-underflow-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Set up balance close to underflow threshold
|
||||
from prisma.models import UserBalance
|
||||
|
||||
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
)
|
||||
|
||||
async def large_refund(amount: int, label: str):
|
||||
try:
|
||||
return await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=-amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
except Exception as e:
|
||||
return f"FAILED-{label}: {e}"
|
||||
|
||||
# Run concurrent refunds that would cause underflow if not protected
|
||||
# Each refund of 500 would cause underflow: initial_balance + (-500) could go below POSTGRES_INT_MIN
|
||||
refund_amount = 500
|
||||
results = await asyncio.gather(
|
||||
large_refund(refund_amount, "A"),
|
||||
large_refund(refund_amount, "B"),
|
||||
large_refund(refund_amount, "C"),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# Check all results are valid and no underflow occurred
|
||||
valid_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, tuple):
|
||||
balance, _ = result
|
||||
assert (
|
||||
balance >= POSTGRES_INT_MIN
|
||||
), f"Result {i} balance {balance} underflowed below {POSTGRES_INT_MIN}"
|
||||
valid_results.append(balance)
|
||||
elif isinstance(result, str) and "FAILED" in result:
|
||||
# Some operations might fail due to validation, that's okay
|
||||
pass
|
||||
else:
|
||||
# Unexpected exception
|
||||
assert not isinstance(
|
||||
result, Exception
|
||||
), f"Unexpected exception in result {i}: {result}"
|
||||
|
||||
# At least one operation should succeed
|
||||
assert (
|
||||
len(valid_results) > 0
|
||||
), f"At least one refund should succeed, got results: {results}"
|
||||
|
||||
# All successful results should be >= POSTGRES_INT_MIN
|
||||
for balance in valid_results:
|
||||
assert (
|
||||
balance >= POSTGRES_INT_MIN
|
||||
), f"Balance {balance} should not be below {POSTGRES_INT_MIN}"
|
||||
|
||||
# Final balance should be valid and at or above POSTGRES_INT_MIN
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance >= POSTGRES_INT_MIN
|
||||
), f"Final balance {final_balance} should not underflow below {POSTGRES_INT_MIN}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Integration test to verify complete migration from User.balance to UserBalance table.
|
||||
|
||||
This test ensures that:
|
||||
1. No User.balance queries exist in the system
|
||||
2. All balance operations go through UserBalance table
|
||||
3. User and UserBalance stay synchronized properly
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import UsageTransactionMetadata, UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for migration tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
pass
|
||||
|
||||
|
||||
async def cleanup_test_user(user_id: str) -> None:
|
||||
"""Clean up test user and their data."""
|
||||
try:
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
await User.prisma().delete_many(where={"id": user_id})
|
||||
except Exception as e:
|
||||
# Log cleanup failures but don't fail the test
|
||||
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_user_balance_migration_complete(server: SpinTestServer):
|
||||
"""Test that User table balance is never used and UserBalance is source of truth."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"migration-test-{datetime.now().timestamp()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# 1. Verify User table does NOT have balance set initially
|
||||
user = await User.prisma().find_unique(where={"id": user_id})
|
||||
assert user is not None
|
||||
# User.balance should not exist or should be None/0 if it exists
|
||||
user_balance_attr = getattr(user, "balance", None)
|
||||
if user_balance_attr is not None:
|
||||
assert (
|
||||
user_balance_attr == 0 or user_balance_attr is None
|
||||
), f"User.balance should be 0 or None, got {user_balance_attr}"
|
||||
|
||||
# 2. Perform various credit operations using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "migration_test"}),
|
||||
)
|
||||
balance1 = await credit_system.get_credits(user_id)
|
||||
assert balance1 == 1000
|
||||
|
||||
await credit_system.spend_credits(
|
||||
user_id,
|
||||
300,
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id="test", reason="Migration test spend"
|
||||
),
|
||||
)
|
||||
balance2 = await credit_system.get_credits(user_id)
|
||||
assert balance2 == 700
|
||||
|
||||
# 3. Verify UserBalance table has correct values
|
||||
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 700
|
||||
), f"UserBalance should be 700, got {user_balance.balance}"
|
||||
|
||||
# 4. CRITICAL: Verify User.balance is NEVER updated during operations
|
||||
user_after = await User.prisma().find_unique(where={"id": user_id})
|
||||
assert user_after is not None
|
||||
user_balance_after = getattr(user_after, "balance", None)
|
||||
if user_balance_after is not None:
|
||||
# If User.balance exists, it should still be 0 (never updated)
|
||||
assert (
|
||||
user_balance_after == 0 or user_balance_after is None
|
||||
), f"User.balance should remain 0/None after operations, got {user_balance_after}. This indicates User.balance is still being used!"
|
||||
|
||||
# 5. Verify get_credits always returns UserBalance value, not User.balance
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance == user_balance.balance
|
||||
), f"get_credits should return UserBalance value {user_balance.balance}, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_detect_stale_user_balance_queries(server: SpinTestServer):
|
||||
"""Test to detect if any operations are still using User.balance instead of UserBalance."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"stale-query-test-{datetime.now().timestamp()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Create UserBalance with specific value
|
||||
await UserBalance.prisma().create(
|
||||
data={"userId": user_id, "balance": 5000} # $50
|
||||
)
|
||||
|
||||
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
|
||||
balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
balance == 5000
|
||||
), f"Expected get_credits to return 5000 from UserBalance, got {balance}"
|
||||
|
||||
# Verify all operations use UserBalance using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "final_verification"}),
|
||||
)
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 6000, f"Expected 6000, got {final_balance}"
|
||||
|
||||
# Verify UserBalance table has the correct value
|
||||
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 6000
|
||||
), f"UserBalance should be 6000, got {user_balance.balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer):
|
||||
"""Test that concurrent operations all use UserBalance locking, not User.balance."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"concurrent-userbalance-test-{datetime.now().timestamp()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Set initial balance in UserBalance
|
||||
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
|
||||
|
||||
# Run concurrent operations to ensure they all use UserBalance atomic operations
|
||||
async def concurrent_spend(amount: int, label: str):
|
||||
try:
|
||||
await credit_system.spend_credits(
|
||||
user_id,
|
||||
amount,
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=f"concurrent-{label}",
|
||||
reason=f"Concurrent test {label}",
|
||||
),
|
||||
)
|
||||
return f"{label}-SUCCESS"
|
||||
except Exception as e:
|
||||
return f"{label}-FAILED: {e}"
|
||||
|
||||
# Run concurrent operations
|
||||
results = await asyncio.gather(
|
||||
concurrent_spend(100, "A"),
|
||||
concurrent_spend(200, "B"),
|
||||
concurrent_spend(300, "C"),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# All should succeed (1000 >= 100+200+300)
|
||||
successful = [r for r in results if "SUCCESS" in str(r)]
|
||||
assert len(successful) == 3, f"All operations should succeed, got {results}"
|
||||
|
||||
# Final balance should be 1000 - 600 = 400
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 400, f"Expected final balance 400, got {final_balance}"
|
||||
|
||||
# Verify UserBalance has correct value
|
||||
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 400
|
||||
), f"UserBalance should be 400, got {user_balance.balance}"
|
||||
|
||||
# Critical: If User.balance exists and was used, it might have wrong value
|
||||
try:
|
||||
user = await User.prisma().find_unique(where={"id": user_id})
|
||||
user_balance_attr = getattr(user, "balance", None)
|
||||
if user_balance_attr is not None:
|
||||
# If User.balance exists, it should NOT be used for operations
|
||||
# The fact that our final balance is correct from UserBalance proves the system is working
|
||||
print(
|
||||
f"✅ User.balance exists ({user_balance_attr}) but UserBalance ({user_balance.balance}) is being used correctly"
|
||||
)
|
||||
except Exception:
|
||||
print("✅ User.balance column doesn't exist - migration is complete")
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -83,7 +83,7 @@ async def disconnect():
|
||||
|
||||
|
||||
# Transaction timeout constant (in milliseconds)
|
||||
TRANSACTION_TIMEOUT = 15000 # 15 seconds - Increased from 5s to prevent timeout errors
|
||||
TRANSACTION_TIMEOUT = 30000 # 30 seconds - Increased from 15s to prevent timeout errors during graph creation under load
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -98,42 +98,6 @@ async def transaction(timeout: int = TRANSACTION_TIMEOUT):
|
||||
yield tx
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def locked_transaction(key: str, timeout: int = TRANSACTION_TIMEOUT):
|
||||
"""
|
||||
Create a transaction and take a per-key advisory *transaction* lock.
|
||||
|
||||
- Uses a 64-bit lock id via hashtextextended(key, 0) to avoid 32-bit collisions.
|
||||
- Bound by lock_timeout and statement_timeout so it won't block indefinitely.
|
||||
- Lock is held for the duration of the transaction and auto-released on commit/rollback.
|
||||
|
||||
Args:
|
||||
key: String lock key (e.g., "usr_trx_<uuid>").
|
||||
timeout: Transaction/lock/statement timeout in milliseconds.
|
||||
"""
|
||||
async with transaction(timeout=timeout) as tx:
|
||||
# Ensure we don't wait longer than desired
|
||||
# Note: SET LOCAL doesn't support parameterized queries, must use string interpolation
|
||||
await tx.execute_raw(f"SET LOCAL statement_timeout = '{int(timeout)}ms'") # type: ignore[arg-type]
|
||||
await tx.execute_raw(f"SET LOCAL lock_timeout = '{int(timeout)}ms'") # type: ignore[arg-type]
|
||||
|
||||
# Block until acquired or lock_timeout hits
|
||||
try:
|
||||
await tx.execute_raw(
|
||||
"SELECT pg_advisory_xact_lock(hashtextextended($1, 0))",
|
||||
key,
|
||||
)
|
||||
except Exception as e:
|
||||
# Normalize PG's lock timeout error to TimeoutError for callers
|
||||
if "lock timeout" in str(e).lower():
|
||||
raise TimeoutError(
|
||||
f"Could not acquire lock for key={key!r} within {timeout}ms"
|
||||
) from e
|
||||
raise
|
||||
|
||||
yield tx
|
||||
|
||||
|
||||
def get_database_schema() -> str:
|
||||
"""Extract database schema from DATABASE_URL."""
|
||||
parsed_url = urlparse(DATABASE_URL)
|
||||
|
||||
284
autogpt_platform/backend/backend/data/dynamic_fields.py
Normal file
284
autogpt_platform/backend/backend/data/dynamic_fields.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Utilities for handling dynamic field names with special delimiters.
|
||||
|
||||
Dynamic fields allow graphs to connect complex data structures using special delimiters:
|
||||
- _#_ for dictionary keys (e.g., "values_#_name" → values["name"])
|
||||
- _$_ for list indices (e.g., "items_$_0" → items[0])
|
||||
- _@_ for object attributes (e.g., "obj_@_attr" → obj.attr)
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.util.mock import MockObject
|
||||
|
||||
# Dynamic field delimiters
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
DYNAMIC_DELIMITERS = (LIST_SPLIT, DICT_SPLIT, OBJC_SPLIT)
|
||||
|
||||
|
||||
def extract_base_field_name(field_name: str) -> str:
|
||||
"""
|
||||
Extract the base field name from a dynamic field name by removing all dynamic suffixes.
|
||||
|
||||
Examples:
|
||||
extract_base_field_name("values_#_name") → "values"
|
||||
extract_base_field_name("items_$_0") → "items"
|
||||
extract_base_field_name("obj_@_attr") → "obj"
|
||||
extract_base_field_name("regular_field") → "regular_field"
|
||||
|
||||
Args:
|
||||
field_name: The field name that may contain dynamic delimiters
|
||||
|
||||
Returns:
|
||||
The base field name without any dynamic suffixes
|
||||
"""
|
||||
base_name = field_name
|
||||
for delimiter in DYNAMIC_DELIMITERS:
|
||||
if delimiter in base_name:
|
||||
base_name = base_name.split(delimiter)[0]
|
||||
return base_name
|
||||
|
||||
|
||||
def is_dynamic_field(field_name: str) -> bool:
|
||||
"""
|
||||
Check if a field name contains dynamic delimiters.
|
||||
|
||||
Args:
|
||||
field_name: The field name to check
|
||||
|
||||
Returns:
|
||||
True if the field contains any dynamic delimiters, False otherwise
|
||||
"""
|
||||
return any(delimiter in field_name for delimiter in DYNAMIC_DELIMITERS)
|
||||
|
||||
|
||||
def get_dynamic_field_description(field_name: str) -> str:
|
||||
"""
|
||||
Generate a description for a dynamic field based on its structure.
|
||||
|
||||
Args:
|
||||
field_name: The full dynamic field name (e.g., "values_#_name")
|
||||
|
||||
Returns:
|
||||
A descriptive string explaining what this dynamic field represents
|
||||
"""
|
||||
base_name = extract_base_field_name(field_name)
|
||||
|
||||
if DICT_SPLIT in field_name:
|
||||
# Extract the key part after _#_
|
||||
parts = field_name.split(DICT_SPLIT)
|
||||
if len(parts) > 1:
|
||||
key = parts[1].split("_")[0] if "_" in parts[1] else parts[1]
|
||||
return f"Dictionary field '{key}' for base field '{base_name}' ({base_name}['{key}'])"
|
||||
elif LIST_SPLIT in field_name:
|
||||
# Extract the index part after _$_
|
||||
parts = field_name.split(LIST_SPLIT)
|
||||
if len(parts) > 1:
|
||||
index = parts[1].split("_")[0] if "_" in parts[1] else parts[1]
|
||||
return (
|
||||
f"List item {index} for base field '{base_name}' ({base_name}[{index}])"
|
||||
)
|
||||
elif OBJC_SPLIT in field_name:
|
||||
# Extract the attribute part after _@_
|
||||
parts = field_name.split(OBJC_SPLIT)
|
||||
if len(parts) > 1:
|
||||
# Get the full attribute name (everything after _@_)
|
||||
attr = parts[1]
|
||||
return f"Object attribute '{attr}' for base field '{base_name}' ({base_name}.{attr})"
|
||||
|
||||
return f"Value for {field_name}"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Dynamic field parsing and merging utilities
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _next_delim(s: str) -> tuple[str | None, int]:
|
||||
"""
|
||||
Return the *earliest* delimiter appearing in `s` and its index.
|
||||
|
||||
If none present → (None, -1).
|
||||
"""
|
||||
first: str | None = None
|
||||
pos = len(s) # sentinel: larger than any real index
|
||||
for d in DYNAMIC_DELIMITERS:
|
||||
i = s.find(d)
|
||||
if 0 <= i < pos:
|
||||
first, pos = d, i
|
||||
return first, (pos if first else -1)
|
||||
|
||||
|
||||
def _tokenise(path: str) -> list[tuple[str, str]] | None:
|
||||
"""
|
||||
Convert the raw path string (starting with a delimiter) into
|
||||
[ (delimiter, identifier), … ] or None if the syntax is malformed.
|
||||
"""
|
||||
tokens: list[tuple[str, str]] = []
|
||||
while path:
|
||||
# 1. Which delimiter starts this chunk?
|
||||
delim = next((d for d in DYNAMIC_DELIMITERS if path.startswith(d)), None)
|
||||
if delim is None:
|
||||
return None # invalid syntax
|
||||
|
||||
# 2. Slice off the delimiter, then up to the next delimiter (or EOS)
|
||||
path = path[len(delim) :]
|
||||
nxt_delim, pos = _next_delim(path)
|
||||
token, path = (
|
||||
path[: pos if pos != -1 else len(path)],
|
||||
path[pos if pos != -1 else len(path) :],
|
||||
)
|
||||
if token == "":
|
||||
return None # empty identifier is invalid
|
||||
tokens.append((delim, token))
|
||||
return tokens
|
||||
|
||||
|
||||
def parse_execution_output(output: tuple[str, Any], name: str) -> Any:
|
||||
"""
|
||||
Retrieve a nested value out of `output` using the flattened *name*.
|
||||
|
||||
On any failure (wrong name, wrong type, out-of-range, bad path)
|
||||
returns **None**.
|
||||
|
||||
Args:
|
||||
output: Tuple of (base_name, data) representing a block output entry
|
||||
name: The flattened field name to extract from the output data
|
||||
|
||||
Returns:
|
||||
The value at the specified path, or None if not found/invalid
|
||||
"""
|
||||
base_name, data = output
|
||||
|
||||
# Exact match → whole object
|
||||
if name == base_name:
|
||||
return data
|
||||
|
||||
# Must start with the expected name
|
||||
if not name.startswith(base_name):
|
||||
return None
|
||||
path = name[len(base_name) :]
|
||||
if not path:
|
||||
return None # nothing left to parse
|
||||
|
||||
tokens = _tokenise(path)
|
||||
if tokens is None:
|
||||
return None
|
||||
|
||||
cur: Any = data
|
||||
for delim, ident in tokens:
|
||||
if delim == LIST_SPLIT:
|
||||
# list[index]
|
||||
try:
|
||||
idx = int(ident)
|
||||
except ValueError:
|
||||
return None
|
||||
if not isinstance(cur, list) or idx >= len(cur):
|
||||
return None
|
||||
cur = cur[idx]
|
||||
|
||||
elif delim == DICT_SPLIT:
|
||||
if not isinstance(cur, dict) or ident not in cur:
|
||||
return None
|
||||
cur = cur[ident]
|
||||
|
||||
elif delim == OBJC_SPLIT:
|
||||
if not hasattr(cur, ident):
|
||||
return None
|
||||
cur = getattr(cur, ident)
|
||||
|
||||
else:
|
||||
return None # unreachable
|
||||
|
||||
return cur
|
||||
|
||||
|
||||
def _assign(container: Any, tokens: list[tuple[str, str]], value: Any) -> Any:
|
||||
"""
|
||||
Recursive helper that *returns* the (possibly new) container with
|
||||
`value` assigned along the remaining `tokens` path.
|
||||
"""
|
||||
if not tokens:
|
||||
return value # leaf reached
|
||||
|
||||
delim, ident = tokens[0]
|
||||
rest = tokens[1:]
|
||||
|
||||
# ---------- list ----------
|
||||
if delim == LIST_SPLIT:
|
||||
try:
|
||||
idx = int(ident)
|
||||
except ValueError:
|
||||
raise ValueError("index must be an integer")
|
||||
|
||||
if container is None:
|
||||
container = []
|
||||
elif not isinstance(container, list):
|
||||
container = list(container) if hasattr(container, "__iter__") else []
|
||||
|
||||
while len(container) <= idx:
|
||||
container.append(None)
|
||||
container[idx] = _assign(container[idx], rest, value)
|
||||
return container
|
||||
|
||||
# ---------- dict ----------
|
||||
if delim == DICT_SPLIT:
|
||||
if container is None:
|
||||
container = {}
|
||||
elif not isinstance(container, dict):
|
||||
container = dict(container) if hasattr(container, "items") else {}
|
||||
container[ident] = _assign(container.get(ident), rest, value)
|
||||
return container
|
||||
|
||||
# ---------- object ----------
|
||||
if delim == OBJC_SPLIT:
|
||||
if container is None:
|
||||
container = MockObject()
|
||||
elif not hasattr(container, "__dict__"):
|
||||
# If it's not an object, create a new one
|
||||
container = MockObject()
|
||||
setattr(
|
||||
container,
|
||||
ident,
|
||||
_assign(getattr(container, ident, None), rest, value),
|
||||
)
|
||||
return container
|
||||
|
||||
return value # unreachable
|
||||
|
||||
|
||||
def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Reconstruct nested objects from a *flattened* dict of key → value.
|
||||
|
||||
Raises ValueError on syntactically invalid list indices.
|
||||
|
||||
Args:
|
||||
data: Dictionary with potentially flattened dynamic field keys
|
||||
|
||||
Returns:
|
||||
Dictionary with nested objects reconstructed from flattened keys
|
||||
"""
|
||||
merged: dict[str, Any] = {}
|
||||
|
||||
for key, value in data.items():
|
||||
# Split off the base name (before the first delimiter, if any)
|
||||
delim, pos = _next_delim(key)
|
||||
if delim is None:
|
||||
merged[key] = value
|
||||
continue
|
||||
|
||||
base, path = key[:pos], key[pos:]
|
||||
tokens = _tokenise(path)
|
||||
if tokens is None:
|
||||
# Invalid key; treat as scalar under the raw name
|
||||
merged[key] = value
|
||||
continue
|
||||
|
||||
merged[base] = _assign(merged.get(base), tokens, value)
|
||||
|
||||
data.update(merged)
|
||||
return data
|
||||
@@ -38,8 +38,8 @@ from prisma.types import (
|
||||
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
|
||||
from pydantic.fields import Field
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.retry import func_retry
|
||||
@@ -92,6 +92,31 @@ ExecutionStatus = AgentExecutionStatus
|
||||
NodeInputMask = Mapping[str, JsonValue]
|
||||
NodesInputMasks = Mapping[str, NodeInputMask]
|
||||
|
||||
# dest: source
|
||||
VALID_STATUS_TRANSITIONS = {
|
||||
ExecutionStatus.QUEUED: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
ExecutionStatus.RUNNING: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.TERMINATED, # For resuming halted execution
|
||||
],
|
||||
ExecutionStatus.COMPLETED: [
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
ExecutionStatus.FAILED: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
ExecutionStatus.TERMINATED: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class GraphExecutionMeta(BaseDbModel):
|
||||
id: str # type: ignore # Override base class to make this required
|
||||
@@ -105,6 +130,8 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
status: ExecutionStatus
|
||||
started_at: datetime
|
||||
ended_at: datetime
|
||||
is_shared: bool = False
|
||||
share_token: Optional[str] = None
|
||||
|
||||
class Stats(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
@@ -221,6 +248,8 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
if stats
|
||||
else None
|
||||
),
|
||||
is_shared=_graph_exec.isShared,
|
||||
share_token=_graph_exec.shareToken,
|
||||
)
|
||||
|
||||
|
||||
@@ -449,6 +478,48 @@ async def get_graph_executions(
|
||||
return [GraphExecutionMeta.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
async def get_graph_executions_count(
|
||||
user_id: Optional[str] = None,
|
||||
graph_id: Optional[str] = None,
|
||||
statuses: Optional[list[ExecutionStatus]] = None,
|
||||
created_time_gte: Optional[datetime] = None,
|
||||
created_time_lte: Optional[datetime] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get count of graph executions with optional filters.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID to filter by
|
||||
graph_id: Optional graph ID to filter by
|
||||
statuses: Optional list of execution statuses to filter by
|
||||
created_time_gte: Optional minimum creation time
|
||||
created_time_lte: Optional maximum creation time
|
||||
|
||||
Returns:
|
||||
Count of matching graph executions
|
||||
"""
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
"isDeleted": False,
|
||||
}
|
||||
|
||||
if user_id:
|
||||
where_filter["userId"] = user_id
|
||||
|
||||
if graph_id:
|
||||
where_filter["agentGraphId"] = graph_id
|
||||
|
||||
if created_time_gte or created_time_lte:
|
||||
where_filter["createdAt"] = {
|
||||
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
|
||||
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
|
||||
}
|
||||
if statuses:
|
||||
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
count = await AgentGraphExecution.prisma().count(where=where_filter)
|
||||
return count
|
||||
|
||||
|
||||
class GraphExecutionsPaginated(BaseModel):
|
||||
"""Response schema for paginated graph executions."""
|
||||
|
||||
@@ -580,7 +651,7 @@ async def create_graph_execution(
|
||||
data={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"inputs": SafeJson(inputs),
|
||||
"credentialInputs": (
|
||||
SafeJson(credential_inputs) if credential_inputs else Json({})
|
||||
@@ -727,6 +798,11 @@ async def update_graph_execution_stats(
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
if not status and not stats:
|
||||
raise ValueError(
|
||||
f"Must provide either status or stats to update for execution {graph_exec_id}"
|
||||
)
|
||||
|
||||
update_data: AgentGraphExecutionUpdateManyMutationInput = {}
|
||||
|
||||
if stats:
|
||||
@@ -738,20 +814,25 @@ async def update_graph_execution_stats(
|
||||
if status:
|
||||
update_data["executionStatus"] = status
|
||||
|
||||
updated_count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"OR": [
|
||||
{"executionStatus": ExecutionStatus.RUNNING},
|
||||
{"executionStatus": ExecutionStatus.QUEUED},
|
||||
# Terminated graph can be resumed.
|
||||
{"executionStatus": ExecutionStatus.TERMINATED},
|
||||
],
|
||||
},
|
||||
where_clause: AgentGraphExecutionWhereInput = {"id": graph_exec_id}
|
||||
|
||||
if status:
|
||||
if allowed_from := VALID_STATUS_TRANSITIONS.get(status, []):
|
||||
# Add OR clause to check if current status is one of the allowed source statuses
|
||||
where_clause["AND"] = [
|
||||
{"id": graph_exec_id},
|
||||
{"OR": [{"executionStatus": s} for s in allowed_from]},
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Status {status} cannot be set via update for execution {graph_exec_id}. "
|
||||
f"This status can only be set at creation or is not a valid target status."
|
||||
)
|
||||
|
||||
await AgentGraphExecution.prisma().update_many(
|
||||
where=where_clause,
|
||||
data=update_data,
|
||||
)
|
||||
if updated_count == 0:
|
||||
return None
|
||||
|
||||
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
|
||||
where={"id": graph_exec_id},
|
||||
@@ -759,6 +840,7 @@ async def update_graph_execution_stats(
|
||||
[*get_io_block_ids(), *get_webhook_block_ids()]
|
||||
),
|
||||
)
|
||||
|
||||
return GraphExecution.from_db(graph_exec)
|
||||
|
||||
|
||||
@@ -985,6 +1067,18 @@ class NodeExecutionEvent(NodeExecutionResult):
|
||||
)
|
||||
|
||||
|
||||
class SharedExecutionResponse(BaseModel):
|
||||
"""Public-safe response for shared executions"""
|
||||
|
||||
id: str
|
||||
graph_name: str
|
||||
graph_description: Optional[str]
|
||||
status: ExecutionStatus
|
||||
created_at: datetime
|
||||
outputs: CompletedBlockOutput # Only the final outputs, no intermediate data
|
||||
# Deliberately exclude: user_id, inputs, credentials, node details
|
||||
|
||||
|
||||
ExecutionEvent = Annotated[
|
||||
GraphExecutionEvent | NodeExecutionEvent, Field(discriminator="event_type")
|
||||
]
|
||||
@@ -1162,3 +1256,98 @@ async def get_block_error_stats(
|
||||
)
|
||||
for row in result
|
||||
]
|
||||
|
||||
|
||||
async def update_graph_execution_share_status(
|
||||
execution_id: str,
|
||||
user_id: str,
|
||||
is_shared: bool,
|
||||
share_token: str | None,
|
||||
shared_at: datetime | None,
|
||||
) -> None:
|
||||
"""Update the sharing status of a graph execution."""
|
||||
await AgentGraphExecution.prisma().update(
|
||||
where={"id": execution_id},
|
||||
data={
|
||||
"isShared": is_shared,
|
||||
"shareToken": share_token,
|
||||
"sharedAt": shared_at,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_execution_by_share_token(
|
||||
share_token: str,
|
||||
) -> SharedExecutionResponse | None:
|
||||
"""Get a shared execution with limited public-safe data."""
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={
|
||||
"shareToken": share_token,
|
||||
"isShared": True,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include={
|
||||
"AgentGraph": True,
|
||||
"NodeExecutions": {
|
||||
"include": {
|
||||
"Output": True,
|
||||
"Node": {
|
||||
"include": {
|
||||
"AgentBlock": True,
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if not execution:
|
||||
return None
|
||||
|
||||
# Extract outputs from OUTPUT blocks only (consistent with GraphExecution.from_db)
|
||||
outputs: CompletedBlockOutput = defaultdict(list)
|
||||
if execution.NodeExecutions:
|
||||
for node_exec in execution.NodeExecutions:
|
||||
if node_exec.Node and node_exec.Node.agentBlockId:
|
||||
# Get the block definition to check its type
|
||||
block = get_block(node_exec.Node.agentBlockId)
|
||||
|
||||
if block and block.block_type == BlockType.OUTPUT:
|
||||
# For OUTPUT blocks, the data is stored in executionData or Input
|
||||
# The executionData contains the structured input with 'name' and 'value' fields
|
||||
if hasattr(node_exec, "executionData") and node_exec.executionData:
|
||||
exec_data = type_utils.convert(
|
||||
node_exec.executionData, dict[str, Any]
|
||||
)
|
||||
if "name" in exec_data:
|
||||
name = exec_data["name"]
|
||||
value = exec_data.get("value")
|
||||
outputs[name].append(value)
|
||||
elif node_exec.Input:
|
||||
# Build input_data from Input relation
|
||||
input_data = {}
|
||||
for data in node_exec.Input:
|
||||
if data.name and data.data is not None:
|
||||
input_data[data.name] = type_utils.convert(
|
||||
data.data, JsonValue
|
||||
)
|
||||
|
||||
if "name" in input_data:
|
||||
name = input_data["name"]
|
||||
value = input_data.get("value")
|
||||
outputs[name].append(value)
|
||||
|
||||
return SharedExecutionResponse(
|
||||
id=execution.id,
|
||||
graph_name=(
|
||||
execution.AgentGraph.name
|
||||
if (execution.AgentGraph and execution.AgentGraph.name)
|
||||
else "Untitled Agent"
|
||||
),
|
||||
graph_description=(
|
||||
execution.AgentGraph.description if execution.AgentGraph else None
|
||||
),
|
||||
status=ExecutionStatus(execution.executionStatus),
|
||||
created_at=execution.createdAt,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ from prisma.enums import AgentExecutionStatus
|
||||
from backend.data.execution import get_graph_executions
|
||||
from backend.data.graph import get_graph_metadata
|
||||
from backend.data.model import UserExecutionSummaryStats
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.logging import TruncatedLogger
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[SummaryData]")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
from prisma.enums import SubmissionStatus
|
||||
@@ -19,6 +20,8 @@ from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import prisma as db
|
||||
from backend.data.dynamic_fields import extract_base_field_name
|
||||
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsFieldInfo,
|
||||
@@ -28,8 +31,17 @@ from backend.data.model import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
from .block import (
|
||||
Block,
|
||||
BlockInput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
EmptySchema,
|
||||
get_block,
|
||||
get_blocks,
|
||||
)
|
||||
from .db import BaseDbModel, query_raw_with_schema, transaction
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
|
||||
@@ -70,12 +82,15 @@ class Node(BaseDbModel):
|
||||
output_links: list[Link] = []
|
||||
|
||||
@property
|
||||
def block(self) -> Block[BlockSchema, BlockSchema]:
|
||||
def block(self) -> "Block[BlockSchema, BlockSchema] | _UnknownBlockBase":
|
||||
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
|
||||
block = get_block(self.block_id)
|
||||
if not block:
|
||||
raise ValueError(
|
||||
f"Block #{self.block_id} does not exist -> Node #{self.id} is invalid"
|
||||
# Log warning but don't raise exception - return a placeholder block for deleted blocks
|
||||
logger.warning(
|
||||
f"Block #{self.block_id} does not exist for Node #{self.id} (deleted/missing block), using UnknownBlock"
|
||||
)
|
||||
return _UnknownBlockBase(self.block_id)
|
||||
return block
|
||||
|
||||
|
||||
@@ -114,17 +129,20 @@ class NodeModel(Node):
|
||||
Returns a copy of the node model, stripped of any non-transferable properties
|
||||
"""
|
||||
stripped_node = self.model_copy(deep=True)
|
||||
# Remove credentials from node input
|
||||
|
||||
# Remove credentials and other (possible) secrets from node input
|
||||
if stripped_node.input_default:
|
||||
stripped_node.input_default = NodeModel._filter_secrets_from_node_input(
|
||||
stripped_node.input_default, self.block.input_schema.jsonschema()
|
||||
)
|
||||
|
||||
# Remove default secret value from secret input nodes
|
||||
if (
|
||||
stripped_node.block.block_type == BlockType.INPUT
|
||||
and stripped_node.input_default.get("secret", False) is True
|
||||
and "value" in stripped_node.input_default
|
||||
):
|
||||
stripped_node.input_default["value"] = ""
|
||||
del stripped_node.input_default["value"]
|
||||
|
||||
# Remove webhook info
|
||||
stripped_node.webhook_id = None
|
||||
@@ -141,8 +159,10 @@ class NodeModel(Node):
|
||||
result = {}
|
||||
for key, value in input_data.items():
|
||||
field_schema: dict | None = field_schemas.get(key)
|
||||
if (field_schema and field_schema.get("secret", False)) or any(
|
||||
sensitive_key in key.lower() for sensitive_key in sensitive_keys
|
||||
if (field_schema and field_schema.get("secret", False)) or (
|
||||
any(sensitive_key in key.lower() for sensitive_key in sensitive_keys)
|
||||
# Prevent removing `secret` flag on input nodes
|
||||
and type(value) is not bool
|
||||
):
|
||||
# This is a secret value -> filter this key-value pair out
|
||||
continue
|
||||
@@ -160,6 +180,7 @@ class BaseGraph(BaseDbModel):
|
||||
is_active: bool = True
|
||||
name: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
recommended_schedule_cron: str | None = None
|
||||
nodes: list[Node] = []
|
||||
links: list[Link] = []
|
||||
@@ -381,6 +402,8 @@ class GraphModel(Graph):
|
||||
user_id: str
|
||||
nodes: list[NodeModel] = [] # type: ignore
|
||||
|
||||
created_at: datetime
|
||||
|
||||
@property
|
||||
def starting_nodes(self) -> list[NodeModel]:
|
||||
outbound_nodes = {link.sink_id for link in self.links}
|
||||
@@ -393,6 +416,10 @@ class GraphModel(Graph):
|
||||
if node.id not in outbound_nodes or node.id in input_nodes
|
||||
]
|
||||
|
||||
@property
|
||||
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
||||
return cast(NodeModel, super().webhook_input_node)
|
||||
|
||||
def meta(self) -> "GraphMeta":
|
||||
"""
|
||||
Returns a GraphMeta object with metadata about the graph.
|
||||
@@ -694,9 +721,11 @@ class GraphModel(Graph):
|
||||
version=graph.version,
|
||||
forked_from_id=graph.forkedFromId,
|
||||
forked_from_version=graph.forkedFromVersion,
|
||||
created_at=graph.createdAt,
|
||||
is_active=graph.isActive,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
instructions=graph.instructions,
|
||||
recommended_schedule_cron=graph.recommendedScheduleCron,
|
||||
nodes=[NodeModel.from_db(node, for_export) for node in graph.Nodes or []],
|
||||
links=list(
|
||||
@@ -718,7 +747,7 @@ def _is_tool_pin(name: str) -> bool:
|
||||
|
||||
|
||||
def _sanitize_pin_name(name: str) -> str:
|
||||
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
sanitized_name = extract_base_field_name(name)
|
||||
if _is_tool_pin(sanitized_name):
|
||||
return "tools"
|
||||
return sanitized_name
|
||||
@@ -736,6 +765,13 @@ class GraphMeta(Graph):
|
||||
return GraphMeta(**graph.model_dump())
|
||||
|
||||
|
||||
class GraphsPaginated(BaseModel):
|
||||
"""Response schema for paginated graphs."""
|
||||
|
||||
graphs: list[GraphMeta]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
# --------------------- CRUD functions --------------------- #
|
||||
|
||||
|
||||
@@ -764,31 +800,42 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
|
||||
return NodeModel.from_db(node)
|
||||
|
||||
|
||||
async def list_graphs(
|
||||
async def list_graphs_paginated(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
filter_by: Literal["active"] | None = "active",
|
||||
) -> list[GraphMeta]:
|
||||
) -> GraphsPaginated:
|
||||
"""
|
||||
Retrieves graph metadata objects.
|
||||
Default behaviour is to get all currently active graphs.
|
||||
Retrieves paginated graph metadata objects.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user that owns the graphs.
|
||||
page: Page number (1-based).
|
||||
page_size: Number of graphs per page.
|
||||
filter_by: An optional filter to either select graphs.
|
||||
user_id: The ID of the user that owns the graph.
|
||||
|
||||
Returns:
|
||||
list[GraphMeta]: A list of objects representing the retrieved graphs.
|
||||
GraphsPaginated: Paginated list of graph metadata.
|
||||
"""
|
||||
where_clause: AgentGraphWhereInput = {"userId": user_id}
|
||||
|
||||
if filter_by == "active":
|
||||
where_clause["isActive"] = True
|
||||
|
||||
# Get total count
|
||||
total_count = await AgentGraph.prisma().count(where=where_clause)
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
# Get paginated results
|
||||
offset = (page - 1) * page_size
|
||||
graphs = await AgentGraph.prisma().find_many(
|
||||
where=where_clause,
|
||||
distinct=["id"],
|
||||
order={"version": "desc"},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
skip=offset,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
graph_models: list[GraphMeta] = []
|
||||
@@ -802,7 +849,15 @@ async def list_graphs(
|
||||
logger.error(f"Error processing graph {graph.id}: {e}")
|
||||
continue
|
||||
|
||||
return graph_models
|
||||
return GraphsPaginated(
|
||||
graphs=graph_models,
|
||||
pagination=Pagination(
|
||||
total_items=total_count,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph | None:
|
||||
@@ -1022,11 +1077,14 @@ async def set_graph_active_version(graph_id: str, version: int, user_id: str) ->
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_all_versions(graph_id: str, user_id: str) -> list[GraphModel]:
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str, user_id: str, limit: int = MAX_GRAPH_VERSIONS_FETCH
|
||||
) -> list[GraphModel]:
|
||||
graph_versions = await AgentGraph.prisma().find_many(
|
||||
where={"id": graph_id, "userId": user_id},
|
||||
order={"version": "desc"},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
take=limit,
|
||||
)
|
||||
|
||||
if not graph_versions:
|
||||
@@ -1144,6 +1202,7 @@ def make_graph_model(creatable_graph: Graph, user_id: str) -> GraphModel:
|
||||
return GraphModel(
|
||||
**creatable_graph.model_dump(exclude={"nodes"}),
|
||||
user_id=user_id,
|
||||
created_at=datetime.now(tz=timezone.utc),
|
||||
nodes=[
|
||||
NodeModel(
|
||||
**creatable_node.model_dump(),
|
||||
@@ -1274,3 +1333,34 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
id,
|
||||
path,
|
||||
)
|
||||
|
||||
|
||||
# Simple placeholder class for deleted/missing blocks
|
||||
class _UnknownBlockBase(Block):
|
||||
"""
|
||||
Placeholder for deleted/missing blocks that inherits from Block
|
||||
but uses a name that doesn't end with 'Block' to avoid auto-discovery.
|
||||
"""
|
||||
|
||||
def __init__(self, block_id: str = "00000000-0000-0000-0000-000000000000"):
|
||||
# Initialize with minimal valid Block parameters
|
||||
super().__init__(
|
||||
id=block_id,
|
||||
description=f"Unknown or deleted block (original ID: {block_id})",
|
||||
disabled=True,
|
||||
input_schema=EmptySchema,
|
||||
output_schema=EmptySchema,
|
||||
categories=set(),
|
||||
contributors=[],
|
||||
static_output=False,
|
||||
block_type=BlockType.STANDARD,
|
||||
webhook_config=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "UnknownBlock"
|
||||
|
||||
async def run(self, input_data, **kwargs):
|
||||
"""Always yield an error for missing blocks."""
|
||||
yield "error", f"Block {self.id} no longer exists"
|
||||
|
||||
@@ -201,25 +201,56 @@ async def test_get_input_schema(server: SpinTestServer, snapshot: Snapshot):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clean_graph(server: SpinTestServer):
|
||||
"""
|
||||
Test the clean_graph function that:
|
||||
1. Clears input block values
|
||||
2. Removes credentials from nodes
|
||||
Test the stripped_for_export function that:
|
||||
1. Removes sensitive/secret fields from node inputs
|
||||
2. Removes webhook information
|
||||
3. Preserves non-sensitive data including input block values
|
||||
"""
|
||||
# Create a graph with input blocks and credentials
|
||||
# Create a graph with input blocks containing both sensitive and normal data
|
||||
graph = Graph(
|
||||
id="test_clean_graph",
|
||||
name="Test Clean Graph",
|
||||
description="Test graph cleaning",
|
||||
nodes=[
|
||||
Node(
|
||||
id="input_node",
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={
|
||||
"_test_id": "input_node",
|
||||
"name": "test_input",
|
||||
"value": "test value",
|
||||
"value": "test value", # This should be preserved
|
||||
"description": "Test input description",
|
||||
},
|
||||
),
|
||||
Node(
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={
|
||||
"_test_id": "input_node_secret",
|
||||
"name": "secret_input",
|
||||
"value": "another value",
|
||||
"secret": True, # This makes the input secret
|
||||
},
|
||||
),
|
||||
Node(
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={
|
||||
"_test_id": "node_with_secrets",
|
||||
"input": "normal_value",
|
||||
"control_test_input": "should be preserved",
|
||||
"api_key": "secret_api_key_123", # Should be filtered
|
||||
"password": "secret_password_456", # Should be filtered
|
||||
"token": "secret_token_789", # Should be filtered
|
||||
"credentials": { # Should be filtered
|
||||
"id": "fake-github-credentials-id",
|
||||
"provider": "github",
|
||||
"type": "api_key",
|
||||
},
|
||||
"anthropic_credentials": { # Should be filtered
|
||||
"id": "fake-anthropic-credentials-id",
|
||||
"provider": "anthropic",
|
||||
"type": "api_key",
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
links=[],
|
||||
)
|
||||
@@ -231,15 +262,54 @@ async def test_clean_graph(server: SpinTestServer):
|
||||
)
|
||||
|
||||
# Clean the graph
|
||||
created_graph = await server.agent_server.test_get_graph(
|
||||
cleaned_graph = await server.agent_server.test_get_graph(
|
||||
created_graph.id, created_graph.version, DEFAULT_USER_ID, for_export=True
|
||||
)
|
||||
|
||||
# # Verify input block value is cleared
|
||||
# Verify sensitive fields are removed but normal fields are preserved
|
||||
input_node = next(
|
||||
n for n in created_graph.nodes if n.block_id == AgentInputBlock().id
|
||||
n for n in cleaned_graph.nodes if n.input_default["_test_id"] == "input_node"
|
||||
)
|
||||
assert input_node.input_default["value"] == ""
|
||||
|
||||
# Non-sensitive fields should be preserved
|
||||
assert input_node.input_default["name"] == "test_input"
|
||||
assert input_node.input_default["value"] == "test value" # Should be preserved now
|
||||
assert input_node.input_default["description"] == "Test input description"
|
||||
|
||||
# Sensitive fields should be filtered out
|
||||
assert "api_key" not in input_node.input_default
|
||||
assert "password" not in input_node.input_default
|
||||
|
||||
# Verify secret input node preserves non-sensitive fields but removes secret value
|
||||
secret_node = next(
|
||||
n
|
||||
for n in cleaned_graph.nodes
|
||||
if n.input_default["_test_id"] == "input_node_secret"
|
||||
)
|
||||
assert secret_node.input_default["name"] == "secret_input"
|
||||
assert "value" not in secret_node.input_default # Secret default should be removed
|
||||
assert secret_node.input_default["secret"] is True
|
||||
|
||||
# Verify sensitive fields are filtered from nodes with secrets
|
||||
secrets_node = next(
|
||||
n
|
||||
for n in cleaned_graph.nodes
|
||||
if n.input_default["_test_id"] == "node_with_secrets"
|
||||
)
|
||||
# Normal fields should be preserved
|
||||
assert secrets_node.input_default["input"] == "normal_value"
|
||||
assert secrets_node.input_default["control_test_input"] == "should be preserved"
|
||||
# Sensitive fields should be filtered out
|
||||
assert "api_key" not in secrets_node.input_default
|
||||
assert "password" not in secrets_node.input_default
|
||||
assert "token" not in secrets_node.input_default
|
||||
assert "credentials" not in secrets_node.input_default
|
||||
assert "anthropic_credentials" not in secrets_node.input_default
|
||||
|
||||
# Verify webhook info is removed (if any nodes had it)
|
||||
for node in cleaned_graph.nodes:
|
||||
assert node.webhook_id is None
|
||||
assert node.webhook is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
@@ -14,6 +14,7 @@ AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
|
||||
"Nodes": {"include": AGENT_NODE_INCLUDE}
|
||||
}
|
||||
|
||||
|
||||
EXECUTION_RESULT_ORDER: list[prisma.types.AgentNodeExecutionOrderByInput] = [
|
||||
{"queuedTime": "desc"},
|
||||
# Fallback: Incomplete execs has no queuedTime.
|
||||
@@ -28,6 +29,13 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
}
|
||||
|
||||
MAX_NODE_EXECUTIONS_FETCH = 1000
|
||||
MAX_LIBRARY_AGENT_EXECUTIONS_FETCH = 10
|
||||
|
||||
# Default limits for potentially large result sets
|
||||
MAX_CREDIT_REFUND_REQUESTS_FETCH = 100
|
||||
MAX_INTEGRATION_WEBHOOKS_FETCH = 100
|
||||
MAX_USER_API_KEYS_FETCH = 500
|
||||
MAX_GRAPH_VERSIONS_FETCH = 50
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
|
||||
"NodeExecutions": {
|
||||
@@ -71,13 +79,56 @@ INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
}
|
||||
|
||||
|
||||
def library_agent_include(user_id: str) -> prisma.types.LibraryAgentInclude:
|
||||
return {
|
||||
"AgentGraph": {
|
||||
"include": {
|
||||
**AGENT_GRAPH_INCLUDE,
|
||||
"Executions": {"where": {"userId": user_id}},
|
||||
}
|
||||
},
|
||||
"Creator": True,
|
||||
def library_agent_include(
|
||||
user_id: str,
|
||||
include_nodes: bool = True,
|
||||
include_executions: bool = True,
|
||||
execution_limit: int = MAX_LIBRARY_AGENT_EXECUTIONS_FETCH,
|
||||
) -> prisma.types.LibraryAgentInclude:
|
||||
"""
|
||||
Fully configurable includes for library agent queries with performance optimization.
|
||||
|
||||
Args:
|
||||
user_id: User ID for filtering user-specific data
|
||||
include_nodes: Whether to include graph nodes (default: True, needed for get_sub_graphs)
|
||||
include_executions: Whether to include executions (default: True, safe with execution_limit)
|
||||
execution_limit: Limit on executions to fetch (default: MAX_LIBRARY_AGENT_EXECUTIONS_FETCH)
|
||||
|
||||
Defaults maintain backward compatibility and safety - includes everything needed for all functionality.
|
||||
For performance optimization, explicitly set include_nodes=False and include_executions=False
|
||||
for listing views where frontend fetches data separately.
|
||||
|
||||
Performance impact:
|
||||
- Default (full nodes + limited executions): Original performance, works everywhere
|
||||
- Listing optimization (no nodes/executions): ~2s for 15 agents vs potential timeouts
|
||||
- Unlimited executions: varies by user (thousands of executions = timeouts)
|
||||
"""
|
||||
result: prisma.types.LibraryAgentInclude = {
|
||||
"Creator": True, # Always needed for creator info
|
||||
}
|
||||
|
||||
# Build AgentGraph include based on requested options
|
||||
if include_nodes or include_executions:
|
||||
agent_graph_include = {}
|
||||
|
||||
# Add nodes if requested (always full nodes)
|
||||
if include_nodes:
|
||||
agent_graph_include.update(AGENT_GRAPH_INCLUDE) # Full nodes
|
||||
|
||||
# Add executions if requested
|
||||
if include_executions:
|
||||
agent_graph_include["Executions"] = {
|
||||
"where": {"userId": user_id},
|
||||
"order_by": {"createdAt": "desc"},
|
||||
"take": execution_limit,
|
||||
}
|
||||
|
||||
result["AgentGraph"] = cast(
|
||||
prisma.types.AgentGraphArgsFromLibraryAgent,
|
||||
{"include": agent_graph_include},
|
||||
)
|
||||
else:
|
||||
# Default: Basic metadata only (fast - recommended for most use cases)
|
||||
result["AgentGraph"] = True # Basic graph metadata (name, description, id)
|
||||
|
||||
return result
|
||||
|
||||
@@ -11,7 +11,10 @@ from prisma.types import (
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
from backend.data.includes import (
|
||||
INTEGRATION_WEBHOOK_INCLUDE,
|
||||
MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
@@ -128,22 +131,36 @@ async def get_webhook(
|
||||
|
||||
@overload
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: Literal[True]
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
*,
|
||||
include_relations: Literal[True],
|
||||
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
) -> list[WebhookWithRelations]: ...
|
||||
@overload
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: Literal[False] = False
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
*,
|
||||
include_relations: Literal[False] = False,
|
||||
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
) -> list[Webhook]: ...
|
||||
|
||||
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: bool = False
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
*,
|
||||
include_relations: bool = False,
|
||||
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
) -> list[Webhook] | list[WebhookWithRelations]:
|
||||
if not credentials_id:
|
||||
raise ValueError("credentials_id must not be empty")
|
||||
webhooks = await IntegrationWebhook.prisma().find_many(
|
||||
where={"userId": user_id, "credentialsId": credentials_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE if include_relations else None,
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
return [
|
||||
(WebhookWithRelations if include_relations else Webhook).from_db(webhook)
|
||||
|
||||
@@ -270,6 +270,7 @@ def SchemaField(
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
format: Optional[str] = None,
|
||||
json_schema_extra: Optional[dict[str, Any]] = None,
|
||||
) -> T:
|
||||
if default is PydanticUndefined and default_factory is None:
|
||||
@@ -285,6 +286,7 @@ def SchemaField(
|
||||
"advanced": advanced,
|
||||
"hidden": hidden,
|
||||
"depends_on": depends_on,
|
||||
"format": format,
|
||||
**(json_schema_extra or {}),
|
||||
}.items()
|
||||
if v is not None
|
||||
@@ -345,6 +347,9 @@ class APIKeyCredentials(_BaseCredentials):
|
||||
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
|
||||
|
||||
def auth_header(self) -> str:
|
||||
# Linear API keys should not have Bearer prefix
|
||||
if self.provider == "linear":
|
||||
return self.api_key.get_secret_value()
|
||||
return f"Bearer {self.api_key.get_secret_value()}"
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from prisma.types import (
|
||||
# from backend.notifications.models import NotificationEvent
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.logging import TruncatedLogger
|
||||
|
||||
@@ -235,6 +235,7 @@ class BaseEventModel(BaseModel):
|
||||
|
||||
|
||||
class NotificationEventModel(BaseEventModel, Generic[NotificationDataType_co]):
|
||||
id: Optional[str] = None # None when creating, populated when reading from DB
|
||||
data: NotificationDataType_co
|
||||
|
||||
@property
|
||||
@@ -378,6 +379,7 @@ class NotificationPreference(BaseModel):
|
||||
|
||||
|
||||
class UserNotificationEventDTO(BaseModel):
|
||||
id: str # Added to track notifications for removal
|
||||
type: NotificationType
|
||||
data: dict
|
||||
created_at: datetime
|
||||
@@ -386,6 +388,7 @@ class UserNotificationEventDTO(BaseModel):
|
||||
@staticmethod
|
||||
def from_db(model: NotificationEvent) -> "UserNotificationEventDTO":
|
||||
return UserNotificationEventDTO(
|
||||
id=model.id,
|
||||
type=model.type,
|
||||
data=dict(model.data),
|
||||
created_at=model.createdAt,
|
||||
@@ -541,6 +544,79 @@ async def empty_user_notification_batch(
|
||||
) from e
|
||||
|
||||
|
||||
async def clear_all_user_notification_batches(user_id: str) -> None:
|
||||
"""Clear ALL notification batches for a user across all types.
|
||||
|
||||
Used when user's email is bounced/inactive and we should stop
|
||||
trying to send them ANY emails.
|
||||
"""
|
||||
try:
|
||||
# Delete all notification events for this user
|
||||
await NotificationEvent.prisma().delete_many(
|
||||
where={"UserNotificationBatch": {"is": {"userId": user_id}}}
|
||||
)
|
||||
|
||||
# Delete all batches for this user
|
||||
await UserNotificationBatch.prisma().delete_many(where={"userId": user_id})
|
||||
|
||||
logger.info(f"Cleared all notification batches for user {user_id}")
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to clear all notification batches for user {user_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def remove_notifications_from_batch(
|
||||
user_id: str, notification_type: NotificationType, notification_ids: list[str]
|
||||
) -> None:
|
||||
"""Remove specific notifications from a user's batch by their IDs.
|
||||
|
||||
This is used after successful sending to remove only the
|
||||
sent notifications, preventing duplicates on retry.
|
||||
"""
|
||||
if not notification_ids:
|
||||
return
|
||||
|
||||
try:
|
||||
# Delete the specific notification events
|
||||
deleted_count = await NotificationEvent.prisma().delete_many(
|
||||
where={
|
||||
"id": {"in": notification_ids},
|
||||
"UserNotificationBatch": {
|
||||
"is": {"userId": user_id, "type": notification_type}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Removed {deleted_count} notifications from batch for user {user_id}"
|
||||
)
|
||||
|
||||
# Check if batch is now empty and delete it if so
|
||||
remaining = await NotificationEvent.prisma().count(
|
||||
where={
|
||||
"UserNotificationBatch": {
|
||||
"is": {"userId": user_id, "type": notification_type}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if remaining == 0:
|
||||
await UserNotificationBatch.prisma().delete_many(
|
||||
where=UserNotificationBatchWhereInput(
|
||||
userId=user_id,
|
||||
type=notification_type,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Deleted empty batch for user {user_id} and type {notification_type}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to remove notifications from batch for user {user_id} and type {notification_type}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_user_notification_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import prisma
|
||||
@@ -9,9 +10,9 @@ from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
|
||||
from backend.data.block import get_blocks
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
from backend.util.cache import cached
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
# Mapping from user reason id to categories to search for when choosing agent to show
|
||||
@@ -25,12 +26,10 @@ REASON_MAPPING: dict[str, list[str]] = {
|
||||
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
|
||||
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
|
||||
|
||||
user_credit = get_user_credit_model()
|
||||
|
||||
|
||||
class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
completedSteps: Optional[list[OnboardingStep]] = None
|
||||
notificationDot: Optional[bool] = None
|
||||
walletShown: Optional[bool] = None
|
||||
notified: Optional[list[OnboardingStep]] = None
|
||||
usageReason: Optional[str] = None
|
||||
integrations: Optional[list[str]] = None
|
||||
@@ -39,6 +38,8 @@ class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
agentInput: Optional[dict[str, Any]] = None
|
||||
onboardingAgentExecutionId: Optional[str] = None
|
||||
agentRuns: Optional[int] = None
|
||||
lastRunAt: Optional[datetime] = None
|
||||
consecutiveRunDays: Optional[int] = None
|
||||
|
||||
|
||||
async def get_user_onboarding(user_id: str):
|
||||
@@ -57,16 +58,22 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update["completedSteps"] = list(set(data.completedSteps))
|
||||
for step in (
|
||||
OnboardingStep.AGENT_NEW_RUN,
|
||||
OnboardingStep.RUN_AGENTS,
|
||||
OnboardingStep.MARKETPLACE_VISIT,
|
||||
OnboardingStep.MARKETPLACE_ADD_AGENT,
|
||||
OnboardingStep.MARKETPLACE_RUN_AGENT,
|
||||
OnboardingStep.BUILDER_SAVE_AGENT,
|
||||
OnboardingStep.BUILDER_RUN_AGENT,
|
||||
OnboardingStep.RE_RUN_AGENT,
|
||||
OnboardingStep.SCHEDULE_AGENT,
|
||||
OnboardingStep.RUN_AGENTS,
|
||||
OnboardingStep.RUN_3_DAYS,
|
||||
OnboardingStep.TRIGGER_WEBHOOK,
|
||||
OnboardingStep.RUN_14_DAYS,
|
||||
OnboardingStep.RUN_AGENTS_100,
|
||||
):
|
||||
if step in data.completedSteps:
|
||||
await reward_user(user_id, step)
|
||||
if data.notificationDot is not None:
|
||||
update["notificationDot"] = data.notificationDot
|
||||
if data.walletShown is not None:
|
||||
update["walletShown"] = data.walletShown
|
||||
if data.notified is not None:
|
||||
update["notified"] = list(set(data.notified))
|
||||
if data.usageReason is not None:
|
||||
@@ -83,6 +90,10 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
|
||||
if data.agentRuns is not None:
|
||||
update["agentRuns"] = data.agentRuns
|
||||
if data.lastRunAt is not None:
|
||||
update["lastRunAt"] = data.lastRunAt
|
||||
if data.consecutiveRunDays is not None:
|
||||
update["consecutiveRunDays"] = data.consecutiveRunDays
|
||||
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
@@ -101,16 +112,28 @@ async def reward_user(user_id: str, step: OnboardingStep):
|
||||
# This is seen as a reward for the GET_RESULTS step in the wallet
|
||||
case OnboardingStep.AGENT_NEW_RUN:
|
||||
reward = 300
|
||||
case OnboardingStep.RUN_AGENTS:
|
||||
reward = 300
|
||||
case OnboardingStep.MARKETPLACE_VISIT:
|
||||
reward = 100
|
||||
case OnboardingStep.MARKETPLACE_ADD_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.MARKETPLACE_RUN_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_SAVE_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_RUN_AGENT:
|
||||
case OnboardingStep.RE_RUN_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.SCHEDULE_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.RUN_AGENTS:
|
||||
reward = 300
|
||||
case OnboardingStep.RUN_3_DAYS:
|
||||
reward = 100
|
||||
case OnboardingStep.TRIGGER_WEBHOOK:
|
||||
reward = 100
|
||||
case OnboardingStep.RUN_14_DAYS:
|
||||
reward = 300
|
||||
case OnboardingStep.RUN_AGENTS_100:
|
||||
reward = 300
|
||||
|
||||
if reward == 0:
|
||||
return
|
||||
@@ -122,7 +145,8 @@ async def reward_user(user_id: str, step: OnboardingStep):
|
||||
return
|
||||
|
||||
onboarding.rewardedFor.append(step)
|
||||
await user_credit.onboarding_reward(user_id, reward, step)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
await user_credit_model.onboarding_reward(user_id, reward, step)
|
||||
await UserOnboarding.prisma().update(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
@@ -132,6 +156,22 @@ async def reward_user(user_id: str, step: OnboardingStep):
|
||||
)
|
||||
|
||||
|
||||
async def complete_webhook_trigger_step(user_id: str):
|
||||
"""
|
||||
Completes the TRIGGER_WEBHOOK onboarding step for the user if not already completed.
|
||||
"""
|
||||
|
||||
onboarding = await get_user_onboarding(user_id)
|
||||
if OnboardingStep.TRIGGER_WEBHOOK not in onboarding.completedSteps:
|
||||
await update_user_onboarding(
|
||||
user_id,
|
||||
UserOnboardingUpdate(
|
||||
completedSteps=onboarding.completedSteps
|
||||
+ [OnboardingStep.TRIGGER_WEBHOOK]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def clean_and_split(text: str) -> list[str]:
|
||||
"""
|
||||
Removes all special characters from a string, truncates it to 100 characters,
|
||||
@@ -236,8 +276,14 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
for word in user_onboarding.integrations
|
||||
]
|
||||
|
||||
where_clause["is_available"] = True
|
||||
|
||||
# Try to take only agents that are available and allowed for onboarding
|
||||
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause),
|
||||
where={
|
||||
"is_available": True,
|
||||
"useForOnboarding": True,
|
||||
},
|
||||
order=[
|
||||
{"featured": "desc"},
|
||||
{"runs": "desc"},
|
||||
@@ -246,59 +292,16 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
take=100,
|
||||
)
|
||||
|
||||
agentListings = await prisma.models.StoreListingVersion.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": [agent.storeListingVersionId for agent in storeAgents]},
|
||||
},
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
|
||||
for listing in agentListings:
|
||||
agent = listing.AgentGraph
|
||||
if agent is None:
|
||||
continue
|
||||
graph = GraphModel.from_db(agent)
|
||||
# Remove agents with empty input schema
|
||||
if not graph.input_schema:
|
||||
storeAgents = [
|
||||
a for a in storeAgents if a.storeListingVersionId != listing.id
|
||||
]
|
||||
continue
|
||||
|
||||
# Remove agents with empty credentials
|
||||
# Get nodes from this agent that have credentials
|
||||
nodes = await prisma.models.AgentNode.prisma().find_many(
|
||||
where={
|
||||
"agentGraphId": agent.id,
|
||||
"agentBlockId": {"in": list(CREDENTIALS_FIELDS.keys())},
|
||||
},
|
||||
)
|
||||
for node in nodes:
|
||||
block_id = node.agentBlockId
|
||||
field_name = CREDENTIALS_FIELDS[block_id]
|
||||
# If there are no credentials or they are empty, remove the agent
|
||||
# FIXME ignores default values
|
||||
if (
|
||||
field_name not in node.constantInput
|
||||
or node.constantInput[field_name] is None
|
||||
):
|
||||
storeAgents = [
|
||||
a for a in storeAgents if a.storeListingVersionId != listing.id
|
||||
]
|
||||
break
|
||||
|
||||
# If there are less than 2 agents, add more agents to the list
|
||||
# If not enough agents found, relax the useForOnboarding filter
|
||||
if len(storeAgents) < 2:
|
||||
storeAgents += await prisma.models.StoreAgent.prisma().find_many(
|
||||
where={
|
||||
"listing_id": {"not_in": [agent.listing_id for agent in storeAgents]},
|
||||
},
|
||||
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause),
|
||||
order=[
|
||||
{"featured": "desc"},
|
||||
{"runs": "desc"},
|
||||
{"rating": "desc"},
|
||||
],
|
||||
take=2 - len(storeAgents),
|
||||
take=100,
|
||||
)
|
||||
|
||||
# Calculate points for the first X agents and choose the top 2
|
||||
@@ -333,8 +336,13 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
]
|
||||
|
||||
|
||||
@cached(maxsize=1, ttl_seconds=300) # Cache for 5 minutes since this rarely changes
|
||||
async def onboarding_enabled() -> bool:
|
||||
"""
|
||||
Check if onboarding should be enabled based on store agent count.
|
||||
Cached to prevent repeated slow database queries.
|
||||
"""
|
||||
# Use a more efficient query that stops counting after finding enough agents
|
||||
count = await prisma.models.StoreAgent.prisma().count(take=MIN_AGENT_COUNT + 1)
|
||||
|
||||
# Onboading is enabled if there are at least 2 agents in the store
|
||||
# Onboarding is enabled if there are at least 2 agents in the store
|
||||
return count >= MIN_AGENT_COUNT
|
||||
|
||||
5
autogpt_platform/backend/backend/data/partial_types.py
Normal file
5
autogpt_platform/backend/backend/data/partial_types.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import prisma.models
|
||||
|
||||
|
||||
class StoreAgentWithRank(prisma.models.StoreAgent):
|
||||
rank: float
|
||||
@@ -1,19 +1,18 @@
|
||||
import logging
|
||||
import os
|
||||
from functools import cache
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from dotenv import load_dotenv
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
from backend.util.cache import cached, thread_cached
|
||||
from backend.util.retry import conn_retry
|
||||
|
||||
load_dotenv()
|
||||
|
||||
HOST = os.getenv("REDIS_HOST", "localhost")
|
||||
PORT = int(os.getenv("REDIS_PORT", "6379"))
|
||||
PASSWORD = os.getenv("REDIS_PASSWORD", "password")
|
||||
PASSWORD = os.getenv("REDIS_PASSWORD", None)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,7 +34,7 @@ def disconnect():
|
||||
get_redis().close()
|
||||
|
||||
|
||||
@cache
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_redis() -> Redis:
|
||||
return connect()
|
||||
|
||||
|
||||
@@ -15,15 +15,20 @@ from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
|
||||
from backend.data.db import prisma
|
||||
from backend.data.model import User, UserIntegrations, UserMetadata
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util.cache import cached
|
||||
from backend.util.encryption import JSONCryptor
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
# Cache decorator alias for consistent user lookup caching
|
||||
cache_user_lookup = cached(maxsize=1000, ttl_seconds=300)
|
||||
|
||||
|
||||
@cache_user_lookup
|
||||
async def get_or_create_user(user_data: dict) -> User:
|
||||
try:
|
||||
user_id = user_data.get("sub")
|
||||
@@ -49,6 +54,7 @@ async def get_or_create_user(user_data: dict) -> User:
|
||||
raise DatabaseError(f"Failed to get or create user {user_data}: {e}") from e
|
||||
|
||||
|
||||
@cache_user_lookup
|
||||
async def get_user_by_id(user_id: str) -> User:
|
||||
user = await prisma.user.find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
@@ -64,6 +70,7 @@ async def get_user_email_by_id(user_id: str) -> Optional[str]:
|
||||
raise DatabaseError(f"Failed to get user email for user {user_id}: {e}") from e
|
||||
|
||||
|
||||
@cache_user_lookup
|
||||
async def get_user_by_email(email: str) -> Optional[User]:
|
||||
try:
|
||||
user = await prisma.user.find_unique(where={"email": email})
|
||||
@@ -74,7 +81,17 @@ async def get_user_by_email(email: str) -> Optional[User]:
|
||||
|
||||
async def update_user_email(user_id: str, email: str):
|
||||
try:
|
||||
# Get old email first for cache invalidation
|
||||
old_user = await prisma.user.find_unique(where={"id": user_id})
|
||||
old_email = old_user.email if old_user else None
|
||||
|
||||
await prisma.user.update(where={"id": user_id}, data={"email": email})
|
||||
|
||||
# Selectively invalidate only the specific user entries
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
if old_email:
|
||||
get_user_by_email.cache_delete(old_email)
|
||||
get_user_by_email.cache_delete(email)
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to update user email for user {user_id}: {e}"
|
||||
@@ -114,6 +131,8 @@ async def update_user_integrations(user_id: str, data: UserIntegrations):
|
||||
where={"id": user_id},
|
||||
data={"integrations": encrypted_data},
|
||||
)
|
||||
# Invalidate cache for this user
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
|
||||
|
||||
async def migrate_and_encrypt_user_integrations():
|
||||
@@ -285,6 +304,10 @@ async def update_user_notification_preference(
|
||||
)
|
||||
if not user:
|
||||
raise ValueError(f"User not found with ID: {user_id}")
|
||||
|
||||
# Invalidate cache for this user since notification preferences are part of user data
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
|
||||
preferences: dict[NotificationType, bool] = {
|
||||
NotificationType.AGENT_RUN: user.notifyOnAgentRun or True,
|
||||
NotificationType.ZERO_BALANCE: user.notifyOnZeroBalance or True,
|
||||
@@ -323,12 +346,44 @@ async def set_user_email_verification(user_id: str, verified: bool) -> None:
|
||||
where={"id": user_id},
|
||||
data={"emailVerified": verified},
|
||||
)
|
||||
# Invalidate cache for this user
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to set email verification status for user {user_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def disable_all_user_notifications(user_id: str) -> None:
|
||||
"""Disable all notification preferences for a user.
|
||||
|
||||
Used when user's email bounces/is inactive to prevent any future notifications.
|
||||
"""
|
||||
try:
|
||||
await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={
|
||||
"notifyOnAgentRun": False,
|
||||
"notifyOnZeroBalance": False,
|
||||
"notifyOnLowBalance": False,
|
||||
"notifyOnBlockExecutionFailed": False,
|
||||
"notifyOnContinuousAgentError": False,
|
||||
"notifyOnDailySummary": False,
|
||||
"notifyOnWeeklySummary": False,
|
||||
"notifyOnMonthlySummary": False,
|
||||
"notifyOnAgentApproved": False,
|
||||
"notifyOnAgentRejected": False,
|
||||
},
|
||||
)
|
||||
# Invalidate cache for this user
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
logger.info(f"Disabled all notification preferences for user {user_id}")
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to disable notifications for user {user_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_user_email_verification(user_id: str) -> bool:
|
||||
"""Get the email verification status for a user."""
|
||||
try:
|
||||
@@ -407,6 +462,10 @@ async def update_user_timezone(user_id: str, timezone: str) -> User:
|
||||
)
|
||||
if not user:
|
||||
raise ValueError(f"User not found with ID: {user_id}")
|
||||
|
||||
# Invalidate cache for this user
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
|
||||
return User.from_db(user)
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to update timezone for user {user_id}: {e}") from e
|
||||
|
||||
@@ -4,7 +4,12 @@ Module for generating AI-based activity status for graph executions.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
try:
|
||||
from typing import NotRequired
|
||||
except ImportError:
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -107,7 +112,7 @@ async def generate_activity_status_for_execution(
|
||||
# Check if we have OpenAI API key
|
||||
try:
|
||||
settings = Settings()
|
||||
if not settings.secrets.openai_api_key:
|
||||
if not settings.secrets.openai_internal_api_key:
|
||||
logger.debug(
|
||||
"OpenAI API key not configured, skipping activity status generation"
|
||||
)
|
||||
@@ -146,17 +151,35 @@ async def generate_activity_status_for_execution(
|
||||
"Focus on the ACTUAL TASK the user wanted done, not the internal workflow steps. "
|
||||
"Avoid technical terms like 'workflow', 'execution', 'components', 'nodes', 'processing', etc. "
|
||||
"Keep it to 3 sentences maximum. Be conversational and human-friendly.\n\n"
|
||||
"UNDERSTAND THE INTENDED PURPOSE:\n"
|
||||
"- FIRST: Read the graph description carefully to understand what the user wanted to accomplish\n"
|
||||
"- The graph name and description tell you the main goal/intention of this automation\n"
|
||||
"- Use this intended purpose as your PRIMARY criteria for success/failure evaluation\n"
|
||||
"- Ask yourself: 'Did this execution actually accomplish what the graph was designed to do?'\n\n"
|
||||
"CRITICAL OUTPUT ANALYSIS:\n"
|
||||
"- Check if blocks that should produce user-facing results actually produced outputs\n"
|
||||
"- Blocks with names containing 'Output', 'Post', 'Create', 'Send', 'Publish', 'Generate' are usually meant to produce final results\n"
|
||||
"- If these critical blocks have NO outputs (empty recent_outputs), the task likely FAILED even if status shows 'completed'\n"
|
||||
"- Sub-agents (AgentExecutorBlock) that produce no outputs usually indicate failed sub-tasks\n"
|
||||
"- Most importantly: Does the execution result match what the graph description promised to deliver?\n\n"
|
||||
"SUCCESS EVALUATION BASED ON INTENTION:\n"
|
||||
"- If the graph is meant to 'create blog posts' → check if blog content was actually created\n"
|
||||
"- If the graph is meant to 'send emails' → check if emails were actually sent\n"
|
||||
"- If the graph is meant to 'analyze data' → check if analysis results were produced\n"
|
||||
"- If the graph is meant to 'generate reports' → check if reports were generated\n"
|
||||
"- Technical completion ≠ goal achievement. Focus on whether the USER'S INTENDED OUTCOME was delivered\n\n"
|
||||
"IMPORTANT: Be HONEST about what actually happened:\n"
|
||||
"- If the input was invalid/nonsensical, say so directly\n"
|
||||
"- If the task failed, explain what went wrong in simple terms\n"
|
||||
"- If errors occurred, focus on what the user needs to know\n"
|
||||
"- Only claim success if the task was genuinely completed\n"
|
||||
"- Don't sugar-coat failures or present them as helpful feedback\n\n"
|
||||
"- Only claim success if the INTENDED PURPOSE was genuinely accomplished AND produced expected outputs\n"
|
||||
"- Don't sugar-coat failures or present them as helpful feedback\n"
|
||||
"- ESPECIALLY: If the graph's main purpose wasn't achieved, this is a failure regardless of 'completed' status\n\n"
|
||||
"Understanding Errors:\n"
|
||||
"- Node errors: Individual steps may fail but the overall task might still complete (e.g., one data source fails but others work)\n"
|
||||
"- Graph error (in overall_status.graph_error): This means the entire execution failed and nothing was accomplished\n"
|
||||
"- Even if execution shows 'completed', check if critical nodes failed that would prevent the desired outcome\n"
|
||||
"- Focus on the end result the user wanted, not whether technical steps completed"
|
||||
"- Missing outputs from critical blocks: Even if no errors, this means the task failed to produce expected results\n"
|
||||
"- Focus on whether the graph's intended purpose was fulfilled, not whether technical steps completed"
|
||||
),
|
||||
},
|
||||
{
|
||||
@@ -165,15 +188,28 @@ async def generate_activity_status_for_execution(
|
||||
f"A user ran '{graph_name}' to accomplish something. Based on this execution data, "
|
||||
f"write what they achieved in simple, user-friendly terms:\n\n"
|
||||
f"{json.dumps(execution_data, indent=2)}\n\n"
|
||||
"CRITICAL: Check overall_status.graph_error FIRST - if present, the entire execution failed.\n"
|
||||
"Then check individual node errors to understand partial failures.\n\n"
|
||||
"ANALYSIS CHECKLIST:\n"
|
||||
"1. READ graph_info.description FIRST - this tells you what the user intended to accomplish\n"
|
||||
"2. Check overall_status.graph_error - if present, the entire execution failed\n"
|
||||
"3. Look for nodes with 'Output', 'Post', 'Create', 'Send', 'Publish', 'Generate' in their block_name\n"
|
||||
"4. Check if these critical blocks have empty recent_outputs arrays - this indicates failure\n"
|
||||
"5. Look for AgentExecutorBlock (sub-agents) with no outputs - this suggests sub-task failures\n"
|
||||
"6. Count how many nodes produced outputs vs total nodes - low ratio suggests problems\n"
|
||||
"7. MOST IMPORTANT: Does the execution outcome match what graph_info.description promised?\n\n"
|
||||
"INTENTION-BASED EVALUATION:\n"
|
||||
"- If description mentions 'blog writing' → did it create blog content?\n"
|
||||
"- If description mentions 'email automation' → were emails actually sent?\n"
|
||||
"- If description mentions 'data analysis' → were analysis results produced?\n"
|
||||
"- If description mentions 'content generation' → was content actually generated?\n"
|
||||
"- If description mentions 'social media posting' → were posts actually made?\n"
|
||||
"- Match the outputs to the stated intention, not just technical completion\n\n"
|
||||
"Write 1-3 sentences about what the user accomplished, such as:\n"
|
||||
"- 'I analyzed your resume and provided detailed feedback for the IT industry.'\n"
|
||||
"- 'I couldn't analyze your resume because the input was just nonsensical text.'\n"
|
||||
"- 'I failed to complete the task due to missing API access.'\n"
|
||||
"- 'I couldn't complete the task because critical steps failed to produce any results.'\n"
|
||||
"- 'I failed to generate the content you requested due to missing API access.'\n"
|
||||
"- 'I extracted key information from your documents and organized it into a summary.'\n"
|
||||
"- 'The task failed to run due to system configuration issues.'\n\n"
|
||||
"Focus on what ACTUALLY happened, not what was attempted."
|
||||
"- 'The task failed because the blog post creation step didn't produce any output.'\n\n"
|
||||
"BE CRITICAL: If the graph's intended purpose (from description) wasn't achieved, report this as a failure even if status is 'completed'."
|
||||
),
|
||||
},
|
||||
]
|
||||
@@ -187,7 +223,7 @@ async def generate_activity_status_for_execution(
|
||||
credentials = APIKeyCredentials(
|
||||
id="openai",
|
||||
provider="openai",
|
||||
api_key=SecretStr(settings.secrets.openai_api_key),
|
||||
api_key=SecretStr(settings.secrets.openai_internal_api_key),
|
||||
title="System OpenAI",
|
||||
)
|
||||
|
||||
@@ -197,6 +233,7 @@ async def generate_activity_status_for_execution(
|
||||
logger.debug(
|
||||
f"Generated activity status for {graph_exec_id}: {activity_status}"
|
||||
)
|
||||
|
||||
return activity_status
|
||||
|
||||
except Exception as e:
|
||||
@@ -423,7 +460,6 @@ async def _call_llm_direct(
|
||||
credentials=credentials,
|
||||
llm_model=LlmModel.GPT4O_MINI,
|
||||
prompt=prompt,
|
||||
json_format=False,
|
||||
max_tokens=150,
|
||||
compress_prompt_to_fit=True,
|
||||
)
|
||||
|
||||
@@ -468,7 +468,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
mock_llm.return_value = (
|
||||
"I analyzed your data and provided the requested insights."
|
||||
)
|
||||
@@ -520,7 +520,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
mock_settings.return_value.secrets.openai_api_key = ""
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = ""
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
@@ -546,7 +546,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
@@ -581,7 +581,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
mock_llm.return_value = "Agent completed execution."
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
@@ -633,7 +633,7 @@ class TestIntegration:
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
|
||||
mock_response = LLMResponse(
|
||||
raw_response={},
|
||||
|
||||
115
autogpt_platform/backend/backend/executor/cluster_lock.py
Normal file
115
autogpt_platform/backend/backend/executor/cluster_lock.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Redis-based distributed locking for cluster coordination."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClusterLock:
|
||||
"""Simple Redis-based distributed lock for preventing duplicate execution."""
|
||||
|
||||
def __init__(self, redis: "Redis", key: str, owner_id: str, timeout: int = 300):
|
||||
self.redis = redis
|
||||
self.key = key
|
||||
self.owner_id = owner_id
|
||||
self.timeout = timeout
|
||||
self._last_refresh = 0.0
|
||||
|
||||
def try_acquire(self) -> str | None:
|
||||
"""Try to acquire the lock.
|
||||
|
||||
Returns:
|
||||
- owner_id (self.owner_id) if successfully acquired
|
||||
- different owner_id if someone else holds the lock
|
||||
- None if Redis is unavailable or other error
|
||||
"""
|
||||
try:
|
||||
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
|
||||
if success:
|
||||
self._last_refresh = time.time()
|
||||
return self.owner_id # Successfully acquired
|
||||
|
||||
# Failed to acquire, get current owner
|
||||
current_value = self.redis.get(self.key)
|
||||
if current_value:
|
||||
current_owner = (
|
||||
current_value.decode("utf-8")
|
||||
if isinstance(current_value, bytes)
|
||||
else str(current_value)
|
||||
)
|
||||
return current_owner
|
||||
|
||||
# Key doesn't exist but we failed to set it - race condition or Redis issue
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ClusterLock.try_acquire failed for key {self.key}: {e}")
|
||||
return None
|
||||
|
||||
def refresh(self) -> bool:
|
||||
"""Refresh lock TTL if we still own it.
|
||||
|
||||
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
||||
During rate limiting, still verifies lock existence but skips TTL extension.
|
||||
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
||||
"""
|
||||
# Calculate refresh interval: max(timeout // 10, 1)
|
||||
refresh_interval = max(self.timeout // 10, 1)
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we're within the rate limit period
|
||||
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
||||
is_rate_limited = (
|
||||
self._last_refresh > 0
|
||||
and (current_time - self._last_refresh) < refresh_interval
|
||||
)
|
||||
|
||||
try:
|
||||
# Always verify lock existence, even during rate limiting
|
||||
current_value = self.redis.get(self.key)
|
||||
if not current_value:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
stored_owner = (
|
||||
current_value.decode("utf-8")
|
||||
if isinstance(current_value, bytes)
|
||||
else str(current_value)
|
||||
)
|
||||
if stored_owner != self.owner_id:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
# If rate limited, return True but don't update TTL or timestamp
|
||||
if is_rate_limited:
|
||||
return True
|
||||
|
||||
# Perform actual refresh
|
||||
if self.redis.expire(self.key, self.timeout):
|
||||
self._last_refresh = current_time
|
||||
return True
|
||||
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
"""Release the lock."""
|
||||
if self._last_refresh == 0:
|
||||
return
|
||||
|
||||
try:
|
||||
self.redis.delete(self.key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._last_refresh = 0.0
|
||||
507
autogpt_platform/backend/backend/executor/cluster_lock_test.py
Normal file
507
autogpt_platform/backend/backend/executor/cluster_lock_test.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""
|
||||
Integration tests for ClusterLock - Redis-based distributed locking.
|
||||
|
||||
Tests the complete lock lifecycle without mocking Redis to ensure
|
||||
real-world behavior is correct. Covers acquisition, refresh, expiry,
|
||||
contention, and error scenarios.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from threading import Thread
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
|
||||
from .cluster_lock import ClusterLock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_client():
|
||||
"""Get Redis client for testing using same config as backend."""
|
||||
from backend.data.redis_client import HOST, PASSWORD, PORT
|
||||
|
||||
# Use same config as backend but without decode_responses since ClusterLock needs raw bytes
|
||||
client = redis.Redis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=False, # ClusterLock needs raw bytes for ownership verification
|
||||
)
|
||||
|
||||
# Clean up any existing test keys
|
||||
try:
|
||||
for key in client.scan_iter(match="test_lock:*"):
|
||||
client.delete(key)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lock_key():
|
||||
"""Generate unique lock key for each test."""
|
||||
return f"test_lock:{uuid.uuid4()}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def owner_id():
|
||||
"""Generate unique owner ID for each test."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestClusterLockBasic:
|
||||
"""Basic lock acquisition and release functionality."""
|
||||
|
||||
def test_lock_acquisition_success(self, redis_client, lock_key, owner_id):
|
||||
"""Test basic lock acquisition succeeds."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Lock should be acquired successfully
|
||||
result = lock.try_acquire()
|
||||
assert result == owner_id # Returns our owner_id when successfully acquired
|
||||
assert lock._last_refresh > 0
|
||||
|
||||
# Lock key should exist in Redis
|
||||
assert redis_client.exists(lock_key) == 1
|
||||
assert redis_client.get(lock_key).decode("utf-8") == owner_id
|
||||
|
||||
def test_lock_acquisition_contention(self, redis_client, lock_key):
|
||||
"""Test second acquisition fails when lock is held."""
|
||||
owner1 = str(uuid.uuid4())
|
||||
owner2 = str(uuid.uuid4())
|
||||
|
||||
lock1 = ClusterLock(redis_client, lock_key, owner1, timeout=60)
|
||||
lock2 = ClusterLock(redis_client, lock_key, owner2, timeout=60)
|
||||
|
||||
# First lock should succeed
|
||||
result1 = lock1.try_acquire()
|
||||
assert result1 == owner1 # Successfully acquired, returns our owner_id
|
||||
|
||||
# Second lock should fail and return the first owner
|
||||
result2 = lock2.try_acquire()
|
||||
assert result2 == owner1 # Returns the current owner (first owner)
|
||||
assert lock2._last_refresh == 0
|
||||
|
||||
def test_lock_release_deletes_redis_key(self, redis_client, lock_key, owner_id):
|
||||
"""Test lock release deletes Redis key and marks locally as released."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
assert lock._last_refresh > 0
|
||||
assert redis_client.exists(lock_key) == 1
|
||||
|
||||
# Release should delete Redis key and mark locally as released
|
||||
lock.release()
|
||||
assert lock._last_refresh == 0
|
||||
assert lock._last_refresh == 0.0
|
||||
|
||||
# Redis key should be deleted for immediate release
|
||||
assert redis_client.exists(lock_key) == 0
|
||||
|
||||
# Another lock should be able to acquire immediately
|
||||
new_owner_id = str(uuid.uuid4())
|
||||
new_lock = ClusterLock(redis_client, lock_key, new_owner_id, timeout=60)
|
||||
assert new_lock.try_acquire() == new_owner_id
|
||||
|
||||
|
||||
class TestClusterLockRefresh:
|
||||
"""Lock refresh and TTL management."""
|
||||
|
||||
def test_lock_refresh_success(self, redis_client, lock_key, owner_id):
|
||||
"""Test lock refresh extends TTL."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
original_ttl = redis_client.ttl(lock_key)
|
||||
|
||||
# Wait a bit then refresh
|
||||
time.sleep(1)
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is True
|
||||
|
||||
# TTL should be reset to full timeout (allow for small timing differences)
|
||||
new_ttl = redis_client.ttl(lock_key)
|
||||
assert new_ttl >= original_ttl or new_ttl >= 58 # Allow for timing variance
|
||||
|
||||
def test_lock_refresh_rate_limiting(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh is rate-limited to timeout/10."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=100
|
||||
) # 100s timeout
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# First refresh should work
|
||||
assert lock.refresh() is True
|
||||
first_refresh_time = lock._last_refresh
|
||||
|
||||
# Immediate second refresh should be skipped (rate limited) but verify key exists
|
||||
assert lock.refresh() is True # Returns True but skips actual refresh
|
||||
assert lock._last_refresh == first_refresh_time # Time unchanged
|
||||
|
||||
def test_lock_refresh_verifies_existence_during_rate_limit(
|
||||
self, redis_client, lock_key, owner_id
|
||||
):
|
||||
"""Test refresh verifies lock existence even during rate limiting."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=100)
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Manually delete the key (simulates expiry or external deletion)
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
# Refresh should detect missing key even during rate limit period
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_lock_refresh_ownership_lost(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh fails when ownership is lost."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Simulate another process taking the lock
|
||||
different_owner = str(uuid.uuid4())
|
||||
redis_client.set(lock_key, different_owner, ex=60)
|
||||
|
||||
# Force refresh past rate limit and verify it fails
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_lock_refresh_when_not_acquired(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh fails when lock was never acquired."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Refresh without acquiring should fail
|
||||
assert lock.refresh() is False
|
||||
|
||||
|
||||
class TestClusterLockExpiry:
|
||||
"""Lock expiry and timeout behavior."""
|
||||
|
||||
def test_lock_natural_expiry(self, redis_client, lock_key, owner_id):
|
||||
"""Test lock expires naturally via Redis TTL."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=2
|
||||
) # 2 second timeout
|
||||
|
||||
lock.try_acquire()
|
||||
assert redis_client.exists(lock_key) == 1
|
||||
|
||||
# Wait for expiry
|
||||
time.sleep(3)
|
||||
assert redis_client.exists(lock_key) == 0
|
||||
|
||||
# New lock with same key should succeed
|
||||
new_lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
assert new_lock.try_acquire() == owner_id
|
||||
|
||||
def test_lock_refresh_prevents_expiry(self, redis_client, lock_key, owner_id):
|
||||
"""Test refreshing prevents lock from expiring."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=3
|
||||
) # 3 second timeout
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Wait and refresh before expiry
|
||||
time.sleep(1)
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is True
|
||||
|
||||
# Wait beyond original timeout
|
||||
time.sleep(2.5)
|
||||
assert redis_client.exists(lock_key) == 1 # Should still exist
|
||||
|
||||
|
||||
class TestClusterLockConcurrency:
|
||||
"""Concurrent access patterns."""
|
||||
|
||||
def test_multiple_threads_contention(self, redis_client, lock_key):
|
||||
"""Test multiple threads competing for same lock."""
|
||||
num_threads = 5
|
||||
successful_acquisitions = []
|
||||
|
||||
def try_acquire_lock(thread_id):
|
||||
owner_id = f"thread_{thread_id}"
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
if lock.try_acquire() == owner_id:
|
||||
successful_acquisitions.append(thread_id)
|
||||
time.sleep(0.1) # Hold lock briefly
|
||||
lock.release()
|
||||
|
||||
threads = []
|
||||
for i in range(num_threads):
|
||||
thread = Thread(target=try_acquire_lock, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Only one thread should have acquired the lock
|
||||
assert len(successful_acquisitions) == 1
|
||||
|
||||
def test_sequential_lock_reuse(self, redis_client, lock_key):
|
||||
"""Test lock can be reused after natural expiry."""
|
||||
owners = [str(uuid.uuid4()) for _ in range(3)]
|
||||
|
||||
for i, owner_id in enumerate(owners):
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=1) # 1 second
|
||||
|
||||
assert lock.try_acquire() == owner_id
|
||||
time.sleep(1.5) # Wait for expiry
|
||||
|
||||
# Verify lock expired
|
||||
assert redis_client.exists(lock_key) == 0
|
||||
|
||||
def test_refresh_during_concurrent_access(self, redis_client, lock_key):
|
||||
"""Test lock refresh works correctly during concurrent access attempts."""
|
||||
owner1 = str(uuid.uuid4())
|
||||
owner2 = str(uuid.uuid4())
|
||||
|
||||
lock1 = ClusterLock(redis_client, lock_key, owner1, timeout=5)
|
||||
lock2 = ClusterLock(redis_client, lock_key, owner2, timeout=5)
|
||||
|
||||
# Thread 1 holds lock and refreshes
|
||||
assert lock1.try_acquire() == owner1
|
||||
|
||||
def refresh_continuously():
|
||||
for _ in range(10):
|
||||
lock1._last_refresh = 0 # Force refresh
|
||||
lock1.refresh()
|
||||
time.sleep(0.1)
|
||||
|
||||
def try_acquire_continuously():
|
||||
attempts = 0
|
||||
while attempts < 20:
|
||||
if lock2.try_acquire() == owner2:
|
||||
return True
|
||||
time.sleep(0.1)
|
||||
attempts += 1
|
||||
return False
|
||||
|
||||
refresh_thread = Thread(target=refresh_continuously)
|
||||
acquire_thread = Thread(target=try_acquire_continuously)
|
||||
|
||||
refresh_thread.start()
|
||||
acquire_thread.start()
|
||||
|
||||
refresh_thread.join()
|
||||
acquire_thread.join()
|
||||
|
||||
# Lock1 should still own the lock due to refreshes
|
||||
assert lock1._last_refresh > 0
|
||||
assert lock2._last_refresh == 0
|
||||
|
||||
|
||||
class TestClusterLockErrorHandling:
|
||||
"""Error handling and edge cases."""
|
||||
|
||||
def test_redis_connection_failure_on_acquire(self, lock_key, owner_id):
|
||||
"""Test graceful handling when Redis is unavailable during acquisition."""
|
||||
# Use invalid Redis connection
|
||||
bad_redis = redis.Redis(
|
||||
host="invalid_host", port=1234, socket_connect_timeout=1
|
||||
)
|
||||
lock = ClusterLock(bad_redis, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Should return None for Redis connection failures
|
||||
result = lock.try_acquire()
|
||||
assert result is None # Returns None when Redis fails
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_redis_connection_failure_on_refresh(
|
||||
self, redis_client, lock_key, owner_id
|
||||
):
|
||||
"""Test graceful handling when Redis fails during refresh."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Acquire normally
|
||||
assert lock.try_acquire() == owner_id
|
||||
|
||||
# Replace Redis client with failing one
|
||||
lock.redis = redis.Redis(
|
||||
host="invalid_host", port=1234, socket_connect_timeout=1
|
||||
)
|
||||
|
||||
# Refresh should fail gracefully
|
||||
lock._last_refresh = 0 # Force refresh
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_invalid_lock_parameters(self, redis_client):
|
||||
"""Test validation of lock parameters."""
|
||||
owner_id = str(uuid.uuid4())
|
||||
|
||||
# All parameters are now simple - no validation needed
|
||||
# Just test basic construction works
|
||||
lock = ClusterLock(redis_client, "test_key", owner_id, timeout=60)
|
||||
assert lock.key == "test_key"
|
||||
assert lock.owner_id == owner_id
|
||||
assert lock.timeout == 60
|
||||
|
||||
def test_refresh_after_redis_key_deleted(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh behavior when Redis key is manually deleted."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Manually delete the key (simulates external deletion)
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
# Refresh should fail and mark as not acquired
|
||||
lock._last_refresh = 0 # Force refresh
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
|
||||
class TestClusterLockDynamicRefreshInterval:
|
||||
"""Dynamic refresh interval based on timeout."""
|
||||
|
||||
def test_refresh_interval_calculation(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh interval is calculated as max(timeout/10, 1)."""
|
||||
test_cases = [
|
||||
(5, 1), # 5/10 = 0, but minimum is 1
|
||||
(10, 1), # 10/10 = 1
|
||||
(30, 3), # 30/10 = 3
|
||||
(100, 10), # 100/10 = 10
|
||||
(200, 20), # 200/10 = 20
|
||||
(1000, 100), # 1000/10 = 100
|
||||
]
|
||||
|
||||
for timeout, expected_interval in test_cases:
|
||||
lock = ClusterLock(
|
||||
redis_client, f"{lock_key}_{timeout}", owner_id, timeout=timeout
|
||||
)
|
||||
lock.try_acquire()
|
||||
|
||||
# Calculate expected interval using same logic as implementation
|
||||
refresh_interval = max(timeout // 10, 1)
|
||||
assert refresh_interval == expected_interval
|
||||
|
||||
# Test rate limiting works with calculated interval
|
||||
assert lock.refresh() is True
|
||||
first_refresh_time = lock._last_refresh
|
||||
|
||||
# Sleep less than interval - should be rate limited
|
||||
time.sleep(0.1)
|
||||
assert lock.refresh() is True
|
||||
assert lock._last_refresh == first_refresh_time # No actual refresh
|
||||
|
||||
|
||||
class TestClusterLockRealWorldScenarios:
|
||||
"""Real-world usage patterns."""
|
||||
|
||||
def test_execution_coordination_simulation(self, redis_client):
|
||||
"""Simulate graph execution coordination across multiple pods."""
|
||||
graph_exec_id = str(uuid.uuid4())
|
||||
lock_key = f"execution:{graph_exec_id}"
|
||||
|
||||
# Simulate 3 pods trying to execute same graph
|
||||
pods = [f"pod_{i}" for i in range(3)]
|
||||
execution_results = {}
|
||||
|
||||
def execute_graph(pod_id):
|
||||
"""Simulate graph execution with cluster lock."""
|
||||
lock = ClusterLock(redis_client, lock_key, pod_id, timeout=300)
|
||||
|
||||
if lock.try_acquire() == pod_id:
|
||||
# Simulate execution work
|
||||
execution_results[pod_id] = "executed"
|
||||
time.sleep(0.1)
|
||||
lock.release()
|
||||
else:
|
||||
execution_results[pod_id] = "rejected"
|
||||
|
||||
threads = []
|
||||
for pod_id in pods:
|
||||
thread = Thread(target=execute_graph, args=(pod_id,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Only one pod should have executed
|
||||
executed_count = sum(
|
||||
1 for result in execution_results.values() if result == "executed"
|
||||
)
|
||||
rejected_count = sum(
|
||||
1 for result in execution_results.values() if result == "rejected"
|
||||
)
|
||||
|
||||
assert executed_count == 1
|
||||
assert rejected_count == 2
|
||||
|
||||
def test_long_running_execution_with_refresh(
|
||||
self, redis_client, lock_key, owner_id
|
||||
):
|
||||
"""Test lock maintains ownership during long execution with periodic refresh."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=30
|
||||
) # 30 second timeout, refresh interval = max(30//10, 1) = 3 seconds
|
||||
|
||||
def long_execution_with_refresh():
|
||||
"""Simulate long-running execution with periodic refresh."""
|
||||
assert lock.try_acquire() == owner_id
|
||||
|
||||
# Simulate 10 seconds of work with refreshes every 2 seconds
|
||||
# This respects rate limiting - actual refreshes will happen at 0s, 3s, 6s, 9s
|
||||
try:
|
||||
for i in range(5): # 5 iterations * 2 seconds = 10 seconds total
|
||||
time.sleep(2)
|
||||
refresh_success = lock.refresh()
|
||||
assert refresh_success is True, f"Refresh failed at iteration {i}"
|
||||
return "completed"
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
# Should complete successfully without losing lock
|
||||
result = long_execution_with_refresh()
|
||||
assert result == "completed"
|
||||
|
||||
def test_graceful_degradation_pattern(self, redis_client, lock_key):
|
||||
"""Test graceful degradation when Redis becomes unavailable."""
|
||||
owner_id = str(uuid.uuid4())
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=3
|
||||
) # Use shorter timeout
|
||||
|
||||
# Normal operation
|
||||
assert lock.try_acquire() == owner_id
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is True
|
||||
|
||||
# Simulate Redis becoming unavailable
|
||||
original_redis = lock.redis
|
||||
lock.redis = redis.Redis(
|
||||
host="invalid_host",
|
||||
port=1234,
|
||||
socket_connect_timeout=1,
|
||||
decode_responses=False,
|
||||
)
|
||||
|
||||
# Should degrade gracefully
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
# Restore Redis and verify can acquire again
|
||||
lock.redis = original_redis
|
||||
# Wait for original lock to expire (use longer wait for 3s timeout)
|
||||
time.sleep(4)
|
||||
|
||||
new_lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
assert new_lock.try_acquire() == owner_id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run specific test for quick validation
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
@@ -9,6 +10,7 @@ from backend.data.execution import (
|
||||
get_execution_kv_data,
|
||||
get_graph_execution_meta,
|
||||
get_graph_executions,
|
||||
get_graph_executions_count,
|
||||
get_latest_node_execution,
|
||||
get_node_execution,
|
||||
get_node_executions,
|
||||
@@ -28,14 +30,17 @@ from backend.data.graph import (
|
||||
get_node,
|
||||
)
|
||||
from backend.data.notifications import (
|
||||
clear_all_user_notification_batches,
|
||||
create_or_add_to_user_notification_batch,
|
||||
empty_user_notification_batch,
|
||||
get_all_batches_by_type,
|
||||
get_user_notification_batch,
|
||||
get_user_notification_oldest_message_in_batch,
|
||||
remove_notifications_from_batch,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_active_user_ids_in_timerange,
|
||||
get_user_by_id,
|
||||
get_user_email_by_id,
|
||||
get_user_email_verification,
|
||||
get_user_integrations,
|
||||
@@ -53,8 +58,10 @@ from backend.util.service import (
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
|
||||
config = Config()
|
||||
_user_credit_model = get_user_credit_model()
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@@ -63,28 +70,41 @@ R = TypeVar("R")
|
||||
async def _spend_credits(
|
||||
user_id: str, cost: int, metadata: UsageTransactionMetadata
|
||||
) -> int:
|
||||
return await _user_credit_model.spend_credits(user_id, cost, metadata)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.spend_credits(user_id, cost, metadata)
|
||||
|
||||
|
||||
async def _get_credits(user_id: str) -> int:
|
||||
return await _user_credit_model.get_credits(user_id)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.get_credits(user_id)
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, app: "FastAPI"):
|
||||
async with super().lifespan(app):
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
await db.connect()
|
||||
|
||||
def run_service(self) -> None:
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
self.run_and_wait(db.connect())
|
||||
super().run_service()
|
||||
logger.info(f"[{self.service_name}] ✅ Ready")
|
||||
yield
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
await db.disconnect()
|
||||
|
||||
async def health_check(self) -> str:
|
||||
if not db.is_connected():
|
||||
raise UnhealthyServiceError("Database is not connected")
|
||||
|
||||
try:
|
||||
# Test actual database connectivity by executing a simple query
|
||||
# This will fail if Prisma query engine is not responding
|
||||
result = await db.query_raw_with_schema("SELECT 1 as health_check")
|
||||
if not result or result[0].get("health_check") != 1:
|
||||
raise UnhealthyServiceError("Database query test failed")
|
||||
except Exception as e:
|
||||
raise UnhealthyServiceError(f"Database health check failed: {e}")
|
||||
|
||||
return await super().health_check()
|
||||
|
||||
@classmethod
|
||||
@@ -101,6 +121,7 @@ class DatabaseManager(AppService):
|
||||
|
||||
# Executions
|
||||
get_graph_executions = _(get_graph_executions)
|
||||
get_graph_executions_count = _(get_graph_executions_count)
|
||||
get_graph_execution_meta = _(get_graph_execution_meta)
|
||||
create_graph_execution = _(create_graph_execution)
|
||||
get_node_execution = _(get_node_execution)
|
||||
@@ -132,15 +153,18 @@ class DatabaseManager(AppService):
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
get_user_email_by_id = _(get_user_email_by_id)
|
||||
get_user_email_verification = _(get_user_email_verification)
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
|
||||
# Notifications - async
|
||||
clear_all_user_notification_batches = _(clear_all_user_notification_batches)
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = _(empty_user_notification_batch)
|
||||
remove_notifications_from_batch = _(remove_notifications_from_batch)
|
||||
get_all_batches_by_type = _(get_all_batches_by_type)
|
||||
get_user_notification_batch = _(get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = _(
|
||||
@@ -169,6 +193,7 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
|
||||
# Executions
|
||||
get_graph_executions = _(d.get_graph_executions)
|
||||
get_graph_executions_count = _(d.get_graph_executions_count)
|
||||
get_graph_execution_meta = _(d.get_graph_execution_meta)
|
||||
get_node_executions = _(d.get_node_executions)
|
||||
update_node_execution_status = _(d.update_node_execution_status)
|
||||
@@ -214,6 +239,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_node = d.get_node
|
||||
get_node_execution = d.get_node_execution
|
||||
get_node_executions = d.get_node_executions
|
||||
get_user_by_id = d.get_user_by_id
|
||||
get_user_integrations = d.get_user_integrations
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
@@ -231,10 +257,12 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_user_notification_preference = d.get_user_notification_preference
|
||||
|
||||
# Notifications
|
||||
clear_all_user_notification_batches = d.clear_all_user_notification_batches
|
||||
create_or_add_to_user_notification_batch = (
|
||||
d.create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = d.empty_user_notification_batch
|
||||
remove_notifications_from_batch = d.remove_notifications_from_batch
|
||||
get_all_batches_by_type = d.get_all_batches_by_type
|
||||
get_user_notification_batch = d.get_user_notification_batch
|
||||
get_user_notification_oldest_message_in_batch = (
|
||||
|
||||
@@ -3,16 +3,42 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
|
||||
import sentry_sdk
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from redis.asyncio.lock import Lock as RedisLock
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import (
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockOutputEntry,
|
||||
BlockSchema,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.dynamic_fields import parse_execution_output
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
NodesInputMasks,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
@@ -25,50 +51,21 @@ from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import LogMetadata
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient, DatabaseManagerAsyncClient
|
||||
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import (
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockOutputEntry,
|
||||
BlockSchema,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
NodesInputMasks,
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.executor.utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
parse_execution_output,
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.server.v2.AutoMod.manager import automod_manager
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
@@ -84,13 +81,24 @@ from backend.util.decorator import (
|
||||
error_logged,
|
||||
time_measured,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.metrics import DiscordChannel
|
||||
from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.retry import continuous_retry, func_retry
|
||||
from backend.util.retry import (
|
||||
continuous_retry,
|
||||
func_retry,
|
||||
send_rate_limited_discord_alert,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .cluster_lock import ClusterLock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
logger = TruncatedLogger(_logger, prefix="[GraphExecutor]")
|
||||
settings = Settings()
|
||||
@@ -106,6 +114,7 @@ utilization_gauge = Gauge(
|
||||
"Ratio of active graph runs to max graph workers",
|
||||
)
|
||||
|
||||
|
||||
# Thread-local storage for ExecutionProcessor instances
|
||||
_tls = threading.local()
|
||||
|
||||
@@ -117,10 +126,14 @@ def init_worker():
|
||||
|
||||
|
||||
def execute_graph(
|
||||
graph_exec_entry: "GraphExecutionEntry", cancel_event: threading.Event
|
||||
graph_exec_entry: "GraphExecutionEntry",
|
||||
cancel_event: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute graph using thread-local ExecutionProcessor instance"""
|
||||
return _tls.processor.on_graph_execution(graph_exec_entry, cancel_event)
|
||||
return _tls.processor.on_graph_execution(
|
||||
graph_exec_entry, cancel_event, cluster_lock
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -177,6 +190,7 @@ async def execute_node(
|
||||
_input_data.inputs = input_data
|
||||
if nodes_input_masks:
|
||||
_input_data.nodes_input_masks = nodes_input_masks
|
||||
_input_data.user_id = user_id
|
||||
input_data = _input_data.model_dump()
|
||||
data.inputs = input_data
|
||||
|
||||
@@ -211,14 +225,37 @@ async def execute_node(
|
||||
extra_exec_kwargs[field_name] = credentials
|
||||
|
||||
output_size = 0
|
||||
|
||||
# sentry tracking nonsense to get user counts for blocks because isolation scopes don't work :(
|
||||
scope = sentry_sdk.get_current_scope()
|
||||
|
||||
# save the tags
|
||||
original_user = scope._user
|
||||
original_tags = dict(scope._tags) if scope._tags else {}
|
||||
# Set user ID for error tracking
|
||||
scope.set_user({"id": user_id})
|
||||
|
||||
scope.set_tag("graph_id", graph_id)
|
||||
scope.set_tag("node_id", node_id)
|
||||
scope.set_tag("block_name", node_block.name)
|
||||
scope.set_tag("block_id", node_block.id)
|
||||
for k, v in (data.user_context or UserContext(timezone="UTC")).model_dump().items():
|
||||
scope.set_tag(f"user_context.{k}", v)
|
||||
|
||||
try:
|
||||
async for output_name, output_data in node_block.execute(
|
||||
input_data, **extra_exec_kwargs
|
||||
):
|
||||
output_data = json.convert_pydantic_to_json(output_data)
|
||||
output_data = json.to_dict(output_data)
|
||||
output_size += len(json.dumps(output_data))
|
||||
log_metadata.debug("Node produced output", **{output_name: output_data})
|
||||
yield output_name, output_data
|
||||
except Exception:
|
||||
# Capture exception WITH context still set before restoring scope
|
||||
sentry_sdk.capture_exception(scope=scope)
|
||||
sentry_sdk.flush() # Ensure it's sent before we restore scope
|
||||
# Re-raise to maintain normal error flow
|
||||
raise
|
||||
finally:
|
||||
# Ensure credentials are released even if execution fails
|
||||
if creds_lock and (await creds_lock.locked()) and (await creds_lock.owned()):
|
||||
@@ -233,6 +270,10 @@ async def execute_node(
|
||||
execution_stats.input_size = input_size
|
||||
execution_stats.output_size = output_size
|
||||
|
||||
# Restore scope AFTER error has been captured
|
||||
scope._user = original_user
|
||||
scope._tags = original_tags
|
||||
|
||||
|
||||
async def _enqueue_next_nodes(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
@@ -429,7 +470,7 @@ class ExecutionProcessor:
|
||||
graph_id=node_exec.graph_id,
|
||||
node_eid=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_name="-",
|
||||
block_name=b.name if (b := get_block(node_exec.block_id)) else "-",
|
||||
)
|
||||
db_client = get_db_async_client()
|
||||
node = await db_client.get_node(node_exec.node_id)
|
||||
@@ -557,7 +598,6 @@ class ExecutionProcessor:
|
||||
await persist_output(
|
||||
"error", str(stats.error) or type(stats.error).__name__
|
||||
)
|
||||
|
||||
return status
|
||||
|
||||
@func_retry
|
||||
@@ -583,6 +623,7 @@ class ExecutionProcessor:
|
||||
self,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
@@ -605,7 +646,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
return
|
||||
|
||||
if exec_meta.status == ExecutionStatus.QUEUED:
|
||||
if exec_meta.status in [ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE]:
|
||||
log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}")
|
||||
exec_meta.status = ExecutionStatus.RUNNING
|
||||
send_execution_update(
|
||||
@@ -641,6 +682,7 @@ class ExecutionProcessor:
|
||||
cancel=cancel,
|
||||
log_metadata=log_metadata,
|
||||
execution_stats=exec_stats,
|
||||
cluster_lock=cluster_lock,
|
||||
)
|
||||
exec_stats.walltime += timing_info.wall_time
|
||||
exec_stats.cputime += timing_info.cpu_time
|
||||
@@ -742,6 +784,7 @@ class ExecutionProcessor:
|
||||
cancel: threading.Event,
|
||||
log_metadata: LogMetadata,
|
||||
execution_stats: GraphExecutionStats,
|
||||
cluster_lock: ClusterLock,
|
||||
) -> ExecutionStatus:
|
||||
"""
|
||||
Returns:
|
||||
@@ -927,7 +970,7 @@ class ExecutionProcessor:
|
||||
and execution_queue.empty()
|
||||
and (running_node_execution or running_node_evaluation)
|
||||
):
|
||||
# There is nothing to execute, and no output to process, let's relax for a while.
|
||||
cluster_lock.refresh()
|
||||
time.sleep(0.1)
|
||||
|
||||
# loop done --------------------------------------------------
|
||||
@@ -969,16 +1012,31 @@ class ExecutionProcessor:
|
||||
if isinstance(e, Exception)
|
||||
else Exception(f"{e.__class__.__name__}: {e}")
|
||||
)
|
||||
if not execution_stats.error:
|
||||
execution_stats.error = str(error)
|
||||
|
||||
known_errors = (InsufficientBalanceError, ModerationError)
|
||||
if isinstance(error, known_errors):
|
||||
execution_stats.error = str(error)
|
||||
return ExecutionStatus.FAILED
|
||||
|
||||
execution_status = ExecutionStatus.FAILED
|
||||
log_metadata.exception(
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
|
||||
)
|
||||
|
||||
# Send rate-limited Discord alert for unknown/unexpected errors
|
||||
send_rate_limited_discord_alert(
|
||||
"graph_execution",
|
||||
error,
|
||||
"unknown_error",
|
||||
f"🚨 **Unknown Graph Execution Error**\n"
|
||||
f"User: {graph_exec.user_id}\n"
|
||||
f"Graph ID: {graph_exec.graph_id}\n"
|
||||
f"Execution ID: {graph_exec.graph_exec_id}\n"
|
||||
f"Error Type: {type(error).__name__}\n"
|
||||
f"Error: {str(error)[:200]}{'...' if len(str(error)) > 200 else ''}\n",
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
finally:
|
||||
@@ -1153,9 +1211,9 @@ class ExecutionProcessor:
|
||||
f"❌ **Insufficient Funds Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
|
||||
f"Current balance: ${e.balance/100:.2f}\n"
|
||||
f"Attempted cost: ${abs(e.amount)/100:.2f}\n"
|
||||
f"Shortfall: ${abs(shortfall)/100:.2f}\n"
|
||||
f"Current balance: ${e.balance / 100:.2f}\n"
|
||||
f"Attempted cost: ${abs(e.amount) / 100:.2f}\n"
|
||||
f"Shortfall: ${abs(shortfall) / 100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
|
||||
@@ -1202,9 +1260,9 @@ class ExecutionProcessor:
|
||||
alert_message = (
|
||||
f"⚠️ **Low Balance Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Balance dropped below ${LOW_BALANCE_THRESHOLD/100:.2f}\n"
|
||||
f"Current balance: ${current_balance/100:.2f}\n"
|
||||
f"Transaction cost: ${transaction_cost/100:.2f}\n"
|
||||
f"Balance dropped below ${LOW_BALANCE_THRESHOLD / 100:.2f}\n"
|
||||
f"Current balance: ${current_balance / 100:.2f}\n"
|
||||
f"Transaction cost: ${transaction_cost / 100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
@@ -1219,6 +1277,7 @@ class ExecutionManager(AppProcess):
|
||||
super().__init__()
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
self.executor_id = str(uuid.uuid4())
|
||||
|
||||
self._executor = None
|
||||
self._stop_consuming = None
|
||||
@@ -1228,6 +1287,8 @@ class ExecutionManager(AppProcess):
|
||||
self._run_thread = None
|
||||
self._run_client = None
|
||||
|
||||
self._execution_locks = {}
|
||||
|
||||
@property
|
||||
def cancel_thread(self) -> threading.Thread:
|
||||
if self._cancel_thread is None:
|
||||
@@ -1432,20 +1493,79 @@ class ExecutionManager(AppProcess):
|
||||
return
|
||||
|
||||
graph_exec_id = graph_exec_entry.graph_exec_id
|
||||
user_id = graph_exec_entry.user_id
|
||||
graph_id = graph_exec_entry.graph_id
|
||||
logger.info(
|
||||
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
|
||||
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}, user_id={user_id}"
|
||||
)
|
||||
if graph_exec_id in self.active_graph_runs:
|
||||
# TODO: Make this check cluster-wide, prevent duplicate runs across executor pods.
|
||||
logger.error(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
|
||||
|
||||
# Check user rate limit before processing
|
||||
try:
|
||||
# Only check executions from the last 24 hours for performance
|
||||
current_running_count = get_db_client().get_graph_executions_count(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
statuses=[ExecutionStatus.RUNNING],
|
||||
created_time_gte=datetime.now(timezone.utc) - timedelta(hours=24),
|
||||
)
|
||||
_ack_message(reject=True, requeue=False)
|
||||
|
||||
if (
|
||||
current_running_count
|
||||
>= settings.config.max_concurrent_graph_executions_per_user
|
||||
):
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Rate limit exceeded for user {user_id} on graph {graph_id}: "
|
||||
f"{current_running_count}/{settings.config.max_concurrent_graph_executions_per_user} running executions"
|
||||
)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.service_name}] Failed to check rate limit for user {user_id}: {e}, proceeding with execution"
|
||||
)
|
||||
# If rate limit check fails, proceed to avoid blocking executions
|
||||
|
||||
# Check for local duplicate execution first
|
||||
if graph_exec_id in self.active_graph_runs:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running locally; rejecting duplicate."
|
||||
)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Try to acquire cluster-wide execution lock
|
||||
cluster_lock = ClusterLock(
|
||||
redis=redis.get_redis(),
|
||||
key=f"exec_lock:{graph_exec_id}",
|
||||
owner_id=self.executor_id,
|
||||
timeout=settings.config.cluster_lock_timeout,
|
||||
)
|
||||
current_owner = cluster_lock.try_acquire()
|
||||
if current_owner != self.executor_id:
|
||||
# Either someone else has it or Redis is unavailable
|
||||
if current_owner is not None:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running on pod {current_owner}"
|
||||
)
|
||||
_ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Could not acquire lock for {graph_exec_id} - Redis unavailable"
|
||||
)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
self._execution_locks[graph_exec_id] = cluster_lock
|
||||
|
||||
logger.info(
|
||||
f"[{self.service_name}] Acquired cluster lock for {graph_exec_id} with executor {self.executor_id}"
|
||||
)
|
||||
|
||||
cancel_event = threading.Event()
|
||||
|
||||
future = self.executor.submit(execute_graph, graph_exec_entry, cancel_event)
|
||||
future = self.executor.submit(
|
||||
execute_graph, graph_exec_entry, cancel_event, cluster_lock
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
self._update_prompt_metrics()
|
||||
|
||||
@@ -1464,6 +1584,10 @@ class ExecutionManager(AppProcess):
|
||||
f"[{self.service_name}] Error in run completion callback: {e}"
|
||||
)
|
||||
finally:
|
||||
# Release the cluster-wide execution lock
|
||||
if graph_exec_id in self._execution_locks:
|
||||
self._execution_locks[graph_exec_id].release()
|
||||
del self._execution_locks[graph_exec_id]
|
||||
self._cleanup_completed_runs()
|
||||
|
||||
future.add_done_callback(_on_run_done)
|
||||
@@ -1546,6 +1670,10 @@ class ExecutionManager(AppProcess):
|
||||
f"{prefix} ⏳ Still waiting for {len(self.active_graph_runs)} executions: {ids}"
|
||||
)
|
||||
|
||||
for graph_exec_id in self.active_graph_runs:
|
||||
if lock := self._execution_locks.get(graph_exec_id):
|
||||
lock.refresh()
|
||||
|
||||
time.sleep(wait_interval)
|
||||
waited += wait_interval
|
||||
|
||||
@@ -1563,6 +1691,15 @@ class ExecutionManager(AppProcess):
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}")
|
||||
|
||||
# Release remaining execution locks
|
||||
try:
|
||||
for lock in self._execution_locks.values():
|
||||
lock.release()
|
||||
self._execution_locks.clear()
|
||||
logger.info(f"{prefix} ✅ Released execution locks")
|
||||
except Exception as e:
|
||||
logger.warning(f"{prefix} ⚠️ Failed to release all locks: {e}")
|
||||
|
||||
# Disconnect the run execution consumer
|
||||
self._stop_message_consumers(
|
||||
self.run_thread,
|
||||
@@ -1577,6 +1714,8 @@ class ExecutionManager(AppProcess):
|
||||
|
||||
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
super().cleanup()
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
@@ -1668,15 +1807,18 @@ def update_graph_execution_state(
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def synchronized(key: str, timeout: int = 60):
|
||||
async def synchronized(key: str, timeout: int = settings.config.cluster_lock_timeout):
|
||||
r = await redis.get_redis_async()
|
||||
lock: RedisLock = r.lock(f"lock:{key}", timeout=timeout)
|
||||
lock: AsyncRedisLock = r.lock(f"lock:{key}", timeout=timeout)
|
||||
try:
|
||||
await lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
if await lock.locked() and await lock.owned():
|
||||
await lock.release()
|
||||
try:
|
||||
await lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release lock for key {key}: {e}")
|
||||
|
||||
|
||||
def increment_execution_count(user_id: str) -> int:
|
||||
|
||||
@@ -191,15 +191,22 @@ class GraphExecutionJobInfo(GraphExecutionJobArgs):
|
||||
id: str
|
||||
name: str
|
||||
next_run_time: str
|
||||
timezone: str = Field(default="UTC", description="Timezone used for scheduling")
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
job_args: GraphExecutionJobArgs, job_obj: JobObj
|
||||
) -> "GraphExecutionJobInfo":
|
||||
# Extract timezone from the trigger if it's a CronTrigger
|
||||
timezone_str = "UTC"
|
||||
if hasattr(job_obj.trigger, "timezone"):
|
||||
timezone_str = str(job_obj.trigger.timezone)
|
||||
|
||||
return GraphExecutionJobInfo(
|
||||
id=job_obj.id,
|
||||
name=job_obj.name,
|
||||
next_run_time=job_obj.next_run_time.isoformat(),
|
||||
timezone=timezone_str,
|
||||
**job_args.model_dump(),
|
||||
)
|
||||
|
||||
@@ -241,7 +248,7 @@ class Scheduler(AppService):
|
||||
raise UnhealthyServiceError("Scheduler is still initializing")
|
||||
|
||||
# Check if we're in the middle of cleanup
|
||||
if self.cleaned_up:
|
||||
if self._shutting_down:
|
||||
return await super().health_check()
|
||||
|
||||
# Normal operation - check if scheduler is running
|
||||
@@ -368,7 +375,6 @@ class Scheduler(AppService):
|
||||
super().run_service()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
if self.scheduler:
|
||||
logger.info("⏳ Shutting down scheduler...")
|
||||
self.scheduler.shutdown(wait=True)
|
||||
@@ -383,7 +389,7 @@ class Scheduler(AppService):
|
||||
logger.info("⏳ Waiting for event loop thread to finish...")
|
||||
_event_loop_thread.join(timeout=SCHEDULER_OPERATION_TIMEOUT_SECONDS)
|
||||
|
||||
logger.info("Scheduler cleanup complete.")
|
||||
super().cleanup()
|
||||
|
||||
@expose
|
||||
def add_graph_execution_schedule(
|
||||
@@ -395,6 +401,7 @@ class Scheduler(AppService):
|
||||
input_data: BlockInput,
|
||||
input_credentials: dict[str, CredentialsMetaInput],
|
||||
name: Optional[str] = None,
|
||||
user_timezone: str | None = None,
|
||||
) -> GraphExecutionJobInfo:
|
||||
# Validate the graph before scheduling to prevent runtime failures
|
||||
# We don't need the return value, just want the validation to run
|
||||
@@ -408,7 +415,18 @@ class Scheduler(AppService):
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Scheduling job for user {user_id} in UTC (cron: {cron})")
|
||||
# Use provided timezone or default to UTC
|
||||
# Note: Timezone should be passed from the client to avoid database lookups
|
||||
if not user_timezone:
|
||||
user_timezone = "UTC"
|
||||
logger.warning(
|
||||
f"No timezone provided for user {user_id}, using UTC for scheduling. "
|
||||
f"Client should pass user's timezone for correct scheduling."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduling job for user {user_id} with timezone {user_timezone} (cron: {cron})"
|
||||
)
|
||||
|
||||
job_args = GraphExecutionJobArgs(
|
||||
user_id=user_id,
|
||||
@@ -422,12 +440,12 @@ class Scheduler(AppService):
|
||||
execute_graph,
|
||||
kwargs=job_args.model_dump(),
|
||||
name=name,
|
||||
trigger=CronTrigger.from_crontab(cron, timezone="UTC"),
|
||||
trigger=CronTrigger.from_crontab(cron, timezone=user_timezone),
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Added job {job.id} with cron schedule '{cron}' in UTC, input data: {input_data}"
|
||||
f"Added job {job.id} with cron schedule '{cron}' in timezone {user_timezone}, input data: {input_data}"
|
||||
)
|
||||
return GraphExecutionJobInfo.from_db(job_args, job)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user