Compare commits

..

1 Commits

Author SHA1 Message Date
Zamil Majdy
e9d846eebb feat(backend): migrate AgentExecutor from ProcessPoolExecutor to ThreadPoolExecutor
- Migrate execution manager from ProcessPoolExecutor to ThreadPoolExecutor for improved performance and resource efficiency
- Rename `Executor` class to `ExecutionProcessor` for better clarity
- Convert classmethods to instance methods following proper OOP design patterns
- Implement thread-local storage using `threading.local()` for thread-safe execution
- Replace process ID tracking with thread ID tracking using `threading.get_ident()`
- Replace `multiprocessing.Manager().Event()` with `threading.Event()`
- Remove signal handling code that doesn't work in worker threads
- Update ExecutionManager to use ThreadPoolExecutor with new `init_worker` initializer

Benefits:
- Performance: Reduced overhead compared to process creation/destruction
- Resource Efficiency: Lower memory footprint and faster startup
- Simplicity: Cleaner implementation using thread-local storage pattern
- Thread Safety: Maintained through isolated ExecutionProcessor instances per thread

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-07 06:53:32 +07:00
579 changed files with 27712 additions and 16271 deletions

View File

@@ -15,7 +15,6 @@
!autogpt_platform/backend/pyproject.toml
!autogpt_platform/backend/poetry.lock
!autogpt_platform/backend/README.md
!autogpt_platform/backend/.env
# Platform - Market
!autogpt_platform/market/market/
@@ -28,7 +27,6 @@
# Platform - Frontend
!autogpt_platform/frontend/src/
!autogpt_platform/frontend/public/
!autogpt_platform/frontend/scripts/
!autogpt_platform/frontend/package.json
!autogpt_platform/frontend/pnpm-lock.yaml
!autogpt_platform/frontend/tsconfig.json
@@ -36,7 +34,6 @@
## config
!autogpt_platform/frontend/*.config.*
!autogpt_platform/frontend/.env.*
!autogpt_platform/frontend/.env
# Classic - AutoGPT
!classic/original_autogpt/autogpt/

View File

@@ -24,8 +24,7 @@
</details>
#### For configuration changes:
- [ ] `.env.default` is updated or already compatible with my changes
- [ ] `.env.example` is updated or already compatible with my changes
- [ ] `docker-compose.yml` is updated or already compatible with my changes
- [ ] I have included a list of my configuration changes in the PR description (under **Changes**)

View File

@@ -82,6 +82,37 @@ jobs:
- name: Run lint
run: pnpm lint
type-check:
runs-on: ubuntu-latest
needs: setup
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Restore dependencies cache
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Run tsc check
run: pnpm type-check
chromatic:
runs-on: ubuntu-latest
needs: setup
@@ -145,7 +176,11 @@ jobs:
- name: Copy default supabase .env
run: |
cp ../.env.default ../.env
cp ../.env.example ../.env
- name: Copy backend .env
run: |
cp ../backend/.env.example ../backend/.env
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -217,6 +252,15 @@ jobs:
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Setup .env
run: cp .env.example .env
- name: Build frontend
run: pnpm build --turbo
# uses Turbopack, much faster and safe enough for a test pipeline
env:
NEXT_PUBLIC_PW_TEST: true
- name: Install Browser 'chromium'
run: pnpm playwright install --with-deps chromium

View File

@@ -1,132 +0,0 @@
name: AutoGPT Platform - Frontend CI
on:
push:
branches: [master, dev]
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- "autogpt_platform/**"
pull_request:
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- "autogpt_platform/**"
merge_group:
defaults:
run:
shell: bash
working-directory: autogpt_platform/frontend
jobs:
setup:
runs-on: ubuntu-latest
outputs:
cache-key: ${{ steps.cache-key.outputs.key }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Generate cache key
id: cache-key
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
types:
runs-on: ubuntu-latest
needs: setup
strategy:
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Copy backend .env
run: |
cp ../backend/.env.default ../backend/.env
- name: Run docker compose
run: |
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
- name: Restore dependencies cache
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Setup .env
run: cp .env.default .env
- name: Wait for services to be ready
run: |
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
- name: Generate API queries
run: pnpm generate:api:force
- name: Check for API schema changes
run: |
if ! git diff --exit-code src/app/api/openapi.json; then
echo "❌ API schema changes detected in src/app/api/openapi.json"
echo ""
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
echo "The API schema is now out of sync with the Front-end queries."
echo ""
echo "To fix this:"
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
echo "2. Run 'pnpm generate:api' locally"
echo "3. Run 'pnpm types' locally"
echo "4. Fix any TypeScript errors that may have been introduced"
echo "5. Commit and push your changes"
echo ""
exit 1
else
echo "✅ No API schema changes detected"
fi
- name: Run Typescript checks
run: pnpm types

3
.gitignore vendored
View File

@@ -5,8 +5,6 @@ classic/original_autogpt/*.json
auto_gpt_workspace/*
*.mpeg
.env
# Root .env files
/.env
azure.yaml
.vscode
.idea/*
@@ -123,6 +121,7 @@ celerybeat.pid
# Environments
.direnv/
.env
.venv
env/
venv*/

View File

@@ -235,7 +235,7 @@ repos:
hooks:
- id: tsc
name: Typecheck - AutoGPT Platform - Frontend
entry: bash -c 'cd autogpt_platform/frontend && pnpm types'
entry: bash -c 'cd autogpt_platform/frontend && pnpm type-check'
files: ^autogpt_platform/frontend/
types: [file]
language: system

View File

