Compare commits

..

4 Commits

Author SHA1 Message Date
Andy Hooker
6fca2352bb feat(build): Add undo/redo functionality and integrate custom nodes
Introduce a `useUndoRedo` hook to manage state history with undo/redo, persistence using localStorage, and state validation. Updated the build page to display a flow with initial custom nodes and edges, replacing the previous flow editor implementation.
2025-02-22 17:17:46 -06:00
Andy Hooker
84759053a7 feat(build): Add undo/redo functionality and integrate custom nodes
Introduce a `useUndoRedo` hook to manage state history with undo/redo, persistence using localStorage, and state validation. Updated the build page to display a flow with initial custom nodes and edges, replacing the previous flow editor implementation.
2025-02-22 17:17:35 -06:00
Andy Hooker
4afe724628 feat(build): Add BuildFlow component for editable flowchart canvas
Introduces the `BuildFlow` component with undo/redo, drag-and-drop, and connection functionality using the ReactFlow library. This implementation supports state management for nodes and edges and integrates a control panel for user actions like undo, redo, and reset. It also includes a read-only mode for non-editable use cases.
2025-02-22 17:17:15 -06:00
Andy Hooker
c178a537b7 feat(build): Add custom node, edge components, and canvas mapping hook
Introduce `BuildCustomNode` and `BuildCustomEdge` components for enhanced React Flow visualizations, enabling node focus and styled edges. Implement `useCanvasMapping` hook to map and enhance nodes and edges dynamically with custom labels and styles.
2025-02-22 17:16:16 -06:00
257 changed files with 4458 additions and 15126 deletions

View File

@@ -129,6 +129,30 @@ updates:
- "minor"
- "patch"
# Submodules
- package-ecosystem: "gitsubmodule"
directory: "autogpt_platform/supabase"
schedule:
interval: "weekly"
open-pull-requests-limit: 1
target-branch: "dev"
commit-message:
prefix: "chore(platform/deps)"
prefix-development: "chore(platform/deps-dev)"
groups:
production-dependencies:
dependency-type: "production"
update-types:
- "minor"
- "patch"
development-dependencies:
dependency-type: "development"
update-types:
- "minor"
- "patch"
# Docs
- package-ecosystem: 'pip'
directory: "docs/"

View File

@@ -115,7 +115,6 @@ jobs:
poetry run pytest -vv \
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
--numprocesses=logical --durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
tests/unit tests/integration
env:
CI: true
@@ -125,14 +124,8 @@ jobs:
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
- name: Upload test results to Codecov
if: ${{ !cancelled() }} # Run even if tests fail
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: autogpt-agent,${{ runner.os }}

View File

@@ -87,20 +87,13 @@ jobs:
poetry run pytest -vv \
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
tests
env:
CI: true
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Upload test results to Codecov
if: ${{ !cancelled() }} # Run even if tests fail
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: agbenchmark,${{ runner.os }}

View File

@@ -139,7 +139,6 @@ jobs:
poetry run pytest -vv \
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
forge
env:
CI: true
@@ -149,14 +148,8 @@ jobs:
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
- name: Upload test results to Codecov
if: ${{ !cancelled() }} # Run even if tests fail
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: forge,${{ runner.os }}

View File

@@ -66,7 +66,7 @@ jobs:
- name: Setup Supabase
uses: supabase/setup-cli@v1
with:
version: 1.178.1
version: latest
- id: get_date
name: Get date

View File

@@ -82,7 +82,7 @@ jobs:
- name: Copy default supabase .env
run: |
cp ../.env.example ../.env
cp ../supabase/docker/.env.example ../.env
- name: Copy backend .env
run: |

3
.gitmodules vendored
View File

@@ -1,3 +1,6 @@
[submodule "classic/forge/tests/vcr_cassettes"]
path = classic/forge/tests/vcr_cassettes
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes
[submodule "autogpt_platform/supabase"]
path = autogpt_platform/supabase
url = https://github.com/supabase/supabase.git

View File

@@ -20,7 +20,6 @@ Instead, please report them via:
- Please provide detailed reports with reproducible steps
- Include the version/commit hash where you discovered the vulnerability
- Allow us a 90-day security fix window before any public disclosure
- After patch is released, allow 30 days for users to update before public disclosure (for a total of 120 days max between update time and fix time)
- Share any potential mitigations or workarounds if known
## Supported Versions

View File

@@ -1,123 +0,0 @@
############
# Secrets
# YOU MUST CHANGE THESE BEFORE GOING INTO PRODUCTION
############
POSTGRES_PASSWORD=your-super-secret-and-long-postgres-password
JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
DASHBOARD_USERNAME=supabase
DASHBOARD_PASSWORD=this_password_is_insecure_and_should_be_updated
SECRET_KEY_BASE=UpNVntn3cDxHJpq99YMc1T1AQgQpc8kfYTuRgBiYa15BLrx8etQoXz3gZv1/u2oq
VAULT_ENC_KEY=your-encryption-key-32-chars-min
############
# Database - You can change these to any PostgreSQL database that has logical replication enabled.
############
POSTGRES_HOST=db
POSTGRES_DB=postgres
POSTGRES_PORT=5432
# default user is postgres
############
# Supavisor -- Database pooler
############
POOLER_PROXY_PORT_TRANSACTION=6543
POOLER_DEFAULT_POOL_SIZE=20
POOLER_MAX_CLIENT_CONN=100
POOLER_TENANT_ID=your-tenant-id
############
# API Proxy - Configuration for the Kong Reverse proxy.
############
KONG_HTTP_PORT=8000
KONG_HTTPS_PORT=8443
############
# API - Configuration for PostgREST.
############
PGRST_DB_SCHEMAS=public,storage,graphql_public
############
# Auth - Configuration for the GoTrue authentication server.
############
## General
SITE_URL=http://localhost:3000
ADDITIONAL_REDIRECT_URLS=
JWT_EXPIRY=3600
DISABLE_SIGNUP=false
API_EXTERNAL_URL=http://localhost:8000
## Mailer Config
MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify"
MAILER_URLPATHS_INVITE="/auth/v1/verify"
MAILER_URLPATHS_RECOVERY="/auth/v1/verify"
MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify"
## Email auth
ENABLE_EMAIL_SIGNUP=true
ENABLE_EMAIL_AUTOCONFIRM=false
SMTP_ADMIN_EMAIL=admin@example.com
SMTP_HOST=supabase-mail
SMTP_PORT=2500
SMTP_USER=fake_mail_user
SMTP_PASS=fake_mail_password
SMTP_SENDER_NAME=fake_sender
ENABLE_ANONYMOUS_USERS=false
## Phone auth
ENABLE_PHONE_SIGNUP=true
ENABLE_PHONE_AUTOCONFIRM=true
############
# Studio - Configuration for the Dashboard
############
STUDIO_DEFAULT_ORGANIZATION=Default Organization
STUDIO_DEFAULT_PROJECT=Default Project
STUDIO_PORT=3000
# replace if you intend to use Studio outside of localhost
SUPABASE_PUBLIC_URL=http://localhost:8000
# Enable webp support
IMGPROXY_ENABLE_WEBP_DETECTION=true
# Add your OpenAI API key to enable SQL Editor Assistant
OPENAI_API_KEY=
############
# Functions - Configuration for Functions
############
# NOTE: VERIFY_JWT applies to all functions. Per-function VERIFY_JWT is not supported yet.
FUNCTIONS_VERIFY_JWT=false
############
# Logs - Configuration for Logflare
# Please refer to https://supabase.com/docs/reference/self-hosting-analytics/introduction
############
LOGFLARE_LOGGER_BACKEND_API_KEY=your-super-secret-and-long-logflare-key
# Change vector.toml sinks to reflect this change
LOGFLARE_API_KEY=your-super-secret-and-long-logflare-key
# Docker socket location - this value will differ depending on your OS
DOCKER_SOCKET_LOCATION=/var/run/docker.sock
# Google Cloud Project details
GOOGLE_PROJECT_ID=GOOGLE_PROJECT_ID
GOOGLE_PROJECT_NUMBER=GOOGLE_PROJECT_NUMBER

View File

@@ -22,29 +22,35 @@ To run the AutoGPT Platform, follow these steps:
2. Run the following command:
```
cp .env.example .env
git submodule update --init --recursive --progress
```
This command will copy the `.env.example` file to `.env`. You can modify the `.env` file to add your own environment variables.
This command will initialize and update the submodules in the repository. The `supabase` folder will be cloned to the root directory.
3. Run the following command:
```
cp supabase/docker/.env.example .env
```
This command will copy the `.env.example` file to `.env` in the `supabase/docker` directory. You can modify the `.env` file to add your own environment variables.
4. Run the following command:
```
docker compose up -d
```
This command will start all the necessary backend services defined in the `docker-compose.yml` file in detached mode.
4. Navigate to `frontend` within the `autogpt_platform` directory:
5. Navigate to `frontend` within the `autogpt_platform` directory:
```
cd frontend
```
You will need to run your frontend application separately on your local machine.
5. Run the following command:
6. Run the following command:
```
cp .env.example .env.local
```
This command will copy the `.env.example` file to `.env.local` in the `frontend` directory. You can modify the `.env.local` within this folder to add your own environment variables for the frontend application.
6. Run the following command:
7. Run the following command:
```
npm install
npm run dev
@@ -55,7 +61,7 @@ To run the AutoGPT Platform, follow these steps:
yarn install && yarn dev
```
7. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
8. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
### Docker Compose Commands

View File

@@ -1,7 +1,7 @@
from .config import Settings
from .depends import requires_admin_user, requires_user
from .jwt_utils import parse_jwt_token
from .middleware import APIKeyValidator, auth_middleware
from .middleware import auth_middleware
from .models import User
__all__ = [
@@ -9,7 +9,6 @@ __all__ = [
"parse_jwt_token",
"requires_user",
"requires_admin_user",
"APIKeyValidator",
"auth_middleware",
"User",
]

View File

@@ -1,10 +1,7 @@
import inspect
import logging
from typing import Any, Callable, Optional
from fastapi import HTTPException, Request, Security
from fastapi.security import APIKeyHeader, HTTPBearer
from starlette.status import HTTP_401_UNAUTHORIZED
from fastapi import HTTPException, Request
from fastapi.security import HTTPBearer
from .config import settings
from .jwt_utils import parse_jwt_token
@@ -32,104 +29,3 @@ async def auth_middleware(request: Request):
except ValueError as e:
raise HTTPException(status_code=401, detail=str(e))
return payload
class APIKeyValidator:
"""
Configurable API key validator that supports custom validation functions
for FastAPI applications.
This class provides a flexible way to implement API key authentication with optional
custom validation logic. It can be used for simple token matching
or more complex validation scenarios like database lookups.
Examples:
Simple token validation:
```python
validator = APIKeyValidator(
header_name="X-API-Key",
expected_token="your-secret-token"
)
@app.get("/protected", dependencies=[Depends(validator.get_dependency())])
def protected_endpoint():
return {"message": "Access granted"}
```
Custom validation with database lookup:
```python
async def validate_with_db(api_key: str):
api_key_obj = await db.get_api_key(api_key)
return api_key_obj if api_key_obj and api_key_obj.is_active else None
validator = APIKeyValidator(
header_name="X-API-Key",
validate_fn=validate_with_db
)
```
Args:
header_name (str): The name of the header containing the API key
expected_token (Optional[str]): The expected API key value for simple token matching
validate_fn (Optional[Callable]): Custom validation function that takes an API key
string and returns a boolean or object. Can be async.
error_status (int): HTTP status code to use for validation errors
error_message (str): Error message to return when validation fails
"""
def __init__(
self,
header_name: str,
expected_token: Optional[str] = None,
validate_fn: Optional[Callable[[str], bool]] = None,
error_status: int = HTTP_401_UNAUTHORIZED,
error_message: str = "Invalid API key",
):
# Create the APIKeyHeader as a class property
self.security_scheme = APIKeyHeader(name=header_name)
self.expected_token = expected_token
self.custom_validate_fn = validate_fn
self.error_status = error_status
self.error_message = error_message
async def default_validator(self, api_key: str) -> bool:
return api_key == self.expected_token
async def __call__(
self, request: Request, api_key: str = Security(APIKeyHeader)
) -> Any:
if api_key is None:
raise HTTPException(status_code=self.error_status, detail="Missing API key")
# Use custom validation if provided, otherwise use default equality check
validator = self.custom_validate_fn or self.default_validator
result = (
await validator(api_key)
if inspect.iscoroutinefunction(validator)
else validator(api_key)
)
if not result:
raise HTTPException(
status_code=self.error_status, detail=self.error_message
)
# Store validation result in request state if it's not just a boolean
if result is not True:
request.state.api_key = result
return result
def get_dependency(self):
"""
Returns a callable dependency that FastAPI will recognize as a security scheme
"""
async def validate_api_key(
request: Request, api_key: str = Security(self.security_scheme)
) -> Any:
return await self(request, api_key)
# This helps FastAPI recognize it as a security dependency
validate_api_key.__name__ = f"validate_{self.security_scheme.model.name}"
return validate_api_key

View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.0.1 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@@ -123,7 +123,7 @@ multidict = ">=4.5,<7.0"
yarl = ">=1.0,<2.0"
[package.extras]
speedups = ["Brotli ; platform_python_implementation == \"CPython\"", "aiodns (>=3.2.0) ; sys_platform == \"linux\" or sys_platform == \"darwin\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"]
[[package]]
name = "aiosignal"
@@ -172,7 +172,7 @@ typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""}
[package.extras]
doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"]
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17) ; platform_python_implementation == \"CPython\" and platform_system != \"Windows\""]
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
trio = ["trio (>=0.23)"]
[[package]]
@@ -201,12 +201,12 @@ files = [
]
[package.extras]
benchmark = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.9\"", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.9\" and python_version < \"3.13\"", "pytest-xdist[psutil]"]
cov = ["cloudpickle ; platform_python_implementation == \"CPython\"", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.9\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.9\" and python_version < \"3.13\"", "pytest-xdist[psutil]"]
dev = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.9\"", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.9\" and python_version < \"3.13\"", "pytest-xdist[psutil]"]
benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
tests = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.9\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.9\" and python_version < \"3.13\"", "pytest-xdist[psutil]"]
tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.9\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.9\" and python_version < \"3.13\""]
tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"]
[[package]]
name = "cachetools"
@@ -512,18 +512,18 @@ google-auth = ">=2.14.1,<3.0.dev0"
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
grpcio = [
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
{version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""},
{version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
]
grpcio-status = [
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""},
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
]
proto-plus = ">=1.22.3,<2.0.0dev"
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
requests = ">=2.18.0,<3.0.0.dev0"
[package.extras]
grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev) ; python_version >= \"3.11\"", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0) ; python_version >= \"3.11\""]
grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"]
grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
@@ -842,7 +842,7 @@ idna = "*"
sniffio = "*"
[package.extras]
brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
brotli = ["brotli", "brotlicffi"]
cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
@@ -890,7 +890,7 @@ zipp = ">=0.5"
[package.extras]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
perf = ["ipython"]
test = ["flufl.flake8", "importlib-resources (>=1.3) ; python_version < \"3.9\"", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""]
test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
[[package]]
name = "iniconfig"
@@ -1156,7 +1156,7 @@ typing-extensions = ">=4.12.2"
[package.extras]
email = ["email-validator (>=2.0.0)"]
timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""]
timezone = ["tzdata"]
[[package]]
name = "pydantic-core"
@@ -1476,30 +1476,30 @@ pyasn1 = ">=0.1.3"
[[package]]
name = "ruff"
version = "0.9.10"
version = "0.9.3"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
groups = ["dev"]
files = [
{file = "ruff-0.9.10-py3-none-linux_armv6l.whl", hash = "sha256:eb4d25532cfd9fe461acc83498361ec2e2252795b4f40b17e80692814329e42d"},
{file = "ruff-0.9.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:188a6638dab1aa9bb6228a7302387b2c9954e455fb25d6b4470cb0641d16759d"},
{file = "ruff-0.9.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5284dcac6b9dbc2fcb71fdfc26a217b2ca4ede6ccd57476f52a587451ebe450d"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:47678f39fa2a3da62724851107f438c8229a3470f533894b5568a39b40029c0c"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:99713a6e2766b7a17147b309e8c915b32b07a25c9efd12ada79f217c9c778b3e"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:524ee184d92f7c7304aa568e2db20f50c32d1d0caa235d8ddf10497566ea1a12"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:df92aeac30af821f9acf819fc01b4afc3dfb829d2782884f8739fb52a8119a16"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de42e4edc296f520bb84954eb992a07a0ec5a02fecb834498415908469854a52"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d257f95b65806104b6b1ffca0ea53f4ef98454036df65b1eda3693534813ecd1"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b60dec7201c0b10d6d11be00e8f2dbb6f40ef1828ee75ed739923799513db24c"},
{file = "ruff-0.9.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:d838b60007da7a39c046fcdd317293d10b845001f38bcb55ba766c3875b01e43"},
{file = "ruff-0.9.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ccaf903108b899beb8e09a63ffae5869057ab649c1e9231c05ae354ebc62066c"},
{file = "ruff-0.9.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f9567d135265d46e59d62dc60c0bfad10e9a6822e231f5b24032dba5a55be6b5"},
{file = "ruff-0.9.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5f202f0d93738c28a89f8ed9eaba01b7be339e5d8d642c994347eaa81c6d75b8"},
{file = "ruff-0.9.10-py3-none-win32.whl", hash = "sha256:bfb834e87c916521ce46b1788fbb8484966e5113c02df216680102e9eb960029"},
{file = "ruff-0.9.10-py3-none-win_amd64.whl", hash = "sha256:f2160eeef3031bf4b17df74e307d4c5fb689a6f3a26a2de3f7ef4044e3c484f1"},
{file = "ruff-0.9.10-py3-none-win_arm64.whl", hash = "sha256:5fd804c0327a5e5ea26615550e706942f348b197d5475ff34c19733aee4b2e69"},
{file = "ruff-0.9.10.tar.gz", hash = "sha256:9bacb735d7bada9cfb0f2c227d3658fc443d90a727b47f206fb33f52f3c0eac7"},
{file = "ruff-0.9.3-py3-none-linux_armv6l.whl", hash = "sha256:7f39b879064c7d9670197d91124a75d118d00b0990586549949aae80cdc16624"},
{file = "ruff-0.9.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:a187171e7c09efa4b4cc30ee5d0d55a8d6c5311b3e1b74ac5cb96cc89bafc43c"},
{file = "ruff-0.9.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c59ab92f8e92d6725b7ded9d4a31be3ef42688a115c6d3da9457a5bda140e2b4"},
{file = "ruff-0.9.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2dc153c25e715be41bb228bc651c1e9b1a88d5c6e5ed0194fa0dfea02b026439"},
{file = "ruff-0.9.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:646909a1e25e0dc28fbc529eab8eb7bb583079628e8cbe738192853dbbe43af5"},
{file = "ruff-0.9.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a5a46e09355695fbdbb30ed9889d6cf1c61b77b700a9fafc21b41f097bfbba4"},
{file = "ruff-0.9.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c4bb09d2bbb394e3730d0918c00276e79b2de70ec2a5231cd4ebb51a57df9ba1"},
{file = "ruff-0.9.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:96a87ec31dc1044d8c2da2ebbed1c456d9b561e7d087734336518181b26b3aa5"},
{file = "ruff-0.9.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9bb7554aca6f842645022fe2d301c264e6925baa708b392867b7a62645304df4"},
{file = "ruff-0.9.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cabc332b7075a914ecea912cd1f3d4370489c8018f2c945a30bcc934e3bc06a6"},
{file = "ruff-0.9.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:33866c3cc2a575cbd546f2cd02bdd466fed65118e4365ee538a3deffd6fcb730"},
{file = "ruff-0.9.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:006e5de2621304c8810bcd2ee101587712fa93b4f955ed0985907a36c427e0c2"},
{file = "ruff-0.9.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ba6eea4459dbd6b1be4e6bfc766079fb9b8dd2e5a35aff6baee4d9b1514ea519"},
{file = "ruff-0.9.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:90230a6b8055ad47d3325e9ee8f8a9ae7e273078a66401ac66df68943ced029b"},
{file = "ruff-0.9.3-py3-none-win32.whl", hash = "sha256:eabe5eb2c19a42f4808c03b82bd313fc84d4e395133fb3fc1b1516170a31213c"},
{file = "ruff-0.9.3-py3-none-win_amd64.whl", hash = "sha256:040ceb7f20791dfa0e78b4230ee9dce23da3b64dd5848e40e3bf3ab76468dcf4"},
{file = "ruff-0.9.3-py3-none-win_arm64.whl", hash = "sha256:800d773f6d4d33b0a3c60e2c6ae8f4c202ea2de056365acfa519aa48acf28e0b"},
{file = "ruff-0.9.3.tar.gz", hash = "sha256:8293f89985a090ebc3ed1064df31f3b4b56320cdfcec8b60d3295bddb955c22a"},
]
[[package]]
@@ -1633,7 +1633,7 @@ files = [
]
[package.extras]
brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""]
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
h2 = ["h2 (>=4,<5)"]
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"]
@@ -1919,14 +1919,14 @@ files = [
]
[package.extras]
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""]
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"]
cover = ["pytest-cov"]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
enabler = ["pytest-enabler (>=2.2)"]
test = ["big-O", "importlib-resources ; python_version < \"3.9\"", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"]
test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"]
type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<4.0"
content-hash = "931772287f71c539575d601e6398423bf68e09ca87ae1a144057c7f5707cf978"
content-hash = "a4d81b3b55a67036ca7a441793e13e8fbe20af973fcf1623f36cdee7bc82999f"

View File

@@ -21,7 +21,7 @@ supabase = "^2.13.0"
[tool.poetry.group.dev.dependencies]
redis = "^5.2.1"
ruff = "^0.9.10"
ruff = "^0.9.3"
[build-system]
requires = ["poetry-core"]

View File

@@ -2,23 +2,13 @@ DB_USER=postgres
DB_PASS=your-super-secret-and-long-postgres-password
DB_NAME=postgres
DB_PORT=5432
DB_HOST=localhost
DB_CONNECTION_LIMIT=12
DB_CONNECT_TIMEOUT=60
DB_POOL_TIMEOUT=300
DB_SCHEMA=platform
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@localhost:${DB_PORT}/${DB_NAME}?connect_timeout=60&schema=platform"
PRISMA_SCHEMA="postgres/schema.prisma"
# EXECUTOR
NUM_GRAPH_WORKERS=10
NUM_NODE_WORKERS=3
BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
# generate using `from cryptography.fernet import Fernet;Fernet.generate_key().decode()`
ENCRYPTION_KEY='dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw='
UNSUBSCRIBE_SECRET_KEY = 'HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio='
REDIS_HOST=localhost
REDIS_PORT=6379
@@ -38,7 +28,6 @@ SENTRY_DSN=
# Email For Postmark so we can send emails
POSTMARK_SERVER_API_TOKEN=
POSTMARK_SENDER_EMAIL=invalid@invalid.com
POSTMARK_WEBHOOK_TOKEN=
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
ENABLE_AUTH=true
@@ -52,9 +41,6 @@ RABBITMQ_PORT=5672
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
## GCS bucket is required for marketplace and library functionality
MEDIA_GCS_BUCKET_NAME=
## For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow
## for integrations to work. Defaults to the value of PLATFORM_BASE_URL if not set.
# FRONTEND_BASE_URL=http://localhost:3000
@@ -173,9 +159,6 @@ EXA_API_KEY=
# E2B
E2B_API_KEY=
# Example API Key
EXAMPLE_API_KEY=
# Mem0
MEM0_API_KEY=

View File

@@ -1 +1,75 @@
[Advanced Setup (Dev Branch)](https://dev-docs.agpt.co/platform/advanced_setup/#autogpt_agent_server_advanced_set_up)
# AutoGPT Agent Server Advanced set up
This guide walks you through a dockerized set up, with an external DB (postgres)
## Setup
We use the Poetry to manage the dependencies. To set up the project, follow these steps inside this directory:
0. Install Poetry
```sh
pip install poetry
```
1. Configure Poetry to use .venv in your project directory
```sh
poetry config virtualenvs.in-project true
```
2. Enter the poetry shell
```sh
poetry shell
```
3. Install dependencies
```sh
poetry install
```
4. Copy .env.example to .env
```sh
cp .env.example .env
```
5. Generate the Prisma client
```sh
poetry run prisma generate
```
> In case Prisma generates the client for the global Python installation instead of the virtual environment, the current mitigation is to just uninstall the global Prisma package:
>
> ```sh
> pip uninstall prisma
> ```
>
> Then run the generation again. The path *should* look something like this:
> `<some path>/pypoetry/virtualenvs/backend-TQIRSwR6-py3.12/bin/prisma`
6. Run the postgres database from the /rnd folder
```sh
cd autogpt_platform/
docker compose up -d
```
7. Run the migrations (from the backend folder)
```sh
cd ../backend
prisma migrate deploy
```
## Running The Server
### Starting the server directly
Run the following command:
```sh
poetry run app
```

View File

@@ -1 +1,210 @@
[Getting Started (Released)](https://docs.agpt.co/platform/getting-started/#autogpt_agent_server)
# AutoGPT Agent Server
This is an initial project for creating the next generation of agent execution, which is an AutoGPT agent server.
The agent server will enable the creation of composite multi-agent systems that utilize AutoGPT agents and other non-agent components as its primitives.
## Docs
You can access the docs for the [AutoGPT Agent Server here](https://docs.agpt.co/server/setup).
## Setup
We use the Poetry to manage the dependencies. To set up the project, follow these steps inside this directory:
0. Install Poetry
```sh
pip install poetry
```
1. Configure Poetry to use .venv in your project directory
```sh
poetry config virtualenvs.in-project true
```
2. Enter the poetry shell
```sh
poetry shell
```
3. Install dependencies
```sh
poetry install
```
4. Copy .env.example to .env
```sh
cp .env.example .env
```
5. Generate the Prisma client
```sh
poetry run prisma generate
```
> In case Prisma generates the client for the global Python installation instead of the virtual environment, the current mitigation is to just uninstall the global Prisma package:
>
> ```sh
> pip uninstall prisma
> ```
>
> Then run the generation again. The path *should* look something like this:
> `<some path>/pypoetry/virtualenvs/backend-TQIRSwR6-py3.12/bin/prisma`
6. Migrate the database. Be careful because this deletes current data in the database.
```sh
docker compose up db -d
poetry run prisma migrate deploy
```
## Running The Server
### Starting the server without Docker
To run the server locally, start in the autogpt_platform folder:
```sh
cd ..
```
Run the following command to run database in docker but the application locally:
```sh
docker compose --profile local up deps --build --detach
cd backend
poetry run app
```
### Starting the server with Docker
Run the following command to build the dockerfiles:
```sh
docker compose build
```
Run the following command to run the app:
```sh
docker compose up
```
Run the following to automatically rebuild when code changes, in another terminal:
```sh
docker compose watch
```
Run the following command to shut down:
```sh
docker compose down
```
If you run into issues with dangling orphans, try:
```sh
docker compose down --volumes --remove-orphans && docker-compose up --force-recreate --renew-anon-volumes --remove-orphans
```
## Testing
To run the tests:
```sh
poetry run test
```
## Development
### Formatting & Linting
Auto formatter and linter are set up in the project. To run them:
Install:
```sh
poetry install --with dev
```
Format the code:
```sh
poetry run format
```
Lint the code:
```sh
poetry run lint
```
## Project Outline
The current project has the following main modules:
### **blocks**
This module stores all the Agent Blocks, which are reusable components to build a graph that represents the agent's behavior.
### **data**
This module stores the logical model that is persisted in the database.
It abstracts the database operations into functions that can be called by the service layer.
Any code that interacts with Prisma objects or the database should reside in this module.
The main models are:
* `block`: anything related to the block used in the graph
* `execution`: anything related to the execution graph execution
* `graph`: anything related to the graph, node, and its relations
### **execution**
This module stores the business logic of executing the graph.
It currently has the following main modules:
* `manager`: A service that consumes the queue of the graph execution and executes the graph. It contains both pieces of logic.
* `scheduler`: A service that triggers scheduled graph execution based on a cron expression. It pushes an execution request to the manager.
### **server**
This module stores the logic for the server API.
It contains all the logic used for the API that allows the client to create, execute, and monitor the graph and its execution.
This API service interacts with other services like those defined in `manager` and `scheduler`.
### **utils**
This module stores utility functions that are used across the project.
Currently, it has two main modules:
* `process`: A module that contains the logic to spawn a new process.
* `service`: A module that serves as a parent class for all the services in the project.
## Service Communication
Currently, there are only 3 active services:
- AgentServer (the API, defined in `server.py`)
- ExecutionManager (the executor, defined in `manager.py`)
- ExecutionScheduler (the scheduler, defined in `scheduler.py`)
The services run in independent Python processes and communicate through an IPC.
A communication layer (`service.py`) is created to decouple the communication library from the implementation.
Currently, the IPC is done using Pyro5 and abstracted in a way that allows a function decorated with `@expose` to be called from a different process.
By default the daemons run on the following ports:
Execution Manager Daemon: 8002
Execution Scheduler Daemon: 8003
Rest Server Daemon: 8004
## Adding a New Agent Block
To add a new agent block, you need to create a new class that inherits from `Block` and provides the following information:
* All the block code should live in the `blocks` (`backend.blocks`) module.
* `input_schema`: the schema of the input data, represented by a Pydantic object.
* `output_schema`: the schema of the output data, represented by a Pydantic object.
* `run` method: the main logic of the block.
* `test_input` & `test_output`: the sample input and output data for the block, which will be used to auto-test the block.
* You can mock the functions declared in the block using the `test_mock` field for your unit tests.
* Once you finish creating the block, you can test it by running `poetry run pytest -s test/block/test_block.py`.

View File

@@ -1,30 +1,22 @@
import logging
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from backend.util.process import AppProcess
logger = logging.getLogger(__name__)
def run_processes(*processes: "AppProcess", **kwargs):
"""
Execute all processes in the app. The last process is run in the foreground.
Includes enhanced error handling and process lifecycle management.
"""
try:
# Run all processes except the last one in the background.
for process in processes[:-1]:
process.start(background=True, **kwargs)
# Run the last process in the foreground.
# Run the last process in the foreground
processes[-1].start(background=False, **kwargs)
finally:
for process in processes:
try:
process.stop()
except Exception as e:
logger.exception(f"[{process.service_name}] unable to stop: {e}")
process.stop()
def main(**kwargs):
@@ -32,7 +24,7 @@ def main(**kwargs):
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
"""
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
from backend.executor import DatabaseManager, ExecutionManager, ExecutionScheduler
from backend.notifications import NotificationManager
from backend.server.rest_api import AgentServer
from backend.server.ws_api import WebsocketServer
@@ -40,7 +32,7 @@ def main(**kwargs):
run_processes(
DatabaseManager(),
ExecutionManager(),
Scheduler(),
ExecutionScheduler(),
NotificationManager(),
WebsocketServer(),
AgentServer(),

View File

@@ -1,5 +1,4 @@
import logging
from typing import Any
from autogpt_libs.utils.cache import thread_cached
@@ -14,7 +13,6 @@ from backend.data.block import (
)
from backend.data.execution import ExecutionStatus
from backend.data.model import SchemaField
from backend.util import json
logger = logging.getLogger(__name__)
@@ -44,23 +42,6 @@ class AgentExecutorBlock(Block):
input_schema: dict = SchemaField(description="Input schema for the graph")
output_schema: dict = SchemaField(description="Output schema for the graph")
@classmethod
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
return data.get("input_schema", {})
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data.get("data", {})
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
required_fields = cls.get_input_schema(data).get("required", [])
return set(required_fields) - set(data)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
class Output(BlockSchema):
pass

View File

@@ -3,7 +3,6 @@ from typing import Any, List
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
from backend.data.model import SchemaField
from backend.util import json
from backend.util.file import MediaFile, store_media_file
from backend.util.mock import MockObject
from backend.util.text import TextFormatter
@@ -154,9 +153,6 @@ class FindInDictionaryBlock(Block):
obj = input_data.input
key = input_data.key
if isinstance(obj, str):
obj = json.loads(obj)
if isinstance(obj, dict) and key in obj:
yield "output", obj[key]
elif isinstance(obj, list) and isinstance(key, int) and 0 <= key < len(obj):

View File

@@ -51,7 +51,6 @@ class ExaContentsBlock(Block):
description="List of document contents",
default=[],
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(

View File

@@ -1,137 +0,0 @@
"""
API module for Example API integration.
This module provides a example of how to create a client for an API.
"""
# We also have a Json Wrapper library available in backend.util.json
from json import JSONDecodeError
from typing import Any, Optional
from pydantic import BaseModel
from backend.data.model import APIKeyCredentials
# This is a wrapper around the requests library that is used to make API requests.
from backend.util.request import Requests
class ExampleAPIException(Exception):
def __init__(self, message: str, status_code: int):
super().__init__(message)
self.status_code = status_code
class CreateResourceResponse(BaseModel):
message: str
is_funny: bool
class GetResourceResponse(BaseModel):
message: str
is_funny: bool
class ExampleClient:
"""Client for the Example API"""
API_BASE_URL = "https://api.example.com/v1"
def __init__(
self,
credentials: Optional[APIKeyCredentials] = None,
custom_requests: Optional[Requests] = None,
):
if custom_requests:
self._requests = custom_requests
else:
headers: dict[str, str] = {
"Content-Type": "application/json",
}
if credentials:
headers["Authorization"] = credentials.auth_header()
self._requests = Requests(
extra_headers=headers,
raise_for_status=False,
)
@staticmethod
def _handle_response(response) -> Any:
"""
Handles API response and checks for errors.
Args:
response: The response object from the request.
Returns:
The parsed JSON response data.
Raises:
ExampleAPIException: If the API request fails.
"""
if not response.ok:
try:
error_data = response.json()
error_message = error_data.get("error", {}).get("message", "")
except JSONDecodeError:
error_message = response.text
raise ExampleAPIException(
f"Example API request failed ({response.status_code}): {error_message}",
response.status_code,
)
response_data = response.json()
if "errors" in response_data:
# This is an example error and needs to be
# replaced with how the real API returns errors
error_messages = [
error.get("message", "") for error in response_data["errors"]
]
raise ExampleAPIException(
f"Example API returned errors: {', '.join(error_messages)}",
response.status_code,
)
return response_data
def get_resource(self, resource_id: str) -> GetResourceResponse:
"""
Fetches a resource from the Example API.
Args:
resource_id: The ID of the resource to fetch.
Returns:
The resource data as a GetResourceResponse object.
Raises:
ExampleAPIException: If the API request fails.
"""
try:
response = self._requests.get(
f"{self.API_BASE_URL}/resources/{resource_id}"
)
return GetResourceResponse(**self._handle_response(response))
except Exception as e:
raise ExampleAPIException(f"Failed to get resource: {str(e)}", 500)
def create_resource(self, data: dict) -> CreateResourceResponse:
"""
Creates a new resource via the Example API.
Args:
data: The resource data to create.
Returns:
The created resource data as a CreateResourceResponse object.
Raises:
ExampleAPIException: If the API request fails.
"""
try:
response = self._requests.post(f"{self.API_BASE_URL}/resources", json=data)
return CreateResourceResponse(**self._handle_response(response))
except Exception as e:
raise ExampleAPIException(f"Failed to create resource: {str(e)}", 500)

View File

@@ -1,37 +0,0 @@
"""
Authentication module for Example API integration.
This module provides credential types and test credentials for the Example API integration.
It defines the structure for API key credentials used to authenticate with the Example API
and provides mock credentials for testing purposes.
"""
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
from backend.integrations.providers import ProviderName
# Define the type of credentials input expected for Example API
ExampleCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.EXAMPLE_PROVIDER], Literal["api_key"]
]
# Mock credentials for testing Example API integration
TEST_CREDENTIALS = APIKeyCredentials(
id="9191c4f0-498f-4235-a79c-59c0e37454d4",
provider="example-provider",
api_key=SecretStr("mock-example-api-key"),
title="Mock Example API key",
expires_at=None,
)
# Dictionary representation of test credentials for input fields
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}

View File

@@ -1,154 +0,0 @@
import logging
from pydantic import BaseModel
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
from ._api import ExampleClient
from ._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, ExampleCredentialsInput
logger = logging.getLogger(__name__)
class GreetingMessage(BaseModel):
message: str
is_funny: bool
class ExampleBlock(Block):
class Input(BlockSchema):
name: str = SchemaField(
description="The name of the example block", placeholder="Enter a name"
)
greetings: list[str] = SchemaField(
description="The greetings to display", default=["Hello", "Hi", "Hey"]
)
is_funny: bool = SchemaField(
description="Whether the block is funny",
placeholder="True",
default=True,
# Advanced fields are moved to the "Advanced" dropdown in the UI
advanced=True,
)
greeting_context: str = SchemaField(
description="The context of the greeting",
placeholder="Enter a context",
default="The user is looking for an inspirational greeting",
# Hidden fields are not shown in the UI at all
hidden=True,
)
# Only if the block needs credentials
credentials: ExampleCredentialsInput = CredentialsField(
description="The credentials for the example block"
)
class Output(BlockSchema):
response: GreetingMessage = SchemaField(
description="The response object generated by the example block."
)
all_responses: list[GreetingMessage] = SchemaField(
description="All the responses from the example block."
)
greeting_count: int = SchemaField(
description="The number of greetings in the input."
)
error: str = SchemaField(description="The error from the example block")
def __init__(self):
super().__init__(
# The unique identifier for the block, this value will be persisted in the DB.
# It should be unique and constant across the application run.
# Use the UUID format for the ID.
id="380694d5-3b2e-4130-bced-b43752b70de9",
# The description of the block, explaining what the block does.
description="The example block",
# The set of categories that the block belongs to.
# Each category is an instance of BlockCategory Enum.
categories={BlockCategory.BASIC},
# The schema, defined as a Pydantic model, for the input data.
input_schema=ExampleBlock.Input,
# The schema, defined as a Pydantic model, for the output data.
output_schema=ExampleBlock.Output,
# The list or single sample input data for the block, for testing.
# This is an instance of the Input schema with sample values.
test_input={
"name": "Craig",
"greetings": ["Hello", "Hi", "Hey"],
"is_funny": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
# The list or single expected output if the test_input is run.
# Each output is a tuple of (output_name, output_data).
test_output=[
("response", GreetingMessage(message="Hello, world!", is_funny=True)),
(
"response",
GreetingMessage(message="Hello, world!", is_funny=True),
), # We mock the function
(
"response",
GreetingMessage(message="Hello, world!", is_funny=True),
), # We mock the function
(
"all_responses",
[
GreetingMessage(message="Hello, world!", is_funny=True),
GreetingMessage(message="Hello, world!", is_funny=True),
GreetingMessage(message="Hello, world!", is_funny=True),
],
),
("greeting_count", 3),
],
# Function names on the block implementation to mock on test run.
# Each mock is a dictionary with function names as keys and mock implementations as values.
test_mock={
"my_function_that_can_be_mocked": lambda *args, **kwargs: GreetingMessage(
message="Hello, world!", is_funny=True
)
},
# The credentials required for testing the block.
# This is an instance of APIKeyCredentials with sample values.
test_credentials=TEST_CREDENTIALS,
)
def my_function_that_can_be_mocked(
self, name: str, credentials: APIKeyCredentials
) -> GreetingMessage:
logger.info("my_function_that_can_be_mocked called with input: %s", name)
# Use the ExampleClient from _api.py to make an API call
client = ExampleClient(credentials=credentials)
# Create a sample resource using the client
resource_data = {"name": name, "type": "greeting"}
# If your API response object matches the return type of the function,
# there is no need to convert the object. In this case we have a different
# object type for the response and the return type of the function.
return GreetingMessage(**client.create_resource(resource_data).model_dump())
def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
The run function implements the block's core logic. It processes the input_data
and yields the block's output.
In addition to credentials, the following parameters can be specified:
graph_id: The ID of the graph containing this block.
node_id: The ID of this block's node in the graph.
graph_exec_id: The ID of the current graph execution.
node_exec_id: The ID of the current node execution.
user_id: The ID of the user executing the block.
"""
rtn_all_responses: list[GreetingMessage] = []
# Here we deomonstrate best practice for blocks that need to yield multiple items.
# We yield each item from the list to allow for operations on each element.
# We also yield the complete list for situations when the full list is needed.
for greeting in input_data.greetings:
message = self.my_function_that_can_be_mocked(greeting, credentials)
rtn_all_responses.append(message)
yield "response", message
yield "all_responses", rtn_all_responses
yield "greeting_count", len(input_data.greetings)

View File

@@ -1,65 +0,0 @@
import logging
from backend.data.block import (
Block,
BlockCategory,
BlockManualWebhookConfig,
BlockOutput,
BlockSchema,
)
from backend.data.model import SchemaField
from backend.integrations.webhooks.example import ExampleWebhookEventType
logger = logging.getLogger(__name__)
class ExampleTriggerBlock(Block):
"""
A trigger block that is activated by an external webhook event.
Unlike standard blocks that are manually executed, trigger blocks are automatically
activated when a webhook event is received from the specified provider.
"""
class Input(BlockSchema):
# The payload field is hidden because it's automatically populated by the webhook
# system rather than being manually entered by the user
payload: dict = SchemaField(hidden=True)
class Output(BlockSchema):
event_data: dict = SchemaField(
description="The contents of the example webhook event."
)
def __init__(self):
super().__init__(
id="7c5933ce-d60c-42dd-9c4e-db82496474a3",
description="This block will output the contents of an example webhook event.",
categories={BlockCategory.BASIC},
input_schema=ExampleTriggerBlock.Input,
output_schema=ExampleTriggerBlock.Output,
# The webhook_config is a key difference from standard blocks
# It defines which external service can trigger this block and what type of events it responds to
webhook_config=BlockManualWebhookConfig(
provider="example_provider", # The external service that will send webhook events
webhook_type=ExampleWebhookEventType.EXAMPLE_EVENT, # The specific event type this block responds to
),
# Test input for trigger blocks should mimic the payload structure that would be received from the webhook
test_input=[
{
"payload": {
"event_type": "example",
"data": "Sample webhook data",
}
}
],
test_output=[
("event_data", {"event_type": "example", "data": "Sample webhook data"})
],
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
# For trigger blocks, the run method is called automatically when a webhook event is received
# The payload from the webhook is passed in as input_data.payload
logger.info("Example trigger block run with payload: %s", input_data.payload)
yield "event_data", input_data.payload

View File

@@ -38,59 +38,6 @@ def _get_headers(credentials: GithubCredentials) -> dict[str, str]:
}
def convert_comment_url_to_api_endpoint(comment_url: str) -> str:
"""
Converts a GitHub comment URL (web interface) to the appropriate API endpoint URL.
Handles:
1. Issue/PR comments: #issuecomment-{id}
2. PR review comments: #discussion_r{id}
Returns the appropriate API endpoint path for the comment.
"""
# First, check if this is already an API URL
parsed_url = urlparse(comment_url)
if parsed_url.hostname == "api.github.com":
return comment_url
# Replace pull with issues for comment endpoints
if "/pull/" in comment_url:
comment_url = comment_url.replace("/pull/", "/issues/")
# Handle issue/PR comments (#issuecomment-xxx)
if "#issuecomment-" in comment_url:
base_url, comment_part = comment_url.split("#issuecomment-")
comment_id = comment_part
# Extract repo information from base URL
parsed_url = urlparse(base_url)
path_parts = parsed_url.path.strip("/").split("/")
owner, repo = path_parts[0], path_parts[1]
# Construct API URL for issue comments
return (
f"https://api.github.com/repos/{owner}/{repo}/issues/comments/{comment_id}"
)
# Handle PR review comments (#discussion_r)
elif "#discussion_r" in comment_url:
base_url, comment_part = comment_url.split("#discussion_r")
comment_id = comment_part
# Extract repo information from base URL
parsed_url = urlparse(base_url)
path_parts = parsed_url.path.strip("/").split("/")
owner, repo = path_parts[0], path_parts[1]
# Construct API URL for PR review comments
return (
f"https://api.github.com/repos/{owner}/{repo}/pulls/comments/{comment_id}"
)
# If no specific comment identifiers are found, use the general URL conversion
return _convert_to_api_url(comment_url)
def get_api(
credentials: GithubCredentials | GithubFineGrainedAPICredentials,
convert_urls: bool = True,

View File

@@ -172,9 +172,7 @@ class GithubCreateCheckRunBlock(Block):
data.output = output_data
check_runs_url = f"{repo_url}/check-runs"
response = api.post(
check_runs_url, data=data.model_dump_json(exclude_none=True)
)
response = api.post(check_runs_url)
result = response.json()
return {
@@ -325,9 +323,7 @@ class GithubUpdateCheckRunBlock(Block):
data.output = output_data
check_run_url = f"{repo_url}/check-runs/{check_run_id}"
response = api.patch(
check_run_url, data=data.model_dump_json(exclude_none=True)
)
response = api.patch(check_run_url)
result = response.json()
return {

View File

@@ -1,4 +1,3 @@
import logging
from urllib.parse import urlparse
from typing_extensions import TypedDict
@@ -6,7 +5,7 @@ from typing_extensions import TypedDict
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._api import convert_comment_url_to_api_endpoint, get_api
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
@@ -15,8 +14,6 @@ from ._auth import (
GithubCredentialsInput,
)
logger = logging.getLogger(__name__)
def is_github_url(url: str) -> bool:
return urlparse(url).netloc == "github.com"
@@ -111,228 +108,6 @@ class GithubCommentBlock(Block):
# --8<-- [end:GithubCommentBlockExample]
class GithubUpdateCommentBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
comment_url: str = SchemaField(
description="URL of the GitHub comment",
placeholder="https://github.com/owner/repo/issues/1#issuecomment-123456789",
default="",
advanced=False,
)
issue_url: str = SchemaField(
description="URL of the GitHub issue or pull request",
placeholder="https://github.com/owner/repo/issues/1",
default="",
)
comment_id: str = SchemaField(
description="ID of the GitHub comment",
placeholder="123456789",
default="",
)
comment: str = SchemaField(
description="Comment to update",
placeholder="Enter your comment",
)
class Output(BlockSchema):
id: int = SchemaField(description="ID of the updated comment")
url: str = SchemaField(description="URL to the comment on GitHub")
error: str = SchemaField(
description="Error message if the comment update failed"
)
def __init__(self):
super().__init__(
id="b3f4d747-10e3-4e69-8c51-f2be1d99c9a7",
description="This block updates a comment on a specified GitHub issue or pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubUpdateCommentBlock.Input,
output_schema=GithubUpdateCommentBlock.Output,
test_input={
"comment_url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
"comment": "This is an updated comment.",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("id", 123456789),
(
"url",
"https://github.com/owner/repo/issues/1#issuecomment-123456789",
),
],
test_mock={
"update_comment": lambda *args, **kwargs: (
123456789,
"https://github.com/owner/repo/issues/1#issuecomment-123456789",
)
},
)
@staticmethod
def update_comment(
credentials: GithubCredentials, comment_url: str, body_text: str
) -> tuple[int, str]:
api = get_api(credentials, convert_urls=False)
data = {"body": body_text}
url = convert_comment_url_to_api_endpoint(comment_url)
logger.info(url)
response = api.patch(url, json=data)
comment = response.json()
return comment["id"], comment["html_url"]
def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
if (
not input_data.comment_url
and input_data.comment_id
and input_data.issue_url
):
parsed_url = urlparse(input_data.issue_url)
path_parts = parsed_url.path.strip("/").split("/")
owner, repo = path_parts[0], path_parts[1]
input_data.comment_url = f"https://api.github.com/repos/{owner}/{repo}/issues/comments/{input_data.comment_id}"
elif (
not input_data.comment_url
and not input_data.comment_id
and input_data.issue_url
):
raise ValueError(
"Must provide either comment_url or comment_id and issue_url"
)
id, url = self.update_comment(
credentials,
input_data.comment_url,
input_data.comment,
)
yield "id", id
yield "url", url
class GithubListCommentsBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
issue_url: str = SchemaField(
description="URL of the GitHub issue or pull request",
placeholder="https://github.com/owner/repo/issues/1",
)
class Output(BlockSchema):
class CommentItem(TypedDict):
id: int
body: str
user: str
url: str
comment: CommentItem = SchemaField(
title="Comment", description="Comments with their ID, body, user, and URL"
)
comments: list[CommentItem] = SchemaField(
description="List of comments with their ID, body, user, and URL"
)
error: str = SchemaField(description="Error message if listing comments failed")
def __init__(self):
super().__init__(
id="c4b5fb63-0005-4a11-b35a-0c2467bd6b59",
description="This block lists all comments for a specified GitHub issue or pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListCommentsBlock.Input,
output_schema=GithubListCommentsBlock.Output,
test_input={
"issue_url": "https://github.com/owner/repo/issues/1",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"comment",
{
"id": 123456789,
"body": "This is a test comment.",
"user": "test_user",
"url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
},
),
(
"comments",
[
{
"id": 123456789,
"body": "This is a test comment.",
"user": "test_user",
"url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
}
],
),
],
test_mock={
"list_comments": lambda *args, **kwargs: [
{
"id": 123456789,
"body": "This is a test comment.",
"user": "test_user",
"url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
}
]
},
)
@staticmethod
def list_comments(
credentials: GithubCredentials, issue_url: str
) -> list[Output.CommentItem]:
parsed_url = urlparse(issue_url)
path_parts = parsed_url.path.strip("/").split("/")
owner = path_parts[0]
repo = path_parts[1]
# GitHub API uses 'issues' for both issues and pull requests when it comes to comments
issue_number = path_parts[3] # Whether 'issues/123' or 'pull/123'
# Construct the proper API URL directly
api_url = f"https://api.github.com/repos/{owner}/{repo}/issues/{issue_number}/comments"
# Set convert_urls=False since we're already providing an API URL
api = get_api(credentials, convert_urls=False)
response = api.get(api_url)
comments = response.json()
parsed_comments: list[GithubListCommentsBlock.Output.CommentItem] = [
{
"id": comment["id"],
"body": comment["body"],
"user": comment["user"]["login"],
"url": comment["html_url"],
}
for comment in comments
]
return parsed_comments
def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
comments = self.list_comments(
credentials,
input_data.issue_url,
)
yield from (("comment", comment) for comment in comments)
yield "comments", comments
class GithubMakeIssueBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")

View File

@@ -144,7 +144,7 @@ class GithubCreateStatusBlock(Block):
data.description = description
status_url = f"{repo_url}/statuses/{sha}"
response = api.post(status_url, data=data.model_dump_json(exclude_none=True))
response = api.post(status_url, json=data)
result = response.json()
return {

View File

@@ -8,7 +8,6 @@ from pydantic import BaseModel
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.settings import Settings
from ._auth import (
GOOGLE_OAUTH_IS_CONFIGURED,
@@ -151,8 +150,8 @@ class GmailReadBlock(Block):
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=Settings().secrets.google_client_id,
client_secret=Settings().secrets.google_client_secret,
client_id=kwargs.get("client_id"),
client_secret=kwargs.get("client_secret"),
scopes=credentials.scopes,
)
return build("gmail", "v1", credentials=creds)

View File

@@ -3,7 +3,6 @@ from googleapiclient.discovery import build
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.settings import Settings
from ._auth import (
GOOGLE_OAUTH_IS_CONFIGURED,
@@ -87,8 +86,8 @@ class GoogleSheetsReadBlock(Block):
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=Settings().secrets.google_client_id,
client_secret=Settings().secrets.google_client_secret,
client_id=kwargs.get("client_id"),
client_secret=kwargs.get("client_secret"),
scopes=credentials.scopes,
)
return build("sheets", "v4", credentials=creds)

View File

@@ -142,16 +142,6 @@ class IdeogramModelBlock(Block):
title="Color Palette Preset",
advanced=True,
)
custom_color_palette: Optional[list[str]] = SchemaField(
description=(
"Only available for model version V_2 or V_2_TURBO. Provide one or more color hex codes "
"(e.g., ['#000030', '#1C0C47', '#9900FF', '#4285F4', '#FFFFFF']) to define a custom color "
"palette. Only used if 'color_palette_name' is 'NONE'."
),
default=None,
title="Custom Color Palette",
advanced=True,
)
class Output(BlockSchema):
result: str = SchemaField(description="Generated image URL")
@@ -174,13 +164,6 @@ class IdeogramModelBlock(Block):
"style_type": StyleType.AUTO,
"negative_prompt": None,
"color_palette_name": ColorPalettePreset.NONE,
"custom_color_palette": [
"#000030",
"#1C0C47",
"#9900FF",
"#4285F4",
"#FFFFFF",
],
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
@@ -190,7 +173,7 @@ class IdeogramModelBlock(Block):
),
],
test_mock={
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name, custom_colors: "https://ideogram.ai/api/images/test-generated-image-url.png",
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name: "https://ideogram.ai/api/images/test-generated-image-url.png",
"upscale_image": lambda api_key, image_url: "https://ideogram.ai/api/images/test-upscaled-image-url.png",
},
test_credentials=TEST_CREDENTIALS,
@@ -212,7 +195,6 @@ class IdeogramModelBlock(Block):
style_type=input_data.style_type.value,
negative_prompt=input_data.negative_prompt,
color_palette_name=input_data.color_palette_name.value,
custom_colors=input_data.custom_color_palette,
)
# Step 2: Upscale the image if requested
@@ -235,7 +217,6 @@ class IdeogramModelBlock(Block):
style_type: str,
negative_prompt: Optional[str],
color_palette_name: str,
custom_colors: Optional[list[str]],
):
url = "https://api.ideogram.ai/generate"
headers = {
@@ -260,11 +241,7 @@ class IdeogramModelBlock(Block):
data["image_request"]["negative_prompt"] = negative_prompt
if color_palette_name != "NONE":
data["color_palette"] = {"name": color_palette_name}
elif custom_colors:
data["color_palette"] = {
"members": [{"color_hex": color} for color in custom_colors]
}
data["image_request"]["color_palette"] = {"name": color_palette_name}
try:
response = requests.post(url, json=data, headers=headers)
@@ -290,7 +267,9 @@ class IdeogramModelBlock(Block):
response = requests.post(
url,
headers=headers,
data={"image_request": "{}"},
data={
"image_request": "{}", # Empty JSON object
},
files=files,
)

View File

@@ -4,11 +4,10 @@ from abc import ABC
from enum import Enum, EnumMeta
from json import JSONDecodeError
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, NamedTuple, Optional
from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple
from pydantic import BaseModel, SecretStr
from pydantic import SecretStr
from backend.data.model import NodeExecutionStats
from backend.integrations.providers import ProviderName
if TYPE_CHECKING:
@@ -17,8 +16,6 @@ if TYPE_CHECKING:
import anthropic
import ollama
import openai
from anthropic._types import NotGiven
from anthropic.types import ToolParam
from groq import Groq
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
@@ -230,299 +227,15 @@ for model in LlmModel:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
class ToolCall(BaseModel):
name: str
arguments: str
class MessageRole(str, Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
class ToolContentBlock(BaseModel):
id: str
type: str
function: ToolCall
class LLMResponse(BaseModel):
raw_response: Any
prompt: List[Any]
response: str
tool_calls: Optional[List[ToolContentBlock]] | None
prompt_tokens: int
completion_tokens: int
def convert_openai_tool_fmt_to_anthropic(
openai_tools: list[dict] | None = None,
) -> Iterable[ToolParam] | NotGiven:
"""
Convert OpenAI tool format to Anthropic tool format.
"""
if not openai_tools or len(openai_tools) == 0:
return anthropic.NOT_GIVEN
anthropic_tools = []
for tool in openai_tools:
if "function" in tool:
# Handle case where tool is already in OpenAI format with "type" and "function"
function_data = tool["function"]
else:
# Handle case where tool is just the function definition
function_data = tool
anthropic_tool: anthropic.types.ToolParam = {
"name": function_data["name"],
"description": function_data.get("description", ""),
"input_schema": {
"type": "object",
"properties": function_data.get("parameters", {}).get("properties", {}),
"required": function_data.get("parameters", {}).get("required", []),
},
}
anthropic_tools.append(anthropic_tool)
return anthropic_tools
def llm_call(
credentials: APIKeyCredentials,
llm_model: LlmModel,
prompt: list[dict],
json_format: bool,
max_tokens: int | None,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
) -> LLMResponse:
"""
Make a call to a language model.
Args:
credentials: The API key credentials to use.
llm_model: The LLM model to use.
prompt: The prompt to send to the LLM.
json_format: Whether the response should be in JSON format.
max_tokens: The maximum number of tokens to generate in the chat completion.
tools: The tools to use in the chat completion.
ollama_host: The host for ollama to use.
Returns:
LLMResponse object containing:
- prompt: The prompt sent to the LLM.
- response: The text response from the LLM.
- tool_calls: Any tool calls the model made, if applicable.
- prompt_tokens: The number of tokens used in the prompt.
- completion_tokens: The number of tokens used in the completion.
"""
provider = llm_model.metadata.provider
max_tokens = max_tokens or llm_model.max_output_tokens or 4096
if provider == "openai":
tools_param = tools if tools else openai.NOT_GIVEN
oai_client = openai.OpenAI(api_key=credentials.api_key.get_secret_value())
response_format = None
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
prompt = [
{"role": "user", "content": "\n".join(sys_messages)},
{"role": "user", "content": "\n".join(usr_messages)},
]
elif json_format:
response_format = {"type": "json_object"}
response = oai_client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
tools=tools_param, # type: ignore
)
if response.choices[0].message.tool_calls:
tool_calls = [
ToolContentBlock(
id=tool.id,
type=tool.type,
function=ToolCall(
name=tool.function.name,
arguments=tool.function.arguments,
),
)
for tool in response.choices[0].message.tool_calls
]
else:
tool_calls = None
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
response=response.choices[0].message.content or "",
tool_calls=tool_calls,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
)
elif provider == "anthropic":
an_tools = convert_openai_tool_fmt_to_anthropic(tools)
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
sysprompt = " ".join(system_messages)
messages = []
last_role = None
for p in prompt:
if p["role"] in ["user", "assistant"]:
if (
p["role"] == last_role
and isinstance(messages[-1]["content"], str)
and isinstance(p["content"], str)
):
# If the role is the same as the last one, combine the content
messages[-1]["content"] += p["content"]
else:
messages.append({"role": p["role"], "content": p["content"]})
last_role = p["role"]
client = anthropic.Anthropic(api_key=credentials.api_key.get_secret_value())
try:
resp = client.messages.create(
model=llm_model.value,
system=sysprompt,
messages=messages,
max_tokens=max_tokens,
tools=an_tools,
)
if not resp.content:
raise ValueError("No content returned from Anthropic.")
tool_calls = None
for content_block in resp.content:
# Antropic is different to openai, need to iterate through
# the content blocks to find the tool calls
if content_block.type == "tool_use":
if tool_calls is None:
tool_calls = []
tool_calls.append(
ToolContentBlock(
id=content_block.id,
type=content_block.type,
function=ToolCall(
name=content_block.name,
arguments=json.dumps(content_block.input),
),
)
)
if not tool_calls and resp.stop_reason == "tool_use":
logger.warning(
"Tool use stop reason but no tool calls found in content. %s", resp
)
return LLMResponse(
raw_response=resp,
prompt=prompt,
response=(
resp.content[0].name
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
else resp.content[0].text
),
tool_calls=tool_calls,
prompt_tokens=resp.usage.input_tokens,
completion_tokens=resp.usage.output_tokens,
)
except anthropic.APIError as e:
error_message = f"Anthropic API error: {str(e)}"
logger.error(error_message)
raise ValueError(error_message)
elif provider == "groq":
if tools:
raise ValueError("Groq does not support tools.")
client = Groq(api_key=credentials.api_key.get_secret_value())
response_format = {"type": "json_object"} if json_format else None
response = client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
)
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
response=response.choices[0].message.content or "",
tool_calls=None,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
)
elif provider == "ollama":
if tools:
raise ValueError("Ollama does not support tools.")
client = ollama.Client(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
response = client.generate(
model=llm_model.value,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
)
return LLMResponse(
raw_response=response.get("response") or "",
prompt=prompt,
response=response.get("response") or "",
tool_calls=None,
prompt_tokens=response.get("prompt_eval_count") or 0,
completion_tokens=response.get("eval_count") or 0,
)
elif provider == "open_router":
tools_param = tools if tools else openai.NOT_GIVEN
client = openai.OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=credentials.api_key.get_secret_value(),
)
response = client.chat.completions.create(
extra_headers={
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
)
# If there's no response, raise an error
if not response.choices:
if response:
raise ValueError(f"OpenRouter error: {response}")
else:
raise ValueError("No response from OpenRouter.")
if response.choices[0].message.tool_calls:
tool_calls = [
ToolContentBlock(
id=tool.id,
type=tool.type,
function=ToolCall(
name=tool.function.name, arguments=tool.function.arguments
),
)
for tool in response.choices[0].message.tool_calls
]
else:
tool_calls = None
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
response=response.choices[0].message.content or "",
tool_calls=tool_calls,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
class Message(BlockSchema):
role: MessageRole
content: str
class AIBlockBase(Block, ABC):
@@ -547,7 +260,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=LlmModel.GPT4_TURBO,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -557,7 +270,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
default="",
description="The system prompt to provide additional context to the model.",
)
conversation_history: list[dict] = SchemaField(
conversation_history: list[Message] = SchemaField(
default=[],
description="The conversation history to provide context for the prompt.",
)
@@ -598,7 +311,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
input_schema=AIStructuredResponseGeneratorBlock.Input,
output_schema=AIStructuredResponseGeneratorBlock.Output,
test_input={
"model": LlmModel.GPT4O,
"model": LlmModel.GPT4_TURBO,
"credentials": TEST_CREDENTIALS_INPUT,
"expected_format": {
"key1": "value1",
@@ -612,21 +325,19 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
("prompt", str),
],
test_mock={
"llm_call": lambda *args, **kwargs: LLMResponse(
raw_response="",
prompt=[""],
response=json.dumps(
"llm_call": lambda *args, **kwargs: (
json.dumps(
{
"key1": "key1Value",
"key2": "key2Value",
}
),
tool_calls=None,
prompt_tokens=0,
completion_tokens=0,
0,
0,
)
},
)
self.prompt = ""
def llm_call(
self,
@@ -635,28 +346,160 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
prompt: list[dict],
json_format: bool,
max_tokens: int | None,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
) -> LLMResponse:
) -> tuple[str, int, int]:
"""
Test mocks work only on class functions, this wraps the llm_call function
so that it can be mocked withing the block testing framework.
Args:
credentials: The API key credentials to use.
llm_model: The LLM model to use.
prompt: The prompt to send to the LLM.
json_format: Whether the response should be in JSON format.
max_tokens: The maximum number of tokens to generate in the chat completion.
ollama_host: The host for ollama to use
Returns:
The response from the LLM.
The number of tokens used in the prompt.
The number of tokens used in the completion.
"""
return llm_call(
credentials=credentials,
llm_model=llm_model,
prompt=prompt,
json_format=json_format,
max_tokens=max_tokens,
tools=tools,
ollama_host=ollama_host,
)
provider = llm_model.metadata.provider
max_tokens = max_tokens or llm_model.max_output_tokens or 4096
if provider == "openai":
oai_client = openai.OpenAI(api_key=credentials.api_key.get_secret_value())
response_format = None
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
prompt = [
{"role": "user", "content": "\n".join(sys_messages)},
{"role": "user", "content": "\n".join(usr_messages)},
]
elif json_format:
response_format = {"type": "json_object"}
response = oai_client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
)
self.prompt = json.dumps(prompt)
return (
response.choices[0].message.content or "",
response.usage.prompt_tokens if response.usage else 0,
response.usage.completion_tokens if response.usage else 0,
)
elif provider == "anthropic":
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
sysprompt = " ".join(system_messages)
messages = []
last_role = None
for p in prompt:
if p["role"] in ["user", "assistant"]:
if p["role"] != last_role:
messages.append({"role": p["role"], "content": p["content"]})
last_role = p["role"]
else:
# If the role is the same as the last one, combine the content
messages[-1]["content"] += "\n" + p["content"]
client = anthropic.Anthropic(api_key=credentials.api_key.get_secret_value())
try:
resp = client.messages.create(
model=llm_model.value,
system=sysprompt,
messages=messages,
max_tokens=max_tokens,
)
self.prompt = json.dumps(prompt)
if not resp.content:
raise ValueError("No content returned from Anthropic.")
return (
(
resp.content[0].name
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
else resp.content[0].text
),
resp.usage.input_tokens,
resp.usage.output_tokens,
)
except anthropic.APIError as e:
error_message = f"Anthropic API error: {str(e)}"
logger.error(error_message)
raise ValueError(error_message)
elif provider == "groq":
client = Groq(api_key=credentials.api_key.get_secret_value())
response_format = {"type": "json_object"} if json_format else None
response = client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
)
self.prompt = json.dumps(prompt)
return (
response.choices[0].message.content or "",
response.usage.prompt_tokens if response.usage else 0,
response.usage.completion_tokens if response.usage else 0,
)
elif provider == "ollama":
client = ollama.Client(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
response = client.generate(
model=llm_model.value,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
)
self.prompt = json.dumps(prompt)
return (
response.get("response") or "",
response.get("prompt_eval_count") or 0,
response.get("eval_count") or 0,
)
elif provider == "open_router":
client = openai.OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=credentials.api_key.get_secret_value(),
)
response = client.chat.completions.create(
extra_headers={
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
)
self.prompt = json.dumps(prompt)
# If there's no response, raise an error
if not response.choices:
if response:
raise ValueError(f"OpenRouter error: {response}")
else:
raise ValueError("No response from OpenRouter.")
return (
response.choices[0].message.content or "",
response.usage.prompt_tokens if response.usage else 0,
response.usage.completion_tokens if response.usage else 0,
)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
logger.debug(f"Calling LLM with input data: {input_data}")
prompt = [json.to_dict(p) for p in input_data.conversation_history]
prompt = [p.model_dump() for p in input_data.conversation_history]
def trim_prompt(s: str) -> str:
lines = s.strip().split("\n")
@@ -706,7 +549,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
for retry_count in range(input_data.retry):
try:
llm_response = self.llm_call(
response_text, input_token, output_token = self.llm_call(
credentials=credentials,
llm_model=llm_model,
prompt=prompt,
@@ -714,12 +557,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
ollama_host=input_data.ollama_host,
max_tokens=input_data.max_tokens,
)
response_text = llm_response.response
self.merge_stats(
NodeExecutionStats(
input_token_count=llm_response.prompt_tokens,
output_token_count=llm_response.completion_tokens,
)
{
"input_token_count": input_token,
"output_token_count": output_token,
}
)
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
@@ -762,10 +604,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
retry_prompt = f"Error calling LLM: {e}"
finally:
self.merge_stats(
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
)
{
"llm_call_count": retry_count + 1,
"llm_retry_count": retry_count,
}
)
raise RuntimeError(retry_prompt)
@@ -779,7 +621,7 @@ class AITextGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=LlmModel.GPT4_TURBO,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -872,7 +714,7 @@ class AITextSummarizerBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=LlmModel.GPT4_TURBO,
description="The language model to use for summarizing the text.",
)
focus: str = SchemaField(
@@ -1033,12 +875,12 @@ class AITextSummarizerBlock(AIBlockBase):
class AIConversationBlock(AIBlockBase):
class Input(BlockSchema):
messages: List[Any] = SchemaField(
messages: List[Message] = SchemaField(
description="List of messages in the conversation.", min_length=1
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=LlmModel.GPT4_TURBO,
description="The language model to use for the conversation.",
)
credentials: AICredentials = AICredentialsField()
@@ -1077,7 +919,7 @@ class AIConversationBlock(AIBlockBase):
},
{"role": "user", "content": "Where was it played?"},
],
"model": LlmModel.GPT4O,
"model": LlmModel.GPT4_TURBO,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
@@ -1139,7 +981,7 @@ class AIListGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=LlmModel.GPT4_TURBO,
description="The language model to use for generating the list.",
advanced=True,
)
@@ -1188,7 +1030,7 @@ class AIListGeneratorBlock(AIBlockBase):
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
"fictional worlds."
),
"model": LlmModel.GPT4O,
"model": LlmModel.GPT4_TURBO,
"credentials": TEST_CREDENTIALS_INPUT,
"max_retries": 3,
},

View File

@@ -1,511 +0,0 @@
import logging
import re
from collections import Counter
from typing import TYPE_CHECKING, Any
from autogpt_libs.utils.cache import thread_cached
import backend.blocks.llm as llm
from backend.blocks.agent import AgentExecutorBlock
from backend.data.block import (
Block,
BlockCategory,
BlockInput,
BlockOutput,
BlockSchema,
BlockType,
get_block,
)
from backend.data.model import SchemaField
from backend.util import json
if TYPE_CHECKING:
from backend.data.graph import Link, Node
logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManager
from backend.util.service import get_service_client
return get_service_client(DatabaseManager)
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
"""
Return a list of tool_call_ids if the entry is a tool request.
Supports both OpenAI and Anthropics formats.
"""
tool_call_ids = []
if entry.get("role") != "assistant":
return tool_call_ids
# OpenAI: check for tool_calls in the entry.
calls = entry.get("tool_calls")
if isinstance(calls, list):
for call in calls:
if tool_id := call.get("id"):
tool_call_ids.append(tool_id)
# Anthropics: check content items for tool_use type.
content = entry.get("content")
if isinstance(content, list):
for item in content:
if item.get("type") != "tool_use":
continue
if tool_id := item.get("id"):
tool_call_ids.append(tool_id)
return tool_call_ids
def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
"""
Return a list of tool_call_ids if the entry is a tool response.
Supports both OpenAI and Anthropics formats.
"""
tool_call_ids: list[str] = []
# OpenAI: a tool response message with role "tool" and key "tool_call_id".
if entry.get("role") == "tool":
if tool_call_id := entry.get("tool_call_id"):
tool_call_ids.append(str(tool_call_id))
# Anthropics: check content items for tool_result type.
if entry.get("role") == "user":
content = entry.get("content")
if isinstance(content, list):
for item in content:
if item.get("type") != "tool_result":
continue
if tool_call_id := item.get("tool_use_id"):
tool_call_ids.append(tool_call_id)
return tool_call_ids
def _create_tool_response(call_id: str, output: dict[str, Any]) -> dict[str, Any]:
"""
Create a tool response message for either OpenAI or Anthropics,
based on the tool_id format.
"""
content = output if isinstance(output, str) else json.dumps(output)
# Anthropics format: tool IDs typically start with "toolu_"
if call_id.startswith("toolu_"):
return {
"role": "user",
"type": "message",
"content": [
{"tool_use_id": call_id, "type": "tool_result", "content": content}
],
}
# OpenAI format: tool IDs typically start with "call_".
# Or default fallback (if the tool_id doesn't match any known prefix)
return {"role": "tool", "tool_call_id": call_id, "content": content}
def get_pending_tool_calls(conversation_history: list[Any]) -> dict[str, int]:
"""
All the tool calls entry in the conversation history requires a response.
This function returns the pending tool calls that has not generated an output yet.
Return: dict[str, int] - A dictionary of pending tool call IDs with their count.
"""
pending_calls = Counter()
for history in conversation_history:
for call_id in _get_tool_requests(history):
pending_calls[call_id] += 1
for call_id in _get_tool_responses(history):
pending_calls[call_id] -= 1
return {call_id: count for call_id, count in pending_calls.items() if count > 0}
class SmartDecisionMakerBlock(Block):
"""
A block that uses a language model to make smart decisions based on a given prompt.
"""
class Input(BlockSchema):
prompt: str = SchemaField(
description="The prompt to send to the language model.",
placeholder="Enter your prompt here...",
)
model: llm.LlmModel = SchemaField(
title="LLM Model",
default=llm.LlmModel.GPT4O,
description="The language model to use for answering the prompt.",
advanced=False,
)
credentials: llm.AICredentials = llm.AICredentialsField()
sys_prompt: str = SchemaField(
title="System Prompt",
default="Thinking carefully step by step decide which function to call. "
"Always choose a function call from the list of function signatures, "
"and always provide the complete argument provided with the type "
"matching the required jsonschema signature, no missing argument is allowed. "
"If you have already completed the task objective, you can end the task "
"by providing the end result of your work as a finish message. "
"Only provide EXACTLY one function call, multiple tool calls is strictly prohibited.",
description="The system prompt to provide additional context to the model.",
)
conversation_history: list[dict] = SchemaField(
default=[],
description="The conversation history to provide context for the prompt.",
)
last_tool_output: Any = SchemaField(
default=None,
description="The output of the last tool that was called.",
)
retry: int = SchemaField(
title="Retry Count",
default=3,
description="Number of times to retry the LLM call if the response does not match the expected format.",
)
prompt_values: dict[str, str] = SchemaField(
advanced=False,
default={},
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
)
max_tokens: int | None = SchemaField(
advanced=True,
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
ollama_host: str = SchemaField(
advanced=True,
default="localhost:11434",
description="Ollama host for local models",
)
@classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
# conversation_history & last_tool_output validation is handled differently
missing_links = super().get_missing_links(
data,
[
link
for link in links
if link.sink_name
not in ["conversation_history", "last_tool_output"]
],
)
# Avoid executing the block if the last_tool_output is connected to a static
# link, like StoreValueBlock or AgentInputBlock.
if any(link.sink_name == "conversation_history" for link in links) and any(
link.sink_name == "last_tool_output" and link.is_static
for link in links
):
raise ValueError(
"Last Tool Output can't be connected to a static (dashed line) "
"link like the output of `StoreValue` or `AgentInput` block"
)
return missing_links
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
if missing_input := super().get_missing_input(data):
return missing_input
conversation_history = data.get("conversation_history", [])
pending_tool_calls = get_pending_tool_calls(conversation_history)
last_tool_output = data.get("last_tool_output")
if not last_tool_output and pending_tool_calls:
return {"last_tool_output"}
return set()
class Output(BlockSchema):
error: str = SchemaField(description="Error message if the API call failed.")
tools: Any = SchemaField(description="The tools that are available to use.")
finished: str = SchemaField(
description="The finished message to display to the user."
)
conversations: list[Any] = SchemaField(
description="The conversation history to provide context for the prompt."
)
def __init__(self):
super().__init__(
id="3b191d9f-356f-482d-8238-ba04b6d18381",
description="Uses AI to intelligently decide what tool to use.",
categories={BlockCategory.AI},
block_type=BlockType.AI,
input_schema=SmartDecisionMakerBlock.Input,
output_schema=SmartDecisionMakerBlock.Output,
test_input={
"prompt": "Hello, World!",
"credentials": llm.TEST_CREDENTIALS_INPUT,
},
test_output=[],
test_credentials=llm.TEST_CREDENTIALS,
)
@staticmethod
def _create_block_function_signature(
sink_node: "Node", links: list["Link"]
) -> dict[str, Any]:
"""
Creates a function signature for a block node.
Args:
sink_node: The node for which to create a function signature.
links: The list of links connected to the sink node.
Returns:
A dictionary representing the function signature in the format expected by LLM tools.
Raises:
ValueError: If the block specified by sink_node.block_id is not found.
"""
block = get_block(sink_node.block_id)
if not block:
raise ValueError(f"Block not found: {sink_node.block_id}")
tool_function: dict[str, Any] = {
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", block.name).lower(),
"description": block.description,
}
properties = {}
required = []
for link in links:
sink_block_input_schema = block.input_schema
description = (
sink_block_input_schema.model_fields[link.sink_name].description
if link.sink_name in sink_block_input_schema.model_fields
and sink_block_input_schema.model_fields[link.sink_name].description
else f"The {link.sink_name} of the tool"
)
properties[link.sink_name.lower()] = {
"type": "string",
"description": description,
}
tool_function["parameters"] = {
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False,
"strict": True,
}
return {"type": "function", "function": tool_function}
@staticmethod
def _create_agent_function_signature(
sink_node: "Node", links: list["Link"]
) -> dict[str, Any]:
"""
Creates a function signature for an agent node.
Args:
sink_node: The agent node for which to create a function signature.
links: The list of links connected to the sink node.
Returns:
A dictionary representing the function signature in the format expected by LLM tools.
Raises:
ValueError: If the graph metadata for the specified graph_id and graph_version is not found.
"""
graph_id = sink_node.input_default.get("graph_id")
graph_version = sink_node.input_default.get("graph_version")
if not graph_id or not graph_version:
raise ValueError("Graph ID or Graph Version not found in sink node.")
db_client = get_database_manager_client()
sink_graph_meta = db_client.get_graph_metadata(graph_id, graph_version)
if not sink_graph_meta:
raise ValueError(
f"Sink graph metadata not found: {graph_id} {graph_version}"
)
tool_function: dict[str, Any] = {
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", sink_graph_meta.name).lower(),
"description": sink_graph_meta.description,
}
properties = {}
required = []
for link in links:
sink_block_input_schema = sink_node.input_default["input_schema"]
description = (
sink_block_input_schema["properties"][link.sink_name]["description"]
if "description"
in sink_block_input_schema["properties"][link.sink_name]
else f"The {link.sink_name} of the tool"
)
properties[link.sink_name.lower()] = {
"type": "string",
"description": description,
}
tool_function["parameters"] = {
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False,
"strict": True,
}
return {"type": "function", "function": tool_function}
@staticmethod
def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
"""
Creates function signatures for tools linked to a specified node within a graph.
This method filters the graph links to identify those that are tools and are
connected to the given node_id. It then constructs function signatures for each
tool based on the metadata and input schema of the linked nodes.
Args:
node_id: The node_id for which to create function signatures.
Returns:
list[dict[str, Any]]: A list of dictionaries, each representing a function signature
for a tool, including its name, description, and parameters.
Raises:
ValueError: If no tool links are found for the specified node_id, or if a sink node
or its metadata cannot be found.
"""
db_client = get_database_manager_client()
tools = [
(link, node)
for link, node in db_client.get_connected_output_nodes(node_id)
if link.source_name.startswith("tools_^_") and link.source_id == node_id
]
if not tools:
raise ValueError("There is no next node to execute.")
return_tool_functions = []
grouped_tool_links: dict[str, tuple["Node", list["Link"]]] = {}
for link, node in tools:
if link.sink_id not in grouped_tool_links:
grouped_tool_links[link.sink_id] = (node, [link])
else:
grouped_tool_links[link.sink_id][1].append(link)
for sink_node, links in grouped_tool_links.values():
if not sink_node:
raise ValueError(f"Sink node not found: {links[0].sink_id}")
if sink_node.block_id == AgentExecutorBlock().id:
return_tool_functions.append(
SmartDecisionMakerBlock._create_agent_function_signature(
sink_node, links
)
)
else:
return_tool_functions.append(
SmartDecisionMakerBlock._create_block_function_signature(
sink_node, links
)
)
return return_tool_functions
def run(
self,
input_data: Input,
*,
credentials: llm.APIKeyCredentials,
graph_id: str,
node_id: str,
graph_exec_id: str,
node_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
tool_functions = self._create_function_signature(node_id)
input_data.conversation_history = input_data.conversation_history or []
prompt = [json.to_dict(p) for p in input_data.conversation_history if p]
pending_tool_calls = get_pending_tool_calls(input_data.conversation_history)
if pending_tool_calls and not input_data.last_tool_output:
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
# Prefill all missing tool calls with the last tool output/
# TODO: we need a better way to handle this.
tool_output = [
_create_tool_response(pending_call_id, input_data.last_tool_output)
for pending_call_id, count in pending_tool_calls.items()
for _ in range(count)
]
# If the SDM block only calls 1 tool at a time, this should not happen.
if len(tool_output) > 1:
logger.warning(
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
f"Multiple pending tool calls are prefilled using a single output. "
f"Execution may not be accurate."
)
# Fallback on adding tool output in the conversation history as user prompt.
if len(tool_output) == 0 and input_data.last_tool_output:
logger.warning(
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
f"No pending tool calls found. This may indicate an issue with the "
f"conversation history, or an LLM calling two tools at the same time."
)
tool_output.append(
{
"role": "user",
"content": f"Last tool output: {json.dumps(input_data.last_tool_output)}",
}
)
prompt.extend(tool_output)
values = input_data.prompt_values
if values:
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
prefix = "[Main Objective Prompt]: "
if input_data.sys_prompt and not any(
p["role"] == "system" and p["content"].startswith(prefix) for p in prompt
):
prompt.append({"role": "system", "content": prefix + input_data.sys_prompt})
if input_data.prompt and not any(
p["role"] == "user" and p["content"].startswith(prefix) for p in prompt
):
prompt.append({"role": "user", "content": prefix + input_data.prompt})
response = llm.llm_call(
credentials=credentials,
llm_model=input_data.model,
prompt=prompt,
json_format=False,
max_tokens=input_data.max_tokens,
tools=tool_functions,
ollama_host=input_data.ollama_host,
)
if not response.tool_calls:
yield "finished", response.response
return
for tool_call in response.tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
for arg_name, arg_value in tool_args.items():
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
response.prompt.append(response.raw_response)
yield "conversations", response.prompt

View File

@@ -156,10 +156,6 @@ class CountdownTimerBlock(Block):
days: Union[int, str] = SchemaField(
advanced=False, description="Duration in days", default=0
)
repeat: int = SchemaField(
description="Number of times to repeat the timer",
default=1,
)
class Output(BlockSchema):
output_message: Any = SchemaField(
@@ -191,6 +187,5 @@ class CountdownTimerBlock(Block):
total_seconds = seconds + minutes * 60 + hours * 3600 + days * 86400
for _ in range(input_data.repeat):
time.sleep(total_seconds)
yield "output_message", input_data.input_message
time.sleep(total_seconds)
yield "output_message", input_data.input_message

View File

@@ -2,7 +2,6 @@ import inspect
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Generator,
@@ -19,7 +18,6 @@ import jsonschema
from prisma.models import AgentBlock
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.util import json
from backend.util.settings import Config
@@ -30,9 +28,6 @@ from .model import (
is_credentials_field_name,
)
if TYPE_CHECKING:
from .graph import Link
app_config = Config()
BlockData = tuple[str, Any] # Input & Output data should be a tuple of (name, data).
@@ -49,7 +44,6 @@ class BlockType(Enum):
WEBHOOK = "Webhook"
WEBHOOK_MANUAL = "Webhook (manual)"
AGENT = "Agent"
AI = "AI"
class BlockCategory(Enum):
@@ -115,10 +109,6 @@ class BlockSchema(BaseModel):
def validate_data(cls, data: BlockInput) -> str | None:
return json.validate_with_jsonschema(schema=cls.jsonschema(), data=data)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return cls.validate_data(data)
@classmethod
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:
"""
@@ -196,19 +186,6 @@ class BlockSchema(BaseModel):
)
}
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data # Return as is, by default.
@classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
input_fields_from_nodes = {link.sink_name for link in links}
return input_fields_from_nodes - set(data)
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
return cls.get_required_fields() - set(data)
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema)
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchema)
@@ -317,7 +294,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.static_output = static_output
self.block_type = block_type
self.webhook_config = webhook_config
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
self.execution_stats = {}
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
@@ -374,14 +351,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
Run the block with the given input data.
Args:
input_data: The input data with the structure of input_schema.
Kwargs: Currently 14/02/2025 these include
graph_id: The ID of the graph.
node_id: The ID of the node.
graph_exec_id: The ID of the graph execution.
node_exec_id: The ID of the node execution.
user_id: The ID of the user.
Returns:
A Generator that yields (output_name, output_data).
output_name: One of the output name defined in Block's output_schema.
@@ -395,29 +364,18 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
return data
raise ValueError(f"{self.name} did not produce any output for {output}")
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
stats_dict = stats.model_dump()
current_stats = self.execution_stats.model_dump()
for key, value in stats_dict.items():
if key not in current_stats:
# Field doesn't exist yet, just set it, but this will probably
# not happen, just in case though so we throw for invalid when
# converting back in
current_stats[key] = value
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
current_stats[key].update(value)
elif isinstance(value, (int, float)) and isinstance(
current_stats[key], (int, float)
):
current_stats[key] += value
elif isinstance(value, list) and isinstance(current_stats[key], list):
current_stats[key].extend(value)
def merge_stats(self, stats: dict[str, Any]) -> dict[str, Any]:
for key, value in stats.items():
if isinstance(value, dict):
self.execution_stats.setdefault(key, {}).update(value)
elif isinstance(value, (int, float)):
self.execution_stats.setdefault(key, 0)
self.execution_stats[key] += value
elif isinstance(value, list):
self.execution_stats.setdefault(key, [])
self.execution_stats[key].extend(value)
else:
current_stats[key] = value
self.execution_stats = NodeExecutionStats(**current_stats)
self.execution_stats[key] = value
return self.execution_stats
@property
@@ -440,6 +398,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
}
def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
# Merge the input data with the extra execution arguments, preferring the args for security
if error := self.input_schema.validate_data(input_data):
raise ValueError(
f"Unable to execute block with invalid input data: {error}"

View File

@@ -2,7 +2,6 @@ from typing import Type
from backend.blocks.ai_music_generator import AIMusicGeneratorBlock
from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
from backend.blocks.example.example import ExampleBlock
from backend.blocks.ideogram import IdeogramModelBlock
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
from backend.blocks.jina.search import ExtractWebsiteContentBlock, SearchTheWebBlock
@@ -16,7 +15,6 @@ from backend.blocks.llm import (
LlmModel,
)
from backend.blocks.replicate_flux_advanced import ReplicateFluxAdvancedModelBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
from backend.data.block import Block
@@ -24,7 +22,6 @@ from backend.data.cost import BlockCost, BlockCostType
from backend.integrations.credentials_store import (
anthropic_credentials,
did_credentials,
example_credentials,
groq_credentials,
ideogram_credentials,
jina_credentials,
@@ -268,17 +265,4 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
},
)
],
SmartDecisionMakerBlock: LLM_COST,
ExampleBlock: [
BlockCost(
cost_amount=1,
cost_filter={
"credentials": {
"id": example_credentials.id,
"provider": example_credentials.provider,
"type": example_credentials.type,
}
},
)
],
}

View File

@@ -32,14 +32,12 @@ from backend.data.model import (
from backend.data.notifications import NotificationEventDTO, RefundRequestData
from backend.data.user import get_user_by_id
from backend.notifications import NotificationManager
from backend.util.exceptions import InsufficientBalanceError
from backend.util.service import get_service_client
from backend.util.settings import Settings
settings = Settings()
stripe.api_key = settings.secrets.stripe_api_key
logger = logging.getLogger(__name__)
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
class UserCreditBase(ABC):
@@ -186,14 +184,6 @@ class UserCreditBase(ABC):
"""
pass
@staticmethod
async def create_billing_portal_session(user_id: str) -> str:
session = stripe.billing_portal.Session.create(
customer=await get_stripe_customer_id(user_id),
return_url=base_url + "/profile/credits",
)
return session.url
@staticmethod
def time_now() -> datetime:
return datetime.now(timezone.utc)
@@ -259,6 +249,7 @@ class UserCreditBase(ABC):
metadata: Json,
new_transaction_key: str | None = None,
):
transaction = await CreditTransaction.prisma().find_first_or_raise(
where={"transactionKey": transaction_key, "userId": user_id}
)
@@ -323,13 +314,9 @@ class UserCreditBase(ABC):
if amount < 0 and user_balance + amount < 0:
if fail_insufficient_credits:
raise InsufficientBalanceError(
message=f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}",
user_id=user_id,
balance=user_balance,
amount=amount,
raise ValueError(
f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}"
)
amount = min(-user_balance, 0)
# Create the transaction
@@ -359,6 +346,7 @@ class UsageTransactionMetadata(BaseModel):
class UserCredit(UserCreditBase):
@thread_cached
def notification_client(self) -> NotificationManager:
return get_service_client(NotificationManager)
@@ -371,6 +359,7 @@ class UserCredit(UserCreditBase):
await asyncio.to_thread(
lambda: self.notification_client().queue_notification(
NotificationEventDTO(
recipient_email=settings.config.refund_notification_email,
user_id=notification_request.user_id,
type=notification_type,
data=notification_request.model_dump(),
@@ -774,8 +763,10 @@ class UserCredit(UserCreditBase):
ui_mode="hosted",
payment_intent_data={"setup_future_usage": "off_session"},
saved_payment_method_options={"payment_method_save": "enabled"},
success_url=base_url + "/profile/credits?topup=success",
cancel_url=base_url + "/profile/credits?topup=cancel",
success_url=settings.config.frontend_base_url
+ "/profile/credits?topup=success",
cancel_url=settings.config.frontend_base_url
+ "/profile/credits?topup=cancel",
allow_promotion_codes=True,
)
@@ -850,6 +841,7 @@ class UserCredit(UserCreditBase):
transaction_time_ceiling: datetime | None = None,
transaction_type: str | None = None,
) -> TransactionHistory:
transactions_filter: CreditTransactionWhereInput = {
"userId": user_id,
"isActive": True,

View File

@@ -2,7 +2,6 @@ import logging
import os
import zlib
from contextlib import asynccontextmanager
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
from uuid import uuid4
from dotenv import load_dotenv
@@ -16,36 +15,7 @@ load_dotenv()
PRISMA_SCHEMA = os.getenv("PRISMA_SCHEMA", "schema.prisma")
os.environ["PRISMA_SCHEMA_PATH"] = PRISMA_SCHEMA
def add_param(url: str, key: str, value: str) -> str:
p = urlparse(url)
qs = dict(parse_qsl(p.query))
qs[key] = value
return urlunparse(p._replace(query=urlencode(qs)))
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
CONN_LIMIT = os.getenv("DB_CONNECTION_LIMIT")
if CONN_LIMIT:
DATABASE_URL = add_param(DATABASE_URL, "connection_limit", CONN_LIMIT)
CONN_TIMEOUT = os.getenv("DB_CONNECT_TIMEOUT")
if CONN_TIMEOUT:
DATABASE_URL = add_param(DATABASE_URL, "connect_timeout", CONN_TIMEOUT)
POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT")
if POOL_TIMEOUT:
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
prisma = Prisma(
auto_register=True,
http={"timeout": HTTP_TIMEOUT},
datasource={"url": DATABASE_URL},
)
prisma = Prisma(auto_register=True)
logger = logging.getLogger(__name__)

View File

@@ -1,10 +1,11 @@
from collections import defaultdict
from datetime import datetime, timezone
from multiprocessing import Manager
from typing import Any, AsyncGenerator, Generator, Generic, Type, TypeVar
from typing import Any, AsyncGenerator, Generator, Generic, Optional, Type, TypeVar
from prisma import Json
from prisma.enums import AgentExecutionStatus
from prisma.errors import PrismaError
from prisma.models import (
AgentGraphExecution,
AgentNodeExecution,
@@ -14,7 +15,6 @@ from pydantic import BaseModel
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
from backend.data.includes import EXECUTION_RESULT_INCLUDE, GRAPH_EXECUTION_INCLUDE
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.queue import AsyncRedisEventBus, RedisEventBus
from backend.server.v2.store.exceptions import DatabaseError
from backend.util import mock, type
@@ -265,33 +265,26 @@ async def upsert_execution_output(
)
async def update_graph_execution_start_time(graph_exec_id: str) -> ExecutionResult:
res = await AgentGraphExecution.prisma().update(
async def update_graph_execution_start_time(graph_exec_id: str):
await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
data={
"executionStatus": ExecutionStatus.RUNNING,
"startedAt": datetime.now(tz=timezone.utc),
},
)
if not res:
raise ValueError(f"Execution {graph_exec_id} not found.")
return ExecutionResult.from_graph(res)
async def update_graph_execution_stats(
graph_exec_id: str,
status: ExecutionStatus,
stats: GraphExecutionStats,
stats: dict[str, Any],
) -> ExecutionResult:
data = stats.model_dump()
if isinstance(data["error"], Exception):
data["error"] = str(data["error"])
res = await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
data={
"executionStatus": status,
"stats": Json(data),
"stats": Json(stats),
},
)
if not res:
@@ -300,13 +293,10 @@ async def update_graph_execution_stats(
return ExecutionResult.from_graph(res)
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
data = stats.model_dump()
if isinstance(data["error"], Exception):
data["error"] = str(data["error"])
async def update_node_execution_stats(node_exec_id: str, stats: dict[str, Any]):
await AgentNodeExecution.prisma().update(
where={"id": node_exec_id},
data={"stats": Json(data)},
data={"stats": Json(stats)},
)
@@ -341,21 +331,28 @@ async def update_execution_status(
return ExecutionResult.from_db(res)
async def delete_execution(
graph_exec_id: str, user_id: str, soft_delete: bool = True
) -> None:
if soft_delete:
deleted_count = await AgentGraphExecution.prisma().update_many(
where={"id": graph_exec_id, "userId": user_id}, data={"isDeleted": True}
)
else:
deleted_count = await AgentGraphExecution.prisma().delete_many(
where={"id": graph_exec_id, "userId": user_id}
)
if deleted_count < 1:
raise DatabaseError(
f"Could not delete graph execution #{graph_exec_id}: not found"
async def get_execution(
execution_id: str, user_id: str
) -> Optional[AgentNodeExecution]:
"""
Get an execution by ID. Returns None if not found.
Args:
execution_id: The ID of the execution to retrieve
Returns:
The execution if found, None otherwise
"""
try:
execution = await AgentNodeExecution.prisma().find_unique(
where={
"id": execution_id,
"userId": user_id,
}
)
return execution
except PrismaError:
return None
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
@@ -377,12 +374,15 @@ async def get_executions_in_timerange(
try:
executions = await AgentGraphExecution.prisma().find_many(
where={
"startedAt": {
"gte": datetime.fromisoformat(start_time),
"lte": datetime.fromisoformat(end_time),
},
"userId": user_id,
"isDeleted": False,
"AND": [
{
"startedAt": {
"gte": datetime.fromisoformat(start_time),
"lte": datetime.fromisoformat(end_time),
}
},
{"userId": user_id},
]
},
include=GRAPH_EXECUTION_INCLUDE,
)
@@ -399,38 +399,7 @@ OBJC_SPLIT = "_@_"
def parse_execution_output(output: BlockData, name: str) -> Any | None:
"""
Extracts partial output data by name from a given BlockData.
The function supports extracting data from lists, dictionaries, and objects
using specific naming conventions:
- For lists: <output_name>_$_<index>
- For dictionaries: <output_name>_#_<key>
- For objects: <output_name>_@_<attribute>
Args:
output (BlockData): A tuple containing the output name and data.
name (str): The name used to extract specific data from the output.
Returns:
Any | None: The extracted data if found, otherwise None.
Examples:
>>> output = ("result", [10, 20, 30])
>>> parse_execution_output(output, "result_$_1")
20
>>> output = ("config", {"key1": "value1", "key2": "value2"})
>>> parse_execution_output(output, "config_#_key1")
'value1'
>>> class Sample:
... attr1 = "value1"
... attr2 = "value2"
>>> output = ("object", Sample())
>>> parse_execution_output(output, "object_@_attr1")
'value1'
"""
# Allow extracting partial output data by name.
output_name, output_data = output
if name == output_name:
@@ -459,37 +428,11 @@ def parse_execution_output(output: BlockData, name: str) -> Any | None:
def merge_execution_input(data: BlockInput) -> BlockInput:
"""
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
This function processes input keys that follow specific patterns to merge them into a unified structure:
- `<input_name>_$_<index>` for list inputs.
- `<input_name>_#_<index>` for dictionary inputs.
- `<input_name>_@_<index>` for object inputs.
Args:
data (BlockInput): A dictionary containing input keys and their corresponding values.
Returns:
BlockInput: A dictionary with merged inputs.
Raises:
ValueError: If a list index is not an integer.
Examples:
>>> data = {
... "list_$_0": "a",
... "list_$_1": "b",
... "dict_#_key1": "value1",
... "dict_#_key2": "value2",
... "object_@_attr1": "value1",
... "object_@_attr2": "value2"
... }
>>> merge_execution_input(data)
{
"list": ["a", "b"],
"dict": {"key1": "value1", "key2": "value2"},
"object": <MockObject attr1="value1" attr2="value2">
}
Merge all dynamic input pins which described by the following pattern:
- <input_name>_$_<index> for list input.
- <input_name>_#_<index> for dict input.
- <input_name>_@_<index> for object input.
This function will construct pins with the same name into a single list/dict/object.
"""
# Merge all input with <input_name>_$_<index> into a single list.

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
import uuid
from collections import defaultdict
@@ -13,14 +14,14 @@ from prisma.models import (
AgentNodeLink,
StoreListingVersion,
)
from prisma.types import AgentGraphExecutionWhereInput, AgentGraphWhereInput
from pydantic.fields import Field, computed_field
from prisma.types import AgentGraphWhereInput
from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock
from backend.util import type as type_utils
from backend.util import type
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
from .block import BlockInput, BlockType, get_block, get_blocks
from .db import BaseDbModel, transaction
from .execution import ExecutionResult, ExecutionStatus
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
@@ -70,20 +71,13 @@ class NodeModel(Node):
webhook: Optional[Webhook] = None
@property
def block(self) -> Block[BlockSchema, BlockSchema]:
block = get_block(self.block_id)
if not block:
raise ValueError(f"Block #{self.block_id} does not exist")
return block
@staticmethod
def from_db(node: AgentNode, for_export: bool = False) -> "NodeModel":
def from_db(node: AgentNode):
obj = NodeModel(
id=node.id,
block_id=node.agentBlockId,
input_default=type_utils.convert(node.constantInput, dict[str, Any]),
metadata=type_utils.convert(node.metadata, dict[str, Any]),
input_default=type.convert(node.constantInput, dict[str, Any]),
metadata=type.convert(node.metadata, dict[str, Any]),
graph_id=node.agentGraphId,
graph_version=node.agentGraphVersion,
webhook_id=node.webhookId,
@@ -91,8 +85,6 @@ class NodeModel(Node):
)
obj.input_links = [Link.from_db(link) for link in node.Input or []]
obj.output_links = [Link.from_db(link) for link in node.Output or []]
if for_export:
return obj.stripped_for_export()
return obj
def is_triggered_by_event_type(self, event_type: str) -> bool:
@@ -111,51 +103,6 @@ class NodeModel(Node):
if event_filter[k] is True
]
def stripped_for_export(self) -> "NodeModel":
"""
Returns a copy of the node model, stripped of any non-transferable properties
"""
stripped_node = self.model_copy(deep=True)
# Remove credentials from node input
if stripped_node.input_default:
stripped_node.input_default = NodeModel._filter_secrets_from_node_input(
stripped_node.input_default, self.block.input_schema.jsonschema()
)
if (
stripped_node.block.block_type == BlockType.INPUT
and "value" in stripped_node.input_default
):
stripped_node.input_default["value"] = ""
# Remove webhook info
stripped_node.webhook_id = None
stripped_node.webhook = None
return stripped_node
@staticmethod
def _filter_secrets_from_node_input(
input_data: dict[str, Any], schema: dict[str, Any] | None
) -> dict[str, Any]:
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
field_schemas = schema.get("properties", {}) if schema else {}
result = {}
for key, value in input_data.items():
field_schema: dict | None = field_schemas.get(key)
if (field_schema and field_schema.get("secret", False)) or any(
sensitive_key in key.lower() for sensitive_key in sensitive_keys
):
# This is a secret value -> filter this key-value pair out
continue
elif isinstance(value, dict):
result[key] = NodeModel._filter_secrets_from_node_input(
value, field_schema
)
else:
result[key] = value
return result
# Fix 2-way reference Node <-> Webhook
Webhook.model_rebuild()
@@ -165,7 +112,6 @@ class GraphExecutionMeta(BaseDbModel):
execution_id: str
started_at: datetime
ended_at: datetime
cost: Optional[int] = Field(..., description="Execution cost in credits")
duration: float
total_run_time: float
status: ExecutionStatus
@@ -182,7 +128,7 @@ class GraphExecutionMeta(BaseDbModel):
total_run_time = duration
try:
stats = type_utils.convert(_graph_exec.stats or {}, dict[str, Any])
stats = type.convert(_graph_exec.stats or {}, dict[str, Any])
except ValueError:
stats = {}
@@ -194,7 +140,6 @@ class GraphExecutionMeta(BaseDbModel):
execution_id=_graph_exec.id,
started_at=start_time,
ended_at=end_time,
cost=stats.get("cost", None),
duration=duration,
total_run_time=total_run_time,
status=ExecutionStatus(_graph_exec.executionStatus),
@@ -239,9 +184,7 @@ class GraphExecution(GraphExecutionMeta):
outputs: dict[str, list] = defaultdict(list)
for exec in node_executions:
if exec.block_id == _OUTPUT_BLOCK_ID:
outputs[exec.input_data["name"]].append(
exec.input_data.get("value", None)
)
outputs[exec.input_data["name"]].append(exec.input_data["value"])
return GraphExecution(
**{
@@ -254,9 +197,10 @@ class GraphExecution(GraphExecutionMeta):
)
class BaseGraph(BaseDbModel):
class Graph(BaseDbModel):
version: int = 1
is_active: bool = True
is_template: bool = False
name: str
description: str
nodes: list[Node] = []
@@ -319,10 +263,6 @@ class BaseGraph(BaseDbModel):
}
class Graph(BaseGraph):
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs, only used in export
class GraphModel(Graph):
user_id: str
nodes: list[NodeModel] = [] # type: ignore
@@ -346,89 +286,42 @@ class GraphModel(Graph):
Reassigns all IDs in the graph to new UUIDs.
This method can be used before storing a new graph to the database.
"""
if reassign_graph_id:
graph_id_map = {
self.id: str(uuid.uuid4()),
**{sub_graph.id: str(uuid.uuid4()) for sub_graph in self.sub_graphs},
}
else:
graph_id_map = {}
self._reassign_ids(self, user_id, graph_id_map)
for sub_graph in self.sub_graphs:
self._reassign_ids(sub_graph, user_id, graph_id_map)
@staticmethod
def _reassign_ids(
graph: BaseGraph,
user_id: str,
graph_id_map: dict[str, str],
):
# Reassign Graph ID
if graph.id in graph_id_map:
graph.id = graph_id_map[graph.id]
id_map = {node.id: str(uuid.uuid4()) for node in self.nodes}
if reassign_graph_id:
self.id = str(uuid.uuid4())
# Reassign Node IDs
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
for node in graph.nodes:
for node in self.nodes:
node.id = id_map[node.id]
# Reassign Link IDs
for link in graph.links:
for link in self.links:
link.source_id = id_map[link.source_id]
link.sink_id = id_map[link.sink_id]
# Reassign User IDs for agent blocks
for node in graph.nodes:
for node in self.nodes:
if node.block_id != AgentExecutorBlock().id:
continue
node.input_default["user_id"] = user_id
node.input_default.setdefault("data", {})
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
node.input_default["graph_id"] = graph_id_map[graph_id]
self.validate_graph()
def validate_graph(self, for_run: bool = False):
self._validate_graph(self, for_run)
for sub_graph in self.sub_graphs:
self._validate_graph(sub_graph, for_run)
@staticmethod
def _validate_graph(graph: BaseGraph, for_run: bool = False):
def sanitize(name):
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
if sanitized_name.startswith("tools_^_"):
return sanitized_name.split("_^_")[0]
return sanitized_name
# Validate smart decision maker nodes
smart_decision_maker_nodes = set()
agent_nodes = set()
nodes_block = {
node.id: block
for node in graph.nodes
if (block := get_block(node.block_id)) is not None
}
for node in graph.nodes:
if (block := nodes_block.get(node.id)) is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
# Smart decision maker nodes
if block.block_type == BlockType.AI:
smart_decision_maker_nodes.add(node.id)
# Agent nodes
elif block.block_type == BlockType.AGENT:
agent_nodes.add(node.id)
return name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
input_links = defaultdict(list)
for link in graph.links:
for link in self.links:
input_links[link.sink_id].append(link)
# Nodes: required fields are filled or connected and dependencies are satisfied
for node in graph.nodes:
if (block := nodes_block.get(node.id)) is None:
for node in self.nodes:
block = get_block(node.block_id)
if block is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
provided_inputs = set(
@@ -445,12 +338,9 @@ class GraphModel(Graph):
)
and (
for_run # Skip input completion validation, unless when executing.
or block.block_type
in [
BlockType.INPUT,
BlockType.OUTPUT,
BlockType.AGENT,
]
or block.block_type == BlockType.INPUT
or block.block_type == BlockType.OUTPUT
or block.block_type == BlockType.AGENT
)
):
raise ValueError(
@@ -488,7 +378,7 @@ class GraphModel(Graph):
f"Node {block.name} #{node.id}: Field `{field_name}` requires [{', '.join(missing_deps)}] to be set"
)
node_map = {v.id: v for v in graph.nodes}
node_map = {v.id: v for v in self.nodes}
def is_static_output_block(nid: str) -> bool:
bid = node_map[nid].block_id
@@ -496,23 +386,23 @@ class GraphModel(Graph):
return b.static_output if b else False
# Links: links are connected and the connected pin data type are compatible.
for link in graph.links:
for link in self.links:
source = (link.source_id, link.source_name)
sink = (link.sink_id, link.sink_name)
prefix = f"Link {source} <-> {sink}"
suffix = f"Link {source} <-> {sink}"
for i, (node_id, name) in enumerate([source, sink]):
node = node_map.get(node_id)
if not node:
raise ValueError(
f"{prefix}, {node_id} is invalid node id, available nodes: {node_map.keys()}"
f"{suffix}, {node_id} is invalid node id, available nodes: {node_map.keys()}"
)
block = get_block(node.block_id)
if not block:
blocks = {v().id: v().name for v in get_blocks().values()}
raise ValueError(
f"{prefix}, {node.block_id} is invalid block id, available blocks: {blocks}"
f"{suffix}, {node.block_id} is invalid block id, available blocks: {blocks}"
)
sanitized_name = sanitize(name)
@@ -520,37 +410,35 @@ class GraphModel(Graph):
if i == 0:
fields = (
block.output_schema.get_fields()
if block.block_type not in [BlockType.AGENT]
if block.block_type != BlockType.AGENT
else vals.get("output_schema", {}).get("properties", {}).keys()
)
else:
fields = (
block.input_schema.get_fields()
if block.block_type not in [BlockType.AGENT]
if block.block_type != BlockType.AGENT
else vals.get("input_schema", {}).get("properties", {}).keys()
)
if sanitized_name not in fields and not name.startswith("tools_^_"):
if sanitized_name not in fields:
fields_msg = f"Allowed fields: {fields}"
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
raise ValueError(f"{suffix}, `{name}` invalid, {fields_msg}")
if is_static_output_block(link.source_id):
link.is_static = True # Each value block output should be static.
@staticmethod
def from_db(
graph: AgentGraph,
for_export: bool = False,
sub_graphs: list[AgentGraph] | None = None,
):
def from_db(graph: AgentGraph, for_export: bool = False):
return GraphModel(
id=graph.id,
user_id=graph.userId if not for_export else "",
user_id=graph.userId,
version=graph.version,
is_active=graph.isActive,
is_template=graph.isTemplate,
name=graph.name or "",
description=graph.description or "",
nodes=[
NodeModel.from_db(node, for_export) for node in graph.AgentNodes or []
NodeModel.from_db(GraphModel._process_node(node, for_export))
for node in graph.AgentNodes or []
],
links=list(
{
@@ -559,12 +447,59 @@ class GraphModel(Graph):
for link in (node.Input or []) + (node.Output or [])
}
),
sub_graphs=[
GraphModel.from_db(sub_graph, for_export)
for sub_graph in sub_graphs or []
],
)
@staticmethod
def _process_node(node: AgentNode, for_export: bool) -> AgentNode:
if for_export:
# Remove credentials from node input
if node.constantInput:
constant_input = type.convert(node.constantInput, dict[str, Any])
constant_input = GraphModel._hide_node_input_credentials(constant_input)
node.constantInput = Json(constant_input)
# Remove webhook info
node.webhookId = None
node.Webhook = None
return node
@staticmethod
def _hide_node_input_credentials(input_data: dict[str, Any]) -> dict[str, Any]:
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
result = {}
for key, value in input_data.items():
if isinstance(value, dict):
result[key] = GraphModel._hide_node_input_credentials(value)
elif isinstance(value, str) and any(
sensitive_key in key.lower() for sensitive_key in sensitive_keys
):
# Skip this key-value pair in the result
continue
else:
result[key] = value
return result
def clean_graph(self):
blocks = [block() for block in get_blocks().values()]
input_blocks = [
node
for node in self.nodes
if next(
(
b
for b in blocks
if b.id == node.block_id and b.block_type == BlockType.INPUT
),
None,
)
]
for node in self.nodes:
if any(input_block.id == node.id for input_block in input_blocks):
node.input_default["value"] = ""
# --------------------- CRUD functions --------------------- #
@@ -594,14 +529,14 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
async def get_graphs(
user_id: str,
filter_by: Literal["active"] | None = "active",
filter_by: Literal["active", "template"] | None = "active",
) -> list[GraphModel]:
"""
Retrieves graph metadata objects.
Default behaviour is to get all currently active graphs.
Args:
filter_by: An optional filter to either select graphs.
filter_by: An optional filter to either select templates or active graphs.
user_id: The ID of the user that owns the graph.
Returns:
@@ -611,6 +546,8 @@ async def get_graphs(
if filter_by == "active":
where_clause["isActive"] = True
elif filter_by == "template":
where_clause["isTemplate"] = True
graphs = await AgentGraph.prisma().find_many(
where=where_clause,
@@ -630,20 +567,17 @@ async def get_graphs(
return graph_models
async def get_graph_executions(
graph_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> list[GraphExecutionMeta]:
where_filter: AgentGraphExecutionWhereInput = {
"isDeleted": False,
}
if user_id:
where_filter["userId"] = user_id
if graph_id:
where_filter["agentGraphId"] = graph_id
async def get_graphs_executions(user_id: str) -> list[GraphExecutionMeta]:
executions = await AgentGraphExecution.prisma().find_many(
where=where_filter,
where={"userId": user_id},
order={"createdAt": "desc"},
)
return [GraphExecutionMeta.from_db(execution) for execution in executions]
async def get_graph_executions(graph_id: str, user_id: str) -> list[GraphExecutionMeta]:
executions = await AgentGraphExecution.prisma().find_many(
where={"agentGraphId": graph_id, "userId": user_id},
order={"createdAt": "desc"},
)
return [GraphExecutionMeta.from_db(execution) for execution in executions]
@@ -653,14 +587,14 @@ async def get_execution_meta(
user_id: str, execution_id: str
) -> GraphExecutionMeta | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "isDeleted": False, "userId": user_id}
where={"id": execution_id, "userId": user_id}
)
return GraphExecutionMeta.from_db(execution) if execution else None
async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "isDeleted": False, "userId": user_id},
where={"id": execution_id, "userId": user_id},
include={
"AgentNodeExecutions": {
"include": {"AgentNode": True, "Input": True, "Output": True},
@@ -676,41 +610,17 @@ async def get_execution(user_id: str, execution_id: str) -> GraphExecution | Non
return GraphExecution.from_db(execution) if execution else None
async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph | None:
where_clause: AgentGraphWhereInput = {
"id": graph_id,
}
if version is not None:
where_clause["version"] = version
graph = await AgentGraph.prisma().find_first(
where=where_clause,
include=AGENT_GRAPH_INCLUDE,
order={"version": "desc"},
)
if not graph:
return None
return Graph(
id=graph.id,
name=graph.name or "",
description=graph.description or "",
version=graph.version,
is_active=graph.isActive,
)
async def get_graph(
graph_id: str,
version: int | None = None,
template: bool = False, # note: currently not in use; TODO: remove from DB entirely
user_id: str | None = None,
for_export: bool = False,
) -> GraphModel | None:
"""
Retrieves a graph from the DB.
Defaults to the version with `is_active` if `version` is not passed.
Defaults to the version with `is_active` if `version` is not passed,
or the latest version with `is_template` if `template=True`.
Returns `None` if the record is not found.
"""
@@ -720,6 +630,8 @@ async def get_graph(
if version is not None:
where_clause["version"] = version
elif not template:
where_clause["isActive"] = True
graph = await AgentGraph.prisma().find_first(
where=where_clause,
@@ -743,74 +655,9 @@ async def get_graph(
):
return None
if for_export:
sub_graphs = await _get_sub_graphs(graph)
return GraphModel.from_db(
graph=graph,
sub_graphs=sub_graphs,
for_export=for_export,
)
return GraphModel.from_db(graph, for_export)
async def _get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
"""
Iteratively fetches all sub-graphs of a given graph, and flattens them into a list.
This call involves a DB fetch in batch, breadth-first, per-level of graph depth.
On each DB fetch we will only fetch the sub-graphs that are not already in the list.
"""
sub_graphs = {graph.id: graph}
search_graphs = [graph]
agent_block_id = AgentExecutorBlock().id
while search_graphs:
sub_graph_ids = [
(graph_id, graph_version)
for graph in search_graphs
for node in graph.AgentNodes or []
if (
node.AgentBlock
and node.AgentBlock.id == agent_block_id
and (graph_id := dict(node.constantInput).get("graph_id"))
and (graph_version := dict(node.constantInput).get("graph_version"))
)
]
if not sub_graph_ids:
break
graphs = await AgentGraph.prisma().find_many(
where={
"OR": [
{
"id": graph_id,
"version": graph_version,
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
}
for graph_id, graph_version in sub_graph_ids
] # type: ignore
},
include=AGENT_GRAPH_INCLUDE,
)
search_graphs = [graph for graph in graphs if graph.id not in sub_graphs]
sub_graphs.update({graph.id: graph for graph in search_graphs})
return [g for g in sub_graphs.values() if g.id != graph.id]
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
links = await AgentNodeLink.prisma().find_many(
where={"agentNodeSourceId": node_id},
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}}, # type: ignore
)
return [
(Link.from_db(link), NodeModel.from_db(link.AgentNodeSink))
for link in links
if link.AgentNodeSink
]
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
# Activate the requested version if it exists and is owned by the user.
updated_count = await AgentGraph.prisma().update_many(
@@ -862,56 +709,50 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
async with transaction() as tx:
await __create_graph(tx, graph, user_id)
if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
if created_graph := await get_graph(
graph.id, graph.version, template=graph.is_template, user_id=user_id
):
return created_graph
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
async def __create_graph(tx, graph: Graph, user_id: str):
graphs = [graph] + graph.sub_graphs
await AgentGraph.prisma(tx).create_many(
data=[
{
"id": graph.id,
"version": graph.version,
"name": graph.name,
"description": graph.description,
"isActive": graph.is_active,
"userId": user_id,
}
for graph in graphs
]
await AgentGraph.prisma(tx).create(
data={
"id": graph.id,
"version": graph.version,
"name": graph.name,
"description": graph.description,
"isTemplate": graph.is_template,
"isActive": graph.is_active,
"userId": user_id,
"AgentNodes": {
"create": [
{
"id": node.id,
"agentBlockId": node.block_id,
"constantInput": Json(node.input_default),
"metadata": Json(node.metadata),
}
for node in graph.nodes
]
},
}
)
await AgentNode.prisma(tx).create_many(
data=[
{
"id": node.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"agentBlockId": node.block_id,
"constantInput": Json(node.input_default),
"metadata": Json(node.metadata),
"webhookId": node.webhook_id,
}
for graph in graphs
for node in graph.nodes
]
)
await AgentNodeLink.prisma(tx).create_many(
data=[
{
"id": str(uuid.uuid4()),
"sourceName": link.source_name,
"sinkName": link.sink_name,
"agentNodeSourceId": link.source_id,
"agentNodeSinkId": link.sink_id,
"isStatic": link.is_static,
}
for graph in graphs
await asyncio.gather(
*[
AgentNodeLink.prisma(tx).create(
{
"id": str(uuid.uuid4()),
"sourceName": link.source_name,
"sinkName": link.sink_name,
"agentNodeSourceId": link.source_id,
"agentNodeSinkId": link.sink_id,
"isStatic": link.is_static,
}
)
for link in graph.links
]
)

View File

@@ -32,15 +32,3 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
}
def library_agent_include(user_id: str) -> prisma.types.LibraryAgentInclude:
return {
"Agent": {
"include": {
**AGENT_GRAPH_INCLUDE,
"AgentGraphExecution": {"where": {"userId": user_id}},
}
},
"Creator": True,
}

View File

@@ -402,37 +402,3 @@ class RefundRequest(BaseModel):
status: str
created_at: datetime
updated_at: datetime
class NodeExecutionStats(BaseModel):
"""Execution statistics for a node execution."""
class Config:
arbitrary_types_allowed = True
error: Optional[Exception | str] = None
walltime: float = 0
cputime: float = 0
cost: float = 0
input_size: int = 0
output_size: int = 0
llm_call_count: int = 0
llm_retry_count: int = 0
input_token_count: int = 0
output_token_count: int = 0
class GraphExecutionStats(BaseModel):
"""Execution statistics for a graph execution."""
class Config:
arbitrary_types_allowed = True
error: Optional[Exception | str] = None
walltime: float = 0
cputime: float = 0
nodes_walltime: float = 0
nodes_cputime: float = 0
node_count: int = 0
node_error_count: int = 0
cost: float = 0

View File

@@ -1,5 +1,5 @@
import logging
from datetime import datetime, timedelta, timezone
from datetime import datetime, timedelta
from enum import Enum
from typing import Annotated, Any, Generic, Optional, TypeVar, Union
@@ -18,25 +18,18 @@ from .db import transaction
logger = logging.getLogger(__name__)
NotificationDataType_co = TypeVar(
"NotificationDataType_co", bound="BaseNotificationData", covariant=True
)
SummaryParamsType_co = TypeVar(
"SummaryParamsType_co", bound="BaseSummaryParams", covariant=True
)
T_co = TypeVar("T_co", bound="BaseNotificationData", covariant=True)
class QueueType(Enum):
class BatchingStrategy(Enum):
IMMEDIATE = "immediate" # Send right away (errors, critical notifications)
BATCH = "batch" # Batch for up to an hour (usage reports)
SUMMARY = "summary" # Daily digest (summary notifications)
HOURLY = "hourly" # Batch for up to an hour (usage reports)
DAILY = "daily" # Daily digest (summary notifications)
BACKOFF = "backoff" # Backoff strategy (exponential backoff)
ADMIN = "admin" # Admin notifications (errors, critical notifications)
class BaseNotificationData(BaseModel):
class Config:
extra = "allow"
pass
class AgentRunData(BaseNotificationData):
@@ -45,7 +38,7 @@ class AgentRunData(BaseNotificationData):
execution_time: float
node_count: int = Field(..., description="Number of nodes executed")
graph_id: str
outputs: list[dict[str, Any]] = Field(..., description="Outputs of the agent")
outputs: dict[str, Any] = Field(..., description="Outputs of the agent")
class ZeroBalanceData(BaseNotificationData):
@@ -53,21 +46,12 @@ class ZeroBalanceData(BaseNotificationData):
last_transaction_time: datetime
top_up_link: str
@field_validator("last_transaction_time")
@classmethod
def validate_timezone(cls, value: datetime):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class LowBalanceData(BaseNotificationData):
agent_name: str = Field(..., description="Name of the agent")
current_balance: float = Field(
..., description="Current balance in credits (100 = $1)"
)
billing_page_link: str = Field(..., description="Link to billing page")
shortfall: float = Field(..., description="Amount of credits needed to continue")
current_balance: float
threshold_amount: float
top_up_link: str
recent_usage: float = Field(..., description="Usage in the last 24 hours")
class BlockExecutionFailedData(BaseNotificationData):
@@ -88,13 +72,6 @@ class ContinuousAgentErrorData(BaseNotificationData):
error_time: datetime
attempts: int = Field(..., description="Number of retry attempts made")
@field_validator("start_time", "error_time")
@classmethod
def validate_timezone(cls, value: datetime):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class BaseSummaryData(BaseNotificationData):
total_credits_used: float
@@ -107,53 +84,18 @@ class BaseSummaryData(BaseNotificationData):
cost_breakdown: dict[str, float]
class BaseSummaryParams(BaseModel):
pass
class DailySummaryParams(BaseSummaryParams):
date: datetime
@field_validator("date")
def validate_timezone(cls, value):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class WeeklySummaryParams(BaseSummaryParams):
start_date: datetime
end_date: datetime
@field_validator("start_date", "end_date")
def validate_timezone(cls, value):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class DailySummaryData(BaseSummaryData):
date: datetime
@field_validator("date")
def validate_timezone(cls, value):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class WeeklySummaryData(BaseSummaryData):
start_date: datetime
end_date: datetime
@field_validator("start_date", "end_date")
def validate_timezone(cls, value):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
week_number: int
year: int
class MonthlySummaryData(BaseNotificationData):
class MonthlySummaryData(BaseSummaryData):
month: int
year: int
@@ -177,10 +119,6 @@ NotificationData = Annotated[
BlockExecutionFailedData,
ContinuousAgentErrorData,
MonthlySummaryData,
WeeklySummaryData,
DailySummaryData,
RefundRequestData,
BaseSummaryData,
],
Field(discriminator="type"),
]
@@ -190,25 +128,19 @@ class NotificationEventDTO(BaseModel):
user_id: str
type: NotificationType
data: dict
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
created_at: datetime = Field(default_factory=datetime.now)
recipient_email: Optional[str] = None
retry_count: int = 0
class SummaryParamsEventDTO(BaseModel):
class NotificationEventModel(BaseModel, Generic[T_co]):
user_id: str
type: NotificationType
data: dict
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
class NotificationEventModel(BaseModel, Generic[NotificationDataType_co]):
user_id: str
type: NotificationType
data: NotificationDataType_co
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
data: T_co
created_at: datetime = Field(default_factory=datetime.now)
@property
def strategy(self) -> QueueType:
def strategy(self) -> BatchingStrategy:
return NotificationTypeOverride(self.type).strategy
@field_validator("type", mode="before")
@@ -222,14 +154,7 @@ class NotificationEventModel(BaseModel, Generic[NotificationDataType_co]):
return NotificationTypeOverride(self.type).template
class SummaryParamsEventModel(BaseModel, Generic[SummaryParamsType_co]):
user_id: str
type: NotificationType
data: SummaryParamsType_co
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
def get_notif_data_type(
def get_data_type(
notification_type: NotificationType,
) -> type[BaseNotificationData]:
return {
@@ -246,20 +171,11 @@ def get_notif_data_type(
}[notification_type]
def get_summary_params_type(
notification_type: NotificationType,
) -> type[BaseSummaryParams]:
return {
NotificationType.DAILY_SUMMARY: DailySummaryParams,
NotificationType.WEEKLY_SUMMARY: WeeklySummaryParams,
}[notification_type]
class NotificationBatch(BaseModel):
user_id: str
events: list[NotificationEvent]
strategy: QueueType
last_update: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
strategy: BatchingStrategy
last_update: datetime = datetime.now()
class NotificationResult(BaseModel):
@@ -272,22 +188,23 @@ class NotificationTypeOverride:
self.notification_type = notification_type
@property
def strategy(self) -> QueueType:
def strategy(self) -> BatchingStrategy:
BATCHING_RULES = {
# These are batched by the notification service
NotificationType.AGENT_RUN: QueueType.BATCH,
NotificationType.AGENT_RUN: BatchingStrategy.IMMEDIATE,
# These are batched by the notification service, but with a backoff strategy
NotificationType.ZERO_BALANCE: QueueType.BACKOFF,
NotificationType.LOW_BALANCE: QueueType.IMMEDIATE,
NotificationType.BLOCK_EXECUTION_FAILED: QueueType.BACKOFF,
NotificationType.CONTINUOUS_AGENT_ERROR: QueueType.BACKOFF,
NotificationType.DAILY_SUMMARY: QueueType.SUMMARY,
NotificationType.WEEKLY_SUMMARY: QueueType.SUMMARY,
NotificationType.MONTHLY_SUMMARY: QueueType.SUMMARY,
NotificationType.REFUND_REQUEST: QueueType.ADMIN,
NotificationType.REFUND_PROCESSED: QueueType.ADMIN,
NotificationType.ZERO_BALANCE: BatchingStrategy.BACKOFF,
NotificationType.LOW_BALANCE: BatchingStrategy.BACKOFF,
NotificationType.BLOCK_EXECUTION_FAILED: BatchingStrategy.BACKOFF,
NotificationType.CONTINUOUS_AGENT_ERROR: BatchingStrategy.BACKOFF,
# These aren't batched by the notification service, so we send them right away
NotificationType.DAILY_SUMMARY: BatchingStrategy.IMMEDIATE,
NotificationType.WEEKLY_SUMMARY: BatchingStrategy.IMMEDIATE,
NotificationType.MONTHLY_SUMMARY: BatchingStrategy.IMMEDIATE,
NotificationType.REFUND_REQUEST: BatchingStrategy.IMMEDIATE,
NotificationType.REFUND_PROCESSED: BatchingStrategy.IMMEDIATE,
}
return BATCHING_RULES.get(self.notification_type, QueueType.IMMEDIATE)
return BATCHING_RULES.get(self.notification_type, BatchingStrategy.HOURLY)
@property
def template(self) -> str:
@@ -337,51 +254,12 @@ class NotificationPreference(BaseModel):
)
daily_limit: int = 10 # Max emails per day
emails_sent_today: int = 0
last_reset_date: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc)
)
class UserNotificationEventDTO(BaseModel):
type: NotificationType
data: dict
created_at: datetime
updated_at: datetime
@staticmethod
def from_db(model: NotificationEvent) -> "UserNotificationEventDTO":
return UserNotificationEventDTO(
type=model.type,
data=dict(model.data),
created_at=model.createdAt,
updated_at=model.updatedAt,
)
class UserNotificationBatchDTO(BaseModel):
user_id: str
type: NotificationType
notifications: list[UserNotificationEventDTO]
created_at: datetime
updated_at: datetime
@staticmethod
def from_db(model: UserNotificationBatch) -> "UserNotificationBatchDTO":
return UserNotificationBatchDTO(
user_id=model.userId,
type=model.type,
notifications=[
UserNotificationEventDTO.from_db(notification)
for notification in model.notifications or []
],
created_at=model.createdAt,
updated_at=model.updatedAt,
)
last_reset_date: datetime = Field(default_factory=datetime.now)
def get_batch_delay(notification_type: NotificationType) -> timedelta:
return {
NotificationType.AGENT_RUN: timedelta(minutes=60),
NotificationType.AGENT_RUN: timedelta(seconds=1),
NotificationType.ZERO_BALANCE: timedelta(minutes=60),
NotificationType.LOW_BALANCE: timedelta(minutes=60),
NotificationType.BLOCK_EXECUTION_FAILED: timedelta(minutes=60),
@@ -392,15 +270,19 @@ def get_batch_delay(notification_type: NotificationType) -> timedelta:
async def create_or_add_to_user_notification_batch(
user_id: str,
notification_type: NotificationType,
notification_data: NotificationEventModel,
) -> UserNotificationBatchDTO:
data: str, # type: 'NotificationEventModel'
) -> dict:
try:
logger.info(
f"Creating or adding to notification batch for {user_id} with type {notification_type} and data {notification_data}"
f"Creating or adding to notification batch for {user_id} with type {notification_type} and data {data}"
)
notification_data = NotificationEventModel[
get_data_type(notification_type)
].model_validate_json(data)
# Serialize the data
json_data: Json = Json(notification_data.data.model_dump())
json_data: Json = Json(notification_data.data.model_dump_json())
# First try to find existing batch
existing_batch = await UserNotificationBatch.prisma().find_unique(
@@ -431,7 +313,7 @@ async def create_or_add_to_user_notification_batch(
},
include={"notifications": True},
)
return UserNotificationBatchDTO.from_db(resp)
return resp.model_dump()
else:
async with transaction() as tx:
notification_event = await tx.notificationevent.create(
@@ -453,33 +335,27 @@ async def create_or_add_to_user_notification_batch(
raise DatabaseError(
f"Failed to add notification event {notification_event.id} to existing batch {existing_batch.id}"
)
return UserNotificationBatchDTO.from_db(resp)
return resp.model_dump()
except Exception as e:
raise DatabaseError(
f"Failed to create or add to notification batch for user {user_id} and type {notification_type}: {e}"
) from e
async def get_user_notification_oldest_message_in_batch(
async def get_user_notification_last_message_in_batch(
user_id: str,
notification_type: NotificationType,
) -> UserNotificationEventDTO | None:
) -> NotificationEvent | None:
try:
batch = await UserNotificationBatch.prisma().find_first(
where={"userId": user_id, "type": notification_type},
include={"notifications": True},
order={"createdAt": "desc"},
)
if not batch:
return None
if not batch.notifications:
return None
sorted_notifications = sorted(batch.notifications, key=lambda x: x.createdAt)
return (
UserNotificationEventDTO.from_db(sorted_notifications[0])
if sorted_notifications
else None
)
return batch.notifications[-1]
except Exception as e:
raise DatabaseError(
f"Failed to get user notification last message in batch for user {user_id} and type {notification_type}: {e}"
@@ -514,34 +390,13 @@ async def empty_user_notification_batch(
async def get_user_notification_batch(
user_id: str,
notification_type: NotificationType,
) -> UserNotificationBatchDTO | None:
) -> UserNotificationBatch | None:
try:
batch = await UserNotificationBatch.prisma().find_first(
return await UserNotificationBatch.prisma().find_first(
where={"userId": user_id, "type": notification_type},
include={"notifications": True},
)
return UserNotificationBatchDTO.from_db(batch) if batch else None
except Exception as e:
raise DatabaseError(
f"Failed to get user notification batch for user {user_id} and type {notification_type}: {e}"
) from e
async def get_all_batches_by_type(
notification_type: NotificationType,
) -> list[UserNotificationBatchDTO]:
try:
batches = await UserNotificationBatch.prisma().find_many(
where={
"type": notification_type,
"notifications": {
"some": {} # Only return batches with at least one notification
},
},
include={"notifications": True},
)
return [UserNotificationBatchDTO.from_db(batch) for batch in batches]
except Exception as e:
raise DatabaseError(
f"Failed to get all batches by type {notification_type}: {e}"
) from e

View File

@@ -1,247 +0,0 @@
import re
from typing import Any, Optional
import prisma
import pydantic
from prisma import Json
from prisma.models import (
AgentGraph,
AgentGraphExecution,
StoreListingVersion,
UserOnboarding,
)
from prisma.types import UserOnboardingUpdateInput
from backend.server.v2.library.db import set_is_deleted_for_library_agent
from backend.server.v2.store.db import get_store_agent_details
from backend.server.v2.store.model import StoreAgentDetails
# Mapping from user reason id to categories to search for when choosing agent to show
REASON_MAPPING: dict[str, list[str]] = {
"content_marketing": ["writing", "marketing", "creative"],
"business_workflow_automation": ["business", "productivity"],
"data_research": ["data", "research"],
"ai_innovation": ["development", "research"],
"personal_productivity": ["personal", "productivity"],
}
class UserOnboardingUpdate(pydantic.BaseModel):
step: int
usageReason: Optional[str] = None
integrations: list[str] = pydantic.Field(default_factory=list)
otherIntegrations: Optional[str] = None
selectedAgentCreator: Optional[str] = None
selectedAgentSlug: Optional[str] = None
agentInput: Optional[dict[str, Any]] = None
isCompleted: bool = False
async def get_user_onboarding(user_id: str):
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id}, # type: ignore
"update": {},
},
)
async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
# Get the user onboarding data
user_onboarding = await get_user_onboarding(user_id)
update: UserOnboardingUpdateInput = {
"step": data.step,
"isCompleted": data.isCompleted,
}
if data.usageReason:
update["usageReason"] = data.usageReason
if data.integrations:
update["integrations"] = data.integrations
if data.otherIntegrations:
update["otherIntegrations"] = data.otherIntegrations
if data.selectedAgentSlug and data.selectedAgentCreator:
update["selectedAgentSlug"] = data.selectedAgentSlug
update["selectedAgentCreator"] = data.selectedAgentCreator
# Check if slug changes
if (
user_onboarding.selectedAgentCreator
and user_onboarding.selectedAgentSlug
and user_onboarding.selectedAgentSlug != data.selectedAgentSlug
):
store_agent = await get_store_agent_details(
user_onboarding.selectedAgentCreator, user_onboarding.selectedAgentSlug
)
store_listing = await StoreListingVersion.prisma().find_unique_or_raise(
where={"id": store_agent.store_listing_version_id}
)
agent_graph = await AgentGraph.prisma().find_first(
where={"id": store_listing.agentId, "version": store_listing.version}
)
execution_count = await AgentGraphExecution.prisma().count(
where={
"userId": user_id,
"agentGraphId": store_listing.agentId,
"agentGraphVersion": store_listing.version,
}
)
# If there was no execution and graph doesn't belong to the user,
# mark the agent as deleted
if execution_count == 0 and agent_graph and agent_graph.userId != user_id:
await set_is_deleted_for_library_agent(
user_id, store_listing.agentId, store_listing.agentVersion, True
)
if data.agentInput:
update["agentInput"] = Json(data.agentInput)
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, **update}, # type: ignore
"update": update,
},
)
def clean_and_split(text: str) -> list[str]:
"""
Removes all special characters from a string, truncates it to 100 characters,
and splits it by whitespace and commas.
Args:
text (str): The input string.
Returns:
list[str]: A list of cleaned words.
"""
# Remove all special characters (keep only alphanumeric and whitespace)
cleaned_text = re.sub(r"[^a-zA-Z0-9\s,]", "", text.strip()[:100])
# Split by whitespace and commas
words = re.split(r"[\s,]+", cleaned_text)
# Remove empty strings from the list
words = [word.lower() for word in words if word]
return words
def calculate_points(
agent, categories: list[str], custom: list[str], integrations: list[str]
) -> int:
"""
Calculates the total points for an agent based on the specified criteria.
Args:
agent: The agent object.
categories (list[str]): List of categories to match.
words (list[str]): List of words to match in the description.
Returns:
int: Total points for the agent.
"""
points = 0
# 1. Category Matches
matched_categories = sum(
1 for category in categories if category in agent.categories
)
points += matched_categories * 100
# 2. Description Word Matches
description_words = agent.description.split() # Split description into words
matched_words = sum(1 for word in custom if word in description_words)
points += matched_words * 100
matched_words = sum(1 for word in integrations if word in description_words)
points += matched_words * 50
# 3. Featured Bonus
if agent.featured:
points += 50
# 4. Rating Bonus
points += agent.rating * 10
# 5. Runs Bonus
runs_points = min(agent.runs / 1000 * 100, 100) # Cap at 100 points
points += runs_points
return int(points)
async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
user_onboarding = await get_user_onboarding(user_id)
categories = REASON_MAPPING.get(user_onboarding.usageReason or "", [])
where_clause: dict[str, Any] = {}
custom = clean_and_split((user_onboarding.usageReason or "").lower())
if categories:
where_clause["OR"] = [
{"categories": {"has": category}} for category in categories
]
else:
where_clause["OR"] = [
{"description": {"contains": word, "mode": "insensitive"}}
for word in custom
]
where_clause["OR"] += [
{"description": {"contains": word, "mode": "insensitive"}}
for word in user_onboarding.integrations
]
agents = await prisma.models.StoreAgent.prisma().find_many(
where=prisma.types.StoreAgentWhereInput(**where_clause),
order=[
{"featured": "desc"},
{"runs": "desc"},
{"rating": "desc"},
],
)
if len(agents) < 2:
agents += await prisma.models.StoreAgent.prisma().find_many(
where={
"listing_id": {"not_in": [agent.listing_id for agent in agents]},
},
order=[
{"featured": "desc"},
{"runs": "desc"},
{"rating": "desc"},
],
take=2 - len(agents),
)
# Calculate points for the first 30 agents and choose the top 2
agent_points = []
for agent in agents[:50]:
points = calculate_points(
agent, categories, custom, user_onboarding.integrations
)
agent_points.append((agent, points))
agent_points.sort(key=lambda x: x[1], reverse=True)
recommended_agents = [agent for agent, _ in agent_points[:2]]
return [
StoreAgentDetails(
store_listing_version_id=agent.storeListingVersionId,
slug=agent.slug,
agent_name=agent.agent_name,
agent_video=agent.agent_video or "",
agent_image=agent.agent_image,
creator=agent.creator_username,
creator_avatar=agent.creator_avatar,
sub_heading=agent.sub_heading,
description=agent.description,
categories=agent.categories,
runs=agent.runs,
rating=agent.rating,
versions=agent.versions,
last_updated=agent.updated_at,
)
for agent in recommended_agents
]

View File

@@ -35,7 +35,7 @@ class BaseRedisEventBus(Generic[M], ABC):
def _serialize_message(self, item: M, channel_key: str) -> tuple[str, str]:
message = json.dumps(item.model_dump(), cls=DateTimeEncoder)
channel_name = f"{self.event_bus_name}/{channel_key}"
logger.debug(f"[{channel_name}] Publishing an event to Redis {message}")
logger.info(f"[{channel_name}] Publishing an event to Redis {message}")
return message, channel_name
def _deserialize_message(self, msg: Any, channel_key: str) -> M | None:
@@ -44,7 +44,7 @@ class BaseRedisEventBus(Generic[M], ABC):
return None
try:
data = json.loads(msg["data"])
logger.debug(f"Consuming an event from Redis {data}")
logger.info(f"Consuming an event from Redis {data}")
return self.Model(**data)
except Exception as e:
logger.error(f"Failed to parse event result from Redis {msg} {e}")

View File

@@ -1,10 +1,6 @@
import base64
import hashlib
import hmac
import logging
from datetime import datetime, timedelta
from typing import Optional, cast
from urllib.parse import quote_plus
from autogpt_libs.auth.models import DEFAULT_USER_ID
from fastapi import HTTPException
@@ -18,7 +14,6 @@ from backend.data.model import UserIntegrations, UserMetadata, UserMetadataRaw
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.server.v2.store.exceptions import DatabaseError
from backend.util.encryption import JSONCryptor
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -63,14 +58,6 @@ async def get_user_email_by_id(user_id: str) -> Optional[str]:
raise DatabaseError(f"Failed to get user email for user {user_id}: {e}") from e
async def get_user_by_email(email: str) -> Optional[User]:
try:
user = await prisma.user.find_unique(where={"email": email})
return User.model_validate(user) if user else None
except Exception as e:
raise DatabaseError(f"Failed to get user by email {email}: {e}") from e
async def update_user_email(user_id: str, email: str):
try:
await prisma.user.update(where={"id": user_id}, data={"email": email})
@@ -313,85 +300,3 @@ async def update_user_notification_preference(
raise DatabaseError(
f"Failed to update user notification preference for user {user_id}: {e}"
) from e
async def set_user_email_verification(user_id: str, verified: bool) -> None:
"""Set the email verification status for a user."""
try:
await User.prisma().update(
where={"id": user_id},
data={"emailVerified": verified},
)
except Exception as e:
raise DatabaseError(
f"Failed to set email verification status for user {user_id}: {e}"
) from e
async def get_user_email_verification(user_id: str) -> bool:
"""Get the email verification status for a user."""
try:
user = await User.prisma().find_unique_or_raise(
where={"id": user_id},
)
return user.emailVerified
except Exception as e:
raise DatabaseError(
f"Failed to get email verification status for user {user_id}: {e}"
) from e
def generate_unsubscribe_link(user_id: str) -> str:
"""Generate a link to unsubscribe from all notifications"""
# Create an HMAC using a secret key
secret_key = Settings().secrets.unsubscribe_secret_key
signature = hmac.new(
secret_key.encode("utf-8"), user_id.encode("utf-8"), hashlib.sha256
).digest()
# Create a token that combines the user_id and signature
token = base64.urlsafe_b64encode(
f"{user_id}:{signature.hex()}".encode("utf-8")
).decode("utf-8")
logger.info(f"Generating unsubscribe link for user {user_id}")
base_url = Settings().config.platform_base_url
return f"{base_url}/api/email/unsubscribe?token={quote_plus(token)}"
async def unsubscribe_user_by_token(token: str) -> None:
"""Unsubscribe a user from all notifications using the token"""
try:
# Decode the token
decoded = base64.urlsafe_b64decode(token).decode("utf-8")
user_id, received_signature_hex = decoded.split(":", 1)
# Verify the signature
secret_key = Settings().secrets.unsubscribe_secret_key
expected_signature = hmac.new(
secret_key.encode("utf-8"), user_id.encode("utf-8"), hashlib.sha256
).digest()
if not hmac.compare_digest(expected_signature.hex(), received_signature_hex):
raise ValueError("Invalid token signature")
user = await get_user_by_id(user_id)
await update_user_notification_preference(
user.id,
NotificationPreferenceDTO(
email=user.email,
daily_limit=0,
preferences={
NotificationType.AGENT_RUN: False,
NotificationType.ZERO_BALANCE: False,
NotificationType.LOW_BALANCE: False,
NotificationType.BLOCK_EXECUTION_FAILED: False,
NotificationType.CONTINUOUS_AGENT_ERROR: False,
NotificationType.DAILY_SUMMARY: False,
NotificationType.WEEKLY_SUMMARY: False,
NotificationType.MONTHLY_SUMMARY: False,
},
),
)
except Exception as e:
raise DatabaseError(f"Failed to unsubscribe user by token {token}: {e}") from e

View File

@@ -1,9 +1,9 @@
from .database import DatabaseManager
from .manager import ExecutionManager
from .scheduler import Scheduler
from .scheduler import ExecutionScheduler
__all__ = [
"DatabaseManager",
"ExecutionManager",
"Scheduler",
"ExecutionScheduler",
]

View File

@@ -1,3 +1,6 @@
from functools import wraps
from typing import Any, Callable, Concatenate, Coroutine, ParamSpec, TypeVar, cast
from backend.data.credit import get_user_credit_model
from backend.data.execution import (
ExecutionResult,
@@ -8,44 +11,24 @@ from backend.data.execution import (
get_incomplete_executions,
get_latest_execution,
update_execution_status,
update_graph_execution_start_time,
update_graph_execution_stats,
update_node_execution_stats,
upsert_execution_input,
upsert_execution_output,
)
from backend.data.graph import (
get_connected_output_nodes,
get_graph,
get_graph_metadata,
get_node,
)
from backend.data.notifications import (
create_or_add_to_user_notification_batch,
empty_user_notification_batch,
get_all_batches_by_type,
get_user_notification_batch,
get_user_notification_oldest_message_in_batch,
)
from backend.data.graph import get_graph, get_node
from backend.data.user import (
get_active_user_ids_in_timerange,
get_user_email_by_id,
get_user_email_verification,
get_user_integrations,
get_user_metadata,
get_user_notification_preference,
update_user_integrations,
update_user_metadata,
)
from backend.util.service import AppService, expose, exposed_run_and_wait
from backend.util.service import AppService, expose, register_pydantic_serializers
from backend.util.settings import Config
P = ParamSpec("P")
R = TypeVar("R")
config = Config()
_user_credit_model = get_user_credit_model()
async def _spend_credits(entry: NodeExecutionEntry) -> int:
return await _user_credit_model.spend_credits(entry, 0, 0)
class DatabaseManager(AppService):
@@ -63,15 +46,28 @@ class DatabaseManager(AppService):
def send_execution_update(self, execution_result: ExecutionResult):
self.event_queue.publish(execution_result)
@staticmethod
def exposed_run_and_wait(
f: Callable[P, Coroutine[None, None, R]]
) -> Callable[Concatenate[object, P], R]:
@expose
@wraps(f)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
coroutine = f(*args, **kwargs)
res = self.run_and_wait(coroutine)
return res
# Register serializers for annotations on bare function
register_pydantic_serializers(f)
return wrapper
# Executions
create_graph_execution = exposed_run_and_wait(create_graph_execution)
get_execution_results = exposed_run_and_wait(get_execution_results)
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
get_latest_execution = exposed_run_and_wait(get_latest_execution)
update_execution_status = exposed_run_and_wait(update_execution_status)
update_graph_execution_start_time = exposed_run_and_wait(
update_graph_execution_start_time
)
update_graph_execution_stats = exposed_run_and_wait(update_graph_execution_stats)
update_node_execution_stats = exposed_run_and_wait(update_node_execution_stats)
upsert_execution_input = exposed_run_and_wait(upsert_execution_input)
@@ -80,35 +76,16 @@ class DatabaseManager(AppService):
# Graphs
get_node = exposed_run_and_wait(get_node)
get_graph = exposed_run_and_wait(get_graph)
get_connected_output_nodes = exposed_run_and_wait(get_connected_output_nodes)
get_graph_metadata = exposed_run_and_wait(get_graph_metadata)
# Credits
spend_credits = exposed_run_and_wait(_spend_credits)
user_credit_model = get_user_credit_model()
spend_credits = cast(
Callable[[Any, NodeExecutionEntry, float, float], int],
exposed_run_and_wait(user_credit_model.spend_credits),
)
# User + User Metadata + User Integrations
get_user_metadata = exposed_run_and_wait(get_user_metadata)
update_user_metadata = exposed_run_and_wait(update_user_metadata)
get_user_integrations = exposed_run_and_wait(get_user_integrations)
update_user_integrations = exposed_run_and_wait(update_user_integrations)
# User Comms - async
get_active_user_ids_in_timerange = exposed_run_and_wait(
get_active_user_ids_in_timerange
)
get_user_email_by_id = exposed_run_and_wait(get_user_email_by_id)
get_user_email_verification = exposed_run_and_wait(get_user_email_verification)
get_user_notification_preference = exposed_run_and_wait(
get_user_notification_preference
)
# Notifications - async
create_or_add_to_user_notification_batch = exposed_run_and_wait(
create_or_add_to_user_notification_batch
)
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
get_all_batches_by_type = exposed_run_and_wait(get_all_batches_by_type)
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)
get_user_notification_oldest_message_in_batch = exposed_run_and_wait(
get_user_notification_oldest_message_in_batch
)

View File

@@ -12,15 +12,11 @@ from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
from redis.lock import Lock as RedisLock
from backend.blocks.basic import AgentOutputBlock
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.notifications import (
AgentRunData,
LowBalanceData,
NotificationEventDTO,
NotificationType,
)
from backend.util.exceptions import InsufficientBalanceError
if TYPE_CHECKING:
from backend.executor import DatabaseManager
@@ -109,10 +105,7 @@ class LogMetadata:
logger.exception(msg, extra={"json_fields": {**self.metadata, **extra}})
def _wrap(self, msg: str, **extra):
extra_msg = str(extra or "")
if len(extra_msg) > 1000:
extra_msg = extra_msg[:1000] + "..."
return f"{self.prefix} {msg} {extra_msg}"
return f"{self.prefix} {msg} {extra}"
T = TypeVar("T")
@@ -122,8 +115,9 @@ ExecutionStream = Generator[NodeExecutionEntry, None, None]
def execute_node(
db_client: "DatabaseManager",
creds_manager: IntegrationCredentialsManager,
notification_service: "NotificationManager",
data: NodeExecutionEntry,
execution_stats: NodeExecutionStats | None = None,
execution_stats: dict[str, Any] | None = None,
) -> ExecutionStream:
"""
Execute a node in the graph. This will trigger a block execution on a node,
@@ -206,10 +200,11 @@ def execute_node(
extra_exec_kwargs[field_name] = credentials
output_size = 0
cost = 0
try:
# Charge the user for the execution before running the block.
cost = db_client.spend_credits(data)
# TODO: We assume the block is executed within 0 seconds.
# This is fine because for now, there is no block that is charged by time.
cost = db_client.spend_credits(data, input_size + output_size, 0)
outputs: dict[str, Any] = {}
for output_name, output_data in node_block.execute(
@@ -233,6 +228,21 @@ def execute_node(
# Update execution status and spend credits
update_execution(ExecutionStatus.COMPLETED)
event = NotificationEventDTO(
user_id=user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=outputs,
agent_name=node_block.name,
credits_used=cost,
execution_time=0,
graph_id=graph_id,
node_count=1,
).model_dump(),
)
logger.info(f"Sending notification for {event}")
notification_service.queue_notification(event)
except Exception as e:
error_msg = str(e)
@@ -253,7 +263,7 @@ def execute_node(
raise e
finally:
# Ensure credentials are released even if execution fails
if creds_lock and creds_lock.locked():
if creds_lock:
try:
creds_lock.release()
except Exception as e:
@@ -261,12 +271,9 @@ def execute_node(
# Update execution stats
if execution_stats is not None:
execution_stats = execution_stats.model_copy(
update=node_block.execution_stats.model_dump()
)
execution_stats.input_size = input_size
execution_stats.output_size = output_size
execution_stats.cost = cost
execution_stats.update(node_block.execution_stats)
execution_stats["input_size"] = input_size
execution_stats["output_size"] = output_size
def _enqueue_next_nodes(
@@ -419,30 +426,46 @@ def validate_exec(
node_block: Block | None = get_block(node.block_id)
if not node_block:
return None, f"Block for {node.block_id} not found."
schema = node_block.input_schema
# Convert non-matching data types to the expected input schema.
for name, data_type in schema.__annotations__.items():
if (value := data.get(name)) and (type(value) is not data_type):
data[name] = convert(value, data_type)
if isinstance(node_block, AgentExecutorBlock):
# Validate the execution metadata for the agent executor block.
try:
exec_data = AgentExecutorBlock.Input(**node.input_default)
except Exception as e:
return None, f"Input data doesn't match {node_block.name}: {str(e)}"
# Validation input
input_schema = exec_data.input_schema
required_fields = set(input_schema["required"])
input_default = exec_data.data
else:
# Convert non-matching data types to the expected input schema.
for name, data_type in node_block.input_schema.__annotations__.items():
if (value := data.get(name)) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Validation input
input_schema = node_block.input_schema.jsonschema()
required_fields = node_block.input_schema.get_required_fields()
input_default = node.input_default
# Input data (without default values) should contain all required fields.
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
if missing_links := schema.get_missing_links(data, node.input_links):
return None, f"{error_prefix} unpopulated links {missing_links}"
input_fields_from_nodes = {link.sink_name for link in node.input_links}
if not input_fields_from_nodes.issubset(data):
return None, f"{error_prefix} {input_fields_from_nodes - set(data)}"
# Merge input data with default values and resolve dynamic dict/list/object pins.
input_default = schema.get_input_defaults(node.input_default)
data = {**input_default, **data}
if resolve_input:
data = merge_execution_input(data)
# Input data post-merge should contain all required fields from the schema.
if missing_input := schema.get_missing_input(data):
return None, f"{error_prefix} missing input {missing_input}"
if not required_fields.issubset(data):
return None, f"{error_prefix} {required_fields - set(data)}"
# Last validation: Validate the input values against the schema.
if error := schema.get_mismatch_error(data):
if error := json.validate_with_jsonschema(schema=input_schema, data=data):
error_message = f"{error_prefix} {error}"
logger.error(error_message)
return None, error_message
@@ -483,6 +506,7 @@ class Executor:
cls.pid = os.getpid()
cls.db_client = get_db_client()
cls.creds_manager = IntegrationCredentialsManager()
cls.notification_service = get_notification_service()
# Set up shutdown handlers
cls.shutdown_lock = threading.Lock()
@@ -523,7 +547,7 @@ class Executor:
cls,
q: ExecutionQueue[NodeExecutionEntry],
node_exec: NodeExecutionEntry,
) -> NodeExecutionStats:
) -> dict[str, Any]:
log_metadata = LogMetadata(
user_id=node_exec.user_id,
graph_eid=node_exec.graph_exec_id,
@@ -533,15 +557,13 @@ class Executor:
block_name="-",
)
execution_stats = NodeExecutionStats()
execution_stats = {}
timing_info, _ = cls._on_node_execution(
q, node_exec, log_metadata, execution_stats
)
execution_stats.walltime = timing_info.wall_time
execution_stats.cputime = timing_info.cpu_time
execution_stats["walltime"] = timing_info.wall_time
execution_stats["cputime"] = timing_info.cpu_time
if isinstance(execution_stats.error, Exception):
execution_stats.error = str(execution_stats.error)
cls.db_client.update_node_execution_stats(
node_exec.node_exec_id, execution_stats
)
@@ -554,13 +576,14 @@ class Executor:
q: ExecutionQueue[NodeExecutionEntry],
node_exec: NodeExecutionEntry,
log_metadata: LogMetadata,
stats: NodeExecutionStats | None = None,
stats: dict[str, Any] | None = None,
):
try:
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
for execution in execute_node(
db_client=cls.db_client,
creds_manager=cls.creds_manager,
notification_service=cls.notification_service,
data=node_exec,
execution_stats=stats,
):
@@ -577,9 +600,6 @@ class Executor:
f"Failed node execution {node_exec.node_exec_id}: {e}"
)
if stats is not None:
stats.error = e
@classmethod
def on_graph_executor_start(cls):
configure_logging()
@@ -588,7 +608,6 @@ class Executor:
cls.db_client = get_db_client()
cls.pool_size = settings.config.num_node_workers
cls.pid = os.getpid()
cls.notification_service = get_notification_service()
cls._init_node_executor_pool()
logger.info(
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
@@ -626,16 +645,12 @@ class Executor:
node_eid="*",
block_name="-",
)
cls.db_client.update_graph_execution_start_time(graph_exec.graph_exec_id)
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
graph_exec, cancel, log_metadata
)
exec_stats.walltime = timing_info.wall_time
exec_stats.cputime = timing_info.cpu_time
exec_stats.error = error
if isinstance(exec_stats.error, Exception):
exec_stats.error = str(exec_stats.error)
exec_stats["walltime"] = timing_info.wall_time
exec_stats["cputime"] = timing_info.cpu_time
exec_stats["error"] = str(error) if error else None
result = cls.db_client.update_graph_execution_stats(
graph_exec_id=graph_exec.graph_exec_id,
status=status,
@@ -643,8 +658,6 @@ class Executor:
)
cls.db_client.send_execution_update(result)
cls._handle_agent_run_notif(graph_exec, exec_stats)
@classmethod
@time_measured
def _on_graph_execution(
@@ -652,7 +665,7 @@ class Executor:
graph_exec: GraphExecutionEntry,
cancel: threading.Event,
log_metadata: LogMetadata,
) -> tuple[GraphExecutionStats, ExecutionStatus, Exception | None]:
) -> tuple[dict[str, Any], ExecutionStatus, Exception | None]:
"""
Returns:
dict: The execution statistics of the graph execution.
@@ -660,7 +673,11 @@ class Executor:
Exception | None: The error that occurred during the execution, if any.
"""
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
exec_stats = GraphExecutionStats()
exec_stats = {
"nodes_walltime": 0,
"nodes_cputime": 0,
"node_count": 0,
}
error = None
finished = False
@@ -686,26 +703,17 @@ class Executor:
queue.add(node_exec)
running_executions: dict[str, AsyncResult] = {}
low_balance_error: Optional[InsufficientBalanceError] = None
def make_exec_callback(exec_data: NodeExecutionEntry):
node_id = exec_data.node_id
def callback(result: object):
running_executions.pop(exec_data.node_id)
if not isinstance(result, NodeExecutionStats):
return
nonlocal exec_stats, low_balance_error
exec_stats.node_count += 1
exec_stats.nodes_cputime += result.cputime
exec_stats.nodes_walltime += result.walltime
exec_stats.cost += result.cost
if (err := result.error) and isinstance(err, Exception):
exec_stats.node_error_count += 1
if isinstance(err, InsufficientBalanceError):
low_balance_error = err
running_executions.pop(node_id)
nonlocal exec_stats
if isinstance(result, dict):
exec_stats["node_count"] += 1
exec_stats["nodes_cputime"] += result.get("cputime", 0)
exec_stats["nodes_walltime"] += result.get("walltime", 0)
return callback
@@ -750,16 +758,6 @@ class Executor:
execution.wait(3)
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
if isinstance(low_balance_error, InsufficientBalanceError):
cls._handle_low_balance_notif(
graph_exec.user_id,
graph_exec.graph_id,
exec_stats,
low_balance_error,
)
raise low_balance_error
except Exception as e:
log_metadata.exception(
f"Failed graph execution {graph_exec.graph_exec_id}: {e}"
@@ -778,67 +776,6 @@ class Executor:
error,
)
@classmethod
def _handle_agent_run_notif(
cls,
graph_exec: GraphExecutionEntry,
exec_stats: GraphExecutionStats,
):
metadata = cls.db_client.get_graph_metadata(
graph_exec.graph_id, graph_exec.graph_version
)
outputs = cls.db_client.get_execution_results(graph_exec.graph_exec_id)
named_outputs = [
{
key: value[0] if key == "name" else value
for key, value in output.output_data.items()
}
for output in outputs
if output.block_id == AgentOutputBlock().id
]
event = NotificationEventDTO(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=named_outputs,
agent_name=metadata.name if metadata else "Unknown Agent",
credits_used=exec_stats.cost,
execution_time=exec_stats.walltime,
graph_id=graph_exec.graph_id,
node_count=exec_stats.node_count,
).model_dump(),
)
cls.notification_service.queue_notification(event)
@classmethod
def _handle_low_balance_notif(
cls,
user_id: str,
graph_id: str,
exec_stats: GraphExecutionStats,
e: InsufficientBalanceError,
):
shortfall = e.balance - e.amount
metadata = cls.db_client.get_graph_metadata(graph_id)
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
cls.notification_service.queue_notification(
NotificationEventDTO(
user_id=user_id,
type=NotificationType.LOW_BALANCE,
data=LowBalanceData(
current_balance=exec_stats.cost,
billing_page_link=f"{base_url}/profile/credits",
shortfall=shortfall,
agent_name=metadata.name if metadata else "Unknown Agent",
).model_dump(),
)
)
class ExecutionManager(AppService):
def __init__(self):
@@ -941,11 +878,6 @@ class ExecutionManager(AppService):
else:
nodes_input.append((node.id, input_data))
if not nodes_input:
raise ValueError(
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
)
graph_exec_id, node_execs = self.db_client.create_graph_execution(
graph_id=graph_id,
graph_version=graph.version,

View File

@@ -1,23 +1,19 @@
import logging
import os
from enum import Enum
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
from apscheduler.job import Job as JobObj
from apscheduler.jobstores.memory import MemoryJobStore
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
from apscheduler.schedulers.blocking import BlockingScheduler
from apscheduler.triggers.cron import CronTrigger
from autogpt_libs.utils.cache import thread_cached
from dotenv import load_dotenv
from prisma.enums import NotificationType
from pydantic import BaseModel
from sqlalchemy import MetaData, create_engine
from backend.data.block import BlockInput
from backend.executor.manager import ExecutionManager
from backend.notifications.notifications import NotificationManager
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Config
@@ -46,7 +42,7 @@ config = Config()
def log(msg, **kwargs):
logger.info("[Scheduler] " + msg, **kwargs)
logger.info("[ExecutionScheduler] " + msg, **kwargs)
def job_listener(event):
@@ -62,15 +58,8 @@ def get_execution_client() -> ExecutionManager:
return get_service_client(ExecutionManager)
@thread_cached
def get_notification_client():
from backend.notifications import NotificationManager
return get_service_client(NotificationManager)
def execute_graph(**kwargs):
args = ExecutionJobArgs(**kwargs)
args = JobArgs(**kwargs)
try:
log(f"Executing recurring job for graph #{args.graph_id}")
get_execution_client().add_execution(
@@ -83,32 +72,7 @@ def execute_graph(**kwargs):
logger.exception(f"Error executing graph {args.graph_id}: {e}")
def process_existing_batches(**kwargs):
args = NotificationJobArgs(**kwargs)
try:
log(
f"Processing existing batches for notification type {args.notification_types}"
)
get_notification_client().process_existing_batches(args.notification_types)
except Exception as e:
logger.exception(f"Error processing existing batches: {e}")
def process_weekly_summary(**kwargs):
try:
log("Processing weekly summary")
get_notification_client().queue_weekly_summary()
except Exception as e:
logger.exception(f"Error processing weekly summary: {e}")
class Jobstores(Enum):
EXECUTION = "execution"
BATCHED_NOTIFICATIONS = "batched_notifications"
WEEKLY_NOTIFICATIONS = "weekly_notifications"
class ExecutionJobArgs(BaseModel):
class JobArgs(BaseModel):
graph_id: str
input_data: BlockInput
user_id: str
@@ -116,14 +80,14 @@ class ExecutionJobArgs(BaseModel):
cron: str
class ExecutionJobInfo(ExecutionJobArgs):
class JobInfo(JobArgs):
id: str
name: str
next_run_time: str
@staticmethod
def from_db(job_args: ExecutionJobArgs, job_obj: JobObj) -> "ExecutionJobInfo":
return ExecutionJobInfo(
def from_db(job_args: JobArgs, job_obj: JobObj) -> "JobInfo":
return JobInfo(
id=job_obj.id,
name=job_obj.name,
next_run_time=job_obj.next_run_time.isoformat(),
@@ -131,29 +95,7 @@ class ExecutionJobInfo(ExecutionJobArgs):
)
class NotificationJobArgs(BaseModel):
notification_types: list[NotificationType]
cron: str
class NotificationJobInfo(NotificationJobArgs):
id: str
name: str
next_run_time: str
@staticmethod
def from_db(
job_args: NotificationJobArgs, job_obj: JobObj
) -> "NotificationJobInfo":
return NotificationJobInfo(
id=job_obj.id,
name=job_obj.name,
next_run_time=job_obj.next_run_time.isoformat(),
**job_args.model_dump(),
)
class Scheduler(AppService):
class ExecutionScheduler(AppService):
scheduler: BlockingScheduler
@classmethod
@@ -169,38 +111,19 @@ class Scheduler(AppService):
def execution_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager)
@property
@thread_cached
def notification_client(self) -> NotificationManager:
return get_service_client(NotificationManager)
def run_service(self):
load_dotenv()
db_schema, db_url = _extract_schema_from_url(os.getenv("DATABASE_URL"))
self.scheduler = BlockingScheduler(
jobstores={
Jobstores.EXECUTION.value: SQLAlchemyJobStore(
"default": SQLAlchemyJobStore(
engine=create_engine(
url=db_url,
pool_size=self.db_pool_size(),
max_overflow=0,
),
metadata=MetaData(schema=db_schema),
# this one is pre-existing so it keeps the
# default table name.
tablename="apscheduler_jobs",
),
Jobstores.BATCHED_NOTIFICATIONS.value: SQLAlchemyJobStore(
engine=create_engine(
url=db_url,
pool_size=self.db_pool_size(),
max_overflow=0,
),
metadata=MetaData(schema=db_schema),
tablename="apscheduler_jobs_batched_notifications",
),
# These don't really need persistence
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
)
}
)
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
@@ -214,8 +137,8 @@ class Scheduler(AppService):
cron: str,
input_data: BlockInput,
user_id: str,
) -> ExecutionJobInfo:
job_args = ExecutionJobArgs(
) -> JobInfo:
job_args = JobArgs(
graph_id=graph_id,
input_data=input_data,
user_id=user_id,
@@ -227,80 +150,37 @@ class Scheduler(AppService):
CronTrigger.from_crontab(cron),
kwargs=job_args.model_dump(),
replace_existing=True,
jobstore=Jobstores.EXECUTION.value,
)
log(f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}")
return ExecutionJobInfo.from_db(job_args, job)
return JobInfo.from_db(job_args, job)
@expose
def delete_schedule(self, schedule_id: str, user_id: str) -> ExecutionJobInfo:
job = self.scheduler.get_job(schedule_id, jobstore=Jobstores.EXECUTION.value)
def delete_schedule(self, schedule_id: str, user_id: str) -> JobInfo:
job = self.scheduler.get_job(schedule_id)
if not job:
log(f"Job {schedule_id} not found.")
raise ValueError(f"Job #{schedule_id} not found.")
job_args = ExecutionJobArgs(**job.kwargs)
job_args = JobArgs(**job.kwargs)
if job_args.user_id != user_id:
raise ValueError("User ID does not match the job's user ID.")
log(f"Deleting job {schedule_id}")
job.remove()
return ExecutionJobInfo.from_db(job_args, job)
return JobInfo.from_db(job_args, job)
@expose
def get_execution_schedules(
self, graph_id: str | None = None, user_id: str | None = None
) -> list[ExecutionJobInfo]:
) -> list[JobInfo]:
schedules = []
for job in self.scheduler.get_jobs(jobstore=Jobstores.EXECUTION.value):
logger.info(
f"Found job {job.id} with cron schedule {job.trigger} and args {job.kwargs}"
)
job_args = ExecutionJobArgs(**job.kwargs)
for job in self.scheduler.get_jobs():
job_args = JobArgs(**job.kwargs)
if (
job.next_run_time is not None
and (graph_id is None or job_args.graph_id == graph_id)
and (user_id is None or job_args.user_id == user_id)
):
schedules.append(ExecutionJobInfo.from_db(job_args, job))
schedules.append(JobInfo.from_db(job_args, job))
return schedules
@expose
def add_batched_notification_schedule(
self,
notification_types: list[NotificationType],
data: dict,
cron: str,
) -> NotificationJobInfo:
job_args = NotificationJobArgs(
notification_types=notification_types,
cron=cron,
)
job = self.scheduler.add_job(
process_existing_batches,
CronTrigger.from_crontab(cron),
kwargs=job_args.model_dump(),
replace_existing=True,
jobstore=Jobstores.BATCHED_NOTIFICATIONS.value,
)
log(f"Added job {job.id} with cron schedule '{cron}' input data: {data}")
return NotificationJobInfo.from_db(job_args, job)
@expose
def add_weekly_notification_schedule(self, cron: str) -> NotificationJobInfo:
job = self.scheduler.add_job(
process_weekly_summary,
CronTrigger.from_crontab(cron),
kwargs={},
replace_existing=True,
jobstore=Jobstores.WEEKLY_NOTIFICATIONS.value,
)
log(f"Added job {job.id} with cron schedule '{cron}'")
return NotificationJobInfo.from_db(
NotificationJobArgs(
cron=cron, notification_types=[NotificationType.WEEKLY_SUMMARY]
),
job,
)

View File

@@ -169,16 +169,7 @@ zerobounce_credentials = APIKeyCredentials(
expires_at=None,
)
example_credentials = APIKeyCredentials(
id="a2b7f68f-aa6a-4995-99ec-b45b40d33498",
provider="example-provider",
api_key=SecretStr(settings.secrets.example_api_key),
title="Use Credits for Example",
expires_at=None,
)
DEFAULT_CREDENTIALS = [
example_credentials,
ollama_credentials,
revid_credentials,
ideogram_credentials,
@@ -234,8 +225,6 @@ class IntegrationCredentialsStore:
all_credentials.append(ollama_credentials)
# These will only be added if the API key is set
if settings.secrets.example_api_key:
all_credentials.append(example_credentials)
if settings.secrets.revid_api_key:
all_credentials.append(revid_credentials)
if settings.secrets.ideogram_api_key:

View File

@@ -10,7 +10,6 @@ class ProviderName(str, Enum):
D_ID = "d_id"
E2B = "e2b"
EXA = "exa"
EXAMPLE_PROVIDER = "example-provider"
FAL = "fal"
GITHUB = "github"
GOOGLE = "google"

View File

@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING
from .compass import CompassWebhookManager
from .example import ExampleWebhookManager
from .github import GithubWebhooksManager
from .slant3d import Slant3DWebhooksManager
@@ -16,7 +15,6 @@ WEBHOOK_MANAGERS_BY_NAME: dict["ProviderName", type["BaseWebhooksManager"]] = {
CompassWebhookManager,
GithubWebhooksManager,
Slant3DWebhooksManager,
ExampleWebhookManager,
]
}
# --8<-- [end:WEBHOOK_MANAGERS_BY_NAME]

View File

@@ -1,147 +0,0 @@
import logging
import requests
from fastapi import Request
from strenum import StrEnum
from backend.data import integrations
from backend.data.model import APIKeyCredentials, Credentials
from backend.integrations.providers import ProviderName
from ._manual_base import ManualWebhookManagerBase
logger = logging.getLogger(__name__)
class ExampleWebhookEventType(StrEnum):
EXAMPLE_EVENT = "example_event"
ANOTHER_EXAMPLE_EVENT = "another_example_event"
# ExampleWebhookManager is a class that manages webhooks for a hypothetical provider.
# It extends ManualWebhookManagerBase, which provides base functionality for manual webhook management.
class ExampleWebhookManager(ManualWebhookManagerBase):
# Define the provider name for this webhook manager.
PROVIDER_NAME = ProviderName.EXAMPLE_PROVIDER
# Define the types of webhooks this manager can handle.
WebhookEventType = ExampleWebhookEventType
BASE_URL = "https://api.example.com"
@classmethod
async def validate_payload(
cls, webhook: integrations.Webhook, request: Request
) -> tuple[dict, str]:
"""
Validate the incoming webhook payload.
Args:
webhook (integrations.Webhook): The webhook object.
request (Request): The incoming request object.
Returns:
tuple: A tuple containing the payload as a dictionary and the event type as a string.
"""
# Extract the JSON payload from the request.
payload = await request.json()
# Set the event type based on the webhook type in the payload.
event_type = payload.get("webhook_type", ExampleWebhookEventType.EXAMPLE_EVENT)
# For the payload its better to return a pydantic model
# rather than a weakly typed dict here
return payload, event_type
async def _register_webhook(
self,
credentials: Credentials,
webhook_type: str,
resource: str,
events: list[str],
ingress_url: str,
secret: str,
) -> tuple[str, dict]:
"""
Register a new webhook with the provider.
Args:
credentials (Credentials): The credentials required for authentication.
webhook_type (str): The type of webhook to register.
resource (str): The resource associated with the webhook.
events (list[str]): The list of events to subscribe to.
ingress_url (str): The URL where the webhook will send data.
secret (str): A secret for securing the webhook.
Returns:
tuple: A tuple containing the provider's webhook ID, if any, and the webhook configuration as a dictionary.
"""
# Ensure the credentials are of the correct type.
if not isinstance(credentials, APIKeyCredentials):
raise ValueError("API key is required to register a webhook")
# Prepare the headers for the request, including the API key.
headers = {
"api-key": credentials.api_key.get_secret_value(),
"Content-Type": "application/json",
}
# Prepare the payload for the request. Note that the events list is not used.
# This is just a fake example
payload = {"endPoint": ingress_url}
# Send a POST request to register the webhook.
response = requests.post(
f"{self.BASE_URL}/example/webhookSubscribe", headers=headers, json=payload
)
# Check if the response indicates a failure.
if not response.ok:
error = response.json().get("error", "Unknown error")
raise RuntimeError(f"Failed to register webhook: {error}")
# Prepare the webhook configuration to return.
webhook_config = {
"endpoint": ingress_url,
"provider": self.PROVIDER_NAME,
"events": ["example_event"],
"type": webhook_type,
}
return "", webhook_config
async def _deregister_webhook(
self, webhook: integrations.Webhook, credentials: Credentials
) -> None:
"""
Deregister a webhook with the provider.
Args:
webhook (integrations.Webhook): The webhook object to deregister.
credentials (Credentials): The credentials associated with the webhook.
Raises:
ValueError: If the webhook doesn't belong to the credentials or if deregistration fails.
"""
if webhook.credentials_id != credentials.id:
raise ValueError(
f"Webhook #{webhook.id} does not belong to credentials {credentials.id}"
)
if not isinstance(credentials, APIKeyCredentials):
raise ValueError("API key is required to deregister a webhook")
headers = {
"api-key": credentials.api_key.get_secret_value(),
"Content-Type": "application/json",
}
# Construct the delete URL based on the webhook information
delete_url = f"{self.BASE_URL}/example/webhooks/{webhook.provider_webhook_id}"
response = requests.delete(delete_url, headers=headers)
if response.status_code not in [204, 404]:
# 204 means successful deletion, 404 means the webhook was already deleted
error = response.json().get("error", "Unknown error")
raise ValueError(f"Failed to delete webhook: {error}")
# If we reach here, the webhook was successfully deleted or didn't exist

View File

@@ -7,9 +7,9 @@ from prisma.enums import NotificationType
from pydantic import BaseModel
from backend.data.notifications import (
NotificationDataType_co,
NotificationEventModel,
NotificationTypeOverride,
T_co,
)
from backend.util.settings import Settings
from backend.util.text import TextFormatter
@@ -48,11 +48,7 @@ class EmailSender:
self,
notification: NotificationType,
user_email: str,
data: (
NotificationEventModel[NotificationDataType_co]
| list[NotificationEventModel[NotificationDataType_co]]
),
user_unsub_link: str | None = None,
data: NotificationEventModel[T_co] | list[NotificationEventModel[T_co]],
):
"""Send an email to a user using a template pulled from the notification type"""
if not self.postmark:
@@ -60,34 +56,20 @@ class EmailSender:
return
template = self._get_template(notification)
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
# Handle the case when data is a list
template_data = data
if isinstance(data, list):
# Create a dictionary with a 'notifications' key containing the list
template_data = {"notifications": data}
try:
subject, full_message = self.formatter.format_email(
base_template=template.base_template,
subject_template=template.subject_template,
content_template=template.body_template,
data=template_data,
unsubscribe_link=f"{base_url}/profile/settings",
data=data,
unsubscribe_link="https://autogpt.com/unsubscribe",
)
except Exception as e:
logger.error(f"Error formatting full message: {e}")
raise e
self._send_email(
user_email=user_email,
user_unsubscribe_link=user_unsub_link,
subject=subject,
body=full_message,
)
self._send_email(user_email, subject, full_message)
def _get_template(self, notification: NotificationType):
# convert the notification type to a notification type override
@@ -108,13 +90,7 @@ class EmailSender:
base_template=base_template,
)
def _send_email(
self,
user_email: str,
subject: str,
body: str,
user_unsubscribe_link: str | None = None,
):
def _send_email(self, user_email: str, subject: str, body: str):
if not self.postmark:
logger.warning("Email tried to send without postmark configured")
return
@@ -124,13 +100,4 @@ class EmailSender:
To=user_email,
Subject=subject,
HtmlBody=body,
# Headers default to None internally so this is fine
Headers=(
{
"List-Unsubscribe-Post": "List-Unsubscribe=One-Click",
"List-Unsubscribe": f"<{user_unsubscribe_link}>",
}
if user_unsubscribe_link
else None
),
)

View File

@@ -1,36 +1,23 @@
import logging
import time
from datetime import datetime, timedelta, timezone
from typing import Callable
import aio_pika
from aio_pika.exceptions import QueueEmpty
from autogpt_libs.utils.cache import thread_cached
from prisma.enums import NotificationType
from pydantic import BaseModel
from backend.data.notifications import (
BaseSummaryData,
BaseSummaryParams,
DailySummaryData,
DailySummaryParams,
BatchingStrategy,
NotificationEventDTO,
NotificationEventModel,
NotificationResult,
NotificationTypeOverride,
QueueType,
SummaryParamsEventDTO,
SummaryParamsEventModel,
WeeklySummaryData,
WeeklySummaryParams,
get_batch_delay,
get_notif_data_type,
get_summary_params_type,
get_data_type,
)
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.user import generate_unsubscribe_link
from backend.data.user import get_user_email_by_id, get_user_notification_preference
from backend.notifications.email import EmailSender
from backend.util.service import AppService, expose, get_service_client
from backend.util.service import AppService, expose
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -59,35 +46,6 @@ def create_notification_config() -> RabbitMQConfig:
"x-dead-letter-routing-key": "failed.immediate",
},
),
Queue(
name="admin_notifications",
exchange=notification_exchange,
routing_key="notification.admin.#",
arguments={
"x-dead-letter-exchange": dead_letter_exchange.name,
"x-dead-letter-routing-key": "failed.admin",
},
),
# Summary notification queues
Queue(
name="summary_notifications",
exchange=notification_exchange,
routing_key="notification.summary.#",
arguments={
"x-dead-letter-exchange": dead_letter_exchange.name,
"x-dead-letter-routing-key": "failed.summary",
},
),
# Batch Queue
Queue(
name="batch_notifications",
exchange=notification_exchange,
routing_key="notification.batch.#",
arguments={
"x-dead-letter-exchange": dead_letter_exchange.name,
"x-dead-letter-routing-key": "failed.batch",
},
),
# Failed notifications queue
Queue(
name="failed_notifications",
@@ -105,25 +63,12 @@ def create_notification_config() -> RabbitMQConfig:
)
@thread_cached
def get_scheduler():
from backend.executor import Scheduler
return get_service_client(Scheduler)
@thread_cached
def get_db():
from backend.executor.database import DatabaseManager
return get_service_client(DatabaseManager)
class NotificationManager(AppService):
"""Service for handling notifications with batching support"""
def __init__(self):
super().__init__()
self.use_db = True
self.rabbitmq_config = create_notification_config()
self.running = True
self.email_sender = EmailSender()
@@ -132,165 +77,13 @@ class NotificationManager(AppService):
def get_port(cls) -> int:
return settings.config.notification_service_port
def get_routing_key(self, event_type: NotificationType) -> str:
strategy = NotificationTypeOverride(event_type).strategy
def get_routing_key(self, event: NotificationEventModel) -> str:
"""Get the appropriate routing key for an event"""
if strategy == QueueType.IMMEDIATE:
return f"notification.immediate.{event_type.value}"
elif strategy == QueueType.BACKOFF:
return f"notification.backoff.{event_type.value}"
elif strategy == QueueType.ADMIN:
return f"notification.admin.{event_type.value}"
elif strategy == QueueType.BATCH:
return f"notification.batch.{event_type.value}"
elif strategy == QueueType.SUMMARY:
return f"notification.summary.{event_type.value}"
return f"notification.{event_type.value}"
@expose
def queue_weekly_summary(self):
"""Process weekly summary for specified notification types"""
try:
logger.info("Processing weekly summary queuing operation")
processed_count = 0
current_time = datetime.now(tz=timezone.utc)
start_time = current_time - timedelta(days=7)
users = get_db().get_active_user_ids_in_timerange(
end_time=current_time.isoformat(),
start_time=start_time.isoformat(),
)
for user in users:
self._queue_scheduled_notification(
SummaryParamsEventDTO(
user_id=user,
type=NotificationType.WEEKLY_SUMMARY,
data=WeeklySummaryParams(
start_date=start_time,
end_date=current_time,
).model_dump(),
),
)
processed_count += 1
logger.info(f"Processed {processed_count} weekly summaries into queue")
except Exception as e:
logger.exception(f"Error processing weekly summary: {e}")
@expose
def process_existing_batches(self, notification_types: list[NotificationType]):
"""Process existing batches for specified notification types"""
try:
processed_count = 0
current_time = datetime.now(tz=timezone.utc)
for notification_type in notification_types:
# Get all batches for this notification type
batches = get_db().get_all_batches_by_type(notification_type)
for batch in batches:
# Check if batch has aged out
oldest_message = (
get_db().get_user_notification_oldest_message_in_batch(
batch.user_id, notification_type
)
)
if not oldest_message:
# this should never happen
logger.error(
f"Batch for user {batch.user_id} and type {notification_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
)
continue
max_delay = get_batch_delay(notification_type)
# If batch has aged out, process it
if oldest_message.created_at + max_delay < current_time:
recipient_email = get_db().get_user_email_by_id(batch.user_id)
if not recipient_email:
logger.error(
f"User email not found for user {batch.user_id}"
)
continue
should_send = self._should_email_user_based_on_preference(
batch.user_id, notification_type
)
if not should_send:
logger.debug(
f"User {batch.user_id} does not want to receive {notification_type} notifications"
)
# Clear the batch
get_db().empty_user_notification_batch(
batch.user_id, notification_type
)
continue
batch_data = get_db().get_user_notification_batch(
batch.user_id, notification_type
)
if not batch_data or not batch_data.notifications:
logger.error(
f"Batch data not found for user {batch.user_id}"
)
# Clear the batch
get_db().empty_user_notification_batch(
batch.user_id, notification_type
)
continue
unsub_link = generate_unsubscribe_link(batch.user_id)
events = [
NotificationEventModel[
get_notif_data_type(db_event.type)
].model_validate(
{
"user_id": batch.user_id,
"type": db_event.type,
"data": db_event.data,
"created_at": db_event.created_at,
}
)
for db_event in batch_data.notifications
]
logger.info(f"{events=}")
self.email_sender.send_templated(
notification=notification_type,
user_email=recipient_email,
data=events,
user_unsub_link=unsub_link,
)
# Clear the batch
get_db().empty_user_notification_batch(
batch.user_id, notification_type
)
processed_count += 1
logger.info(f"Processed {processed_count} aged batches")
return {
"success": True,
"processed_count": processed_count,
"notification_types": [nt.value for nt in notification_types],
"timestamp": current_time.isoformat(),
}
except Exception as e:
logger.exception(f"Error processing batches: {e}")
return {
"success": False,
"error": str(e),
"notification_types": [nt.value for nt in notification_types],
"timestamp": datetime.now(tz=timezone.utc).isoformat(),
}
if event.strategy == BatchingStrategy.IMMEDIATE:
return f"notification.immediate.{event.type.value}"
elif event.strategy == BatchingStrategy.BACKOFF:
return f"notification.backoff.{event.type.value}"
return f"notification.{event.type.value}"
@expose
def queue_notification(self, event: NotificationEventDTO) -> NotificationResult:
@@ -299,9 +92,9 @@ class NotificationManager(AppService):
logger.info(f"Received Request to queue {event=}")
# Workaround for not being able to serialize generics over the expose bus
parsed_event = NotificationEventModel[
get_notif_data_type(event.type)
get_data_type(event.type)
].model_validate(event.model_dump())
routing_key = self.get_routing_key(parsed_event.type)
routing_key = self.get_routing_key(parsed_event)
message = parsed_event.model_dump_json()
logger.info(f"Received Request to queue {message=}")
@@ -328,182 +121,24 @@ class NotificationManager(AppService):
logger.exception(f"Error queueing notification: {e}")
return NotificationResult(success=False, message=str(e))
def _queue_scheduled_notification(self, event: SummaryParamsEventDTO):
"""Queue a scheduled notification - exposed method for other services to call"""
try:
logger.info(f"Received Request to queue scheduled notification {event=}")
parsed_event = SummaryParamsEventModel[
get_summary_params_type(event.type)
].model_validate(event.model_dump())
routing_key = self.get_routing_key(event.type)
message = parsed_event.model_dump_json()
logger.info(f"Received Request to queue {message=}")
exchange = "notifications"
# Publish to RabbitMQ
self.run_and_wait(
self.rabbit.publish_message(
routing_key=routing_key,
message=message,
exchange=next(
ex for ex in self.rabbit_config.exchanges if ex.name == exchange
),
)
)
except Exception as e:
logger.exception(f"Error queueing notification: {e}")
def _should_email_user_based_on_preference(
self, user_id: str, event_type: NotificationType
) -> bool:
"""Check if a user wants to receive a notification based on their preferences and email verification status"""
validated_email = get_db().get_user_email_verification(user_id)
preference = (
get_db()
.get_user_notification_preference(user_id)
.preferences.get(event_type, True)
)
# only if both are true, should we email this person
return validated_email and preference
def _gather_summary_data(
self, user_id: str, event_type: NotificationType, params: BaseSummaryParams
) -> BaseSummaryData:
"""Gathers the data to build a summary notification"""
logger.info(
f"Gathering summary data for {user_id} and {event_type} wiht {params=}"
)
# total_credits_used = self.run_and_wait(
# get_total_credits_used(user_id, start_time, end_time)
# )
# total_executions = self.run_and_wait(
# get_total_executions(user_id, start_time, end_time)
# )
# most_used_agent = self.run_and_wait(
# get_most_used_agent(user_id, start_time, end_time)
# )
# execution_times = self.run_and_wait(
# get_execution_time(user_id, start_time, end_time)
# )
# runs = self.run_and_wait(
# get_runs(user_id, start_time, end_time)
# )
total_credits_used = 3.0
total_executions = 2
most_used_agent = {"name": "Some"}
execution_times = [1, 2, 3]
runs = [{"status": "COMPLETED"}, {"status": "FAILED"}]
successful_runs = len([run for run in runs if run["status"] == "COMPLETED"])
failed_runs = len([run for run in runs if run["status"] != "COMPLETED"])
average_execution_time = (
sum(execution_times) / len(execution_times) if execution_times else 0
)
# cost_breakdown = self.run_and_wait(
# get_cost_breakdown(user_id, start_time, end_time)
# )
cost_breakdown = {
"agent1": 1.0,
"agent2": 2.0,
}
if event_type == NotificationType.DAILY_SUMMARY and isinstance(
params, DailySummaryParams
):
return DailySummaryData(
total_credits_used=total_credits_used,
total_executions=total_executions,
most_used_agent=most_used_agent["name"],
total_execution_time=sum(execution_times),
successful_runs=successful_runs,
failed_runs=failed_runs,
average_execution_time=average_execution_time,
cost_breakdown=cost_breakdown,
date=params.date,
)
elif event_type == NotificationType.WEEKLY_SUMMARY and isinstance(
params, WeeklySummaryParams
):
return WeeklySummaryData(
total_credits_used=total_credits_used,
total_executions=total_executions,
most_used_agent=most_used_agent["name"],
total_execution_time=sum(execution_times),
successful_runs=successful_runs,
failed_runs=failed_runs,
average_execution_time=average_execution_time,
cost_breakdown=cost_breakdown,
start_date=params.start_date,
end_date=params.end_date,
)
else:
raise ValueError("Invalid event type or params")
def _should_batch(
self, user_id: str, event_type: NotificationType, event: NotificationEventModel
) -> bool:
get_db().create_or_add_to_user_notification_batch(user_id, event_type, event)
oldest_message = get_db().get_user_notification_oldest_message_in_batch(
user_id, event_type
)
if not oldest_message:
logger.error(
f"Batch for user {user_id} and type {event_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
)
return False
oldest_age = oldest_message.created_at
max_delay = get_batch_delay(event_type)
if oldest_age + max_delay < datetime.now(tz=timezone.utc):
logger.info(f"Batch for user {user_id} and type {event_type} is old enough")
return True
logger.info(
f"Batch for user {user_id} and type {event_type} is not old enough: {oldest_age + max_delay} < {datetime.now(tz=timezone.utc)} max_delay={max_delay}"
)
return False
return self.run_and_wait(
get_user_notification_preference(user_id)
).preferences.get(event_type, True)
def _parse_message(self, message: str) -> NotificationEvent | None:
try:
event = NotificationEventDTO.model_validate_json(message)
model = NotificationEventModel[
get_notif_data_type(event.type)
get_data_type(event.type)
].model_validate_json(message)
return NotificationEvent(event=event, model=model)
except Exception as e:
logger.error(f"Error parsing message due to non matching schema {e}")
return None
def _process_admin_message(self, message: str) -> bool:
"""Process a single notification, sending to an admin, returning whether to put into the failed queue"""
try:
parsed = self._parse_message(message)
if not parsed:
return False
event = parsed.event
model = parsed.model
logger.debug(f"Processing notification for admin: {model}")
recipient_email = settings.config.refund_notification_email
self.email_sender.send_templated(event.type, recipient_email, model)
return True
except Exception as e:
logger.exception(f"Error processing notification for admin queue: {e}")
return False
def _process_immediate(self, message: str) -> bool:
"""Process a single notification immediately, returning whether to put into the failed queue"""
try:
@@ -512,9 +147,11 @@ class NotificationManager(AppService):
return False
event = parsed.event
model = parsed.model
logger.debug(f"Processing immediate notification: {model}")
recipient_email = get_db().get_user_email_by_id(event.user_id)
if event.recipient_email:
recipient_email = event.recipient_email
else:
recipient_email = self.run_and_wait(get_user_email_by_id(event.user_id))
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
@@ -528,126 +165,11 @@ class NotificationManager(AppService):
)
return True
unsub_link = generate_unsubscribe_link(event.user_id)
self.email_sender.send_templated(
notification=event.type,
user_email=recipient_email,
data=model,
user_unsub_link=unsub_link,
)
self.email_sender.send_templated(event.type, recipient_email, model)
logger.info(f"Processing notification: {model}")
return True
except Exception as e:
logger.exception(f"Error processing notification for immediate queue: {e}")
return False
def _process_batch(self, message: str) -> bool:
"""Process a single notification with a batching strategy, returning whether to put into the failed queue"""
try:
parsed = self._parse_message(message)
if not parsed:
return False
event = parsed.event
model = parsed.model
logger.info(f"Processing batch notification: {model}")
recipient_email = get_db().get_user_email_by_id(event.user_id)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
should_send = self._should_email_user_based_on_preference(
event.user_id, event.type
)
if not should_send:
logger.info(
f"User {event.user_id} does not want to receive {event.type} notifications"
)
return True
should_send = self._should_batch(event.user_id, event.type, model)
if not should_send:
logger.info("Batch not old enough to send")
return False
batch = get_db().get_user_notification_batch(event.user_id, event.type)
if not batch or not batch.notifications:
logger.error(f"Batch not found for user {event.user_id}")
return False
unsub_link = generate_unsubscribe_link(event.user_id)
batch_messages = [
NotificationEventModel[
get_notif_data_type(db_event.type)
].model_validate(
{
"user_id": event.user_id,
"type": db_event.type,
"data": db_event.data,
"created_at": db_event.created_at,
}
)
for db_event in batch.notifications
]
self.email_sender.send_templated(
notification=event.type,
user_email=recipient_email,
data=batch_messages,
user_unsub_link=unsub_link,
)
# only empty the batch if we sent the email successfully
get_db().empty_user_notification_batch(event.user_id, event.type)
return True
except Exception as e:
logger.exception(f"Error processing notification for batch queue: {e}")
return False
def _process_summary(self, message: str) -> bool:
"""Process a single notification with a summary strategy, returning whether to put into the failed queue"""
try:
logger.info(f"Processing summary notification: {message}")
event = SummaryParamsEventDTO.model_validate_json(message)
model = SummaryParamsEventModel[
get_summary_params_type(event.type)
].model_validate_json(message)
logger.info(f"Processing summary notification: {model}")
recipient_email = get_db().get_user_email_by_id(event.user_id)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
should_send = self._should_email_user_based_on_preference(
event.user_id, event.type
)
if not should_send:
logger.info(
f"User {event.user_id} does not want to receive {event.type} notifications"
)
return True
summary_data = self._gather_summary_data(
event.user_id, event.type, model.data
)
unsub_link = generate_unsubscribe_link(event.user_id)
data = NotificationEventModel(
user_id=event.user_id,
type=event.type,
data=summary_data,
)
self.email_sender.send_templated(
notification=event.type,
user_email=recipient_email,
data=data,
user_unsub_link=unsub_link,
)
return True
except Exception as e:
logger.exception(f"Error processing notification for summary queue: {e}")
logger.exception(f"Error processing notification: {e}")
return False
def _run_queue(
@@ -682,33 +204,12 @@ class NotificationManager(AppService):
def run_service(self):
logger.info(f"[{self.service_name}] Started notification service")
# Set up scheduler for batch processing of all notification types
# this can be changed later to spawn differnt cleanups on different schedules
try:
get_scheduler().add_batched_notification_schedule(
notification_types=list(NotificationType),
data={},
cron="0 * * * *",
)
# get_scheduler().add_weekly_notification_schedule(
# # weekly on Friday at 12pm
# cron="0 12 * * 5",
# )
logger.info("Scheduled notification cleanup")
except Exception as e:
logger.error(f"Error scheduling notification cleanup: {e}")
# Set up queue consumers
channel = self.run_and_wait(self.rabbit.get_channel())
immediate_queue = self.run_and_wait(
channel.get_queue("immediate_notifications")
)
batch_queue = self.run_and_wait(channel.get_queue("batch_notifications"))
admin_queue = self.run_and_wait(channel.get_queue("admin_notifications"))
summary_queue = self.run_and_wait(channel.get_queue("summary_notifications"))
while self.running:
try:
@@ -717,22 +218,6 @@ class NotificationManager(AppService):
process_func=self._process_immediate,
error_queue_name="immediate_notifications",
)
self._run_queue(
queue=admin_queue,
process_func=self._process_admin_message,
error_queue_name="admin_notifications",
)
self._run_queue(
queue=batch_queue,
process_func=self._process_batch,
error_queue_name="batch_notifications",
)
self._run_queue(
queue=summary_queue,
process_func=self._process_summary,
error_queue_name="summary_notifications",
)
time.sleep(0.1)

View File

@@ -1,142 +1,75 @@
{# Agent Run #}
{# Template variables:
notification.data: the stuff below but a list of them
data.agent_name: the name of the agent
data.name: the name of the agent
data.credits_used: the number of credits used by the agent
data.node_count: the number of nodes the agent ran on
data.execution_time: the time it took to run the agent
data.graph_id: the id of the graph the agent ran on
data.outputs: the list of outputs of the agent
data.outputs: the dict[str, Any] of outputs of the agent
#}
{% if notifications is defined %}
{# BATCH MODE #}
<div style="font-family: 'Poppins', sans-serif; color: #070629;">
<h2 style="color: #5D23BB; margin-bottom: 15px;">Agent Run Summary</h2>
<p style="font-size: 16px; line-height: 165%; margin-top: 0; margin-bottom: 15px;">
<strong>{{ notifications|length }}</strong> agent runs have completed!
</p>
{# Calculate summary stats #}
{% set total_time = 0 %}
{% set total_nodes = 0 %}
{% set total_credits = 0 %}
{% set agent_names = [] %}
{% for notification in notifications %}
{% set total_time = total_time + notification.data.execution_time %}
{% set total_nodes = total_nodes + notification.data.node_count %}
{% set total_credits = total_credits + notification.data.credits_used %}
{% if notification.data.agent_name not in agent_names %}
{% set agent_names = agent_names + [notification.data.agent_name] %}
{% endif %}
{% endfor %}
<div style="background-color: #f8f7ff; border-radius: 8px; padding: 15px; margin-bottom: 25px;">
<h3 style="margin-top: 0; margin-bottom: 10px; color: #5D23BB;">Summary</h3>
<p style="margin: 5px 0;"><strong>Agents:</strong> {{ agent_names|join(", ") }}</p>
<p style="margin: 5px 0;"><strong>Total Time:</strong> {{ total_time | int }} seconds</p>
<p style="margin: 5px 0;"><strong>Total Nodes:</strong> {{ total_nodes }}</p>
<p style="margin: 5px 0;"><strong>Total Cost:</strong> ${{ "{:.2f}".format((total_credits|float)/100) }}</p>
</div>
<h3 style="margin-top: 25px; margin-bottom: 15px; color: #5D23BB;">Individual Runs</h3>
{% for notification in notifications %}
<div style="margin-bottom: 30px; border-left: 3px solid #5D23BB; padding-left: 15px;">
<p style="font-size: 16px; font-weight: 600; margin-top: 0; margin-bottom: 10px;">
Agent: <strong>{{ notification.data.agent_name }}</strong>
</p>
<div style="margin-left: 10px;">
<p style="margin: 5px 0;"><strong>Time:</strong> {{ notification.data.execution_time | int }} seconds</p>
<p style="margin: 5px 0;"><strong>Nodes:</strong> {{ notification.data.node_count }}</p>
<p style="margin: 5px 0;"><strong>Cost:</strong> ${{ "{:.2f}".format((notification.data.credits_used|float)/100) }}</p>
</div>
{% if notification.data.outputs and notification.data.outputs|length > 0 %}
<div style="margin-left: 10px; margin-top: 15px;">
<p style="font-weight: 600; margin-bottom: 10px;">Results:</p>
{% for output in notification.data.outputs %}
<div style="margin-left: 10px; margin-bottom: 12px;">
<p style="color: #5D23BB; font-weight: 500; margin-top: 0; margin-bottom: 5px;">
{{ output.name }}
</p>
{% for key, value in output.items() %}
{% if key != 'name' %}
<div style="margin-left: 10px; background-color: #f5f5ff; padding: 8px 12px; border-radius: 4px;
font-family: 'Roboto Mono', monospace; white-space: pre-wrap; word-break: break-word;
overflow-wrap: break-word; max-width: 100%; overflow-x: auto; margin-top: 3px;
margin-bottom: 8px; line-height: 1.4;">
{% if value is iterable and value is not string %}
{% if value|length == 1 %}
{{ value[0] }}
{% else %}
[{% for item in value %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}]
{% endif %}
{% else %}
{{ value }}
{% endif %}
</div>
{% endif %}
{% endfor %}
</div>
{% endfor %}
</div>
{% endif %}
</div>
{% endfor %}
</div>
{% else %}
{# SINGLE NOTIFICATION MODE - Original template #}
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%;
margin-top: 0; margin-bottom: 10px;">
Your agent, <strong>{{ data.agent_name }}</strong>, has completed its run!
</p>
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%;
margin-top: 0; margin-bottom: 20px; padding-left: 20px;">
<p style="margin-bottom: 10px;"><strong>Time Taken:</strong> {{ data.execution_time | int }} seconds</p>
<p style="margin-bottom: 10px;"><strong>Nodes Used:</strong> {{ data.node_count }}</p>
<p style="margin-bottom: 10px;"><strong>Cost:</strong> ${{ "{:.2f}".format((data.credits_used|float)/100) }}</p>
</p>
{% if data.outputs and data.outputs|length > 0 %}
<div style="margin-left: 15px; margin-bottom: 20px;">
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-weight: 600;
font-size: 16px; margin-bottom: 10px;">
Results:
</p>
{% for output in data.outputs %}
<div style="margin-left: 15px; margin-bottom: 15px;">
<p style="font-family: 'Poppins', sans-serif; color: #5D23BB; font-weight: 500;
font-size: 16px; margin-top: 0; margin-bottom: 8px;">
{{ output.name }}
</p>
{% for key, value in output.items() %}
{% if key != 'name' %}
<div style="margin-left: 15px; background-color: #f5f5ff; padding: 8px 12px; border-radius: 4px;
font-family: 'Roboto Mono', monospace; white-space: pre-wrap; word-break: break-word;
overflow-wrap: break-word; max-width: 100%; overflow-x: auto; margin-top: 5px;
margin-bottom: 10px; line-height: 1.4;">
{% if value is iterable and value is not string %}
{% if value|length == 1 %}
{{ value[0] }}
{% else %}
[{% for item in value %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}]
{% endif %}
{% else %}
{{ value }}
{% endif %}
</div>
{% endif %}
{% endfor %}
</div>
{% endfor %}
</div>
{% endif %}
{% endif %}
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 10px;
">
Hi,
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 10px;
">
We've run your agent {{ data.name }} and it took {{ data.execution_time }} seconds to complete.
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 10px;
">
It ran on {{ data.node_count }} nodes and used {{ data.credits_used }} credits.
</p>
<ul style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 10px;
">
It output the following:
{# jinja2 list iteration thorugh data.outputs #}
{% for key, value in data.outputs.items() %}
<li>{{ key }}: {{ value }}</li>
{% endfor %}
</ul>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 10px;
">
Your feedback has been instrumental in shaping AutoGPT, and we couldn't have
done it without you. We look forward to continuing this journey together as we
bring AI-powered automation to the world.
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 0;
">
Thank you again for your time and support.
</p>

View File

@@ -227,10 +227,19 @@
<table class="ml-8 wrapper" border="0" cellspacing="0" cellpadding="0"
style="color: #070629; text-align: left;">
<tr>
<td class="col mobile-center" align="center" width="80">
<img
src="https://storage.mlcdn.com/account_image/597379/68W8w94Zwl52yQyrKdFERRquu2CivAcn17ST22HF.jpg"
border="0" alt="" width="80" class="avatar"
style="display: inline-block; max-width: 80px; border-radius: 80px;">
</td>
<td class="col" width="30" height="30" style="line-height: 30px;"></td>
<td class="col center mobile-center" align>
<p
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-top: 0; margin-bottom: 0;">
Thank you for being a part of the AutoGPT community! Join the conversation on our Discord <a href="https://discord.gg/autogpt" style="color: #4285F4; text-decoration: underline;">here</a> and share your thoughts with us anytime.
John Ababseh<br>Product Manager<br>
<a href="mailto:john.ababseh@agpt.co" target="_blank"
style="color: #4285F4; font-weight: normal; font-style: normal; text-decoration: underline;">john.ababseh@agpt.co</a>
</p>
</td>
</tr>

View File

@@ -1,114 +0,0 @@
{# Low Balance Notification Email Template #}
{# Template variables:
data.agent_name: the name of the agent
data.current_balance: the current balance of the user
data.billing_page_link: the link to the billing page
data.shortfall: the shortfall amount
#}
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Low Balance Warning</strong>
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 20px;
">
Your agent "<strong>{{ data.agent_name }}</strong>" has been stopped due to low balance.
</p>
<div style="
margin-left: 15px;
margin-bottom: 20px;
padding: 15px;
border-left: 4px solid #5D23BB;
background-color: #f8f8ff;
">
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Current Balance:</strong> ${{ "{:.2f}".format((data.current_balance|float)/100) }}
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Shortfall:</strong> ${{ "{:.2f}".format((data.shortfall|float)/100) }}
</p>
</div>
<div style="
margin-left: 15px;
margin-bottom: 20px;
padding: 15px;
border-left: 4px solid #FF6B6B;
background-color: #FFF0F0;
">
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Low Balance:</strong>
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 5px;
">
Your agent "<strong>{{ data.agent_name }}</strong>" requires additional credits to continue running. The current operation has been canceled until your balance is replenished.
</p>
</div>
<div style="
text-align: center;
margin: 30px 0;
">
<a href="{{ data.billing_page_link }}" style="
font-family: 'Poppins', sans-serif;
background-color: #5D23BB;
color: white;
padding: 12px 24px;
text-decoration: none;
border-radius: 4px;
font-weight: 500;
display: inline-block;
">
Manage Billing
</a>
</div>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 150%;
margin-top: 30px;
margin-bottom: 10px;
font-style: italic;
">
This is an automated notification. Your agent is stopped and will need manually restarted unless set to trigger automatically.
</p>

View File

@@ -1,27 +0,0 @@
{# Weekly Summary #}
{# Template variables:
data: the stuff below
data.start_date: the start date of the summary
data.end_date: the end date of the summary
data.total_credits_used: the total credits used during the summary
data.total_executions: the total number of executions during the summary
data.most_used_agent: the most used agent's nameduring the summary
data.total_execution_time: the total execution time during the summary
data.successful_runs: the total number of successful runs during the summary
data.failed_runs: the total number of failed runs during the summary
data.average_execution_time: the average execution time during the summary
data.cost_breakdown: the cost breakdown during the summary
#}
<h1>Weekly Summary</h1>
<p>Start Date: {{ data.start_date }}</p>
<p>End Date: {{ data.end_date }}</p>
<p>Total Credits Used: {{ data.total_credits_used }}</p>
<p>Total Executions: {{ data.total_executions }}</p>
<p>Most Used Agent: {{ data.most_used_agent }}</p>
<p>Total Execution Time: {{ data.total_execution_time }}</p>
<p>Successful Runs: {{ data.successful_runs }}</p>
<p>Failed Runs: {{ data.failed_runs }}</p>
<p>Average Execution Time: {{ data.average_execution_time }}</p>
<p>Cost Breakdown: {{ data.cost_breakdown }}</p>

View File

@@ -1,5 +1,5 @@
from backend.app import run_processes
from backend.executor import DatabaseManager, Scheduler
from backend.executor import DatabaseManager, ExecutionScheduler
from backend.notifications.notifications import NotificationManager
from backend.server.rest_api import AgentServer
@@ -11,7 +11,7 @@ def main():
run_processes(
NotificationManager(),
DatabaseManager(),
Scheduler(),
ExecutionScheduler(),
AgentServer(),
)

View File

@@ -16,19 +16,14 @@ import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.user
import backend.server.integrations.router
import backend.server.routers.v1
import backend.server.v2.library.db
import backend.server.v2.library.model
import backend.server.v2.library.routes
import backend.server.v2.otto.routes
import backend.server.v2.postmark.postmark
import backend.server.v2.store.model
import backend.server.v2.store.routes
import backend.util.service
import backend.util.settings
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.server.external.api import external_app
settings = backend.util.settings.Settings()
@@ -69,7 +64,8 @@ docs_url = (
app = fastapi.FastAPI(
title="AutoGPT Agent Server",
description=(
"This server is used to execute agents that are created by the AutoGPT system."
"This server is used to execute agents that are created by the "
"AutoGPT system."
),
summary="AutoGPT Agent Server",
version="0.1",
@@ -102,15 +98,6 @@ app.include_router(
app.include_router(
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
)
app.include_router(
backend.server.v2.otto.routes.router, tags=["v2"], prefix="/api/otto"
)
app.include_router(
backend.server.v2.postmark.postmark.router,
tags=["v2", "email"],
prefix="/api/email",
)
app.mount("/external-api", external_app)
@@ -154,10 +141,9 @@ class AgentServer(backend.util.service.AppProcess):
graph_id: str,
graph_version: int,
user_id: str,
for_export: bool = False,
):
return await backend.server.routers.v1.get_graph(
graph_id, user_id, graph_version, for_export
graph_id, user_id, graph_version
)
@staticmethod
@@ -257,15 +243,5 @@ class AgentServer(backend.util.service.AppProcess):
):
return await backend.server.v2.store.routes.review_submission(request, user)
@staticmethod
def test_create_credentials(
user_id: str,
provider: ProviderName,
credentials: Credentials,
) -> Credentials:
return backend.server.integrations.router.create_credentials(
user_id=user_id, provider=provider, credentials=credentials
)
def set_test_dependency_overrides(self, overrides: dict):
app.dependency_overrides.update(overrides)

View File

@@ -10,14 +10,12 @@ from autogpt_libs.auth.middleware import auth_middleware
from autogpt_libs.feature_flag.client import feature_flag
from autogpt_libs.utils.cache import thread_cached
from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
from starlette.status import HTTP_204_NO_CONTENT
from typing_extensions import Optional, TypedDict
import backend.data.block
import backend.server.integrations.router
import backend.server.routers.analytics
import backend.server.v2.library.db as library_db
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.api_key import (
APIKeyError,
@@ -38,23 +36,18 @@ from backend.data.credit import (
TransactionHistory,
get_auto_top_up,
get_block_costs,
get_stripe_customer_id,
get_user_credit_model,
set_auto_top_up,
)
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.data.onboarding import (
UserOnboardingUpdate,
get_recommended_agents,
get_user_onboarding,
update_user_onboarding,
)
from backend.data.user import (
get_or_create_user,
get_user_notification_preference,
update_user_email,
update_user_notification_preference,
)
from backend.executor import ExecutionManager, Scheduler, scheduler
from backend.executor import ExecutionManager, ExecutionScheduler, scheduler
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import (
on_graph_activate,
@@ -83,8 +76,8 @@ def execution_manager_client() -> ExecutionManager:
@thread_cached
def execution_scheduler_client() -> Scheduler:
return get_service_client(Scheduler)
def execution_scheduler_client() -> ExecutionScheduler:
return get_service_client(ExecutionScheduler)
settings = Settings()
@@ -157,38 +150,6 @@ async def update_preferences(
return output
########################################################
##################### Onboarding #######################
########################################################
@v1_router.get(
"/onboarding", tags=["onboarding"], dependencies=[Depends(auth_middleware)]
)
async def get_onboarding(user_id: Annotated[str, Depends(get_user_id)]):
return await get_user_onboarding(user_id)
@v1_router.patch(
"/onboarding", tags=["onboarding"], dependencies=[Depends(auth_middleware)]
)
async def update_onboarding(
user_id: Annotated[str, Depends(get_user_id)], data: UserOnboardingUpdate
):
return await update_user_onboarding(user_id, data)
@v1_router.get(
"/onboarding/agents",
tags=["onboarding"],
dependencies=[Depends(auth_middleware)],
)
async def get_onboarding_agents(
user_id: Annotated[str, Depends(get_user_id)],
):
return await get_recommended_agents(user_id)
########################################################
##################### Blocks ###########################
########################################################
@@ -340,7 +301,15 @@ async def stripe_webhook(request: Request):
async def manage_payment_method(
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[str, str]:
return {"url": await _user_credit_model.create_billing_portal_session(user_id)}
session = stripe.billing_portal.Session.create(
customer=await get_stripe_customer_id(user_id),
return_url=settings.config.frontend_base_url + "/profile/credits",
)
if not session:
raise HTTPException(
status_code=400, detail="Failed to create billing portal session"
)
return {"url": session.url}
@v1_router.get(path="/credits/transactions", dependencies=[Depends(auth_middleware)])
@@ -396,10 +365,10 @@ async def get_graph(
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
version: int | None = None,
for_export: bool = False,
hide_credentials: bool = False,
) -> graph_db.GraphModel:
graph = await graph_db.get_graph(
graph_id, version, user_id=user_id, for_export=for_export
graph_id, version, user_id=user_id, for_export=hide_credentials
)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
@@ -424,19 +393,18 @@ async def get_graph_all_versions(
path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)]
)
async def create_new_graph(
create_graph: CreateGraph,
user_id: Annotated[str, Depends(get_user_id)],
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.GraphModel:
graph = graph_db.make_graph_model(create_graph.graph, user_id)
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
graph.validate_graph(for_run=False)
graph = await graph_db.create_graph(graph, user_id=user_id)
# Create a library agent for the new graph
library_agent = await library_db.create_library_agent(graph, user_id)
_ = asyncio.create_task(
library_db.add_generated_agent_image(graph, library_agent.id)
await library_db.create_library_agent(
graph.id,
graph.version,
user_id,
)
graph = await on_graph_activate(
@@ -481,10 +449,17 @@ async def update_graph(
latest_version_number = max(g.version for g in existing_versions)
graph.version = latest_version_number + 1
latest_version_graph = next(
v for v in existing_versions if v.version == latest_version_number
)
current_active_version = next((v for v in existing_versions if v.is_active), None)
if latest_version_graph.is_template != graph.is_template:
raise HTTPException(
400, detail="Changing is_template on an existing graph is forbidden"
)
graph.is_active = not graph.is_template
graph = graph_db.make_graph_model(graph, user_id)
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
graph.validate_graph(for_run=False)
graph.reassign_ids(user_id=user_id)
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
@@ -574,10 +549,14 @@ def execute_graph(
user_id: Annotated[str, Depends(get_user_id)],
graph_version: Optional[int] = None,
) -> ExecuteGraphResponse:
graph_exec = execution_manager_client().add_execution(
graph_id, node_input, user_id=user_id, graph_version=graph_version
)
return ExecuteGraphResponse(graph_exec_id=graph_exec.graph_exec_id)
try:
graph_exec = execution_manager_client().add_execution(
graph_id, node_input, user_id=user_id, graph_version=graph_version
)
return ExecuteGraphResponse(graph_exec_id=graph_exec.graph_exec_id)
except Exception as e:
msg = str(e).encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)
@v1_router.post(
@@ -615,7 +594,7 @@ async def stop_graph_run(
async def get_graphs_executions(
user_id: Annotated[str, Depends(get_user_id)],
) -> list[graph_db.GraphExecutionMeta]:
return await graph_db.get_graph_executions(user_id=user_id)
return await graph_db.get_graphs_executions(user_id=user_id)
@v1_router.get(
@@ -646,26 +625,11 @@ async def get_graph_execution(
result = await graph_db.get_execution(execution_id=graph_exec_id, user_id=user_id)
if not result:
raise HTTPException(
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
)
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return result
@v1_router.delete(
path="/executions/{graph_exec_id}",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
status_code=HTTP_204_NO_CONTENT,
)
async def delete_graph_execution(
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> None:
await execution_db.delete_execution(graph_exec_id=graph_exec_id, user_id=user_id)
########################################################
##################### Schedules ########################
########################################################
@@ -686,7 +650,7 @@ class ScheduleCreationRequest(pydantic.BaseModel):
async def create_schedule(
user_id: Annotated[str, Depends(get_user_id)],
schedule: ScheduleCreationRequest,
) -> scheduler.ExecutionJobInfo:
) -> scheduler.JobInfo:
graph = await graph_db.get_graph(
schedule.graph_id, schedule.graph_version, user_id=user_id
)
@@ -728,7 +692,7 @@ def delete_schedule(
def get_execution_schedules(
user_id: Annotated[str, Depends(get_user_id)],
graph_id: str | None = None,
) -> list[scheduler.ExecutionJobInfo]:
) -> list[scheduler.JobInfo]:
return execution_scheduler_client().get_execution_schedules(
user_id=user_id,
graph_id=graph_id,

View File

@@ -1,62 +1,27 @@
import logging
from typing import Optional
import fastapi
import prisma.errors
import prisma.fields
import prisma.models
import prisma.types
import backend.data.graph
import backend.server.model
import backend.server.v2.library.model as library_model
import backend.server.v2.store.exceptions as store_exceptions
import backend.server.v2.store.image_gen as store_image_gen
import backend.server.v2.store.media as store_media
from backend.data.db import locked_transaction
from backend.data.includes import library_agent_include
from backend.util.settings import Config
logger = logging.getLogger(__name__)
config = Config()
async def list_library_agents(
user_id: str,
search_term: Optional[str] = None,
sort_by: library_model.LibraryAgentSort = library_model.LibraryAgentSort.UPDATED_AT,
page: int = 1,
page_size: int = 50,
) -> library_model.LibraryAgentResponse:
"""
Retrieves a paginated list of LibraryAgent records for a given user.
Args:
user_id: The ID of the user whose LibraryAgents we want to retrieve.
search_term: Optional string to filter agents by name/description.
sort_by: Sorting field (createdAt, updatedAt, isFavorite, isCreatedByUser).
page: Current page (1-indexed).
page_size: Number of items per page.
Returns:
A LibraryAgentResponse containing the list of agents and pagination details.
Raises:
DatabaseError: If there is an issue fetching from Prisma.
"""
async def get_library_agents(
user_id: str, search_query: str | None = None
) -> list[library_model.LibraryAgent]:
logger.debug(
f"Fetching library agents for user_id={user_id}, "
f"search_term={repr(search_term)}, "
f"sort_by={sort_by}, page={page}, page_size={page_size}"
f"Fetching library agents for user_id={user_id} search_query={search_query}"
)
if page < 1 or page_size < 1:
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
raise store_exceptions.DatabaseError("Invalid pagination input")
if search_term and len(search_term.strip()) > 100:
logger.warning(f"Search term too long: {repr(search_term)}")
raise store_exceptions.DatabaseError("Search term is too long")
if search_query and len(search_query.strip()) > 100:
logger.warning(f"Search query too long: {search_query}")
raise store_exceptions.DatabaseError("Search query is too long.")
where_clause: prisma.types.LibraryAgentWhereInput = {
"userId": user_id,
@@ -64,206 +29,69 @@ async def list_library_agents(
"isArchived": False,
}
# Build search filter if applicable
if search_term:
if search_query:
where_clause["OR"] = [
{
"Agent": {
"is": {"name": {"contains": search_term, "mode": "insensitive"}}
"is": {"name": {"contains": search_query, "mode": "insensitive"}}
}
},
{
"Agent": {
"is": {
"description": {"contains": search_term, "mode": "insensitive"}
"description": {"contains": search_query, "mode": "insensitive"}
}
}
},
]
# Determine sorting
order_by: prisma.types.LibraryAgentOrderByInput | None = None
if sort_by == library_model.LibraryAgentSort.CREATED_AT:
order_by = {"createdAt": "asc"}
elif sort_by == library_model.LibraryAgentSort.UPDATED_AT:
order_by = {"updatedAt": "desc"}
try:
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
where=where_clause,
include=library_agent_include(user_id),
order=order_by,
skip=(page - 1) * page_size,
take=page_size,
include={
"Agent": {
"include": {
"AgentNodes": {"include": {"Input": True, "Output": True}}
}
}
},
order=[{"updatedAt": "desc"}],
)
agent_count = await prisma.models.LibraryAgent.prisma().count(
where=where_clause
)
logger.debug(
f"Retrieved {len(library_agents)} library agents for user #{user_id}"
)
# Only pass valid agents to the response
valid_library_agents: list[library_model.LibraryAgent] = []
for agent in library_agents:
try:
library_agent = library_model.LibraryAgent.from_db(agent)
valid_library_agents.append(library_agent)
except Exception as e:
# Skip this agent if there was an error
logger.error(
f"Error parsing LibraryAgent when getting library agents from db: {e}"
)
continue
# Return the response with only valid agents
return library_model.LibraryAgentResponse(
agents=valid_library_agents,
pagination=backend.server.model.Pagination(
total_items=agent_count,
total_pages=(agent_count + page_size - 1) // page_size,
current_page=page,
page_size=page_size,
),
)
logger.debug(f"Retrieved {len(library_agents)} agents for user_id={user_id}.")
return [library_model.LibraryAgent.from_db(agent) for agent in library_agents]
except prisma.errors.PrismaError as e:
logger.error(f"Database error fetching library agents: {e}")
raise store_exceptions.DatabaseError("Failed to fetch library agents") from e
async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent:
"""
Get a specific agent from the user's library.
Args:
library_agent_id: ID of the library agent to retrieve.
user_id: ID of the authenticated user.
Returns:
The requested LibraryAgent.
Raises:
AgentNotFoundError: If the specified agent does not exist.
DatabaseError: If there's an error during retrieval.
"""
try:
library_agent = await prisma.models.LibraryAgent.prisma().find_first(
where={
"id": id,
"userId": user_id,
"isDeleted": False,
},
include=library_agent_include(user_id),
)
if not library_agent:
raise store_exceptions.AgentNotFoundError(f"Library agent #{id} not found")
return library_model.LibraryAgent.from_db(library_agent)
except prisma.errors.PrismaError as e:
logger.error(f"Database error fetching library agent: {e}")
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
async def add_generated_agent_image(
graph: backend.data.graph.GraphModel,
library_agent_id: str,
) -> Optional[prisma.models.LibraryAgent]:
"""
Generates an image for the specified LibraryAgent and updates its record.
"""
user_id = graph.user_id
graph_id = graph.id
# Use .jpeg here since we are generating JPEG images
filename = f"agent_{graph_id}.jpeg"
try:
if not (image_url := await store_media.check_media_exists(user_id, filename)):
# Generate agent image as JPEG
image = await store_image_gen.generate_agent_image(graph)
# Create UploadFile with the correct filename and content_type
image_file = fastapi.UploadFile(file=image, filename=filename)
image_url = await store_media.upload_media(
user_id=user_id, file=image_file, use_file_name=True
)
except Exception as e:
logger.warning(f"Error generating and uploading agent image: {e}")
return None
return await prisma.models.LibraryAgent.prisma().update(
where={"id": library_agent_id},
data={"imageUrl": image_url},
)
raise store_exceptions.DatabaseError("Unable to fetch library agents.")
async def create_library_agent(
graph: backend.data.graph.GraphModel,
user_id: str,
agent_id: str, agent_version: int, user_id: str
) -> prisma.models.LibraryAgent:
"""
Adds an agent to the user's library (LibraryAgent table).
Args:
agent: The agent/Graph to add to the library.
user_id: The user to whom the agent will be added.
Returns:
The newly created LibraryAgent record.
Raises:
AgentNotFoundError: If the specified agent does not exist.
DatabaseError: If there's an error during creation or if image generation fails.
Adds an agent to the user's library (LibraryAgent table)
"""
logger.info(
f"Creating library agent for graph #{graph.id} v{graph.version}; "
f"user #{user_id}"
)
try:
return await prisma.models.LibraryAgent.prisma().create(
data={
"isCreatedByUser": (user_id == graph.user_id),
"userId": user_id,
"agentId": agent_id,
"agentVersion": agent_version,
"isCreatedByUser": False,
"useGraphIsActiveVersion": True,
"User": {"connect": {"id": user_id}},
"Agent": {
"connect": {
"graphVersionId": {"id": graph.id, "version": graph.version}
}
},
}
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error creating agent in library: {e}")
raise store_exceptions.DatabaseError("Failed to create agent in library") from e
logger.error(f"Database error creating agent to library: {str(e)}")
raise store_exceptions.DatabaseError("Failed to create agent to library") from e
async def update_agent_version_in_library(
user_id: str,
agent_id: str,
agent_version: int,
user_id: str, agent_id: str, agent_version: int
) -> None:
"""
Updates the agent version in the library if useGraphIsActiveVersion is True.
Args:
user_id: Owner of the LibraryAgent.
agent_id: The agent's ID to update.
agent_version: The new version of the agent.
Raises:
DatabaseError: If there's an error with the update.
Updates the agent version in the library
"""
logger.debug(
f"Updating agent version in library for user #{user_id}, "
f"agent #{agent_id} v{agent_version}"
)
try:
library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise(
where={
@@ -283,7 +111,7 @@ async def update_agent_version_in_library(
},
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating agent version in library: {e}")
logger.error(f"Database error updating agent version in library: {str(e)}")
raise store_exceptions.DatabaseError(
"Failed to update agent version in library"
) from e
@@ -292,43 +120,23 @@ async def update_agent_version_in_library(
async def update_library_agent(
library_agent_id: str,
user_id: str,
auto_update_version: Optional[bool] = None,
is_favorite: Optional[bool] = None,
is_archived: Optional[bool] = None,
is_deleted: Optional[bool] = None,
auto_update_version: bool = False,
is_favorite: bool = False,
is_archived: bool = False,
is_deleted: bool = False,
) -> None:
"""
Updates the specified LibraryAgent record.
Args:
library_agent_id: The ID of the LibraryAgent to update.
user_id: The owner of this LibraryAgent.
auto_update_version: Whether the agent should auto-update to active version.
is_favorite: Whether this agent is marked as a favorite.
is_archived: Whether this agent is archived.
is_deleted: Whether this agent is deleted.
Raises:
DatabaseError: If there's an error in the update operation.
Updates the library agent with the given fields
"""
logger.debug(
f"Updating library agent {library_agent_id} for user {user_id} with "
f"auto_update_version={auto_update_version}, is_favorite={is_favorite}, "
f"is_archived={is_archived}, is_deleted={is_deleted}"
)
update_fields: prisma.types.LibraryAgentUpdateManyMutationInput = {}
if auto_update_version is not None:
update_fields["useGraphIsActiveVersion"] = auto_update_version
if is_favorite is not None:
update_fields["isFavorite"] = is_favorite
if is_archived is not None:
update_fields["isArchived"] = is_archived
if is_deleted is not None:
update_fields["isDeleted"] = is_deleted
try:
await prisma.models.LibraryAgent.prisma().update_many(
where={"id": library_agent_id, "userId": user_id}, data=update_fields
where={"id": library_agent_id, "userId": user_id},
data={
"useGraphIsActiveVersion": auto_update_version,
"isFavorite": is_favorite,
"isArchived": is_archived,
"isDeleted": is_deleted,
},
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating library agent: {str(e)}")
@@ -344,140 +152,76 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
where={"agentId": graph_id, "userId": user_id}
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error deleting library agent: {e}")
logger.error(f"Database error deleting library agent: {str(e)}")
raise store_exceptions.DatabaseError("Failed to delete library agent") from e
async def add_store_agent_to_library(
store_listing_version_id: str, user_id: str
) -> library_model.LibraryAgent:
"""
Adds an agent from a store listing version to the user's library if they don't already have it.
Args:
store_listing_version_id: The ID of the store listing version containing the agent.
user_id: The users library to which the agent is being added.
Returns:
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
Raises:
AgentNotFoundError: If the store listing or associated agent is not found.
DatabaseError: If there's an issue creating the LibraryAgent record.
"""
logger.debug(
f"Adding agent from store listing version #{store_listing_version_id} "
f"to library for user #{user_id}"
)
try:
async with locked_transaction(f"user_trx_{user_id}"):
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}, include={"Agent": True}
)
)
if not store_listing_version or not store_listing_version.Agent:
logger.warning(
f"Store listing version not found: {store_listing_version_id}"
)
raise store_exceptions.AgentNotFoundError(
f"Store listing version {store_listing_version_id} not found or invalid"
)
graph = store_listing_version.Agent
if graph.userId == user_id:
logger.warning(
f"User #{user_id} attempted to add their own agent to their library"
)
raise store_exceptions.DatabaseError("Cannot add own agent to library")
# Check if user already has this agent
existing_library_agent = (
await prisma.models.LibraryAgent.prisma().find_first(
where={
"userId": user_id,
"agentId": graph.id,
"agentVersion": graph.version,
},
include=library_agent_include(user_id),
)
)
if existing_library_agent:
if existing_library_agent.isDeleted:
# Even if agent exists it needs to be marked as not deleted
await set_is_deleted_for_library_agent(
user_id, graph.id, graph.version, False
)
else:
logger.debug(
f"User #{user_id} already has graph #{graph.id} "
"in their library"
)
return library_model.LibraryAgent.from_db(existing_library_agent)
# Create LibraryAgent entry
added_agent = await prisma.models.LibraryAgent.prisma().create(
data={
"userId": user_id,
"agentId": graph.id,
"agentVersion": graph.version,
"isCreatedByUser": False,
},
include=library_agent_include(user_id),
)
logger.debug(
f"Added graph #{graph.id} "
f"for store listing #{store_listing_version.id} "
f"to library for user #{user_id}"
)
return library_model.LibraryAgent.from_db(added_agent)
except store_exceptions.AgentNotFoundError:
# Reraise for external handling.
raise
except prisma.errors.PrismaError as e:
logger.error(f"Database error adding agent to library: {e}")
raise store_exceptions.DatabaseError("Failed to add agent to library") from e
async def set_is_deleted_for_library_agent(
user_id: str, agent_id: str, agent_version: int, is_deleted: bool
) -> None:
"""
Changes the isDeleted flag for a library agent.
Args:
user_id: The user's library from which the agent is being removed.
agent_id: The ID of the agent to remove.
agent_version: The version of the agent to remove.
is_deleted: Whether the agent is being marked as deleted.
Raises:
DatabaseError: If there's an issue updating the Library
Finds the agent from the store listing version and adds it to the user's library (LibraryAgent table)
if they don't already have it
"""
logger.debug(
f"Setting isDeleted={is_deleted} for agent {agent_id} v{agent_version} "
f"in library for user {user_id}"
f"Adding agent from store listing version {store_listing_version_id} to library for user {user_id}"
)
try:
logger.warning(
f"Setting isDeleted={is_deleted} for agent {agent_id} v{agent_version} in library for user {user_id}"
# Get store listing version to find agent
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}, include={"Agent": True}
)
)
count = await prisma.models.LibraryAgent.prisma().update_many(
if not store_listing_version or not store_listing_version.Agent:
logger.warning(
f"Store listing version not found: {store_listing_version_id}"
)
raise store_exceptions.AgentNotFoundError(
f"Store listing version {store_listing_version_id} not found"
)
agent = store_listing_version.Agent
if agent.userId == user_id:
logger.warning(
f"User {user_id} cannot add their own agent to their library"
)
raise store_exceptions.DatabaseError("Cannot add own agent to library")
# Check if user already has this agent
existing_user_agent = await prisma.models.LibraryAgent.prisma().find_first(
where={
"userId": user_id,
"agentId": agent_id,
"agentVersion": agent_version,
},
data={"isDeleted": is_deleted},
"agentId": agent.id,
"agentVersion": agent.version,
}
)
logger.warning(f"Updated {count} isDeleted library agents")
if existing_user_agent:
logger.debug(
f"User {user_id} already has agent {agent.id} in their library"
)
return
# Create LibraryAgent entry
await prisma.models.LibraryAgent.prisma().create(
data={
"userId": user_id,
"agentId": agent.id,
"agentVersion": agent.version,
"isCreatedByUser": False,
}
)
logger.debug(f"Added agent {agent.id} to library for user {user_id}")
except store_exceptions.AgentNotFoundError:
raise
except prisma.errors.PrismaError as e:
logger.error(f"Database error setting agent isDeleted: {e}")
raise store_exceptions.DatabaseError(
"Failed to set agent isDeleted in library"
) from e
logger.error(f"Database error adding agent to library: {str(e)}")
raise store_exceptions.DatabaseError("Failed to add agent to library") from e
##############################################
@@ -488,44 +232,20 @@ async def set_is_deleted_for_library_agent(
async def get_presets(
user_id: str, page: int, page_size: int
) -> library_model.LibraryAgentPresetResponse:
"""
Retrieves a paginated list of AgentPresets for the specified user.
Args:
user_id: The user ID whose presets are being retrieved.
page: The current page index (0-based or 1-based, clarify in your domain).
page_size: Number of items to retrieve per page.
Returns:
A LibraryAgentPresetResponse containing a list of presets and pagination info.
Raises:
DatabaseError: If there's a database error during the operation.
"""
logger.debug(
f"Fetching presets for user #{user_id}, page={page}, page_size={page_size}"
)
if page < 0 or page_size < 1:
logger.warning(
"Invalid pagination input: page=%d, page_size=%d", page, page_size
)
raise store_exceptions.DatabaseError("Invalid pagination parameters")
try:
presets_records = await prisma.models.AgentPreset.prisma().find_many(
presets = await prisma.models.AgentPreset.prisma().find_many(
where={"userId": user_id},
skip=page * page_size,
take=page_size,
)
total_items = await prisma.models.AgentPreset.prisma().count(
where={"userId": user_id}
where={"userId": user_id},
)
total_pages = (total_items + page_size - 1) // page_size
presets = [
library_model.LibraryAgentPreset.from_db(preset)
for preset in presets_records
library_model.LibraryAgentPreset.from_db(preset) for preset in presets
]
return library_model.LibraryAgentPresetResponse(
@@ -539,67 +259,34 @@ async def get_presets(
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error getting presets: {e}")
logger.error(f"Database error getting presets: {str(e)}")
raise store_exceptions.DatabaseError("Failed to fetch presets") from e
async def get_preset(
user_id: str, preset_id: str
) -> library_model.LibraryAgentPreset | None:
"""
Retrieves a single AgentPreset by its ID for a given user.
Args:
user_id: The user that owns the preset.
preset_id: The ID of the preset.
Returns:
A LibraryAgentPreset if it exists and matches the user, otherwise None.
Raises:
DatabaseError: If there's a database error during the fetch.
"""
logger.debug(f"Fetching preset #{preset_id} for user #{user_id}")
try:
preset = await prisma.models.AgentPreset.prisma().find_unique(
where={"id": preset_id},
include={"InputPresets": True},
where={"id": preset_id}, include={"InputPresets": True}
)
if not preset or preset.userId != user_id:
return None
return library_model.LibraryAgentPreset.from_db(preset)
except prisma.errors.PrismaError as e:
logger.error(f"Database error getting preset: {e}")
logger.error(f"Database error getting preset: {str(e)}")
raise store_exceptions.DatabaseError("Failed to fetch preset") from e
async def upsert_preset(
user_id: str,
preset: library_model.CreateLibraryAgentPresetRequest,
preset_id: Optional[str] = None,
preset_id: str | None = None,
) -> library_model.LibraryAgentPreset:
"""
Creates or updates an AgentPreset for a user.
Args:
user_id: The ID of the user creating/updating the preset.
preset: The preset data used for creation or update.
preset_id: An optional preset ID to update; if None, a new preset is created.
Returns:
The newly created or updated LibraryAgentPreset.
Raises:
DatabaseError: If there's a database error in creating or updating the preset.
ValueError: If attempting to update a non-existent preset.
"""
logger.debug(
f"Upserting preset #{preset_id} ({repr(preset.name)}) for user #{user_id}",
)
try:
if preset_id:
# Update existing preset
updated = await prisma.models.AgentPreset.prisma().update(
new_preset = await prisma.models.AgentPreset.prisma().update(
where={"id": preset_id},
data={
"name": preset.name,
@@ -614,9 +301,8 @@ async def upsert_preset(
},
include={"InputPresets": True},
)
if not updated:
if not new_preset:
raise ValueError(f"AgentPreset #{preset_id} not found")
return library_model.LibraryAgentPreset.from_db(updated)
else:
# Create new preset
new_preset = await prisma.models.AgentPreset.prisma().create(
@@ -638,27 +324,16 @@ async def upsert_preset(
)
return library_model.LibraryAgentPreset.from_db(new_preset)
except prisma.errors.PrismaError as e:
logger.error(f"Database error upserting preset: {e}")
logger.error(f"Database error creating preset: {str(e)}")
raise store_exceptions.DatabaseError("Failed to create preset") from e
async def delete_preset(user_id: str, preset_id: str) -> None:
"""
Soft-deletes a preset by marking it as isDeleted = True.
Args:
user_id: The user that owns the preset.
preset_id: The ID of the preset to delete.
Raises:
DatabaseError: If there's a database error during deletion.
"""
logger.info(f"Deleting preset {preset_id} for user {user_id}")
try:
await prisma.models.AgentPreset.prisma().update_many(
where={"id": preset_id, "userId": user_id},
data={"isDeleted": True},
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error deleting preset: {e}")
logger.error(f"Database error deleting preset: {str(e)}")
raise store_exceptions.DatabaseError("Failed to delete preset") from e

View File

@@ -3,11 +3,21 @@ from datetime import datetime
import prisma.errors
import prisma.models
import pytest
from prisma import Prisma
import backend.data.includes
import backend.server.v2.library.db as db
import backend.server.v2.store.exceptions
from backend.data.db import connect
from backend.data.includes import library_agent_include
@pytest.fixture(autouse=True)
async def setup_prisma():
# Don't register client if already registered
try:
Prisma()
except prisma.errors.ClientAlreadyRegisteredError:
pass
yield
@pytest.mark.asyncio
@@ -22,6 +32,7 @@ async def test_get_library_agents(mocker):
userId="test-user",
isActive=True,
createdAt=datetime.now(),
isTemplate=False,
)
]
@@ -46,6 +57,7 @@ async def test_get_library_agents(mocker):
userId="other-user",
isActive=True,
createdAt=datetime.now(),
isTemplate=False,
),
)
]
@@ -60,31 +72,27 @@ async def test_get_library_agents(mocker):
mock_library_agent.return_value.find_many = mocker.AsyncMock(
return_value=mock_library_agents
)
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
# Call function
result = await db.list_library_agents("test-user")
result = await db.get_library_agents("test-user")
# Verify results
assert len(result.agents) == 1
assert result.agents[0].id == "ua1"
assert result.agents[0].name == "Test Agent 2"
assert result.agents[0].description == "Test Description 2"
assert result.agents[0].agent_id == "agent2"
assert result.agents[0].agent_version == 1
assert result.agents[0].can_access_graph is False
assert result.agents[0].is_latest_version is True
assert result.pagination.total_items == 1
assert result.pagination.total_pages == 1
assert result.pagination.current_page == 1
assert result.pagination.page_size == 50
assert len(result) == 1
assert result[0].id == "ua1"
assert result[0].name == "Test Agent 2"
assert result[0].description == "Test Description 2"
assert result[0].is_created_by_user is False
assert result[0].is_latest_version is True
assert result[0].is_favorite is False
assert result[0].agent_id == "agent2"
assert result[0].agent_version == 1
assert result[0].preset_id is None
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio
async def test_add_agent_to_library(mocker):
await connect()
# Mock data
mock_store_listing_data = prisma.models.StoreListingVersion(
mock_store_listing = prisma.models.StoreListingVersion(
id="version123",
version=1,
createdAt=datetime.now(),
@@ -109,37 +117,21 @@ async def test_add_agent_to_library(mocker):
userId="creator",
isActive=True,
createdAt=datetime.now(),
isTemplate=False,
),
)
mock_library_agent_data = prisma.models.LibraryAgent(
id="ua1",
userId="test-user",
agentId=mock_store_listing_data.agentId,
agentVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
createdAt=datetime.now(),
updatedAt=datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
Agent=mock_store_listing_data.Agent,
)
# Mock prisma calls
mock_store_listing_version = mocker.patch(
"prisma.models.StoreListingVersion.prisma"
)
mock_store_listing_version.return_value.find_unique = mocker.AsyncMock(
return_value=mock_store_listing_data
return_value=mock_store_listing
)
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock(
return_value=mock_library_agent_data
)
mock_library_agent.return_value.create = mocker.AsyncMock()
# Call function
await db.add_store_agent_to_library("version123", "test-user")
@@ -153,20 +145,17 @@ async def test_add_agent_to_library(mocker):
"userId": "test-user",
"agentId": "agent1",
"agentVersion": 1,
},
include=library_agent_include("test-user"),
}
)
mock_library_agent.return_value.create.assert_called_once_with(
data=prisma.types.LibraryAgentCreateInput(
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
),
include=library_agent_include("test-user"),
)
)
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio
async def test_add_agent_to_library_not_found(mocker):
await connect()
# Mock prisma calls
mock_store_listing_version = mocker.patch(
"prisma.models.StoreListingVersion.prisma"

View File

@@ -1,8 +1,6 @@
import datetime
from enum import Enum
from typing import Any, Optional
from typing import Any
import prisma.enums
import prisma.models
import pydantic
@@ -11,29 +9,13 @@ import backend.data.graph as graph_model
import backend.server.model as server_model
class LibraryAgentStatus(str, Enum):
COMPLETED = "COMPLETED" # All runs completed
HEALTHY = "HEALTHY" # Agent is running (not all runs have completed)
WAITING = "WAITING" # Agent is queued or waiting to start
ERROR = "ERROR" # Agent is in an error state
class LibraryAgent(pydantic.BaseModel):
"""
Represents an agent in the library, including metadata for display and
user interaction within the system.
"""
id: str # Changed from agent_id to match GraphMeta
id: str
agent_id: str
agent_version: int
agent_version: int # Changed from agent_version to match GraphMeta
image_url: str | None
creator_name: str
creator_image_url: str
status: LibraryAgentStatus
preset_id: str | None
updated_at: datetime.datetime
@@ -42,135 +24,47 @@ class LibraryAgent(pydantic.BaseModel):
# Made input_schema and output_schema match GraphMeta's type
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
output_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
# Indicates whether there's a new output (based on recent runs)
new_output: bool
is_favorite: bool
is_created_by_user: bool
# Whether the user can access the underlying graph
can_access_graph: bool
# Indicates if this agent is the latest version
is_latest_version: bool
@staticmethod
def from_db(agent: prisma.models.LibraryAgent) -> "LibraryAgent":
"""
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
model instance.
"""
def from_db(agent: prisma.models.LibraryAgent):
if not agent.Agent:
raise ValueError("Associated Agent record is required.")
raise ValueError("AgentGraph is required")
graph = graph_model.GraphModel.from_db(agent.Agent)
agent_updated_at = agent.Agent.updatedAt
lib_agent_updated_at = agent.updatedAt
# Compute updated_at as the latest between library agent and graph
# Take the latest updated_at timestamp either when the graph was updated or the library agent was updated
updated_at = (
max(agent_updated_at, lib_agent_updated_at)
if agent_updated_at
else lib_agent_updated_at
)
creator_name = "Unknown"
creator_image_url = ""
if agent.Creator:
creator_name = agent.Creator.name or "Unknown"
creator_image_url = agent.Creator.avatarUrl or ""
# Logic to calculate status and new_output
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
days=7
)
executions = agent.Agent.AgentGraphExecution or []
status_result = _calculate_agent_status(executions, week_ago)
status = status_result.status
new_output = status_result.new_output
# Check if user can access the graph
can_access_graph = agent.Agent.userId == agent.userId
# Hard-coded to True until a method to check is implemented
is_latest_version = True
return LibraryAgent(
id=agent.id,
agent_id=agent.agentId,
agent_version=agent.agentVersion,
image_url=agent.imageUrl,
creator_name=creator_name,
creator_image_url=creator_image_url,
status=status,
updated_at=updated_at,
name=graph.name,
description=graph.description,
input_schema=graph.input_schema,
new_output=new_output,
can_access_graph=can_access_graph,
is_latest_version=is_latest_version,
output_schema=graph.output_schema,
is_favorite=agent.isFavorite,
is_created_by_user=agent.isCreatedByUser,
is_latest_version=graph.is_active,
preset_id=agent.AgentPreset.id if agent.AgentPreset else None,
)
class AgentStatusResult(pydantic.BaseModel):
status: LibraryAgentStatus
new_output: bool
def _calculate_agent_status(
executions: list[prisma.models.AgentGraphExecution],
recent_threshold: datetime.datetime,
) -> AgentStatusResult:
"""
Helper function to determine the overall agent status and whether there
is new output (i.e., completed runs within the recent threshold).
:param executions: A list of AgentGraphExecution objects.
:param recent_threshold: A datetime; any execution after this indicates new output.
:return: (AgentStatus, new_output_flag)
"""
if not executions:
return AgentStatusResult(status=LibraryAgentStatus.COMPLETED, new_output=False)
# Track how many times each execution status appears
status_counts = {status: 0 for status in prisma.enums.AgentExecutionStatus}
new_output = False
for execution in executions:
# Check if there's a completed run more recent than `recent_threshold`
if execution.createdAt >= recent_threshold:
if execution.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED:
new_output = True
status_counts[execution.executionStatus] += 1
# Determine the final status based on counts
if status_counts[prisma.enums.AgentExecutionStatus.FAILED] > 0:
return AgentStatusResult(status=LibraryAgentStatus.ERROR, new_output=new_output)
elif status_counts[prisma.enums.AgentExecutionStatus.QUEUED] > 0:
return AgentStatusResult(
status=LibraryAgentStatus.WAITING, new_output=new_output
)
elif status_counts[prisma.enums.AgentExecutionStatus.RUNNING] > 0:
return AgentStatusResult(
status=LibraryAgentStatus.HEALTHY, new_output=new_output
)
else:
return AgentStatusResult(
status=LibraryAgentStatus.COMPLETED, new_output=new_output
)
class LibraryAgentResponse(pydantic.BaseModel):
"""Response schema for a list of library agents and pagination info."""
agents: list[LibraryAgent]
pagination: server_model.Pagination
class LibraryAgentPreset(pydantic.BaseModel):
"""Represents a preset configuration for a library agent."""
id: str
updated_at: datetime.datetime
@@ -184,14 +78,14 @@ class LibraryAgentPreset(pydantic.BaseModel):
inputs: block_model.BlockInput
@classmethod
def from_db(cls, preset: prisma.models.AgentPreset) -> "LibraryAgentPreset":
@staticmethod
def from_db(preset: prisma.models.AgentPreset):
input_data: block_model.BlockInput = {}
for preset_input in preset.InputPresets or []:
input_data[preset_input.name] = preset_input.data
return cls(
return LibraryAgentPreset(
id=preset.id,
updated_at=preset.updatedAt,
agent_id=preset.agentId,
@@ -204,56 +98,14 @@ class LibraryAgentPreset(pydantic.BaseModel):
class LibraryAgentPresetResponse(pydantic.BaseModel):
"""Response schema for a list of agent presets and pagination info."""
presets: list[LibraryAgentPreset]
pagination: server_model.Pagination
class CreateLibraryAgentPresetRequest(pydantic.BaseModel):
"""
Request model used when creating a new preset for a library agent.
"""
name: str
description: str
inputs: block_model.BlockInput
agent_id: str
agent_version: int
is_active: bool
class LibraryAgentFilter(str, Enum):
"""Possible filters for searching library agents."""
IS_FAVOURITE = "isFavourite"
IS_CREATED_BY_USER = "isCreatedByUser"
class LibraryAgentSort(str, Enum):
"""Possible sort options for sorting library agents."""
CREATED_AT = "createdAt"
UPDATED_AT = "updatedAt"
class LibraryAgentUpdateRequest(pydantic.BaseModel):
"""
Schema for updating a library agent via PUT.
Includes flags for auto-updating version, marking as favorite,
archiving, or deleting.
"""
auto_update_version: Optional[bool] = pydantic.Field(
default=None, description="Auto-update the agent version"
)
is_favorite: Optional[bool] = pydantic.Field(
default=None, description="Mark the agent as a favorite"
)
is_archived: Optional[bool] = pydantic.Field(
default=None, description="Archive the agent"
)
is_deleted: Optional[bool] = pydantic.Field(
default=None, description="Delete the agent"
)

View File

@@ -2,14 +2,148 @@ import datetime
import prisma.fields
import prisma.models
import pytest
import backend.server.v2.library.model as library_model
from backend.util import json
import backend.data.block
import backend.server.model
import backend.server.v2.library.model
@pytest.mark.asyncio
async def test_agent_preset_from_db():
def test_library_agent():
agent = backend.server.v2.library.model.LibraryAgent(
id="test-agent-123",
agent_id="agent-123",
agent_version=1,
preset_id=None,
updated_at=datetime.datetime.now(),
name="Test Agent",
description="Test description",
input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}},
is_favorite=False,
is_created_by_user=False,
is_latest_version=True,
)
assert agent.id == "test-agent-123"
assert agent.agent_id == "agent-123"
assert agent.agent_version == 1
assert agent.name == "Test Agent"
assert agent.description == "Test description"
assert agent.is_favorite is False
assert agent.is_created_by_user is False
assert agent.is_latest_version is True
assert agent.input_schema == {"type": "object", "properties": {}}
assert agent.output_schema == {"type": "object", "properties": {}}
def test_library_agent_with_user_created():
agent = backend.server.v2.library.model.LibraryAgent(
id="user-agent-456",
agent_id="agent-456",
agent_version=2,
preset_id=None,
updated_at=datetime.datetime.now(),
name="User Created Agent",
description="An agent created by the user",
input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}},
is_favorite=False,
is_created_by_user=True,
is_latest_version=True,
)
assert agent.id == "user-agent-456"
assert agent.agent_id == "agent-456"
assert agent.agent_version == 2
assert agent.name == "User Created Agent"
assert agent.description == "An agent created by the user"
assert agent.is_favorite is False
assert agent.is_created_by_user is True
assert agent.is_latest_version is True
assert agent.input_schema == {"type": "object", "properties": {}}
assert agent.output_schema == {"type": "object", "properties": {}}
def test_library_agent_preset():
preset = backend.server.v2.library.model.LibraryAgentPreset(
id="preset-123",
name="Test Preset",
description="Test preset description",
agent_id="test-agent-123",
agent_version=1,
is_active=True,
inputs={
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
},
updated_at=datetime.datetime.now(),
)
assert preset.id == "preset-123"
assert preset.name == "Test Preset"
assert preset.description == "Test preset description"
assert preset.agent_id == "test-agent-123"
assert preset.agent_version == 1
assert preset.is_active is True
assert preset.inputs == {
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
}
def test_library_agent_preset_response():
preset = backend.server.v2.library.model.LibraryAgentPreset(
id="preset-123",
name="Test Preset",
description="Test preset description",
agent_id="test-agent-123",
agent_version=1,
is_active=True,
inputs={
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
},
updated_at=datetime.datetime.now(),
)
pagination = backend.server.model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=10
)
response = backend.server.v2.library.model.LibraryAgentPresetResponse(
presets=[preset], pagination=pagination
)
assert len(response.presets) == 1
assert response.presets[0].id == "preset-123"
assert response.pagination.total_items == 1
assert response.pagination.total_pages == 1
assert response.pagination.current_page == 1
assert response.pagination.page_size == 10
def test_create_library_agent_preset_request():
request = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
name="New Preset",
description="New preset description",
agent_id="agent-123",
agent_version=1,
is_active=True,
inputs={
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
},
)
assert request.name == "New Preset"
assert request.description == "New preset description"
assert request.agent_id == "agent-123"
assert request.agent_version == 1
assert request.is_active is True
assert request.inputs == {
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
}
def test_library_agent_from_db():
# Create mock DB agent
db_agent = prisma.models.AgentPreset(
id="test-agent-123",
@@ -27,13 +161,13 @@ async def test_agent_preset_from_db():
id="input-123",
time=datetime.datetime.now(),
name="input1",
data=json.dumps({"type": "string", "value": "test value"}), # type: ignore
data=prisma.fields.Json({"type": "string", "value": "test value"}),
)
],
)
# Convert to LibraryAgentPreset
agent = library_model.LibraryAgentPreset.from_db(db_agent)
agent = backend.server.v2.library.model.LibraryAgentPreset.from_db(db_agent)
assert agent.id == "test-agent-123"
assert agent.agent_version == 1

View File

@@ -1,9 +1,8 @@
import logging
from typing import Optional
from typing import Annotated, Sequence
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from fastapi.responses import JSONResponse
import fastapi
import backend.server.v2.library.db as library_db
import backend.server.v2.library.model as library_model
@@ -11,182 +10,129 @@ import backend.server.v2.store.exceptions as store_exceptions
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/agents",
tags=["library", "private"],
dependencies=[Depends(autogpt_auth_lib.auth_middleware)],
)
router = fastapi.APIRouter()
@router.get(
"",
responses={
500: {"description": "Server error", "content": {"application/json": {}}},
},
"/agents",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_auth_lib.auth_middleware)],
)
async def list_library_agents(
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
search_term: Optional[str] = Query(
None, description="Search term to filter agents"
),
sort_by: library_model.LibraryAgentSort = Query(
library_model.LibraryAgentSort.UPDATED_AT,
description="Criteria to sort results by",
),
page: int = Query(
1,
ge=1,
description="Page number to retrieve (must be >= 1)",
),
page_size: int = Query(
15,
ge=1,
description="Number of agents per page (must be >= 1)",
),
) -> library_model.LibraryAgentResponse:
async def get_library_agents(
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)]
) -> Sequence[library_model.LibraryAgent]:
"""
Get all agents in the user's library (both created and saved).
Args:
user_id: ID of the authenticated user.
search_term: Optional search term to filter agents by name/description.
filter_by: List of filters to apply (favorites, created by user).
sort_by: List of sorting criteria (created date, updated date).
page: Page number to retrieve.
page_size: Number of agents per page.
Returns:
A LibraryAgentResponse containing agents and pagination metadata.
Raises:
HTTPException: If a server/database error occurs.
Get all agents in the user's library, including both created and saved agents.
"""
try:
return await library_db.list_library_agents(
user_id=user_id,
search_term=search_term,
sort_by=sort_by,
page=page,
page_size=page_size,
)
agents = await library_db.get_library_agents(user_id)
return agents
except Exception as e:
logger.error(f"Could not fetch library agents: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get library agents",
) from e
@router.get("/{library_agent_id}")
async def get_library_agent(
library_agent_id: str,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> library_model.LibraryAgent:
return await library_db.get_library_agent(id=library_agent_id, user_id=user_id)
logger.exception(f"Exception occurred whilst getting library agents: {e}")
raise fastapi.HTTPException(
status_code=500, detail="Failed to get library agents"
)
@router.post(
"",
status_code=status.HTTP_201_CREATED,
responses={
201: {"description": "Agent added successfully"},
404: {"description": "Store listing version not found"},
500: {"description": "Server error"},
},
"/agents/{store_listing_version_id}",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_auth_lib.auth_middleware)],
status_code=201,
)
async def add_marketplace_agent_to_library(
store_listing_version_id: str = Body(embed=True),
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> library_model.LibraryAgent:
async def add_agent_to_library(
store_listing_version_id: str,
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)],
) -> fastapi.Response:
"""
Add an agent from the marketplace to the user's library.
Add an agent from the store to the user's library.
Args:
store_listing_version_id: ID of the store listing version to add.
user_id: ID of the authenticated user.
store_listing_version_id (str): ID of the store listing version to add
user_id (str): ID of the authenticated user
Returns:
library_model.LibraryAgent: Agent added to the library
fastapi.Response: 201 status code on success
Raises:
HTTPException(404): If the listing version is not found.
HTTPException(500): If a server/database error occurs.
HTTPException: If there is an error adding the agent to the library
"""
try:
return await library_db.add_store_agent_to_library(
store_listing_version_id=store_listing_version_id,
user_id=user_id,
)
# Use the database function to add the agent to the library
await library_db.add_store_agent_to_library(store_listing_version_id, user_id)
return fastapi.Response(status_code=201)
except store_exceptions.AgentNotFoundError:
logger.warning(f"Agent not found: {store_listing_version_id}")
raise HTTPException(
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing version {store_listing_version_id} not found",
)
except store_exceptions.DatabaseError as e:
logger.error(f"Database error occurred whilst adding agent to library: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to add agent to library",
) from e
logger.exception(f"Database error occurred whilst adding agent to library: {e}")
raise fastapi.HTTPException(
status_code=500, detail="Failed to add agent to library"
)
except Exception as e:
logger.error(f"Unexpected error while adding agent: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to add agent to library",
) from e
logger.exception(
f"Unexpected exception occurred whilst adding agent to library: {e}"
)
raise fastapi.HTTPException(
status_code=500, detail="Failed to add agent to library"
)
@router.put(
"/{library_agent_id}",
status_code=status.HTTP_204_NO_CONTENT,
responses={
204: {"description": "Agent updated successfully"},
500: {"description": "Server error"},
},
"/agents/{library_agent_id}",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_auth_lib.auth_middleware)],
status_code=204,
)
async def update_library_agent(
library_agent_id: str,
payload: library_model.LibraryAgentUpdateRequest,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> JSONResponse:
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)],
auto_update_version: bool = False,
is_favorite: bool = False,
is_archived: bool = False,
is_deleted: bool = False,
) -> fastapi.Response:
"""
Update the library agent with the given fields.
Args:
library_agent_id: ID of the library agent to update.
payload: Fields to update (auto_update_version, is_favorite, etc.).
user_id: ID of the authenticated user.
library_agent_id (str): ID of the library agent to update
user_id (str): ID of the authenticated user
auto_update_version (bool): Whether to auto-update the agent version
is_favorite (bool): Whether the agent is marked as favorite
is_archived (bool): Whether the agent is archived
is_deleted (bool): Whether the agent is deleted
Returns:
204 (No Content) on success.
fastapi.Response: 204 status code on success
Raises:
HTTPException(500): If a server/database error occurs.
HTTPException: If there is an error updating the library agent
"""
try:
# Use the database function to update the library agent
await library_db.update_library_agent(
library_agent_id=library_agent_id,
user_id=user_id,
auto_update_version=payload.auto_update_version,
is_favorite=payload.is_favorite,
is_archived=payload.is_archived,
is_deleted=payload.is_deleted,
)
return JSONResponse(
status_code=status.HTTP_204_NO_CONTENT,
content={"message": "Agent updated successfully"},
library_agent_id,
user_id,
auto_update_version,
is_favorite,
is_archived,
is_deleted,
)
return fastapi.Response(status_code=204)
except store_exceptions.DatabaseError as e:
logger.exception(f"Database error while updating library agent: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update library agent",
) from e
logger.exception(f"Database error occurred whilst updating library agent: {e}")
raise fastapi.HTTPException(
status_code=500, detail="Failed to update library agent"
)
except Exception as e:
logger.exception(f"Unexpected error while updating library agent: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update library agent",
) from e
logger.exception(
f"Unexpected exception occurred whilst updating library agent: {e}"
)
raise fastapi.HTTPException(
status_code=500, detail="Failed to update library agent"
)

View File

@@ -3,225 +3,113 @@ from typing import Annotated, Any
import autogpt_libs.auth as autogpt_auth_lib
import autogpt_libs.utils.cache
from fastapi import APIRouter, Body, Depends, HTTPException, status
import fastapi
import backend.executor
import backend.server.v2.library.db as db
import backend.server.v2.library.model as models
import backend.server.v2.library.db as library_db
import backend.server.v2.library.model as library_model
import backend.util.service
logger = logging.getLogger(__name__)
router = APIRouter()
router = fastapi.APIRouter()
@autogpt_libs.utils.cache.thread_cached
def execution_manager_client() -> backend.executor.ExecutionManager:
"""Return a cached instance of ExecutionManager client."""
return backend.util.service.get_service_client(backend.executor.ExecutionManager)
@router.get(
"/presets",
summary="List presets",
description="Retrieve a paginated list of presets for the current user.",
)
@router.get("/presets")
async def get_presets(
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)],
page: int = 1,
page_size: int = 10,
) -> models.LibraryAgentPresetResponse:
"""
Retrieve a paginated list of presets for the current user.
Args:
user_id (str): ID of the authenticated user.
page (int): Page number for pagination.
page_size (int): Number of items per page.
Returns:
models.LibraryAgentPresetResponse: A response containing the list of presets.
"""
) -> library_model.LibraryAgentPresetResponse:
try:
return await db.get_presets(user_id, page, page_size)
presets = await library_db.get_presets(user_id, page, page_size)
return presets
except Exception as e:
logger.exception(f"Exception occurred while getting presets: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get presets",
)
logger.exception(f"Exception occurred whilst getting presets: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to get presets")
@router.get(
"/presets/{preset_id}",
summary="Get a specific preset",
description="Retrieve details for a specific preset by its ID.",
)
@router.get("/presets/{preset_id}")
async def get_preset(
preset_id: str,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> models.LibraryAgentPreset:
"""
Retrieve details for a specific preset by its ID.
Args:
preset_id (str): ID of the preset to retrieve.
user_id (str): ID of the authenticated user.
Returns:
models.LibraryAgentPreset: The preset details.
Raises:
HTTPException: If the preset is not found or an error occurs.
"""
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)],
) -> library_model.LibraryAgentPreset:
try:
preset = await db.get_preset(user_id, preset_id)
preset = await library_db.get_preset(user_id, preset_id)
if not preset:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
raise fastapi.HTTPException(
status_code=404,
detail=f"Preset {preset_id} not found",
)
return preset
except Exception as e:
logger.exception(f"Exception occurred whilst getting preset: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get preset",
)
raise fastapi.HTTPException(status_code=500, detail="Failed to get preset")
@router.post(
"/presets",
summary="Create a new preset",
description="Create a new preset for the current user.",
)
@router.post("/presets")
async def create_preset(
preset: models.CreateLibraryAgentPresetRequest,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> models.LibraryAgentPreset:
"""
Create a new library agent preset. Automatically corrects node_input format if needed.
Args:
preset (models.CreateLibraryAgentPresetRequest): The preset data to create.
user_id (str): ID of the authenticated user.
Returns:
models.LibraryAgentPreset: The created preset.
Raises:
HTTPException: If an error occurs while creating the preset.
"""
preset: library_model.CreateLibraryAgentPresetRequest,
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)],
) -> library_model.LibraryAgentPreset:
try:
return await db.upsert_preset(user_id, preset)
return await library_db.upsert_preset(user_id, preset)
except Exception as e:
logger.exception(f"Exception occurred while creating preset: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create preset",
)
logger.exception(f"Exception occurred whilst creating preset: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to create preset")
@router.put(
"/presets/{preset_id}",
summary="Update an existing preset",
description="Update an existing preset by its ID.",
)
@router.put("/presets/{preset_id}")
async def update_preset(
preset_id: str,
preset: models.CreateLibraryAgentPresetRequest,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> models.LibraryAgentPreset:
"""
Update an existing library agent preset. If the preset doesn't exist, it may be created.
Args:
preset_id (str): ID of the preset to update.
preset (models.CreateLibraryAgentPresetRequest): The preset data to update.
user_id (str): ID of the authenticated user.
Returns:
models.LibraryAgentPreset: The updated preset.
Raises:
HTTPException: If an error occurs while updating the preset.
"""
preset: library_model.CreateLibraryAgentPresetRequest,
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)],
) -> library_model.LibraryAgentPreset:
try:
return await db.upsert_preset(user_id, preset, preset_id)
return await library_db.upsert_preset(user_id, preset, preset_id)
except Exception as e:
logger.exception(f"Exception occurred whilst updating preset: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update preset",
)
raise fastapi.HTTPException(status_code=500, detail="Failed to update preset")
@router.delete(
"/presets/{preset_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a preset",
description="Delete an existing preset by its ID.",
)
@router.delete("/presets/{preset_id}")
async def delete_preset(
preset_id: str,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> None:
"""
Delete a preset by its ID. Returns 204 No Content on success.
Args:
preset_id (str): ID of the preset to delete.
user_id (str): ID of the authenticated user.
Raises:
HTTPException: If an error occurs while deleting the preset.
"""
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)],
):
try:
await db.delete_preset(user_id, preset_id)
await library_db.delete_preset(user_id, preset_id)
return fastapi.Response(status_code=204)
except Exception as e:
logger.exception(f"Exception occurred whilst deleting preset: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to delete preset",
)
raise fastapi.HTTPException(status_code=500, detail="Failed to delete preset")
@router.post(
"/presets/{preset_id}/execute",
path="/presets/{preset_id}/execute",
tags=["presets"],
summary="Execute a preset",
description="Execute a preset with the given graph and node input for the current user.",
dependencies=[fastapi.Depends(autogpt_auth_lib.auth_middleware)],
)
async def execute_preset(
graph_id: str,
graph_version: int,
preset_id: str,
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
node_input: Annotated[
dict[str, Any], fastapi.Body(..., embed=True, default_factory=dict)
],
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)],
) -> dict[str, Any]: # FIXME: add proper return type
"""
Execute a preset given graph parameters, returning the execution ID on success.
Args:
graph_id (str): ID of the graph to execute.
graph_version (int): Version of the graph to execute.
preset_id (str): ID of the preset to execute.
node_input (Dict[Any, Any]): Input data for the node.
user_id (str): ID of the authenticated user.
Returns:
Dict[str, Any]: A response containing the execution ID.
Raises:
HTTPException: If the preset is not found or an error occurs while executing the preset.
"""
try:
preset = await db.get_preset(user_id, preset_id)
preset = await library_db.get_preset(user_id, preset_id)
if not preset:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Preset not found",
)
raise fastapi.HTTPException(status_code=404, detail="Preset not found")
logger.debug(f"Preset inputs: {preset.inputs}")
# Merge input overrides with preset inputs
merged_node_input = preset.inputs | node_input
@@ -237,11 +125,6 @@ async def execute_preset(
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
return {"id": execution.graph_exec_id}
except HTTPException:
raise
except Exception as e:
logger.exception(f"Exception occurred while executing preset: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
msg = str(e).encode().decode("unicode_escape")
raise fastapi.HTTPException(status_code=400, detail=msg)

View File

@@ -1,11 +1,11 @@
import datetime
import autogpt_libs.auth as autogpt_auth_lib
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
import backend.server.model as server_model
import backend.server.v2.library.model as library_model
from backend.server.v2.library.routes import router as library_router
@@ -29,81 +29,61 @@ app.dependency_overrides[autogpt_auth_lib.auth_middleware] = override_auth_middl
app.dependency_overrides[autogpt_auth_lib.depends.get_user_id] = override_get_user_id
@pytest.mark.asyncio
async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
mocked_value = library_model.LibraryAgentResponse(
agents=[
library_model.LibraryAgent(
id="test-agent-1",
agent_id="test-agent-1",
agent_version=1,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=True,
is_latest_version=True,
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
),
library_model.LibraryAgent(
id="test-agent-2",
agent_id="test-agent-2",
agent_version=1,
name="Test Agent 2",
description="Test Description 2",
image_url=None,
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=False,
is_latest_version=True,
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
),
],
pagination=server_model.Pagination(
total_items=2, total_pages=1, current_page=1, page_size=50
def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
mocked_value = [
library_model.LibraryAgent(
id="test-agent-1",
agent_id="test-agent-1",
agent_version=1,
preset_id="preset-1",
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
is_favorite=False,
is_created_by_user=True,
is_latest_version=True,
name="Test Agent 1",
description="Test Description 1",
input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}},
),
)
mock_db_call = mocker.patch("backend.server.v2.library.db.list_library_agents")
library_model.LibraryAgent(
id="test-agent-2",
agent_id="test-agent-2",
agent_version=1,
preset_id="preset-2",
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
is_favorite=False,
is_created_by_user=False,
is_latest_version=True,
name="Test Agent 2",
description="Test Description 2",
input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}},
),
]
mock_db_call = mocker.patch("backend.server.v2.library.db.get_library_agents")
mock_db_call.return_value = mocked_value
response = client.get("/agents?search_term=test")
response = client.get("/agents")
assert response.status_code == 200
data = library_model.LibraryAgentResponse.model_validate(response.json())
assert len(data.agents) == 2
assert data.agents[0].agent_id == "test-agent-1"
assert data.agents[0].can_access_graph is True
assert data.agents[1].agent_id == "test-agent-2"
assert data.agents[1].can_access_graph is False
mock_db_call.assert_called_once_with(
user_id="test-user-id",
search_term="test",
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
page=1,
page_size=15,
)
data = [
library_model.LibraryAgent.model_validate(agent) for agent in response.json()
]
assert len(data) == 2
assert data[0].agent_id == "test-agent-1"
assert data[0].is_created_by_user is True
assert data[1].agent_id == "test-agent-2"
assert data[1].is_created_by_user is False
mock_db_call.assert_called_once_with("test-user-id")
def test_get_library_agents_error(mocker: pytest_mock.MockFixture):
mock_db_call = mocker.patch("backend.server.v2.library.db.list_library_agents")
mock_db_call = mocker.patch("backend.server.v2.library.db.get_library_agents")
mock_db_call.side_effect = Exception("Test error")
response = client.get("/agents?search_term=test")
response = client.get("/agents")
assert response.status_code == 500
mock_db_call.assert_called_once_with(
user_id="test-user-id",
search_term="test",
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
page=1,
page_size=15,
)
mock_db_call.assert_called_once_with("test-user-id")
@pytest.mark.skip(reason="Mocker Not implemented")

View File

@@ -1,34 +0,0 @@
from typing import Any, Dict, Optional
from pydantic import BaseModel
class Document(BaseModel):
url: str
relevance_score: float
class ApiResponse(BaseModel):
answer: str
documents: list[Document]
success: bool
class GraphData(BaseModel):
nodes: list[Dict[str, Any]]
edges: list[Dict[str, Any]]
graph_name: Optional[str] = None
graph_description: Optional[str] = None
class Message(BaseModel):
query: str
response: str
class ChatRequest(BaseModel):
query: str
conversation_history: list[Message]
message_id: str
include_graph_data: bool = False
graph_id: Optional[str] = None

View File

@@ -1,26 +0,0 @@
import logging
from autogpt_libs.auth.middleware import auth_middleware
from fastapi import APIRouter, Depends
from backend.server.utils import get_user_id
from .models import ApiResponse, ChatRequest
from .service import OttoService
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"/ask", response_model=ApiResponse, dependencies=[Depends(auth_middleware)]
)
async def proxy_otto_request(
request: ChatRequest, user_id: str = Depends(get_user_id)
) -> ApiResponse:
"""
Proxy requests to Otto API while adding necessary security headers and logging.
Requires an authenticated user.
"""
return await OttoService.ask(request, user_id)

View File

@@ -1,138 +0,0 @@
import asyncio
import logging
from typing import Optional
import aiohttp
from fastapi import HTTPException
from backend.data import graph as graph_db
from backend.data.block import get_block
from backend.util.settings import Settings
from .models import ApiResponse, ChatRequest, GraphData
logger = logging.getLogger(__name__)
settings = Settings()
OTTO_API_URL = settings.config.otto_api_url
class OttoService:
@staticmethod
async def _fetch_graph_data(
request: ChatRequest, user_id: str
) -> Optional[GraphData]:
"""Fetch graph data if requested and available."""
if not (request.include_graph_data and request.graph_id):
return None
try:
graph = await graph_db.get_graph(request.graph_id, user_id=user_id)
if not graph:
return None
nodes_data = []
for node in graph.nodes:
block = get_block(node.block_id)
if not block:
continue
node_data = {
"id": node.id,
"block_id": node.block_id,
"block_name": block.name,
"block_type": (
block.block_type.value if hasattr(block, "block_type") else None
),
"data": {
k: v
for k, v in (node.input_default or {}).items()
if k not in ["credentials"] # Exclude sensitive data
},
}
nodes_data.append(node_data)
# Create a GraphData object with the required fields
return GraphData(
nodes=nodes_data,
edges=[],
graph_name=graph.name,
graph_description=graph.description,
)
except Exception as e:
logger.error(f"Failed to fetch graph data: {str(e)}")
return None
@staticmethod
async def ask(request: ChatRequest, user_id: str) -> ApiResponse:
"""
Send request to Otto API and handle the response.
"""
# Check if Otto API URL is configured
if not OTTO_API_URL:
logger.error("Otto API URL is not configured")
raise HTTPException(
status_code=503, detail="Otto service is not configured"
)
try:
async with aiohttp.ClientSession() as session:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
# If graph data is requested, fetch it
graph_data = await OttoService._fetch_graph_data(request, user_id)
# Prepare the payload with optional graph data
payload = {
"query": request.query,
"conversation_history": [
msg.model_dump() for msg in request.conversation_history
],
"user_id": user_id,
"message_id": request.message_id,
}
if graph_data:
payload["graph_data"] = graph_data.model_dump()
logger.info(f"Sending request to Otto API for user {user_id}")
logger.debug(f"Request payload: {payload}")
async with session.post(
OTTO_API_URL,
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=60),
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Otto API error: {error_text}")
raise HTTPException(
status_code=response.status,
detail=f"Otto API request failed: {error_text}",
)
data = await response.json()
logger.info(
f"Successfully received response from Otto API for user {user_id}"
)
return ApiResponse(**data)
except aiohttp.ClientError as e:
logger.error(f"Connection error to Otto API: {str(e)}")
raise HTTPException(
status_code=503, detail="Failed to connect to Otto service"
)
except asyncio.TimeoutError:
logger.error("Timeout error connecting to Otto API after 60 seconds")
raise HTTPException(
status_code=504, detail="Request to Otto service timed out"
)
except Exception as e:
logger.error(f"Unexpected error in Otto API proxy: {str(e)}")
raise HTTPException(
status_code=500, detail="Internal server error in Otto proxy"
)

View File

@@ -1,212 +0,0 @@
from enum import Enum
from typing import Literal
from pydantic import BaseModel
# Models from https://account.postmarkapp.com/servers/<id>/streams/outbound/webhooks/new
class PostmarkDeliveryWebhook(BaseModel):
RecordType: Literal["Delivery"] = "Delivery"
ServerID: int
MessageStream: str
MessageID: str
Recipient: str
Tag: str
DeliveredAt: str
Details: str
Metadata: dict[str, str]
class PostmarkBounceEnum(Enum):
HardBounce = 1
"""
The server was unable to deliver your message (ex: unknown user, mailbox not found).
"""
Transient = 2
"""
The server could not temporarily deliver your message (ex: Message is delayed due to network troubles).
"""
Unsubscribe = 16
"""
Unsubscribe or Remove request.
"""
Subscribe = 32
"""
Subscribe request from someone wanting to get added to the mailing list.
"""
AutoResponder = 64
"""
"Autoresponder" is an automatic email responder including nondescript NDRs and some "out of office" replies.
"""
AddressChange = 128
"""
The recipient has requested an address change.
"""
DnsError = 256
"""
A temporary DNS error.
"""
SpamNotification = 512
"""
The message was delivered, but was either blocked by the user, or classified as spam, bulk mail, or had rejected content.
"""
OpenRelayTest = 1024
"""
The NDR is actually a test email message to see if the mail server is an open relay.
"""
Unknown = 2048
"""
Unable to classify the NDR.
"""
SoftBounce = 4096
"""
Unable to temporarily deliver message (i.e. mailbox full, account disabled, exceeds quota, out of disk space).
"""
VirusNotification = 8192
"""
The bounce is actually a virus notification warning about a virus/code infected message.
"""
ChallengeVerification = 16384
"""
The bounce is a challenge asking for verification you actually sent the email. Typcial challenges are made by Spam Arrest, or MailFrontier Matador.
"""
BadEmailAddress = 100000
"""
The address is not a valid email address.
"""
SpamComplaint = 100001
"""
The subscriber explicitly marked this message as spam.
"""
ManuallyDeactivated = 100002
"""
The email was manually deactivated.
"""
Unconfirmed = 100003
"""
Registration not confirmed — The subscriber has not clicked on the confirmation link upon registration or import.
"""
Blocked = 100006
"""
Blocked from this ISP due to content or blacklisting.
"""
SMTPApiError = 100007
"""
An error occurred while accepting an email through the SMTP API.
"""
InboundError = 100008
"""
Processing failed — Unable to deliver inbound message to destination inbound hook.
"""
DMARCPolicy = 100009
"""
Email rejected due DMARC Policy.
"""
TemplateRenderingFailed = 100010
"""
Template rendering failed — An error occurred while attempting to render your template.
"""
class PostmarkBounceWebhook(BaseModel):
RecordType: Literal["Bounce"] = "Bounce"
ID: int
Type: str
TypeCode: PostmarkBounceEnum
Tag: str
MessageID: str
Details: str
Email: str
From: str
BouncedAt: str
Inactive: bool
DumpAvailable: bool
CanActivate: bool
Subject: str
ServerID: int
MessageStream: str
Content: str
Name: str
Description: str
Metadata: dict[str, str]
class PostmarkSpamComplaintWebhook(BaseModel):
RecordType: Literal["SpamComplaint"] = "SpamComplaint"
ID: int
Type: str
TypeCode: int
Tag: str
MessageID: str
Details: str
Email: str
From: str
BouncedAt: str
Inactive: bool
DumpAvailable: bool
CanActivate: bool
Subject: str
ServerID: int
MessageStream: str
Content: str
Name: str
Description: str
Metadata: dict[str, str]
class PostmarkOpenWebhook(BaseModel):
RecordType: Literal["Open"] = "Open"
MessageStream: str
Metadata: dict[str, str]
FirstOpen: bool
Recipient: str
MessageID: str
ReceivedAt: str
Platform: str
ReadSeconds: int
Tag: str
UserAgent: str
OS: dict[str, str]
Client: dict[str, str]
Geo: dict[str, str]
class PostmarkClickWebhook(BaseModel):
RecordType: Literal["Click"] = "Click"
MessageStream: str
Metadata: dict[str, str]
Recipient: str
MessageID: str
ReceivedAt: str
Platform: str
ClickLocation: str
OriginalLink: str
Tag: str
UserAgent: str
OS: dict[str, str]
Client: dict[str, str]
Geo: dict[str, str]
class PostmarkSubscriptionChangeWebhook(BaseModel):
RecordType: Literal["SubscriptionChange"] = "SubscriptionChange"
MessageID: str
ServerID: int
MessageStream: str
ChangedAt: str
Recipient: str
Origin: str
SuppressSending: bool
SuppressionReason: str
Tag: str
Metadata: dict[str, str]
PostmarkWebhook = (
PostmarkDeliveryWebhook
| PostmarkBounceWebhook
| PostmarkSpamComplaintWebhook
| PostmarkOpenWebhook
| PostmarkClickWebhook
| PostmarkSubscriptionChangeWebhook
)

View File

@@ -1,116 +0,0 @@
import logging
from typing import Annotated
from autogpt_libs.auth.middleware import APIKeyValidator
from fastapi import APIRouter, Body, Depends, Query
from fastapi.responses import JSONResponse
from backend.data.user import (
get_user_by_email,
set_user_email_verification,
unsubscribe_user_by_token,
)
from backend.server.v2.postmark.models import (
PostmarkBounceEnum,
PostmarkBounceWebhook,
PostmarkClickWebhook,
PostmarkDeliveryWebhook,
PostmarkOpenWebhook,
PostmarkSpamComplaintWebhook,
PostmarkSubscriptionChangeWebhook,
PostmarkWebhook,
)
from backend.util.settings import Settings
settings = Settings()
postmark_validator = APIKeyValidator(
"X-Postmark-Webhook-Token",
settings.secrets.postmark_webhook_token,
)
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/unsubscribe")
async def unsubscribe_via_one_click(token: Annotated[str, Query()]):
logger.info(f"Received unsubscribe request from One Click Unsubscribe: {token}")
try:
await unsubscribe_user_by_token(token)
except Exception as e:
logger.error(f"Failed to unsubscribe user by token {token}: {e}")
raise e
return JSONResponse(status_code=200, content={"status": "ok"})
@router.post("/", dependencies=[Depends(postmark_validator.get_dependency())])
async def postmark_webhook_handler(
webhook: Annotated[
PostmarkWebhook,
Body(discriminator="RecordType"),
]
):
logger.info(f"Received webhook from Postmark: {webhook}")
match webhook:
case PostmarkDeliveryWebhook():
delivery_handler(webhook)
case PostmarkBounceWebhook():
await bounce_handler(webhook)
case PostmarkSpamComplaintWebhook():
spam_handler(webhook)
case PostmarkOpenWebhook():
open_handler(webhook)
case PostmarkClickWebhook():
click_handler(webhook)
case PostmarkSubscriptionChangeWebhook():
subscription_handler(webhook)
case _:
logger.warning(f"Unknown webhook type: {type(webhook)}")
return
async def bounce_handler(event: PostmarkBounceWebhook):
logger.info(f"Bounce handler {event=}")
if event.TypeCode in [
PostmarkBounceEnum.Transient,
PostmarkBounceEnum.SoftBounce,
PostmarkBounceEnum.DnsError,
]:
logger.info(
f"Softish bounce: {event.TypeCode} for {event.Email}, not setting email verification to false"
)
return
logger.info(f"{event.Email=}")
user = await get_user_by_email(event.Email)
if not user:
logger.error(f"User not found for email: {event.Email}")
return
await set_user_email_verification(user.id, False)
logger.debug(f"Setting email verification to false for user: {user.id}")
def spam_handler(event: PostmarkSpamComplaintWebhook):
logger.info("Spam handler")
pass
def delivery_handler(event: PostmarkDeliveryWebhook):
logger.info("Delivery handler")
pass
def open_handler(event: PostmarkOpenWebhook):
logger.info("Open handler")
pass
def click_handler(event: PostmarkClickWebhook):
logger.info("Click handler")
pass
def subscription_handler(event: PostmarkSubscriptionChangeWebhook):
logger.info("Subscription handler")
pass

View File

@@ -1,5 +1,6 @@
import logging
from datetime import datetime
from typing import Optional
import fastapi
import prisma.enums
@@ -83,30 +84,20 @@ async def get_store_agents(
)
total_pages = (total + page_size - 1) // page_size
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
for agent in agents:
try:
# Create the StoreAgent object safely
store_agent = backend.server.v2.store.model.StoreAgent(
slug=agent.slug,
agent_name=agent.agent_name,
agent_image=agent.agent_image[0] if agent.agent_image else "",
creator=agent.creator_username or "Needs Profile",
creator_avatar=agent.creator_avatar or "",
sub_heading=agent.sub_heading,
description=agent.description,
runs=agent.runs,
rating=agent.rating,
)
# Add to the list only if creation was successful
store_agents.append(store_agent)
except Exception as e:
# Skip this agent if there was an error
# You could log the error here if needed
logger.error(
f"Error parsing Store agent when getting store agents from db: {e}"
)
continue
store_agents = [
backend.server.v2.store.model.StoreAgent(
slug=agent.slug,
agent_name=agent.agent_name,
agent_image=agent.agent_image[0] if agent.agent_image else "",
creator=agent.creator_username or "Needs Profile",
creator_avatar=agent.creator_avatar or "",
sub_heading=agent.sub_heading,
description=agent.description,
runs=agent.runs,
rating=agent.rating,
)
for agent in agents
]
logger.debug(f"Found {len(store_agents)} agents")
return backend.server.v2.store.model.StoreAgentsResponse(
@@ -610,7 +601,7 @@ async def get_user_profile(
avatar_url=profile.avatarUrl,
)
except Exception as e:
logger.error(f"Error getting user profile: {e}")
logger.error("Error getting user profile: %s", e)
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to get user profile"
) from e
@@ -629,7 +620,7 @@ async def update_profile(
Raises:
DatabaseError: If there's an issue updating or creating the profile
"""
logger.info(f"Updating profile for user {user_id} with data: {profile}")
logger.info("Updating profile for user %s with data: %s", user_id, profile)
try:
# Sanitize username to allow only letters, numbers, and hyphens
username = "".join(
@@ -648,13 +639,15 @@ async def update_profile(
# Verify that the user is authorized to update this profile
if existing_profile.userId != user_id:
logger.error(
f"Unauthorized update attempt for profile {existing_profile.id} by user {user_id}"
"Unauthorized update attempt for profile %s by user %s",
existing_profile.userId,
user_id,
)
raise backend.server.v2.store.exceptions.DatabaseError(
f"Unauthorized update attempt for profile {existing_profile.id} by user {user_id}"
)
logger.debug(f"Updating existing profile for user {user_id}")
logger.debug("Updating existing profile for user %s", user_id)
# Prepare update data, only including non-None values
update_data = {}
if profile.name is not None:
@@ -674,7 +667,7 @@ async def update_profile(
data=prisma.types.ProfileUpdateInput(**update_data),
)
if updated_profile is None:
logger.error(f"Failed to update profile for user {user_id}")
logger.error("Failed to update profile for user %s", user_id)
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to update profile"
)
@@ -691,7 +684,7 @@ async def update_profile(
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating profile: {e}")
logger.error("Database error updating profile: %s", e)
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to update profile"
) from e
@@ -705,35 +698,45 @@ async def get_my_agents(
logger.debug(f"Getting my agents for user {user_id}, page={page}")
try:
search_filter: prisma.types.LibraryAgentWhereInput = {
"userId": user_id,
"Agent": {"is": {"StoreListing": {"none": {"isDeleted": False}}}},
"isArchived": False,
"isDeleted": False,
}
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
where=search_filter,
order=[{"agentVersion": "desc"}],
agents_with_max_version = await prisma.models.AgentGraph.prisma().find_many(
where=prisma.types.AgentGraphWhereInput(
userId=user_id, StoreListing={"none": {"isDeleted": False}}
),
order=[{"version": "desc"}],
distinct=["id"],
skip=(page - 1) * page_size,
take=page_size,
include={"Agent": True},
)
total = await prisma.models.LibraryAgent.prisma().count(where=search_filter)
# store_listings = await prisma.models.StoreListing.prisma().find_many(
# where=prisma.types.StoreListingWhereInput(
# isDeleted=False,
# ),
# )
total = len(
await prisma.models.AgentGraph.prisma().find_many(
where=prisma.types.AgentGraphWhereInput(
userId=user_id, StoreListing={"none": {"isDeleted": False}}
),
order=[{"version": "desc"}],
distinct=["id"],
)
)
total_pages = (total + page_size - 1) // page_size
agents = agents_with_max_version
my_agents = [
backend.server.v2.store.model.MyAgent(
agent_id=graph.id,
agent_version=graph.version,
agent_name=graph.name or "",
last_edited=graph.updatedAt or graph.createdAt,
description=graph.description or "",
agent_image=library_agent.imageUrl,
agent_id=agent.id,
agent_version=agent.version,
agent_name=agent.name or "",
last_edited=agent.updatedAt or agent.createdAt,
description=agent.description or "",
)
for library_agent in library_agents
if (graph := library_agent.Agent)
for agent in agents
]
return backend.server.v2.store.model.MyAgentsResponse(
@@ -753,31 +756,47 @@ async def get_my_agents(
async def get_agent(
user_id: str,
store_listing_version_id: str,
store_listing_version_id: str, version_id: Optional[int]
) -> GraphModel:
"""Get agent using the version ID and store listing version ID."""
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}
)
)
if not store_listing_version:
raise ValueError(f"Store listing version {store_listing_version_id} not found")
graph = await backend.data.graph.get_graph(
user_id=user_id,
graph_id=store_listing_version.agentId,
version=store_listing_version.agentVersion,
for_export=True,
)
if not graph:
raise ValueError(
f"Agent {store_listing_version.agentId} v{store_listing_version.agentVersion} not found"
try:
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}, include={"Agent": True}
)
)
return graph
if not store_listing_version or not store_listing_version.Agent:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing version {store_listing_version_id} not found",
)
graph_id = store_listing_version.agentId
graph_version = store_listing_version.agentVersion
graph = await backend.data.graph.get_graph(graph_id, graph_version)
if not graph:
raise fastapi.HTTPException(
status_code=404,
detail=(
f"Agent #{graph_id} not found "
f"for store listing version #{store_listing_version_id}"
),
)
graph.version = 1
graph.is_template = False
graph.is_active = True
delattr(graph, "user_id")
return graph
except Exception as e:
logger.error(f"Error getting agent: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch agent"
) from e
async def review_store_submission(

View File

@@ -146,6 +146,7 @@ async def test_create_store_submission(mocker):
userId="user-id",
createdAt=datetime.now(),
isActive=True,
isTemplate=False,
)
mock_listing = prisma.models.StoreListing(

View File

@@ -1,30 +1,16 @@
import asyncio
import io
import logging
from enum import Enum
import replicate
import replicate.exceptions
from prisma.models import AgentGraph
import requests
from replicate.helpers import FileOutput
from backend.blocks.ideogram import (
AspectRatio,
ColorPalettePreset,
IdeogramModelBlock,
IdeogramModelName,
MagicPromptOption,
StyleType,
UpscaleOption,
)
from backend.data.graph import Graph
from backend.data.model import CredentialsMetaInput, ProviderName
from backend.integrations.credentials_store import ideogram_credentials
from backend.util.request import requests
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
class ImageSize(str, Enum):
@@ -35,71 +21,7 @@ class ImageStyle(str, Enum):
DIGITAL_ART = "digital art"
async def generate_agent_image(agent: Graph | AgentGraph) -> io.BytesIO:
if settings.config.use_agent_image_generation_v2:
return await asyncio.to_thread(generate_agent_image_v2, graph=agent)
else:
return await generate_agent_image_v1(agent=agent)
def generate_agent_image_v2(graph: Graph | AgentGraph) -> io.BytesIO:
"""
Generate an image for an agent using Ideogram model.
Returns:
str: The URL of the generated image
"""
if not ideogram_credentials.api_key:
raise ValueError("Missing Ideogram API key")
name = graph.name
description = f"{name} ({graph.description})" if graph.description else name
prompt = (
f"Create a visually striking retro-futuristic vector pop art illustration prominently featuring "
f'"{name}" in bold typography. The image clearly and literally depicts a {description}, '
f"along with recognizable objects directly associated with the primary function of a {name}. "
f"Ensure the imagery is concrete, intuitive, and immediately understandable, clearly conveying the "
f"purpose of a {name}. Maintain vibrant, limited-palette colors, sharp vector lines, geometric "
f"shapes, flat illustration techniques, and solid colors without gradients or shading. Preserve a "
f"retro-futuristic aesthetic influenced by mid-century futurism and 1960s psychedelia, "
f"prioritizing clear visual storytelling and thematic clarity above all else."
)
custom_colors = [
"#000030",
"#1C0C47",
"#9900FF",
"#4285F4",
"#FFFFFF",
]
# Run the Ideogram model block with the specified parameters
url = IdeogramModelBlock().run_once(
IdeogramModelBlock.Input(
credentials=CredentialsMetaInput(
id=ideogram_credentials.id,
provider=ProviderName.IDEOGRAM,
title=ideogram_credentials.title,
type=ideogram_credentials.type,
),
prompt=prompt,
ideogram_model_name=IdeogramModelName.V2,
aspect_ratio=AspectRatio.ASPECT_16_9,
magic_prompt_option=MagicPromptOption.OFF,
style_type=StyleType.AUTO,
upscale=UpscaleOption.NO_UPSCALE,
color_palette_name=ColorPalettePreset.NONE,
custom_color_palette=custom_colors,
seed=None,
negative_prompt=None,
),
"result",
credentials=ideogram_credentials,
)
return io.BytesIO(requests.get(url).content)
async def generate_agent_image_v1(agent: Graph | AgentGraph) -> io.BytesIO:
async def generate_agent_image(agent: Graph) -> io.BytesIO:
"""
Generate an image for an agent using Flux model via Replicate API.
@@ -110,6 +32,8 @@ async def generate_agent_image_v1(agent: Graph | AgentGraph) -> io.BytesIO:
io.BytesIO: The generated image as bytes
"""
try:
settings = Settings()
if not settings.secrets.replicate_api_key:
raise ValueError("Missing Replicate API key in settings")
@@ -146,12 +70,14 @@ async def generate_agent_image_v1(agent: Graph | AgentGraph) -> io.BytesIO:
# If it's a URL string, fetch the image bytes
result_url = output[0]
response = requests.get(result_url)
response.raise_for_status()
image_bytes = response.content
elif isinstance(output, FileOutput):
image_bytes = output.read()
elif isinstance(output, str):
# Output is a URL
response = requests.get(output)
response.raise_for_status()
image_bytes = response.content
else:
raise RuntimeError("Unexpected output format from the model.")

View File

@@ -6,7 +6,6 @@ import fastapi
from google.cloud import storage
import backend.server.v2.store.exceptions
from backend.util.exceptions import MissingConfigError
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -28,32 +27,34 @@ async def check_media_exists(user_id: str, filename: str) -> str | None:
Returns:
str | None: URL of the blob if it exists, None otherwise
"""
settings = Settings()
if not settings.config.media_gcs_bucket_name:
raise MissingConfigError("GCS media bucket is not configured")
try:
settings = Settings()
storage_client = storage.Client()
bucket = storage_client.bucket(settings.config.media_gcs_bucket_name)
storage_client = storage.Client()
bucket = storage_client.bucket(settings.config.media_gcs_bucket_name)
# Check images
image_path = f"users/{user_id}/images/{filename}"
image_blob = bucket.blob(image_path)
if image_blob.exists():
return image_blob.public_url
# Check images
image_path = f"users/{user_id}/images/{filename}"
image_blob = bucket.blob(image_path)
if image_blob.exists():
return image_blob.public_url
# Check videos
video_path = f"users/{user_id}/videos/{filename}"
# Check videos
video_path = f"users/{user_id}/videos/{filename}"
video_blob = bucket.blob(video_path)
if video_blob.exists():
return video_blob.public_url
video_blob = bucket.blob(video_path)
if video_blob.exists():
return video_blob.public_url
return None
return None
except Exception as e:
logger.error(f"Error checking if media file exists: {str(e)}")
return None
async def upload_media(
user_id: str, file: fastapi.UploadFile, use_file_name: bool = False
) -> str:
# Get file content for deeper validation
try:
content = await file.read(1024) # Read first 1KB for validation

View File

@@ -24,7 +24,6 @@ class MyAgent(pydantic.BaseModel):
agent_id: str
agent_version: int
agent_name: str
agent_image: str | None = None
description: str
last_edited: datetime.datetime

View File

@@ -1,3 +1,4 @@
import json
import logging
import tempfile
import typing
@@ -7,6 +8,7 @@ import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware
import fastapi
import fastapi.responses
from fastapi.encoders import jsonable_encoder
import backend.data.block
import backend.data.graph
@@ -14,7 +16,6 @@ import backend.server.v2.store.db
import backend.server.v2.store.image_gen
import backend.server.v2.store.media
import backend.server.v2.store.model
import backend.util.json
logger = logging.getLogger(__name__)
@@ -590,18 +591,19 @@ async def generate_image(
tags=["store", "public"],
)
async def download_agent_file(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
store_listing_version_id: str = fastapi.Path(
..., description="The ID of the agent to download"
),
version: typing.Optional[int] = fastapi.Query(
None, description="Specific version of the agent"
),
) -> fastapi.responses.FileResponse:
"""
Download the agent file by streaming its content.
Args:
store_listing_version_id (str): The ID of the agent to download
agent_id (str): The ID of the agent to download.
version (Optional[int]): Specific version of the agent to download.
Returns:
StreamingResponse: A streaming response containing the agent's graph data.
@@ -611,16 +613,35 @@ async def download_agent_file(
"""
graph_data = await backend.server.v2.store.db.get_agent(
user_id=user_id,
store_listing_version_id=store_listing_version_id,
store_listing_version_id=store_listing_version_id, version_id=version
)
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
graph_data.clean_graph()
graph_date_dict = jsonable_encoder(graph_data)
def remove_credentials(obj):
if obj and isinstance(obj, dict):
if "credentials" in obj:
del obj["credentials"]
if "creds" in obj:
del obj["creds"]
for value in obj.values():
remove_credentials(value)
elif isinstance(obj, list):
for item in obj:
remove_credentials(item)
return obj
graph_date_dict = remove_credentials(graph_date_dict)
file_name = f"agent_{store_listing_version_id}_v{version or 'latest'}.json"
# Sending graph as a stream (similar to marketplace v1)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
tmp_file.write(backend.util.json.dumps(graph_data))
tmp_file.write(json.dumps(graph_date_dict))
tmp_file.flush()
return fastapi.responses.FileResponse(

View File

@@ -40,10 +40,7 @@ def create_test_graph() -> graph.Graph:
),
graph.Node(
block_id=AgentInputBlock().id,
input_default={
"name": "input_2",
"description": "This is my description of this parameter",
},
input_default={"name": "input_2"},
),
graph.Node(
block_id=FillTextTemplateBlock().id,
@@ -77,7 +74,7 @@ def create_test_graph() -> graph.Graph:
return graph.Graph(
name="TestGraph",
description="Test graph description",
description="Test graph",
nodes=nodes,
links=links,
)

View File

@@ -4,22 +4,3 @@ class MissingConfigError(Exception):
class NeedConfirmation(Exception):
"""The user must explicitly confirm that they want to proceed"""
class InsufficientBalanceError(ValueError):
user_id: str
message: str
balance: float
amount: float
def __init__(self, message: str, user_id: str, balance: float, amount: float):
super().__init__(message)
self.args = (message, user_id, balance, amount)
self.message = message
self.user_id = user_id
self.balance = balance
self.amount = amount
def __str__(self):
"""Used to display the error message in the frontend, because we str() the error when sending the execution update"""
return self.message

View File

@@ -9,8 +9,6 @@ from .type import type_match
def to_dict(data) -> dict:
if isinstance(data, BaseModel):
data = data.model_dump()
return jsonable_encoder(data)

View File

@@ -54,11 +54,11 @@ class AppProcess(ABC):
"""
pass
def health_check(self) -> str:
def health_check(self):
"""
A method to check the health of the process.
"""
return "OK"
pass
def execute_run_command(self, silent):
signal.signal(signal.SIGTERM, self._self_terminate)
@@ -109,8 +109,6 @@ class AppProcess(ABC):
)
self.process.start()
self.health_check()
logger.info(f"[{self.service_name}] started with PID {self.process.pid}")
return self.process.pid or 0
def stop(self):
@@ -122,6 +120,4 @@ class AppProcess(ABC):
self.process.terminate()
self.process.join()
logger.info(f"[{self.service_name}] with PID {self.process.pid} stopped")
self.process = None

View File

@@ -2,7 +2,7 @@ import ipaddress
import re
import socket
from typing import Callable
from urllib.parse import urljoin, urlparse, urlunparse
from urllib.parse import urlparse, urlunparse
import idna
import requests as req
@@ -128,14 +128,7 @@ class Requests:
self.extra_headers = extra_headers
def request(
self,
method,
url,
headers=None,
allow_redirects=True,
max_redirects=10,
*args,
**kwargs,
self, method, url, headers=None, allow_redirects=False, *args, **kwargs
) -> req.Response:
# Merge any extra headers
if self.extra_headers is not None:
@@ -146,41 +139,18 @@ class Requests:
if self.extra_url_validator is not None:
url = self.extra_url_validator(url)
# Perform the request with redirects disabled for manual handling
# Perform the request
response = req.request(
method,
url,
headers=headers,
allow_redirects=False,
allow_redirects=allow_redirects,
*args,
**kwargs,
)
if self.raise_for_status:
response.raise_for_status()
# If allowed and a redirect is received, follow the redirect
if allow_redirects and response.is_redirect:
if max_redirects <= 0:
raise Exception("Too many redirects.")
location = response.headers.get("Location")
if not location:
return response
new_url = validate_url(urljoin(url, location), self.trusted_origins)
if self.extra_url_validator is not None:
new_url = self.extra_url_validator(new_url)
return self.request(
method,
new_url,
headers=headers,
allow_redirects=allow_redirects,
max_redirects=max_redirects - 1,
*args,
**kwargs,
)
return response
def get(self, url, *args, **kwargs) -> req.Response:

View File

@@ -1,6 +1,5 @@
import asyncio
import builtins
import inspect
import logging
import os
import threading
@@ -8,21 +7,18 @@ import time
import typing
from abc import ABC, abstractmethod
from enum import Enum
from functools import wraps
from types import NoneType, UnionType
from typing import (
Annotated,
Any,
Awaitable,
Callable,
Concatenate,
Coroutine,
Dict,
FrozenSet,
Iterator,
List,
Optional,
ParamSpec,
Set,
Tuple,
Type,
@@ -33,17 +29,12 @@ from typing import (
get_origin,
)
import httpx
import Pyro5.api
import uvicorn
from fastapi import FastAPI, Request, responses
from pydantic import BaseModel, TypeAdapter, create_model
from pydantic import BaseModel
from Pyro5 import api as pyro
from Pyro5 import config as pyro_config
from backend.data import db, rabbitmq, redis
from backend.util.exceptions import InsufficientBalanceError
from backend.util.json import to_dict
from backend.util.process import AppProcess
from backend.util.retry import conn_retry
from backend.util.settings import Config, Secrets
@@ -53,36 +44,12 @@ T = TypeVar("T")
C = TypeVar("C", bound=Callable)
config = Config()
api_host = config.pyro_host
api_comm_retry = config.pyro_client_comm_retry
api_comm_timeout = config.pyro_client_comm_timeout
api_call_timeout = config.rpc_client_call_timeout
pyro_config.MAX_RETRIES = api_comm_retry # type: ignore
pyro_config.COMMTIMEOUT = api_comm_timeout # type: ignore
pyro_host = config.pyro_host
pyro_config.MAX_RETRIES = config.pyro_client_comm_retry # type: ignore
pyro_config.COMMTIMEOUT = config.pyro_client_comm_timeout # type: ignore
P = ParamSpec("P")
R = TypeVar("R")
def fastapi_expose(func: C) -> C:
func = getattr(func, "__func__", func)
setattr(func, "__exposed__", True)
return func
def fastapi_exposed_run_and_wait(
f: Callable[P, Coroutine[None, None, R]]
) -> Callable[Concatenate[object, P], R]:
# TODO:
# This function lies about its return type to make the DynamicClient
# call the function synchronously, fix this when DynamicClient can choose
# to call a function synchronously or asynchronously.
return expose(f) # type: ignore
# ----- Begin Pyro Expose Block ---- #
def pyro_expose(func: C) -> C:
def expose(func: C) -> C:
"""
Decorator to mark a method or class to be exposed for remote calls.
@@ -146,36 +113,7 @@ def _make_custom_deserializer(model: Type[BaseModel]):
return custom_dict_to_class
def pyro_exposed_run_and_wait(
f: Callable[P, Coroutine[None, None, R]]
) -> Callable[Concatenate[object, P], R]:
@expose
@wraps(f)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
coroutine = f(*args, **kwargs)
res = self.run_and_wait(coroutine)
return res
# Register serializers for annotations on bare function
register_pydantic_serializers(f)
return wrapper
if config.use_http_based_rpc:
expose = fastapi_expose
exposed_run_and_wait = fastapi_exposed_run_and_wait
else:
expose = pyro_expose
exposed_run_and_wait = pyro_exposed_run_and_wait
# ----- End Pyro Expose Block ---- #
# --------------------------------------------------
# AppService for IPC service based on HTTP request through FastAPI
# --------------------------------------------------
class BaseAppService(AppProcess, ABC):
class AppService(AppProcess, ABC):
shared_event_loop: asyncio.AbstractEventLoop
use_db: bool = False
use_redis: bool = False
@@ -183,6 +121,9 @@ class BaseAppService(AppProcess, ABC):
rabbitmq_service: Optional[rabbitmq.AsyncRabbitMQ] = None
use_supabase: bool = False
def __init__(self):
self.uri = None
@classmethod
@abstractmethod
def get_port(cls) -> int:
@@ -190,7 +131,7 @@ class BaseAppService(AppProcess, ABC):
@classmethod
def get_host(cls) -> str:
return os.environ.get(f"{cls.service_name.upper()}_HOST", api_host)
return os.environ.get(f"{cls.service_name.upper()}_HOST", config.pyro_host)
@property
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
@@ -210,8 +151,12 @@ class BaseAppService(AppProcess, ABC):
while True:
time.sleep(10)
def __run_async(self, coro: Coroutine[Any, Any, T]):
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop)
def run_and_wait(self, coro: Coroutine[Any, Any, T]) -> T:
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop).result()
future = self.__run_async(coro)
return future.result()
def run(self):
self.shared_event_loop = asyncio.get_event_loop()
@@ -221,8 +166,12 @@ class BaseAppService(AppProcess, ABC):
redis.connect()
if self.rabbitmq_config:
logger.info(f"[{self.__class__.__name__}] ⏳ Configuring RabbitMQ...")
# if self.use_async:
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
self.shared_event_loop.run_until_complete(self.rabbitmq_service.connect())
# else:
# self.rabbitmq_service = rabbitmq.SyncRabbitMQ(self.rabbitmq_config)
# self.rabbitmq_service.connect()
if self.use_supabase:
from supabase import create_client
@@ -231,6 +180,19 @@ class BaseAppService(AppProcess, ABC):
secrets.supabase_url, secrets.supabase_service_role_key
)
# Initialize the async loop.
async_thread = threading.Thread(target=self.__start_async_loop)
async_thread.daemon = True
async_thread.start()
# Initialize pyro service
daemon_thread = threading.Thread(target=self.__start_pyro)
daemon_thread.daemon = True
daemon_thread.start()
# Run the main service (if it's not implemented, just sleep).
self.run_service()
def cleanup(self):
if self.use_db:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting DB...")
@@ -241,141 +203,6 @@ class BaseAppService(AppProcess, ABC):
if self.rabbitmq_config:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting RabbitMQ...")
class RemoteCallError(BaseModel):
type: str = "RemoteCallError"
args: Optional[Tuple[Any, ...]] = None
EXCEPTION_MAPPING = {
e.__name__: e
for e in [
ValueError,
TimeoutError,
ConnectionError,
InsufficientBalanceError,
]
}
class FastApiAppService(BaseAppService, ABC):
fastapi_app: FastAPI
@staticmethod
def _handle_internal_http_error(status_code: int = 500, log_error: bool = True):
def handler(request: Request, exc: Exception):
if log_error:
if status_code == 500:
log = logger.exception
else:
log = logger.error
log(f"{request.method} {request.url.path} failed: {exc}")
return responses.JSONResponse(
status_code=status_code,
content=RemoteCallError(
type=str(exc.__class__.__name__),
args=exc.args or (str(exc),),
).model_dump(),
)
return handler
def _create_fastapi_endpoint(self, func: Callable) -> Callable:
"""
Generates a FastAPI endpoint for the given function, handling default and optional parameters.
:param func: The original function (sync/async, bound or unbound)
:return: A FastAPI endpoint function.
"""
sig = inspect.signature(func)
fields = {}
is_bound_method = False
for name, param in sig.parameters.items():
if name in ("self", "cls"):
is_bound_method = True
continue
# Use the provided annotation or fallback to str if not specified
annotation = (
param.annotation if param.annotation != inspect.Parameter.empty else str
)
# If a default value is provided, use it; otherwise, mark the field as required with '...'
default = param.default if param.default != inspect.Parameter.empty else ...
fields[name] = (annotation, default)
# Dynamically create a Pydantic model for the request body
RequestBodyModel = create_model("RequestBodyModel", **fields)
f = func.__get__(self) if is_bound_method else func
if asyncio.iscoroutinefunction(f):
async def async_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
return await f(
**{name: getattr(body, name) for name in body.model_fields}
)
return async_endpoint
else:
def sync_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
return f(**{name: getattr(body, name) for name in body.model_fields})
return sync_endpoint
@conn_retry("FastAPI server", "Starting FastAPI server")
def __start_fastapi(self):
logger.info(
f"[{self.service_name}] Starting RPC server at http://{api_host}:{self.get_port()}"
)
server = uvicorn.Server(
uvicorn.Config(
self.fastapi_app,
host=api_host,
port=self.get_port(),
log_level="warning",
)
)
self.shared_event_loop.run_until_complete(server.serve())
def run(self):
super().run()
self.fastapi_app = FastAPI()
# Register the exposed API routes.
for attr_name, attr in vars(type(self)).items():
if getattr(attr, "__exposed__", False):
route_path = f"/{attr_name}"
self.fastapi_app.add_api_route(
route_path,
self._create_fastapi_endpoint(attr),
methods=["POST"],
)
self.fastapi_app.add_api_route(
"/health_check", self.health_check, methods=["POST"]
)
self.fastapi_app.add_exception_handler(
ValueError, self._handle_internal_http_error(400)
)
self.fastapi_app.add_exception_handler(
Exception, self._handle_internal_http_error(500)
)
# Start the FastAPI server in a separate thread.
api_thread = threading.Thread(target=self.__start_fastapi, daemon=True)
api_thread.start()
# Run the main service loop (blocking).
self.run_service()
# ----- Begin Pyro AppService Block ---- #
class PyroAppService(BaseAppService, ABC):
@conn_retry("Pyro", "Starting Pyro Service")
def __start_pyro(self):
maximum_connection_thread_count = max(
@@ -384,137 +211,40 @@ class PyroAppService(BaseAppService, ABC):
)
Pyro5.config.THREADPOOL_SIZE = maximum_connection_thread_count # type: ignore
daemon = Pyro5.api.Daemon(host=api_host, port=self.get_port())
daemon = Pyro5.api.Daemon(host=config.pyro_host, port=self.get_port())
self.uri = daemon.register(self, objectId=self.service_name)
logger.info(f"[{self.service_name}] Connected to Pyro; URI = {self.uri}")
daemon.requestLoop()
def run(self):
super().run()
# Initialize the async loop.
async_thread = threading.Thread(target=self.shared_event_loop.run_forever)
async_thread.daemon = True
async_thread.start()
# Initialize pyro service
daemon_thread = threading.Thread(target=self.__start_pyro)
daemon_thread.daemon = True
daemon_thread.start()
# Run the main service loop (blocking).
self.run_service()
def __start_async_loop(self):
self.shared_event_loop.run_forever()
if config.use_http_based_rpc:
class AppService(FastApiAppService, ABC): # type: ignore #AppService defined twice
pass
else:
class AppService(PyroAppService, ABC):
pass
# --------- UTILITIES --------- #
# ----- End Pyro AppService Block ---- #
# --------------------------------------------------
# HTTP Client utilities for dynamic service client abstraction
# --------------------------------------------------
AS = TypeVar("AS", bound=AppService)
def fastapi_close_service_client(client: Any) -> None:
if hasattr(client, "close"):
client.close()
else:
logger.warning(f"Client {client} is not closable")
@conn_retry("FastAPI client", "Creating service client", max_retry=api_comm_retry)
def fastapi_get_service_client(
service_type: Type[AS],
call_timeout: int | None = api_call_timeout,
) -> AS:
class DynamicClient:
def __init__(self):
host = service_type.get_host()
port = service_type.get_port()
self.base_url = f"http://{host}:{port}".rstrip("/")
self.client = httpx.Client(
base_url=self.base_url,
timeout=call_timeout,
)
def _call_method(self, method_name: str, **kwargs) -> Any:
try:
response = self.client.post(method_name, json=to_dict(kwargs))
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error in {method_name}: {e.response.text}")
error = RemoteCallError.model_validate(e.response.json())
# DEBUG HELP: if you made a custom exception, make sure you override self.args to be how to make your exception
raise EXCEPTION_MAPPING.get(error.type, Exception)(
*(error.args or [str(e)])
)
def close(self):
self.client.close()
def __getattr__(self, name: str) -> Callable[..., Any]:
# Try to get the original function from the service type.
orig_func = getattr(service_type, name, None)
if orig_func is None:
raise AttributeError(f"Method {name} not found in {service_type}")
sig = inspect.signature(orig_func)
ret_ann = sig.return_annotation
if ret_ann != inspect.Signature.empty:
expected_return = TypeAdapter(ret_ann)
else:
expected_return = None
def method(*args, **kwargs) -> Any:
if args:
arg_names = list(sig.parameters.keys())
if arg_names[0] in ("self", "cls"):
arg_names = arg_names[1:]
kwargs.update(dict(zip(arg_names, args)))
result = self._call_method(name, **kwargs)
if expected_return:
return expected_return.validate_python(result)
return result
return method
client = cast(AS, DynamicClient())
client.health_check()
return cast(AS, client)
# ----- Begin Pyro Client Block ---- #
class PyroClient:
proxy: Pyro5.api.Proxy
def pyro_close_service_client(client: BaseAppService) -> None:
def close_service_client(client: AppService) -> None:
if isinstance(client, PyroClient):
client.proxy._pyroRelease()
else:
raise RuntimeError(f"Client {client.__class__} is not a Pyro client.")
def pyro_get_service_client(service_type: Type[AS]) -> AS:
def get_service_client(service_type: Type[AS]) -> AS:
service_name = service_type.service_name
class DynamicClient(PyroClient):
@conn_retry("Pyro", f"Connecting to [{service_name}]")
def __init__(self):
uri = f"PYRO:{service_type.service_name}@{service_type.get_host()}:{service_type.get_port()}"
host = os.environ.get(f"{service_name.upper()}_HOST", pyro_host)
uri = f"PYRO:{service_type.service_name}@{host}:{service_type.get_port()}"
logger.debug(f"Connecting to service [{service_name}]. URI = {uri}")
self.proxy = Pyro5.api.Proxy(uri)
# Attempt to bind to ensure the connection is established
@@ -574,13 +304,3 @@ def _pydantic_models_from_type_annotation(annotation) -> Iterator[type[BaseModel
yield annotype
elif annotype not in builtin_types and not issubclass(annotype, Enum):
raise TypeError(f"Unsupported type encountered: {annotype}")
if config.use_http_based_rpc:
close_service_client = fastapi_close_service_client
get_service_client = fastapi_get_service_client
else:
close_service_client = pyro_close_service_client
get_service_client = pyro_get_service_client
# ----- End Pyro Client Block ---- #

View File

@@ -65,10 +65,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
le=1000,
description="Maximum number of workers to use for node execution within a single graph.",
)
use_http_based_rpc: bool = Field(
default=True,
description="Whether to use HTTP-based RPC for communication between services.",
)
pyro_host: str = Field(
default="localhost",
description="The default hostname of the Pyro server.",
@@ -81,10 +77,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default=3,
description="The default number of retries for Pyro client connections.",
)
rpc_client_call_timeout: int = Field(
default=300,
description="The default timeout in seconds, for RPC client calls.",
)
enable_auth: bool = Field(
default=True,
description="If authentication is enabled or not",
@@ -164,11 +156,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The port for notification service daemon to run on",
)
otto_api_url: str = Field(
default="",
description="The URL for the Otto API service",
)
platform_base_url: str = Field(
default="",
description="Must be set so the application knows where it's hosted at. "
@@ -215,11 +202,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The email address to use for sending emails",
)
use_agent_image_generation_v2: bool = Field(
default=True,
description="Whether to use the new agent image generation service",
)
@field_validator("platform_base_url", "frontend_base_url")
@classmethod
def validate_platform_base_url(cls, v: str, info: ValidationInfo) -> str:
@@ -320,16 +302,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
default="", description="Postmark server API token used for sending emails"
)
postmark_webhook_token: str = Field(
default="",
description="The token to use for the Postmark webhook",
)
unsubscribe_secret_key: str = Field(
default="",
description="The secret key to use for the unsubscribe user by token",
)
# OAuth server credentials for integrations
# --8<-- [start:OAuthServerCredentialsExample]
github_client_id: str = Field(default="", description="GitHub OAuth client ID")
@@ -404,7 +376,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
smartlead_api_key: str = Field(default="", description="SmartLead API Key")
zerobounce_api_key: str = Field(default="", description="ZeroBounce API Key")
example_api_key: str = Field(default="", description="Example API Key")
# Add more secret fields as needed
model_config = SettingsConfigDict(

View File

@@ -8,7 +8,7 @@ from backend.data.block import Block, BlockSchema, initialize_blocks
from backend.data.execution import ExecutionResult, ExecutionStatus
from backend.data.model import _BaseCredentials
from backend.data.user import create_default_user
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
from backend.executor import DatabaseManager, ExecutionManager, ExecutionScheduler
from backend.notifications.notifications import NotificationManager
from backend.server.rest_api import AgentServer
from backend.server.utils import get_user_id
@@ -21,7 +21,7 @@ class SpinTestServer:
self.db_api = DatabaseManager()
self.exec_manager = ExecutionManager()
self.agent_server = AgentServer()
self.scheduler = Scheduler()
self.scheduler = ExecutionScheduler()
self.notif_manager = NotificationManager()
@staticmethod

View File

@@ -1,7 +1,6 @@
import logging
import bleach
from bleach.css_sanitizer import CSSSanitizer
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
from markupsafe import Markup
@@ -9,95 +8,15 @@ from markupsafe import Markup
logger = logging.getLogger(__name__)
def format_filter_for_jinja2(value, format_string=None):
if format_string:
return format_string % float(value)
return value
class TextFormatter:
def __init__(self):
self.env = SandboxedEnvironment(loader=BaseLoader(), autoescape=True)
self.env.filters.clear()
self.env.tests.clear()
self.env.globals.clear()
# Instead of clearing all filters, just remove potentially unsafe ones
unsafe_filters = ["pprint", "tojson", "urlize", "xmlattr"]
for f in unsafe_filters:
if f in self.env.filters:
del self.env.filters[f]
self.env.filters["format"] = format_filter_for_jinja2
# Define allowed CSS properties (sorted alphabetically, if you add more)
allowed_css_properties = [
"background-color",
"border",
"border-bottom",
"border-color",
"border-left",
"border-radius",
"border-right",
"border-style",
"border-top",
"border-width",
"bottom",
"box-shadow",
"clear",
"color",
"display",
"float",
"font-family",
"font-size",
"font-weight",
"height",
"left",
"letter-spacing",
"line-height",
"margin-bottom",
"margin-left",
"margin-right",
"margin-top",
"padding",
"position",
"right",
"text-align",
"text-decoration",
"text-shadow",
"text-transform",
"top",
"width",
]
self.css_sanitizer = CSSSanitizer(allowed_css_properties=allowed_css_properties)
# Define allowed tags (sorted alphabetically, if you add more)
self.allowed_tags = [
"a",
"b",
"br",
"div",
"em",
"h1",
"h2",
"h3",
"h4",
"h5",
"i",
"img",
"li",
"p",
"span",
"strong",
"u",
"ul",
]
# Define allowed attributes to be used on specific tags
self.allowed_attributes = {
"*": ["class", "style"],
"a": ["href"],
"img": ["src"],
}
self.allowed_tags = ["p", "b", "i", "u", "ul", "li", "br", "strong", "em"]
self.allowed_attributes = {"*": ["style", "class"]}
def format_string(self, template_str: str, values=None, **kwargs) -> str:
"""Regular template rendering with escaping"""
@@ -118,19 +37,17 @@ class TextFormatter:
# First render the content template
content = self.format_string(content_template, data, **kwargs)
# Clean the HTML + CSS but don't escape it
# Clean the HTML but don't escape it
clean_content = bleach.clean(
content,
tags=self.allowed_tags,
attributes=self.allowed_attributes,
css_sanitizer=self.css_sanitizer,
strip=True,
)
# Mark the cleaned HTML as safe using Markup
safe_content = Markup(clean_content)
# Render subject
rendered_subject_template = self.format_string(subject_template, data, **kwargs)
# Create new env just for HTML template

View File

@@ -1,8 +0,0 @@
-- Add imageUrl column
ALTER TABLE "LibraryAgent"
ADD COLUMN "creatorId" TEXT,
ADD COLUMN "imageUrl" TEXT;
-- Add foreign key constraint for creatorId -> Profile
ALTER TABLE "LibraryAgent"
ADD CONSTRAINT "LibraryAgent_creatorId_fkey" FOREIGN KEY ("creatorId") REFERENCES "Profile"("id") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@@ -1,10 +0,0 @@
-- First, add the column as nullable to avoid issues with existing rows
ALTER TABLE "User" ADD COLUMN "emailVerified" BOOLEAN;
-- Set default values for existing rows
UPDATE "User" SET "emailVerified" = true;
-- Now make it NOT NULL and set the default
ALTER TABLE "User" ALTER COLUMN "emailVerified" SET NOT NULL;
ALTER TABLE "User" ALTER COLUMN "emailVerified" SET DEFAULT true;

View File

@@ -1,26 +0,0 @@
-- Create UserOnboarding table
CREATE TABLE "UserOnboarding" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"step" INTEGER NOT NULL DEFAULT 0,
"usageReason" TEXT,
"integrations" TEXT[] DEFAULT ARRAY[]::TEXT[],
"otherIntegrations" TEXT,
"selectedAgentCreator" TEXT,
"selectedAgentSlug" TEXT,
"agentInput" JSONB,
"isCompleted" BOOLEAN NOT NULL DEFAULT false,
"userId" TEXT NOT NULL,
CONSTRAINT "UserOnboarding_pkey" PRIMARY KEY ("id")
);
-- Create unique constraint on userId
ALTER TABLE "UserOnboarding" ADD CONSTRAINT "UserOnboarding_userId_key" UNIQUE ("userId");
-- Create index on userId
CREATE INDEX "UserOnboarding_userId_idx" ON "UserOnboarding"("userId");
-- Add foreign key constraint
ALTER TABLE "UserOnboarding" ADD CONSTRAINT "UserOnboarding_userId_fkey"
FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -1,6 +0,0 @@
-- Add isDeleted column to AgentGraphExecution
ALTER TABLE "AgentGraphExecution"
ADD COLUMN "isDeleted"
BOOLEAN
NOT NULL
DEFAULT false;

View File

@@ -1,11 +0,0 @@
-- DropIndex
DROP INDEX "APIKey_userId_idx";
-- DropIndex
DROP INDEX "StoreListing_agentId_owningUserId_idx";
-- DropIndex
DROP INDEX "StoreListing_isDeleted_idx";
-- DropIndex
DROP INDEX "StoreListingVersion_agentId_agentVersion_isDeleted_idx";

View File

@@ -1,8 +0,0 @@
/*
Warnings:
- You are about to drop the column `isTemplate` on the `AgentGraph` table. All the data in the column will be lost.
*/
-- AlterTable
ALTER TABLE "AgentGraph" DROP COLUMN "isTemplate";

View File

@@ -395,7 +395,6 @@ files = [
]
[package.dependencies]
tinycss2 = {version = ">=1.1.0,<1.5", optional = true, markers = "extra == \"css\""}
webencodings = "*"
[package.extras]
@@ -4412,21 +4411,21 @@ all = ["numpy"]
[[package]]
name = "realtime"
version = "2.3.0"
version = "2.2.0"
description = ""
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "realtime-2.3.0-py3-none-any.whl", hash = "sha256:6c241681d0517a3bc5e0132842bffd8b592286131b01a68b41cf7e0be94828fc"},
{file = "realtime-2.3.0.tar.gz", hash = "sha256:4071b095d7f750fcd68ec322e05045fce067b5cd5309a7ca809fcc87e50f56a1"},
{file = "realtime-2.2.0-py3-none-any.whl", hash = "sha256:26dbaa58d143345318344bd7a7d4dc67154d6e0e9c98524327053a78bb3cc6b6"},
{file = "realtime-2.2.0.tar.gz", hash = "sha256:f87a51b6b8dd8c72c30af6c841e0161132dcb32bf8b96178f3fe3866d575ef33"},
]
[package.dependencies]
aiohttp = ">=3.11.11,<4.0.0"
python-dateutil = ">=2.8.1,<3.0.0"
typing-extensions = ">=4.12.2,<5.0.0"
websockets = ">=11,<15"
websockets = ">=11,<14"
[[package]]
name = "redis"
@@ -5074,25 +5073,6 @@ files = [
doc = ["reno", "sphinx"]
test = ["pytest", "tornado (>=4.5)", "typeguard"]
[[package]]
name = "tinycss2"
version = "1.4.0"
description = "A tiny CSS parser"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "tinycss2-1.4.0-py3-none-any.whl", hash = "sha256:3a49cf47b7675da0b15d0c6e1df8df4ebd96e9394bb905a5775adb0d884c5289"},
{file = "tinycss2-1.4.0.tar.gz", hash = "sha256:10c0972f6fc0fbee87c3edb76549357415e94548c1ae10ebccdea16fb404a9b7"},
]
[package.dependencies]
webencodings = ">=0.4"
[package.extras]
doc = ["sphinx", "sphinx_rtd_theme"]
test = ["pytest", "ruff"]
[[package]]
name = "todoist-api-python"
version = "2.1.7"
@@ -5573,81 +5553,98 @@ test = ["websockets"]
[[package]]
name = "websockets"
version = "14.2"
version = "13.1"
description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)"
optional = false
python-versions = ">=3.9"
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "websockets-14.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e8179f95323b9ab1c11723e5d91a89403903f7b001828161b480a7810b334885"},
{file = "websockets-14.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0d8c3e2cdb38f31d8bd7d9d28908005f6fa9def3324edb9bf336d7e4266fd397"},
{file = "websockets-14.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:714a9b682deb4339d39ffa674f7b674230227d981a37d5d174a4a83e3978a610"},
{file = "websockets-14.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2e53c72052f2596fb792a7acd9704cbc549bf70fcde8a99e899311455974ca3"},
{file = "websockets-14.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e3fbd68850c837e57373d95c8fe352203a512b6e49eaae4c2f4088ef8cf21980"},
{file = "websockets-14.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b27ece32f63150c268593d5fdb82819584831a83a3f5809b7521df0685cd5d8"},
{file = "websockets-14.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4daa0faea5424d8713142b33825fff03c736f781690d90652d2c8b053345b0e7"},
{file = "websockets-14.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:bc63cee8596a6ec84d9753fd0fcfa0452ee12f317afe4beae6b157f0070c6c7f"},
{file = "websockets-14.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7a570862c325af2111343cc9b0257b7119b904823c675b22d4ac547163088d0d"},
{file = "websockets-14.2-cp310-cp310-win32.whl", hash = "sha256:75862126b3d2d505e895893e3deac0a9339ce750bd27b4ba515f008b5acf832d"},
{file = "websockets-14.2-cp310-cp310-win_amd64.whl", hash = "sha256:cc45afb9c9b2dc0852d5c8b5321759cf825f82a31bfaf506b65bf4668c96f8b2"},
{file = "websockets-14.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3bdc8c692c866ce5fefcaf07d2b55c91d6922ac397e031ef9b774e5b9ea42166"},
{file = "websockets-14.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c93215fac5dadc63e51bcc6dceca72e72267c11def401d6668622b47675b097f"},
{file = "websockets-14.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1c9b6535c0e2cf8a6bf938064fb754aaceb1e6a4a51a80d884cd5db569886910"},
{file = "websockets-14.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a52a6d7cf6938e04e9dceb949d35fbdf58ac14deea26e685ab6368e73744e4c"},
{file = "websockets-14.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f05702e93203a6ff5226e21d9b40c037761b2cfb637187c9802c10f58e40473"},
{file = "websockets-14.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22441c81a6748a53bfcb98951d58d1af0661ab47a536af08920d129b4d1c3473"},
{file = "websockets-14.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:efd9b868d78b194790e6236d9cbc46d68aba4b75b22497eb4ab64fa640c3af56"},
{file = "websockets-14.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1a5a20d5843886d34ff8c57424cc65a1deda4375729cbca4cb6b3353f3ce4142"},
{file = "websockets-14.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:34277a29f5303d54ec6468fb525d99c99938607bc96b8d72d675dee2b9f5bf1d"},
{file = "websockets-14.2-cp311-cp311-win32.whl", hash = "sha256:02687db35dbc7d25fd541a602b5f8e451a238ffa033030b172ff86a93cb5dc2a"},
{file = "websockets-14.2-cp311-cp311-win_amd64.whl", hash = "sha256:862e9967b46c07d4dcd2532e9e8e3c2825e004ffbf91a5ef9dde519ee2effb0b"},
{file = "websockets-14.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1f20522e624d7ffbdbe259c6b6a65d73c895045f76a93719aa10cd93b3de100c"},
{file = "websockets-14.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:647b573f7d3ada919fd60e64d533409a79dcf1ea21daeb4542d1d996519ca967"},
{file = "websockets-14.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6af99a38e49f66be5a64b1e890208ad026cda49355661549c507152113049990"},
{file = "websockets-14.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:091ab63dfc8cea748cc22c1db2814eadb77ccbf82829bac6b2fbe3401d548eda"},
{file = "websockets-14.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b374e8953ad477d17e4851cdc66d83fdc2db88d9e73abf755c94510ebddceb95"},
{file = "websockets-14.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a39d7eceeea35db85b85e1169011bb4321c32e673920ae9c1b6e0978590012a3"},
{file = "websockets-14.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0a6f3efd47ffd0d12080594f434faf1cd2549b31e54870b8470b28cc1d3817d9"},
{file = "websockets-14.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:065ce275e7c4ffb42cb738dd6b20726ac26ac9ad0a2a48e33ca632351a737267"},
{file = "websockets-14.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e9d0e53530ba7b8b5e389c02282f9d2aa47581514bd6049d3a7cffe1385cf5fe"},
{file = "websockets-14.2-cp312-cp312-win32.whl", hash = "sha256:20e6dd0984d7ca3037afcb4494e48c74ffb51e8013cac71cf607fffe11df7205"},
{file = "websockets-14.2-cp312-cp312-win_amd64.whl", hash = "sha256:44bba1a956c2c9d268bdcdf234d5e5ff4c9b6dc3e300545cbe99af59dda9dcce"},
{file = "websockets-14.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6f1372e511c7409a542291bce92d6c83320e02c9cf392223272287ce55bc224e"},
{file = "websockets-14.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4da98b72009836179bb596a92297b1a61bb5a830c0e483a7d0766d45070a08ad"},
{file = "websockets-14.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f8a86a269759026d2bde227652b87be79f8a734e582debf64c9d302faa1e9f03"},
{file = "websockets-14.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86cf1aaeca909bf6815ea714d5c5736c8d6dd3a13770e885aafe062ecbd04f1f"},
{file = "websockets-14.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9b0f6c3ba3b1240f602ebb3971d45b02cc12bd1845466dd783496b3b05783a5"},
{file = "websockets-14.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:669c3e101c246aa85bc8534e495952e2ca208bd87994650b90a23d745902db9a"},
{file = "websockets-14.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:eabdb28b972f3729348e632ab08f2a7b616c7e53d5414c12108c29972e655b20"},
{file = "websockets-14.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2066dc4cbcc19f32c12a5a0e8cc1b7ac734e5b64ac0a325ff8353451c4b15ef2"},
{file = "websockets-14.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ab95d357cd471df61873dadf66dd05dd4709cae001dd6342edafc8dc6382f307"},
{file = "websockets-14.2-cp313-cp313-win32.whl", hash = "sha256:a9e72fb63e5f3feacdcf5b4ff53199ec8c18d66e325c34ee4c551ca748623bbc"},
{file = "websockets-14.2-cp313-cp313-win_amd64.whl", hash = "sha256:b439ea828c4ba99bb3176dc8d9b933392a2413c0f6b149fdcba48393f573377f"},
{file = "websockets-14.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7cd5706caec1686c5d233bc76243ff64b1c0dc445339bd538f30547e787c11fe"},
{file = "websockets-14.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ec607328ce95a2f12b595f7ae4c5d71bf502212bddcea528290b35c286932b12"},
{file = "websockets-14.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:da85651270c6bfb630136423037dd4975199e5d4114cae6d3066641adcc9d1c7"},
{file = "websockets-14.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3ecadc7ce90accf39903815697917643f5b7cfb73c96702318a096c00aa71f5"},
{file = "websockets-14.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1979bee04af6a78608024bad6dfcc0cc930ce819f9e10342a29a05b5320355d0"},
{file = "websockets-14.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2dddacad58e2614a24938a50b85969d56f88e620e3f897b7d80ac0d8a5800258"},
{file = "websockets-14.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:89a71173caaf75fa71a09a5f614f450ba3ec84ad9fca47cb2422a860676716f0"},
{file = "websockets-14.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:6af6a4b26eea4fc06c6818a6b962a952441e0e39548b44773502761ded8cc1d4"},
{file = "websockets-14.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:80c8efa38957f20bba0117b48737993643204645e9ec45512579132508477cfc"},
{file = "websockets-14.2-cp39-cp39-win32.whl", hash = "sha256:2e20c5f517e2163d76e2729104abc42639c41cf91f7b1839295be43302713661"},
{file = "websockets-14.2-cp39-cp39-win_amd64.whl", hash = "sha256:b4c8cef610e8d7c70dea92e62b6814a8cd24fbd01d7103cc89308d2bfe1659ef"},
{file = "websockets-14.2-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:d7d9cafbccba46e768be8a8ad4635fa3eae1ffac4c6e7cb4eb276ba41297ed29"},
{file = "websockets-14.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:c76193c1c044bd1e9b3316dcc34b174bbf9664598791e6fb606d8d29000e070c"},
{file = "websockets-14.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd475a974d5352390baf865309fe37dec6831aafc3014ffac1eea99e84e83fc2"},
{file = "websockets-14.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c6c0097a41968b2e2b54ed3424739aab0b762ca92af2379f152c1aef0187e1c"},
{file = "websockets-14.2-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d7ff794c8b36bc402f2e07c0b2ceb4a2424147ed4785ff03e2a7af03711d60a"},
{file = "websockets-14.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dec254fcabc7bd488dab64846f588fc5b6fe0d78f641180030f8ea27b76d72c3"},
{file = "websockets-14.2-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:bbe03eb853e17fd5b15448328b4ec7fb2407d45fb0245036d06a3af251f8e48f"},
{file = "websockets-14.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:a3c4aa3428b904d5404a0ed85f3644d37e2cb25996b7f096d77caeb0e96a3b42"},
{file = "websockets-14.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:577a4cebf1ceaf0b65ffc42c54856214165fb8ceeba3935852fc33f6b0c55e7f"},
{file = "websockets-14.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ad1c1d02357b7665e700eca43a31d52814ad9ad9b89b58118bdabc365454b574"},
{file = "websockets-14.2-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f390024a47d904613577df83ba700bd189eedc09c57af0a904e5c39624621270"},
{file = "websockets-14.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:3c1426c021c38cf92b453cdf371228d3430acd775edee6bac5a4d577efc72365"},
{file = "websockets-14.2-py3-none-any.whl", hash = "sha256:7a6ceec4ea84469f15cf15807a747e9efe57e369c384fa86e022b3bea679b79b"},
{file = "websockets-14.2.tar.gz", hash = "sha256:5059ed9c54945efb321f097084b4c7e52c246f2c869815876a69d1efc4ad6eb5"},
{file = "websockets-13.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f48c749857f8fb598fb890a75f540e3221d0976ed0bf879cf3c7eef34151acee"},
{file = "websockets-13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c7e72ce6bda6fb9409cc1e8164dd41d7c91466fb599eb047cfda72fe758a34a7"},
{file = "websockets-13.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f779498eeec470295a2b1a5d97aa1bc9814ecd25e1eb637bd9d1c73a327387f6"},
{file = "websockets-13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4676df3fe46956fbb0437d8800cd5f2b6d41143b6e7e842e60554398432cf29b"},
{file = "websockets-13.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7affedeb43a70351bb811dadf49493c9cfd1ed94c9c70095fd177e9cc1541fa"},
{file = "websockets-13.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1971e62d2caa443e57588e1d82d15f663b29ff9dfe7446d9964a4b6f12c1e700"},
{file = "websockets-13.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5f2e75431f8dc4a47f31565a6e1355fb4f2ecaa99d6b89737527ea917066e26c"},
{file = "websockets-13.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:58cf7e75dbf7e566088b07e36ea2e3e2bd5676e22216e4cad108d4df4a7402a0"},
{file = "websockets-13.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c90d6dec6be2c7d03378a574de87af9b1efea77d0c52a8301dd831ece938452f"},
{file = "websockets-13.1-cp310-cp310-win32.whl", hash = "sha256:730f42125ccb14602f455155084f978bd9e8e57e89b569b4d7f0f0c17a448ffe"},
{file = "websockets-13.1-cp310-cp310-win_amd64.whl", hash = "sha256:5993260f483d05a9737073be197371940c01b257cc45ae3f1d5d7adb371b266a"},
{file = "websockets-13.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:61fc0dfcda609cda0fc9fe7977694c0c59cf9d749fbb17f4e9483929e3c48a19"},
{file = "websockets-13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ceec59f59d092c5007e815def4ebb80c2de330e9588e101cf8bd94c143ec78a5"},
{file = "websockets-13.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c1dca61c6db1166c48b95198c0b7d9c990b30c756fc2923cc66f68d17dc558fd"},
{file = "websockets-13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:308e20f22c2c77f3f39caca508e765f8725020b84aa963474e18c59accbf4c02"},
{file = "websockets-13.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62d516c325e6540e8a57b94abefc3459d7dab8ce52ac75c96cad5549e187e3a7"},
{file = "websockets-13.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87c6e35319b46b99e168eb98472d6c7d8634ee37750d7693656dc766395df096"},
{file = "websockets-13.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5f9fee94ebafbc3117c30be1844ed01a3b177bb6e39088bc6b2fa1dc15572084"},
{file = "websockets-13.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:7c1e90228c2f5cdde263253fa5db63e6653f1c00e7ec64108065a0b9713fa1b3"},
{file = "websockets-13.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6548f29b0e401eea2b967b2fdc1c7c7b5ebb3eeb470ed23a54cd45ef078a0db9"},
{file = "websockets-13.1-cp311-cp311-win32.whl", hash = "sha256:c11d4d16e133f6df8916cc5b7e3e96ee4c44c936717d684a94f48f82edb7c92f"},
{file = "websockets-13.1-cp311-cp311-win_amd64.whl", hash = "sha256:d04f13a1d75cb2b8382bdc16ae6fa58c97337253826dfe136195b7f89f661557"},
{file = "websockets-13.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9d75baf00138f80b48f1eac72ad1535aac0b6461265a0bcad391fc5aba875cfc"},
{file = "websockets-13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9b6f347deb3dcfbfde1c20baa21c2ac0751afaa73e64e5b693bb2b848efeaa49"},
{file = "websockets-13.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:de58647e3f9c42f13f90ac7e5f58900c80a39019848c5547bc691693098ae1bd"},
{file = "websockets-13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1b54689e38d1279a51d11e3467dd2f3a50f5f2e879012ce8f2d6943f00e83f0"},
{file = "websockets-13.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf1781ef73c073e6b0f90af841aaf98501f975d306bbf6221683dd594ccc52b6"},
{file = "websockets-13.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d23b88b9388ed85c6faf0e74d8dec4f4d3baf3ecf20a65a47b836d56260d4b9"},
{file = "websockets-13.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3c78383585f47ccb0fcf186dcb8a43f5438bd7d8f47d69e0b56f71bf431a0a68"},
{file = "websockets-13.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:d6d300f8ec35c24025ceb9b9019ae9040c1ab2f01cddc2bcc0b518af31c75c14"},
{file = "websockets-13.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a9dcaf8b0cc72a392760bb8755922c03e17a5a54e08cca58e8b74f6902b433cf"},
{file = "websockets-13.1-cp312-cp312-win32.whl", hash = "sha256:2f85cf4f2a1ba8f602298a853cec8526c2ca42a9a4b947ec236eaedb8f2dc80c"},
{file = "websockets-13.1-cp312-cp312-win_amd64.whl", hash = "sha256:38377f8b0cdeee97c552d20cf1865695fcd56aba155ad1b4ca8779a5b6ef4ac3"},
{file = "websockets-13.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a9ab1e71d3d2e54a0aa646ab6d4eebfaa5f416fe78dfe4da2839525dc5d765c6"},
{file = "websockets-13.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b9d7439d7fab4dce00570bb906875734df13d9faa4b48e261c440a5fec6d9708"},
{file = "websockets-13.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:327b74e915cf13c5931334c61e1a41040e365d380f812513a255aa804b183418"},
{file = "websockets-13.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:325b1ccdbf5e5725fdcb1b0e9ad4d2545056479d0eee392c291c1bf76206435a"},
{file = "websockets-13.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:346bee67a65f189e0e33f520f253d5147ab76ae42493804319b5716e46dddf0f"},
{file = "websockets-13.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91a0fa841646320ec0d3accdff5b757b06e2e5c86ba32af2e0815c96c7a603c5"},
{file = "websockets-13.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:18503d2c5f3943e93819238bf20df71982d193f73dcecd26c94514f417f6b135"},
{file = "websockets-13.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:a9cd1af7e18e5221d2878378fbc287a14cd527fdd5939ed56a18df8a31136bb2"},
{file = "websockets-13.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:70c5be9f416aa72aab7a2a76c90ae0a4fe2755c1816c153c1a2bcc3333ce4ce6"},
{file = "websockets-13.1-cp313-cp313-win32.whl", hash = "sha256:624459daabeb310d3815b276c1adef475b3e6804abaf2d9d2c061c319f7f187d"},
{file = "websockets-13.1-cp313-cp313-win_amd64.whl", hash = "sha256:c518e84bb59c2baae725accd355c8dc517b4a3ed8db88b4bc93c78dae2974bf2"},
{file = "websockets-13.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c7934fd0e920e70468e676fe7f1b7261c1efa0d6c037c6722278ca0228ad9d0d"},
{file = "websockets-13.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:149e622dc48c10ccc3d2760e5f36753db9cacf3ad7bc7bbbfd7d9c819e286f23"},
{file = "websockets-13.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a569eb1b05d72f9bce2ebd28a1ce2054311b66677fcd46cf36204ad23acead8c"},
{file = "websockets-13.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95df24ca1e1bd93bbca51d94dd049a984609687cb2fb08a7f2c56ac84e9816ea"},
{file = "websockets-13.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8dbb1bf0c0a4ae8b40bdc9be7f644e2f3fb4e8a9aca7145bfa510d4a374eeb7"},
{file = "websockets-13.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:035233b7531fb92a76beefcbf479504db8c72eb3bff41da55aecce3a0f729e54"},
{file = "websockets-13.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:e4450fc83a3df53dec45922b576e91e94f5578d06436871dce3a6be38e40f5db"},
{file = "websockets-13.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:463e1c6ec853202dd3657f156123d6b4dad0c546ea2e2e38be2b3f7c5b8e7295"},
{file = "websockets-13.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6d6855bbe70119872c05107e38fbc7f96b1d8cb047d95c2c50869a46c65a8e96"},
{file = "websockets-13.1-cp38-cp38-win32.whl", hash = "sha256:204e5107f43095012b00f1451374693267adbb832d29966a01ecc4ce1db26faf"},
{file = "websockets-13.1-cp38-cp38-win_amd64.whl", hash = "sha256:485307243237328c022bc908b90e4457d0daa8b5cf4b3723fd3c4a8012fce4c6"},
{file = "websockets-13.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9b37c184f8b976f0c0a231a5f3d6efe10807d41ccbe4488df8c74174805eea7d"},
{file = "websockets-13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:163e7277e1a0bd9fb3c8842a71661ad19c6aa7bb3d6678dc7f89b17fbcc4aeb7"},
{file = "websockets-13.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4b889dbd1342820cc210ba44307cf75ae5f2f96226c0038094455a96e64fb07a"},
{file = "websockets-13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:586a356928692c1fed0eca68b4d1c2cbbd1ca2acf2ac7e7ebd3b9052582deefa"},
{file = "websockets-13.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7bd6abf1e070a6b72bfeb71049d6ad286852e285f146682bf30d0296f5fbadfa"},
{file = "websockets-13.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2aad13a200e5934f5a6767492fb07151e1de1d6079c003ab31e1823733ae79"},
{file = "websockets-13.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:df01aea34b6e9e33572c35cd16bae5a47785e7d5c8cb2b54b2acdb9678315a17"},
{file = "websockets-13.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:e54affdeb21026329fb0744ad187cf812f7d3c2aa702a5edb562b325191fcab6"},
{file = "websockets-13.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9ef8aa8bdbac47f4968a5d66462a2a0935d044bf35c0e5a8af152d58516dbeb5"},
{file = "websockets-13.1-cp39-cp39-win32.whl", hash = "sha256:deeb929efe52bed518f6eb2ddc00cc496366a14c726005726ad62c2dd9017a3c"},
{file = "websockets-13.1-cp39-cp39-win_amd64.whl", hash = "sha256:7c65ffa900e7cc958cd088b9a9157a8141c991f8c53d11087e6fb7277a03f81d"},
{file = "websockets-13.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:5dd6da9bec02735931fccec99d97c29f47cc61f644264eb995ad6c0c27667238"},
{file = "websockets-13.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:2510c09d8e8df777177ee3d40cd35450dc169a81e747455cc4197e63f7e7bfe5"},
{file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1c3cf67185543730888b20682fb186fc8d0fa6f07ccc3ef4390831ab4b388d9"},
{file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bcc03c8b72267e97b49149e4863d57c2d77f13fae12066622dc78fe322490fe6"},
{file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:004280a140f220c812e65f36944a9ca92d766b6cc4560be652a0a3883a79ed8a"},
{file = "websockets-13.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e2620453c075abeb0daa949a292e19f56de518988e079c36478bacf9546ced23"},
{file = "websockets-13.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9156c45750b37337f7b0b00e6248991a047be4aa44554c9886fe6bdd605aab3b"},
{file = "websockets-13.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:80c421e07973a89fbdd93e6f2003c17d20b69010458d3a8e37fb47874bd67d51"},
{file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82d0ba76371769d6a4e56f7e83bb8e81846d17a6190971e38b5de108bde9b0d7"},
{file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9875a0143f07d74dc5e1ded1c4581f0d9f7ab86c78994e2ed9e95050073c94d"},
{file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a11e38ad8922c7961447f35c7b17bffa15de4d17c70abd07bfbe12d6faa3e027"},
{file = "websockets-13.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:4059f790b6ae8768471cddb65d3c4fe4792b0ab48e154c9f0a04cefaabcd5978"},
{file = "websockets-13.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:25c35bf84bf7c7369d247f0b8cfa157f989862c49104c5cf85cb5436a641d93e"},
{file = "websockets-13.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:83f91d8a9bb404b8c2c41a707ac7f7f75b9442a0a876df295de27251a856ad09"},
{file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a43cfdcddd07f4ca2b1afb459824dd3c6d53a51410636a2c7fc97b9a8cf4842"},
{file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48a2ef1381632a2f0cb4efeff34efa97901c9fbc118e01951ad7cfc10601a9bb"},
{file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:459bf774c754c35dbb487360b12c5727adab887f1622b8aed5755880a21c4a20"},
{file = "websockets-13.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:95858ca14a9f6fa8413d29e0a585b31b278388aa775b8a81fa24830123874678"},
{file = "websockets-13.1-py3-none-any.whl", hash = "sha256:a9a396a6ad26130cdae92ae10c36af09d9bfe6cafe69670fd3b6da9b07b4044f"},
{file = "websockets-13.1.tar.gz", hash = "sha256:a3b3366087c1bc0a2795111edcadddb8b3b59509d5db5d7ea3fdd69f954a8878"},
]
[[package]]
@@ -6090,4 +6087,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.13"
content-hash = "7d52ef1c6567900f7f1e079b2d317b861330bef797573221277f04b1981d0e05"
content-hash = "ba36ce74308bd37e19ca790e63dae387e5ad6173b2945dd63b58b3e918e85b46"

View File

@@ -13,7 +13,7 @@ aio-pika = "^9.5.4"
anthropic = "^0.45.2"
apscheduler = "^3.11.0"
autogpt-libs = { path = "../autogpt_libs", develop = true }
bleach = {extras = ["css"], version = "^6.2.0"}
bleach = "^6.2.0"
click = "^8.1.7"
cryptography = "^43.0"
discord-py = "^2.4.0"
@@ -61,7 +61,7 @@ tenacity = "^9.0.0"
todoist-api-python = "^2.1.7"
tweepy = "^4.14.0"
uvicorn = { extras = ["standard"], version = "^0.34.0" }
websockets = "^14.2"
websockets = "^13.1"
youtube-transcript-api = "^0.6.2"
zerobouncesdk = "^1.1.1"
# NOTE: please insert new dependencies in their alphabetical location

Some files were not shown because too many files have changed in this diff Show More