Compare commits

..

4 Commits

Author SHA1 Message Date
Andy Hooker
6fca2352bb feat(build): Add undo/redo functionality and integrate custom nodes
Introduce a `useUndoRedo` hook to manage state history with undo/redo, persistence using localStorage, and state validation. Updated the build page to display a flow with initial custom nodes and edges, replacing the previous flow editor implementation.
2025-02-22 17:17:46 -06:00
Andy Hooker
84759053a7 feat(build): Add undo/redo functionality and integrate custom nodes
Introduce a `useUndoRedo` hook to manage state history with undo/redo, persistence using localStorage, and state validation. Updated the build page to display a flow with initial custom nodes and edges, replacing the previous flow editor implementation.
2025-02-22 17:17:35 -06:00
Andy Hooker
4afe724628 feat(build): Add BuildFlow component for editable flowchart canvas
Introduces the `BuildFlow` component with undo/redo, drag-and-drop, and connection functionality using the ReactFlow library. This implementation supports state management for nodes and edges and integrates a control panel for user actions like undo, redo, and reset. It also includes a read-only mode for non-editable use cases.
2025-02-22 17:17:15 -06:00
Andy Hooker
c178a537b7 feat(build): Add custom node, edge components, and canvas mapping hook
Introduce `BuildCustomNode` and `BuildCustomEdge` components for enhanced React Flow visualizations, enabling node focus and styled edges. Implement `useCanvasMapping` hook to map and enhance nodes and edges dynamically with custom labels and styles.
2025-02-22 17:16:16 -06:00
439 changed files with 11043 additions and 27141 deletions

View File

@@ -129,6 +129,30 @@ updates:
- "minor"
- "patch"
# Submodules
- package-ecosystem: "gitsubmodule"
directory: "autogpt_platform/supabase"
schedule:
interval: "weekly"
open-pull-requests-limit: 1
target-branch: "dev"
commit-message:
prefix: "chore(platform/deps)"
prefix-development: "chore(platform/deps-dev)"
groups:
production-dependencies:
dependency-type: "production"
update-types:
- "minor"
- "patch"
development-dependencies:
dependency-type: "development"
update-types:
- "minor"
- "patch"
# Docs
- package-ecosystem: 'pip'
directory: "docs/"

View File

@@ -115,7 +115,6 @@ jobs:
poetry run pytest -vv \
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
--numprocesses=logical --durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
tests/unit tests/integration
env:
CI: true
@@ -125,14 +124,8 @@ jobs:
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
- name: Upload test results to Codecov
if: ${{ !cancelled() }} # Run even if tests fail
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: autogpt-agent,${{ runner.os }}

View File

@@ -87,20 +87,13 @@ jobs:
poetry run pytest -vv \
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
tests
env:
CI: true
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Upload test results to Codecov
if: ${{ !cancelled() }} # Run even if tests fail
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: agbenchmark,${{ runner.os }}

View File

@@ -139,7 +139,6 @@ jobs:
poetry run pytest -vv \
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
forge
env:
CI: true
@@ -149,14 +148,8 @@ jobs:
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
- name: Upload test results to Codecov
if: ${{ !cancelled() }} # Run even if tests fail
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: forge,${{ runner.os }}

View File

@@ -34,7 +34,6 @@ jobs:
python -m prisma migrate deploy
env:
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
trigger:

View File

@@ -36,7 +36,6 @@ jobs:
python -m prisma migrate deploy
env:
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
trigger:
needs: migrate

View File

@@ -66,7 +66,7 @@ jobs:
- name: Setup Supabase
uses: supabase/setup-cli@v1
with:
version: 1.178.1
version: latest
- id: get_date
name: Get date
@@ -80,35 +80,18 @@ jobs:
- name: Install Poetry (Unix)
run: |
# Extract Poetry version from backend/poetry.lock
HEAD_POETRY_VERSION=$(head -n 1 poetry.lock | grep -oP '(?<=Poetry )[0-9]+\.[0-9]+\.[0-9]+')
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
if [ -n "$BASE_REF" ]; then
BASE_BRANCH=${BASE_REF/refs\/heads\//}
BASE_POETRY_VERSION=$((git show "origin/$BASE_BRANCH":./poetry.lock; true) | head -n 1 | grep -oP '(?<=Poetry )[0-9]+\.[0-9]+\.[0-9]+')
echo "Found Poetry version ${BASE_POETRY_VERSION} in backend/poetry.lock on ${BASE_REF}"
POETRY_VERSION=$(printf '%s\n' "$HEAD_POETRY_VERSION" "$BASE_POETRY_VERSION" | sort -V | tail -n1)
else
POETRY_VERSION=$HEAD_POETRY_VERSION
fi
echo "Using Poetry version ${POETRY_VERSION}"
# Install Poetry
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$POETRY_VERSION python3 -
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
env:
BASE_REF: ${{ github.base_ref || github.event.merge_group.base_ref }}
- name: Check poetry.lock
run: |
poetry lock
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
if ! git diff --quiet poetry.lock; then
echo "Error: poetry.lock not up to date."
echo
git diff poetry.lock
@@ -135,7 +118,6 @@ jobs:
run: poetry run prisma migrate dev --name updates
env:
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
- id: lint
name: Run Linter
@@ -152,13 +134,12 @@ jobs:
env:
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
REDIS_HOST: "localhost"
REDIS_PORT: "6379"
REDIS_PASSWORD: "testpassword"
REDIS_HOST: 'localhost'
REDIS_PORT: '6379'
REDIS_PASSWORD: 'testpassword'
env:
CI: true
@@ -171,8 +152,8 @@ jobs:
# If you want to replace this, you can do so by making our entire system generate
# new credentials for each local user and update the environment variables in
# the backend service, docker composes, and examples
RABBITMQ_DEFAULT_USER: "rabbitmq_user_default"
RABBITMQ_DEFAULT_PASS: "k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7"
RABBITMQ_DEFAULT_USER: 'rabbitmq_user_default'
RABBITMQ_DEFAULT_PASS: 'k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7'
# - name: Upload coverage reports to Codecov
# uses: codecov/codecov-action@v4

View File

@@ -56,30 +56,6 @@ jobs:
run: |
yarn type-check
design:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Install dependencies
run: |
yarn install --frozen-lockfile
- name: Run Chromatic
uses: chromaui/action@latest
with:
# ⚠️ Make sure to configure a `CHROMATIC_PROJECT_TOKEN` repository secret
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
workingDir: autogpt_platform/frontend
test:
runs-on: ubuntu-latest
strategy:
@@ -106,7 +82,7 @@ jobs:
- name: Copy default supabase .env
run: |
cp ../.env.example ../.env
cp ../supabase/docker/.env.example ../.env
- name: Copy backend .env
run: |

3
.gitmodules vendored
View File

@@ -1,3 +1,6 @@
[submodule "classic/forge/tests/vcr_cassettes"]
path = classic/forge/tests/vcr_cassettes
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes
[submodule "autogpt_platform/supabase"]
path = autogpt_platform/supabase
url = https://github.com/supabase/supabase.git

View File

@@ -140,7 +140,7 @@ repos:
language: system
- repo: https://github.com/psf/black
rev: 24.10.0
rev: 23.12.1
# Black has sensible defaults, doesn't need package context, and ignores
# everything in .gitignore, so it works fine without any config or arguments.
hooks:

View File

@@ -2,6 +2,9 @@
If you are reading this, you are probably looking for the full **[contribution guide]**,
which is part of our [wiki].
Also check out our [🚀 Roadmap][roadmap] for information about our priorities and associated tasks.
<!-- You can find our immediate priorities and their progress on our public [kanban board]. -->
[contribution guide]: https://github.com/Significant-Gravitas/AutoGPT/wiki/Contributing
[wiki]: https://github.com/Significant-Gravitas/AutoGPT/wiki
[roadmap]: https://github.com/Significant-Gravitas/AutoGPT/discussions/6971

View File

@@ -15,11 +15,7 @@
> Setting up and hosting the AutoGPT Platform yourself is a technical process.
> If you'd rather something that just works, we recommend [joining the waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta.
### Updated Setup Instructions:
Weve moved to a fully maintained and regularly updated documentation site.
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
https://github.com/user-attachments/assets/d04273a5-b36a-4a37-818e-f631ce72d603
This tutorial assumes you have Docker, VSCode, git and npm installed.

View File

@@ -20,7 +20,6 @@ Instead, please report them via:
- Please provide detailed reports with reproducible steps
- Include the version/commit hash where you discovered the vulnerability
- Allow us a 90-day security fix window before any public disclosure
- After patch is released, allow 30 days for users to update before public disclosure (for a total of 120 days max between update time and fix time)
- Share any potential mitigations or workarounds if known
## Supported Versions

View File

@@ -1,123 +0,0 @@
############
# Secrets
# YOU MUST CHANGE THESE BEFORE GOING INTO PRODUCTION
############
POSTGRES_PASSWORD=your-super-secret-and-long-postgres-password
JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
DASHBOARD_USERNAME=supabase
DASHBOARD_PASSWORD=this_password_is_insecure_and_should_be_updated
SECRET_KEY_BASE=UpNVntn3cDxHJpq99YMc1T1AQgQpc8kfYTuRgBiYa15BLrx8etQoXz3gZv1/u2oq
VAULT_ENC_KEY=your-encryption-key-32-chars-min
############
# Database - You can change these to any PostgreSQL database that has logical replication enabled.
############
POSTGRES_HOST=db
POSTGRES_DB=postgres
POSTGRES_PORT=5432
# default user is postgres
############
# Supavisor -- Database pooler
############
POOLER_PROXY_PORT_TRANSACTION=6543
POOLER_DEFAULT_POOL_SIZE=20
POOLER_MAX_CLIENT_CONN=100
POOLER_TENANT_ID=your-tenant-id
############
# API Proxy - Configuration for the Kong Reverse proxy.
############
KONG_HTTP_PORT=8000
KONG_HTTPS_PORT=8443
############
# API - Configuration for PostgREST.
############
PGRST_DB_SCHEMAS=public,storage,graphql_public
############
# Auth - Configuration for the GoTrue authentication server.
############
## General
SITE_URL=http://localhost:3000
ADDITIONAL_REDIRECT_URLS=
JWT_EXPIRY=3600
DISABLE_SIGNUP=false
API_EXTERNAL_URL=http://localhost:8000
## Mailer Config
MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify"
MAILER_URLPATHS_INVITE="/auth/v1/verify"
MAILER_URLPATHS_RECOVERY="/auth/v1/verify"
MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify"
## Email auth
ENABLE_EMAIL_SIGNUP=true
ENABLE_EMAIL_AUTOCONFIRM=false
SMTP_ADMIN_EMAIL=admin@example.com
SMTP_HOST=supabase-mail
SMTP_PORT=2500
SMTP_USER=fake_mail_user
SMTP_PASS=fake_mail_password
SMTP_SENDER_NAME=fake_sender
ENABLE_ANONYMOUS_USERS=false
## Phone auth
ENABLE_PHONE_SIGNUP=true
ENABLE_PHONE_AUTOCONFIRM=true
############
# Studio - Configuration for the Dashboard
############
STUDIO_DEFAULT_ORGANIZATION=Default Organization
STUDIO_DEFAULT_PROJECT=Default Project
STUDIO_PORT=3000
# replace if you intend to use Studio outside of localhost
SUPABASE_PUBLIC_URL=http://localhost:8000
# Enable webp support
IMGPROXY_ENABLE_WEBP_DETECTION=true
# Add your OpenAI API key to enable SQL Editor Assistant
OPENAI_API_KEY=
############
# Functions - Configuration for Functions
############
# NOTE: VERIFY_JWT applies to all functions. Per-function VERIFY_JWT is not supported yet.
FUNCTIONS_VERIFY_JWT=false
############
# Logs - Configuration for Logflare
# Please refer to https://supabase.com/docs/reference/self-hosting-analytics/introduction
############
LOGFLARE_LOGGER_BACKEND_API_KEY=your-super-secret-and-long-logflare-key
# Change vector.toml sinks to reflect this change
LOGFLARE_API_KEY=your-super-secret-and-long-logflare-key
# Docker socket location - this value will differ depending on your OS
DOCKER_SOCKET_LOCATION=/var/run/docker.sock
# Google Cloud Project details
GOOGLE_PROJECT_ID=GOOGLE_PROJECT_ID
GOOGLE_PROJECT_NUMBER=GOOGLE_PROJECT_NUMBER

View File

@@ -22,29 +22,35 @@ To run the AutoGPT Platform, follow these steps:
2. Run the following command:
```
cp .env.example .env
git submodule update --init --recursive --progress
```
This command will copy the `.env.example` file to `.env`. You can modify the `.env` file to add your own environment variables.
This command will initialize and update the submodules in the repository. The `supabase` folder will be cloned to the root directory.
3. Run the following command:
```
cp supabase/docker/.env.example .env
```
This command will copy the `.env.example` file to `.env` in the `supabase/docker` directory. You can modify the `.env` file to add your own environment variables.
4. Run the following command:
```
docker compose up -d
```
This command will start all the necessary backend services defined in the `docker-compose.yml` file in detached mode.
4. Navigate to `frontend` within the `autogpt_platform` directory:
5. Navigate to `frontend` within the `autogpt_platform` directory:
```
cd frontend
```
You will need to run your frontend application separately on your local machine.
5. Run the following command:
6. Run the following command:
```
cp .env.example .env.local
```
This command will copy the `.env.example` file to `.env.local` in the `frontend` directory. You can modify the `.env.local` within this folder to add your own environment variables for the frontend application.
6. Run the following command:
7. Run the following command:
```
npm install
npm run dev
@@ -55,7 +61,7 @@ To run the AutoGPT Platform, follow these steps:
yarn install && yarn dev
```
7. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
8. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
### Docker Compose Commands

View File

@@ -1,13 +1,14 @@
from .config import Settings
from .depends import requires_admin_user, requires_user
from .jwt_utils import parse_jwt_token
from .middleware import APIKeyValidator, auth_middleware
from .middleware import auth_middleware
from .models import User
__all__ = [
"Settings",
"parse_jwt_token",
"requires_user",
"requires_admin_user",
"APIKeyValidator",
"auth_middleware",
"User",
]

View File

@@ -1,11 +1,14 @@
import os
from dotenv import load_dotenv
load_dotenv()
class Settings:
def __init__(self):
self.JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
self.JWT_ALGORITHM: str = "HS256"
JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
JWT_ALGORITHM: str = "HS256"
@property
def is_configured(self) -> bool:

View File

@@ -1,6 +1,6 @@
import fastapi
from .config import settings
from .config import Settings
from .middleware import auth_middleware
from .models import DEFAULT_USER_ID, User
@@ -17,7 +17,7 @@ def requires_admin_user(
def verify_user(payload: dict | None, admin_only: bool) -> User:
if not payload:
if settings.ENABLE_AUTH:
if Settings.ENABLE_AUTH:
raise fastapi.HTTPException(
status_code=401, detail="Authorization header is missing"
)

View File

@@ -1,10 +1,7 @@
import inspect
import logging
from typing import Any, Callable, Optional
from fastapi import HTTPException, Request, Security
from fastapi.security import APIKeyHeader, HTTPBearer
from starlette.status import HTTP_401_UNAUTHORIZED
from fastapi import HTTPException, Request
from fastapi.security import HTTPBearer
from .config import settings
from .jwt_utils import parse_jwt_token
@@ -32,104 +29,3 @@ async def auth_middleware(request: Request):
except ValueError as e:
raise HTTPException(status_code=401, detail=str(e))
return payload
class APIKeyValidator:
"""
Configurable API key validator that supports custom validation functions
for FastAPI applications.
This class provides a flexible way to implement API key authentication with optional
custom validation logic. It can be used for simple token matching
or more complex validation scenarios like database lookups.
Examples:
Simple token validation:
```python
validator = APIKeyValidator(
header_name="X-API-Key",
expected_token="your-secret-token"
)
@app.get("/protected", dependencies=[Depends(validator.get_dependency())])
def protected_endpoint():
return {"message": "Access granted"}
```
Custom validation with database lookup:
```python
async def validate_with_db(api_key: str):
api_key_obj = await db.get_api_key(api_key)
return api_key_obj if api_key_obj and api_key_obj.is_active else None
validator = APIKeyValidator(
header_name="X-API-Key",
validate_fn=validate_with_db
)
```
Args:
header_name (str): The name of the header containing the API key
expected_token (Optional[str]): The expected API key value for simple token matching
validate_fn (Optional[Callable]): Custom validation function that takes an API key
string and returns a boolean or object. Can be async.
error_status (int): HTTP status code to use for validation errors
error_message (str): Error message to return when validation fails
"""
def __init__(
self,
header_name: str,
expected_token: Optional[str] = None,
validate_fn: Optional[Callable[[str], bool]] = None,
error_status: int = HTTP_401_UNAUTHORIZED,
error_message: str = "Invalid API key",
):
# Create the APIKeyHeader as a class property
self.security_scheme = APIKeyHeader(name=header_name)
self.expected_token = expected_token
self.custom_validate_fn = validate_fn
self.error_status = error_status
self.error_message = error_message
async def default_validator(self, api_key: str) -> bool:
return api_key == self.expected_token
async def __call__(
self, request: Request, api_key: str = Security(APIKeyHeader)
) -> Any:
if api_key is None:
raise HTTPException(status_code=self.error_status, detail="Missing API key")
# Use custom validation if provided, otherwise use default equality check
validator = self.custom_validate_fn or self.default_validator
result = (
await validator(api_key)
if inspect.iscoroutinefunction(validator)
else validator(api_key)
)
if not result:
raise HTTPException(
status_code=self.error_status, detail=self.error_message
)
# Store validation result in request state if it's not just a boolean
if result is not True:
request.state.api_key = result
return result
def get_dependency(self):
"""
Returns a callable dependency that FastAPI will recognize as a security scheme
"""
async def validate_api_key(
request: Request, api_key: str = Security(self.security_scheme)
) -> Any:
return await self(request, api_key)
# This helps FastAPI recognize it as a security dependency
validate_api_key.__name__ = f"validate_{self.security_scheme.model.name}"
return validate_api_key

View File

@@ -8,7 +8,7 @@ from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from .filters import BelowLevelFilter
from .formatters import AGPTFormatter
from .formatters import AGPTFormatter, StructuredLoggingFormatter
LOG_DIR = Path(__file__).parent.parent.parent.parent / "logs"
LOG_FILE = "activity.log"
@@ -81,26 +81,9 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
"""
config = LoggingConfig()
log_handlers: list[logging.Handler] = []
# Console output handlers
stdout = logging.StreamHandler(stream=sys.stdout)
stdout.setLevel(config.level)
stdout.addFilter(BelowLevelFilter(logging.WARNING))
if config.level == logging.DEBUG:
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
else:
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
stderr = logging.StreamHandler()
stderr.setLevel(logging.WARNING)
if config.level == logging.DEBUG:
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
else:
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
log_handlers += [stdout, stderr]
# Cloud logging setup
if config.enable_cloud_logging or force_cloud_logging:
import google.cloud.logging
@@ -114,7 +97,28 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
transport=SyncTransport,
)
cloud_handler.setLevel(config.level)
cloud_handler.setFormatter(StructuredLoggingFormatter())
log_handlers.append(cloud_handler)
print("Cloud logging enabled")
else:
# Console output handlers
stdout = logging.StreamHandler(stream=sys.stdout)
stdout.setLevel(config.level)
stdout.addFilter(BelowLevelFilter(logging.WARNING))
if config.level == logging.DEBUG:
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
else:
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
stderr = logging.StreamHandler()
stderr.setLevel(logging.WARNING)
if config.level == logging.DEBUG:
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
else:
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
log_handlers += [stdout, stderr]
print("Console logging enabled")
# File logging setup
if config.enable_file_logging:
@@ -152,6 +156,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
error_log_handler.setLevel(logging.ERROR)
error_log_handler.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT, no_color=True))
log_handlers.append(error_log_handler)
print("File logging enabled")
# Configure the root logger
logging.basicConfig(

View File

@@ -1,6 +1,7 @@
import logging
from colorama import Fore, Style
from google.cloud.logging_v2.handlers import CloudLoggingFilter, StructuredLogHandler
from .utils import remove_color_codes
@@ -79,3 +80,16 @@ class AGPTFormatter(FancyConsoleFormatter):
return remove_color_codes(super().format(record))
else:
return super().format(record)
class StructuredLoggingFormatter(StructuredLogHandler, logging.Formatter):
def __init__(self):
# Set up CloudLoggingFilter to add diagnostic info to the log records
self.cloud_logging_filter = CloudLoggingFilter()
# Init StructuredLogHandler
super().__init__()
def format(self, record: logging.LogRecord) -> str:
self.cloud_logging_filter.filter(record)
return super().format(record)

View File

@@ -2,7 +2,6 @@ import logging
import re
from typing import Any
import uvicorn.config
from colorama import Fore
@@ -26,14 +25,3 @@ def print_attribute(
"color": value_color,
},
)
def generate_uvicorn_config():
"""
Generates a uvicorn logging config that silences uvicorn's default logging and tells it to use the native logging module.
"""
log_config = dict(uvicorn.config.LOGGING_CONFIG)
log_config["loggers"]["uvicorn"] = {"handlers": []}
log_config["loggers"]["uvicorn.error"] = {"handlers": []}
log_config["loggers"]["uvicorn.access"] = {"handlers": []}
return log_config

View File

@@ -1,59 +1,20 @@
import inspect
import threading
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
from typing import Callable, ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
@overload
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
@overload
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
def thread_cached(
func: Callable[P, R] | Callable[P, Awaitable[R]],
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
thread_local = threading.local()
def _clear():
if hasattr(thread_local, "cache"):
del thread_local.cache
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
if inspect.iscoroutinefunction(func):
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
*args, **kwargs
)
return cache[key]
setattr(async_wrapper, "clear_cache", _clear)
return async_wrapper
else:
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
setattr(sync_wrapper, "clear_cache", _clear)
return sync_wrapper
def clear_thread_cache(func: Callable) -> None:
if clear := getattr(func, "clear_cache", None):
clear()
return wrapper

File diff suppressed because it is too large Load Diff

View File

@@ -10,17 +10,18 @@ packages = [{ include = "autogpt_libs" }]
colorama = "^0.4.6"
expiringdict = "^1.2.2"
google-cloud-logging = "^3.11.4"
pydantic = "^2.11.1"
pydantic-settings = "^2.8.1"
pydantic = "^2.10.6"
pydantic-settings = "^2.7.1"
pyjwt = "^2.10.1"
pytest-asyncio = "^0.26.0"
pytest-asyncio = "^0.25.3"
pytest-mock = "^3.14.0"
python = ">=3.10,<4.0"
supabase = "^2.15.0"
python-dotenv = "^1.0.1"
supabase = "^2.13.0"
[tool.poetry.group.dev.dependencies]
redis = "^5.2.1"
ruff = "^0.11.0"
ruff = "^0.9.3"
[build-system]
requires = ["poetry-core"]

View File

@@ -2,24 +2,13 @@ DB_USER=postgres
DB_PASS=your-super-secret-and-long-postgres-password
DB_NAME=postgres
DB_PORT=5432
DB_HOST=localhost
DB_CONNECTION_LIMIT=12
DB_CONNECT_TIMEOUT=60
DB_POOL_TIMEOUT=300
DB_SCHEMA=platform
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@localhost:${DB_PORT}/${DB_NAME}?connect_timeout=60&schema=platform"
PRISMA_SCHEMA="postgres/schema.prisma"
# EXECUTOR
NUM_GRAPH_WORKERS=10
NUM_NODE_WORKERS=3
BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
# generate using `from cryptography.fernet import Fernet;Fernet.generate_key().decode()`
ENCRYPTION_KEY='dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw='
UNSUBSCRIBE_SECRET_KEY = 'HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio='
REDIS_HOST=localhost
REDIS_PORT=6379
@@ -39,7 +28,6 @@ SENTRY_DSN=
# Email For Postmark so we can send emails
POSTMARK_SERVER_API_TOKEN=
POSTMARK_SENDER_EMAIL=invalid@invalid.com
POSTMARK_WEBHOOK_TOKEN=
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
ENABLE_AUTH=true
@@ -53,9 +41,6 @@ RABBITMQ_PORT=5672
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
## GCS bucket is required for marketplace and library functionality
MEDIA_GCS_BUCKET_NAME=
## For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow
## for integrations to work. Defaults to the value of PLATFORM_BASE_URL if not set.
# FRONTEND_BASE_URL=http://localhost:3000
@@ -189,8 +174,6 @@ SMARTLEAD_API_KEY=
# ZeroBounce
ZEROBOUNCE_API_KEY=
## ===== OPTIONAL API KEYS END ===== ##
# Logging Configuration
LOG_LEVEL=INFO
ENABLE_CLOUD_LOGGING=false

View File

@@ -73,6 +73,7 @@ FROM server_dependencies AS server
COPY autogpt_platform/backend /app/autogpt_platform/backend
RUN poetry install --no-ansi --only-root
ENV DATABASE_URL=""
ENV PORT=8000
CMD ["poetry", "run", "rest"]

View File

@@ -1 +1,75 @@
[Advanced Setup (Dev Branch)](https://dev-docs.agpt.co/platform/advanced_setup/#autogpt_agent_server_advanced_set_up)
# AutoGPT Agent Server Advanced set up
This guide walks you through a dockerized set up, with an external DB (postgres)
## Setup
We use the Poetry to manage the dependencies. To set up the project, follow these steps inside this directory:
0. Install Poetry
```sh
pip install poetry
```
1. Configure Poetry to use .venv in your project directory
```sh
poetry config virtualenvs.in-project true
```
2. Enter the poetry shell
```sh
poetry shell
```
3. Install dependencies
```sh
poetry install
```
4. Copy .env.example to .env
```sh
cp .env.example .env
```
5. Generate the Prisma client
```sh
poetry run prisma generate
```
> In case Prisma generates the client for the global Python installation instead of the virtual environment, the current mitigation is to just uninstall the global Prisma package:
>
> ```sh
> pip uninstall prisma
> ```
>
> Then run the generation again. The path *should* look something like this:
> `<some path>/pypoetry/virtualenvs/backend-TQIRSwR6-py3.12/bin/prisma`
6. Run the postgres database from the /rnd folder
```sh
cd autogpt_platform/
docker compose up -d
```
7. Run the migrations (from the backend folder)
```sh
cd ../backend
prisma migrate deploy
```
## Running The Server
### Starting the server directly
Run the following command:
```sh
poetry run app
```

View File

@@ -1 +1,210 @@
[Getting Started (Released)](https://docs.agpt.co/platform/getting-started/#autogpt_agent_server)
# AutoGPT Agent Server
This is an initial project for creating the next generation of agent execution, which is an AutoGPT agent server.
The agent server will enable the creation of composite multi-agent systems that utilize AutoGPT agents and other non-agent components as its primitives.
## Docs
You can access the docs for the [AutoGPT Agent Server here](https://docs.agpt.co/server/setup).
## Setup
We use the Poetry to manage the dependencies. To set up the project, follow these steps inside this directory:
0. Install Poetry
```sh
pip install poetry
```
1. Configure Poetry to use .venv in your project directory
```sh
poetry config virtualenvs.in-project true
```
2. Enter the poetry shell
```sh
poetry shell
```
3. Install dependencies
```sh
poetry install
```
4. Copy .env.example to .env
```sh
cp .env.example .env
```
5. Generate the Prisma client
```sh
poetry run prisma generate
```
> In case Prisma generates the client for the global Python installation instead of the virtual environment, the current mitigation is to just uninstall the global Prisma package:
>
> ```sh
> pip uninstall prisma
> ```
>
> Then run the generation again. The path *should* look something like this:
> `<some path>/pypoetry/virtualenvs/backend-TQIRSwR6-py3.12/bin/prisma`
6. Migrate the database. Be careful because this deletes current data in the database.
```sh
docker compose up db -d
poetry run prisma migrate deploy
```
## Running The Server
### Starting the server without Docker
To run the server locally, start in the autogpt_platform folder:
```sh
cd ..
```
Run the following command to run database in docker but the application locally:
```sh
docker compose --profile local up deps --build --detach
cd backend
poetry run app
```
### Starting the server with Docker
Run the following command to build the dockerfiles:
```sh
docker compose build
```
Run the following command to run the app:
```sh
docker compose up
```
Run the following to automatically rebuild when code changes, in another terminal:
```sh
docker compose watch
```
Run the following command to shut down:
```sh
docker compose down
```
If you run into issues with dangling orphans, try:
```sh
docker compose down --volumes --remove-orphans && docker-compose up --force-recreate --renew-anon-volumes --remove-orphans
```
## Testing
To run the tests:
```sh
poetry run test
```
## Development
### Formatting & Linting
Auto formatter and linter are set up in the project. To run them:
Install:
```sh
poetry install --with dev
```
Format the code:
```sh
poetry run format
```
Lint the code:
```sh
poetry run lint
```
## Project Outline
The current project has the following main modules:
### **blocks**
This module stores all the Agent Blocks, which are reusable components to build a graph that represents the agent's behavior.
### **data**
This module stores the logical model that is persisted in the database.
It abstracts the database operations into functions that can be called by the service layer.
Any code that interacts with Prisma objects or the database should reside in this module.
The main models are:
* `block`: anything related to the block used in the graph
* `execution`: anything related to the execution graph execution
* `graph`: anything related to the graph, node, and its relations
### **execution**
This module stores the business logic of executing the graph.
It currently has the following main modules:
* `manager`: A service that consumes the queue of the graph execution and executes the graph. It contains both pieces of logic.
* `scheduler`: A service that triggers scheduled graph execution based on a cron expression. It pushes an execution request to the manager.
### **server**
This module stores the logic for the server API.
It contains all the logic used for the API that allows the client to create, execute, and monitor the graph and its execution.
This API service interacts with other services like those defined in `manager` and `scheduler`.
### **utils**
This module stores utility functions that are used across the project.
Currently, it has two main modules:
* `process`: A module that contains the logic to spawn a new process.
* `service`: A module that serves as a parent class for all the services in the project.
## Service Communication
Currently, there are only 3 active services:
- AgentServer (the API, defined in `server.py`)
- ExecutionManager (the executor, defined in `manager.py`)
- ExecutionScheduler (the scheduler, defined in `scheduler.py`)
The services run in independent Python processes and communicate through an IPC.
A communication layer (`service.py`) is created to decouple the communication library from the implementation.
Currently, the IPC is done using Pyro5 and abstracted in a way that allows a function decorated with `@expose` to be called from a different process.
By default the daemons run on the following ports:
Execution Manager Daemon: 8002
Execution Scheduler Daemon: 8003
Rest Server Daemon: 8004
## Adding a New Agent Block
To add a new agent block, you need to create a new class that inherits from `Block` and provides the following information:
* All the block code should live in the `blocks` (`backend.blocks`) module.
* `input_schema`: the schema of the input data, represented by a Pydantic object.
* `output_schema`: the schema of the output data, represented by a Pydantic object.
* `run` method: the main logic of the block.
* `test_input` & `test_output`: the sample input and output data for the block, which will be used to auto-test the block.
* You can mock the functions declared in the block using the `test_mock` field for your unit tests.
* Once you finish creating the block, you can test it by running `poetry run pytest -s test/block/test_block.py`.

View File

@@ -1,30 +1,22 @@
import logging
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from backend.util.process import AppProcess
logger = logging.getLogger(__name__)
def run_processes(*processes: "AppProcess", **kwargs):
"""
Execute all processes in the app. The last process is run in the foreground.
Includes enhanced error handling and process lifecycle management.
"""
try:
# Run all processes except the last one in the background.
for process in processes[:-1]:
process.start(background=True, **kwargs)
# Run the last process in the foreground.
# Run the last process in the foreground
processes[-1].start(background=False, **kwargs)
finally:
for process in processes:
try:
process.stop()
except Exception as e:
logger.exception(f"[{process.service_name}] unable to stop: {e}")
process.stop()
def main(**kwargs):
@@ -32,7 +24,7 @@ def main(**kwargs):
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
"""
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
from backend.executor import DatabaseManager, ExecutionManager, ExecutionScheduler
from backend.notifications import NotificationManager
from backend.server.rest_api import AgentServer
from backend.server.ws_api import WebsocketServer
@@ -40,7 +32,7 @@ def main(**kwargs):
run_processes(
DatabaseManager(),
ExecutionManager(),
Scheduler(),
ExecutionScheduler(),
NotificationManager(),
WebsocketServer(),
AgentServer(),

View File

@@ -2,103 +2,88 @@ import importlib
import os
import re
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar
from typing import Type, TypeVar
from backend.data.block import Block
# Dynamically load all modules under backend.blocks
AVAILABLE_MODULES = []
current_dir = Path(__file__).parent
modules = [
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
for f in current_dir.rglob("*.py")
if f.is_file() and f.name != "__init__.py"
]
for module in modules:
if not re.match("^[a-z0-9_.]+$", module):
raise ValueError(
f"Block module {module} error: module name must be lowercase, "
"and contain only alphanumeric characters and underscores."
)
importlib.import_module(f".{module}", package=__name__)
AVAILABLE_MODULES.append(module)
# Load all Block instances from the available modules
AVAILABLE_BLOCKS: dict[str, Type[Block]] = {}
if TYPE_CHECKING:
from backend.data.block import Block
T = TypeVar("T")
_AVAILABLE_BLOCKS: dict[str, type["Block"]] = {}
def load_all_blocks() -> dict[str, type["Block"]]:
from backend.data.block import Block
if _AVAILABLE_BLOCKS:
return _AVAILABLE_BLOCKS
# Dynamically load all modules under backend.blocks
AVAILABLE_MODULES = []
current_dir = Path(__file__).parent
modules = [
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
for f in current_dir.rglob("*.py")
if f.is_file() and f.name != "__init__.py"
]
for module in modules:
if not re.match("^[a-z0-9_.]+$", module):
raise ValueError(
f"Block module {module} error: module name must be lowercase, "
"and contain only alphanumeric characters and underscores."
)
importlib.import_module(f".{module}", package=__name__)
AVAILABLE_MODULES.append(module)
# Load all Block instances from the available modules
for block_cls in all_subclasses(Block):
class_name = block_cls.__name__
if class_name.endswith("Base"):
continue
if not class_name.endswith("Block"):
raise ValueError(
f"Block class {class_name} does not end with 'Block'. "
"If you are creating an abstract class, "
"please name the class with 'Base' at the end"
)
block = block_cls.create()
if not isinstance(block.id, str) or len(block.id) != 36:
raise ValueError(
f"Block ID {block.name} error: {block.id} is not a valid UUID"
)
if block.id in _AVAILABLE_BLOCKS:
raise ValueError(
f"Block ID {block.name} error: {block.id} is already in use"
)
input_schema = block.input_schema.model_fields
output_schema = block.output_schema.model_fields
# Make sure `error` field is a string in the output schema
if "error" in output_schema and output_schema["error"].annotation is not str:
raise ValueError(
f"{block.name} `error` field in output_schema must be a string"
)
# Ensure all fields in input_schema and output_schema are annotated SchemaFields
for field_name, field in [*input_schema.items(), *output_schema.items()]:
if field.annotation is None:
raise ValueError(
f"{block.name} has a field {field_name} that is not annotated"
)
if field.json_schema_extra is None:
raise ValueError(
f"{block.name} has a field {field_name} not defined as SchemaField"
)
for field in block.input_schema.model_fields.values():
if field.annotation is bool and field.default not in (True, False):
raise ValueError(
f"{block.name} has a boolean field with no default value"
)
_AVAILABLE_BLOCKS[block.id] = block_cls
return _AVAILABLE_BLOCKS
__all__ = ["load_all_blocks"]
def all_subclasses(cls: type[T]) -> list[type[T]]:
def all_subclasses(cls: Type[T]) -> list[Type[T]]:
subclasses = cls.__subclasses__()
for subclass in subclasses:
subclasses += all_subclasses(subclass)
return subclasses
for block_cls in all_subclasses(Block):
name = block_cls.__name__
if block_cls.__name__.endswith("Base"):
continue
if not block_cls.__name__.endswith("Block"):
raise ValueError(
f"Block class {block_cls.__name__} does not end with 'Block', If you are creating an abstract class, please name the class with 'Base' at the end"
)
block = block_cls.create()
if not isinstance(block.id, str) or len(block.id) != 36:
raise ValueError(f"Block ID {block.name} error: {block.id} is not a valid UUID")
if block.id in AVAILABLE_BLOCKS:
raise ValueError(f"Block ID {block.name} error: {block.id} is already in use")
input_schema = block.input_schema.model_fields
output_schema = block.output_schema.model_fields
# Make sure `error` field is a string in the output schema
if "error" in output_schema and output_schema["error"].annotation is not str:
raise ValueError(
f"{block.name} `error` field in output_schema must be a string"
)
# Make sure all fields in input_schema and output_schema are annotated and has a value
for field_name, field in [*input_schema.items(), *output_schema.items()]:
if field.annotation is None:
raise ValueError(
f"{block.name} has a field {field_name} that is not annotated"
)
if field.json_schema_extra is None:
raise ValueError(
f"{block.name} has a field {field_name} not defined as SchemaField"
)
for field in block.input_schema.model_fields.values():
if field.annotation is bool and field.default not in (True, False):
raise ValueError(f"{block.name} has a boolean field with no default value")
if block.disabled:
continue
AVAILABLE_BLOCKS[block.id] = block_cls
__all__ = ["AVAILABLE_MODULES", "AVAILABLE_BLOCKS"]

View File

@@ -1,5 +1,6 @@
import logging
from typing import Any
from autogpt_libs.utils.cache import thread_cached
from backend.data.block import (
Block,
@@ -12,11 +13,25 @@ from backend.data.block import (
)
from backend.data.execution import ExecutionStatus
from backend.data.model import SchemaField
from backend.util import json
logger = logging.getLogger(__name__)
@thread_cached
def get_executor_manager_client():
from backend.executor import ExecutionManager
from backend.util.service import get_service_client
return get_service_client(ExecutionManager)
@thread_cached
def get_event_bus():
from backend.data.execution import RedisExecutionEventBus
return RedisExecutionEventBus()
class AgentExecutorBlock(Block):
class Input(BlockSchema):
user_id: str = SchemaField(description="User ID")
@@ -27,23 +42,6 @@ class AgentExecutorBlock(Block):
input_schema: dict = SchemaField(description="Input schema for the graph")
output_schema: dict = SchemaField(description="Output schema for the graph")
@classmethod
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
return data.get("input_schema", {})
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data.get("data", {})
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
required_fields = cls.get_input_schema(data).get("required", [])
return set(required_fields) - set(data)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
class Output(BlockSchema):
pass
@@ -58,26 +56,26 @@ class AgentExecutorBlock(Block):
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
from backend.data.execution import ExecutionEventType
from backend.executor import utils as execution_utils
executor_manager = get_executor_manager_client()
event_bus = get_event_bus()
event_bus = execution_utils.get_execution_event_bus()
graph_exec = execution_utils.add_graph_execution(
graph_exec = executor_manager.add_execution(
graph_id=input_data.graph_id,
graph_version=input_data.graph_version,
user_id=input_data.user_id,
inputs=input_data.data,
data=input_data.data,
)
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.id}"
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.graph_exec_id}"
logger.info(f"Starting execution of {log_id}")
for event in event_bus.listen(
user_id=graph_exec.user_id,
graph_id=graph_exec.graph_id,
graph_exec_id=graph_exec.id,
graph_id=graph_exec.graph_id, graph_exec_id=graph_exec.graph_exec_id
):
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
logger.info(
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
)
if not event.node_id:
if event.status in [
ExecutionStatus.COMPLETED,
ExecutionStatus.TERMINATED,
@@ -88,10 +86,6 @@ class AgentExecutorBlock(Block):
else:
continue
logger.debug(
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
)
if not event.block_id:
logger.warning(f"{log_id} received event without block_id {event}")
continue
@@ -106,7 +100,5 @@ class AgentExecutorBlock(Block):
continue
for output_data in event.output_data.get("output", []):
logger.debug(
f"Execution {log_id} produced {output_name}: {output_data}"
)
logger.info(f"Execution {log_id} produced {output_name}: {output_data}")
yield output_name, output_data

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel
from backend.data.model import SchemaField
@@ -143,12 +143,11 @@ class ContactEmail(BaseModel):
class EmploymentHistory(BaseModel):
"""An employment history in Apollo"""
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
from_attributes=True,
populate_by_name=True,
)
class Config:
extra = "allow"
arbitrary_types_allowed = True
from_attributes = True
populate_by_name = True
_id: Optional[str] = None
created_at: Optional[str] = None
@@ -189,12 +188,11 @@ class TypedCustomField(BaseModel):
class Pagination(BaseModel):
"""Pagination in Apollo"""
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
from_attributes=True,
populate_by_name=True,
)
class Config:
extra = "allow" # Allow extra fields
arbitrary_types_allowed = True # Allow any type
from_attributes = True # Allow from_orm
populate_by_name = True # Allow field aliases to work both ways
page: int = 0
per_page: int = 0
@@ -232,12 +230,11 @@ class PhoneNumber(BaseModel):
class Organization(BaseModel):
"""An organization in Apollo"""
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
from_attributes=True,
populate_by_name=True,
)
class Config:
extra = "allow"
arbitrary_types_allowed = True
from_attributes = True
populate_by_name = True
id: Optional[str] = "N/A"
name: Optional[str] = "N/A"
@@ -271,12 +268,11 @@ class Organization(BaseModel):
class Contact(BaseModel):
"""A contact in Apollo"""
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
from_attributes=True,
populate_by_name=True,
)
class Config:
extra = "allow"
arbitrary_types_allowed = True
from_attributes = True
populate_by_name = True
contact_roles: list[Any] = []
id: Optional[str] = None
@@ -373,14 +369,14 @@ If a company has several office locations, results are still based on the headqu
To exclude companies based on location, use the organization_not_locations parameter.
""",
default_factory=list,
default=[],
)
organizations_not_locations: list[str] = SchemaField(
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
""",
default_factory=list,
default=[],
)
q_organization_keyword_tags: list[str] = SchemaField(
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
@@ -394,7 +390,7 @@ If the value you enter for this parameter does not match with a company's name,
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, identify the values for organization_id when you call this endpoint.""",
default_factory=list,
default=[],
)
max_results: int = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
@@ -447,14 +443,14 @@ Results also include job titles with the same terms, even if they are not exact
Use this parameter in combination with the person_seniorities[] parameter to find people based on specific job functions and seniority levels.
""",
default_factory=list,
default=[],
placeholder="marketing manager",
)
person_locations: list[str] = SchemaField(
description="""The location where people live. You can search across cities, US states, and countries.
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
default_factory=list,
default=[],
)
person_seniorities: list[SenorityLevels] = SchemaField(
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
@@ -464,7 +460,7 @@ For a person to be included in search results, they only need to match 1 of the
Searches only return results based on their current job title, so searching for Director-level employees only returns people that currently hold a Director-level title. If someone was previously a Director, but is currently a VP, they would not be included in your search results.
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
default_factory=list,
default=[],
)
organization_locations: list[str] = SchemaField(
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
@@ -472,7 +468,7 @@ Use this parameter in combination with the person_titles[] parameter to find peo
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
To find people based on their personal location, use the person_locations parameter.""",
default_factory=list,
default=[],
)
q_organization_domains: list[str] = SchemaField(
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
@@ -480,23 +476,23 @@ To find people based on their personal location, use the person_locations parame
You can add multiple domains to search across companies.
Examples: apollo.io and microsoft.com""",
default_factory=list,
default=[],
)
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
default_factory=list,
default=[],
)
organization_ids: list[str] = SchemaField(
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
default_factory=list,
default=[],
)
organization_num_empoloyees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
default_factory=list,
default=[],
)
q_keywords: str = SchemaField(
description="""A string of words over which we want to filter the results""",
@@ -526,12 +522,11 @@ Use the page parameter to search the different pages of data.""",
class SearchPeopleResponse(BaseModel):
"""Response from Apollo's search people API"""
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
from_attributes=True,
populate_by_name=True,
)
class Config:
extra = "allow" # Allow extra fields
arbitrary_types_allowed = True # Allow any type
from_attributes = True # Allow from_orm
populate_by_name = True # Allow field aliases to work both ways
breadcrumbs: list[Breadcrumb] = []
partial_results_only: bool = True