@@ -3,16 +3,6 @@
[![Discord Follow](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fautogpt%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&label=total%20members&logo=discord&logoColor=white&color=7289da)](https://discord.gg/autogpt) &ensp;
[![Twitter Follow](https://img.shields.io/twitter/follow/Auto_GPT?style=social)](https://twitter.com/Auto_GPT) &ensp;
<!-- Keep these links. Translations will automatically update with the README. -->
[Deutsch](https://zdoc.app/de/Significant-Gravitas/AutoGPT) |
[Español](https://zdoc.app/es/Significant-Gravitas/AutoGPT) |
[français](https://zdoc.app/fr/Significant-Gravitas/AutoGPT) |
[日本語](https://zdoc.app/ja/Significant-Gravitas/AutoGPT) |
[한국어](https://zdoc.app/ko/Significant-Gravitas/AutoGPT) |
[Português](https://zdoc.app/pt/Significant-Gravitas/AutoGPT) |
[Русский](https://zdoc.app/ru/Significant-Gravitas/AutoGPT) |
[中文](https://zdoc.app/zh/Significant-Gravitas/AutoGPT)
**AutoGPT** is a powerful platform that allows you to create, deploy, and manage continuous AI agents that automate complex workflows.
## Hosting Options

View File

@@ -1,11 +1,9 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Repository Overview
AutoGPT Platform is a monorepo containing:
- **Backend** (`/backend`): Python FastAPI server with async support
- **Frontend** (`/frontend`): Next.js React application
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
@@ -13,7 +11,6 @@ AutoGPT Platform is a monorepo containing:
## Essential Commands
### Backend Development
```bash
# Install dependencies
cd backend && poetry install
@@ -33,18 +30,11 @@ poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in TESTING.md
#### Creating/Updating Snapshots
@@ -57,8 +47,8 @@ poetry run pytest path/to/test.py --snapshot-update
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
### Frontend Development
### Frontend Development
```bash
# Install dependencies
cd frontend && npm install
@@ -76,13 +66,12 @@ npm run storybook
npm run build
# Type checking
npm run types
npm run type-check
```
## Architecture Overview
### Backend Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
@@ -91,7 +80,6 @@ npm run types
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
### Frontend Architecture
- **Framework**: Next.js App Router with React Server Components
- **State Management**: React hooks + Supabase client for real-time updates
- **Workflow Builder**: Visual graph editor using @xyflow/react
@@ -99,7 +87,6 @@ npm run types
- **Feature Flags**: LaunchDarkly integration
### Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
@@ -107,16 +94,13 @@ npm run types
5. **Virus Scanning**: ClamAV integration for file upload security
### Testing Approach
- Backend uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Frontend uses Playwright for E2E tests
- Component testing via Storybook
### Database Schema
Key models (defined in `/backend/schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
@@ -124,31 +108,13 @@ Key models (defined in `/backend/schema.prisma`):
- `StoreListing`: Marketplace listings for sharing agents
### Environment Configuration
#### Configuration Files
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
#### Docker Environment Loading Order
1. `.env.default` files provide base configuration (tracked in git)
2. `.env` files provide user-specific overrides (gitignored)
3. Docker Compose `environment:` sections provide service-specific overrides
4. Shell environment variables have highest precedence
#### Key Points
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
- Backend: `.env` file in `/backend`
- Frontend: `.env.local` file in `/frontend`
- Both require Supabase credentials and API keys for various services
### Common Development Tasks
**Adding a new block:**
1. Create new file in `/backend/backend/blocks/`
2. Inherit from `Block` base class
3. Define input/output schemas
@@ -156,18 +122,13 @@ Key models (defined in `/backend/schema.prisma`):
5. Register in block registry
6. Generate the block uuid using `uuid.uuid4()`
Note: when making many new blocks analyze the interfaces for each of these blcoks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
**Modifying the API:**
1. Update route in `/backend/backend/server/routers/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
**Frontend feature development:**
1. Components go in `/frontend/src/components/`
2. Use existing UI components from `/frontend/src/components/ui/`
3. Add Storybook stories for new components
@@ -176,7 +137,6 @@ ex: do the inputs and outputs tie well together?
### Security Implementation
**Cache Protection Middleware:**
- Located in `/backend/backend/server/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
@@ -185,20 +145,14 @@ ex: do the inputs and outputs tie well together?
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications
### Creating Pull Requests
### Creating Pull Requests
- Create the PR aginst the `dev` branch of the repository.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
- Use conventional commit messages (see below)/
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
- Run the github pre-commit hooks to ensure code quality.
### Reviewing/Revising Pull Requests
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
### Conventional Commits
Use this format for commit messages and Pull Request titles:

View File

@@ -8,6 +8,7 @@ Welcome to the AutoGPT Platform - a powerful system for creating and running AI
- Docker
- Docker Compose V2 (comes with Docker Desktop, or can be installed separately)
- Node.js & NPM (for running the frontend application)
### Running the System
@@ -23,10 +24,10 @@ To run the AutoGPT Platform, follow these steps:
2. Run the following command:
```
cp .env.default .env
cp .env.example .env
```
This command will copy the `.env.default` file to `.env`. You can modify the `.env` file to add your own environment variables.
This command will copy the `.env.example` file to `.env`. You can modify the `.env` file to add your own environment variables.
3. Run the following command:
@@ -36,7 +37,44 @@ To run the AutoGPT Platform, follow these steps:
This command will start all the necessary backend services defined in the `docker-compose.yml` file in detached mode.
4. After all the services are in ready state, open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
4. Navigate to `frontend` within the `autogpt_platform` directory:
```
cd frontend
```
You will need to run your frontend application separately on your local machine.
5. Run the following command:
```
cp .env.example .env.local
```
This command will copy the `.env.example` file to `.env.local` in the `frontend` directory. You can modify the `.env.local` within this folder to add your own environment variables for the frontend application.
6. Run the following command:
Enable corepack and install dependencies by running:
```
corepack enable
pnpm i
```
Generate the API client (this step is required before running the frontend):
```
pnpm generate:api-client
```
Then start the frontend application in development mode:
```
pnpm dev
```
7. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
### Docker Compose Commands
@@ -139,21 +177,20 @@ The platform includes scripts for generating and managing the API client:
- `pnpm fetch:openapi`: Fetches the OpenAPI specification from the backend service (requires backend to be running on port 8006)
- `pnpm generate:api-client`: Generates the TypeScript API client from the OpenAPI specification using Orval
- `pnpm generate:api`: Runs both fetch and generate commands in sequence
- `pnpm generate:api-all`: Runs both fetch and generate commands in sequence
#### Manual API Client Updates
If you need to update the API client after making changes to the backend API:
1. Ensure the backend services are running:
```
docker compose up -d
```
2. Generate the updated API client:
```
pnpm generate:api
pnpm generate:api-all
```
This will fetch the latest OpenAPI specification and regenerate the TypeScript client code.

View File

@@ -7,5 +7,9 @@ class Settings:
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
self.JWT_ALGORITHM: str = "HS256"
@property
def is_configured(self) -> bool:
return bool(self.JWT_SECRET_KEY)
settings = Settings()

View File

@@ -0,0 +1,196 @@
import asyncio
import contextlib
import logging
from functools import wraps
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
import ldclient
from fastapi import HTTPException
from ldclient import Context, LDClient
from ldclient.config import Config
from typing_extensions import ParamSpec
from .config import SETTINGS
logger = logging.getLogger(__name__)
P = ParamSpec("P")
T = TypeVar("T")
_is_initialized = False
def get_client() -> LDClient:
"""Get the LaunchDarkly client singleton."""
if not _is_initialized:
initialize_launchdarkly()
return ldclient.get()
def initialize_launchdarkly() -> None:
sdk_key = SETTINGS.launch_darkly_sdk_key
logger.debug(
f"Initializing LaunchDarkly with SDK key: {'present' if sdk_key else 'missing'}"
)
if not sdk_key:
logger.warning("LaunchDarkly SDK key not configured")
return
config = Config(sdk_key)
ldclient.set_config(config)
if ldclient.get().is_initialized():
global _is_initialized
_is_initialized = True
logger.info("LaunchDarkly client initialized successfully")
else:
logger.error("LaunchDarkly client failed to initialize")
def shutdown_launchdarkly() -> None:
"""Shutdown the LaunchDarkly client."""
if ldclient.get().is_initialized():
ldclient.get().close()
logger.info("LaunchDarkly client closed successfully")
def create_context(
user_id: str, additional_attributes: Optional[Dict[str, Any]] = None
) -> Context:
"""Create LaunchDarkly context with optional additional attributes."""
builder = Context.builder(str(user_id)).kind("user")
if additional_attributes:
for key, value in additional_attributes.items():
builder.set(key, value)
return builder.build()
def is_feature_enabled(flag_key: str, user_id: str, default: bool = False) -> bool:
"""
Simple helper to check if a feature flag is enabled for a user.
Args:
flag_key: The LaunchDarkly feature flag key
user_id: The user ID to evaluate the flag for
default: Default value if LaunchDarkly is unavailable or flag evaluation fails
Returns:
True if feature is enabled, False otherwise
"""
try:
client = get_client()
context = create_context(str(user_id))
return client.variation(flag_key, context, default)
except Exception as e:
logger.debug(
f"LaunchDarkly flag evaluation failed for {flag_key}: {e}, using default={default}"
)
return default
def feature_flag(
flag_key: str,
default: bool = False,
) -> Callable[
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
]:
"""
Decorator for feature flag protected endpoints.
"""
def decorator(
func: Callable[P, Union[T, Awaitable[T]]],
) -> Callable[P, Union[T, Awaitable[T]]]:
@wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
user_id = kwargs.get("user_id")
if not user_id:
raise ValueError("user_id is required")
if not get_client().is_initialized():
logger.warning(
f"LaunchDarkly not initialized, using default={default}"
)
is_enabled = default
else:
context = create_context(str(user_id))
is_enabled = get_client().variation(flag_key, context, default)
if not is_enabled:
raise HTTPException(status_code=404, detail="Feature not available")
result = func(*args, **kwargs)
if asyncio.iscoroutine(result):
return await result
return cast(T, result)
except Exception as e:
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
raise
@wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
user_id = kwargs.get("user_id")
if not user_id:
raise ValueError("user_id is required")
if not get_client().is_initialized():
logger.warning(
f"LaunchDarkly not initialized, using default={default}"
)
is_enabled = default
else:
context = create_context(str(user_id))
is_enabled = get_client().variation(flag_key, context, default)
if not is_enabled:
raise HTTPException(status_code=404, detail="Feature not available")
return cast(T, func(*args, **kwargs))
except Exception as e:
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
raise
return cast(
Callable[P, Union[T, Awaitable[T]]],
async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper,
)
return decorator
def percentage_rollout(
flag_key: str,
default: bool = False,
) -> Callable[
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
]:
"""Decorator for percentage-based rollouts."""
return feature_flag(flag_key, default)
def beta_feature(
flag_key: Optional[str] = None,
unauthorized_response: Any = {"message": "Not available in beta"},
) -> Callable[
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
]:
"""Decorator for beta features."""
actual_key = f"beta-{flag_key}" if flag_key else "beta"
return feature_flag(actual_key, False)
@contextlib.contextmanager
def mock_flag_variation(flag_key: str, return_value: Any):
"""Context manager for testing feature flags."""
original_variation = get_client().variation
get_client().variation = lambda key, context, default: (
return_value if key == flag_key else original_variation(key, context, default)
)
try:
yield
finally:
get_client().variation = original_variation

View File

@@ -0,0 +1,84 @@
import pytest
from ldclient import LDClient
from autogpt_libs.feature_flag.client import (
feature_flag,
is_feature_enabled,
mock_flag_variation,
)
@pytest.fixture
def ld_client(mocker):
client = mocker.Mock(spec=LDClient)
mocker.patch("ldclient.get", return_value=client)
client.is_initialized.return_value = True
return client
@pytest.mark.asyncio
async def test_feature_flag_enabled(ld_client):
ld_client.variation.return_value = True
@feature_flag("test-flag")
async def test_function(user_id: str):
return "success"
result = test_function(user_id="test-user")
assert result == "success"
ld_client.variation.assert_called_once()
@pytest.mark.asyncio
async def test_feature_flag_unauthorized_response(ld_client):
ld_client.variation.return_value = False
@feature_flag("test-flag")
async def test_function(user_id: str):
return "success"
result = test_function(user_id="test-user")
assert result == {"error": "disabled"}
def test_mock_flag_variation(ld_client):
with mock_flag_variation("test-flag", True):
assert ld_client.variation("test-flag", None, False)
with mock_flag_variation("test-flag", False):
assert ld_client.variation("test-flag", None, False)
def test_is_feature_enabled(ld_client):
"""Test the is_feature_enabled helper function."""
ld_client.is_initialized.return_value = True
ld_client.variation.return_value = True
result = is_feature_enabled("test-flag", "user123", default=False)
assert result is True
ld_client.variation.assert_called_once()
call_args = ld_client.variation.call_args
assert call_args[0][0] == "test-flag" # flag_key
assert call_args[0][2] is False # default value
def test_is_feature_enabled_not_initialized(ld_client):
"""Test is_feature_enabled when LaunchDarkly is not initialized."""
ld_client.is_initialized.return_value = False
result = is_feature_enabled("test-flag", "user123", default=True)
assert result is True # Should return default
ld_client.variation.assert_not_called()
def test_is_feature_enabled_exception(mocker):
"""Test is_feature_enabled when get_client() raises an exception."""
mocker.patch(
"autogpt_libs.feature_flag.client.get_client",
side_effect=Exception("Client error"),
)
result = is_feature_enabled("test-flag", "user123", default=True)
assert result is True # Should return default

View File

@@ -0,0 +1,15 @@
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
launch_darkly_sdk_key: str = Field(
default="",
description="The Launch Darkly SDK key",
validation_alias="LAUNCH_DARKLY_SDK_KEY",
)
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
SETTINGS = Settings()

View File

@@ -1,8 +1,6 @@
"""Logging module for Auto-GPT."""
import logging
import os
import socket
import sys
from pathlib import Path
@@ -12,15 +10,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
from .filters import BelowLevelFilter
from .formatters import AGPTFormatter
# Configure global socket timeout and gRPC keepalive to prevent deadlocks
# This must be done at import time before any gRPC connections are established
socket.setdefaulttimeout(30) # 30-second socket timeout
# Enable gRPC keepalive to detect dead connections faster
os.environ.setdefault("GRPC_KEEPALIVE_TIME_MS", "30000") # 30 seconds
os.environ.setdefault("GRPC_KEEPALIVE_TIMEOUT_MS", "5000") # 5 seconds
os.environ.setdefault("GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS", "true")
LOG_DIR = Path(__file__).parent.parent.parent.parent / "logs"
LOG_FILE = "activity.log"
DEBUG_LOG_FILE = "debug.log"
@@ -90,6 +79,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
Note: This function is typically called at the start of the application
to set up the logging infrastructure.
"""
config = LoggingConfig()
log_handlers: list[logging.Handler] = []
@@ -115,17 +105,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
if config.enable_cloud_logging or force_cloud_logging:
import google.cloud.logging
from google.cloud.logging.handlers import CloudLoggingHandler
from google.cloud.logging_v2.handlers.transports import (
BackgroundThreadTransport,
)
from google.cloud.logging_v2.handlers.transports.sync import SyncTransport
client = google.cloud.logging.Client()
# Use BackgroundThreadTransport to prevent blocking the main thread
# and deadlocks when gRPC calls to Google Cloud Logging hang
cloud_handler = CloudLoggingHandler(
client,
name="autogpt_logs",
transport=BackgroundThreadTransport,
transport=SyncTransport,
)
cloud_handler.setLevel(config.level)
log_handlers.append(cloud_handler)

View File

@@ -1,5 +1,39 @@
import logging
import re
from typing import Any
import uvicorn.config
from colorama import Fore
def remove_color_codes(s: str) -> str:
return re.sub(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])", "", s)
def fmt_kwargs(kwargs: dict) -> str:
return ", ".join(f"{n}={repr(v)}" for n, v in kwargs.items())
def print_attribute(
title: str, value: Any, title_color: str = Fore.GREEN, value_color: str = ""
) -> None:
logger = logging.getLogger()
logger.info(
str(value),
extra={
"title": f"{title.rstrip(':')}:",
"title_color": title_color,
"color": value_color,
},
)
def generate_uvicorn_config():
"""
Generates a uvicorn logging config that silences uvicorn's default logging and tells it to use the native logging module.
"""
log_config = dict(uvicorn.config.LOGGING_CONFIG)
log_config["loggers"]["uvicorn"] = {"handlers": []}
log_config["loggers"]["uvicorn.error"] = {"handlers": []}
log_config["loggers"]["uvicorn.access"] = {"handlers": []}
return log_config

View File

@@ -1,34 +1,17 @@
import inspect
import logging
import threading
import time
from functools import wraps
from typing import (
Awaitable,
Callable,
ParamSpec,
Protocol,
Tuple,
TypeVar,
cast,
overload,
runtime_checkable,
)
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
P = ParamSpec("P")
R = TypeVar("R")
logger = logging.getLogger(__name__)
@overload
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
@overload
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
pass
@overload
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
pass
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
def thread_cached(
@@ -74,193 +57,3 @@ def thread_cached(
def clear_thread_cache(func: Callable) -> None:
if clear := getattr(func, "clear_cache", None):
clear()
FuncT = TypeVar("FuncT")
R_co = TypeVar("R_co", covariant=True)
@runtime_checkable
class AsyncCachedFunction(Protocol[P, R_co]):
"""Protocol for async functions with cache management methods."""
def cache_clear(self) -> None:
"""Clear all cached entries."""
return None
def cache_info(self) -> dict[str, int | None]:
"""Get cache statistics."""
return {}
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
"""Call the cached function."""
return None # type: ignore
def async_ttl_cache(
maxsize: int = 128, ttl_seconds: int | None = None
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
"""
TTL (Time To Live) cache decorator for async functions.
Similar to functools.lru_cache but works with async functions and includes optional TTL.
Args:
maxsize: Maximum number of cached entries
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
Returns:
Decorator function
Example:
# With TTL
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
async def api_call(param: str) -> dict:
return {"result": param}
# Without TTL (permanent cache like lru_cache)
@async_ttl_cache(maxsize=1000)
async def expensive_computation(param: str) -> dict:
return {"result": param}
"""
def decorator(
async_func: Callable[P, Awaitable[R]],
) -> AsyncCachedFunction[P, R]:
# Cache storage - use union type to handle both cases
cache_storage: dict[tuple, R | Tuple[R, float]] = {}
@wraps(async_func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# Create cache key from arguments
key = (args, tuple(sorted(kwargs.items())))
current_time = time.time()
# Check if we have a valid cached entry
if key in cache_storage:
if ttl_seconds is None:
# No TTL - return cached result directly
logger.debug(
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
)
return cast(R, cache_storage[key])
else:
# With TTL - check expiration
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
)
return cast(R, result)
else:
# Expired entry
del cache_storage[key]
logger.debug(
f"Cache entry expired for {async_func.__name__}"
)
# Cache miss or expired - fetch fresh data
logger.debug(
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
)
result = await async_func(*args, **kwargs)
# Store in cache
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Simple cleanup when cache gets too large
if len(cache_storage) > maxsize:
# Remove oldest entries (simple FIFO cleanup)
cutoff = maxsize // 2
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
logger.debug(
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
)
return result
# Add cache management methods (similar to functools.lru_cache)
def cache_clear() -> None:
cache_storage.clear()
def cache_info() -> dict[str, int | None]:
return {
"size": len(cache_storage),
"maxsize": maxsize,
"ttl_seconds": ttl_seconds,
}
# Attach methods to wrapper
setattr(wrapper, "cache_clear", cache_clear)
setattr(wrapper, "cache_info", cache_info)
return cast(AsyncCachedFunction[P, R], wrapper)
return decorator
@overload
def async_cache(
func: Callable[P, Awaitable[R]],
) -> AsyncCachedFunction[P, R]:
pass
@overload
def async_cache(
func: None = None,
*,
maxsize: int = 128,
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
pass
def async_cache(
func: Callable[P, Awaitable[R]] | None = None,
*,
maxsize: int = 128,
) -> (
AsyncCachedFunction[P, R]
| Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]
):
"""
Process-level cache decorator for async functions (no TTL).
Similar to functools.lru_cache but works with async functions.
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
Args:
func: The async function to cache (when used without parentheses)
maxsize: Maximum number of cached entries
Returns:
Decorated function or decorator
Example:
# Without parentheses (uses default maxsize=128)
@async_cache
async def get_data(param: str) -> dict:
return {"result": param}
# With parentheses and custom maxsize
@async_cache(maxsize=1000)
async def expensive_computation(param: str) -> dict:
# Expensive computation here
return {"result": param}
"""
if func is None:
# Called with parentheses @async_cache() or @async_cache(maxsize=...)
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
else:
# Called without parentheses @async_cache
decorator = async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
return decorator(func)

View File

@@ -16,12 +16,7 @@ from unittest.mock import Mock
import pytest
from autogpt_libs.utils.cache import (
async_cache,
async_ttl_cache,
clear_thread_cache,
thread_cached,
)
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
class TestThreadCached:
@@ -328,378 +323,3 @@ class TestThreadCached:
assert function_using_mock(2) == 42
assert mock.call_count == 2
class TestAsyncTTLCache:
"""Tests for the @async_ttl_cache decorator."""
@pytest.mark.asyncio
async def test_basic_caching(self):
"""Test basic caching functionality."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def cached_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01) # Simulate async work
return x + y
# First call
result1 = await cached_function(1, 2)
assert result1 == 3
assert call_count == 1
# Second call with same args - should use cache
result2 = await cached_function(1, 2)
assert result2 == 3
assert call_count == 1 # No additional call
# Different args - should call function again
result3 = await cached_function(2, 3)
assert result3 == 5
assert call_count == 2
@pytest.mark.asyncio
async def test_ttl_expiration(self):
"""Test that cache entries expire after TTL."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
async def short_lived_cache(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 2
# First call
result1 = await short_lived_cache(5)
assert result1 == 10
assert call_count == 1
# Second call immediately - should use cache
result2 = await short_lived_cache(5)
assert result2 == 10
assert call_count == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# Third call after expiration - should call function again
result3 = await short_lived_cache(5)
assert result3 == 10
assert call_count == 2
@pytest.mark.asyncio
async def test_cache_info(self):
"""Test cache info functionality."""
@async_ttl_cache(maxsize=5, ttl_seconds=300)
async def info_test_function(x: int) -> int:
return x * 3
# Check initial cache info
info = info_test_function.cache_info()
assert info["size"] == 0
assert info["maxsize"] == 5
assert info["ttl_seconds"] == 300
# Add an entry
await info_test_function(1)
info = info_test_function.cache_info()
assert info["size"] == 1
@pytest.mark.asyncio
async def test_cache_clear(self):
"""Test cache clearing functionality."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def clearable_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 4
# First call
result1 = await clearable_function(2)
assert result1 == 8
assert call_count == 1
# Second call - should use cache
result2 = await clearable_function(2)
assert result2 == 8
assert call_count == 1
# Clear cache
clearable_function.cache_clear()
# Third call after clear - should call function again
result3 = await clearable_function(2)
assert result3 == 8
assert call_count == 2
@pytest.mark.asyncio
async def test_maxsize_cleanup(self):
"""Test that cache cleans up when maxsize is exceeded."""
call_count = 0
@async_ttl_cache(maxsize=3, ttl_seconds=60)
async def size_limited_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x**2
# Fill cache to maxsize
await size_limited_function(1) # call_count: 1
await size_limited_function(2) # call_count: 2
await size_limited_function(3) # call_count: 3
info = size_limited_function.cache_info()
assert info["size"] == 3
# Add one more entry - should trigger cleanup
await size_limited_function(4) # call_count: 4
# Cache size should be reduced (cleanup removes oldest entries)
info = size_limited_function.cache_info()
assert info["size"] is not None and info["size"] <= 3 # Should be cleaned up
@pytest.mark.asyncio
async def test_argument_variations(self):
"""Test caching with different argument patterns."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
nonlocal call_count
call_count += 1
return f"{a}-{b}-{c}"
# Different ways to call with same logical arguments
result1 = await arg_test_function(1, "test", c=200)
assert call_count == 1
# Same arguments, same order - should use cache
result2 = await arg_test_function(1, "test", c=200)
assert call_count == 1
assert result1 == result2
# Different arguments - should call function
result3 = await arg_test_function(2, "test", c=200)
assert call_count == 2
assert result1 != result3
@pytest.mark.asyncio
async def test_exception_handling(self):
"""Test that exceptions are not cached."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def exception_function(x: int) -> int:
nonlocal call_count
call_count += 1
if x < 0:
raise ValueError("Negative value not allowed")
return x * 2
# Successful call - should be cached
result1 = await exception_function(5)
assert result1 == 10
assert call_count == 1
# Same successful call - should use cache
result2 = await exception_function(5)
assert result2 == 10
assert call_count == 1
# Exception call - should not be cached
with pytest.raises(ValueError):
await exception_function(-1)
assert call_count == 2
# Same exception call - should call again (not cached)
with pytest.raises(ValueError):
await exception_function(-1)
assert call_count == 3
@pytest.mark.asyncio
async def test_concurrent_calls(self):
"""Test caching behavior with concurrent calls."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def concurrent_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.05) # Simulate work
return x * x
# Launch concurrent calls with same arguments
tasks = [concurrent_function(3) for _ in range(5)]
results = await asyncio.gather(*tasks)
# All results should be the same
assert all(result == 9 for result in results)
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
# This tests that the cache doesn't break under concurrent access
assert 1 <= call_count <= 5
class TestAsyncCache:
"""Tests for the @async_cache decorator (no TTL)."""
@pytest.mark.asyncio
async def test_basic_caching_no_ttl(self):
"""Test basic caching functionality without TTL."""
call_count = 0
@async_cache(maxsize=10)
async def cached_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01) # Simulate async work
return x + y
# First call
result1 = await cached_function(1, 2)
assert result1 == 3
assert call_count == 1
# Second call with same args - should use cache
result2 = await cached_function(1, 2)
assert result2 == 3
assert call_count == 1 # No additional call
# Third call after some time - should still use cache (no TTL)
await asyncio.sleep(0.05)
result3 = await cached_function(1, 2)
assert result3 == 3
assert call_count == 1 # Still no additional call
# Different args - should call function again
result4 = await cached_function(2, 3)
assert result4 == 5
assert call_count == 2
@pytest.mark.asyncio
async def test_no_ttl_vs_ttl_behavior(self):
"""Test the difference between TTL and no-TTL caching."""
ttl_call_count = 0
no_ttl_call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
async def ttl_function(x: int) -> int:
nonlocal ttl_call_count
ttl_call_count += 1
return x * 2
@async_cache(maxsize=10) # No TTL
async def no_ttl_function(x: int) -> int:
nonlocal no_ttl_call_count
no_ttl_call_count += 1
return x * 2
# First calls
await ttl_function(5)
await no_ttl_function(5)
assert ttl_call_count == 1
assert no_ttl_call_count == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# Second calls after TTL expiry
await ttl_function(5) # Should call function again (TTL expired)
await no_ttl_function(5) # Should use cache (no TTL)
assert ttl_call_count == 2 # TTL function called again
assert no_ttl_call_count == 1 # No-TTL function still cached
@pytest.mark.asyncio
async def test_async_cache_info(self):
"""Test cache info for no-TTL cache."""
@async_cache(maxsize=5)
async def info_test_function(x: int) -> int:
return x * 3
# Check initial cache info
info = info_test_function.cache_info()
assert info["size"] == 0
assert info["maxsize"] == 5
assert info["ttl_seconds"] is None # No TTL
# Add an entry
await info_test_function(1)
info = info_test_function.cache_info()
assert info["size"] == 1
class TestTTLOptional:
"""Tests for optional TTL functionality."""
@pytest.mark.asyncio
async def test_ttl_none_behavior(self):
"""Test that ttl_seconds=None works like no TTL."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=None)
async def no_ttl_via_none(x: int) -> int:
nonlocal call_count
call_count += 1
return x**2
# First call
result1 = await no_ttl_via_none(3)
assert result1 == 9
assert call_count == 1
# Wait (would expire if there was TTL)
await asyncio.sleep(0.1)
# Second call - should still use cache
result2 = await no_ttl_via_none(3)
assert result2 == 9
assert call_count == 1 # No additional call
# Check cache info
info = no_ttl_via_none.cache_info()
assert info["ttl_seconds"] is None
@pytest.mark.asyncio
async def test_cache_options_comparison(self):
"""Test different cache options work as expected."""
ttl_calls = 0
no_ttl_calls = 0
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
async def ttl_function(x: int) -> int:
nonlocal ttl_calls
ttl_calls += 1
return x * 10
@async_cache(maxsize=10) # Process-level cache (no TTL)
async def process_function(x: int) -> int:
nonlocal no_ttl_calls
no_ttl_calls += 1
return x * 10
# Both should cache initially
await ttl_function(3)
await process_function(3)
assert ttl_calls == 1
assert no_ttl_calls == 1
# Immediate second calls - both should use cache
await ttl_function(3)
await process_function(3)
assert ttl_calls == 1
assert no_ttl_calls == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# After TTL expiry
await ttl_function(3) # Should call function again
await process_function(3) # Should still use cache
assert ttl_calls == 2 # TTL cache expired, called again
assert no_ttl_calls == 1 # Process cache never expires

View File

@@ -1,52 +0,0 @@
# Development and testing files
**/__pycache__
**/*.pyc
**/*.pyo
**/*.pyd
**/.Python
**/env/
**/venv/
**/.venv/
**/pip-log.txt
**/.pytest_cache/
**/test-results/
**/snapshots/
**/test/
# IDE and editor files
**/.vscode/
**/.idea/
**/*.swp
**/*.swo
*~
# OS files
.DS_Store
Thumbs.db
# Logs
**/*.log
**/logs/
# Git
.git/
.gitignore
# Documentation
**/*.md
!README.md
# Local development files
.env
.env.local
**/.env.test
# Build artifacts
**/dist/
**/build/
**/target/
# Docker files (avoid recursion)
Dockerfile*
docker-compose*
.dockerignore

View File

@@ -1,9 +1,3 @@
# Backend Configuration
# This file contains environment variables that MUST be set for the AutoGPT platform
# Variables with working defaults in settings.py are not included here
## ===== REQUIRED DATABASE CONFIGURATION ===== ##
# PostgreSQL Database Connection
DB_USER=postgres
DB_PASS=your-super-secret-and-long-postgres-password
DB_NAME=postgres
@@ -16,50 +10,72 @@ DB_SCHEMA=platform
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
PRISMA_SCHEMA="postgres/schema.prisma"
ENABLE_AUTH=true
## ===== REQUIRED SERVICE CREDENTIALS ===== ##
# Redis Configuration
# EXECUTOR
NUM_GRAPH_WORKERS=10
BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
# generate using `from cryptography.fernet import Fernet;Fernet.generate_key().decode()`
ENCRYPTION_KEY='dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw='
UNSUBSCRIBE_SECRET_KEY = 'HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio='
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_PASSWORD=password
# RabbitMQ Credentials
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
ENABLE_CREDIT=false
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=
# Supabase Authentication
# What environment things should be logged under: local dev or prod
APP_ENV=local
# What environment to behave as: "local" or "cloud"
BEHAVE_AS=local
PYRO_HOST=localhost
SENTRY_DSN=
# Email For Postmark so we can send emails
POSTMARK_SERVER_API_TOKEN=
POSTMARK_SENDER_EMAIL=invalid@invalid.com
POSTMARK_WEBHOOK_TOKEN=
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
ENABLE_AUTH=true
SUPABASE_URL=http://localhost:8000
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
## ===== REQUIRED SECURITY KEYS ===== ##
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
# RabbitMQ credentials -- Used for communication between services
RABBITMQ_HOST=localhost
RABBITMQ_PORT=5672
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
# Platform URLs (set these for webhooks and OAuth to work)
PLATFORM_BASE_URL=http://localhost:8000
FRONTEND_BASE_URL=http://localhost:3000
# Media Storage (required for marketplace and library functionality)
## GCS bucket is required for marketplace and library functionality
MEDIA_GCS_BUCKET_NAME=
## ===== API KEYS AND OAUTH CREDENTIALS ===== ##
# All API keys below are optional - only add what you need
## For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow
## for integrations to work. Defaults to the value of PLATFORM_BASE_URL if not set.
# FRONTEND_BASE_URL=http://localhost:3000
# AI/LLM Services
OPENAI_API_KEY=
ANTHROPIC_API_KEY=
GROQ_API_KEY=
LLAMA_API_KEY=
AIML_API_KEY=
V0_API_KEY=
OPEN_ROUTER_API_KEY=
NVIDIA_API_KEY=
## PLATFORM_BASE_URL must be set to a *publicly accessible* URL pointing to your backend
## to use the platform's webhook-related functionality.
## If you are developing locally, you can use something like ngrok to get a publc URL
## and tunnel it to your locally running backend.
PLATFORM_BASE_URL=http://localhost:3000
## Cloudflare Turnstile (CAPTCHA) Configuration
## Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
## This is the backend secret key
TURNSTILE_SECRET_KEY=
## This is the verify URL
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
## == INTEGRATION CREDENTIALS == ##
# Each set of server side credentials is required for the corresponding 3rd party
# integration to work.
# OAuth Credentials
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
# e.g. http://localhost:3000/auth/integrations/oauth_callback
@@ -69,6 +85,7 @@ GITHUB_CLIENT_SECRET=
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
# You'll need to add/enable the following scopes (minimum):
# https://console.developers.google.com/apis/api/gmail.googleapis.com/overview ?project=<your_project_id>
# https://console.cloud.google.com/apis/library/sheets.googleapis.com/ ?project=<your_project_id>
@@ -104,66 +121,104 @@ LINEAR_CLIENT_SECRET=
TODOIST_CLIENT_ID=
TODOIST_CLIENT_SECRET=
NOTION_CLIENT_ID=
NOTION_CLIENT_SECRET=
## ===== OPTIONAL API KEYS ===== ##
# LLM
OPENAI_API_KEY=
ANTHROPIC_API_KEY=
AIML_API_KEY=
GROQ_API_KEY=
OPEN_ROUTER_API_KEY=
LLAMA_API_KEY=
# Reddit
# Go to https://www.reddit.com/prefs/apps and create a new app
# Choose "script" for the type
# Fill in the redirect uri as <your_frontend_url>/auth/integrations/oauth_callback, e.g. http://localhost:3000/auth/integrations/oauth_callback
REDDIT_CLIENT_ID=
REDDIT_CLIENT_SECRET=
REDDIT_USER_AGENT="AutoGPT:1.0 (by /u/autogpt)"
# Payment Processing
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=
# Email Service (for sending notifications and confirmations)
POSTMARK_SERVER_API_TOKEN=
POSTMARK_SENDER_EMAIL=invalid@invalid.com
POSTMARK_WEBHOOK_TOKEN=
# Error Tracking
SENTRY_DSN=
# Cloudflare Turnstile (CAPTCHA) Configuration
# Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
# This is the backend secret key
TURNSTILE_SECRET_KEY=
# This is the verify URL
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
# Feature Flags
LAUNCH_DARKLY_SDK_KEY=
# Content Generation & Media
DID_API_KEY=
FAL_API_KEY=
IDEOGRAM_API_KEY=
REPLICATE_API_KEY=
REVID_API_KEY=
SCREENSHOTONE_API_KEY=
UNREAL_SPEECH_API_KEY=
# Data & Search Services
E2B_API_KEY=
EXA_API_KEY=
JINA_API_KEY=
MEM0_API_KEY=
OPENWEATHERMAP_API_KEY=
GOOGLE_MAPS_API_KEY=
# Communication Services
# Discord
DISCORD_BOT_TOKEN=
MEDIUM_API_KEY=
MEDIUM_AUTHOR_ID=
# SMTP/Email
SMTP_SERVER=
SMTP_PORT=
SMTP_USERNAME=
SMTP_PASSWORD=
# Business & Marketing Tools
# D-ID
DID_API_KEY=
# Open Weather Map
OPENWEATHERMAP_API_KEY=
# SMTP
SMTP_SERVER=
SMTP_PORT=
SMTP_USERNAME=
SMTP_PASSWORD=
# Medium
MEDIUM_API_KEY=
MEDIUM_AUTHOR_ID=
# Google Maps
GOOGLE_MAPS_API_KEY=
# Replicate
REPLICATE_API_KEY=
# Ideogram
IDEOGRAM_API_KEY=
# Fal
FAL_API_KEY=
# Exa
EXA_API_KEY=
# E2B
E2B_API_KEY=
# Mem0
MEM0_API_KEY=
# Nvidia
NVIDIA_API_KEY=
# Apollo
APOLLO_API_KEY=
ENRICHLAYER_API_KEY=
AYRSHARE_API_KEY=
AYRSHARE_JWT_KEY=
# SmartLead
SMARTLEAD_API_KEY=
# ZeroBounce
ZEROBOUNCE_API_KEY=
# Other Services
AUTOMOD_API_KEY=
# Ayrshare
AYRSHARE_API_KEY=
AYRSHARE_JWT_KEY=
## ===== OPTIONAL API KEYS END ===== ##
# Block Error Rate Monitoring
BLOCK_ERROR_RATE_THRESHOLD=0.5
BLOCK_ERROR_RATE_CHECK_INTERVAL_SECS=86400
# Logging Configuration
LOG_LEVEL=INFO
ENABLE_CLOUD_LOGGING=false
ENABLE_FILE_LOGGING=false
# Use to manually set the log directory
# LOG_DIR=./logs
# Example Blocks Configuration
# Set to true to enable example blocks in development
# These blocks are disabled by default in production
ENABLE_EXAMPLE_BLOCKS=false
# Cloud Storage Configuration
# Cleanup interval for expired files (hours between cleanup runs, 1-24 hours)
CLOUD_STORAGE_CLEANUP_INTERVAL_HOURS=6

View File

@@ -1,4 +1,3 @@
.env
database.db
database.db-journal
dev.db

View File

@@ -8,14 +8,14 @@ WORKDIR /app
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
# Update package list and install build dependencies in a single layer
RUN apt-get update --allow-releaseinfo-change --fix-missing \
&& apt-get install -y \
build-essential \
libpq5 \
libz-dev \
libssl-dev \
postgresql-client
RUN apt-get update --allow-releaseinfo-change --fix-missing
# Install build dependencies
RUN apt-get install -y build-essential
RUN apt-get install -y libpq5
RUN apt-get install -y libz-dev
RUN apt-get install -y libssl-dev
RUN apt-get install -y postgresql-client
ENV POETRY_HOME=/opt/poetry
ENV POETRY_NO_INTERACTION=1
@@ -68,12 +68,6 @@ COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.tom
WORKDIR /app/autogpt_platform/backend
FROM server_dependencies AS migrate
# Migration stage only needs schema and migrations - much lighter than full backend
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
FROM server_dependencies AS server
COPY autogpt_platform/backend /app/autogpt_platform/backend

View File

@@ -159,7 +159,6 @@ class AirtableOAuthHandler(BaseOAuthHandler):
logger.info("Successfully refreshed tokens")
new_credentials = OAuth2Credentials(
id=credentials.id,
access_token=SecretStr(response.access_token),
refresh_token=SecretStr(response.refresh_token),
access_token_expires_at=int(time.time()) + response.expires_in,

View File

@@ -1,5 +1,3 @@
from enum import Enum
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
@@ -13,12 +11,6 @@ from backend.sdk import (
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class TikTokVisibility(str, Enum):
PUBLIC = "public"
PRIVATE = "private"
FOLLOWERS = "followers"
class PostToTikTokBlock(Block):
"""Block for posting to TikTok with TikTok-specific options."""
@@ -28,6 +20,7 @@ class PostToTikTokBlock(Block):
# Override post field to include TikTok-specific information
post: str = SchemaField(
description="The post text (max 2,200 chars, empty string allowed). Use @handle to mention users. Line breaks will be ignored.",
default="",
advanced=False,
)
@@ -40,7 +33,7 @@ class PostToTikTokBlock(Block):
# TikTok-specific options
auto_add_music: bool = SchemaField(
description="Whether to automatically add recommended music to the post. If you set this field to true, you can change the music later in the TikTok app.",
description="Automatically add recommended music to image posts",
default=False,
advanced=True,
)
@@ -60,17 +53,17 @@ class PostToTikTokBlock(Block):
advanced=True,
)
is_ai_generated: bool = SchemaField(
description="If you enable the toggle, your video will be labeled as “Creator labeled as AI-generated” once posted and cant be changed. The “Creator labeled as AI-generated” label indicates that the content was completely AI-generated or significantly edited with AI.",
description="Label content as AI-generated (video only)",
default=False,
advanced=True,
)
is_branded_content: bool = SchemaField(
description="Whether to enable the Branded Content toggle. If this field is set to true, the video will be labeled as Branded Content, indicating you are in a paid partnership with a brand. A “Paid partnership” label will be attached to the video.",
description="Label as branded content (paid partnership)",
default=False,
advanced=True,
)
is_brand_organic: bool = SchemaField(
description="Whether to enable the Brand Organic Content toggle. If this field is set to true, the video will be labeled as Brand Organic Content, indicating you are promoting yourself or your own business. A “Promotional content” label will be attached to the video.",
description="Label as brand organic content (promotional)",
default=False,
advanced=True,
)
@@ -87,9 +80,9 @@ class PostToTikTokBlock(Block):
default=0,
advanced=True,
)
visibility: TikTokVisibility = SchemaField(
visibility: str = SchemaField(
description="Post visibility: 'public', 'private', 'followers', or 'friends'",
default=TikTokVisibility.PUBLIC,
default="public",
advanced=True,
)
draft: bool = SchemaField(
@@ -104,6 +97,7 @@ class PostToTikTokBlock(Block):
def __init__(self):
super().__init__(
disabled=True,
id="7faf4b27-96b0-4f05-bf64-e0de54ae74e1",
description="Post to TikTok using Ayrshare",
categories={BlockCategory.SOCIAL},
@@ -166,6 +160,12 @@ class PostToTikTokBlock(Block):
yield "error", f"Image cover index {input_data.image_cover_index} is out of range (max: {len(input_data.media_urls) - 1})"
return
# Validate visibility option
valid_visibility = ["public", "private", "followers", "friends"]
if input_data.visibility not in valid_visibility:
yield "error", f"TikTok visibility must be one of: {', '.join(valid_visibility)}"
return
# Check for PNG files (not supported)
has_png = any(url.lower().endswith(".png") for url in input_data.media_urls)
if has_png:
@@ -218,8 +218,8 @@ class PostToTikTokBlock(Block):
if input_data.title:
tiktok_options["title"] = input_data.title
if input_data.visibility != TikTokVisibility.PUBLIC:
tiktok_options["visibility"] = input_data.visibility.value
if input_data.visibility != "public":
tiktok_options["visibility"] = input_data.visibility
response = await client.create_post(
post=input_data.post,

File diff suppressed because it is too large Load Diff

View File

@@ -1,408 +0,0 @@
"""
API module for Enrichlayer integration.
This module provides a client for interacting with the Enrichlayer API,
which allows fetching LinkedIn profile data and related information.
"""
import datetime
import enum
import logging
from json import JSONDecodeError
from typing import Any, Optional, TypeVar
from pydantic import BaseModel, Field
from backend.data.model import APIKeyCredentials
from backend.util.request import Requests
logger = logging.getLogger(__name__)
T = TypeVar("T")
class EnrichlayerAPIException(Exception):
"""Exception raised for Enrichlayer API errors."""
def __init__(self, message: str, status_code: int):
super().__init__(message)
self.status_code = status_code
class FallbackToCache(enum.Enum):
ON_ERROR = "on-error"
NEVER = "never"
class UseCache(enum.Enum):
IF_PRESENT = "if-present"
NEVER = "never"
class SocialMediaProfiles(BaseModel):
"""Social media profiles model."""
twitter: Optional[str] = None
facebook: Optional[str] = None
github: Optional[str] = None
class Experience(BaseModel):
"""Experience model for LinkedIn profiles."""
company: Optional[str] = None
title: Optional[str] = None
description: Optional[str] = None
location: Optional[str] = None
starts_at: Optional[dict[str, int]] = None
ends_at: Optional[dict[str, int]] = None
company_linkedin_profile_url: Optional[str] = None
class Education(BaseModel):
"""Education model for LinkedIn profiles."""
school: Optional[str] = None
degree_name: Optional[str] = None
field_of_study: Optional[str] = None
starts_at: Optional[dict[str, int]] = None
ends_at: Optional[dict[str, int]] = None
school_linkedin_profile_url: Optional[str] = None
class PersonProfileResponse(BaseModel):
"""Response model for LinkedIn person profile.
This model represents the response from Enrichlayer's LinkedIn profile API.
The API returns comprehensive profile data including work experience,
education, skills, and contact information (when available).
Example API Response:
{
"public_identifier": "johnsmith",
"full_name": "John Smith",
"occupation": "Software Engineer at Tech Corp",
"experiences": [
{
"company": "Tech Corp",
"title": "Software Engineer",
"starts_at": {"year": 2020, "month": 1}
}
],
"education": [...],
"skills": ["Python", "JavaScript", ...]
}
"""
public_identifier: Optional[str] = None
profile_pic_url: Optional[str] = None
full_name: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
occupation: Optional[str] = None
headline: Optional[str] = None
summary: Optional[str] = None
country: Optional[str] = None
country_full_name: Optional[str] = None
city: Optional[str] = None
state: Optional[str] = None
experiences: Optional[list[Experience]] = None
education: Optional[list[Education]] = None
languages: Optional[list[str]] = None
skills: Optional[list[str]] = None
inferred_salary: Optional[dict[str, Any]] = None
personal_email: Optional[str] = None
personal_contact_number: Optional[str] = None
social_media_profiles: Optional[SocialMediaProfiles] = None
extra: Optional[dict[str, Any]] = None
class SimilarProfile(BaseModel):
"""Similar profile model for LinkedIn person lookup."""
similarity: float
linkedin_profile_url: str
class PersonLookupResponse(BaseModel):
"""Response model for LinkedIn person lookup.
This model represents the response from Enrichlayer's person lookup API.
The API returns a LinkedIn profile URL and similarity scores when
searching for a person by name and company.
Example API Response:
{
"url": "https://www.linkedin.com/in/johnsmith/",
"name_similarity_score": 0.95,
"company_similarity_score": 0.88,
"title_similarity_score": 0.75,
"location_similarity_score": 0.60
}
"""
url: str | None = None
name_similarity_score: float | None
company_similarity_score: float | None
title_similarity_score: float | None
location_similarity_score: float | None
last_updated: datetime.datetime | None = None
profile: PersonProfileResponse | None = None
class RoleLookupResponse(BaseModel):
"""Response model for LinkedIn role lookup.
This model represents the response from Enrichlayer's role lookup API.
The API returns LinkedIn profile data for a specific role at a company.
Example API Response:
{
"linkedin_profile_url": "https://www.linkedin.com/in/johnsmith/",
"profile_data": {...} // Full PersonProfileResponse data when enrich_profile=True
}
"""
linkedin_profile_url: Optional[str] = None
profile_data: Optional[PersonProfileResponse] = None
class ProfilePictureResponse(BaseModel):
"""Response model for LinkedIn profile picture.
This model represents the response from Enrichlayer's profile picture API.
The API returns a URL to the person's LinkedIn profile picture.
Example API Response:
{
"tmp_profile_pic_url": "https://media.licdn.com/dms/image/..."
}
"""
tmp_profile_pic_url: str = Field(
..., description="URL of the profile picture", alias="tmp_profile_pic_url"
)
@property
def profile_picture_url(self) -> str:
"""Backward compatibility property for profile_picture_url."""
return self.tmp_profile_pic_url
class EnrichlayerClient:
"""Client for interacting with the Enrichlayer API."""
API_BASE_URL = "https://enrichlayer.com/api/v2"
def __init__(
self,
credentials: Optional[APIKeyCredentials] = None,
custom_requests: Optional[Requests] = None,
):
"""
Initialize the Enrichlayer client.
Args:
credentials: The credentials to use for authentication.
custom_requests: Custom Requests instance for testing.
"""
if custom_requests:
self._requests = custom_requests
else:
headers: dict[str, str] = {
"Content-Type": "application/json",
}
if credentials:
headers["Authorization"] = (
f"Bearer {credentials.api_key.get_secret_value()}"
)
self._requests = Requests(
extra_headers=headers,
raise_for_status=False,
)
async def _handle_response(self, response) -> Any:
"""
Handle API response and check for errors.
Args:
response: The response object from the request.
Returns:
The response data.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
if not response.ok:
try:
error_data = response.json()
error_message = error_data.get("message", "")
except JSONDecodeError:
error_message = response.text
raise EnrichlayerAPIException(
f"Enrichlayer API request failed ({response.status_code}): {error_message}",
response.status_code,
)
return response.json()
async def fetch_profile(
self,
linkedin_url: str,
fallback_to_cache: FallbackToCache = FallbackToCache.ON_ERROR,
use_cache: UseCache = UseCache.IF_PRESENT,
include_skills: bool = False,
include_inferred_salary: bool = False,
include_personal_email: bool = False,
include_personal_contact_number: bool = False,
include_social_media: bool = False,
include_extra: bool = False,
) -> PersonProfileResponse:
"""
Fetch a LinkedIn profile with optional parameters.
Args:
linkedin_url: The LinkedIn profile URL to fetch.
fallback_to_cache: Cache usage if live fetch fails ('on-error' or 'never').
use_cache: Cache utilization ('if-present' or 'never').
include_skills: Whether to include skills data.
include_inferred_salary: Whether to include inferred salary data.
include_personal_email: Whether to include personal email.
include_personal_contact_number: Whether to include personal contact number.
include_social_media: Whether to include social media profiles.
include_extra: Whether to include additional data.
Returns:
The LinkedIn profile data.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
params = {
"url": linkedin_url,
"fallback_to_cache": fallback_to_cache.value.lower(),
"use_cache": use_cache.value.lower(),
}
if include_skills:
params["skills"] = "include"
if include_inferred_salary:
params["inferred_salary"] = "include"
if include_personal_email:
params["personal_email"] = "include"
if include_personal_contact_number:
params["personal_contact_number"] = "include"
if include_social_media:
params["twitter_profile_id"] = "include"
params["facebook_profile_id"] = "include"
params["github_profile_id"] = "include"
if include_extra:
params["extra"] = "include"
response = await self._requests.get(
f"{self.API_BASE_URL}/profile", params=params
)
return PersonProfileResponse(**await self._handle_response(response))
async def lookup_person(
self,
first_name: str,
company_domain: str,
last_name: str | None = None,
location: Optional[str] = None,
title: Optional[str] = None,
include_similarity_checks: bool = False,
enrich_profile: bool = False,
) -> PersonLookupResponse:
"""
Look up a LinkedIn profile by person's information.
Args:
first_name: The person's first name.
last_name: The person's last name.
company_domain: The domain of the company they work for.
location: The person's location.
title: The person's job title.
include_similarity_checks: Whether to include similarity checks.
enrich_profile: Whether to enrich the profile.
Returns:
The LinkedIn profile lookup result.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
params = {"first_name": first_name, "company_domain": company_domain}
if last_name:
params["last_name"] = last_name
if location:
params["location"] = location
if title:
params["title"] = title
if include_similarity_checks:
params["similarity_checks"] = "include"
if enrich_profile:
params["enrich_profile"] = "enrich"
response = await self._requests.get(
f"{self.API_BASE_URL}/profile/resolve", params=params
)
return PersonLookupResponse(**await self._handle_response(response))
async def lookup_role(
self, role: str, company_name: str, enrich_profile: bool = False
) -> RoleLookupResponse:
"""
Look up a LinkedIn profile by role in a company.
Args:
role: The role title (e.g., CEO, CTO).
company_name: The name of the company.
enrich_profile: Whether to enrich the profile.
Returns:
The LinkedIn profile lookup result.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
params = {
"role": role,
"company_name": company_name,
}
if enrich_profile:
params["enrich_profile"] = "enrich"
response = await self._requests.get(
f"{self.API_BASE_URL}/find/company/role", params=params
)
return RoleLookupResponse(**await self._handle_response(response))
async def get_profile_picture(
self, linkedin_profile_url: str
) -> ProfilePictureResponse:
"""
Get a LinkedIn profile picture URL.
Args:
linkedin_profile_url: The LinkedIn profile URL.
Returns:
The profile picture URL.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
params = {
"linkedin_person_profile_url": linkedin_profile_url,
}
response = await self._requests.get(
f"{self.API_BASE_URL}/person/profile-picture", params=params
)
return ProfilePictureResponse(**await self._handle_response(response))

View File

@@ -1,34 +0,0 @@
"""
Authentication module for Enrichlayer API integration.
This module provides credential types and test credentials for the Enrichlayer API.
"""
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
from backend.integrations.providers import ProviderName
# Define the type of credentials input expected for Enrichlayer API
EnrichlayerCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.ENRICHLAYER], Literal["api_key"]
]
# Mock credentials for testing Enrichlayer API integration
TEST_CREDENTIALS = APIKeyCredentials(
id="1234a567-89bc-4def-ab12-3456cdef7890",
provider="enrichlayer",
api_key=SecretStr("mock-enrichlayer-api-key"),
title="Mock Enrichlayer API key",
expires_at=None,
)
# Dictionary representation of test credentials for input fields
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}

View File

@@ -1,527 +0,0 @@
"""
Block definitions for Enrichlayer API integration.
This module implements blocks for interacting with the Enrichlayer API,
which provides access to LinkedIn profile data and related information.
"""
import logging
from typing import Optional
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
from backend.util.type import MediaFileType
from ._api import (
EnrichlayerClient,
Experience,
FallbackToCache,
PersonLookupResponse,
PersonProfileResponse,
RoleLookupResponse,
UseCache,
)
from ._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, EnrichlayerCredentialsInput
logger = logging.getLogger(__name__)
class GetLinkedinProfileBlock(Block):
"""Block to fetch LinkedIn profile data using Enrichlayer API."""
class Input(BlockSchema):
"""Input schema for GetLinkedinProfileBlock."""
linkedin_url: str = SchemaField(
description="LinkedIn profile URL to fetch data from",
placeholder="https://www.linkedin.com/in/username/",
)
fallback_to_cache: FallbackToCache = SchemaField(
description="Cache usage if live fetch fails",
default=FallbackToCache.ON_ERROR,
advanced=True,
)
use_cache: UseCache = SchemaField(
description="Cache utilization strategy",
default=UseCache.IF_PRESENT,
advanced=True,
)
include_skills: bool = SchemaField(
description="Include skills data",
default=False,
advanced=True,
)
include_inferred_salary: bool = SchemaField(
description="Include inferred salary data",
default=False,
advanced=True,
)
include_personal_email: bool = SchemaField(
description="Include personal email",
default=False,
advanced=True,
)
include_personal_contact_number: bool = SchemaField(
description="Include personal contact number",
default=False,
advanced=True,
)
include_social_media: bool = SchemaField(
description="Include social media profiles",
default=False,
advanced=True,
)
include_extra: bool = SchemaField(
description="Include additional data",
default=False,
advanced=True,
)
credentials: EnrichlayerCredentialsInput = CredentialsField(
description="Enrichlayer API credentials"
)
class Output(BlockSchema):
"""Output schema for GetLinkedinProfileBlock."""
profile: PersonProfileResponse = SchemaField(
description="LinkedIn profile data"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
"""Initialize GetLinkedinProfileBlock."""
super().__init__(
id="f6e0ac73-4f1d-4acb-b4b7-b67066c5984e",
description="Fetch LinkedIn profile data using Enrichlayer",
categories={BlockCategory.SOCIAL},
input_schema=GetLinkedinProfileBlock.Input,
output_schema=GetLinkedinProfileBlock.Output,
test_input={
"linkedin_url": "https://www.linkedin.com/in/williamhgates/",
"include_skills": True,
"include_social_media": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"profile",
PersonProfileResponse(
public_identifier="williamhgates",
full_name="Bill Gates",
occupation="Co-chair at Bill & Melinda Gates Foundation",
experiences=[
Experience(
company="Bill & Melinda Gates Foundation",
title="Co-chair",
starts_at={"year": 2000},
)
],
),
)
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_fetch_profile": lambda *args, **kwargs: PersonProfileResponse(
public_identifier="williamhgates",
full_name="Bill Gates",
occupation="Co-chair at Bill & Melinda Gates Foundation",
experiences=[
Experience(
company="Bill & Melinda Gates Foundation",
title="Co-chair",
starts_at={"year": 2000},
)
],
),
},
)
@staticmethod
async def _fetch_profile(
credentials: APIKeyCredentials,
linkedin_url: str,
fallback_to_cache: FallbackToCache = FallbackToCache.ON_ERROR,
use_cache: UseCache = UseCache.IF_PRESENT,
include_skills: bool = False,
include_inferred_salary: bool = False,
include_personal_email: bool = False,
include_personal_contact_number: bool = False,
include_social_media: bool = False,
include_extra: bool = False,
):
client = EnrichlayerClient(credentials)
profile = await client.fetch_profile(
linkedin_url=linkedin_url,
fallback_to_cache=fallback_to_cache,
use_cache=use_cache,
include_skills=include_skills,
include_inferred_salary=include_inferred_salary,
include_personal_email=include_personal_email,
include_personal_contact_number=include_personal_contact_number,
include_social_media=include_social_media,
include_extra=include_extra,
)
return profile
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Run the block to fetch LinkedIn profile data.
Args:
input_data: Input parameters for the block
credentials: API key credentials for Enrichlayer
**kwargs: Additional keyword arguments
Yields:
Tuples of (output_name, output_value)
"""
try:
profile = await self._fetch_profile(
credentials=credentials,
linkedin_url=input_data.linkedin_url,
fallback_to_cache=input_data.fallback_to_cache,
use_cache=input_data.use_cache,
include_skills=input_data.include_skills,
include_inferred_salary=input_data.include_inferred_salary,
include_personal_email=input_data.include_personal_email,
include_personal_contact_number=input_data.include_personal_contact_number,
include_social_media=input_data.include_social_media,
include_extra=input_data.include_extra,
)
yield "profile", profile
except Exception as e:
logger.error(f"Error fetching LinkedIn profile: {str(e)}")
yield "error", str(e)
class LinkedinPersonLookupBlock(Block):
"""Block to look up LinkedIn profiles by person's information using Enrichlayer API."""
class Input(BlockSchema):
"""Input schema for LinkedinPersonLookupBlock."""
first_name: str = SchemaField(
description="Person's first name",
placeholder="John",
advanced=False,
)
last_name: str | None = SchemaField(
description="Person's last name",
placeholder="Doe",
default=None,
advanced=False,
)
company_domain: str = SchemaField(
description="Domain of the company they work for (optional)",
placeholder="example.com",
advanced=False,
)
location: Optional[str] = SchemaField(
description="Person's location (optional)",
placeholder="San Francisco",
default=None,
)
title: Optional[str] = SchemaField(
description="Person's job title (optional)",
placeholder="CEO",
default=None,
)
include_similarity_checks: bool = SchemaField(
description="Include similarity checks",
default=False,
advanced=True,
)
enrich_profile: bool = SchemaField(
description="Enrich the profile with additional data",
default=False,
advanced=True,
)
credentials: EnrichlayerCredentialsInput = CredentialsField(
description="Enrichlayer API credentials"
)
class Output(BlockSchema):
"""Output schema for LinkedinPersonLookupBlock."""
lookup_result: PersonLookupResponse = SchemaField(
description="LinkedIn profile lookup result"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
"""Initialize LinkedinPersonLookupBlock."""
super().__init__(
id="d237a98a-5c4b-4a1c-b9e3-e6f9a6c81df7",
description="Look up LinkedIn profiles by person information using Enrichlayer",
categories={BlockCategory.SOCIAL},
input_schema=LinkedinPersonLookupBlock.Input,
output_schema=LinkedinPersonLookupBlock.Output,
test_input={
"first_name": "Bill",
"last_name": "Gates",
"company_domain": "gatesfoundation.org",
"include_similarity_checks": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"lookup_result",
PersonLookupResponse(
url="https://www.linkedin.com/in/williamhgates/",
name_similarity_score=0.93,
company_similarity_score=0.83,
title_similarity_score=0.3,
location_similarity_score=0.20,
),
)
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_lookup_person": lambda *args, **kwargs: PersonLookupResponse(
url="https://www.linkedin.com/in/williamhgates/",
name_similarity_score=0.93,
company_similarity_score=0.83,
title_similarity_score=0.3,
location_similarity_score=0.20,
)
},
)
@staticmethod
async def _lookup_person(
credentials: APIKeyCredentials,
first_name: str,
company_domain: str,
last_name: str | None = None,
location: Optional[str] = None,
title: Optional[str] = None,
include_similarity_checks: bool = False,
enrich_profile: bool = False,
):
client = EnrichlayerClient(credentials=credentials)
lookup_result = await client.lookup_person(
first_name=first_name,
last_name=last_name,
company_domain=company_domain,
location=location,
title=title,
include_similarity_checks=include_similarity_checks,
enrich_profile=enrich_profile,
)
return lookup_result
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Run the block to look up LinkedIn profiles.
Args:
input_data: Input parameters for the block
credentials: API key credentials for Enrichlayer
**kwargs: Additional keyword arguments
Yields:
Tuples of (output_name, output_value)
"""
try:
lookup_result = await self._lookup_person(
credentials=credentials,
first_name=input_data.first_name,
last_name=input_data.last_name,
company_domain=input_data.company_domain,
location=input_data.location,
title=input_data.title,
include_similarity_checks=input_data.include_similarity_checks,
enrich_profile=input_data.enrich_profile,
)
yield "lookup_result", lookup_result
except Exception as e:
logger.error(f"Error looking up LinkedIn profile: {str(e)}")
yield "error", str(e)
class LinkedinRoleLookupBlock(Block):
"""Block to look up LinkedIn profiles by role in a company using Enrichlayer API."""
class Input(BlockSchema):
"""Input schema for LinkedinRoleLookupBlock."""
role: str = SchemaField(
description="Role title (e.g., CEO, CTO)",
placeholder="CEO",
)
company_name: str = SchemaField(
description="Name of the company",
placeholder="Microsoft",
)
enrich_profile: bool = SchemaField(
description="Enrich the profile with additional data",
default=False,
advanced=True,
)
credentials: EnrichlayerCredentialsInput = CredentialsField(
description="Enrichlayer API credentials"
)
class Output(BlockSchema):
"""Output schema for LinkedinRoleLookupBlock."""
role_lookup_result: RoleLookupResponse = SchemaField(
description="LinkedIn role lookup result"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
"""Initialize LinkedinRoleLookupBlock."""
super().__init__(
id="3b9fc742-06d4-49c7-b5ce-7e302dd7c8a7",
description="Look up LinkedIn profiles by role in a company using Enrichlayer",
categories={BlockCategory.SOCIAL},
input_schema=LinkedinRoleLookupBlock.Input,
output_schema=LinkedinRoleLookupBlock.Output,
test_input={
"role": "Co-chair",
"company_name": "Gates Foundation",
"enrich_profile": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"role_lookup_result",
RoleLookupResponse(
linkedin_profile_url="https://www.linkedin.com/in/williamhgates/",
),
)
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_lookup_role": lambda *args, **kwargs: RoleLookupResponse(
linkedin_profile_url="https://www.linkedin.com/in/williamhgates/",
),
},
)
@staticmethod
async def _lookup_role(
credentials: APIKeyCredentials,
role: str,
company_name: str,
enrich_profile: bool = False,
):
client = EnrichlayerClient(credentials=credentials)
role_lookup_result = await client.lookup_role(
role=role,
company_name=company_name,
enrich_profile=enrich_profile,
)
return role_lookup_result
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Run the block to look up LinkedIn profiles by role.
Args:
input_data: Input parameters for the block
credentials: API key credentials for Enrichlayer
**kwargs: Additional keyword arguments
Yields:
Tuples of (output_name, output_value)
"""
try:
role_lookup_result = await self._lookup_role(
credentials=credentials,
role=input_data.role,
company_name=input_data.company_name,
enrich_profile=input_data.enrich_profile,
)
yield "role_lookup_result", role_lookup_result
except Exception as e:
logger.error(f"Error looking up role in company: {str(e)}")
yield "error", str(e)
class GetLinkedinProfilePictureBlock(Block):
"""Block to get LinkedIn profile pictures using Enrichlayer API."""
class Input(BlockSchema):
"""Input schema for GetLinkedinProfilePictureBlock."""
linkedin_profile_url: str = SchemaField(
description="LinkedIn profile URL",
placeholder="https://www.linkedin.com/in/username/",
)
credentials: EnrichlayerCredentialsInput = CredentialsField(
description="Enrichlayer API credentials"
)
class Output(BlockSchema):
"""Output schema for GetLinkedinProfilePictureBlock."""
profile_picture_url: MediaFileType = SchemaField(
description="LinkedIn profile picture URL"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
"""Initialize GetLinkedinProfilePictureBlock."""
super().__init__(
id="68d5a942-9b3f-4e9a-b7c1-d96ea4321f0d",
description="Get LinkedIn profile pictures using Enrichlayer",
categories={BlockCategory.SOCIAL},
input_schema=GetLinkedinProfilePictureBlock.Input,
output_schema=GetLinkedinProfilePictureBlock.Output,
test_input={
"linkedin_profile_url": "https://www.linkedin.com/in/williamhgates/",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"profile_picture_url",
"https://media.licdn.com/dms/image/C4D03AQFj-xjuXrLFSQ/profile-displayphoto-shrink_800_800/0/1576881858598?e=1686787200&v=beta&t=zrQC76QwsfQQIWthfOnrKRBMZ5D-qIAvzLXLmWgYvTk",
)
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_get_profile_picture": lambda *args, **kwargs: "https://media.licdn.com/dms/image/C4D03AQFj-xjuXrLFSQ/profile-displayphoto-shrink_800_800/0/1576881858598?e=1686787200&v=beta&t=zrQC76QwsfQQIWthfOnrKRBMZ5D-qIAvzLXLmWgYvTk",
},
)
@staticmethod
async def _get_profile_picture(
credentials: APIKeyCredentials, linkedin_profile_url: str
):
client = EnrichlayerClient(credentials=credentials)
profile_picture_response = await client.get_profile_picture(
linkedin_profile_url=linkedin_profile_url,
)
return profile_picture_response.profile_picture_url
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Run the block to get LinkedIn profile pictures.
Args:
input_data: Input parameters for the block
credentials: API key credentials for Enrichlayer
**kwargs: Additional keyword arguments
Yields:
Tuples of (output_name, output_value)
"""
try:
profile_picture = await self._get_profile_picture(
credentials=credentials,
linkedin_profile_url=input_data.linkedin_profile_url,
)
yield "profile_picture_url", profile_picture
except Exception as e:
logger.error(f"Error getting profile picture: {str(e)}")
yield "error", str(e)

View File

@@ -1,247 +0,0 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
# Enum definitions based on available options
class WebsetStatus(str, Enum):
IDLE = "idle"
PENDING = "pending"
RUNNING = "running"
PAUSED = "paused"
class WebsetSearchStatus(str, Enum):
CREATED = "created"
# Add more if known, based on example it's "created"
class ImportStatus(str, Enum):
PENDING = "pending"
# Add more if known
class ImportFormat(str, Enum):
CSV = "csv"
# Add more if known
class EnrichmentStatus(str, Enum):
PENDING = "pending"
# Add more if known
class EnrichmentFormat(str, Enum):
TEXT = "text"
# Add more if known
class MonitorStatus(str, Enum):
ENABLED = "enabled"
# Add more if known
class MonitorBehaviorType(str, Enum):
SEARCH = "search"
# Add more if known
class MonitorRunStatus(str, Enum):
CREATED = "created"
# Add more if known
class CanceledReason(str, Enum):
WEBSET_DELETED = "webset_deleted"
# Add more if known
class FailedReason(str, Enum):
INVALID_FORMAT = "invalid_format"
# Add more if known
class Confidence(str, Enum):
HIGH = "high"
# Add more if known
# Nested models
class Entity(BaseModel):
type: str
class Criterion(BaseModel):
description: str
successRate: Optional[int] = None
class ExcludeItem(BaseModel):
source: str = Field(default="import")
id: str
class Relationship(BaseModel):
definition: str
limit: Optional[float] = None
class ScopeItem(BaseModel):
source: str = Field(default="import")
id: str
relationship: Optional[Relationship] = None
class Progress(BaseModel):
found: int
analyzed: int
completion: int
timeLeft: int
class Bounds(BaseModel):
min: int
max: int
class Expected(BaseModel):
total: int
confidence: str = Field(default="high") # Use str or Confidence enum
bounds: Bounds
class Recall(BaseModel):
expected: Expected
reasoning: str
class WebsetSearch(BaseModel):
id: str
object: str = Field(default="webset_search")
status: str = Field(default="created") # Or use WebsetSearchStatus
websetId: str
query: str
entity: Entity
criteria: List[Criterion]
count: int
behavior: str = Field(default="override")
exclude: List[ExcludeItem]
scope: List[ScopeItem]
progress: Progress
recall: Recall
metadata: Dict[str, Any] = Field(default_factory=dict)
canceledAt: Optional[datetime] = None
canceledReason: Optional[str] = Field(default=None) # Or use CanceledReason
createdAt: datetime
updatedAt: datetime
class ImportEntity(BaseModel):
type: str
class Import(BaseModel):
id: str
object: str = Field(default="import")
status: str = Field(default="pending") # Or use ImportStatus
format: str = Field(default="csv") # Or use ImportFormat
entity: ImportEntity
title: str
count: int
metadata: Dict[str, Any] = Field(default_factory=dict)
failedReason: Optional[str] = Field(default=None) # Or use FailedReason
failedAt: Optional[datetime] = None
failedMessage: Optional[str] = None
createdAt: datetime
updatedAt: datetime
class Option(BaseModel):
label: str
class WebsetEnrichment(BaseModel):
id: str
object: str = Field(default="webset_enrichment")
status: str = Field(default="pending") # Or use EnrichmentStatus
websetId: str
title: str
description: str
format: str = Field(default="text") # Or use EnrichmentFormat
options: List[Option]
instructions: str
metadata: Dict[str, Any] = Field(default_factory=dict)
createdAt: datetime
updatedAt: datetime
class Cadence(BaseModel):
cron: str
timezone: str = Field(default="Etc/UTC")
class BehaviorConfig(BaseModel):
query: Optional[str] = None
criteria: Optional[List[Criterion]] = None
entity: Optional[Entity] = None
count: Optional[int] = None
behavior: Optional[str] = Field(default=None)
class Behavior(BaseModel):
type: str = Field(default="search") # Or use MonitorBehaviorType
config: BehaviorConfig
class MonitorRun(BaseModel):
id: str
object: str = Field(default="monitor_run")
status: str = Field(default="created") # Or use MonitorRunStatus
monitorId: str
type: str = Field(default="search")
completedAt: Optional[datetime] = None
failedAt: Optional[datetime] = None
failedReason: Optional[str] = None
canceledAt: Optional[datetime] = None
createdAt: datetime
updatedAt: datetime
class Monitor(BaseModel):
id: str
object: str = Field(default="monitor")
status: str = Field(default="enabled") # Or use MonitorStatus
websetId: str
cadence: Cadence
behavior: Behavior
lastRun: Optional[MonitorRun] = None
nextRunAt: Optional[datetime] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
createdAt: datetime
updatedAt: datetime
class Webset(BaseModel):
id: str
object: str = Field(default="webset")
status: WebsetStatus
externalId: Optional[str] = None
title: Optional[str] = None
searches: List[WebsetSearch]
imports: List[Import]
enrichments: List[WebsetEnrichment]
monitors: List[Monitor]
streams: List[Any]
createdAt: datetime
updatedAt: datetime
metadata: Dict[str, Any] = Field(default_factory=dict)
class ListWebsets(BaseModel):
data: List[Webset]
hasMore: bool
nextCursor: Optional[str] = None

View File

@@ -114,7 +114,6 @@ class ExaWebsetWebhookBlock(Block):
def __init__(self):
super().__init__(
disabled=True,
id="d0204ed8-8b81-408d-8b8d-ed087a546228",
description="Receive webhook notifications for Exa webset events",
categories={BlockCategory.INPUT},

View File

@@ -1,33 +1,7 @@
from datetime import datetime
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional
from exa_py import Exa
from exa_py.websets.types import (
CreateCriterionParameters,
CreateEnrichmentParameters,
CreateWebsetParameters,
CreateWebsetParametersSearch,
ExcludeItem,
Format,
ImportItem,
ImportSource,
Option,
ScopeItem,
ScopeRelationship,
ScopeSourceType,
WebsetArticleEntity,
WebsetCompanyEntity,
WebsetCustomEntity,
WebsetPersonEntity,
WebsetResearchPaperEntity,
WebsetStatus,
)
from pydantic import Field
from typing import Any, Optional
from backend.sdk import (
APIKeyCredentials,
BaseModel,
Block,
BlockCategory,
BlockOutput,
@@ -38,69 +12,7 @@ from backend.sdk import (
)
from ._config import exa
class SearchEntityType(str, Enum):
COMPANY = "company"
PERSON = "person"
ARTICLE = "article"
RESEARCH_PAPER = "research_paper"
CUSTOM = "custom"
AUTO = "auto"
class SearchType(str, Enum):
IMPORT = "import"
WEBSET = "webset"
class EnrichmentFormat(str, Enum):
TEXT = "text"
DATE = "date"
NUMBER = "number"
OPTIONS = "options"
EMAIL = "email"
PHONE = "phone"
class Webset(BaseModel):
id: str
status: WebsetStatus | None = Field(..., title="WebsetStatus")
"""
The status of the webset
"""
external_id: Annotated[Optional[str], Field(alias="externalId")] = None
"""
The external identifier for the webset
NOTE: Returning dict to avoid ui crashing due to nested objects
"""
searches: List[dict[str, Any]] | None = None
"""
The searches that have been performed on the webset.
NOTE: Returning dict to avoid ui crashing due to nested objects
"""
enrichments: List[dict[str, Any]] | None = None
"""
The Enrichments to apply to the Webset Items.
NOTE: Returning dict to avoid ui crashing due to nested objects
"""
monitors: List[dict[str, Any]] | None = None
"""
The Monitors for the Webset.
NOTE: Returning dict to avoid ui crashing due to nested objects
"""
metadata: Optional[Dict[str, Any]] = {}
"""
Set of key-value pairs you want to associate with this object.
"""
created_at: Annotated[datetime, Field(alias="createdAt")] | None = None
"""
The date and time the webset was created
"""
updated_at: Annotated[datetime, Field(alias="updatedAt")] | None = None
"""
The date and time the webset was last updated
"""
from .helpers import WebsetEnrichmentConfig, WebsetSearchConfig
class ExaCreateWebsetBlock(Block):
@@ -108,121 +20,40 @@ class ExaCreateWebsetBlock(Block):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
# Search parameters (flattened)
search_query: str = SchemaField(
description="Your search query. Use this to describe what you are looking for. Any URL provided will be crawled and used as context for the search.",
placeholder="Marketing agencies based in the US, that focus on consumer products",
search: WebsetSearchConfig = SchemaField(
description="Initial search configuration for the Webset"
)
search_count: Optional[int] = SchemaField(
default=10,
description="Number of items the search will attempt to find. The actual number of items found may be less than this number depending on the search complexity.",
ge=1,
le=1000,
)
search_entity_type: SearchEntityType = SchemaField(
default=SearchEntityType.AUTO,
description="Entity type: 'company', 'person', 'article', 'research_paper', or 'custom'. If not provided, we automatically detect the entity from the query.",
advanced=True,
)
search_entity_description: Optional[str] = SchemaField(
enrichments: Optional[list[WebsetEnrichmentConfig]] = SchemaField(
default=None,
description="Description for custom entity type (required when search_entity_type is 'custom')",
description="Enrichments to apply to Webset items",
advanced=True,
)
# Search criteria (flattened)
search_criteria: list[str] = SchemaField(
default_factory=list,
description="List of criteria descriptions that every item will be evaluated against. If not provided, we automatically detect the criteria from the query.",
advanced=True,
)
# Search exclude sources (flattened)
search_exclude_sources: list[str] = SchemaField(
default_factory=list,
description="List of source IDs (imports or websets) to exclude from search results",
advanced=True,
)
search_exclude_types: list[SearchType] = SchemaField(
default_factory=list,
description="List of source types corresponding to exclude sources ('import' or 'webset')",
advanced=True,
)
# Search scope sources (flattened)
search_scope_sources: list[str] = SchemaField(
default_factory=list,
description="List of source IDs (imports or websets) to limit search scope to",
advanced=True,
)
search_scope_types: list[SearchType] = SchemaField(
default_factory=list,
description="List of source types corresponding to scope sources ('import' or 'webset')",
advanced=True,
)
search_scope_relationships: list[str] = SchemaField(
default_factory=list,
description="List of relationship definitions for hop searches (optional, one per scope source)",
advanced=True,
)
search_scope_relationship_limits: list[int] = SchemaField(
default_factory=list,
description="List of limits on the number of related entities to find (optional, one per scope relationship)",
advanced=True,
)
# Import parameters (flattened)
import_sources: list[str] = SchemaField(
default_factory=list,
description="List of source IDs to import from",
advanced=True,
)
import_types: list[SearchType] = SchemaField(
default_factory=list,
description="List of source types corresponding to import sources ('import' or 'webset')",
advanced=True,
)
# Enrichment parameters (flattened)
enrichment_descriptions: list[str] = SchemaField(
default_factory=list,
description="List of enrichment task descriptions to perform on each webset item",
advanced=True,
)
enrichment_formats: list[EnrichmentFormat] = SchemaField(
default_factory=list,
description="List of formats for enrichment responses ('text', 'date', 'number', 'options', 'email', 'phone'). If not specified, we automatically select the best format.",
advanced=True,
)
enrichment_options: list[list[str]] = SchemaField(
default_factory=list,
description="List of option lists for enrichments with 'options' format. Each inner list contains the option labels.",
advanced=True,
)
enrichment_metadata: list[dict] = SchemaField(
default_factory=list,
description="List of metadata dictionaries for enrichments",
advanced=True,
)
# Webset metadata
external_id: Optional[str] = SchemaField(
default=None,
description="External identifier for the webset. You can use this to reference the webset by your own internal identifiers.",
description="External identifier for the webset",
placeholder="my-webset-123",
advanced=True,
)
metadata: Optional[dict] = SchemaField(
default_factory=dict,
default=None,
description="Key-value pairs to associate with this webset",
advanced=True,
)
class Output(BlockSchema):
webset: Webset = SchemaField(
webset_id: str = SchemaField(
description="The unique identifier for the created webset"
)
status: str = SchemaField(description="The status of the webset")
external_id: Optional[str] = SchemaField(
description="The external identifier for the webset", default=None
)
created_at: str = SchemaField(
description="The date and time the webset was created"
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
@@ -236,171 +67,44 @@ class ExaCreateWebsetBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/websets/v0/websets"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
exa = Exa(credentials.api_key.get_secret_value())
# Build the payload
payload: dict[str, Any] = {
"search": input_data.search.model_dump(exclude_none=True),
}
# ------------------------------------------------------------
# Build entity (if explicitly provided)
# ------------------------------------------------------------
entity = None
if input_data.search_entity_type == SearchEntityType.COMPANY:
entity = WebsetCompanyEntity(type="company")
elif input_data.search_entity_type == SearchEntityType.PERSON:
entity = WebsetPersonEntity(type="person")
elif input_data.search_entity_type == SearchEntityType.ARTICLE:
entity = WebsetArticleEntity(type="article")
elif input_data.search_entity_type == SearchEntityType.RESEARCH_PAPER:
entity = WebsetResearchPaperEntity(type="research_paper")
elif (
input_data.search_entity_type == SearchEntityType.CUSTOM
and input_data.search_entity_description
):
entity = WebsetCustomEntity(
type="custom", description=input_data.search_entity_description
)
# Convert enrichments to API format
if input_data.enrichments:
enrichments_data = []
for enrichment in input_data.enrichments:
enrichments_data.append(enrichment.model_dump(exclude_none=True))
payload["enrichments"] = enrichments_data
# ------------------------------------------------------------
# Build criteria list
# ------------------------------------------------------------
criteria = None
if input_data.search_criteria:
criteria = [
CreateCriterionParameters(description=item)
for item in input_data.search_criteria
]
if input_data.external_id:
payload["externalId"] = input_data.external_id
# ------------------------------------------------------------
# Build exclude sources list
# ------------------------------------------------------------
exclude_items = None
if input_data.search_exclude_sources:
exclude_items = []
for idx, src_id in enumerate(input_data.search_exclude_sources):
src_type = None
if input_data.search_exclude_types and idx < len(
input_data.search_exclude_types
):
src_type = input_data.search_exclude_types[idx]
# Default to IMPORT if type missing
if src_type == SearchType.WEBSET:
source_enum = ImportSource.webset
else:
source_enum = ImportSource.import_
exclude_items.append(ExcludeItem(source=source_enum, id=src_id))
if input_data.metadata:
payload["metadata"] = input_data.metadata
# ------------------------------------------------------------
# Build scope list
# ------------------------------------------------------------
scope_items = None
if input_data.search_scope_sources:
scope_items = []
for idx, src_id in enumerate(input_data.search_scope_sources):
src_type = None
if input_data.search_scope_types and idx < len(
input_data.search_scope_types
):
src_type = input_data.search_scope_types[idx]
relationship = None
if input_data.search_scope_relationships and idx < len(
input_data.search_scope_relationships
):
rel_def = input_data.search_scope_relationships[idx]
lim = None
if input_data.search_scope_relationship_limits and idx < len(
input_data.search_scope_relationship_limits
):
lim = input_data.search_scope_relationship_limits[idx]
relationship = ScopeRelationship(definition=rel_def, limit=lim)
if src_type == SearchType.WEBSET:
src_enum = ScopeSourceType.webset
else:
src_enum = ScopeSourceType.import_
scope_items.append(
ScopeItem(source=src_enum, id=src_id, relationship=relationship)
)
try:
response = await Requests().post(url, headers=headers, json=payload)
data = response.json()
# ------------------------------------------------------------
# Assemble search parameters (only if a query is provided)
# ------------------------------------------------------------
search_params = None
if input_data.search_query:
search_params = CreateWebsetParametersSearch(
query=input_data.search_query,
count=input_data.search_count,
entity=entity,
criteria=criteria,
exclude=exclude_items,
scope=scope_items,
)
yield "webset_id", data.get("id", "")
yield "status", data.get("status", "")
yield "external_id", data.get("externalId")
yield "created_at", data.get("createdAt", "")
# ------------------------------------------------------------
# Build imports list
# ------------------------------------------------------------
imports_params = None
if input_data.import_sources:
imports_params = []
for idx, src_id in enumerate(input_data.import_sources):
src_type = None
if input_data.import_types and idx < len(input_data.import_types):
src_type = input_data.import_types[idx]
if src_type == SearchType.WEBSET:
source_enum = ImportSource.webset
else:
source_enum = ImportSource.import_
imports_params.append(ImportItem(source=source_enum, id=src_id))
# ------------------------------------------------------------
# Build enrichment list
# ------------------------------------------------------------
enrichments_params = None
if input_data.enrichment_descriptions:
enrichments_params = []
for idx, desc in enumerate(input_data.enrichment_descriptions):
fmt = None
if input_data.enrichment_formats and idx < len(
input_data.enrichment_formats
):
fmt_enum = input_data.enrichment_formats[idx]
if fmt_enum is not None:
fmt = Format(
fmt_enum.value if isinstance(fmt_enum, Enum) else fmt_enum
)
options_list = None
if input_data.enrichment_options and idx < len(
input_data.enrichment_options
):
raw_opts = input_data.enrichment_options[idx]
if raw_opts:
options_list = [Option(label=o) for o in raw_opts]
metadata_obj = None
if input_data.enrichment_metadata and idx < len(
input_data.enrichment_metadata
):
metadata_obj = input_data.enrichment_metadata[idx]
enrichments_params.append(
CreateEnrichmentParameters(
description=desc,
format=fmt,
options=options_list,
metadata=metadata_obj,
)
)
# ------------------------------------------------------------
# Create the webset
# ------------------------------------------------------------
webset = exa.websets.create(
params=CreateWebsetParameters(
search=search_params,
imports=imports_params,
enrichments=enrichments_params,
external_id=input_data.external_id,
metadata=input_data.metadata,
)
)
# Use alias field names returned from Exa SDK so that nested models validate correctly
yield "webset", Webset.model_validate(webset.model_dump(by_alias=True))
except Exception as e:
yield "error", str(e)
yield "webset_id", ""
yield "status", ""
yield "created_at", ""
class ExaUpdateWebsetBlock(Block):
@@ -479,11 +183,6 @@ class ExaListWebsetsBlock(Block):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
trigger: Any | None = SchemaField(
default=None,
description="Trigger for the webset, value is ignored!",
advanced=False,
)
cursor: Optional[str] = SchemaField(
default=None,
description="Cursor for pagination through results",
@@ -498,9 +197,7 @@ class ExaListWebsetsBlock(Block):
)
class Output(BlockSchema):
websets: list[Webset] = SchemaField(
description="List of websets", default_factory=list
)
websets: list = SchemaField(description="List of websets", default_factory=list)
has_more: bool = SchemaField(
description="Whether there are more results to paginate through",
default=False,
@@ -558,6 +255,9 @@ class ExaGetWebsetBlock(Block):
description="The ID or external ID of the Webset to retrieve",
placeholder="webset-id-or-external-id",
)
expand_items: bool = SchemaField(
default=False, description="Include items in the response", advanced=True
)
class Output(BlockSchema):
webset_id: str = SchemaField(description="The unique identifier for the webset")
@@ -609,8 +309,12 @@ class ExaGetWebsetBlock(Block):
"x-api-key": credentials.api_key.get_secret_value(),
}
params = {}
if input_data.expand_items:
params["expand[]"] = "items"
try:
response = await Requests().get(url, headers=headers)
response = await Requests().get(url, headers=headers, params=params)
data = response.json()
yield "webset_id", data.get("id", "")

View File

@@ -29,8 +29,8 @@ class FirecrawlExtractBlock(Block):
prompt: str | None = SchemaField(
description="The prompt to use for the crawl", default=None, advanced=False
)
output_schema: dict | None = SchemaField(
description="A Json Schema describing the output structure if more rigid structure is desired.",
output_schema: str | None = SchemaField(
description="A more rigid structure if you already know the JSON layout.",
default=None,
)
enable_web_search: bool = SchemaField(
@@ -56,6 +56,7 @@ class FirecrawlExtractBlock(Block):
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
# Sync call
extract_result = app.extract(
urls=input_data.urls,
prompt=input_data.prompt,

View File

@@ -1,388 +0,0 @@
import logging
import re
from enum import Enum
from typing import Optional
from typing_extensions import TypedDict
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
logger = logging.getLogger(__name__)
class CheckRunStatus(Enum):
QUEUED = "queued"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
class CheckRunConclusion(Enum):
SUCCESS = "success"
FAILURE = "failure"
NEUTRAL = "neutral"
CANCELLED = "cancelled"
SKIPPED = "skipped"
TIMED_OUT = "timed_out"
ACTION_REQUIRED = "action_required"
class GithubGetCIResultsBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
target: str | int = SchemaField(
description="Commit SHA or PR number to get CI results for",
placeholder="abc123def or 123",
)
search_pattern: Optional[str] = SchemaField(
description="Optional regex pattern to search for in CI logs (e.g., error messages, file names)",
placeholder=".*error.*|.*warning.*",
default=None,
advanced=True,
)
check_name_filter: Optional[str] = SchemaField(
description="Optional filter for specific check names (supports wildcards)",
placeholder="*lint* or build-*",
default=None,
advanced=True,
)
class Output(BlockSchema):
class CheckRunItem(TypedDict, total=False):
id: int
name: str
status: str
conclusion: Optional[str]
started_at: Optional[str]
completed_at: Optional[str]
html_url: str
details_url: Optional[str]
output_title: Optional[str]
output_summary: Optional[str]
output_text: Optional[str]
annotations: list[dict]
class MatchedLine(TypedDict):
check_name: str
line_number: int
line: str
context: list[str]
check_run: CheckRunItem = SchemaField(
title="Check Run",
description="Individual CI check run with details",
)
check_runs: list[CheckRunItem] = SchemaField(
description="List of all CI check runs"
)
matched_line: MatchedLine = SchemaField(
title="Matched Line",
description="Line matching the search pattern with context",
)
matched_lines: list[MatchedLine] = SchemaField(
description="All lines matching the search pattern across all checks"
)
overall_status: str = SchemaField(
description="Overall CI status (pending, success, failure)"
)
overall_conclusion: str = SchemaField(
description="Overall CI conclusion if completed"
)
total_checks: int = SchemaField(description="Total number of CI checks")
passed_checks: int = SchemaField(description="Number of passed checks")
failed_checks: int = SchemaField(description="Number of failed checks")
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="8ad9e103-78f2-4fdb-ba12-3571f2c95e98",
description="This block gets CI results for a commit or PR, with optional search for specific errors/warnings in logs.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubGetCIResultsBlock.Input,
output_schema=GithubGetCIResultsBlock.Output,
test_input={
"repo": "owner/repo",
"target": "abc123def456",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("overall_status", "completed"),
("overall_conclusion", "success"),
("total_checks", 1),
("passed_checks", 1),
("failed_checks", 0),
(
"check_runs",
[
{
"id": 123456,
"name": "build",
"status": "completed",
"conclusion": "success",
"started_at": "2024-01-01T00:00:00Z",
"completed_at": "2024-01-01T00:05:00Z",
"html_url": "https://github.com/owner/repo/runs/123456",
"details_url": None,
"output_title": "Build passed",
"output_summary": "All tests passed",
"output_text": "Build log output...",
"annotations": [],
}
],
),
],
test_mock={
"get_ci_results": lambda *args, **kwargs: {
"check_runs": [
{
"id": 123456,
"name": "build",
"status": "completed",
"conclusion": "success",
"started_at": "2024-01-01T00:00:00Z",
"completed_at": "2024-01-01T00:05:00Z",
"html_url": "https://github.com/owner/repo/runs/123456",
"details_url": None,
"output_title": "Build passed",
"output_summary": "All tests passed",
"output_text": "Build log output...",
"annotations": [],
}
],
"total_count": 1,
}
},
)
@staticmethod
async def get_commit_sha(api, repo: str, target: str | int) -> str:
"""Get commit SHA from either a commit SHA or PR URL."""
# If it's already a SHA, return it
if isinstance(target, str):
if re.match(r"^[0-9a-f]{6,40}$", target, re.IGNORECASE):
return target
# If it's a PR URL, get the head SHA
if isinstance(target, int):
pr_url = f"https://api.github.com/repos/{repo}/pulls/{target}"
response = await api.get(pr_url)
pr_data = response.json()
return pr_data["head"]["sha"]
raise ValueError("Target must be a commit SHA or PR URL")
@staticmethod
async def search_in_logs(
check_runs: list,
pattern: str,
) -> list[Output.MatchedLine]:
"""Search for pattern in check run logs."""
if not pattern:
return []
matched_lines = []
regex = re.compile(pattern, re.IGNORECASE | re.MULTILINE)
for check in check_runs:
output_text = check.get("output_text", "") or ""
if not output_text:
continue
lines = output_text.split("\n")
for i, line in enumerate(lines):
if regex.search(line):
# Get context (2 lines before and after)
start = max(0, i - 2)
end = min(len(lines), i + 3)
context = lines[start:end]
matched_lines.append(
{
"check_name": check["name"],
"line_number": i + 1,
"line": line,
"context": context,
}
)
return matched_lines
@staticmethod
async def get_ci_results(
credentials: GithubCredentials,
repo: str,
target: str | int,
search_pattern: Optional[str] = None,
check_name_filter: Optional[str] = None,
) -> dict:
api = get_api(credentials, convert_urls=False)
# Get the commit SHA
commit_sha = await GithubGetCIResultsBlock.get_commit_sha(api, repo, target)
# Get check runs for the commit
check_runs_url = (
f"https://api.github.com/repos/{repo}/commits/{commit_sha}/check-runs"
)
# Get all pages of check runs
all_check_runs = []
page = 1
per_page = 100
while True:
response = await api.get(
check_runs_url, params={"per_page": per_page, "page": page}
)
data = response.json()
check_runs = data.get("check_runs", [])
all_check_runs.extend(check_runs)
if len(check_runs) < per_page:
break
page += 1
# Filter by check name if specified
if check_name_filter:
import fnmatch
filtered_runs = []
for run in all_check_runs:
if fnmatch.fnmatch(run["name"].lower(), check_name_filter.lower()):
filtered_runs.append(run)
all_check_runs = filtered_runs
# Get check run details with logs
detailed_runs = []
for run in all_check_runs:
# Get detailed output including logs
if run.get("output", {}).get("text"):
# Already has output
detailed_run = {
"id": run["id"],
"name": run["name"],
"status": run["status"],
"conclusion": run.get("conclusion"),
"started_at": run.get("started_at"),
"completed_at": run.get("completed_at"),
"html_url": run["html_url"],
"details_url": run.get("details_url"),
"output_title": run.get("output", {}).get("title"),
"output_summary": run.get("output", {}).get("summary"),
"output_text": run.get("output", {}).get("text"),
"annotations": [],
}
else:
# Try to get logs from the check run
detailed_run = {
"id": run["id"],
"name": run["name"],
"status": run["status"],
"conclusion": run.get("conclusion"),
"started_at": run.get("started_at"),
"completed_at": run.get("completed_at"),
"html_url": run["html_url"],
"details_url": run.get("details_url"),
"output_title": run.get("output", {}).get("title"),
"output_summary": run.get("output", {}).get("summary"),
"output_text": None,
"annotations": [],
}
# Get annotations if available
if run.get("output", {}).get("annotations_count", 0) > 0:
annotations_url = f"https://api.github.com/repos/{repo}/check-runs/{run['id']}/annotations"
try:
ann_response = await api.get(annotations_url)
detailed_run["annotations"] = ann_response.json()
except Exception:
pass
detailed_runs.append(detailed_run)
return {
"check_runs": detailed_runs,
"total_count": len(detailed_runs),
}
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
target = int(input_data.target)
except ValueError:
target = input_data.target
result = await self.get_ci_results(
credentials,
input_data.repo,
target,
input_data.search_pattern,
input_data.check_name_filter,
)
check_runs = result["check_runs"]
# Calculate overall status
if not check_runs:
yield "overall_status", "no_checks"
yield "overall_conclusion", "no_checks"
else:
all_completed = all(run["status"] == "completed" for run in check_runs)
if all_completed:
yield "overall_status", "completed"
# Determine overall conclusion
has_failure = any(
run["conclusion"] in ["failure", "timed_out", "action_required"]
for run in check_runs
)
if has_failure:
yield "overall_conclusion", "failure"
else:
yield "overall_conclusion", "success"
else:
yield "overall_status", "pending"
yield "overall_conclusion", "pending"
# Count checks
total = len(check_runs)
passed = sum(1 for run in check_runs if run.get("conclusion") == "success")
failed = sum(
1 for run in check_runs if run.get("conclusion") in ["failure", "timed_out"]
)
yield "total_checks", total
yield "passed_checks", passed
yield "failed_checks", failed
# Output check runs
yield "check_runs", check_runs
# Search for patterns if specified
if input_data.search_pattern:
matched_lines = await self.search_in_logs(
check_runs, input_data.search_pattern
)
if matched_lines:
yield "matched_lines", matched_lines

View File

@@ -1,840 +0,0 @@
import logging
from enum import Enum
from typing import Any, List, Optional
from typing_extensions import TypedDict
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
logger = logging.getLogger(__name__)
class ReviewEvent(Enum):
COMMENT = "COMMENT"
APPROVE = "APPROVE"
REQUEST_CHANGES = "REQUEST_CHANGES"
class GithubCreatePRReviewBlock(Block):
class Input(BlockSchema):
class ReviewComment(TypedDict, total=False):
path: str
position: Optional[int]
body: str
line: Optional[int] # Will be used as position if position not provided
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
pr_number: int = SchemaField(
description="Pull request number",
placeholder="123",
)
body: str = SchemaField(
description="Body of the review comment",
placeholder="Enter your review comment",
)
event: ReviewEvent = SchemaField(
description="The review action to perform",
default=ReviewEvent.COMMENT,
)
create_as_draft: bool = SchemaField(
description="Create the review as a draft (pending) or post it immediately",
default=False,
advanced=False,
)
comments: Optional[List[ReviewComment]] = SchemaField(
description="Optional inline comments to add to specific files/lines. Note: Only path, body, and position are supported. Position is line number in diff from first @@ hunk.",
default=None,
advanced=True,
)
class Output(BlockSchema):
review_id: int = SchemaField(description="ID of the created review")
state: str = SchemaField(
description="State of the review (e.g., PENDING, COMMENTED, APPROVED, CHANGES_REQUESTED)"
)
html_url: str = SchemaField(description="URL of the created review")
error: str = SchemaField(
description="Error message if the review creation failed"
)
def __init__(self):
super().__init__(
id="84754b30-97d2-4c37-a3b8-eb39f268275b",
description="This block creates a review on a GitHub pull request with optional inline comments. You can create it as a draft or post immediately. Note: For inline comments, 'position' should be the line number in the diff (starting from the first @@ hunk header).",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCreatePRReviewBlock.Input,
output_schema=GithubCreatePRReviewBlock.Output,
test_input={
"repo": "owner/repo",
"pr_number": 1,
"body": "This looks good to me!",
"event": "APPROVE",
"create_as_draft": False,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("review_id", 123456),
("state", "APPROVED"),
(
"html_url",
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
),
],
test_mock={
"create_review": lambda *args, **kwargs: (
123456,
"APPROVED",
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
)
},
)
@staticmethod
async def create_review(
credentials: GithubCredentials,
repo: str,
pr_number: int,
body: str,
event: ReviewEvent,
create_as_draft: bool,
comments: Optional[List[Input.ReviewComment]] = None,
) -> tuple[int, str, str]:
api = get_api(credentials, convert_urls=False)
# GitHub API endpoint for creating reviews
reviews_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews"
# Get commit_id if we have comments
commit_id = None
if comments:
# Get PR details to get the head commit for inline comments
pr_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}"
pr_response = await api.get(pr_url)
pr_data = pr_response.json()
commit_id = pr_data["head"]["sha"]
# Prepare the request data
# If create_as_draft is True, omit the event field (creates a PENDING review)
# Otherwise, use the actual event value which will auto-submit the review
data: dict[str, Any] = {"body": body}
# Add commit_id if we have it
if commit_id:
data["commit_id"] = commit_id
# Add comments if provided
if comments:
# Process comments to ensure they have the required fields
processed_comments = []
for comment in comments:
comment_data: dict = {
"path": comment.get("path", ""),
"body": comment.get("body", ""),
}
# Add position or line
# Note: For review comments, only position is supported (not line/side)
if "position" in comment and comment.get("position") is not None:
comment_data["position"] = comment.get("position")
elif "line" in comment and comment.get("line") is not None:
# Note: Using line as position - may not work correctly
# Position should be calculated from the diff
comment_data["position"] = comment.get("line")
# Note: side, start_line, and start_side are NOT supported for review comments
# They are only for standalone PR comments
processed_comments.append(comment_data)
data["comments"] = processed_comments
if not create_as_draft:
# Only add event field if not creating a draft
data["event"] = event.value
# Create the review
response = await api.post(reviews_url, json=data)
review_data = response.json()
return review_data["id"], review_data["state"], review_data["html_url"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
review_id, state, html_url = await self.create_review(
credentials,
input_data.repo,
input_data.pr_number,
input_data.body,
input_data.event,
input_data.create_as_draft,
input_data.comments,
)
yield "review_id", review_id
yield "state", state
yield "html_url", html_url
except Exception as e:
yield "error", str(e)
class GithubListPRReviewsBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
pr_number: int = SchemaField(
description="Pull request number",
placeholder="123",
)
class Output(BlockSchema):
class ReviewItem(TypedDict):
id: int
user: str
state: str
body: str
html_url: str
review: ReviewItem = SchemaField(
title="Review",
description="Individual review with details",
)
reviews: list[ReviewItem] = SchemaField(
description="List of all reviews on the pull request"
)
error: str = SchemaField(description="Error message if listing reviews failed")
def __init__(self):
super().__init__(
id="f79bc6eb-33c0-4099-9c0f-d664ae1ba4d0",
description="This block lists all reviews for a specified GitHub pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListPRReviewsBlock.Input,
output_schema=GithubListPRReviewsBlock.Output,
test_input={
"repo": "owner/repo",
"pr_number": 1,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"reviews",
[
{
"id": 123456,
"user": "reviewer1",
"state": "APPROVED",
"body": "Looks good!",
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
}
],
),
(
"review",
{
"id": 123456,
"user": "reviewer1",
"state": "APPROVED",
"body": "Looks good!",
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
},
),
],
test_mock={
"list_reviews": lambda *args, **kwargs: [
{
"id": 123456,
"user": "reviewer1",
"state": "APPROVED",
"body": "Looks good!",
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
}
]
},
)
@staticmethod
async def list_reviews(
credentials: GithubCredentials, repo: str, pr_number: int
) -> list[Output.ReviewItem]:
api = get_api(credentials, convert_urls=False)
# GitHub API endpoint for listing reviews
reviews_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews"
response = await api.get(reviews_url)
data = response.json()
reviews: list[GithubListPRReviewsBlock.Output.ReviewItem] = [
{
"id": review["id"],
"user": review["user"]["login"],
"state": review["state"],
"body": review.get("body", ""),
"html_url": review["html_url"],
}
for review in data
]
return reviews
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
reviews = await self.list_reviews(
credentials,
input_data.repo,
input_data.pr_number,
)
yield "reviews", reviews
for review in reviews:
yield "review", review
class GithubSubmitPendingReviewBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
pr_number: int = SchemaField(
description="Pull request number",
placeholder="123",
)
review_id: int = SchemaField(
description="ID of the pending review to submit",
placeholder="123456",
)
event: ReviewEvent = SchemaField(
description="The review action to perform when submitting",
default=ReviewEvent.COMMENT,
)
class Output(BlockSchema):
state: str = SchemaField(description="State of the submitted review")
html_url: str = SchemaField(description="URL of the submitted review")
error: str = SchemaField(
description="Error message if the review submission failed"
)
def __init__(self):
super().__init__(
id="2e468217-7ca0-4201-9553-36e93eb9357a",
description="This block submits a pending (draft) review on a GitHub pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubSubmitPendingReviewBlock.Input,
output_schema=GithubSubmitPendingReviewBlock.Output,
test_input={
"repo": "owner/repo",
"pr_number": 1,
"review_id": 123456,
"event": "APPROVE",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("state", "APPROVED"),
(
"html_url",
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
),
],
test_mock={
"submit_review": lambda *args, **kwargs: (
"APPROVED",
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
)
},
)
@staticmethod
async def submit_review(
credentials: GithubCredentials,
repo: str,
pr_number: int,
review_id: int,
event: ReviewEvent,
) -> tuple[str, str]:
api = get_api(credentials, convert_urls=False)
# GitHub API endpoint for submitting a review
submit_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews/{review_id}/events"
data = {"event": event.value}
response = await api.post(submit_url, json=data)
review_data = response.json()
return review_data["state"], review_data["html_url"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
state, html_url = await self.submit_review(
credentials,
input_data.repo,
input_data.pr_number,
input_data.review_id,
input_data.event,
)
yield "state", state
yield "html_url", html_url
except Exception as e:
yield "error", str(e)
class GithubResolveReviewDiscussionBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
pr_number: int = SchemaField(
description="Pull request number",
placeholder="123",
)
comment_id: int = SchemaField(
description="ID of the review comment to resolve/unresolve",
placeholder="123456",
)
resolve: bool = SchemaField(
description="Whether to resolve (true) or unresolve (false) the discussion",
default=True,
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the operation was successful")
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="b4b8a38c-95ae-4c91-9ef8-c2cffaf2b5d1",
description="This block resolves or unresolves a review discussion thread on a GitHub pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubResolveReviewDiscussionBlock.Input,
output_schema=GithubResolveReviewDiscussionBlock.Output,
test_input={
"repo": "owner/repo",
"pr_number": 1,
"comment_id": 123456,
"resolve": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"resolve_discussion": lambda *args, **kwargs: True},
)
@staticmethod
async def resolve_discussion(
credentials: GithubCredentials,
repo: str,
pr_number: int,
comment_id: int,
resolve: bool,
) -> bool:
api = get_api(credentials, convert_urls=False)
# Extract owner and repo name
parts = repo.split("/")
owner = parts[0]
repo_name = parts[1]
# GitHub GraphQL API is needed for resolving/unresolving discussions
# First, we need to get the node ID of the comment
graphql_url = "https://api.github.com/graphql"
# Query to get the review comment node ID
query = """
query($owner: String!, $repo: String!, $number: Int!) {
repository(owner: $owner, name: $repo) {
pullRequest(number: $number) {
reviewThreads(first: 100) {
nodes {
comments(first: 100) {
nodes {
databaseId
id
}
}
id
isResolved
}
}
}
}
}
"""
variables = {"owner": owner, "repo": repo_name, "number": pr_number}
response = await api.post(
graphql_url, json={"query": query, "variables": variables}
)
data = response.json()
# Find the thread containing our comment
thread_id = None
for thread in data["data"]["repository"]["pullRequest"]["reviewThreads"][
"nodes"
]:
for comment in thread["comments"]["nodes"]:
if comment["databaseId"] == comment_id:
thread_id = thread["id"]
break
if thread_id:
break
if not thread_id:
raise ValueError(f"Comment {comment_id} not found in pull request")
# Now resolve or unresolve the thread
# GitHub's GraphQL API has separate mutations for resolve and unresolve
if resolve:
mutation = """
mutation($threadId: ID!) {
resolveReviewThread(input: {threadId: $threadId}) {
thread {
isResolved
}
}
}
"""
else:
mutation = """
mutation($threadId: ID!) {
unresolveReviewThread(input: {threadId: $threadId}) {
thread {
isResolved
}
}
}
"""
mutation_variables = {"threadId": thread_id}
response = await api.post(
graphql_url, json={"query": mutation, "variables": mutation_variables}
)
result = response.json()
if "errors" in result:
raise Exception(f"GraphQL error: {result['errors']}")
return True
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
success = await self.resolve_discussion(
credentials,
input_data.repo,
input_data.pr_number,
input_data.comment_id,
input_data.resolve,
)
yield "success", success
except Exception as e:
yield "success", False
yield "error", str(e)
class GithubGetPRReviewCommentsBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
pr_number: int = SchemaField(
description="Pull request number",
placeholder="123",
)
review_id: Optional[int] = SchemaField(
description="ID of a specific review to get comments from (optional)",
placeholder="123456",
default=None,
advanced=True,
)
class Output(BlockSchema):
class CommentItem(TypedDict):
id: int
user: str
body: str
path: str
line: int
side: str
created_at: str
updated_at: str
in_reply_to_id: Optional[int]
html_url: str
comment: CommentItem = SchemaField(
title="Comment",
description="Individual review comment with details",
)
comments: list[CommentItem] = SchemaField(
description="List of all review comments on the pull request"
)
error: str = SchemaField(description="Error message if getting comments failed")
def __init__(self):
super().__init__(
id="1d34db7f-10c1-45c1-9d43-749f743c8bd4",
description="This block gets all review comments from a GitHub pull request or from a specific review.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubGetPRReviewCommentsBlock.Input,
output_schema=GithubGetPRReviewCommentsBlock.Output,
test_input={
"repo": "owner/repo",
"pr_number": 1,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"comments",
[
{
"id": 123456,
"user": "reviewer1",
"body": "This needs improvement",
"path": "src/main.py",
"line": 42,
"side": "RIGHT",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z",
"in_reply_to_id": None,
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
}
],
),
(
"comment",
{
"id": 123456,
"user": "reviewer1",
"body": "This needs improvement",
"path": "src/main.py",
"line": 42,
"side": "RIGHT",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z",
"in_reply_to_id": None,
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
},
),
],
test_mock={
"get_comments": lambda *args, **kwargs: [
{
"id": 123456,
"user": "reviewer1",
"body": "This needs improvement",
"path": "src/main.py",
"line": 42,
"side": "RIGHT",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z",
"in_reply_to_id": None,
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
}
]
},
)
@staticmethod
async def get_comments(
credentials: GithubCredentials,
repo: str,
pr_number: int,
review_id: Optional[int] = None,
) -> list[Output.CommentItem]:
api = get_api(credentials, convert_urls=False)
# Determine the endpoint based on whether we want comments from a specific review
if review_id:
# Get comments from a specific review
comments_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews/{review_id}/comments"
else:
# Get all review comments on the PR
comments_url = (
f"https://api.github.com/repos/{repo}/pulls/{pr_number}/comments"
)
response = await api.get(comments_url)
data = response.json()
comments: list[GithubGetPRReviewCommentsBlock.Output.CommentItem] = [
{
"id": comment["id"],
"user": comment["user"]["login"],
"body": comment["body"],
"path": comment.get("path", ""),
"line": comment.get("line", 0),
"side": comment.get("side", ""),
"created_at": comment["created_at"],
"updated_at": comment["updated_at"],
"in_reply_to_id": comment.get("in_reply_to_id"),
"html_url": comment["html_url"],
}
for comment in data
]
return comments
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
comments = await self.get_comments(
credentials,
input_data.repo,
input_data.pr_number,
input_data.review_id,
)
yield "comments", comments
for comment in comments:
yield "comment", comment
except Exception as e:
yield "error", str(e)
class GithubCreateCommentObjectBlock(Block):
class Input(BlockSchema):
path: str = SchemaField(
description="The file path to comment on",
placeholder="src/main.py",
)
body: str = SchemaField(
description="The comment text",
placeholder="Please fix this issue",
)
position: Optional[int] = SchemaField(
description="Position in the diff (line number from first @@ hunk). Use this OR line.",
placeholder="6",
default=None,
advanced=True,
)
line: Optional[int] = SchemaField(
description="Line number in the file (will be used as position if position not provided)",
placeholder="42",
default=None,
advanced=True,
)
side: Optional[str] = SchemaField(
description="Side of the diff to comment on (NOTE: Only for standalone comments, not review comments)",
default="RIGHT",
advanced=True,
)
start_line: Optional[int] = SchemaField(
description="Start line for multi-line comments (NOTE: Only for standalone comments, not review comments)",
default=None,
advanced=True,
)
start_side: Optional[str] = SchemaField(
description="Side for the start of multi-line comments (NOTE: Only for standalone comments, not review comments)",
default=None,
advanced=True,
)
class Output(BlockSchema):
comment_object: dict = SchemaField(
description="The comment object formatted for GitHub API"
)
def __init__(self):
super().__init__(
id="b7d5e4f2-8c3a-4e6b-9f1d-7a8b9c5e4d3f",
description="Creates a comment object for use with GitHub blocks. Note: For review comments, only path, body, and position are used. Side fields are only for standalone PR comments.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCreateCommentObjectBlock.Input,
output_schema=GithubCreateCommentObjectBlock.Output,
test_input={
"path": "src/main.py",
"body": "Please fix this issue",
"position": 6,
},
test_output=[
(
"comment_object",
{
"path": "src/main.py",
"body": "Please fix this issue",
"position": 6,
},
),
],
)
async def run(
self,
input_data: Input,
**kwargs,
) -> BlockOutput:
# Build the comment object
comment_obj: dict = {
"path": input_data.path,
"body": input_data.body,
}
# Add position or line
if input_data.position is not None:
comment_obj["position"] = input_data.position
elif input_data.line is not None:
# Note: line will be used as position, which may not be accurate
# Position should be calculated from the diff
comment_obj["position"] = input_data.line
# Add optional fields only if they differ from defaults or are explicitly provided
if input_data.side and input_data.side != "RIGHT":
comment_obj["side"] = input_data.side
if input_data.start_line is not None:
comment_obj["start_line"] = input_data.start_line
if input_data.start_side:
comment_obj["start_side"] = input_data.start_side
yield "comment_object", comment_obj

View File

@@ -21,8 +21,6 @@ from ._auth import (
GoogleCredentialsInput,
)
settings = Settings()
class CalendarEvent(BaseModel):
"""Structured representation of a Google Calendar event."""
@@ -223,8 +221,8 @@ class GoogleCalendarReadEventsBlock(Block):
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=settings.secrets.google_client_id,
client_secret=settings.secrets.google_client_secret,
client_id=Settings().secrets.google_client_id,
client_secret=Settings().secrets.google_client_secret,
scopes=credentials.scopes,
)
return build("calendar", "v3", credentials=creds)
@@ -571,8 +569,8 @@ class GoogleCalendarCreateEventBlock(Block):
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=settings.secrets.google_client_id,
client_secret=settings.secrets.google_client_secret,
client_id=Settings().secrets.google_client_id,
client_secret=Settings().secrets.google_client_secret,
scopes=credentials.scopes,
)
return build("calendar", "v3", credentials=creds)

File diff suppressed because it is too large Load Diff

View File

@@ -37,7 +37,6 @@ LLMProviderName = Literal[
ProviderName.OPENAI,
ProviderName.OPEN_ROUTER,
ProviderName.LLAMA_API,
ProviderName.V0,
]
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
@@ -82,11 +81,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
O3 = "o3-2025-04-16"
O1 = "o1"
O1_MINI = "o1-mini"
# GPT-5 models
GPT5 = "gpt-5-2025-08-07"
GPT5_MINI = "gpt-5-mini-2025-08-07"
GPT5_NANO = "gpt-5-nano-2025-08-07"
GPT5_CHAT = "gpt-5-chat-latest"
GPT41 = "gpt-4.1-2025-04-14"
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
GPT4O_MINI = "gpt-4o-mini"
@@ -94,7 +88,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"
# Anthropic models
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
CLAUDE_4_OPUS = "claude-opus-4-20250514"
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
@@ -156,10 +149,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
LLAMA_API_LLAMA3_3_8B = "Llama-3.3-8B-Instruct"
LLAMA_API_LLAMA3_3_70B = "Llama-3.3-70B-Instruct"
# v0 by Vercel models
V0_1_5_MD = "v0-1.5-md"
V0_1_5_LG = "v0-1.5-lg"
V0_1_0_MD = "v0-1.0-md"
@property
def metadata(self) -> ModelMetadata:
@@ -184,11 +173,6 @@ MODEL_METADATA = {
LlmModel.O3_MINI: ModelMetadata("openai", 200000, 100000), # o3-mini-2025-01-31
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
# GPT-5 models
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
LlmModel.GPT41: ModelMetadata("openai", 1047576, 32768),
LlmModel.GPT41_MINI: ModelMetadata("openai", 1047576, 32768),
LlmModel.GPT4O_MINI: ModelMetadata(
@@ -200,9 +184,6 @@ MODEL_METADATA = {
), # gpt-4-turbo-2024-04-09
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
# https://docs.anthropic.com/en/docs/about-claude/models
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
"anthropic", 200000, 32000
), # claude-opus-4-1-20250805
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
"anthropic", 200000, 8192
), # claude-4-opus-20250514
@@ -285,10 +266,6 @@ MODEL_METADATA = {
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata("llama_api", 128000, 4028),
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata("llama_api", 128000, 4028),
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata("llama_api", 128000, 4028),
# v0 by Vercel models
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000),
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000),
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
}
for model in LlmModel:
@@ -502,7 +479,6 @@ async def llm_call(
messages=messages,
max_tokens=max_tokens,
tools=an_tools,
timeout=600,
)
if not resp.content:
@@ -685,11 +661,7 @@ async def llm_call(
client = openai.OpenAI(
base_url="https://api.aimlapi.com/v2",
api_key=credentials.api_key.get_secret_value(),
default_headers={
"X-Project": "AutoGPT",
"X-Title": "AutoGPT",
"HTTP-Referer": "https://github.com/Significant-Gravitas/AutoGPT",
},
default_headers={"X-Project": "AutoGPT"},
)
completion = client.chat.completions.create(
@@ -709,42 +681,6 @@ async def llm_call(
),
reasoning=None,
)
elif provider == "v0":
tools_param = tools if tools else openai.NOT_GIVEN
client = openai.AsyncOpenAI(
base_url="https://api.v0.dev/v1",
api_key=credentials.api_key.get_secret_value(),
)
response_format = None
if json_format:
response_format = {"type": "json_object"}
parallel_tool_calls_param = get_parallel_tool_calls_param(
llm_model, parallel_tool_calls
)
response = await client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=parallel_tool_calls_param,
)
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
response=response.choices[0].message.content or "",
tool_calls=tool_calls,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=reasoning,
)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@@ -3,7 +3,8 @@ from typing import List
from backend.data.block import BlockOutput, BlockSchema
from backend.data.model import APIKeyCredentials, SchemaField
from backend.util.settings import BehaveAs, Settings
from backend.util import settings
from backend.util.settings import BehaveAs
from ._api import (
TEST_CREDENTIALS,
@@ -15,8 +16,6 @@ from ._api import (
)
from .base import Slant3DBlockBase
settings = Settings()
class Slant3DCreateOrderBlock(Slant3DBlockBase):
"""Block for creating new orders"""
@@ -281,7 +280,7 @@ class Slant3DGetOrdersBlock(Slant3DBlockBase):
input_schema=self.Input,
output_schema=self.Output,
# This block is disabled for cloud hosted because it allows access to all orders for the account
disabled=settings.config.behave_as == BehaveAs.CLOUD,
disabled=settings.Settings().config.behave_as == BehaveAs.CLOUD,
test_input={"credentials": TEST_CREDENTIALS_INPUT},
test_credentials=TEST_CREDENTIALS,
test_output=[

View File

@@ -9,7 +9,8 @@ from backend.data.block import (
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.util.settings import AppEnvironment, BehaveAs, Settings
from backend.util import settings
from backend.util.settings import AppEnvironment, BehaveAs
from ._api import (
TEST_CREDENTIALS,
@@ -18,8 +19,6 @@ from ._api import (
Slant3DCredentialsInput,
)
settings = Settings()
class Slant3DTriggerBase:
"""Base class for Slant3D webhook triggers"""
@@ -77,8 +76,8 @@ class Slant3DOrderWebhookBlock(Slant3DTriggerBase, Block):
),
# All webhooks are currently subscribed to for all orders. This works for self hosted, but not for cloud hosted prod
disabled=(
settings.config.behave_as == BehaveAs.CLOUD
and settings.config.app_env != AppEnvironment.LOCAL
settings.Settings().config.behave_as == BehaveAs.CLOUD
and settings.Settings().config.app_env != AppEnvironment.LOCAL
),
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=self.Input,

View File

@@ -291,32 +291,9 @@ class SmartDecisionMakerBlock(Block):
for link in links:
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
# Handle dynamic fields (e.g., values_#_*, items_$_*, etc.)
# These are fields that get merged by the executor into their base field
if (
"_#_" in link.sink_name
or "_$_" in link.sink_name
or "_@_" in link.sink_name
):
# For dynamic fields, provide a generic string schema
# The executor will handle merging these into the appropriate structure
properties[sink_name] = {
"type": "string",
"description": f"Dynamic value for {link.sink_name}",
}
else:
# For regular fields, use the block's schema
try:
properties[sink_name] = sink_block_input_schema.get_field_schema(
link.sink_name
)
except (KeyError, AttributeError):
# If the field doesn't exist in the schema, provide a generic schema
properties[sink_name] = {
"type": "string",
"description": f"Value for {link.sink_name}",
}
properties[sink_name] = sink_block_input_schema.get_field_schema(
link.sink_name
)
tool_function["parameters"] = {
**block.input_schema.jsonschema(),
@@ -501,6 +478,10 @@ class SmartDecisionMakerBlock(Block):
}
)
prompt.extend(tool_output)
if input_data.multiple_tool_calls:
input_data.sys_prompt += "\nYou can call a tool (different tools) multiple times in a single response."
else:
input_data.sys_prompt += "\nOnly provide EXACTLY one function call, multiple tool calls is strictly prohibited."
values = input_data.prompt_values
if values:

View File

@@ -1,8 +1,9 @@
import logging
import pytest
from prisma.models import User
from backend.data.model import ProviderName, User
from backend.data.model import ProviderName
from backend.server.model import CreateGraph
from backend.server.rest_api import AgentServer
from backend.usecases.sample import create_test_graph, create_test_user

View File

@@ -1,130 +0,0 @@
from unittest.mock import Mock
import pytest
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
@pytest.mark.asyncio
async def test_smart_decision_maker_handles_dynamic_dict_fields():
"""Test Smart Decision Maker can handle dynamic dictionary fields (_#_) for any block"""
# Create a mock node for CreateDictionaryBlock
mock_node = Mock()
mock_node.block = CreateDictionaryBlock()
mock_node.block_id = CreateDictionaryBlock().id
mock_node.input_default = {}
# Create mock links with dynamic dictionary fields
mock_links = [
Mock(
source_name="tools_^_create_dict_~_name",
sink_name="values_#_name", # Dynamic dict field
sink_id="dict_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_create_dict_~_age",
sink_name="values_#_age", # Dynamic dict field
sink_id="dict_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_create_dict_~_city",
sink_name="values_#_city", # Dynamic dict field
sink_id="dict_node_id",
source_id="smart_decision_node_id",
),
]
# Generate function signature
signature = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, mock_links # type: ignore
)
# Verify the signature was created successfully
assert signature["type"] == "function"
assert "parameters" in signature["function"]
assert "properties" in signature["function"]["parameters"]
# Check that dynamic fields are handled
properties = signature["function"]["parameters"]["properties"]
assert len(properties) == 3 # Should have all three fields
# Each dynamic field should have proper schema
for prop_value in properties.values():
assert "type" in prop_value
assert prop_value["type"] == "string" # Dynamic fields get string type
assert "description" in prop_value
assert "Dynamic value for" in prop_value["description"]
@pytest.mark.asyncio
async def test_smart_decision_maker_handles_dynamic_list_fields():
"""Test Smart Decision Maker can handle dynamic list fields (_$_) for any block"""
# Create a mock node for AddToListBlock
mock_node = Mock()
mock_node.block = AddToListBlock()
mock_node.block_id = AddToListBlock().id
mock_node.input_default = {}
# Create mock links with dynamic list fields
mock_links = [
Mock(
source_name="tools_^_add_to_list_~_0",
sink_name="entries_$_0", # Dynamic list field
sink_id="list_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_add_to_list_~_1",
sink_name="entries_$_1", # Dynamic list field
sink_id="list_node_id",
source_id="smart_decision_node_id",
),
]
# Generate function signature
signature = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, mock_links # type: ignore
)
# Verify dynamic list fields are handled properly
assert signature["type"] == "function"
properties = signature["function"]["parameters"]["properties"]
assert len(properties) == 2 # Should have both list items
# Each dynamic field should have proper schema
for prop_value in properties.values():
assert prop_value["type"] == "string"
assert "Dynamic value for" in prop_value["description"]
@pytest.mark.asyncio
async def test_create_dict_block_with_dynamic_values():
"""Test CreateDictionaryBlock processes dynamic values correctly"""
block = CreateDictionaryBlock()
# Simulate what happens when executor merges dynamic fields
# The executor merges values_#_* fields into the values dict
input_data = block.input_schema(
values={
"existing": "value",
"name": "Alice", # This would come from values_#_name
"age": 25, # This would come from values_#_age
}
)
# Run the block
result = {}
async for output_name, output_value in block.run(input_data):
result[output_name] = output_value
# Check the result
assert "dictionary" in result
assert result["dictionary"]["existing"] == "value"
assert result["dictionary"]["name"] == "Alice"
assert result["dictionary"]["age"] == 25

View File

@@ -1,78 +1,19 @@
import asyncio
import time
from datetime import datetime, timedelta
from typing import Any, Literal, Union
from zoneinfo import ZoneInfo
from pydantic import BaseModel
from typing import Any, Union
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
# Shared timezone literal type for all time/date blocks
TimezoneLiteral = Literal[
"UTC", # UTC±00:00
"Pacific/Honolulu", # UTC-10:00
"America/Anchorage", # UTC-09:00 (Alaska)
"America/Los_Angeles", # UTC-08:00 (Pacific)
"America/Denver", # UTC-07:00 (Mountain)
"America/Chicago", # UTC-06:00 (Central)
"America/New_York", # UTC-05:00 (Eastern)
"America/Caracas", # UTC-04:00
"America/Sao_Paulo", # UTC-03:00
"America/St_Johns", # UTC-02:30 (Newfoundland)
"Atlantic/South_Georgia", # UTC-02:00
"Atlantic/Azores", # UTC-01:00
"Europe/London", # UTC+00:00 (GMT/BST)
"Europe/Paris", # UTC+01:00 (CET)
"Europe/Athens", # UTC+02:00 (EET)
"Europe/Moscow", # UTC+03:00
"Asia/Tehran", # UTC+03:30 (Iran)
"Asia/Dubai", # UTC+04:00
"Asia/Kabul", # UTC+04:30 (Afghanistan)
"Asia/Karachi", # UTC+05:00 (Pakistan)
"Asia/Kolkata", # UTC+05:30 (India)
"Asia/Kathmandu", # UTC+05:45 (Nepal)
"Asia/Dhaka", # UTC+06:00 (Bangladesh)
"Asia/Yangon", # UTC+06:30 (Myanmar)
"Asia/Bangkok", # UTC+07:00
"Asia/Shanghai", # UTC+08:00 (China)
"Australia/Eucla", # UTC+08:45
"Asia/Tokyo", # UTC+09:00 (Japan)
"Australia/Adelaide", # UTC+09:30
"Australia/Sydney", # UTC+10:00
"Australia/Lord_Howe", # UTC+10:30
"Pacific/Noumea", # UTC+11:00
"Pacific/Auckland", # UTC+12:00 (New Zealand)
"Pacific/Chatham", # UTC+12:45
"Pacific/Tongatapu", # UTC+13:00
"Pacific/Kiritimati", # UTC+14:00
"Etc/GMT-12", # UTC+12:00
"Etc/GMT+12", # UTC-12:00
]
class TimeStrftimeFormat(BaseModel):
discriminator: Literal["strftime"]
format: str = "%H:%M:%S"
timezone: TimezoneLiteral = "UTC"
class TimeISO8601Format(BaseModel):
discriminator: Literal["iso8601"]
timezone: TimezoneLiteral = "UTC"
include_microseconds: bool = False
class GetCurrentTimeBlock(Block):
class Input(BlockSchema):
trigger: str = SchemaField(
description="Trigger any data to output the current time"
)
format_type: Union[TimeStrftimeFormat, TimeISO8601Format] = SchemaField(
discriminator="discriminator",
description="Format type for time output (strftime with custom format or ISO 8601)",
default=TimeStrftimeFormat(discriminator="strftime"),
format: str = SchemaField(
description="Format of the time to output", default="%H:%M:%S"
)
class Output(BlockSchema):
@@ -89,65 +30,19 @@ class GetCurrentTimeBlock(Block):
output_schema=GetCurrentTimeBlock.Output,
test_input=[
{"trigger": "Hello"},
{
"trigger": "Hello",
"format_type": {
"discriminator": "strftime",
"format": "%H:%M",
},
},
{
"trigger": "Hello",
"format_type": {
"discriminator": "iso8601",
"timezone": "UTC",
"include_microseconds": False,
},
},
{"trigger": "Hello", "format": "%H:%M"},
],
test_output=[
("time", lambda _: time.strftime("%H:%M:%S")),
("time", lambda _: time.strftime("%H:%M")),
(
"time",
lambda t: "T" in t and ("+" in t or "Z" in t),
), # Check for ISO format with timezone
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
if isinstance(input_data.format_type, TimeISO8601Format):
# ISO 8601 format for time only (extract time portion from full ISO datetime)
tz = ZoneInfo(input_data.format_type.timezone)
dt = datetime.now(tz=tz)
# Get the full ISO format and extract just the time portion with timezone
if input_data.format_type.include_microseconds:
full_iso = dt.isoformat()
else:
full_iso = dt.isoformat(timespec="seconds")
# Extract time portion (everything after 'T')
current_time = full_iso.split("T")[1] if "T" in full_iso else full_iso
current_time = f"T{current_time}" # Add T prefix for ISO 8601 time format
else: # TimeStrftimeFormat
tz = ZoneInfo(input_data.format_type.timezone)
dt = datetime.now(tz=tz)
current_time = dt.strftime(input_data.format_type.format)
current_time = time.strftime(input_data.format)
yield "time", current_time
class DateStrftimeFormat(BaseModel):
discriminator: Literal["strftime"]
format: str = "%Y-%m-%d"
timezone: TimezoneLiteral = "UTC"
class DateISO8601Format(BaseModel):
discriminator: Literal["iso8601"]
timezone: TimezoneLiteral = "UTC"
class GetCurrentDateBlock(Block):
class Input(BlockSchema):
trigger: str = SchemaField(
@@ -158,10 +53,8 @@ class GetCurrentDateBlock(Block):
description="Offset in days from the current date",
default=0,
)
format_type: Union[DateStrftimeFormat, DateISO8601Format] = SchemaField(
discriminator="discriminator",
description="Format type for date output (strftime with custom format or ISO 8601)",
default=DateStrftimeFormat(discriminator="strftime"),
format: str = SchemaField(
description="Format of the date to output", default="%Y-%m-%d"
)
class Output(BlockSchema):
@@ -178,22 +71,7 @@ class GetCurrentDateBlock(Block):
output_schema=GetCurrentDateBlock.Output,
test_input=[
{"trigger": "Hello", "offset": "7"},
{
"trigger": "Hello",
"offset": "7",
"format_type": {
"discriminator": "strftime",
"format": "%m/%d/%Y",
},
},
{
"trigger": "Hello",
"offset": "0",
"format_type": {
"discriminator": "iso8601",
"timezone": "UTC",
},
},
{"trigger": "Hello", "offset": "7", "format": "%m/%d/%Y"},
],
test_output=[
(
@@ -207,12 +85,6 @@ class GetCurrentDateBlock(Block):
< timedelta(days=8),
# 7 days difference + 1 day error margin.
),
(
"date",
lambda t: len(t) == 10
and t[4] == "-"
and t[7] == "-", # ISO date format YYYY-MM-DD
),
],
)
@@ -221,31 +93,8 @@ class GetCurrentDateBlock(Block):
offset = int(input_data.offset)
except ValueError:
offset = 0
if isinstance(input_data.format_type, DateISO8601Format):
# ISO 8601 format for date only (YYYY-MM-DD)
tz = ZoneInfo(input_data.format_type.timezone)
current_date = datetime.now(tz=tz) - timedelta(days=offset)
# ISO 8601 date format is YYYY-MM-DD
date_str = current_date.date().isoformat()
else: # DateStrftimeFormat
tz = ZoneInfo(input_data.format_type.timezone)
current_date = datetime.now(tz=tz) - timedelta(days=offset)
date_str = current_date.strftime(input_data.format_type.format)
yield "date", date_str
class StrftimeFormat(BaseModel):
discriminator: Literal["strftime"]
format: str = "%Y-%m-%d %H:%M:%S"
timezone: TimezoneLiteral = "UTC"
class ISO8601Format(BaseModel):
discriminator: Literal["iso8601"]
timezone: TimezoneLiteral = "UTC"
include_microseconds: bool = False
current_date = datetime.now() - timedelta(days=offset)
yield "date", current_date.strftime(input_data.format)
class GetCurrentDateAndTimeBlock(Block):
@@ -253,10 +102,9 @@ class GetCurrentDateAndTimeBlock(Block):
trigger: str = SchemaField(
description="Trigger any data to output the current date and time"
)
format_type: Union[StrftimeFormat, ISO8601Format] = SchemaField(
discriminator="discriminator",
description="Format type for date and time output (strftime with custom format or ISO 8601/RFC 3339)",
default=StrftimeFormat(discriminator="strftime"),
format: str = SchemaField(
description="Format of the date and time to output",
default="%Y-%m-%d %H:%M:%S",
)
class Output(BlockSchema):
@@ -273,63 +121,20 @@ class GetCurrentDateAndTimeBlock(Block):
output_schema=GetCurrentDateAndTimeBlock.Output,
test_input=[
{"trigger": "Hello"},
{
"trigger": "Hello",
"format_type": {
"discriminator": "strftime",
"format": "%Y/%m/%d",
},
},
{
"trigger": "Hello",
"format_type": {
"discriminator": "iso8601",
"timezone": "UTC",
"include_microseconds": False,
},
},
],
test_output=[
(
"date_time",
lambda t: abs(
datetime.now(tz=ZoneInfo("UTC"))
- datetime.strptime(t + "+00:00", "%Y-%m-%d %H:%M:%S%z")
datetime.now() - datetime.strptime(t, "%Y-%m-%d %H:%M:%S")
)
< timedelta(seconds=10), # 10 seconds error margin.
),
(
"date_time",
lambda t: abs(
datetime.now().date() - datetime.strptime(t, "%Y/%m/%d").date()
)
< timedelta(days=1), # Date format only, no time component
),
(
"date_time",
lambda t: abs(
datetime.now(tz=ZoneInfo("UTC")) - datetime.fromisoformat(t)
)
< timedelta(seconds=10), # 10 seconds error margin for ISO format.
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
if isinstance(input_data.format_type, ISO8601Format):
# ISO 8601 format with specified timezone (also RFC3339-compliant)
tz = ZoneInfo(input_data.format_type.timezone)
dt = datetime.now(tz=tz)
# Format with or without microseconds
if input_data.format_type.include_microseconds:
current_date_time = dt.isoformat()
else:
current_date_time = dt.isoformat(timespec="seconds")
else: # StrftimeFormat
tz = ZoneInfo(input_data.format_type.timezone)
dt = datetime.now(tz=tz)
current_date_time = dt.strftime(input_data.format_type.format)
current_date_time = time.strftime(input_data.format)
yield "date_time", current_date_time

View File

@@ -5,12 +5,6 @@ from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
from backend.blocks.apollo.organization import SearchOrganizationsBlock
from backend.blocks.apollo.people import SearchPeopleBlock
from backend.blocks.apollo.person import GetPersonDetailBlock
from backend.blocks.enrichlayer.linkedin import (
GetLinkedinProfileBlock,
GetLinkedinProfilePictureBlock,
LinkedinPersonLookupBlock,
LinkedinRoleLookupBlock,
)
from backend.blocks.flux_kontext import AIImageEditorBlock, FluxKontextModelName
from backend.blocks.ideogram import IdeogramModelBlock
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
@@ -36,7 +30,6 @@ from backend.integrations.credentials_store import (
anthropic_credentials,
apollo_credentials,
did_credentials,
enrichlayer_credentials,
groq_credentials,
ideogram_credentials,
jina_credentials,
@@ -46,7 +39,6 @@ from backend.integrations.credentials_store import (
replicate_credentials,
revid_credentials,
unreal_credentials,
v0_credentials,
)
# =============== Configure the cost for each LLM Model call =============== #
@@ -56,18 +48,12 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.O3_MINI: 2, # $1.10 / $4.40
LlmModel.O1: 16, # $15 / $60
LlmModel.O1_MINI: 4,
# GPT-5 models
LlmModel.GPT5: 2,
LlmModel.GPT5_MINI: 1,
LlmModel.GPT5_NANO: 1,
LlmModel.GPT5_CHAT: 2,
LlmModel.GPT41: 2,
LlmModel.GPT41_MINI: 1,
LlmModel.GPT4O_MINI: 1,
LlmModel.GPT4O: 3,
LlmModel.GPT4_TURBO: 10,
LlmModel.GPT3_5_TURBO: 1,
LlmModel.CLAUDE_4_1_OPUS: 21,
LlmModel.CLAUDE_4_OPUS: 21,
LlmModel.CLAUDE_4_SONNET: 5,
LlmModel.CLAUDE_3_7_SONNET: 5,
@@ -123,10 +109,6 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
LlmModel.DEEPSEEK_R1_0528: 1,
# v0 by Vercel models
LlmModel.V0_1_5_MD: 1,
LlmModel.V0_1_5_LG: 2,
LlmModel.V0_1_0_MD: 1,
}
for model in LlmModel:
@@ -216,23 +198,6 @@ LLM_COST = (
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "llama_api"
]
# v0 by Vercel Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": v0_credentials.id,
"provider": v0_credentials.provider,
"type": v0_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "v0"
]
# AI/ML Api Models
+ [
BlockCost(
@@ -405,54 +370,6 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
},
)
],
GetLinkedinProfileBlock: [
BlockCost(
cost_amount=1,
cost_filter={
"credentials": {
"id": enrichlayer_credentials.id,
"provider": enrichlayer_credentials.provider,
"type": enrichlayer_credentials.type,
}
},
)
],
LinkedinPersonLookupBlock: [
BlockCost(
cost_amount=2,
cost_filter={
"credentials": {
"id": enrichlayer_credentials.id,
"provider": enrichlayer_credentials.provider,
"type": enrichlayer_credentials.type,
}
},
)
],
LinkedinRoleLookupBlock: [
BlockCost(
cost_amount=3,
cost_filter={
"credentials": {
"id": enrichlayer_credentials.id,
"provider": enrichlayer_credentials.provider,
"type": enrichlayer_credentials.type,
}
},
)
],
GetLinkedinProfilePictureBlock: [
BlockCost(
cost_amount=3,
cost_filter={
"credentials": {
"id": enrichlayer_credentials.id,
"provider": enrichlayer_credentials.provider,
"type": enrichlayer_credentials.type,
}
},
)
],
SmartDecisionMakerBlock: LLM_COST,
SearchOrganizationsBlock: [
BlockCost(

View File

@@ -34,10 +34,10 @@ from backend.data.model import (
from backend.data.notifications import NotificationEventModel, RefundRequestData
from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications.notifications import queue_notification_async
from backend.server.model import Pagination
from backend.server.v2.admin.model import UserHistoryResponse
from backend.util.exceptions import InsufficientBalanceError
from backend.util.json import SafeJson
from backend.util.models import Pagination
from backend.util.retry import func_retry
from backend.util.settings import Settings
@@ -286,17 +286,11 @@ class UserCreditBase(ABC):
transaction = await CreditTransaction.prisma().find_first_or_raise(
where={"transactionKey": transaction_key, "userId": user_id}
)
if transaction.isActive:
return
async with db.locked_transaction(f"usr_trx_{user_id}"):
transaction = await CreditTransaction.prisma().find_first_or_raise(
where={"transactionKey": transaction_key, "userId": user_id}
)
if transaction.isActive:
return
user_balance, _ = await self._get_credits(user_id)
await CreditTransaction.prisma().update(
where={
@@ -1004,8 +998,8 @@ def get_block_costs() -> dict[str, list[BlockCost]]:
async def get_stripe_customer_id(user_id: str) -> str:
user = await get_user_by_id(user_id)
if user.stripe_customer_id:
return user.stripe_customer_id
if user.stripeCustomerId:
return user.stripeCustomerId
customer = stripe.Customer.create(
name=user.name or "",
@@ -1028,10 +1022,10 @@ async def set_auto_top_up(user_id: str, config: AutoTopUpConfig):
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
user = await get_user_by_id(user_id)
if not user.top_up_config:
if not user.topUpConfig:
return AutoTopUpConfig(threshold=0, amount=0)
return AutoTopUpConfig.model_validate(user.top_up_config)
return AutoTopUpConfig.model_validate(user.topUpConfig)
async def admin_get_user_history(

View File

@@ -33,7 +33,7 @@ from prisma.types import (
AgentNodeExecutionUpdateInput,
AgentNodeExecutionWhereInput,
)
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
from pydantic import BaseModel, ConfigDict, JsonValue
from pydantic.fields import Field
from backend.server.v2.store.exceptions import DatabaseError
@@ -59,7 +59,7 @@ from .includes import (
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
graph_execution_include,
)
from .model import GraphExecutionStats, NodeExecutionStats
from .model import GraphExecutionStats
T = TypeVar("T")
@@ -318,30 +318,18 @@ class NodeExecutionResult(BaseModel):
@staticmethod
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
try:
stats = NodeExecutionStats.model_validate(_node_exec.stats or {})
except (ValueError, ValidationError):
stats = NodeExecutionStats()
if stats.cleared_inputs:
input_data: BlockInput = defaultdict()
for name, messages in stats.cleared_inputs.items():
input_data[name] = messages[-1] if messages else ""
elif _node_exec.executionData:
if _node_exec.executionData:
# Execution that has been queued for execution will persist its data.
input_data = type_utils.convert(_node_exec.executionData, dict[str, Any])
else:
# For incomplete execution, executionData will not be yet available.
input_data: BlockInput = defaultdict()
for data in _node_exec.Input or []:
input_data[data.name] = type_utils.convert(data.data, type[Any])
output_data: CompletedBlockOutput = defaultdict(list)
if stats.cleared_outputs:
for name, messages in stats.cleared_outputs.items():
output_data[name].extend(messages)
else:
for data in _node_exec.Output or []:
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
for data in _node_exec.Output or []:
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
graph_execution: AgentGraphExecution | None = _node_exec.GraphExecution
if graph_execution:

View File

@@ -1,109 +0,0 @@
import logging
from collections import defaultdict
from datetime import datetime
from prisma.enums import AgentExecutionStatus
from backend.data.execution import get_graph_executions
from backend.data.graph import get_graph_metadata
from backend.data.model import UserExecutionSummaryStats
from backend.server.v2.store.exceptions import DatabaseError
from backend.util.logging import TruncatedLogger
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[SummaryData]")
async def get_user_execution_summary_data(
user_id: str, start_time: datetime, end_time: datetime
) -> UserExecutionSummaryStats:
"""Gather all summary data for a user in a time range.
This function fetches graph executions once and aggregates all required
statistics in a single pass for efficiency.
"""
try:
# Fetch graph executions once
executions = await get_graph_executions(
user_id=user_id,
created_time_gte=start_time,
created_time_lte=end_time,
)
# Initialize aggregation variables
total_credits_used = 0.0
total_executions = len(executions)
successful_runs = 0
failed_runs = 0
terminated_runs = 0
execution_times = []
agent_usage = defaultdict(int)
cost_by_graph_id = defaultdict(float)
# Single pass through executions to aggregate all stats
for execution in executions:
# Count execution statuses (including TERMINATED as failed)
if execution.status == AgentExecutionStatus.COMPLETED:
successful_runs += 1
elif execution.status == AgentExecutionStatus.FAILED:
failed_runs += 1
elif execution.status == AgentExecutionStatus.TERMINATED:
terminated_runs += 1
# Aggregate costs from stats
if execution.stats and hasattr(execution.stats, "cost"):
cost_in_dollars = execution.stats.cost / 100
total_credits_used += cost_in_dollars
cost_by_graph_id[execution.graph_id] += cost_in_dollars
# Collect execution times
if execution.stats and hasattr(execution.stats, "duration"):
execution_times.append(execution.stats.duration)
# Count agent usage
agent_usage[execution.graph_id] += 1
# Calculate derived stats
total_execution_time = sum(execution_times)
average_execution_time = (
total_execution_time / len(execution_times) if execution_times else 0
)
# Find most used agent
most_used_agent = "No agents used"
if agent_usage:
most_used_agent_id = max(agent_usage, key=lambda k: agent_usage[k])
try:
graph_meta = await get_graph_metadata(graph_id=most_used_agent_id)
most_used_agent = (
graph_meta.name if graph_meta else f"Agent {most_used_agent_id[:8]}"
)
except Exception:
logger.warning(f"Could not get metadata for graph {most_used_agent_id}")
most_used_agent = f"Agent {most_used_agent_id[:8]}"
# Convert graph_ids to agent names for cost breakdown
cost_breakdown = {}
for graph_id, cost in cost_by_graph_id.items():
try:
graph_meta = await get_graph_metadata(graph_id=graph_id)
agent_name = graph_meta.name if graph_meta else f"Agent {graph_id[:8]}"
except Exception:
logger.warning(f"Could not get metadata for graph {graph_id}")
agent_name = f"Agent {graph_id[:8]}"
cost_breakdown[agent_name] = cost
# Build the summary stats object (include terminated runs as failed)
return UserExecutionSummaryStats(
total_credits_used=total_credits_used,
total_executions=total_executions,
successful_runs=successful_runs,
failed_runs=failed_runs + terminated_runs,
most_used_agent=most_used_agent,
total_execution_time=total_execution_time,
average_execution_time=average_execution_time,
cost_breakdown=cost_breakdown,
)
except Exception as e:
logger.error(f"Failed to get user summary data: {e}")
raise DatabaseError(f"Failed to get user summary data: {e}") from e

View File

@@ -5,7 +5,6 @@ import enum
import logging
from collections import defaultdict
from datetime import datetime, timezone
from json import JSONDecodeError
from typing import (
TYPE_CHECKING,
Annotated,
@@ -41,120 +40,12 @@ from pydantic_core import (
from typing_extensions import TypedDict
from backend.integrations.providers import ProviderName
from backend.util.json import loads as json_loads
from backend.util.settings import Secrets
# Type alias for any provider name (including custom ones)
AnyProviderName = str # Will be validated as ProviderName at runtime
class User(BaseModel):
"""Application-layer User model with snake_case convention."""
model_config = ConfigDict(
extra="forbid",
str_strip_whitespace=True,
)
id: str = Field(..., description="User ID")
email: str = Field(..., description="User email address")
email_verified: bool = Field(default=True, description="Whether email is verified")
name: Optional[str] = Field(None, description="User display name")
created_at: datetime = Field(..., description="When user was created")
updated_at: datetime = Field(..., description="When user was last updated")
metadata: dict[str, Any] = Field(
default_factory=dict, description="User metadata as dict"
)
integrations: str = Field(default="", description="Encrypted integrations data")
stripe_customer_id: Optional[str] = Field(None, description="Stripe customer ID")
top_up_config: Optional["AutoTopUpConfig"] = Field(
None, description="Top up configuration"
)
# Notification preferences
max_emails_per_day: int = Field(default=3, description="Maximum emails per day")
notify_on_agent_run: bool = Field(default=True, description="Notify on agent run")
notify_on_zero_balance: bool = Field(
default=True, description="Notify on zero balance"
)
notify_on_low_balance: bool = Field(
default=True, description="Notify on low balance"
)
notify_on_block_execution_failed: bool = Field(
default=True, description="Notify on block execution failure"
)
notify_on_continuous_agent_error: bool = Field(
default=True, description="Notify on continuous agent error"
)
notify_on_daily_summary: bool = Field(
default=True, description="Notify on daily summary"
)
notify_on_weekly_summary: bool = Field(
default=True, description="Notify on weekly summary"
)
notify_on_monthly_summary: bool = Field(
default=True, description="Notify on monthly summary"
)
@classmethod
def from_db(cls, prisma_user: "PrismaUser") -> "User":
"""Convert a database User object to application User model."""
# Handle metadata field - convert from JSON string or dict to dict
metadata = {}
if prisma_user.metadata:
if isinstance(prisma_user.metadata, str):
try:
metadata = json_loads(prisma_user.metadata)
except (JSONDecodeError, TypeError):
metadata = {}
elif isinstance(prisma_user.metadata, dict):
metadata = prisma_user.metadata
# Handle topUpConfig field
top_up_config = None
if prisma_user.topUpConfig:
if isinstance(prisma_user.topUpConfig, str):
try:
config_dict = json_loads(prisma_user.topUpConfig)
top_up_config = AutoTopUpConfig.model_validate(config_dict)
except (JSONDecodeError, TypeError, ValueError):
top_up_config = None
elif isinstance(prisma_user.topUpConfig, dict):
try:
top_up_config = AutoTopUpConfig.model_validate(
prisma_user.topUpConfig
)
except ValueError:
top_up_config = None
return cls(
id=prisma_user.id,
email=prisma_user.email,
email_verified=prisma_user.emailVerified or True,
name=prisma_user.name,
created_at=prisma_user.createdAt,
updated_at=prisma_user.updatedAt,
metadata=metadata,
integrations=prisma_user.integrations or "",
stripe_customer_id=prisma_user.stripeCustomerId,
top_up_config=top_up_config,
max_emails_per_day=prisma_user.maxEmailsPerDay or 3,
notify_on_agent_run=prisma_user.notifyOnAgentRun or True,
notify_on_zero_balance=prisma_user.notifyOnZeroBalance or True,
notify_on_low_balance=prisma_user.notifyOnLowBalance or True,
notify_on_block_execution_failed=prisma_user.notifyOnBlockExecutionFailed
or True,
notify_on_continuous_agent_error=prisma_user.notifyOnContinuousAgentError
or True,
notify_on_daily_summary=prisma_user.notifyOnDailySummary or True,
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
)
if TYPE_CHECKING:
from prisma.models import User as PrismaUser
from backend.data.block import BlockSchema
T = TypeVar("T")
@@ -764,9 +655,6 @@ class NodeExecutionStats(BaseModel):
output_token_count: int = 0
extra_cost: int = 0
extra_steps: int = 0
# Moderation fields
cleared_inputs: Optional[dict[str, list[str]]] = None
cleared_outputs: Optional[dict[str, list[str]]] = None
def __iadd__(self, other: "NodeExecutionStats") -> "NodeExecutionStats":
"""Mutate this instance by adding another NodeExecutionStats."""
@@ -821,21 +709,3 @@ class GraphExecutionStats(BaseModel):
activity_status: Optional[str] = Field(
default=None, description="AI-generated summary of what the agent did"
)
class UserExecutionSummaryStats(BaseModel):
"""Summary of user statistics for a specific user."""
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
)
total_credits_used: float = Field(default=0)
total_executions: int = Field(default=0)
successful_runs: int = Field(default=0)
failed_runs: int = Field(default=0)
most_used_agent: str = Field(default="")
total_execution_time: float = Field(default=0)
average_execution_time: float = Field(default=0)
cost_breakdown: dict[str, float] = Field(default_factory=dict)

View File

@@ -140,7 +140,6 @@ class SyncRabbitMQ(RabbitMQBase):
socket_timeout=SOCKET_TIMEOUT,
connection_attempts=CONNECTION_ATTEMPTS,
retry_delay=RETRY_DELAY,
heartbeat=300, # 5 minute timeout (heartbeats sent every 2.5 min)
)
self._connection = pika.BlockingConnection(parameters)
@@ -245,7 +244,6 @@ class AsyncRabbitMQ(RabbitMQBase):
password=self.password,
virtualhost=self.config.vhost.lstrip("/"),
blocked_connection_timeout=BLOCKED_CONNECTION_TIMEOUT,
heartbeat=300, # 5 minute timeout (heartbeats sent every 2.5 min)
)
self._channel = await self._connection.channel()
await self._channel.set_qos(prefetch_count=1)

View File

@@ -9,11 +9,11 @@ from urllib.parse import quote_plus
from autogpt_libs.auth.models import DEFAULT_USER_ID
from fastapi import HTTPException
from prisma.enums import NotificationType
from prisma.models import User as PrismaUser
from prisma.models import User
from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
from backend.data.db import prisma
from backend.data.model import User, UserIntegrations, UserMetadata
from backend.data.model import UserIntegrations, UserMetadata
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.server.v2.store.exceptions import DatabaseError
from backend.util.encryption import JSONCryptor
@@ -21,7 +21,6 @@ from backend.util.json import SafeJson
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
async def get_or_create_user(user_data: dict) -> User:
@@ -44,7 +43,7 @@ async def get_or_create_user(user_data: dict) -> User:
)
)
return User.from_db(user)
return User.model_validate(user)
except Exception as e:
raise DatabaseError(f"Failed to get or create user {user_data}: {e}") from e
@@ -53,7 +52,7 @@ async def get_user_by_id(user_id: str) -> User:
user = await prisma.user.find_unique(where={"id": user_id})
if not user:
raise ValueError(f"User not found with ID: {user_id}")
return User.from_db(user)
return User.model_validate(user)
async def get_user_email_by_id(user_id: str) -> Optional[str]:
@@ -67,7 +66,7 @@ async def get_user_email_by_id(user_id: str) -> Optional[str]:
async def get_user_by_email(email: str) -> Optional[User]:
try:
user = await prisma.user.find_unique(where={"email": email})
return User.from_db(user) if user else None
return User.model_validate(user) if user else None
except Exception as e:
raise DatabaseError(f"Failed to get user by email {email}: {e}") from e
@@ -91,11 +90,11 @@ async def create_default_user() -> Optional[User]:
name="Default User",
)
)
return User.from_db(user)
return User.model_validate(user)
async def get_user_integrations(user_id: str) -> UserIntegrations:
user = await PrismaUser.prisma().find_unique_or_raise(
user = await User.prisma().find_unique_or_raise(
where={"id": user_id},
)
@@ -110,7 +109,7 @@ async def get_user_integrations(user_id: str) -> UserIntegrations:
async def update_user_integrations(user_id: str, data: UserIntegrations):
encrypted_data = JSONCryptor().encrypt(data.model_dump(exclude_none=True))
await PrismaUser.prisma().update(
await User.prisma().update(
where={"id": user_id},
data={"integrations": encrypted_data},
)
@@ -118,7 +117,7 @@ async def update_user_integrations(user_id: str, data: UserIntegrations):
async def migrate_and_encrypt_user_integrations():
"""Migrate integration credentials and OAuth states from metadata to integrations column."""
users = await PrismaUser.prisma().find_many(
users = await User.prisma().find_many(
where={
"metadata": cast(
JsonFilter,
@@ -154,7 +153,7 @@ async def migrate_and_encrypt_user_integrations():
raw_metadata.pop("integration_oauth_states", None)
# Update metadata without integration data
await PrismaUser.prisma().update(
await User.prisma().update(
where={"id": user.id},
data={"metadata": SafeJson(raw_metadata)},
)
@@ -162,7 +161,7 @@ async def migrate_and_encrypt_user_integrations():
async def get_active_user_ids_in_timerange(start_time: str, end_time: str) -> list[str]:
try:
users = await PrismaUser.prisma().find_many(
users = await User.prisma().find_many(
where={
"AgentGraphExecutions": {
"some": {
@@ -192,7 +191,7 @@ async def get_active_users_ids() -> list[str]:
async def get_user_notification_preference(user_id: str) -> NotificationPreference:
try:
user = await PrismaUser.prisma().find_unique_or_raise(
user = await User.prisma().find_unique_or_raise(
where={"id": user_id},
)
@@ -269,7 +268,7 @@ async def update_user_notification_preference(
if data.daily_limit:
update_data["maxEmailsPerDay"] = data.daily_limit
user = await PrismaUser.prisma().update(
user = await User.prisma().update(
where={"id": user_id},
data=update_data,
)
@@ -307,7 +306,7 @@ async def update_user_notification_preference(
async def set_user_email_verification(user_id: str, verified: bool) -> None:
"""Set the email verification status for a user."""
try:
await PrismaUser.prisma().update(
await User.prisma().update(
where={"id": user_id},
data={"emailVerified": verified},
)
@@ -320,7 +319,7 @@ async def set_user_email_verification(user_id: str, verified: bool) -> None:
async def get_user_email_verification(user_id: str) -> bool:
"""Get the email verification status for a user."""
try:
user = await PrismaUser.prisma().find_unique_or_raise(
user = await User.prisma().find_unique_or_raise(
where={"id": user_id},
)
return user.emailVerified
@@ -333,7 +332,7 @@ async def get_user_email_verification(user_id: str) -> bool:
def generate_unsubscribe_link(user_id: str) -> str:
"""Generate a link to unsubscribe from all notifications"""
# Create an HMAC using a secret key
secret_key = settings.secrets.unsubscribe_secret_key
secret_key = Settings().secrets.unsubscribe_secret_key
signature = hmac.new(
secret_key.encode("utf-8"), user_id.encode("utf-8"), hashlib.sha256
).digest()
@@ -344,7 +343,7 @@ def generate_unsubscribe_link(user_id: str) -> str:
).decode("utf-8")
logger.info(f"Generating unsubscribe link for user {user_id}")
base_url = settings.config.platform_base_url
base_url = Settings().config.platform_base_url
return f"{base_url}/api/email/unsubscribe?token={quote_plus(token)}"
@@ -356,7 +355,7 @@ async def unsubscribe_user_by_token(token: str) -> None:
user_id, received_signature_hex = decoded.split(":", 1)
# Verify the signature
secret_key = settings.secrets.unsubscribe_secret_key
secret_key = Settings().secrets.unsubscribe_secret_key
expected_signature = hmac.new(
secret_key.encode("utf-8"), user_id.encode("utf-8"), hashlib.sha256
).digest()

View File

@@ -6,17 +6,20 @@ import json
import logging
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
from autogpt_libs.feature_flag.client import is_feature_enabled
from pydantic import SecretStr
from backend.blocks.llm import LlmModel, llm_call
from backend.data.block import get_block
from backend.data.execution import ExecutionStatus, NodeExecutionResult
from backend.data.model import APIKeyCredentials, GraphExecutionStats
from backend.util.feature_flag import Flag, is_feature_enabled
from backend.util.retry import func_retry
from backend.util.settings import Settings
from backend.util.truncate import truncate
# LaunchDarkly feature flag key for AI activity status generation
AI_ACTIVITY_STATUS_FLAG_KEY = "ai-agent-execution-summary"
if TYPE_CHECKING:
from backend.executor import DatabaseManagerAsyncClient
@@ -99,8 +102,8 @@ async def generate_activity_status_for_execution(
Returns:
AI-generated activity status string, or None if feature is disabled
"""
# Check LaunchDarkly feature flag for AI activity status generation with full context support
if not await is_feature_enabled(Flag.AI_ACTIVITY_STATUS, user_id):
# Check LaunchDarkly feature flag for AI activity status generation
if not is_feature_enabled(AI_ACTIVITY_STATUS_FLAG_KEY, user_id, default=False):
logger.debug("AI activity status generation is disabled via LaunchDarkly")
return None

View File

@@ -20,7 +20,6 @@ from backend.data.execution import (
upsert_execution_input,
upsert_execution_output,
)
from backend.data.generate_data import get_user_execution_summary_data
from backend.data.graph import (
get_connected_output_nodes,
get_graph,
@@ -42,13 +41,7 @@ from backend.data.user import (
get_user_notification_preference,
update_user_integrations,
)
from backend.util.service import (
AppService,
AppServiceClient,
UnhealthyServiceError,
endpoint_to_sync,
expose,
)
from backend.util.service import AppService, AppServiceClient, endpoint_to_sync, expose
from backend.util.settings import Config
config = Config()
@@ -80,10 +73,10 @@ class DatabaseManager(AppService):
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
self.run_and_wait(db.disconnect())
async def health_check(self) -> str:
def health_check(self) -> str:
if not db.is_connected():
raise UnhealthyServiceError("Database is not connected")
return await super().health_check()
raise RuntimeError("Database is not connected")
return super().health_check()
@classmethod
def get_port(cls) -> int:
@@ -145,9 +138,6 @@ class DatabaseManager(AppService):
get_user_notification_oldest_message_in_batch
)
# Summary data - async
get_user_execution_summary_data = _(get_user_execution_summary_data)
class DatabaseManagerClient(AppServiceClient):
d = DatabaseManager
@@ -173,8 +163,22 @@ class DatabaseManagerClient(AppServiceClient):
spend_credits = _(d.spend_credits)
get_credits = _(d.get_credits)
# Summary data - async
get_user_execution_summary_data = _(d.get_user_execution_summary_data)
# User Comms - async
get_active_user_ids_in_timerange = _(d.get_active_user_ids_in_timerange)
get_user_email_by_id = _(d.get_user_email_by_id)
get_user_email_verification = _(d.get_user_email_verification)
get_user_notification_preference = _(d.get_user_notification_preference)
# Notifications - async
create_or_add_to_user_notification_batch = _(
d.create_or_add_to_user_notification_batch
)
empty_user_notification_batch = _(d.empty_user_notification_batch)
get_all_batches_by_type = _(d.get_all_batches_by_type)
get_user_notification_batch = _(d.get_user_notification_batch)
get_user_notification_oldest_message_in_batch = _(
d.get_user_notification_oldest_message_in_batch
)
# Block error monitoring
get_block_error_stats = _(d.get_block_error_stats)
@@ -205,23 +209,3 @@ class DatabaseManagerAsyncClient(AppServiceClient):
update_user_integrations = d.update_user_integrations
get_execution_kv_data = d.get_execution_kv_data
set_execution_kv_data = d.set_execution_kv_data
# User Comms
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
get_user_email_by_id = d.get_user_email_by_id
get_user_email_verification = d.get_user_email_verification
get_user_notification_preference = d.get_user_notification_preference
# Notifications
create_or_add_to_user_notification_batch = (
d.create_or_add_to_user_notification_batch
)
empty_user_notification_batch = d.empty_user_notification_batch
get_all_batches_by_type = d.get_all_batches_by_type
get_user_notification_batch = d.get_user_notification_batch
get_user_notification_oldest_message_in_batch = (
d.get_user_notification_oldest_message_in_batch
)
# Summary data
get_user_execution_summary_data = d.get_user_execution_summary_data

View File

@@ -27,7 +27,7 @@ from backend.executor.activity_status_generator import (
)
from backend.executor.utils import LogMetadata
from backend.notifications.notifications import queue_notification
from backend.util.exceptions import InsufficientBalanceError, ModerationError
from backend.util.exceptions import InsufficientBalanceError
if TYPE_CHECKING:
from backend.executor import DatabaseManagerClient, DatabaseManagerAsyncClient
@@ -67,7 +67,6 @@ from backend.executor.utils import (
validate_exec,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.server.v2.AutoMod.manager import automod_manager
from backend.util import json
from backend.util.clients import (
get_async_execution_event_bus,
@@ -733,6 +732,7 @@ class ExecutionProcessor:
log_metadata: LogMetadata,
execution_stats: GraphExecutionStats,
) -> ExecutionStatus:
"""
Returns:
dict: The execution statistics of the graph execution.
@@ -760,22 +760,6 @@ class ExecutionProcessor:
amount=1,
)
# Input moderation
try:
if moderation_error := asyncio.run_coroutine_threadsafe(
automod_manager.moderate_graph_execution_inputs(
db_client=get_db_async_client(),
graph_exec=graph_exec,
),
self.node_evaluation_loop,
).result(timeout=30.0):
raise moderation_error
except asyncio.TimeoutError:
log_metadata.warning(
f"Input moderation timed out for graph execution {graph_exec.graph_exec_id}, bypassing moderation and continuing execution"
)
# Continue execution without moderation
# ------------------------------------------------------------
# Prepopulate queue ---------------------------------------
# ------------------------------------------------------------
@@ -914,25 +898,6 @@ class ExecutionProcessor:
time.sleep(0.1)
# loop done --------------------------------------------------
# Output moderation
try:
if moderation_error := asyncio.run_coroutine_threadsafe(
automod_manager.moderate_graph_execution_outputs(
db_client=get_db_async_client(),
graph_exec_id=graph_exec.graph_exec_id,
user_id=graph_exec.user_id,
graph_id=graph_exec.graph_id,
),
self.node_evaluation_loop,
).result(timeout=30.0):
raise moderation_error
except asyncio.TimeoutError:
log_metadata.warning(
f"Output moderation timed out for graph execution {graph_exec.graph_exec_id}, bypassing moderation and continuing execution"
)
# Continue execution without moderation
# Determine final execution status based on whether there was an error or termination
if cancel.is_set():
execution_status = ExecutionStatus.TERMINATED
@@ -953,12 +918,11 @@ class ExecutionProcessor:
else Exception(f"{e.__class__.__name__}: {e}")
)
known_errors = (InsufficientBalanceError, ModerationError)
known_errors = (InsufficientBalanceError,)
if isinstance(error, known_errors):
execution_stats.error = str(error)
return ExecutionStatus.FAILED
execution_status = ExecutionStatus.FAILED
log_metadata.exception(
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
)
@@ -1208,9 +1172,6 @@ class ExecutionManager(AppProcess):
)
return
# Check if channel is closed and force reconnection if needed
if not self.cancel_client.is_ready:
self.cancel_client.disconnect()
self.cancel_client.connect()
cancel_channel = self.cancel_client.get_channel()
cancel_channel.basic_consume(
@@ -1240,9 +1201,6 @@ class ExecutionManager(AppProcess):
)
return
# Check if channel is closed and force reconnection if needed
if not self.run_client.is_ready:
self.run_client.disconnect()
self.run_client.connect()
run_channel = self.run_client.get_channel()
run_channel.basic_qos(prefetch_count=self.pool_size)
@@ -1300,7 +1258,7 @@ class ExecutionManager(AppProcess):
def _handle_run_message(
self,
_channel: BlockingChannel,
channel: BlockingChannel,
method: Basic.Deliver,
_properties: BasicProperties,
body: bytes,
@@ -1310,9 +1268,6 @@ class ExecutionManager(AppProcess):
@func_retry
def _ack_message(reject: bool = False):
"""Acknowledge or reject the message based on execution status."""
# Connection can be lost, so always get a fresh channel
channel = self.run_client.get_channel()
if reject:
channel.connection.add_callback_threadsafe(
lambda: channel.basic_nack(delivery_tag, requeue=True)
@@ -1404,25 +1359,6 @@ class ExecutionManager(AppProcess):
else:
utilization_gauge.set(active_count / self.pool_size)
def _stop_message_consumers(
self, thread: threading.Thread, client: SyncRabbitMQ, prefix: str
):
try:
channel = client.get_channel()
channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming())
try:
thread.join(timeout=300)
except TimeoutError:
logger.error(
f"{prefix} ⚠️ Run thread did not finish in time, forcing disconnect"
)
client.disconnect()
logger.info(f"{prefix} ✅ Run client disconnected")
except Exception as e:
logger.error(f"{prefix} ⚠️ Error disconnecting run client: {type(e)} {e}")
def cleanup(self):
"""Override cleanup to implement graceful shutdown with active execution waiting."""
prefix = f"[{self.service_name}][on_graph_executor_stop {os.getpid()}]"
@@ -1478,16 +1414,26 @@ class ExecutionManager(AppProcess):
logger.error(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}")
# Disconnect the run execution consumer
self._stop_message_consumers(
self.run_thread,
self.run_client,
prefix + " [run-consumer]",
)
self._stop_message_consumers(
self.cancel_thread,
self.cancel_client,
prefix + " [cancel-consumer]",
)
try:
run_channel = self.run_client.get_channel()
run_channel.connection.add_callback_threadsafe(
lambda: self.run_client.disconnect()
)
self.run_thread.join()
logger.info(f"{prefix} ✅ Run client disconnected")
except Exception as e:
logger.error(f"{prefix} ⚠️ Error disconnecting run client: {type(e)} {e}")
# Disconnect the cancel execution consumer
try:
cancel_channel = self.cancel_client.get_channel()
cancel_channel.connection.add_callback_threadsafe(
lambda: self.cancel_client.disconnect()
)
self.cancel_thread.join()
logger.info(f"{prefix} ✅ Cancel client disconnected")
except Exception as e:
logger.error(f"{prefix} ⚠️ Error disconnecting cancel client: {type(e)} {e}")
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")

View File

@@ -3,6 +3,7 @@ import logging
import autogpt_libs.auth.models
import fastapi.responses
import pytest
from prisma.models import User
import backend.server.v2.library.model
import backend.server.v2.store.model
@@ -11,7 +12,6 @@ from backend.blocks.data_manipulation import FindInDictionaryBlock
from backend.blocks.io import AgentInputBlock
from backend.blocks.maths import CalculatorBlock, Operation
from backend.data import execution, graph
from backend.data.model import User
from backend.server.model import CreateGraph
from backend.server.rest_api import AgentServer
from backend.usecases.sample import create_test_graph, create_test_user

View File

@@ -1,22 +1,17 @@
import asyncio
import logging
import os
import threading
from enum import Enum
from typing import Optional
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from apscheduler.events import (
EVENT_JOB_ERROR,
EVENT_JOB_EXECUTED,
EVENT_JOB_MAX_INSTANCES,
EVENT_JOB_MISSED,
)
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
from apscheduler.job import Job as JobObj
from apscheduler.jobstores.memory import MemoryJobStore
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.schedulers.blocking import BlockingScheduler
from apscheduler.triggers.cron import CronTrigger
from autogpt_libs.utils.cache import thread_cached
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import MetaData, create_engine
@@ -35,14 +30,7 @@ from backend.monitoring import (
from backend.util.cloud_storage import cleanup_expired_files_async
from backend.util.exceptions import NotAuthorizedError, NotFoundError
from backend.util.logging import PrefixFilter
from backend.util.retry import func_retry
from backend.util.service import (
AppService,
AppServiceClient,
UnhealthyServiceError,
endpoint_to_async,
expose,
)
from backend.util.service import AppService, AppServiceClient, endpoint_to_async, expose
from backend.util.settings import Config
@@ -72,69 +60,26 @@ apscheduler_logger.addFilter(PrefixFilter("[Scheduler] [APScheduler]"))
config = Config()
# Timeout constants
SCHEDULER_OPERATION_TIMEOUT_SECONDS = 300 # 5 minutes for scheduler operations
def job_listener(event):
"""Logs job execution outcomes for better monitoring."""
if event.exception:
logger.error(
f"Job {event.job_id} failed: {type(event.exception).__name__}: {event.exception}"
)
logger.error(f"Job {event.job_id} failed.")
else:
logger.info(f"Job {event.job_id} completed successfully.")
def job_missed_listener(event):
"""Logs when jobs are missed due to scheduling issues."""
logger.warning(
f"Job {event.job_id} was missed at scheduled time {event.scheduled_run_time}. "
f"This can happen if the scheduler is overloaded or if previous executions are still running."
)
def job_max_instances_listener(event):
"""Logs when jobs hit max instances limit."""
logger.warning(
f"Job {event.job_id} execution was SKIPPED - max instances limit reached. "
f"Previous execution(s) are still running. "
f"Consider increasing max_instances or check why previous executions are taking too long."
)
_event_loop: asyncio.AbstractEventLoop | None = None
_event_loop_thread: threading.Thread | None = None
@func_retry
@thread_cached
def get_event_loop():
"""Get the shared event loop."""
if _event_loop is None:
raise RuntimeError("Event loop not initialized. Scheduler not started.")
return _event_loop
def run_async(coro, timeout: float = SCHEDULER_OPERATION_TIMEOUT_SECONDS):
"""Run a coroutine in the shared event loop and wait for completion."""
loop = get_event_loop()
future = asyncio.run_coroutine_threadsafe(coro, loop)
try:
return future.result(timeout=timeout)
except Exception as e:
logger.error(f"Async operation failed: {type(e).__name__}: {e}")
raise
return asyncio.new_event_loop()
def execute_graph(**kwargs):
"""Execute graph in the shared event loop and wait for completion."""
# Wait for completion to ensure job doesn't exit prematurely
run_async(_execute_graph(**kwargs))
get_event_loop().run_until_complete(_execute_graph(**kwargs))
async def _execute_graph(**kwargs):
args = GraphExecutionJobArgs(**kwargs)
start_time = asyncio.get_event_loop().time()
try:
logger.info(f"Executing recurring job for graph #{args.graph_id}")
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
@@ -144,28 +89,17 @@ async def _execute_graph(**kwargs):
inputs=args.input_data,
graph_credentials_inputs=args.input_credentials,
)
elapsed = asyncio.get_event_loop().time() - start_time
logger.info(
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
f"(took {elapsed:.2f}s to create and publish)"
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id}"
)
if elapsed > 10:
logger.warning(
f"Graph execution {graph_exec.id} took {elapsed:.2f}s to create/publish - "
f"this is unusually slow and may indicate resource contention"
)
except Exception as e:
elapsed = asyncio.get_event_loop().time() - start_time
logger.error(
f"Error executing graph {args.graph_id} after {elapsed:.2f}s: "
f"{type(e).__name__}: {e}"
)
# TODO: We need to communicate this error to the user somehow.
logger.error(f"Error executing graph {args.graph_id}: {e}")
def cleanup_expired_files():
"""Clean up expired files from cloud storage."""
# Wait for completion
run_async(cleanup_expired_files_async())
get_event_loop().run_until_complete(cleanup_expired_files_async())
# Monitoring functions are now imported from monitoring module
@@ -221,7 +155,7 @@ class NotificationJobInfo(NotificationJobArgs):
class Scheduler(AppService):
scheduler: BackgroundScheduler
scheduler: BlockingScheduler
def __init__(self, register_system_tasks: bool = True):
self.register_system_tasks = register_system_tasks
@@ -234,50 +168,15 @@ class Scheduler(AppService):
def db_pool_size(cls) -> int:
return config.scheduler_db_pool_size
async def health_check(self) -> str:
# Thread-safe health check with proper initialization handling
if not hasattr(self, "scheduler"):
raise UnhealthyServiceError("Scheduler is still initializing")
# Check if we're in the middle of cleanup
if self.cleaned_up:
return await super().health_check()
# Normal operation - check if scheduler is running
def health_check(self) -> str:
if not self.scheduler.running:
raise UnhealthyServiceError("Scheduler is not running")
return await super().health_check()
raise RuntimeError("Scheduler is not running")
return super().health_check()
def run_service(self):
load_dotenv()
# Initialize the event loop for async jobs
global _event_loop
_event_loop = asyncio.new_event_loop()
# Use daemon thread since it should die with the main service
global _event_loop_thread
_event_loop_thread = threading.Thread(
target=_event_loop.run_forever, daemon=True, name="SchedulerEventLoop"
)
_event_loop_thread.start()
db_schema, db_url = _extract_schema_from_url(os.getenv("DIRECT_URL"))
# Configure executors to limit concurrency without skipping jobs
from apscheduler.executors.pool import ThreadPoolExecutor
self.scheduler = BackgroundScheduler(
executors={
"default": ThreadPoolExecutor(
max_workers=self.db_pool_size()
), # Match DB pool size to prevent resource contention
},
job_defaults={
"coalesce": True, # Skip redundant missed jobs - just run the latest
"max_instances": 1000, # Effectively unlimited - never drop executions
"misfire_grace_time": None, # No time limit for missed jobs
},
self.scheduler = BlockingScheduler(
jobstores={
Jobstores.EXECUTION.value: SQLAlchemyJobStore(
engine=create_engine(
@@ -307,10 +206,9 @@ class Scheduler(AppService):
if self.register_system_tasks:
# Notification PROCESS WEEKLY SUMMARY
# Runs every Monday at 9 AM UTC
self.scheduler.add_job(
process_weekly_summary,
CronTrigger.from_crontab("0 9 * * 1"),
CronTrigger.from_crontab("0 * * * *"),
id="process_weekly_summary",
kwargs={},
replace_existing=True,
@@ -358,30 +256,13 @@ class Scheduler(AppService):
)
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
self.scheduler.start()
# Keep the service running since BackgroundScheduler doesn't block
super().run_service()
def cleanup(self):
super().cleanup()
if self.scheduler:
logger.info("⏳ Shutting down scheduler...")
self.scheduler.shutdown(wait=True)
global _event_loop
if _event_loop:
logger.info("⏳ Closing event loop...")
_event_loop.call_soon_threadsafe(_event_loop.stop)
global _event_loop_thread
if _event_loop_thread:
logger.info("⏳ Waiting for event loop thread to finish...")
_event_loop_thread.join(timeout=SCHEDULER_OPERATION_TIMEOUT_SECONDS)
logger.info("Scheduler cleanup complete.")
self.scheduler.shutdown(wait=False)
@expose
def add_graph_execution_schedule(
@@ -394,18 +275,6 @@ class Scheduler(AppService):
input_credentials: dict[str, CredentialsMetaInput],
name: Optional[str] = None,
) -> GraphExecutionJobInfo:
# Validate the graph before scheduling to prevent runtime failures
# We don't need the return value, just want the validation to run
run_async(
execution_utils.validate_and_construct_node_execution_input(
graph_id=graph_id,
user_id=user_id,
graph_inputs=input_data,
graph_version=graph_version,
graph_credentials_inputs=input_credentials,
)
)
job_args = GraphExecutionJobArgs(
user_id=user_id,
graph_id=graph_id,

View File

@@ -548,7 +548,7 @@ async def validate_graph_with_credentials(
return node_input_errors
async def _construct_starting_node_execution_input(
async def construct_node_execution_input(
graph: GraphModel,
user_id: str,
graph_inputs: BlockInput,
@@ -615,67 +615,6 @@ async def _construct_starting_node_execution_input(
return nodes_input
async def validate_and_construct_node_execution_input(
graph_id: str,
user_id: str,
graph_inputs: BlockInput,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> tuple[GraphModel, list[tuple[str, BlockInput]], dict[str, dict[str, JsonValue]]]:
"""
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
This centralizes the logic used by both scheduler validation and actual execution.
Args:
graph_id: The ID of the graph to validate/construct.
user_id: The ID of the user.
graph_inputs: The input data for the graph execution.
graph_version: The version of the graph to use.
graph_credentials_inputs: Credentials inputs to use.
nodes_input_masks: Node inputs to use.
Returns:
tuple[GraphModel, list[tuple[str, BlockInput]]]: Graph model and list of tuples for node execution input.
Raises:
NotFoundError: If the graph is not found.
GraphValidationError: If the graph has validation issues.
ValueError: If there are other validation errors.
"""
if prisma.is_connected():
gdb = graph_db
else:
gdb = get_database_manager_async_client()
graph: GraphModel | None = await gdb.get_graph(
graph_id=graph_id,
user_id=user_id,
version=graph_version,
include_subgraphs=True,
)
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
nodes_input_masks = _merge_nodes_input_masks(
(
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else {}
),
nodes_input_masks or {},
)
starting_nodes_input = await _construct_starting_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=graph_inputs,
nodes_input_masks=nodes_input_masks,
)
return graph, starting_nodes_input, nodes_input_masks
def _merge_nodes_input_masks(
overrides_map_1: dict[str, dict[str, JsonValue]],
overrides_map_2: dict[str, dict[str, JsonValue]],
@@ -852,19 +791,34 @@ async def add_graph_execution(
ValueError: If the graph is not found or if there are validation errors.
"""
if prisma.is_connected():
gdb = graph_db
edb = execution_db
else:
gdb = get_database_manager_async_client()
edb = get_database_manager_async_client()
graph, starting_nodes_input, nodes_input_masks = (
await validate_and_construct_node_execution_input(
graph_id=graph_id,
user_id=user_id,
graph_inputs=inputs or {},
graph_version=graph_version,
graph_credentials_inputs=graph_credentials_inputs,
nodes_input_masks=nodes_input_masks,
)
graph: GraphModel | None = await gdb.get_graph(
graph_id=graph_id,
user_id=user_id,
version=graph_version,
include_subgraphs=True,
)
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
nodes_input_masks = _merge_nodes_input_masks(
(
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else {}
),
nodes_input_masks or {},
)
starting_nodes_input = await construct_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=inputs or {},
nodes_input_masks=nodes_input_masks,
)
graph_exec = None
@@ -881,19 +835,11 @@ async def add_graph_execution(
graph_exec_entry = graph_exec.to_graph_execution_entry()
if nodes_input_masks:
graph_exec_entry.nodes_input_masks = nodes_input_masks
logger.info(
f"Created graph execution #{graph_exec.id} for graph "
f"#{graph_id} with {len(starting_nodes_input)} starting nodes. "
f"Now publishing to execution queue."
)
await queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec_entry.model_dump_json(),
exchange=GRAPH_EXECUTION_EXCHANGE,
)
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
bus = get_async_execution_event_bus()
await bus.publish(graph_exec)

View File

@@ -182,15 +182,6 @@ zerobounce_credentials = APIKeyCredentials(
expires_at=None,
)
enrichlayer_credentials = APIKeyCredentials(
id="d9fce73a-6c1d-4e8b-ba2e-12a456789def",
provider="enrichlayer",
api_key=SecretStr(settings.secrets.enrichlayer_api_key),
title="Use Credits for Enrichlayer",
expires_at=None,
)
llama_api_credentials = APIKeyCredentials(
id="d44045af-1c33-4833-9e19-752313214de2",
provider="llama_api",
@@ -199,14 +190,6 @@ llama_api_credentials = APIKeyCredentials(
expires_at=None,
)
v0_credentials = APIKeyCredentials(
id="c4e6d1a0-3b5f-4789-a8e2-9b123456789f",
provider="v0",
api_key=SecretStr(settings.secrets.v0_api_key),
title="Use Credits for v0 by Vercel",
expires_at=None,
)
DEFAULT_CREDENTIALS = [
ollama_credentials,
revid_credentials,
@@ -220,7 +203,6 @@ DEFAULT_CREDENTIALS = [
jina_credentials,
unreal_credentials,
open_router_credentials,
enrichlayer_credentials,
fal_credentials,
exa_credentials,
e2b_credentials,
@@ -231,8 +213,6 @@ DEFAULT_CREDENTIALS = [
smartlead_credentials,
zerobounce_credentials,
google_maps_credentials,
llama_api_credentials,
v0_credentials,
]
@@ -299,8 +279,6 @@ class IntegrationCredentialsStore:
all_credentials.append(unreal_credentials)
if settings.secrets.open_router_api_key:
all_credentials.append(open_router_credentials)
if settings.secrets.enrichlayer_api_key:
all_credentials.append(enrichlayer_credentials)
if settings.secrets.fal_api_key:
all_credentials.append(fal_credentials)
if settings.secrets.exa_api_key:

View File

@@ -25,7 +25,6 @@ class ProviderName(str, Enum):
GROQ = "groq"
HTTP = "http"
HUBSPOT = "hubspot"
ENRICHLAYER = "enrichlayer"
IDEOGRAM = "ideogram"
JINA = "jina"
LLAMA_API = "llama_api"
@@ -48,7 +47,6 @@ class ProviderName(str, Enum):
TWITTER = "twitter"
TODOIST = "todoist"
UNREAL_SPEECH = "unreal_speech"
V0 = "v0"
ZEROBOUNCE = "zerobounce"
@classmethod

View File

@@ -1,15 +0,0 @@
from backend.app import run_processes
from backend.notifications.notifications import NotificationManager
def main():
"""
Run the AutoGPT-server Notification Service.
"""
run_processes(
NotificationManager(),
)
if __name__ == "__main__":
main()

View File

@@ -1,7 +1,8 @@
import asyncio
import logging
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime, timedelta, timezone
from typing import Awaitable, Callable
from typing import Callable
import aio_pika
from prisma.enums import NotificationType
@@ -27,17 +28,11 @@ from backend.data.notifications import (
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.user import generate_unsubscribe_link
from backend.notifications.email import EmailSender
from backend.util.clients import get_database_manager_async_client
from backend.util.clients import get_database_manager_client
from backend.util.logging import TruncatedLogger
from backend.util.metrics import discord_send_alert
from backend.util.retry import continuous_retry
from backend.util.service import (
AppService,
AppServiceClient,
UnhealthyServiceError,
endpoint_to_sync,
expose,
)
from backend.util.service import AppService, AppServiceClient, endpoint_to_sync, expose
from backend.util.settings import Settings
logger = TruncatedLogger(logging.getLogger(__name__), "[NotificationManager]")
@@ -48,6 +43,8 @@ NOTIFICATION_EXCHANGE = Exchange(name="notifications", type=ExchangeType.TOPIC)
DEAD_LETTER_EXCHANGE = Exchange(name="dead_letter", type=ExchangeType.TOPIC)
EXCHANGES = [NOTIFICATION_EXCHANGE, DEAD_LETTER_EXCHANGE]
background_executor = ProcessPoolExecutor(max_workers=2)
def create_notification_config() -> RabbitMQConfig:
"""Create RabbitMQ configuration for notifications"""
@@ -188,33 +185,24 @@ class NotificationManager(AppService):
@property
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
"""Access the RabbitMQ service. Will raise if not configured."""
if not hasattr(self, "rabbitmq_service") or not self.rabbitmq_service:
raise UnhealthyServiceError("RabbitMQ not configured for this service")
if not self.rabbitmq_service:
raise RuntimeError("RabbitMQ not configured for this service")
return self.rabbitmq_service
@property
def rabbit_config(self) -> rabbitmq.RabbitMQConfig:
"""Access the RabbitMQ config. Will raise if not configured."""
if not self.rabbitmq_config:
raise UnhealthyServiceError("RabbitMQ not configured for this service")
raise RuntimeError("RabbitMQ not configured for this service")
return self.rabbitmq_config
async def health_check(self) -> str:
# Service is unhealthy if RabbitMQ is not ready
if not hasattr(self, "rabbitmq_service") or not self.rabbitmq_service:
raise UnhealthyServiceError("RabbitMQ not configured for this service")
if not self.rabbitmq_service.is_ready:
raise UnhealthyServiceError("RabbitMQ channel is not ready")
return await super().health_check()
@classmethod
def get_port(cls) -> int:
return settings.config.notification_service_port
@expose
async def queue_weekly_summary(self):
# Use the existing event loop instead of creating a new one with asyncio.run()
asyncio.create_task(self._queue_weekly_summary())
def queue_weekly_summary(self):
background_executor.submit(lambda: asyncio.run(self._queue_weekly_summary()))
async def _queue_weekly_summary(self):
"""Process weekly summary for specified notification types"""
@@ -223,14 +211,10 @@ class NotificationManager(AppService):
processed_count = 0
current_time = datetime.now(tz=timezone.utc)
start_time = current_time - timedelta(days=7)
logger.info(
f"Querying for active users between {start_time} and {current_time}"
)
users = await get_database_manager_async_client().get_active_user_ids_in_timerange(
users = get_database_manager_client().get_active_user_ids_in_timerange(
end_time=current_time.isoformat(),
start_time=start_time.isoformat(),
)
logger.info(f"Found {len(users)} active users in the last 7 days")
for user in users:
await self._queue_scheduled_notification(
SummaryParamsEventModel(
@@ -250,15 +234,10 @@ class NotificationManager(AppService):
logger.exception(f"Error processing weekly summary: {e}")
@expose
async def process_existing_batches(
self, notification_types: list[NotificationType]
):
# Use the existing event loop instead of creating a new process
asyncio.create_task(self._process_existing_batches(notification_types))
def process_existing_batches(self, notification_types: list[NotificationType]):
background_executor.submit(self._process_existing_batches, notification_types)
async def _process_existing_batches(
self, notification_types: list[NotificationType]
):
def _process_existing_batches(self, notification_types: list[NotificationType]):
"""Process existing batches for specified notification types"""
try:
processed_count = 0
@@ -266,15 +245,13 @@ class NotificationManager(AppService):
for notification_type in notification_types:
# Get all batches for this notification type
batches = (
await get_database_manager_async_client().get_all_batches_by_type(
notification_type
)
batches = get_database_manager_client().get_all_batches_by_type(
notification_type
)
for batch in batches:
# Check if batch has aged out
oldest_message = await get_database_manager_async_client().get_user_notification_oldest_message_in_batch(
oldest_message = get_database_manager_client().get_user_notification_oldest_message_in_batch(
batch.user_id, notification_type
)
@@ -289,8 +266,10 @@ class NotificationManager(AppService):
# If batch has aged out, process it
if oldest_message.created_at + max_delay < current_time:
recipient_email = await get_database_manager_async_client().get_user_email_by_id(
batch.user_id
recipient_email = (
get_database_manager_client().get_user_email_by_id(
batch.user_id
)
)
if not recipient_email:
@@ -299,7 +278,7 @@ class NotificationManager(AppService):
)
continue
should_send = await self._should_email_user_based_on_preference(
should_send = self._should_email_user_based_on_preference(
batch.user_id, notification_type
)
@@ -308,13 +287,15 @@ class NotificationManager(AppService):
f"User {batch.user_id} does not want to receive {notification_type} notifications"
)
# Clear the batch
await get_database_manager_async_client().empty_user_notification_batch(
get_database_manager_client().empty_user_notification_batch(
batch.user_id, notification_type
)
continue
batch_data = await get_database_manager_async_client().get_user_notification_batch(
batch.user_id, notification_type
batch_data = (
get_database_manager_client().get_user_notification_batch(
batch.user_id, notification_type
)
)
if not batch_data or not batch_data.notifications:
@@ -322,7 +303,7 @@ class NotificationManager(AppService):
f"Batch data not found for user {batch.user_id}"
)
# Clear the batch
await get_database_manager_async_client().empty_user_notification_batch(
get_database_manager_client().empty_user_notification_batch(
batch.user_id, notification_type
)
continue
@@ -358,7 +339,7 @@ class NotificationManager(AppService):
)
# Clear the batch
await get_database_manager_async_client().empty_user_notification_batch(
get_database_manager_client().empty_user_notification_batch(
batch.user_id, notification_type
)
@@ -388,13 +369,10 @@ class NotificationManager(AppService):
async def _queue_scheduled_notification(self, event: SummaryParamsEventModel):
"""Queue a scheduled notification - exposed method for other services to call"""
try:
logger.info(
f"Queueing scheduled notification type={event.type} user_id={event.user_id}"
)
logger.debug(f"Received Request to queue scheduled notification {event=}")
exchange = "notifications"
routing_key = get_routing_key(event.type)
logger.info(f"Using routing key: {routing_key}")
# Publish to RabbitMQ
await self.rabbit.publish_message(
@@ -402,132 +380,117 @@ class NotificationManager(AppService):
message=event.model_dump_json(),
exchange=next(ex for ex in EXCHANGES if ex.name == exchange),
)
logger.info(f"Successfully queued notification for user {event.user_id}")
except Exception as e:
logger.exception(f"Error queueing notification: {e}")
async def _should_email_user_based_on_preference(
def _should_email_user_based_on_preference(
self, user_id: str, event_type: NotificationType
) -> bool:
"""Check if a user wants to receive a notification based on their preferences and email verification status"""
validated_email = (
await get_database_manager_async_client().get_user_email_verification(
user_id
)
validated_email = get_database_manager_client().get_user_email_verification(
user_id
)
preference = (
await get_database_manager_async_client().get_user_notification_preference(
user_id
)
).preferences.get(event_type, True)
get_database_manager_client()
.get_user_notification_preference(user_id)
.preferences.get(event_type, True)
)
# only if both are true, should we email this person
return validated_email and preference
async def _gather_summary_data(
def _gather_summary_data(
self, user_id: str, event_type: NotificationType, params: BaseSummaryParams
) -> BaseSummaryData:
"""Gathers the data to build a summary notification"""
logger.info(
f"Gathering summary data for {user_id} and {event_type} with {params=}"
f"Gathering summary data for {user_id} and {event_type} wiht {params=}"
)
try:
# Get summary data from the database
summary_data = await get_database_manager_async_client().get_user_execution_summary_data(
user_id=user_id,
start_time=params.start_date,
end_time=params.end_date,
# total_credits_used = self.run_and_wait(
# get_total_credits_used(user_id, start_time, end_time)
# )
# total_executions = self.run_and_wait(
# get_total_executions(user_id, start_time, end_time)
# )
# most_used_agent = self.run_and_wait(
# get_most_used_agent(user_id, start_time, end_time)
# )
# execution_times = self.run_and_wait(
# get_execution_time(user_id, start_time, end_time)
# )
# runs = self.run_and_wait(
# get_runs(user_id, start_time, end_time)
# )
total_credits_used = 3.0
total_executions = 2
most_used_agent = {"name": "Some"}
execution_times = [1, 2, 3]
runs = [{"status": "COMPLETED"}, {"status": "FAILED"}]
successful_runs = len([run for run in runs if run["status"] == "COMPLETED"])
failed_runs = len([run for run in runs if run["status"] != "COMPLETED"])
average_execution_time = (
sum(execution_times) / len(execution_times) if execution_times else 0
)
# cost_breakdown = self.run_and_wait(
# get_cost_breakdown(user_id, start_time, end_time)
# )
cost_breakdown = {
"agent1": 1.0,
"agent2": 2.0,
}
if event_type == NotificationType.DAILY_SUMMARY and isinstance(
params, DailySummaryParams
):
return DailySummaryData(
total_credits_used=total_credits_used,
total_executions=total_executions,
most_used_agent=most_used_agent["name"],
total_execution_time=sum(execution_times),
successful_runs=successful_runs,
failed_runs=failed_runs,
average_execution_time=average_execution_time,
cost_breakdown=cost_breakdown,
date=params.date,
)
elif event_type == NotificationType.WEEKLY_SUMMARY and isinstance(
params, WeeklySummaryParams
):
return WeeklySummaryData(
total_credits_used=total_credits_used,
total_executions=total_executions,
most_used_agent=most_used_agent["name"],
total_execution_time=sum(execution_times),
successful_runs=successful_runs,
failed_runs=failed_runs,
average_execution_time=average_execution_time,
cost_breakdown=cost_breakdown,
start_date=params.start_date,
end_date=params.end_date,
)
else:
raise ValueError("Invalid event type or params")
# Extract data from summary
total_credits_used = summary_data.total_credits_used
total_executions = summary_data.total_executions
most_used_agent = summary_data.most_used_agent
successful_runs = summary_data.successful_runs
failed_runs = summary_data.failed_runs
total_execution_time = summary_data.total_execution_time
average_execution_time = summary_data.average_execution_time
cost_breakdown = summary_data.cost_breakdown
if event_type == NotificationType.DAILY_SUMMARY and isinstance(
params, DailySummaryParams
):
return DailySummaryData(
total_credits_used=total_credits_used,
total_executions=total_executions,
most_used_agent=most_used_agent,
total_execution_time=total_execution_time,
successful_runs=successful_runs,
failed_runs=failed_runs,
average_execution_time=average_execution_time,
cost_breakdown=cost_breakdown,
date=params.date,
)
elif event_type == NotificationType.WEEKLY_SUMMARY and isinstance(
params, WeeklySummaryParams
):
return WeeklySummaryData(
total_credits_used=total_credits_used,
total_executions=total_executions,
most_used_agent=most_used_agent,
total_execution_time=total_execution_time,
successful_runs=successful_runs,
failed_runs=failed_runs,
average_execution_time=average_execution_time,
cost_breakdown=cost_breakdown,
start_date=params.start_date,
end_date=params.end_date,
)
else:
raise ValueError("Invalid event type or params")
except Exception as e:
logger.error(f"Failed to gather summary data: {e}")
# Return sensible defaults in case of error
if event_type == NotificationType.DAILY_SUMMARY and isinstance(
params, DailySummaryParams
):
return DailySummaryData(
total_credits_used=0.0,
total_executions=0,
most_used_agent="No data available",
total_execution_time=0.0,
successful_runs=0,
failed_runs=0,
average_execution_time=0.0,
cost_breakdown={},
date=params.date,
)
elif event_type == NotificationType.WEEKLY_SUMMARY and isinstance(
params, WeeklySummaryParams
):
return WeeklySummaryData(
total_credits_used=0.0,
total_executions=0,
most_used_agent="No data available",
total_execution_time=0.0,
successful_runs=0,
failed_runs=0,
average_execution_time=0.0,
cost_breakdown={},
start_date=params.start_date,
end_date=params.end_date,
)
else:
raise ValueError("Invalid event type or params") from e
async def _should_batch(
def _should_batch(
self, user_id: str, event_type: NotificationType, event: NotificationEventModel
) -> bool:
await get_database_manager_async_client().create_or_add_to_user_notification_batch(
get_database_manager_client().create_or_add_to_user_notification_batch(
user_id, event_type, event
)
oldest_message = await get_database_manager_async_client().get_user_notification_oldest_message_in_batch(
user_id, event_type
oldest_message = (
get_database_manager_client().get_user_notification_oldest_message_in_batch(
user_id, event_type
)
)
if not oldest_message:
logger.error(
@@ -556,7 +519,7 @@ class NotificationManager(AppService):
logger.error(f"Error parsing message due to non matching schema {e}")
return None
async def _process_admin_message(self, message: str) -> bool:
def _process_admin_message(self, message: str) -> bool:
"""Process a single notification, sending to an admin, returning whether to put into the failed queue"""
try:
event = self._parse_message(message)
@@ -570,7 +533,7 @@ class NotificationManager(AppService):
logger.exception(f"Error processing notification for admin queue: {e}")
return False
async def _process_immediate(self, message: str) -> bool:
def _process_immediate(self, message: str) -> bool:
"""Process a single notification immediately, returning whether to put into the failed queue"""
try:
event = self._parse_message(message)
@@ -578,16 +541,14 @@ class NotificationManager(AppService):
return False
logger.debug(f"Processing immediate notification: {event}")
recipient_email = (
await get_database_manager_async_client().get_user_email_by_id(
event.user_id
)
recipient_email = get_database_manager_client().get_user_email_by_id(
event.user_id
)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
should_send = await self._should_email_user_based_on_preference(
should_send = self._should_email_user_based_on_preference(
event.user_id, event.type
)
if not should_send:
@@ -609,7 +570,7 @@ class NotificationManager(AppService):
logger.exception(f"Error processing notification for immediate queue: {e}")
return False
async def _process_batch(self, message: str) -> bool:
def _process_batch(self, message: str) -> bool:
"""Process a single notification with a batching strategy, returning whether to put into the failed queue"""
try:
event = self._parse_message(message)
@@ -617,16 +578,14 @@ class NotificationManager(AppService):
return False
logger.info(f"Processing batch notification: {event}")
recipient_email = (
await get_database_manager_async_client().get_user_email_by_id(
event.user_id
)
recipient_email = get_database_manager_client().get_user_email_by_id(
event.user_id
)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
should_send = await self._should_email_user_based_on_preference(
should_send = self._should_email_user_based_on_preference(
event.user_id, event.type
)
if not should_send:
@@ -635,15 +594,13 @@ class NotificationManager(AppService):
)
return True
should_send = await self._should_batch(event.user_id, event.type, event)
should_send = self._should_batch(event.user_id, event.type, event)
if not should_send:
logger.info("Batch not old enough to send")
return False
batch = (
await get_database_manager_async_client().get_user_notification_batch(
event.user_id, event.type
)
batch = get_database_manager_client().get_user_notification_batch(
event.user_id, event.type
)
if not batch or not batch.notifications:
logger.error(f"Batch not found for user {event.user_id}")
@@ -745,7 +702,7 @@ class NotificationManager(AppService):
logger.info(
f"Successfully sent all {successfully_sent_count} notifications, clearing batch"
)
await get_database_manager_async_client().empty_user_notification_batch(
get_database_manager_client().empty_user_notification_batch(
event.user_id, event.type
)
else:
@@ -758,7 +715,7 @@ class NotificationManager(AppService):
logger.exception(f"Error processing notification for batch queue: {e}")
return False
async def _process_summary(self, message: str) -> bool:
def _process_summary(self, message: str) -> bool:
"""Process a single notification with a summary strategy, returning whether to put into the failed queue"""
try:
logger.info(f"Processing summary notification: {message}")
@@ -769,15 +726,13 @@ class NotificationManager(AppService):
logger.info(f"Processing summary notification: {model}")
recipient_email = (
await get_database_manager_async_client().get_user_email_by_id(
event.user_id
)
recipient_email = get_database_manager_client().get_user_email_by_id(
event.user_id
)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
should_send = await self._should_email_user_based_on_preference(
should_send = self._should_email_user_based_on_preference(
event.user_id, event.type
)
if not should_send:
@@ -786,7 +741,7 @@ class NotificationManager(AppService):
)
return True
summary_data = await self._gather_summary_data(
summary_data = self._gather_summary_data(
event.user_id, event.type, model.data
)
@@ -812,7 +767,7 @@ class NotificationManager(AppService):
async def _consume_queue(
self,
queue: aio_pika.abc.AbstractQueue,
process_func: Callable[[str], Awaitable[bool]],
process_func: Callable[[str], bool],
queue_name: str,
):
"""Continuously consume messages from a queue using async iteration"""
@@ -826,7 +781,7 @@ class NotificationManager(AppService):
try:
async with message.process():
result = await process_func(message.body.decode())
result = process_func(message.body.decode())
if not result:
# Message will be rejected when exiting context without exception
raise aio_pika.exceptions.MessageProcessError(
@@ -925,8 +880,6 @@ class NotificationManagerClient(AppServiceClient):
def get_service_type(cls):
return NotificationManager
process_existing_batches = endpoint_to_sync(
NotificationManager.process_existing_batches
)
queue_weekly_summary = endpoint_to_sync(NotificationManager.queue_weekly_summary)
process_existing_batches = NotificationManager.process_existing_batches
queue_weekly_summary = NotificationManager.queue_weekly_summary
discord_system_alert = endpoint_to_sync(NotificationManager.discord_system_alert)

View File

@@ -5,64 +5,23 @@ data.start_date: the start date of the summary
data.end_date: the end date of the summary
data.total_credits_used: the total credits used during the summary
data.total_executions: the total number of executions during the summary
data.most_used_agent: the most used agent's name during the summary
data.most_used_agent: the most used agent's nameduring the summary
data.total_execution_time: the total execution time during the summary
data.successful_runs: the total number of successful runs during the summary
data.failed_runs: the total number of failed runs during the summary
data.average_execution_time: the average execution time during the summary
data.cost_breakdown: the cost breakdown during the summary (dict mapping agent names to credit amounts)
data.cost_breakdown: the cost breakdown during the summary
#}
<h1 style="color: #5D23BB; font-size: 32px; font-weight: 600; margin-bottom: 25px; margin-top: 0;">
Weekly Summary
</h1>
<h1>Weekly Summary</h1>
<h2 style="color: #070629; font-size: 24px; font-weight: 500; margin-bottom: 20px;">
Your Agent Activity: {{ data.start_date.strftime('%B %-d') }} {{ data.end_date.strftime('%B %-d') }}
</h2>
<div style="background-color: #ffffff; border-radius: 8px; padding: 20px; margin-bottom: 25px;">
<ul style="list-style-type: disc; padding-left: 20px; margin: 0;">
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 8px;">
<strong>Total Executions:</strong> {{ data.total_executions }}
</li>
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 8px;">
<strong>Total Credits Used:</strong> {{ data.total_credits_used|format("%.2f") }}
</li>
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 8px;">
<strong>Total Execution Time:</strong> {{ data.total_execution_time|format("%.1f") }} seconds
</li>
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 8px;">
<strong>Successful Runs:</strong> {{ data.successful_runs }}
</li>
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 8px;">
<strong>Failed Runs:</strong> {{ data.failed_runs }}
</li>
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 8px;">
<strong>Average Execution Time:</strong> {{ data.average_execution_time|format("%.1f") }} seconds
</li>
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 8px;">
<strong>Most Used Agent:</strong> {{ data.most_used_agent }}
</li>
{% if data.cost_breakdown %}
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 8px;">
<strong>Cost Breakdown:</strong>
<ul style="list-style-type: disc; padding-left: 40px; margin-top: 8px;">
{% for agent_name, credits in data.cost_breakdown.items() %}
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 4px;">
{{ agent_name }}: {{ credits|format("%.2f") }} credits
</li>
{% endfor %}
</ul>
</li>
{% endif %}
</ul>
</div>
<p style="font-size: 16px; line-height: 165%; margin-top: 20px; margin-bottom: 10px;">
Thank you for being a part of the AutoGPT community! 🎉
</p>
<p style="font-size: 16px; line-height: 165%; margin-bottom: 0;">
Join the conversation on <a href="https://discord.gg/autogpt" style="color: #4285F4; text-decoration: underline;">Discord here</a>.
</p>
<p>Start Date: {{ data.start_date }}</p>
<p>End Date: {{ data.end_date }}</p>
<p>Total Credits Used: {{ data.total_credits_used }}</p>
<p>Total Executions: {{ data.total_executions }}</p>
<p>Most Used Agent: {{ data.most_used_agent }}</p>
<p>Total Execution Time: {{ data.total_execution_time }}</p>
<p>Successful Runs: {{ data.successful_runs }}</p>
<p>Failed Runs: {{ data.failed_runs }}</p>
<p>Average Execution Time: {{ data.average_execution_time }}</p>
<p>Cost Breakdown: {{ data.cost_breakdown }}</p>

View File

@@ -1,5 +1,6 @@
from backend.app import run_processes
from backend.executor.scheduler import Scheduler
from backend.notifications.notifications import NotificationManager
def main():
@@ -7,6 +8,7 @@ def main():
Run all the processes required for the AutoGPT-server Scheduling System.
"""
run_processes(
NotificationManager(),
Scheduler(),
)

View File

@@ -634,7 +634,7 @@ async def get_ayrshare_sso_url(
# SocialPlatform.TELEGRAM,
# SocialPlatform.GOOGLE_MY_BUSINESS,
# SocialPlatform.PINTEREST,
SocialPlatform.TIKTOK,
# SocialPlatform.TIKTOK,
# SocialPlatform.BLUESKY,
# SocialPlatform.SNAPCHAT,
# SocialPlatform.THREADS,

View File

@@ -0,0 +1,11 @@
from supabase import Client, create_client
from backend.util.settings import Settings
settings = Settings()
def get_supabase() -> Client:
return create_client(
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
)

View File

@@ -60,6 +60,21 @@ class UpdatePermissionsRequest(pydantic.BaseModel):
permissions: list[APIKeyPermission]
class Pagination(pydantic.BaseModel):
total_items: int = pydantic.Field(
description="Total number of items.", examples=[42]
)
total_pages: int = pydantic.Field(
description="Total number of pages.", examples=[2]
)
current_page: int = pydantic.Field(
description="Current_page page number.", examples=[1]
)
page_size: int = pydantic.Field(
description="Number of items per page.", examples=[25]
)
class RequestTopUp(pydantic.BaseModel):
credit_amount: int

View File

@@ -9,6 +9,11 @@ import fastapi.responses
import pydantic
import starlette.middleware.cors
import uvicorn
from autogpt_libs.feature_flag.client import (
initialize_launchdarkly,
shutdown_launchdarkly,
)
from autogpt_libs.logging.utils import generate_uvicorn_config
from fastapi.exceptions import RequestValidationError
from fastapi.routing import APIRoute
@@ -35,9 +40,6 @@ from backend.integrations.providers import ProviderName
from backend.server.external.api import external_app
from backend.server.middleware.security import SecurityHeadersMiddleware
from backend.util import json
from backend.util.cloud_storage import shutdown_cloud_storage_handler
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
from backend.util.service import UnhealthyServiceError
settings = backend.util.settings.Settings()
logger = logging.getLogger(__name__)
@@ -73,12 +75,6 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
with launch_darkly_context():
yield
try:
await shutdown_cloud_storage_handler()
except Exception as e:
logger.warning(f"Error shutting down cloud storage handler: {e}")
await backend.data.db.disconnect()
@@ -229,7 +225,7 @@ app.mount("/external-api", external_app)
@app.get(path="/health", tags=["health"], dependencies=[])
async def health():
if not backend.data.db.is_connected():
raise UnhealthyServiceError("Database is not connected")
raise RuntimeError("Database is not connected")
return {"status": "healthy"}
@@ -246,7 +242,7 @@ class AgentServer(backend.util.service.AppProcess):
server_app,
host=backend.util.settings.Config().agent_api_host,
port=backend.util.settings.Config().agent_api_port,
log_config=None,
log_config=generate_uvicorn_config(),
)
def cleanup(self):

View File

@@ -8,6 +8,7 @@ from typing import Annotated, Any, Sequence
import pydantic
import stripe
from autogpt_libs.auth.middleware import auth_middleware
from autogpt_libs.feature_flag.client import feature_flag
from fastapi import (
APIRouter,
Body,
@@ -84,7 +85,6 @@ from backend.server.utils import get_user_id
from backend.util.clients import get_scheduler_client
from backend.util.cloud_storage import get_cloud_storage_handler
from backend.util.exceptions import GraphValidationError, NotFoundError
from backend.util.feature_flag import feature_flag
from backend.util.settings import Settings
from backend.util.virus_scanner import scan_content_safe
@@ -458,16 +458,12 @@ async def stripe_webhook(request: Request):
event = stripe.Webhook.construct_event(
payload, sig_header, settings.secrets.stripe_webhook_secret
)
except ValueError as e:
except ValueError:
# Invalid payload
raise HTTPException(
status_code=400, detail=f"Invalid payload: {str(e) or type(e).__name__}"
)
except stripe.SignatureVerificationError as e:
raise HTTPException(status_code=400)
except stripe.SignatureVerificationError:
# Invalid signature
raise HTTPException(
status_code=400, detail=f"Invalid signature: {str(e) or type(e).__name__}"
)
raise HTTPException(status_code=400)
if (
event["type"] == "checkout.session.completed"
@@ -680,15 +676,7 @@ async def update_graph(
# Handle deactivation of the previously active version
await on_graph_deactivate(current_active_version, user_id=user_id)
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
new_graph_version_with_subgraphs = await graph_db.get_graph(
graph_id,
new_graph_version.version,
user_id=user_id,
include_subgraphs=True,
)
assert new_graph_version_with_subgraphs # make type checker happy
return new_graph_version_with_subgraphs
return new_graph_version
@v1_router.put(
@@ -1071,6 +1059,7 @@ async def get_api_key(
tags=["api-keys"],
dependencies=[Depends(auth_middleware)],
)
@feature_flag("api-keys-enabled")
async def delete_api_key(
key_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> Optional[APIKeyWithoutHash]:
@@ -1099,6 +1088,7 @@ async def delete_api_key(
tags=["api-keys"],
dependencies=[Depends(auth_middleware)],
)
@feature_flag("api-keys-enabled")
async def suspend_key(
key_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> Optional[APIKeyWithoutHash]:
@@ -1124,6 +1114,7 @@ async def suspend_key(
tags=["api-keys"],
dependencies=[Depends(auth_middleware)],
)
@feature_flag("api-keys-enabled")
async def update_permissions(
key_id: str,
request: UpdatePermissionsRequest,

View File

@@ -1 +0,0 @@
# AutoMod integration for content moderation

View File

@@ -1,353 +0,0 @@
import asyncio
import json
import logging
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
from backend.executor import DatabaseManagerAsyncClient
from pydantic import ValidationError
from backend.data.execution import ExecutionStatus
from backend.server.v2.AutoMod.models import (
AutoModRequest,
AutoModResponse,
ModerationConfig,
)
from backend.util.exceptions import ModerationError
from backend.util.feature_flag import Flag, is_feature_enabled
from backend.util.request import Requests
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
class AutoModManager:
def __init__(self):
self.config = self._load_config()
def _load_config(self) -> ModerationConfig:
"""Load AutoMod configuration from settings"""
settings = Settings()
return ModerationConfig(
enabled=settings.config.automod_enabled,
api_url=settings.config.automod_api_url,
api_key=settings.secrets.automod_api_key,
timeout=settings.config.automod_timeout,
retry_attempts=settings.config.automod_retry_attempts,
retry_delay=settings.config.automod_retry_delay,
fail_open=settings.config.automod_fail_open,
)
async def moderate_graph_execution_inputs(
self, db_client: "DatabaseManagerAsyncClient", graph_exec, timeout: int = 10
) -> Exception | None:
"""
Complete input moderation flow for graph execution
Returns: error_if_failed (None means success)
"""
if not self.config.enabled:
return None
# Check if AutoMod feature is enabled for this user
if not await is_feature_enabled(Flag.AUTOMOD, graph_exec.user_id):
logger.debug(f"AutoMod feature not enabled for user {graph_exec.user_id}")
return None
# Get graph model and collect all inputs
graph_model = await db_client.get_graph(
graph_exec.graph_id,
user_id=graph_exec.user_id,
version=graph_exec.graph_version,
)
if not graph_model or not graph_model.nodes:
return None
all_inputs = []
for node in graph_model.nodes:
if node.input_default:
all_inputs.extend(str(v) for v in node.input_default.values() if v)
if (masks := graph_exec.nodes_input_masks) and (mask := masks.get(node.id)):
all_inputs.extend(str(v) for v in mask.values() if v)
if not all_inputs:
return None
# Combine all content and moderate directly
content = " ".join(all_inputs)
# Run moderation
logger.warning(
f"Moderating inputs for graph execution {graph_exec.graph_exec_id}"
)
try:
moderation_passed = await self._moderate_content(
content,
{
"user_id": graph_exec.user_id,
"graph_id": graph_exec.graph_id,
"graph_exec_id": graph_exec.graph_exec_id,
"moderation_type": "execution_input",
},
)
if not moderation_passed:
logger.warning(
f"Moderation failed for graph execution {graph_exec.graph_exec_id}"
)
# Update node statuses for frontend display before raising error
await self._update_failed_nodes_for_moderation(
db_client, graph_exec.graph_exec_id, "input"
)
return ModerationError(
message="Execution failed due to input content moderation",
user_id=graph_exec.user_id,
graph_exec_id=graph_exec.graph_exec_id,
moderation_type="input",
)
return None
except asyncio.TimeoutError:
logger.warning(
f"Input moderation timed out for graph execution {graph_exec.graph_exec_id}, bypassing moderation"
)
return None # Bypass moderation on timeout
except Exception as e:
logger.warning(f"Input moderation execution failed: {e}")
return ModerationError(
message="Execution failed due to input content moderation error",
user_id=graph_exec.user_id,
graph_exec_id=graph_exec.graph_exec_id,
moderation_type="input",
)
async def moderate_graph_execution_outputs(
self,
db_client: "DatabaseManagerAsyncClient",
graph_exec_id: str,
user_id: str,
graph_id: str,
timeout: int = 10,
) -> Exception | None:
"""
Complete output moderation flow for graph execution
Returns: error_if_failed (None means success)
"""
if not self.config.enabled:
return None
# Check if AutoMod feature is enabled for this user
if not await is_feature_enabled(Flag.AUTOMOD, user_id):
logger.debug(f"AutoMod feature not enabled for user {user_id}")
return None
# Get completed executions and collect outputs
completed_executions = await db_client.get_node_executions(
graph_exec_id, statuses=[ExecutionStatus.COMPLETED], include_exec_data=True
)
if not completed_executions:
return None
all_outputs = []
for exec_entry in completed_executions:
if exec_entry.output_data:
all_outputs.extend(str(v) for v in exec_entry.output_data.values() if v)
if not all_outputs:
return None
# Combine all content and moderate directly
content = " ".join(all_outputs)
# Run moderation
logger.warning(f"Moderating outputs for graph execution {graph_exec_id}")
try:
moderation_passed = await self._moderate_content(
content,
{
"user_id": user_id,
"graph_id": graph_id,
"graph_exec_id": graph_exec_id,
"moderation_type": "execution_output",
},
)
if not moderation_passed:
logger.warning(f"Moderation failed for graph execution {graph_exec_id}")
# Update node statuses for frontend display before raising error
await self._update_failed_nodes_for_moderation(
db_client, graph_exec_id, "output"
)
return ModerationError(
message="Execution failed due to output content moderation",
user_id=user_id,
graph_exec_id=graph_exec_id,
moderation_type="output",
)
return None
except asyncio.TimeoutError:
logger.warning(
f"Output moderation timed out for graph execution {graph_exec_id}, bypassing moderation"
)
return None # Bypass moderation on timeout
except Exception as e:
logger.warning(f"Output moderation execution failed: {e}")
return ModerationError(
message="Execution failed due to output content moderation error",
user_id=user_id,
graph_exec_id=graph_exec_id,
moderation_type="output",
)
async def _update_failed_nodes_for_moderation(
self,
db_client: "DatabaseManagerAsyncClient",
graph_exec_id: str,
moderation_type: Literal["input", "output"],
):
"""Update node execution statuses for frontend display when moderation fails"""
# Import here to avoid circular imports
from backend.executor.manager import send_async_execution_update
if moderation_type == "input":
# For input moderation, mark queued/running/incomplete nodes as failed
target_statuses = [
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
ExecutionStatus.INCOMPLETE,
]
else:
# For output moderation, mark completed nodes as failed
target_statuses = [ExecutionStatus.COMPLETED]
# Get the executions that need to be updated
executions_to_update = await db_client.get_node_executions(
graph_exec_id, statuses=target_statuses, include_exec_data=True
)
if not executions_to_update:
return
# Prepare database update tasks
exec_updates = []
for exec_entry in executions_to_update:
# Collect all input and output names to clear
cleared_inputs = {}
cleared_outputs = {}
if exec_entry.input_data:
for name in exec_entry.input_data.keys():
cleared_inputs[name] = ["Failed due to content moderation"]
if exec_entry.output_data:
for name in exec_entry.output_data.keys():
cleared_outputs[name] = ["Failed due to content moderation"]
# Add update task to list
exec_updates.append(
db_client.update_node_execution_status(
exec_entry.node_exec_id,
status=ExecutionStatus.FAILED,
stats={
"error": "Failed due to content moderation",
"cleared_inputs": cleared_inputs,
"cleared_outputs": cleared_outputs,
},
)
)
# Execute all database updates in parallel
updated_execs = await asyncio.gather(*exec_updates)
# Send all websocket updates in parallel
await asyncio.gather(
*[
send_async_execution_update(updated_exec)
for updated_exec in updated_execs
]
)
async def _moderate_content(self, content: str, metadata: dict[str, Any]) -> bool:
"""Moderate content using AutoMod API
Returns:
True: Content approved or timeout occurred
False: Content rejected by moderation
Raises:
asyncio.TimeoutError: When moderation times out (should be bypassed)
"""
try:
request_data = AutoModRequest(
type="text",
content=content,
metadata=metadata,
)
response = await self._make_request(request_data)
if response.success and response.status == "approved":
logger.debug(
f"Content approved for {metadata.get('graph_exec_id', 'unknown')}"
)
return True
else:
reasons = [r.reason for r in response.moderation_results if r.reason]
error_msg = f"Content rejected by AutoMod: {'; '.join(reasons)}"
logger.warning(f"Content rejected: {error_msg}")
return False
except asyncio.TimeoutError:
# Re-raise timeout to be handled by calling methods
logger.warning(
f"AutoMod API timeout for {metadata.get('graph_exec_id', 'unknown')}"
)
raise
except Exception as e:
logger.error(f"AutoMod moderation error: {e}")
return self.config.fail_open
async def _make_request(self, request_data: AutoModRequest) -> AutoModResponse:
"""Make HTTP request to AutoMod API using the standard request utility"""
url = f"{self.config.api_url}/moderate"
headers = {
"Content-Type": "application/json",
"X-API-Key": self.config.api_key.strip(),
}
# Create requests instance with timeout and retry configuration
requests = Requests(
extra_headers=headers,
retry_max_wait=float(self.config.timeout),
)
try:
response = await requests.post(
url, json=request_data.model_dump(), timeout=self.config.timeout
)
response_data = response.json()
return AutoModResponse.model_validate(response_data)
except asyncio.TimeoutError:
# Re-raise timeout error to be caught by _moderate_content
raise
except (json.JSONDecodeError, ValidationError) as e:
raise Exception(f"Invalid response from AutoMod API: {e}")
except Exception as e:
# Check if this is an aiohttp timeout that we should convert
if "timeout" in str(e).lower():
raise asyncio.TimeoutError(f"AutoMod API request timed out: {e}")
raise Exception(f"AutoMod API request failed: {e}")
# Global instance
automod_manager = AutoModManager()

View File

@@ -1,57 +0,0 @@
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class AutoModRequest(BaseModel):
"""Request model for AutoMod API"""
type: str = Field(..., description="Content type - 'text', 'image', 'video'")
content: str = Field(..., description="The content to moderate")
metadata: Optional[Dict[str, Any]] = Field(
default=None, description="Additional context about the content"
)
class ModerationResult(BaseModel):
"""Individual moderation result"""
decision: str = Field(
..., description="Moderation decision: 'approved', 'rejected', 'flagged'"
)
reason: Optional[str] = Field(default=None, description="Reason for the decision")
class AutoModResponse(BaseModel):
"""Response model for AutoMod API"""
success: bool = Field(..., description="Whether the request was successful")
status: str = Field(
..., description="Overall status: 'approved', 'rejected', 'flagged', 'pending'"
)
moderation_results: List[ModerationResult] = Field(
default_factory=list, description="List of moderation results"
)
class ModerationConfig(BaseModel):
"""Configuration for AutoMod integration"""
enabled: bool = Field(default=True, description="Whether moderation is enabled")
api_url: str = Field(default="", description="AutoMod API base URL")
api_key: str = Field(..., description="AutoMod API key")
timeout: int = Field(default=30, description="Request timeout in seconds")
retry_attempts: int = Field(default=3, description="Number of retry attempts")
retry_delay: float = Field(
default=1.0, description="Delay between retries in seconds"
)
fail_open: bool = Field(
default=False,
description="If True, allow execution to continue if moderation fails",
)
moderate_inputs: bool = Field(
default=True, description="Whether to moderate block inputs"
)
moderate_outputs: bool = Field(
default=True, description="Whether to moderate block outputs"
)

View File

@@ -14,7 +14,7 @@ import backend.server.v2.admin.credit_admin_routes as credit_admin_routes
import backend.server.v2.admin.model as admin_model
from backend.data.model import UserTransaction
from backend.server.conftest import ADMIN_USER_ID, TARGET_USER_ID
from backend.util.models import Pagination
from backend.server.model import Pagination
app = fastapi.FastAPI()
app.include_router(credit_admin_routes.router)

View File

@@ -1,7 +1,7 @@
from pydantic import BaseModel
from backend.data.model import UserTransaction
from backend.util.models import Pagination
from backend.server.model import Pagination
class UserHistoryResponse(BaseModel):

View File

@@ -9,6 +9,7 @@ import prisma.models
import prisma.types
import backend.data.graph as graph_db
import backend.server.model
import backend.server.v2.library.model as library_model
import backend.server.v2.store.exceptions as store_exceptions
import backend.server.v2.store.image_gen as store_image_gen
@@ -22,7 +23,6 @@ from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.exceptions import NotFoundError
from backend.util.json import SafeJson
from backend.util.models import Pagination
from backend.util.settings import Config
logger = logging.getLogger(__name__)
@@ -131,7 +131,7 @@ async def list_library_agents(
# Return the response with only valid agents
return library_model.LibraryAgentResponse(
agents=valid_library_agents,
pagination=Pagination(
pagination=backend.server.model.Pagination(
total_items=agent_count,
total_pages=(agent_count + page_size - 1) // page_size,
current_page=page,
@@ -241,11 +241,7 @@ async def get_library_agent_by_graph_id(
)
if not agent:
return None
assert agent.AgentGraph # make type checker happy
# Include sub-graphs so we can make a full credentials input schema
sub_graphs = await graph_db.get_sub_graphs(agent.AgentGraph)
return library_model.LibraryAgent.from_db(agent, sub_graphs=sub_graphs)
return library_model.LibraryAgent.from_db(agent)
except prisma.errors.PrismaError as e:
logger.error(f"Database error fetching library agent by graph ID: {e}")
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
@@ -629,7 +625,7 @@ async def list_presets(
return library_model.LibraryAgentPresetResponse(
presets=presets,
pagination=Pagination(
pagination=backend.server.model.Pagination(
total_items=total_items,
total_pages=total_pages,
current_page=page,

View File

@@ -8,9 +8,9 @@ import pydantic
import backend.data.block as block_model
import backend.data.graph as graph_model
import backend.server.model as server_model
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
from backend.integrations.providers import ProviderName
from backend.util.models import Pagination
class LibraryAgentStatus(str, Enum):
@@ -213,7 +213,7 @@ class LibraryAgentResponse(pydantic.BaseModel):
"""Response schema for a list of library agents and pagination info."""
agents: list[LibraryAgent]
pagination: Pagination
pagination: server_model.Pagination
class LibraryAgentPresetCreatable(pydantic.BaseModel):
@@ -317,7 +317,7 @@ class LibraryAgentPresetResponse(pydantic.BaseModel):
"""Response schema for a list of agent presets and pagination info."""
presets: list[LibraryAgentPreset]
pagination: Pagination
pagination: server_model.Pagination
class LibraryAgentFilter(str, Enum):

View File

@@ -7,9 +7,9 @@ import pytest
import pytest_mock
from pytest_snapshot.plugin import Snapshot
import backend.server.model as server_model
import backend.server.v2.library.model as library_model
from backend.server.v2.library.routes import router as library_router
from backend.util.models import Pagination
app = fastapi.FastAPI()
app.include_router(library_router)
@@ -77,7 +77,7 @@ async def test_get_library_agents_success(
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
),
],
pagination=Pagination(
pagination=server_model.Pagination(
total_items=2, total_pages=1, current_page=1, page_size=50
),
)

View File

@@ -466,8 +466,6 @@ async def get_store_submissions(
# internal_comments omitted for regular users
reviewed_at=sub.reviewed_at,
changes_summary=sub.changes_summary,
video_url=sub.video_url,
categories=sub.categories,
)
submission_models.append(submission_model)
@@ -548,7 +546,7 @@ async def create_store_submission(
description: str = "",
sub_heading: str = "",
categories: list[str] = [],
changes_summary: str | None = "Initial Submission",
changes_summary: str = "Initial Submission",
) -> backend.server.v2.store.model.StoreSubmission:
"""
Create the first (and only) store listing and thus submission as a normal user
@@ -687,160 +685,6 @@ async def create_store_submission(
) from e
async def edit_store_submission(
user_id: str,
store_listing_version_id: str,
name: str,
video_url: str | None = None,
image_urls: list[str] = [],
description: str = "",
sub_heading: str = "",
categories: list[str] = [],
changes_summary: str | None = "Update submission",
) -> backend.server.v2.store.model.StoreSubmission:
"""
Edit an existing store listing submission.
Args:
user_id: ID of the authenticated user editing the submission
store_listing_version_id: ID of the store listing version to edit
agent_id: ID of the agent being submitted
agent_version: Version of the agent being submitted
slug: URL slug for the listing (only changeable for PENDING submissions)
name: Name of the agent
video_url: Optional URL to video demo
image_urls: List of image URLs for the listing
description: Description of the agent
sub_heading: Optional sub-heading for the agent
categories: List of categories for the agent
changes_summary: Summary of changes made in this submission
Returns:
StoreSubmission: The updated store submission
Raises:
SubmissionNotFoundError: If the submission is not found
UnauthorizedError: If the user doesn't own the submission
InvalidOperationError: If trying to edit a submission that can't be edited
"""
try:
# Get the current version and verify ownership
current_version = await prisma.models.StoreListingVersion.prisma().find_first(
where=prisma.types.StoreListingVersionWhereInput(
id=store_listing_version_id
),
include={
"StoreListing": {
"include": {
"Versions": {"order_by": {"version": "desc"}, "take": 1}
}
}
},
)
if not current_version:
raise backend.server.v2.store.exceptions.SubmissionNotFoundError(
f"Store listing version not found: {store_listing_version_id}"
)
# Verify the user owns this submission
if (
not current_version.StoreListing
or current_version.StoreListing.owningUserId != user_id
):
raise backend.server.v2.store.exceptions.UnauthorizedError(
f"User {user_id} does not own submission {store_listing_version_id}"
)
# Currently we are not allowing user to update the agent associated with a submission
# If we allow it in future, then we need a check here to verify the agent belongs to this user.
# Check if we can edit this submission
if current_version.submissionStatus == prisma.enums.SubmissionStatus.REJECTED:
raise backend.server.v2.store.exceptions.InvalidOperationError(
"Cannot edit a rejected submission"
)
# For APPROVED submissions, we need to create a new version
if current_version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
# Create a new version for the existing listing
return await create_store_version(
user_id=user_id,
agent_id=current_version.agentGraphId,
agent_version=current_version.agentGraphVersion,
store_listing_id=current_version.storeListingId,
name=name,
video_url=video_url,
image_urls=image_urls,
description=description,
sub_heading=sub_heading,
categories=categories,
changes_summary=changes_summary,
)
# For PENDING submissions, we can update the existing version
elif current_version.submissionStatus == prisma.enums.SubmissionStatus.PENDING:
# Update the existing version
updated_version = await prisma.models.StoreListingVersion.prisma().update(
where={"id": store_listing_version_id},
data=prisma.types.StoreListingVersionUpdateInput(
name=name,
videoUrl=video_url,
imageUrls=image_urls,
description=description,
categories=categories,
subHeading=sub_heading,
changesSummary=changes_summary,
),
)
logger.debug(
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
)
if not updated_version:
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to update store listing version"
)
return backend.server.v2.store.model.StoreSubmission(
agent_id=current_version.agentGraphId,
agent_version=current_version.agentGraphVersion,
name=name,
sub_heading=sub_heading,
slug=current_version.StoreListing.slug,
description=description,
image_urls=image_urls,
date_submitted=updated_version.submittedAt or updated_version.createdAt,
status=updated_version.submissionStatus,
runs=0,
rating=0.0,
store_listing_version_id=updated_version.id,
changes_summary=changes_summary,
video_url=video_url,
categories=categories,
version=updated_version.version,
)
else:
raise backend.server.v2.store.exceptions.InvalidOperationError(
f"Cannot edit submission with status: {current_version.submissionStatus}"
)
except (
backend.server.v2.store.exceptions.SubmissionNotFoundError,
backend.server.v2.store.exceptions.UnauthorizedError,
backend.server.v2.store.exceptions.AgentNotFoundError,
backend.server.v2.store.exceptions.ListingExistsError,
backend.server.v2.store.exceptions.InvalidOperationError,
):
raise
except prisma.errors.PrismaError as e:
logger.error(f"Database error editing store submission: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to edit store submission"
) from e
async def create_store_version(
user_id: str,
agent_id: str,
@@ -852,7 +696,7 @@ async def create_store_version(
description: str = "",
sub_heading: str = "",
categories: list[str] = [],
changes_summary: str | None = "Initial submission",
changes_summary: str = "Update Submission",
) -> backend.server.v2.store.model.StoreSubmission:
"""
Create a new version for an existing store listing

View File

@@ -94,15 +94,3 @@ class SubmissionNotFoundError(StoreError):
"""Raised when a submission is not found"""
pass
class InvalidOperationError(StoreError):
"""Raised when an operation is not valid for the current state"""
pass
class UnauthorizedError(StoreError):
"""Raised when a user is not authorized to perform an action"""
pass

View File

@@ -33,30 +33,30 @@ async def check_media_exists(user_id: str, filename: str) -> str | None:
if not settings.config.media_gcs_bucket_name:
raise MissingConfigError("GCS media bucket is not configured")
async with async_storage.Storage() as async_client:
bucket_name = settings.config.media_gcs_bucket_name
async_client = async_storage.Storage()
bucket_name = settings.config.media_gcs_bucket_name
# Check images
image_path = f"users/{user_id}/images/{filename}"
try:
await async_client.download_metadata(bucket_name, image_path)
# If we get here, the file exists - construct public URL
return f"https://storage.googleapis.com/{bucket_name}/{image_path}"
except Exception:
# File doesn't exist, continue to check videos
pass
# Check images
image_path = f"users/{user_id}/images/{filename}"
try:
await async_client.download_metadata(bucket_name, image_path)
# If we get here, the file exists - construct public URL
return f"https://storage.googleapis.com/{bucket_name}/{image_path}"
except Exception:
# File doesn't exist, continue to check videos
pass
# Check videos
video_path = f"users/{user_id}/videos/{filename}"
try:
await async_client.download_metadata(bucket_name, video_path)
# If we get here, the file exists - construct public URL
return f"https://storage.googleapis.com/{bucket_name}/{video_path}"
except Exception:
# File doesn't exist
pass
# Check videos
video_path = f"users/{user_id}/videos/{filename}"
try:
await async_client.download_metadata(bucket_name, video_path)
# If we get here, the file exists - construct public URL
return f"https://storage.googleapis.com/{bucket_name}/{video_path}"
except Exception:
# File doesn't exist
pass
return None
return None
async def upload_media(
@@ -177,24 +177,22 @@ async def upload_media(
storage_path = f"users/{user_id}/{media_type}/{unique_filename}"
try:
async with async_storage.Storage() as async_client:
bucket_name = settings.config.media_gcs_bucket_name
async_client = async_storage.Storage()
bucket_name = settings.config.media_gcs_bucket_name
file_bytes = await file.read()
await scan_content_safe(file_bytes, filename=unique_filename)
file_bytes = await file.read()
await scan_content_safe(file_bytes, filename=unique_filename)
# Upload using pure async client
await async_client.upload(
bucket_name, storage_path, file_bytes, content_type=content_type
)
# Upload using pure async client
await async_client.upload(
bucket_name, storage_path, file_bytes, content_type=content_type
)
# Construct public URL
public_url = (
f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
)
# Construct public URL
public_url = f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
logger.info(f"Successfully uploaded file to: {storage_path}")
return public_url
logger.info(f"Successfully uploaded file to: {storage_path}")
return public_url
except Exception as e:
logger.error(f"GCS storage error: {str(e)}")

View File

@@ -26,10 +26,6 @@ def mock_storage_client(mocker):
mock_client = AsyncMock()
mock_client.upload = AsyncMock()
# Mock context manager methods
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
# Mock the constructor to return our mock client
mocker.patch(
"backend.server.v2.store.media.async_storage.Storage", return_value=mock_client

View File

@@ -4,7 +4,7 @@ from typing import List
import prisma.enums
import pydantic
from backend.util.models import Pagination
from backend.server.model import Pagination
class MyAgent(pydantic.BaseModel):
@@ -115,9 +115,11 @@ class StoreSubmission(pydantic.BaseModel):
reviewed_at: datetime.datetime | None = None
changes_summary: str | None = None
# Additional fields for editing
video_url: str | None = None
categories: list[str] = []
reviewer_id: str | None = None
review_comments: str | None = None # External comments visible to creator
internal_comments: str | None = None # Private notes for admin use only
reviewed_at: datetime.datetime | None = None
changes_summary: str | None = None
class StoreSubmissionsResponse(pydantic.BaseModel):
@@ -159,16 +161,6 @@ class StoreSubmissionRequest(pydantic.BaseModel):
changes_summary: str | None = None
class StoreSubmissionEditRequest(pydantic.BaseModel):
name: str
sub_heading: str
video_url: str | None = None
image_urls: list[str] = []
description: str = ""
categories: list[str] = []
changes_summary: str | None = None
class ProfileDetails(pydantic.BaseModel):
name: str
username: str

View File

@@ -564,47 +564,6 @@ async def create_submission(
)
@router.put(
"/submissions/{store_listing_version_id}",
summary="Edit store submission",
tags=["store", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
response_model=backend.server.v2.store.model.StoreSubmission,
)
async def edit_submission(
store_listing_version_id: str,
submission_request: backend.server.v2.store.model.StoreSubmissionEditRequest,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
):
"""
Edit an existing store listing submission.
Args:
store_listing_version_id (str): ID of the store listing version to edit
submission_request (StoreSubmissionRequest): The updated submission details
user_id (str): ID of the authenticated user editing the listing
Returns:
StoreSubmission: The updated store submission
Raises:
HTTPException: If there is an error editing the submission
"""
return await backend.server.v2.store.db.edit_store_submission(
user_id=user_id,
store_listing_version_id=store_listing_version_id,
name=submission_request.name,
video_url=submission_request.video_url,
image_urls=submission_request.image_urls,
description=submission_request.description,
sub_heading=submission_request.sub_heading,
categories=submission_request.categories,
changes_summary=submission_request.changes_summary,
)
@router.post(
"/submissions/media",
summary="Upload submission media",

View File

@@ -551,8 +551,6 @@ def test_get_submissions_success(
agent_version=1,
sub_heading="Test agent subheading",
slug="test-agent",
video_url="test.mp4",
categories=["test-category"],
)
],
pagination=backend.server.v2.store.model.Pagination(

View File

@@ -6,6 +6,7 @@ from typing import Protocol
import pydantic
import uvicorn
from autogpt_libs.auth import parse_jwt_token
from autogpt_libs.logging.utils import generate_uvicorn_config
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
from starlette.middleware.cors import CORSMiddleware
@@ -308,7 +309,7 @@ class WebsocketServer(AppProcess):
server_app,
host=Config().websocket_server_host,
port=Config().websocket_server_port,
log_config=None,
log_config=generate_uvicorn_config(),
)
def cleanup(self):

View File

@@ -1,12 +1,13 @@
from pathlib import Path
from prisma.models import User
from backend.blocks.basic import StoreValueBlock
from backend.blocks.block import BlockInstallationBlock
from backend.blocks.http import SendWebRequestBlock
from backend.blocks.llm import AITextGeneratorBlock
from backend.blocks.text import ExtractTextInformationBlock, FillTextTemplateBlock
from backend.data.graph import Graph, Link, Node, create_graph
from backend.data.model import User
from backend.data.user import get_or_create_user
from backend.util.test import SpinTestServer, wait_execution

View File

@@ -1,8 +1,9 @@
from prisma.models import User
from backend.blocks.llm import AIStructuredResponseGeneratorBlock
from backend.blocks.reddit import GetRedditPostsBlock, PostRedditCommentBlock
from backend.blocks.text import FillTextTemplateBlock, MatchTextPatternBlock
from backend.data.graph import Graph, Link, Node, create_graph
from backend.data.model import User
from backend.data.user import get_or_create_user
from backend.util.test import SpinTestServer, wait_execution

View File

@@ -1,9 +1,10 @@
from prisma.models import User
from backend.blocks.basic import StoreValueBlock
from backend.blocks.io import AgentInputBlock
from backend.blocks.text import FillTextTemplateBlock
from backend.data import graph
from backend.data.graph import create_graph
from backend.data.model import User
from backend.data.user import get_or_create_user
from backend.util.test import SpinTestServer, wait_execution

View File

@@ -2,18 +2,11 @@
Centralized service client helpers with thread caching.
"""
from functools import cache
from typing import TYPE_CHECKING
from autogpt_libs.utils.cache import async_cache, thread_cached
from backend.util.settings import Settings
settings = Settings()
from autogpt_libs.utils.cache import thread_cached
if TYPE_CHECKING:
from supabase import AClient, Client
from backend.data.execution import (
AsyncRedisExecutionEventBus,
RedisExecutionEventBus,
@@ -116,29 +109,6 @@ def get_integration_credentials_store() -> "IntegrationCredentialsStore":
return IntegrationCredentialsStore()
# ============ Supabase Clients ============ #
@cache
def get_supabase() -> "Client":
"""Get a process-cached synchronous Supabase client instance."""
from supabase import create_client
return create_client(
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
)
@async_cache
async def get_async_supabase() -> "AClient":
"""Get a process-cached asynchronous Supabase client instance."""
from supabase import create_async_client
return await create_async_client(
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
)
# ============ Notification Queue Helpers ============ #

View File

@@ -46,20 +46,6 @@ class CloudStorageHandler:
self._async_gcs_client = async_gcs_storage.Storage()
return self._async_gcs_client
async def close(self):
"""Close all client connections properly."""
if self._async_gcs_client is not None:
await self._async_gcs_client.close()
self._async_gcs_client = None
async def __aenter__(self):
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.close()
def _get_sync_gcs_client(self):
"""Lazy initialization of sync GCS client (only for signed URLs)."""
if self._sync_gcs_client is None:
@@ -521,17 +507,6 @@ async def get_cloud_storage_handler() -> CloudStorageHandler:
return _cloud_storage_handler
async def shutdown_cloud_storage_handler():
"""Properly shutdown the global cloud storage handler."""
global _cloud_storage_handler
if _cloud_storage_handler is not None:
async with _handler_lock:
if _cloud_storage_handler is not None:
await _cloud_storage_handler.close()
_cloud_storage_handler = None
async def cleanup_expired_files_async() -> int:
"""
Clean up expired files from cloud storage.

View File

@@ -33,33 +33,6 @@ class InsufficientBalanceError(ValueError):
return self.message
class ModerationError(ValueError):
"""Content moderation failure during execution"""
user_id: str
message: str
graph_exec_id: str
moderation_type: str
def __init__(
self,
message: str,
user_id: str,
graph_exec_id: str,
moderation_type: str = "content",
):
super().__init__(message)
self.args = (message, user_id, graph_exec_id, moderation_type)
self.message = message
self.user_id = user_id
self.graph_exec_id = graph_exec_id
self.moderation_type = moderation_type
def __str__(self):
"""Used to display the error message in the frontend, because we str() the error when sending the execution update"""
return self.message
class GraphValidationError(ValueError):
"""Structured validation error for graph validation failures"""
@@ -71,10 +44,4 @@ class GraphValidationError(ValueError):
self.node_errors = node_errors or {}
def __str__(self):
return self.message + "".join(
[
f"\n {node_id}:"
+ "".join([f"\n {k}: {e}" for k, e in errors.items()])
for node_id, errors in self.node_errors.items()
]
)
return self.message

View File

@@ -1,257 +0,0 @@
import contextlib
import logging
from enum import Enum
from functools import wraps
from typing import Any, Awaitable, Callable, TypeVar
import ldclient
from autogpt_libs.utils.cache import async_ttl_cache
from fastapi import HTTPException
from ldclient import Context, LDClient
from ldclient.config import Config
from typing_extensions import ParamSpec
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
# Load settings at module level
settings = Settings()
P = ParamSpec("P")
T = TypeVar("T")
_is_initialized = False
class Flag(str, Enum):
"""
Centralized enum for all LaunchDarkly feature flags.
Add new flags here to ensure consistency across the codebase.
"""
AUTOMOD = "AutoMod"
AI_ACTIVITY_STATUS = "ai-agent-execution-summary"
BETA_BLOCKS = "beta-blocks"
AGENT_ACTIVITY = "agent-activity"
def get_client() -> LDClient:
"""Get the LaunchDarkly client singleton."""
if not _is_initialized:
initialize_launchdarkly()
return ldclient.get()
def initialize_launchdarkly() -> None:
sdk_key = settings.secrets.launch_darkly_sdk_key
logger.debug(
f"Initializing LaunchDarkly with SDK key: {'present' if sdk_key else 'missing'}"
)
if not sdk_key:
logger.warning("LaunchDarkly SDK key not configured")
return
config = Config(sdk_key)
ldclient.set_config(config)
if ldclient.get().is_initialized():
global _is_initialized
_is_initialized = True
logger.info("LaunchDarkly client initialized successfully")
else:
logger.error("LaunchDarkly client failed to initialize")
def shutdown_launchdarkly() -> None:
"""Shutdown the LaunchDarkly client."""
if ldclient.get().is_initialized():
ldclient.get().close()
logger.info("LaunchDarkly client closed successfully")
@async_ttl_cache(maxsize=1000, ttl_seconds=86400) # 1000 entries, 24 hours TTL
async def _fetch_user_context_data(user_id: str) -> Context:
"""
Fetch user context for LaunchDarkly from Supabase.
Args:
user_id: The user ID to fetch data for
Returns:
LaunchDarkly Context object
"""
builder = Context.builder(user_id).kind("user").anonymous(True)
try:
from backend.util.clients import get_supabase
# If we have user data, update context
response = get_supabase().auth.admin.get_user_by_id(user_id)
if response and response.user:
user = response.user
builder.anonymous(False)
if user.role:
builder.set("role", user.role)
# It's weird, I know, but it is what it is.
builder.set("custom", {"role": user.role})
if user.email:
builder.set("email", user.email)
builder.set("email_domain", user.email.split("@")[-1])
except Exception as e:
logger.warning(f"Failed to fetch user context for {user_id}: {e}")
return builder.build()
async def get_feature_flag_value(
flag_key: str,
user_id: str,
default: Any = None,
) -> Any:
"""
Get the raw value of a feature flag for a user.
This is the generic function that returns the actual flag value,
which could be a boolean, string, number, or JSON object.
Args:
flag_key: The LaunchDarkly feature flag key
user_id: The user ID to evaluate the flag for
default: Default value if LaunchDarkly is unavailable or flag evaluation fails
Returns:
The flag value from LaunchDarkly
"""
try:
client = get_client()
# Check if client is initialized
if not client.is_initialized():
logger.debug(
f"LaunchDarkly not initialized, using default={default} for {flag_key}"
)
return default
# Get user context from Supabase
context = await _fetch_user_context_data(user_id)
# Evaluate flag
result = client.variation(flag_key, context, default)
logger.debug(
f"Feature flag {flag_key} for user {user_id}: {result} (type: {type(result).__name__})"
)
return result
except Exception as e:
logger.warning(
f"LaunchDarkly flag evaluation failed for {flag_key}: {e}, using default={default}"
)
return default
async def is_feature_enabled(
flag_key: Flag,
user_id: str,
default: bool = False,
) -> bool:
"""
Check if a feature flag is enabled for a user.
Args:
flag_key: The Flag enum value
user_id: The user ID to evaluate the flag for
default: Default value if LaunchDarkly is unavailable or flag evaluation fails
Returns:
True if feature is enabled, False otherwise
"""
result = await get_feature_flag_value(flag_key.value, user_id, default)
# If the result is already a boolean, return it
if isinstance(result, bool):
return result
# Log a warning if the flag is not returning a boolean
logger.warning(
f"Feature flag {flag_key} returned non-boolean value: {result} (type: {type(result).__name__}). "
f"This flag should be configured as a boolean in LaunchDarkly. Using default={default}"
)
# Return the default if we get a non-boolean value
# This prevents objects from being incorrectly treated as True
return default
def feature_flag(
flag_key: str,
default: bool = False,
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
"""
Decorator for async feature flag protected endpoints.
Args:
flag_key: The LaunchDarkly feature flag key
default: Default value if flag evaluation fails
Returns:
Decorator that only works with async functions
"""
def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
@wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
user_id = kwargs.get("user_id")
if not user_id:
raise ValueError("user_id is required")
if not get_client().is_initialized():
logger.warning(
f"LaunchDarkly not initialized, using default={default}"
)
is_enabled = default
else:
# Use the internal function directly since we have a raw string flag_key
flag_value = await get_feature_flag_value(
flag_key, str(user_id), default
)
# Ensure we treat flag value as boolean
if isinstance(flag_value, bool):
is_enabled = flag_value
else:
# Log warning and use default for non-boolean values
logger.warning(
f"Feature flag {flag_key} returned non-boolean value: {flag_value} (type: {type(flag_value).__name__}). "
f"Using default={default}"
)
is_enabled = default
if not is_enabled:
raise HTTPException(status_code=404, detail="Feature not available")
return await func(*args, **kwargs)
except Exception as e:
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
raise
return async_wrapper
return decorator
@contextlib.contextmanager
def mock_flag_variation(flag_key: str, return_value: Any):
"""Context manager for testing feature flags."""
original_variation = get_client().variation
get_client().variation = lambda key, context, default: (
return_value if key == flag_key else original_variation(key, context, default)
)
try:
yield
finally:
get_client().variation = original_variation

View File

@@ -1,113 +0,0 @@
import pytest
from fastapi import HTTPException
from ldclient import LDClient
from backend.util.feature_flag import (
Flag,
feature_flag,
is_feature_enabled,
mock_flag_variation,
)
@pytest.fixture
def ld_client(mocker):
client = mocker.Mock(spec=LDClient)
mocker.patch("ldclient.get", return_value=client)
client.is_initialized.return_value = True
return client
@pytest.mark.asyncio
async def test_feature_flag_enabled(ld_client):
ld_client.variation.return_value = True
@feature_flag("test-flag")
async def test_function(user_id: str):
return "success"
result = await test_function(user_id="test-user")
assert result == "success"
ld_client.variation.assert_called_once()
@pytest.mark.asyncio
async def test_feature_flag_unauthorized_response(ld_client):
ld_client.variation.return_value = False
@feature_flag("test-flag")
async def test_function(user_id: str):
return "success"
with pytest.raises(HTTPException) as exc_info:
await test_function(user_id="test-user")
assert exc_info.value.status_code == 404
def test_mock_flag_variation(ld_client):
with mock_flag_variation("test-flag", True):
assert ld_client.variation("test-flag", None, False) is True
with mock_flag_variation("test-flag", False):
assert ld_client.variation("test-flag", None, True) is False
@pytest.mark.asyncio
async def test_is_feature_enabled(ld_client):
"""Test the is_feature_enabled helper function."""
ld_client.is_initialized.return_value = True
ld_client.variation.return_value = True
result = await is_feature_enabled(Flag.AUTOMOD, "user123", default=False)
assert result is True
ld_client.variation.assert_called_once()
call_args = ld_client.variation.call_args
assert call_args[0][0] == "AutoMod" # flag_key
assert call_args[0][2] is False # default value
@pytest.mark.asyncio
async def test_is_feature_enabled_not_initialized(ld_client):
"""Test is_feature_enabled when LaunchDarkly is not initialized."""
ld_client.is_initialized.return_value = False
result = await is_feature_enabled(Flag.AGENT_ACTIVITY, "user123", default=True)
assert result is True # Should return default
ld_client.variation.assert_not_called()
@pytest.mark.asyncio
async def test_is_feature_enabled_exception(mocker):
"""Test is_feature_enabled when get_client() raises an exception."""
mocker.patch(
"backend.util.feature_flag.get_client",
side_effect=Exception("Client error"),
)
result = await is_feature_enabled(Flag.AGENT_ACTIVITY, "user123", default=True)
assert result is True # Should return default
def test_flag_enum_values():
"""Test that Flag enum has expected values."""
assert Flag.AUTOMOD == "AutoMod"
assert Flag.AI_ACTIVITY_STATUS == "ai-agent-execution-summary"
assert Flag.BETA_BLOCKS == "beta-blocks"
assert Flag.AGENT_ACTIVITY == "agent-activity"
@pytest.mark.asyncio
async def test_is_feature_enabled_with_flag_enum(mocker):
"""Test is_feature_enabled function with Flag enum."""
mock_get_feature_flag_value = mocker.patch(
"backend.util.feature_flag.get_feature_flag_value"
)
mock_get_feature_flag_value.return_value = True
result = await is_feature_enabled(Flag.AUTOMOD, "user123")
assert result is True
# Should call with the flag's string value
mock_get_feature_flag_value.assert_called_once()

View File

@@ -7,16 +7,14 @@ from sentry_sdk.integrations.logging import LoggingIntegration
from backend.util.settings import Settings
settings = Settings()
def sentry_init():
sentry_dsn = settings.secrets.sentry_dsn
sentry_dsn = Settings().secrets.sentry_dsn
sentry_sdk.init(
dsn=sentry_dsn,
traces_sample_rate=1.0,
profiles_sample_rate=1.0,
environment=f"app:{settings.config.app_env.value}-behave:{settings.config.behave_as.value}",
environment=f"app:{Settings().config.app_env.value}-behave:{Settings().config.behave_as.value}",
_experiments={"enable_logs": True},
integrations=[
LoggingIntegration(sentry_logs_level=logging.INFO),
@@ -35,7 +33,9 @@ def sentry_capture_error(error: Exception):
async def discord_send_alert(content: str):
from backend.blocks.discord import SendDiscordMessageBlock
from backend.data.model import APIKeyCredentials, CredentialsMetaInput, ProviderName
from backend.util.settings import Settings
settings = Settings()
creds = APIKeyCredentials(
provider="discord",
api_key=SecretStr(settings.secrets.discord_bot_token),

View File

@@ -1,20 +0,0 @@
"""
Shared models and types used across the backend to avoid circular imports.
"""
import pydantic
class Pagination(pydantic.BaseModel):
total_items: int = pydantic.Field(
description="Total number of items.", examples=[42]
)
total_pages: int = pydantic.Field(
description="Total number of pages.", examples=[2]
)
current_page: int = pydantic.Field(
description="Current_page page number.", examples=[1]
)
page_size: int = pydantic.Field(
description="Number of items per page.", examples=[25]
)

View File

@@ -24,6 +24,7 @@ from typing import (
import httpx
import uvicorn
from autogpt_libs.logging.utils import generate_uvicorn_config
from fastapi import FastAPI, Request, responses
from pydantic import BaseModel, TypeAdapter, create_model
@@ -44,34 +45,6 @@ api_comm_retry = config.pyro_client_comm_retry
api_comm_timeout = config.pyro_client_comm_timeout
api_call_timeout = config.rpc_client_call_timeout
def _validate_no_prisma_objects(obj: Any, path: str = "result") -> None:
"""
Recursively validate that no Prisma objects are being returned from service methods.
This enforces proper separation of layers - only application models should cross service boundaries.
"""
if obj is None:
return
# Check if it's a Prisma model object
if hasattr(obj, "__class__") and hasattr(obj.__class__, "__module__"):
module_name = obj.__class__.__module__
if module_name and "prisma.models" in module_name:
raise ValueError(
f"Prisma object {obj.__class__.__name__} found in {path}. "
"Service methods must return application models, not Prisma objects. "
f"Use {obj.__class__.__name__}.from_db() to convert to application model."
)
# Recursively check collections
if isinstance(obj, (list, tuple)):
for i, item in enumerate(obj):
_validate_no_prisma_objects(item, f"{path}[{i}]")
elif isinstance(obj, dict):
for key, value in obj.items():
_validate_no_prisma_objects(value, f"{path}['{key}']")
P = ParamSpec("P")
R = TypeVar("R")
EXPOSED_FLAG = "__exposed__"
@@ -124,36 +97,6 @@ class RemoteCallError(BaseModel):
args: Optional[Tuple[Any, ...]] = None
class UnhealthyServiceError(ValueError):
def __init__(
self, message: str = "Service is unhealthy or not ready", log: bool = True
):
msg = f"[{get_service_name()}] - {message}"
super().__init__(msg)
self.message = msg
if log:
logger.error(self.message)
def __str__(self):
return self.message
class HTTPClientError(Exception):
"""Exception for HTTP client errors (4xx status codes) that should not be retried."""
def __init__(self, status_code: int, message: str):
self.status_code = status_code
super().__init__(f"HTTP {status_code}: {message}")
class HTTPServerError(Exception):
"""Exception for HTTP server errors (5xx status codes) that can be retried."""
def __init__(self, status_code: int, message: str):
self.status_code = status_code
super().__init__(f"HTTP {status_code}: {message}")
EXCEPTION_MAPPING = {
e.__name__: e
for e in [
@@ -161,9 +104,6 @@ EXCEPTION_MAPPING = {
RuntimeError,
TimeoutError,
ConnectionError,
UnhealthyServiceError,
HTTPClientError,
HTTPServerError,
*[
ErrorType
for _, ErrorType in inspect.getmembers(exceptions)
@@ -236,21 +176,17 @@ class AppService(BaseAppService, ABC):
if asyncio.iscoroutinefunction(f):
async def async_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
result = await f(
return await f(
**{name: getattr(body, name) for name in type(body).model_fields}
)
_validate_no_prisma_objects(result, f"{func.__name__} result")
return result
return async_endpoint
else:
def sync_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
result = f(
return f(
**{name: getattr(body, name) for name in type(body).model_fields}
)
_validate_no_prisma_objects(result, f"{func.__name__} result")
return result
return sync_endpoint
@@ -265,13 +201,13 @@ class AppService(BaseAppService, ABC):
self.fastapi_app,
host=api_host,
port=self.get_port(),
log_config=None, # Explicitly None to avoid uvicorn replacing the logger.
log_config=generate_uvicorn_config(),
log_level=self.log_level,
)
)
self.shared_event_loop.run_until_complete(server.serve())
async def health_check(self) -> str:
def health_check(self) -> str:
"""
A method to check the health of the process.
"""
@@ -362,7 +298,6 @@ def get_service_client(
AttributeError, # Missing attributes
asyncio.CancelledError, # Task was cancelled
concurrent.futures.CancelledError, # Future was cancelled
HTTPClientError, # HTTP 4xx client errors - don't retry
),
)(fn)
@@ -440,31 +375,11 @@ def get_service_client(
self._connection_failure_count = 0
return response.json()
except httpx.HTTPStatusError as e:
status_code = e.response.status_code
# Try to parse the error response as RemoteCallError for mapped exceptions
error_response = None
try:
error_response = RemoteCallError.model_validate(e.response.json())
except Exception:
pass
# If we successfully parsed a mapped exception type, re-raise it
if error_response and error_response.type in EXCEPTION_MAPPING:
exception_class = EXCEPTION_MAPPING[error_response.type]
args = error_response.args or [str(e)]
raise exception_class(*args)
# Otherwise categorize by HTTP status code
if 400 <= status_code < 500:
# Client errors (4xx) - wrap to prevent retries
raise HTTPClientError(status_code, str(e))
elif 500 <= status_code < 600:
# Server errors (5xx) - wrap but allow retries
raise HTTPServerError(status_code, str(e))
else:
# Other status codes (1xx, 2xx, 3xx) - re-raise original error
raise e
error = RemoteCallError.model_validate(e.response.json())
# DEBUG HELP: if you made a custom exception, make sure you override self.args to be how to make your exception
raise EXCEPTION_MAPPING.get(error.type, Exception)(
*(error.args or [str(e)])
)
@_maybe_retry
def _call_method_sync(self, method_name: str, **kwargs: Any) -> Any:
@@ -491,43 +406,11 @@ def get_service_client(
raise
async def aclose(self) -> None:
if hasattr(self, "sync_client"):
self.sync_client.close()
if hasattr(self, "async_client"):
await self.async_client.aclose()
self.sync_client.close()
await self.async_client.aclose()
def close(self) -> None:
if hasattr(self, "sync_client"):
self.sync_client.close()
# Note: Cannot close async client synchronously
def __del__(self):
"""Cleanup HTTP clients on garbage collection to prevent resource leaks."""
try:
if hasattr(self, "sync_client"):
self.sync_client.close()
if hasattr(self, "async_client"):
# Note: Can't await in __del__, so we just close sync
# The async client will be cleaned up by garbage collection
import warnings
warnings.warn(
"DynamicClient async client not explicitly closed. "
"Call aclose() before destroying the client.",
ResourceWarning,
stacklevel=2,
)
except Exception:
# Silently ignore cleanup errors in __del__
pass
async def __aenter__(self):
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.aclose()
self.sync_client.close()
def _get_params(
self, signature: inspect.Signature, *args: Any, **kwargs: Any

View File

@@ -8,8 +8,6 @@ import pytest
from backend.util.service import (
AppService,
AppServiceClient,
HTTPClientError,
HTTPServerError,
endpoint_to_async,
expose,
get_service_client,
@@ -368,125 +366,3 @@ def test_service_no_retry_when_disabled(server):
# This should fail immediately without retry
with pytest.raises(RuntimeError, match="Intended error for testing"):
client.always_failing_add(5, 3)
class TestHTTPErrorRetryBehavior:
"""Test that HTTP client errors (4xx) are not retried but server errors (5xx) can be."""
# Note: These tests access private methods for testing internal behavior
# Type ignore comments are used to suppress warnings about accessing private methods
def test_http_client_error_not_retried(self):
"""Test that 4xx errors are wrapped as HTTPClientError and not retried."""
# Create a mock response with 404 status
mock_response = Mock()
mock_response.status_code = 404
mock_response.json.return_value = {"message": "Not found"}
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"404 Not Found", request=Mock(), response=mock_response
)
# Create client
client = get_service_client(ServiceTestClient)
dynamic_client = client
# Test the _handle_call_method_response directly
with pytest.raises(HTTPClientError) as exc_info:
dynamic_client._handle_call_method_response( # type: ignore[attr-defined]
response=mock_response, method_name="test_method"
)
assert exc_info.value.status_code == 404
assert "404" in str(exc_info.value)
def test_http_server_error_can_be_retried(self):
"""Test that 5xx errors are wrapped as HTTPServerError and can be retried."""
# Create a mock response with 500 status
mock_response = Mock()
mock_response.status_code = 500
mock_response.json.return_value = {"message": "Internal server error"}
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"500 Internal Server Error", request=Mock(), response=mock_response
)
# Create client
client = get_service_client(ServiceTestClient)
dynamic_client = client
# Test the _handle_call_method_response directly
with pytest.raises(HTTPServerError) as exc_info:
dynamic_client._handle_call_method_response( # type: ignore[attr-defined]
response=mock_response, method_name="test_method"
)
assert exc_info.value.status_code == 500
assert "500" in str(exc_info.value)
def test_mapped_exception_preserves_original_type(self):
"""Test that mapped exceptions preserve their original type regardless of HTTP status."""
# Create a mock response with ValueError in the remote call error
mock_response = Mock()
mock_response.status_code = 400
mock_response.json.return_value = {
"type": "ValueError",
"args": ["Invalid parameter value"],
}
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"400 Bad Request", request=Mock(), response=mock_response
)
# Create client
client = get_service_client(ServiceTestClient)
dynamic_client = client
# Test the _handle_call_method_response directly
with pytest.raises(ValueError) as exc_info:
dynamic_client._handle_call_method_response( # type: ignore[attr-defined]
response=mock_response, method_name="test_method"
)
assert "Invalid parameter value" in str(exc_info.value)
def test_client_error_status_codes_coverage(self):
"""Test that various 4xx status codes are all wrapped as HTTPClientError."""
client_error_codes = [400, 401, 403, 404, 405, 409, 422, 429]
for status_code in client_error_codes:
mock_response = Mock()
mock_response.status_code = status_code
mock_response.json.return_value = {"message": f"Error {status_code}"}
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
f"{status_code} Error", request=Mock(), response=mock_response
)
client = get_service_client(ServiceTestClient)
dynamic_client = client
with pytest.raises(HTTPClientError) as exc_info:
dynamic_client._handle_call_method_response( # type: ignore
response=mock_response, method_name="test_method"
)
assert exc_info.value.status_code == status_code
def test_server_error_status_codes_coverage(self):
"""Test that various 5xx status codes are all wrapped as HTTPServerError."""
server_error_codes = [500, 501, 502, 503, 504, 505]
for status_code in server_error_codes:
mock_response = Mock()
mock_response.status_code = status_code
mock_response.json.return_value = {"message": f"Error {status_code}"}
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
f"{status_code} Error", request=Mock(), response=mock_response
)
client = get_service_client(ServiceTestClient)
dynamic_client = client
with pytest.raises(HTTPServerError) as exc_info:
dynamic_client._handle_call_method_response( # type: ignore
response=mock_response, method_name="test_method"
)
assert exc_info.value.status_code == status_code

View File

@@ -295,32 +295,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Maximum file size in MB for file uploads (1-1024 MB)",
)
# AutoMod configuration
automod_enabled: bool = Field(
default=False,
description="Whether AutoMod content moderation is enabled",
)
automod_api_url: str = Field(
default="",
description="AutoMod API base URL - Make sure it ends in /api",
)
automod_timeout: int = Field(
default=30,
description="Timeout in seconds for AutoMod API requests",
)
automod_retry_attempts: int = Field(
default=3,
description="Number of retry attempts for AutoMod API requests",
)
automod_retry_delay: float = Field(
default=1.0,
description="Delay between retries for AutoMod API requests in seconds",
)
automod_fail_open: bool = Field(
default=False,
description="If True, allow execution to continue if AutoMod fails",
)
@field_validator("platform_base_url", "frontend_base_url")
@classmethod
def validate_platform_base_url(cls, v: str, info: ValidationInfo) -> str:
@@ -360,7 +334,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Maximum message size limit for communication with the message bus",
)
backend_cors_allow_origins: List[str] = Field(default=["http://localhost:3000"])
backend_cors_allow_origins: List[str] = Field(default_factory=list)
@field_validator("backend_cors_allow_origins")
@classmethod
@@ -472,7 +446,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
groq_api_key: str = Field(default="", description="Groq API key")
open_router_api_key: str = Field(default="", description="Open Router API Key")
llama_api_key: str = Field(default="", description="Llama API Key")
v0_api_key: str = Field(default="", description="v0 by Vercel API key")
reddit_client_id: str = Field(default="", description="Reddit client ID")
reddit_client_secret: str = Field(default="", description="Reddit client secret")
@@ -522,20 +495,10 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
apollo_api_key: str = Field(default="", description="Apollo API Key")
smartlead_api_key: str = Field(default="", description="SmartLead API Key")
zerobounce_api_key: str = Field(default="", description="ZeroBounce API Key")
enrichlayer_api_key: str = Field(default="", description="Enrichlayer API Key")
# AutoMod API credentials
automod_api_key: str = Field(default="", description="AutoMod API key")
# LaunchDarkly feature flags
launch_darkly_sdk_key: str = Field(
default="",
description="The LaunchDarkly SDK key for feature flag management",
)
ayrshare_api_key: str = Field(default="", description="Ayrshare API Key")
ayrshare_jwt_key: str = Field(default="", description="Ayrshare private Key")
# Add more secret fields as needed
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",

View File

@@ -1,41 +0,0 @@
-- Drop the existing view
DROP VIEW IF EXISTS "StoreSubmission";
-- Recreate the view with the new fields
CREATE VIEW "StoreSubmission" AS
SELECT
sl.id AS listing_id,
sl."owningUserId" AS user_id,
slv."agentGraphId" AS agent_id,
slv.version AS agent_version,
sl.slug,
COALESCE(slv.name, '') AS name,
slv."subHeading" AS sub_heading,
slv.description,
slv."imageUrls" AS image_urls,
slv."submittedAt" AS date_submitted,
slv."submissionStatus" AS status,
COALESCE(ar.run_count, 0::bigint) AS runs,
COALESCE(avg(sr.score::numeric), 0.0)::double precision AS rating,
slv.id AS store_listing_version_id,
slv."reviewerId" AS reviewer_id,
slv."reviewComments" AS review_comments,
slv."internalComments" AS internal_comments,
slv."reviewedAt" AS reviewed_at,
slv."changesSummary" AS changes_summary,
-- Add the two new fields:
slv."videoUrl" AS video_url,
slv.categories
FROM "StoreListing" sl
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
LEFT JOIN "StoreListingReview" sr ON sr."storeListingVersionId" = slv.id
LEFT JOIN (
SELECT "AgentGraphExecution"."agentGraphId", count(*) AS run_count
FROM "AgentGraphExecution"
GROUP BY "AgentGraphExecution"."agentGraphId"
) ar ON ar."agentGraphId" = slv."agentGraphId"
WHERE sl."isDeleted" = false
GROUP BY sl.id, sl."owningUserId", slv.id, slv."agentGraphId", slv.version, sl.slug, slv.name,
slv."subHeading", slv.description, slv."imageUrls", slv."submittedAt",
slv."submissionStatus", slv."reviewerId", slv."reviewComments", slv."internalComments",
slv."reviewedAt", slv."changesSummary", slv."videoUrl", slv.categories, ar.run_count;

View File

@@ -1079,25 +1079,6 @@ files = [
dnspython = ">=2.0.0"
idna = ">=2.0.0"
[[package]]
name = "exa-py"
version = "1.14.20"
description = "Python SDK for Exa API."
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "exa_py-1.14.20-py3-none-any.whl", hash = "sha256:e0ed9d99c3c494a0e6903e11a0f6fb773b3b23d0cd802380cf58efc97d9d332d"},
{file = "exa_py-1.14.20.tar.gz", hash = "sha256:423789a0635b7a4ecd5f56d6b4a0dfb01126fa45ce1e04106c0bb96b7d551ebf"},
]
[package.dependencies]
httpx = ">=0.28.1"
openai = ">=1.48"
pydantic = ">=2.10.6"
requests = ">=2.32.3"
typing-extensions = ">=4.12.2"
[[package]]
name = "exceptiongroup"
version = "1.3.0"
@@ -6737,4 +6718,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.13"
content-hash = "795414d7ce8f288ea6c65893268b5c29a7c9a60ad75cde28ac7bcdb65f426dfe"
content-hash = "225ddae645d22cc57f46330e735c069fb52e708123aa642e74adbf077dda0796"

View File

@@ -10,7 +10,6 @@ packages = [{ include = "backend", format = "sdist" }]
[tool.poetry.dependencies]
python = ">=3.10,<3.13"
aio-pika = "^9.5.5"
aiohttp = "^3.10.0"
aiodns = "^3.5.0"
anthropic = "^0.59.0"
apscheduler = "^3.11.0"
@@ -76,7 +75,6 @@ setuptools = "^80.9.0"
gcloud-aio-storage = "^9.5.0"
pandas = "^2.3.1"
firecrawl-py = "^2.16.3"
exa-py = "^1.14.20"
[tool.poetry.group.dev.dependencies]
aiohappyeyeballs = "^2.6.1"
@@ -103,7 +101,6 @@ rest = "backend.rest:main"
db = "backend.db:main"
ws = "backend.ws:main"
scheduler = "backend.scheduler:main"
notification = "backend.notification:main"
executor = "backend.exec:main"
cli = "backend.cli:main"
format = "linter:format"

Some files were not shown because too many files have changed in this diff Show More