mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-14 09:38:00 -05:00
Compare commits
3 Commits
gmail-repl
...
fixes-to-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b71c03d96 | ||
|
|
b20b00a441 | ||
|
|
84810ce0af |
244
.github/copilot-instructions.md
vendored
244
.github/copilot-instructions.md
vendored
@@ -1,244 +0,0 @@
|
|||||||
# GitHub Copilot Instructions for AutoGPT
|
|
||||||
|
|
||||||
This file provides comprehensive onboarding information for GitHub Copilot coding agent to work efficiently with the AutoGPT repository.
|
|
||||||
|
|
||||||
## Repository Overview
|
|
||||||
|
|
||||||
**AutoGPT** is a powerful platform for creating, deploying, and managing continuous AI agents that automate complex workflows. This is a large monorepo (~150MB) containing multiple components:
|
|
||||||
|
|
||||||
- **AutoGPT Platform** (`autogpt_platform/`) - Main focus: Modern AI agent platform (Polyform Shield License)
|
|
||||||
- **Classic AutoGPT** (`classic/`) - Legacy agent system (MIT License)
|
|
||||||
- **Documentation** (`docs/`) - MkDocs-based documentation site
|
|
||||||
- **Infrastructure** - Docker configurations, CI/CD, and development tools
|
|
||||||
|
|
||||||
**Primary Languages & Frameworks:**
|
|
||||||
- **Backend**: Python 3.10-3.13, FastAPI, Prisma ORM, PostgreSQL, RabbitMQ
|
|
||||||
- **Frontend**: TypeScript, Next.js 15, React, Tailwind CSS, Radix UI
|
|
||||||
- **Development**: Docker, Poetry, pnpm, Playwright, Storybook
|
|
||||||
|
|
||||||
## Build and Validation Instructions
|
|
||||||
|
|
||||||
### Essential Setup Commands
|
|
||||||
|
|
||||||
**Always run these commands in the correct directory and in this order:**
|
|
||||||
|
|
||||||
1. **Initial Setup** (required once):
|
|
||||||
```bash
|
|
||||||
# Clone and enter repository
|
|
||||||
git clone <repo> && cd AutoGPT
|
|
||||||
|
|
||||||
# Start all services (database, redis, rabbitmq, clamav)
|
|
||||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Backend Setup** (always run before backend development):
|
|
||||||
```bash
|
|
||||||
cd autogpt_platform/backend
|
|
||||||
poetry install # Install dependencies
|
|
||||||
poetry run prisma migrate dev # Run database migrations
|
|
||||||
poetry run prisma generate # Generate Prisma client
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **Frontend Setup** (always run before frontend development):
|
|
||||||
```bash
|
|
||||||
cd autogpt_platform/frontend
|
|
||||||
pnpm install # Install dependencies
|
|
||||||
```
|
|
||||||
|
|
||||||
### Runtime Requirements
|
|
||||||
|
|
||||||
**Critical:** Always ensure Docker services are running before starting development:
|
|
||||||
```bash
|
|
||||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
|
||||||
```
|
|
||||||
|
|
||||||
**Python Version:** Use Python 3.11 (required; managed by Poetry via pyproject.toml)
|
|
||||||
**Node.js Version:** Use Node.js 21+ with pnpm package manager
|
|
||||||
|
|
||||||
### Development Commands
|
|
||||||
|
|
||||||
**Backend Development:**
|
|
||||||
```bash
|
|
||||||
cd autogpt_platform/backend
|
|
||||||
poetry run serve # Start development server (port 8000)
|
|
||||||
poetry run test # Run all tests (requires ~5 minutes)
|
|
||||||
poetry run pytest path/to/test.py # Run specific test
|
|
||||||
poetry run format # Format code (Black + isort) - always run first
|
|
||||||
poetry run lint # Lint code (ruff) - run after format
|
|
||||||
```
|
|
||||||
|
|
||||||
**Frontend Development:**
|
|
||||||
```bash
|
|
||||||
cd autogpt_platform/frontend
|
|
||||||
pnpm dev # Start development server (port 3000) - use for active development
|
|
||||||
pnpm build # Build for production (only needed for E2E tests or deployment)
|
|
||||||
pnpm test # Run Playwright E2E tests (requires build first)
|
|
||||||
pnpm test-ui # Run tests with UI
|
|
||||||
pnpm format # Format and lint code
|
|
||||||
pnpm storybook # Start component development server
|
|
||||||
```
|
|
||||||
|
|
||||||
### Testing Strategy
|
|
||||||
|
|
||||||
**Backend Tests:**
|
|
||||||
- **Block Tests**: `poetry run pytest backend/blocks/test/test_block.py -xvs` (validates all blocks)
|
|
||||||
- **Specific Block**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[BlockName]' -xvs`
|
|
||||||
- **Snapshot Tests**: Use `--snapshot-update` when output changes, always review with `git diff`
|
|
||||||
|
|
||||||
**Frontend Tests:**
|
|
||||||
- **E2E Tests**: Always run `pnpm dev` before `pnpm test` (Playwright requires running instance)
|
|
||||||
- **Component Tests**: Use Storybook for isolated component development
|
|
||||||
|
|
||||||
### Critical Validation Steps
|
|
||||||
|
|
||||||
**Before committing changes:**
|
|
||||||
1. Run `poetry run format` (backend) and `pnpm format` (frontend)
|
|
||||||
2. Ensure all tests pass in modified areas
|
|
||||||
3. Verify Docker services are still running
|
|
||||||
4. Check that database migrations apply cleanly
|
|
||||||
|
|
||||||
**Common Issues & Workarounds:**
|
|
||||||
- **Prisma issues**: Run `poetry run prisma generate` after schema changes
|
|
||||||
- **Permission errors**: Ensure Docker has proper permissions
|
|
||||||
- **Port conflicts**: Check the `docker-compose.yml` file for the current list of exposed ports. You can list all mapped ports with:
|
|
||||||
- **Test timeouts**: Backend tests can take 5+ minutes, use `-x` flag to stop on first failure
|
|
||||||
|
|
||||||
## Project Layout & Architecture
|
|
||||||
|
|
||||||
### Core Architecture
|
|
||||||
|
|
||||||
**AutoGPT Platform** (`autogpt_platform/`):
|
|
||||||
- `backend/` - FastAPI server with async support
|
|
||||||
- `backend/backend/` - Core API logic
|
|
||||||
- `backend/blocks/` - Agent execution blocks
|
|
||||||
- `backend/data/` - Database models and schemas
|
|
||||||
- `schema.prisma` - Database schema definition
|
|
||||||
- `frontend/` - Next.js application
|
|
||||||
- `src/app/` - App Router pages and layouts
|
|
||||||
- `src/components/` - Reusable React components
|
|
||||||
- `src/lib/` - Utilities and configurations
|
|
||||||
- `autogpt_libs/` - Shared Python utilities
|
|
||||||
- `docker-compose.yml` - Development stack orchestration
|
|
||||||
|
|
||||||
**Key Configuration Files:**
|
|
||||||
- `pyproject.toml` - Python dependencies and tooling
|
|
||||||
- `package.json` - Node.js dependencies and scripts
|
|
||||||
- `schema.prisma` - Database schema and migrations
|
|
||||||
- `next.config.mjs` - Next.js configuration
|
|
||||||
- `tailwind.config.ts` - Styling configuration
|
|
||||||
|
|
||||||
### Security & Middleware
|
|
||||||
|
|
||||||
**Cache Protection**: Backend includes middleware preventing sensitive data caching in browsers/proxies
|
|
||||||
**Authentication**: JWT-based with Supabase integration
|
|
||||||
**User ID Validation**: All data access requires user ID checks - verify this for any `data/*.py` changes
|
|
||||||
|
|
||||||
### Development Workflow
|
|
||||||
|
|
||||||
**GitHub Actions**: Multiple CI/CD workflows in `.github/workflows/`
|
|
||||||
- `platform-backend-ci.yml` - Backend testing and validation
|
|
||||||
- `platform-frontend-ci.yml` - Frontend testing and validation
|
|
||||||
- `platform-fullstack-ci.yml` - End-to-end integration tests
|
|
||||||
|
|
||||||
**Pre-commit Hooks**: Run linting and formatting checks
|
|
||||||
**Conventional Commits**: Use format `type(scope): description` (e.g., `feat(backend): add API`)
|
|
||||||
|
|
||||||
### Key Source Files
|
|
||||||
|
|
||||||
**Backend Entry Points:**
|
|
||||||
- `backend/backend/server/server.py` - FastAPI application setup
|
|
||||||
- `backend/backend/data/` - Database models and user management
|
|
||||||
- `backend/blocks/` - Agent execution blocks and logic
|
|
||||||
|
|
||||||
**Frontend Entry Points:**
|
|
||||||
- `frontend/src/app/layout.tsx` - Root application layout
|
|
||||||
- `frontend/src/app/page.tsx` - Home page
|
|
||||||
- `frontend/src/lib/supabase/` - Authentication and database client
|
|
||||||
|
|
||||||
**Protected Routes**: Update `frontend/lib/supabase/middleware.ts` when adding protected routes
|
|
||||||
|
|
||||||
### Agent Block System
|
|
||||||
|
|
||||||
Agents are built using a visual block-based system where each block performs a single action. Blocks are defined in `backend/blocks/` and must include:
|
|
||||||
- Block definition with input/output schemas
|
|
||||||
- Execution logic with proper error handling
|
|
||||||
- Tests validating functionality
|
|
||||||
|
|
||||||
### Database & ORM
|
|
||||||
|
|
||||||
**Prisma ORM** with PostgreSQL backend including pgvector for embeddings:
|
|
||||||
- Schema in `schema.prisma`
|
|
||||||
- Migrations in `backend/migrations/`
|
|
||||||
- Always run `prisma migrate dev` and `prisma generate` after schema changes
|
|
||||||
|
|
||||||
## Environment Configuration
|
|
||||||
|
|
||||||
### Configuration Files Priority Order
|
|
||||||
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
|
|
||||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
|
||||||
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
|
|
||||||
4. Docker Compose `environment:` sections override file-based config
|
|
||||||
5. Shell environment variables have highest precedence
|
|
||||||
|
|
||||||
### Docker Environment Setup
|
|
||||||
- All services use hardcoded defaults (no `${VARIABLE}` substitutions)
|
|
||||||
- The `env_file` directive loads variables INTO containers at runtime
|
|
||||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
|
||||||
- Copy `.env.default` files to `.env` for local development customization
|
|
||||||
|
|
||||||
## Advanced Development Patterns
|
|
||||||
|
|
||||||
### Adding New Blocks
|
|
||||||
1. Create file in `/backend/backend/blocks/`
|
|
||||||
2. Inherit from `Block` base class with input/output schemas
|
|
||||||
3. Implement `run` method with proper error handling
|
|
||||||
4. Generate block UUID using `uuid.uuid4()`
|
|
||||||
5. Register in block registry
|
|
||||||
6. Write tests alongside block implementation
|
|
||||||
7. Consider how inputs/outputs connect with other blocks in graph editor
|
|
||||||
|
|
||||||
### API Development
|
|
||||||
1. Update routes in `/backend/backend/server/routers/`
|
|
||||||
2. Add/update Pydantic models in same directory
|
|
||||||
3. Write tests alongside route files
|
|
||||||
4. For `data/*.py` changes, validate user ID checks
|
|
||||||
5. Run `poetry run test` to verify changes
|
|
||||||
|
|
||||||
### Frontend Development
|
|
||||||
1. Components in `/frontend/src/components/`
|
|
||||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
|
||||||
3. Add Storybook stories for component development
|
|
||||||
4. Test user-facing features with Playwright E2E tests
|
|
||||||
5. Update protected routes in middleware when needed
|
|
||||||
|
|
||||||
### Security Guidelines
|
|
||||||
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
|
||||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
|
||||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
|
||||||
- Prevents sensitive data caching in browsers/proxies
|
|
||||||
- Add new cacheable endpoints to `CACHEABLE_PATHS`
|
|
||||||
|
|
||||||
### CI/CD Alignment
|
|
||||||
The repository has comprehensive CI workflows that test:
|
|
||||||
- **Backend**: Python 3.11-3.13, services (Redis/RabbitMQ/ClamAV), Prisma migrations, Poetry lock validation
|
|
||||||
- **Frontend**: Node.js 21, pnpm, Playwright with Docker Compose stack, API schema validation
|
|
||||||
- **Integration**: Full-stack type checking and E2E testing
|
|
||||||
|
|
||||||
Match these patterns when developing locally - the copilot setup environment mirrors these CI configurations.
|
|
||||||
|
|
||||||
## Collaboration with Other AI Assistants
|
|
||||||
|
|
||||||
This repository is actively developed with assistance from Claude (via CLAUDE.md files). When working on this codebase:
|
|
||||||
- Check for existing CLAUDE.md files that provide additional context
|
|
||||||
- Follow established patterns and conventions already in the codebase
|
|
||||||
- Maintain consistency with existing code style and architecture
|
|
||||||
- Consider that changes may be reviewed and extended by both human developers and AI assistants
|
|
||||||
|
|
||||||
## Trust These Instructions
|
|
||||||
|
|
||||||
These instructions are comprehensive and tested. Only perform additional searches if:
|
|
||||||
1. Information here is incomplete for your specific task
|
|
||||||
2. You encounter errors not covered by the workarounds
|
|
||||||
3. You need to understand implementation details not covered above
|
|
||||||
|
|
||||||
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
|
|
||||||
302
.github/workflows/copilot-setup-steps.yml
vendored
302
.github/workflows/copilot-setup-steps.yml
vendored
@@ -1,302 +0,0 @@
|
|||||||
name: "Copilot Setup Steps"
|
|
||||||
|
|
||||||
# Automatically run the setup steps when they are changed to allow for easy validation, and
|
|
||||||
# allow manual testing through the repository's "Actions" tab
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
push:
|
|
||||||
paths:
|
|
||||||
- .github/workflows/copilot-setup-steps.yml
|
|
||||||
pull_request:
|
|
||||||
paths:
|
|
||||||
- .github/workflows/copilot-setup-steps.yml
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
# The job MUST be called `copilot-setup-steps` or it will not be picked up by Copilot.
|
|
||||||
copilot-setup-steps:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
timeout-minutes: 45
|
|
||||||
|
|
||||||
# Set the permissions to the lowest permissions possible needed for your steps.
|
|
||||||
# Copilot will be given its own token for its operations.
|
|
||||||
permissions:
|
|
||||||
# If you want to clone the repository as part of your setup steps, for example to install dependencies, you'll need the `contents: read` permission. If you don't clone the repository in your setup steps, Copilot will do this for you automatically after the steps complete.
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
# You can define any steps you want, and they will run before the agent starts.
|
|
||||||
# If you do not check out your code, Copilot will do this for you.
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
submodules: true
|
|
||||||
|
|
||||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: "3.11" # Use standard version matching CI
|
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: ~/.cache/pypoetry
|
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
|
||||||
|
|
||||||
- name: Install Poetry
|
|
||||||
run: |
|
|
||||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
|
||||||
cd autogpt_platform/backend
|
|
||||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
|
||||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
|
||||||
|
|
||||||
# Install Poetry
|
|
||||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
|
||||||
|
|
||||||
# Add Poetry to PATH
|
|
||||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
|
||||||
|
|
||||||
- name: Check poetry.lock
|
|
||||||
working-directory: autogpt_platform/backend
|
|
||||||
run: |
|
|
||||||
poetry lock
|
|
||||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
|
||||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
|
||||||
git checkout poetry.lock # Reset for clean setup
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Install Python dependencies
|
|
||||||
working-directory: autogpt_platform/backend
|
|
||||||
run: poetry install
|
|
||||||
|
|
||||||
- name: Generate Prisma Client
|
|
||||||
working-directory: autogpt_platform/backend
|
|
||||||
run: poetry run prisma generate
|
|
||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
|
||||||
- name: Set up Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: "21"
|
|
||||||
|
|
||||||
- name: Enable corepack
|
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Set pnpm store directory
|
|
||||||
run: |
|
|
||||||
pnpm config set store-dir ~/.pnpm-store
|
|
||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: ~/.pnpm-store
|
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
|
||||||
${{ runner.os }}-pnpm-
|
|
||||||
|
|
||||||
- name: Install JavaScript dependencies
|
|
||||||
working-directory: autogpt_platform/frontend
|
|
||||||
run: pnpm install --frozen-lockfile
|
|
||||||
|
|
||||||
# Install Playwright browsers for frontend testing
|
|
||||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
|
||||||
# - name: Install Playwright browsers
|
|
||||||
# working-directory: autogpt_platform/frontend
|
|
||||||
# run: pnpm playwright install --with-deps chromium
|
|
||||||
|
|
||||||
# Docker setup for development environment
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
|
|
||||||
- name: Copy default environment files
|
|
||||||
working-directory: autogpt_platform
|
|
||||||
run: |
|
|
||||||
# Copy default environment files for development
|
|
||||||
cp .env.default .env
|
|
||||||
cp backend/.env.default backend/.env
|
|
||||||
cp frontend/.env.default frontend/.env
|
|
||||||
|
|
||||||
# Phase 1: Cache and load Docker images for faster setup
|
|
||||||
- name: Set up Docker image cache
|
|
||||||
id: docker-cache
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: ~/docker-cache
|
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
|
||||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
|
||||||
restore-keys: |
|
|
||||||
docker-images-v2-${{ runner.os }}-
|
|
||||||
docker-images-v1-${{ runner.os }}-
|
|
||||||
|
|
||||||
- name: Load or pull Docker images
|
|
||||||
working-directory: autogpt_platform
|
|
||||||
run: |
|
|
||||||
mkdir -p ~/docker-cache
|
|
||||||
|
|
||||||
# Define image list for easy maintenance
|
|
||||||
IMAGES=(
|
|
||||||
"redis:latest"
|
|
||||||
"rabbitmq:management"
|
|
||||||
"clamav/clamav-debian:latest"
|
|
||||||
"busybox:latest"
|
|
||||||
"kong:2.8.1"
|
|
||||||
"supabase/gotrue:v2.170.0"
|
|
||||||
"supabase/postgres:15.8.1.049"
|
|
||||||
"supabase/postgres-meta:v0.86.1"
|
|
||||||
"supabase/studio:20250224-d10db0f"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
|
||||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
|
||||||
echo "Docker cache found, loading images in parallel..."
|
|
||||||
for image in "${IMAGES[@]}"; do
|
|
||||||
# Convert image name to filename (replace : and / with -)
|
|
||||||
filename=$(echo "$image" | tr ':/' '--')
|
|
||||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
|
||||||
echo "Loading $image..."
|
|
||||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
wait
|
|
||||||
echo "All cached images loaded"
|
|
||||||
else
|
|
||||||
echo "No Docker cache found, pulling images in parallel..."
|
|
||||||
# Pull all images in parallel
|
|
||||||
for image in "${IMAGES[@]}"; do
|
|
||||||
docker pull "$image" &
|
|
||||||
done
|
|
||||||
wait
|
|
||||||
|
|
||||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
|
||||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
|
||||||
echo "Saving Docker images to cache in parallel..."
|
|
||||||
for image in "${IMAGES[@]}"; do
|
|
||||||
# Convert image name to filename (replace : and / with -)
|
|
||||||
filename=$(echo "$image" | tr ':/' '--')
|
|
||||||
echo "Saving $image..."
|
|
||||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
|
||||||
done
|
|
||||||
wait
|
|
||||||
echo "Docker image cache saved"
|
|
||||||
else
|
|
||||||
echo "Skipping cache save for PR/feature branch"
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "Docker images ready for use"
|
|
||||||
|
|
||||||
# Phase 2: Build migrate service with GitHub Actions cache
|
|
||||||
- name: Build migrate Docker image with cache
|
|
||||||
working-directory: autogpt_platform
|
|
||||||
run: |
|
|
||||||
# Build the migrate image with buildx for GHA caching
|
|
||||||
docker buildx build \
|
|
||||||
--cache-from type=gha \
|
|
||||||
--cache-to type=gha,mode=max \
|
|
||||||
--target migrate \
|
|
||||||
--tag autogpt_platform-migrate:latest \
|
|
||||||
--load \
|
|
||||||
-f backend/Dockerfile \
|
|
||||||
..
|
|
||||||
|
|
||||||
# Start services using pre-built images
|
|
||||||
- name: Start Docker services for development
|
|
||||||
working-directory: autogpt_platform
|
|
||||||
run: |
|
|
||||||
# Start essential services (migrate image already built with correct tag)
|
|
||||||
docker compose --profile local up deps --no-build --detach
|
|
||||||
echo "Waiting for services to be ready..."
|
|
||||||
|
|
||||||
# Wait for database to be ready
|
|
||||||
echo "Checking database readiness..."
|
|
||||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
|
||||||
echo " Waiting for database..."
|
|
||||||
sleep 2
|
|
||||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
|
||||||
|
|
||||||
# Check migrate service status
|
|
||||||
echo "Checking migration status..."
|
|
||||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
|
||||||
|
|
||||||
# Wait for migrate service to complete
|
|
||||||
echo "Waiting for migrations to complete..."
|
|
||||||
timeout 30 bash -c '
|
|
||||||
ATTEMPTS=0
|
|
||||||
while [ $ATTEMPTS -lt 15 ]; do
|
|
||||||
ATTEMPTS=$((ATTEMPTS + 1))
|
|
||||||
|
|
||||||
# Check using docker directly (more reliable than docker compose ps)
|
|
||||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
|
||||||
|
|
||||||
if [ -z "$CONTAINER_STATUS" ]; then
|
|
||||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
|
||||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
|
||||||
echo "✅ Migrations completed successfully"
|
|
||||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
|
||||||
exit 0
|
|
||||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
|
||||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
|
||||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
|
||||||
echo "Migration logs:"
|
|
||||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
|
||||||
exit 1
|
|
||||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
|
||||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
|
||||||
else
|
|
||||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
|
||||||
fi
|
|
||||||
|
|
||||||
sleep 2
|
|
||||||
done
|
|
||||||
|
|
||||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
|
||||||
echo "Final container check:"
|
|
||||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
|
||||||
echo "Migration logs (if available):"
|
|
||||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
|
||||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
|
||||||
|
|
||||||
# Brief wait for other services to stabilize
|
|
||||||
echo "Waiting 5 seconds for other services to stabilize..."
|
|
||||||
sleep 5
|
|
||||||
|
|
||||||
# Verify installations and provide environment info
|
|
||||||
- name: Verify setup and show environment info
|
|
||||||
run: |
|
|
||||||
echo "=== Python Setup ==="
|
|
||||||
python --version
|
|
||||||
poetry --version
|
|
||||||
|
|
||||||
echo "=== Node.js Setup ==="
|
|
||||||
node --version
|
|
||||||
pnpm --version
|
|
||||||
|
|
||||||
echo "=== Additional Tools ==="
|
|
||||||
docker --version
|
|
||||||
docker compose version
|
|
||||||
gh --version || true
|
|
||||||
|
|
||||||
echo "=== Services Status ==="
|
|
||||||
cd autogpt_platform
|
|
||||||
docker compose ps || true
|
|
||||||
|
|
||||||
echo "=== Backend Dependencies ==="
|
|
||||||
cd backend
|
|
||||||
poetry show | head -10 || true
|
|
||||||
|
|
||||||
echo "=== Frontend Dependencies ==="
|
|
||||||
cd ../frontend
|
|
||||||
pnpm list --depth=0 | head -10 || true
|
|
||||||
|
|
||||||
echo "=== Environment Files ==="
|
|
||||||
ls -la ../.env* || true
|
|
||||||
ls -la .env* || true
|
|
||||||
ls -la ../backend/.env* || true
|
|
||||||
|
|
||||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
|
||||||
echo "🚀 Ready for development with Docker services running"
|
|
||||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
|
||||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
|
||||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.11", "3.12", "3.13"]
|
python-version: ["3.11"]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
services:
|
services:
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ Key models (defined in `/backend/schema.prisma`):
|
|||||||
5. Register in block registry
|
5. Register in block registry
|
||||||
6. Generate the block uuid using `uuid.uuid4()`
|
6. Generate the block uuid using `uuid.uuid4()`
|
||||||
|
|
||||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
Note: when making many new blocks analyze the interfaces for each of these blcoks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||||
ex: do the inputs and outputs tie well together?
|
ex: do the inputs and outputs tie well together?
|
||||||
|
|
||||||
**Modifying the API:**
|
**Modifying the API:**
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ from starlette.status import HTTP_401_UNAUTHORIZED
|
|||||||
from .config import settings
|
from .config import settings
|
||||||
from .jwt_utils import parse_jwt_token
|
from .jwt_utils import parse_jwt_token
|
||||||
|
|
||||||
|
security = HTTPBearer()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
bearer_auth = HTTPBearer(auto_error=False)
|
|
||||||
|
|
||||||
|
|
||||||
async def auth_middleware(request: Request):
|
async def auth_middleware(request: Request):
|
||||||
@@ -20,10 +20,11 @@ async def auth_middleware(request: Request):
|
|||||||
logger.warning("Auth disabled")
|
logger.warning("Auth disabled")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
credentials = await bearer_auth(request)
|
security = HTTPBearer()
|
||||||
|
credentials = await security(request)
|
||||||
|
|
||||||
if not credentials:
|
if not credentials:
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
raise HTTPException(status_code=401, detail="Authorization header is missing")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = parse_jwt_token(credentials.credentials)
|
payload = parse_jwt_token(credentials.credentials)
|
||||||
|
|||||||
41
autogpt_platform/autogpt_libs/poetry.lock
generated
41
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1253,31 +1253,30 @@ pyasn1 = ">=0.1.3"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruff"
|
name = "ruff"
|
||||||
version = "0.12.9"
|
version = "0.12.3"
|
||||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
groups = ["dev"]
|
groups = ["dev"]
|
||||||
files = [
|
files = [
|
||||||
{file = "ruff-0.12.9-py3-none-linux_armv6l.whl", hash = "sha256:fcebc6c79fcae3f220d05585229463621f5dbf24d79fdc4936d9302e177cfa3e"},
|
{file = "ruff-0.12.3-py3-none-linux_armv6l.whl", hash = "sha256:47552138f7206454eaf0c4fe827e546e9ddac62c2a3d2585ca54d29a890137a2"},
|
||||||
{file = "ruff-0.12.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:aed9d15f8c5755c0e74467731a007fcad41f19bcce41cd75f768bbd687f8535f"},
|
{file = "ruff-0.12.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:0a9153b000c6fe169bb307f5bd1b691221c4286c133407b8827c406a55282041"},
|
||||||
{file = "ruff-0.12.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5b15ea354c6ff0d7423814ba6d44be2807644d0c05e9ed60caca87e963e93f70"},
|
{file = "ruff-0.12.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa6b24600cf3b750e48ddb6057e901dd5b9aa426e316addb2a1af185a7509882"},
|
||||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d596c2d0393c2502eaabfef723bd74ca35348a8dac4267d18a94910087807c53"},
|
{file = "ruff-0.12.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2506961bf6ead54887ba3562604d69cb430f59b42133d36976421bc8bd45901"},
|
||||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1b15599931a1a7a03c388b9c5df1bfa62be7ede6eb7ef753b272381f39c3d0ff"},
|
{file = "ruff-0.12.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c4faaff1f90cea9d3033cbbcdf1acf5d7fb11d8180758feb31337391691f3df0"},
|
||||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3d02faa2977fb6f3f32ddb7828e212b7dd499c59eb896ae6c03ea5c303575756"},
|
{file = "ruff-0.12.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40dced4a79d7c264389de1c59467d5d5cefd79e7e06d1dfa2c75497b5269a5a6"},
|
||||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:17d5b6b0b3a25259b69ebcba87908496e6830e03acfb929ef9fd4c58675fa2ea"},
|
{file = "ruff-0.12.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0262d50ba2767ed0fe212aa7e62112a1dcbfd46b858c5bf7bbd11f326998bafc"},
|
||||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:72db7521860e246adbb43f6ef464dd2a532ef2ef1f5dd0d470455b8d9f1773e0"},
|
{file = "ruff-0.12.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12371aec33e1a3758597c5c631bae9a5286f3c963bdfb4d17acdd2d395406687"},
|
||||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a03242c1522b4e0885af63320ad754d53983c9599157ee33e77d748363c561ce"},
|
{file = "ruff-0.12.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:560f13b6baa49785665276c963edc363f8ad4b4fc910a883e2625bdb14a83a9e"},
|
||||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fc83e4e9751e6c13b5046d7162f205d0a7bac5840183c5beebf824b08a27340"},
|
{file = "ruff-0.12.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023040a3499f6f974ae9091bcdd0385dd9e9eb4942f231c23c57708147b06311"},
|
||||||
{file = "ruff-0.12.9-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:881465ed56ba4dd26a691954650de6ad389a2d1fdb130fe51ff18a25639fe4bb"},
|
{file = "ruff-0.12.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:883d844967bffff5ab28bba1a4d246c1a1b2933f48cb9840f3fdc5111c603b07"},
|
||||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:43f07a3ccfc62cdb4d3a3348bf0588358a66da756aa113e071b8ca8c3b9826af"},
|
{file = "ruff-0.12.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2120d3aa855ff385e0e562fdee14d564c9675edbe41625c87eeab744a7830d12"},
|
||||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:07adb221c54b6bba24387911e5734357f042e5669fa5718920ee728aba3cbadc"},
|
{file = "ruff-0.12.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b16647cbb470eaf4750d27dddc6ebf7758b918887b56d39e9c22cce2049082b"},
|
||||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f5cd34fabfdea3933ab85d72359f118035882a01bff15bd1d2b15261d85d5f66"},
|
{file = "ruff-0.12.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e1417051edb436230023575b149e8ff843a324557fe0a265863b7602df86722f"},
|
||||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:f6be1d2ca0686c54564da8e7ee9e25f93bdd6868263805f8c0b8fc6a449db6d7"},
|
{file = "ruff-0.12.3-py3-none-win32.whl", hash = "sha256:dfd45e6e926deb6409d0616078a666ebce93e55e07f0fb0228d4b2608b2c248d"},
|
||||||
{file = "ruff-0.12.9-py3-none-win32.whl", hash = "sha256:cc7a37bd2509974379d0115cc5608a1a4a6c4bff1b452ea69db83c8855d53f93"},
|
{file = "ruff-0.12.3-py3-none-win_amd64.whl", hash = "sha256:a946cf1e7ba3209bdef039eb97647f1c77f6f540e5845ec9c114d3af8df873e7"},
|
||||||
{file = "ruff-0.12.9-py3-none-win_amd64.whl", hash = "sha256:6fb15b1977309741d7d098c8a3cb7a30bc112760a00fb6efb7abc85f00ba5908"},
|
{file = "ruff-0.12.3-py3-none-win_arm64.whl", hash = "sha256:5f9c7c9c8f84c2d7f27e93674d27136fbf489720251544c4da7fb3d742e011b1"},
|
||||||
{file = "ruff-0.12.9-py3-none-win_arm64.whl", hash = "sha256:63c8c819739d86b96d500cce885956a1a48ab056bbcbc61b747ad494b2485089"},
|
{file = "ruff-0.12.3.tar.gz", hash = "sha256:f1b5a4b6668fd7b7ea3697d8d98857390b40c1320a63a178eee6be0899ea2d77"},
|
||||||
{file = "ruff-0.12.9.tar.gz", hash = "sha256:fbd94b2e3c623f659962934e52c2bea6fc6da11f667a427a368adaf3af2c866a"},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1615,4 +1614,4 @@ type = ["pytest-mypy"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.10,<4.0"
|
python-versions = ">=3.10,<4.0"
|
||||||
content-hash = "4cc687aabe5865665fb8c4ccc0ea7e0af80b41e401ca37919f57efa6e0b5be00"
|
content-hash = "f67db13e6f68b1d67a55eee908c1c560bfa44da8509f98f842889a7570a9830f"
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ supabase = "^2.16.0"
|
|||||||
uvicorn = "^0.35.0"
|
uvicorn = "^0.35.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
ruff = "^0.12.9"
|
ruff = "^0.12.3"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|||||||
@@ -106,15 +106,6 @@ TODOIST_CLIENT_SECRET=
|
|||||||
|
|
||||||
NOTION_CLIENT_ID=
|
NOTION_CLIENT_ID=
|
||||||
NOTION_CLIENT_SECRET=
|
NOTION_CLIENT_SECRET=
|
||||||
|
|
||||||
# Discord OAuth App credentials
|
|
||||||
# 1. Go to https://discord.com/developers/applications
|
|
||||||
# 2. Create a new application
|
|
||||||
# 3. Go to OAuth2 section and add redirect URI: http://localhost:3000/auth/integrations/oauth_callback
|
|
||||||
# 4. Copy Client ID and Client Secret below
|
|
||||||
DISCORD_CLIENT_ID=
|
|
||||||
DISCORD_CLIENT_SECRET=
|
|
||||||
|
|
||||||
REDDIT_CLIENT_ID=
|
REDDIT_CLIENT_ID=
|
||||||
REDDIT_CLIENT_SECRET=
|
REDDIT_CLIENT_SECRET=
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +1,16 @@
|
|||||||
FROM debian:13-slim AS builder
|
FROM python:3.11.10-slim-bookworm AS builder
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1
|
ENV PYTHONDONTWRITEBYTECODE 1
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED 1
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
||||||
|
|
||||||
# Update package list and install Python and build dependencies
|
# Update package list and install build dependencies in a single layer
|
||||||
RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||||
&& apt-get install -y \
|
&& apt-get install -y \
|
||||||
python3.13 \
|
|
||||||
python3.13-dev \
|
|
||||||
python3.13-venv \
|
|
||||||
python3-pip \
|
|
||||||
build-essential \
|
build-essential \
|
||||||
libpq5 \
|
libpq5 \
|
||||||
libz-dev \
|
libz-dev \
|
||||||
@@ -24,11 +19,13 @@ RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
|||||||
|
|
||||||
ENV POETRY_HOME=/opt/poetry
|
ENV POETRY_HOME=/opt/poetry
|
||||||
ENV POETRY_NO_INTERACTION=1
|
ENV POETRY_NO_INTERACTION=1
|
||||||
ENV POETRY_VIRTUALENVS_CREATE=true
|
ENV POETRY_VIRTUALENVS_CREATE=false
|
||||||
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
|
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
RUN pip3 install poetry --break-system-packages
|
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||||
|
RUN pip3 install --upgrade pip setuptools
|
||||||
|
|
||||||
|
RUN pip3 install poetry
|
||||||
|
|
||||||
# Copy and install dependencies
|
# Copy and install dependencies
|
||||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||||
@@ -40,30 +37,27 @@ RUN poetry install --no-ansi --no-root
|
|||||||
COPY autogpt_platform/backend/schema.prisma ./
|
COPY autogpt_platform/backend/schema.prisma ./
|
||||||
RUN poetry run prisma generate
|
RUN poetry run prisma generate
|
||||||
|
|
||||||
FROM debian:13-slim AS server_dependencies
|
FROM python:3.11.10-slim-bookworm AS server_dependencies
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
ENV POETRY_HOME=/opt/poetry \
|
ENV POETRY_HOME=/opt/poetry \
|
||||||
POETRY_NO_INTERACTION=1 \
|
POETRY_NO_INTERACTION=1 \
|
||||||
POETRY_VIRTUALENVS_CREATE=true \
|
POETRY_VIRTUALENVS_CREATE=false
|
||||||
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
|
||||||
DEBIAN_FRONTEND=noninteractive
|
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
# Install Python without upgrading system-managed packages
|
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN pip3 install --upgrade pip setuptools
|
||||||
python3.13 \
|
|
||||||
python3-pip
|
|
||||||
|
|
||||||
# Copy only necessary files from builder
|
# Copy only necessary files from builder
|
||||||
COPY --from=builder /app /app
|
COPY --from=builder /app /app
|
||||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
COPY --from=builder /usr/local/lib/python3.11 /usr/local/lib/python3.11
|
||||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
COPY --from=builder /usr/local/bin /usr/local/bin
|
||||||
# Copy Prisma binaries
|
# Copy Prisma binaries
|
||||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||||
|
|
||||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
|
||||||
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
|
|
||||||
RUN mkdir -p /app/autogpt_platform/autogpt_libs
|
RUN mkdir -p /app/autogpt_platform/autogpt_libs
|
||||||
RUN mkdir -p /app/autogpt_platform/backend
|
RUN mkdir -p /app/autogpt_platform/backend
|
||||||
|
|||||||
@@ -25,9 +25,6 @@ class AgentExecutorBlock(Block):
|
|||||||
user_id: str = SchemaField(description="User ID")
|
user_id: str = SchemaField(description="User ID")
|
||||||
graph_id: str = SchemaField(description="Graph ID")
|
graph_id: str = SchemaField(description="Graph ID")
|
||||||
graph_version: int = SchemaField(description="Graph Version")
|
graph_version: int = SchemaField(description="Graph Version")
|
||||||
agent_name: Optional[str] = SchemaField(
|
|
||||||
default=None, description="Name to display in the Builder UI"
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs: BlockInput = SchemaField(description="Input data for the graph")
|
inputs: BlockInput = SchemaField(description="Input data for the graph")
|
||||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ class AIMusicGeneratorBlock(Block):
|
|||||||
output_format=input_data.output_format,
|
output_format=input_data.output_format,
|
||||||
normalization_strategy=input_data.normalization_strategy,
|
normalization_strategy=input_data.normalization_strategy,
|
||||||
)
|
)
|
||||||
if result and isinstance(result, str) and result.startswith("http"):
|
if result and result != "No output received":
|
||||||
yield "result", result
|
yield "result", result
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,205 +0,0 @@
|
|||||||
"""
|
|
||||||
Meeting BaaS API client module.
|
|
||||||
All API calls centralized for consistency and maintainability.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from backend.sdk import Requests
|
|
||||||
|
|
||||||
|
|
||||||
class MeetingBaasAPI:
|
|
||||||
"""Client for Meeting BaaS API endpoints."""
|
|
||||||
|
|
||||||
BASE_URL = "https://api.meetingbaas.com"
|
|
||||||
|
|
||||||
def __init__(self, api_key: str):
|
|
||||||
"""Initialize API client with authentication key."""
|
|
||||||
self.api_key = api_key
|
|
||||||
self.headers = {"x-meeting-baas-api-key": api_key}
|
|
||||||
self.requests = Requests()
|
|
||||||
|
|
||||||
# Bot Management Endpoints
|
|
||||||
|
|
||||||
async def join_meeting(
|
|
||||||
self,
|
|
||||||
bot_name: str,
|
|
||||||
meeting_url: str,
|
|
||||||
reserved: bool = False,
|
|
||||||
bot_image: Optional[str] = None,
|
|
||||||
entry_message: Optional[str] = None,
|
|
||||||
start_time: Optional[int] = None,
|
|
||||||
speech_to_text: Optional[Dict[str, Any]] = None,
|
|
||||||
webhook_url: Optional[str] = None,
|
|
||||||
automatic_leave: Optional[Dict[str, Any]] = None,
|
|
||||||
extra: Optional[Dict[str, Any]] = None,
|
|
||||||
recording_mode: str = "speaker_view",
|
|
||||||
streaming: Optional[Dict[str, Any]] = None,
|
|
||||||
deduplication_key: Optional[str] = None,
|
|
||||||
zoom_sdk_id: Optional[str] = None,
|
|
||||||
zoom_sdk_pwd: Optional[str] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Deploy a bot to join and record a meeting.
|
|
||||||
|
|
||||||
POST /bots
|
|
||||||
"""
|
|
||||||
body = {
|
|
||||||
"bot_name": bot_name,
|
|
||||||
"meeting_url": meeting_url,
|
|
||||||
"reserved": reserved,
|
|
||||||
"recording_mode": recording_mode,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add optional fields if provided
|
|
||||||
if bot_image is not None:
|
|
||||||
body["bot_image"] = bot_image
|
|
||||||
if entry_message is not None:
|
|
||||||
body["entry_message"] = entry_message
|
|
||||||
if start_time is not None:
|
|
||||||
body["start_time"] = start_time
|
|
||||||
if speech_to_text is not None:
|
|
||||||
body["speech_to_text"] = speech_to_text
|
|
||||||
if webhook_url is not None:
|
|
||||||
body["webhook_url"] = webhook_url
|
|
||||||
if automatic_leave is not None:
|
|
||||||
body["automatic_leave"] = automatic_leave
|
|
||||||
if extra is not None:
|
|
||||||
body["extra"] = extra
|
|
||||||
if streaming is not None:
|
|
||||||
body["streaming"] = streaming
|
|
||||||
if deduplication_key is not None:
|
|
||||||
body["deduplication_key"] = deduplication_key
|
|
||||||
if zoom_sdk_id is not None:
|
|
||||||
body["zoom_sdk_id"] = zoom_sdk_id
|
|
||||||
if zoom_sdk_pwd is not None:
|
|
||||||
body["zoom_sdk_pwd"] = zoom_sdk_pwd
|
|
||||||
|
|
||||||
response = await self.requests.post(
|
|
||||||
f"{self.BASE_URL}/bots",
|
|
||||||
headers=self.headers,
|
|
||||||
json=body,
|
|
||||||
)
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
async def leave_meeting(self, bot_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
Remove a bot from an ongoing meeting.
|
|
||||||
|
|
||||||
DELETE /bots/{uuid}
|
|
||||||
"""
|
|
||||||
response = await self.requests.delete(
|
|
||||||
f"{self.BASE_URL}/bots/{bot_id}",
|
|
||||||
headers=self.headers,
|
|
||||||
)
|
|
||||||
return response.status in [200, 204]
|
|
||||||
|
|
||||||
async def retranscribe(
|
|
||||||
self,
|
|
||||||
bot_uuid: str,
|
|
||||||
speech_to_text: Optional[Dict[str, Any]] = None,
|
|
||||||
webhook_url: Optional[str] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Re-run transcription on a bot's audio.
|
|
||||||
|
|
||||||
POST /bots/retranscribe
|
|
||||||
"""
|
|
||||||
body: Dict[str, Any] = {"bot_uuid": bot_uuid}
|
|
||||||
|
|
||||||
if speech_to_text is not None:
|
|
||||||
body["speech_to_text"] = speech_to_text
|
|
||||||
if webhook_url is not None:
|
|
||||||
body["webhook_url"] = webhook_url
|
|
||||||
|
|
||||||
response = await self.requests.post(
|
|
||||||
f"{self.BASE_URL}/bots/retranscribe",
|
|
||||||
headers=self.headers,
|
|
||||||
json=body,
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status == 202:
|
|
||||||
return {"accepted": True}
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
# Data Retrieval Endpoints
|
|
||||||
|
|
||||||
async def get_meeting_data(
|
|
||||||
self, bot_id: str, include_transcripts: bool = True
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Retrieve meeting data including recording and transcripts.
|
|
||||||
|
|
||||||
GET /bots/meeting_data
|
|
||||||
"""
|
|
||||||
params = {
|
|
||||||
"bot_id": bot_id,
|
|
||||||
"include_transcripts": str(include_transcripts).lower(),
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await self.requests.get(
|
|
||||||
f"{self.BASE_URL}/bots/meeting_data",
|
|
||||||
headers=self.headers,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
async def get_screenshots(self, bot_id: str) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Retrieve screenshots captured during a meeting.
|
|
||||||
|
|
||||||
GET /bots/{uuid}/screenshots
|
|
||||||
"""
|
|
||||||
response = await self.requests.get(
|
|
||||||
f"{self.BASE_URL}/bots/{bot_id}/screenshots",
|
|
||||||
headers=self.headers,
|
|
||||||
)
|
|
||||||
result = response.json()
|
|
||||||
# Ensure we return a list
|
|
||||||
if isinstance(result, list):
|
|
||||||
return result
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def delete_data(self, bot_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
Delete a bot's recorded data.
|
|
||||||
|
|
||||||
POST /bots/{uuid}/delete_data
|
|
||||||
"""
|
|
||||||
response = await self.requests.post(
|
|
||||||
f"{self.BASE_URL}/bots/{bot_id}/delete_data",
|
|
||||||
headers=self.headers,
|
|
||||||
)
|
|
||||||
return response.status == 200
|
|
||||||
|
|
||||||
async def list_bots_with_metadata(
|
|
||||||
self,
|
|
||||||
limit: Optional[int] = None,
|
|
||||||
offset: Optional[int] = None,
|
|
||||||
sort_by: Optional[str] = None,
|
|
||||||
sort_order: Optional[str] = None,
|
|
||||||
filter_by: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
List bots with metadata including IDs, names, and meeting details.
|
|
||||||
|
|
||||||
GET /bots/bots_with_metadata
|
|
||||||
"""
|
|
||||||
params = {}
|
|
||||||
if limit is not None:
|
|
||||||
params["limit"] = limit
|
|
||||||
if offset is not None:
|
|
||||||
params["offset"] = offset
|
|
||||||
if sort_by is not None:
|
|
||||||
params["sort_by"] = sort_by
|
|
||||||
if sort_order is not None:
|
|
||||||
params["sort_order"] = sort_order
|
|
||||||
if filter_by is not None:
|
|
||||||
params.update(filter_by)
|
|
||||||
|
|
||||||
response = await self.requests.get(
|
|
||||||
f"{self.BASE_URL}/bots/bots_with_metadata",
|
|
||||||
headers=self.headers,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
return response.json()
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
"""
|
|
||||||
Shared configuration for all Meeting BaaS blocks using the SDK pattern.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from backend.sdk import BlockCostType, ProviderBuilder
|
|
||||||
|
|
||||||
# Configure the Meeting BaaS provider with API key authentication
|
|
||||||
baas = (
|
|
||||||
ProviderBuilder("baas")
|
|
||||||
.with_api_key("MEETING_BAAS_API_KEY", "Meeting BaaS API Key")
|
|
||||||
.with_base_cost(5, BlockCostType.RUN) # Higher cost for meeting recording service
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
@@ -1,217 +0,0 @@
|
|||||||
"""
|
|
||||||
Meeting BaaS bot (recording) blocks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from backend.sdk import (
|
|
||||||
APIKeyCredentials,
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchema,
|
|
||||||
CredentialsMetaInput,
|
|
||||||
SchemaField,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ._api import MeetingBaasAPI
|
|
||||||
from ._config import baas
|
|
||||||
|
|
||||||
|
|
||||||
class BaasBotJoinMeetingBlock(Block):
|
|
||||||
"""
|
|
||||||
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchema):
|
|
||||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
|
||||||
description="Meeting BaaS API credentials"
|
|
||||||
)
|
|
||||||
meeting_url: str = SchemaField(
|
|
||||||
description="The URL of the meeting the bot should join"
|
|
||||||
)
|
|
||||||
bot_name: str = SchemaField(
|
|
||||||
description="Display name for the bot in the meeting"
|
|
||||||
)
|
|
||||||
bot_image: str = SchemaField(
|
|
||||||
description="URL to an image for the bot's avatar (16:9 ratio recommended)",
|
|
||||||
default="",
|
|
||||||
)
|
|
||||||
entry_message: str = SchemaField(
|
|
||||||
description="Chat message the bot will post upon entry", default=""
|
|
||||||
)
|
|
||||||
reserved: bool = SchemaField(
|
|
||||||
description="Use a reserved bot slot (joins 4 min before meeting)",
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
start_time: Optional[int] = SchemaField(
|
|
||||||
description="Unix timestamp (ms) when bot should join", default=None
|
|
||||||
)
|
|
||||||
webhook_url: str | None = SchemaField(
|
|
||||||
description="URL to receive webhook events for this bot", default=None
|
|
||||||
)
|
|
||||||
timeouts: dict = SchemaField(
|
|
||||||
description="Automatic leave timeouts configuration", default={}
|
|
||||||
)
|
|
||||||
extra: dict = SchemaField(
|
|
||||||
description="Custom metadata to attach to the bot", default={}
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchema):
|
|
||||||
bot_id: str = SchemaField(description="UUID of the deployed bot")
|
|
||||||
join_response: dict = SchemaField(
|
|
||||||
description="Full response from join operation"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="377d1a6a-a99b-46cf-9af3-1d1b12758e04",
|
|
||||||
description="Deploy a bot to join and record a meeting",
|
|
||||||
categories={BlockCategory.COMMUNICATION},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
|
||||||
) -> BlockOutput:
|
|
||||||
api_key = credentials.api_key.get_secret_value()
|
|
||||||
api = MeetingBaasAPI(api_key)
|
|
||||||
|
|
||||||
# Call API with all parameters
|
|
||||||
data = await api.join_meeting(
|
|
||||||
bot_name=input_data.bot_name,
|
|
||||||
meeting_url=input_data.meeting_url,
|
|
||||||
reserved=input_data.reserved,
|
|
||||||
bot_image=input_data.bot_image if input_data.bot_image else None,
|
|
||||||
entry_message=(
|
|
||||||
input_data.entry_message if input_data.entry_message else None
|
|
||||||
),
|
|
||||||
start_time=input_data.start_time,
|
|
||||||
speech_to_text={"provider": "Default"},
|
|
||||||
webhook_url=input_data.webhook_url if input_data.webhook_url else None,
|
|
||||||
automatic_leave=input_data.timeouts if input_data.timeouts else None,
|
|
||||||
extra=input_data.extra if input_data.extra else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "bot_id", data.get("bot_id", "")
|
|
||||||
yield "join_response", data
|
|
||||||
|
|
||||||
|
|
||||||
class BaasBotLeaveMeetingBlock(Block):
|
|
||||||
"""
|
|
||||||
Force the bot to exit the call.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchema):
|
|
||||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
|
||||||
description="Meeting BaaS API credentials"
|
|
||||||
)
|
|
||||||
bot_id: str = SchemaField(description="UUID of the bot to remove from meeting")
|
|
||||||
|
|
||||||
class Output(BlockSchema):
|
|
||||||
left: bool = SchemaField(description="Whether the bot successfully left")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="bf77d128-8b25-4280-b5c7-2d553ba7e482",
|
|
||||||
description="Remove a bot from an ongoing meeting",
|
|
||||||
categories={BlockCategory.COMMUNICATION},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
|
||||||
) -> BlockOutput:
|
|
||||||
api_key = credentials.api_key.get_secret_value()
|
|
||||||
api = MeetingBaasAPI(api_key)
|
|
||||||
|
|
||||||
# Leave meeting
|
|
||||||
left = await api.leave_meeting(input_data.bot_id)
|
|
||||||
|
|
||||||
yield "left", left
|
|
||||||
|
|
||||||
|
|
||||||
class BaasBotFetchMeetingDataBlock(Block):
|
|
||||||
"""
|
|
||||||
Pull MP4 URL, transcript & metadata for a completed meeting.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchema):
|
|
||||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
|
||||||
description="Meeting BaaS API credentials"
|
|
||||||
)
|
|
||||||
bot_id: str = SchemaField(description="UUID of the bot whose data to fetch")
|
|
||||||
include_transcripts: bool = SchemaField(
|
|
||||||
description="Include transcript data in response", default=True
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchema):
|
|
||||||
mp4_url: str = SchemaField(
|
|
||||||
description="URL to download the meeting recording (time-limited)"
|
|
||||||
)
|
|
||||||
transcript: list = SchemaField(description="Meeting transcript data")
|
|
||||||
metadata: dict = SchemaField(description="Meeting metadata and bot information")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="ea7c1309-303c-4da1-893f-89c0e9d64e78",
|
|
||||||
description="Retrieve recorded meeting data",
|
|
||||||
categories={BlockCategory.DATA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
|
||||||
) -> BlockOutput:
|
|
||||||
api_key = credentials.api_key.get_secret_value()
|
|
||||||
api = MeetingBaasAPI(api_key)
|
|
||||||
|
|
||||||
# Fetch meeting data
|
|
||||||
data = await api.get_meeting_data(
|
|
||||||
bot_id=input_data.bot_id,
|
|
||||||
include_transcripts=input_data.include_transcripts,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "mp4_url", data.get("mp4", "")
|
|
||||||
yield "transcript", data.get("bot_data", {}).get("transcripts", [])
|
|
||||||
yield "metadata", data.get("bot_data", {}).get("bot", {})
|
|
||||||
|
|
||||||
|
|
||||||
class BaasBotDeleteRecordingBlock(Block):
|
|
||||||
"""
|
|
||||||
Purge MP4 + transcript data for privacy or storage management.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchema):
|
|
||||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
|
||||||
description="Meeting BaaS API credentials"
|
|
||||||
)
|
|
||||||
bot_id: str = SchemaField(description="UUID of the bot whose data to delete")
|
|
||||||
|
|
||||||
class Output(BlockSchema):
|
|
||||||
deleted: bool = SchemaField(
|
|
||||||
description="Whether the data was successfully deleted"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="bf8d1aa6-42d8-4944-b6bd-6bac554c0d3b",
|
|
||||||
description="Permanently delete a meeting's recorded data",
|
|
||||||
categories={BlockCategory.DATA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
|
||||||
) -> BlockOutput:
|
|
||||||
api_key = credentials.api_key.get_secret_value()
|
|
||||||
api = MeetingBaasAPI(api_key)
|
|
||||||
|
|
||||||
# Delete recording data
|
|
||||||
deleted = await api.delete_data(input_data.bot_id)
|
|
||||||
|
|
||||||
yield "deleted", deleted
|
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
"""
|
|
||||||
DataForSEO API client with async support using the SDK patterns.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from backend.sdk import Requests, UserPasswordCredentials
|
|
||||||
|
|
||||||
|
|
||||||
class DataForSeoClient:
|
|
||||||
"""Client for the DataForSEO API using async requests."""
|
|
||||||
|
|
||||||
API_URL = "https://api.dataforseo.com"
|
|
||||||
|
|
||||||
def __init__(self, credentials: UserPasswordCredentials):
|
|
||||||
self.credentials = credentials
|
|
||||||
self.requests = Requests(
|
|
||||||
trusted_origins=["https://api.dataforseo.com"],
|
|
||||||
raise_for_status=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_headers(self) -> Dict[str, str]:
|
|
||||||
"""Generate the authorization header using Basic Auth."""
|
|
||||||
username = self.credentials.username.get_secret_value()
|
|
||||||
password = self.credentials.password.get_secret_value()
|
|
||||||
credentials_str = f"{username}:{password}"
|
|
||||||
encoded = base64.b64encode(credentials_str.encode("ascii")).decode("ascii")
|
|
||||||
return {
|
|
||||||
"Authorization": f"Basic {encoded}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
async def keyword_suggestions(
|
|
||||||
self,
|
|
||||||
keyword: str,
|
|
||||||
location_code: Optional[int] = None,
|
|
||||||
language_code: Optional[str] = None,
|
|
||||||
include_seed_keyword: bool = True,
|
|
||||||
include_serp_info: bool = False,
|
|
||||||
include_clickstream_data: bool = False,
|
|
||||||
limit: int = 100,
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get keyword suggestions from DataForSEO Labs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
keyword: Seed keyword
|
|
||||||
location_code: Location code for targeting
|
|
||||||
language_code: Language code (e.g., "en")
|
|
||||||
include_seed_keyword: Include seed keyword in results
|
|
||||||
include_serp_info: Include SERP data
|
|
||||||
include_clickstream_data: Include clickstream metrics
|
|
||||||
limit: Maximum number of results (up to 3000)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
API response with keyword suggestions
|
|
||||||
"""
|
|
||||||
endpoint = f"{self.API_URL}/v3/dataforseo_labs/google/keyword_suggestions/live"
|
|
||||||
|
|
||||||
# Build payload only with non-None values to avoid sending null fields
|
|
||||||
task_data: dict[str, Any] = {
|
|
||||||
"keyword": keyword,
|
|
||||||
}
|
|
||||||
|
|
||||||
if location_code is not None:
|
|
||||||
task_data["location_code"] = location_code
|
|
||||||
if language_code is not None:
|
|
||||||
task_data["language_code"] = language_code
|
|
||||||
if include_seed_keyword is not None:
|
|
||||||
task_data["include_seed_keyword"] = include_seed_keyword
|
|
||||||
if include_serp_info is not None:
|
|
||||||
task_data["include_serp_info"] = include_serp_info
|
|
||||||
if include_clickstream_data is not None:
|
|
||||||
task_data["include_clickstream_data"] = include_clickstream_data
|
|
||||||
if limit is not None:
|
|
||||||
task_data["limit"] = limit
|
|
||||||
|
|
||||||
payload = [task_data]
|
|
||||||
|
|
||||||
response = await self.requests.post(
|
|
||||||
endpoint,
|
|
||||||
headers=self._get_headers(),
|
|
||||||
json=payload,
|
|
||||||
)
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
# Check for API errors
|
|
||||||
if response.status != 200:
|
|
||||||
error_message = data.get("status_message", "Unknown error")
|
|
||||||
raise Exception(
|
|
||||||
f"DataForSEO API error ({response.status}): {error_message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract the results from the response
|
|
||||||
if data.get("tasks") and len(data["tasks"]) > 0:
|
|
||||||
task = data["tasks"][0]
|
|
||||||
if task.get("status_code") == 20000: # Success code
|
|
||||||
return task.get("result", [])
|
|
||||||
else:
|
|
||||||
error_msg = task.get("status_message", "Task failed")
|
|
||||||
raise Exception(f"DataForSEO task error: {error_msg}")
|
|
||||||
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def related_keywords(
|
|
||||||
self,
|
|
||||||
keyword: str,
|
|
||||||
location_code: Optional[int] = None,
|
|
||||||
language_code: Optional[str] = None,
|
|
||||||
include_seed_keyword: bool = True,
|
|
||||||
include_serp_info: bool = False,
|
|
||||||
include_clickstream_data: bool = False,
|
|
||||||
limit: int = 100,
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get related keywords from DataForSEO Labs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
keyword: Seed keyword
|
|
||||||
location_code: Location code for targeting
|
|
||||||
language_code: Language code (e.g., "en")
|
|
||||||
include_seed_keyword: Include seed keyword in results
|
|
||||||
include_serp_info: Include SERP data
|
|
||||||
include_clickstream_data: Include clickstream metrics
|
|
||||||
limit: Maximum number of results (up to 3000)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
API response with related keywords
|
|
||||||
"""
|
|
||||||
endpoint = f"{self.API_URL}/v3/dataforseo_labs/google/related_keywords/live"
|
|
||||||
|
|
||||||
# Build payload only with non-None values to avoid sending null fields
|
|
||||||
task_data: dict[str, Any] = {
|
|
||||||
"keyword": keyword,
|
|
||||||
}
|
|
||||||
|
|
||||||
if location_code is not None:
|
|
||||||
task_data["location_code"] = location_code
|
|
||||||
if language_code is not None:
|
|
||||||
task_data["language_code"] = language_code
|
|
||||||
if include_seed_keyword is not None:
|
|
||||||
task_data["include_seed_keyword"] = include_seed_keyword
|
|
||||||
if include_serp_info is not None:
|
|
||||||
task_data["include_serp_info"] = include_serp_info
|
|
||||||
if include_clickstream_data is not None:
|
|
||||||
task_data["include_clickstream_data"] = include_clickstream_data
|
|
||||||
if limit is not None:
|
|
||||||
task_data["limit"] = limit
|
|
||||||
|
|
||||||
payload = [task_data]
|
|
||||||
|
|
||||||
response = await self.requests.post(
|
|
||||||
endpoint,
|
|
||||||
headers=self._get_headers(),
|
|
||||||
json=payload,
|
|
||||||
)
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
# Check for API errors
|
|
||||||
if response.status != 200:
|
|
||||||
error_message = data.get("status_message", "Unknown error")
|
|
||||||
raise Exception(
|
|
||||||
f"DataForSEO API error ({response.status}): {error_message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract the results from the response
|
|
||||||
if data.get("tasks") and len(data["tasks"]) > 0:
|
|
||||||
task = data["tasks"][0]
|
|
||||||
if task.get("status_code") == 20000: # Success code
|
|
||||||
return task.get("result", [])
|
|
||||||
else:
|
|
||||||
error_msg = task.get("status_message", "Task failed")
|
|
||||||
raise Exception(f"DataForSEO task error: {error_msg}")
|
|
||||||
|
|
||||||
return []
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
"""
|
|
||||||
Configuration for all DataForSEO blocks using the new SDK pattern.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from backend.sdk import BlockCostType, ProviderBuilder
|
|
||||||
|
|
||||||
# Build the DataForSEO provider with username/password authentication
|
|
||||||
dataforseo = (
|
|
||||||
ProviderBuilder("dataforseo")
|
|
||||||
.with_user_password(
|
|
||||||
username_env_var="DATAFORSEO_USERNAME",
|
|
||||||
password_env_var="DATAFORSEO_PASSWORD",
|
|
||||||
title="DataForSEO Credentials",
|
|
||||||
)
|
|
||||||
.with_base_cost(1, BlockCostType.RUN)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
@@ -1,273 +0,0 @@
|
|||||||
"""
|
|
||||||
DataForSEO Google Keyword Suggestions block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from backend.sdk import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchema,
|
|
||||||
CredentialsMetaInput,
|
|
||||||
SchemaField,
|
|
||||||
UserPasswordCredentials,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ._api import DataForSeoClient
|
|
||||||
from ._config import dataforseo
|
|
||||||
|
|
||||||
|
|
||||||
class KeywordSuggestion(BlockSchema):
|
|
||||||
"""Schema for a keyword suggestion result."""
|
|
||||||
|
|
||||||
keyword: str = SchemaField(description="The keyword suggestion")
|
|
||||||
search_volume: Optional[int] = SchemaField(
|
|
||||||
description="Monthly search volume", default=None
|
|
||||||
)
|
|
||||||
competition: Optional[float] = SchemaField(
|
|
||||||
description="Competition level (0-1)", default=None
|
|
||||||
)
|
|
||||||
cpc: Optional[float] = SchemaField(
|
|
||||||
description="Cost per click in USD", default=None
|
|
||||||
)
|
|
||||||
keyword_difficulty: Optional[int] = SchemaField(
|
|
||||||
description="Keyword difficulty score", default=None
|
|
||||||
)
|
|
||||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
|
||||||
description="data from SERP for each keyword", default=None
|
|
||||||
)
|
|
||||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
|
||||||
description="Clickstream data metrics", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DataForSeoKeywordSuggestionsBlock(Block):
|
|
||||||
"""Block for getting keyword suggestions from DataForSEO Labs."""
|
|
||||||
|
|
||||||
class Input(BlockSchema):
|
|
||||||
credentials: CredentialsMetaInput = dataforseo.credentials_field(
|
|
||||||
description="DataForSEO credentials (username and password)"
|
|
||||||
)
|
|
||||||
keyword: str = SchemaField(description="Seed keyword to get suggestions for")
|
|
||||||
location_code: Optional[int] = SchemaField(
|
|
||||||
description="Location code for targeting (e.g., 2840 for USA)",
|
|
||||||
default=2840, # USA
|
|
||||||
)
|
|
||||||
language_code: Optional[str] = SchemaField(
|
|
||||||
description="Language code (e.g., 'en' for English)",
|
|
||||||
default="en",
|
|
||||||
)
|
|
||||||
include_seed_keyword: bool = SchemaField(
|
|
||||||
description="Include the seed keyword in results",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
include_serp_info: bool = SchemaField(
|
|
||||||
description="Include SERP information",
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
include_clickstream_data: bool = SchemaField(
|
|
||||||
description="Include clickstream metrics",
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
limit: int = SchemaField(
|
|
||||||
description="Maximum number of results (up to 3000)",
|
|
||||||
default=100,
|
|
||||||
ge=1,
|
|
||||||
le=3000,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchema):
|
|
||||||
suggestions: List[KeywordSuggestion] = SchemaField(
|
|
||||||
description="List of keyword suggestions with metrics"
|
|
||||||
)
|
|
||||||
suggestion: KeywordSuggestion = SchemaField(
|
|
||||||
description="A single keyword suggestion with metrics"
|
|
||||||
)
|
|
||||||
total_count: int = SchemaField(
|
|
||||||
description="Total number of suggestions returned"
|
|
||||||
)
|
|
||||||
seed_keyword: str = SchemaField(
|
|
||||||
description="The seed keyword used for the query"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="73c3e7c4-2b3f-4e9f-9e3e-8f7a5c3e2d45",
|
|
||||||
description="Get keyword suggestions from DataForSEO Labs Google API",
|
|
||||||
categories={BlockCategory.SEARCH, BlockCategory.DATA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
test_input={
|
|
||||||
"credentials": dataforseo.get_test_credentials().model_dump(),
|
|
||||||
"keyword": "digital marketing",
|
|
||||||
"location_code": 2840,
|
|
||||||
"language_code": "en",
|
|
||||||
"limit": 1,
|
|
||||||
},
|
|
||||||
test_credentials=dataforseo.get_test_credentials(),
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"suggestion",
|
|
||||||
lambda x: hasattr(x, "keyword")
|
|
||||||
and x.keyword == "digital marketing strategy",
|
|
||||||
),
|
|
||||||
("suggestions", lambda x: isinstance(x, list) and len(x) == 1),
|
|
||||||
("total_count", 1),
|
|
||||||
("seed_keyword", "digital marketing"),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_fetch_keyword_suggestions": lambda *args, **kwargs: [
|
|
||||||
{
|
|
||||||
"items": [
|
|
||||||
{
|
|
||||||
"keyword": "digital marketing strategy",
|
|
||||||
"keyword_info": {
|
|
||||||
"search_volume": 10000,
|
|
||||||
"competition": 0.5,
|
|
||||||
"cpc": 2.5,
|
|
||||||
},
|
|
||||||
"keyword_properties": {
|
|
||||||
"keyword_difficulty": 50,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _fetch_keyword_suggestions(
|
|
||||||
self,
|
|
||||||
client: DataForSeoClient,
|
|
||||||
input_data: Input,
|
|
||||||
) -> Any:
|
|
||||||
"""Private method to fetch keyword suggestions - can be mocked for testing."""
|
|
||||||
return await client.keyword_suggestions(
|
|
||||||
keyword=input_data.keyword,
|
|
||||||
location_code=input_data.location_code,
|
|
||||||
language_code=input_data.language_code,
|
|
||||||
include_seed_keyword=input_data.include_seed_keyword,
|
|
||||||
include_serp_info=input_data.include_serp_info,
|
|
||||||
include_clickstream_data=input_data.include_clickstream_data,
|
|
||||||
limit=input_data.limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: UserPasswordCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
"""Execute the keyword suggestions query."""
|
|
||||||
client = DataForSeoClient(credentials)
|
|
||||||
|
|
||||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
|
||||||
|
|
||||||
# Process and format the results
|
|
||||||
suggestions = []
|
|
||||||
if results and len(results) > 0:
|
|
||||||
# results is a list, get the first element
|
|
||||||
first_result = results[0] if isinstance(results, list) else results
|
|
||||||
items = (
|
|
||||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
|
||||||
)
|
|
||||||
for item in items:
|
|
||||||
# Create the KeywordSuggestion object
|
|
||||||
suggestion = KeywordSuggestion(
|
|
||||||
keyword=item.get("keyword", ""),
|
|
||||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
|
||||||
competition=item.get("keyword_info", {}).get("competition"),
|
|
||||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
|
||||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
|
||||||
"keyword_difficulty"
|
|
||||||
),
|
|
||||||
serp_info=(
|
|
||||||
item.get("serp_info") if input_data.include_serp_info else None
|
|
||||||
),
|
|
||||||
clickstream_data=(
|
|
||||||
item.get("clickstream_keyword_info")
|
|
||||||
if input_data.include_clickstream_data
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
yield "suggestion", suggestion
|
|
||||||
suggestions.append(suggestion)
|
|
||||||
|
|
||||||
yield "suggestions", suggestions
|
|
||||||
yield "total_count", len(suggestions)
|
|
||||||
yield "seed_keyword", input_data.keyword
|
|
||||||
|
|
||||||
|
|
||||||
class KeywordSuggestionExtractorBlock(Block):
|
|
||||||
"""Extracts individual fields from a KeywordSuggestion object."""
|
|
||||||
|
|
||||||
class Input(BlockSchema):
|
|
||||||
suggestion: KeywordSuggestion = SchemaField(
|
|
||||||
description="The keyword suggestion object to extract fields from"
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchema):
|
|
||||||
keyword: str = SchemaField(description="The keyword suggestion")
|
|
||||||
search_volume: Optional[int] = SchemaField(
|
|
||||||
description="Monthly search volume", default=None
|
|
||||||
)
|
|
||||||
competition: Optional[float] = SchemaField(
|
|
||||||
description="Competition level (0-1)", default=None
|
|
||||||
)
|
|
||||||
cpc: Optional[float] = SchemaField(
|
|
||||||
description="Cost per click in USD", default=None
|
|
||||||
)
|
|
||||||
keyword_difficulty: Optional[int] = SchemaField(
|
|
||||||
description="Keyword difficulty score", default=None
|
|
||||||
)
|
|
||||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
|
||||||
description="data from SERP for each keyword", default=None
|
|
||||||
)
|
|
||||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
|
||||||
description="Clickstream data metrics", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="4193cb94-677c-48b0-9eec-6ac72fffd0f2",
|
|
||||||
description="Extract individual fields from a KeywordSuggestion object",
|
|
||||||
categories={BlockCategory.DATA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
test_input={
|
|
||||||
"suggestion": KeywordSuggestion(
|
|
||||||
keyword="test keyword",
|
|
||||||
search_volume=1000,
|
|
||||||
competition=0.5,
|
|
||||||
cpc=2.5,
|
|
||||||
keyword_difficulty=60,
|
|
||||||
).model_dump()
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
("keyword", "test keyword"),
|
|
||||||
("search_volume", 1000),
|
|
||||||
("competition", 0.5),
|
|
||||||
("cpc", 2.5),
|
|
||||||
("keyword_difficulty", 60),
|
|
||||||
("serp_info", None),
|
|
||||||
("clickstream_data", None),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
"""Extract fields from the KeywordSuggestion object."""
|
|
||||||
suggestion = input_data.suggestion
|
|
||||||
|
|
||||||
yield "keyword", suggestion.keyword
|
|
||||||
yield "search_volume", suggestion.search_volume
|
|
||||||
yield "competition", suggestion.competition
|
|
||||||
yield "cpc", suggestion.cpc
|
|
||||||
yield "keyword_difficulty", suggestion.keyword_difficulty
|
|
||||||
yield "serp_info", suggestion.serp_info
|
|
||||||
yield "clickstream_data", suggestion.clickstream_data
|
|
||||||
@@ -1,283 +0,0 @@
|
|||||||
"""
|
|
||||||
DataForSEO Google Related Keywords block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from backend.sdk import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchema,
|
|
||||||
CredentialsMetaInput,
|
|
||||||
SchemaField,
|
|
||||||
UserPasswordCredentials,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ._api import DataForSeoClient
|
|
||||||
from ._config import dataforseo
|
|
||||||
|
|
||||||
|
|
||||||
class RelatedKeyword(BlockSchema):
|
|
||||||
"""Schema for a related keyword result."""
|
|
||||||
|
|
||||||
keyword: str = SchemaField(description="The related keyword")
|
|
||||||
search_volume: Optional[int] = SchemaField(
|
|
||||||
description="Monthly search volume", default=None
|
|
||||||
)
|
|
||||||
competition: Optional[float] = SchemaField(
|
|
||||||
description="Competition level (0-1)", default=None
|
|
||||||
)
|
|
||||||
cpc: Optional[float] = SchemaField(
|
|
||||||
description="Cost per click in USD", default=None
|
|
||||||
)
|
|
||||||
keyword_difficulty: Optional[int] = SchemaField(
|
|
||||||
description="Keyword difficulty score", default=None
|
|
||||||
)
|
|
||||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
|
||||||
description="SERP data for the keyword", default=None
|
|
||||||
)
|
|
||||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
|
||||||
description="Clickstream data metrics", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DataForSeoRelatedKeywordsBlock(Block):
|
|
||||||
"""Block for getting related keywords from DataForSEO Labs."""
|
|
||||||
|
|
||||||
class Input(BlockSchema):
|
|
||||||
credentials: CredentialsMetaInput = dataforseo.credentials_field(
|
|
||||||
description="DataForSEO credentials (username and password)"
|
|
||||||
)
|
|
||||||
keyword: str = SchemaField(
|
|
||||||
description="Seed keyword to find related keywords for"
|
|
||||||
)
|
|
||||||
location_code: Optional[int] = SchemaField(
|
|
||||||
description="Location code for targeting (e.g., 2840 for USA)",
|
|
||||||
default=2840, # USA
|
|
||||||
)
|
|
||||||
language_code: Optional[str] = SchemaField(
|
|
||||||
description="Language code (e.g., 'en' for English)",
|
|
||||||
default="en",
|
|
||||||
)
|
|
||||||
include_seed_keyword: bool = SchemaField(
|
|
||||||
description="Include the seed keyword in results",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
include_serp_info: bool = SchemaField(
|
|
||||||
description="Include SERP information",
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
include_clickstream_data: bool = SchemaField(
|
|
||||||
description="Include clickstream metrics",
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
limit: int = SchemaField(
|
|
||||||
description="Maximum number of results (up to 3000)",
|
|
||||||
default=100,
|
|
||||||
ge=1,
|
|
||||||
le=3000,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchema):
|
|
||||||
related_keywords: List[RelatedKeyword] = SchemaField(
|
|
||||||
description="List of related keywords with metrics"
|
|
||||||
)
|
|
||||||
related_keyword: RelatedKeyword = SchemaField(
|
|
||||||
description="A related keyword with metrics"
|
|
||||||
)
|
|
||||||
total_count: int = SchemaField(
|
|
||||||
description="Total number of related keywords returned"
|
|
||||||
)
|
|
||||||
seed_keyword: str = SchemaField(
|
|
||||||
description="The seed keyword used for the query"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8f2e4d6a-1b3c-4a5e-9d7f-2c8e6a4b3f1d",
|
|
||||||
description="Get related keywords from DataForSEO Labs Google API",
|
|
||||||
categories={BlockCategory.SEARCH, BlockCategory.DATA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
test_input={
|
|
||||||
"credentials": dataforseo.get_test_credentials().model_dump(),
|
|
||||||
"keyword": "content marketing",
|
|
||||||
"location_code": 2840,
|
|
||||||
"language_code": "en",
|
|
||||||
"limit": 1,
|
|
||||||
},
|
|
||||||
test_credentials=dataforseo.get_test_credentials(),
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"related_keyword",
|
|
||||||
lambda x: hasattr(x, "keyword") and x.keyword == "content strategy",
|
|
||||||
),
|
|
||||||
("related_keywords", lambda x: isinstance(x, list) and len(x) == 1),
|
|
||||||
("total_count", 1),
|
|
||||||
("seed_keyword", "content marketing"),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_fetch_related_keywords": lambda *args, **kwargs: [
|
|
||||||
{
|
|
||||||
"items": [
|
|
||||||
{
|
|
||||||
"keyword_data": {
|
|
||||||
"keyword": "content strategy",
|
|
||||||
"keyword_info": {
|
|
||||||
"search_volume": 8000,
|
|
||||||
"competition": 0.4,
|
|
||||||
"cpc": 3.0,
|
|
||||||
},
|
|
||||||
"keyword_properties": {
|
|
||||||
"keyword_difficulty": 45,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _fetch_related_keywords(
|
|
||||||
self,
|
|
||||||
client: DataForSeoClient,
|
|
||||||
input_data: Input,
|
|
||||||
) -> Any:
|
|
||||||
"""Private method to fetch related keywords - can be mocked for testing."""
|
|
||||||
return await client.related_keywords(
|
|
||||||
keyword=input_data.keyword,
|
|
||||||
location_code=input_data.location_code,
|
|
||||||
language_code=input_data.language_code,
|
|
||||||
include_seed_keyword=input_data.include_seed_keyword,
|
|
||||||
include_serp_info=input_data.include_serp_info,
|
|
||||||
include_clickstream_data=input_data.include_clickstream_data,
|
|
||||||
limit=input_data.limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: UserPasswordCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
"""Execute the related keywords query."""
|
|
||||||
client = DataForSeoClient(credentials)
|
|
||||||
|
|
||||||
results = await self._fetch_related_keywords(client, input_data)
|
|
||||||
|
|
||||||
# Process and format the results
|
|
||||||
related_keywords = []
|
|
||||||
if results and len(results) > 0:
|
|
||||||
# results is a list, get the first element
|
|
||||||
first_result = results[0] if isinstance(results, list) else results
|
|
||||||
items = (
|
|
||||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
|
||||||
)
|
|
||||||
for item in items:
|
|
||||||
# Extract keyword_data from the item
|
|
||||||
keyword_data = item.get("keyword_data", {})
|
|
||||||
|
|
||||||
# Create the RelatedKeyword object
|
|
||||||
keyword = RelatedKeyword(
|
|
||||||
keyword=keyword_data.get("keyword", ""),
|
|
||||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
|
||||||
"search_volume"
|
|
||||||
),
|
|
||||||
competition=keyword_data.get("keyword_info", {}).get("competition"),
|
|
||||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
|
||||||
keyword_difficulty=keyword_data.get("keyword_properties", {}).get(
|
|
||||||
"keyword_difficulty"
|
|
||||||
),
|
|
||||||
serp_info=(
|
|
||||||
keyword_data.get("serp_info")
|
|
||||||
if input_data.include_serp_info
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
clickstream_data=(
|
|
||||||
keyword_data.get("clickstream_keyword_info")
|
|
||||||
if input_data.include_clickstream_data
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
yield "related_keyword", keyword
|
|
||||||
related_keywords.append(keyword)
|
|
||||||
|
|
||||||
yield "related_keywords", related_keywords
|
|
||||||
yield "total_count", len(related_keywords)
|
|
||||||
yield "seed_keyword", input_data.keyword
|
|
||||||
|
|
||||||
|
|
||||||
class RelatedKeywordExtractorBlock(Block):
|
|
||||||
"""Extracts individual fields from a RelatedKeyword object."""
|
|
||||||
|
|
||||||
class Input(BlockSchema):
|
|
||||||
related_keyword: RelatedKeyword = SchemaField(
|
|
||||||
description="The related keyword object to extract fields from"
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchema):
|
|
||||||
keyword: str = SchemaField(description="The related keyword")
|
|
||||||
search_volume: Optional[int] = SchemaField(
|
|
||||||
description="Monthly search volume", default=None
|
|
||||||
)
|
|
||||||
competition: Optional[float] = SchemaField(
|
|
||||||
description="Competition level (0-1)", default=None
|
|
||||||
)
|
|
||||||
cpc: Optional[float] = SchemaField(
|
|
||||||
description="Cost per click in USD", default=None
|
|
||||||
)
|
|
||||||
keyword_difficulty: Optional[int] = SchemaField(
|
|
||||||
description="Keyword difficulty score", default=None
|
|
||||||
)
|
|
||||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
|
||||||
description="SERP data for the keyword", default=None
|
|
||||||
)
|
|
||||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
|
||||||
description="Clickstream data metrics", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="98342061-09d2-4952-bf77-0761fc8cc9a8",
|
|
||||||
description="Extract individual fields from a RelatedKeyword object",
|
|
||||||
categories={BlockCategory.DATA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
test_input={
|
|
||||||
"related_keyword": RelatedKeyword(
|
|
||||||
keyword="test related keyword",
|
|
||||||
search_volume=800,
|
|
||||||
competition=0.4,
|
|
||||||
cpc=3.0,
|
|
||||||
keyword_difficulty=55,
|
|
||||||
).model_dump()
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
("keyword", "test related keyword"),
|
|
||||||
("search_volume", 800),
|
|
||||||
("competition", 0.4),
|
|
||||||
("cpc", 3.0),
|
|
||||||
("keyword_difficulty", 55),
|
|
||||||
("serp_info", None),
|
|
||||||
("clickstream_data", None),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
"""Extract fields from the RelatedKeyword object."""
|
|
||||||
related_keyword = input_data.related_keyword
|
|
||||||
|
|
||||||
yield "keyword", related_keyword.keyword
|
|
||||||
yield "search_volume", related_keyword.search_volume
|
|
||||||
yield "competition", related_keyword.competition
|
|
||||||
yield "cpc", related_keyword.cpc
|
|
||||||
yield "keyword_difficulty", related_keyword.keyword_difficulty
|
|
||||||
yield "serp_info", related_keyword.serp_info
|
|
||||||
yield "clickstream_data", related_keyword.clickstream_data
|
|
||||||
@@ -2,29 +2,45 @@ import base64
|
|||||||
import io
|
import io
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import discord
|
import discord
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||||
from backend.data.model import APIKeyCredentials, SchemaField
|
from backend.data.model import (
|
||||||
|
APIKeyCredentials,
|
||||||
|
CredentialsField,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
SchemaField,
|
||||||
|
)
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
from ._auth import (
|
DiscordCredentials = CredentialsMetaInput[
|
||||||
TEST_BOT_CREDENTIALS,
|
Literal[ProviderName.DISCORD], Literal["api_key"]
|
||||||
TEST_BOT_CREDENTIALS_INPUT,
|
]
|
||||||
DiscordBotCredentialsField,
|
|
||||||
DiscordBotCredentialsInput,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Keep backward compatibility alias
|
|
||||||
DiscordCredentials = DiscordBotCredentialsInput
|
def DiscordCredentialsField() -> DiscordCredentials:
|
||||||
DiscordCredentialsField = DiscordBotCredentialsField
|
return CredentialsField(description="Discord bot token")
|
||||||
TEST_CREDENTIALS = TEST_BOT_CREDENTIALS
|
|
||||||
TEST_CREDENTIALS_INPUT = TEST_BOT_CREDENTIALS_INPUT
|
|
||||||
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
|
provider="discord",
|
||||||
|
api_key=SecretStr("test_api_key"),
|
||||||
|
title="Mock Discord API key",
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
TEST_CREDENTIALS_INPUT = {
|
||||||
|
"provider": TEST_CREDENTIALS.provider,
|
||||||
|
"id": TEST_CREDENTIALS.id,
|
||||||
|
"type": TEST_CREDENTIALS.type,
|
||||||
|
"title": TEST_CREDENTIALS.type,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ReadDiscordMessagesBlock(Block):
|
class ReadDiscordMessagesBlock(Block):
|
||||||
@@ -1,117 +0,0 @@
|
|||||||
"""
|
|
||||||
Discord API helper functions for making authenticated requests.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordAPIException(Exception):
|
|
||||||
"""Exception raised for Discord API errors."""
|
|
||||||
|
|
||||||
def __init__(self, message: str, status_code: int):
|
|
||||||
super().__init__(message)
|
|
||||||
self.status_code = status_code
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordOAuthUser(BaseModel):
|
|
||||||
"""Model for Discord OAuth user response."""
|
|
||||||
|
|
||||||
user_id: str
|
|
||||||
username: str
|
|
||||||
avatar_url: str
|
|
||||||
banner: Optional[str] = None
|
|
||||||
accent_color: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_api(credentials: OAuth2Credentials) -> Requests:
|
|
||||||
"""
|
|
||||||
Create a Requests instance configured for Discord API calls with OAuth2 credentials.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
credentials: The OAuth2 credentials containing the access token.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A configured Requests instance for Discord API calls.
|
|
||||||
"""
|
|
||||||
return Requests(
|
|
||||||
trusted_origins=[],
|
|
||||||
extra_headers={
|
|
||||||
"Authorization": f"Bearer {credentials.access_token.get_secret_value()}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
raise_for_status=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(credentials: OAuth2Credentials) -> DiscordOAuthUser:
|
|
||||||
"""
|
|
||||||
Fetch the current user's information using Discord OAuth2 API.
|
|
||||||
|
|
||||||
Reference: https://discord.com/developers/docs/resources/user#get-current-user
|
|
||||||
|
|
||||||
Args:
|
|
||||||
credentials: The OAuth2 credentials.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A model containing user data with avatar URL.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
DiscordAPIException: If the API request fails.
|
|
||||||
"""
|
|
||||||
api = get_api(credentials)
|
|
||||||
response = await api.get("https://discord.com/api/oauth2/@me")
|
|
||||||
|
|
||||||
if not response.ok:
|
|
||||||
error_text = response.text()
|
|
||||||
raise DiscordAPIException(
|
|
||||||
f"Failed to fetch user info: {response.status} - {error_text}",
|
|
||||||
response.status,
|
|
||||||
)
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
logger.info(f"Discord OAuth2 API Response: {data}")
|
|
||||||
|
|
||||||
# The /api/oauth2/@me endpoint returns a user object nested in the response
|
|
||||||
user_info = data.get("user", {})
|
|
||||||
logger.info(f"User info extracted: {user_info}")
|
|
||||||
|
|
||||||
# Build avatar URL
|
|
||||||
user_id = user_info.get("id")
|
|
||||||
avatar_hash = user_info.get("avatar")
|
|
||||||
if avatar_hash:
|
|
||||||
# Custom avatar
|
|
||||||
avatar_ext = "gif" if avatar_hash.startswith("a_") else "png"
|
|
||||||
avatar_url = (
|
|
||||||
f"https://cdn.discordapp.com/avatars/{user_id}/{avatar_hash}.{avatar_ext}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Default avatar based on discriminator or user ID
|
|
||||||
discriminator = user_info.get("discriminator", "0")
|
|
||||||
if discriminator == "0":
|
|
||||||
# New username system - use user ID for default avatar
|
|
||||||
default_avatar_index = (int(user_id) >> 22) % 6
|
|
||||||
else:
|
|
||||||
# Legacy discriminator system
|
|
||||||
default_avatar_index = int(discriminator) % 5
|
|
||||||
avatar_url = (
|
|
||||||
f"https://cdn.discordapp.com/embed/avatars/{default_avatar_index}.png"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = DiscordOAuthUser(
|
|
||||||
user_id=user_id,
|
|
||||||
username=user_info.get("username", ""),
|
|
||||||
avatar_url=avatar_url,
|
|
||||||
banner=user_info.get("banner"),
|
|
||||||
accent_color=user_info.get("accent_color"),
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Returning user data: {result.model_dump()}")
|
|
||||||
return result
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
from typing import Literal
|
|
||||||
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import (
|
|
||||||
APIKeyCredentials,
|
|
||||||
CredentialsField,
|
|
||||||
CredentialsMetaInput,
|
|
||||||
OAuth2Credentials,
|
|
||||||
)
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.settings import Secrets
|
|
||||||
|
|
||||||
secrets = Secrets()
|
|
||||||
DISCORD_OAUTH_IS_CONFIGURED = bool(
|
|
||||||
secrets.discord_client_id and secrets.discord_client_secret
|
|
||||||
)
|
|
||||||
|
|
||||||
# Bot token credentials (existing)
|
|
||||||
DiscordBotCredentials = APIKeyCredentials
|
|
||||||
DiscordBotCredentialsInput = CredentialsMetaInput[
|
|
||||||
Literal[ProviderName.DISCORD], Literal["api_key"]
|
|
||||||
]
|
|
||||||
|
|
||||||
# OAuth2 credentials (new)
|
|
||||||
DiscordOAuthCredentials = OAuth2Credentials
|
|
||||||
DiscordOAuthCredentialsInput = CredentialsMetaInput[
|
|
||||||
Literal[ProviderName.DISCORD], Literal["oauth2"]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def DiscordBotCredentialsField() -> DiscordBotCredentialsInput:
|
|
||||||
"""Creates a Discord bot token credentials field."""
|
|
||||||
return CredentialsField(description="Discord bot token")
|
|
||||||
|
|
||||||
|
|
||||||
def DiscordOAuthCredentialsField(scopes: list[str]) -> DiscordOAuthCredentialsInput:
|
|
||||||
"""Creates a Discord OAuth2 credentials field."""
|
|
||||||
return CredentialsField(
|
|
||||||
description="Discord OAuth2 credentials",
|
|
||||||
required_scopes=set(scopes) | {"identify"}, # Basic user info scope
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Test credentials for bot tokens
|
|
||||||
TEST_BOT_CREDENTIALS = APIKeyCredentials(
|
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
|
||||||
provider="discord",
|
|
||||||
api_key=SecretStr("test_api_key"),
|
|
||||||
title="Mock Discord API key",
|
|
||||||
expires_at=None,
|
|
||||||
)
|
|
||||||
TEST_BOT_CREDENTIALS_INPUT = {
|
|
||||||
"provider": TEST_BOT_CREDENTIALS.provider,
|
|
||||||
"id": TEST_BOT_CREDENTIALS.id,
|
|
||||||
"type": TEST_BOT_CREDENTIALS.type,
|
|
||||||
"title": TEST_BOT_CREDENTIALS.type,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Test credentials for OAuth2
|
|
||||||
TEST_OAUTH_CREDENTIALS = OAuth2Credentials(
|
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
|
||||||
provider="discord",
|
|
||||||
access_token=SecretStr("test_access_token"),
|
|
||||||
title="Mock Discord OAuth",
|
|
||||||
scopes=["identify"],
|
|
||||||
username="testuser",
|
|
||||||
)
|
|
||||||
TEST_OAUTH_CREDENTIALS_INPUT = {
|
|
||||||
"provider": TEST_OAUTH_CREDENTIALS.provider,
|
|
||||||
"id": TEST_OAUTH_CREDENTIALS.id,
|
|
||||||
"type": TEST_OAUTH_CREDENTIALS.type,
|
|
||||||
"title": TEST_OAUTH_CREDENTIALS.type,
|
|
||||||
}
|
|
||||||
@@ -1,99 +0,0 @@
|
|||||||
"""
|
|
||||||
Discord OAuth-based blocks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
|
||||||
from backend.data.model import OAuth2Credentials, SchemaField
|
|
||||||
|
|
||||||
from ._api import DiscordOAuthUser, get_current_user
|
|
||||||
from ._auth import (
|
|
||||||
DISCORD_OAUTH_IS_CONFIGURED,
|
|
||||||
TEST_OAUTH_CREDENTIALS,
|
|
||||||
TEST_OAUTH_CREDENTIALS_INPUT,
|
|
||||||
DiscordOAuthCredentialsField,
|
|
||||||
DiscordOAuthCredentialsInput,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordGetCurrentUserBlock(Block):
|
|
||||||
"""
|
|
||||||
Gets information about the currently authenticated Discord user using OAuth2.
|
|
||||||
This block requires Discord OAuth2 credentials (not bot tokens).
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchema):
|
|
||||||
credentials: DiscordOAuthCredentialsInput = DiscordOAuthCredentialsField(
|
|
||||||
["identify"]
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchema):
|
|
||||||
user_id: str = SchemaField(description="The authenticated user's Discord ID")
|
|
||||||
username: str = SchemaField(description="The user's username")
|
|
||||||
avatar_url: str = SchemaField(description="URL to the user's avatar image")
|
|
||||||
banner_url: str = SchemaField(
|
|
||||||
description="URL to the user's banner image (if set)", default=""
|
|
||||||
)
|
|
||||||
accent_color: int = SchemaField(
|
|
||||||
description="The user's accent color as an integer", default=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8c7e39b8-4e9d-4f3a-b4e1-2a8c9d5f6e3b",
|
|
||||||
input_schema=DiscordGetCurrentUserBlock.Input,
|
|
||||||
output_schema=DiscordGetCurrentUserBlock.Output,
|
|
||||||
description="Gets information about the currently authenticated Discord user using OAuth2 credentials.",
|
|
||||||
categories={BlockCategory.SOCIAL},
|
|
||||||
disabled=not DISCORD_OAUTH_IS_CONFIGURED,
|
|
||||||
test_input={
|
|
||||||
"credentials": TEST_OAUTH_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_OAUTH_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("user_id", "123456789012345678"),
|
|
||||||
("username", "testuser"),
|
|
||||||
(
|
|
||||||
"avatar_url",
|
|
||||||
"https://cdn.discordapp.com/avatars/123456789012345678/avatar.png",
|
|
||||||
),
|
|
||||||
("banner_url", ""),
|
|
||||||
("accent_color", 0),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"get_user": lambda _: DiscordOAuthUser(
|
|
||||||
user_id="123456789012345678",
|
|
||||||
username="testuser",
|
|
||||||
avatar_url="https://cdn.discordapp.com/avatars/123456789012345678/avatar.png",
|
|
||||||
banner=None,
|
|
||||||
accent_color=0,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_user(credentials: OAuth2Credentials) -> DiscordOAuthUser:
|
|
||||||
user_info = await get_current_user(credentials)
|
|
||||||
return user_info
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self, input_data: Input, *, credentials: OAuth2Credentials, **kwargs
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
result = await self.get_user(credentials)
|
|
||||||
|
|
||||||
# Yield each output field
|
|
||||||
yield "user_id", result.user_id
|
|
||||||
yield "username", result.username
|
|
||||||
yield "avatar_url", result.avatar_url
|
|
||||||
|
|
||||||
# Handle banner URL if banner hash exists
|
|
||||||
if result.banner:
|
|
||||||
banner_url = f"https://cdn.discordapp.com/banners/{result.user_id}/{result.banner}.png"
|
|
||||||
yield "banner_url", banner_url
|
|
||||||
else:
|
|
||||||
yield "banner_url", ""
|
|
||||||
|
|
||||||
yield "accent_color", result.accent_color or 0
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Failed to get Discord user info: {e}")
|
|
||||||
@@ -39,18 +39,6 @@ def serialize_email_recipients(recipients: list[str]) -> str:
|
|||||||
return ", ".join(recipients)
|
return ", ".join(recipients)
|
||||||
|
|
||||||
|
|
||||||
def deduplicate_email_addresses(addresses: list[str]) -> list[str]:
|
|
||||||
"""Deduplicate email addresses while preserving order.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
addresses: List of email addresses that may contain duplicates or None values
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of unique email addresses with None values filtered out
|
|
||||||
"""
|
|
||||||
return list(dict.fromkeys(filter(None, addresses)))
|
|
||||||
|
|
||||||
|
|
||||||
def _make_mime_text(
|
def _make_mime_text(
|
||||||
body: str,
|
body: str,
|
||||||
content_type: Optional[Literal["auto", "plain", "html"]] = None,
|
content_type: Optional[Literal["auto", "plain", "html"]] = None,
|
||||||
@@ -1275,8 +1263,11 @@ class GmailReplyBlock(GmailBase):
|
|||||||
recipients += [
|
recipients += [
|
||||||
addr for _, addr in getaddresses([headers.get("cc", "")])
|
addr for _, addr in getaddresses([headers.get("cc", "")])
|
||||||
]
|
]
|
||||||
# Deduplicate recipients while preserving order
|
dedup: list[str] = []
|
||||||
input_data.to = deduplicate_email_addresses(recipients)
|
for r in recipients:
|
||||||
|
if r and r not in dedup:
|
||||||
|
dedup.append(r)
|
||||||
|
input_data.to = dedup
|
||||||
else:
|
else:
|
||||||
sender = parseaddr(headers.get("reply-to", headers.get("from", "")))[1]
|
sender = parseaddr(headers.get("reply-to", headers.get("from", "")))[1]
|
||||||
input_data.to = [sender] if sender else []
|
input_data.to = [sender] if sender else []
|
||||||
@@ -1326,224 +1317,6 @@ class GmailReplyBlock(GmailBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GmailCreateDraftReplyBlock(GmailBase):
|
|
||||||
"""
|
|
||||||
Creates draft replies to Gmail threads with intelligent content type detection.
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- Automatic HTML detection: Draft replies containing HTML tags are formatted as text/html
|
|
||||||
- No hard-wrap for plain text: Plain text drafts preserve natural line flow
|
|
||||||
- Manual content type override: Use content_type parameter to force specific format
|
|
||||||
- Reply-all functionality: Option to draft reply to all original recipients
|
|
||||||
- Thread preservation: Maintains proper email threading with headers
|
|
||||||
- Full Unicode/emoji support with UTF-8 encoding
|
|
||||||
- Attachment support for multiple files
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchema):
|
|
||||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
|
||||||
[
|
|
||||||
"https://www.googleapis.com/auth/gmail.modify",
|
|
||||||
"https://www.googleapis.com/auth/gmail.readonly",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
threadId: str = SchemaField(description="Thread ID to reply in")
|
|
||||||
parentMessageId: str = SchemaField(
|
|
||||||
description="ID of the message being replied to"
|
|
||||||
)
|
|
||||||
to: list[str] = SchemaField(description="To recipients", default_factory=list)
|
|
||||||
cc: list[str] = SchemaField(description="CC recipients", default_factory=list)
|
|
||||||
bcc: list[str] = SchemaField(description="BCC recipients", default_factory=list)
|
|
||||||
replyAll: bool = SchemaField(
|
|
||||||
description="Reply to all original recipients", default=False
|
|
||||||
)
|
|
||||||
subject: str = SchemaField(description="Email subject", default="")
|
|
||||||
body: str = SchemaField(description="Email body (plain text or HTML)")
|
|
||||||
content_type: Optional[Literal["auto", "plain", "html"]] = SchemaField(
|
|
||||||
description="Content type: 'auto' (default - detects HTML), 'plain', or 'html'",
|
|
||||||
default=None,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
attachments: list[MediaFileType] = SchemaField(
|
|
||||||
description="Files to attach", default_factory=list, advanced=True
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchema):
|
|
||||||
draftId: str = SchemaField(description="Created draft ID")
|
|
||||||
messageId: str = SchemaField(description="Draft message ID")
|
|
||||||
threadId: str = SchemaField(description="Thread ID")
|
|
||||||
status: str = SchemaField(description="Draft creation status")
|
|
||||||
error: str = SchemaField(description="Error message if any")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8f2e9d3c-4b1a-4c7e-9a2f-1d3e5f7a9b1c",
|
|
||||||
description="Create draft replies to Gmail threads with automatic HTML detection and proper text formatting. Drafts maintain proper email threading and can be edited before sending.",
|
|
||||||
categories={BlockCategory.COMMUNICATION},
|
|
||||||
input_schema=GmailCreateDraftReplyBlock.Input,
|
|
||||||
output_schema=GmailCreateDraftReplyBlock.Output,
|
|
||||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
|
||||||
test_input={
|
|
||||||
"threadId": "t1",
|
|
||||||
"parentMessageId": "m1",
|
|
||||||
"body": "Thanks for your message. I'll draft a response.",
|
|
||||||
"replyAll": False,
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("draftId", "draft1"),
|
|
||||||
("messageId", "msg1"),
|
|
||||||
("threadId", "t1"),
|
|
||||||
("status", "draft_reply_created"),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_create_draft_reply": lambda *args, **kwargs: {
|
|
||||||
"id": "draft1",
|
|
||||||
"message": {"id": "msg1", "threadId": "t1"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GoogleCredentials,
|
|
||||||
graph_exec_id: str,
|
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
service = self._build_service(credentials, **kwargs)
|
|
||||||
result = await self._create_draft_reply(
|
|
||||||
service,
|
|
||||||
input_data,
|
|
||||||
graph_exec_id,
|
|
||||||
user_id,
|
|
||||||
)
|
|
||||||
yield "draftId", result["id"]
|
|
||||||
yield "messageId", result["message"]["id"]
|
|
||||||
yield "threadId", result["message"].get("threadId", input_data.threadId)
|
|
||||||
yield "status", "draft_reply_created"
|
|
||||||
|
|
||||||
async def _create_draft_reply(
|
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
|
||||||
) -> dict:
|
|
||||||
# Fetch parent message metadata
|
|
||||||
parent = await asyncio.to_thread(
|
|
||||||
lambda: service.users()
|
|
||||||
.messages()
|
|
||||||
.get(
|
|
||||||
userId="me",
|
|
||||||
id=input_data.parentMessageId,
|
|
||||||
format="metadata",
|
|
||||||
metadataHeaders=[
|
|
||||||
"Subject",
|
|
||||||
"References",
|
|
||||||
"Message-ID",
|
|
||||||
"From",
|
|
||||||
"To",
|
|
||||||
"Cc",
|
|
||||||
"Reply-To",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
.execute()
|
|
||||||
)
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
h["name"].lower(): h["value"]
|
|
||||||
for h in parent.get("payload", {}).get("headers", [])
|
|
||||||
}
|
|
||||||
|
|
||||||
# Auto-populate recipients if not provided
|
|
||||||
if not (input_data.to or input_data.cc or input_data.bcc):
|
|
||||||
if input_data.replyAll:
|
|
||||||
# Reply all - include all original recipients
|
|
||||||
recipients = [parseaddr(headers.get("from", ""))[1]]
|
|
||||||
recipients += [
|
|
||||||
addr for _, addr in getaddresses([headers.get("to", "")])
|
|
||||||
]
|
|
||||||
recipients += [
|
|
||||||
addr for _, addr in getaddresses([headers.get("cc", "")])
|
|
||||||
]
|
|
||||||
# Deduplicate recipients
|
|
||||||
dedup: list[str] = []
|
|
||||||
for r in recipients:
|
|
||||||
if r and r not in dedup:
|
|
||||||
dedup.append(r)
|
|
||||||
input_data.to = dedup
|
|
||||||
else:
|
|
||||||
# Reply to sender only
|
|
||||||
sender = parseaddr(headers.get("reply-to", headers.get("from", "")))[1]
|
|
||||||
input_data.to = [sender] if sender else []
|
|
||||||
|
|
||||||
# Generate subject with Re: prefix if needed
|
|
||||||
subject = input_data.subject or (f"Re: {headers.get('subject', '')}".strip())
|
|
||||||
|
|
||||||
# Build References header chain
|
|
||||||
references = headers.get("references", "").split()
|
|
||||||
if headers.get("message-id"):
|
|
||||||
references.append(headers["message-id"])
|
|
||||||
|
|
||||||
# Create MIME message with threading headers
|
|
||||||
msg = MIMEMultipart()
|
|
||||||
if input_data.to:
|
|
||||||
msg["To"] = ", ".join(input_data.to)
|
|
||||||
if input_data.cc:
|
|
||||||
msg["Cc"] = ", ".join(input_data.cc)
|
|
||||||
if input_data.bcc:
|
|
||||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
|
||||||
msg["Subject"] = subject
|
|
||||||
|
|
||||||
# Set threading headers for proper conversation grouping
|
|
||||||
if headers.get("message-id"):
|
|
||||||
msg["In-Reply-To"] = headers["message-id"]
|
|
||||||
if references:
|
|
||||||
msg["References"] = " ".join(references)
|
|
||||||
|
|
||||||
# Add body with proper content type handling
|
|
||||||
msg.attach(_make_mime_text(input_data.body, input_data.content_type))
|
|
||||||
|
|
||||||
# Handle attachments if any
|
|
||||||
for attach in input_data.attachments:
|
|
||||||
local_path = await store_media_file(
|
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=attach,
|
|
||||||
return_content=False,
|
|
||||||
)
|
|
||||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
|
||||||
part = MIMEBase("application", "octet-stream")
|
|
||||||
with open(abs_path, "rb") as f:
|
|
||||||
part.set_payload(f.read())
|
|
||||||
encoders.encode_base64(part)
|
|
||||||
part.add_header(
|
|
||||||
"Content-Disposition", f"attachment; filename={Path(abs_path).name}"
|
|
||||||
)
|
|
||||||
msg.attach(part)
|
|
||||||
|
|
||||||
# Encode message for Gmail API
|
|
||||||
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
|
|
||||||
|
|
||||||
# Create draft with threadId to ensure it appears as a reply
|
|
||||||
draft = await asyncio.to_thread(
|
|
||||||
lambda: service.users()
|
|
||||||
.drafts()
|
|
||||||
.create(
|
|
||||||
userId="me",
|
|
||||||
body={
|
|
||||||
"message": {
|
|
||||||
"threadId": input_data.threadId,
|
|
||||||
"raw": raw,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.execute()
|
|
||||||
)
|
|
||||||
|
|
||||||
return draft
|
|
||||||
|
|
||||||
|
|
||||||
class GmailGetProfileBlock(GmailBase):
|
class GmailGetProfileBlock(GmailBase):
|
||||||
class Input(BlockSchema):
|
class Input(BlockSchema):
|
||||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
|
|
||||||
@@ -1,283 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.clients import get_database_manager_async_client
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# Duplicate pydantic models for store data so we don't accidently change the data shape in the blocks unintentionally when editing the backend
|
|
||||||
class LibraryAgent(BaseModel):
|
|
||||||
"""Model representing an agent in the user's library."""
|
|
||||||
|
|
||||||
library_agent_id: str = ""
|
|
||||||
agent_id: str = ""
|
|
||||||
agent_version: int = 0
|
|
||||||
agent_name: str = ""
|
|
||||||
description: str = ""
|
|
||||||
creator: str = ""
|
|
||||||
is_archived: bool = False
|
|
||||||
categories: list[str] = []
|
|
||||||
|
|
||||||
|
|
||||||
class AddToLibraryFromStoreBlock(Block):
|
|
||||||
"""
|
|
||||||
Block that adds an agent from the store to the user's library.
|
|
||||||
This enables users to easily import agents from the marketplace into their personal collection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchema):
|
|
||||||
store_listing_version_id: str = SchemaField(
|
|
||||||
description="The ID of the store listing version to add to library"
|
|
||||||
)
|
|
||||||
agent_name: str | None = SchemaField(
|
|
||||||
description="Optional custom name for the agent in your library",
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchema):
|
|
||||||
success: bool = SchemaField(
|
|
||||||
description="Whether the agent was successfully added to library"
|
|
||||||
)
|
|
||||||
library_agent_id: str = SchemaField(
|
|
||||||
description="The ID of the library agent entry"
|
|
||||||
)
|
|
||||||
agent_id: str = SchemaField(description="The ID of the agent graph")
|
|
||||||
agent_version: int = SchemaField(
|
|
||||||
description="The version number of the agent graph"
|
|
||||||
)
|
|
||||||
agent_name: str = SchemaField(description="The name of the agent")
|
|
||||||
message: str = SchemaField(description="Success or error message")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="2602a7b1-3f4d-4e5f-9c8b-1a2b3c4d5e6f",
|
|
||||||
description="Add an agent from the store to your personal library",
|
|
||||||
categories={BlockCategory.BASIC},
|
|
||||||
input_schema=AddToLibraryFromStoreBlock.Input,
|
|
||||||
output_schema=AddToLibraryFromStoreBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"store_listing_version_id": "test-listing-id",
|
|
||||||
"agent_name": "My Custom Agent",
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
("success", True),
|
|
||||||
("library_agent_id", "test-library-id"),
|
|
||||||
("agent_id", "test-agent-id"),
|
|
||||||
("agent_version", 1),
|
|
||||||
("agent_name", "Test Agent"),
|
|
||||||
("message", "Agent successfully added to library"),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_add_to_library": lambda *_, **__: LibraryAgent(
|
|
||||||
library_agent_id="test-library-id",
|
|
||||||
agent_id="test-agent-id",
|
|
||||||
agent_version=1,
|
|
||||||
agent_name="Test Agent",
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
library_agent = await self._add_to_library(
|
|
||||||
user_id=user_id,
|
|
||||||
store_listing_version_id=input_data.store_listing_version_id,
|
|
||||||
custom_name=input_data.agent_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "success", True
|
|
||||||
yield "library_agent_id", library_agent.library_agent_id
|
|
||||||
yield "agent_id", library_agent.agent_id
|
|
||||||
yield "agent_version", library_agent.agent_version
|
|
||||||
yield "agent_name", library_agent.agent_name
|
|
||||||
yield "message", "Agent successfully added to library"
|
|
||||||
|
|
||||||
async def _add_to_library(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
store_listing_version_id: str,
|
|
||||||
custom_name: str | None = None,
|
|
||||||
) -> LibraryAgent:
|
|
||||||
"""
|
|
||||||
Add a store agent to the user's library using the existing library database function.
|
|
||||||
"""
|
|
||||||
library_agent = (
|
|
||||||
await get_database_manager_async_client().add_store_agent_to_library(
|
|
||||||
store_listing_version_id=store_listing_version_id, user_id=user_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# If custom name is provided, we could update the library agent name here
|
|
||||||
# For now, we'll just return the agent info
|
|
||||||
agent_name = custom_name if custom_name else library_agent.name
|
|
||||||
|
|
||||||
return LibraryAgent(
|
|
||||||
library_agent_id=library_agent.id,
|
|
||||||
agent_id=library_agent.graph_id,
|
|
||||||
agent_version=library_agent.graph_version,
|
|
||||||
agent_name=agent_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ListLibraryAgentsBlock(Block):
|
|
||||||
"""
|
|
||||||
Block that lists all agents in the user's library.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchema):
|
|
||||||
search_query: str | None = SchemaField(
|
|
||||||
description="Optional search query to filter agents", default=None
|
|
||||||
)
|
|
||||||
limit: int = SchemaField(
|
|
||||||
description="Maximum number of agents to return", default=50, ge=1, le=100
|
|
||||||
)
|
|
||||||
page: int = SchemaField(
|
|
||||||
description="Page number for pagination", default=1, ge=1
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchema):
|
|
||||||
agents: list[LibraryAgent] = SchemaField(
|
|
||||||
description="List of agents in the library",
|
|
||||||
default_factory=list,
|
|
||||||
)
|
|
||||||
agent: LibraryAgent = SchemaField(
|
|
||||||
description="Individual library agent (yielded for each agent)"
|
|
||||||
)
|
|
||||||
total_count: int = SchemaField(
|
|
||||||
description="Total number of agents in library", default=0
|
|
||||||
)
|
|
||||||
page: int = SchemaField(description="Current page number", default=1)
|
|
||||||
total_pages: int = SchemaField(description="Total number of pages", default=1)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="082602d3-a74d-4600-9e9c-15b3af7eae98",
|
|
||||||
description="List all agents in your personal library",
|
|
||||||
categories={BlockCategory.BASIC, BlockCategory.DATA},
|
|
||||||
input_schema=ListLibraryAgentsBlock.Input,
|
|
||||||
output_schema=ListLibraryAgentsBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"search_query": None,
|
|
||||||
"limit": 10,
|
|
||||||
"page": 1,
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"agents",
|
|
||||||
[
|
|
||||||
LibraryAgent(
|
|
||||||
library_agent_id="test-lib-id",
|
|
||||||
agent_id="test-agent-id",
|
|
||||||
agent_version=1,
|
|
||||||
agent_name="Test Library Agent",
|
|
||||||
description="A test agent in library",
|
|
||||||
creator="Test User",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
("total_count", 1),
|
|
||||||
("page", 1),
|
|
||||||
("total_pages", 1),
|
|
||||||
(
|
|
||||||
"agent",
|
|
||||||
LibraryAgent(
|
|
||||||
library_agent_id="test-lib-id",
|
|
||||||
agent_id="test-agent-id",
|
|
||||||
agent_version=1,
|
|
||||||
agent_name="Test Library Agent",
|
|
||||||
description="A test agent in library",
|
|
||||||
creator="Test User",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_list_library_agents": lambda *_, **__: {
|
|
||||||
"agents": [
|
|
||||||
LibraryAgent(
|
|
||||||
library_agent_id="test-lib-id",
|
|
||||||
agent_id="test-agent-id",
|
|
||||||
agent_version=1,
|
|
||||||
agent_name="Test Library Agent",
|
|
||||||
description="A test agent in library",
|
|
||||||
creator="Test User",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
"total": 1,
|
|
||||||
"page": 1,
|
|
||||||
"total_pages": 1,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
result = await self._list_library_agents(
|
|
||||||
user_id=user_id,
|
|
||||||
search_query=input_data.search_query,
|
|
||||||
limit=input_data.limit,
|
|
||||||
page=input_data.page,
|
|
||||||
)
|
|
||||||
|
|
||||||
agents = result["agents"]
|
|
||||||
|
|
||||||
yield "agents", agents
|
|
||||||
yield "total_count", result["total"]
|
|
||||||
yield "page", result["page"]
|
|
||||||
yield "total_pages", result["total_pages"]
|
|
||||||
|
|
||||||
# Yield each agent individually for better graph connectivity
|
|
||||||
for agent in agents:
|
|
||||||
yield "agent", agent
|
|
||||||
|
|
||||||
async def _list_library_agents(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
search_query: str | None = None,
|
|
||||||
limit: int = 50,
|
|
||||||
page: int = 1,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
List agents in the user's library using the database client.
|
|
||||||
"""
|
|
||||||
result = await get_database_manager_async_client().list_library_agents(
|
|
||||||
user_id=user_id,
|
|
||||||
search_term=search_query,
|
|
||||||
page=page,
|
|
||||||
page_size=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
agents = [
|
|
||||||
LibraryAgent(
|
|
||||||
library_agent_id=agent.id,
|
|
||||||
agent_id=agent.graph_id,
|
|
||||||
agent_version=agent.graph_version,
|
|
||||||
agent_name=agent.name,
|
|
||||||
description=getattr(agent, "description", ""),
|
|
||||||
creator=getattr(agent, "creator", ""),
|
|
||||||
is_archived=getattr(agent, "is_archived", False),
|
|
||||||
categories=getattr(agent, "categories", []),
|
|
||||||
)
|
|
||||||
for agent in result.agents
|
|
||||||
]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"agents": agents,
|
|
||||||
"total": result.pagination.total_items,
|
|
||||||
"page": result.pagination.current_page,
|
|
||||||
"total_pages": result.pagination.total_pages,
|
|
||||||
}
|
|
||||||
@@ -1,311 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.clients import get_database_manager_async_client
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# Duplicate pydantic models for store data so we don't accidently change the data shape in the blocks unintentionally when editing the backend
|
|
||||||
class StoreAgent(BaseModel):
|
|
||||||
"""Model representing a store agent."""
|
|
||||||
|
|
||||||
slug: str = ""
|
|
||||||
name: str = ""
|
|
||||||
description: str = ""
|
|
||||||
creator: str = ""
|
|
||||||
rating: float = 0.0
|
|
||||||
runs: int = 0
|
|
||||||
categories: list[str] = []
|
|
||||||
|
|
||||||
|
|
||||||
class StoreAgentDict(BaseModel):
|
|
||||||
"""Dictionary representation of a store agent."""
|
|
||||||
|
|
||||||
slug: str
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
creator: str
|
|
||||||
rating: float
|
|
||||||
runs: int
|
|
||||||
|
|
||||||
|
|
||||||
class SearchAgentsResponse(BaseModel):
|
|
||||||
"""Response from searching store agents."""
|
|
||||||
|
|
||||||
agents: list[StoreAgentDict]
|
|
||||||
total_count: int
|
|
||||||
|
|
||||||
|
|
||||||
class StoreAgentDetails(BaseModel):
|
|
||||||
"""Detailed information about a store agent."""
|
|
||||||
|
|
||||||
found: bool
|
|
||||||
store_listing_version_id: str = ""
|
|
||||||
agent_name: str = ""
|
|
||||||
description: str = ""
|
|
||||||
creator: str = ""
|
|
||||||
categories: list[str] = []
|
|
||||||
runs: int = 0
|
|
||||||
rating: float = 0.0
|
|
||||||
|
|
||||||
|
|
||||||
class GetStoreAgentDetailsBlock(Block):
|
|
||||||
"""
|
|
||||||
Block that retrieves detailed information about an agent from the store.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchema):
|
|
||||||
creator: str = SchemaField(description="The username of the agent creator")
|
|
||||||
slug: str = SchemaField(description="The name of the agent")
|
|
||||||
|
|
||||||
class Output(BlockSchema):
|
|
||||||
found: bool = SchemaField(
|
|
||||||
description="Whether the agent was found in the store"
|
|
||||||
)
|
|
||||||
store_listing_version_id: str = SchemaField(
|
|
||||||
description="The store listing version ID"
|
|
||||||
)
|
|
||||||
agent_name: str = SchemaField(description="Name of the agent")
|
|
||||||
description: str = SchemaField(description="Description of the agent")
|
|
||||||
creator: str = SchemaField(description="Creator of the agent")
|
|
||||||
categories: list[str] = SchemaField(
|
|
||||||
description="Categories the agent belongs to", default_factory=list
|
|
||||||
)
|
|
||||||
runs: int = SchemaField(
|
|
||||||
description="Number of times the agent has been run", default=0
|
|
||||||
)
|
|
||||||
rating: float = SchemaField(
|
|
||||||
description="Average rating of the agent", default=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="b604f0ec-6e0d-40a7-bf55-9fd09997cced",
|
|
||||||
description="Get detailed information about an agent from the store",
|
|
||||||
categories={BlockCategory.BASIC, BlockCategory.DATA},
|
|
||||||
input_schema=GetStoreAgentDetailsBlock.Input,
|
|
||||||
output_schema=GetStoreAgentDetailsBlock.Output,
|
|
||||||
test_input={"creator": "test-creator", "slug": "test-agent-slug"},
|
|
||||||
test_output=[
|
|
||||||
("found", True),
|
|
||||||
("store_listing_version_id", "test-listing-id"),
|
|
||||||
("agent_name", "Test Agent"),
|
|
||||||
("description", "A test agent"),
|
|
||||||
("creator", "Test Creator"),
|
|
||||||
("categories", ["productivity", "automation"]),
|
|
||||||
("runs", 100),
|
|
||||||
("rating", 4.5),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_get_agent_details": lambda *_, **__: StoreAgentDetails(
|
|
||||||
found=True,
|
|
||||||
store_listing_version_id="test-listing-id",
|
|
||||||
agent_name="Test Agent",
|
|
||||||
description="A test agent",
|
|
||||||
creator="Test Creator",
|
|
||||||
categories=["productivity", "automation"],
|
|
||||||
runs=100,
|
|
||||||
rating=4.5,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
static_output=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
details = await self._get_agent_details(
|
|
||||||
creator=input_data.creator, slug=input_data.slug
|
|
||||||
)
|
|
||||||
yield "found", details.found
|
|
||||||
yield "store_listing_version_id", details.store_listing_version_id
|
|
||||||
yield "agent_name", details.agent_name
|
|
||||||
yield "description", details.description
|
|
||||||
yield "creator", details.creator
|
|
||||||
yield "categories", details.categories
|
|
||||||
yield "runs", details.runs
|
|
||||||
yield "rating", details.rating
|
|
||||||
|
|
||||||
async def _get_agent_details(self, creator: str, slug: str) -> StoreAgentDetails:
|
|
||||||
"""
|
|
||||||
Retrieve detailed information about a store agent.
|
|
||||||
"""
|
|
||||||
# Get by specific version ID
|
|
||||||
agent_details = (
|
|
||||||
await get_database_manager_async_client().get_store_agent_details(
|
|
||||||
username=creator, agent_name=slug
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return StoreAgentDetails(
|
|
||||||
found=True,
|
|
||||||
store_listing_version_id=agent_details.store_listing_version_id,
|
|
||||||
agent_name=agent_details.agent_name,
|
|
||||||
description=agent_details.description,
|
|
||||||
creator=agent_details.creator,
|
|
||||||
categories=(
|
|
||||||
agent_details.categories if hasattr(agent_details, "categories") else []
|
|
||||||
),
|
|
||||||
runs=agent_details.runs,
|
|
||||||
rating=agent_details.rating,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SearchStoreAgentsBlock(Block):
|
|
||||||
"""
|
|
||||||
Block that searches for agents in the store based on various criteria.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchema):
|
|
||||||
query: str | None = SchemaField(
|
|
||||||
description="Search query to find agents", default=None
|
|
||||||
)
|
|
||||||
category: str | None = SchemaField(
|
|
||||||
description="Filter by category", default=None
|
|
||||||
)
|
|
||||||
sort_by: Literal["rating", "runs", "name", "recent"] = SchemaField(
|
|
||||||
description="How to sort the results", default="rating"
|
|
||||||
)
|
|
||||||
limit: int = SchemaField(
|
|
||||||
description="Maximum number of results to return", default=10, ge=1, le=100
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchema):
|
|
||||||
agents: list[StoreAgent] = SchemaField(
|
|
||||||
description="List of agents matching the search criteria",
|
|
||||||
default_factory=list,
|
|
||||||
)
|
|
||||||
agent: StoreAgent = SchemaField(description="Basic information of the agent")
|
|
||||||
total_count: int = SchemaField(
|
|
||||||
description="Total number of agents found", default=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="39524701-026c-4328-87cc-1b88c8e2cb4c",
|
|
||||||
description="Search for agents in the store",
|
|
||||||
categories={BlockCategory.BASIC, BlockCategory.DATA},
|
|
||||||
input_schema=SearchStoreAgentsBlock.Input,
|
|
||||||
output_schema=SearchStoreAgentsBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"query": "productivity",
|
|
||||||
"category": None,
|
|
||||||
"sort_by": "rating",
|
|
||||||
"limit": 10,
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"agents",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"slug": "test-agent",
|
|
||||||
"name": "Test Agent",
|
|
||||||
"description": "A test agent",
|
|
||||||
"creator": "Test Creator",
|
|
||||||
"rating": 4.5,
|
|
||||||
"runs": 100,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
),
|
|
||||||
("total_count", 1),
|
|
||||||
(
|
|
||||||
"agent",
|
|
||||||
{
|
|
||||||
"slug": "test-agent",
|
|
||||||
"name": "Test Agent",
|
|
||||||
"description": "A test agent",
|
|
||||||
"creator": "Test Creator",
|
|
||||||
"rating": 4.5,
|
|
||||||
"runs": 100,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_search_agents": lambda *_, **__: SearchAgentsResponse(
|
|
||||||
agents=[
|
|
||||||
StoreAgentDict(
|
|
||||||
slug="test-agent",
|
|
||||||
name="Test Agent",
|
|
||||||
description="A test agent",
|
|
||||||
creator="Test Creator",
|
|
||||||
rating=4.5,
|
|
||||||
runs=100,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
total_count=1,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
result = await self._search_agents(
|
|
||||||
query=input_data.query,
|
|
||||||
category=input_data.category,
|
|
||||||
sort_by=input_data.sort_by,
|
|
||||||
limit=input_data.limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
agents = result.agents
|
|
||||||
total_count = result.total_count
|
|
||||||
|
|
||||||
# Convert to dict for output
|
|
||||||
agents_as_dicts = [agent.model_dump() for agent in agents]
|
|
||||||
|
|
||||||
yield "agents", agents_as_dicts
|
|
||||||
yield "total_count", total_count
|
|
||||||
|
|
||||||
for agent_dict in agents_as_dicts:
|
|
||||||
yield "agent", agent_dict
|
|
||||||
|
|
||||||
async def _search_agents(
|
|
||||||
self,
|
|
||||||
query: str | None = None,
|
|
||||||
category: str | None = None,
|
|
||||||
sort_by: str = "rating",
|
|
||||||
limit: int = 10,
|
|
||||||
) -> SearchAgentsResponse:
|
|
||||||
"""
|
|
||||||
Search for agents in the store using the existing store database function.
|
|
||||||
"""
|
|
||||||
# Map our sort_by to the store's sorted_by parameter
|
|
||||||
sorted_by_map = {
|
|
||||||
"rating": "most_popular",
|
|
||||||
"runs": "most_runs",
|
|
||||||
"name": "alphabetical",
|
|
||||||
"recent": "recently_updated",
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await get_database_manager_async_client().get_store_agents(
|
|
||||||
featured=False,
|
|
||||||
creators=None,
|
|
||||||
sorted_by=sorted_by_map.get(sort_by, "most_popular"),
|
|
||||||
search_query=query,
|
|
||||||
category=category,
|
|
||||||
page=1,
|
|
||||||
page_size=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
agents = [
|
|
||||||
StoreAgentDict(
|
|
||||||
slug=agent.slug,
|
|
||||||
name=agent.agent_name,
|
|
||||||
description=agent.description,
|
|
||||||
creator=agent.creator,
|
|
||||||
rating=agent.rating,
|
|
||||||
runs=agent.runs,
|
|
||||||
)
|
|
||||||
for agent in result.agents
|
|
||||||
]
|
|
||||||
|
|
||||||
return SearchAgentsResponse(agents=agents, total_count=len(agents))
|
|
||||||
@@ -1,155 +0,0 @@
|
|||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.blocks.system.library_operations import (
|
|
||||||
AddToLibraryFromStoreBlock,
|
|
||||||
LibraryAgent,
|
|
||||||
)
|
|
||||||
from backend.blocks.system.store_operations import (
|
|
||||||
GetStoreAgentDetailsBlock,
|
|
||||||
SearchAgentsResponse,
|
|
||||||
SearchStoreAgentsBlock,
|
|
||||||
StoreAgentDetails,
|
|
||||||
StoreAgentDict,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_add_to_library_from_store_block_success(mocker):
|
|
||||||
"""Test successful addition of agent from store to library."""
|
|
||||||
block = AddToLibraryFromStoreBlock()
|
|
||||||
|
|
||||||
# Mock the library agent response
|
|
||||||
mock_library_agent = MagicMock()
|
|
||||||
mock_library_agent.id = "lib-agent-123"
|
|
||||||
mock_library_agent.graph_id = "graph-456"
|
|
||||||
mock_library_agent.graph_version = 1
|
|
||||||
mock_library_agent.name = "Test Agent"
|
|
||||||
|
|
||||||
mocker.patch.object(
|
|
||||||
block,
|
|
||||||
"_add_to_library",
|
|
||||||
return_value=LibraryAgent(
|
|
||||||
library_agent_id="lib-agent-123",
|
|
||||||
agent_id="graph-456",
|
|
||||||
agent_version=1,
|
|
||||||
agent_name="Test Agent",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
input_data = block.Input(
|
|
||||||
store_listing_version_id="store-listing-v1", agent_name="Custom Agent Name"
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = {}
|
|
||||||
async for name, value in block.run(input_data, user_id="test-user"):
|
|
||||||
outputs[name] = value
|
|
||||||
|
|
||||||
assert outputs["success"] is True
|
|
||||||
assert outputs["library_agent_id"] == "lib-agent-123"
|
|
||||||
assert outputs["agent_id"] == "graph-456"
|
|
||||||
assert outputs["agent_version"] == 1
|
|
||||||
assert outputs["agent_name"] == "Test Agent"
|
|
||||||
assert outputs["message"] == "Agent successfully added to library"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_store_agent_details_block_success(mocker):
|
|
||||||
"""Test successful retrieval of store agent details."""
|
|
||||||
block = GetStoreAgentDetailsBlock()
|
|
||||||
|
|
||||||
mocker.patch.object(
|
|
||||||
block,
|
|
||||||
"_get_agent_details",
|
|
||||||
return_value=StoreAgentDetails(
|
|
||||||
found=True,
|
|
||||||
store_listing_version_id="version-123",
|
|
||||||
agent_name="Test Agent",
|
|
||||||
description="A test agent for testing",
|
|
||||||
creator="Test Creator",
|
|
||||||
categories=["productivity", "automation"],
|
|
||||||
runs=100,
|
|
||||||
rating=4.5,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
input_data = block.Input(creator="Test Creator", slug="test-slug")
|
|
||||||
outputs = {}
|
|
||||||
async for name, value in block.run(input_data):
|
|
||||||
outputs[name] = value
|
|
||||||
|
|
||||||
assert outputs["found"] is True
|
|
||||||
assert outputs["store_listing_version_id"] == "version-123"
|
|
||||||
assert outputs["agent_name"] == "Test Agent"
|
|
||||||
assert outputs["description"] == "A test agent for testing"
|
|
||||||
assert outputs["creator"] == "Test Creator"
|
|
||||||
assert outputs["categories"] == ["productivity", "automation"]
|
|
||||||
assert outputs["runs"] == 100
|
|
||||||
assert outputs["rating"] == 4.5
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_search_store_agents_block(mocker):
|
|
||||||
"""Test searching for store agents."""
|
|
||||||
block = SearchStoreAgentsBlock()
|
|
||||||
|
|
||||||
mocker.patch.object(
|
|
||||||
block,
|
|
||||||
"_search_agents",
|
|
||||||
return_value=SearchAgentsResponse(
|
|
||||||
agents=[
|
|
||||||
StoreAgentDict(
|
|
||||||
slug="creator1/agent1",
|
|
||||||
name="Agent One",
|
|
||||||
description="First test agent",
|
|
||||||
creator="Creator 1",
|
|
||||||
rating=4.8,
|
|
||||||
runs=500,
|
|
||||||
),
|
|
||||||
StoreAgentDict(
|
|
||||||
slug="creator2/agent2",
|
|
||||||
name="Agent Two",
|
|
||||||
description="Second test agent",
|
|
||||||
creator="Creator 2",
|
|
||||||
rating=4.2,
|
|
||||||
runs=200,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
total_count=2,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
input_data = block.Input(
|
|
||||||
query="test", category="productivity", sort_by="rating", limit=10
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = {}
|
|
||||||
async for name, value in block.run(input_data):
|
|
||||||
outputs[name] = value
|
|
||||||
|
|
||||||
assert len(outputs["agents"]) == 2
|
|
||||||
assert outputs["total_count"] == 2
|
|
||||||
assert outputs["agents"][0]["name"] == "Agent One"
|
|
||||||
assert outputs["agents"][0]["rating"] == 4.8
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_search_store_agents_block_empty_results(mocker):
|
|
||||||
"""Test searching with no results."""
|
|
||||||
block = SearchStoreAgentsBlock()
|
|
||||||
|
|
||||||
mocker.patch.object(
|
|
||||||
block,
|
|
||||||
"_search_agents",
|
|
||||||
return_value=SearchAgentsResponse(agents=[], total_count=0),
|
|
||||||
)
|
|
||||||
|
|
||||||
input_data = block.Input(query="nonexistent", limit=10)
|
|
||||||
|
|
||||||
outputs = {}
|
|
||||||
async for name, value in block.run(input_data):
|
|
||||||
outputs[name] = value
|
|
||||||
|
|
||||||
assert outputs["agents"] == []
|
|
||||||
assert outputs["total_count"] == 0
|
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Literal, Union
|
from typing import Any, Literal, Union
|
||||||
@@ -8,7 +7,6 @@ from zoneinfo import ZoneInfo
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||||
from backend.data.execution import UserContext
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
|
|
||||||
# Shared timezone literal type for all time/date blocks
|
# Shared timezone literal type for all time/date blocks
|
||||||
@@ -53,80 +51,16 @@ TimezoneLiteral = Literal[
|
|||||||
"Etc/GMT+12", # UTC-12:00
|
"Etc/GMT+12", # UTC-12:00
|
||||||
]
|
]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_timezone(
|
|
||||||
format_type: Any, # Any format type with timezone and use_user_timezone attributes
|
|
||||||
user_timezone: str | None,
|
|
||||||
) -> ZoneInfo:
|
|
||||||
"""
|
|
||||||
Determine which timezone to use based on format settings and user context.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
format_type: The format configuration containing timezone settings
|
|
||||||
user_timezone: The user's timezone from context
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ZoneInfo object for the determined timezone
|
|
||||||
"""
|
|
||||||
if format_type.use_user_timezone and user_timezone:
|
|
||||||
tz = ZoneInfo(user_timezone)
|
|
||||||
logger.debug(f"Using user timezone: {user_timezone}")
|
|
||||||
else:
|
|
||||||
tz = ZoneInfo(format_type.timezone)
|
|
||||||
logger.debug(f"Using specified timezone: {format_type.timezone}")
|
|
||||||
return tz
|
|
||||||
|
|
||||||
|
|
||||||
def _format_datetime_iso8601(dt: datetime, include_microseconds: bool = False) -> str:
|
|
||||||
"""
|
|
||||||
Format a datetime object to ISO8601 string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dt: The datetime object to format
|
|
||||||
include_microseconds: Whether to include microseconds in the output
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ISO8601 formatted string
|
|
||||||
"""
|
|
||||||
if include_microseconds:
|
|
||||||
return dt.isoformat()
|
|
||||||
else:
|
|
||||||
return dt.isoformat(timespec="seconds")
|
|
||||||
|
|
||||||
|
|
||||||
# BACKWARDS COMPATIBILITY NOTE:
|
|
||||||
# The timezone field is kept at the format level (not block level) for backwards compatibility.
|
|
||||||
# Existing graphs have timezone saved within format_type, moving it would break them.
|
|
||||||
#
|
|
||||||
# The use_user_timezone flag was added to allow using the user's profile timezone.
|
|
||||||
# Default is False to maintain backwards compatibility - existing graphs will continue
|
|
||||||
# using their specified timezone.
|
|
||||||
#
|
|
||||||
# KNOWN ISSUE: If a user switches between format types (strftime <-> iso8601),
|
|
||||||
# the timezone setting doesn't carry over. This is a UX issue but fixing it would
|
|
||||||
# require either:
|
|
||||||
# 1. Moving timezone to block level (breaking change, needs migration)
|
|
||||||
# 2. Complex state management to sync timezone across format types
|
|
||||||
#
|
|
||||||
# Future migration path: When we do a major version bump, consider moving timezone
|
|
||||||
# to the block Input level for better UX.
|
|
||||||
|
|
||||||
|
|
||||||
class TimeStrftimeFormat(BaseModel):
|
class TimeStrftimeFormat(BaseModel):
|
||||||
discriminator: Literal["strftime"]
|
discriminator: Literal["strftime"]
|
||||||
format: str = "%H:%M:%S"
|
format: str = "%H:%M:%S"
|
||||||
timezone: TimezoneLiteral = "UTC"
|
timezone: TimezoneLiteral = "UTC"
|
||||||
# When True, overrides timezone with user's profile timezone
|
|
||||||
use_user_timezone: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class TimeISO8601Format(BaseModel):
|
class TimeISO8601Format(BaseModel):
|
||||||
discriminator: Literal["iso8601"]
|
discriminator: Literal["iso8601"]
|
||||||
timezone: TimezoneLiteral = "UTC"
|
timezone: TimezoneLiteral = "UTC"
|
||||||
# When True, overrides timezone with user's profile timezone
|
|
||||||
use_user_timezone: bool = False
|
|
||||||
include_microseconds: bool = False
|
include_microseconds: bool = False
|
||||||
|
|
||||||
|
|
||||||
@@ -181,27 +115,25 @@ class GetCurrentTimeBlock(Block):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||||
self, input_data: Input, *, user_context: UserContext, **kwargs
|
|
||||||
) -> BlockOutput:
|
|
||||||
# Extract timezone from user_context (always present)
|
|
||||||
effective_timezone = user_context.timezone
|
|
||||||
|
|
||||||
# Get the appropriate timezone
|
|
||||||
tz = _get_timezone(input_data.format_type, effective_timezone)
|
|
||||||
dt = datetime.now(tz=tz)
|
|
||||||
|
|
||||||
if isinstance(input_data.format_type, TimeISO8601Format):
|
if isinstance(input_data.format_type, TimeISO8601Format):
|
||||||
|
# ISO 8601 format for time only (extract time portion from full ISO datetime)
|
||||||
|
tz = ZoneInfo(input_data.format_type.timezone)
|
||||||
|
dt = datetime.now(tz=tz)
|
||||||
|
|
||||||
# Get the full ISO format and extract just the time portion with timezone
|
# Get the full ISO format and extract just the time portion with timezone
|
||||||
full_iso = _format_datetime_iso8601(
|
if input_data.format_type.include_microseconds:
|
||||||
dt, input_data.format_type.include_microseconds
|
full_iso = dt.isoformat()
|
||||||
)
|
else:
|
||||||
|
full_iso = dt.isoformat(timespec="seconds")
|
||||||
|
|
||||||
# Extract time portion (everything after 'T')
|
# Extract time portion (everything after 'T')
|
||||||
current_time = full_iso.split("T")[1] if "T" in full_iso else full_iso
|
current_time = full_iso.split("T")[1] if "T" in full_iso else full_iso
|
||||||
current_time = f"T{current_time}" # Add T prefix for ISO 8601 time format
|
current_time = f"T{current_time}" # Add T prefix for ISO 8601 time format
|
||||||
else: # TimeStrftimeFormat
|
else: # TimeStrftimeFormat
|
||||||
|
tz = ZoneInfo(input_data.format_type.timezone)
|
||||||
|
dt = datetime.now(tz=tz)
|
||||||
current_time = dt.strftime(input_data.format_type.format)
|
current_time = dt.strftime(input_data.format_type.format)
|
||||||
|
|
||||||
yield "time", current_time
|
yield "time", current_time
|
||||||
|
|
||||||
|
|
||||||
@@ -209,15 +141,11 @@ class DateStrftimeFormat(BaseModel):
|
|||||||
discriminator: Literal["strftime"]
|
discriminator: Literal["strftime"]
|
||||||
format: str = "%Y-%m-%d"
|
format: str = "%Y-%m-%d"
|
||||||
timezone: TimezoneLiteral = "UTC"
|
timezone: TimezoneLiteral = "UTC"
|
||||||
# When True, overrides timezone with user's profile timezone
|
|
||||||
use_user_timezone: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class DateISO8601Format(BaseModel):
|
class DateISO8601Format(BaseModel):
|
||||||
discriminator: Literal["iso8601"]
|
discriminator: Literal["iso8601"]
|
||||||
timezone: TimezoneLiteral = "UTC"
|
timezone: TimezoneLiteral = "UTC"
|
||||||
# When True, overrides timezone with user's profile timezone
|
|
||||||
use_user_timezone: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class GetCurrentDateBlock(Block):
|
class GetCurrentDateBlock(Block):
|
||||||
@@ -289,23 +217,20 @@ class GetCurrentDateBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||||
# Extract timezone from user_context (required keyword argument)
|
|
||||||
user_context: UserContext = kwargs["user_context"]
|
|
||||||
effective_timezone = user_context.timezone
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
offset = int(input_data.offset)
|
offset = int(input_data.offset)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
offset = 0
|
offset = 0
|
||||||
|
|
||||||
# Get the appropriate timezone
|
|
||||||
tz = _get_timezone(input_data.format_type, effective_timezone)
|
|
||||||
current_date = datetime.now(tz=tz) - timedelta(days=offset)
|
|
||||||
|
|
||||||
if isinstance(input_data.format_type, DateISO8601Format):
|
if isinstance(input_data.format_type, DateISO8601Format):
|
||||||
|
# ISO 8601 format for date only (YYYY-MM-DD)
|
||||||
|
tz = ZoneInfo(input_data.format_type.timezone)
|
||||||
|
current_date = datetime.now(tz=tz) - timedelta(days=offset)
|
||||||
# ISO 8601 date format is YYYY-MM-DD
|
# ISO 8601 date format is YYYY-MM-DD
|
||||||
date_str = current_date.date().isoformat()
|
date_str = current_date.date().isoformat()
|
||||||
else: # DateStrftimeFormat
|
else: # DateStrftimeFormat
|
||||||
|
tz = ZoneInfo(input_data.format_type.timezone)
|
||||||
|
current_date = datetime.now(tz=tz) - timedelta(days=offset)
|
||||||
date_str = current_date.strftime(input_data.format_type.format)
|
date_str = current_date.strftime(input_data.format_type.format)
|
||||||
|
|
||||||
yield "date", date_str
|
yield "date", date_str
|
||||||
@@ -315,15 +240,11 @@ class StrftimeFormat(BaseModel):
|
|||||||
discriminator: Literal["strftime"]
|
discriminator: Literal["strftime"]
|
||||||
format: str = "%Y-%m-%d %H:%M:%S"
|
format: str = "%Y-%m-%d %H:%M:%S"
|
||||||
timezone: TimezoneLiteral = "UTC"
|
timezone: TimezoneLiteral = "UTC"
|
||||||
# When True, overrides timezone with user's profile timezone
|
|
||||||
use_user_timezone: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class ISO8601Format(BaseModel):
|
class ISO8601Format(BaseModel):
|
||||||
discriminator: Literal["iso8601"]
|
discriminator: Literal["iso8601"]
|
||||||
timezone: TimezoneLiteral = "UTC"
|
timezone: TimezoneLiteral = "UTC"
|
||||||
# When True, overrides timezone with user's profile timezone
|
|
||||||
use_user_timezone: bool = False
|
|
||||||
include_microseconds: bool = False
|
include_microseconds: bool = False
|
||||||
|
|
||||||
|
|
||||||
@@ -395,22 +316,20 @@ class GetCurrentDateAndTimeBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||||
# Extract timezone from user_context (required keyword argument)
|
|
||||||
user_context: UserContext = kwargs["user_context"]
|
|
||||||
effective_timezone = user_context.timezone
|
|
||||||
|
|
||||||
# Get the appropriate timezone
|
|
||||||
tz = _get_timezone(input_data.format_type, effective_timezone)
|
|
||||||
dt = datetime.now(tz=tz)
|
|
||||||
|
|
||||||
if isinstance(input_data.format_type, ISO8601Format):
|
if isinstance(input_data.format_type, ISO8601Format):
|
||||||
# ISO 8601 format with specified timezone (also RFC3339-compliant)
|
# ISO 8601 format with specified timezone (also RFC3339-compliant)
|
||||||
current_date_time = _format_datetime_iso8601(
|
tz = ZoneInfo(input_data.format_type.timezone)
|
||||||
dt, input_data.format_type.include_microseconds
|
dt = datetime.now(tz=tz)
|
||||||
)
|
|
||||||
else: # StrftimeFormat
|
|
||||||
current_date_time = dt.strftime(input_data.format_type.format)
|
|
||||||
|
|
||||||
|
# Format with or without microseconds
|
||||||
|
if input_data.format_type.include_microseconds:
|
||||||
|
current_date_time = dt.isoformat()
|
||||||
|
else:
|
||||||
|
current_date_time = dt.isoformat(timespec="seconds")
|
||||||
|
else: # StrftimeFormat
|
||||||
|
tz = ZoneInfo(input_data.format_type.timezone)
|
||||||
|
dt = datetime.now(tz=tz)
|
||||||
|
current_date_time = dt.strftime(input_data.format_type.format)
|
||||||
yield "date_time", current_date_time
|
yield "date_time", current_date_time
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from prisma.models import CreditTransaction
|
|||||||
from backend.blocks.llm import AITextGeneratorBlock
|
from backend.blocks.llm import AITextGeneratorBlock
|
||||||
from backend.data.block import get_block
|
from backend.data.block import get_block
|
||||||
from backend.data.credit import BetaUserCredit, UsageTransactionMetadata
|
from backend.data.credit import BetaUserCredit, UsageTransactionMetadata
|
||||||
from backend.data.execution import NodeExecutionEntry, UserContext
|
from backend.data.execution import NodeExecutionEntry
|
||||||
from backend.data.user import DEFAULT_USER_ID
|
from backend.data.user import DEFAULT_USER_ID
|
||||||
from backend.executor.utils import block_usage_cost
|
from backend.executor.utils import block_usage_cost
|
||||||
from backend.integrations.credentials_store import openai_credentials
|
from backend.integrations.credentials_store import openai_credentials
|
||||||
@@ -75,7 +75,6 @@ async def test_block_credit_usage(server: SpinTestServer):
|
|||||||
"type": openai_credentials.type,
|
"type": openai_credentials.type,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
user_context=UserContext(timezone="UTC"),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
assert spending_amount_1 > 0
|
assert spending_amount_1 > 0
|
||||||
@@ -89,7 +88,6 @@ async def test_block_credit_usage(server: SpinTestServer):
|
|||||||
node_exec_id="test_node_exec",
|
node_exec_id="test_node_exec",
|
||||||
block_id=AITextGeneratorBlock().id,
|
block_id=AITextGeneratorBlock().id,
|
||||||
inputs={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
|
inputs={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
|
||||||
user_context=UserContext(timezone="UTC"),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
assert spending_amount_2 == 0
|
assert spending_amount_2 == 0
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ from pydantic.fields import Field
|
|||||||
from backend.server.v2.store.exceptions import DatabaseError
|
from backend.server.v2.store.exceptions import DatabaseError
|
||||||
from backend.util import type as type_utils
|
from backend.util import type as type_utils
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
from backend.util.models import Pagination
|
|
||||||
from backend.util.retry import func_retry
|
from backend.util.retry import func_retry
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
from backend.util.truncate import truncate
|
from backend.util.truncate import truncate
|
||||||
@@ -90,7 +89,6 @@ ExecutionStatus = AgentExecutionStatus
|
|||||||
|
|
||||||
|
|
||||||
class GraphExecutionMeta(BaseDbModel):
|
class GraphExecutionMeta(BaseDbModel):
|
||||||
id: str # type: ignore # Override base class to make this required
|
|
||||||
user_id: str
|
user_id: str
|
||||||
graph_id: str
|
graph_id: str
|
||||||
graph_version: int
|
graph_version: int
|
||||||
@@ -292,14 +290,13 @@ class GraphExecutionWithNodes(GraphExecution):
|
|||||||
node_executions=node_executions,
|
node_executions=node_executions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_graph_execution_entry(self, user_context: "UserContext"):
|
def to_graph_execution_entry(self):
|
||||||
return GraphExecutionEntry(
|
return GraphExecutionEntry(
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
graph_id=self.graph_id,
|
graph_id=self.graph_id,
|
||||||
graph_version=self.graph_version or 0,
|
graph_version=self.graph_version or 0,
|
||||||
graph_exec_id=self.id,
|
graph_exec_id=self.id,
|
||||||
nodes_input_masks={}, # FIXME: store credentials on AgentGraphExecution
|
nodes_input_masks={}, # FIXME: store credentials on AgentGraphExecution
|
||||||
user_context=user_context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -371,9 +368,7 @@ class NodeExecutionResult(BaseModel):
|
|||||||
end_time=_node_exec.endedTime,
|
end_time=_node_exec.endedTime,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_node_execution_entry(
|
def to_node_execution_entry(self) -> "NodeExecutionEntry":
|
||||||
self, user_context: "UserContext"
|
|
||||||
) -> "NodeExecutionEntry":
|
|
||||||
return NodeExecutionEntry(
|
return NodeExecutionEntry(
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
graph_exec_id=self.graph_exec_id,
|
graph_exec_id=self.graph_exec_id,
|
||||||
@@ -382,7 +377,6 @@ class NodeExecutionResult(BaseModel):
|
|||||||
node_id=self.node_id,
|
node_id=self.node_id,
|
||||||
block_id=self.block_id,
|
block_id=self.block_id,
|
||||||
inputs=self.input_data,
|
inputs=self.input_data,
|
||||||
user_context=user_context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -390,13 +384,13 @@ class NodeExecutionResult(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
async def get_graph_executions(
|
async def get_graph_executions(
|
||||||
graph_exec_id: Optional[str] = None,
|
graph_exec_id: str | None = None,
|
||||||
graph_id: Optional[str] = None,
|
graph_id: str | None = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: str | None = None,
|
||||||
statuses: Optional[list[ExecutionStatus]] = None,
|
statuses: list[ExecutionStatus] | None = None,
|
||||||
created_time_gte: Optional[datetime] = None,
|
created_time_gte: datetime | None = None,
|
||||||
created_time_lte: Optional[datetime] = None,
|
created_time_lte: datetime | None = None,
|
||||||
limit: Optional[int] = None,
|
limit: int | None = None,
|
||||||
) -> list[GraphExecutionMeta]:
|
) -> list[GraphExecutionMeta]:
|
||||||
"""⚠️ **Optional `user_id` check**: MUST USE check in user-facing endpoints."""
|
"""⚠️ **Optional `user_id` check**: MUST USE check in user-facing endpoints."""
|
||||||
where_filter: AgentGraphExecutionWhereInput = {
|
where_filter: AgentGraphExecutionWhereInput = {
|
||||||
@@ -424,60 +418,6 @@ async def get_graph_executions(
|
|||||||
return [GraphExecutionMeta.from_db(execution) for execution in executions]
|
return [GraphExecutionMeta.from_db(execution) for execution in executions]
|
||||||
|
|
||||||
|
|
||||||
class GraphExecutionsPaginated(BaseModel):
|
|
||||||
"""Response schema for paginated graph executions."""
|
|
||||||
|
|
||||||
executions: list[GraphExecutionMeta]
|
|
||||||
pagination: Pagination
|
|
||||||
|
|
||||||
|
|
||||||
async def get_graph_executions_paginated(
|
|
||||||
user_id: str,
|
|
||||||
graph_id: Optional[str] = None,
|
|
||||||
page: int = 1,
|
|
||||||
page_size: int = 25,
|
|
||||||
statuses: Optional[list[ExecutionStatus]] = None,
|
|
||||||
created_time_gte: Optional[datetime] = None,
|
|
||||||
created_time_lte: Optional[datetime] = None,
|
|
||||||
) -> GraphExecutionsPaginated:
|
|
||||||
"""Get paginated graph executions for a specific graph."""
|
|
||||||
where_filter: AgentGraphExecutionWhereInput = {
|
|
||||||
"isDeleted": False,
|
|
||||||
"userId": user_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
if graph_id:
|
|
||||||
where_filter["agentGraphId"] = graph_id
|
|
||||||
if created_time_gte or created_time_lte:
|
|
||||||
where_filter["createdAt"] = {
|
|
||||||
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
|
|
||||||
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
|
|
||||||
}
|
|
||||||
if statuses:
|
|
||||||
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
|
|
||||||
|
|
||||||
total_count = await AgentGraphExecution.prisma().count(where=where_filter)
|
|
||||||
total_pages = (total_count + page_size - 1) // page_size
|
|
||||||
|
|
||||||
offset = (page - 1) * page_size
|
|
||||||
executions = await AgentGraphExecution.prisma().find_many(
|
|
||||||
where=where_filter,
|
|
||||||
order={"createdAt": "desc"},
|
|
||||||
take=page_size,
|
|
||||||
skip=offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
return GraphExecutionsPaginated(
|
|
||||||
executions=[GraphExecutionMeta.from_db(execution) for execution in executions],
|
|
||||||
pagination=Pagination(
|
|
||||||
total_items=total_count,
|
|
||||||
total_pages=total_pages,
|
|
||||||
current_page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_graph_execution_meta(
|
async def get_graph_execution_meta(
|
||||||
user_id: str, execution_id: str
|
user_id: str, execution_id: str
|
||||||
) -> GraphExecutionMeta | None:
|
) -> GraphExecutionMeta | None:
|
||||||
@@ -877,19 +817,12 @@ async def get_latest_node_execution(
|
|||||||
# ----------------- Execution Infrastructure ----------------- #
|
# ----------------- Execution Infrastructure ----------------- #
|
||||||
|
|
||||||
|
|
||||||
class UserContext(BaseModel):
|
|
||||||
"""Generic user context for graph execution containing user-specific settings."""
|
|
||||||
|
|
||||||
timezone: str
|
|
||||||
|
|
||||||
|
|
||||||
class GraphExecutionEntry(BaseModel):
|
class GraphExecutionEntry(BaseModel):
|
||||||
user_id: str
|
user_id: str
|
||||||
graph_exec_id: str
|
graph_exec_id: str
|
||||||
graph_id: str
|
graph_id: str
|
||||||
graph_version: int
|
graph_version: int
|
||||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None
|
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None
|
||||||
user_context: UserContext
|
|
||||||
|
|
||||||
|
|
||||||
class NodeExecutionEntry(BaseModel):
|
class NodeExecutionEntry(BaseModel):
|
||||||
@@ -900,7 +833,6 @@ class NodeExecutionEntry(BaseModel):
|
|||||||
node_id: str
|
node_id: str
|
||||||
block_id: str
|
block_id: str
|
||||||
inputs: BlockInput
|
inputs: BlockInput
|
||||||
user_context: UserContext
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionQueue(Generic[T]):
|
class ExecutionQueue(Generic[T]):
|
||||||
|
|||||||
@@ -96,12 +96,6 @@ class User(BaseModel):
|
|||||||
default=True, description="Notify on monthly summary"
|
default=True, description="Notify on monthly summary"
|
||||||
)
|
)
|
||||||
|
|
||||||
# User timezone for scheduling and time display
|
|
||||||
timezone: str = Field(
|
|
||||||
default="not-set",
|
|
||||||
description="User timezone (IANA timezone identifier or 'not-set')",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db(cls, prisma_user: "PrismaUser") -> "User":
|
def from_db(cls, prisma_user: "PrismaUser") -> "User":
|
||||||
"""Convert a database User object to application User model."""
|
"""Convert a database User object to application User model."""
|
||||||
@@ -155,7 +149,6 @@ class User(BaseModel):
|
|||||||
notify_on_daily_summary=prisma_user.notifyOnDailySummary or True,
|
notify_on_daily_summary=prisma_user.notifyOnDailySummary or True,
|
||||||
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
|
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
|
||||||
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
|
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
|
||||||
timezone=prisma_user.timezone or "not-set",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -54,6 +54,19 @@ class AgentRunData(BaseNotificationData):
|
|||||||
|
|
||||||
|
|
||||||
class ZeroBalanceData(BaseNotificationData):
|
class ZeroBalanceData(BaseNotificationData):
|
||||||
|
last_transaction: float
|
||||||
|
last_transaction_time: datetime
|
||||||
|
top_up_link: str
|
||||||
|
|
||||||
|
@field_validator("last_transaction_time")
|
||||||
|
@classmethod
|
||||||
|
def validate_timezone(cls, value: datetime):
|
||||||
|
if value.tzinfo is None:
|
||||||
|
raise ValueError("datetime must have timezone information")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class LowBalanceData(BaseNotificationData):
|
||||||
agent_name: str = Field(..., description="Name of the agent")
|
agent_name: str = Field(..., description="Name of the agent")
|
||||||
current_balance: float = Field(
|
current_balance: float = Field(
|
||||||
..., description="Current balance in credits (100 = $1)"
|
..., description="Current balance in credits (100 = $1)"
|
||||||
@@ -62,13 +75,6 @@ class ZeroBalanceData(BaseNotificationData):
|
|||||||
shortfall: float = Field(..., description="Amount of credits needed to continue")
|
shortfall: float = Field(..., description="Amount of credits needed to continue")
|
||||||
|
|
||||||
|
|
||||||
class LowBalanceData(BaseNotificationData):
|
|
||||||
current_balance: float = Field(
|
|
||||||
..., description="Current balance in credits (100 = $1)"
|
|
||||||
)
|
|
||||||
billing_page_link: str = Field(..., description="Link to billing page")
|
|
||||||
|
|
||||||
|
|
||||||
class BlockExecutionFailedData(BaseNotificationData):
|
class BlockExecutionFailedData(BaseNotificationData):
|
||||||
block_name: str
|
block_name: str
|
||||||
block_id: str
|
block_id: str
|
||||||
@@ -175,42 +181,6 @@ class RefundRequestData(BaseNotificationData):
|
|||||||
balance: int
|
balance: int
|
||||||
|
|
||||||
|
|
||||||
class AgentApprovalData(BaseNotificationData):
|
|
||||||
agent_name: str
|
|
||||||
agent_id: str
|
|
||||||
agent_version: int
|
|
||||||
reviewer_name: str
|
|
||||||
reviewer_email: str
|
|
||||||
comments: str
|
|
||||||
reviewed_at: datetime
|
|
||||||
store_url: str
|
|
||||||
|
|
||||||
@field_validator("reviewed_at")
|
|
||||||
@classmethod
|
|
||||||
def validate_timezone(cls, value: datetime):
|
|
||||||
if value.tzinfo is None:
|
|
||||||
raise ValueError("datetime must have timezone information")
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class AgentRejectionData(BaseNotificationData):
|
|
||||||
agent_name: str
|
|
||||||
agent_id: str
|
|
||||||
agent_version: int
|
|
||||||
reviewer_name: str
|
|
||||||
reviewer_email: str
|
|
||||||
comments: str
|
|
||||||
reviewed_at: datetime
|
|
||||||
resubmit_url: str
|
|
||||||
|
|
||||||
@field_validator("reviewed_at")
|
|
||||||
@classmethod
|
|
||||||
def validate_timezone(cls, value: datetime):
|
|
||||||
if value.tzinfo is None:
|
|
||||||
raise ValueError("datetime must have timezone information")
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
NotificationData = Annotated[
|
NotificationData = Annotated[
|
||||||
Union[
|
Union[
|
||||||
AgentRunData,
|
AgentRunData,
|
||||||
@@ -270,8 +240,6 @@ def get_notif_data_type(
|
|||||||
NotificationType.MONTHLY_SUMMARY: MonthlySummaryData,
|
NotificationType.MONTHLY_SUMMARY: MonthlySummaryData,
|
||||||
NotificationType.REFUND_REQUEST: RefundRequestData,
|
NotificationType.REFUND_REQUEST: RefundRequestData,
|
||||||
NotificationType.REFUND_PROCESSED: RefundRequestData,
|
NotificationType.REFUND_PROCESSED: RefundRequestData,
|
||||||
NotificationType.AGENT_APPROVED: AgentApprovalData,
|
|
||||||
NotificationType.AGENT_REJECTED: AgentRejectionData,
|
|
||||||
}[notification_type]
|
}[notification_type]
|
||||||
|
|
||||||
|
|
||||||
@@ -306,7 +274,7 @@ class NotificationTypeOverride:
|
|||||||
# These are batched by the notification service
|
# These are batched by the notification service
|
||||||
NotificationType.AGENT_RUN: QueueType.BATCH,
|
NotificationType.AGENT_RUN: QueueType.BATCH,
|
||||||
# These are batched by the notification service, but with a backoff strategy
|
# These are batched by the notification service, but with a backoff strategy
|
||||||
NotificationType.ZERO_BALANCE: QueueType.IMMEDIATE,
|
NotificationType.ZERO_BALANCE: QueueType.BACKOFF,
|
||||||
NotificationType.LOW_BALANCE: QueueType.IMMEDIATE,
|
NotificationType.LOW_BALANCE: QueueType.IMMEDIATE,
|
||||||
NotificationType.BLOCK_EXECUTION_FAILED: QueueType.BACKOFF,
|
NotificationType.BLOCK_EXECUTION_FAILED: QueueType.BACKOFF,
|
||||||
NotificationType.CONTINUOUS_AGENT_ERROR: QueueType.BACKOFF,
|
NotificationType.CONTINUOUS_AGENT_ERROR: QueueType.BACKOFF,
|
||||||
@@ -315,8 +283,6 @@ class NotificationTypeOverride:
|
|||||||
NotificationType.MONTHLY_SUMMARY: QueueType.SUMMARY,
|
NotificationType.MONTHLY_SUMMARY: QueueType.SUMMARY,
|
||||||
NotificationType.REFUND_REQUEST: QueueType.ADMIN,
|
NotificationType.REFUND_REQUEST: QueueType.ADMIN,
|
||||||
NotificationType.REFUND_PROCESSED: QueueType.ADMIN,
|
NotificationType.REFUND_PROCESSED: QueueType.ADMIN,
|
||||||
NotificationType.AGENT_APPROVED: QueueType.IMMEDIATE,
|
|
||||||
NotificationType.AGENT_REJECTED: QueueType.IMMEDIATE,
|
|
||||||
}
|
}
|
||||||
return BATCHING_RULES.get(self.notification_type, QueueType.IMMEDIATE)
|
return BATCHING_RULES.get(self.notification_type, QueueType.IMMEDIATE)
|
||||||
|
|
||||||
@@ -334,8 +300,6 @@ class NotificationTypeOverride:
|
|||||||
NotificationType.MONTHLY_SUMMARY: "monthly_summary.html",
|
NotificationType.MONTHLY_SUMMARY: "monthly_summary.html",
|
||||||
NotificationType.REFUND_REQUEST: "refund_request.html",
|
NotificationType.REFUND_REQUEST: "refund_request.html",
|
||||||
NotificationType.REFUND_PROCESSED: "refund_processed.html",
|
NotificationType.REFUND_PROCESSED: "refund_processed.html",
|
||||||
NotificationType.AGENT_APPROVED: "agent_approved.html",
|
|
||||||
NotificationType.AGENT_REJECTED: "agent_rejected.html",
|
|
||||||
}[self.notification_type]
|
}[self.notification_type]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -351,8 +315,6 @@ class NotificationTypeOverride:
|
|||||||
NotificationType.MONTHLY_SUMMARY: "We did a lot this month!",
|
NotificationType.MONTHLY_SUMMARY: "We did a lot this month!",
|
||||||
NotificationType.REFUND_REQUEST: "[ACTION REQUIRED] You got a ${{data.amount / 100}} refund request from {{data.user_name}}",
|
NotificationType.REFUND_REQUEST: "[ACTION REQUIRED] You got a ${{data.amount / 100}} refund request from {{data.user_name}}",
|
||||||
NotificationType.REFUND_PROCESSED: "Refund for ${{data.amount / 100}} to {{data.user_name}} has been processed",
|
NotificationType.REFUND_PROCESSED: "Refund for ${{data.amount / 100}} to {{data.user_name}} has been processed",
|
||||||
NotificationType.AGENT_APPROVED: "🎉 Your agent '{{data.agent_name}}' has been approved!",
|
|
||||||
NotificationType.AGENT_REJECTED: "Your agent '{{data.agent_name}}' needs some updates",
|
|
||||||
}[self.notification_type]
|
}[self.notification_type]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,151 +0,0 @@
|
|||||||
"""Tests for notification data models."""
|
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from backend.data.notifications import AgentApprovalData, AgentRejectionData
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgentApprovalData:
|
|
||||||
"""Test cases for AgentApprovalData model."""
|
|
||||||
|
|
||||||
def test_valid_agent_approval_data(self):
|
|
||||||
"""Test creating valid AgentApprovalData."""
|
|
||||||
data = AgentApprovalData(
|
|
||||||
agent_name="Test Agent",
|
|
||||||
agent_id="test-agent-123",
|
|
||||||
agent_version=1,
|
|
||||||
reviewer_name="John Doe",
|
|
||||||
reviewer_email="john@example.com",
|
|
||||||
comments="Great agent, approved!",
|
|
||||||
reviewed_at=datetime.now(timezone.utc),
|
|
||||||
store_url="https://app.autogpt.com/store/test-agent-123",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert data.agent_name == "Test Agent"
|
|
||||||
assert data.agent_id == "test-agent-123"
|
|
||||||
assert data.agent_version == 1
|
|
||||||
assert data.reviewer_name == "John Doe"
|
|
||||||
assert data.reviewer_email == "john@example.com"
|
|
||||||
assert data.comments == "Great agent, approved!"
|
|
||||||
assert data.store_url == "https://app.autogpt.com/store/test-agent-123"
|
|
||||||
assert data.reviewed_at.tzinfo is not None
|
|
||||||
|
|
||||||
def test_agent_approval_data_without_timezone_raises_error(self):
|
|
||||||
"""Test that AgentApprovalData raises error without timezone."""
|
|
||||||
with pytest.raises(
|
|
||||||
ValidationError, match="datetime must have timezone information"
|
|
||||||
):
|
|
||||||
AgentApprovalData(
|
|
||||||
agent_name="Test Agent",
|
|
||||||
agent_id="test-agent-123",
|
|
||||||
agent_version=1,
|
|
||||||
reviewer_name="John Doe",
|
|
||||||
reviewer_email="john@example.com",
|
|
||||||
comments="Great agent, approved!",
|
|
||||||
reviewed_at=datetime.now(), # No timezone
|
|
||||||
store_url="https://app.autogpt.com/store/test-agent-123",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_agent_approval_data_with_empty_comments(self):
|
|
||||||
"""Test AgentApprovalData with empty comments."""
|
|
||||||
data = AgentApprovalData(
|
|
||||||
agent_name="Test Agent",
|
|
||||||
agent_id="test-agent-123",
|
|
||||||
agent_version=1,
|
|
||||||
reviewer_name="John Doe",
|
|
||||||
reviewer_email="john@example.com",
|
|
||||||
comments="", # Empty comments
|
|
||||||
reviewed_at=datetime.now(timezone.utc),
|
|
||||||
store_url="https://app.autogpt.com/store/test-agent-123",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert data.comments == ""
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgentRejectionData:
|
|
||||||
"""Test cases for AgentRejectionData model."""
|
|
||||||
|
|
||||||
def test_valid_agent_rejection_data(self):
|
|
||||||
"""Test creating valid AgentRejectionData."""
|
|
||||||
data = AgentRejectionData(
|
|
||||||
agent_name="Test Agent",
|
|
||||||
agent_id="test-agent-123",
|
|
||||||
agent_version=1,
|
|
||||||
reviewer_name="Jane Doe",
|
|
||||||
reviewer_email="jane@example.com",
|
|
||||||
comments="Please fix the security issues before resubmitting.",
|
|
||||||
reviewed_at=datetime.now(timezone.utc),
|
|
||||||
resubmit_url="https://app.autogpt.com/build/test-agent-123",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert data.agent_name == "Test Agent"
|
|
||||||
assert data.agent_id == "test-agent-123"
|
|
||||||
assert data.agent_version == 1
|
|
||||||
assert data.reviewer_name == "Jane Doe"
|
|
||||||
assert data.reviewer_email == "jane@example.com"
|
|
||||||
assert data.comments == "Please fix the security issues before resubmitting."
|
|
||||||
assert data.resubmit_url == "https://app.autogpt.com/build/test-agent-123"
|
|
||||||
assert data.reviewed_at.tzinfo is not None
|
|
||||||
|
|
||||||
def test_agent_rejection_data_without_timezone_raises_error(self):
|
|
||||||
"""Test that AgentRejectionData raises error without timezone."""
|
|
||||||
with pytest.raises(
|
|
||||||
ValidationError, match="datetime must have timezone information"
|
|
||||||
):
|
|
||||||
AgentRejectionData(
|
|
||||||
agent_name="Test Agent",
|
|
||||||
agent_id="test-agent-123",
|
|
||||||
agent_version=1,
|
|
||||||
reviewer_name="Jane Doe",
|
|
||||||
reviewer_email="jane@example.com",
|
|
||||||
comments="Please fix the security issues.",
|
|
||||||
reviewed_at=datetime.now(), # No timezone
|
|
||||||
resubmit_url="https://app.autogpt.com/build/test-agent-123",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_agent_rejection_data_with_long_comments(self):
|
|
||||||
"""Test AgentRejectionData with long comments."""
|
|
||||||
long_comment = "A" * 1000 # Very long comment
|
|
||||||
data = AgentRejectionData(
|
|
||||||
agent_name="Test Agent",
|
|
||||||
agent_id="test-agent-123",
|
|
||||||
agent_version=1,
|
|
||||||
reviewer_name="Jane Doe",
|
|
||||||
reviewer_email="jane@example.com",
|
|
||||||
comments=long_comment,
|
|
||||||
reviewed_at=datetime.now(timezone.utc),
|
|
||||||
resubmit_url="https://app.autogpt.com/build/test-agent-123",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert data.comments == long_comment
|
|
||||||
|
|
||||||
def test_model_serialization(self):
|
|
||||||
"""Test that models can be serialized and deserialized."""
|
|
||||||
original_data = AgentRejectionData(
|
|
||||||
agent_name="Test Agent",
|
|
||||||
agent_id="test-agent-123",
|
|
||||||
agent_version=1,
|
|
||||||
reviewer_name="Jane Doe",
|
|
||||||
reviewer_email="jane@example.com",
|
|
||||||
comments="Please fix the issues.",
|
|
||||||
reviewed_at=datetime.now(timezone.utc),
|
|
||||||
resubmit_url="https://app.autogpt.com/build/test-agent-123",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Serialize to dict
|
|
||||||
data_dict = original_data.model_dump()
|
|
||||||
|
|
||||||
# Deserialize back
|
|
||||||
restored_data = AgentRejectionData.model_validate(data_dict)
|
|
||||||
|
|
||||||
assert restored_data.agent_name == original_data.agent_name
|
|
||||||
assert restored_data.agent_id == original_data.agent_id
|
|
||||||
assert restored_data.agent_version == original_data.agent_version
|
|
||||||
assert restored_data.reviewer_name == original_data.reviewer_name
|
|
||||||
assert restored_data.reviewer_email == original_data.reviewer_email
|
|
||||||
assert restored_data.comments == original_data.comments
|
|
||||||
assert restored_data.reviewed_at == original_data.reviewed_at
|
|
||||||
assert restored_data.resubmit_url == original_data.resubmit_url
|
|
||||||
@@ -208,8 +208,6 @@ async def get_user_notification_preference(user_id: str) -> NotificationPreferen
|
|||||||
NotificationType.DAILY_SUMMARY: user.notifyOnDailySummary or False,
|
NotificationType.DAILY_SUMMARY: user.notifyOnDailySummary or False,
|
||||||
NotificationType.WEEKLY_SUMMARY: user.notifyOnWeeklySummary or False,
|
NotificationType.WEEKLY_SUMMARY: user.notifyOnWeeklySummary or False,
|
||||||
NotificationType.MONTHLY_SUMMARY: user.notifyOnMonthlySummary or False,
|
NotificationType.MONTHLY_SUMMARY: user.notifyOnMonthlySummary or False,
|
||||||
NotificationType.AGENT_APPROVED: user.notifyOnAgentApproved or False,
|
|
||||||
NotificationType.AGENT_REJECTED: user.notifyOnAgentRejected or False,
|
|
||||||
}
|
}
|
||||||
daily_limit = user.maxEmailsPerDay or 3
|
daily_limit = user.maxEmailsPerDay or 3
|
||||||
notification_preference = NotificationPreference(
|
notification_preference = NotificationPreference(
|
||||||
@@ -268,14 +266,6 @@ async def update_user_notification_preference(
|
|||||||
update_data["notifyOnMonthlySummary"] = data.preferences[
|
update_data["notifyOnMonthlySummary"] = data.preferences[
|
||||||
NotificationType.MONTHLY_SUMMARY
|
NotificationType.MONTHLY_SUMMARY
|
||||||
]
|
]
|
||||||
if NotificationType.AGENT_APPROVED in data.preferences:
|
|
||||||
update_data["notifyOnAgentApproved"] = data.preferences[
|
|
||||||
NotificationType.AGENT_APPROVED
|
|
||||||
]
|
|
||||||
if NotificationType.AGENT_REJECTED in data.preferences:
|
|
||||||
update_data["notifyOnAgentRejected"] = data.preferences[
|
|
||||||
NotificationType.AGENT_REJECTED
|
|
||||||
]
|
|
||||||
if data.daily_limit:
|
if data.daily_limit:
|
||||||
update_data["maxEmailsPerDay"] = data.daily_limit
|
update_data["maxEmailsPerDay"] = data.daily_limit
|
||||||
|
|
||||||
@@ -296,8 +286,6 @@ async def update_user_notification_preference(
|
|||||||
NotificationType.DAILY_SUMMARY: user.notifyOnDailySummary or True,
|
NotificationType.DAILY_SUMMARY: user.notifyOnDailySummary or True,
|
||||||
NotificationType.WEEKLY_SUMMARY: user.notifyOnWeeklySummary or True,
|
NotificationType.WEEKLY_SUMMARY: user.notifyOnWeeklySummary or True,
|
||||||
NotificationType.MONTHLY_SUMMARY: user.notifyOnMonthlySummary or True,
|
NotificationType.MONTHLY_SUMMARY: user.notifyOnMonthlySummary or True,
|
||||||
NotificationType.AGENT_APPROVED: user.notifyOnAgentApproved or True,
|
|
||||||
NotificationType.AGENT_REJECTED: user.notifyOnAgentRejected or True,
|
|
||||||
}
|
}
|
||||||
notification_preference = NotificationPreference(
|
notification_preference = NotificationPreference(
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
@@ -396,17 +384,3 @@ async def unsubscribe_user_by_token(token: str) -> None:
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise DatabaseError(f"Failed to unsubscribe user by token {token}: {e}") from e
|
raise DatabaseError(f"Failed to unsubscribe user by token {token}: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
async def update_user_timezone(user_id: str, timezone: str) -> User:
|
|
||||||
"""Update a user's timezone setting."""
|
|
||||||
try:
|
|
||||||
user = await PrismaUser.prisma().update(
|
|
||||||
where={"id": user_id},
|
|
||||||
data={"timezone": timezone},
|
|
||||||
)
|
|
||||||
if not user:
|
|
||||||
raise ValueError(f"User not found with ID: {user_id}")
|
|
||||||
return User.from_db(user)
|
|
||||||
except Exception as e:
|
|
||||||
raise DatabaseError(f"Failed to update timezone for user {user_id}: {e}") from e
|
|
||||||
|
|||||||
@@ -42,8 +42,6 @@ from backend.data.user import (
|
|||||||
get_user_notification_preference,
|
get_user_notification_preference,
|
||||||
update_user_integrations,
|
update_user_integrations,
|
||||||
)
|
)
|
||||||
from backend.server.v2.library.db import add_store_agent_to_library, list_library_agents
|
|
||||||
from backend.server.v2.store.db import get_store_agent_details, get_store_agents
|
|
||||||
from backend.util.service import (
|
from backend.util.service import (
|
||||||
AppService,
|
AppService,
|
||||||
AppServiceClient,
|
AppServiceClient,
|
||||||
@@ -147,14 +145,6 @@ class DatabaseManager(AppService):
|
|||||||
get_user_notification_oldest_message_in_batch
|
get_user_notification_oldest_message_in_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
# Library
|
|
||||||
list_library_agents = _(list_library_agents)
|
|
||||||
add_store_agent_to_library = _(add_store_agent_to_library)
|
|
||||||
|
|
||||||
# Store
|
|
||||||
get_store_agents = _(get_store_agents)
|
|
||||||
get_store_agent_details = _(get_store_agent_details)
|
|
||||||
|
|
||||||
# Summary data - async
|
# Summary data - async
|
||||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||||
|
|
||||||
@@ -183,20 +173,12 @@ class DatabaseManagerClient(AppServiceClient):
|
|||||||
spend_credits = _(d.spend_credits)
|
spend_credits = _(d.spend_credits)
|
||||||
get_credits = _(d.get_credits)
|
get_credits = _(d.get_credits)
|
||||||
|
|
||||||
|
# Summary data - async
|
||||||
|
get_user_execution_summary_data = _(d.get_user_execution_summary_data)
|
||||||
|
|
||||||
# Block error monitoring
|
# Block error monitoring
|
||||||
get_block_error_stats = _(d.get_block_error_stats)
|
get_block_error_stats = _(d.get_block_error_stats)
|
||||||
|
|
||||||
# User Emails
|
|
||||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
|
||||||
|
|
||||||
# Library
|
|
||||||
list_library_agents = _(d.list_library_agents)
|
|
||||||
add_store_agent_to_library = _(d.add_store_agent_to_library)
|
|
||||||
|
|
||||||
# Store
|
|
||||||
get_store_agents = _(d.get_store_agents)
|
|
||||||
get_store_agent_details = _(d.get_store_agent_details)
|
|
||||||
|
|
||||||
|
|
||||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||||
d = DatabaseManager
|
d = DatabaseManager
|
||||||
@@ -241,13 +223,5 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
d.get_user_notification_oldest_message_in_batch
|
d.get_user_notification_oldest_message_in_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
# Library
|
|
||||||
list_library_agents = d.list_library_agents
|
|
||||||
add_store_agent_to_library = d.add_store_agent_to_library
|
|
||||||
|
|
||||||
# Store
|
|
||||||
get_store_agents = d.get_store_agents
|
|
||||||
get_store_agent_details = d.get_store_agent_details
|
|
||||||
|
|
||||||
# Summary data
|
# Summary data
|
||||||
get_user_execution_summary_data = d.get_user_execution_summary_data
|
get_user_execution_summary_data = d.get_user_execution_summary_data
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from backend.data.notifications import (
|
|||||||
LowBalanceData,
|
LowBalanceData,
|
||||||
NotificationEventModel,
|
NotificationEventModel,
|
||||||
NotificationType,
|
NotificationType,
|
||||||
ZeroBalanceData,
|
|
||||||
)
|
)
|
||||||
from backend.data.rabbitmq import SyncRabbitMQ
|
from backend.data.rabbitmq import SyncRabbitMQ
|
||||||
from backend.executor.activity_status_generator import (
|
from backend.executor.activity_status_generator import (
|
||||||
@@ -52,7 +51,6 @@ from backend.data.execution import (
|
|||||||
GraphExecutionEntry,
|
GraphExecutionEntry,
|
||||||
NodeExecutionEntry,
|
NodeExecutionEntry,
|
||||||
NodeExecutionResult,
|
NodeExecutionResult,
|
||||||
UserContext,
|
|
||||||
)
|
)
|
||||||
from backend.data.graph import Link, Node
|
from backend.data.graph import Link, Node
|
||||||
from backend.executor.utils import (
|
from backend.executor.utils import (
|
||||||
@@ -76,7 +74,6 @@ from backend.util.clients import (
|
|||||||
get_database_manager_async_client,
|
get_database_manager_async_client,
|
||||||
get_database_manager_client,
|
get_database_manager_client,
|
||||||
get_execution_event_bus,
|
get_execution_event_bus,
|
||||||
get_notification_manager_client,
|
|
||||||
)
|
)
|
||||||
from backend.util.decorator import (
|
from backend.util.decorator import (
|
||||||
async_error_logged,
|
async_error_logged,
|
||||||
@@ -86,7 +83,6 @@ from backend.util.decorator import (
|
|||||||
)
|
)
|
||||||
from backend.util.file import clean_exec_files
|
from backend.util.file import clean_exec_files
|
||||||
from backend.util.logging import TruncatedLogger, configure_logging
|
from backend.util.logging import TruncatedLogger, configure_logging
|
||||||
from backend.util.metrics import DiscordChannel
|
|
||||||
from backend.util.process import AppProcess, set_service_name
|
from backend.util.process import AppProcess, set_service_name
|
||||||
from backend.util.retry import continuous_retry, func_retry
|
from backend.util.retry import continuous_retry, func_retry
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
@@ -194,9 +190,6 @@ async def execute_node(
|
|||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add user context from NodeExecutionEntry
|
|
||||||
extra_exec_kwargs["user_context"] = data.user_context
|
|
||||||
|
|
||||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||||
# changes during execution. ⚠️ This means a set of credentials can only be used by
|
# changes during execution. ⚠️ This means a set of credentials can only be used by
|
||||||
# one (running) block at a time; simultaneous execution of blocks using same
|
# one (running) block at a time; simultaneous execution of blocks using same
|
||||||
@@ -243,7 +236,6 @@ async def _enqueue_next_nodes(
|
|||||||
graph_id: str,
|
graph_id: str,
|
||||||
log_metadata: LogMetadata,
|
log_metadata: LogMetadata,
|
||||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||||
user_context: UserContext,
|
|
||||||
) -> list[NodeExecutionEntry]:
|
) -> list[NodeExecutionEntry]:
|
||||||
async def add_enqueued_execution(
|
async def add_enqueued_execution(
|
||||||
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
||||||
@@ -262,7 +254,6 @@ async def _enqueue_next_nodes(
|
|||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
block_id=block_id,
|
block_id=block_id,
|
||||||
inputs=data,
|
inputs=data,
|
||||||
user_context=user_context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
|
async def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
|
||||||
@@ -687,20 +678,19 @@ class ExecutionProcessor:
|
|||||||
self,
|
self,
|
||||||
node_exec: NodeExecutionEntry,
|
node_exec: NodeExecutionEntry,
|
||||||
execution_count: int,
|
execution_count: int,
|
||||||
) -> tuple[int, int]:
|
) -> int:
|
||||||
total_cost = 0
|
total_cost = 0
|
||||||
remaining_balance = 0
|
|
||||||
db_client = get_db_client()
|
db_client = get_db_client()
|
||||||
block = get_block(node_exec.block_id)
|
block = get_block(node_exec.block_id)
|
||||||
if not block:
|
if not block:
|
||||||
logger.error(f"Block {node_exec.block_id} not found.")
|
logger.error(f"Block {node_exec.block_id} not found.")
|
||||||
return total_cost, 0
|
return total_cost
|
||||||
|
|
||||||
cost, matching_filter = block_usage_cost(
|
cost, matching_filter = block_usage_cost(
|
||||||
block=block, input_data=node_exec.inputs
|
block=block, input_data=node_exec.inputs
|
||||||
)
|
)
|
||||||
if cost > 0:
|
if cost > 0:
|
||||||
remaining_balance = db_client.spend_credits(
|
db_client.spend_credits(
|
||||||
user_id=node_exec.user_id,
|
user_id=node_exec.user_id,
|
||||||
cost=cost,
|
cost=cost,
|
||||||
metadata=UsageTransactionMetadata(
|
metadata=UsageTransactionMetadata(
|
||||||
@@ -718,7 +708,7 @@ class ExecutionProcessor:
|
|||||||
|
|
||||||
cost, usage_count = execution_usage_cost(execution_count)
|
cost, usage_count = execution_usage_cost(execution_count)
|
||||||
if cost > 0:
|
if cost > 0:
|
||||||
remaining_balance = db_client.spend_credits(
|
db_client.spend_credits(
|
||||||
user_id=node_exec.user_id,
|
user_id=node_exec.user_id,
|
||||||
cost=cost,
|
cost=cost,
|
||||||
metadata=UsageTransactionMetadata(
|
metadata=UsageTransactionMetadata(
|
||||||
@@ -733,7 +723,7 @@ class ExecutionProcessor:
|
|||||||
)
|
)
|
||||||
total_cost += cost
|
total_cost += cost
|
||||||
|
|
||||||
return total_cost, remaining_balance
|
return total_cost
|
||||||
|
|
||||||
@time_measured
|
@time_measured
|
||||||
def _on_graph_execution(
|
def _on_graph_execution(
|
||||||
@@ -797,8 +787,7 @@ class ExecutionProcessor:
|
|||||||
ExecutionStatus.TERMINATED,
|
ExecutionStatus.TERMINATED,
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
node_entry = node_exec.to_node_execution_entry(graph_exec.user_context)
|
execution_queue.add(node_exec.to_node_execution_entry())
|
||||||
execution_queue.add(node_entry)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
# Main dispatch / polling loop -----------------------------
|
# Main dispatch / polling loop -----------------------------
|
||||||
@@ -816,19 +805,12 @@ class ExecutionProcessor:
|
|||||||
|
|
||||||
# Charge usage (may raise) ------------------------------
|
# Charge usage (may raise) ------------------------------
|
||||||
try:
|
try:
|
||||||
cost, remaining_balance = self._charge_usage(
|
cost = self._charge_usage(
|
||||||
node_exec=queued_node_exec,
|
node_exec=queued_node_exec,
|
||||||
execution_count=increment_execution_count(graph_exec.user_id),
|
execution_count=increment_execution_count(graph_exec.user_id),
|
||||||
)
|
)
|
||||||
with execution_stats_lock:
|
with execution_stats_lock:
|
||||||
execution_stats.cost += cost
|
execution_stats.cost += cost
|
||||||
# Check if we crossed the low balance threshold
|
|
||||||
self._handle_low_balance(
|
|
||||||
db_client=db_client,
|
|
||||||
user_id=graph_exec.user_id,
|
|
||||||
current_balance=remaining_balance,
|
|
||||||
transaction_cost=cost,
|
|
||||||
)
|
|
||||||
except InsufficientBalanceError as balance_error:
|
except InsufficientBalanceError as balance_error:
|
||||||
error = balance_error # Set error to trigger FAILED status
|
error = balance_error # Set error to trigger FAILED status
|
||||||
node_exec_id = queued_node_exec.node_exec_id
|
node_exec_id = queued_node_exec.node_exec_id
|
||||||
@@ -843,10 +825,11 @@ class ExecutionProcessor:
|
|||||||
status=ExecutionStatus.FAILED,
|
status=ExecutionStatus.FAILED,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._handle_insufficient_funds_notif(
|
self._handle_low_balance_notif(
|
||||||
db_client,
|
db_client,
|
||||||
graph_exec.user_id,
|
graph_exec.user_id,
|
||||||
graph_exec.graph_id,
|
graph_exec.graph_id,
|
||||||
|
execution_stats,
|
||||||
error,
|
error,
|
||||||
)
|
)
|
||||||
# Gracefully stop the execution loop
|
# Gracefully stop the execution loop
|
||||||
@@ -1069,7 +1052,6 @@ class ExecutionProcessor:
|
|||||||
db_client = get_db_async_client()
|
db_client = get_db_async_client()
|
||||||
|
|
||||||
log_metadata.debug(f"Enqueue nodes for {node_id}: {output}")
|
log_metadata.debug(f"Enqueue nodes for {node_id}: {output}")
|
||||||
|
|
||||||
for next_execution in await _enqueue_next_nodes(
|
for next_execution in await _enqueue_next_nodes(
|
||||||
db_client=db_client,
|
db_client=db_client,
|
||||||
node=output.node,
|
node=output.node,
|
||||||
@@ -1079,7 +1061,6 @@ class ExecutionProcessor:
|
|||||||
graph_id=graph_exec.graph_id,
|
graph_id=graph_exec.graph_id,
|
||||||
log_metadata=log_metadata,
|
log_metadata=log_metadata,
|
||||||
nodes_input_masks=nodes_input_masks,
|
nodes_input_masks=nodes_input_masks,
|
||||||
user_context=graph_exec.user_context,
|
|
||||||
):
|
):
|
||||||
execution_queue.add(next_execution)
|
execution_queue.add(next_execution)
|
||||||
|
|
||||||
@@ -1120,25 +1101,25 @@ class ExecutionProcessor:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_insufficient_funds_notif(
|
def _handle_low_balance_notif(
|
||||||
self,
|
self,
|
||||||
db_client: "DatabaseManagerClient",
|
db_client: "DatabaseManagerClient",
|
||||||
user_id: str,
|
user_id: str,
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
|
exec_stats: GraphExecutionStats,
|
||||||
e: InsufficientBalanceError,
|
e: InsufficientBalanceError,
|
||||||
):
|
):
|
||||||
shortfall = abs(e.amount) - e.balance
|
shortfall = e.balance - e.amount
|
||||||
metadata = db_client.get_graph_metadata(graph_id)
|
metadata = db_client.get_graph_metadata(graph_id)
|
||||||
base_url = (
|
base_url = (
|
||||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||||
)
|
)
|
||||||
|
|
||||||
queue_notification(
|
queue_notification(
|
||||||
NotificationEventModel(
|
NotificationEventModel(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
type=NotificationType.ZERO_BALANCE,
|
type=NotificationType.LOW_BALANCE,
|
||||||
data=ZeroBalanceData(
|
data=LowBalanceData(
|
||||||
current_balance=e.balance,
|
current_balance=exec_stats.cost,
|
||||||
billing_page_link=f"{base_url}/profile/credits",
|
billing_page_link=f"{base_url}/profile/credits",
|
||||||
shortfall=shortfall,
|
shortfall=shortfall,
|
||||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||||
@@ -1146,73 +1127,6 @@ class ExecutionProcessor:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
user_email = db_client.get_user_email_by_id(user_id)
|
|
||||||
|
|
||||||
alert_message = (
|
|
||||||
f"❌ **Insufficient Funds Alert**\n"
|
|
||||||
f"User: {user_email or user_id}\n"
|
|
||||||
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
|
|
||||||
f"Current balance: ${e.balance/100:.2f}\n"
|
|
||||||
f"Attempted cost: ${abs(e.amount)/100:.2f}\n"
|
|
||||||
f"Shortfall: ${abs(shortfall)/100:.2f}\n"
|
|
||||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
|
||||||
)
|
|
||||||
|
|
||||||
get_notification_manager_client().discord_system_alert(
|
|
||||||
alert_message, DiscordChannel.PRODUCT
|
|
||||||
)
|
|
||||||
except Exception as alert_error:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to send insufficient funds Discord alert: {alert_error}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_low_balance(
|
|
||||||
self,
|
|
||||||
db_client: "DatabaseManagerClient",
|
|
||||||
user_id: str,
|
|
||||||
current_balance: int,
|
|
||||||
transaction_cost: int,
|
|
||||||
):
|
|
||||||
"""Check and handle low balance scenarios after a transaction"""
|
|
||||||
LOW_BALANCE_THRESHOLD = settings.config.low_balance_threshold
|
|
||||||
|
|
||||||
balance_before = current_balance + transaction_cost
|
|
||||||
|
|
||||||
if (
|
|
||||||
current_balance < LOW_BALANCE_THRESHOLD
|
|
||||||
and balance_before >= LOW_BALANCE_THRESHOLD
|
|
||||||
):
|
|
||||||
base_url = (
|
|
||||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
|
||||||
)
|
|
||||||
queue_notification(
|
|
||||||
NotificationEventModel(
|
|
||||||
user_id=user_id,
|
|
||||||
type=NotificationType.LOW_BALANCE,
|
|
||||||
data=LowBalanceData(
|
|
||||||
current_balance=current_balance,
|
|
||||||
billing_page_link=f"{base_url}/profile/credits",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
user_email = db_client.get_user_email_by_id(user_id)
|
|
||||||
alert_message = (
|
|
||||||
f"⚠️ **Low Balance Alert**\n"
|
|
||||||
f"User: {user_email or user_id}\n"
|
|
||||||
f"Balance dropped below ${LOW_BALANCE_THRESHOLD/100:.2f}\n"
|
|
||||||
f"Current balance: ${current_balance/100:.2f}\n"
|
|
||||||
f"Transaction cost: ${transaction_cost/100:.2f}\n"
|
|
||||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
|
||||||
)
|
|
||||||
get_notification_manager_client().discord_system_alert(
|
|
||||||
alert_message, DiscordChannel.PRODUCT
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to send low balance Discord alert: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionManager(AppProcess):
|
class ExecutionManager(AppProcess):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@@ -1,149 +0,0 @@
|
|||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from prisma.enums import NotificationType
|
|
||||||
|
|
||||||
from backend.data.notifications import LowBalanceData
|
|
||||||
from backend.executor.manager import ExecutionProcessor
|
|
||||||
from backend.util.test import SpinTestServer
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
|
|
||||||
"""Test that _handle_low_balance triggers notification when crossing threshold."""
|
|
||||||
|
|
||||||
execution_processor = ExecutionProcessor()
|
|
||||||
user_id = "test-user-123"
|
|
||||||
current_balance = 400 # $4 - below $5 threshold
|
|
||||||
transaction_cost = 600 # $6 transaction
|
|
||||||
|
|
||||||
# Mock dependencies
|
|
||||||
with patch(
|
|
||||||
"backend.executor.manager.queue_notification"
|
|
||||||
) as mock_queue_notif, patch(
|
|
||||||
"backend.executor.manager.get_notification_manager_client"
|
|
||||||
) as mock_get_client, patch(
|
|
||||||
"backend.executor.manager.settings"
|
|
||||||
) as mock_settings:
|
|
||||||
|
|
||||||
# Setup mocks
|
|
||||||
mock_client = MagicMock()
|
|
||||||
mock_get_client.return_value = mock_client
|
|
||||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
|
||||||
mock_settings.config.frontend_base_url = "https://test.com"
|
|
||||||
|
|
||||||
# Create mock database client
|
|
||||||
mock_db_client = MagicMock()
|
|
||||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
|
||||||
|
|
||||||
# Test the low balance handler
|
|
||||||
execution_processor._handle_low_balance(
|
|
||||||
db_client=mock_db_client,
|
|
||||||
user_id=user_id,
|
|
||||||
current_balance=current_balance,
|
|
||||||
transaction_cost=transaction_cost,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify notification was queued
|
|
||||||
mock_queue_notif.assert_called_once()
|
|
||||||
notification_call = mock_queue_notif.call_args[0][0]
|
|
||||||
|
|
||||||
# Verify notification details
|
|
||||||
assert notification_call.type == NotificationType.LOW_BALANCE
|
|
||||||
assert notification_call.user_id == user_id
|
|
||||||
assert isinstance(notification_call.data, LowBalanceData)
|
|
||||||
assert notification_call.data.current_balance == current_balance
|
|
||||||
|
|
||||||
# Verify Discord alert was sent
|
|
||||||
mock_client.discord_system_alert.assert_called_once()
|
|
||||||
discord_message = mock_client.discord_system_alert.call_args[0][0]
|
|
||||||
assert "Low Balance Alert" in discord_message
|
|
||||||
assert "test@example.com" in discord_message
|
|
||||||
assert "$4.00" in discord_message
|
|
||||||
assert "$6.00" in discord_message
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_handle_low_balance_no_notification_when_not_crossing(
|
|
||||||
server: SpinTestServer,
|
|
||||||
):
|
|
||||||
"""Test that no notification is sent when not crossing the threshold."""
|
|
||||||
|
|
||||||
execution_processor = ExecutionProcessor()
|
|
||||||
user_id = "test-user-123"
|
|
||||||
current_balance = 600 # $6 - above $5 threshold
|
|
||||||
transaction_cost = (
|
|
||||||
100 # $1 transaction (balance before was $7, still above threshold)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock dependencies
|
|
||||||
with patch(
|
|
||||||
"backend.executor.manager.queue_notification"
|
|
||||||
) as mock_queue_notif, patch(
|
|
||||||
"backend.executor.manager.get_notification_manager_client"
|
|
||||||
) as mock_get_client, patch(
|
|
||||||
"backend.executor.manager.settings"
|
|
||||||
) as mock_settings:
|
|
||||||
|
|
||||||
# Setup mocks
|
|
||||||
mock_client = MagicMock()
|
|
||||||
mock_get_client.return_value = mock_client
|
|
||||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
|
||||||
|
|
||||||
# Create mock database client
|
|
||||||
mock_db_client = MagicMock()
|
|
||||||
|
|
||||||
# Test the low balance handler
|
|
||||||
execution_processor._handle_low_balance(
|
|
||||||
db_client=mock_db_client,
|
|
||||||
user_id=user_id,
|
|
||||||
current_balance=current_balance,
|
|
||||||
transaction_cost=transaction_cost,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify no notification was sent
|
|
||||||
mock_queue_notif.assert_not_called()
|
|
||||||
mock_client.discord_system_alert.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_handle_low_balance_no_duplicate_when_already_below(
|
|
||||||
server: SpinTestServer,
|
|
||||||
):
|
|
||||||
"""Test that no notification is sent when already below threshold."""
|
|
||||||
|
|
||||||
execution_processor = ExecutionProcessor()
|
|
||||||
user_id = "test-user-123"
|
|
||||||
current_balance = 300 # $3 - below $5 threshold
|
|
||||||
transaction_cost = (
|
|
||||||
100 # $1 transaction (balance before was $4, also below threshold)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock dependencies
|
|
||||||
with patch(
|
|
||||||
"backend.executor.manager.queue_notification"
|
|
||||||
) as mock_queue_notif, patch(
|
|
||||||
"backend.executor.manager.get_notification_manager_client"
|
|
||||||
) as mock_get_client, patch(
|
|
||||||
"backend.executor.manager.settings"
|
|
||||||
) as mock_settings:
|
|
||||||
|
|
||||||
# Setup mocks
|
|
||||||
mock_client = MagicMock()
|
|
||||||
mock_get_client.return_value = mock_client
|
|
||||||
mock_settings.config.low_balance_threshold = 500 # $5 threshold
|
|
||||||
|
|
||||||
# Create mock database client
|
|
||||||
mock_db_client = MagicMock()
|
|
||||||
|
|
||||||
# Test the low balance handler
|
|
||||||
execution_processor._handle_low_balance(
|
|
||||||
db_client=mock_db_client,
|
|
||||||
user_id=user_id,
|
|
||||||
current_balance=current_balance,
|
|
||||||
transaction_cost=transaction_cost,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify no notification was sent (user was already below threshold)
|
|
||||||
mock_queue_notif.assert_not_called()
|
|
||||||
mock_client.discord_system_alert.assert_not_called()
|
|
||||||
@@ -17,7 +17,6 @@ from apscheduler.jobstores.memory import MemoryJobStore
|
|||||||
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
|
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
|
||||||
from apscheduler.schedulers.background import BackgroundScheduler
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
from apscheduler.triggers.cron import CronTrigger
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
from apscheduler.util import ZoneInfo
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
from sqlalchemy import MetaData, create_engine
|
from sqlalchemy import MetaData, create_engine
|
||||||
@@ -304,7 +303,6 @@ class Scheduler(AppService):
|
|||||||
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
|
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
|
||||||
},
|
},
|
||||||
logger=apscheduler_logger,
|
logger=apscheduler_logger,
|
||||||
timezone=ZoneInfo("UTC"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.register_system_tasks:
|
if self.register_system_tasks:
|
||||||
@@ -408,8 +406,6 @@ class Scheduler(AppService):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Scheduling job for user {user_id} in UTC (cron: {cron})")
|
|
||||||
|
|
||||||
job_args = GraphExecutionJobArgs(
|
job_args = GraphExecutionJobArgs(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
@@ -422,12 +418,12 @@ class Scheduler(AppService):
|
|||||||
execute_graph,
|
execute_graph,
|
||||||
kwargs=job_args.model_dump(),
|
kwargs=job_args.model_dump(),
|
||||||
name=name,
|
name=name,
|
||||||
trigger=CronTrigger.from_crontab(cron, timezone="UTC"),
|
trigger=CronTrigger.from_crontab(cron),
|
||||||
jobstore=Jobstores.EXECUTION.value,
|
jobstore=Jobstores.EXECUTION.value,
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Added job {job.id} with cron schedule '{cron}' in UTC, input data: {input_data}"
|
f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}"
|
||||||
)
|
)
|
||||||
return GraphExecutionJobInfo.from_db(job_args, job)
|
return GraphExecutionJobInfo.from_db(job_args, job)
|
||||||
|
|
||||||
|
|||||||
@@ -18,12 +18,10 @@ from backend.data.execution import (
|
|||||||
ExecutionStatus,
|
ExecutionStatus,
|
||||||
GraphExecutionStats,
|
GraphExecutionStats,
|
||||||
GraphExecutionWithNodes,
|
GraphExecutionWithNodes,
|
||||||
UserContext,
|
|
||||||
)
|
)
|
||||||
from backend.data.graph import GraphModel, Node
|
from backend.data.graph import GraphModel, Node
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||||
from backend.data.user import get_user_by_id
|
|
||||||
from backend.util.clients import (
|
from backend.util.clients import (
|
||||||
get_async_execution_event_bus,
|
get_async_execution_event_bus,
|
||||||
get_async_execution_queue,
|
get_async_execution_queue,
|
||||||
@@ -36,27 +34,6 @@ from backend.util.mock import MockObject
|
|||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
from backend.util.type import convert
|
from backend.util.type import convert
|
||||||
|
|
||||||
|
|
||||||
async def get_user_context(user_id: str) -> UserContext:
|
|
||||||
"""
|
|
||||||
Get UserContext for a user, always returns a valid context with timezone.
|
|
||||||
Defaults to UTC if user has no timezone set.
|
|
||||||
"""
|
|
||||||
user_context = UserContext(timezone="UTC") # Default to UTC
|
|
||||||
try:
|
|
||||||
user = await get_user_by_id(user_id)
|
|
||||||
if user and user.timezone and user.timezone != "not-set":
|
|
||||||
user_context.timezone = user.timezone
|
|
||||||
logger.debug(f"Retrieved user context: timezone={user.timezone}")
|
|
||||||
else:
|
|
||||||
logger.debug("User has no timezone set, using UTC")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not fetch user timezone: {e}")
|
|
||||||
# Continue with UTC as default
|
|
||||||
|
|
||||||
return user_context
|
|
||||||
|
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
|
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
|
||||||
|
|
||||||
@@ -900,11 +877,8 @@ async def add_graph_execution(
|
|||||||
preset_id=preset_id,
|
preset_id=preset_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fetch user context for the graph execution
|
|
||||||
user_context = await get_user_context(user_id)
|
|
||||||
|
|
||||||
queue = await get_async_execution_queue()
|
queue = await get_async_execution_queue()
|
||||||
graph_exec_entry = graph_exec.to_graph_execution_entry(user_context)
|
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||||
if nodes_input_masks:
|
if nodes_input_masks:
|
||||||
graph_exec_entry.nodes_input_masks = nodes_input_masks
|
graph_exec_entry.nodes_input_masks = nodes_input_masks
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from backend.integrations.oauth.todoist import TodoistOAuthHandler
|
from backend.integrations.oauth.todoist import TodoistOAuthHandler
|
||||||
|
|
||||||
from .discord import DiscordOAuthHandler
|
|
||||||
from .github import GitHubOAuthHandler
|
from .github import GitHubOAuthHandler
|
||||||
from .google import GoogleOAuthHandler
|
from .google import GoogleOAuthHandler
|
||||||
from .notion import NotionOAuthHandler
|
from .notion import NotionOAuthHandler
|
||||||
@@ -16,7 +15,6 @@ if TYPE_CHECKING:
|
|||||||
# --8<-- [start:HANDLERS_BY_NAMEExample]
|
# --8<-- [start:HANDLERS_BY_NAMEExample]
|
||||||
# Build handlers dict with string keys for compatibility with SDK auto-registration
|
# Build handlers dict with string keys for compatibility with SDK auto-registration
|
||||||
_ORIGINAL_HANDLERS = [
|
_ORIGINAL_HANDLERS = [
|
||||||
DiscordOAuthHandler,
|
|
||||||
GitHubOAuthHandler,
|
GitHubOAuthHandler,
|
||||||
GoogleOAuthHandler,
|
GoogleOAuthHandler,
|
||||||
NotionOAuthHandler,
|
NotionOAuthHandler,
|
||||||
|
|||||||
@@ -1,175 +0,0 @@
|
|||||||
import time
|
|
||||||
from typing import Optional
|
|
||||||
from urllib.parse import urlencode
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
from .base import BaseOAuthHandler
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordOAuthHandler(BaseOAuthHandler):
|
|
||||||
"""
|
|
||||||
Discord OAuth2 handler implementation.
|
|
||||||
|
|
||||||
Based on the documentation at:
|
|
||||||
- https://discord.com/developers/docs/topics/oauth2
|
|
||||||
|
|
||||||
Discord OAuth2 tokens expire after 7 days by default and include refresh tokens.
|
|
||||||
"""
|
|
||||||
|
|
||||||
PROVIDER_NAME = ProviderName.DISCORD
|
|
||||||
DEFAULT_SCOPES = ["identify"] # Basic user information
|
|
||||||
|
|
||||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
|
||||||
self.client_id = client_id
|
|
||||||
self.client_secret = client_secret
|
|
||||||
self.redirect_uri = redirect_uri
|
|
||||||
self.auth_base_url = "https://discord.com/oauth2/authorize"
|
|
||||||
self.token_url = "https://discord.com/api/oauth2/token"
|
|
||||||
self.revoke_url = "https://discord.com/api/oauth2/token/revoke"
|
|
||||||
|
|
||||||
def get_login_url(
|
|
||||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
|
||||||
) -> str:
|
|
||||||
# Handle default scopes
|
|
||||||
scopes = self.handle_default_scopes(scopes)
|
|
||||||
|
|
||||||
params = {
|
|
||||||
"client_id": self.client_id,
|
|
||||||
"redirect_uri": self.redirect_uri,
|
|
||||||
"response_type": "code",
|
|
||||||
"scope": " ".join(scopes),
|
|
||||||
"state": state,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Discord supports PKCE
|
|
||||||
if code_challenge:
|
|
||||||
params["code_challenge"] = code_challenge
|
|
||||||
params["code_challenge_method"] = "S256"
|
|
||||||
|
|
||||||
return f"{self.auth_base_url}?{urlencode(params)}"
|
|
||||||
|
|
||||||
async def exchange_code_for_tokens(
|
|
||||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
|
||||||
) -> OAuth2Credentials:
|
|
||||||
params = {
|
|
||||||
"code": code,
|
|
||||||
"redirect_uri": self.redirect_uri,
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Include PKCE verifier if provided
|
|
||||||
if code_verifier:
|
|
||||||
params["code_verifier"] = code_verifier
|
|
||||||
|
|
||||||
return await self._request_tokens(params)
|
|
||||||
|
|
||||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
|
||||||
if not credentials.access_token:
|
|
||||||
raise ValueError("No access token to revoke")
|
|
||||||
|
|
||||||
# Discord requires client authentication for token revocation
|
|
||||||
data = {
|
|
||||||
"token": credentials.access_token.get_secret_value(),
|
|
||||||
"token_type_hint": "access_token",
|
|
||||||
}
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/x-www-form-urlencoded",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await Requests().post(
|
|
||||||
url=self.revoke_url,
|
|
||||||
data=data,
|
|
||||||
headers=headers,
|
|
||||||
auth=(self.client_id, self.client_secret),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Discord returns 200 OK for successful revocation
|
|
||||||
return response.status == 200
|
|
||||||
|
|
||||||
async def _refresh_tokens(
|
|
||||||
self, credentials: OAuth2Credentials
|
|
||||||
) -> OAuth2Credentials:
|
|
||||||
if not credentials.refresh_token:
|
|
||||||
return credentials
|
|
||||||
|
|
||||||
return await self._request_tokens(
|
|
||||||
{
|
|
||||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
|
||||||
"grant_type": "refresh_token",
|
|
||||||
},
|
|
||||||
current_credentials=credentials,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _request_tokens(
|
|
||||||
self,
|
|
||||||
params: dict[str, str],
|
|
||||||
current_credentials: Optional[OAuth2Credentials] = None,
|
|
||||||
) -> OAuth2Credentials:
|
|
||||||
request_body = {
|
|
||||||
"client_id": self.client_id,
|
|
||||||
"client_secret": self.client_secret,
|
|
||||||
**params,
|
|
||||||
}
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/x-www-form-urlencoded",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await Requests().post(
|
|
||||||
self.token_url, data=request_body, headers=headers
|
|
||||||
)
|
|
||||||
token_data: dict = response.json()
|
|
||||||
|
|
||||||
# Get username if this is a new token request
|
|
||||||
username = None
|
|
||||||
if "access_token" in token_data:
|
|
||||||
username = await self._request_username(token_data["access_token"])
|
|
||||||
|
|
||||||
now = int(time.time())
|
|
||||||
new_credentials = OAuth2Credentials(
|
|
||||||
provider=self.PROVIDER_NAME,
|
|
||||||
title=current_credentials.title if current_credentials else None,
|
|
||||||
username=username,
|
|
||||||
access_token=token_data["access_token"],
|
|
||||||
scopes=token_data.get("scope", "").split()
|
|
||||||
or (current_credentials.scopes if current_credentials else []),
|
|
||||||
refresh_token=token_data.get("refresh_token"),
|
|
||||||
# Discord tokens expire after expires_in seconds (typically 7 days)
|
|
||||||
access_token_expires_at=(
|
|
||||||
now + expires_in
|
|
||||||
if (expires_in := token_data.get("expires_in", None))
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
# Discord doesn't provide separate refresh token expiration
|
|
||||||
refresh_token_expires_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if current_credentials:
|
|
||||||
new_credentials.id = current_credentials.id
|
|
||||||
|
|
||||||
return new_credentials
|
|
||||||
|
|
||||||
async def _request_username(self, access_token: str) -> str | None:
|
|
||||||
"""
|
|
||||||
Fetch the username using the Discord OAuth2 @me endpoint.
|
|
||||||
"""
|
|
||||||
url = "https://discord.com/api/oauth2/@me"
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {access_token}",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await Requests().get(url, headers=headers)
|
|
||||||
|
|
||||||
if not response.ok:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Get user info from the response
|
|
||||||
data = response.json()
|
|
||||||
user_info = data.get("user", {})
|
|
||||||
|
|
||||||
# Return username (without discriminator)
|
|
||||||
return user_info.get("username")
|
|
||||||
@@ -29,7 +29,7 @@ from backend.data.user import generate_unsubscribe_link
|
|||||||
from backend.notifications.email import EmailSender
|
from backend.notifications.email import EmailSender
|
||||||
from backend.util.clients import get_database_manager_async_client
|
from backend.util.clients import get_database_manager_async_client
|
||||||
from backend.util.logging import TruncatedLogger
|
from backend.util.logging import TruncatedLogger
|
||||||
from backend.util.metrics import DiscordChannel, discord_send_alert
|
from backend.util.metrics import discord_send_alert
|
||||||
from backend.util.retry import continuous_retry
|
from backend.util.retry import continuous_retry
|
||||||
from backend.util.service import (
|
from backend.util.service import (
|
||||||
AppService,
|
AppService,
|
||||||
@@ -382,10 +382,8 @@ class NotificationManager(AppService):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@expose
|
@expose
|
||||||
async def discord_system_alert(
|
async def discord_system_alert(self, content: str):
|
||||||
self, content: str, channel: DiscordChannel = DiscordChannel.PLATFORM
|
await discord_send_alert(content)
|
||||||
):
|
|
||||||
await discord_send_alert(content, channel)
|
|
||||||
|
|
||||||
async def _queue_scheduled_notification(self, event: SummaryParamsEventModel):
|
async def _queue_scheduled_notification(self, event: SummaryParamsEventModel):
|
||||||
"""Queue a scheduled notification - exposed method for other services to call"""
|
"""Queue a scheduled notification - exposed method for other services to call"""
|
||||||
|
|||||||
@@ -1,73 +0,0 @@
|
|||||||
{# Agent Approved Notification Email Template #}
|
|
||||||
{#
|
|
||||||
Template variables:
|
|
||||||
data.agent_name: the name of the approved agent
|
|
||||||
data.agent_id: the ID of the agent
|
|
||||||
data.agent_version: the version of the agent
|
|
||||||
data.reviewer_name: the name of the reviewer who approved it
|
|
||||||
data.reviewer_email: the email of the reviewer
|
|
||||||
data.comments: comments from the reviewer
|
|
||||||
data.reviewed_at: when the agent was reviewed
|
|
||||||
data.store_url: URL to view the agent in the store
|
|
||||||
|
|
||||||
Subject: 🎉 Your agent '{{ data.agent_name }}' has been approved!
|
|
||||||
#}
|
|
||||||
|
|
||||||
{% block content %}
|
|
||||||
<h1 style="color: #28a745; font-size: 32px; font-weight: 700; margin: 0 0 24px 0; text-align: center;">
|
|
||||||
🎉 Congratulations!
|
|
||||||
</h1>
|
|
||||||
|
|
||||||
<p style="color: #586069; font-size: 18px; text-align: center; margin: 0 0 24px 0;">
|
|
||||||
Your agent <strong>'{{ data.agent_name }}'</strong> has been approved and is now live in the store!
|
|
||||||
</p>
|
|
||||||
|
|
||||||
<div style="height: 32px; background: transparent;"></div>
|
|
||||||
|
|
||||||
{% if data.comments %}
|
|
||||||
<div style="background: #d4edda; border: 1px solid #c3e6cb; border-radius: 8px; padding: 20px; margin: 0;">
|
|
||||||
<h3 style="color: #155724; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
|
|
||||||
💬 Creator feedback area
|
|
||||||
</h3>
|
|
||||||
<p style="color: #155724; margin: 0; font-size: 16px; line-height: 1.5;">
|
|
||||||
{{ data.comments }}
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div style="height: 40px; background: transparent;"></div>
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
<div style="background: #d1ecf1; border: 1px solid #bee5eb; border-radius: 8px; padding: 20px; margin: 0;">
|
|
||||||
<h3 style="color: #0c5460; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
|
|
||||||
What's Next?
|
|
||||||
</h3>
|
|
||||||
<ul style="color: #0c5460; margin: 0; padding-left: 18px; font-size: 16px; line-height: 1.6;">
|
|
||||||
<li>Your agent is now live and discoverable in the AutoGPT Store</li>
|
|
||||||
<li>Users can find, install, and run your agent</li>
|
|
||||||
<li>You can update your agent anytime by submitting a new version</li>
|
|
||||||
</ul>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div style="height: 32px; background: transparent;"></div>
|
|
||||||
|
|
||||||
<div style="text-align: center; margin: 24px 0;">
|
|
||||||
<a href="{{ data.store_url }}" style="display: inline-block; background: linear-gradient(135deg, #7c3aed 0%, #5b21b6 100%); color: black; text-decoration: none; padding: 14px 28px; border-radius: 6px; font-weight: 600; font-size: 16px;">
|
|
||||||
View Your Agent in Store
|
|
||||||
</a>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div style="height: 32px; background: transparent;"></div>
|
|
||||||
|
|
||||||
<div style="background: #fff3cd; border: 1px solid #ffeaa7; border-radius: 6px; padding: 16px; margin: 24px 0; text-align: center;">
|
|
||||||
<p style="margin: 0; color: #856404; font-size: 14px;">
|
|
||||||
<strong>💡 Pro Tip:</strong> Share your agent with the community! Post about it on social media, forums, or your blog to help more users discover and benefit from your creation.
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div style="height: 32px; background: transparent;"></div>
|
|
||||||
|
|
||||||
<p style="color: #6a737d; font-size: 14px; text-align: center; margin: 24px 0;">
|
|
||||||
Thank you for contributing to the AutoGPT ecosystem! 🚀
|
|
||||||
</p>
|
|
||||||
|
|
||||||
{% endblock %}
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
{# Agent Rejected Notification Email Template #}
|
|
||||||
{#
|
|
||||||
Template variables:
|
|
||||||
data.agent_name: the name of the rejected agent
|
|
||||||
data.agent_id: the ID of the agent
|
|
||||||
data.agent_version: the version of the agent
|
|
||||||
data.reviewer_name: the name of the reviewer who rejected it
|
|
||||||
data.reviewer_email: the email of the reviewer
|
|
||||||
data.comments: comments from the reviewer explaining the rejection
|
|
||||||
data.reviewed_at: when the agent was reviewed
|
|
||||||
data.resubmit_url: URL to resubmit the agent
|
|
||||||
|
|
||||||
Subject: Your agent '{{ data.agent_name }}' needs some updates
|
|
||||||
#}
|
|
||||||
|
|
||||||
|
|
||||||
{% block content %}
|
|
||||||
<h1 style="color: #d73a49; font-size: 32px; font-weight: 700; margin: 0 0 24px 0; text-align: center;">
|
|
||||||
📝 Review Complete
|
|
||||||
</h1>
|
|
||||||
|
|
||||||
<p style="color: #586069; font-size: 18px; text-align: center; margin: 0 0 24px 0;">
|
|
||||||
Your agent <strong>'{{ data.agent_name }}'</strong> needs some updates before approval.
|
|
||||||
</p>
|
|
||||||
|
|
||||||
<div style="height: 32px; background: transparent;"></div>
|
|
||||||
|
|
||||||
<div style="background: #f8d7da; border: 1px solid #f5c6cb; border-radius: 8px; padding: 20px; margin: 0 0 24px 0;">
|
|
||||||
<h3 style="color: #721c24; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
|
|
||||||
💬 Creator feedback area
|
|
||||||
</h3>
|
|
||||||
<p style="color: #721c24; margin: 0; font-size: 16px; line-height: 1.5;">
|
|
||||||
{{ data.comments }}
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div style="height: 40px; background: transparent;"></div>
|
|
||||||
|
|
||||||
<div style="background: #d4edda; border: 1px solid #c3e6cb; border-radius: 8px; padding: 20px; margin: 0 0 24px 0;">
|
|
||||||
<h3 style="color: #155724; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
|
|
||||||
☑ Steps to Resubmit:
|
|
||||||
</h3>
|
|
||||||
<ul style="color: #155724; margin: 0; padding-left: 18px; font-size: 16px; line-height: 1.6;">
|
|
||||||
<li>Review the feedback provided above carefully</li>
|
|
||||||
<li>Make the necessary updates to your agent</li>
|
|
||||||
<li>Test your agent thoroughly to ensure it works as expected</li>
|
|
||||||
<li>Submit your updated agent for review</li>
|
|
||||||
</ul>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div style="height: 32px; background: transparent;"></div>
|
|
||||||
|
|
||||||
<div style="background: #fff3cd; border: 1px solid #ffeaa7; border-radius: 6px; padding: 12px; margin: 0 0 24px 0; text-align: center;">
|
|
||||||
<p style="margin: 0; color: #856404; font-size: 14px;">
|
|
||||||
<strong>💡 Tip:</strong> Address all the points mentioned in the feedback to increase your chances of approval in the next review.
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div style="text-align: center; margin: 32px 0;">
|
|
||||||
<a href="{{ data.resubmit_url }}" style="display: inline-block; background: linear-gradient(135deg, #7c3aed 0%, #5b21b6 100%); color: black; text-decoration: none; padding: 14px 28px; border-radius: 6px; font-weight: 600; font-size: 16px;">
|
|
||||||
Update & Resubmit Agent
|
|
||||||
</a>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div style="background: #d1ecf1; border: 1px solid #bee5eb; border-radius: 6px; padding: 16px; margin: 24px 0;">
|
|
||||||
<p style="margin: 0; color: #0c5460; font-size: 14px; text-align: center;">
|
|
||||||
<strong>🌟 Don't Give Up!</strong> Many successful agents go through multiple iterations before approval. Our review team is here to help you succeed!
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div style="height: 32px; background: transparent;"></div>
|
|
||||||
|
|
||||||
<p style="color: #6a737d; font-size: 14px; text-align: center; margin: 32px 0 24px 0;">
|
|
||||||
We're excited to see your improved agent submission! 🚀
|
|
||||||
</p>
|
|
||||||
|
|
||||||
{% endblock %}
|
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
{# Low Balance Notification Email Template #}
|
{# Low Balance Notification Email Template #}
|
||||||
{# Template variables:
|
{# Template variables:
|
||||||
|
data.agent_name: the name of the agent
|
||||||
data.current_balance: the current balance of the user
|
data.current_balance: the current balance of the user
|
||||||
data.billing_page_link: the link to the billing page
|
data.billing_page_link: the link to the billing page
|
||||||
|
data.shortfall: the shortfall amount
|
||||||
#}
|
#}
|
||||||
|
|
||||||
<p style="
|
<p style="
|
||||||
@@ -23,7 +25,7 @@ data.billing_page_link: the link to the billing page
|
|||||||
margin-top: 0;
|
margin-top: 0;
|
||||||
margin-bottom: 20px;
|
margin-bottom: 20px;
|
||||||
">
|
">
|
||||||
Your account balance has dropped below the recommended threshold.
|
Your agent "<strong>{{ data.agent_name }}</strong>" has been stopped due to low balance.
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<div style="
|
<div style="
|
||||||
@@ -42,6 +44,15 @@ data.billing_page_link: the link to the billing page
|
|||||||
">
|
">
|
||||||
<strong>Current Balance:</strong> ${{ "{:.2f}".format((data.current_balance|float)/100) }}
|
<strong>Current Balance:</strong> ${{ "{:.2f}".format((data.current_balance|float)/100) }}
|
||||||
</p>
|
</p>
|
||||||
|
<p style="
|
||||||
|
font-family: 'Poppins', sans-serif;
|
||||||
|
color: #070629;
|
||||||
|
font-size: 16px;
|
||||||
|
margin-top: 0;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
">
|
||||||
|
<strong>Shortfall:</strong> ${{ "{:.2f}".format((data.shortfall|float)/100) }}
|
||||||
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
||||||
@@ -68,7 +79,7 @@ data.billing_page_link: the link to the billing page
|
|||||||
margin-top: 0;
|
margin-top: 0;
|
||||||
margin-bottom: 5px;
|
margin-bottom: 5px;
|
||||||
">
|
">
|
||||||
Your account requires additional credits to continue running agents. Please add credits to your account to avoid service interruption.
|
Your agent "<strong>{{ data.agent_name }}</strong>" requires additional credits to continue running. The current operation has been canceled until your balance is replenished.
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -99,5 +110,5 @@ data.billing_page_link: the link to the billing page
|
|||||||
margin-bottom: 10px;
|
margin-bottom: 10px;
|
||||||
font-style: italic;
|
font-style: italic;
|
||||||
">
|
">
|
||||||
This is an automated low balance notification. Consider adding credits soon to avoid service interruption.
|
This is an automated notification. Your agent is stopped and will need manually restarted unless set to trigger automatically.
|
||||||
</p>
|
</p>
|
||||||
|
|||||||
@@ -1,114 +0,0 @@
|
|||||||
{# Low Balance Notification Email Template #}
|
|
||||||
{# Template variables:
|
|
||||||
data.agent_name: the name of the agent
|
|
||||||
data.current_balance: the current balance of the user
|
|
||||||
data.billing_page_link: the link to the billing page
|
|
||||||
data.shortfall: the shortfall amount
|
|
||||||
#}
|
|
||||||
|
|
||||||
<p style="
|
|
||||||
font-family: 'Poppins', sans-serif;
|
|
||||||
color: #070629;
|
|
||||||
font-size: 16px;
|
|
||||||
line-height: 165%;
|
|
||||||
margin-top: 0;
|
|
||||||
margin-bottom: 10px;
|
|
||||||
">
|
|
||||||
<strong>Zero Balance Warning</strong>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
<p style="
|
|
||||||
font-family: 'Poppins', sans-serif;
|
|
||||||
color: #070629;
|
|
||||||
font-size: 16px;
|
|
||||||
line-height: 165%;
|
|
||||||
margin-top: 0;
|
|
||||||
margin-bottom: 20px;
|
|
||||||
">
|
|
||||||
Your agent "<strong>{{ data.agent_name }}</strong>" has been stopped due to low balance.
|
|
||||||
</p>
|
|
||||||
|
|
||||||
<div style="
|
|
||||||
margin-left: 15px;
|
|
||||||
margin-bottom: 20px;
|
|
||||||
padding: 15px;
|
|
||||||
border-left: 4px solid #5D23BB;
|
|
||||||
background-color: #f8f8ff;
|
|
||||||
">
|
|
||||||
<p style="
|
|
||||||
font-family: 'Poppins', sans-serif;
|
|
||||||
color: #070629;
|
|
||||||
font-size: 16px;
|
|
||||||
margin-top: 0;
|
|
||||||
margin-bottom: 10px;
|
|
||||||
">
|
|
||||||
<strong>Current Balance:</strong> ${{ "{:.2f}".format((data.current_balance|float)/100) }}
|
|
||||||
</p>
|
|
||||||
<p style="
|
|
||||||
font-family: 'Poppins', sans-serif;
|
|
||||||
color: #070629;
|
|
||||||
font-size: 16px;
|
|
||||||
margin-top: 0;
|
|
||||||
margin-bottom: 10px;
|
|
||||||
">
|
|
||||||
<strong>Shortfall:</strong> ${{ "{:.2f}".format((data.shortfall|float)/100) }}
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
|
|
||||||
<div style="
|
|
||||||
margin-left: 15px;
|
|
||||||
margin-bottom: 20px;
|
|
||||||
padding: 15px;
|
|
||||||
border-left: 4px solid #FF6B6B;
|
|
||||||
background-color: #FFF0F0;
|
|
||||||
">
|
|
||||||
<p style="
|
|
||||||
font-family: 'Poppins', sans-serif;
|
|
||||||
color: #070629;
|
|
||||||
font-size: 16px;
|
|
||||||
margin-top: 0;
|
|
||||||
margin-bottom: 10px;
|
|
||||||
">
|
|
||||||
<strong>Low Balance:</strong>
|
|
||||||
</p>
|
|
||||||
<p style="
|
|
||||||
font-family: 'Poppins', sans-serif;
|
|
||||||
color: #070629;
|
|
||||||
font-size: 16px;
|
|
||||||
margin-top: 0;
|
|
||||||
margin-bottom: 5px;
|
|
||||||
">
|
|
||||||
Your agent "<strong>{{ data.agent_name }}</strong>" requires additional credits to continue running. The current operation has been canceled until your balance is replenished.
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div style="
|
|
||||||
text-align: center;
|
|
||||||
margin: 30px 0;
|
|
||||||
">
|
|
||||||
<a href="{{ data.billing_page_link }}" style="
|
|
||||||
font-family: 'Poppins', sans-serif;
|
|
||||||
background-color: #5D23BB;
|
|
||||||
color: white;
|
|
||||||
padding: 12px 24px;
|
|
||||||
text-decoration: none;
|
|
||||||
border-radius: 4px;
|
|
||||||
font-weight: 500;
|
|
||||||
display: inline-block;
|
|
||||||
">
|
|
||||||
Manage Billing
|
|
||||||
</a>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<p style="
|
|
||||||
font-family: 'Poppins', sans-serif;
|
|
||||||
color: #070629;
|
|
||||||
font-size: 16px;
|
|
||||||
line-height: 150%;
|
|
||||||
margin-top: 30px;
|
|
||||||
margin-bottom: 10px;
|
|
||||||
font-style: italic;
|
|
||||||
">
|
|
||||||
This is an automated notification. Your agent is stopped and will need manually restarted unless set to trigger automatically.
|
|
||||||
</p>
|
|
||||||
@@ -14,8 +14,6 @@ from backend.data.model import (
|
|||||||
CredentialsField,
|
CredentialsField,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
CredentialsType,
|
CredentialsType,
|
||||||
OAuth2Credentials,
|
|
||||||
UserPasswordCredentials,
|
|
||||||
)
|
)
|
||||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||||
@@ -106,39 +104,14 @@ class Provider:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_test_credentials(self) -> Credentials:
|
def get_test_credentials(self) -> Credentials:
|
||||||
"""Get test credentials for the provider based on supported auth types."""
|
"""Get test credentials for the provider."""
|
||||||
test_id = str(self.test_credentials_uuid)
|
return APIKeyCredentials(
|
||||||
|
id=str(self.test_credentials_uuid),
|
||||||
# Return credentials based on the first supported auth type
|
provider=self.name,
|
||||||
if "user_password" in self.supported_auth_types:
|
api_key=SecretStr("mock-api-key"),
|
||||||
return UserPasswordCredentials(
|
title=f"Mock {self.name.title()} API key",
|
||||||
id=test_id,
|
expires_at=None,
|
||||||
provider=self.name,
|
)
|
||||||
username=SecretStr(f"mock-{self.name}-username"),
|
|
||||||
password=SecretStr(f"mock-{self.name}-password"),
|
|
||||||
title=f"Mock {self.name.title()} credentials",
|
|
||||||
)
|
|
||||||
elif "oauth2" in self.supported_auth_types:
|
|
||||||
return OAuth2Credentials(
|
|
||||||
id=test_id,
|
|
||||||
provider=self.name,
|
|
||||||
username=f"mock-{self.name}-username",
|
|
||||||
access_token=SecretStr(f"mock-{self.name}-access-token"),
|
|
||||||
access_token_expires_at=None,
|
|
||||||
refresh_token=SecretStr(f"mock-{self.name}-refresh-token"),
|
|
||||||
refresh_token_expires_at=None,
|
|
||||||
scopes=[f"mock-{self.name}-scope"],
|
|
||||||
title=f"Mock {self.name.title()} OAuth credentials",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Default to API key credentials
|
|
||||||
return APIKeyCredentials(
|
|
||||||
id=test_id,
|
|
||||||
provider=self.name,
|
|
||||||
api_key=SecretStr(f"mock-{self.name}-api-key"),
|
|
||||||
title=f"Mock {self.name.title()} API key",
|
|
||||||
expires_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_api(self, credentials: Credentials) -> Any:
|
def get_api(self, credentials: Credentials) -> Any:
|
||||||
"""Get API client instance for the given credentials."""
|
"""Get API client instance for the given credentials."""
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import pydantic
|
|||||||
|
|
||||||
from backend.data.api_key import APIKeyPermission, APIKeyWithoutHash
|
from backend.data.api_key import APIKeyPermission, APIKeyWithoutHash
|
||||||
from backend.data.graph import Graph
|
from backend.data.graph import Graph
|
||||||
from backend.util.timezone_name import TimeZoneName
|
|
||||||
|
|
||||||
|
|
||||||
class WSMethod(enum.Enum):
|
class WSMethod(enum.Enum):
|
||||||
@@ -71,12 +70,3 @@ class UploadFileResponse(pydantic.BaseModel):
|
|||||||
size: int
|
size: int
|
||||||
content_type: str
|
content_type: str
|
||||||
expires_in_hours: int
|
expires_in_hours: int
|
||||||
|
|
||||||
|
|
||||||
class TimezoneResponse(pydantic.BaseModel):
|
|
||||||
# Allow "not-set" as a special value, or any valid IANA timezone
|
|
||||||
timezone: TimeZoneName | str
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateTimezoneRequest(pydantic.BaseModel):
|
|
||||||
timezone: TimeZoneName
|
|
||||||
|
|||||||
@@ -20,8 +20,6 @@ import backend.server.routers.postmark.postmark
|
|||||||
import backend.server.routers.v1
|
import backend.server.routers.v1
|
||||||
import backend.server.v2.admin.credit_admin_routes
|
import backend.server.v2.admin.credit_admin_routes
|
||||||
import backend.server.v2.admin.store_admin_routes
|
import backend.server.v2.admin.store_admin_routes
|
||||||
import backend.server.v2.builder
|
|
||||||
import backend.server.v2.builder.routes
|
|
||||||
import backend.server.v2.library.db
|
import backend.server.v2.library.db
|
||||||
import backend.server.v2.library.model
|
import backend.server.v2.library.model
|
||||||
import backend.server.v2.library.routes
|
import backend.server.v2.library.routes
|
||||||
@@ -197,9 +195,6 @@ app.include_router(backend.server.routers.v1.v1_router, tags=["v1"], prefix="/ap
|
|||||||
app.include_router(
|
app.include_router(
|
||||||
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
|
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
|
||||||
)
|
)
|
||||||
app.include_router(
|
|
||||||
backend.server.v2.builder.routes.router, tags=["v2"], prefix="/api/builder"
|
|
||||||
)
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.server.v2.admin.store_admin_routes.router,
|
backend.server.v2.admin.store_admin_routes.router,
|
||||||
tags=["v2", "admin"],
|
tags=["v2", "admin"],
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from fastapi import (
|
|||||||
File,
|
File,
|
||||||
HTTPException,
|
HTTPException,
|
||||||
Path,
|
Path,
|
||||||
Query,
|
|
||||||
Request,
|
Request,
|
||||||
Response,
|
Response,
|
||||||
UploadFile,
|
UploadFile,
|
||||||
@@ -61,11 +60,9 @@ from backend.data.onboarding import (
|
|||||||
)
|
)
|
||||||
from backend.data.user import (
|
from backend.data.user import (
|
||||||
get_or_create_user,
|
get_or_create_user,
|
||||||
get_user_by_id,
|
|
||||||
get_user_notification_preference,
|
get_user_notification_preference,
|
||||||
update_user_email,
|
update_user_email,
|
||||||
update_user_notification_preference,
|
update_user_notification_preference,
|
||||||
update_user_timezone,
|
|
||||||
)
|
)
|
||||||
from backend.executor import scheduler
|
from backend.executor import scheduler
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
@@ -80,21 +77,15 @@ from backend.server.model import (
|
|||||||
ExecuteGraphResponse,
|
ExecuteGraphResponse,
|
||||||
RequestTopUp,
|
RequestTopUp,
|
||||||
SetGraphActiveVersion,
|
SetGraphActiveVersion,
|
||||||
TimezoneResponse,
|
|
||||||
UpdatePermissionsRequest,
|
UpdatePermissionsRequest,
|
||||||
UpdateTimezoneRequest,
|
|
||||||
UploadFileResponse,
|
UploadFileResponse,
|
||||||
)
|
)
|
||||||
from backend.server.utils import get_user_id
|
from backend.server.utils import get_user_id
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||||
|
from backend.util.feature_flag import feature_flag
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
from backend.util.timezone_utils import (
|
|
||||||
convert_cron_to_utc,
|
|
||||||
convert_utc_time_to_user_timezone,
|
|
||||||
get_user_timezone_or_utc,
|
|
||||||
)
|
|
||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
|
||||||
|
|
||||||
@@ -158,35 +149,6 @@ async def update_user_email_route(
|
|||||||
return {"email": email}
|
return {"email": email}
|
||||||
|
|
||||||
|
|
||||||
@v1_router.get(
|
|
||||||
"/auth/user/timezone",
|
|
||||||
summary="Get user timezone",
|
|
||||||
tags=["auth"],
|
|
||||||
dependencies=[Depends(auth_middleware)],
|
|
||||||
)
|
|
||||||
async def get_user_timezone_route(
|
|
||||||
user_data: dict = Depends(auth_middleware),
|
|
||||||
) -> TimezoneResponse:
|
|
||||||
"""Get user timezone setting."""
|
|
||||||
user = await get_or_create_user(user_data)
|
|
||||||
return TimezoneResponse(timezone=user.timezone)
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.post(
|
|
||||||
"/auth/user/timezone",
|
|
||||||
summary="Update user timezone",
|
|
||||||
tags=["auth"],
|
|
||||||
dependencies=[Depends(auth_middleware)],
|
|
||||||
response_model=TimezoneResponse,
|
|
||||||
)
|
|
||||||
async def update_user_timezone_route(
|
|
||||||
user_id: Annotated[str, Depends(get_user_id)], request: UpdateTimezoneRequest
|
|
||||||
) -> TimezoneResponse:
|
|
||||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
|
||||||
user = await update_user_timezone(user_id, str(request.timezone))
|
|
||||||
return TimezoneResponse(timezone=user.timezone)
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
"/auth/user/preferences",
|
"/auth/user/preferences",
|
||||||
summary="Get notification preferences",
|
summary="Get notification preferences",
|
||||||
@@ -858,11 +820,11 @@ async def _stop_graph_run(
|
|||||||
|
|
||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/executions",
|
path="/executions",
|
||||||
summary="List all executions",
|
summary="Get all executions",
|
||||||
tags=["graphs"],
|
tags=["graphs"],
|
||||||
dependencies=[Depends(auth_middleware)],
|
dependencies=[Depends(auth_middleware)],
|
||||||
)
|
)
|
||||||
async def list_graphs_executions(
|
async def get_graphs_executions(
|
||||||
user_id: Annotated[str, Depends(get_user_id)],
|
user_id: Annotated[str, Depends(get_user_id)],
|
||||||
) -> list[execution_db.GraphExecutionMeta]:
|
) -> list[execution_db.GraphExecutionMeta]:
|
||||||
return await execution_db.get_graph_executions(user_id=user_id)
|
return await execution_db.get_graph_executions(user_id=user_id)
|
||||||
@@ -870,24 +832,15 @@ async def list_graphs_executions(
|
|||||||
|
|
||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/graphs/{graph_id}/executions",
|
path="/graphs/{graph_id}/executions",
|
||||||
summary="List graph executions",
|
summary="Get graph executions",
|
||||||
tags=["graphs"],
|
tags=["graphs"],
|
||||||
dependencies=[Depends(auth_middleware)],
|
dependencies=[Depends(auth_middleware)],
|
||||||
)
|
)
|
||||||
async def list_graph_executions(
|
async def get_graph_executions(
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
user_id: Annotated[str, Depends(get_user_id)],
|
user_id: Annotated[str, Depends(get_user_id)],
|
||||||
page: int = Query(1, ge=1, description="Page number (1-indexed)"),
|
) -> list[execution_db.GraphExecutionMeta]:
|
||||||
page_size: int = Query(
|
return await execution_db.get_graph_executions(graph_id=graph_id, user_id=user_id)
|
||||||
25, ge=1, le=100, description="Number of executions per page"
|
|
||||||
),
|
|
||||||
) -> execution_db.GraphExecutionsPaginated:
|
|
||||||
return await execution_db.get_graph_executions_paginated(
|
|
||||||
graph_id=graph_id,
|
|
||||||
user_id=user_id,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
@@ -971,36 +924,16 @@ async def create_graph_execution_schedule(
|
|||||||
detail=f"Graph #{graph_id} v{schedule_params.graph_version} not found.",
|
detail=f"Graph #{graph_id} v{schedule_params.graph_version} not found.",
|
||||||
)
|
)
|
||||||
|
|
||||||
user = await get_user_by_id(user_id)
|
return await get_scheduler_client().add_execution_schedule(
|
||||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
|
||||||
|
|
||||||
# Convert cron expression from user timezone to UTC
|
|
||||||
try:
|
|
||||||
utc_cron = convert_cron_to_utc(schedule_params.cron, user_timezone)
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Invalid cron expression for timezone {user_timezone}: {e}",
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await get_scheduler_client().add_execution_schedule(
|
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
graph_version=graph.version,
|
graph_version=graph.version,
|
||||||
name=schedule_params.name,
|
name=schedule_params.name,
|
||||||
cron=utc_cron, # Send UTC cron to scheduler
|
cron=schedule_params.cron,
|
||||||
input_data=schedule_params.inputs,
|
input_data=schedule_params.inputs,
|
||||||
input_credentials=schedule_params.credentials,
|
input_credentials=schedule_params.credentials,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert the next_run_time back to user timezone for display
|
|
||||||
if result.next_run_time:
|
|
||||||
result.next_run_time = convert_utc_time_to_user_timezone(
|
|
||||||
result.next_run_time, user_timezone
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/graphs/{graph_id}/schedules",
|
path="/graphs/{graph_id}/schedules",
|
||||||
@@ -1012,24 +945,11 @@ async def list_graph_execution_schedules(
|
|||||||
user_id: Annotated[str, Depends(get_user_id)],
|
user_id: Annotated[str, Depends(get_user_id)],
|
||||||
graph_id: str = Path(),
|
graph_id: str = Path(),
|
||||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||||
schedules = await get_scheduler_client().get_execution_schedules(
|
return await get_scheduler_client().get_execution_schedules(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get user timezone for conversion
|
|
||||||
user = await get_user_by_id(user_id)
|
|
||||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
|
||||||
|
|
||||||
# Convert next_run_time to user timezone for display
|
|
||||||
for schedule in schedules:
|
|
||||||
if schedule.next_run_time:
|
|
||||||
schedule.next_run_time = convert_utc_time_to_user_timezone(
|
|
||||||
schedule.next_run_time, user_timezone
|
|
||||||
)
|
|
||||||
|
|
||||||
return schedules
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/schedules",
|
path="/schedules",
|
||||||
@@ -1040,20 +960,7 @@ async def list_graph_execution_schedules(
|
|||||||
async def list_all_graphs_execution_schedules(
|
async def list_all_graphs_execution_schedules(
|
||||||
user_id: Annotated[str, Depends(get_user_id)],
|
user_id: Annotated[str, Depends(get_user_id)],
|
||||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||||
schedules = await get_scheduler_client().get_execution_schedules(user_id=user_id)
|
return await get_scheduler_client().get_execution_schedules(user_id=user_id)
|
||||||
|
|
||||||
# Get user timezone for conversion
|
|
||||||
user = await get_user_by_id(user_id)
|
|
||||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
|
||||||
|
|
||||||
# Convert UTC next_run_time to user timezone for display
|
|
||||||
for schedule in schedules:
|
|
||||||
if schedule.next_run_time:
|
|
||||||
schedule.next_run_time = convert_utc_time_to_user_timezone(
|
|
||||||
schedule.next_run_time, user_timezone
|
|
||||||
)
|
|
||||||
|
|
||||||
return schedules
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.delete(
|
@v1_router.delete(
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class AutoModManager:
|
|||||||
f"Moderating inputs for graph execution {graph_exec.graph_exec_id}"
|
f"Moderating inputs for graph execution {graph_exec.graph_exec_id}"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
moderation_passed, content_id = await self._moderate_content(
|
moderation_passed = await self._moderate_content(
|
||||||
content,
|
content,
|
||||||
{
|
{
|
||||||
"user_id": graph_exec.user_id,
|
"user_id": graph_exec.user_id,
|
||||||
@@ -99,7 +99,7 @@ class AutoModManager:
|
|||||||
)
|
)
|
||||||
# Update node statuses for frontend display before raising error
|
# Update node statuses for frontend display before raising error
|
||||||
await self._update_failed_nodes_for_moderation(
|
await self._update_failed_nodes_for_moderation(
|
||||||
db_client, graph_exec.graph_exec_id, "input", content_id
|
db_client, graph_exec.graph_exec_id, "input"
|
||||||
)
|
)
|
||||||
|
|
||||||
return ModerationError(
|
return ModerationError(
|
||||||
@@ -107,7 +107,6 @@ class AutoModManager:
|
|||||||
user_id=graph_exec.user_id,
|
user_id=graph_exec.user_id,
|
||||||
graph_exec_id=graph_exec.graph_exec_id,
|
graph_exec_id=graph_exec.graph_exec_id,
|
||||||
moderation_type="input",
|
moderation_type="input",
|
||||||
content_id=content_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@@ -168,7 +167,7 @@ class AutoModManager:
|
|||||||
# Run moderation
|
# Run moderation
|
||||||
logger.warning(f"Moderating outputs for graph execution {graph_exec_id}")
|
logger.warning(f"Moderating outputs for graph execution {graph_exec_id}")
|
||||||
try:
|
try:
|
||||||
moderation_passed, content_id = await self._moderate_content(
|
moderation_passed = await self._moderate_content(
|
||||||
content,
|
content,
|
||||||
{
|
{
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
@@ -182,7 +181,7 @@ class AutoModManager:
|
|||||||
logger.warning(f"Moderation failed for graph execution {graph_exec_id}")
|
logger.warning(f"Moderation failed for graph execution {graph_exec_id}")
|
||||||
# Update node statuses for frontend display before raising error
|
# Update node statuses for frontend display before raising error
|
||||||
await self._update_failed_nodes_for_moderation(
|
await self._update_failed_nodes_for_moderation(
|
||||||
db_client, graph_exec_id, "output", content_id
|
db_client, graph_exec_id, "output"
|
||||||
)
|
)
|
||||||
|
|
||||||
return ModerationError(
|
return ModerationError(
|
||||||
@@ -190,7 +189,6 @@ class AutoModManager:
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
graph_exec_id=graph_exec_id,
|
graph_exec_id=graph_exec_id,
|
||||||
moderation_type="output",
|
moderation_type="output",
|
||||||
content_id=content_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@@ -214,7 +212,6 @@ class AutoModManager:
|
|||||||
db_client: "DatabaseManagerAsyncClient",
|
db_client: "DatabaseManagerAsyncClient",
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
moderation_type: Literal["input", "output"],
|
moderation_type: Literal["input", "output"],
|
||||||
content_id: str | None = None,
|
|
||||||
):
|
):
|
||||||
"""Update node execution statuses for frontend display when moderation fails"""
|
"""Update node execution statuses for frontend display when moderation fails"""
|
||||||
# Import here to avoid circular imports
|
# Import here to avoid circular imports
|
||||||
@@ -239,11 +236,6 @@ class AutoModManager:
|
|||||||
if not executions_to_update:
|
if not executions_to_update:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create error message with content_id if available
|
|
||||||
error_message = "Failed due to content moderation"
|
|
||||||
if content_id:
|
|
||||||
error_message += f" (Moderation ID: {content_id})"
|
|
||||||
|
|
||||||
# Prepare database update tasks
|
# Prepare database update tasks
|
||||||
exec_updates = []
|
exec_updates = []
|
||||||
for exec_entry in executions_to_update:
|
for exec_entry in executions_to_update:
|
||||||
@@ -253,11 +245,11 @@ class AutoModManager:
|
|||||||
|
|
||||||
if exec_entry.input_data:
|
if exec_entry.input_data:
|
||||||
for name in exec_entry.input_data.keys():
|
for name in exec_entry.input_data.keys():
|
||||||
cleared_inputs[name] = [error_message]
|
cleared_inputs[name] = ["Failed due to content moderation"]
|
||||||
|
|
||||||
if exec_entry.output_data:
|
if exec_entry.output_data:
|
||||||
for name in exec_entry.output_data.keys():
|
for name in exec_entry.output_data.keys():
|
||||||
cleared_outputs[name] = [error_message]
|
cleared_outputs[name] = ["Failed due to content moderation"]
|
||||||
|
|
||||||
# Add update task to list
|
# Add update task to list
|
||||||
exec_updates.append(
|
exec_updates.append(
|
||||||
@@ -265,7 +257,7 @@ class AutoModManager:
|
|||||||
exec_entry.node_exec_id,
|
exec_entry.node_exec_id,
|
||||||
status=ExecutionStatus.FAILED,
|
status=ExecutionStatus.FAILED,
|
||||||
stats={
|
stats={
|
||||||
"error": error_message,
|
"error": "Failed due to content moderation",
|
||||||
"cleared_inputs": cleared_inputs,
|
"cleared_inputs": cleared_inputs,
|
||||||
"cleared_outputs": cleared_outputs,
|
"cleared_outputs": cleared_outputs,
|
||||||
},
|
},
|
||||||
@@ -283,15 +275,12 @@ class AutoModManager:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _moderate_content(
|
async def _moderate_content(self, content: str, metadata: dict[str, Any]) -> bool:
|
||||||
self, content: str, metadata: dict[str, Any]
|
|
||||||
) -> tuple[bool, str | None]:
|
|
||||||
"""Moderate content using AutoMod API
|
"""Moderate content using AutoMod API
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (approval_status, content_id)
|
True: Content approved or timeout occurred
|
||||||
- approval_status: True if approved or timeout occurred, False if rejected
|
False: Content rejected by moderation
|
||||||
- content_id: Reference ID from moderation API, or None if not available
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
asyncio.TimeoutError: When moderation times out (should be bypassed)
|
asyncio.TimeoutError: When moderation times out (should be bypassed)
|
||||||
@@ -309,12 +298,12 @@ class AutoModManager:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Content approved for {metadata.get('graph_exec_id', 'unknown')}"
|
f"Content approved for {metadata.get('graph_exec_id', 'unknown')}"
|
||||||
)
|
)
|
||||||
return True, response.content_id
|
return True
|
||||||
else:
|
else:
|
||||||
reasons = [r.reason for r in response.moderation_results if r.reason]
|
reasons = [r.reason for r in response.moderation_results if r.reason]
|
||||||
error_msg = f"Content rejected by AutoMod: {'; '.join(reasons)}"
|
error_msg = f"Content rejected by AutoMod: {'; '.join(reasons)}"
|
||||||
logger.warning(f"Content rejected: {error_msg}")
|
logger.warning(f"Content rejected: {error_msg}")
|
||||||
return False, response.content_id
|
return False
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# Re-raise timeout to be handled by calling methods
|
# Re-raise timeout to be handled by calling methods
|
||||||
@@ -324,7 +313,7 @@ class AutoModManager:
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"AutoMod moderation error: {e}")
|
logger.error(f"AutoMod moderation error: {e}")
|
||||||
return self.config.fail_open, None
|
return self.config.fail_open
|
||||||
|
|
||||||
async def _make_request(self, request_data: AutoModRequest) -> AutoModResponse:
|
async def _make_request(self, request_data: AutoModRequest) -> AutoModResponse:
|
||||||
"""Make HTTP request to AutoMod API using the standard request utility"""
|
"""Make HTTP request to AutoMod API using the standard request utility"""
|
||||||
|
|||||||
@@ -26,9 +26,6 @@ class AutoModResponse(BaseModel):
|
|||||||
"""Response model for AutoMod API"""
|
"""Response model for AutoMod API"""
|
||||||
|
|
||||||
success: bool = Field(..., description="Whether the request was successful")
|
success: bool = Field(..., description="Whether the request was successful")
|
||||||
content_id: str = Field(
|
|
||||||
..., description="Unique reference ID for this moderation request"
|
|
||||||
)
|
|
||||||
status: str = Field(
|
status: str = Field(
|
||||||
..., description="Overall status: 'approved', 'rejected', 'flagged', 'pending'"
|
..., description="Overall status: 'approved', 'rejected', 'flagged', 'pending'"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,376 +0,0 @@
|
|||||||
import functools
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
|
|
||||||
import prisma
|
|
||||||
|
|
||||||
import backend.data.block
|
|
||||||
from backend.blocks import load_all_blocks
|
|
||||||
from backend.blocks.llm import LlmModel
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockSchema
|
|
||||||
from backend.data.credit import get_block_costs
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.server.v2.builder.model import (
|
|
||||||
BlockCategoryResponse,
|
|
||||||
BlockData,
|
|
||||||
BlockResponse,
|
|
||||||
BlockType,
|
|
||||||
CountResponse,
|
|
||||||
Provider,
|
|
||||||
ProviderResponse,
|
|
||||||
SearchBlocksResponse,
|
|
||||||
)
|
|
||||||
from backend.util.models import Pagination
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
|
||||||
_static_counts_cache: dict | None = None
|
|
||||||
_suggested_blocks: list[BlockData] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse]:
|
|
||||||
categories: dict[BlockCategory, BlockCategoryResponse] = {}
|
|
||||||
|
|
||||||
for block_type in load_all_blocks().values():
|
|
||||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
|
||||||
# Skip disabled blocks
|
|
||||||
if block.disabled:
|
|
||||||
continue
|
|
||||||
# Skip blocks that don't have categories (all should have at least one)
|
|
||||||
if not block.categories:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Add block to the categories
|
|
||||||
for category in block.categories:
|
|
||||||
if category not in categories:
|
|
||||||
categories[category] = BlockCategoryResponse(
|
|
||||||
name=category.name.lower(),
|
|
||||||
total_blocks=0,
|
|
||||||
blocks=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
categories[category].total_blocks += 1
|
|
||||||
|
|
||||||
# Append if the category has less than the specified number of blocks
|
|
||||||
if len(categories[category].blocks) < category_blocks:
|
|
||||||
categories[category].blocks.append(block.to_dict())
|
|
||||||
|
|
||||||
# Sort categories by name
|
|
||||||
return sorted(categories.values(), key=lambda x: x.name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_blocks(
|
|
||||||
*,
|
|
||||||
category: str | None = None,
|
|
||||||
type: BlockType | None = None,
|
|
||||||
provider: ProviderName | None = None,
|
|
||||||
page: int = 1,
|
|
||||||
page_size: int = 50,
|
|
||||||
) -> BlockResponse:
|
|
||||||
"""
|
|
||||||
Get blocks based on either category, type or provider.
|
|
||||||
Providing nothing fetches all block types.
|
|
||||||
"""
|
|
||||||
# Only one of category, type, or provider can be specified
|
|
||||||
if (category and type) or (category and provider) or (type and provider):
|
|
||||||
raise ValueError("Only one of category, type, or provider can be specified")
|
|
||||||
|
|
||||||
blocks: list[Block[BlockSchema, BlockSchema]] = []
|
|
||||||
skip = (page - 1) * page_size
|
|
||||||
take = page_size
|
|
||||||
total = 0
|
|
||||||
|
|
||||||
for block_type in load_all_blocks().values():
|
|
||||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
|
||||||
# Skip disabled blocks
|
|
||||||
if block.disabled:
|
|
||||||
continue
|
|
||||||
# Skip blocks that don't match the category
|
|
||||||
if category and category not in {c.name.lower() for c in block.categories}:
|
|
||||||
continue
|
|
||||||
# Skip blocks that don't match the type
|
|
||||||
if (
|
|
||||||
(type == "input" and block.block_type.value != "Input")
|
|
||||||
or (type == "output" and block.block_type.value != "Output")
|
|
||||||
or (type == "action" and block.block_type.value in ("Input", "Output"))
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
# Skip blocks that don't match the provider
|
|
||||||
if provider:
|
|
||||||
credentials_info = block.input_schema.get_credentials_fields_info().values()
|
|
||||||
if not any(provider in info.provider for info in credentials_info):
|
|
||||||
continue
|
|
||||||
|
|
||||||
total += 1
|
|
||||||
if skip > 0:
|
|
||||||
skip -= 1
|
|
||||||
continue
|
|
||||||
if take > 0:
|
|
||||||
take -= 1
|
|
||||||
blocks.append(block)
|
|
||||||
|
|
||||||
costs = get_block_costs()
|
|
||||||
|
|
||||||
return BlockResponse(
|
|
||||||
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
|
|
||||||
pagination=Pagination(
|
|
||||||
total_items=total,
|
|
||||||
total_pages=(total + page_size - 1) // page_size,
|
|
||||||
current_page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def search_blocks(
|
|
||||||
include_blocks: bool = True,
|
|
||||||
include_integrations: bool = True,
|
|
||||||
query: str = "",
|
|
||||||
page: int = 1,
|
|
||||||
page_size: int = 50,
|
|
||||||
) -> SearchBlocksResponse:
|
|
||||||
"""
|
|
||||||
Get blocks based on the filter and query.
|
|
||||||
`providers` only applies for `integrations` filter.
|
|
||||||
"""
|
|
||||||
blocks: list[Block[BlockSchema, BlockSchema]] = []
|
|
||||||
query = query.lower()
|
|
||||||
|
|
||||||
total = 0
|
|
||||||
skip = (page - 1) * page_size
|
|
||||||
take = page_size
|
|
||||||
block_count = 0
|
|
||||||
integration_count = 0
|
|
||||||
|
|
||||||
for block_type in load_all_blocks().values():
|
|
||||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
|
||||||
# Skip disabled blocks
|
|
||||||
if block.disabled:
|
|
||||||
continue
|
|
||||||
# Skip blocks that don't match the query
|
|
||||||
if (
|
|
||||||
query not in block.name.lower()
|
|
||||||
and query not in block.description.lower()
|
|
||||||
and not _matches_llm_model(block.input_schema, query)
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
keep = False
|
|
||||||
credentials = list(block.input_schema.get_credentials_fields().values())
|
|
||||||
if include_integrations and len(credentials) > 0:
|
|
||||||
keep = True
|
|
||||||
integration_count += 1
|
|
||||||
if include_blocks and len(credentials) == 0:
|
|
||||||
keep = True
|
|
||||||
block_count += 1
|
|
||||||
|
|
||||||
if not keep:
|
|
||||||
continue
|
|
||||||
|
|
||||||
total += 1
|
|
||||||
if skip > 0:
|
|
||||||
skip -= 1
|
|
||||||
continue
|
|
||||||
if take > 0:
|
|
||||||
take -= 1
|
|
||||||
blocks.append(block)
|
|
||||||
|
|
||||||
costs = get_block_costs()
|
|
||||||
|
|
||||||
return SearchBlocksResponse(
|
|
||||||
blocks=BlockResponse(
|
|
||||||
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
|
|
||||||
pagination=Pagination(
|
|
||||||
total_items=total,
|
|
||||||
total_pages=(total + page_size - 1) // page_size,
|
|
||||||
current_page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
total_block_count=block_count,
|
|
||||||
total_integration_count=integration_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_providers(
|
|
||||||
query: str = "",
|
|
||||||
page: int = 1,
|
|
||||||
page_size: int = 50,
|
|
||||||
) -> ProviderResponse:
|
|
||||||
providers = []
|
|
||||||
query = query.lower()
|
|
||||||
|
|
||||||
skip = (page - 1) * page_size
|
|
||||||
take = page_size
|
|
||||||
|
|
||||||
all_providers = _get_all_providers()
|
|
||||||
|
|
||||||
for provider in all_providers.values():
|
|
||||||
if (
|
|
||||||
query not in provider.name.value.lower()
|
|
||||||
and query not in provider.description.lower()
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
if skip > 0:
|
|
||||||
skip -= 1
|
|
||||||
continue
|
|
||||||
if take > 0:
|
|
||||||
take -= 1
|
|
||||||
providers.append(provider)
|
|
||||||
|
|
||||||
total = len(all_providers)
|
|
||||||
|
|
||||||
return ProviderResponse(
|
|
||||||
providers=providers,
|
|
||||||
pagination=Pagination(
|
|
||||||
total_items=total,
|
|
||||||
total_pages=(total + page_size - 1) // page_size,
|
|
||||||
current_page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_counts(user_id: str) -> CountResponse:
|
|
||||||
my_agents = await prisma.models.LibraryAgent.prisma().count(
|
|
||||||
where={
|
|
||||||
"userId": user_id,
|
|
||||||
"isDeleted": False,
|
|
||||||
"isArchived": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
counts = await _get_static_counts()
|
|
||||||
return CountResponse(
|
|
||||||
my_agents=my_agents,
|
|
||||||
**counts,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_static_counts():
|
|
||||||
"""
|
|
||||||
Get counts of blocks, integrations, and marketplace agents.
|
|
||||||
This is cached to avoid unnecessary database queries and calculations.
|
|
||||||
Can't use functools.cache here because the function is async.
|
|
||||||
"""
|
|
||||||
global _static_counts_cache
|
|
||||||
if _static_counts_cache is not None:
|
|
||||||
return _static_counts_cache
|
|
||||||
|
|
||||||
all_blocks = 0
|
|
||||||
input_blocks = 0
|
|
||||||
action_blocks = 0
|
|
||||||
output_blocks = 0
|
|
||||||
integrations = 0
|
|
||||||
|
|
||||||
for block_type in load_all_blocks().values():
|
|
||||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
|
||||||
if block.disabled:
|
|
||||||
continue
|
|
||||||
|
|
||||||
all_blocks += 1
|
|
||||||
|
|
||||||
if block.block_type.value == "Input":
|
|
||||||
input_blocks += 1
|
|
||||||
elif block.block_type.value == "Output":
|
|
||||||
output_blocks += 1
|
|
||||||
else:
|
|
||||||
action_blocks += 1
|
|
||||||
|
|
||||||
credentials = list(block.input_schema.get_credentials_fields().values())
|
|
||||||
if len(credentials) > 0:
|
|
||||||
integrations += 1
|
|
||||||
|
|
||||||
marketplace_agents = await prisma.models.StoreAgent.prisma().count()
|
|
||||||
|
|
||||||
_static_counts_cache = {
|
|
||||||
"all_blocks": all_blocks,
|
|
||||||
"input_blocks": input_blocks,
|
|
||||||
"action_blocks": action_blocks,
|
|
||||||
"output_blocks": output_blocks,
|
|
||||||
"integrations": integrations,
|
|
||||||
"marketplace_agents": marketplace_agents,
|
|
||||||
}
|
|
||||||
|
|
||||||
return _static_counts_cache
|
|
||||||
|
|
||||||
|
|
||||||
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
|
||||||
for field in schema_cls.model_fields.values():
|
|
||||||
if field.annotation == LlmModel:
|
|
||||||
# Check if query matches any value in llm_models
|
|
||||||
if any(query in name for name in llm_models):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def _get_all_providers() -> dict[ProviderName, Provider]:
|
|
||||||
providers: dict[ProviderName, Provider] = {}
|
|
||||||
|
|
||||||
for block_type in load_all_blocks().values():
|
|
||||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
|
||||||
if block.disabled:
|
|
||||||
continue
|
|
||||||
|
|
||||||
credentials_info = block.input_schema.get_credentials_fields_info().values()
|
|
||||||
for info in credentials_info:
|
|
||||||
for provider in info.provider: # provider is a ProviderName enum member
|
|
||||||
if provider in providers:
|
|
||||||
providers[provider].integration_count += 1
|
|
||||||
else:
|
|
||||||
providers[provider] = Provider(
|
|
||||||
name=provider, description="", integration_count=1
|
|
||||||
)
|
|
||||||
return providers
|
|
||||||
|
|
||||||
|
|
||||||
async def get_suggested_blocks(count: int = 5) -> list[BlockData]:
|
|
||||||
global _suggested_blocks
|
|
||||||
|
|
||||||
if _suggested_blocks is not None and len(_suggested_blocks) >= count:
|
|
||||||
return _suggested_blocks[:count]
|
|
||||||
|
|
||||||
_suggested_blocks = []
|
|
||||||
# Sum the number of executions for each block type
|
|
||||||
# Prisma cannot group by nested relations, so we do a raw query
|
|
||||||
# Calculate the cutoff timestamp
|
|
||||||
timestamp_threshold = datetime.now(timezone.utc) - timedelta(days=30)
|
|
||||||
|
|
||||||
results = await prisma.get_client().query_raw(
|
|
||||||
"""
|
|
||||||
SELECT
|
|
||||||
agent_node."agentBlockId" AS block_id,
|
|
||||||
COUNT(execution.id) AS execution_count
|
|
||||||
FROM "AgentNodeExecution" execution
|
|
||||||
JOIN "AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
|
|
||||||
WHERE execution."endedTime" >= $1::timestamp
|
|
||||||
GROUP BY agent_node."agentBlockId"
|
|
||||||
ORDER BY execution_count DESC;
|
|
||||||
""",
|
|
||||||
timestamp_threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the top blocks based on execution count
|
|
||||||
# But ignore Input and Output blocks
|
|
||||||
blocks: list[tuple[BlockData, int]] = []
|
|
||||||
|
|
||||||
for block_type in load_all_blocks().values():
|
|
||||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
|
||||||
if block.disabled or block.block_type in (
|
|
||||||
backend.data.block.BlockType.INPUT,
|
|
||||||
backend.data.block.BlockType.OUTPUT,
|
|
||||||
backend.data.block.BlockType.AGENT,
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
# Find the execution count for this block
|
|
||||||
execution_count = next(
|
|
||||||
(row["execution_count"] for row in results if row["block_id"] == block.id),
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
blocks.append((block.to_dict(), execution_count))
|
|
||||||
# Sort blocks by execution count
|
|
||||||
blocks.sort(key=lambda x: x[1], reverse=True)
|
|
||||||
|
|
||||||
_suggested_blocks = [block[0] for block in blocks]
|
|
||||||
|
|
||||||
# Return the top blocks
|
|
||||||
return _suggested_blocks[:count]
|
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
import backend.server.v2.library.model as library_model
|
|
||||||
import backend.server.v2.store.model as store_model
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.models import Pagination
|
|
||||||
|
|
||||||
FilterType = Literal[
|
|
||||||
"blocks",
|
|
||||||
"integrations",
|
|
||||||
"marketplace_agents",
|
|
||||||
"my_agents",
|
|
||||||
]
|
|
||||||
|
|
||||||
BlockType = Literal["all", "input", "action", "output"]
|
|
||||||
|
|
||||||
BlockData = dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
# Suggestions
|
|
||||||
class SuggestionsResponse(BaseModel):
|
|
||||||
otto_suggestions: list[str]
|
|
||||||
recent_searches: list[str]
|
|
||||||
providers: list[ProviderName]
|
|
||||||
top_blocks: list[BlockData]
|
|
||||||
|
|
||||||
|
|
||||||
# All blocks
|
|
||||||
class BlockCategoryResponse(BaseModel):
|
|
||||||
name: str
|
|
||||||
total_blocks: int
|
|
||||||
blocks: list[BlockData]
|
|
||||||
|
|
||||||
model_config = {"use_enum_values": False} # <== use enum names like "AI"
|
|
||||||
|
|
||||||
|
|
||||||
# Input/Action/Output and see all for block categories
|
|
||||||
class BlockResponse(BaseModel):
|
|
||||||
blocks: list[BlockData]
|
|
||||||
pagination: Pagination
|
|
||||||
|
|
||||||
|
|
||||||
# Providers
|
|
||||||
class Provider(BaseModel):
|
|
||||||
name: ProviderName
|
|
||||||
description: str
|
|
||||||
integration_count: int
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderResponse(BaseModel):
|
|
||||||
providers: list[Provider]
|
|
||||||
pagination: Pagination
|
|
||||||
|
|
||||||
|
|
||||||
# Search
|
|
||||||
class SearchRequest(BaseModel):
|
|
||||||
search_query: str | None = None
|
|
||||||
filter: list[FilterType] | None = None
|
|
||||||
by_creator: list[str] | None = None
|
|
||||||
search_id: str | None = None
|
|
||||||
page: int | None = None
|
|
||||||
page_size: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class SearchBlocksResponse(BaseModel):
|
|
||||||
blocks: BlockResponse
|
|
||||||
total_block_count: int
|
|
||||||
total_integration_count: int
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResponse(BaseModel):
|
|
||||||
items: list[BlockData | library_model.LibraryAgent | store_model.StoreAgent]
|
|
||||||
total_items: dict[FilterType, int]
|
|
||||||
page: int
|
|
||||||
more_pages: bool
|
|
||||||
|
|
||||||
|
|
||||||
class CountResponse(BaseModel):
|
|
||||||
all_blocks: int
|
|
||||||
input_blocks: int
|
|
||||||
action_blocks: int
|
|
||||||
output_blocks: int
|
|
||||||
integrations: int
|
|
||||||
marketplace_agents: int
|
|
||||||
my_agents: int
|
|
||||||
@@ -1,239 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Annotated, Sequence
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
from autogpt_libs.auth.depends import auth_middleware, get_user_id
|
|
||||||
|
|
||||||
import backend.server.v2.builder.db as builder_db
|
|
||||||
import backend.server.v2.builder.model as builder_model
|
|
||||||
import backend.server.v2.library.db as library_db
|
|
||||||
import backend.server.v2.library.model as library_model
|
|
||||||
import backend.server.v2.store.db as store_db
|
|
||||||
import backend.server.v2.store.model as store_model
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.models import Pagination
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = fastapi.APIRouter()
|
|
||||||
|
|
||||||
|
|
||||||
# Taken from backend/server/v2/store/db.py
|
|
||||||
def sanitize_query(query: str | None) -> str | None:
|
|
||||||
if query is None:
|
|
||||||
return query
|
|
||||||
query = query.strip()[:100]
|
|
||||||
return (
|
|
||||||
query.replace("\\", "\\\\")
|
|
||||||
.replace("%", "\\%")
|
|
||||||
.replace("_", "\\_")
|
|
||||||
.replace("[", "\\[")
|
|
||||||
.replace("]", "\\]")
|
|
||||||
.replace("'", "\\'")
|
|
||||||
.replace('"', '\\"')
|
|
||||||
.replace(";", "\\;")
|
|
||||||
.replace("--", "\\--")
|
|
||||||
.replace("/*", "\\/*")
|
|
||||||
.replace("*/", "\\*/")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/suggestions",
|
|
||||||
summary="Get Builder suggestions",
|
|
||||||
dependencies=[fastapi.Depends(auth_middleware)],
|
|
||||||
response_model=builder_model.SuggestionsResponse,
|
|
||||||
)
|
|
||||||
async def get_suggestions(
|
|
||||||
user_id: Annotated[str, fastapi.Depends(get_user_id)],
|
|
||||||
) -> builder_model.SuggestionsResponse:
|
|
||||||
"""
|
|
||||||
Get all suggestions for the Blocks Menu.
|
|
||||||
"""
|
|
||||||
return builder_model.SuggestionsResponse(
|
|
||||||
otto_suggestions=[
|
|
||||||
"What blocks do I need to get started?",
|
|
||||||
"Help me create a list",
|
|
||||||
"Help me feed my data to Google Maps",
|
|
||||||
],
|
|
||||||
recent_searches=[
|
|
||||||
"image generation",
|
|
||||||
"deepfake",
|
|
||||||
"competitor analysis",
|
|
||||||
],
|
|
||||||
providers=[
|
|
||||||
ProviderName.TWITTER,
|
|
||||||
ProviderName.GITHUB,
|
|
||||||
ProviderName.NOTION,
|
|
||||||
ProviderName.GOOGLE,
|
|
||||||
ProviderName.DISCORD,
|
|
||||||
ProviderName.GOOGLE_MAPS,
|
|
||||||
],
|
|
||||||
top_blocks=await builder_db.get_suggested_blocks(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/categories",
|
|
||||||
summary="Get Builder block categories",
|
|
||||||
dependencies=[fastapi.Depends(auth_middleware)],
|
|
||||||
response_model=Sequence[builder_model.BlockCategoryResponse],
|
|
||||||
)
|
|
||||||
async def get_block_categories(
|
|
||||||
blocks_per_category: Annotated[int, fastapi.Query()] = 3,
|
|
||||||
) -> Sequence[builder_model.BlockCategoryResponse]:
|
|
||||||
"""
|
|
||||||
Get all block categories with a specified number of blocks per category.
|
|
||||||
"""
|
|
||||||
return builder_db.get_block_categories(blocks_per_category)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/blocks",
|
|
||||||
summary="Get Builder blocks",
|
|
||||||
dependencies=[fastapi.Depends(auth_middleware)],
|
|
||||||
response_model=builder_model.BlockResponse,
|
|
||||||
)
|
|
||||||
async def get_blocks(
|
|
||||||
category: Annotated[str | None, fastapi.Query()] = None,
|
|
||||||
type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None,
|
|
||||||
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
|
|
||||||
page: Annotated[int, fastapi.Query()] = 1,
|
|
||||||
page_size: Annotated[int, fastapi.Query()] = 50,
|
|
||||||
) -> builder_model.BlockResponse:
|
|
||||||
"""
|
|
||||||
Get blocks based on either category, type, or provider.
|
|
||||||
"""
|
|
||||||
return builder_db.get_blocks(
|
|
||||||
category=category,
|
|
||||||
type=type,
|
|
||||||
provider=provider,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/providers",
|
|
||||||
summary="Get Builder integration providers",
|
|
||||||
dependencies=[fastapi.Depends(auth_middleware)],
|
|
||||||
response_model=builder_model.ProviderResponse,
|
|
||||||
)
|
|
||||||
async def get_providers(
|
|
||||||
page: Annotated[int, fastapi.Query()] = 1,
|
|
||||||
page_size: Annotated[int, fastapi.Query()] = 50,
|
|
||||||
) -> builder_model.ProviderResponse:
|
|
||||||
"""
|
|
||||||
Get all integration providers with their block counts.
|
|
||||||
"""
|
|
||||||
return builder_db.get_providers(
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/search",
|
|
||||||
summary="Builder search",
|
|
||||||
tags=["store", "private"],
|
|
||||||
dependencies=[fastapi.Depends(auth_middleware)],
|
|
||||||
response_model=builder_model.SearchResponse,
|
|
||||||
)
|
|
||||||
async def search(
|
|
||||||
options: builder_model.SearchRequest,
|
|
||||||
user_id: Annotated[str, fastapi.Depends(get_user_id)],
|
|
||||||
) -> builder_model.SearchResponse:
|
|
||||||
"""
|
|
||||||
Search for blocks (including integrations), marketplace agents, and user library agents.
|
|
||||||
"""
|
|
||||||
# If no filters are provided, then we will return all types
|
|
||||||
if not options.filter:
|
|
||||||
options.filter = [
|
|
||||||
"blocks",
|
|
||||||
"integrations",
|
|
||||||
"marketplace_agents",
|
|
||||||
"my_agents",
|
|
||||||
]
|
|
||||||
options.search_query = sanitize_query(options.search_query)
|
|
||||||
options.page = options.page or 1
|
|
||||||
options.page_size = options.page_size or 50
|
|
||||||
|
|
||||||
# Blocks&Integrations
|
|
||||||
blocks = builder_model.SearchBlocksResponse(
|
|
||||||
blocks=builder_model.BlockResponse(
|
|
||||||
blocks=[],
|
|
||||||
pagination=Pagination.empty(),
|
|
||||||
),
|
|
||||||
total_block_count=0,
|
|
||||||
total_integration_count=0,
|
|
||||||
)
|
|
||||||
if "blocks" in options.filter or "integrations" in options.filter:
|
|
||||||
blocks = builder_db.search_blocks(
|
|
||||||
include_blocks="blocks" in options.filter,
|
|
||||||
include_integrations="integrations" in options.filter,
|
|
||||||
query=options.search_query or "",
|
|
||||||
page=options.page,
|
|
||||||
page_size=options.page_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Library Agents
|
|
||||||
my_agents = library_model.LibraryAgentResponse(
|
|
||||||
agents=[],
|
|
||||||
pagination=Pagination.empty(),
|
|
||||||
)
|
|
||||||
if "my_agents" in options.filter:
|
|
||||||
my_agents = await library_db.list_library_agents(
|
|
||||||
user_id=user_id,
|
|
||||||
search_term=options.search_query,
|
|
||||||
page=options.page,
|
|
||||||
page_size=options.page_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Marketplace Agents
|
|
||||||
marketplace_agents = store_model.StoreAgentsResponse(
|
|
||||||
agents=[],
|
|
||||||
pagination=Pagination.empty(),
|
|
||||||
)
|
|
||||||
if "marketplace_agents" in options.filter:
|
|
||||||
marketplace_agents = await store_db.get_store_agents(
|
|
||||||
creators=options.by_creator,
|
|
||||||
search_query=options.search_query,
|
|
||||||
page=options.page,
|
|
||||||
page_size=options.page_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
more_pages = False
|
|
||||||
if (
|
|
||||||
blocks.blocks.pagination.current_page < blocks.blocks.pagination.total_pages
|
|
||||||
or my_agents.pagination.current_page < my_agents.pagination.total_pages
|
|
||||||
or marketplace_agents.pagination.current_page
|
|
||||||
< marketplace_agents.pagination.total_pages
|
|
||||||
):
|
|
||||||
more_pages = True
|
|
||||||
|
|
||||||
return builder_model.SearchResponse(
|
|
||||||
items=blocks.blocks.blocks + my_agents.agents + marketplace_agents.agents,
|
|
||||||
total_items={
|
|
||||||
"blocks": blocks.total_block_count,
|
|
||||||
"integrations": blocks.total_integration_count,
|
|
||||||
"marketplace_agents": marketplace_agents.pagination.total_items,
|
|
||||||
"my_agents": my_agents.pagination.total_items,
|
|
||||||
},
|
|
||||||
page=options.page,
|
|
||||||
more_pages=more_pages,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/counts",
|
|
||||||
summary="Get Builder item counts",
|
|
||||||
dependencies=[fastapi.Depends(auth_middleware)],
|
|
||||||
response_model=builder_model.CountResponse,
|
|
||||||
)
|
|
||||||
async def get_counts(
|
|
||||||
user_id: Annotated[str, fastapi.Depends(get_user_id)],
|
|
||||||
) -> builder_model.CountResponse:
|
|
||||||
"""
|
|
||||||
Get item counts for the menu categories in the Blocks Menu.
|
|
||||||
"""
|
|
||||||
return await builder_db.get_counts(user_id)
|
|
||||||
@@ -51,7 +51,6 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
description: str
|
description: str
|
||||||
|
|
||||||
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
|
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
|
||||||
output_schema: dict[str, Any]
|
|
||||||
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
|
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
|
||||||
description="Input schema for credentials required by the agent",
|
description="Input schema for credentials required by the agent",
|
||||||
)
|
)
|
||||||
@@ -127,7 +126,6 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
name=graph.name,
|
name=graph.name,
|
||||||
description=graph.description,
|
description=graph.description,
|
||||||
input_schema=graph.input_schema,
|
input_schema=graph.input_schema,
|
||||||
output_schema=graph.output_schema,
|
|
||||||
credentials_input_schema=(
|
credentials_input_schema=(
|
||||||
graph.credentials_input_schema if sub_graphs is not None else None
|
graph.credentials_input_schema if sub_graphs is not None else None
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ async def test_get_library_agents_success(
|
|||||||
creator_name="Test Creator",
|
creator_name="Test Creator",
|
||||||
creator_image_url="",
|
creator_image_url="",
|
||||||
input_schema={"type": "object", "properties": {}},
|
input_schema={"type": "object", "properties": {}},
|
||||||
output_schema={"type": "object", "properties": {}},
|
|
||||||
credentials_input_schema={"type": "object", "properties": {}},
|
credentials_input_schema={"type": "object", "properties": {}},
|
||||||
has_external_trigger=False,
|
has_external_trigger=False,
|
||||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||||
@@ -69,7 +68,6 @@ async def test_get_library_agents_success(
|
|||||||
creator_name="Test Creator",
|
creator_name="Test Creator",
|
||||||
creator_image_url="",
|
creator_image_url="",
|
||||||
input_schema={"type": "object", "properties": {}},
|
input_schema={"type": "object", "properties": {}},
|
||||||
output_schema={"type": "object", "properties": {}},
|
|
||||||
credentials_input_schema={"type": "object", "properties": {}},
|
credentials_input_schema={"type": "object", "properties": {}},
|
||||||
has_external_trigger=False,
|
has_external_trigger=False,
|
||||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||||
@@ -134,7 +132,6 @@ def test_add_agent_to_library_success(mocker: pytest_mock.MockFixture):
|
|||||||
creator_name="Test Creator",
|
creator_name="Test Creator",
|
||||||
creator_image_url="",
|
creator_image_url="",
|
||||||
input_schema={"type": "object", "properties": {}},
|
input_schema={"type": "object", "properties": {}},
|
||||||
output_schema={"type": "object", "properties": {}},
|
|
||||||
credentials_input_schema={"type": "object", "properties": {}},
|
credentials_input_schema={"type": "object", "properties": {}},
|
||||||
has_external_trigger=False,
|
has_external_trigger=False,
|
||||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||||
|
|||||||
@@ -17,21 +17,8 @@ from backend.data.graph import (
|
|||||||
get_sub_graphs,
|
get_sub_graphs,
|
||||||
)
|
)
|
||||||
from backend.data.includes import AGENT_GRAPH_INCLUDE
|
from backend.data.includes import AGENT_GRAPH_INCLUDE
|
||||||
from backend.data.notifications import (
|
|
||||||
AgentApprovalData,
|
|
||||||
AgentRejectionData,
|
|
||||||
NotificationEventModel,
|
|
||||||
)
|
|
||||||
from backend.notifications.notifications import queue_notification_async
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
settings = Settings()
|
|
||||||
|
|
||||||
|
|
||||||
# Constants for default admin values
|
|
||||||
DEFAULT_ADMIN_NAME = "AutoGPT Admin"
|
|
||||||
DEFAULT_ADMIN_EMAIL = "admin@autogpt.co"
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_query(query: str | None) -> str | None:
|
def sanitize_query(query: str | None) -> str | None:
|
||||||
@@ -55,7 +42,7 @@ def sanitize_query(query: str | None) -> str | None:
|
|||||||
|
|
||||||
async def get_store_agents(
|
async def get_store_agents(
|
||||||
featured: bool = False,
|
featured: bool = False,
|
||||||
creators: list[str] | None = None,
|
creator: str | None = None,
|
||||||
sorted_by: str | None = None,
|
sorted_by: str | None = None,
|
||||||
search_query: str | None = None,
|
search_query: str | None = None,
|
||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
@@ -66,15 +53,15 @@ async def get_store_agents(
|
|||||||
Get PUBLIC store agents from the StoreAgent view
|
Get PUBLIC store agents from the StoreAgent view
|
||||||
"""
|
"""
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
f"Getting store agents. featured={featured}, creator={creator}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||||
)
|
)
|
||||||
sanitized_query = sanitize_query(search_query)
|
sanitized_query = sanitize_query(search_query)
|
||||||
|
|
||||||
where_clause = {}
|
where_clause = {}
|
||||||
if featured:
|
if featured:
|
||||||
where_clause["featured"] = featured
|
where_clause["featured"] = featured
|
||||||
if creators:
|
if creator:
|
||||||
where_clause["creator_username"] = {"in": creators}
|
where_clause["creator_username"] = creator
|
||||||
if category:
|
if category:
|
||||||
where_clause["categories"] = {"has": category}
|
where_clause["categories"] = {"has": category}
|
||||||
|
|
||||||
@@ -1254,8 +1241,7 @@ async def review_store_submission(
|
|||||||
where={"id": store_listing_version_id},
|
where={"id": store_listing_version_id},
|
||||||
include={
|
include={
|
||||||
"StoreListing": True,
|
"StoreListing": True,
|
||||||
"AgentGraph": {"include": {**AGENT_GRAPH_INCLUDE, "User": True}},
|
"AgentGraph": {"include": AGENT_GRAPH_INCLUDE},
|
||||||
"Reviewer": True,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -1266,13 +1252,6 @@ async def review_store_submission(
|
|||||||
detail=f"Store listing version {store_listing_version_id} not found",
|
detail=f"Store listing version {store_listing_version_id} not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if we're rejecting an already approved agent
|
|
||||||
is_rejecting_approved = (
|
|
||||||
not is_approved
|
|
||||||
and store_listing_version.submissionStatus
|
|
||||||
== prisma.enums.SubmissionStatus.APPROVED
|
|
||||||
)
|
|
||||||
|
|
||||||
# If approving, update the listing to indicate it has an approved version
|
# If approving, update the listing to indicate it has an approved version
|
||||||
if is_approved and store_listing_version.AgentGraph:
|
if is_approved and store_listing_version.AgentGraph:
|
||||||
heading = f"Sub-graph of {store_listing_version.name}v{store_listing_version.agentGraphVersion}"
|
heading = f"Sub-graph of {store_listing_version.name}v{store_listing_version.agentGraphVersion}"
|
||||||
@@ -1303,37 +1282,6 @@ async def review_store_submission(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# If rejecting an approved agent, update the StoreListing accordingly
|
|
||||||
if is_rejecting_approved:
|
|
||||||
# Check if there are other approved versions
|
|
||||||
other_approved = (
|
|
||||||
await prisma.models.StoreListingVersion.prisma().find_first(
|
|
||||||
where={
|
|
||||||
"storeListingId": store_listing_version.StoreListing.id,
|
|
||||||
"id": {"not": store_listing_version_id},
|
|
||||||
"submissionStatus": prisma.enums.SubmissionStatus.APPROVED,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not other_approved:
|
|
||||||
# No other approved versions, update hasApprovedVersion to False
|
|
||||||
await prisma.models.StoreListing.prisma().update(
|
|
||||||
where={"id": store_listing_version.StoreListing.id},
|
|
||||||
data={
|
|
||||||
"hasApprovedVersion": False,
|
|
||||||
"ActiveVersion": {"disconnect": True},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Set the most recent other approved version as active
|
|
||||||
await prisma.models.StoreListing.prisma().update(
|
|
||||||
where={"id": store_listing_version.StoreListing.id},
|
|
||||||
data={
|
|
||||||
"ActiveVersion": {"connect": {"id": other_approved.id}},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
submission_status = (
|
submission_status = (
|
||||||
prisma.enums.SubmissionStatus.APPROVED
|
prisma.enums.SubmissionStatus.APPROVED
|
||||||
if is_approved
|
if is_approved
|
||||||
@@ -1362,89 +1310,6 @@ async def review_store_submission(
|
|||||||
f"Failed to update store listing version {store_listing_version_id}"
|
f"Failed to update store listing version {store_listing_version_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send email notification to the agent creator
|
|
||||||
if store_listing_version.AgentGraph and store_listing_version.AgentGraph.User:
|
|
||||||
agent_creator = store_listing_version.AgentGraph.User
|
|
||||||
reviewer = (
|
|
||||||
store_listing_version.Reviewer
|
|
||||||
if store_listing_version.Reviewer
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
base_url = (
|
|
||||||
settings.config.frontend_base_url
|
|
||||||
or settings.config.platform_base_url
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_approved:
|
|
||||||
store_agent = (
|
|
||||||
await prisma.models.StoreAgent.prisma().find_first_or_raise(
|
|
||||||
where={"storeListingVersionId": submission.id}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send approval notification
|
|
||||||
notification_data = AgentApprovalData(
|
|
||||||
agent_name=submission.name,
|
|
||||||
agent_id=submission.agentGraphId,
|
|
||||||
agent_version=submission.agentGraphVersion,
|
|
||||||
reviewer_name=(
|
|
||||||
reviewer.name
|
|
||||||
if reviewer and reviewer.name
|
|
||||||
else DEFAULT_ADMIN_NAME
|
|
||||||
),
|
|
||||||
reviewer_email=(
|
|
||||||
reviewer.email if reviewer else DEFAULT_ADMIN_EMAIL
|
|
||||||
),
|
|
||||||
comments=external_comments,
|
|
||||||
reviewed_at=submission.reviewedAt
|
|
||||||
or datetime.now(tz=timezone.utc),
|
|
||||||
store_url=f"{base_url}/marketplace/agent/{store_agent.creator_username}/{store_agent.slug}",
|
|
||||||
)
|
|
||||||
|
|
||||||
notification_event = NotificationEventModel[AgentApprovalData](
|
|
||||||
user_id=agent_creator.id,
|
|
||||||
type=prisma.enums.NotificationType.AGENT_APPROVED,
|
|
||||||
data=notification_data,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Send rejection notification
|
|
||||||
notification_data = AgentRejectionData(
|
|
||||||
agent_name=submission.name,
|
|
||||||
agent_id=submission.agentGraphId,
|
|
||||||
agent_version=submission.agentGraphVersion,
|
|
||||||
reviewer_name=(
|
|
||||||
reviewer.name
|
|
||||||
if reviewer and reviewer.name
|
|
||||||
else DEFAULT_ADMIN_NAME
|
|
||||||
),
|
|
||||||
reviewer_email=(
|
|
||||||
reviewer.email if reviewer else DEFAULT_ADMIN_EMAIL
|
|
||||||
),
|
|
||||||
comments=external_comments,
|
|
||||||
reviewed_at=submission.reviewedAt
|
|
||||||
or datetime.now(tz=timezone.utc),
|
|
||||||
resubmit_url=f"{base_url}/build?flowID={submission.agentGraphId}",
|
|
||||||
)
|
|
||||||
|
|
||||||
notification_event = NotificationEventModel[AgentRejectionData](
|
|
||||||
user_id=agent_creator.id,
|
|
||||||
type=prisma.enums.NotificationType.AGENT_REJECTED,
|
|
||||||
data=notification_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Queue the notification for immediate sending
|
|
||||||
await queue_notification_async(notification_event)
|
|
||||||
logger.info(
|
|
||||||
f"Queued {'approval' if is_approved else 'rejection'} notification for user {agent_creator.id} and agent {submission.name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to send email notification for agent review: {e}")
|
|
||||||
# Don't fail the review process if email sending fails
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Convert to Pydantic model for consistency
|
# Convert to Pydantic model for consistency
|
||||||
return backend.server.v2.store.model.StoreSubmission(
|
return backend.server.v2.store.model.StoreSubmission(
|
||||||
agent_id=submission.agentGraphId,
|
agent_id=submission.agentGraphId,
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ async def get_agents(
|
|||||||
try:
|
try:
|
||||||
agents = await backend.server.v2.store.db.get_store_agents(
|
agents = await backend.server.v2.store.db.get_store_agents(
|
||||||
featured=featured,
|
featured=featured,
|
||||||
creators=[creator] if creator else None,
|
creator=creator,
|
||||||
sorted_by=sorted_by,
|
sorted_by=sorted_by,
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
category=category,
|
category=category,
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ def test_get_agents_defaults(
|
|||||||
snapshot.assert_match(json.dumps(response.json(), indent=2), "def_agts")
|
snapshot.assert_match(json.dumps(response.json(), indent=2), "def_agts")
|
||||||
mock_db_call.assert_called_once_with(
|
mock_db_call.assert_called_once_with(
|
||||||
featured=False,
|
featured=False,
|
||||||
creators=None,
|
creator=None,
|
||||||
sorted_by=None,
|
sorted_by=None,
|
||||||
search_query=None,
|
search_query=None,
|
||||||
category=None,
|
category=None,
|
||||||
@@ -113,7 +113,7 @@ def test_get_agents_featured(
|
|||||||
snapshot.assert_match(json.dumps(response.json(), indent=2), "feat_agts")
|
snapshot.assert_match(json.dumps(response.json(), indent=2), "feat_agts")
|
||||||
mock_db_call.assert_called_once_with(
|
mock_db_call.assert_called_once_with(
|
||||||
featured=True,
|
featured=True,
|
||||||
creators=None,
|
creator=None,
|
||||||
sorted_by=None,
|
sorted_by=None,
|
||||||
search_query=None,
|
search_query=None,
|
||||||
category=None,
|
category=None,
|
||||||
@@ -160,7 +160,7 @@ def test_get_agents_by_creator(
|
|||||||
snapshot.assert_match(json.dumps(response.json(), indent=2), "agts_by_creator")
|
snapshot.assert_match(json.dumps(response.json(), indent=2), "agts_by_creator")
|
||||||
mock_db_call.assert_called_once_with(
|
mock_db_call.assert_called_once_with(
|
||||||
featured=False,
|
featured=False,
|
||||||
creators=["specific-creator"],
|
creator="specific-creator",
|
||||||
sorted_by=None,
|
sorted_by=None,
|
||||||
search_query=None,
|
search_query=None,
|
||||||
category=None,
|
category=None,
|
||||||
@@ -207,7 +207,7 @@ def test_get_agents_sorted(
|
|||||||
snapshot.assert_match(json.dumps(response.json(), indent=2), "agts_sorted")
|
snapshot.assert_match(json.dumps(response.json(), indent=2), "agts_sorted")
|
||||||
mock_db_call.assert_called_once_with(
|
mock_db_call.assert_called_once_with(
|
||||||
featured=False,
|
featured=False,
|
||||||
creators=None,
|
creator=None,
|
||||||
sorted_by="runs",
|
sorted_by="runs",
|
||||||
search_query=None,
|
search_query=None,
|
||||||
category=None,
|
category=None,
|
||||||
@@ -254,7 +254,7 @@ def test_get_agents_search(
|
|||||||
snapshot.assert_match(json.dumps(response.json(), indent=2), "agts_search")
|
snapshot.assert_match(json.dumps(response.json(), indent=2), "agts_search")
|
||||||
mock_db_call.assert_called_once_with(
|
mock_db_call.assert_called_once_with(
|
||||||
featured=False,
|
featured=False,
|
||||||
creators=None,
|
creator=None,
|
||||||
sorted_by=None,
|
sorted_by=None,
|
||||||
search_query="specific",
|
search_query="specific",
|
||||||
category=None,
|
category=None,
|
||||||
@@ -300,7 +300,7 @@ def test_get_agents_category(
|
|||||||
snapshot.assert_match(json.dumps(response.json(), indent=2), "agts_category")
|
snapshot.assert_match(json.dumps(response.json(), indent=2), "agts_category")
|
||||||
mock_db_call.assert_called_once_with(
|
mock_db_call.assert_called_once_with(
|
||||||
featured=False,
|
featured=False,
|
||||||
creators=None,
|
creator=None,
|
||||||
sorted_by=None,
|
sorted_by=None,
|
||||||
search_query=None,
|
search_query=None,
|
||||||
category="test-category",
|
category="test-category",
|
||||||
@@ -349,7 +349,7 @@ def test_get_agents_pagination(
|
|||||||
snapshot.assert_match(json.dumps(response.json(), indent=2), "agts_pagination")
|
snapshot.assert_match(json.dumps(response.json(), indent=2), "agts_pagination")
|
||||||
mock_db_call.assert_called_once_with(
|
mock_db_call.assert_called_once_with(
|
||||||
featured=False,
|
featured=False,
|
||||||
creators=None,
|
creator=None,
|
||||||
sorted_by=None,
|
sorted_by=None,
|
||||||
search_query=None,
|
search_query=None,
|
||||||
category=None,
|
category=None,
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ class ModerationError(ValueError):
|
|||||||
message: str
|
message: str
|
||||||
graph_exec_id: str
|
graph_exec_id: str
|
||||||
moderation_type: str
|
moderation_type: str
|
||||||
content_id: str | None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -48,20 +47,16 @@ class ModerationError(ValueError):
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
moderation_type: str = "content",
|
moderation_type: str = "content",
|
||||||
content_id: str | None = None,
|
|
||||||
):
|
):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.args = (message, user_id, graph_exec_id, moderation_type, content_id)
|
self.args = (message, user_id, graph_exec_id, moderation_type)
|
||||||
self.message = message
|
self.message = message
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.graph_exec_id = graph_exec_id
|
self.graph_exec_id = graph_exec_id
|
||||||
self.moderation_type = moderation_type
|
self.moderation_type = moderation_type
|
||||||
self.content_id = content_id
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
"""Used to display the error message in the frontend, because we str() the error when sending the execution update"""
|
"""Used to display the error message in the frontend, because we str() the error when sending the execution update"""
|
||||||
if self.content_id:
|
|
||||||
return f"{self.message} (Moderation ID: {self.content_id})"
|
|
||||||
return self.message
|
return self.message
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
@@ -11,11 +10,6 @@ from backend.util.settings import Settings
|
|||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
|
|
||||||
class DiscordChannel(str, Enum):
|
|
||||||
PLATFORM = "platform" # For platform/system alerts
|
|
||||||
PRODUCT = "product" # For product alerts (low balance, zero balance, etc.)
|
|
||||||
|
|
||||||
|
|
||||||
def sentry_init():
|
def sentry_init():
|
||||||
sentry_dsn = settings.secrets.sentry_dsn
|
sentry_dsn = settings.secrets.sentry_dsn
|
||||||
sentry_sdk.init(
|
sentry_sdk.init(
|
||||||
@@ -38,10 +32,8 @@ def sentry_capture_error(error: Exception):
|
|||||||
sentry_sdk.flush()
|
sentry_sdk.flush()
|
||||||
|
|
||||||
|
|
||||||
async def discord_send_alert(
|
async def discord_send_alert(content: str):
|
||||||
content: str, channel: DiscordChannel = DiscordChannel.PLATFORM
|
from backend.blocks.discord import SendDiscordMessageBlock
|
||||||
):
|
|
||||||
from backend.blocks.discord.bot_blocks import SendDiscordMessageBlock
|
|
||||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput, ProviderName
|
from backend.data.model import APIKeyCredentials, CredentialsMetaInput, ProviderName
|
||||||
|
|
||||||
creds = APIKeyCredentials(
|
creds = APIKeyCredentials(
|
||||||
@@ -51,14 +43,6 @@ async def discord_send_alert(
|
|||||||
expires_at=None,
|
expires_at=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Select channel based on enum
|
|
||||||
if channel == DiscordChannel.PLATFORM:
|
|
||||||
channel_name = settings.config.platform_alert_discord_channel
|
|
||||||
elif channel == DiscordChannel.PRODUCT:
|
|
||||||
channel_name = settings.config.product_alert_discord_channel
|
|
||||||
else:
|
|
||||||
channel_name = settings.config.platform_alert_discord_channel
|
|
||||||
|
|
||||||
return await SendDiscordMessageBlock().run_once(
|
return await SendDiscordMessageBlock().run_once(
|
||||||
SendDiscordMessageBlock.Input(
|
SendDiscordMessageBlock.Input(
|
||||||
credentials=CredentialsMetaInput(
|
credentials=CredentialsMetaInput(
|
||||||
@@ -68,7 +52,7 @@ async def discord_send_alert(
|
|||||||
provider=ProviderName.DISCORD,
|
provider=ProviderName.DISCORD,
|
||||||
),
|
),
|
||||||
message_content=content,
|
message_content=content,
|
||||||
channel_name=channel_name,
|
channel_name=settings.config.platform_alert_discord_channel,
|
||||||
),
|
),
|
||||||
"status",
|
"status",
|
||||||
credentials=creds,
|
credentials=creds,
|
||||||
|
|||||||
@@ -18,12 +18,3 @@ class Pagination(pydantic.BaseModel):
|
|||||||
page_size: int = pydantic.Field(
|
page_size: int = pydantic.Field(
|
||||||
description="Number of items per page.", examples=[25]
|
description="Number of items per page.", examples=[25]
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def empty() -> "Pagination":
|
|
||||||
return Pagination(
|
|
||||||
total_items=0,
|
|
||||||
total_pages=0,
|
|
||||||
current_page=0,
|
|
||||||
page_size=0,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -48,9 +48,10 @@ class AppProcess(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
@property
|
@property
|
||||||
def service_name(self) -> str:
|
def service_name(cls) -> str:
|
||||||
return self.__class__.__name__
|
return cls.__name__
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
|
|||||||
@@ -97,11 +97,11 @@ class BaseAppService(AppProcess, ABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_host(cls) -> str:
|
def get_host(cls) -> str:
|
||||||
source_host = os.environ.get(f"{get_service_name().upper()}_HOST", api_host)
|
source_host = os.environ.get(f"{get_service_name().upper()}_HOST", api_host)
|
||||||
target_host = os.environ.get(f"{cls.__name__.upper()}_HOST", api_host)
|
target_host = os.environ.get(f"{cls.service_name.upper()}_HOST", api_host)
|
||||||
|
|
||||||
if source_host == target_host and source_host != api_host:
|
if source_host == target_host and source_host != api_host:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Service {cls.__name__} is the same host as the source service."
|
f"Service {cls.service_name} is the same host as the source service."
|
||||||
f"Use the localhost of {api_host} instead."
|
f"Use the localhost of {api_host} instead."
|
||||||
)
|
)
|
||||||
return api_host
|
return api_host
|
||||||
|
|||||||
@@ -95,10 +95,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
default=500,
|
default=500,
|
||||||
description="Maximum number of credits above the balance to be auto-approved.",
|
description="Maximum number of credits above the balance to be auto-approved.",
|
||||||
)
|
)
|
||||||
low_balance_threshold: int = Field(
|
|
||||||
default=500,
|
|
||||||
description="Credit threshold for low balance notifications (100 = $1, default 500 = $5)",
|
|
||||||
)
|
|
||||||
refund_notification_email: str = Field(
|
refund_notification_email: str = Field(
|
||||||
default="refund@agpt.co",
|
default="refund@agpt.co",
|
||||||
description="Email address to send refund notifications to.",
|
description="Email address to send refund notifications to.",
|
||||||
@@ -254,10 +250,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
default="local-alerts",
|
default="local-alerts",
|
||||||
description="The Discord channel for the platform",
|
description="The Discord channel for the platform",
|
||||||
)
|
)
|
||||||
product_alert_discord_channel: str = Field(
|
|
||||||
default="product-alerts",
|
|
||||||
description="The Discord channel for product alerts (low balance, zero balance, etc.)",
|
|
||||||
)
|
|
||||||
|
|
||||||
clamav_service_host: str = Field(
|
clamav_service_host: str = Field(
|
||||||
default="localhost",
|
default="localhost",
|
||||||
@@ -473,10 +465,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
|||||||
twitter_client_secret: str = Field(
|
twitter_client_secret: str = Field(
|
||||||
default="", description="Twitter/X OAuth client secret"
|
default="", description="Twitter/X OAuth client secret"
|
||||||
)
|
)
|
||||||
discord_client_id: str = Field(default="", description="Discord OAuth client ID")
|
|
||||||
discord_client_secret: str = Field(
|
|
||||||
default="", description="Discord OAuth client secret"
|
|
||||||
)
|
|
||||||
|
|
||||||
openai_api_key: str = Field(default="", description="OpenAI API key")
|
openai_api_key: str = Field(default="", description="OpenAI API key")
|
||||||
aiml_api_key: str = Field(default="", description="'AI/ML API' key")
|
aiml_api_key: str = Field(default="", description="'AI/ML API' key")
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from backend.data.block import Block, BlockSchema, initialize_blocks
|
|||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
ExecutionStatus,
|
ExecutionStatus,
|
||||||
NodeExecutionResult,
|
NodeExecutionResult,
|
||||||
UserContext,
|
|
||||||
get_graph_execution,
|
get_graph_execution,
|
||||||
)
|
)
|
||||||
from backend.data.model import _BaseCredentials
|
from backend.data.model import _BaseCredentials
|
||||||
@@ -139,7 +138,6 @@ async def execute_block_test(block: Block):
|
|||||||
"graph_exec_id": str(uuid.uuid4()),
|
"graph_exec_id": str(uuid.uuid4()),
|
||||||
"node_exec_id": str(uuid.uuid4()),
|
"node_exec_id": str(uuid.uuid4()),
|
||||||
"user_id": str(uuid.uuid4()),
|
"user_id": str(uuid.uuid4()),
|
||||||
"user_context": UserContext(timezone="UTC"), # Default for tests
|
|
||||||
}
|
}
|
||||||
input_model = cast(type[BlockSchema], block.input_schema)
|
input_model = cast(type[BlockSchema], block.input_schema)
|
||||||
credentials_input_fields = input_model.get_credentials_fields()
|
credentials_input_fields = input_model.get_credentials_fields()
|
||||||
|
|||||||
@@ -1,115 +0,0 @@
|
|||||||
"""
|
|
||||||
Time zone name validation and serialization.
|
|
||||||
|
|
||||||
This file is adapted from pydantic-extra-types:
|
|
||||||
https://github.com/pydantic/pydantic-extra-types/blob/main/pydantic_extra_types/timezone_name.py
|
|
||||||
|
|
||||||
The MIT License (MIT)
|
|
||||||
|
|
||||||
Copyright (c) 2023 Samuel Colvin and other contributors
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
||||||
Modifications:
|
|
||||||
- Modified to always use pytz for timezone data to ensure consistency across environments
|
|
||||||
- Removed zoneinfo support to prevent environment-specific timezone lists
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Callable, cast
|
|
||||||
|
|
||||||
import pytz
|
|
||||||
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
|
|
||||||
from pydantic_core import PydanticCustomError, core_schema
|
|
||||||
|
|
||||||
# Cache the timezones at module level to avoid repeated computation
|
|
||||||
ALL_TIMEZONES: set[str] = set(pytz.all_timezones)
|
|
||||||
|
|
||||||
|
|
||||||
def get_timezones() -> set[str]:
|
|
||||||
"""Get timezones from pytz for consistency across all environments."""
|
|
||||||
# Return cached timezone set
|
|
||||||
return ALL_TIMEZONES
|
|
||||||
|
|
||||||
|
|
||||||
class TimeZoneNameSettings(type):
|
|
||||||
def __new__(
|
|
||||||
cls, name: str, bases: tuple[type, ...], dct: dict[str, Any], **kwargs: Any
|
|
||||||
) -> type[TimeZoneName]:
|
|
||||||
dct["strict"] = kwargs.pop("strict", True)
|
|
||||||
return cast("type[TimeZoneName]", super().__new__(cls, name, bases, dct))
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
cls, name: str, bases: tuple[type, ...], dct: dict[str, Any], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
super().__init__(name, bases, dct)
|
|
||||||
cls.strict = kwargs.get("strict", True)
|
|
||||||
|
|
||||||
|
|
||||||
def timezone_name_settings(
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Callable[[type[TimeZoneName]], type[TimeZoneName]]:
|
|
||||||
def wrapper(cls: type[TimeZoneName]) -> type[TimeZoneName]:
|
|
||||||
cls.strict = kwargs.get("strict", True)
|
|
||||||
return cls
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
@timezone_name_settings(strict=True)
|
|
||||||
class TimeZoneName(str):
|
|
||||||
"""TimeZoneName is a custom string subclass for validating and serializing timezone names."""
|
|
||||||
|
|
||||||
__slots__: list[str] = []
|
|
||||||
allowed_values: set[str] = set(get_timezones())
|
|
||||||
allowed_values_list: list[str] = sorted(allowed_values)
|
|
||||||
allowed_values_upper_to_correct: dict[str, str] = {
|
|
||||||
val.upper(): val for val in allowed_values
|
|
||||||
}
|
|
||||||
strict: bool
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate(
|
|
||||||
cls, __input_value: str, _: core_schema.ValidationInfo
|
|
||||||
) -> TimeZoneName:
|
|
||||||
if __input_value not in cls.allowed_values:
|
|
||||||
if not cls.strict:
|
|
||||||
upper_value = __input_value.strip().upper()
|
|
||||||
if upper_value in cls.allowed_values_upper_to_correct:
|
|
||||||
return cls(cls.allowed_values_upper_to_correct[upper_value])
|
|
||||||
raise PydanticCustomError("TimeZoneName", "Invalid timezone name.")
|
|
||||||
return cls(__input_value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_pydantic_core_schema__(
|
|
||||||
cls, _: type[Any], __: GetCoreSchemaHandler
|
|
||||||
) -> core_schema.AfterValidatorFunctionSchema:
|
|
||||||
return core_schema.with_info_after_validator_function(
|
|
||||||
cls._validate,
|
|
||||||
core_schema.str_schema(min_length=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_pydantic_json_schema__(
|
|
||||||
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
json_schema = handler(schema)
|
|
||||||
json_schema.update({"enum": cls.allowed_values_list})
|
|
||||||
return json_schema
|
|
||||||
@@ -1,148 +0,0 @@
|
|||||||
"""
|
|
||||||
Timezone conversion utilities for API endpoints.
|
|
||||||
Handles conversion between user timezones and UTC for scheduler operations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Optional
|
|
||||||
from zoneinfo import ZoneInfo
|
|
||||||
|
|
||||||
from croniter import croniter
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_cron_to_utc(cron_expr: str, user_timezone: str) -> str:
|
|
||||||
"""
|
|
||||||
Convert a cron expression from user timezone to UTC.
|
|
||||||
|
|
||||||
NOTE: This is a simplified conversion that only adjusts minute and hour fields.
|
|
||||||
Complex cron expressions with specific day/month/weekday patterns may not
|
|
||||||
convert accurately due to timezone offset variations throughout the year.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cron_expr: Cron expression in user timezone
|
|
||||||
user_timezone: User's IANA timezone identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Cron expression adjusted for UTC execution
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If timezone or cron expression is invalid
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
user_tz = ZoneInfo(user_timezone)
|
|
||||||
utc_tz = ZoneInfo("UTC")
|
|
||||||
|
|
||||||
# Split the cron expression into its five fields
|
|
||||||
cron_fields = cron_expr.strip().split()
|
|
||||||
if len(cron_fields) != 5:
|
|
||||||
raise ValueError(
|
|
||||||
"Cron expression must have 5 fields (minute hour day month weekday)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the current time in the user's timezone
|
|
||||||
now_user = datetime.now(user_tz)
|
|
||||||
|
|
||||||
# Get the next scheduled time in user timezone
|
|
||||||
cron = croniter(cron_expr, now_user)
|
|
||||||
next_user_time = cron.get_next(datetime)
|
|
||||||
|
|
||||||
# Convert to UTC
|
|
||||||
next_utc_time = next_user_time.astimezone(utc_tz)
|
|
||||||
|
|
||||||
# Adjust minute and hour fields for UTC, keep day/month/weekday as in original
|
|
||||||
utc_cron_parts = [
|
|
||||||
str(next_utc_time.minute),
|
|
||||||
str(next_utc_time.hour),
|
|
||||||
cron_fields[2], # day of month
|
|
||||||
cron_fields[3], # month
|
|
||||||
cron_fields[4], # day of week
|
|
||||||
]
|
|
||||||
|
|
||||||
utc_cron = " ".join(utc_cron_parts)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Converted cron '{cron_expr}' from {user_timezone} to UTC: '{utc_cron}'"
|
|
||||||
)
|
|
||||||
return utc_cron
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to convert cron expression '{cron_expr}' from {user_timezone} to UTC: {e}"
|
|
||||||
)
|
|
||||||
raise ValueError(f"Invalid cron expression or timezone: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def convert_utc_time_to_user_timezone(utc_time_str: str, user_timezone: str) -> str:
|
|
||||||
"""
|
|
||||||
Convert a UTC datetime string to user timezone.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
utc_time_str: ISO format datetime string in UTC
|
|
||||||
user_timezone: User's IANA timezone identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ISO format datetime string in user timezone
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Parse the time string
|
|
||||||
parsed_time = datetime.fromisoformat(utc_time_str.replace("Z", "+00:00"))
|
|
||||||
|
|
||||||
user_tz = ZoneInfo(user_timezone)
|
|
||||||
|
|
||||||
# If the time already has timezone info, convert it to user timezone
|
|
||||||
if parsed_time.tzinfo is not None:
|
|
||||||
# Convert to user timezone regardless of source timezone
|
|
||||||
user_time = parsed_time.astimezone(user_tz)
|
|
||||||
return user_time.isoformat()
|
|
||||||
|
|
||||||
# If no timezone info, treat as UTC and convert to user timezone
|
|
||||||
parsed_time = parsed_time.replace(tzinfo=ZoneInfo("UTC"))
|
|
||||||
user_time = parsed_time.astimezone(user_tz)
|
|
||||||
return user_time.isoformat()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to convert UTC time '{utc_time_str}' to {user_timezone}: {e}"
|
|
||||||
)
|
|
||||||
# Return original time if conversion fails
|
|
||||||
return utc_time_str
|
|
||||||
|
|
||||||
|
|
||||||
def validate_timezone(timezone: str) -> bool:
|
|
||||||
"""
|
|
||||||
Validate if a timezone string is a valid IANA timezone identifier.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timezone: Timezone string to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if valid, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
ZoneInfo(timezone)
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_user_timezone_or_utc(user_timezone: Optional[str]) -> str:
|
|
||||||
"""
|
|
||||||
Get user timezone or default to UTC if invalid/missing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_timezone: User's timezone preference
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Valid timezone string (user's preference or UTC fallback)
|
|
||||||
"""
|
|
||||||
if not user_timezone or user_timezone == "not-set":
|
|
||||||
return "UTC"
|
|
||||||
|
|
||||||
if validate_timezone(user_timezone):
|
|
||||||
return user_timezone
|
|
||||||
|
|
||||||
logger.warning(f"Invalid user timezone '{user_timezone}', falling back to UTC")
|
|
||||||
return "UTC"
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
-- AlterTable
|
|
||||||
ALTER TABLE "User" ADD COLUMN "timezone" TEXT NOT NULL DEFAULT 'not-set'
|
|
||||||
CHECK (timezone = 'not-set' OR now() AT TIME ZONE timezone IS NOT NULL);
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
-- AlterEnum
|
|
||||||
-- This migration adds more than one value to an enum.
|
|
||||||
-- With PostgreSQL versions 11 and earlier, this is not possible
|
|
||||||
-- in a single migration. This can be worked around by creating
|
|
||||||
-- multiple migrations, each migration adding only one value to
|
|
||||||
-- the enum.
|
|
||||||
|
|
||||||
|
|
||||||
ALTER TYPE "NotificationType" ADD VALUE 'AGENT_APPROVED';
|
|
||||||
ALTER TYPE "NotificationType" ADD VALUE 'AGENT_REJECTED';
|
|
||||||
|
|
||||||
-- AlterTable
|
|
||||||
ALTER TABLE "User" ADD COLUMN "notifyOnAgentApproved" BOOLEAN NOT NULL DEFAULT true,
|
|
||||||
ADD COLUMN "notifyOnAgentRejected" BOOLEAN NOT NULL DEFAULT true;
|
|
||||||
93
autogpt_platform/backend/poetry.lock
generated
93
autogpt_platform/backend/poetry.lock
generated
@@ -331,66 +331,6 @@ docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphi
|
|||||||
tests = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
|
tests = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
|
||||||
tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\""]
|
tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\""]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "audioop-lts"
|
|
||||||
version = "0.2.2"
|
|
||||||
description = "LTS Port of Python audioop"
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.13"
|
|
||||||
groups = ["main"]
|
|
||||||
markers = "python_version >= \"3.13\""
|
|
||||||
files = [
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-abi3-macosx_10_13_universal2.whl", hash = "sha256:fd3d4602dc64914d462924a08c1a9816435a2155d74f325853c1f1ac3b2d9800"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-abi3-macosx_10_13_x86_64.whl", hash = "sha256:550c114a8df0aafe9a05442a1162dfc8fec37e9af1d625ae6060fed6e756f303"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-abi3-macosx_11_0_arm64.whl", hash = "sha256:9a13dc409f2564de15dd68be65b462ba0dde01b19663720c68c1140c782d1d75"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-abi3-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:51c916108c56aa6e426ce611946f901badac950ee2ddaf302b7ed35d9958970d"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:47eba38322370347b1c47024defbd36374a211e8dd5b0dcbce7b34fdb6f8847b"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-abi3-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:ba7c3a7e5f23e215cb271516197030c32aef2e754252c4c70a50aaff7031a2c8"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-abi3-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:def246fe9e180626731b26e89816e79aae2276f825420a07b4a647abaa84becc"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-abi3-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e160bf9df356d841bb6c180eeeea1834085464626dc1b68fa4e1d59070affdc3"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:4b4cd51a57b698b2d06cb9993b7ac8dfe89a3b2878e96bc7948e9f19ff51dba6"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-abi3-musllinux_1_2_ppc64le.whl", hash = "sha256:4a53aa7c16a60a6857e6b0b165261436396ef7293f8b5c9c828a3a203147ed4a"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-abi3-musllinux_1_2_riscv64.whl", hash = "sha256:3fc38008969796f0f689f1453722a0f463da1b8a6fbee11987830bfbb664f623"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-abi3-musllinux_1_2_s390x.whl", hash = "sha256:15ab25dd3e620790f40e9ead897f91e79c0d3ce65fe193c8ed6c26cffdd24be7"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:03f061a1915538fd96272bac9551841859dbb2e3bf73ebe4a23ef043766f5449"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-abi3-win32.whl", hash = "sha256:3bcddaaf6cc5935a300a8387c99f7a7fbbe212a11568ec6cf6e4bc458c048636"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-abi3-win_amd64.whl", hash = "sha256:a2c2a947fae7d1062ef08c4e369e0ba2086049a5e598fda41122535557012e9e"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-abi3-win_arm64.whl", hash = "sha256:5f93a5db13927a37d2d09637ccca4b2b6b48c19cd9eda7b17a2e9f77edee6a6f"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:73f80bf4cd5d2ca7814da30a120de1f9408ee0619cc75da87d0641273d202a09"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:106753a83a25ee4d6f473f2be6b0966fc1c9af7e0017192f5531a3e7463dce58"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:fbdd522624141e40948ab3e8cdae6e04c748d78710e9f0f8d4dae2750831de19"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:143fad0311e8209ece30a8dbddab3b65ab419cbe8c0dde6e8828da25999be911"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dfbbc74ec68a0fd08cfec1f4b5e8cca3d3cd7de5501b01c4b5d209995033cde9"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:cfcac6aa6f42397471e4943e0feb2244549db5c5d01efcd02725b96af417f3fe"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:752d76472d9804ac60f0078c79cdae8b956f293177acd2316cd1e15149aee132"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:83c381767e2cc10e93e40281a04852facc4cd9334550e0f392f72d1c0a9c5753"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c0022283e9556e0f3643b7c3c03f05063ca72b3063291834cca43234f20c60bb"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:a2d4f1513d63c795e82948e1305f31a6d530626e5f9f2605408b300ae6095093"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:c9c8e68d8b4a56fda8c025e538e639f8c5953f5073886b596c93ec9b620055e7"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:96f19de485a2925314f5020e85911fb447ff5fbef56e8c7c6927851b95533a1c"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e541c3ef484852ef36545f66209444c48b28661e864ccadb29daddb6a4b8e5f5"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-cp313t-win32.whl", hash = "sha256:d5e73fa573e273e4f2e5ff96f9043858a5e9311e94ffefd88a3186a910c70917"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-cp313t-win_amd64.whl", hash = "sha256:9191d68659eda01e448188f60364c7763a7ca6653ed3f87ebb165822153a8547"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp313-cp313t-win_arm64.whl", hash = "sha256:c174e322bb5783c099aaf87faeb240c8d210686b04bd61dfd05a8e5a83d88969"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:f9ee9b52f5f857fbaf9d605a360884f034c92c1c23021fb90b2e39b8e64bede6"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:49ee1a41738a23e98d98b937a0638357a2477bc99e61b0f768a8f654f45d9b7a"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5b00be98ccd0fc123dcfad31d50030d25fcf31488cde9e61692029cd7394733b"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:a6d2e0f9f7a69403e388894d4ca5ada5c47230716a03f2847cfc7bd1ecb589d6"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f9b0b8a03ef474f56d1a842af1a2e01398b8f7654009823c6d9e0ecff4d5cfbf"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2b267b70747d82125f1a021506565bdc5609a2b24bcb4773c16d79d2bb260bbd"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0337d658f9b81f4cd0fdb1f47635070cc084871a3d4646d9de74fdf4e7c3d24a"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:167d3b62586faef8b6b2275c3218796b12621a60e43f7e9d5845d627b9c9b80e"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:0d9385e96f9f6da847f4d571ce3cb15b5091140edf3db97276872647ce37efd7"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:48159d96962674eccdca9a3df280e864e8ac75e40a577cc97c5c42667ffabfc5"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:8fefe5868cd082db1186f2837d64cfbfa78b548ea0d0543e9b28935ccce81ce9"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:58cf54380c3884fb49fdd37dfb7a772632b6701d28edd3e2904743c5e1773602"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:088327f00488cdeed296edd9215ca159f3a5a5034741465789cad403fcf4bec0"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp314-cp314t-win32.whl", hash = "sha256:068aa17a38b4e0e7de771c62c60bbca2455924b67a8814f3b0dee92b5820c0b3"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp314-cp314t-win_amd64.whl", hash = "sha256:a5bf613e96f49712073de86f20dbdd4014ca18efd4d34ed18c75bd808337851b"},
|
|
||||||
{file = "audioop_lts-0.2.2-cp314-cp314t-win_arm64.whl", hash = "sha256:b492c3b040153e68b9fdaff5913305aaaba5bb433d8a7f73d5cf6a64ed3cc1dd"},
|
|
||||||
{file = "audioop_lts-0.2.2.tar.gz", hash = "sha256:64d0c62d88e67b98a1a5e71987b7aa7b5bcffc7dcee65b635823dbdd0a8dbbd0"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "autogpt-libs"
|
name = "autogpt-libs"
|
||||||
version = "0.2.0"
|
version = "0.2.0"
|
||||||
@@ -862,22 +802,6 @@ files = [
|
|||||||
{file = "crashtest-0.4.1.tar.gz", hash = "sha256:80d7b1f316ebfbd429f648076d6275c877ba30ba48979de4191714a75266f0ce"},
|
{file = "crashtest-0.4.1.tar.gz", hash = "sha256:80d7b1f316ebfbd429f648076d6275c877ba30ba48979de4191714a75266f0ce"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "croniter"
|
|
||||||
version = "6.0.0"
|
|
||||||
description = "croniter provides iteration for datetime object with cron like format"
|
|
||||||
optional = false
|
|
||||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.6"
|
|
||||||
groups = ["main"]
|
|
||||||
files = [
|
|
||||||
{file = "croniter-6.0.0-py2.py3-none-any.whl", hash = "sha256:2f878c3856f17896979b2a4379ba1f09c83e374931ea15cc835c5dd2eee9b368"},
|
|
||||||
{file = "croniter-6.0.0.tar.gz", hash = "sha256:37c504b313956114a983ece2c2b07790b1f1094fe9d81cc94739214748255577"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
python-dateutil = "*"
|
|
||||||
pytz = ">2021.1"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cryptography"
|
name = "cryptography"
|
||||||
version = "43.0.3"
|
version = "43.0.3"
|
||||||
@@ -981,7 +905,6 @@ files = [
|
|||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
aiohttp = ">=3.7.4,<4"
|
aiohttp = ">=3.7.4,<4"
|
||||||
audioop-lts = {version = "*", markers = "python_version >= \"3.13\""}
|
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
dev = ["black (==22.6)", "typing_extensions (>=4.3,<5)"]
|
dev = ["black (==22.6)", "typing_extensions (>=4.3,<5)"]
|
||||||
@@ -1519,10 +1442,7 @@ grpcio-status = [
|
|||||||
{version = ">=1.33.2,<2.0.0", optional = true, markers = "extra == \"grpc\""},
|
{version = ">=1.33.2,<2.0.0", optional = true, markers = "extra == \"grpc\""},
|
||||||
{version = ">=1.49.1,<2.0.0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
{version = ">=1.49.1,<2.0.0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||||
]
|
]
|
||||||
proto-plus = [
|
proto-plus = ">=1.22.3,<2.0.0"
|
||||||
{version = ">=1.22.3,<2.0.0"},
|
|
||||||
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
|
|
||||||
]
|
|
||||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
|
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
|
||||||
requests = ">=2.18.0,<3.0.0"
|
requests = ">=2.18.0,<3.0.0"
|
||||||
|
|
||||||
@@ -1628,10 +1548,7 @@ files = [
|
|||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0", extras = ["grpc"]}
|
google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0", extras = ["grpc"]}
|
||||||
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0"
|
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0"
|
||||||
proto-plus = [
|
proto-plus = ">=1.22.3,<2.0.0"
|
||||||
{version = ">=1.22.3,<2.0.0"},
|
|
||||||
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
|
|
||||||
]
|
|
||||||
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
|
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1692,7 +1609,6 @@ opentelemetry-api = ">=1.9.0"
|
|||||||
proto-plus = [
|
proto-plus = [
|
||||||
{version = ">=1.22.0,<2.0.0"},
|
{version = ">=1.22.0,<2.0.0"},
|
||||||
{version = ">=1.22.2,<2.0.0", markers = "python_version >= \"3.11\""},
|
{version = ">=1.22.2,<2.0.0", markers = "python_version >= \"3.11\""},
|
||||||
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
|
|
||||||
]
|
]
|
||||||
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
|
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
|
||||||
|
|
||||||
@@ -4853,7 +4769,6 @@ grpcio = ">=1.41.0"
|
|||||||
httpx = {version = ">=0.20.0", extras = ["http2"]}
|
httpx = {version = ">=0.20.0", extras = ["http2"]}
|
||||||
numpy = [
|
numpy = [
|
||||||
{version = ">=1.21", markers = "python_version >= \"3.10\" and python_version < \"3.12\""},
|
{version = ">=1.21", markers = "python_version >= \"3.10\" and python_version < \"3.12\""},
|
||||||
{version = ">=2.1.0", markers = "python_version >= \"3.13\""},
|
|
||||||
{version = ">=1.26", markers = "python_version == \"3.12\""},
|
{version = ">=1.26", markers = "python_version == \"3.12\""},
|
||||||
]
|
]
|
||||||
portalocker = ">=2.7.0,<3.0.0"
|
portalocker = ">=2.7.0,<3.0.0"
|
||||||
@@ -6821,5 +6736,5 @@ cffi = ["cffi (>=1.11)"]
|
|||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.10,<3.14"
|
python-versions = ">=3.10,<3.13"
|
||||||
content-hash = "e780199a6b02f5fef3f930a4f1d69443af1977b591172c3a18a299166345c37a"
|
content-hash = "795414d7ce8f288ea6c65893268b5c29a7c9a60ad75cde28ac7bcdb65f426dfe"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "autogpt-platform-backend"
|
name = "autogpt-platform-backend"
|
||||||
version = "0.6.22"
|
version = "0.4.9"
|
||||||
description = "A platform for building AI-powered agentic workflows"
|
description = "A platform for building AI-powered agentic workflows"
|
||||||
authors = ["AutoGPT <info@agpt.co>"]
|
authors = ["AutoGPT <info@agpt.co>"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
@@ -8,7 +8,7 @@ packages = [{ include = "backend", format = "sdist" }]
|
|||||||
|
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.10,<3.14"
|
python = ">=3.10,<3.13"
|
||||||
aio-pika = "^9.5.5"
|
aio-pika = "^9.5.5"
|
||||||
aiohttp = "^3.10.0"
|
aiohttp = "^3.10.0"
|
||||||
aiodns = "^3.5.0"
|
aiodns = "^3.5.0"
|
||||||
@@ -77,7 +77,6 @@ gcloud-aio-storage = "^9.5.0"
|
|||||||
pandas = "^2.3.1"
|
pandas = "^2.3.1"
|
||||||
firecrawl-py = "^2.16.3"
|
firecrawl-py = "^2.16.3"
|
||||||
exa-py = "^1.14.20"
|
exa-py = "^1.14.20"
|
||||||
croniter = "^6.0.0"
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
aiohappyeyeballs = "^2.6.1"
|
aiohappyeyeballs = "^2.6.1"
|
||||||
|
|||||||
@@ -33,10 +33,6 @@ model User {
|
|||||||
notifyOnDailySummary Boolean @default(true)
|
notifyOnDailySummary Boolean @default(true)
|
||||||
notifyOnWeeklySummary Boolean @default(true)
|
notifyOnWeeklySummary Boolean @default(true)
|
||||||
notifyOnMonthlySummary Boolean @default(true)
|
notifyOnMonthlySummary Boolean @default(true)
|
||||||
notifyOnAgentApproved Boolean @default(true)
|
|
||||||
notifyOnAgentRejected Boolean @default(true)
|
|
||||||
|
|
||||||
timezone String @default("not-set")
|
|
||||||
|
|
||||||
// Relations
|
// Relations
|
||||||
|
|
||||||
@@ -191,8 +187,6 @@ enum NotificationType {
|
|||||||
MONTHLY_SUMMARY
|
MONTHLY_SUMMARY
|
||||||
REFUND_REQUEST
|
REFUND_REQUEST
|
||||||
REFUND_PROCESSED
|
REFUND_PROCESSED
|
||||||
AGENT_APPROVED
|
|
||||||
AGENT_REJECTED
|
|
||||||
}
|
}
|
||||||
|
|
||||||
model NotificationEvent {
|
model NotificationEvent {
|
||||||
|
|||||||
@@ -15,10 +15,6 @@
|
|||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {}
|
"properties": {}
|
||||||
},
|
},
|
||||||
"output_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {}
|
|
||||||
},
|
|
||||||
"credentials_input_schema": {
|
"credentials_input_schema": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {}
|
"properties": {}
|
||||||
@@ -44,10 +40,6 @@
|
|||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {}
|
"properties": {}
|
||||||
},
|
},
|
||||||
"output_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {}
|
|
||||||
},
|
|
||||||
"credentials_input_schema": {
|
"credentials_input_schema": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {}
|
"properties": {}
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ export default defineConfig({
|
|||||||
usePrefetch: true,
|
usePrefetch: true,
|
||||||
// Will add more as their use cases arise
|
// Will add more as their use cases arise
|
||||||
},
|
},
|
||||||
useDates: true,
|
|
||||||
operations: {
|
operations: {
|
||||||
"getV2List library agents": {
|
"getV2List library agents": {
|
||||||
query: {
|
query: {
|
||||||
@@ -35,12 +34,6 @@ export default defineConfig({
|
|||||||
useInfiniteQueryParam: "page",
|
useInfiniteQueryParam: "page",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"getV1List graph executions": {
|
|
||||||
query: {
|
|
||||||
useInfinite: true,
|
|
||||||
useInfiniteQueryParam: "page",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ import LoadingBox, { LoadingSpinner } from "@/components/ui/loading";
|
|||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
import { AgentRunDetailsView } from "./components/agent-run-details-view";
|
import { AgentRunDetailsView } from "./components/agent-run-details-view";
|
||||||
import { AgentRunDraftView } from "./components/agent-run-draft-view";
|
import { AgentRunDraftView } from "./components/agent-run-draft-view";
|
||||||
import { useAgentRunsInfinite } from "../use-agent-runs";
|
|
||||||
import { AgentRunsSelectorList } from "./components/agent-runs-selector-list";
|
import { AgentRunsSelectorList } from "./components/agent-runs-selector-list";
|
||||||
import { AgentScheduleDetailsView } from "./components/agent-schedule-details-view";
|
import { AgentScheduleDetailsView } from "./components/agent-schedule-details-view";
|
||||||
|
|
||||||
@@ -56,8 +55,7 @@ export function OldAgentLibraryView() {
|
|||||||
|
|
||||||
const [graph, setGraph] = useState<Graph | null>(null); // Graph version corresponding to LibraryAgent
|
const [graph, setGraph] = useState<Graph | null>(null); // Graph version corresponding to LibraryAgent
|
||||||
const [agent, setAgent] = useState<LibraryAgent | null>(null);
|
const [agent, setAgent] = useState<LibraryAgent | null>(null);
|
||||||
const agentRunsQuery = useAgentRunsInfinite(graph?.id); // only runs once graph.id is known
|
const [agentRuns, setAgentRuns] = useState<GraphExecutionMeta[]>([]);
|
||||||
const agentRuns = agentRunsQuery.agentRuns;
|
|
||||||
const [agentPresets, setAgentPresets] = useState<LibraryAgentPreset[]>([]);
|
const [agentPresets, setAgentPresets] = useState<LibraryAgentPreset[]>([]);
|
||||||
const [schedules, setSchedules] = useState<Schedule[]>([]);
|
const [schedules, setSchedules] = useState<Schedule[]>([]);
|
||||||
const [selectedView, selectView] = useState<
|
const [selectedView, selectView] = useState<
|
||||||
@@ -158,22 +156,19 @@ export function OldAgentLibraryView() {
|
|||||||
(graph && graph.version == _graph.version) || setGraph(_graph),
|
(graph && graph.version == _graph.version) || setGraph(_graph),
|
||||||
);
|
);
|
||||||
Promise.all([
|
Promise.all([
|
||||||
agentRunsQuery.refetchRuns(),
|
api.getGraphExecutions(agent.graph_id),
|
||||||
api.listLibraryAgentPresets({
|
api.listLibraryAgentPresets({
|
||||||
graph_id: agent.graph_id,
|
graph_id: agent.graph_id,
|
||||||
page_size: 100,
|
page_size: 100,
|
||||||
}),
|
}),
|
||||||
]).then(([runsQueryResult, presets]) => {
|
]).then(([runs, presets]) => {
|
||||||
|
setAgentRuns(runs);
|
||||||
setAgentPresets(presets.presets);
|
setAgentPresets(presets.presets);
|
||||||
|
|
||||||
const newestAgentRunsResponse = runsQueryResult.data?.pages[0];
|
|
||||||
if (!newestAgentRunsResponse || newestAgentRunsResponse.status != 200)
|
|
||||||
return;
|
|
||||||
const newestAgentRuns = newestAgentRunsResponse.data.executions;
|
|
||||||
// Preload the corresponding graph versions for the latest 10 runs
|
// Preload the corresponding graph versions for the latest 10 runs
|
||||||
new Set(
|
new Set(runs.slice(0, 10).map((run) => run.graph_version)).forEach(
|
||||||
newestAgentRuns.slice(0, 10).map((run) => run.graph_version),
|
(version) => getGraphVersion(agent.graph_id, version),
|
||||||
).forEach((version) => getGraphVersion(agent.graph_id, version));
|
);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}, [api, agentID, getGraphVersion, graph]);
|
}, [api, agentID, getGraphVersion, graph]);
|
||||||
@@ -192,7 +187,7 @@ export function OldAgentLibraryView() {
|
|||||||
else if (!latest.started_at) return latest;
|
else if (!latest.started_at) return latest;
|
||||||
return latest.started_at > current.started_at ? latest : current;
|
return latest.started_at > current.started_at ? latest : current;
|
||||||
}, agentRuns[0]);
|
}, agentRuns[0]);
|
||||||
selectRun(latestRun.id as GraphExecutionID);
|
selectRun(latestRun.id);
|
||||||
} else {
|
} else {
|
||||||
// select top preset
|
// select top preset
|
||||||
const latestPreset = agentPresets.toSorted(
|
const latestPreset = agentPresets.toSorted(
|
||||||
@@ -282,7 +277,18 @@ export function OldAgentLibraryView() {
|
|||||||
incrementRuns();
|
incrementRuns();
|
||||||
}
|
}
|
||||||
|
|
||||||
agentRunsQuery.upsertAgentRun(data);
|
setAgentRuns((prev) => {
|
||||||
|
const index = prev.findIndex((run) => run.id === data.id);
|
||||||
|
if (index === -1) {
|
||||||
|
return [...prev, data];
|
||||||
|
}
|
||||||
|
const newRuns = [...prev];
|
||||||
|
newRuns[index] = { ...newRuns[index], ...data };
|
||||||
|
return newRuns;
|
||||||
|
});
|
||||||
|
if (data.id === selectedView.id) {
|
||||||
|
setSelectedRun((prev) => ({ ...prev, ...data }));
|
||||||
|
}
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -338,7 +344,7 @@ export function OldAgentLibraryView() {
|
|||||||
if (selectedView.type == "run" && selectedView.id == run.id) {
|
if (selectedView.type == "run" && selectedView.id == run.id) {
|
||||||
openRunDraftView();
|
openRunDraftView();
|
||||||
}
|
}
|
||||||
agentRunsQuery.removeAgentRun(run.id);
|
setAgentRuns((runs) => runs.filter((r) => r.id !== run.id));
|
||||||
},
|
},
|
||||||
[api, selectedView, openRunDraftView],
|
[api, selectedView, openRunDraftView],
|
||||||
);
|
);
|
||||||
@@ -475,9 +481,9 @@ export function OldAgentLibraryView() {
|
|||||||
{/* Sidebar w/ list of runs */}
|
{/* Sidebar w/ list of runs */}
|
||||||
{/* TODO: render this below header in sm and md layouts */}
|
{/* TODO: render this below header in sm and md layouts */}
|
||||||
<AgentRunsSelectorList
|
<AgentRunsSelectorList
|
||||||
className="agpt-div w-full border-b pb-2 lg:w-auto lg:border-b-0 lg:border-r lg:pb-0"
|
className="agpt-div w-full border-b lg:w-auto lg:border-b-0 lg:border-r"
|
||||||
agent={agent}
|
agent={agent}
|
||||||
agentRunsQuery={agentRunsQuery}
|
agentRuns={agentRuns}
|
||||||
agentPresets={agentPresets}
|
agentPresets={agentPresets}
|
||||||
schedules={schedules}
|
schedules={schedules}
|
||||||
selectedView={selectedView}
|
selectedView={selectedView}
|
||||||
|
|||||||
@@ -14,19 +14,16 @@ import {
|
|||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
import { Badge } from "@/components/ui/badge";
|
import { Badge } from "@/components/ui/badge";
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||||
import LoadingBox from "@/components/ui/loading";
|
|
||||||
import { InfiniteScroll } from "@/components/contextual/InfiniteScroll/InfiniteScroll";
|
|
||||||
import { Separator } from "@/components/ui/separator";
|
import { Separator } from "@/components/ui/separator";
|
||||||
|
|
||||||
import { agentRunStatusMap } from "@/components/agents/agent-run-status-chip";
|
import { agentRunStatusMap } from "@/components/agents/agent-run-status-chip";
|
||||||
import AgentRunSummaryCard from "@/components/agents/agent-run-summary-card";
|
import AgentRunSummaryCard from "@/components/agents/agent-run-summary-card";
|
||||||
import { AgentRunsQuery } from "../../use-agent-runs";
|
import { Button } from "../../../../../../../../components/atoms/Button/Button";
|
||||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
|
||||||
|
|
||||||
interface AgentRunsSelectorListProps {
|
interface AgentRunsSelectorListProps {
|
||||||
agent: LibraryAgent;
|
agent: LibraryAgent;
|
||||||
agentRunsQuery: AgentRunsQuery;
|
agentRuns: GraphExecutionMeta[];
|
||||||
agentPresets: LibraryAgentPreset[];
|
agentPresets: LibraryAgentPreset[];
|
||||||
schedules: Schedule[];
|
schedules: Schedule[];
|
||||||
selectedView: { type: "run" | "preset" | "schedule"; id?: string };
|
selectedView: { type: "run" | "preset" | "schedule"; id?: string };
|
||||||
@@ -43,13 +40,7 @@ interface AgentRunsSelectorListProps {
|
|||||||
|
|
||||||
export function AgentRunsSelectorList({
|
export function AgentRunsSelectorList({
|
||||||
agent,
|
agent,
|
||||||
agentRunsQuery: {
|
agentRuns,
|
||||||
agentRuns,
|
|
||||||
agentRunsLoading,
|
|
||||||
hasMoreRuns,
|
|
||||||
fetchMoreRuns,
|
|
||||||
isFetchingMoreRuns,
|
|
||||||
},
|
|
||||||
agentPresets,
|
agentPresets,
|
||||||
schedules,
|
schedules,
|
||||||
selectedView,
|
selectedView,
|
||||||
@@ -75,7 +66,7 @@ export function AgentRunsSelectorList({
|
|||||||
}
|
}
|
||||||
}, [selectedView]);
|
}, [selectedView]);
|
||||||
|
|
||||||
const listItemClasses = "h-28 w-72 lg:w-full lg:h-32";
|
const listItemClasses = "h-28 w-72 lg:h-32 xl:w-80";
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<aside className={cn("flex flex-col gap-4", className)}>
|
<aside className={cn("flex flex-col gap-4", className)}>
|
||||||
@@ -110,99 +101,84 @@ export function AgentRunsSelectorList({
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Runs / Schedules list */}
|
{/* Runs / Schedules list */}
|
||||||
{agentRunsLoading && activeListTab === "runs" ? (
|
<ScrollArea className="lg:h-[calc(100vh-200px)]">
|
||||||
<LoadingBox className="h-28 w-full lg:h-[calc(100vh-300px)] lg:w-72 xl:w-80" />
|
<div className="flex gap-2 lg:flex-col">
|
||||||
) : (
|
{/* New Run button - only in small layouts */}
|
||||||
<ScrollArea
|
{allowDraftNewRun && (
|
||||||
className="w-full lg:h-[calc(100vh-300px)] lg:w-72 xl:w-80"
|
<Button
|
||||||
orientation={window.innerWidth >= 1024 ? "vertical" : "horizontal"}
|
size="large"
|
||||||
>
|
className={
|
||||||
<InfiniteScroll
|
"flex h-28 w-40 items-center gap-2 py-6 lg:hidden " +
|
||||||
direction={window.innerWidth >= 1024 ? "vertical" : "horizontal"}
|
(selectedView.type == "run" && !selectedView.id
|
||||||
hasNextPage={hasMoreRuns}
|
? "agpt-card-selected text-accent"
|
||||||
fetchNextPage={fetchMoreRuns}
|
: "")
|
||||||
isFetchingNextPage={isFetchingMoreRuns}
|
}
|
||||||
>
|
onClick={onSelectDraftNewRun}
|
||||||
<div className="flex items-center gap-2 lg:flex-col">
|
leftIcon={<Plus className="h-6 w-6" />}
|
||||||
{/* New Run button - only in small layouts */}
|
>
|
||||||
{allowDraftNewRun && (
|
New {agent.has_external_trigger ? "trigger" : "run"}
|
||||||
<Button
|
</Button>
|
||||||
size="large"
|
)}
|
||||||
className={
|
|
||||||
"flex h-12 w-40 items-center gap-2 py-6 lg:hidden " +
|
|
||||||
(selectedView.type == "run" && !selectedView.id
|
|
||||||
? "agpt-card-selected text-accent"
|
|
||||||
: "")
|
|
||||||
}
|
|
||||||
onClick={onSelectDraftNewRun}
|
|
||||||
leftIcon={<Plus className="h-6 w-6" />}
|
|
||||||
>
|
|
||||||
New {agent.has_external_trigger ? "trigger" : "run"}
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{activeListTab === "runs" ? (
|
{activeListTab === "runs" ? (
|
||||||
<>
|
<>
|
||||||
{agentPresets
|
{agentPresets
|
||||||
.toSorted(
|
.toSorted(
|
||||||
(a, b) => b.updated_at.getTime() - a.updated_at.getTime(),
|
(a, b) => b.updated_at.getTime() - a.updated_at.getTime(),
|
||||||
)
|
)
|
||||||
.map((preset) => (
|
.map((preset) => (
|
||||||
<AgentRunSummaryCard
|
<AgentRunSummaryCard
|
||||||
className={cn(listItemClasses, "lg:h-auto")}
|
className={cn(listItemClasses, "lg:h-auto")}
|
||||||
key={preset.id}
|
key={preset.id}
|
||||||
type="preset"
|
type="preset"
|
||||||
status={preset.is_active ? "active" : "inactive"}
|
status={preset.is_active ? "active" : "inactive"}
|
||||||
title={preset.name}
|
title={preset.name}
|
||||||
// timestamp={preset.last_run_time} // TODO: implement this
|
// timestamp={preset.last_run_time} // TODO: implement this
|
||||||
selected={selectedView.id === preset.id}
|
selected={selectedView.id === preset.id}
|
||||||
onClick={() => onSelectPreset(preset.id)}
|
onClick={() => onSelectPreset(preset.id)}
|
||||||
onDelete={() => doDeletePreset(preset.id)}
|
onDelete={() => doDeletePreset(preset.id)}
|
||||||
/>
|
/>
|
||||||
))}
|
))}
|
||||||
{agentPresets.length > 0 && <Separator className="my-1" />}
|
{agentPresets.length > 0 && <Separator className="my-1" />}
|
||||||
{agentRuns
|
{agentRuns
|
||||||
.toSorted(
|
.toSorted(
|
||||||
(a, b) => b.started_at.getTime() - a.started_at.getTime(),
|
(a, b) => b.started_at.getTime() - a.started_at.getTime(),
|
||||||
)
|
)
|
||||||
.map((run) => (
|
.map((run) => (
|
||||||
<AgentRunSummaryCard
|
|
||||||
className={listItemClasses}
|
|
||||||
key={run.id}
|
|
||||||
type="run"
|
|
||||||
status={agentRunStatusMap[run.status]}
|
|
||||||
title={
|
|
||||||
(run.preset_id
|
|
||||||
? agentPresets.find((p) => p.id == run.preset_id)
|
|
||||||
?.name
|
|
||||||
: null) ?? agent.name
|
|
||||||
}
|
|
||||||
timestamp={run.started_at}
|
|
||||||
selected={selectedView.id === run.id}
|
|
||||||
onClick={() => onSelectRun(run.id)}
|
|
||||||
onDelete={() => doDeleteRun(run)}
|
|
||||||
/>
|
|
||||||
))}
|
|
||||||
</>
|
|
||||||
) : (
|
|
||||||
schedules.map((schedule) => (
|
|
||||||
<AgentRunSummaryCard
|
<AgentRunSummaryCard
|
||||||
className={listItemClasses}
|
className={listItemClasses}
|
||||||
key={schedule.id}
|
key={run.id}
|
||||||
type="schedule"
|
type="run"
|
||||||
status="scheduled" // TODO: implement active/inactive status for schedules
|
status={agentRunStatusMap[run.status]}
|
||||||
title={schedule.name}
|
title={
|
||||||
timestamp={schedule.next_run_time}
|
(run.preset_id
|
||||||
selected={selectedView.id === schedule.id}
|
? agentPresets.find((p) => p.id == run.preset_id)?.name
|
||||||
onClick={() => onSelectSchedule(schedule.id)}
|
: null) ?? agent.name
|
||||||
onDelete={() => doDeleteSchedule(schedule.id)}
|
}
|
||||||
|
timestamp={run.started_at}
|
||||||
|
selected={selectedView.id === run.id}
|
||||||
|
onClick={() => onSelectRun(run.id)}
|
||||||
|
onDelete={() => doDeleteRun(run)}
|
||||||
/>
|
/>
|
||||||
))
|
))}
|
||||||
)}
|
</>
|
||||||
</div>
|
) : (
|
||||||
</InfiniteScroll>
|
schedules.map((schedule) => (
|
||||||
</ScrollArea>
|
<AgentRunSummaryCard
|
||||||
)}
|
className={listItemClasses}
|
||||||
|
key={schedule.id}
|
||||||
|
type="schedule"
|
||||||
|
status="scheduled" // TODO: implement active/inactive status for schedules
|
||||||
|
title={schedule.name}
|
||||||
|
timestamp={schedule.next_run_time}
|
||||||
|
selected={selectedView.id === schedule.id}
|
||||||
|
onClick={() => onSelectSchedule(schedule.id)}
|
||||||
|
onDelete={() => doDeleteSchedule(schedule.id)}
|
||||||
|
/>
|
||||||
|
))
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</ScrollArea>
|
||||||
</aside>
|
</aside>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,8 +18,6 @@ import { Input } from "@/components/ui/input";
|
|||||||
import LoadingBox from "@/components/ui/loading";
|
import LoadingBox from "@/components/ui/loading";
|
||||||
import { useToastOnFail } from "@/components/molecules/Toast/use-toast";
|
import { useToastOnFail } from "@/components/molecules/Toast/use-toast";
|
||||||
import { humanizeCronExpression } from "@/lib/cron-expression-utils";
|
import { humanizeCronExpression } from "@/lib/cron-expression-utils";
|
||||||
import { formatScheduleTime } from "@/lib/timezone-utils";
|
|
||||||
import { useGetV1GetUserTimezone } from "@/app/api/__generated__/endpoints/auth/auth";
|
|
||||||
import { PlayIcon } from "lucide-react";
|
import { PlayIcon } from "lucide-react";
|
||||||
|
|
||||||
export function AgentScheduleDetailsView({
|
export function AgentScheduleDetailsView({
|
||||||
@@ -41,10 +39,6 @@ export function AgentScheduleDetailsView({
|
|||||||
|
|
||||||
const toastOnFail = useToastOnFail();
|
const toastOnFail = useToastOnFail();
|
||||||
|
|
||||||
// Get user's timezone for displaying schedule times
|
|
||||||
const { data: timezoneData } = useGetV1GetUserTimezone();
|
|
||||||
const userTimezone = timezoneData?.data?.timezone || "UTC";
|
|
||||||
|
|
||||||
const infoStats: { label: string; value: React.ReactNode }[] = useMemo(() => {
|
const infoStats: { label: string; value: React.ReactNode }[] = useMemo(() => {
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
@@ -55,14 +49,14 @@ export function AgentScheduleDetailsView({
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
label: "Schedule",
|
label: "Schedule",
|
||||||
value: humanizeCronExpression(schedule.cron, userTimezone),
|
value: humanizeCronExpression(schedule.cron),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
label: "Next run",
|
label: "Next run",
|
||||||
value: formatScheduleTime(schedule.next_run_time, userTimezone),
|
value: schedule.next_run_time.toLocaleString(),
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
}, [schedule, selectedRunStatus, userTimezone]);
|
}, [schedule, selectedRunStatus]);
|
||||||
|
|
||||||
const agentRunInputs: Record<
|
const agentRunInputs: Record<
|
||||||
string,
|
string,
|
||||||
|
|||||||
@@ -1,201 +0,0 @@
|
|||||||
import {
|
|
||||||
getV1ListGraphExecutionsResponse,
|
|
||||||
useGetV1ListGraphExecutionsInfinite,
|
|
||||||
} from "@/app/api/__generated__/endpoints/graphs/graphs";
|
|
||||||
import { GraphExecutionsPaginated } from "@/app/api/__generated__/models/graphExecutionsPaginated";
|
|
||||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
|
||||||
import {
|
|
||||||
GraphExecutionMeta as LegacyGraphExecutionMeta,
|
|
||||||
GraphID,
|
|
||||||
GraphExecutionID,
|
|
||||||
} from "@/lib/autogpt-server-api";
|
|
||||||
import { GraphExecutionMeta as RawGraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
|
||||||
|
|
||||||
export type GraphExecutionMeta = Omit<
|
|
||||||
RawGraphExecutionMeta,
|
|
||||||
"id" | "user_id" | "graph_id" | "preset_id" | "stats"
|
|
||||||
> &
|
|
||||||
Pick<
|
|
||||||
LegacyGraphExecutionMeta,
|
|
||||||
"id" | "user_id" | "graph_id" | "preset_id" | "stats"
|
|
||||||
>;
|
|
||||||
|
|
||||||
/** Hook to fetch runs for a specific graph, with support for infinite scroll.
|
|
||||||
*
|
|
||||||
* @param graphID - The ID of the graph to fetch agent runs for. This parameter is
|
|
||||||
* optional in the sense that the hook doesn't run unless it is passed.
|
|
||||||
* This way, it can be used in components where the graph ID is not
|
|
||||||
* immediately available.
|
|
||||||
*/
|
|
||||||
export const useAgentRunsInfinite = (graphID?: GraphID) => {
|
|
||||||
const queryClient = getQueryClient();
|
|
||||||
const {
|
|
||||||
data: queryResults,
|
|
||||||
refetch: refetchRuns,
|
|
||||||
isPending: agentRunsLoading,
|
|
||||||
isRefetching: agentRunsReloading,
|
|
||||||
hasNextPage: hasMoreRuns,
|
|
||||||
fetchNextPage: fetchMoreRuns,
|
|
||||||
isFetchingNextPage: isFetchingMoreRuns,
|
|
||||||
queryKey,
|
|
||||||
} = useGetV1ListGraphExecutionsInfinite(
|
|
||||||
graphID!,
|
|
||||||
{ page: 1, page_size: 20 },
|
|
||||||
{
|
|
||||||
query: {
|
|
||||||
getNextPageParam: (lastPage) => {
|
|
||||||
const pagination = (lastPage.data as GraphExecutionsPaginated)
|
|
||||||
.pagination;
|
|
||||||
const hasMore =
|
|
||||||
pagination.current_page * pagination.page_size <
|
|
||||||
pagination.total_items;
|
|
||||||
|
|
||||||
return hasMore ? pagination.current_page + 1 : undefined;
|
|
||||||
},
|
|
||||||
|
|
||||||
// Prevent query from running if graphID is not available (yet)
|
|
||||||
...(!graphID
|
|
||||||
? {
|
|
||||||
enabled: false,
|
|
||||||
queryFn: () =>
|
|
||||||
// Fake empty response if graphID is not available (yet)
|
|
||||||
Promise.resolve({
|
|
||||||
status: 200,
|
|
||||||
data: {
|
|
||||||
executions: [],
|
|
||||||
pagination: {
|
|
||||||
current_page: 1,
|
|
||||||
page_size: 20,
|
|
||||||
total_items: 0,
|
|
||||||
total_pages: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
headers: new Headers(),
|
|
||||||
} satisfies getV1ListGraphExecutionsResponse),
|
|
||||||
}
|
|
||||||
: {}),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
queryClient,
|
|
||||||
);
|
|
||||||
|
|
||||||
const agentRuns =
|
|
||||||
queryResults?.pages.flatMap((page) => {
|
|
||||||
const response = page.data as GraphExecutionsPaginated;
|
|
||||||
return response.executions;
|
|
||||||
}) ?? [];
|
|
||||||
|
|
||||||
const agentRunCount = queryResults?.pages[-1]
|
|
||||||
? (queryResults.pages[-1].data as GraphExecutionsPaginated).pagination
|
|
||||||
.total_items
|
|
||||||
: 0;
|
|
||||||
|
|
||||||
const upsertAgentRun = (newAgentRun: GraphExecutionMeta) => {
|
|
||||||
queryClient.setQueryData(
|
|
||||||
[queryKey, { page: 1, page_size: 20 }],
|
|
||||||
(currentQueryData: typeof queryResults) => {
|
|
||||||
if (!currentQueryData?.pages) return currentQueryData;
|
|
||||||
|
|
||||||
const exists = currentQueryData.pages.some((page) => {
|
|
||||||
const response = page.data as GraphExecutionsPaginated;
|
|
||||||
return response.executions.some((run) => run.id === newAgentRun.id);
|
|
||||||
});
|
|
||||||
if (exists) {
|
|
||||||
// If the run already exists, we update it
|
|
||||||
return {
|
|
||||||
...currentQueryData,
|
|
||||||
pages: currentQueryData.pages.map((page) => {
|
|
||||||
const response = page.data as GraphExecutionsPaginated;
|
|
||||||
const executions = response.executions;
|
|
||||||
|
|
||||||
const index = executions.findIndex(
|
|
||||||
(run) => run.id === newAgentRun.id,
|
|
||||||
);
|
|
||||||
if (index === -1) return page;
|
|
||||||
|
|
||||||
const newExecutions = [...executions];
|
|
||||||
newExecutions[index] = newAgentRun;
|
|
||||||
|
|
||||||
return {
|
|
||||||
...page,
|
|
||||||
data: {
|
|
||||||
...response,
|
|
||||||
executions: newExecutions,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the run does not exist, we add it to the first page
|
|
||||||
const page = currentQueryData.pages[0];
|
|
||||||
const updatedExecutions = [
|
|
||||||
newAgentRun,
|
|
||||||
...(page.data as GraphExecutionsPaginated).executions,
|
|
||||||
];
|
|
||||||
const updatedPage = {
|
|
||||||
...page,
|
|
||||||
data: {
|
|
||||||
...page.data,
|
|
||||||
executions: updatedExecutions,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
const updatedPages = [updatedPage, ...currentQueryData.pages.slice(1)];
|
|
||||||
return {
|
|
||||||
...currentQueryData,
|
|
||||||
pages: updatedPages,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const removeAgentRun = (runID: GraphExecutionID) => {
|
|
||||||
queryClient.setQueryData(
|
|
||||||
[queryKey, { page: 1, page_size: 20 }],
|
|
||||||
(currentQueryData: typeof queryResults) => {
|
|
||||||
if (!currentQueryData?.pages) return currentQueryData;
|
|
||||||
|
|
||||||
let found = false;
|
|
||||||
return {
|
|
||||||
...currentQueryData,
|
|
||||||
pages: currentQueryData.pages.map((page) => {
|
|
||||||
const response = page.data as GraphExecutionsPaginated;
|
|
||||||
const filteredExecutions = response.executions.filter(
|
|
||||||
(run) => run.id !== runID,
|
|
||||||
);
|
|
||||||
if (filteredExecutions.length < response.executions.length) {
|
|
||||||
found = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
...page,
|
|
||||||
data: {
|
|
||||||
...response,
|
|
||||||
executions: filteredExecutions,
|
|
||||||
pagination: {
|
|
||||||
...response.pagination,
|
|
||||||
total_items:
|
|
||||||
response.pagination.total_items - (found ? 1 : 0),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
},
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
return {
|
|
||||||
agentRuns: agentRuns as GraphExecutionMeta[],
|
|
||||||
refetchRuns,
|
|
||||||
agentRunCount,
|
|
||||||
agentRunsLoading: agentRunsLoading || agentRunsReloading,
|
|
||||||
hasMoreRuns,
|
|
||||||
fetchMoreRuns,
|
|
||||||
isFetchingMoreRuns,
|
|
||||||
upsertAgentRun,
|
|
||||||
removeAgentRun,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
export type AgentRunsQuery = ReturnType<typeof useAgentRunsInfinite>;
|
|
||||||
@@ -28,6 +28,7 @@ export default function LibraryAgentList() {
|
|||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<InfiniteScroll
|
<InfiniteScroll
|
||||||
|
dataLength={agents.length}
|
||||||
isFetchingNextPage={isFetchingNextPage}
|
isFetchingNextPage={isFetchingNextPage}
|
||||||
fetchNextPage={fetchNextPage}
|
fetchNextPage={fetchNextPage}
|
||||||
hasNextPage={hasNextPage}
|
hasNextPage={hasNextPage}
|
||||||
|
|||||||
@@ -45,7 +45,6 @@ export default function LibraryUploadAgentDialog(): React.ReactNode {
|
|||||||
form,
|
form,
|
||||||
setisDroped,
|
setisDroped,
|
||||||
agentObject,
|
agentObject,
|
||||||
clearAgentFile,
|
|
||||||
} = useLibraryUploadAgentDialog();
|
} = useLibraryUploadAgentDialog();
|
||||||
return (
|
return (
|
||||||
<Dialog open={isOpen} onOpenChange={setIsOpen}>
|
<Dialog open={isOpen} onOpenChange={setIsOpen}>
|
||||||
@@ -106,7 +105,9 @@ export default function LibraryUploadAgentDialog(): React.ReactNode {
|
|||||||
<div className="relative flex rounded-[10px] border p-2 font-sans text-sm font-medium text-[#525252] outline-none">
|
<div className="relative flex rounded-[10px] border p-2 font-sans text-sm font-medium text-[#525252] outline-none">
|
||||||
<span className="line-clamp-1">{field.value.name}</span>
|
<span className="line-clamp-1">{field.value.name}</span>
|
||||||
<Button
|
<Button
|
||||||
onClick={clearAgentFile}
|
onClick={() =>
|
||||||
|
form.setValue("agentFile", undefined as any)
|
||||||
|
}
|
||||||
className="absolute left-[-10px] top-[-16px] mt-2 h-fit border-none bg-red-200 p-1"
|
className="absolute left-[-10px] top-[-16px] mt-2 h-fit border-none bg-red-200 p-1"
|
||||||
>
|
>
|
||||||
<X
|
<X
|
||||||
|
|||||||
@@ -102,22 +102,6 @@ export const useLibraryUploadAgentDialog = () => {
|
|||||||
setisDroped(false);
|
setisDroped(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
const clearAgentFile = () => {
|
|
||||||
const currentName = form.getValues("agentName");
|
|
||||||
const currentDescription = form.getValues("agentDescription");
|
|
||||||
const prevAgent = agentObject;
|
|
||||||
|
|
||||||
form.setValue("agentFile", undefined as any);
|
|
||||||
if (prevAgent && currentName === prevAgent.name) {
|
|
||||||
form.setValue("agentName", "");
|
|
||||||
}
|
|
||||||
if (prevAgent && currentDescription === prevAgent.description) {
|
|
||||||
form.setValue("agentDescription", "");
|
|
||||||
}
|
|
||||||
|
|
||||||
setAgentObject(null);
|
|
||||||
};
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
onSubmit,
|
onSubmit,
|
||||||
isUploading,
|
isUploading,
|
||||||
@@ -128,6 +112,5 @@ export const useLibraryUploadAgentDialog = () => {
|
|||||||
isDroped,
|
isDroped,
|
||||||
handleChange,
|
handleChange,
|
||||||
setisDroped,
|
setisDroped,
|
||||||
clearAgentFile,
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,4 +1,11 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
import Link from "next/link";
|
||||||
|
|
||||||
|
import { Alert, AlertDescription } from "@/components/ui/alert";
|
||||||
|
import {
|
||||||
|
ArrowBottomRightIcon,
|
||||||
|
QuestionMarkCircledIcon,
|
||||||
|
} from "@radix-ui/react-icons";
|
||||||
|
|
||||||
import LibraryActionHeader from "./components/LibraryActionHeader/LibraryActionHeader";
|
import LibraryActionHeader from "./components/LibraryActionHeader/LibraryActionHeader";
|
||||||
import LibraryAgentList from "./components/LibraryAgentList/LibraryAgentList";
|
import LibraryAgentList from "./components/LibraryAgentList/LibraryAgentList";
|
||||||
@@ -15,6 +22,21 @@ export default function LibraryPage() {
|
|||||||
<LibraryActionHeader />
|
<LibraryActionHeader />
|
||||||
<LibraryAgentList />
|
<LibraryAgentList />
|
||||||
</LibraryPageStateProvider>
|
</LibraryPageStateProvider>
|
||||||
|
|
||||||
|
<Alert
|
||||||
|
variant="default"
|
||||||
|
className="fixed bottom-2 left-1/2 hidden max-w-4xl -translate-x-1/2 md:block"
|
||||||
|
>
|
||||||
|
<AlertDescription className="text-center">
|
||||||
|
Prefer the old experience? Click{" "}
|
||||||
|
<Link href="/monitoring" className="underline">
|
||||||
|
here
|
||||||
|
</Link>{" "}
|
||||||
|
to go to it. Please do let us know why by clicking the{" "}
|
||||||
|
<QuestionMarkCircledIcon className="inline-block size-6 rounded-full bg-[rgba(65,65,64,1)] p-1 align-bottom text-neutral-50" />{" "}
|
||||||
|
in the bottom right corner <ArrowBottomRightIcon className="inline" />
|
||||||
|
</AlertDescription>
|
||||||
|
</Alert>
|
||||||
</main>
|
</main>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ export interface AgentTableCardProps {
|
|||||||
sub_heading: string;
|
sub_heading: string;
|
||||||
description: string;
|
description: string;
|
||||||
imageSrc: string[];
|
imageSrc: string[];
|
||||||
dateSubmitted: Date;
|
dateSubmitted: string;
|
||||||
status: SubmissionStatus;
|
status: SubmissionStatus;
|
||||||
runs: number;
|
runs: number;
|
||||||
rating: number;
|
rating: number;
|
||||||
@@ -80,7 +80,7 @@ export const AgentTableCard = ({
|
|||||||
<div className="mt-4 flex flex-wrap gap-4">
|
<div className="mt-4 flex flex-wrap gap-4">
|
||||||
<Status status={status} />
|
<Status status={status} />
|
||||||
<div className="text-sm text-neutral-600 dark:text-neutral-400">
|
<div className="text-sm text-neutral-600 dark:text-neutral-400">
|
||||||
{dateSubmitted.toLocaleDateString()}
|
{dateSubmitted}
|
||||||
</div>
|
</div>
|
||||||
<div className="text-sm text-neutral-600 dark:text-neutral-400">
|
<div className="text-sm text-neutral-600 dark:text-neutral-400">
|
||||||
{runs.toLocaleString()} runs
|
{runs.toLocaleString()} runs
|
||||||
|
|||||||
@@ -25,10 +25,11 @@ export interface AgentTableRowProps {
|
|||||||
sub_heading: string;
|
sub_heading: string;
|
||||||
description: string;
|
description: string;
|
||||||
imageSrc: string[];
|
imageSrc: string[];
|
||||||
dateSubmitted: Date;
|
date_submitted: string;
|
||||||
status: SubmissionStatus;
|
status: SubmissionStatus;
|
||||||
runs: number;
|
runs: number;
|
||||||
rating: number;
|
rating: number;
|
||||||
|
dateSubmitted: string;
|
||||||
id: number;
|
id: number;
|
||||||
video_url?: string;
|
video_url?: string;
|
||||||
categories?: string[];
|
categories?: string[];
|
||||||
@@ -129,7 +130,7 @@ export const AgentTableRow = ({
|
|||||||
|
|
||||||
{/* Date column */}
|
{/* Date column */}
|
||||||
<div className="text-sm text-neutral-600 dark:text-neutral-400">
|
<div className="text-sm text-neutral-600 dark:text-neutral-400">
|
||||||
{dateSubmitted.toLocaleDateString()}
|
{dateSubmitted}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Status column */}
|
{/* Status column */}
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ interface useAgentTableRowProps {
|
|||||||
sub_heading: string;
|
sub_heading: string;
|
||||||
description: string;
|
description: string;
|
||||||
imageSrc: string[];
|
imageSrc: string[];
|
||||||
dateSubmitted: Date;
|
dateSubmitted: string;
|
||||||
status: SubmissionStatus;
|
status: SubmissionStatus;
|
||||||
runs: number;
|
runs: number;
|
||||||
rating: number;
|
rating: number;
|
||||||
|
|||||||
@@ -86,10 +86,13 @@ export const MainDashboardPage = () => {
|
|||||||
agent_id: submission.agent_id,
|
agent_id: submission.agent_id,
|
||||||
agent_version: submission.agent_version,
|
agent_version: submission.agent_version,
|
||||||
sub_heading: submission.sub_heading,
|
sub_heading: submission.sub_heading,
|
||||||
|
date_submitted: submission.date_submitted,
|
||||||
agentName: submission.name,
|
agentName: submission.name,
|
||||||
description: submission.description,
|
description: submission.description,
|
||||||
imageSrc: submission.image_urls || [""],
|
imageSrc: submission.image_urls || [""],
|
||||||
dateSubmitted: submission.date_submitted,
|
dateSubmitted: new Date(
|
||||||
|
submission.date_submitted,
|
||||||
|
).toLocaleDateString(),
|
||||||
status: submission.status,
|
status: submission.status,
|
||||||
runs: submission.runs,
|
runs: submission.runs,
|
||||||
rating: submission.rating,
|
rating: submission.rating,
|
||||||
|
|||||||
@@ -5,25 +5,17 @@ import { NotificationPreference } from "@/app/api/__generated__/models/notificat
|
|||||||
import { User } from "@supabase/supabase-js";
|
import { User } from "@supabase/supabase-js";
|
||||||
import { EmailForm } from "./components/EmailForm/EmailForm";
|
import { EmailForm } from "./components/EmailForm/EmailForm";
|
||||||
import { NotificationForm } from "./components/NotificationForm/NotificationForm";
|
import { NotificationForm } from "./components/NotificationForm/NotificationForm";
|
||||||
import { TimezoneForm } from "./components/TimezoneForm/TimezoneForm";
|
|
||||||
|
|
||||||
type SettingsFormProps = {
|
type SettingsFormProps = {
|
||||||
preferences: NotificationPreference;
|
preferences: NotificationPreference;
|
||||||
user: User;
|
user: User;
|
||||||
timezone?: string;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export function SettingsForm({
|
export function SettingsForm({ preferences, user }: SettingsFormProps) {
|
||||||
preferences,
|
|
||||||
user,
|
|
||||||
timezone,
|
|
||||||
}: SettingsFormProps) {
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-8">
|
<div className="flex flex-col gap-8">
|
||||||
<EmailForm user={user} />
|
<EmailForm user={user} />
|
||||||
<Separator />
|
<Separator />
|
||||||
<TimezoneForm user={user} currentTimezone={timezone} />
|
|
||||||
<Separator />
|
|
||||||
<NotificationForm preferences={preferences} user={user} />
|
<NotificationForm preferences={preferences} user={user} />
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -106,60 +106,6 @@ export function NotificationForm({ preferences, user }: NotificationFormProps) {
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Store Notifications */}
|
|
||||||
<div className="flex flex-col gap-6">
|
|
||||||
<Text variant="h4" size="body-medium" className="text-slate-400">
|
|
||||||
Store Notifications
|
|
||||||
</Text>
|
|
||||||
<FormField
|
|
||||||
control={form.control}
|
|
||||||
name="notifyOnAgentApproved"
|
|
||||||
render={({ field }) => (
|
|
||||||
<FormItem className="flex flex-row items-center justify-between">
|
|
||||||
<div className="space-y-0.5">
|
|
||||||
<Text variant="h4" size="body-medium">
|
|
||||||
Agent Approved
|
|
||||||
</Text>
|
|
||||||
<Text variant="body">
|
|
||||||
Get notified when your submitted agent is approved for the
|
|
||||||
store
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
<FormControl>
|
|
||||||
<Switch
|
|
||||||
checked={field.value}
|
|
||||||
onCheckedChange={field.onChange}
|
|
||||||
/>
|
|
||||||
</FormControl>
|
|
||||||
</FormItem>
|
|
||||||
)}
|
|
||||||
/>
|
|
||||||
|
|
||||||
<FormField
|
|
||||||
control={form.control}
|
|
||||||
name="notifyOnAgentRejected"
|
|
||||||
render={({ field }) => (
|
|
||||||
<FormItem className="flex flex-row items-center justify-between">
|
|
||||||
<div className="space-y-0.5">
|
|
||||||
<Text variant="h4" size="body-medium">
|
|
||||||
Agent Rejected
|
|
||||||
</Text>
|
|
||||||
<Text variant="body">
|
|
||||||
Receive notifications when your agent submission needs
|
|
||||||
updates
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
<FormControl>
|
|
||||||
<Switch
|
|
||||||
checked={field.value}
|
|
||||||
onCheckedChange={field.onChange}
|
|
||||||
/>
|
|
||||||
</FormControl>
|
|
||||||
</FormItem>
|
|
||||||
)}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Balance Notifications */}
|
{/* Balance Notifications */}
|
||||||
<div className="flex flex-col gap-4">
|
<div className="flex flex-col gap-4">
|
||||||
<Text variant="h4" size="body-medium" className="text-slate-400">
|
<Text variant="h4" size="body-medium" className="text-slate-400">
|
||||||
|
|||||||
@@ -18,8 +18,6 @@ const notificationFormSchema = z.object({
|
|||||||
notifyOnDailySummary: z.boolean(),
|
notifyOnDailySummary: z.boolean(),
|
||||||
notifyOnWeeklySummary: z.boolean(),
|
notifyOnWeeklySummary: z.boolean(),
|
||||||
notifyOnMonthlySummary: z.boolean(),
|
notifyOnMonthlySummary: z.boolean(),
|
||||||
notifyOnAgentApproved: z.boolean(),
|
|
||||||
notifyOnAgentRejected: z.boolean(),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
function createNotificationDefaultValues(preferences: {
|
function createNotificationDefaultValues(preferences: {
|
||||||
@@ -36,8 +34,6 @@ function createNotificationDefaultValues(preferences: {
|
|||||||
notifyOnDailySummary: preferences.preferences?.DAILY_SUMMARY,
|
notifyOnDailySummary: preferences.preferences?.DAILY_SUMMARY,
|
||||||
notifyOnWeeklySummary: preferences.preferences?.WEEKLY_SUMMARY,
|
notifyOnWeeklySummary: preferences.preferences?.WEEKLY_SUMMARY,
|
||||||
notifyOnMonthlySummary: preferences.preferences?.MONTHLY_SUMMARY,
|
notifyOnMonthlySummary: preferences.preferences?.MONTHLY_SUMMARY,
|
||||||
notifyOnAgentApproved: preferences.preferences?.AGENT_APPROVED,
|
|
||||||
notifyOnAgentRejected: preferences.preferences?.AGENT_REJECTED,
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,8 +80,6 @@ export function useNotificationForm({
|
|||||||
DAILY_SUMMARY: values.notifyOnDailySummary,
|
DAILY_SUMMARY: values.notifyOnDailySummary,
|
||||||
WEEKLY_SUMMARY: values.notifyOnWeeklySummary,
|
WEEKLY_SUMMARY: values.notifyOnWeeklySummary,
|
||||||
MONTHLY_SUMMARY: values.notifyOnMonthlySummary,
|
MONTHLY_SUMMARY: values.notifyOnMonthlySummary,
|
||||||
AGENT_APPROVED: values.notifyOnAgentApproved,
|
|
||||||
AGENT_REJECTED: values.notifyOnAgentRejected,
|
|
||||||
},
|
},
|
||||||
daily_limit: 0,
|
daily_limit: 0,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,126 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import * as React from "react";
|
|
||||||
import { useTimezoneForm } from "./useTimezoneForm";
|
|
||||||
import { User } from "@supabase/supabase-js";
|
|
||||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import {
|
|
||||||
Select,
|
|
||||||
SelectContent,
|
|
||||||
SelectItem,
|
|
||||||
SelectTrigger,
|
|
||||||
SelectValue,
|
|
||||||
} from "@/components/ui/select";
|
|
||||||
import {
|
|
||||||
Form,
|
|
||||||
FormControl,
|
|
||||||
FormField,
|
|
||||||
FormItem,
|
|
||||||
FormLabel,
|
|
||||||
FormMessage,
|
|
||||||
} from "@/components/ui/form";
|
|
||||||
|
|
||||||
type TimezoneFormProps = {
|
|
||||||
user: User;
|
|
||||||
currentTimezone?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Common timezones list - can be expanded later
|
|
||||||
const TIMEZONES = [
|
|
||||||
{ value: "UTC", label: "UTC (Coordinated Universal Time)" },
|
|
||||||
{ value: "America/New_York", label: "Eastern Time (US & Canada)" },
|
|
||||||
{ value: "America/Chicago", label: "Central Time (US & Canada)" },
|
|
||||||
{ value: "America/Denver", label: "Mountain Time (US & Canada)" },
|
|
||||||
{ value: "America/Los_Angeles", label: "Pacific Time (US & Canada)" },
|
|
||||||
{ value: "America/Phoenix", label: "Arizona (US)" },
|
|
||||||
{ value: "America/Anchorage", label: "Alaska (US)" },
|
|
||||||
{ value: "Pacific/Honolulu", label: "Hawaii (US)" },
|
|
||||||
{ value: "Europe/London", label: "London (UK)" },
|
|
||||||
{ value: "Europe/Paris", label: "Paris (France)" },
|
|
||||||
{ value: "Europe/Berlin", label: "Berlin (Germany)" },
|
|
||||||
{ value: "Europe/Moscow", label: "Moscow (Russia)" },
|
|
||||||
{ value: "Asia/Dubai", label: "Dubai (UAE)" },
|
|
||||||
{ value: "Asia/Kolkata", label: "India Standard Time" },
|
|
||||||
{ value: "Asia/Shanghai", label: "China Standard Time" },
|
|
||||||
{ value: "Asia/Tokyo", label: "Tokyo (Japan)" },
|
|
||||||
{ value: "Asia/Seoul", label: "Seoul (South Korea)" },
|
|
||||||
{ value: "Asia/Singapore", label: "Singapore" },
|
|
||||||
{ value: "Australia/Sydney", label: "Sydney (Australia)" },
|
|
||||||
{ value: "Australia/Melbourne", label: "Melbourne (Australia)" },
|
|
||||||
{ value: "Pacific/Auckland", label: "Auckland (New Zealand)" },
|
|
||||||
{ value: "America/Toronto", label: "Toronto (Canada)" },
|
|
||||||
{ value: "America/Vancouver", label: "Vancouver (Canada)" },
|
|
||||||
{ value: "America/Mexico_City", label: "Mexico City (Mexico)" },
|
|
||||||
{ value: "America/Sao_Paulo", label: "São Paulo (Brazil)" },
|
|
||||||
{ value: "America/Buenos_Aires", label: "Buenos Aires (Argentina)" },
|
|
||||||
{ value: "Africa/Cairo", label: "Cairo (Egypt)" },
|
|
||||||
{ value: "Africa/Johannesburg", label: "Johannesburg (South Africa)" },
|
|
||||||
];
|
|
||||||
|
|
||||||
export function TimezoneForm({
|
|
||||||
user,
|
|
||||||
currentTimezone = "not-set",
|
|
||||||
}: TimezoneFormProps) {
|
|
||||||
// If timezone is not set, try to detect it from the browser
|
|
||||||
const effectiveTimezone = React.useMemo(() => {
|
|
||||||
if (currentTimezone === "not-set") {
|
|
||||||
// Try to get browser timezone as a suggestion
|
|
||||||
try {
|
|
||||||
return Intl.DateTimeFormat().resolvedOptions().timeZone || "UTC";
|
|
||||||
} catch {
|
|
||||||
return "UTC";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return currentTimezone;
|
|
||||||
}, [currentTimezone]);
|
|
||||||
|
|
||||||
const { form, onSubmit, isLoading } = useTimezoneForm({
|
|
||||||
user,
|
|
||||||
currentTimezone: effectiveTimezone,
|
|
||||||
});
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Card>
|
|
||||||
<CardHeader>
|
|
||||||
<CardTitle>Timezone</CardTitle>
|
|
||||||
</CardHeader>
|
|
||||||
<CardContent>
|
|
||||||
<Form {...form}>
|
|
||||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-6">
|
|
||||||
<FormField
|
|
||||||
control={form.control}
|
|
||||||
name="timezone"
|
|
||||||
render={({ field }) => (
|
|
||||||
<FormItem>
|
|
||||||
<FormLabel>Select your timezone</FormLabel>
|
|
||||||
<Select
|
|
||||||
onValueChange={field.onChange}
|
|
||||||
defaultValue={field.value}
|
|
||||||
>
|
|
||||||
<FormControl>
|
|
||||||
<SelectTrigger>
|
|
||||||
<SelectValue placeholder="Select a timezone" />
|
|
||||||
</SelectTrigger>
|
|
||||||
</FormControl>
|
|
||||||
<SelectContent>
|
|
||||||
{TIMEZONES.map((tz) => (
|
|
||||||
<SelectItem key={tz.value} value={tz.value}>
|
|
||||||
{tz.label}
|
|
||||||
</SelectItem>
|
|
||||||
))}
|
|
||||||
</SelectContent>
|
|
||||||
</Select>
|
|
||||||
<FormMessage />
|
|
||||||
</FormItem>
|
|
||||||
)}
|
|
||||||
/>
|
|
||||||
<Button type="submit" disabled={isLoading}>
|
|
||||||
{isLoading ? "Saving..." : "Save timezone"}
|
|
||||||
</Button>
|
|
||||||
</form>
|
|
||||||
</Form>
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { useState } from "react";
|
|
||||||
import { useForm } from "react-hook-form";
|
|
||||||
import { z } from "zod";
|
|
||||||
import { zodResolver } from "@hookform/resolvers/zod";
|
|
||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
|
||||||
import { User } from "@supabase/supabase-js";
|
|
||||||
import {
|
|
||||||
usePostV1UpdateUserTimezone,
|
|
||||||
getGetV1GetUserTimezoneQueryKey,
|
|
||||||
} from "@/app/api/__generated__/endpoints/auth/auth";
|
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
|
||||||
|
|
||||||
const formSchema = z.object({
|
|
||||||
timezone: z.string().min(1, "Please select a timezone"),
|
|
||||||
});
|
|
||||||
|
|
||||||
type FormData = z.infer<typeof formSchema>;
|
|
||||||
|
|
||||||
type UseTimezoneFormProps = {
|
|
||||||
user: User;
|
|
||||||
currentTimezone: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const useTimezoneForm = ({ currentTimezone }: UseTimezoneFormProps) => {
|
|
||||||
const [isLoading, setIsLoading] = useState(false);
|
|
||||||
const { toast } = useToast();
|
|
||||||
const queryClient = useQueryClient();
|
|
||||||
|
|
||||||
const form = useForm<FormData>({
|
|
||||||
resolver: zodResolver(formSchema),
|
|
||||||
defaultValues: {
|
|
||||||
timezone: currentTimezone,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const updateTimezone = usePostV1UpdateUserTimezone();
|
|
||||||
|
|
||||||
const onSubmit = async (data: FormData) => {
|
|
||||||
setIsLoading(true);
|
|
||||||
try {
|
|
||||||
await updateTimezone.mutateAsync({
|
|
||||||
data: { timezone: data.timezone } as any,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Invalidate the timezone query to refetch the updated value
|
|
||||||
await queryClient.invalidateQueries({
|
|
||||||
queryKey: getGetV1GetUserTimezoneQueryKey(),
|
|
||||||
});
|
|
||||||
|
|
||||||
toast({
|
|
||||||
title: "Success",
|
|
||||||
description: "Your timezone has been updated successfully.",
|
|
||||||
variant: "success",
|
|
||||||
});
|
|
||||||
} catch {
|
|
||||||
toast({
|
|
||||||
title: "Error",
|
|
||||||
description: "Failed to update timezone. Please try again.",
|
|
||||||
variant: "destructive",
|
|
||||||
});
|
|
||||||
} finally {
|
|
||||||
setIsLoading(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
return {
|
|
||||||
form,
|
|
||||||
onSubmit,
|
|
||||||
isLoading,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
@@ -1,11 +1,7 @@
|
|||||||
"use client";
|
"use client";
|
||||||
import {
|
import { useGetV1GetNotificationPreferences } from "@/app/api/__generated__/endpoints/auth/auth";
|
||||||
useGetV1GetNotificationPreferences,
|
|
||||||
useGetV1GetUserTimezone,
|
|
||||||
} from "@/app/api/__generated__/endpoints/auth/auth";
|
|
||||||
import { SettingsForm } from "@/app/(platform)/profile/(user)/settings/components/SettingsForm/SettingsForm";
|
import { SettingsForm } from "@/app/(platform)/profile/(user)/settings/components/SettingsForm/SettingsForm";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { useTimezoneDetection } from "@/hooks/useTimezoneDetection";
|
|
||||||
import * as React from "react";
|
import * as React from "react";
|
||||||
import SettingsLoading from "./loading";
|
import SettingsLoading from "./loading";
|
||||||
import { redirect } from "next/navigation";
|
import { redirect } from "next/navigation";
|
||||||
@@ -14,8 +10,8 @@ import { Text } from "@/components/atoms/Text/Text";
|
|||||||
export default function SettingsPage() {
|
export default function SettingsPage() {
|
||||||
const {
|
const {
|
||||||
data: preferences,
|
data: preferences,
|
||||||
isError: preferencesError,
|
isError,
|
||||||
isLoading: preferencesLoading,
|
isLoading,
|
||||||
} = useGetV1GetNotificationPreferences({
|
} = useGetV1GetNotificationPreferences({
|
||||||
query: {
|
query: {
|
||||||
select: (res) => {
|
select: (res) => {
|
||||||
@@ -24,24 +20,9 @@ export default function SettingsPage() {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const { data: timezoneData, isLoading: timezoneLoading } =
|
|
||||||
useGetV1GetUserTimezone({
|
|
||||||
query: {
|
|
||||||
select: (res) => {
|
|
||||||
return res.data;
|
|
||||||
},
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const { user, isUserLoading } = useSupabase();
|
const { user, isUserLoading } = useSupabase();
|
||||||
|
|
||||||
// Auto-detect timezone if it's not set
|
if (isLoading || isUserLoading) {
|
||||||
const timezone = timezoneData?.timezone
|
|
||||||
? String(timezoneData.timezone)
|
|
||||||
: "not-set";
|
|
||||||
useTimezoneDetection(timezone);
|
|
||||||
|
|
||||||
if (preferencesLoading || isUserLoading || timezoneLoading) {
|
|
||||||
return <SettingsLoading />;
|
return <SettingsLoading />;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,7 +30,7 @@ export default function SettingsPage() {
|
|||||||
redirect("/login");
|
redirect("/login");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (preferencesError || !preferences || !preferences.preferences) {
|
if (isError || !preferences || !preferences.preferences) {
|
||||||
return "Errror..."; // TODO: Will use a Error reusable components from Block Menu redesign
|
return "Errror..."; // TODO: Will use a Error reusable components from Block Menu redesign
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,7 +42,7 @@ export default function SettingsPage() {
|
|||||||
Manage your account settings and preferences.
|
Manage your account settings and preferences.
|
||||||
</Text>
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
<SettingsForm preferences={preferences} user={user} timezone={timezone} />
|
<SettingsForm preferences={preferences} user={user} />
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ import {
|
|||||||
import { isServerSide } from "@/lib/utils/is-server-side";
|
import { isServerSide } from "@/lib/utils/is-server-side";
|
||||||
import { getAgptServerBaseUrl } from "@/lib/env-config";
|
import { getAgptServerBaseUrl } from "@/lib/env-config";
|
||||||
|
|
||||||
import { transformDates } from "./date-transformer";
|
|
||||||
|
|
||||||
const FRONTEND_BASE_URL =
|
const FRONTEND_BASE_URL =
|
||||||
process.env.NEXT_PUBLIC_FRONTEND_BASE_URL || "http://localhost:3000";
|
process.env.NEXT_PUBLIC_FRONTEND_BASE_URL || "http://localhost:3000";
|
||||||
const API_PROXY_BASE_URL = `${FRONTEND_BASE_URL}/api/proxy`; // Sending request via nextjs Server
|
const API_PROXY_BASE_URL = `${FRONTEND_BASE_URL}/api/proxy`; // Sending request via nextjs Server
|
||||||
@@ -100,12 +98,9 @@ export const customMutator = async <T = any>(
|
|||||||
|
|
||||||
const response_data = await getBody<T>(response);
|
const response_data = await getBody<T>(response);
|
||||||
|
|
||||||
// Transform ISO date strings to Date objects in the response data
|
|
||||||
const transformedData = transformDates(response_data);
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
status: response.status,
|
status: response.status,
|
||||||
data: transformedData,
|
data: response_data,
|
||||||
headers: response.headers,
|
headers: response.headers,
|
||||||
} as T;
|
} as T;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,51 +0,0 @@
|
|||||||
/**
|
|
||||||
* Date transformation utility for converting ISO date strings to Date objects
|
|
||||||
* in API responses. This handles the conversion recursively for nested objects.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// ISO date regex pattern to match strings that look like ISO dates
|
|
||||||
const ISO_DATE_REGEX = /^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(\.\d+)?Z?$/;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Validates if a string is a valid ISO date and can be parsed
|
|
||||||
*/
|
|
||||||
function isValidISODate(dateString: string): boolean {
|
|
||||||
if (!ISO_DATE_REGEX.test(dateString)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const date = new Date(dateString);
|
|
||||||
return !isNaN(date.getTime());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Recursively transforms ISO date strings to Date objects in an object or array
|
|
||||||
* @param obj - The object or array to transform
|
|
||||||
* @returns The transformed object with Date objects
|
|
||||||
*/
|
|
||||||
export function transformDates<T>(obj: T): T {
|
|
||||||
if (typeof obj !== "object" || obj === null) return obj;
|
|
||||||
|
|
||||||
// Handle arrays
|
|
||||||
if (Array.isArray(obj)) {
|
|
||||||
return obj.map(transformDates) as T;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle objects
|
|
||||||
const transformed = {} as T;
|
|
||||||
|
|
||||||
for (const [key, value] of Object.entries(obj)) {
|
|
||||||
if (typeof value === "string" && isValidISODate(value)) {
|
|
||||||
// Convert ISO date string to Date object
|
|
||||||
(transformed as any)[key] = new Date(value);
|
|
||||||
} else if (typeof value === "object") {
|
|
||||||
// Recursively transform nested objects/arrays
|
|
||||||
(transformed as any)[key] = transformDates(value);
|
|
||||||
} else {
|
|
||||||
// Keep primitive values as-is
|
|
||||||
(transformed as any)[key] = value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return transformed;
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -98,23 +98,8 @@ function createResponse(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function createErrorResponse(
|
function createErrorResponse(error: unknown): NextResponse {
|
||||||
error: unknown,
|
console.error("API proxy error:", error);
|
||||||
path: string,
|
|
||||||
method: string,
|
|
||||||
): NextResponse {
|
|
||||||
if (
|
|
||||||
error &&
|
|
||||||
typeof error === "object" &&
|
|
||||||
"status" in error &&
|
|
||||||
[401, 403].includes(error.status as number)
|
|
||||||
) {
|
|
||||||
// Log this since it indicates a potential frontend bug
|
|
||||||
console.warn(
|
|
||||||
`Authentication error in API proxy for ${method} ${path}:`,
|
|
||||||
"message" in error ? error.message : error,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If it's our custom ApiError, preserve the original status and response
|
// If it's our custom ApiError, preserve the original status and response
|
||||||
if (error instanceof ApiError) {
|
if (error instanceof ApiError) {
|
||||||
@@ -162,6 +147,7 @@ async function handler(
|
|||||||
const contentType = req.headers.get("Content-Type");
|
const contentType = req.headers.get("Content-Type");
|
||||||
|
|
||||||
let responseBody: any;
|
let responseBody: any;
|
||||||
|
const responseStatus: number = 200;
|
||||||
const responseHeaders: Record<string, string> = {
|
const responseHeaders: Record<string, string> = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
};
|
};
|
||||||
@@ -180,13 +166,9 @@ async function handler(
|
|||||||
return createUnsupportedContentTypeResponse(contentType);
|
return createUnsupportedContentTypeResponse(contentType);
|
||||||
}
|
}
|
||||||
|
|
||||||
return createResponse(responseBody, 200, responseHeaders);
|
return createResponse(responseBody, responseStatus, responseHeaders);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
return createErrorResponse(
|
return createErrorResponse(error);
|
||||||
error,
|
|
||||||
path.map((s) => `/${s}`).join(""),
|
|
||||||
method,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user