View File

@@ -32,18 +32,18 @@ If a company has several office locations, results are still based on the headqu
To exclude companies based on location, use the organization_not_locations parameter.
""",
default_factory=list,
default=[],
)
organizations_not_locations: list[str] = SchemaField(
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
""",
default_factory=list,
default=[],
)
q_organization_keyword_tags: list[str] = SchemaField(
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
default_factory=list,
default=[],
)
q_organization_name: str = SchemaField(
description="""Filter search results to include a specific company name.
@@ -56,7 +56,7 @@ If the value you enter for this parameter does not match with a company's name,
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, identify the values for organization_id when you call this endpoint.""",
default_factory=list,
default=[],
)
max_results: int = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
@@ -72,7 +72,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
class Output(BlockSchema):
organizations: list[Organization] = SchemaField(
description="List of organizations found",
default_factory=list,
default=[],
)
organization: Organization = SchemaField(
description="Each found organization, one at a time",

View File

@@ -26,14 +26,14 @@ class SearchPeopleBlock(Block):
Use this parameter in combination with the person_seniorities[] parameter to find people based on specific job functions and seniority levels.
""",
default_factory=list,
default=[],
advanced=False,
)
person_locations: list[str] = SchemaField(
description="""The location where people live. You can search across cities, US states, and countries.
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
default_factory=list,
default=[],
advanced=False,
)
person_seniorities: list[SenorityLevels] = SchemaField(
@@ -44,7 +44,7 @@ class SearchPeopleBlock(Block):
Searches only return results based on their current job title, so searching for Director-level employees only returns people that currently hold a Director-level title. If someone was previously a Director, but is currently a VP, they would not be included in your search results.
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
default_factory=list,
default=[],
advanced=False,
)
organization_locations: list[str] = SchemaField(
@@ -53,7 +53,7 @@ class SearchPeopleBlock(Block):
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
To find people based on their personal location, use the person_locations parameter.""",
default_factory=list,
default=[],
advanced=False,
)
q_organization_domains: list[str] = SchemaField(
@@ -62,26 +62,26 @@ class SearchPeopleBlock(Block):
You can add multiple domains to search across companies.
Examples: apollo.io and microsoft.com""",
default_factory=list,
default=[],
advanced=False,
)
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
default_factory=list,
default=[],
advanced=False,
)
organization_ids: list[str] = SchemaField(
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
default_factory=list,
default=[],
advanced=False,
)
organization_num_empoloyees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
default_factory=list,
default=[],
advanced=False,
)
q_keywords: str = SchemaField(
@@ -104,7 +104,7 @@ class SearchPeopleBlock(Block):
class Output(BlockSchema):
people: list[Contact] = SchemaField(
description="List of people found",
default_factory=list,
default=[],
)
person: Contact = SchemaField(
description="Each found person, one at a time",

View File

@@ -3,20 +3,22 @@ from typing import Any, List
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
from backend.data.model import SchemaField
from backend.util import json
from backend.util.file import store_media_file
from backend.util.file import MediaFile, store_media_file
from backend.util.mock import MockObject
from backend.util.type import MediaFileType, convert
from backend.util.text import TextFormatter
from backend.util.type import convert
formatter = TextFormatter()
class FileStoreBlock(Block):
class Input(BlockSchema):
file_in: MediaFileType = SchemaField(
file_in: MediaFile = SchemaField(
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
)
class Output(BlockSchema):
file_out: MediaFileType = SchemaField(
file_out: MediaFile = SchemaField(
description="The relative path to the stored file in the temporary directory."
)
@@ -88,6 +90,29 @@ class StoreValueBlock(Block):
yield "output", input_data.data or input_data.input
class PrintToConsoleBlock(Block):
class Input(BlockSchema):
text: str = SchemaField(description="The text to print to the console.")
class Output(BlockSchema):
status: str = SchemaField(description="The status of the print operation.")
def __init__(self):
super().__init__(
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
description="Print the given text to the console, this is used for a debugging purpose.",
categories={BlockCategory.BASIC},
input_schema=PrintToConsoleBlock.Input,
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
test_output=("status", "printed"),
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
print(">>>>> Print: ", input_data.text)
yield "status", "printed"
class FindInDictionaryBlock(Block):
class Input(BlockSchema):
input: Any = SchemaField(description="Dictionary to lookup from")
@@ -128,9 +153,6 @@ class FindInDictionaryBlock(Block):
obj = input_data.input
key = input_data.key
if isinstance(obj, str):
obj = json.loads(obj)
if isinstance(obj, dict) and key in obj:
yield "output", obj[key]
elif isinstance(obj, list) and isinstance(key, int) and 0 <= key < len(obj):
@@ -148,10 +170,192 @@ class FindInDictionaryBlock(Block):
yield "missing", input_data.input
class AgentInputBlock(Block):
"""
This block is used to provide input to the graph.
It takes in a value, name, description, default values list and bool to limit selection to default values.
It Outputs the value passed as input.
"""
class Input(BlockSchema):
name: str = SchemaField(description="The name of the input.")
value: Any = SchemaField(
description="The value to be passed as input.",
default=None,
)
title: str | None = SchemaField(
description="The title of the input.", default=None, advanced=True
)
description: str | None = SchemaField(
description="The description of the input.",
default=None,
advanced=True,
)
placeholder_values: List[Any] = SchemaField(
description="The placeholder values to be passed as input.",
default=[],
advanced=True,
)
limit_to_placeholder_values: bool = SchemaField(
description="Whether to limit the selection to placeholder values.",
default=False,
advanced=True,
)
advanced: bool = SchemaField(
description="Whether to show the input in the advanced section, if the field is not required.",
default=False,
advanced=True,
)
secret: bool = SchemaField(
description="Whether the input should be treated as a secret.",
default=False,
advanced=True,
)
class Output(BlockSchema):
result: Any = SchemaField(description="The value passed as input.")
def __init__(self):
super().__init__(
id="c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
description="This block is used to provide input to the graph.",
input_schema=AgentInputBlock.Input,
output_schema=AgentInputBlock.Output,
test_input=[
{
"value": "Hello, World!",
"name": "input_1",
"description": "This is a test input.",
"placeholder_values": [],
"limit_to_placeholder_values": False,
},
{
"value": "Hello, World!",
"name": "input_2",
"description": "This is a test input.",
"placeholder_values": ["Hello, World!"],
"limit_to_placeholder_values": True,
},
],
test_output=[
("result", "Hello, World!"),
("result", "Hello, World!"),
],
categories={BlockCategory.INPUT, BlockCategory.BASIC},
block_type=BlockType.INPUT,
static_output=True,
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "result", input_data.value
class AgentOutputBlock(Block):
"""
Records the output of the graph for users to see.
Behavior:
If `format` is provided and the `value` is of a type that can be formatted,
the block attempts to format the recorded_value using the `format`.
If formatting fails or no `format` is provided, the raw `value` is output.
"""
class Input(BlockSchema):
value: Any = SchemaField(
description="The value to be recorded as output.",
default=None,
advanced=False,
)
name: str = SchemaField(description="The name of the output.")
title: str | None = SchemaField(
description="The title of the output.",
default=None,
advanced=True,
)
description: str | None = SchemaField(
description="The description of the output.",
default=None,
advanced=True,
)
format: str = SchemaField(
description="The format string to be used to format the recorded_value. Use Jinja2 syntax.",
default="",
advanced=True,
)
advanced: bool = SchemaField(
description="Whether to treat the output as advanced.",
default=False,
advanced=True,
)
secret: bool = SchemaField(
description="Whether the output should be treated as a secret.",
default=False,
advanced=True,
)
class Output(BlockSchema):
output: Any = SchemaField(description="The value recorded as output.")
name: Any = SchemaField(description="The name of the value recorded as output.")
def __init__(self):
super().__init__(
id="363ae599-353e-4804-937e-b2ee3cef3da4",
description="Stores the output of the graph for users to see.",
input_schema=AgentOutputBlock.Input,
output_schema=AgentOutputBlock.Output,
test_input=[
{
"value": "Hello, World!",
"name": "output_1",
"description": "This is a test output.",
"format": "{{ output_1 }}!!",
},
{
"value": "42",
"name": "output_2",
"description": "This is another test output.",
"format": "{{ output_2 }}",
},
{
"value": MockObject(value="!!", key="key"),
"name": "output_3",
"description": "This is a test output with a mock object.",
"format": "{{ output_3 }}",
},
],
test_output=[
("output", "Hello, World!!!"),
("output", "42"),
("output", MockObject(value="!!", key="key")),
],
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
block_type=BlockType.OUTPUT,
static_output=True,
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
"""
Attempts to format the recorded_value using the fmt_string if provided.
If formatting fails or no fmt_string is given, returns the original recorded_value.
"""
if input_data.format:
try:
yield "output", formatter.format_string(
input_data.format, {input_data.name: input_data.value}
)
except Exception as e:
yield "output", f"Error: {e}, {input_data.value}"
else:
yield "output", input_data.value
yield "name", input_data.name
class AddToDictionaryBlock(Block):
class Input(BlockSchema):
dictionary: dict[Any, Any] = SchemaField(
default_factory=dict,
default={},
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
)
key: str = SchemaField(
@@ -167,7 +371,7 @@ class AddToDictionaryBlock(Block):
advanced=False,
)
entries: dict[Any, Any] = SchemaField(
default_factory=dict,
default={},
description="The entries to add to the dictionary. This is the batch version of the `key` and `value` fields.",
advanced=True,
)
@@ -229,7 +433,7 @@ class AddToDictionaryBlock(Block):
class AddToListBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(
default_factory=list,
default=[],
advanced=False,
description="The list to add the entry to. If not provided, a new list will be created.",
)
@@ -239,7 +443,7 @@ class AddToListBlock(Block):
default=None,
)
entries: List[Any] = SchemaField(
default_factory=lambda: list(),
default=[],
description="The entries to add to the list. This is the batch version of the `entry` field.",
advanced=True,
)

View File

@@ -55,7 +55,7 @@ class CodeExecutionBlock(Block):
"These commands are executed with `sh`, in the foreground."
),
placeholder="pip install cowsay",
default_factory=list,
default=[],
advanced=False,
)
@@ -207,7 +207,7 @@ class InstantiationBlock(Block):
"These commands are executed with `sh`, in the foreground."
),
placeholder="pip install cowsay",
default_factory=list,
default=[],
advanced=False,
)

View File

@@ -8,7 +8,6 @@ from backend.data.block import (
BlockSchema,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.compass import CompassWebhookType
@@ -43,7 +42,7 @@ class CompassAITriggerBlock(Block):
input_schema=CompassAITriggerBlock.Input,
output_schema=CompassAITriggerBlock.Output,
webhook_config=BlockManualWebhookConfig(
provider=ProviderName.COMPASS,
provider="compass",
webhook_type=CompassWebhookType.TRANSCRIPTION,
),
test_input=[

View File

@@ -34,7 +34,7 @@ class ReadCsvBlock(Block):
)
skip_columns: list[str] = SchemaField(
description="The columns to skip from the start of the row",
default_factory=list,
default=[],
)
class Output(BlockSchema):

View File

@@ -49,9 +49,8 @@ class ExaContentsBlock(Block):
class Output(BlockSchema):
results: list = SchemaField(
description="List of document contents",
default_factory=list,
default=[],
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(

View File

@@ -38,11 +38,11 @@ class ExaSearchBlock(Block):
)
include_domains: List[str] = SchemaField(
description="Domains to include in search",
default_factory=list,
default=[],
)
exclude_domains: List[str] = SchemaField(
description="Domains to exclude from search",
default_factory=list,
default=[],
advanced=True,
)
start_crawl_date: datetime = SchemaField(
@@ -59,12 +59,12 @@ class ExaSearchBlock(Block):
)
include_text: List[str] = SchemaField(
description="Text patterns to include",
default_factory=list,
default=[],
advanced=True,
)
exclude_text: List[str] = SchemaField(
description="Text patterns to exclude",
default_factory=list,
default=[],
advanced=True,
)
contents: ContentSettings = SchemaField(
@@ -76,7 +76,7 @@ class ExaSearchBlock(Block):
class Output(BlockSchema):
results: list = SchemaField(
description="List of search results",
default_factory=list,
default=[],
)
def __init__(self):

View File

@@ -26,12 +26,12 @@ class ExaFindSimilarBlock(Block):
)
include_domains: List[str] = SchemaField(
description="Domains to include in search",
default_factory=list,
default=[],
advanced=True,
)
exclude_domains: List[str] = SchemaField(
description="Domains to exclude from search",
default_factory=list,
default=[],
advanced=True,
)
start_crawl_date: datetime = SchemaField(
@@ -48,12 +48,12 @@ class ExaFindSimilarBlock(Block):
)
include_text: List[str] = SchemaField(
description="Text patterns to include (max 1 string, up to 5 words)",
default_factory=list,
default=[],
advanced=True,
)
exclude_text: List[str] = SchemaField(
description="Text patterns to exclude (max 1 string, up to 5 words)",
default_factory=list,
default=[],
advanced=True,
)
contents: ContentSettings = SchemaField(
@@ -65,7 +65,7 @@ class ExaFindSimilarBlock(Block):
class Output(BlockSchema):
results: List[Any] = SchemaField(
description="List of similar documents with title, URL, published date, author, and score",
default_factory=list,
default=[],
)
def __init__(self):

View File

@@ -42,7 +42,7 @@ class AIVideoGeneratorBlock(Block):
description="Error message if video generation failed."
)
logs: list[str] = SchemaField(
description="Generation progress logs.",
description="Generation progress logs.", optional=True
)
def __init__(self):

View File

@@ -1,51 +0,0 @@
from backend.data.block import (
Block,
BlockCategory,
BlockManualWebhookConfig,
BlockOutput,
BlockSchema,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.generic import GenericWebhookType
class GenericWebhookTriggerBlock(Block):
class Input(BlockSchema):
payload: dict = SchemaField(hidden=True, default_factory=dict)
constants: dict = SchemaField(
description="The constants to be set when the block is put on the graph",
default_factory=dict,
)
class Output(BlockSchema):
payload: dict = SchemaField(
description="The complete webhook payload that was received from the generic webhook."
)
constants: dict = SchemaField(
description="The constants to be set when the block is put on the graph"
)
example_payload = {"message": "Hello, World!"}
def __init__(self):
super().__init__(
id="8fa8c167-2002-47ce-aba8-97572fc5d387",
description="This block will output the contents of the generic input for the webhook.",
categories={BlockCategory.INPUT},
input_schema=GenericWebhookTriggerBlock.Input,
output_schema=GenericWebhookTriggerBlock.Output,
webhook_config=BlockManualWebhookConfig(
provider=ProviderName.GENERIC_WEBHOOK,
webhook_type=GenericWebhookType.PLAIN,
),
test_input={"constants": {"key": "value"}, "payload": self.example_payload},
test_output=[
("constants", {"key": "value"}),
("payload", self.example_payload),
],
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "constants", input_data.constants
yield "payload", input_data.payload

View File

@@ -38,59 +38,6 @@ def _get_headers(credentials: GithubCredentials) -> dict[str, str]:
}
def convert_comment_url_to_api_endpoint(comment_url: str) -> str:
"""
Converts a GitHub comment URL (web interface) to the appropriate API endpoint URL.
Handles:
1. Issue/PR comments: #issuecomment-{id}
2. PR review comments: #discussion_r{id}
Returns the appropriate API endpoint path for the comment.
"""
# First, check if this is already an API URL
parsed_url = urlparse(comment_url)
if parsed_url.hostname == "api.github.com":
return comment_url
# Replace pull with issues for comment endpoints
if "/pull/" in comment_url:
comment_url = comment_url.replace("/pull/", "/issues/")
# Handle issue/PR comments (#issuecomment-xxx)
if "#issuecomment-" in comment_url:
base_url, comment_part = comment_url.split("#issuecomment-")
comment_id = comment_part
# Extract repo information from base URL
parsed_url = urlparse(base_url)
path_parts = parsed_url.path.strip("/").split("/")
owner, repo = path_parts[0], path_parts[1]
# Construct API URL for issue comments
return (
f"https://api.github.com/repos/{owner}/{repo}/issues/comments/{comment_id}"
)
# Handle PR review comments (#discussion_r)
elif "#discussion_r" in comment_url:
base_url, comment_part = comment_url.split("#discussion_r")
comment_id = comment_part
# Extract repo information from base URL
parsed_url = urlparse(base_url)
path_parts = parsed_url.path.strip("/").split("/")
owner, repo = path_parts[0], path_parts[1]
# Construct API URL for PR review comments
return (
f"https://api.github.com/repos/{owner}/{repo}/pulls/comments/{comment_id}"
)
# If no specific comment identifiers are found, use the general URL conversion
return _convert_to_api_url(comment_url)
def get_api(
credentials: GithubCredentials | GithubFineGrainedAPICredentials,
convert_urls: bool = True,

View File

@@ -172,9 +172,7 @@ class GithubCreateCheckRunBlock(Block):
data.output = output_data
check_runs_url = f"{repo_url}/check-runs"
response = api.post(
check_runs_url, data=data.model_dump_json(exclude_none=True)
)
response = api.post(check_runs_url)
result = response.json()
return {
@@ -325,9 +323,7 @@ class GithubUpdateCheckRunBlock(Block):
data.output = output_data
check_run_url = f"{repo_url}/check-runs/{check_run_id}"
response = api.patch(
check_run_url, data=data.model_dump_json(exclude_none=True)
)
response = api.patch(check_run_url)
result = response.json()
return {

View File

@@ -1,4 +1,3 @@
import logging
from urllib.parse import urlparse
from typing_extensions import TypedDict
@@ -6,7 +5,7 @@ from typing_extensions import TypedDict
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._api import convert_comment_url_to_api_endpoint, get_api
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
@@ -15,8 +14,6 @@ from ._auth import (
GithubCredentialsInput,
)
logger = logging.getLogger(__name__)
def is_github_url(url: str) -> bool:
return urlparse(url).netloc == "github.com"
@@ -111,228 +108,6 @@ class GithubCommentBlock(Block):
# --8<-- [end:GithubCommentBlockExample]
class GithubUpdateCommentBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
comment_url: str = SchemaField(
description="URL of the GitHub comment",
placeholder="https://github.com/owner/repo/issues/1#issuecomment-123456789",
default="",
advanced=False,
)
issue_url: str = SchemaField(
description="URL of the GitHub issue or pull request",
placeholder="https://github.com/owner/repo/issues/1",
default="",
)
comment_id: str = SchemaField(
description="ID of the GitHub comment",
placeholder="123456789",
default="",
)
comment: str = SchemaField(
description="Comment to update",
placeholder="Enter your comment",
)
class Output(BlockSchema):
id: int = SchemaField(description="ID of the updated comment")
url: str = SchemaField(description="URL to the comment on GitHub")
error: str = SchemaField(
description="Error message if the comment update failed"
)
def __init__(self):
super().__init__(
id="b3f4d747-10e3-4e69-8c51-f2be1d99c9a7",
description="This block updates a comment on a specified GitHub issue or pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubUpdateCommentBlock.Input,
output_schema=GithubUpdateCommentBlock.Output,
test_input={
"comment_url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
"comment": "This is an updated comment.",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("id", 123456789),
(
"url",
"https://github.com/owner/repo/issues/1#issuecomment-123456789",
),
],
test_mock={
"update_comment": lambda *args, **kwargs: (
123456789,
"https://github.com/owner/repo/issues/1#issuecomment-123456789",
)
},
)
@staticmethod
def update_comment(
credentials: GithubCredentials, comment_url: str, body_text: str
) -> tuple[int, str]:
api = get_api(credentials, convert_urls=False)
data = {"body": body_text}
url = convert_comment_url_to_api_endpoint(comment_url)
logger.info(url)
response = api.patch(url, json=data)
comment = response.json()
return comment["id"], comment["html_url"]
def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
if (
not input_data.comment_url
and input_data.comment_id
and input_data.issue_url
):
parsed_url = urlparse(input_data.issue_url)
path_parts = parsed_url.path.strip("/").split("/")
owner, repo = path_parts[0], path_parts[1]
input_data.comment_url = f"https://api.github.com/repos/{owner}/{repo}/issues/comments/{input_data.comment_id}"
elif (
not input_data.comment_url
and not input_data.comment_id
and input_data.issue_url
):
raise ValueError(
"Must provide either comment_url or comment_id and issue_url"
)
id, url = self.update_comment(
credentials,
input_data.comment_url,
input_data.comment,
)
yield "id", id
yield "url", url
class GithubListCommentsBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
issue_url: str = SchemaField(
description="URL of the GitHub issue or pull request",
placeholder="https://github.com/owner/repo/issues/1",
)
class Output(BlockSchema):
class CommentItem(TypedDict):
id: int
body: str
user: str
url: str
comment: CommentItem = SchemaField(
title="Comment", description="Comments with their ID, body, user, and URL"
)
comments: list[CommentItem] = SchemaField(
description="List of comments with their ID, body, user, and URL"
)
error: str = SchemaField(description="Error message if listing comments failed")
def __init__(self):
super().__init__(
id="c4b5fb63-0005-4a11-b35a-0c2467bd6b59",
description="This block lists all comments for a specified GitHub issue or pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListCommentsBlock.Input,
output_schema=GithubListCommentsBlock.Output,
test_input={
"issue_url": "https://github.com/owner/repo/issues/1",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"comment",
{
"id": 123456789,
"body": "This is a test comment.",
"user": "test_user",
"url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
},
),
(
"comments",
[
{
"id": 123456789,
"body": "This is a test comment.",
"user": "test_user",
"url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
}
],
),
],
test_mock={
"list_comments": lambda *args, **kwargs: [
{
"id": 123456789,
"body": "This is a test comment.",
"user": "test_user",
"url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
}
]
},
)
@staticmethod
def list_comments(
credentials: GithubCredentials, issue_url: str
) -> list[Output.CommentItem]:
parsed_url = urlparse(issue_url)
path_parts = parsed_url.path.strip("/").split("/")
owner = path_parts[0]
repo = path_parts[1]
# GitHub API uses 'issues' for both issues and pull requests when it comes to comments
issue_number = path_parts[3] # Whether 'issues/123' or 'pull/123'
# Construct the proper API URL directly
api_url = f"https://api.github.com/repos/{owner}/{repo}/issues/{issue_number}/comments"
# Set convert_urls=False since we're already providing an API URL
api = get_api(credentials, convert_urls=False)
response = api.get(api_url)
comments = response.json()
parsed_comments: list[GithubListCommentsBlock.Output.CommentItem] = [
{
"id": comment["id"],
"body": comment["body"],
"user": comment["user"]["login"],
"url": comment["html_url"],
}
for comment in comments
]
return parsed_comments
def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
comments = self.list_comments(
credentials,
input_data.issue_url,
)
yield from (("comment", comment) for comment in comments)
yield "comments", comments
class GithubMakeIssueBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")

View File

@@ -144,7 +144,7 @@ class GithubCreateStatusBlock(Block):
data.description = description
status_url = f"{repo_url}/statuses/{sha}"
response = api.post(status_url, data=data.model_dump_json(exclude_none=True))
response = api.post(status_url, json=data)
result = response.json()
return {

View File

@@ -12,7 +12,6 @@ from backend.data.block import (
BlockWebhookConfig,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from ._auth import (
TEST_CREDENTIALS,
@@ -37,7 +36,7 @@ class GitHubTriggerBase:
placeholder="{owner}/{repo}",
)
# --8<-- [start:example-payload-field]
payload: dict = SchemaField(hidden=True, default_factory=dict)
payload: dict = SchemaField(hidden=True, default={})
# --8<-- [end:example-payload-field]
class Output(BlockSchema):
@@ -124,7 +123,7 @@ class GithubPullRequestTriggerBlock(GitHubTriggerBase, Block):
output_schema=GithubPullRequestTriggerBlock.Output,
# --8<-- [start:example-webhook_config]
webhook_config=BlockWebhookConfig(
provider=ProviderName.GITHUB,
provider="github",
webhook_type=GithubWebhookType.REPO,
resource_format="{repo}",
event_filter_input="events",

View File

@@ -8,7 +8,6 @@ from pydantic import BaseModel
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.settings import Settings
from ._auth import (
GOOGLE_OAUTH_IS_CONFIGURED,
@@ -151,8 +150,8 @@ class GmailReadBlock(Block):
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=Settings().secrets.google_client_id,
client_secret=Settings().secrets.google_client_secret,
client_id=kwargs.get("client_id"),
client_secret=kwargs.get("client_secret"),
scopes=credentials.scopes,
)
return build("gmail", "v1", credentials=creds)

View File

@@ -3,7 +3,6 @@ from googleapiclient.discovery import build
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.settings import Settings
from ._auth import (
GOOGLE_OAUTH_IS_CONFIGURED,
@@ -87,8 +86,8 @@ class GoogleSheetsReadBlock(Block):
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=Settings().secrets.google_client_id,
client_secret=Settings().secrets.google_client_secret,
client_id=kwargs.get("client_id"),
client_secret=kwargs.get("client_secret"),
scopes=credentials.scopes,
)
return build("sheets", "v4", credentials=creds)

View File

@@ -1,16 +1,11 @@
import json
import logging
from enum import Enum
from typing import Any
from requests.exceptions import HTTPError, RequestException
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import requests
logger = logging.getLogger(name=__name__)
class HttpMethod(Enum):
GET = "GET"
@@ -34,7 +29,7 @@ class SendWebRequestBlock(Block):
)
headers: dict[str, str] = SchemaField(
description="The headers to include in the request",
default_factory=dict,
default={},
)
json_format: bool = SchemaField(
title="JSON format",
@@ -48,9 +43,8 @@ class SendWebRequestBlock(Block):
class Output(BlockSchema):
response: object = SchemaField(description="The response from the server")
client_error: object = SchemaField(description="Errors on 4xx status codes")
server_error: object = SchemaField(description="Errors on 5xx status codes")
error: str = SchemaField(description="Errors for all other exceptions")
client_error: object = SchemaField(description="The error on 4xx status codes")
server_error: object = SchemaField(description="The error on 5xx status codes")
def __init__(self):
super().__init__(
@@ -74,40 +68,20 @@ class SendWebRequestBlock(Block):
# we should send it as plain text instead
input_data.json_format = False
try:
response = requests.request(
input_data.method.value,
input_data.url,
headers=input_data.headers,
json=body if input_data.json_format else None,
data=body if not input_data.json_format else None,
)
result = response.json() if input_data.json_format else response.text
response = requests.request(
input_data.method.value,
input_data.url,
headers=input_data.headers,
json=body if input_data.json_format else None,
data=body if not input_data.json_format else None,
)
result = response.json() if input_data.json_format else response.text
if response.status_code // 100 == 2:
yield "response", result
except HTTPError as e:
# Handle error responses
try:
result = e.response.json() if input_data.json_format else str(e)
except json.JSONDecodeError:
result = str(e)
if 400 <= e.response.status_code < 500:
yield "client_error", result
elif 500 <= e.response.status_code < 600:
yield "server_error", result
else:
error_msg = (
"Unexpected status code "
f"{e.response.status_code} '{e.response.reason}'"
)
logger.warning(error_msg)
yield "error", error_msg
except RequestException as e:
# Handle other request-related exceptions
yield "error", str(e)
except Exception as e:
# Catch any other unexpected exceptions
yield "error", str(e)
elif response.status_code // 100 == 4:
yield "client_error", result
elif response.status_code // 100 == 5:
yield "server_error", result
else:
raise ValueError(f"Unexpected status code: {response.status_code}")

View File

@@ -15,8 +15,7 @@ class HubSpotCompanyBlock(Block):
description="Operation to perform (create, update, get)", default="get"
)
company_data: dict = SchemaField(
description="Company data for create/update operations",
default_factory=dict,
description="Company data for create/update operations", default={}
)
domain: str = SchemaField(
description="Company domain for get/update operations", default=""

View File

@@ -15,8 +15,7 @@ class HubSpotContactBlock(Block):
description="Operation to perform (create, update, get)", default="get"
)
contact_data: dict = SchemaField(
description="Contact data for create/update operations",
default_factory=dict,
description="Contact data for create/update operations", default={}
)
email: str = SchemaField(
description="Email address for get/update operations", default=""

View File

@@ -19,7 +19,7 @@ class HubSpotEngagementBlock(Block):
)
email_data: dict = SchemaField(
description="Email data including recipient, subject, content",
default_factory=dict,
default={},
)
contact_id: str = SchemaField(
description="Contact ID for engagement tracking", default=""
@@ -27,6 +27,7 @@ class HubSpotEngagementBlock(Block):
timeframe_days: int = SchemaField(
description="Number of days to look back for engagement",
default=30,
optional=True,
)
class Output(BlockSchema):

View File

@@ -142,16 +142,6 @@ class IdeogramModelBlock(Block):
title="Color Palette Preset",
advanced=True,
)
custom_color_palette: Optional[list[str]] = SchemaField(
description=(
"Only available for model version V_2 or V_2_TURBO. Provide one or more color hex codes "
"(e.g., ['#000030', '#1C0C47', '#9900FF', '#4285F4', '#FFFFFF']) to define a custom color "
"palette. Only used if 'color_palette_name' is 'NONE'."
),
default=None,
title="Custom Color Palette",
advanced=True,
)
class Output(BlockSchema):
result: str = SchemaField(description="Generated image URL")
@@ -174,13 +164,6 @@ class IdeogramModelBlock(Block):
"style_type": StyleType.AUTO,
"negative_prompt": None,
"color_palette_name": ColorPalettePreset.NONE,
"custom_color_palette": [
"#000030",
"#1C0C47",
"#9900FF",
"#4285F4",
"#FFFFFF",
],
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
@@ -190,7 +173,7 @@ class IdeogramModelBlock(Block):
),
],
test_mock={
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name, custom_colors: "https://ideogram.ai/api/images/test-generated-image-url.png",
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name: "https://ideogram.ai/api/images/test-generated-image-url.png",
"upscale_image": lambda api_key, image_url: "https://ideogram.ai/api/images/test-upscaled-image-url.png",
},
test_credentials=TEST_CREDENTIALS,
@@ -212,7 +195,6 @@ class IdeogramModelBlock(Block):
style_type=input_data.style_type.value,
negative_prompt=input_data.negative_prompt,
color_palette_name=input_data.color_palette_name.value,
custom_colors=input_data.custom_color_palette,
)
# Step 2: Upscale the image if requested
@@ -235,7 +217,6 @@ class IdeogramModelBlock(Block):
style_type: str,
negative_prompt: Optional[str],
color_palette_name: str,
custom_colors: Optional[list[str]],
):
url = "https://api.ideogram.ai/generate"
headers = {
@@ -260,11 +241,7 @@ class IdeogramModelBlock(Block):
data["image_request"]["negative_prompt"] = negative_prompt
if color_palette_name != "NONE":
data["color_palette"] = {"name": color_palette_name}
elif custom_colors:
data["color_palette"] = {
"members": [{"color_hex": color} for color in custom_colors]
}
data["image_request"]["color_palette"] = {"name": color_palette_name}
try:
response = requests.post(url, json=data, headers=headers)
@@ -290,7 +267,9 @@ class IdeogramModelBlock(Block):
response = requests.post(
url,
headers=headers,
data={"image_request": "{}"},
data={
"image_request": "{}", # Empty JSON object
},
files=files,
)

View File

@@ -1,556 +0,0 @@
import copy
from datetime import date, time
from typing import Any, Optional
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.mock import MockObject
from backend.util.settings import Config
from backend.util.text import TextFormatter
from backend.util.type import LongTextType, MediaFileType, ShortTextType
formatter = TextFormatter()
config = Config()
class AgentInputBlock(Block):
"""
This block is used to provide input to the graph.
It takes in a value, name, description, default values list and bool to limit selection to default values.
It Outputs the value passed as input.
"""
class Input(BlockSchema):
name: str = SchemaField(description="The name of the input.")
value: Any = SchemaField(
description="The value to be passed as input.",
default=None,
)
title: str | None = SchemaField(
description="The title of the input.", default=None, advanced=True
)
description: str | None = SchemaField(
description="The description of the input.",
default=None,
advanced=True,
)
placeholder_values: list = SchemaField(
description="The placeholder values to be passed as input.",
default_factory=list,
advanced=True,
hidden=True,
)
advanced: bool = SchemaField(
description="Whether to show the input in the advanced section, if the field is not required.",
default=False,
advanced=True,
)
secret: bool = SchemaField(
description="Whether the input should be treated as a secret.",
default=False,
advanced=True,
)
def generate_schema(self):
schema = copy.deepcopy(self.get_field_schema("value"))
if possible_values := self.placeholder_values:
schema["enum"] = possible_values
return schema
class Output(BlockSchema):
result: Any = SchemaField(description="The value passed as input.")
def __init__(self, **kwargs):
super().__init__(
**{
"id": "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
"description": "Base block for user inputs.",
"input_schema": AgentInputBlock.Input,
"output_schema": AgentInputBlock.Output,
"test_input": [
{
"value": "Hello, World!",
"name": "input_1",
"description": "Example test input.",
"placeholder_values": [],
},
{
"value": "Hello, World!",
"name": "input_2",
"description": "Example test input with placeholders.",
"placeholder_values": ["Hello, World!"],
},
],
"test_output": [
("result", "Hello, World!"),
("result", "Hello, World!"),
],
"categories": {BlockCategory.INPUT, BlockCategory.BASIC},
"block_type": BlockType.INPUT,
"static_output": True,
**kwargs,
}
)
def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
if input_data.value is not None:
yield "result", input_data.value
class AgentOutputBlock(Block):
"""
Records the output of the graph for users to see.
Behavior:
If `format` is provided and the `value` is of a type that can be formatted,
the block attempts to format the recorded_value using the `format`.
If formatting fails or no `format` is provided, the raw `value` is output.
"""
class Input(BlockSchema):
value: Any = SchemaField(
description="The value to be recorded as output.",
default=None,
advanced=False,
)
name: str = SchemaField(description="The name of the output.")
title: str | None = SchemaField(
description="The title of the output.",
default=None,
advanced=True,
)
description: str | None = SchemaField(
description="The description of the output.",
default=None,
advanced=True,
)
format: str = SchemaField(
description="The format string to be used to format the recorded_value. Use Jinja2 syntax.",
default="",
advanced=True,
)
advanced: bool = SchemaField(
description="Whether to treat the output as advanced.",
default=False,
advanced=True,
)
secret: bool = SchemaField(
description="Whether the output should be treated as a secret.",
default=False,
advanced=True,
)
def generate_schema(self):
return self.get_field_schema("value")
class Output(BlockSchema):
output: Any = SchemaField(description="The value recorded as output.")
name: Any = SchemaField(description="The name of the value recorded as output.")
def __init__(self):
super().__init__(
id="363ae599-353e-4804-937e-b2ee3cef3da4",
description="Stores the output of the graph for users to see.",
input_schema=AgentOutputBlock.Input,
output_schema=AgentOutputBlock.Output,
test_input=[
{
"value": "Hello, World!",
"name": "output_1",
"description": "This is a test output.",
"format": "{{ output_1 }}!!",
},
{
"value": "42",
"name": "output_2",
"description": "This is another test output.",
"format": "{{ output_2 }}",
},
{
"value": MockObject(value="!!", key="key"),
"name": "output_3",
"description": "This is a test output with a mock object.",
"format": "{{ output_3 }}",
},
],
test_output=[
("output", "Hello, World!!!"),
("output", "42"),
("output", MockObject(value="!!", key="key")),
],
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
block_type=BlockType.OUTPUT,
static_output=True,
)
def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
"""
Attempts to format the recorded_value using the fmt_string if provided.
If formatting fails or no fmt_string is given, returns the original recorded_value.
"""
if input_data.format:
try:
yield "output", formatter.format_string(
input_data.format, {input_data.name: input_data.value}
)
except Exception as e:
yield "output", f"Error: {e}, {input_data.value}"
else:
yield "output", input_data.value
yield "name", input_data.name
class AgentShortTextInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: Optional[ShortTextType] = SchemaField(
description="Short text input.",
default=None,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="Short text result.")
def __init__(self):
super().__init__(
id="7fcd3bcb-8e1b-4e69-903d-32d3d4a92158",
description="Block for short text input (single-line).",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentShortTextInputBlock.Input,
output_schema=AgentShortTextInputBlock.Output,
test_input=[
{
"value": "Hello",
"name": "short_text_1",
"description": "Short text example 1",
"placeholder_values": [],
},
{
"value": "Quick test",
"name": "short_text_2",
"description": "Short text example 2",
"placeholder_values": ["Quick test", "Another option"],
},
],
test_output=[
("result", "Hello"),
("result", "Quick test"),
],
)
class AgentLongTextInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: Optional[LongTextType] = SchemaField(
description="Long text input (potentially multi-line).",
default=None,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="Long text result.")
def __init__(self):
super().__init__(
id="90a56ffb-7024-4b2b-ab50-e26c5e5ab8ba",
description="Block for long text input (multi-line).",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentLongTextInputBlock.Input,
output_schema=AgentLongTextInputBlock.Output,
test_input=[
{
"value": "Lorem ipsum dolor sit amet...",
"name": "long_text_1",
"description": "Long text example 1",
"placeholder_values": [],
},
{
"value": "Another multiline text input.",
"name": "long_text_2",
"description": "Long text example 2",
"placeholder_values": ["Another multiline text input."],
},
],
test_output=[
("result", "Lorem ipsum dolor sit amet..."),
("result", "Another multiline text input."),
],
)
class AgentNumberInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: Optional[int] = SchemaField(
description="Number input.",
default=None,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: int = SchemaField(description="Number result.")
def __init__(self):
super().__init__(
id="96dae2bb-97a2-41c2-bd2f-13a3b5a8ea98",
description="Block for number input.",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentNumberInputBlock.Input,
output_schema=AgentNumberInputBlock.Output,
test_input=[
{
"value": 42,
"name": "number_input_1",
"description": "Number example 1",
"placeholder_values": [],
},
{
"value": 314,
"name": "number_input_2",
"description": "Number example 2",
"placeholder_values": [314, 2718],
},
],
test_output=[
("result", 42),
("result", 314),
],
)
class AgentDateInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: Optional[date] = SchemaField(
description="Date input (YYYY-MM-DD).",
default=None,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: date = SchemaField(description="Date result.")
def __init__(self):
super().__init__(
id="7e198b09-4994-47db-8b4d-952d98241817",
description="Block for date input.",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentDateInputBlock.Input,
output_schema=AgentDateInputBlock.Output,
test_input=[
{
# If your system can parse JSON date strings to date objects
"value": str(date(2025, 3, 19)),
"name": "date_input_1",
"description": "Example date input 1",
},
{
"value": str(date(2023, 12, 31)),
"name": "date_input_2",
"description": "Example date input 2",
},
],
test_output=[
("result", date(2025, 3, 19)),
("result", date(2023, 12, 31)),
],
)
class AgentTimeInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: Optional[time] = SchemaField(
description="Time input (HH:MM:SS).",
default=None,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: time = SchemaField(description="Time result.")
def __init__(self):
super().__init__(
id="2a1c757e-86cf-4c7e-aacf-060dc382e434",
description="Block for time input.",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentTimeInputBlock.Input,
output_schema=AgentTimeInputBlock.Output,
test_input=[
{
"value": str(time(9, 30, 0)),
"name": "time_input_1",
"description": "Time example 1",
},
{
"value": str(time(23, 59, 59)),
"name": "time_input_2",
"description": "Time example 2",
},
],
test_output=[
("result", time(9, 30, 0)),
("result", time(23, 59, 59)),
],
)
class AgentFileInputBlock(AgentInputBlock):
"""
A simplified file-upload block. In real usage, you might have a custom
file type or handle binary data. Here, we'll store a string path as the example.
"""
class Input(AgentInputBlock.Input):
value: Optional[MediaFileType] = SchemaField(
description="Path or reference to an uploaded file.",
default=None,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="File reference/path result.")
def __init__(self):
super().__init__(
id="95ead23f-8283-4654-aef3-10c053b74a31",
description="Block for file upload input (string path for example).",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentFileInputBlock.Input,
output_schema=AgentFileInputBlock.Output,
test_input=[
{
"value": "data:image/png;base64,MQ==",
"name": "file_upload_1",
"description": "Example file upload 1",
},
],
test_output=[
("result", str),
],
)
def run(
self,
input_data: Input,
*,
graph_exec_id: str,
**kwargs,
) -> BlockOutput:
if not input_data.value:
return
file_path = store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.value,
return_content=False,
)
yield "result", file_path
class AgentDropdownInputBlock(AgentInputBlock):
"""
A specialized text input block that relies on placeholder_values to present a dropdown.
"""
class Input(AgentInputBlock.Input):
value: Optional[str] = SchemaField(
description="Text selected from a dropdown.",
default=None,
advanced=False,
title="Default Value",
)
placeholder_values: list = SchemaField(
description="Possible values for the dropdown.",
default_factory=list,
advanced=False,
title="Dropdown Options",
)
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="Selected dropdown value.")
def __init__(self):
super().__init__(
id="655d6fdf-a334-421c-b733-520549c07cd1",
description="Block for dropdown text selection.",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentDropdownInputBlock.Input,
output_schema=AgentDropdownInputBlock.Output,
test_input=[
{
"value": "Option A",
"name": "dropdown_1",
"placeholder_values": ["Option A", "Option B", "Option C"],
"description": "Dropdown example 1",
},
{
"value": "Option C",
"name": "dropdown_2",
"placeholder_values": ["Option A", "Option B", "Option C"],
"description": "Dropdown example 2",
},
],
test_output=[
("result", "Option A"),
("result", "Option C"),
],
)
class AgentToggleInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: bool = SchemaField(
description="Boolean toggle input.",
default=False,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: bool = SchemaField(description="Boolean toggle result.")
def __init__(self):
super().__init__(
id="cbf36ab5-df4a-43b6-8a7f-f7ed8652116e",
description="Block for boolean toggle input.",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentToggleInputBlock.Input,
output_schema=AgentToggleInputBlock.Output,
test_input=[
{
"value": True,
"name": "toggle_1",
"description": "Toggle example 1",
},
{
"value": False,
"name": "toggle_2",
"description": "Toggle example 2",
},
],
test_output=[
("result", True),
("result", False),
],
)
IO_BLOCK_IDs = [
AgentInputBlock().id,
AgentOutputBlock().id,
AgentShortTextInputBlock().id,
AgentLongTextInputBlock().id,
AgentNumberInputBlock().id,
AgentDateInputBlock().id,
AgentTimeInputBlock().id,
AgentFileInputBlock().id,
AgentDropdownInputBlock().id,
AgentToggleInputBlock().id,
]

View File

@@ -11,13 +11,13 @@ class StepThroughItemsBlock(Block):
advanced=False,
description="The list or dictionary of items to iterate over",
placeholder="[1, 2, 3, 4, 5] or {'key1': 'value1', 'key2': 'value2'}",
default_factory=list,
default=[],
)
items_object: dict = SchemaField(
advanced=False,
description="The list or dictionary of items to iterate over",
placeholder="[1, 2, 3, 4, 5] or {'key1': 'value1', 'key2': 'value2'}",
default_factory=dict,
default={},
)
items_str: str = SchemaField(
advanced=False,

View File

@@ -23,7 +23,7 @@ class JinaChunkingBlock(Block):
class Output(BlockSchema):
chunks: list = SchemaField(description="List of chunked texts")
tokens: list = SchemaField(
description="List of token information for each chunk",
description="List of token information for each chunk", optional=True
)
def __init__(self):

View File

@@ -1,4 +1,4 @@
from urllib.parse import quote
from groq._utils._utils import quote
from backend.blocks.jina._auth import (
TEST_CREDENTIALS,

View File

@@ -28,8 +28,8 @@ class LinearCreateIssueBlock(Block):
priority: int | None = SchemaField(
description="Priority of the issue",
default=None,
ge=0,
le=4,
minimum=0,
maximum=4,
)
project_name: str | None = SchemaField(
description="Name of the project to create the issue on",

View File

@@ -4,24 +4,27 @@ from abc import ABC
from enum import Enum, EnumMeta
from json import JSONDecodeError
from types import MappingProxyType
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple
from pydantic import SecretStr
from backend.integrations.providers import ProviderName
if TYPE_CHECKING:
from enum import _EnumMemberT
import anthropic
import ollama
import openai
from anthropic.types import ToolParam
from groq import Groq
from pydantic import BaseModel, SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.settings import BehaveAs, Settings
from backend.util.text import TextFormatter
@@ -71,10 +74,12 @@ class ModelMetadata(NamedTuple):
class LlmModelMeta(EnumMeta):
@property
def __members__(self) -> MappingProxyType:
def __members__(
self: type["_EnumMemberT"],
) -> MappingProxyType[str, "_EnumMemberT"]:
if Settings().config.behave_as == BehaveAs.LOCAL:
members = super().__members__
return MappingProxyType(members)
return members
else:
removed_providers = ["ollama"]
existing_members = super().__members__
@@ -89,17 +94,14 @@ class LlmModelMeta(EnumMeta):
class LlmModel(str, Enum, metaclass=LlmModelMeta):
# OpenAI models
O3_MINI = "o3-mini"
O3 = "o3-2025-04-16"
O1 = "o1"
O1_PREVIEW = "o1-preview"
O1_MINI = "o1-mini"
GPT41 = "gpt-4.1-2025-04-14"
GPT4O_MINI = "gpt-4o-mini"
GPT4O = "gpt-4o"
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"
# Anthropic models
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
@@ -120,7 +122,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
# OpenRouter models
GEMINI_FLASH_1_5 = "google/gemini-flash-1.5"
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
GROK_BETA = "x-ai/grok-beta"
MISTRAL_NEMO = "mistralai/mistral-nemo"
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
@@ -138,8 +139,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
@property
def metadata(self) -> ModelMetadata:
@@ -160,14 +159,12 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
MODEL_METADATA = {
# https://platform.openai.com/docs/models
LlmModel.O3: ModelMetadata("openai", 200000, 100000),
LlmModel.O3_MINI: ModelMetadata("openai", 200000, 100000), # o3-mini-2025-01-31
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
LlmModel.O1_PREVIEW: ModelMetadata(
"openai", 128000, 32768
), # o1-preview-2024-09-12
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
LlmModel.GPT41: ModelMetadata("openai", 1047576, 32768),
LlmModel.GPT4O_MINI: ModelMetadata(
"openai", 128000, 16384
), # gpt-4o-mini-2024-07-18
@@ -177,9 +174,6 @@ MODEL_METADATA = {
), # gpt-4-turbo-2024-04-09
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
# https://docs.anthropic.com/en/docs/about-claude/models
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
"anthropic", 200000, 8192
), # claude-3-7-sonnet-20250219
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
"anthropic", 200000, 8192
), # claude-3-5-sonnet-20241022
@@ -205,7 +199,6 @@ MODEL_METADATA = {
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768, None),
# https://openrouter.ai/models
LlmModel.GEMINI_FLASH_1_5: ModelMetadata("open_router", 1000000, 8192),
LlmModel.GEMINI_2_5_PRO: ModelMetadata("open_router", 1050000, 8192),
LlmModel.GROK_BETA: ModelMetadata("open_router", 131072, 131072),
LlmModel.MISTRAL_NEMO: ModelMetadata("open_router", 128000, 4096),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata("open_router", 128000, 4096),
@@ -227,8 +220,6 @@ MODEL_METADATA = {
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata("open_router", 300000, 5120),
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata("open_router", 65536, 4096),
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata("open_router", 4096, 4096),
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata("open_router", 131072, 131072),
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata("open_router", 1048576, 1000000),
}
for model in LlmModel:
@@ -236,312 +227,21 @@ for model in LlmModel:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
class ToolCall(BaseModel):
name: str
arguments: str
class MessageRole(str, Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
class ToolContentBlock(BaseModel):
id: str
type: str
function: ToolCall
class LLMResponse(BaseModel):
raw_response: Any
prompt: List[Any]
response: str
tool_calls: Optional[List[ToolContentBlock]] | None
prompt_tokens: int
completion_tokens: int
def convert_openai_tool_fmt_to_anthropic(
openai_tools: list[dict] | None = None,
) -> Iterable[ToolParam] | anthropic.NotGiven:
"""
Convert OpenAI tool format to Anthropic tool format.
"""
if not openai_tools or len(openai_tools) == 0:
return anthropic.NOT_GIVEN
anthropic_tools = []
for tool in openai_tools:
if "function" in tool:
# Handle case where tool is already in OpenAI format with "type" and "function"
function_data = tool["function"]
else:
# Handle case where tool is just the function definition
function_data = tool
anthropic_tool: anthropic.types.ToolParam = {
"name": function_data["name"],
"description": function_data.get("description", ""),
"input_schema": {
"type": "object",
"properties": function_data.get("parameters", {}).get("properties", {}),
"required": function_data.get("parameters", {}).get("required", []),
},
}
anthropic_tools.append(anthropic_tool)
return anthropic_tools
def llm_call(
credentials: APIKeyCredentials,
llm_model: LlmModel,
prompt: list[dict],
json_format: bool,
max_tokens: int | None,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
parallel_tool_calls: bool | None = None,
) -> LLMResponse:
"""
Make a call to a language model.
Args:
credentials: The API key credentials to use.
llm_model: The LLM model to use.
prompt: The prompt to send to the LLM.
json_format: Whether the response should be in JSON format.
max_tokens: The maximum number of tokens to generate in the chat completion.
tools: The tools to use in the chat completion.
ollama_host: The host for ollama to use.
Returns:
LLMResponse object containing:
- prompt: The prompt sent to the LLM.
- response: The text response from the LLM.
- tool_calls: Any tool calls the model made, if applicable.
- prompt_tokens: The number of tokens used in the prompt.
- completion_tokens: The number of tokens used in the completion.
"""
provider = llm_model.metadata.provider
max_tokens = max_tokens or llm_model.max_output_tokens or 4096
if provider == "openai":
tools_param = tools if tools else openai.NOT_GIVEN
oai_client = openai.OpenAI(api_key=credentials.api_key.get_secret_value())
response_format = None
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
prompt = [
{"role": "user", "content": "\n".join(sys_messages)},
{"role": "user", "content": "\n".join(usr_messages)},
]
elif json_format:
response_format = {"type": "json_object"}
response = oai_client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=(
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
),
)
if response.choices[0].message.tool_calls:
tool_calls = [
ToolContentBlock(
id=tool.id,
type=tool.type,
function=ToolCall(
name=tool.function.name,
arguments=tool.function.arguments,
),
)
for tool in response.choices[0].message.tool_calls
]
else:
tool_calls = None
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
response=response.choices[0].message.content or "",
tool_calls=tool_calls,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
)
elif provider == "anthropic":
an_tools = convert_openai_tool_fmt_to_anthropic(tools)
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
sysprompt = " ".join(system_messages)
messages = []
last_role = None
for p in prompt:
if p["role"] in ["user", "assistant"]:
if (
p["role"] == last_role
and isinstance(messages[-1]["content"], str)
and isinstance(p["content"], str)
):
# If the role is the same as the last one, combine the content
messages[-1]["content"] += p["content"]
else:
messages.append({"role": p["role"], "content": p["content"]})
last_role = p["role"]
client = anthropic.Anthropic(api_key=credentials.api_key.get_secret_value())
try:
resp = client.messages.create(
model=llm_model.value,
system=sysprompt,
messages=messages,
max_tokens=max_tokens,
tools=an_tools,
)
if not resp.content:
raise ValueError("No content returned from Anthropic.")
tool_calls = None
for content_block in resp.content:
# Antropic is different to openai, need to iterate through
# the content blocks to find the tool calls
if content_block.type == "tool_use":
if tool_calls is None:
tool_calls = []
tool_calls.append(
ToolContentBlock(
id=content_block.id,
type=content_block.type,
function=ToolCall(
name=content_block.name,
arguments=json.dumps(content_block.input),
),
)
)
if not tool_calls and resp.stop_reason == "tool_use":
logger.warning(
"Tool use stop reason but no tool calls found in content. %s", resp
)
return LLMResponse(
raw_response=resp,
prompt=prompt,
response=(
resp.content[0].name
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
else getattr(resp.content[0], "text", "")
),
tool_calls=tool_calls,
prompt_tokens=resp.usage.input_tokens,
completion_tokens=resp.usage.output_tokens,
)
except anthropic.APIError as e:
error_message = f"Anthropic API error: {str(e)}"
logger.error(error_message)
raise ValueError(error_message)
elif provider == "groq":
if tools:
raise ValueError("Groq does not support tools.")
client = Groq(api_key=credentials.api_key.get_secret_value())
response_format = {"type": "json_object"} if json_format else None
response = client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
)
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
response=response.choices[0].message.content or "",
tool_calls=None,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
)
elif provider == "ollama":
if tools:
raise ValueError("Ollama does not support tools.")
client = ollama.Client(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
response = client.generate(
model=llm_model.value,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
)
return LLMResponse(
raw_response=response.get("response") or "",
prompt=prompt,
response=response.get("response") or "",
tool_calls=None,
prompt_tokens=response.get("prompt_eval_count") or 0,
completion_tokens=response.get("eval_count") or 0,
)
elif provider == "open_router":
tools_param = tools if tools else openai.NOT_GIVEN
client = openai.OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=credentials.api_key.get_secret_value(),
)
response = client.chat.completions.create(
extra_headers={
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=(
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
),
)
# If there's no response, raise an error
if not response.choices:
if response:
raise ValueError(f"OpenRouter error: {response}")
else:
raise ValueError("No response from OpenRouter.")
if response.choices[0].message.tool_calls:
tool_calls = [
ToolContentBlock(
id=tool.id,
type=tool.type,
function=ToolCall(
name=tool.function.name, arguments=tool.function.arguments
),
)
for tool in response.choices[0].message.tool_calls
]
else:
tool_calls = None
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
response=response.choices[0].message.content or "",
tool_calls=tool_calls,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
class Message(BlockSchema):
role: MessageRole
content: str
class AIBlockBase(Block, ABC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt = []
self.prompt = ""
def merge_llm_stats(self, block: "AIBlockBase"):
self.merge_stats(block.execution_stats)
@@ -560,7 +260,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=LlmModel.GPT4_TURBO,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -570,8 +270,8 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
default="",
description="The system prompt to provide additional context to the model.",
)
conversation_history: list[dict] = SchemaField(
default_factory=list,
conversation_history: list[Message] = SchemaField(
default=[],
description="The conversation history to provide context for the prompt.",
)
retry: int = SchemaField(
@@ -581,7 +281,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
prompt_values: dict[str, str] = SchemaField(
advanced=False,
default_factory=dict,
default={},
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
)
max_tokens: int | None = SchemaField(
@@ -600,7 +300,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
response: dict[str, Any] = SchemaField(
description="The response object generated by the language model."
)
prompt: list = SchemaField(description="The prompt sent to the language model.")
prompt: str = SchemaField(description="The prompt sent to the language model.")
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
@@ -611,7 +311,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
input_schema=AIStructuredResponseGeneratorBlock.Input,
output_schema=AIStructuredResponseGeneratorBlock.Output,
test_input={
"model": LlmModel.GPT4O,
"model": LlmModel.GPT4_TURBO,
"credentials": TEST_CREDENTIALS_INPUT,
"expected_format": {
"key1": "value1",
@@ -622,24 +322,22 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
test_credentials=TEST_CREDENTIALS,
test_output=[
("response", {"key1": "key1Value", "key2": "key2Value"}),
("prompt", list),
("prompt", str),
],
test_mock={
"llm_call": lambda *args, **kwargs: LLMResponse(
raw_response="",
prompt=[""],
response=json.dumps(
"llm_call": lambda *args, **kwargs: (
json.dumps(
{
"key1": "key1Value",
"key2": "key2Value",
}
),
tool_calls=None,
prompt_tokens=0,
completion_tokens=0,
0,
0,
)
},
)
self.prompt = ""
def llm_call(
self,
@@ -648,29 +346,160 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
prompt: list[dict],
json_format: bool,
max_tokens: int | None,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
) -> LLMResponse:
) -> tuple[str, int, int]:
"""
Test mocks work only on class functions, this wraps the llm_call function
so that it can be mocked withing the block testing framework.
Args:
credentials: The API key credentials to use.
llm_model: The LLM model to use.
prompt: The prompt to send to the LLM.
json_format: Whether the response should be in JSON format.
max_tokens: The maximum number of tokens to generate in the chat completion.
ollama_host: The host for ollama to use
Returns:
The response from the LLM.
The number of tokens used in the prompt.
The number of tokens used in the completion.
"""
self.prompt = prompt
return llm_call(
credentials=credentials,
llm_model=llm_model,
prompt=prompt,
json_format=json_format,
max_tokens=max_tokens,
tools=tools,
ollama_host=ollama_host,
)
provider = llm_model.metadata.provider
max_tokens = max_tokens or llm_model.max_output_tokens or 4096
if provider == "openai":
oai_client = openai.OpenAI(api_key=credentials.api_key.get_secret_value())
response_format = None
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
prompt = [
{"role": "user", "content": "\n".join(sys_messages)},
{"role": "user", "content": "\n".join(usr_messages)},
]
elif json_format:
response_format = {"type": "json_object"}
response = oai_client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
)
self.prompt = json.dumps(prompt)
return (
response.choices[0].message.content or "",
response.usage.prompt_tokens if response.usage else 0,
response.usage.completion_tokens if response.usage else 0,
)
elif provider == "anthropic":
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
sysprompt = " ".join(system_messages)
messages = []
last_role = None
for p in prompt:
if p["role"] in ["user", "assistant"]:
if p["role"] != last_role:
messages.append({"role": p["role"], "content": p["content"]})
last_role = p["role"]
else:
# If the role is the same as the last one, combine the content
messages[-1]["content"] += "\n" + p["content"]
client = anthropic.Anthropic(api_key=credentials.api_key.get_secret_value())
try:
resp = client.messages.create(
model=llm_model.value,
system=sysprompt,
messages=messages,
max_tokens=max_tokens,
)
self.prompt = json.dumps(prompt)
if not resp.content:
raise ValueError("No content returned from Anthropic.")
return (
(
resp.content[0].name
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
else resp.content[0].text
),
resp.usage.input_tokens,
resp.usage.output_tokens,
)
except anthropic.APIError as e:
error_message = f"Anthropic API error: {str(e)}"
logger.error(error_message)
raise ValueError(error_message)
elif provider == "groq":
client = Groq(api_key=credentials.api_key.get_secret_value())
response_format = {"type": "json_object"} if json_format else None
response = client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
)
self.prompt = json.dumps(prompt)
return (
response.choices[0].message.content or "",
response.usage.prompt_tokens if response.usage else 0,
response.usage.completion_tokens if response.usage else 0,
)
elif provider == "ollama":
client = ollama.Client(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
response = client.generate(
model=llm_model.value,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
)
self.prompt = json.dumps(prompt)
return (
response.get("response") or "",
response.get("prompt_eval_count") or 0,
response.get("eval_count") or 0,
)
elif provider == "open_router":
client = openai.OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=credentials.api_key.get_secret_value(),
)
response = client.chat.completions.create(
extra_headers={
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
)
self.prompt = json.dumps(prompt)
# If there's no response, raise an error
if not response.choices:
if response:
raise ValueError(f"OpenRouter error: {response}")
else:
raise ValueError("No response from OpenRouter.")
return (
response.choices[0].message.content or "",
response.usage.prompt_tokens if response.usage else 0,
response.usage.completion_tokens if response.usage else 0,
)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
logger.debug(f"Calling LLM with input data: {input_data}")
prompt = [json.to_dict(p) for p in input_data.conversation_history]
prompt = [p.model_dump() for p in input_data.conversation_history]
def trim_prompt(s: str) -> str:
lines = s.strip().split("\n")
@@ -720,7 +549,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
for retry_count in range(input_data.retry):
try:
llm_response = self.llm_call(
response_text, input_token, output_token = self.llm_call(
credentials=credentials,
llm_model=llm_model,
prompt=prompt,
@@ -728,12 +557,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
ollama_host=input_data.ollama_host,
max_tokens=input_data.max_tokens,
)
response_text = llm_response.response
self.merge_stats(
NodeExecutionStats(
input_token_count=llm_response.prompt_tokens,
output_token_count=llm_response.completion_tokens,
)
{
"input_token_count": input_token,
"output_token_count": output_token,
}
)
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
@@ -776,10 +604,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
retry_prompt = f"Error calling LLM: {e}"
finally:
self.merge_stats(
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
)
{
"llm_call_count": retry_count + 1,
"llm_retry_count": retry_count,
}
)
raise RuntimeError(retry_prompt)
@@ -793,7 +621,7 @@ class AITextGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=LlmModel.GPT4_TURBO,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -810,7 +638,7 @@ class AITextGeneratorBlock(AIBlockBase):
)
prompt_values: dict[str, str] = SchemaField(
advanced=False,
default_factory=dict,
default={},
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
)
ollama_host: str = SchemaField(
@@ -828,7 +656,7 @@ class AITextGeneratorBlock(AIBlockBase):
response: str = SchemaField(
description="The response generated by the language model."
)
prompt: list = SchemaField(description="The prompt sent to the language model.")
prompt: str = SchemaField(description="The prompt sent to the language model.")
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
@@ -845,7 +673,7 @@ class AITextGeneratorBlock(AIBlockBase):
test_credentials=TEST_CREDENTIALS,
test_output=[
("response", "Response text"),
("prompt", list),
("prompt", str),
],
test_mock={"llm_call": lambda *args, **kwargs: "Response text"},
)
@@ -864,10 +692,7 @@ class AITextGeneratorBlock(AIBlockBase):
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
object_input_data = AIStructuredResponseGeneratorBlock.Input(
**{
attr: getattr(input_data, attr)
for attr in AITextGeneratorBlock.Input.model_fields
},
**{attr: getattr(input_data, attr) for attr in input_data.model_fields},
expected_format={},
)
yield "response", self.llm_call(object_input_data, credentials)
@@ -889,7 +714,7 @@ class AITextSummarizerBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=LlmModel.GPT4_TURBO,
description="The language model to use for summarizing the text.",
)
focus: str = SchemaField(
@@ -924,7 +749,7 @@ class AITextSummarizerBlock(AIBlockBase):
class Output(BlockSchema):
summary: str = SchemaField(description="The final summary of the text.")
prompt: list = SchemaField(description="The prompt sent to the language model.")
prompt: str = SchemaField(description="The prompt sent to the language model.")
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
@@ -941,7 +766,7 @@ class AITextSummarizerBlock(AIBlockBase):
test_credentials=TEST_CREDENTIALS,
test_output=[
("summary", "Final summary of a long text"),
("prompt", list),
("prompt", str),
],
test_mock={
"llm_call": lambda input_data, credentials: (
@@ -1050,18 +875,12 @@ class AITextSummarizerBlock(AIBlockBase):
class AIConversationBlock(AIBlockBase):
class Input(BlockSchema):
prompt: str = SchemaField(
description="The prompt to send to the language model.",
placeholder="Enter your prompt here...",
default="",
advanced=False,
)
messages: List[Any] = SchemaField(
description="List of messages in the conversation.",
messages: List[Message] = SchemaField(
description="List of messages in the conversation.", min_length=1
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=LlmModel.GPT4_TURBO,
description="The language model to use for the conversation.",
)
credentials: AICredentials = AICredentialsField()
@@ -1080,7 +899,7 @@ class AIConversationBlock(AIBlockBase):
response: str = SchemaField(
description="The model's response to the conversation."
)
prompt: list = SchemaField(description="The prompt sent to the language model.")
prompt: str = SchemaField(description="The prompt sent to the language model.")
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
@@ -1100,7 +919,7 @@ class AIConversationBlock(AIBlockBase):
},
{"role": "user", "content": "Where was it played?"},
],
"model": LlmModel.GPT4O,
"model": LlmModel.GPT4_TURBO,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
@@ -1109,7 +928,7 @@ class AIConversationBlock(AIBlockBase):
"response",
"The 2020 World Series was played at Globe Life Field in Arlington, Texas.",
),
("prompt", list),
("prompt", str),
],
test_mock={
"llm_call": lambda *args, **kwargs: "The 2020 World Series was played at Globe Life Field in Arlington, Texas."
@@ -1131,7 +950,7 @@ class AIConversationBlock(AIBlockBase):
) -> BlockOutput:
response = self.llm_call(
AIStructuredResponseGeneratorBlock.Input(
prompt=input_data.prompt,
prompt="",
credentials=input_data.credentials,
model=input_data.model,
conversation_history=input_data.messages,
@@ -1162,7 +981,7 @@ class AIListGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=LlmModel.GPT4_TURBO,
description="The language model to use for generating the list.",
advanced=True,
)
@@ -1189,7 +1008,7 @@ class AIListGeneratorBlock(AIBlockBase):
list_item: str = SchemaField(
description="Each individual item in the list.",
)
prompt: list = SchemaField(description="The prompt sent to the language model.")
prompt: str = SchemaField(description="The prompt sent to the language model.")
error: str = SchemaField(
description="Error message if the list generation failed."
)
@@ -1211,7 +1030,7 @@ class AIListGeneratorBlock(AIBlockBase):
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
"fictional worlds."
),
"model": LlmModel.GPT4O,
"model": LlmModel.GPT4_TURBO,
"credentials": TEST_CREDENTIALS_INPUT,
"max_retries": 3,
},
@@ -1221,7 +1040,7 @@ class AIListGeneratorBlock(AIBlockBase):
"generated_list",
["Zylora Prime", "Kharon-9", "Vortexia", "Oceara", "Draknos"],
),
("prompt", list),
("prompt", str),
("list_item", "Zylora Prime"),
("list_item", "Kharon-9"),
("list_item", "Vortexia"),

View File

@@ -8,13 +8,13 @@ from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
from backend.util.file import MediaFile, get_exec_file_path, store_media_file
class MediaDurationBlock(Block):
class Input(BlockSchema):
media_in: MediaFileType = SchemaField(
media_in: MediaFile = SchemaField(
description="Media input (URL, data URI, or local path)."
)
is_video: bool = SchemaField(
@@ -69,7 +69,7 @@ class LoopVideoBlock(Block):
"""
class Input(BlockSchema):
video_in: MediaFileType = SchemaField(
video_in: MediaFile = SchemaField(
description="The input video (can be a URL, data URI, or local path)."
)
# Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`.
@@ -137,7 +137,7 @@ class LoopVideoBlock(Block):
assert isinstance(looped_clip, VideoFileClip)
# 4) Save the looped output
output_filename = MediaFileType(
output_filename = MediaFile(
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
)
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
@@ -162,10 +162,10 @@ class AddAudioToVideoBlock(Block):
"""
class Input(BlockSchema):
video_in: MediaFileType = SchemaField(
video_in: MediaFile = SchemaField(
description="Video input (URL, data URI, or local path)."
)
audio_in: MediaFileType = SchemaField(
audio_in: MediaFile = SchemaField(
description="Audio input (URL, data URI, or local path)."
)
volume: float = SchemaField(
@@ -178,7 +178,7 @@ class AddAudioToVideoBlock(Block):
)
class Output(BlockSchema):
video_out: MediaFileType = SchemaField(
video_out: MediaFile = SchemaField(
description="Final video (with attached audio), as a path or data URI."
)
error: str = SchemaField(
@@ -229,7 +229,7 @@ class AddAudioToVideoBlock(Block):
final_clip = video_clip.with_audio(audio_clip)
# 4) Write to output file
output_filename = MediaFileType(
output_filename = MediaFile(
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
)
output_abspath = os.path.join(abs_temp_dir, output_filename)

View File

@@ -65,7 +65,7 @@ class AddMemoryBlock(Block, Mem0Base):
default=Content(discriminator="content", content="I'm a vegetarian"),
)
metadata: dict[str, Any] = SchemaField(
description="Optional metadata for the memory", default_factory=dict
description="Optional metadata for the memory", default={}
)
limit_memory_to_run: bool = SchemaField(
@@ -173,7 +173,7 @@ class SearchMemoryBlock(Block, Mem0Base):
)
categories_filter: list[str] = SchemaField(
description="Categories to filter by",
default_factory=list,
default=[],
advanced=True,
)
limit_memory_to_run: bool = SchemaField(

View File

@@ -6,14 +6,13 @@ from backend.blocks.nvidia._auth import (
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import requests
from backend.util.type import MediaFileType
class NvidiaDeepfakeDetectBlock(Block):
class Input(BlockSchema):
credentials: NvidiaCredentialsInput = NvidiaCredentialsField()
image_base64: MediaFileType = SchemaField(
description="Image to analyze for deepfakes",
image_base64: str = SchemaField(
description="Image to analyze for deepfakes", image_upload=True
)
return_image: bool = SchemaField(
description="Whether to return the processed image with markings",
@@ -23,12 +22,16 @@ class NvidiaDeepfakeDetectBlock(Block):
class Output(BlockSchema):
status: str = SchemaField(
description="Detection status (SUCCESS, ERROR, CONTENT_FILTERED)",
default="",
)
image: MediaFileType = SchemaField(
image: str = SchemaField(
description="Processed image with detection markings (if return_image=True)",
default="",
image_output=True,
)
is_deepfake: float = SchemaField(
description="Probability that the image is a deepfake (0-1)",
default=0.0,
)
def __init__(self):

View File

@@ -177,8 +177,7 @@ class PineconeInsertBlock(Block):
description="Namespace to use in Pinecone", default=""
)
metadata: dict = SchemaField(
description="Additional metadata to store with each vector",
default_factory=dict,
description="Additional metadata to store with each vector", default={}
)
class Output(BlockSchema):

View File

@@ -12,7 +12,7 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import MediaFileType, store_media_file
from backend.util.file import MediaFile, store_media_file
from backend.util.request import Requests
@@ -57,7 +57,7 @@ class ScreenshotWebPageBlock(Block):
)
class Output(BlockSchema):
image: MediaFileType = SchemaField(description="The screenshot image data")
image: MediaFile = SchemaField(description="The screenshot image data")
error: str = SchemaField(description="Error message if the screenshot failed")
def __init__(self):
@@ -142,9 +142,7 @@ class ScreenshotWebPageBlock(Block):
return {
"image": store_media_file(
graph_exec_id=graph_exec_id,
file=MediaFileType(
f"data:image/{format.value};base64,{b64encode(response.content).decode('utf-8')}"
),
file=f"data:image/{format.value};base64,{b64encode(response.content).decode('utf-8')}",
return_content=True,
)
}

View File

@@ -8,7 +8,6 @@ from backend.data.block import (
BlockWebhookConfig,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.util import settings
from backend.util.settings import AppEnvironment, BehaveAs
@@ -26,7 +25,7 @@ class Slant3DTriggerBase:
class Input(BlockSchema):
credentials: Slant3DCredentialsInput = Slant3DCredentialsField()
# Webhook URL is handled by the webhook system
payload: dict = SchemaField(hidden=True, default_factory=dict)
payload: dict = SchemaField(hidden=True, default={})
class Output(BlockSchema):
payload: dict = SchemaField(
@@ -83,7 +82,7 @@ class Slant3DOrderWebhookBlock(Slant3DTriggerBase, Block):
input_schema=self.Input,
output_schema=self.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.SLANT3D,
provider="slant3d",
webhook_type="orders", # Only one type for now
resource_format="", # No resource format needed
event_filter_input="events",

View File

@@ -1,509 +0,0 @@
import logging
import re
from collections import Counter
from typing import TYPE_CHECKING, Any
from autogpt_libs.utils.cache import thread_cached
import backend.blocks.llm as llm
from backend.blocks.agent import AgentExecutorBlock
from backend.data.block import (
Block,
BlockCategory,
BlockInput,
BlockOutput,
BlockSchema,
BlockType,
)
from backend.data.model import SchemaField
from backend.util import json
if TYPE_CHECKING:
from backend.data.graph import Link, Node
logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManager
from backend.util.service import get_service_client
return get_service_client(DatabaseManager)
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
"""
Return a list of tool_call_ids if the entry is a tool request.
Supports both OpenAI and Anthropics formats.
"""
tool_call_ids = []
if entry.get("role") != "assistant":
return tool_call_ids
# OpenAI: check for tool_calls in the entry.
calls = entry.get("tool_calls")
if isinstance(calls, list):
for call in calls:
if tool_id := call.get("id"):
tool_call_ids.append(tool_id)
# Anthropics: check content items for tool_use type.
content = entry.get("content")
if isinstance(content, list):
for item in content:
if item.get("type") != "tool_use":
continue
if tool_id := item.get("id"):
tool_call_ids.append(tool_id)
return tool_call_ids
def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
"""
Return a list of tool_call_ids if the entry is a tool response.
Supports both OpenAI and Anthropics formats.
"""
tool_call_ids: list[str] = []
# OpenAI: a tool response message with role "tool" and key "tool_call_id".
if entry.get("role") == "tool":
if tool_call_id := entry.get("tool_call_id"):
tool_call_ids.append(str(tool_call_id))
# Anthropics: check content items for tool_result type.
if entry.get("role") == "user":
content = entry.get("content")
if isinstance(content, list):
for item in content:
if item.get("type") != "tool_result":
continue
if tool_call_id := item.get("tool_use_id"):
tool_call_ids.append(tool_call_id)
return tool_call_ids
def _create_tool_response(call_id: str, output: dict[str, Any]) -> dict[str, Any]:
"""
Create a tool response message for either OpenAI or Anthropics,
based on the tool_id format.
"""
content = output if isinstance(output, str) else json.dumps(output)
# Anthropics format: tool IDs typically start with "toolu_"
if call_id.startswith("toolu_"):
return {
"role": "user",
"type": "message",
"content": [
{"tool_use_id": call_id, "type": "tool_result", "content": content}
],
}
# OpenAI format: tool IDs typically start with "call_".
# Or default fallback (if the tool_id doesn't match any known prefix)
return {"role": "tool", "tool_call_id": call_id, "content": content}
def get_pending_tool_calls(conversation_history: list[Any]) -> dict[str, int]:
"""
All the tool calls entry in the conversation history requires a response.
This function returns the pending tool calls that has not generated an output yet.
Return: dict[str, int] - A dictionary of pending tool call IDs with their count.
"""
pending_calls = Counter()
for history in conversation_history:
for call_id in _get_tool_requests(history):
pending_calls[call_id] += 1
for call_id in _get_tool_responses(history):
pending_calls[call_id] -= 1
return {call_id: count for call_id, count in pending_calls.items() if count > 0}
class SmartDecisionMakerBlock(Block):
"""
A block that uses a language model to make smart decisions based on a given prompt.
"""
class Input(BlockSchema):
prompt: str = SchemaField(
description="The prompt to send to the language model.",
placeholder="Enter your prompt here...",
)
model: llm.LlmModel = SchemaField(
title="LLM Model",
default=llm.LlmModel.GPT4O,
description="The language model to use for answering the prompt.",
advanced=False,
)
credentials: llm.AICredentials = llm.AICredentialsField()
sys_prompt: str = SchemaField(
title="System Prompt",
default="Thinking carefully step by step decide which function to call. "
"Always choose a function call from the list of function signatures, "
"and always provide the complete argument provided with the type "
"matching the required jsonschema signature, no missing argument is allowed. "
"If you have already completed the task objective, you can end the task "
"by providing the end result of your work as a finish message. "
"Only provide EXACTLY one function call, multiple tool calls is strictly prohibited.",
description="The system prompt to provide additional context to the model.",
)
conversation_history: list[dict] = SchemaField(
default_factory=list,
description="The conversation history to provide context for the prompt.",
)
last_tool_output: Any = SchemaField(
default=None,
description="The output of the last tool that was called.",
)
retry: int = SchemaField(
title="Retry Count",
default=3,
description="Number of times to retry the LLM call if the response does not match the expected format.",
)
prompt_values: dict[str, str] = SchemaField(
advanced=False,
default_factory=dict,
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
)
max_tokens: int | None = SchemaField(
advanced=True,
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
ollama_host: str = SchemaField(
advanced=True,
default="localhost:11434",
description="Ollama host for local models",
)
@classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
# conversation_history & last_tool_output validation is handled differently
missing_links = super().get_missing_links(
data,
[
link
for link in links
if link.sink_name
not in ["conversation_history", "last_tool_output"]
],
)
# Avoid executing the block if the last_tool_output is connected to a static
# link, like StoreValueBlock or AgentInputBlock.
if any(link.sink_name == "conversation_history" for link in links) and any(
link.sink_name == "last_tool_output" and link.is_static
for link in links
):
raise ValueError(
"Last Tool Output can't be connected to a static (dashed line) "
"link like the output of `StoreValue` or `AgentInput` block"
)
return missing_links
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
if missing_input := super().get_missing_input(data):
return missing_input
conversation_history = data.get("conversation_history", [])
pending_tool_calls = get_pending_tool_calls(conversation_history)
last_tool_output = data.get("last_tool_output")
if not last_tool_output and pending_tool_calls:
return {"last_tool_output"}
return set()
class Output(BlockSchema):
error: str = SchemaField(description="Error message if the API call failed.")
tools: Any = SchemaField(description="The tools that are available to use.")
finished: str = SchemaField(
description="The finished message to display to the user."
)
conversations: list[Any] = SchemaField(
description="The conversation history to provide context for the prompt."
)
def __init__(self):
super().__init__(
id="3b191d9f-356f-482d-8238-ba04b6d18381",
description="Uses AI to intelligently decide what tool to use.",
categories={BlockCategory.AI},
block_type=BlockType.AI,
input_schema=SmartDecisionMakerBlock.Input,
output_schema=SmartDecisionMakerBlock.Output,
test_input={
"prompt": "Hello, World!",
"credentials": llm.TEST_CREDENTIALS_INPUT,
},
test_output=[],
test_credentials=llm.TEST_CREDENTIALS,
)
@staticmethod
def _create_block_function_signature(
sink_node: "Node", links: list["Link"]
) -> dict[str, Any]:
"""
Creates a function signature for a block node.
Args:
sink_node: The node for which to create a function signature.
links: The list of links connected to the sink node.
Returns:
A dictionary representing the function signature in the format expected by LLM tools.
Raises:
ValueError: If the block specified by sink_node.block_id is not found.
"""
block = sink_node.block
tool_function: dict[str, Any] = {
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", block.name).lower(),
"description": block.description,
}
properties = {}
required = []
for link in links:
sink_block_input_schema = block.input_schema
description = (
sink_block_input_schema.model_fields[link.sink_name].description
if link.sink_name in sink_block_input_schema.model_fields
and sink_block_input_schema.model_fields[link.sink_name].description
else f"The {link.sink_name} of the tool"
)
properties[link.sink_name.lower()] = {
"type": "string",
"description": description,
}
tool_function["parameters"] = {
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False,
"strict": True,
}
return {"type": "function", "function": tool_function}
@staticmethod
def _create_agent_function_signature(
sink_node: "Node", links: list["Link"]
) -> dict[str, Any]:
"""
Creates a function signature for an agent node.
Args:
sink_node: The agent node for which to create a function signature.
links: The list of links connected to the sink node.
Returns:
A dictionary representing the function signature in the format expected by LLM tools.
Raises:
ValueError: If the graph metadata for the specified graph_id and graph_version is not found.
"""
graph_id = sink_node.input_default.get("graph_id")
graph_version = sink_node.input_default.get("graph_version")
if not graph_id or not graph_version:
raise ValueError("Graph ID or Graph Version not found in sink node.")
db_client = get_database_manager_client()
sink_graph_meta = db_client.get_graph_metadata(graph_id, graph_version)
if not sink_graph_meta:
raise ValueError(
f"Sink graph metadata not found: {graph_id} {graph_version}"
)
tool_function: dict[str, Any] = {
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", sink_graph_meta.name).lower(),
"description": sink_graph_meta.description,
}
properties = {}
required = []
for link in links:
sink_block_input_schema = sink_node.input_default["input_schema"]
description = (
sink_block_input_schema["properties"][link.sink_name]["description"]
if "description"
in sink_block_input_schema["properties"][link.sink_name]
else f"The {link.sink_name} of the tool"
)
properties[link.sink_name.lower()] = {
"type": "string",
"description": description,
}
tool_function["parameters"] = {
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False,
"strict": True,
}
return {"type": "function", "function": tool_function}
@staticmethod
def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
"""
Creates function signatures for tools linked to a specified node within a graph.
This method filters the graph links to identify those that are tools and are
connected to the given node_id. It then constructs function signatures for each
tool based on the metadata and input schema of the linked nodes.
Args:
node_id: The node_id for which to create function signatures.
Returns:
list[dict[str, Any]]: A list of dictionaries, each representing a function signature
for a tool, including its name, description, and parameters.
Raises:
ValueError: If no tool links are found for the specified node_id, or if a sink node
or its metadata cannot be found.
"""
db_client = get_database_manager_client()
tools = [
(link, node)
for link, node in db_client.get_connected_output_nodes(node_id)
if link.source_name.startswith("tools_^_") and link.source_id == node_id
]
if not tools:
raise ValueError("There is no next node to execute.")
return_tool_functions = []
grouped_tool_links: dict[str, tuple["Node", list["Link"]]] = {}
for link, node in tools:
if link.sink_id not in grouped_tool_links:
grouped_tool_links[link.sink_id] = (node, [link])
else:
grouped_tool_links[link.sink_id][1].append(link)
for sink_node, links in grouped_tool_links.values():
if not sink_node:
raise ValueError(f"Sink node not found: {links[0].sink_id}")
if sink_node.block_id == AgentExecutorBlock().id:
return_tool_functions.append(
SmartDecisionMakerBlock._create_agent_function_signature(
sink_node, links
)
)
else:
return_tool_functions.append(
SmartDecisionMakerBlock._create_block_function_signature(
sink_node, links
)
)
return return_tool_functions
def run(
self,
input_data: Input,
*,
credentials: llm.APIKeyCredentials,
graph_id: str,
node_id: str,
graph_exec_id: str,
node_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
tool_functions = self._create_function_signature(node_id)
input_data.conversation_history = input_data.conversation_history or []
prompt = [json.to_dict(p) for p in input_data.conversation_history if p]
pending_tool_calls = get_pending_tool_calls(input_data.conversation_history)
if pending_tool_calls and not input_data.last_tool_output:
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
# Prefill all missing tool calls with the last tool output/
# TODO: we need a better way to handle this.
tool_output = [
_create_tool_response(pending_call_id, input_data.last_tool_output)
for pending_call_id, count in pending_tool_calls.items()
for _ in range(count)
]
# If the SDM block only calls 1 tool at a time, this should not happen.
if len(tool_output) > 1:
logger.warning(
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
f"Multiple pending tool calls are prefilled using a single output. "
f"Execution may not be accurate."
)
# Fallback on adding tool output in the conversation history as user prompt.
if len(tool_output) == 0 and input_data.last_tool_output:
logger.warning(
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
f"No pending tool calls found. This may indicate an issue with the "
f"conversation history, or an LLM calling two tools at the same time."
)
tool_output.append(
{
"role": "user",
"content": f"Last tool output: {json.dumps(input_data.last_tool_output)}",
}
)
prompt.extend(tool_output)
values = input_data.prompt_values
if values:
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
prefix = "[Main Objective Prompt]: "
if input_data.sys_prompt and not any(
p["role"] == "system" and p["content"].startswith(prefix) for p in prompt
):
prompt.append({"role": "system", "content": prefix + input_data.sys_prompt})
if input_data.prompt and not any(
p["role"] == "user" and p["content"].startswith(prefix) for p in prompt
):
prompt.append({"role": "user", "content": prefix + input_data.prompt})
response = llm.llm_call(
credentials=credentials,
llm_model=input_data.model,
prompt=prompt,
json_format=False,
max_tokens=input_data.max_tokens,
tools=tool_functions,
ollama_host=input_data.ollama_host,
parallel_tool_calls=False,
)
if not response.tool_calls:
yield "finished", response.response
return
for tool_call in response.tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
for arg_name, arg_value in tool_args.items():
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
response.prompt.append(response.raw_response)
yield "conversations", response.prompt

View File

@@ -112,7 +112,7 @@ class AddLeadToCampaignBlock(Block):
lead_list: list[LeadInput] = SchemaField(
description="An array of JSON objects, each representing a lead's details. Can hold max 100 leads.",
max_length=100,
default_factory=list,
default=[],
advanced=False,
)
settings: LeadUploadSettings = SchemaField(
@@ -248,7 +248,7 @@ class SaveCampaignSequencesBlock(Block):
)
sequences: list[Sequence] = SchemaField(
description="The sequences to save",
default_factory=list,
default=[],
advanced=False,
)
credentials: SmartLeadCredentialsInput = SchemaField(

View File

@@ -39,7 +39,7 @@ class LeadCustomFields(BaseModel):
fields: dict[str, str] = SchemaField(
description="Custom fields for a lead (max 20 fields)",
max_length=20,
default_factory=dict,
default={},
)
@@ -85,7 +85,7 @@ class AddLeadsRequest(BaseModel):
lead_list: list[LeadInput] = SchemaField(
description="List of leads to add to the campaign",
max_length=100,
default_factory=list,
default=[],
)
settings: LeadUploadSettings
campaign_id: int

View File

@@ -156,10 +156,6 @@ class CountdownTimerBlock(Block):
days: Union[int, str] = SchemaField(
advanced=False, description="Duration in days", default=0
)
repeat: int = SchemaField(
description="Number of times to repeat the timer",
default=1,
)
class Output(BlockSchema):
output_message: Any = SchemaField(
@@ -191,6 +187,5 @@ class CountdownTimerBlock(Block):
total_seconds = seconds + minutes * 60 + hours * 3600 + days * 86400
for _ in range(input_data.repeat):
time.sleep(total_seconds)
yield "output_message", input_data.input_message
time.sleep(total_seconds)
yield "output_message", input_data.input_message

View File

@@ -156,7 +156,7 @@
# participant_ids: list[str] = SchemaField(
# description="Array of User IDs to create conversation with (max 50)",
# placeholder="Enter participant user IDs",
# default_factory=list,
# default=[],
# advanced=False
# )

View File

@@ -39,6 +39,7 @@ class TwitterGetListBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to lookup",
placeholder="Enter list ID",
required=True,
)
class Output(BlockSchema):
@@ -183,6 +184,7 @@ class TwitterGetOwnedListsBlock(Block):
user_id: str = SchemaField(
description="The user ID whose owned Lists to retrieve",
placeholder="Enter user ID",
required=True,
)
max_results: int | None = SchemaField(

View File

@@ -45,11 +45,13 @@ class TwitterRemoveListMemberBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to remove the member from",
placeholder="Enter list ID",
required=True,
)
user_id: str = SchemaField(
description="The ID of the user to remove from the List",
placeholder="Enter user ID to remove",
required=True,
)
class Output(BlockSchema):
@@ -118,11 +120,13 @@ class TwitterAddListMemberBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to add the member to",
placeholder="Enter list ID",
required=True,
)
user_id: str = SchemaField(
description="The ID of the user to add to the List",
placeholder="Enter user ID to add",
required=True,
)
class Output(BlockSchema):
@@ -191,6 +195,7 @@ class TwitterGetListMembersBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to get members from",
placeholder="Enter list ID",
required=True,
)
max_results: int | None = SchemaField(
@@ -371,6 +376,7 @@ class TwitterGetListMembershipsBlock(Block):
user_id: str = SchemaField(
description="The ID of the user whose List memberships to retrieve",
placeholder="Enter user ID",
required=True,
)
max_results: int | None = SchemaField(

View File

@@ -42,6 +42,7 @@ class TwitterGetListTweetsBlock(Block):
list_id: str = SchemaField(
description="The ID of the List whose Tweets you would like to retrieve",
placeholder="Enter list ID",
required=True,
)
max_results: int | None = SchemaField(

View File

@@ -28,6 +28,7 @@ class TwitterDeleteListBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to be deleted",
placeholder="Enter list ID",
required=True,
)
class Output(BlockSchema):

View File

@@ -39,6 +39,7 @@ class TwitterUnpinListBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to unpin",
placeholder="Enter list ID",
required=True,
)
class Output(BlockSchema):
@@ -102,6 +103,7 @@ class TwitterPinListBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to pin",
placeholder="Enter list ID",
required=True,
)
class Output(BlockSchema):

View File

@@ -44,7 +44,7 @@ class SpaceList(BaseModel):
space_ids: list[str] = SchemaField(
description="List of Space IDs to lookup (up to 100)",
placeholder="Enter Space IDs",
default_factory=list,
default=[],
advanced=False,
)
@@ -54,7 +54,7 @@ class UserList(BaseModel):
user_ids: list[str] = SchemaField(
description="List of user IDs to lookup their Spaces (up to 100)",
placeholder="Enter user IDs",
default_factory=list,
default=[],
advanced=False,
)
@@ -227,6 +227,7 @@ class TwitterGetSpaceByIdBlock(Block):
space_id: str = SchemaField(
description="Space ID to lookup",
placeholder="Enter Space ID",
required=True,
)
class Output(BlockSchema):
@@ -388,6 +389,7 @@ class TwitterGetSpaceBuyersBlock(Block):
space_id: str = SchemaField(
description="Space ID to lookup buyers for",
placeholder="Enter Space ID",
required=True,
)
class Output(BlockSchema):
@@ -515,6 +517,7 @@ class TwitterGetSpaceTweetsBlock(Block):
space_id: str = SchemaField(
description="Space ID to lookup tweets for",
placeholder="Enter Space ID",
required=True,
)
class Output(BlockSchema):

View File

@@ -200,7 +200,7 @@ class UserIdList(BaseModel):
user_ids: list[str] = SchemaField(
description="List of user IDs to lookup (max 100)",
placeholder="Enter user IDs",
default_factory=list,
default=[],
advanced=False,
)
@@ -210,7 +210,7 @@ class UsernameList(BaseModel):
usernames: list[str] = SchemaField(
description="List of Twitter usernames/handles to lookup (max 100)",
placeholder="Enter usernames",
default_factory=list,
default=[],
advanced=False,
)

View File

@@ -8,6 +8,7 @@ import pathlib
import click
import psutil
from backend import app
from backend.util.process import AppProcess
@@ -41,13 +42,8 @@ def write_pid(pid: int):
class MainApp(AppProcess):
def run(self):
from backend import app
app.main(silent=True)
def cleanup(self):
pass
@click.group()
def main():
@@ -224,8 +220,9 @@ def event():
@test.command()
@click.argument("server_address")
@click.argument("graph_exec_id")
def websocket(server_address: str, graph_exec_id: str):
@click.argument("graph_id")
@click.argument("graph_version")
def websocket(server_address: str, graph_id: str, graph_version: int):
"""
Tests the websocket connection.
"""
@@ -233,20 +230,16 @@ def websocket(server_address: str, graph_exec_id: str):
import websockets.asyncio.client
from backend.server.ws_api import (
WSMessage,
WSMethod,
WSSubscribeGraphExecutionRequest,
)
from backend.server.ws_api import ExecutionSubscription, Methods, WsMessage
async def send_message(server_address: str):
uri = f"ws://{server_address}"
async with websockets.asyncio.client.connect(uri) as websocket:
try:
msg = WSMessage(
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
data=WSSubscribeGraphExecutionRequest(
graph_exec_id=graph_exec_id,
msg = WsMessage(
method=Methods.SUBSCRIBE,
data=ExecutionSubscription(
graph_id=graph_id, graph_version=graph_version
).model_dump(),
).model_dump_json()
await websocket.send(msg)

View File

@@ -12,12 +12,12 @@ async def log_raw_analytics(
data_index: str,
):
details = await prisma.models.AnalyticsDetails.prisma().create(
data=prisma.types.AnalyticsDetailsCreateInput(
userId=user_id,
type=type,
data=prisma.Json(data),
dataIndex=data_index,
)
data={
"userId": user_id,
"type": type,
"data": prisma.Json(data),
"dataIndex": data_index,
}
)
return details
@@ -32,12 +32,12 @@ async def log_raw_metric(
raise ValueError("metric_value must be non-negative")
result = await prisma.models.AnalyticsMetrics.prisma().create(
data=prisma.types.AnalyticsMetricsCreateInput(
value=metric_value,
analyticMetric=metric_name,
userId=user_id,
dataString=data_string,
)
data={
"value": metric_value,
"analyticMetric": metric_name,
"userId": user_id,
"dataString": data_string,
},
)
return result

View File

@@ -2,7 +2,6 @@ import inspect
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Generator,
@@ -17,25 +16,18 @@ from typing import (
import jsonref
import jsonschema
from prisma.models import AgentBlock
from prisma.types import AgentBlockCreateInput
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.settings import Config
from .model import (
ContributorDetails,
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
is_credentials_field_name,
)
if TYPE_CHECKING:
from .graph import Link
app_config = Config()
BlockData = tuple[str, Any] # Input & Output data should be a tuple of (name, data).
@@ -52,7 +44,6 @@ class BlockType(Enum):
WEBHOOK = "Webhook"
WEBHOOK_MANUAL = "Webhook (manual)"
AGENT = "Agent"
AI = "AI"
class BlockCategory(Enum):
@@ -118,30 +109,21 @@ class BlockSchema(BaseModel):
def validate_data(cls, data: BlockInput) -> str | None:
return json.validate_with_jsonschema(schema=cls.jsonschema(), data=data)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return cls.validate_data(data)
@classmethod
def get_field_schema(cls, field_name: str) -> dict[str, Any]:
model_schema = cls.jsonschema().get("properties", {})
if not model_schema:
raise ValueError(f"Invalid model schema {cls}")
property_schema = model_schema.get(field_name)
if not property_schema:
raise ValueError(f"Invalid property name {field_name}")
return property_schema
@classmethod
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:
"""
Validate the data against a specific property (one of the input/output name).
Returns the validation error message if the data does not match the schema.
"""
model_schema = cls.jsonschema().get("properties", {})
if not model_schema:
return f"Invalid model schema {cls}"
property_schema = model_schema.get(field_name)
if not property_schema:
return f"Invalid property name {field_name}"
try:
property_schema = cls.get_field_schema(field_name)
jsonschema.validate(json.to_dict(data), property_schema)
return None
except jsonschema.ValidationError as e:
@@ -204,28 +186,6 @@ class BlockSchema(BaseModel):
)
}
@classmethod
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
return {
field_name: CredentialsFieldInfo.model_validate(
cls.get_field_schema(field_name), by_alias=True
)
for field_name in cls.get_credentials_fields().keys()
}
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data # Return as is, by default.
@classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
input_fields_from_nodes = {link.sink_name for link in links}
return input_fields_from_nodes - set(data)
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
return cls.get_required_fields() - set(data)
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema)
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchema)
@@ -242,7 +202,7 @@ class BlockManualWebhookConfig(BaseModel):
the user has to manually set up the webhook at the provider.
"""
provider: ProviderName
provider: str
"""The service provider that the webhook connects to"""
webhook_type: str
@@ -334,7 +294,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.static_output = static_output
self.block_type = block_type
self.webhook_config = webhook_config
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
self.execution_stats = {}
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
@@ -391,14 +351,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
Run the block with the given input data.
Args:
input_data: The input data with the structure of input_schema.
Kwargs: Currently 14/02/2025 these include
graph_id: The ID of the graph.
node_id: The ID of the node.
graph_exec_id: The ID of the graph execution.
node_exec_id: The ID of the node execution.
user_id: The ID of the user.
Returns:
A Generator that yields (output_name, output_data).
output_name: One of the output name defined in Block's output_schema.
@@ -412,29 +364,18 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
return data
raise ValueError(f"{self.name} did not produce any output for {output}")
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
stats_dict = stats.model_dump()
current_stats = self.execution_stats.model_dump()
for key, value in stats_dict.items():
if key not in current_stats:
# Field doesn't exist yet, just set it, but this will probably
# not happen, just in case though so we throw for invalid when
# converting back in
current_stats[key] = value
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
current_stats[key].update(value)
elif isinstance(value, (int, float)) and isinstance(
current_stats[key], (int, float)
):
current_stats[key] += value
elif isinstance(value, list) and isinstance(current_stats[key], list):
current_stats[key].extend(value)
def merge_stats(self, stats: dict[str, Any]) -> dict[str, Any]:
for key, value in stats.items():
if isinstance(value, dict):
self.execution_stats.setdefault(key, {}).update(value)
elif isinstance(value, (int, float)):
self.execution_stats.setdefault(key, 0)
self.execution_stats[key] += value
elif isinstance(value, list):
self.execution_stats.setdefault(key, [])
self.execution_stats[key].extend(value)
else:
current_stats[key] = value
self.execution_stats = NodeExecutionStats(**current_stats)
self.execution_stats[key] = value
return self.execution_stats
@property
@@ -457,6 +398,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
}
def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
# Merge the input data with the extra execution arguments, preferring the args for security
if error := self.input_schema.validate_data(input_data):
raise ValueError(
f"Unable to execute block with invalid input data: {error}"
@@ -478,9 +420,9 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
def get_blocks() -> dict[str, Type[Block]]:
from backend.blocks import load_all_blocks
from backend.blocks import AVAILABLE_BLOCKS # noqa: E402
return load_all_blocks()
return AVAILABLE_BLOCKS
async def initialize_blocks() -> None:
@@ -491,12 +433,12 @@ async def initialize_blocks() -> None:
)
if not existing_block:
await AgentBlock.prisma().create(
data=AgentBlockCreateInput(
id=block.id,
name=block.name,
inputSchema=json.dumps(block.input_schema.jsonschema()),
outputSchema=json.dumps(block.output_schema.jsonschema()),
)
data={
"id": block.id,
"name": block.name,
"inputSchema": json.dumps(block.input_schema.jsonschema()),
"outputSchema": json.dumps(block.output_schema.jsonschema()),
}
)
continue
@@ -519,7 +461,6 @@ async def initialize_blocks() -> None:
)
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
def get_block(block_id: str) -> Block | None:
cls = get_blocks().get(block_id)
return cls() if cls else None

View File

@@ -15,7 +15,6 @@ from backend.blocks.llm import (
LlmModel,
)
from backend.blocks.replicate_flux_advanced import ReplicateFluxAdvancedModelBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
from backend.data.block import Block
@@ -36,17 +35,14 @@ from backend.integrations.credentials_store import (
# =============== Configure the cost for each LLM Model call =============== #
MODEL_COST: dict[LlmModel, int] = {
LlmModel.O3: 7,
LlmModel.O3_MINI: 2, # $1.10 / $4.40
LlmModel.O1: 16, # $15 / $60
LlmModel.O1_PREVIEW: 16,
LlmModel.O1_MINI: 4,
LlmModel.GPT41: 2,
LlmModel.GPT4O_MINI: 1,
LlmModel.GPT4O: 3,
LlmModel.GPT4_TURBO: 10,
LlmModel.GPT3_5_TURBO: 1,
LlmModel.CLAUDE_3_7_SONNET: 5,
LlmModel.CLAUDE_3_5_SONNET: 4,
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
LlmModel.CLAUDE_3_HAIKU: 1,
@@ -63,7 +59,6 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.DEEPSEEK_LLAMA_70B: 1, # ? / ?
LlmModel.OLLAMA_DOLPHIN: 1,
LlmModel.GEMINI_FLASH_1_5: 1,
LlmModel.GEMINI_2_5_PRO: 4,
LlmModel.GROK_BETA: 5,
LlmModel.MISTRAL_NEMO: 1,
LlmModel.COHERE_COMMAND_R_08_2024: 1,
@@ -79,8 +74,6 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.AMAZON_NOVA_PRO_V1: 1,
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
LlmModel.META_LLAMA_4_SCOUT: 1,
LlmModel.META_LLAMA_4_MAVERICK: 1,
}
for model in LlmModel:
@@ -272,5 +265,4 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
},
)
],
SmartDecisionMakerBlock: LLM_COST,
}

View File

@@ -11,20 +11,18 @@ from prisma.enums import (
CreditRefundRequestStatus,
CreditTransactionType,
NotificationType,
OnboardingStep,
)
from prisma.errors import UniqueViolationError
from prisma.models import CreditRefundRequest, CreditTransaction, User
from prisma.types import (
CreditRefundRequestCreateInput,
CreditTransactionCreateInput,
CreditTransactionWhereInput,
)
from prisma.types import CreditTransactionCreateInput, CreditTransactionWhereInput
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_exponential
from backend.data import db
from backend.data.block import Block, BlockInput, get_block
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCost
from backend.data.cost import BlockCost, BlockCostType
from backend.data.execution import NodeExecutionEntry
from backend.data.model import (
AutoTopUpConfig,
RefundRequest,
@@ -33,16 +31,13 @@ from backend.data.model import (
)
from backend.data.notifications import NotificationEventDTO, RefundRequestData
from backend.data.user import get_user_by_id
from backend.executor.utils import UsageTransactionMetadata
from backend.notifications import NotificationManager
from backend.util.exceptions import InsufficientBalanceError
from backend.util.service import get_service_client
from backend.util.settings import Settings
settings = Settings()
stripe.api_key = settings.secrets.stripe_api_key
logger = logging.getLogger(__name__)
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
class UserCreditBase(ABC):
@@ -94,20 +89,20 @@ class UserCreditBase(ABC):
@abstractmethod
async def spend_credits(
self,
user_id: str,
cost: int,
metadata: UsageTransactionMetadata,
entry: NodeExecutionEntry,
data_size: float,
run_time: float,
) -> int:
"""
Spend the credits for the user based on the cost.
Spend the credits for the user based on the block usage.
Args:
user_id (str): The user ID.
cost (int): The cost to spend.
metadata (UsageTransactionMetadata): The metadata of the transaction.
entry (NodeExecutionEntry): The node execution identifiers & data.
data_size (float): The size of the data being processed.
run_time (float): The time taken to run the block.
Returns:
int: The remaining balance.
int: amount of credit spent
"""
pass
@@ -122,18 +117,6 @@ class UserCreditBase(ABC):
"""
pass
@abstractmethod
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
"""
Reward the user with credits for completing an onboarding step.
Won't reward if the user has already received credits for the step.
Args:
user_id (str): The user ID.
step (OnboardingStep): The onboarding step.
"""
pass
@abstractmethod
async def top_up_intent(self, user_id: str, amount: int) -> str:
"""
@@ -201,14 +184,6 @@ class UserCreditBase(ABC):
"""
pass
@staticmethod
async def create_billing_portal_session(user_id: str) -> str:
session = stripe.billing_portal.Session.create(
customer=await get_stripe_customer_id(user_id),
return_url=base_url + "/profile/credits",
)
return session.url
@staticmethod
def time_now() -> datetime:
return datetime.now(timezone.utc)
@@ -226,7 +201,7 @@ class UserCreditBase(ABC):
"userId": user_id,
"createdAt": {"lte": top_time},
"isActive": True,
"NOT": [{"runningBalance": None}],
"runningBalance": {"not": None}, # type: ignore
},
order={"createdAt": "desc"},
)
@@ -274,6 +249,7 @@ class UserCreditBase(ABC):
metadata: Json,
new_transaction_key: str | None = None,
):
transaction = await CreditTransaction.prisma().find_first_or_raise(
where={"transactionKey": transaction_key, "userId": user_id}
)
@@ -338,32 +314,39 @@ class UserCreditBase(ABC):
if amount < 0 and user_balance + amount < 0:
if fail_insufficient_credits:
raise InsufficientBalanceError(
message=f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}",
user_id=user_id,
balance=user_balance,
amount=amount,
raise ValueError(
f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}"
)
amount = min(-user_balance, 0)
# Create the transaction
transaction_data = CreditTransactionCreateInput(
userId=user_id,
amount=amount,
runningBalance=user_balance + amount,
type=transaction_type,
metadata=metadata,
isActive=is_active,
createdAt=self.time_now(),
)
transaction_data: CreditTransactionCreateInput = {
"userId": user_id,
"amount": amount,
"runningBalance": user_balance + amount,
"type": transaction_type,
"metadata": metadata,
"isActive": is_active,
"createdAt": self.time_now(),
}
if transaction_key:
transaction_data["transactionKey"] = transaction_key
tx = await CreditTransaction.prisma().create(data=transaction_data)
return user_balance + amount, tx.transactionKey
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
node_exec_id: str | None = None
block_id: str | None = None
block: str | None = None
input: BlockInput | None = None
class UserCredit(UserCreditBase):
@thread_cached
def notification_client(self) -> NotificationManager:
return get_service_client(NotificationManager)
@@ -376,6 +359,7 @@ class UserCredit(UserCreditBase):
await asyncio.to_thread(
lambda: self.notification_client().queue_notification(
NotificationEventDTO(
recipient_email=settings.config.refund_notification_email,
user_id=notification_request.user_id,
type=notification_type,
data=notification_request.model_dump(),
@@ -383,21 +367,89 @@ class UserCredit(UserCreditBase):
)
)
def _block_usage_cost(
self,
block: Block,
input_data: BlockInput,
data_size: float,
run_time: float,
) -> tuple[int, BlockInput]:
block_costs = BLOCK_COSTS.get(type(block))
if not block_costs:
return 0, {}
for block_cost in block_costs:
if not self._is_cost_filter_match(block_cost.cost_filter, input_data):
continue
if block_cost.cost_type == BlockCostType.RUN:
return block_cost.cost_amount, block_cost.cost_filter
if block_cost.cost_type == BlockCostType.SECOND:
return (
int(run_time * block_cost.cost_amount),
block_cost.cost_filter,
)
if block_cost.cost_type == BlockCostType.BYTE:
return (
int(data_size * block_cost.cost_amount),
block_cost.cost_filter,
)
return 0, {}
def _is_cost_filter_match(
self, cost_filter: BlockInput, input_data: BlockInput
) -> bool:
"""
Filter rules:
- If cost_filter is an object, then check if cost_filter is the subset of input_data
- Otherwise, check if cost_filter is equal to input_data.
- Undefined, null, and empty string are considered as equal.
"""
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
return cost_filter == input_data
return all(
(not input_data.get(k) and not v)
or (input_data.get(k) and self._is_cost_filter_match(v, input_data[k]))
for k, v in cost_filter.items()
)
async def spend_credits(
self,
user_id: str,
cost: int,
metadata: UsageTransactionMetadata,
entry: NodeExecutionEntry,
data_size: float,
run_time: float,
) -> int:
block = get_block(entry.block_id)
if not block:
raise ValueError(f"Block not found: {entry.block_id}")
cost, matching_filter = self._block_usage_cost(
block=block, input_data=entry.data, data_size=data_size, run_time=run_time
)
if cost == 0:
return 0
balance, _ = await self._add_transaction(
user_id=user_id,
user_id=entry.user_id,
amount=-cost,
transaction_type=CreditTransactionType.USAGE,
metadata=Json(metadata.model_dump()),
metadata=Json(
UsageTransactionMetadata(
graph_exec_id=entry.graph_exec_id,
graph_id=entry.graph_id,
node_id=entry.node_id,
node_exec_id=entry.node_exec_id,
block_id=entry.block_id,
block=block.name,
input=matching_filter,
).model_dump()
),
)
user_id = entry.user_id
# Auto top-up if balance is below threshold.
auto_top_up = await get_auto_top_up(user_id)
@@ -407,7 +459,7 @@ class UserCredit(UserCreditBase):
user_id=user_id,
amount=auto_top_up.amount,
# Avoid multiple auto top-ups within the same graph execution.
key=f"AUTO-TOP-UP-{user_id}-{metadata.graph_exec_id}",
key=f"AUTO-TOP-UP-{user_id}-{entry.graph_exec_id}",
ceiling_balance=auto_top_up.threshold,
)
except Exception as e:
@@ -416,29 +468,11 @@ class UserCredit(UserCreditBase):
f"Auto top-up failed for user {user_id}, balance: {balance}, amount: {auto_top_up.amount}, error: {e}"
)
return balance
return cost
async def top_up_credits(self, user_id: str, amount: int):
await self._top_up_credits(user_id, amount)
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
key = f"REWARD-{user_id}-{step.value}"
if not await CreditTransaction.prisma().find_first(
where={
"userId": user_id,
"transactionKey": key,
}
):
await self._add_transaction(
user_id=user_id,
amount=credits,
transaction_type=CreditTransactionType.GRANT,
transaction_key=key,
metadata=Json(
{"reason": f"Reward for completing {step.value} onboarding step."}
),
)
async def top_up_refund(
self, user_id: str, transaction_key: str, metadata: dict[str, str]
) -> int:
@@ -457,15 +491,15 @@ class UserCredit(UserCreditBase):
try:
refund_request = await CreditRefundRequest.prisma().create(
data=CreditRefundRequestCreateInput(
id=refund_key,
transactionKey=transaction_key,
userId=user_id,
amount=amount,
reason=metadata.get("reason", ""),
status=CreditRefundRequestStatus.PENDING,
result="The refund request is under review.",
)
data={
"id": refund_key,
"transactionKey": transaction_key,
"userId": user_id,
"amount": amount,
"reason": metadata.get("reason", ""),
"status": CreditRefundRequestStatus.PENDING,
"result": "The refund request is under review.",
}
)
except UniqueViolationError:
raise ValueError(
@@ -729,8 +763,10 @@ class UserCredit(UserCreditBase):
ui_mode="hosted",
payment_intent_data={"setup_future_usage": "off_session"},
saved_payment_method_options={"payment_method_save": "enabled"},
success_url=base_url + "/profile/credits?topup=success",
cancel_url=base_url + "/profile/credits?topup=cancel",
success_url=settings.config.frontend_base_url
+ "/profile/credits?topup=success",
cancel_url=settings.config.frontend_base_url
+ "/profile/credits?topup=cancel",
allow_promotion_codes=True,
)
@@ -805,6 +841,7 @@ class UserCredit(UserCreditBase):
transaction_time_ceiling: datetime | None = None,
transaction_type: str | None = None,
) -> TransactionHistory:
transactions_filter: CreditTransactionWhereInput = {
"userId": user_id,
"isActive": True,
@@ -926,9 +963,6 @@ class DisabledUserCredit(UserCreditBase):
async def top_up_credits(self, *args, **kwargs):
pass
async def onboarding_reward(self, *args, **kwargs):
pass
async def top_up_intent(self, *args, **kwargs) -> str:
return ""

View File

@@ -2,7 +2,6 @@ import logging
import os
import zlib
from contextlib import asynccontextmanager
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
from uuid import uuid4
from dotenv import load_dotenv
@@ -16,36 +15,7 @@ load_dotenv()
PRISMA_SCHEMA = os.getenv("PRISMA_SCHEMA", "schema.prisma")
os.environ["PRISMA_SCHEMA_PATH"] = PRISMA_SCHEMA
def add_param(url: str, key: str, value: str) -> str:
p = urlparse(url)
qs = dict(parse_qsl(p.query))
qs[key] = value
return urlunparse(p._replace(query=urlencode(qs)))
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
CONN_LIMIT = os.getenv("DB_CONNECTION_LIMIT")
if CONN_LIMIT:
DATABASE_URL = add_param(DATABASE_URL, "connection_limit", CONN_LIMIT)
CONN_TIMEOUT = os.getenv("DB_CONNECT_TIMEOUT")
if CONN_TIMEOUT:
DATABASE_URL = add_param(DATABASE_URL, "connect_timeout", CONN_TIMEOUT)
POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT")
if POOL_TIMEOUT:
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
prisma = Prisma(
auto_register=True,
http={"timeout": HTTP_TIMEOUT},
datasource={"url": DATABASE_URL},
)
prisma = Prisma(auto_register=True)
logger = logging.getLogger(__name__)
@@ -62,10 +32,10 @@ async def connect():
# Connection acquired from a pool like Supabase somehow still possibly allows
# the db client obtains a connection but still reject query connection afterward.
# try:
# await prisma.execute_raw("SELECT 1")
# except Exception as e:
# raise ConnectionError("Failed to connect to Prisma.") from e
try:
await prisma.execute_raw("SELECT 1")
except Exception as e:
raise ConnectionError("Failed to connect to Prisma.") from e
@conn_retry("Prisma", "Releasing connection")
@@ -89,7 +59,7 @@ async def transaction():
async def locked_transaction(key: str):
lock_key = zlib.crc32(key.encode("utf-8"))
async with transaction() as tx:
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
await tx.execute_raw(f"SELECT pg_advisory_xact_lock({lock_key})")
yield tx

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,4 @@
from typing import cast
import prisma.enums
import prisma.types
from backend.blocks.io import IO_BLOCK_IDs
import prisma
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
"Input": True,
@@ -13,61 +8,27 @@ AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
}
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
"Nodes": {"include": AGENT_NODE_INCLUDE}
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
}
EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
"Input": True,
"Output": True,
"Node": True,
"GraphExecution": True,
}
MAX_NODE_EXECUTIONS_FETCH = 1000
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
"NodeExecutions": {
"include": {
"Input": True,
"Output": True,
"Node": True,
"GraphExecution": True,
},
"order_by": [
{"queuedTime": "desc"},
# Fallback: Incomplete execs has no queuedTime.
{"addedTime": "desc"},
],
"take": MAX_NODE_EXECUTIONS_FETCH, # Avoid loading excessive node executions.
}
"AgentNode": True,
"AgentGraphExecution": True,
}
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
"NodeExecutions": {
**cast(
prisma.types.FindManyAgentNodeExecutionArgsFromAgentGraphExecution,
GRAPH_EXECUTION_INCLUDE_WITH_NODES["NodeExecutions"],
),
"where": {
"Node": {"is": {"AgentBlock": {"is": {"id": {"in": IO_BLOCK_IDs}}}}},
"NOT": [{"executionStatus": prisma.enums.AgentExecutionStatus.INCOMPLETE}],
},
"AgentNodeExecutions": {
"include": {
"Input": True,
"Output": True,
"AgentNode": True,
"AgentGraphExecution": True,
}
}
}
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
"AgentNodes": {"include": AGENT_NODE_INCLUDE}
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
}
def library_agent_include(user_id: str) -> prisma.types.LibraryAgentInclude:
return {
"AgentGraph": {
"include": {
**AGENT_GRAPH_INCLUDE,
"Executions": {"where": {"userId": user_id}},
}
},
"Creator": True,
}

View File

@@ -3,14 +3,12 @@ from typing import TYPE_CHECKING, AsyncGenerator, Optional
from prisma import Json
from prisma.models import IntegrationWebhook
from prisma.types import IntegrationWebhookCreateInput
from pydantic import Field, computed_field
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
from backend.data.queue import AsyncRedisEventBus
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.utils import webhook_ingress_url
from backend.util.exceptions import NotFoundError
from .db import BaseDbModel
@@ -67,35 +65,28 @@ class Webhook(BaseDbModel):
async def create_webhook(webhook: Webhook) -> Webhook:
created_webhook = await IntegrationWebhook.prisma().create(
data=IntegrationWebhookCreateInput(
id=webhook.id,
userId=webhook.user_id,
provider=webhook.provider.value,
credentialsId=webhook.credentials_id,
webhookType=webhook.webhook_type,
resource=webhook.resource,
events=webhook.events,
config=Json(webhook.config),
secret=webhook.secret,
providerWebhookId=webhook.provider_webhook_id,
)
data={
"id": webhook.id,
"userId": webhook.user_id,
"provider": webhook.provider.value,
"credentialsId": webhook.credentials_id,
"webhookType": webhook.webhook_type,
"resource": webhook.resource,
"events": webhook.events,
"config": Json(webhook.config),
"secret": webhook.secret,
"providerWebhookId": webhook.provider_webhook_id,
}
)
return Webhook.from_db(created_webhook)
async def get_webhook(webhook_id: str) -> Webhook:
"""
⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints.
Raises:
NotFoundError: if no record with the given ID exists
"""
webhook = await IntegrationWebhook.prisma().find_unique(
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
webhook = await IntegrationWebhook.prisma().find_unique_or_raise(
where={"id": webhook_id},
include=INTEGRATION_WEBHOOK_INCLUDE,
)
if not webhook:
raise NotFoundError(f"Webhook #{webhook_id} not found")
return Webhook.from_db(webhook)

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import base64
import logging
from collections import defaultdict
from datetime import datetime, timezone
from typing import (
TYPE_CHECKING,
@@ -13,7 +12,6 @@ from typing import (
Generic,
Literal,
Optional,
Sequence,
TypedDict,
TypeVar,
get_args,
@@ -143,20 +141,17 @@ def SchemaField(
secret: bool = False,
exclude: bool = False,
hidden: Optional[bool] = None,
depends_on: Optional[list[str]] = None,
ge: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
discriminator: Optional[str] = None,
json_schema_extra: Optional[dict[str, Any]] = None,
depends_on: list[str] | None = None,
image_upload: Optional[bool] = None,
image_output: Optional[bool] = None,
**kwargs,
) -> T:
if default is PydanticUndefined and default_factory is None:
advanced = False
elif advanced is None:
advanced = True
json_schema_extra = {
json_extra = {
k: v
for k, v in {
"placeholder": placeholder,
@@ -164,7 +159,8 @@ def SchemaField(
"advanced": advanced,
"hidden": hidden,
"depends_on": depends_on,
**(json_schema_extra or {}),
"image_upload": image_upload,
"image_output": image_output,
}.items()
if v is not None
}
@@ -176,12 +172,8 @@ def SchemaField(
title=title,
description=description,
exclude=exclude,
ge=ge,
le=le,
min_length=min_length,
max_length=max_length,
discriminator=discriminator,
json_schema_extra=json_schema_extra,
json_schema_extra=json_extra,
**kwargs,
) # type: ignore
@@ -302,7 +294,9 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
)
field_schema = model.jsonschema()["properties"][field_name]
try:
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
field_schema
)
except ValidationError as e:
if "Field required [type=missing" not in str(e):
raise
@@ -328,90 +322,14 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
)
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
provider: frozenset[CP] = Field(..., alias="credentials_provider")
supported_types: frozenset[CT] = Field(..., alias="credentials_types")
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
credentials_provider: list[CP]
credentials_scopes: Optional[list[str]] = None
credentials_types: list[CT]
discriminator: Optional[str] = None
discriminator_mapping: Optional[dict[str, CP]] = None
@classmethod
def combine(
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
"""
Combines multiple CredentialsFieldInfo objects into as few as possible.
Rules:
- Items can only be combined if they have the same supported credentials types
and the same supported providers.
- When combining items, the `required_scopes` of the result is a join
of the `required_scopes` of the original items.
Params:
*fields: (CredentialsFieldInfo, key) objects to group and combine
Returns:
A sequence of tuples containing combined CredentialsFieldInfo objects and
the set of keys of the respective original items that were grouped together.
"""
if not fields:
return []
# Group fields by their provider and supported_types
grouped_fields: defaultdict[
tuple[frozenset[CP], frozenset[CT]],
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
] = defaultdict(list)
for field, key in fields:
group_key = (frozenset(field.provider), frozenset(field.supported_types))
grouped_fields[group_key].append((key, field))
# Combine fields within each group
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
for group in grouped_fields.values():
# Start with the first field in the group
_, combined = group[0]
# Track the keys that were combined
combined_keys = {key for key, _ in group}
# Combine required_scopes from all fields in the group
all_scopes = set()
for _, field in group:
if field.required_scopes:
all_scopes.update(field.required_scopes)
# Create a new combined field
result.append(
(
CredentialsFieldInfo[CP, CT](
credentials_provider=combined.provider,
credentials_types=combined.supported_types,
credentials_scopes=frozenset(all_scopes) or None,
discriminator=combined.discriminator,
discriminator_mapping=combined.discriminator_mapping,
),
combined_keys,
)
)
return result
def discriminate(self, discriminator_value: Any) -> CredentialsFieldInfo:
if not (self.discriminator and self.discriminator_mapping):
return self
discriminator_value = self.discriminator_mapping[discriminator_value]
return CredentialsFieldInfo(
credentials_provider=frozenset([discriminator_value]),
credentials_types=self.supported_types,
credentials_scopes=self.required_scopes,
)
def CredentialsField(
required_scopes: set[str] = set(),
@@ -484,46 +402,3 @@ class RefundRequest(BaseModel):
status: str
created_at: datetime
updated_at: datetime
class NodeExecutionStats(BaseModel):
"""Execution statistics for a node execution."""
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
)
error: Optional[Exception | str] = None
walltime: float = 0
cputime: float = 0
input_size: int = 0
output_size: int = 0
llm_call_count: int = 0
llm_retry_count: int = 0
input_token_count: int = 0
output_token_count: int = 0
class GraphExecutionStats(BaseModel):
"""Execution statistics for a graph execution."""
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
)
error: Optional[Exception | str] = None
walltime: float = Field(
default=0, description="Time between start and end of run (seconds)"
)
cputime: float = 0
nodes_walltime: float = Field(
default=0, description="Total node execution time (seconds)"
)
nodes_cputime: float = 0
node_count: int = Field(default=0, description="Total number of node executions")
node_error_count: int = Field(
default=0, description="Total number of errors generated"
)
cost: int = Field(default=0, description="Total execution cost (cents)")

View File

@@ -1,19 +1,15 @@
import logging
from datetime import datetime, timedelta, timezone
from datetime import datetime, timedelta
from enum import Enum
from typing import Annotated, Any, Generic, Optional, TypeVar, Union
from prisma import Json
from prisma.enums import NotificationType
from prisma.models import NotificationEvent, UserNotificationBatch
from prisma.types import (
NotificationEventCreateInput,
UserNotificationBatchCreateInput,
UserNotificationBatchWhereInput,
)
from prisma.types import UserNotificationBatchWhereInput
# from backend.notifications.models import NotificationEvent
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
from pydantic import BaseModel, EmailStr, Field, field_validator
from backend.server.v2.store.exceptions import DatabaseError
@@ -22,24 +18,18 @@ from .db import transaction
logger = logging.getLogger(__name__)
NotificationDataType_co = TypeVar(
"NotificationDataType_co", bound="BaseNotificationData", covariant=True
)
SummaryParamsType_co = TypeVar(
"SummaryParamsType_co", bound="BaseSummaryParams", covariant=True
)
T_co = TypeVar("T_co", bound="BaseNotificationData", covariant=True)
class QueueType(Enum):
class BatchingStrategy(Enum):
IMMEDIATE = "immediate" # Send right away (errors, critical notifications)
BATCH = "batch" # Batch for up to an hour (usage reports)
SUMMARY = "summary" # Daily digest (summary notifications)
HOURLY = "hourly" # Batch for up to an hour (usage reports)
DAILY = "daily" # Daily digest (summary notifications)
BACKOFF = "backoff" # Backoff strategy (exponential backoff)
ADMIN = "admin" # Admin notifications (errors, critical notifications)
class BaseNotificationData(BaseModel):
model_config = ConfigDict(extra="allow")
pass
class AgentRunData(BaseNotificationData):
@@ -48,7 +38,7 @@ class AgentRunData(BaseNotificationData):
execution_time: float
node_count: int = Field(..., description="Number of nodes executed")
graph_id: str
outputs: list[dict[str, Any]] = Field(..., description="Outputs of the agent")
outputs: dict[str, Any] = Field(..., description="Outputs of the agent")
class ZeroBalanceData(BaseNotificationData):
@@ -56,21 +46,12 @@ class ZeroBalanceData(BaseNotificationData):
last_transaction_time: datetime
top_up_link: str
@field_validator("last_transaction_time")
@classmethod
def validate_timezone(cls, value: datetime):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class LowBalanceData(BaseNotificationData):
agent_name: str = Field(..., description="Name of the agent")
current_balance: float = Field(
..., description="Current balance in credits (100 = $1)"
)
billing_page_link: str = Field(..., description="Link to billing page")
shortfall: float = Field(..., description="Amount of credits needed to continue")
current_balance: float
threshold_amount: float
top_up_link: str
recent_usage: float = Field(..., description="Usage in the last 24 hours")
class BlockExecutionFailedData(BaseNotificationData):
@@ -91,13 +72,6 @@ class ContinuousAgentErrorData(BaseNotificationData):
error_time: datetime
attempts: int = Field(..., description="Number of retry attempts made")
@field_validator("start_time", "error_time")
@classmethod
def validate_timezone(cls, value: datetime):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class BaseSummaryData(BaseNotificationData):
total_credits_used: float
@@ -110,53 +84,18 @@ class BaseSummaryData(BaseNotificationData):
cost_breakdown: dict[str, float]
class BaseSummaryParams(BaseModel):
pass
class DailySummaryParams(BaseSummaryParams):
date: datetime
@field_validator("date")
def validate_timezone(cls, value):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class WeeklySummaryParams(BaseSummaryParams):
start_date: datetime
end_date: datetime
@field_validator("start_date", "end_date")
def validate_timezone(cls, value):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class DailySummaryData(BaseSummaryData):
date: datetime
@field_validator("date")
def validate_timezone(cls, value):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class WeeklySummaryData(BaseSummaryData):
start_date: datetime
end_date: datetime
@field_validator("start_date", "end_date")
def validate_timezone(cls, value):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
week_number: int
year: int
class MonthlySummaryData(BaseNotificationData):
class MonthlySummaryData(BaseSummaryData):
month: int
year: int
@@ -180,10 +119,6 @@ NotificationData = Annotated[
BlockExecutionFailedData,
ContinuousAgentErrorData,
MonthlySummaryData,
WeeklySummaryData,
DailySummaryData,
RefundRequestData,
BaseSummaryData,
],
Field(discriminator="type"),
]
@@ -193,25 +128,19 @@ class NotificationEventDTO(BaseModel):
user_id: str
type: NotificationType
data: dict
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
created_at: datetime = Field(default_factory=datetime.now)
recipient_email: Optional[str] = None
retry_count: int = 0
class SummaryParamsEventDTO(BaseModel):
class NotificationEventModel(BaseModel, Generic[T_co]):
user_id: str
type: NotificationType
data: dict
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
class NotificationEventModel(BaseModel, Generic[NotificationDataType_co]):
user_id: str
type: NotificationType
data: NotificationDataType_co
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
data: T_co
created_at: datetime = Field(default_factory=datetime.now)
@property
def strategy(self) -> QueueType:
def strategy(self) -> BatchingStrategy:
return NotificationTypeOverride(self.type).strategy
@field_validator("type", mode="before")
@@ -225,14 +154,7 @@ class NotificationEventModel(BaseModel, Generic[NotificationDataType_co]):
return NotificationTypeOverride(self.type).template
class SummaryParamsEventModel(BaseModel, Generic[SummaryParamsType_co]):
user_id: str
type: NotificationType
data: SummaryParamsType_co
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
def get_notif_data_type(
def get_data_type(
notification_type: NotificationType,
) -> type[BaseNotificationData]:
return {
@@ -249,20 +171,11 @@ def get_notif_data_type(
}[notification_type]
def get_summary_params_type(
notification_type: NotificationType,
) -> type[BaseSummaryParams]:
return {
NotificationType.DAILY_SUMMARY: DailySummaryParams,
NotificationType.WEEKLY_SUMMARY: WeeklySummaryParams,
}[notification_type]
class NotificationBatch(BaseModel):
user_id: str
events: list[NotificationEvent]
strategy: QueueType
last_update: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
strategy: BatchingStrategy
last_update: datetime = datetime.now()
class NotificationResult(BaseModel):
@@ -275,22 +188,23 @@ class NotificationTypeOverride:
self.notification_type = notification_type
@property
def strategy(self) -> QueueType:
def strategy(self) -> BatchingStrategy:
BATCHING_RULES = {
# These are batched by the notification service
NotificationType.AGENT_RUN: QueueType.BATCH,
NotificationType.AGENT_RUN: BatchingStrategy.IMMEDIATE,
# These are batched by the notification service, but with a backoff strategy
NotificationType.ZERO_BALANCE: QueueType.BACKOFF,
NotificationType.LOW_BALANCE: QueueType.IMMEDIATE,
NotificationType.BLOCK_EXECUTION_FAILED: QueueType.BACKOFF,
NotificationType.CONTINUOUS_AGENT_ERROR: QueueType.BACKOFF,
NotificationType.DAILY_SUMMARY: QueueType.SUMMARY,
NotificationType.WEEKLY_SUMMARY: QueueType.SUMMARY,
NotificationType.MONTHLY_SUMMARY: QueueType.SUMMARY,
NotificationType.REFUND_REQUEST: QueueType.ADMIN,
NotificationType.REFUND_PROCESSED: QueueType.ADMIN,
NotificationType.ZERO_BALANCE: BatchingStrategy.BACKOFF,
NotificationType.LOW_BALANCE: BatchingStrategy.BACKOFF,
NotificationType.BLOCK_EXECUTION_FAILED: BatchingStrategy.BACKOFF,
NotificationType.CONTINUOUS_AGENT_ERROR: BatchingStrategy.BACKOFF,
# These aren't batched by the notification service, so we send them right away
NotificationType.DAILY_SUMMARY: BatchingStrategy.IMMEDIATE,
NotificationType.WEEKLY_SUMMARY: BatchingStrategy.IMMEDIATE,
NotificationType.MONTHLY_SUMMARY: BatchingStrategy.IMMEDIATE,
NotificationType.REFUND_REQUEST: BatchingStrategy.IMMEDIATE,
NotificationType.REFUND_PROCESSED: BatchingStrategy.IMMEDIATE,
}
return BATCHING_RULES.get(self.notification_type, QueueType.IMMEDIATE)
return BATCHING_RULES.get(self.notification_type, BatchingStrategy.HOURLY)
@property
def template(self) -> str:
@@ -340,51 +254,12 @@ class NotificationPreference(BaseModel):
)
daily_limit: int = 10 # Max emails per day
emails_sent_today: int = 0
last_reset_date: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc)
)
class UserNotificationEventDTO(BaseModel):
type: NotificationType
data: dict
created_at: datetime
updated_at: datetime
@staticmethod
def from_db(model: NotificationEvent) -> "UserNotificationEventDTO":
return UserNotificationEventDTO(
type=model.type,
data=dict(model.data),
created_at=model.createdAt,
updated_at=model.updatedAt,
)
class UserNotificationBatchDTO(BaseModel):
user_id: str
type: NotificationType
notifications: list[UserNotificationEventDTO]
created_at: datetime
updated_at: datetime
@staticmethod
def from_db(model: UserNotificationBatch) -> "UserNotificationBatchDTO":
return UserNotificationBatchDTO(
user_id=model.userId,
type=model.type,
notifications=[
UserNotificationEventDTO.from_db(notification)
for notification in model.Notifications or []
],
created_at=model.createdAt,
updated_at=model.updatedAt,
)
last_reset_date: datetime = Field(default_factory=datetime.now)
def get_batch_delay(notification_type: NotificationType) -> timedelta:
return {
NotificationType.AGENT_RUN: timedelta(minutes=60),
NotificationType.AGENT_RUN: timedelta(seconds=1),
NotificationType.ZERO_BALANCE: timedelta(minutes=60),
NotificationType.LOW_BALANCE: timedelta(minutes=60),
NotificationType.BLOCK_EXECUTION_FAILED: timedelta(minutes=60),
@@ -395,17 +270,19 @@ def get_batch_delay(notification_type: NotificationType) -> timedelta:
async def create_or_add_to_user_notification_batch(
user_id: str,
notification_type: NotificationType,
notification_data: NotificationEventModel,
) -> UserNotificationBatchDTO:
data: str, # type: 'NotificationEventModel'
) -> dict:
try:
logger.info(
f"Creating or adding to notification batch for {user_id} with type {notification_type} and data {notification_data}"
f"Creating or adding to notification batch for {user_id} with type {notification_type} and data {data}"
)
if not notification_data.data:
raise ValueError("Notification data must be provided")
notification_data = NotificationEventModel[
get_data_type(notification_type)
].model_validate_json(data)
# Serialize the data
json_data: Json = Json(notification_data.data.model_dump())
json_data: Json = Json(notification_data.data.model_dump_json())
# First try to find existing batch
existing_batch = await UserNotificationBatch.prisma().find_unique(
@@ -415,76 +292,70 @@ async def create_or_add_to_user_notification_batch(
"type": notification_type,
}
},
include={"Notifications": True},
include={"notifications": True},
)
if not existing_batch:
async with transaction() as tx:
notification_event = await tx.notificationevent.create(
data=NotificationEventCreateInput(
type=notification_type,
data=json_data,
)
data={
"type": notification_type,
"data": json_data,
}
)
# Create new batch
resp = await tx.usernotificationbatch.create(
data=UserNotificationBatchCreateInput(
userId=user_id,
type=notification_type,
Notifications={"connect": [{"id": notification_event.id}]},
),
include={"Notifications": True},
data={
"userId": user_id,
"type": notification_type,
"notifications": {"connect": [{"id": notification_event.id}]},
},
include={"notifications": True},
)
return UserNotificationBatchDTO.from_db(resp)
return resp.model_dump()
else:
async with transaction() as tx:
notification_event = await tx.notificationevent.create(
data=NotificationEventCreateInput(
type=notification_type,
data=json_data,
UserNotificationBatch={"connect": {"id": existing_batch.id}},
)
data={
"type": notification_type,
"data": json_data,
"UserNotificationBatch": {"connect": {"id": existing_batch.id}},
}
)
# Add to existing batch
resp = await tx.usernotificationbatch.update(
where={"id": existing_batch.id},
data={
"Notifications": {"connect": [{"id": notification_event.id}]}
"notifications": {"connect": [{"id": notification_event.id}]}
},
include={"Notifications": True},
include={"notifications": True},
)
if not resp:
raise DatabaseError(
f"Failed to add notification event {notification_event.id} to existing batch {existing_batch.id}"
)
return UserNotificationBatchDTO.from_db(resp)
return resp.model_dump()
except Exception as e:
raise DatabaseError(
f"Failed to create or add to notification batch for user {user_id} and type {notification_type}: {e}"
) from e
async def get_user_notification_oldest_message_in_batch(
async def get_user_notification_last_message_in_batch(
user_id: str,
notification_type: NotificationType,
) -> UserNotificationEventDTO | None:
) -> NotificationEvent | None:
try:
batch = await UserNotificationBatch.prisma().find_first(
where={"userId": user_id, "type": notification_type},
include={"Notifications": True},
order={"createdAt": "desc"},
)
if not batch:
return None
if not batch.Notifications:
if not batch.notifications:
return None
sorted_notifications = sorted(batch.Notifications, key=lambda x: x.createdAt)
return (
UserNotificationEventDTO.from_db(sorted_notifications[0])
if sorted_notifications
else None
)
return batch.notifications[-1]
except Exception as e:
raise DatabaseError(
f"Failed to get user notification last message in batch for user {user_id} and type {notification_type}: {e}"
@@ -519,34 +390,13 @@ async def empty_user_notification_batch(
async def get_user_notification_batch(
user_id: str,
notification_type: NotificationType,
) -> UserNotificationBatchDTO | None:
) -> UserNotificationBatch | None:
try:
batch = await UserNotificationBatch.prisma().find_first(
return await UserNotificationBatch.prisma().find_first(
where={"userId": user_id, "type": notification_type},
include={"Notifications": True},
include={"notifications": True},
)
return UserNotificationBatchDTO.from_db(batch) if batch else None
except Exception as e:
raise DatabaseError(
f"Failed to get user notification batch for user {user_id} and type {notification_type}: {e}"
) from e
async def get_all_batches_by_type(
notification_type: NotificationType,
) -> list[UserNotificationBatchDTO]:
try:
batches = await UserNotificationBatch.prisma().find_many(
where={
"type": notification_type,
"Notifications": {
"some": {} # Only return batches with at least one notification
},
},
include={"Notifications": True},
)
return [UserNotificationBatchDTO.from_db(batch) for batch in batches]
except Exception as e:
raise DatabaseError(
f"Failed to get all batches by type {notification_type}: {e}"
) from e

View File

@@ -1,338 +0,0 @@
import re
from typing import Any, Optional
import prisma
import pydantic
from prisma import Json
from prisma.enums import OnboardingStep
from prisma.models import UserOnboarding
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
from backend.data import db
from backend.data.block import get_blocks
from backend.data.credit import get_user_credit_model
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.server.v2.store.model import StoreAgentDetails
# Mapping from user reason id to categories to search for when choosing agent to show
REASON_MAPPING: dict[str, list[str]] = {
"content_marketing": ["writing", "marketing", "creative"],
"business_workflow_automation": ["business", "productivity"],
"data_research": ["data", "research"],
"ai_innovation": ["development", "research"],
"personal_productivity": ["personal", "productivity"],
}
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
user_credit = get_user_credit_model()
class UserOnboardingUpdate(pydantic.BaseModel):
completedSteps: Optional[list[OnboardingStep]] = None
notificationDot: Optional[bool] = None
notified: Optional[list[OnboardingStep]] = None
usageReason: Optional[str] = None
integrations: Optional[list[str]] = None
otherIntegrations: Optional[str] = None
selectedStoreListingVersionId: Optional[str] = None
agentInput: Optional[dict[str, Any]] = None
onboardingAgentExecutionId: Optional[str] = None
async def get_user_onboarding(user_id: str):
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
data={
"create": UserOnboardingCreateInput(userId=user_id),
"update": {},
},
)
async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update: UserOnboardingUpdateInput = {}
if data.completedSteps is not None:
update["completedSteps"] = list(set(data.completedSteps))
for step in (
OnboardingStep.AGENT_NEW_RUN,
OnboardingStep.GET_RESULTS,
OnboardingStep.MARKETPLACE_ADD_AGENT,
OnboardingStep.MARKETPLACE_RUN_AGENT,
OnboardingStep.BUILDER_SAVE_AGENT,
OnboardingStep.BUILDER_RUN_AGENT,
):
if step in data.completedSteps:
await reward_user(user_id, step)
if data.notificationDot is not None:
update["notificationDot"] = data.notificationDot
if data.notified is not None:
update["notified"] = list(set(data.notified))
if data.usageReason is not None:
update["usageReason"] = data.usageReason
if data.integrations is not None:
update["integrations"] = data.integrations
if data.otherIntegrations is not None:
update["otherIntegrations"] = data.otherIntegrations
if data.selectedStoreListingVersionId is not None:
update["selectedStoreListingVersionId"] = data.selectedStoreListingVersionId
if data.agentInput is not None:
update["agentInput"] = Json(data.agentInput)
if data.onboardingAgentExecutionId is not None:
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, **update},
"update": update,
},
)
async def reward_user(user_id: str, step: OnboardingStep):
async with db.locked_transaction(f"usr_trx_{user_id}-reward"):
reward = 0
match step:
# Reward user when they clicked New Run during onboarding
# This is because they need credits before scheduling a run (next step)
case OnboardingStep.AGENT_NEW_RUN:
reward = 300
case OnboardingStep.GET_RESULTS:
reward = 300
case OnboardingStep.MARKETPLACE_ADD_AGENT:
reward = 100
case OnboardingStep.MARKETPLACE_RUN_AGENT:
reward = 100
case OnboardingStep.BUILDER_SAVE_AGENT:
reward = 100
case OnboardingStep.BUILDER_RUN_AGENT:
reward = 100
if reward == 0:
return
onboarding = await get_user_onboarding(user_id)
# Skip if already rewarded
if step in onboarding.rewardedFor:
return
onboarding.rewardedFor.append(step)
await user_credit.onboarding_reward(user_id, reward, step)
await UserOnboarding.prisma().update(
where={"userId": user_id},
data={
"completedSteps": list(set(onboarding.completedSteps + [step])),
"rewardedFor": onboarding.rewardedFor,
},
)
def clean_and_split(text: str) -> list[str]:
"""
Removes all special characters from a string, truncates it to 100 characters,
and splits it by whitespace and commas.
Args:
text (str): The input string.
Returns:
list[str]: A list of cleaned words.
"""
# Remove all special characters (keep only alphanumeric and whitespace)
cleaned_text = re.sub(r"[^a-zA-Z0-9\s,]", "", text.strip()[:100])
# Split by whitespace and commas
words = re.split(r"[\s,]+", cleaned_text)
# Remove empty strings from the list
words = [word.lower() for word in words if word]
return words
def calculate_points(
agent, categories: list[str], custom: list[str], integrations: list[str]
) -> int:
"""
Calculates the total points for an agent based on the specified criteria.
Args:
agent: The agent object.
categories (list[str]): List of categories to match.
words (list[str]): List of words to match in the description.
Returns:
int: Total points for the agent.
"""
points = 0
# 1. Category Matches
matched_categories = sum(
1 for category in categories if category in agent.categories
)
points += matched_categories * 100
# 2. Description Word Matches
description_words = agent.description.split() # Split description into words
matched_words = sum(1 for word in custom if word in description_words)
points += matched_words * 100
matched_words = sum(1 for word in integrations if word in description_words)
points += matched_words * 50
# 3. Featured Bonus
if agent.featured:
points += 50
# 4. Rating Bonus
points += agent.rating * 10
# 5. Runs Bonus
runs_points = min(agent.runs / 1000 * 100, 100) # Cap at 100 points
points += runs_points
return int(points)
def get_credentials_blocks() -> dict[str, str]:
# Returns a dictionary of block id to credentials field name
creds: dict[str, str] = {}
blocks = get_blocks()
for id, block in blocks.items():
for field_name, field_info in block().input_schema.model_fields.items():
if field_info.annotation == CredentialsMetaInput:
creds[id] = field_name
return creds
CREDENTIALS_FIELDS: dict[str, str] = get_credentials_blocks()
async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
user_onboarding = await get_user_onboarding(user_id)
categories = REASON_MAPPING.get(user_onboarding.usageReason or "", [])
where_clause: dict[str, Any] = {}
custom = clean_and_split((user_onboarding.usageReason or "").lower())
if categories:
where_clause["OR"] = [
{"categories": {"has": category}} for category in categories
]
else:
where_clause["OR"] = [
{"description": {"contains": word, "mode": "insensitive"}}
for word in custom
]
where_clause["OR"] += [
{"description": {"contains": word, "mode": "insensitive"}}
for word in user_onboarding.integrations
]
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
where=prisma.types.StoreAgentWhereInput(**where_clause),
order=[
{"featured": "desc"},
{"runs": "desc"},
{"rating": "desc"},
],
take=100,
)
agentListings = await prisma.models.StoreListingVersion.prisma().find_many(
where={
"id": {"in": [agent.storeListingVersionId for agent in storeAgents]},
},
include={"AgentGraph": True},
)
for listing in agentListings:
agent = listing.AgentGraph
if agent is None:
continue
graph = GraphModel.from_db(agent)
# Remove agents with empty input schema
if not graph.input_schema:
storeAgents = [
a for a in storeAgents if a.storeListingVersionId != listing.id
]
continue
# Remove agents with empty credentials
# Get nodes from this agent that have credentials
nodes = await prisma.models.AgentNode.prisma().find_many(
where={
"agentGraphId": agent.id,
"agentBlockId": {"in": list(CREDENTIALS_FIELDS.keys())},
},
)
for node in nodes:
block_id = node.agentBlockId
field_name = CREDENTIALS_FIELDS[block_id]
# If there are no credentials or they are empty, remove the agent
# FIXME ignores default values
if (
field_name not in node.constantInput
or node.constantInput[field_name] is None
):
storeAgents = [
a for a in storeAgents if a.storeListingVersionId != listing.id
]
break
# If there are less than 2 agents, add more agents to the list
if len(storeAgents) < 2:
storeAgents += await prisma.models.StoreAgent.prisma().find_many(
where={
"listing_id": {"not_in": [agent.listing_id for agent in storeAgents]},
},
order=[
{"featured": "desc"},
{"runs": "desc"},
{"rating": "desc"},
],
take=2 - len(storeAgents),
)
# Calculate points for the first X agents and choose the top 2
agent_points = []
for agent in storeAgents[:POINTS_AGENT_COUNT]:
points = calculate_points(
agent, categories, custom, user_onboarding.integrations
)
agent_points.append((agent, points))
agent_points.sort(key=lambda x: x[1], reverse=True)
recommended_agents = [agent for agent, _ in agent_points[:2]]
return [
StoreAgentDetails(
store_listing_version_id=agent.storeListingVersionId,
slug=agent.slug,
agent_name=agent.agent_name,
agent_video=agent.agent_video or "",
agent_image=agent.agent_image,
creator=agent.creator_username,
creator_avatar=agent.creator_avatar,
sub_heading=agent.sub_heading,
description=agent.description,
categories=agent.categories,
runs=agent.runs,
rating=agent.rating,
versions=agent.versions,
last_updated=agent.updated_at,
)
for agent in recommended_agents
]
async def onboarding_enabled() -> bool:
count = await prisma.models.StoreAgent.prisma().count(take=MIN_AGENT_COUNT + 1)
# Onboading is enabled if there are at least 2 agents in the store
return count >= MIN_AGENT_COUNT

View File

@@ -1,6 +1,8 @@
import asyncio
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, AsyncGenerator, Generator, Generic, Optional, TypeVar
from pydantic import BaseModel
@@ -12,6 +14,13 @@ from backend.data import redis
logger = logging.getLogger(__name__)
class DateTimeEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, datetime):
return o.isoformat()
return super().default(o)
M = TypeVar("M", bound=BaseModel)
@@ -23,14 +32,10 @@ class BaseRedisEventBus(Generic[M], ABC):
def event_bus_name(self) -> str:
pass
@property
def Message(self) -> type["_EventPayloadWrapper[M]"]:
return _EventPayloadWrapper[self.Model]
def _serialize_message(self, item: M, channel_key: str) -> tuple[str, str]:
message = self.Message(payload=item).model_dump_json()
message = json.dumps(item.model_dump(), cls=DateTimeEncoder)
channel_name = f"{self.event_bus_name}/{channel_key}"
logger.debug(f"[{channel_name}] Publishing an event to Redis {message}")
logger.info(f"[{channel_name}] Publishing an event to Redis {message}")
return message, channel_name
def _deserialize_message(self, msg: Any, channel_key: str) -> M | None:
@@ -38,8 +43,9 @@ class BaseRedisEventBus(Generic[M], ABC):
if msg["type"] != message_type:
return None
try:
logger.debug(f"[{channel_key}] Consuming an event from Redis {msg['data']}")
return self.Message.model_validate_json(msg["data"]).payload
data = json.loads(msg["data"])
logger.info(f"Consuming an event from Redis {data}")
return self.Model(**data)
except Exception as e:
logger.error(f"Failed to parse event result from Redis {msg} {e}")
@@ -51,16 +57,9 @@ class BaseRedisEventBus(Generic[M], ABC):
return pubsub, full_channel_name
class _EventPayloadWrapper(BaseModel, Generic[M]):
"""
Wrapper model to allow `RedisEventBus.Model` to be a discriminated union
of multiple event types.
"""
payload: M
class RedisEventBus(BaseRedisEventBus[M], ABC):
Model: type[M]
@property
def connection(self) -> redis.Redis:
return redis.get_redis()
@@ -86,6 +85,8 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
Model: type[M]
@property
async def connection(self) -> redis.AsyncRedis:
return await redis.get_redis_async()

View File

@@ -4,18 +4,10 @@ from enum import Enum
from typing import Awaitable, Optional
import aio_pika
import aio_pika.exceptions as aio_ex
import pika
import pika.adapters.blocking_connection
from pika.exceptions import AMQPError
from pika.spec import BasicProperties
from pydantic import BaseModel
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from backend.util.retry import conn_retry
from backend.util.settings import Settings
@@ -169,12 +161,6 @@ class SyncRabbitMQ(RabbitMQBase):
routing_key=queue.routing_key or queue.name,
)
@retry(
retry=retry_if_exception_type((AMQPError, ConnectionError)),
wait=wait_random_exponential(multiplier=1, max=5),
stop=stop_after_attempt(5),
reraise=True,
)
def publish_message(
self,
routing_key: str,
@@ -272,12 +258,6 @@ class AsyncRabbitMQ(RabbitMQBase):
exchange, routing_key=queue.routing_key or queue.name
)
@retry(
retry=retry_if_exception_type((aio_ex.AMQPError, ConnectionError)),
wait=wait_random_exponential(multiplier=1, max=5),
stop=stop_after_attempt(5),
reraise=True,
)
async def publish_message(
self,
routing_key: str,

View File

@@ -1,24 +1,19 @@
import base64
import hashlib
import hmac
import logging
from datetime import datetime, timedelta
from typing import Optional, cast
from urllib.parse import quote_plus
from autogpt_libs.auth.models import DEFAULT_USER_ID
from fastapi import HTTPException
from prisma import Json
from prisma.enums import NotificationType
from prisma.models import User
from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
from prisma.types import UserUpdateInput
from backend.data.db import prisma
from backend.data.model import UserIntegrations, UserMetadata, UserMetadataRaw
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.server.v2.store.exceptions import DatabaseError
from backend.util.encryption import JSONCryptor
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -36,11 +31,11 @@ async def get_or_create_user(user_data: dict) -> User:
user = await prisma.user.find_unique(where={"id": user_id})
if not user:
user = await prisma.user.create(
data=UserCreateInput(
id=user_id,
email=user_email,
name=user_data.get("user_metadata", {}).get("name"),
)
data={
"id": user_id,
"email": user_email,
"name": user_data.get("user_metadata", {}).get("name"),
}
)
return User.model_validate(user)
@@ -63,14 +58,6 @@ async def get_user_email_by_id(user_id: str) -> Optional[str]:
raise DatabaseError(f"Failed to get user email for user {user_id}: {e}") from e
async def get_user_by_email(email: str) -> Optional[User]:
try:
user = await prisma.user.find_unique(where={"email": email})
return User.model_validate(user) if user else None
except Exception as e:
raise DatabaseError(f"Failed to get user by email {email}: {e}") from e
async def update_user_email(user_id: str, email: str):
try:
await prisma.user.update(where={"id": user_id}, data={"email": email})
@@ -84,11 +71,11 @@ async def create_default_user() -> Optional[User]:
user = await prisma.user.find_unique(where={"id": DEFAULT_USER_ID})
if not user:
user = await prisma.user.create(
data=UserCreateInput(
id=DEFAULT_USER_ID,
email="default@example.com",
name="Default User",
)
data={
"id": DEFAULT_USER_ID,
"email": "default@example.com",
"name": "Default User",
}
)
return User.model_validate(user)
@@ -135,21 +122,16 @@ async def migrate_and_encrypt_user_integrations():
"""Migrate integration credentials and OAuth states from metadata to integrations column."""
users = await User.prisma().find_many(
where={
"metadata": cast(
JsonFilter,
{
"path": ["integration_credentials"],
"not": Json(
{"a": "yolo"}
), # bogus value works to check if key exists
},
)
"metadata": {
"path": ["integration_credentials"],
"not": Json({"a": "yolo"}), # bogus value works to check if key exists
} # type: ignore
}
)
logger.info(f"Migrating integration credentials for {len(users)} users")
for user in users:
raw_metadata = cast(dict, user.metadata)
raw_metadata = cast(UserMetadataRaw, user.metadata)
metadata = UserMetadata.model_validate(raw_metadata)
# Get existing integrations data
@@ -165,6 +147,7 @@ async def migrate_and_encrypt_user_integrations():
await update_user_integrations(user_id=user.id, data=integrations)
# Remove from metadata
raw_metadata = dict(raw_metadata)
raw_metadata.pop("integration_credentials", None)
raw_metadata.pop("integration_oauth_states", None)
@@ -317,85 +300,3 @@ async def update_user_notification_preference(
raise DatabaseError(
f"Failed to update user notification preference for user {user_id}: {e}"
) from e
async def set_user_email_verification(user_id: str, verified: bool) -> None:
"""Set the email verification status for a user."""
try:
await User.prisma().update(
where={"id": user_id},
data={"emailVerified": verified},
)
except Exception as e:
raise DatabaseError(
f"Failed to set email verification status for user {user_id}: {e}"
) from e
async def get_user_email_verification(user_id: str) -> bool:
"""Get the email verification status for a user."""
try:
user = await User.prisma().find_unique_or_raise(
where={"id": user_id},
)
return user.emailVerified
except Exception as e:
raise DatabaseError(
f"Failed to get email verification status for user {user_id}: {e}"
) from e
def generate_unsubscribe_link(user_id: str) -> str:
"""Generate a link to unsubscribe from all notifications"""
# Create an HMAC using a secret key
secret_key = Settings().secrets.unsubscribe_secret_key
signature = hmac.new(
secret_key.encode("utf-8"), user_id.encode("utf-8"), hashlib.sha256
).digest()
# Create a token that combines the user_id and signature
token = base64.urlsafe_b64encode(
f"{user_id}:{signature.hex()}".encode("utf-8")
).decode("utf-8")
logger.info(f"Generating unsubscribe link for user {user_id}")
base_url = Settings().config.platform_base_url
return f"{base_url}/api/email/unsubscribe?token={quote_plus(token)}"
async def unsubscribe_user_by_token(token: str) -> None:
"""Unsubscribe a user from all notifications using the token"""
try:
# Decode the token
decoded = base64.urlsafe_b64decode(token).decode("utf-8")
user_id, received_signature_hex = decoded.split(":", 1)
# Verify the signature
secret_key = Settings().secrets.unsubscribe_secret_key
expected_signature = hmac.new(
secret_key.encode("utf-8"), user_id.encode("utf-8"), hashlib.sha256
).digest()
if not hmac.compare_digest(expected_signature.hex(), received_signature_hex):
raise ValueError("Invalid token signature")
user = await get_user_by_id(user_id)
await update_user_notification_preference(
user.id,
NotificationPreferenceDTO(
email=user.email,
daily_limit=0,
preferences={
NotificationType.AGENT_RUN: False,
NotificationType.ZERO_BALANCE: False,
NotificationType.LOW_BALANCE: False,
NotificationType.BLOCK_EXECUTION_FAILED: False,
NotificationType.CONTINUOUS_AGENT_ERROR: False,
NotificationType.DAILY_SUMMARY: False,
NotificationType.WEEKLY_SUMMARY: False,
NotificationType.MONTHLY_SUMMARY: False,
},
),
)
except Exception as e:
raise DatabaseError(f"Failed to unsubscribe user by token {token}: {e}") from e

View File

@@ -1,12 +1,15 @@
from backend.app import run_processes
from backend.executor import ExecutionManager
from backend.executor import DatabaseManager, ExecutionManager
def main():
"""
Run all the processes required for the AutoGPT-server REST API.
"""
run_processes(ExecutionManager())
run_processes(
DatabaseManager(),
ExecutionManager(),
)
if __name__ == "__main__":

View File

@@ -1,9 +1,9 @@
from .database import DatabaseManager
from .manager import ExecutionManager
from .scheduler import Scheduler
from .scheduler import ExecutionScheduler
__all__ = [
"DatabaseManager",
"ExecutionManager",
"Scheduler",
"ExecutionScheduler",
]

View File

@@ -1,89 +1,73 @@
import logging
from functools import wraps
from typing import Any, Callable, Concatenate, Coroutine, ParamSpec, TypeVar, cast
from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.credit import get_user_credit_model
from backend.data.execution import (
ExecutionResult,
NodeExecutionEntry,
RedisExecutionEventBus,
create_graph_execution,
get_graph_execution,
get_incomplete_node_executions,
get_latest_node_execution,
get_node_execution_results,
update_graph_execution_start_time,
get_execution_results,
get_incomplete_executions,
get_latest_execution,
update_execution_status,
update_graph_execution_stats,
update_node_execution_stats,
update_node_execution_status,
update_node_execution_status_batch,
upsert_execution_input,
upsert_execution_output,
)
from backend.data.graph import (
get_connected_output_nodes,
get_graph,
get_graph_metadata,
get_node,
)
from backend.data.notifications import (
create_or_add_to_user_notification_batch,
empty_user_notification_batch,
get_all_batches_by_type,
get_user_notification_batch,
get_user_notification_oldest_message_in_batch,
)
from backend.data.graph import get_graph, get_node
from backend.data.user import (
get_active_user_ids_in_timerange,
get_user_email_by_id,
get_user_email_verification,
get_user_integrations,
get_user_metadata,
get_user_notification_preference,
update_user_integrations,
update_user_metadata,
)
from backend.util.service import AppService, exposed_run_and_wait
from backend.util.service import AppService, expose, register_pydantic_serializers
from backend.util.settings import Config
P = ParamSpec("P")
R = TypeVar("R")
config = Config()
_user_credit_model = get_user_credit_model()
logger = logging.getLogger(__name__)
async def _spend_credits(
user_id: str, cost: int, metadata: UsageTransactionMetadata
) -> int:
return await _user_credit_model.spend_credits(user_id, cost, metadata)
class DatabaseManager(AppService):
def run_service(self) -> None:
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
self.run_and_wait(db.connect())
super().run_service()
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
self.run_and_wait(db.disconnect())
def __init__(self):
super().__init__()
self.use_db = True
self.use_redis = True
self.event_queue = RedisExecutionEventBus()
@classmethod
def get_port(cls) -> int:
return config.database_api_port
@expose
def send_execution_update(self, execution_result: ExecutionResult):
self.event_queue.publish(execution_result)
@staticmethod
def exposed_run_and_wait(
f: Callable[P, Coroutine[None, None, R]]
) -> Callable[Concatenate[object, P], R]:
@expose
@wraps(f)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
coroutine = f(*args, **kwargs)
res = self.run_and_wait(coroutine)
return res
# Register serializers for annotations on bare function
register_pydantic_serializers(f)
return wrapper
# Executions
get_graph_execution = exposed_run_and_wait(get_graph_execution)
create_graph_execution = exposed_run_and_wait(create_graph_execution)
get_node_execution_results = exposed_run_and_wait(get_node_execution_results)
get_incomplete_node_executions = exposed_run_and_wait(
get_incomplete_node_executions
)
get_latest_node_execution = exposed_run_and_wait(get_latest_node_execution)
update_node_execution_status = exposed_run_and_wait(update_node_execution_status)
update_node_execution_status_batch = exposed_run_and_wait(
update_node_execution_status_batch
)
update_graph_execution_start_time = exposed_run_and_wait(
update_graph_execution_start_time
)
get_execution_results = exposed_run_and_wait(get_execution_results)
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
get_latest_execution = exposed_run_and_wait(get_latest_execution)
update_execution_status = exposed_run_and_wait(update_execution_status)
update_graph_execution_stats = exposed_run_and_wait(update_graph_execution_stats)
update_node_execution_stats = exposed_run_and_wait(update_node_execution_stats)
upsert_execution_input = exposed_run_and_wait(upsert_execution_input)
@@ -92,35 +76,16 @@ class DatabaseManager(AppService):
# Graphs
get_node = exposed_run_and_wait(get_node)
get_graph = exposed_run_and_wait(get_graph)
get_connected_output_nodes = exposed_run_and_wait(get_connected_output_nodes)
get_graph_metadata = exposed_run_and_wait(get_graph_metadata)
# Credits
spend_credits = exposed_run_and_wait(_spend_credits)
user_credit_model = get_user_credit_model()
spend_credits = cast(
Callable[[Any, NodeExecutionEntry, float, float], int],
exposed_run_and_wait(user_credit_model.spend_credits),
)
# User + User Metadata + User Integrations
get_user_metadata = exposed_run_and_wait(get_user_metadata)
update_user_metadata = exposed_run_and_wait(update_user_metadata)
get_user_integrations = exposed_run_and_wait(get_user_integrations)
update_user_integrations = exposed_run_and_wait(update_user_integrations)
# User Comms - async
get_active_user_ids_in_timerange = exposed_run_and_wait(
get_active_user_ids_in_timerange
)
get_user_email_by_id = exposed_run_and_wait(get_user_email_by_id)
get_user_email_verification = exposed_run_and_wait(get_user_email_verification)
get_user_notification_preference = exposed_run_and_wait(
get_user_notification_preference
)
# Notifications - async
create_or_add_to_user_notification_batch = exposed_run_and_wait(
create_or_add_to_user_notification_batch
)
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
get_all_batches_by_type = exposed_run_and_wait(get_all_batches_by_type)
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)
get_user_notification_oldest_message_in_batch = exposed_run_and_wait(
get_user_notification_oldest_message_in_batch
)

Some files were not shown because too many files have changed in this diff Show More