Compare commits

..

1 Commits

Author SHA1 Message Date
Nicholas Tindle
8b0688d962 Update client.ts 2025-02-13 16:20:24 -06:00
375 changed files with 7910 additions and 28857 deletions

View File

@@ -129,6 +129,30 @@ updates:
- "minor"
- "patch"
# Submodules
- package-ecosystem: "gitsubmodule"
directory: "autogpt_platform/supabase"
schedule:
interval: "weekly"
open-pull-requests-limit: 1
target-branch: "dev"
commit-message:
prefix: "chore(platform/deps)"
prefix-development: "chore(platform/deps-dev)"
groups:
production-dependencies:
dependency-type: "production"
update-types:
- "minor"
- "patch"
development-dependencies:
dependency-type: "development"
update-types:
- "minor"
- "patch"
# Docs
- package-ecosystem: 'pip'
directory: "docs/"

View File

@@ -115,7 +115,6 @@ jobs:
poetry run pytest -vv \
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
--numprocesses=logical --durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
tests/unit tests/integration
env:
CI: true
@@ -125,14 +124,8 @@ jobs:
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
- name: Upload test results to Codecov
if: ${{ !cancelled() }} # Run even if tests fail
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: autogpt-agent,${{ runner.os }}

View File

@@ -87,20 +87,13 @@ jobs:
poetry run pytest -vv \
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
tests
env:
CI: true
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Upload test results to Codecov
if: ${{ !cancelled() }} # Run even if tests fail
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: agbenchmark,${{ runner.os }}

View File

@@ -139,7 +139,6 @@ jobs:
poetry run pytest -vv \
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
forge
env:
CI: true
@@ -149,14 +148,8 @@ jobs:
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
- name: Upload test results to Codecov
if: ${{ !cancelled() }} # Run even if tests fail
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: forge,${{ runner.os }}

View File

@@ -42,14 +42,6 @@ jobs:
REDIS_PASSWORD: testpassword
ports:
- 6379:6379
rabbitmq:
image: rabbitmq:3.12-management
ports:
- 5672:5672
- 15672:15672
env:
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
steps:
- name: Checkout repository
@@ -66,7 +58,7 @@ jobs:
- name: Setup Supabase
uses: supabase/setup-cli@v1
with:
version: 1.178.1
version: latest
- id: get_date
name: Get date
@@ -147,13 +139,6 @@ jobs:
RUN_ENV: local
PORT: 8080
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
# We know these are here, don't report this as a security vulnerability
# This is used as the default credential for the entire system's RabbitMQ instance
# If you want to replace this, you can do so by making our entire system generate
# new credentials for each local user and update the environment variables in
# the backend service, docker composes, and examples
RABBITMQ_DEFAULT_USER: 'rabbitmq_user_default'
RABBITMQ_DEFAULT_PASS: 'k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7'
# - name: Upload coverage reports to Codecov
# uses: codecov/codecov-action@v4

View File

@@ -77,12 +77,12 @@ jobs:
- name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@main
with:
large-packages: false # slow
docker-images: false # limited benefit
large-packages: false # slow
docker-images: false # limited benefit
- name: Copy default supabase .env
run: |
cp ../.env.example ../.env
cp ../supabase/docker/.env.example ../.env
- name: Copy backend .env
run: |
@@ -104,12 +104,11 @@ jobs:
run: yarn playwright install --with-deps ${{ matrix.browser }}
- name: Run tests
timeout-minutes: 20
run: |
yarn test --project=${{ matrix.browser }}
- name: Print Final Docker Compose logs
if: always()
- name: Print Docker Compose logs in debug mode
if: runner.debug
run: |
docker compose -f ../docker-compose.yml logs

3
.gitmodules vendored
View File

@@ -1,3 +1,6 @@
[submodule "classic/forge/tests/vcr_cassettes"]
path = classic/forge/tests/vcr_cassettes
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes
[submodule "autogpt_platform/supabase"]
path = autogpt_platform/supabase
url = https://github.com/supabase/supabase.git

View File

@@ -140,7 +140,7 @@ repos:
language: system
- repo: https://github.com/psf/black
rev: 24.10.0
rev: 23.12.1
# Black has sensible defaults, doesn't need package context, and ignores
# everything in .gitignore, so it works fine without any config or arguments.
hooks:
@@ -170,16 +170,6 @@ repos:
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
args: [--config=classic/benchmark/.flake8]
- repo: local
hooks:
- id: prettier
name: Format (Prettier) - AutoGPT Platform - Frontend
alias: format-platform-frontend
entry: bash -c 'cd autogpt_platform/frontend && npx prettier --write $(echo "$@" | sed "s|autogpt_platform/frontend/||g")' --
files: ^autogpt_platform/frontend/
types: [file]
language: system
- repo: local
# To have watertight type checking, we check *all* the files in an affected
# project. To trigger on poetry.lock we also reset the file `types` filter.
@@ -231,16 +221,6 @@ repos:
language: system
pass_filenames: false
- repo: local
hooks:
- id: tsc
name: Typecheck - AutoGPT Platform - Frontend
entry: bash -c 'cd autogpt_platform/frontend && npm run type-check'
files: ^autogpt_platform/frontend/
types: [file]
language: system
pass_filenames: false
- repo: local
hooks:
- id: pytest

View File

@@ -2,6 +2,9 @@
If you are reading this, you are probably looking for the full **[contribution guide]**,
which is part of our [wiki].
Also check out our [🚀 Roadmap][roadmap] for information about our priorities and associated tasks.
<!-- You can find our immediate priorities and their progress on our public [kanban board]. -->
[contribution guide]: https://github.com/Significant-Gravitas/AutoGPT/wiki/Contributing
[wiki]: https://github.com/Significant-Gravitas/AutoGPT/wiki
[roadmap]: https://github.com/Significant-Gravitas/AutoGPT/discussions/6971

173
LICENSE
View File

@@ -1,8 +1,5 @@
All portions of this repository are under one of two licenses.
The all files outside of the autogpt_platform folder are under the MIT License below.
The autogpt_platform folder is under the Polyform Shield License below.
All portions of this repository are under one of two licenses. The majority of the AutoGPT repository is under the MIT License below. The autogpt_platform folder is under the
Polyform Shield License.
MIT License
@@ -30,169 +27,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# PolyForm Shield License 1.0.0
<https://polyformproject.org/licenses/shield/1.0.0>
## Acceptance
In order to get any license under these terms, you must agree
to them as both strict obligations and conditions to all
your licenses.
## Copyright License
The licensor grants you a copyright license for the
software to do everything you might do with the software
that would otherwise infringe the licensor's copyright
in it for any permitted purpose. However, you may
only distribute the software according to [Distribution
License](#distribution-license) and make changes or new works
based on the software according to [Changes and New Works
License](#changes-and-new-works-license).
## Distribution License
The licensor grants you an additional copyright license
to distribute copies of the software. Your license
to distribute covers distributing the software with
changes and new works permitted by [Changes and New Works
License](#changes-and-new-works-license).
## Notices
You must ensure that anyone who gets a copy of any part of
the software from you also gets a copy of these terms or the
URL for them above, as well as copies of any plain-text lines
beginning with `Required Notice:` that the licensor provided
with the software. For example:
> Required Notice: Copyright Yoyodyne, Inc. (http://example.com)
## Changes and New Works License
The licensor grants you an additional copyright license to
make changes and new works based on the software for any
permitted purpose.
## Patent License
The licensor grants you a patent license for the software that
covers patent claims the licensor can license, or becomes able
to license, that you would infringe by using the software.
## Noncompete
Any purpose is a permitted purpose, except for providing any
product that competes with the software or any product the
licensor or any of its affiliates provides using the software.
## Competition
Goods and services compete even when they provide functionality
through different kinds of interfaces or for different technical
platforms. Applications can compete with services, libraries
with plugins, frameworks with development tools, and so on,
even if they're written in different programming languages
or for different computer architectures. Goods and services
compete even when provided free of charge. If you market a
product as a practical substitute for the software or another
product, it definitely competes.
## New Products
If you are using the software to provide a product that does
not compete, but the licensor or any of its affiliates brings
your product into competition by providing a new version of
the software or another product using the software, you may
continue using versions of the software available under these
terms beforehand to provide your competing product, but not
any later versions.
## Discontinued Products
You may begin using the software to compete with a product
or service that the licensor or any of its affiliates has
stopped providing, unless the licensor includes a plain-text
line beginning with `Licensor Line of Business:` with the
software that mentions that line of business. For example:
> Licensor Line of Business: YoyodyneCMS Content Management
System (http://example.com/cms)
## Sales of Business
If the licensor or any of its affiliates sells a line of
business developing the software or using the software
to provide a product, the buyer can also enforce
[Noncompete](#noncompete) for that product.
## Fair Use
You may have "fair use" rights for the software under the
law. These terms do not limit them.
## No Other Rights
These terms do not allow you to sublicense or transfer any of
your licenses to anyone else, or prevent the licensor from
granting licenses to anyone else. These terms do not imply
any other licenses.
## Patent Defense
If you make any written claim that the software infringes or
contributes to infringement of any patent, your patent license
for the software granted under these terms ends immediately. If
your company makes such a claim, your patent license ends
immediately for work on behalf of your company.
## Violations
The first time you are notified in writing that you have
violated any of these terms, or done anything with the software
not covered by your licenses, your licenses can nonetheless
continue if you come into full compliance with these terms,
and take practical steps to correct past violations, within
32 days of receiving notice. Otherwise, all your licenses
end immediately.
## No Liability
***As far as the law allows, the software comes as is, without
any warranty or condition, and the licensor will not be liable
to you for any damages arising out of these terms or the use
or nature of the software, under any kind of legal claim.***
## Definitions
The **licensor** is the individual or entity offering these
terms, and the **software** is the software the licensor makes
available under these terms.
A **product** can be a good or service, or a combination
of them.
**You** refers to the individual or entity agreeing to these
terms.
**Your company** is any legal entity, sole proprietorship,
or other kind of organization that you work for, plus all
its affiliates.
**Affiliates** means the other organizations than an
organization has control over, is under the control of, or is
under common control with.
**Control** means ownership of substantially all the assets of
an entity, or the power to direct its management and policies
by vote, contract, or otherwise. Control can be direct or
indirect.
**Your licenses** are all the licenses granted to you for the
software under these terms.
**Use** means anything you do with the software requiring one
of your licenses.

View File

@@ -2,6 +2,7 @@
[![Discord Follow](https://dcbadge.vercel.app/api/server/autogpt?style=flat)](https://discord.gg/autogpt) &ensp;
[![Twitter Follow](https://img.shields.io/twitter/follow/Auto_GPT?style=social)](https://twitter.com/Auto_GPT) &ensp;
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
**AutoGPT** is a powerful platform that allows you to create, deploy, and manage continuous AI agents that automate complex workflows.
@@ -79,7 +80,7 @@ Be part of the revolution! **AutoGPT** is here to stay, at the forefront of AI i
**Licensing:**
MIT License: All files outside of autogpt_platform folder are under the MIT License.
MIT License: The majority of the AutoGPT repository is under the MIT License.
Polyform Shield License: This license applies to the autogpt_platform folder.

View File

@@ -20,7 +20,6 @@ Instead, please report them via:
- Please provide detailed reports with reproducible steps
- Include the version/commit hash where you discovered the vulnerability
- Allow us a 90-day security fix window before any public disclosure
- After patch is released, allow 30 days for users to update before public disclosure (for a total of 120 days max between update time and fix time)
- Share any potential mitigations or workarounds if known
## Supported Versions

View File

@@ -1,123 +0,0 @@
############
# Secrets
# YOU MUST CHANGE THESE BEFORE GOING INTO PRODUCTION
############
POSTGRES_PASSWORD=your-super-secret-and-long-postgres-password
JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
DASHBOARD_USERNAME=supabase
DASHBOARD_PASSWORD=this_password_is_insecure_and_should_be_updated
SECRET_KEY_BASE=UpNVntn3cDxHJpq99YMc1T1AQgQpc8kfYTuRgBiYa15BLrx8etQoXz3gZv1/u2oq
VAULT_ENC_KEY=your-encryption-key-32-chars-min
############
# Database - You can change these to any PostgreSQL database that has logical replication enabled.
############
POSTGRES_HOST=db
POSTGRES_DB=postgres
POSTGRES_PORT=5432
# default user is postgres
############
# Supavisor -- Database pooler
############
POOLER_PROXY_PORT_TRANSACTION=6543
POOLER_DEFAULT_POOL_SIZE=20
POOLER_MAX_CLIENT_CONN=100
POOLER_TENANT_ID=your-tenant-id
############
# API Proxy - Configuration for the Kong Reverse proxy.
############
KONG_HTTP_PORT=8000
KONG_HTTPS_PORT=8443
############
# API - Configuration for PostgREST.
############
PGRST_DB_SCHEMAS=public,storage,graphql_public
############
# Auth - Configuration for the GoTrue authentication server.
############
## General
SITE_URL=http://localhost:3000
ADDITIONAL_REDIRECT_URLS=
JWT_EXPIRY=3600
DISABLE_SIGNUP=false
API_EXTERNAL_URL=http://localhost:8000
## Mailer Config
MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify"
MAILER_URLPATHS_INVITE="/auth/v1/verify"
MAILER_URLPATHS_RECOVERY="/auth/v1/verify"
MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify"
## Email auth
ENABLE_EMAIL_SIGNUP=true
ENABLE_EMAIL_AUTOCONFIRM=false
SMTP_ADMIN_EMAIL=admin@example.com
SMTP_HOST=supabase-mail
SMTP_PORT=2500
SMTP_USER=fake_mail_user
SMTP_PASS=fake_mail_password
SMTP_SENDER_NAME=fake_sender
ENABLE_ANONYMOUS_USERS=false
## Phone auth
ENABLE_PHONE_SIGNUP=true
ENABLE_PHONE_AUTOCONFIRM=true
############
# Studio - Configuration for the Dashboard
############
STUDIO_DEFAULT_ORGANIZATION=Default Organization
STUDIO_DEFAULT_PROJECT=Default Project
STUDIO_PORT=3000
# replace if you intend to use Studio outside of localhost
SUPABASE_PUBLIC_URL=http://localhost:8000
# Enable webp support
IMGPROXY_ENABLE_WEBP_DETECTION=true
# Add your OpenAI API key to enable SQL Editor Assistant
OPENAI_API_KEY=
############
# Functions - Configuration for Functions
############
# NOTE: VERIFY_JWT applies to all functions. Per-function VERIFY_JWT is not supported yet.
FUNCTIONS_VERIFY_JWT=false
############
# Logs - Configuration for Logflare
# Please refer to https://supabase.com/docs/reference/self-hosting-analytics/introduction
############
LOGFLARE_LOGGER_BACKEND_API_KEY=your-super-secret-and-long-logflare-key
# Change vector.toml sinks to reflect this change
LOGFLARE_API_KEY=your-super-secret-and-long-logflare-key
# Docker socket location - this value will differ depending on your OS
DOCKER_SOCKET_LOCATION=/var/run/docker.sock
# Google Cloud Project details
GOOGLE_PROJECT_ID=GOOGLE_PROJECT_ID
GOOGLE_PROJECT_NUMBER=GOOGLE_PROJECT_NUMBER

View File

@@ -2,7 +2,7 @@
**Contributor License Agreement (“Agreement”)**
Thank you for your interest in the AutoGPT project at [https://github.com/Significant-Gravitas/AutoGPT](https://github.com/Significant-Gravitas/AutoGPT) stewarded by Determinist Ltd (“**Determinist**”), with offices at 3rd Floor 1 Ashley Road, Altrincham, Cheshire, WA14 2DT, United Kingdom. The form of license below is a document that clarifies the terms under which You, the person listed below, may contribute software code described below (the “**Contribution**”) to the project. We appreciate your participation in our project, and your help in improving our products, so we want you to understand what will be done with the Contributions. This license is for your protection as well as the protection of Determinist and its licensees; it does not change your rights to use your own Contributions for any other purpose.
Thank you for your interest in the AutoGPT open source project at [https://github.com/Significant-Gravitas/AutoGPT](https://github.com/Significant-Gravitas/AutoGPT) stewarded by Determinist Ltd (“**Determinist**”), with offices at 3rd Floor 1 Ashley Road, Altrincham, Cheshire, WA14 2DT, United Kingdom. The form of license below is a document that clarifies the terms under which You, the person listed below, may contribute software code described below (the “**Contribution**”) to the project. We appreciate your participation in our project, and your help in improving our products, so we want you to understand what will be done with the Contributions. This license is for your protection as well as the protection of Determinist and its licensees; it does not change your rights to use your own Contributions for any other purpose.
By submitting a Pull Request which modifies the content of the “autogpt\_platform” folder at [https://github.com/Significant-Gravitas/AutoGPT/tree/master/autogpt\_platform](https://github.com/Significant-Gravitas/AutoGPT/tree/master/autogpt_platform), You hereby agree:

View File

@@ -22,29 +22,35 @@ To run the AutoGPT Platform, follow these steps:
2. Run the following command:
```
cp .env.example .env
git submodule update --init --recursive --progress
```
This command will copy the `.env.example` file to `.env`. You can modify the `.env` file to add your own environment variables.
This command will initialize and update the submodules in the repository. The `supabase` folder will be cloned to the root directory.
3. Run the following command:
```
cp supabase/docker/.env.example .env
```
This command will copy the `.env.example` file to `.env` in the `supabase/docker` directory. You can modify the `.env` file to add your own environment variables.
4. Run the following command:
```
docker compose up -d
```
This command will start all the necessary backend services defined in the `docker-compose.yml` file in detached mode.
4. Navigate to `frontend` within the `autogpt_platform` directory:
5. Navigate to `frontend` within the `autogpt_platform` directory:
```
cd frontend
```
You will need to run your frontend application separately on your local machine.
5. Run the following command:
6. Run the following command:
```
cp .env.example .env.local
```
This command will copy the `.env.example` file to `.env.local` in the `frontend` directory. You can modify the `.env.local` within this folder to add your own environment variables for the frontend application.
6. Run the following command:
7. Run the following command:
```
npm install
npm run dev
@@ -55,7 +61,7 @@ To run the AutoGPT Platform, follow these steps:
yarn install && yarn dev
```
7. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
8. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
### Docker Compose Commands

View File

@@ -1,13 +1,14 @@
from .config import Settings
from .depends import requires_admin_user, requires_user
from .jwt_utils import parse_jwt_token
from .middleware import APIKeyValidator, auth_middleware
from .middleware import auth_middleware
from .models import User
__all__ = [
"Settings",
"parse_jwt_token",
"requires_user",
"requires_admin_user",
"APIKeyValidator",
"auth_middleware",
"User",
]

View File

@@ -1,11 +1,14 @@
import os
from dotenv import load_dotenv
load_dotenv()
class Settings:
def __init__(self):
self.JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
self.JWT_ALGORITHM: str = "HS256"
JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
JWT_ALGORITHM: str = "HS256"
@property
def is_configured(self) -> bool:

View File

@@ -1,6 +1,6 @@
import fastapi
from .config import settings
from .config import Settings
from .middleware import auth_middleware
from .models import DEFAULT_USER_ID, User
@@ -17,7 +17,7 @@ def requires_admin_user(
def verify_user(payload: dict | None, admin_only: bool) -> User:
if not payload:
if settings.ENABLE_AUTH:
if Settings.ENABLE_AUTH:
raise fastapi.HTTPException(
status_code=401, detail="Authorization header is missing"
)

View File

@@ -1,10 +1,7 @@
import inspect
import logging
from typing import Any, Callable, Optional
from fastapi import HTTPException, Request, Security
from fastapi.security import APIKeyHeader, HTTPBearer
from starlette.status import HTTP_401_UNAUTHORIZED
from fastapi import HTTPException, Request
from fastapi.security import HTTPBearer
from .config import settings
from .jwt_utils import parse_jwt_token
@@ -32,104 +29,3 @@ async def auth_middleware(request: Request):
except ValueError as e:
raise HTTPException(status_code=401, detail=str(e))
return payload
class APIKeyValidator:
"""
Configurable API key validator that supports custom validation functions
for FastAPI applications.
This class provides a flexible way to implement API key authentication with optional
custom validation logic. It can be used for simple token matching
or more complex validation scenarios like database lookups.
Examples:
Simple token validation:
```python
validator = APIKeyValidator(
header_name="X-API-Key",
expected_token="your-secret-token"
)
@app.get("/protected", dependencies=[Depends(validator.get_dependency())])
def protected_endpoint():
return {"message": "Access granted"}
```
Custom validation with database lookup:
```python
async def validate_with_db(api_key: str):
api_key_obj = await db.get_api_key(api_key)
return api_key_obj if api_key_obj and api_key_obj.is_active else None
validator = APIKeyValidator(
header_name="X-API-Key",
validate_fn=validate_with_db
)
```
Args:
header_name (str): The name of the header containing the API key
expected_token (Optional[str]): The expected API key value for simple token matching
validate_fn (Optional[Callable]): Custom validation function that takes an API key
string and returns a boolean or object. Can be async.
error_status (int): HTTP status code to use for validation errors
error_message (str): Error message to return when validation fails
"""
def __init__(
self,
header_name: str,
expected_token: Optional[str] = None,
validate_fn: Optional[Callable[[str], bool]] = None,
error_status: int = HTTP_401_UNAUTHORIZED,
error_message: str = "Invalid API key",
):
# Create the APIKeyHeader as a class property
self.security_scheme = APIKeyHeader(name=header_name)
self.expected_token = expected_token
self.custom_validate_fn = validate_fn
self.error_status = error_status
self.error_message = error_message
async def default_validator(self, api_key: str) -> bool:
return api_key == self.expected_token
async def __call__(
self, request: Request, api_key: str = Security(APIKeyHeader)
) -> Any:
if api_key is None:
raise HTTPException(status_code=self.error_status, detail="Missing API key")
# Use custom validation if provided, otherwise use default equality check
validator = self.custom_validate_fn or self.default_validator
result = (
await validator(api_key)
if inspect.iscoroutinefunction(validator)
else validator(api_key)
)
if not result:
raise HTTPException(
status_code=self.error_status, detail=self.error_message
)
# Store validation result in request state if it's not just a boolean
if result is not True:
request.state.api_key = result
return result
def get_dependency(self):
"""
Returns a callable dependency that FastAPI will recognize as a security scheme
"""
async def validate_api_key(
request: Request, api_key: str = Security(self.security_scheme)
) -> Any:
return await self(request, api_key)
# This helps FastAPI recognize it as a security dependency
validate_api_key.__name__ = f"validate_{self.security_scheme.model.name}"
return validate_api_key

View File

@@ -13,6 +13,7 @@ from typing_extensions import ParamSpec
from .config import SETTINGS
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
P = ParamSpec("P")
T = TypeVar("T")

View File

@@ -99,6 +99,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
cloud_handler.setLevel(config.level)
cloud_handler.setFormatter(StructuredLoggingFormatter())
log_handlers.append(cloud_handler)
print("Cloud logging enabled")
else:
# Console output handlers
stdout = logging.StreamHandler(stream=sys.stdout)
@@ -117,6 +118,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
log_handlers += [stdout, stderr]
print("Console logging enabled")
# File logging setup
if config.enable_file_logging:
@@ -154,6 +156,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
error_log_handler.setLevel(logging.ERROR)
error_log_handler.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT, no_color=True))
log_handlers.append(error_log_handler)
print("File logging enabled")
# Configure the root logger
logging.basicConfig(

File diff suppressed because it is too large Load Diff

View File

@@ -10,17 +10,18 @@ packages = [{ include = "autogpt_libs" }]
colorama = "^0.4.6"
expiringdict = "^1.2.2"
google-cloud-logging = "^3.11.4"
pydantic = "^2.11.1"
pydantic-settings = "^2.8.1"
pydantic = "^2.10.6"
pydantic-settings = "^2.7.1"
pyjwt = "^2.10.1"
pytest-asyncio = "^0.26.0"
pytest-asyncio = "^0.25.3"
pytest-mock = "^3.14.0"
python = ">=3.10,<4.0"
supabase = "^2.15.0"
python-dotenv = "^1.0.1"
supabase = "^2.13.0"
[tool.poetry.group.dev.dependencies]
redis = "^5.2.1"
ruff = "^0.11.0"
ruff = "^0.9.3"
[build-system]
requires = ["poetry-core"]

View File

@@ -2,23 +2,13 @@ DB_USER=postgres
DB_PASS=your-super-secret-and-long-postgres-password
DB_NAME=postgres
DB_PORT=5432
DB_HOST=localhost
DB_CONNECTION_LIMIT=12
DB_CONNECT_TIMEOUT=60
DB_POOL_TIMEOUT=300
DB_SCHEMA=platform
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@localhost:${DB_PORT}/${DB_NAME}?connect_timeout=60&schema=platform"
PRISMA_SCHEMA="postgres/schema.prisma"
# EXECUTOR
NUM_GRAPH_WORKERS=10
NUM_NODE_WORKERS=3
BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
# generate using `from cryptography.fernet import Fernet;Fernet.generate_key().decode()`
ENCRYPTION_KEY='dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw='
UNSUBSCRIBE_SECRET_KEY = 'HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio='
REDIS_HOST=localhost
REDIS_PORT=6379
@@ -35,11 +25,6 @@ BEHAVE_AS=local
PYRO_HOST=localhost
SENTRY_DSN=
# Email For Postmark so we can send emails
POSTMARK_SERVER_API_TOKEN=
POSTMARK_SENDER_EMAIL=invalid@invalid.com
POSTMARK_WEBHOOK_TOKEN=
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
ENABLE_AUTH=true
SUPABASE_URL=http://localhost:8000
@@ -52,9 +37,6 @@ RABBITMQ_PORT=5672
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
## GCS bucket is required for marketplace and library functionality
MEDIA_GCS_BUCKET_NAME=
## For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow
## for integrations to work. Defaults to the value of PLATFORM_BASE_URL if not set.
# FRONTEND_BASE_URL=http://localhost:3000
@@ -179,17 +161,6 @@ MEM0_API_KEY=
# Nvidia
NVIDIA_API_KEY=
# Apollo
APOLLO_API_KEY=
# SmartLead
SMARTLEAD_API_KEY=
# ZeroBounce
ZEROBOUNCE_API_KEY=
## ===== OPTIONAL API KEYS END ===== ##
# Logging Configuration
LOG_LEVEL=INFO
ENABLE_CLOUD_LOGGING=false

View File

@@ -1 +1,75 @@
[Advanced Setup (Dev Branch)](https://dev-docs.agpt.co/platform/advanced_setup/#autogpt_agent_server_advanced_set_up)
# AutoGPT Agent Server Advanced set up
This guide walks you through a dockerized set up, with an external DB (postgres)
## Setup
We use the Poetry to manage the dependencies. To set up the project, follow these steps inside this directory:
0. Install Poetry
```sh
pip install poetry
```
1. Configure Poetry to use .venv in your project directory
```sh
poetry config virtualenvs.in-project true
```
2. Enter the poetry shell
```sh
poetry shell
```
3. Install dependencies
```sh
poetry install
```
4. Copy .env.example to .env
```sh
cp .env.example .env
```
5. Generate the Prisma client
```sh
poetry run prisma generate
```
> In case Prisma generates the client for the global Python installation instead of the virtual environment, the current mitigation is to just uninstall the global Prisma package:
>
> ```sh
> pip uninstall prisma
> ```
>
> Then run the generation again. The path *should* look something like this:
> `<some path>/pypoetry/virtualenvs/backend-TQIRSwR6-py3.12/bin/prisma`
6. Run the postgres database from the /rnd folder
```sh
cd autogpt_platform/
docker compose up -d
```
7. Run the migrations (from the backend folder)
```sh
cd ../backend
prisma migrate deploy
```
## Running The Server
### Starting the server directly
Run the following command:
```sh
poetry run app
```

View File

@@ -1 +1,210 @@
[Getting Started (Released)](https://docs.agpt.co/platform/getting-started/#autogpt_agent_server)
# AutoGPT Agent Server
This is an initial project for creating the next generation of agent execution, which is an AutoGPT agent server.
The agent server will enable the creation of composite multi-agent systems that utilize AutoGPT agents and other non-agent components as its primitives.
## Docs
You can access the docs for the [AutoGPT Agent Server here](https://docs.agpt.co/server/setup).
## Setup
We use the Poetry to manage the dependencies. To set up the project, follow these steps inside this directory:
0. Install Poetry
```sh
pip install poetry
```
1. Configure Poetry to use .venv in your project directory
```sh
poetry config virtualenvs.in-project true
```
2. Enter the poetry shell
```sh
poetry shell
```
3. Install dependencies
```sh
poetry install
```
4. Copy .env.example to .env
```sh
cp .env.example .env
```
5. Generate the Prisma client
```sh
poetry run prisma generate
```
> In case Prisma generates the client for the global Python installation instead of the virtual environment, the current mitigation is to just uninstall the global Prisma package:
>
> ```sh
> pip uninstall prisma
> ```
>
> Then run the generation again. The path *should* look something like this:
> `<some path>/pypoetry/virtualenvs/backend-TQIRSwR6-py3.12/bin/prisma`
6. Migrate the database. Be careful because this deletes current data in the database.
```sh
docker compose up db -d
poetry run prisma migrate deploy
```
## Running The Server
### Starting the server without Docker
To run the server locally, start in the autogpt_platform folder:
```sh
cd ..
```
Run the following command to run database in docker but the application locally:
```sh
docker compose --profile local up deps --build --detach
cd backend
poetry run app
```
### Starting the server with Docker
Run the following command to build the dockerfiles:
```sh
docker compose build
```
Run the following command to run the app:
```sh
docker compose up
```
Run the following to automatically rebuild when code changes, in another terminal:
```sh
docker compose watch
```
Run the following command to shut down:
```sh
docker compose down
```
If you run into issues with dangling orphans, try:
```sh
docker compose down --volumes --remove-orphans && docker-compose up --force-recreate --renew-anon-volumes --remove-orphans
```
## Testing
To run the tests:
```sh
poetry run test
```
## Development
### Formatting & Linting
Auto formatter and linter are set up in the project. To run them:
Install:
```sh
poetry install --with dev
```
Format the code:
```sh
poetry run format
```
Lint the code:
```sh
poetry run lint
```
## Project Outline
The current project has the following main modules:
### **blocks**
This module stores all the Agent Blocks, which are reusable components to build a graph that represents the agent's behavior.
### **data**
This module stores the logical model that is persisted in the database.
It abstracts the database operations into functions that can be called by the service layer.
Any code that interacts with Prisma objects or the database should reside in this module.
The main models are:
* `block`: anything related to the block used in the graph
* `execution`: anything related to the execution graph execution
* `graph`: anything related to the graph, node, and its relations
### **execution**
This module stores the business logic of executing the graph.
It currently has the following main modules:
* `manager`: A service that consumes the queue of the graph execution and executes the graph. It contains both pieces of logic.
* `scheduler`: A service that triggers scheduled graph execution based on a cron expression. It pushes an execution request to the manager.
### **server**
This module stores the logic for the server API.
It contains all the logic used for the API that allows the client to create, execute, and monitor the graph and its execution.
This API service interacts with other services like those defined in `manager` and `scheduler`.
### **utils**
This module stores utility functions that are used across the project.
Currently, it has two main modules:
* `process`: A module that contains the logic to spawn a new process.
* `service`: A module that serves as a parent class for all the services in the project.
## Service Communication
Currently, there are only 3 active services:
- AgentServer (the API, defined in `server.py`)
- ExecutionManager (the executor, defined in `manager.py`)
- ExecutionScheduler (the scheduler, defined in `scheduler.py`)
The services run in independent Python processes and communicate through an IPC.
A communication layer (`service.py`) is created to decouple the communication library from the implementation.
Currently, the IPC is done using Pyro5 and abstracted in a way that allows a function decorated with `@expose` to be called from a different process.
By default the daemons run on the following ports:
Execution Manager Daemon: 8002
Execution Scheduler Daemon: 8003
Rest Server Daemon: 8004
## Adding a New Agent Block
To add a new agent block, you need to create a new class that inherits from `Block` and provides the following information:
* All the block code should live in the `blocks` (`backend.blocks`) module.
* `input_schema`: the schema of the input data, represented by a Pydantic object.
* `output_schema`: the schema of the output data, represented by a Pydantic object.
* `run` method: the main logic of the block.
* `test_input` & `test_output`: the sample input and output data for the block, which will be used to auto-test the block.
* You can mock the functions declared in the block using the `test_mock` field for your unit tests.
* Once you finish creating the block, you can test it by running `poetry run pytest -s test/block/test_block.py`.

View File

@@ -1,30 +1,22 @@
import logging
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from backend.util.process import AppProcess
logger = logging.getLogger(__name__)
def run_processes(*processes: "AppProcess", **kwargs):
"""
Execute all processes in the app. The last process is run in the foreground.
Includes enhanced error handling and process lifecycle management.
"""
try:
# Run all processes except the last one in the background.
for process in processes[:-1]:
process.start(background=True, **kwargs)
# Run the last process in the foreground.
# Run the last process in the foreground
processes[-1].start(background=False, **kwargs)
finally:
for process in processes:
try:
process.stop()
except Exception as e:
logger.exception(f"[{process.service_name}] unable to stop: {e}")
process.stop()
def main(**kwargs):
@@ -32,16 +24,14 @@ def main(**kwargs):
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
"""
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
from backend.notifications import NotificationManager
from backend.executor import DatabaseManager, ExecutionManager, ExecutionScheduler
from backend.server.rest_api import AgentServer
from backend.server.ws_api import WebsocketServer
run_processes(
DatabaseManager(),
ExecutionManager(),
Scheduler(),
NotificationManager(),
ExecutionScheduler(),
WebsocketServer(),
AgentServer(),
**kwargs,

View File

@@ -2,103 +2,88 @@ import importlib
import os
import re
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar
from typing import Type, TypeVar
from backend.data.block import Block
# Dynamically load all modules under backend.blocks
AVAILABLE_MODULES = []
current_dir = Path(__file__).parent
modules = [
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
for f in current_dir.rglob("*.py")
if f.is_file() and f.name != "__init__.py"
]
for module in modules:
if not re.match("^[a-z0-9_.]+$", module):
raise ValueError(
f"Block module {module} error: module name must be lowercase, "
"and contain only alphanumeric characters and underscores."
)
importlib.import_module(f".{module}", package=__name__)
AVAILABLE_MODULES.append(module)
# Load all Block instances from the available modules
AVAILABLE_BLOCKS: dict[str, Type[Block]] = {}
if TYPE_CHECKING:
from backend.data.block import Block
T = TypeVar("T")
_AVAILABLE_BLOCKS: dict[str, type["Block"]] = {}
def load_all_blocks() -> dict[str, type["Block"]]:
from backend.data.block import Block
if _AVAILABLE_BLOCKS:
return _AVAILABLE_BLOCKS
# Dynamically load all modules under backend.blocks
AVAILABLE_MODULES = []
current_dir = Path(__file__).parent
modules = [
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
for f in current_dir.rglob("*.py")
if f.is_file() and f.name != "__init__.py"
]
for module in modules:
if not re.match("^[a-z0-9_.]+$", module):
raise ValueError(
f"Block module {module} error: module name must be lowercase, "
"and contain only alphanumeric characters and underscores."
)
importlib.import_module(f".{module}", package=__name__)
AVAILABLE_MODULES.append(module)
# Load all Block instances from the available modules
for block_cls in all_subclasses(Block):
class_name = block_cls.__name__
if class_name.endswith("Base"):
continue
if not class_name.endswith("Block"):
raise ValueError(
f"Block class {class_name} does not end with 'Block'. "
"If you are creating an abstract class, "
"please name the class with 'Base' at the end"
)
block = block_cls.create()
if not isinstance(block.id, str) or len(block.id) != 36:
raise ValueError(
f"Block ID {block.name} error: {block.id} is not a valid UUID"
)
if block.id in _AVAILABLE_BLOCKS:
raise ValueError(
f"Block ID {block.name} error: {block.id} is already in use"
)
input_schema = block.input_schema.model_fields
output_schema = block.output_schema.model_fields
# Make sure `error` field is a string in the output schema
if "error" in output_schema and output_schema["error"].annotation is not str:
raise ValueError(
f"{block.name} `error` field in output_schema must be a string"
)
# Ensure all fields in input_schema and output_schema are annotated SchemaFields
for field_name, field in [*input_schema.items(), *output_schema.items()]:
if field.annotation is None:
raise ValueError(
f"{block.name} has a field {field_name} that is not annotated"
)
if field.json_schema_extra is None:
raise ValueError(
f"{block.name} has a field {field_name} not defined as SchemaField"
)
for field in block.input_schema.model_fields.values():
if field.annotation is bool and field.default not in (True, False):
raise ValueError(
f"{block.name} has a boolean field with no default value"
)
_AVAILABLE_BLOCKS[block.id] = block_cls
return _AVAILABLE_BLOCKS
__all__ = ["load_all_blocks"]
def all_subclasses(cls: type[T]) -> list[type[T]]:
def all_subclasses(cls: Type[T]) -> list[Type[T]]:
subclasses = cls.__subclasses__()
for subclass in subclasses:
subclasses += all_subclasses(subclass)
return subclasses
for block_cls in all_subclasses(Block):
name = block_cls.__name__
if block_cls.__name__.endswith("Base"):
continue
if not block_cls.__name__.endswith("Block"):
raise ValueError(
f"Block class {block_cls.__name__} does not end with 'Block', If you are creating an abstract class, please name the class with 'Base' at the end"
)
block = block_cls.create()
if not isinstance(block.id, str) or len(block.id) != 36:
raise ValueError(f"Block ID {block.name} error: {block.id} is not a valid UUID")
if block.id in AVAILABLE_BLOCKS:
raise ValueError(f"Block ID {block.name} error: {block.id} is already in use")
input_schema = block.input_schema.model_fields
output_schema = block.output_schema.model_fields
# Make sure `error` field is a string in the output schema
if "error" in output_schema and output_schema["error"].annotation is not str:
raise ValueError(
f"{block.name} `error` field in output_schema must be a string"
)
# Make sure all fields in input_schema and output_schema are annotated and has a value
for field_name, field in [*input_schema.items(), *output_schema.items()]:
if field.annotation is None:
raise ValueError(
f"{block.name} has a field {field_name} that is not annotated"
)
if field.json_schema_extra is None:
raise ValueError(
f"{block.name} has a field {field_name} not defined as SchemaField"
)
for field in block.input_schema.model_fields.values():
if field.annotation is bool and field.default not in (True, False):
raise ValueError(f"{block.name} has a boolean field with no default value")
if block.disabled:
continue
AVAILABLE_BLOCKS[block.id] = block_cls
__all__ = ["AVAILABLE_MODULES", "AVAILABLE_BLOCKS"]

View File

@@ -1,5 +1,4 @@
import logging
from typing import Any
from autogpt_libs.utils.cache import thread_cached
@@ -14,7 +13,6 @@ from backend.data.block import (
)
from backend.data.execution import ExecutionStatus
from backend.data.model import SchemaField
from backend.util import json
logger = logging.getLogger(__name__)
@@ -44,23 +42,6 @@ class AgentExecutorBlock(Block):
input_schema: dict = SchemaField(description="Input schema for the graph")
output_schema: dict = SchemaField(description="Output schema for the graph")
@classmethod
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
return data.get("input_schema", {})
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data.get("data", {})
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
required_fields = cls.get_input_schema(data).get("required", [])
return set(required_fields) - set(data)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
class Output(BlockSchema):
pass
@@ -75,8 +56,6 @@ class AgentExecutorBlock(Block):
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
from backend.data.execution import ExecutionEventType
executor_manager = get_executor_manager_client()
event_bus = get_event_bus()
@@ -90,11 +69,13 @@ class AgentExecutorBlock(Block):
logger.info(f"Starting execution of {log_id}")
for event in event_bus.listen(
user_id=graph_exec.user_id,
graph_id=graph_exec.graph_id,
graph_exec_id=graph_exec.graph_exec_id,
graph_id=graph_exec.graph_id, graph_exec_id=graph_exec.graph_exec_id
):
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
logger.info(
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
)
if not event.node_id:
if event.status in [
ExecutionStatus.COMPLETED,
ExecutionStatus.TERMINATED,
@@ -105,10 +86,6 @@ class AgentExecutorBlock(Block):
else:
continue
logger.info(
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
)
if not event.block_id:
logger.warning(f"{log_id} received event without block_id {event}")
continue

View File

@@ -1,108 +0,0 @@
import logging
from typing import List
from backend.blocks.apollo._auth import ApolloCredentials
from backend.blocks.apollo.models import (
Contact,
Organization,
SearchOrganizationsRequest,
SearchOrganizationsResponse,
SearchPeopleRequest,
SearchPeopleResponse,
)
from backend.util.request import Requests
logger = logging.getLogger(name=__name__)
class ApolloClient:
"""Client for the Apollo API"""
API_URL = "https://api.apollo.io/api/v1"
def __init__(self, credentials: ApolloCredentials):
self.credentials = credentials
self.requests = Requests()
def _get_headers(self) -> dict[str, str]:
return {"x-api-key": self.credentials.api_key.get_secret_value()}
def search_people(self, query: SearchPeopleRequest) -> List[Contact]:
"""Search for people in Apollo"""
response = self.requests.get(
f"{self.API_URL}/mixed_people/search",
headers=self._get_headers(),
params=query.model_dump(exclude={"credentials", "max_results"}),
)
parsed_response = SearchPeopleResponse(**response.json())
if parsed_response.pagination.total_entries == 0:
return []
people = parsed_response.people
# handle pagination
if (
query.max_results is not None
and query.max_results < parsed_response.pagination.total_entries
and len(people) < query.max_results
):
while (
len(people) < query.max_results
and query.page < parsed_response.pagination.total_pages
and len(parsed_response.people) > 0
):
query.page += 1
response = self.requests.get(
f"{self.API_URL}/mixed_people/search",
headers=self._get_headers(),
params=query.model_dump(exclude={"credentials", "max_results"}),
)
parsed_response = SearchPeopleResponse(**response.json())
people.extend(parsed_response.people[: query.max_results - len(people)])
logger.info(f"Found {len(people)} people")
return people[: query.max_results] if query.max_results else people
def search_organizations(
self, query: SearchOrganizationsRequest
) -> List[Organization]:
"""Search for organizations in Apollo"""
response = self.requests.get(
f"{self.API_URL}/mixed_companies/search",
headers=self._get_headers(),
params=query.model_dump(exclude={"credentials", "max_results"}),
)
parsed_response = SearchOrganizationsResponse(**response.json())
if parsed_response.pagination.total_entries == 0:
return []
organizations = parsed_response.organizations
# handle pagination
if (
query.max_results is not None
and query.max_results < parsed_response.pagination.total_entries
and len(organizations) < query.max_results
):
while (
len(organizations) < query.max_results
and query.page < parsed_response.pagination.total_pages
and len(parsed_response.organizations) > 0
):
query.page += 1
response = self.requests.get(
f"{self.API_URL}/mixed_companies/search",
headers=self._get_headers(),
params=query.model_dump(exclude={"credentials", "max_results"}),
)
parsed_response = SearchOrganizationsResponse(**response.json())
organizations.extend(
parsed_response.organizations[
: query.max_results - len(organizations)
]
)
logger.info(f"Found {len(organizations)} organizations")
return (
organizations[: query.max_results] if query.max_results else organizations
)

View File

@@ -1,35 +0,0 @@
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
ApolloCredentials = APIKeyCredentials
ApolloCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.APOLLO],
Literal["api_key"],
]
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="apollo",
api_key=SecretStr("mock-apollo-api-key"),
title="Mock Apollo API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
def ApolloCredentialsField() -> ApolloCredentialsInput:
"""
Creates a Apollo credentials input on a block.
"""
return CredentialsField(
description="The Apollo integration can be used with an API Key.",
)

View File

@@ -1,543 +0,0 @@
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from backend.data.model import SchemaField
class PrimaryPhone(BaseModel):
"""A primary phone in Apollo"""
number: str
source: str
sanitized_number: str
class SenorityLevels(str, Enum):
"""Seniority levels in Apollo"""
OWNER = "owner"
FOUNDER = "founder"
C_SUITE = "c_suite"
PARTNER = "partner"
VP = "vp"
HEAD = "head"
DIRECTOR = "director"
MANAGER = "manager"
SENIOR = "senior"
ENTRY = "entry"
INTERN = "intern"
class ContactEmailStatuses(str, Enum):
"""Contact email statuses in Apollo"""
VERIFIED = "verified"
UNVERIFIED = "unverified"
LIKELY_TO_ENGAGE = "likely_to_engage"
UNAVAILABLE = "unavailable"
class RuleConfigStatus(BaseModel):
"""A rule config status in Apollo"""
_id: str
created_at: str
rule_action_config_id: str
rule_config_id: str
status_cd: str
updated_at: str
id: str
key: str
class ContactCampaignStatus(BaseModel):
"""A contact campaign status in Apollo"""
id: str
emailer_campaign_id: str
send_email_from_user_id: str
inactive_reason: str
status: str
added_at: str
added_by_user_id: str
finished_at: str
paused_at: str
auto_unpause_at: str
send_email_from_email_address: str
send_email_from_email_account_id: str
manually_set_unpause: str
failure_reason: str
current_step_id: str
in_response_to_emailer_message_id: str
cc_emails: str
bcc_emails: str
to_emails: str
class Account(BaseModel):
"""An account in Apollo"""
id: str
name: str
website_url: str
blog_url: str
angellist_url: str
linkedin_url: str
twitter_url: str
facebook_url: str
primary_phone: PrimaryPhone
languages: list[str]
alexa_ranking: int
phone: str
linkedin_uid: str
founded_year: int
publicly_traded_symbol: str
publicly_traded_exchange: str
logo_url: str
chrunchbase_url: str
primary_domain: str
domain: str
team_id: str
organization_id: str
account_stage_id: str
source: str
original_source: str
creator_id: str
owner_id: str
created_at: str
phone_status: str
hubspot_id: str
salesforce_id: str
crm_owner_id: str
parent_account_id: str
sanitized_phone: str
# no listed type on the API docs
account_playbook_statues: list[Any]
account_rule_config_statuses: list[RuleConfigStatus]
existence_level: str
label_ids: list[str]
typed_custom_fields: Any
custom_field_errors: Any
modality: str
source_display_name: str
salesforce_record_id: str
crm_record_url: str
class ContactEmail(BaseModel):
"""A contact email in Apollo"""
email: str = ""
email_md5: str = ""
email_sha256: str = ""
email_status: str = ""
email_source: str = ""
extrapolated_email_confidence: str = ""
position: int = 0
email_from_customer: str = ""
free_domain: bool = True
class EmploymentHistory(BaseModel):
"""An employment history in Apollo"""
class Config:
extra = "allow"
arbitrary_types_allowed = True
from_attributes = True
populate_by_name = True
_id: Optional[str] = None
created_at: Optional[str] = None
current: Optional[bool] = None
degree: Optional[str] = None
description: Optional[str] = None
emails: Optional[str] = None
end_date: Optional[str] = None
grade_level: Optional[str] = None
kind: Optional[str] = None
major: Optional[str] = None
organization_id: Optional[str] = None
organization_name: Optional[str] = None
raw_address: Optional[str] = None
start_date: Optional[str] = None
title: Optional[str] = None
updated_at: Optional[str] = None
id: Optional[str] = None
key: Optional[str] = None
class Breadcrumb(BaseModel):
"""A breadcrumb in Apollo"""
label: Optional[str] = "N/A"
signal_field_name: Optional[str] = "N/A"
value: str | list | None = "N/A"
display_name: Optional[str] = "N/A"
class TypedCustomField(BaseModel):
"""A typed custom field in Apollo"""
id: Optional[str] = "N/A"
value: Optional[str] = "N/A"
class Pagination(BaseModel):
"""Pagination in Apollo"""
class Config:
extra = "allow" # Allow extra fields
arbitrary_types_allowed = True # Allow any type
from_attributes = True # Allow from_orm
populate_by_name = True # Allow field aliases to work both ways
page: int = 0
per_page: int = 0
total_entries: int = 0
total_pages: int = 0
class DialerFlags(BaseModel):
"""A dialer flags in Apollo"""
country_name: str
country_enabled: bool
high_risk_calling_enabled: bool
potential_high_risk_number: bool
class PhoneNumber(BaseModel):
"""A phone number in Apollo"""
raw_number: str = ""
sanitized_number: str = ""
type: str = ""
position: int = 0
status: str = ""
dnc_status: str = ""
dnc_other_info: str = ""
dailer_flags: DialerFlags = DialerFlags(
country_name="",
country_enabled=True,
high_risk_calling_enabled=True,
potential_high_risk_number=True,
)
class Organization(BaseModel):
"""An organization in Apollo"""
class Config:
extra = "allow"
arbitrary_types_allowed = True
from_attributes = True
populate_by_name = True
id: Optional[str] = "N/A"
name: Optional[str] = "N/A"
website_url: Optional[str] = "N/A"
blog_url: Optional[str] = "N/A"
angellist_url: Optional[str] = "N/A"
linkedin_url: Optional[str] = "N/A"
twitter_url: Optional[str] = "N/A"
facebook_url: Optional[str] = "N/A"
primary_phone: Optional[PrimaryPhone] = PrimaryPhone(
number="N/A", source="N/A", sanitized_number="N/A"
)
languages: list[str] = []
alexa_ranking: Optional[int] = 0
phone: Optional[str] = "N/A"
linkedin_uid: Optional[str] = "N/A"
founded_year: Optional[int] = 0
publicly_traded_symbol: Optional[str] = "N/A"
publicly_traded_exchange: Optional[str] = "N/A"
logo_url: Optional[str] = "N/A"
chrunchbase_url: Optional[str] = "N/A"
primary_domain: Optional[str] = "N/A"
sanitized_phone: Optional[str] = "N/A"
owned_by_organization_id: Optional[str] = "N/A"
intent_strength: Optional[str] = "N/A"
show_intent: bool = True
has_intent_signal_account: Optional[bool] = True
intent_signal_account: Optional[str] = "N/A"
class Contact(BaseModel):
"""A contact in Apollo"""
class Config:
extra = "allow"
arbitrary_types_allowed = True
from_attributes = True
populate_by_name = True
contact_roles: list[Any] = []
id: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
name: Optional[str] = None
linkedin_url: Optional[str] = None
title: Optional[str] = None
contact_stage_id: Optional[str] = None
owner_id: Optional[str] = None
creator_id: Optional[str] = None
person_id: Optional[str] = None
email_needs_tickling: bool = True
organization_name: Optional[str] = None
source: Optional[str] = None
original_source: Optional[str] = None
organization_id: Optional[str] = None
headline: Optional[str] = None
photo_url: Optional[str] = None
present_raw_address: Optional[str] = None
linkededin_uid: Optional[str] = None
extrapolated_email_confidence: Optional[float] = None
salesforce_id: Optional[str] = None
salesforce_lead_id: Optional[str] = None
salesforce_contact_id: Optional[str] = None
saleforce_account_id: Optional[str] = None
crm_owner_id: Optional[str] = None
created_at: Optional[str] = None
emailer_campaign_ids: list[str] = []
direct_dial_status: Optional[str] = None
direct_dial_enrichment_failed_at: Optional[str] = None
email_status: Optional[str] = None
email_source: Optional[str] = None
account_id: Optional[str] = None
last_activity_date: Optional[str] = None
hubspot_vid: Optional[str] = None
hubspot_company_id: Optional[str] = None
crm_id: Optional[str] = None
sanitized_phone: Optional[str] = None
merged_crm_ids: Optional[str] = None
updated_at: Optional[str] = None
queued_for_crm_push: bool = True
suggested_from_rule_engine_config_id: Optional[str] = None
email_unsubscribed: Optional[str] = None
label_ids: list[Any] = []
has_pending_email_arcgate_request: bool = True
has_email_arcgate_request: bool = True
existence_level: Optional[str] = None
email: Optional[str] = None
email_from_customer: Optional[str] = None
typed_custom_fields: list[TypedCustomField] = []
custom_field_errors: Any = None
salesforce_record_id: Optional[str] = None
crm_record_url: Optional[str] = None
email_status_unavailable_reason: Optional[str] = None
email_true_status: Optional[str] = None
updated_email_true_status: bool = True
contact_rule_config_statuses: list[RuleConfigStatus] = []
source_display_name: Optional[str] = None
twitter_url: Optional[str] = None
contact_campaign_statuses: list[ContactCampaignStatus] = []
state: Optional[str] = None
city: Optional[str] = None
country: Optional[str] = None
account: Optional[Account] = None
contact_emails: list[ContactEmail] = []
organization: Optional[Organization] = None
employment_history: list[EmploymentHistory] = []
time_zone: Optional[str] = None
intent_strength: Optional[str] = None
show_intent: bool = True
phone_numbers: list[PhoneNumber] = []
account_phone_note: Optional[str] = None
free_domain: bool = True
is_likely_to_engage: bool = True
email_domain_catchall: bool = True
contact_job_change_event: Optional[str] = None
class SearchOrganizationsRequest(BaseModel):
"""Request for Apollo's search organizations API"""
organization_num_empoloyees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
default=[0, 1000000],
)
organization_locations: list[str] = SchemaField(
description="""The location of the company headquarters. You can search across cities, US states, and countries.
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, any Boston-based companies will not appearch in your search results, even if they match other parameters.
To exclude companies based on location, use the organization_not_locations parameter.
""",
default=[],
)
organizations_not_locations: list[str] = SchemaField(
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
""",
default=[],
)
q_organization_keyword_tags: list[str] = SchemaField(
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
)
q_organization_name: str = SchemaField(
description="""Filter search results to include a specific company name.
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible."""
)
organization_ids: list[str] = SchemaField(
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, identify the values for organization_id when you call this endpoint.""",
default=[],
)
max_results: int = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
default=100,
ge=1,
le=50000,
advanced=True,
)
page: int = SchemaField(
description="""The page number of the Apollo data that you want to retrieve.
Use this parameter in combination with the per_page parameter to make search results for navigable and improve the performance of the endpoint.""",
default=1,
)
per_page: int = SchemaField(
description="""The number of search results that should be returned for each page. Limited the number of results per page improves the endpoint's performance.
Use the page parameter to search the different pages of data.""",
default=100,
)
class SearchOrganizationsResponse(BaseModel):
"""Response from Apollo's search organizations API"""
breadcrumbs: list[Breadcrumb] = []
partial_results_only: bool = True
has_join: bool = True
disable_eu_prospecting: bool = True
partial_results_limit: int = 0
pagination: Pagination = Pagination(
page=0, per_page=0, total_entries=0, total_pages=0
)
# no listed type on the API docs
accounts: list[Any] = []
organizations: list[Organization] = []
models_ids: list[str] = []
num_fetch_result: Optional[str] = "N/A"
derived_params: Optional[str] = "N/A"
class SearchPeopleRequest(BaseModel):
"""Request for Apollo's search people API"""
person_titles: list[str] = SchemaField(
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
Results also include job titles with the same terms, even if they are not exact matches. For example, searching for marketing manager might return people with the job title content marketing manager.
Use this parameter in combination with the person_seniorities[] parameter to find people based on specific job functions and seniority levels.
""",
default=[],
placeholder="marketing manager",
)
person_locations: list[str] = SchemaField(
description="""The location where people live. You can search across cities, US states, and countries.
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
default=[],
)
person_seniorities: list[SenorityLevels] = SchemaField(
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
For a person to be included in search results, they only need to match 1 of the seniorities you add. Adding more seniorities expands your search results.
Searches only return results based on their current job title, so searching for Director-level employees only returns people that currently hold a Director-level title. If someone was previously a Director, but is currently a VP, they would not be included in your search results.
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
default=[],
)
organization_locations: list[str] = SchemaField(
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
To find people based on their personal location, use the person_locations parameter.""",
default=[],
)
q_organization_domains: list[str] = SchemaField(
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
You can add multiple domains to search across companies.
Examples: apollo.io and microsoft.com""",
default=[],
)
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
default=[],
)
organization_ids: list[str] = SchemaField(
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
default=[],
)
organization_num_empoloyees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
default=[],
)
q_keywords: str = SchemaField(
description="""A string of words over which we want to filter the results""",
default="",
)
page: int = SchemaField(
description="""The page number of the Apollo data that you want to retrieve.
Use this parameter in combination with the per_page parameter to make search results for navigable and improve the performance of the endpoint.""",
default=1,
)
per_page: int = SchemaField(
description="""The number of search results that should be returned for each page. Limited the number of results per page improves the endpoint's performance.
Use the page parameter to search the different pages of data.""",
default=100,
)
max_results: int = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
default=100,
ge=1,
le=50000,
advanced=True,
)
class SearchPeopleResponse(BaseModel):
"""Response from Apollo's search people API"""
class Config:
extra = "allow" # Allow extra fields
arbitrary_types_allowed = True # Allow any type
from_attributes = True # Allow from_orm
populate_by_name = True # Allow field aliases to work both ways
breadcrumbs: list[Breadcrumb] = []
partial_results_only: bool = True
has_join: bool = True
disable_eu_prospecting: bool = True
partial_results_limit: int = 0
pagination: Pagination = Pagination(
page=0, per_page=0, total_entries=0, total_pages=0
)
contacts: list[Contact] = []
people: list[Contact] = []
model_ids: list[str] = []
num_fetch_result: Optional[str] = "N/A"
derived_params: Optional[str] = "N/A"

View File

@@ -1,219 +0,0 @@
from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
ApolloCredentials,
ApolloCredentialsInput,
)
from backend.blocks.apollo.models import (
Organization,
PrimaryPhone,
SearchOrganizationsRequest,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class SearchOrganizationsBlock(Block):
"""Search for organizations in Apollo"""
class Input(BlockSchema):
organization_num_empoloyees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
default=[0, 1000000],
)
organization_locations: list[str] = SchemaField(
description="""The location of the company headquarters. You can search across cities, US states, and countries.
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, any Boston-based companies will not appearch in your search results, even if they match other parameters.
To exclude companies based on location, use the organization_not_locations parameter.
""",
default=[],
)
organizations_not_locations: list[str] = SchemaField(
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
""",
default=[],
)
q_organization_keyword_tags: list[str] = SchemaField(
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
default=[],
)
q_organization_name: str = SchemaField(
description="""Filter search results to include a specific company name.
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible.""",
default="",
advanced=False,
)
organization_ids: list[str] = SchemaField(
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, identify the values for organization_id when you call this endpoint.""",
default=[],
)
max_results: int = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
default=100,
ge=1,
le=50000,
advanced=True,
)
credentials: ApolloCredentialsInput = SchemaField(
description="Apollo credentials",
)
class Output(BlockSchema):
organizations: list[Organization] = SchemaField(
description="List of organizations found",
default=[],
)
organization: Organization = SchemaField(
description="Each found organization, one at a time",
)
error: str = SchemaField(
description="Error message if the search failed",
default="",
)
def __init__(self):
super().__init__(
id="3d71270d-599e-4148-9b95-71b35d2f44f0",
description="Search for organizations in Apollo",
categories={BlockCategory.SEARCH},
input_schema=SearchOrganizationsBlock.Input,
output_schema=SearchOrganizationsBlock.Output,
test_credentials=TEST_CREDENTIALS,
test_input={"query": "Google", "credentials": TEST_CREDENTIALS_INPUT},
test_output=[
(
"organization",
Organization(
id="1",
name="Google",
website_url="https://google.com",
blog_url="https://google.com/blog",
angellist_url="https://angel.co/google",
linkedin_url="https://linkedin.com/company/google",
twitter_url="https://twitter.com/google",
facebook_url="https://facebook.com/google",
primary_phone=PrimaryPhone(
source="google",
number="1234567890",
sanitized_number="1234567890",
),
languages=["en"],
alexa_ranking=1000,
phone="1234567890",
linkedin_uid="1234567890",
founded_year=2000,
publicly_traded_symbol="GOOGL",
publicly_traded_exchange="NASDAQ",
logo_url="https://google.com/logo.png",
chrunchbase_url="https://chrunchbase.com/google",
primary_domain="google.com",
sanitized_phone="1234567890",
owned_by_organization_id="1",
intent_strength="strong",
show_intent=True,
has_intent_signal_account=True,
intent_signal_account="1",
),
),
(
"organizations",
[
Organization(
id="1",
name="Google",
website_url="https://google.com",
blog_url="https://google.com/blog",
angellist_url="https://angel.co/google",
linkedin_url="https://linkedin.com/company/google",
twitter_url="https://twitter.com/google",
facebook_url="https://facebook.com/google",
primary_phone=PrimaryPhone(
source="google",
number="1234567890",
sanitized_number="1234567890",
),
languages=["en"],
alexa_ranking=1000,
phone="1234567890",
linkedin_uid="1234567890",
founded_year=2000,
publicly_traded_symbol="GOOGL",
publicly_traded_exchange="NASDAQ",
logo_url="https://google.com/logo.png",
chrunchbase_url="https://chrunchbase.com/google",
primary_domain="google.com",
sanitized_phone="1234567890",
owned_by_organization_id="1",
intent_strength="strong",
show_intent=True,
has_intent_signal_account=True,
intent_signal_account="1",
),
],
),
],
test_mock={
"search_organizations": lambda *args, **kwargs: [
Organization(
id="1",
name="Google",
website_url="https://google.com",
blog_url="https://google.com/blog",
angellist_url="https://angel.co/google",
linkedin_url="https://linkedin.com/company/google",
twitter_url="https://twitter.com/google",
facebook_url="https://facebook.com/google",
primary_phone=PrimaryPhone(
source="google",
number="1234567890",
sanitized_number="1234567890",
),
languages=["en"],
alexa_ranking=1000,
phone="1234567890",
linkedin_uid="1234567890",
founded_year=2000,
publicly_traded_symbol="GOOGL",
publicly_traded_exchange="NASDAQ",
logo_url="https://google.com/logo.png",
chrunchbase_url="https://chrunchbase.com/google",
primary_domain="google.com",
sanitized_phone="1234567890",
owned_by_organization_id="1",
intent_strength="strong",
show_intent=True,
has_intent_signal_account=True,
intent_signal_account="1",
)
]
},
)
@staticmethod
def search_organizations(
query: SearchOrganizationsRequest, credentials: ApolloCredentials
) -> list[Organization]:
client = ApolloClient(credentials)
return client.search_organizations(query)
def run(
self, input_data: Input, *, credentials: ApolloCredentials, **kwargs
) -> BlockOutput:
query = SearchOrganizationsRequest(
**input_data.model_dump(exclude={"credentials"})
)
organizations = self.search_organizations(query, credentials)
for organization in organizations:
yield "organization", organization
yield "organizations", organizations

View File

@@ -1,394 +0,0 @@
from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
ApolloCredentials,
ApolloCredentialsInput,
)
from backend.blocks.apollo.models import (
Contact,
ContactEmailStatuses,
SearchPeopleRequest,
SenorityLevels,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class SearchPeopleBlock(Block):
"""Search for people in Apollo"""
class Input(BlockSchema):
person_titles: list[str] = SchemaField(
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
Results also include job titles with the same terms, even if they are not exact matches. For example, searching for marketing manager might return people with the job title content marketing manager.
Use this parameter in combination with the person_seniorities[] parameter to find people based on specific job functions and seniority levels.
""",
default=[],
advanced=False,
)
person_locations: list[str] = SchemaField(
description="""The location where people live. You can search across cities, US states, and countries.
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
default=[],
advanced=False,
)
person_seniorities: list[SenorityLevels] = SchemaField(
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
For a person to be included in search results, they only need to match 1 of the seniorities you add. Adding more seniorities expands your search results.
Searches only return results based on their current job title, so searching for Director-level employees only returns people that currently hold a Director-level title. If someone was previously a Director, but is currently a VP, they would not be included in your search results.
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
default=[],
advanced=False,
)
organization_locations: list[str] = SchemaField(
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
To find people based on their personal location, use the person_locations parameter.""",
default=[],
advanced=False,
)
q_organization_domains: list[str] = SchemaField(
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
You can add multiple domains to search across companies.
Examples: apollo.io and microsoft.com""",
default=[],
advanced=False,
)
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
default=[],
advanced=False,
)
organization_ids: list[str] = SchemaField(
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
default=[],
advanced=False,
)
organization_num_empoloyees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
default=[],
advanced=False,
)
q_keywords: str = SchemaField(
description="""A string of words over which we want to filter the results""",
default="",
advanced=False,
)
max_results: int = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
default=100,
ge=1,
le=50000,
advanced=True,
)
credentials: ApolloCredentialsInput = SchemaField(
description="Apollo credentials",
)
class Output(BlockSchema):
people: list[Contact] = SchemaField(
description="List of people found",
default=[],
)
person: Contact = SchemaField(
description="Each found person, one at a time",
)
error: str = SchemaField(
description="Error message if the search failed",
default="",
)
def __init__(self):
super().__init__(
id="c2adb3aa-5aae-488d-8a6e-4eb8c23e2ed6",
description="Search for people in Apollo",
categories={BlockCategory.SEARCH},
input_schema=SearchPeopleBlock.Input,
output_schema=SearchPeopleBlock.Output,
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT},
test_output=[
(
"person",
Contact(
contact_roles=[],
id="1",
name="John Doe",
first_name="John",
last_name="Doe",
linkedin_url="https://www.linkedin.com/in/johndoe",
title="Software Engineer",
organization_name="Google",
organization_id="123456",
contact_stage_id="1",
owner_id="1",
creator_id="1",
person_id="1",
email_needs_tickling=True,
source="apollo",
original_source="apollo",
headline="Software Engineer",
photo_url="https://www.linkedin.com/in/johndoe",
present_raw_address="123 Main St, Anytown, USA",
linkededin_uid="123456",
extrapolated_email_confidence=0.8,
salesforce_id="123456",
salesforce_lead_id="123456",
salesforce_contact_id="123456",
saleforce_account_id="123456",
crm_owner_id="123456",
created_at="2021-01-01",
emailer_campaign_ids=[],
direct_dial_status="active",
direct_dial_enrichment_failed_at="2021-01-01",
email_status="active",
email_source="apollo",
account_id="123456",
last_activity_date="2021-01-01",
hubspot_vid="123456",
hubspot_company_id="123456",
crm_id="123456",
sanitized_phone="123456",
merged_crm_ids="123456",
updated_at="2021-01-01",
queued_for_crm_push=True,
suggested_from_rule_engine_config_id="123456",
email_unsubscribed=None,
label_ids=[],
has_pending_email_arcgate_request=True,
has_email_arcgate_request=True,
existence_level=None,
email=None,
email_from_customer=None,
typed_custom_fields=[],
custom_field_errors=None,
salesforce_record_id=None,
crm_record_url=None,
email_status_unavailable_reason=None,
email_true_status=None,
updated_email_true_status=True,
contact_rule_config_statuses=[],
source_display_name=None,
twitter_url=None,
contact_campaign_statuses=[],
state=None,
city=None,
country=None,
account=None,
contact_emails=[],
organization=None,
employment_history=[],
time_zone=None,
intent_strength=None,
show_intent=True,
phone_numbers=[],
account_phone_note=None,
free_domain=True,
is_likely_to_engage=True,
email_domain_catchall=True,
contact_job_change_event=None,
),
),
(
"people",
[
Contact(
contact_roles=[],
id="1",
name="John Doe",
first_name="John",
last_name="Doe",
linkedin_url="https://www.linkedin.com/in/johndoe",
title="Software Engineer",
organization_name="Google",
organization_id="123456",
contact_stage_id="1",
owner_id="1",
creator_id="1",
person_id="1",
email_needs_tickling=True,
source="apollo",
original_source="apollo",
headline="Software Engineer",
photo_url="https://www.linkedin.com/in/johndoe",
present_raw_address="123 Main St, Anytown, USA",
linkededin_uid="123456",
extrapolated_email_confidence=0.8,
salesforce_id="123456",
salesforce_lead_id="123456",
salesforce_contact_id="123456",
saleforce_account_id="123456",
crm_owner_id="123456",
created_at="2021-01-01",
emailer_campaign_ids=[],
direct_dial_status="active",
direct_dial_enrichment_failed_at="2021-01-01",
email_status="active",
email_source="apollo",
account_id="123456",
last_activity_date="2021-01-01",
hubspot_vid="123456",
hubspot_company_id="123456",
crm_id="123456",
sanitized_phone="123456",
merged_crm_ids="123456",
updated_at="2021-01-01",
queued_for_crm_push=True,
suggested_from_rule_engine_config_id="123456",
email_unsubscribed=None,
label_ids=[],
has_pending_email_arcgate_request=True,
has_email_arcgate_request=True,
existence_level=None,
email=None,
email_from_customer=None,
typed_custom_fields=[],
custom_field_errors=None,
salesforce_record_id=None,
crm_record_url=None,
email_status_unavailable_reason=None,
email_true_status=None,
updated_email_true_status=True,
contact_rule_config_statuses=[],
source_display_name=None,
twitter_url=None,
contact_campaign_statuses=[],
state=None,
city=None,
country=None,
account=None,
contact_emails=[],
organization=None,
employment_history=[],
time_zone=None,
intent_strength=None,
show_intent=True,
phone_numbers=[],
account_phone_note=None,
free_domain=True,
is_likely_to_engage=True,
email_domain_catchall=True,
contact_job_change_event=None,
),
],
),
],
test_mock={
"search_people": lambda query, credentials: [
Contact(
id="1",
name="John Doe",
first_name="John",
last_name="Doe",
linkedin_url="https://www.linkedin.com/in/johndoe",
title="Software Engineer",
organization_name="Google",
organization_id="123456",
contact_stage_id="1",
owner_id="1",
creator_id="1",
person_id="1",
email_needs_tickling=True,
source="apollo",
original_source="apollo",
headline="Software Engineer",
photo_url="https://www.linkedin.com/in/johndoe",
present_raw_address="123 Main St, Anytown, USA",
linkededin_uid="123456",
extrapolated_email_confidence=0.8,
salesforce_id="123456",
salesforce_lead_id="123456",
salesforce_contact_id="123456",
saleforce_account_id="123456",
crm_owner_id="123456",
created_at="2021-01-01",
emailer_campaign_ids=[],
direct_dial_status="active",
direct_dial_enrichment_failed_at="2021-01-01",
email_status="active",
email_source="apollo",
account_id="123456",
last_activity_date="2021-01-01",
hubspot_vid="123456",
hubspot_company_id="123456",
crm_id="123456",
sanitized_phone="123456",
merged_crm_ids="123456",
updated_at="2021-01-01",
queued_for_crm_push=True,
suggested_from_rule_engine_config_id="123456",
email_unsubscribed=None,
label_ids=[],
has_pending_email_arcgate_request=True,
has_email_arcgate_request=True,
existence_level=None,
email=None,
email_from_customer=None,
typed_custom_fields=[],
custom_field_errors=None,
salesforce_record_id=None,
crm_record_url=None,
email_status_unavailable_reason=None,
email_true_status=None,
updated_email_true_status=True,
contact_rule_config_statuses=[],
source_display_name=None,
twitter_url=None,
contact_campaign_statuses=[],
state=None,
city=None,
country=None,
account=None,
contact_emails=[],
organization=None,
employment_history=[],
time_zone=None,
intent_strength=None,
show_intent=True,
phone_numbers=[],
account_phone_note=None,
free_domain=True,
is_likely_to_engage=True,
email_domain_catchall=True,
contact_job_change_event=None,
),
]
},
)
@staticmethod
def search_people(
query: SearchPeopleRequest, credentials: ApolloCredentials
) -> list[Contact]:
client = ApolloClient(credentials)
return client.search_people(query)
def run(
self,
input_data: Input,
*,
credentials: ApolloCredentials,
**kwargs,
) -> BlockOutput:
query = SearchPeopleRequest(**input_data.model_dump(exclude={"credentials"}))
people = self.search_people(query, credentials)
for person in people:
yield "person", person
yield "people", people

View File

@@ -3,20 +3,22 @@ from typing import Any, List
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
from backend.data.model import SchemaField
from backend.util import json
from backend.util.file import store_media_file
from backend.util.file import MediaFile, store_media_file
from backend.util.mock import MockObject
from backend.util.type import MediaFileType, convert
from backend.util.text import TextFormatter
from backend.util.type import convert
formatter = TextFormatter()
class FileStoreBlock(Block):
class Input(BlockSchema):
file_in: MediaFileType = SchemaField(
file_in: MediaFile = SchemaField(
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
)
class Output(BlockSchema):
file_out: MediaFileType = SchemaField(
file_out: MediaFile = SchemaField(
description="The relative path to the stored file in the temporary directory."
)
@@ -88,6 +90,29 @@ class StoreValueBlock(Block):
yield "output", input_data.data or input_data.input
class PrintToConsoleBlock(Block):
class Input(BlockSchema):
text: str = SchemaField(description="The text to print to the console.")
class Output(BlockSchema):
status: str = SchemaField(description="The status of the print operation.")
def __init__(self):
super().__init__(
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
description="Print the given text to the console, this is used for a debugging purpose.",
categories={BlockCategory.BASIC},
input_schema=PrintToConsoleBlock.Input,
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
test_output=("status", "printed"),
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
print(">>>>> Print: ", input_data.text)
yield "status", "printed"
class FindInDictionaryBlock(Block):
class Input(BlockSchema):
input: Any = SchemaField(description="Dictionary to lookup from")
@@ -128,9 +153,6 @@ class FindInDictionaryBlock(Block):
obj = input_data.input
key = input_data.key
if isinstance(obj, str):
obj = json.loads(obj)
if isinstance(obj, dict) and key in obj:
yield "output", obj[key]
elif isinstance(obj, list) and isinstance(key, int) and 0 <= key < len(obj):
@@ -148,6 +170,188 @@ class FindInDictionaryBlock(Block):
yield "missing", input_data.input
class AgentInputBlock(Block):
"""
This block is used to provide input to the graph.
It takes in a value, name, description, default values list and bool to limit selection to default values.
It Outputs the value passed as input.
"""
class Input(BlockSchema):
name: str = SchemaField(description="The name of the input.")
value: Any = SchemaField(
description="The value to be passed as input.",
default=None,
)
title: str | None = SchemaField(
description="The title of the input.", default=None, advanced=True
)
description: str | None = SchemaField(
description="The description of the input.",
default=None,
advanced=True,
)
placeholder_values: List[Any] = SchemaField(
description="The placeholder values to be passed as input.",
default=[],
advanced=True,
)
limit_to_placeholder_values: bool = SchemaField(
description="Whether to limit the selection to placeholder values.",
default=False,
advanced=True,
)
advanced: bool = SchemaField(
description="Whether to show the input in the advanced section, if the field is not required.",
default=False,
advanced=True,
)
secret: bool = SchemaField(
description="Whether the input should be treated as a secret.",
default=False,
advanced=True,
)
class Output(BlockSchema):
result: Any = SchemaField(description="The value passed as input.")
def __init__(self):
super().__init__(
id="c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
description="This block is used to provide input to the graph.",
input_schema=AgentInputBlock.Input,
output_schema=AgentInputBlock.Output,
test_input=[
{
"value": "Hello, World!",
"name": "input_1",
"description": "This is a test input.",
"placeholder_values": [],
"limit_to_placeholder_values": False,
},
{
"value": "Hello, World!",
"name": "input_2",
"description": "This is a test input.",
"placeholder_values": ["Hello, World!"],
"limit_to_placeholder_values": True,
},
],
test_output=[
("result", "Hello, World!"),
("result", "Hello, World!"),
],
categories={BlockCategory.INPUT, BlockCategory.BASIC},
block_type=BlockType.INPUT,
static_output=True,
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "result", input_data.value
class AgentOutputBlock(Block):
"""
Records the output of the graph for users to see.
Behavior:
If `format` is provided and the `value` is of a type that can be formatted,
the block attempts to format the recorded_value using the `format`.
If formatting fails or no `format` is provided, the raw `value` is output.
"""
class Input(BlockSchema):
value: Any = SchemaField(
description="The value to be recorded as output.",
default=None,
advanced=False,
)
name: str = SchemaField(description="The name of the output.")
title: str | None = SchemaField(
description="The title of the output.",
default=None,
advanced=True,
)
description: str | None = SchemaField(
description="The description of the output.",
default=None,
advanced=True,
)
format: str = SchemaField(
description="The format string to be used to format the recorded_value. Use Jinja2 syntax.",
default="",
advanced=True,
)
advanced: bool = SchemaField(
description="Whether to treat the output as advanced.",
default=False,
advanced=True,
)
secret: bool = SchemaField(
description="Whether the output should be treated as a secret.",
default=False,
advanced=True,
)
class Output(BlockSchema):
output: Any = SchemaField(description="The value recorded as output.")
name: Any = SchemaField(description="The name of the value recorded as output.")
def __init__(self):
super().__init__(
id="363ae599-353e-4804-937e-b2ee3cef3da4",
description="Stores the output of the graph for users to see.",
input_schema=AgentOutputBlock.Input,
output_schema=AgentOutputBlock.Output,
test_input=[
{
"value": "Hello, World!",
"name": "output_1",
"description": "This is a test output.",
"format": "{{ output_1 }}!!",
},
{
"value": "42",
"name": "output_2",
"description": "This is another test output.",
"format": "{{ output_2 }}",
},
{
"value": MockObject(value="!!", key="key"),
"name": "output_3",
"description": "This is a test output with a mock object.",
"format": "{{ output_3 }}",
},
],
test_output=[
("output", "Hello, World!!!"),
("output", "42"),
("output", MockObject(value="!!", key="key")),
],
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
block_type=BlockType.OUTPUT,
static_output=True,
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
"""
Attempts to format the recorded_value using the fmt_string if provided.
If formatting fails or no fmt_string is given, returns the original recorded_value.
"""
if input_data.format:
try:
yield "output", formatter.format_string(
input_data.format, {input_data.name: input_data.value}
)
except Exception as e:
yield "output", f"Error: {e}, {input_data.value}"
else:
yield "output", input_data.value
yield "name", input_data.name
class AddToDictionaryBlock(Block):
class Input(BlockSchema):
dictionary: dict[Any, Any] = SchemaField(

View File

@@ -8,7 +8,6 @@ from backend.data.block import (
BlockSchema,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.compass import CompassWebhookType
@@ -43,7 +42,7 @@ class CompassAITriggerBlock(Block):
input_schema=CompassAITriggerBlock.Input,
output_schema=CompassAITriggerBlock.Output,
webhook_config=BlockManualWebhookConfig(
provider=ProviderName.COMPASS,
provider="compass",
webhook_type=CompassWebhookType.TRANSCRIPTION,
),
test_input=[

View File

@@ -51,7 +51,6 @@ class ExaContentsBlock(Block):
description="List of document contents",
default=[],
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(

View File

@@ -38,59 +38,6 @@ def _get_headers(credentials: GithubCredentials) -> dict[str, str]:
}
def convert_comment_url_to_api_endpoint(comment_url: str) -> str:
"""
Converts a GitHub comment URL (web interface) to the appropriate API endpoint URL.
Handles:
1. Issue/PR comments: #issuecomment-{id}
2. PR review comments: #discussion_r{id}
Returns the appropriate API endpoint path for the comment.
"""
# First, check if this is already an API URL
parsed_url = urlparse(comment_url)
if parsed_url.hostname == "api.github.com":
return comment_url
# Replace pull with issues for comment endpoints
if "/pull/" in comment_url:
comment_url = comment_url.replace("/pull/", "/issues/")
# Handle issue/PR comments (#issuecomment-xxx)
if "#issuecomment-" in comment_url:
base_url, comment_part = comment_url.split("#issuecomment-")
comment_id = comment_part
# Extract repo information from base URL
parsed_url = urlparse(base_url)
path_parts = parsed_url.path.strip("/").split("/")
owner, repo = path_parts[0], path_parts[1]
# Construct API URL for issue comments
return (
f"https://api.github.com/repos/{owner}/{repo}/issues/comments/{comment_id}"
)
# Handle PR review comments (#discussion_r)
elif "#discussion_r" in comment_url:
base_url, comment_part = comment_url.split("#discussion_r")
comment_id = comment_part
# Extract repo information from base URL
parsed_url = urlparse(base_url)
path_parts = parsed_url.path.strip("/").split("/")
owner, repo = path_parts[0], path_parts[1]
# Construct API URL for PR review comments
return (
f"https://api.github.com/repos/{owner}/{repo}/pulls/comments/{comment_id}"
)
# If no specific comment identifiers are found, use the general URL conversion
return _convert_to_api_url(comment_url)
def get_api(
credentials: GithubCredentials | GithubFineGrainedAPICredentials,
convert_urls: bool = True,

View File

@@ -172,9 +172,7 @@ class GithubCreateCheckRunBlock(Block):
data.output = output_data
check_runs_url = f"{repo_url}/check-runs"
response = api.post(
check_runs_url, data=data.model_dump_json(exclude_none=True)
)
response = api.post(check_runs_url)
result = response.json()
return {
@@ -325,9 +323,7 @@ class GithubUpdateCheckRunBlock(Block):
data.output = output_data
check_run_url = f"{repo_url}/check-runs/{check_run_id}"
response = api.patch(
check_run_url, data=data.model_dump_json(exclude_none=True)
)
response = api.patch(check_run_url)
result = response.json()
return {

View File

@@ -1,4 +1,3 @@
import logging
from urllib.parse import urlparse
from typing_extensions import TypedDict
@@ -6,7 +5,7 @@ from typing_extensions import TypedDict
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._api import convert_comment_url_to_api_endpoint, get_api
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
@@ -15,8 +14,6 @@ from ._auth import (
GithubCredentialsInput,
)
logger = logging.getLogger(__name__)
def is_github_url(url: str) -> bool:
return urlparse(url).netloc == "github.com"
@@ -111,228 +108,6 @@ class GithubCommentBlock(Block):
# --8<-- [end:GithubCommentBlockExample]
class GithubUpdateCommentBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
comment_url: str = SchemaField(
description="URL of the GitHub comment",
placeholder="https://github.com/owner/repo/issues/1#issuecomment-123456789",
default="",
advanced=False,
)
issue_url: str = SchemaField(
description="URL of the GitHub issue or pull request",
placeholder="https://github.com/owner/repo/issues/1",
default="",
)
comment_id: str = SchemaField(
description="ID of the GitHub comment",
placeholder="123456789",
default="",
)
comment: str = SchemaField(
description="Comment to update",
placeholder="Enter your comment",
)
class Output(BlockSchema):
id: int = SchemaField(description="ID of the updated comment")
url: str = SchemaField(description="URL to the comment on GitHub")
error: str = SchemaField(
description="Error message if the comment update failed"
)
def __init__(self):
super().__init__(
id="b3f4d747-10e3-4e69-8c51-f2be1d99c9a7",
description="This block updates a comment on a specified GitHub issue or pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubUpdateCommentBlock.Input,
output_schema=GithubUpdateCommentBlock.Output,
test_input={
"comment_url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
"comment": "This is an updated comment.",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("id", 123456789),
(
"url",
"https://github.com/owner/repo/issues/1#issuecomment-123456789",
),
],
test_mock={
"update_comment": lambda *args, **kwargs: (
123456789,
"https://github.com/owner/repo/issues/1#issuecomment-123456789",
)
},
)
@staticmethod
def update_comment(
credentials: GithubCredentials, comment_url: str, body_text: str
) -> tuple[int, str]:
api = get_api(credentials, convert_urls=False)
data = {"body": body_text}
url = convert_comment_url_to_api_endpoint(comment_url)
logger.info(url)
response = api.patch(url, json=data)
comment = response.json()
return comment["id"], comment["html_url"]
def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
if (
not input_data.comment_url
and input_data.comment_id
and input_data.issue_url
):
parsed_url = urlparse(input_data.issue_url)
path_parts = parsed_url.path.strip("/").split("/")
owner, repo = path_parts[0], path_parts[1]
input_data.comment_url = f"https://api.github.com/repos/{owner}/{repo}/issues/comments/{input_data.comment_id}"
elif (
not input_data.comment_url
and not input_data.comment_id
and input_data.issue_url
):
raise ValueError(
"Must provide either comment_url or comment_id and issue_url"
)
id, url = self.update_comment(
credentials,
input_data.comment_url,
input_data.comment,
)
yield "id", id
yield "url", url
class GithubListCommentsBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
issue_url: str = SchemaField(
description="URL of the GitHub issue or pull request",
placeholder="https://github.com/owner/repo/issues/1",
)
class Output(BlockSchema):
class CommentItem(TypedDict):
id: int
body: str
user: str
url: str
comment: CommentItem = SchemaField(
title="Comment", description="Comments with their ID, body, user, and URL"
)
comments: list[CommentItem] = SchemaField(
description="List of comments with their ID, body, user, and URL"
)
error: str = SchemaField(description="Error message if listing comments failed")
def __init__(self):
super().__init__(
id="c4b5fb63-0005-4a11-b35a-0c2467bd6b59",
description="This block lists all comments for a specified GitHub issue or pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListCommentsBlock.Input,
output_schema=GithubListCommentsBlock.Output,
test_input={
"issue_url": "https://github.com/owner/repo/issues/1",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"comment",
{
"id": 123456789,
"body": "This is a test comment.",
"user": "test_user",
"url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
},
),
(
"comments",
[
{
"id": 123456789,
"body": "This is a test comment.",
"user": "test_user",
"url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
}
],
),
],
test_mock={
"list_comments": lambda *args, **kwargs: [
{
"id": 123456789,
"body": "This is a test comment.",
"user": "test_user",
"url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
}
]
},
)
@staticmethod
def list_comments(
credentials: GithubCredentials, issue_url: str
) -> list[Output.CommentItem]:
parsed_url = urlparse(issue_url)
path_parts = parsed_url.path.strip("/").split("/")
owner = path_parts[0]
repo = path_parts[1]
# GitHub API uses 'issues' for both issues and pull requests when it comes to comments
issue_number = path_parts[3] # Whether 'issues/123' or 'pull/123'
# Construct the proper API URL directly
api_url = f"https://api.github.com/repos/{owner}/{repo}/issues/{issue_number}/comments"
# Set convert_urls=False since we're already providing an API URL
api = get_api(credentials, convert_urls=False)
response = api.get(api_url)
comments = response.json()
parsed_comments: list[GithubListCommentsBlock.Output.CommentItem] = [
{
"id": comment["id"],
"body": comment["body"],
"user": comment["user"]["login"],
"url": comment["html_url"],
}
for comment in comments
]
return parsed_comments
def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
comments = self.list_comments(
credentials,
input_data.issue_url,
)
yield from (("comment", comment) for comment in comments)
yield "comments", comments
class GithubMakeIssueBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")

View File

@@ -144,7 +144,7 @@ class GithubCreateStatusBlock(Block):
data.description = description
status_url = f"{repo_url}/statuses/{sha}"
response = api.post(status_url, data=data.model_dump_json(exclude_none=True))
response = api.post(status_url, json=data)
result = response.json()
return {

View File

@@ -12,7 +12,6 @@ from backend.data.block import (
BlockWebhookConfig,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from ._auth import (
TEST_CREDENTIALS,
@@ -124,7 +123,7 @@ class GithubPullRequestTriggerBlock(GitHubTriggerBase, Block):
output_schema=GithubPullRequestTriggerBlock.Output,
# --8<-- [start:example-webhook_config]
webhook_config=BlockWebhookConfig(
provider=ProviderName.GITHUB,
provider="github",
webhook_type=GithubWebhookType.REPO,
resource_format="{repo}",
event_filter_input="events",

View File

@@ -8,7 +8,6 @@ from pydantic import BaseModel
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.settings import Settings
from ._auth import (
GOOGLE_OAUTH_IS_CONFIGURED,
@@ -151,8 +150,8 @@ class GmailReadBlock(Block):
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=Settings().secrets.google_client_id,
client_secret=Settings().secrets.google_client_secret,
client_id=kwargs.get("client_id"),
client_secret=kwargs.get("client_secret"),
scopes=credentials.scopes,
)
return build("gmail", "v1", credentials=creds)

View File

@@ -3,7 +3,6 @@ from googleapiclient.discovery import build
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.settings import Settings
from ._auth import (
GOOGLE_OAUTH_IS_CONFIGURED,
@@ -87,8 +86,8 @@ class GoogleSheetsReadBlock(Block):
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=Settings().secrets.google_client_id,
client_secret=Settings().secrets.google_client_secret,
client_id=kwargs.get("client_id"),
client_secret=kwargs.get("client_secret"),
scopes=credentials.scopes,
)
return build("sheets", "v4", credentials=creds)

View File

@@ -1,16 +1,11 @@
import json
import logging
from enum import Enum
from typing import Any
from requests.exceptions import HTTPError, RequestException
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import requests
logger = logging.getLogger(name=__name__)
class HttpMethod(Enum):
GET = "GET"
@@ -48,9 +43,8 @@ class SendWebRequestBlock(Block):
class Output(BlockSchema):
response: object = SchemaField(description="The response from the server")
client_error: object = SchemaField(description="Errors on 4xx status codes")
server_error: object = SchemaField(description="Errors on 5xx status codes")
error: str = SchemaField(description="Errors for all other exceptions")
client_error: object = SchemaField(description="The error on 4xx status codes")
server_error: object = SchemaField(description="The error on 5xx status codes")
def __init__(self):
super().__init__(
@@ -74,40 +68,20 @@ class SendWebRequestBlock(Block):
# we should send it as plain text instead
input_data.json_format = False
try:
response = requests.request(
input_data.method.value,
input_data.url,
headers=input_data.headers,
json=body if input_data.json_format else None,
data=body if not input_data.json_format else None,
)
result = response.json() if input_data.json_format else response.text
response = requests.request(
input_data.method.value,
input_data.url,
headers=input_data.headers,
json=body if input_data.json_format else None,
data=body if not input_data.json_format else None,
)
result = response.json() if input_data.json_format else response.text
if response.status_code // 100 == 2:
yield "response", result
except HTTPError as e:
# Handle error responses
try:
result = e.response.json() if input_data.json_format else str(e)
except json.JSONDecodeError:
result = str(e)
if 400 <= e.response.status_code < 500:
yield "client_error", result
elif 500 <= e.response.status_code < 600:
yield "server_error", result
else:
error_msg = (
"Unexpected status code "
f"{e.response.status_code} '{e.response.reason}'"
)
logger.warning(error_msg)
yield "error", error_msg
except RequestException as e:
# Handle other request-related exceptions
yield "error", str(e)
except Exception as e:
# Catch any other unexpected exceptions
yield "error", str(e)
elif response.status_code // 100 == 4:
yield "client_error", result
elif response.status_code // 100 == 5:
yield "server_error", result
else:
raise ValueError(f"Unexpected status code: {response.status_code}")

View File

@@ -142,16 +142,6 @@ class IdeogramModelBlock(Block):
title="Color Palette Preset",
advanced=True,
)
custom_color_palette: Optional[list[str]] = SchemaField(
description=(
"Only available for model version V_2 or V_2_TURBO. Provide one or more color hex codes "
"(e.g., ['#000030', '#1C0C47', '#9900FF', '#4285F4', '#FFFFFF']) to define a custom color "
"palette. Only used if 'color_palette_name' is 'NONE'."
),
default=None,
title="Custom Color Palette",
advanced=True,
)
class Output(BlockSchema):
result: str = SchemaField(description="Generated image URL")
@@ -174,13 +164,6 @@ class IdeogramModelBlock(Block):
"style_type": StyleType.AUTO,
"negative_prompt": None,
"color_palette_name": ColorPalettePreset.NONE,
"custom_color_palette": [
"#000030",
"#1C0C47",
"#9900FF",
"#4285F4",
"#FFFFFF",
],
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
@@ -190,7 +173,7 @@ class IdeogramModelBlock(Block):
),
],
test_mock={
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name, custom_colors: "https://ideogram.ai/api/images/test-generated-image-url.png",
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name: "https://ideogram.ai/api/images/test-generated-image-url.png",
"upscale_image": lambda api_key, image_url: "https://ideogram.ai/api/images/test-upscaled-image-url.png",
},
test_credentials=TEST_CREDENTIALS,
@@ -212,7 +195,6 @@ class IdeogramModelBlock(Block):
style_type=input_data.style_type.value,
negative_prompt=input_data.negative_prompt,
color_palette_name=input_data.color_palette_name.value,
custom_colors=input_data.custom_color_palette,
)
# Step 2: Upscale the image if requested
@@ -235,7 +217,6 @@ class IdeogramModelBlock(Block):
style_type: str,
negative_prompt: Optional[str],
color_palette_name: str,
custom_colors: Optional[list[str]],
):
url = "https://api.ideogram.ai/generate"
headers = {
@@ -260,11 +241,7 @@ class IdeogramModelBlock(Block):
data["image_request"]["negative_prompt"] = negative_prompt
if color_palette_name != "NONE":
data["color_palette"] = {"name": color_palette_name}
elif custom_colors:
data["color_palette"] = {
"members": [{"color_hex": color} for color in custom_colors]
}
data["image_request"]["color_palette"] = {"name": color_palette_name}
try:
response = requests.post(url, json=data, headers=headers)
@@ -290,7 +267,9 @@ class IdeogramModelBlock(Block):
response = requests.post(
url,
headers=headers,
data={"image_request": "{}"},
data={
"image_request": "{}", # Empty JSON object
},
files=files,
)

View File

@@ -1,555 +0,0 @@
from datetime import date, time
from typing import Any, Optional
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.mock import MockObject
from backend.util.settings import Config
from backend.util.text import TextFormatter
from backend.util.type import LongTextType, MediaFileType, ShortTextType
formatter = TextFormatter()
config = Config()
class AgentInputBlock(Block):
"""
This block is used to provide input to the graph.
It takes in a value, name, description, default values list and bool to limit selection to default values.
It Outputs the value passed as input.
"""
class Input(BlockSchema):
name: str = SchemaField(description="The name of the input.")
value: Any = SchemaField(
description="The value to be passed as input.",
default=None,
)
title: str | None = SchemaField(
description="The title of the input.", default=None, advanced=True
)
description: str | None = SchemaField(
description="The description of the input.",
default=None,
advanced=True,
)
placeholder_values: list = SchemaField(
description="The placeholder values to be passed as input.",
default=[],
advanced=True,
hidden=True,
)
advanced: bool = SchemaField(
description="Whether to show the input in the advanced section, if the field is not required.",
default=False,
advanced=True,
)
secret: bool = SchemaField(
description="Whether the input should be treated as a secret.",
default=False,
advanced=True,
)
def generate_schema(self):
schema = self.get_field_schema("value")
if possible_values := self.placeholder_values:
schema["enum"] = possible_values
return schema
class Output(BlockSchema):
result: Any = SchemaField(description="The value passed as input.")
def __init__(self, **kwargs):
super().__init__(
**{
"id": "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
"description": "Base block for user inputs.",
"input_schema": AgentInputBlock.Input,
"output_schema": AgentInputBlock.Output,
"test_input": [
{
"value": "Hello, World!",
"name": "input_1",
"description": "Example test input.",
"placeholder_values": [],
},
{
"value": "Hello, World!",
"name": "input_2",
"description": "Example test input with placeholders.",
"placeholder_values": ["Hello, World!"],
},
],
"test_output": [
("result", "Hello, World!"),
("result", "Hello, World!"),
],
"categories": {BlockCategory.INPUT, BlockCategory.BASIC},
"block_type": BlockType.INPUT,
"static_output": True,
**kwargs,
}
)
def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
if input_data.value is not None:
yield "result", input_data.value
class AgentOutputBlock(Block):
"""
Records the output of the graph for users to see.
Behavior:
If `format` is provided and the `value` is of a type that can be formatted,
the block attempts to format the recorded_value using the `format`.
If formatting fails or no `format` is provided, the raw `value` is output.
"""
class Input(BlockSchema):
value: Any = SchemaField(
description="The value to be recorded as output.",
default=None,
advanced=False,
)
name: str = SchemaField(description="The name of the output.")
title: str | None = SchemaField(
description="The title of the output.",
default=None,
advanced=True,
)
description: str | None = SchemaField(
description="The description of the output.",
default=None,
advanced=True,
)
format: str = SchemaField(
description="The format string to be used to format the recorded_value. Use Jinja2 syntax.",
default="",
advanced=True,
)
advanced: bool = SchemaField(
description="Whether to treat the output as advanced.",
default=False,
advanced=True,
)
secret: bool = SchemaField(
description="Whether the output should be treated as a secret.",
default=False,
advanced=True,
)
def generate_schema(self):
return self.get_field_schema("value")
class Output(BlockSchema):
output: Any = SchemaField(description="The value recorded as output.")
name: Any = SchemaField(description="The name of the value recorded as output.")
def __init__(self):
super().__init__(
id="363ae599-353e-4804-937e-b2ee3cef3da4",
description="Stores the output of the graph for users to see.",
input_schema=AgentOutputBlock.Input,
output_schema=AgentOutputBlock.Output,
test_input=[
{
"value": "Hello, World!",
"name": "output_1",
"description": "This is a test output.",
"format": "{{ output_1 }}!!",
},
{
"value": "42",
"name": "output_2",
"description": "This is another test output.",
"format": "{{ output_2 }}",
},
{
"value": MockObject(value="!!", key="key"),
"name": "output_3",
"description": "This is a test output with a mock object.",
"format": "{{ output_3 }}",
},
],
test_output=[
("output", "Hello, World!!!"),
("output", "42"),
("output", MockObject(value="!!", key="key")),
],
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
block_type=BlockType.OUTPUT,
static_output=True,
)
def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
"""
Attempts to format the recorded_value using the fmt_string if provided.
If formatting fails or no fmt_string is given, returns the original recorded_value.
"""
if input_data.format:
try:
yield "output", formatter.format_string(
input_data.format, {input_data.name: input_data.value}
)
except Exception as e:
yield "output", f"Error: {e}, {input_data.value}"
else:
yield "output", input_data.value
yield "name", input_data.name
class AgentShortTextInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: Optional[ShortTextType] = SchemaField(
description="Short text input.",
default=None,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="Short text result.")
def __init__(self):
super().__init__(
id="7fcd3bcb-8e1b-4e69-903d-32d3d4a92158",
description="Block for short text input (single-line).",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentShortTextInputBlock.Input,
output_schema=AgentShortTextInputBlock.Output,
test_input=[
{
"value": "Hello",
"name": "short_text_1",
"description": "Short text example 1",
"placeholder_values": [],
},
{
"value": "Quick test",
"name": "short_text_2",
"description": "Short text example 2",
"placeholder_values": ["Quick test", "Another option"],
},
],
test_output=[
("result", "Hello"),
("result", "Quick test"),
],
)
class AgentLongTextInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: Optional[LongTextType] = SchemaField(
description="Long text input (potentially multi-line).",
default=None,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="Long text result.")
def __init__(self):
super().__init__(
id="90a56ffb-7024-4b2b-ab50-e26c5e5ab8ba",
description="Block for long text input (multi-line).",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentLongTextInputBlock.Input,
output_schema=AgentLongTextInputBlock.Output,
test_input=[
{
"value": "Lorem ipsum dolor sit amet...",
"name": "long_text_1",
"description": "Long text example 1",
"placeholder_values": [],
},
{
"value": "Another multiline text input.",
"name": "long_text_2",
"description": "Long text example 2",
"placeholder_values": ["Another multiline text input."],
},
],
test_output=[
("result", "Lorem ipsum dolor sit amet..."),
("result", "Another multiline text input."),
],
)
class AgentNumberInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: Optional[int] = SchemaField(
description="Number input.",
default=None,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: int = SchemaField(description="Number result.")
def __init__(self):
super().__init__(
id="96dae2bb-97a2-41c2-bd2f-13a3b5a8ea98",
description="Block for number input.",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentNumberInputBlock.Input,
output_schema=AgentNumberInputBlock.Output,
test_input=[
{
"value": 42,
"name": "number_input_1",
"description": "Number example 1",
"placeholder_values": [],
},
{
"value": 314,
"name": "number_input_2",
"description": "Number example 2",
"placeholder_values": [314, 2718],
},
],
test_output=[
("result", 42),
("result", 314),
],
)
class AgentDateInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: Optional[date] = SchemaField(
description="Date input (YYYY-MM-DD).",
default=None,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: date = SchemaField(description="Date result.")
def __init__(self):
super().__init__(
id="7e198b09-4994-47db-8b4d-952d98241817",
description="Block for date input.",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentDateInputBlock.Input,
output_schema=AgentDateInputBlock.Output,
test_input=[
{
# If your system can parse JSON date strings to date objects
"value": str(date(2025, 3, 19)),
"name": "date_input_1",
"description": "Example date input 1",
},
{
"value": str(date(2023, 12, 31)),
"name": "date_input_2",
"description": "Example date input 2",
},
],
test_output=[
("result", date(2025, 3, 19)),
("result", date(2023, 12, 31)),
],
)
class AgentTimeInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: Optional[time] = SchemaField(
description="Time input (HH:MM:SS).",
default=None,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: time = SchemaField(description="Time result.")
def __init__(self):
super().__init__(
id="2a1c757e-86cf-4c7e-aacf-060dc382e434",
description="Block for time input.",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentTimeInputBlock.Input,
output_schema=AgentTimeInputBlock.Output,
test_input=[
{
"value": str(time(9, 30, 0)),
"name": "time_input_1",
"description": "Time example 1",
},
{
"value": str(time(23, 59, 59)),
"name": "time_input_2",
"description": "Time example 2",
},
],
test_output=[
("result", time(9, 30, 0)),
("result", time(23, 59, 59)),
],
)
class AgentFileInputBlock(AgentInputBlock):
"""
A simplified file-upload block. In real usage, you might have a custom
file type or handle binary data. Here, we'll store a string path as the example.
"""
class Input(AgentInputBlock.Input):
value: Optional[MediaFileType] = SchemaField(
description="Path or reference to an uploaded file.",
default=None,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="File reference/path result.")
def __init__(self):
super().__init__(
id="95ead23f-8283-4654-aef3-10c053b74a31",
description="Block for file upload input (string path for example).",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentFileInputBlock.Input,
output_schema=AgentFileInputBlock.Output,
test_input=[
{
"value": "data:image/png;base64,MQ==",
"name": "file_upload_1",
"description": "Example file upload 1",
},
],
test_output=[
("result", str),
],
)
def run(
self,
input_data: Input,
*,
graph_exec_id: str,
**kwargs,
) -> BlockOutput:
if not input_data.value:
return
file_path = store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.value,
return_content=False,
)
yield "result", file_path
class AgentDropdownInputBlock(AgentInputBlock):
"""
A specialized text input block that relies on placeholder_values to present a dropdown.
"""
class Input(AgentInputBlock.Input):
value: Optional[str] = SchemaField(
description="Text selected from a dropdown.",
default=None,
advanced=False,
title="Default Value",
)
placeholder_values: list = SchemaField(
description="Possible values for the dropdown.",
default=[],
advanced=False,
title="Dropdown Options",
)
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="Selected dropdown value.")
def __init__(self):
super().__init__(
id="655d6fdf-a334-421c-b733-520549c07cd1",
description="Block for dropdown text selection.",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentDropdownInputBlock.Input,
output_schema=AgentDropdownInputBlock.Output,
test_input=[
{
"value": "Option A",
"name": "dropdown_1",
"placeholder_values": ["Option A", "Option B", "Option C"],
"description": "Dropdown example 1",
},
{
"value": "Option C",
"name": "dropdown_2",
"placeholder_values": ["Option A", "Option B", "Option C"],
"description": "Dropdown example 2",
},
],
test_output=[
("result", "Option A"),
("result", "Option C"),
],
)
class AgentToggleInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: bool = SchemaField(
description="Boolean toggle input.",
default=False,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: bool = SchemaField(description="Boolean toggle result.")
def __init__(self):
super().__init__(
id="cbf36ab5-df4a-43b6-8a7f-f7ed8652116e",
description="Block for boolean toggle input.",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentToggleInputBlock.Input,
output_schema=AgentToggleInputBlock.Output,
test_input=[
{
"value": True,
"name": "toggle_1",
"description": "Toggle example 1",
},
{
"value": False,
"name": "toggle_2",
"description": "Toggle example 2",
},
],
test_output=[
("result", True),
("result", False),
],
)
IO_BLOCK_IDs = [
AgentInputBlock().id,
AgentOutputBlock().id,
AgentShortTextInputBlock().id,
AgentLongTextInputBlock().id,
AgentNumberInputBlock().id,
AgentDateInputBlock().id,
AgentTimeInputBlock().id,
AgentFileInputBlock().id,
AgentDropdownInputBlock().id,
AgentToggleInputBlock().id,
]

View File

@@ -4,11 +4,10 @@ from abc import ABC
from enum import Enum, EnumMeta
from json import JSONDecodeError
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, NamedTuple, Optional
from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple
from pydantic import BaseModel, SecretStr
from pydantic import SecretStr
from backend.data.model import NodeExecutionStats
from backend.integrations.providers import ProviderName
if TYPE_CHECKING:
@@ -17,8 +16,6 @@ if TYPE_CHECKING:
import anthropic
import ollama
import openai
from anthropic._types import NotGiven
from anthropic.types import ToolParam
from groq import Groq
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
@@ -230,299 +227,15 @@ for model in LlmModel:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
class ToolCall(BaseModel):
name: str
arguments: str
class MessageRole(str, Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
class ToolContentBlock(BaseModel):
id: str
type: str
function: ToolCall
class LLMResponse(BaseModel):
raw_response: Any
prompt: List[Any]
response: str
tool_calls: Optional[List[ToolContentBlock]] | None
prompt_tokens: int
completion_tokens: int
def convert_openai_tool_fmt_to_anthropic(
openai_tools: list[dict] | None = None,
) -> Iterable[ToolParam] | NotGiven:
"""
Convert OpenAI tool format to Anthropic tool format.
"""
if not openai_tools or len(openai_tools) == 0:
return anthropic.NOT_GIVEN
anthropic_tools = []
for tool in openai_tools:
if "function" in tool:
# Handle case where tool is already in OpenAI format with "type" and "function"
function_data = tool["function"]
else:
# Handle case where tool is just the function definition
function_data = tool
anthropic_tool: anthropic.types.ToolParam = {
"name": function_data["name"],
"description": function_data.get("description", ""),
"input_schema": {
"type": "object",
"properties": function_data.get("parameters", {}).get("properties", {}),
"required": function_data.get("parameters", {}).get("required", []),
},
}
anthropic_tools.append(anthropic_tool)
return anthropic_tools
def llm_call(
credentials: APIKeyCredentials,
llm_model: LlmModel,
prompt: list[dict],
json_format: bool,
max_tokens: int | None,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
) -> LLMResponse:
"""
Make a call to a language model.
Args:
credentials: The API key credentials to use.
llm_model: The LLM model to use.
prompt: The prompt to send to the LLM.
json_format: Whether the response should be in JSON format.
max_tokens: The maximum number of tokens to generate in the chat completion.
tools: The tools to use in the chat completion.
ollama_host: The host for ollama to use.
Returns:
LLMResponse object containing:
- prompt: The prompt sent to the LLM.
- response: The text response from the LLM.
- tool_calls: Any tool calls the model made, if applicable.
- prompt_tokens: The number of tokens used in the prompt.
- completion_tokens: The number of tokens used in the completion.
"""
provider = llm_model.metadata.provider
max_tokens = max_tokens or llm_model.max_output_tokens or 4096
if provider == "openai":
tools_param = tools if tools else openai.NOT_GIVEN
oai_client = openai.OpenAI(api_key=credentials.api_key.get_secret_value())
response_format = None
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
prompt = [
{"role": "user", "content": "\n".join(sys_messages)},
{"role": "user", "content": "\n".join(usr_messages)},
]
elif json_format:
response_format = {"type": "json_object"}
response = oai_client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
tools=tools_param, # type: ignore
)
if response.choices[0].message.tool_calls:
tool_calls = [
ToolContentBlock(
id=tool.id,
type=tool.type,
function=ToolCall(
name=tool.function.name,
arguments=tool.function.arguments,
),
)
for tool in response.choices[0].message.tool_calls
]
else:
tool_calls = None
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
response=response.choices[0].message.content or "",
tool_calls=tool_calls,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
)
elif provider == "anthropic":
an_tools = convert_openai_tool_fmt_to_anthropic(tools)
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
sysprompt = " ".join(system_messages)
messages = []
last_role = None
for p in prompt:
if p["role"] in ["user", "assistant"]:
if (
p["role"] == last_role
and isinstance(messages[-1]["content"], str)
and isinstance(p["content"], str)
):
# If the role is the same as the last one, combine the content
messages[-1]["content"] += p["content"]
else:
messages.append({"role": p["role"], "content": p["content"]})
last_role = p["role"]
client = anthropic.Anthropic(api_key=credentials.api_key.get_secret_value())
try:
resp = client.messages.create(
model=llm_model.value,
system=sysprompt,
messages=messages,
max_tokens=max_tokens,
tools=an_tools,
)
if not resp.content:
raise ValueError("No content returned from Anthropic.")
tool_calls = None
for content_block in resp.content:
# Antropic is different to openai, need to iterate through
# the content blocks to find the tool calls
if content_block.type == "tool_use":
if tool_calls is None:
tool_calls = []
tool_calls.append(
ToolContentBlock(
id=content_block.id,
type=content_block.type,
function=ToolCall(
name=content_block.name,
arguments=json.dumps(content_block.input),
),
)
)
if not tool_calls and resp.stop_reason == "tool_use":
logger.warning(
"Tool use stop reason but no tool calls found in content. %s", resp
)
return LLMResponse(
raw_response=resp,
prompt=prompt,
response=(
resp.content[0].name
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
else resp.content[0].text
),
tool_calls=tool_calls,
prompt_tokens=resp.usage.input_tokens,
completion_tokens=resp.usage.output_tokens,
)
except anthropic.APIError as e:
error_message = f"Anthropic API error: {str(e)}"
logger.error(error_message)
raise ValueError(error_message)
elif provider == "groq":
if tools:
raise ValueError("Groq does not support tools.")
client = Groq(api_key=credentials.api_key.get_secret_value())
response_format = {"type": "json_object"} if json_format else None
response = client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
)
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
response=response.choices[0].message.content or "",
tool_calls=None,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
)
elif provider == "ollama":
if tools:
raise ValueError("Ollama does not support tools.")
client = ollama.Client(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
response = client.generate(
model=llm_model.value,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
)
return LLMResponse(
raw_response=response.get("response") or "",
prompt=prompt,
response=response.get("response") or "",
tool_calls=None,
prompt_tokens=response.get("prompt_eval_count") or 0,
completion_tokens=response.get("eval_count") or 0,
)
elif provider == "open_router":
tools_param = tools if tools else openai.NOT_GIVEN
client = openai.OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=credentials.api_key.get_secret_value(),
)
response = client.chat.completions.create(
extra_headers={
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
)
# If there's no response, raise an error
if not response.choices:
if response:
raise ValueError(f"OpenRouter error: {response}")
else:
raise ValueError("No response from OpenRouter.")
if response.choices[0].message.tool_calls:
tool_calls = [
ToolContentBlock(
id=tool.id,
type=tool.type,
function=ToolCall(
name=tool.function.name, arguments=tool.function.arguments
),
)
for tool in response.choices[0].message.tool_calls
]
else:
tool_calls = None
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
response=response.choices[0].message.content or "",
tool_calls=tool_calls,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
class Message(BlockSchema):
role: MessageRole
content: str
class AIBlockBase(Block, ABC):
@@ -547,7 +260,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=LlmModel.GPT4_TURBO,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -557,7 +270,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
default="",
description="The system prompt to provide additional context to the model.",
)
conversation_history: list[dict] = SchemaField(
conversation_history: list[Message] = SchemaField(
default=[],
description="The conversation history to provide context for the prompt.",
)
@@ -598,7 +311,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
input_schema=AIStructuredResponseGeneratorBlock.Input,
output_schema=AIStructuredResponseGeneratorBlock.Output,
test_input={
"model": LlmModel.GPT4O,
"model": LlmModel.GPT4_TURBO,
"credentials": TEST_CREDENTIALS_INPUT,
"expected_format": {
"key1": "value1",
@@ -612,21 +325,19 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
("prompt", str),
],
test_mock={
"llm_call": lambda *args, **kwargs: LLMResponse(
raw_response="",
prompt=[""],
response=json.dumps(
"llm_call": lambda *args, **kwargs: (
json.dumps(
{
"key1": "key1Value",
"key2": "key2Value",
}
),
tool_calls=None,
prompt_tokens=0,
completion_tokens=0,
0,
0,
)
},
)
self.prompt = ""
def llm_call(
self,
@@ -635,28 +346,160 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
prompt: list[dict],
json_format: bool,
max_tokens: int | None,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
) -> LLMResponse:
) -> tuple[str, int, int]:
"""
Test mocks work only on class functions, this wraps the llm_call function
so that it can be mocked withing the block testing framework.
Args:
credentials: The API key credentials to use.
llm_model: The LLM model to use.
prompt: The prompt to send to the LLM.
json_format: Whether the response should be in JSON format.
max_tokens: The maximum number of tokens to generate in the chat completion.
ollama_host: The host for ollama to use
Returns:
The response from the LLM.
The number of tokens used in the prompt.
The number of tokens used in the completion.
"""
return llm_call(
credentials=credentials,
llm_model=llm_model,
prompt=prompt,
json_format=json_format,
max_tokens=max_tokens,
tools=tools,
ollama_host=ollama_host,
)
provider = llm_model.metadata.provider
max_tokens = max_tokens or llm_model.max_output_tokens or 4096
if provider == "openai":
oai_client = openai.OpenAI(api_key=credentials.api_key.get_secret_value())
response_format = None
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
prompt = [
{"role": "user", "content": "\n".join(sys_messages)},
{"role": "user", "content": "\n".join(usr_messages)},
]
elif json_format:
response_format = {"type": "json_object"}
response = oai_client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
)
self.prompt = json.dumps(prompt)
return (
response.choices[0].message.content or "",
response.usage.prompt_tokens if response.usage else 0,
response.usage.completion_tokens if response.usage else 0,
)
elif provider == "anthropic":
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
sysprompt = " ".join(system_messages)
messages = []
last_role = None
for p in prompt:
if p["role"] in ["user", "assistant"]:
if p["role"] != last_role:
messages.append({"role": p["role"], "content": p["content"]})
last_role = p["role"]
else:
# If the role is the same as the last one, combine the content
messages[-1]["content"] += "\n" + p["content"]
client = anthropic.Anthropic(api_key=credentials.api_key.get_secret_value())
try:
resp = client.messages.create(
model=llm_model.value,
system=sysprompt,
messages=messages,
max_tokens=max_tokens,
)
self.prompt = json.dumps(prompt)
if not resp.content:
raise ValueError("No content returned from Anthropic.")
return (
(
resp.content[0].name
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
else resp.content[0].text
),
resp.usage.input_tokens,
resp.usage.output_tokens,
)
except anthropic.APIError as e:
error_message = f"Anthropic API error: {str(e)}"
logger.error(error_message)
raise ValueError(error_message)
elif provider == "groq":
client = Groq(api_key=credentials.api_key.get_secret_value())
response_format = {"type": "json_object"} if json_format else None
response = client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
)
self.prompt = json.dumps(prompt)
return (
response.choices[0].message.content or "",
response.usage.prompt_tokens if response.usage else 0,
response.usage.completion_tokens if response.usage else 0,
)
elif provider == "ollama":
client = ollama.Client(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
response = client.generate(
model=llm_model.value,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
)
self.prompt = json.dumps(prompt)
return (
response.get("response") or "",
response.get("prompt_eval_count") or 0,
response.get("eval_count") or 0,
)
elif provider == "open_router":
client = openai.OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=credentials.api_key.get_secret_value(),
)
response = client.chat.completions.create(
extra_headers={
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
)
self.prompt = json.dumps(prompt)
# If there's no response, raise an error
if not response.choices:
if response:
raise ValueError(f"OpenRouter error: {response}")
else:
raise ValueError("No response from OpenRouter.")
return (
response.choices[0].message.content or "",
response.usage.prompt_tokens if response.usage else 0,
response.usage.completion_tokens if response.usage else 0,
)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
logger.debug(f"Calling LLM with input data: {input_data}")
prompt = [json.to_dict(p) for p in input_data.conversation_history]
prompt = [p.model_dump() for p in input_data.conversation_history]
def trim_prompt(s: str) -> str:
lines = s.strip().split("\n")
@@ -706,7 +549,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
for retry_count in range(input_data.retry):
try:
llm_response = self.llm_call(
response_text, input_token, output_token = self.llm_call(
credentials=credentials,
llm_model=llm_model,
prompt=prompt,
@@ -714,12 +557,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
ollama_host=input_data.ollama_host,
max_tokens=input_data.max_tokens,
)
response_text = llm_response.response
self.merge_stats(
NodeExecutionStats(
input_token_count=llm_response.prompt_tokens,
output_token_count=llm_response.completion_tokens,
)
{
"input_token_count": input_token,
"output_token_count": output_token,
}
)
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
@@ -762,10 +604,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
retry_prompt = f"Error calling LLM: {e}"
finally:
self.merge_stats(
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
)
{
"llm_call_count": retry_count + 1,
"llm_retry_count": retry_count,
}
)
raise RuntimeError(retry_prompt)
@@ -779,7 +621,7 @@ class AITextGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=LlmModel.GPT4_TURBO,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -872,7 +714,7 @@ class AITextSummarizerBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=LlmModel.GPT4_TURBO,
description="The language model to use for summarizing the text.",
)
focus: str = SchemaField(
@@ -1033,12 +875,12 @@ class AITextSummarizerBlock(AIBlockBase):
class AIConversationBlock(AIBlockBase):
class Input(BlockSchema):
messages: List[Any] = SchemaField(
messages: List[Message] = SchemaField(
description="List of messages in the conversation.", min_length=1
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=LlmModel.GPT4_TURBO,
description="The language model to use for the conversation.",
)
credentials: AICredentials = AICredentialsField()
@@ -1077,7 +919,7 @@ class AIConversationBlock(AIBlockBase):
},
{"role": "user", "content": "Where was it played?"},
],
"model": LlmModel.GPT4O,
"model": LlmModel.GPT4_TURBO,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
@@ -1139,7 +981,7 @@ class AIListGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=LlmModel.GPT4_TURBO,
description="The language model to use for generating the list.",
advanced=True,
)
@@ -1188,7 +1030,7 @@ class AIListGeneratorBlock(AIBlockBase):
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
"fictional worlds."
),
"model": LlmModel.GPT4O,
"model": LlmModel.GPT4_TURBO,
"credentials": TEST_CREDENTIALS_INPUT,
"max_retries": 3,
},

View File

@@ -8,13 +8,13 @@ from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
from backend.util.file import MediaFile, get_exec_file_path, store_media_file
class MediaDurationBlock(Block):
class Input(BlockSchema):
media_in: MediaFileType = SchemaField(
media_in: MediaFile = SchemaField(
description="Media input (URL, data URI, or local path)."
)
is_video: bool = SchemaField(
@@ -69,7 +69,7 @@ class LoopVideoBlock(Block):
"""
class Input(BlockSchema):
video_in: MediaFileType = SchemaField(
video_in: MediaFile = SchemaField(
description="The input video (can be a URL, data URI, or local path)."
)
# Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`.
@@ -137,7 +137,7 @@ class LoopVideoBlock(Block):
assert isinstance(looped_clip, VideoFileClip)
# 4) Save the looped output
output_filename = MediaFileType(
output_filename = MediaFile(
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
)
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
@@ -162,10 +162,10 @@ class AddAudioToVideoBlock(Block):
"""
class Input(BlockSchema):
video_in: MediaFileType = SchemaField(
video_in: MediaFile = SchemaField(
description="Video input (URL, data URI, or local path)."
)
audio_in: MediaFileType = SchemaField(
audio_in: MediaFile = SchemaField(
description="Audio input (URL, data URI, or local path)."
)
volume: float = SchemaField(
@@ -178,7 +178,7 @@ class AddAudioToVideoBlock(Block):
)
class Output(BlockSchema):
video_out: MediaFileType = SchemaField(
video_out: MediaFile = SchemaField(
description="Final video (with attached audio), as a path or data URI."
)
error: str = SchemaField(
@@ -229,7 +229,7 @@ class AddAudioToVideoBlock(Block):
final_clip = video_clip.with_audio(audio_clip)
# 4) Write to output file
output_filename = MediaFileType(
output_filename = MediaFile(
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
)
output_abspath = os.path.join(abs_temp_dir, output_filename)

View File

@@ -6,14 +6,13 @@ from backend.blocks.nvidia._auth import (
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import requests
from backend.util.type import MediaFileType
class NvidiaDeepfakeDetectBlock(Block):
class Input(BlockSchema):
credentials: NvidiaCredentialsInput = NvidiaCredentialsField()
image_base64: MediaFileType = SchemaField(
description="Image to analyze for deepfakes",
image_base64: str = SchemaField(
description="Image to analyze for deepfakes", image_upload=True
)
return_image: bool = SchemaField(
description="Whether to return the processed image with markings",
@@ -23,12 +22,16 @@ class NvidiaDeepfakeDetectBlock(Block):
class Output(BlockSchema):
status: str = SchemaField(
description="Detection status (SUCCESS, ERROR, CONTENT_FILTERED)",
default="",
)
image: MediaFileType = SchemaField(
image: str = SchemaField(
description="Processed image with detection markings (if return_image=True)",
default="",
image_output=True,
)
is_deepfake: float = SchemaField(
description="Probability that the image is a deepfake (0-1)",
default=0.0,
)
def __init__(self):

View File

@@ -12,7 +12,7 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import MediaFileType, store_media_file
from backend.util.file import MediaFile, store_media_file
from backend.util.request import Requests
@@ -57,7 +57,7 @@ class ScreenshotWebPageBlock(Block):
)
class Output(BlockSchema):
image: MediaFileType = SchemaField(description="The screenshot image data")
image: MediaFile = SchemaField(description="The screenshot image data")
error: str = SchemaField(description="Error message if the screenshot failed")
def __init__(self):
@@ -142,9 +142,7 @@ class ScreenshotWebPageBlock(Block):
return {
"image": store_media_file(
graph_exec_id=graph_exec_id,
file=MediaFileType(
f"data:image/{format.value};base64,{b64encode(response.content).decode('utf-8')}"
),
file=f"data:image/{format.value};base64,{b64encode(response.content).decode('utf-8')}",
return_content=True,
)
}

View File

@@ -8,7 +8,6 @@ from backend.data.block import (
BlockWebhookConfig,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.util import settings
from backend.util.settings import AppEnvironment, BehaveAs
@@ -83,7 +82,7 @@ class Slant3DOrderWebhookBlock(Slant3DTriggerBase, Block):
input_schema=self.Input,
output_schema=self.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.SLANT3D,
provider="slant3d",
webhook_type="orders", # Only one type for now
resource_format="", # No resource format needed
event_filter_input="events",

View File

@@ -1,511 +0,0 @@
import logging
import re
from collections import Counter
from typing import TYPE_CHECKING, Any
from autogpt_libs.utils.cache import thread_cached
import backend.blocks.llm as llm
from backend.blocks.agent import AgentExecutorBlock
from backend.data.block import (
Block,
BlockCategory,
BlockInput,
BlockOutput,
BlockSchema,
BlockType,
get_block,
)
from backend.data.model import SchemaField
from backend.util import json
if TYPE_CHECKING:
from backend.data.graph import Link, Node
logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManager
from backend.util.service import get_service_client
return get_service_client(DatabaseManager)
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
"""
Return a list of tool_call_ids if the entry is a tool request.
Supports both OpenAI and Anthropics formats.
"""
tool_call_ids = []
if entry.get("role") != "assistant":
return tool_call_ids
# OpenAI: check for tool_calls in the entry.
calls = entry.get("tool_calls")
if isinstance(calls, list):
for call in calls:
if tool_id := call.get("id"):
tool_call_ids.append(tool_id)
# Anthropics: check content items for tool_use type.
content = entry.get("content")
if isinstance(content, list):
for item in content:
if item.get("type") != "tool_use":
continue
if tool_id := item.get("id"):
tool_call_ids.append(tool_id)
return tool_call_ids
def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
"""
Return a list of tool_call_ids if the entry is a tool response.
Supports both OpenAI and Anthropics formats.
"""
tool_call_ids: list[str] = []
# OpenAI: a tool response message with role "tool" and key "tool_call_id".
if entry.get("role") == "tool":
if tool_call_id := entry.get("tool_call_id"):
tool_call_ids.append(str(tool_call_id))
# Anthropics: check content items for tool_result type.
if entry.get("role") == "user":
content = entry.get("content")
if isinstance(content, list):
for item in content:
if item.get("type") != "tool_result":
continue
if tool_call_id := item.get("tool_use_id"):
tool_call_ids.append(tool_call_id)
return tool_call_ids
def _create_tool_response(call_id: str, output: dict[str, Any]) -> dict[str, Any]:
"""
Create a tool response message for either OpenAI or Anthropics,
based on the tool_id format.
"""
content = output if isinstance(output, str) else json.dumps(output)
# Anthropics format: tool IDs typically start with "toolu_"
if call_id.startswith("toolu_"):
return {
"role": "user",
"type": "message",
"content": [
{"tool_use_id": call_id, "type": "tool_result", "content": content}
],
}
# OpenAI format: tool IDs typically start with "call_".
# Or default fallback (if the tool_id doesn't match any known prefix)
return {"role": "tool", "tool_call_id": call_id, "content": content}
def get_pending_tool_calls(conversation_history: list[Any]) -> dict[str, int]:
"""
All the tool calls entry in the conversation history requires a response.
This function returns the pending tool calls that has not generated an output yet.
Return: dict[str, int] - A dictionary of pending tool call IDs with their count.
"""
pending_calls = Counter()
for history in conversation_history:
for call_id in _get_tool_requests(history):
pending_calls[call_id] += 1
for call_id in _get_tool_responses(history):
pending_calls[call_id] -= 1
return {call_id: count for call_id, count in pending_calls.items() if count > 0}
class SmartDecisionMakerBlock(Block):
"""
A block that uses a language model to make smart decisions based on a given prompt.
"""
class Input(BlockSchema):
prompt: str = SchemaField(
description="The prompt to send to the language model.",
placeholder="Enter your prompt here...",
)
model: llm.LlmModel = SchemaField(
title="LLM Model",
default=llm.LlmModel.GPT4O,
description="The language model to use for answering the prompt.",
advanced=False,
)
credentials: llm.AICredentials = llm.AICredentialsField()
sys_prompt: str = SchemaField(
title="System Prompt",
default="Thinking carefully step by step decide which function to call. "
"Always choose a function call from the list of function signatures, "
"and always provide the complete argument provided with the type "
"matching the required jsonschema signature, no missing argument is allowed. "
"If you have already completed the task objective, you can end the task "
"by providing the end result of your work as a finish message. "
"Only provide EXACTLY one function call, multiple tool calls is strictly prohibited.",
description="The system prompt to provide additional context to the model.",
)
conversation_history: list[dict] = SchemaField(
default=[],
description="The conversation history to provide context for the prompt.",
)
last_tool_output: Any = SchemaField(
default=None,
description="The output of the last tool that was called.",
)
retry: int = SchemaField(
title="Retry Count",
default=3,
description="Number of times to retry the LLM call if the response does not match the expected format.",
)
prompt_values: dict[str, str] = SchemaField(
advanced=False,
default={},
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
)
max_tokens: int | None = SchemaField(
advanced=True,
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
ollama_host: str = SchemaField(
advanced=True,
default="localhost:11434",
description="Ollama host for local models",
)
@classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
# conversation_history & last_tool_output validation is handled differently
missing_links = super().get_missing_links(
data,
[
link
for link in links
if link.sink_name
not in ["conversation_history", "last_tool_output"]
],
)
# Avoid executing the block if the last_tool_output is connected to a static
# link, like StoreValueBlock or AgentInputBlock.
if any(link.sink_name == "conversation_history" for link in links) and any(
link.sink_name == "last_tool_output" and link.is_static
for link in links
):
raise ValueError(
"Last Tool Output can't be connected to a static (dashed line) "
"link like the output of `StoreValue` or `AgentInput` block"
)
return missing_links
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
if missing_input := super().get_missing_input(data):
return missing_input
conversation_history = data.get("conversation_history", [])
pending_tool_calls = get_pending_tool_calls(conversation_history)
last_tool_output = data.get("last_tool_output")
if not last_tool_output and pending_tool_calls:
return {"last_tool_output"}
return set()
class Output(BlockSchema):
error: str = SchemaField(description="Error message if the API call failed.")
tools: Any = SchemaField(description="The tools that are available to use.")
finished: str = SchemaField(
description="The finished message to display to the user."
)
conversations: list[Any] = SchemaField(
description="The conversation history to provide context for the prompt."
)
def __init__(self):
super().__init__(
id="3b191d9f-356f-482d-8238-ba04b6d18381",
description="Uses AI to intelligently decide what tool to use.",
categories={BlockCategory.AI},
block_type=BlockType.AI,
input_schema=SmartDecisionMakerBlock.Input,
output_schema=SmartDecisionMakerBlock.Output,
test_input={
"prompt": "Hello, World!",
"credentials": llm.TEST_CREDENTIALS_INPUT,
},
test_output=[],
test_credentials=llm.TEST_CREDENTIALS,
)
@staticmethod
def _create_block_function_signature(
sink_node: "Node", links: list["Link"]
) -> dict[str, Any]:
"""
Creates a function signature for a block node.
Args:
sink_node: The node for which to create a function signature.
links: The list of links connected to the sink node.
Returns:
A dictionary representing the function signature in the format expected by LLM tools.
Raises:
ValueError: If the block specified by sink_node.block_id is not found.
"""
block = get_block(sink_node.block_id)
if not block:
raise ValueError(f"Block not found: {sink_node.block_id}")
tool_function: dict[str, Any] = {
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", block.name).lower(),
"description": block.description,
}
properties = {}
required = []
for link in links:
sink_block_input_schema = block.input_schema
description = (
sink_block_input_schema.model_fields[link.sink_name].description
if link.sink_name in sink_block_input_schema.model_fields
and sink_block_input_schema.model_fields[link.sink_name].description
else f"The {link.sink_name} of the tool"
)
properties[link.sink_name.lower()] = {
"type": "string",
"description": description,
}
tool_function["parameters"] = {
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False,
"strict": True,
}
return {"type": "function", "function": tool_function}
@staticmethod
def _create_agent_function_signature(
sink_node: "Node", links: list["Link"]
) -> dict[str, Any]:
"""
Creates a function signature for an agent node.
Args:
sink_node: The agent node for which to create a function signature.
links: The list of links connected to the sink node.
Returns:
A dictionary representing the function signature in the format expected by LLM tools.
Raises:
ValueError: If the graph metadata for the specified graph_id and graph_version is not found.
"""
graph_id = sink_node.input_default.get("graph_id")
graph_version = sink_node.input_default.get("graph_version")
if not graph_id or not graph_version:
raise ValueError("Graph ID or Graph Version not found in sink node.")
db_client = get_database_manager_client()
sink_graph_meta = db_client.get_graph_metadata(graph_id, graph_version)
if not sink_graph_meta:
raise ValueError(
f"Sink graph metadata not found: {graph_id} {graph_version}"
)
tool_function: dict[str, Any] = {
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", sink_graph_meta.name).lower(),
"description": sink_graph_meta.description,
}
properties = {}
required = []
for link in links:
sink_block_input_schema = sink_node.input_default["input_schema"]
description = (
sink_block_input_schema["properties"][link.sink_name]["description"]
if "description"
in sink_block_input_schema["properties"][link.sink_name]
else f"The {link.sink_name} of the tool"
)
properties[link.sink_name.lower()] = {
"type": "string",
"description": description,
}
tool_function["parameters"] = {
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False,
"strict": True,
}
return {"type": "function", "function": tool_function}
@staticmethod
def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
"""
Creates function signatures for tools linked to a specified node within a graph.
This method filters the graph links to identify those that are tools and are
connected to the given node_id. It then constructs function signatures for each
tool based on the metadata and input schema of the linked nodes.
Args:
node_id: The node_id for which to create function signatures.
Returns:
list[dict[str, Any]]: A list of dictionaries, each representing a function signature
for a tool, including its name, description, and parameters.
Raises:
ValueError: If no tool links are found for the specified node_id, or if a sink node
or its metadata cannot be found.
"""
db_client = get_database_manager_client()
tools = [
(link, node)
for link, node in db_client.get_connected_output_nodes(node_id)
if link.source_name.startswith("tools_^_") and link.source_id == node_id
]
if not tools:
raise ValueError("There is no next node to execute.")
return_tool_functions = []
grouped_tool_links: dict[str, tuple["Node", list["Link"]]] = {}
for link, node in tools:
if link.sink_id not in grouped_tool_links:
grouped_tool_links[link.sink_id] = (node, [link])
else:
grouped_tool_links[link.sink_id][1].append(link)
for sink_node, links in grouped_tool_links.values():
if not sink_node:
raise ValueError(f"Sink node not found: {links[0].sink_id}")
if sink_node.block_id == AgentExecutorBlock().id:
return_tool_functions.append(
SmartDecisionMakerBlock._create_agent_function_signature(
sink_node, links
)
)
else:
return_tool_functions.append(
SmartDecisionMakerBlock._create_block_function_signature(
sink_node, links
)
)
return return_tool_functions
def run(
self,
input_data: Input,
*,
credentials: llm.APIKeyCredentials,
graph_id: str,
node_id: str,
graph_exec_id: str,
node_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
tool_functions = self._create_function_signature(node_id)
input_data.conversation_history = input_data.conversation_history or []
prompt = [json.to_dict(p) for p in input_data.conversation_history if p]
pending_tool_calls = get_pending_tool_calls(input_data.conversation_history)
if pending_tool_calls and not input_data.last_tool_output:
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
# Prefill all missing tool calls with the last tool output/
# TODO: we need a better way to handle this.
tool_output = [
_create_tool_response(pending_call_id, input_data.last_tool_output)
for pending_call_id, count in pending_tool_calls.items()
for _ in range(count)
]
# If the SDM block only calls 1 tool at a time, this should not happen.
if len(tool_output) > 1:
logger.warning(
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
f"Multiple pending tool calls are prefilled using a single output. "
f"Execution may not be accurate."
)
# Fallback on adding tool output in the conversation history as user prompt.
if len(tool_output) == 0 and input_data.last_tool_output:
logger.warning(
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
f"No pending tool calls found. This may indicate an issue with the "
f"conversation history, or an LLM calling two tools at the same time."
)
tool_output.append(
{
"role": "user",
"content": f"Last tool output: {json.dumps(input_data.last_tool_output)}",
}
)
prompt.extend(tool_output)
values = input_data.prompt_values
if values:
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
prefix = "[Main Objective Prompt]: "
if input_data.sys_prompt and not any(
p["role"] == "system" and p["content"].startswith(prefix) for p in prompt
):
prompt.append({"role": "system", "content": prefix + input_data.sys_prompt})
if input_data.prompt and not any(
p["role"] == "user" and p["content"].startswith(prefix) for p in prompt
):
prompt.append({"role": "user", "content": prefix + input_data.prompt})
response = llm.llm_call(
credentials=credentials,
llm_model=input_data.model,
prompt=prompt,
json_format=False,
max_tokens=input_data.max_tokens,
tools=tool_functions,
ollama_host=input_data.ollama_host,
)
if not response.tool_calls:
yield "finished", response.response
return
for tool_call in response.tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
for arg_name, arg_value in tool_args.items():
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
response.prompt.append(response.raw_response)
yield "conversations", response.prompt

View File

@@ -1,97 +0,0 @@
from backend.blocks.smartlead.models import (
AddLeadsRequest,
AddLeadsToCampaignResponse,
CreateCampaignRequest,
CreateCampaignResponse,
SaveSequencesRequest,
SaveSequencesResponse,
)
from backend.util.request import Requests
class SmartLeadClient:
"""Client for the SmartLead API"""
# This api is stupid and requires your api key in the url. DO NOT RAISE ERRORS FOR BAD REQUESTS.
# FILTER OUT THE API KEY FROM THE ERROR MESSAGE.
API_URL = "https://server.smartlead.ai/api/v1"
def __init__(self, api_key: str):
self.api_key = api_key
self.requests = Requests()
def _add_auth_to_url(self, url: str) -> str:
return f"{url}?api_key={self.api_key}"
def _handle_error(self, e: Exception) -> str:
return e.__str__().replace(self.api_key, "API KEY")
def create_campaign(self, request: CreateCampaignRequest) -> CreateCampaignResponse:
try:
response = self.requests.post(
self._add_auth_to_url(f"{self.API_URL}/campaigns/create"),
json=request.model_dump(),
)
response_data = response.json()
return CreateCampaignResponse(**response_data)
except ValueError as e:
raise ValueError(f"Invalid response format: {str(e)}")
except Exception as e:
raise ValueError(f"Failed to create campaign: {self._handle_error(e)}")
def add_leads_to_campaign(
self, request: AddLeadsRequest
) -> AddLeadsToCampaignResponse:
try:
response = self.requests.post(
self._add_auth_to_url(
f"{self.API_URL}/campaigns/{request.campaign_id}/leads"
),
json=request.model_dump(exclude={"campaign_id"}),
)
response_data = response.json()
response_parsed = AddLeadsToCampaignResponse(**response_data)
if not response_parsed.ok:
raise ValueError(
f"Failed to add leads to campaign: {response_parsed.error}"
)
return response_parsed
except ValueError as e:
raise ValueError(f"Invalid response format: {str(e)}")
except Exception as e:
raise ValueError(
f"Failed to add leads to campaign: {self._handle_error(e)}"
)
def save_campaign_sequences(
self, campaign_id: int, request: SaveSequencesRequest
) -> SaveSequencesResponse:
"""
Save sequences within a campaign.
Args:
campaign_id: ID of the campaign to save sequences for
request: SaveSequencesRequest containing the sequences configuration
Returns:
SaveSequencesResponse with the result of the operation
Note:
For variant_distribution_type:
- MANUAL_EQUAL: Equally distributes variants across leads
- AI_EQUAL: Requires winning_metric_property and lead_distribution_percentage
- MANUAL_PERCENTAGE: Requires variant_distribution_percentage in seq_variants
"""
try:
response = self.requests.post(
self._add_auth_to_url(
f"{self.API_URL}/campaigns/{campaign_id}/sequences"
),
json=request.model_dump(exclude_none=True),
)
return SaveSequencesResponse(**response.json())
except Exception as e:
raise ValueError(
f"Failed to save campaign sequences: {e.__str__().replace(self.api_key, 'API KEY')}"
)

View File

@@ -1,35 +0,0 @@
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
SmartLeadCredentials = APIKeyCredentials
SmartLeadCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.SMARTLEAD],
Literal["api_key"],
]
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="smartlead",
api_key=SecretStr("mock-smartlead-api-key"),
title="Mock SmartLead API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
def SmartLeadCredentialsField() -> SmartLeadCredentialsInput:
"""
Creates a SmartLead credentials input on a block.
"""
return CredentialsField(
description="The SmartLead integration can be used with an API Key.",
)

View File

@@ -1,326 +0,0 @@
from backend.blocks.smartlead._api import SmartLeadClient
from backend.blocks.smartlead._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
SmartLeadCredentials,
SmartLeadCredentialsInput,
)
from backend.blocks.smartlead.models import (
AddLeadsRequest,
AddLeadsToCampaignResponse,
CreateCampaignRequest,
CreateCampaignResponse,
LeadInput,
LeadUploadSettings,
SaveSequencesRequest,
SaveSequencesResponse,
Sequence,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class CreateCampaignBlock(Block):
"""Create a campaign in SmartLead"""
class Input(BlockSchema):
name: str = SchemaField(
description="The name of the campaign",
)
credentials: SmartLeadCredentialsInput = SchemaField(
description="SmartLead credentials",
)
class Output(BlockSchema):
id: int = SchemaField(
description="The ID of the created campaign",
)
name: str = SchemaField(
description="The name of the created campaign",
)
created_at: str = SchemaField(
description="The date and time the campaign was created",
)
error: str = SchemaField(
description="Error message if the search failed",
default="",
)
def __init__(self):
super().__init__(
id="8865699f-9188-43c4-89b0-79c84cfaa03e",
description="Create a campaign in SmartLead",
categories={BlockCategory.CRM},
input_schema=CreateCampaignBlock.Input,
output_schema=CreateCampaignBlock.Output,
test_credentials=TEST_CREDENTIALS,
test_input={"name": "Test Campaign", "credentials": TEST_CREDENTIALS_INPUT},
test_output=[
(
"id",
1,
),
(
"name",
"Test Campaign",
),
(
"created_at",
"2024-01-01T00:00:00Z",
),
],
test_mock={
"create_campaign": lambda name, credentials: CreateCampaignResponse(
ok=True,
id=1,
name=name,
created_at="2024-01-01T00:00:00Z",
)
},
)
@staticmethod
def create_campaign(
name: str, credentials: SmartLeadCredentials
) -> CreateCampaignResponse:
client = SmartLeadClient(credentials.api_key.get_secret_value())
return client.create_campaign(CreateCampaignRequest(name=name))
def run(
self,
input_data: Input,
*,
credentials: SmartLeadCredentials,
**kwargs,
) -> BlockOutput:
response = self.create_campaign(input_data.name, credentials)
yield "id", response.id
yield "name", response.name
yield "created_at", response.created_at
if not response.ok:
yield "error", "Failed to create campaign"
class AddLeadToCampaignBlock(Block):
"""Add a lead to a campaign in SmartLead"""
class Input(BlockSchema):
campaign_id: int = SchemaField(
description="The ID of the campaign to add the lead to",
)
lead_list: list[LeadInput] = SchemaField(
description="An array of JSON objects, each representing a lead's details. Can hold max 100 leads.",
max_length=100,
default=[],
advanced=False,
)
settings: LeadUploadSettings = SchemaField(
description="Settings for lead upload",
default=LeadUploadSettings(),
)
credentials: SmartLeadCredentialsInput = SchemaField(
description="SmartLead credentials",
)
class Output(BlockSchema):
campaign_id: int = SchemaField(
description="The ID of the campaign the lead was added to (passed through)",
)
upload_count: int = SchemaField(
description="The number of leads added to the campaign",
)
already_added_to_campaign: int = SchemaField(
description="The number of leads that were already added to the campaign",
)
duplicate_count: int = SchemaField(
description="The number of emails that were duplicates",
)
invalid_email_count: int = SchemaField(
description="The number of emails that were invalidly formatted",
)
is_lead_limit_exhausted: bool = SchemaField(
description="Whether the lead limit was exhausted",
)
lead_import_stopped_count: int = SchemaField(
description="The number of leads that were not added to the campaign because the lead import was stopped",
)
error: str = SchemaField(
description="Error message if the lead was not added to the campaign",
default="",
)
def __init__(self):
super().__init__(
id="fb8106a4-1a8f-42f9-a502-f6d07e6fe0ec",
description="Add a lead to a campaign in SmartLead",
categories={BlockCategory.CRM},
input_schema=AddLeadToCampaignBlock.Input,
output_schema=AddLeadToCampaignBlock.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"campaign_id": 1,
"lead_list": [],
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"campaign_id",
1,
),
(
"upload_count",
1,
),
],
test_mock={
"add_leads_to_campaign": lambda campaign_id, lead_list, credentials: AddLeadsToCampaignResponse(
ok=True,
upload_count=1,
already_added_to_campaign=0,
duplicate_count=0,
invalid_email_count=0,
is_lead_limit_exhausted=False,
lead_import_stopped_count=0,
error="",
total_leads=1,
block_count=0,
invalid_emails=[],
unsubscribed_leads=[],
bounce_count=0,
)
},
)
@staticmethod
def add_leads_to_campaign(
campaign_id: int, lead_list: list[LeadInput], credentials: SmartLeadCredentials
) -> AddLeadsToCampaignResponse:
client = SmartLeadClient(credentials.api_key.get_secret_value())
return client.add_leads_to_campaign(
AddLeadsRequest(
campaign_id=campaign_id,
lead_list=lead_list,
settings=LeadUploadSettings(
ignore_global_block_list=False,
ignore_unsubscribe_list=False,
ignore_community_bounce_list=False,
ignore_duplicate_leads_in_other_campaign=False,
),
),
)
def run(
self,
input_data: Input,
*,
credentials: SmartLeadCredentials,
**kwargs,
) -> BlockOutput:
response = self.add_leads_to_campaign(
input_data.campaign_id, input_data.lead_list, credentials
)
yield "campaign_id", input_data.campaign_id
yield "upload_count", response.upload_count
if response.already_added_to_campaign:
yield "already_added_to_campaign", response.already_added_to_campaign
if response.duplicate_count:
yield "duplicate_count", response.duplicate_count
if response.invalid_email_count:
yield "invalid_email_count", response.invalid_email_count
if response.is_lead_limit_exhausted:
yield "is_lead_limit_exhausted", response.is_lead_limit_exhausted
if response.lead_import_stopped_count:
yield "lead_import_stopped_count", response.lead_import_stopped_count
if response.error:
yield "error", response.error
if not response.ok:
yield "error", "Failed to add leads to campaign"
class SaveCampaignSequencesBlock(Block):
"""Save sequences within a campaign"""
class Input(BlockSchema):
campaign_id: int = SchemaField(
description="The ID of the campaign to save sequences for",
)
sequences: list[Sequence] = SchemaField(
description="The sequences to save",
default=[],
advanced=False,
)
credentials: SmartLeadCredentialsInput = SchemaField(
description="SmartLead credentials",
)
class Output(BlockSchema):
data: dict | str | None = SchemaField(
description="Data from the API",
default=None,
)
message: str = SchemaField(
description="Message from the API",
default="",
)
error: str = SchemaField(
description="Error message if the sequences were not saved",
default="",
)
def __init__(self):
super().__init__(
id="e7d9f41c-dc10-4f39-98ba-a432abd128c0",
description="Save sequences within a campaign",
categories={BlockCategory.CRM},
input_schema=SaveCampaignSequencesBlock.Input,
output_schema=SaveCampaignSequencesBlock.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"campaign_id": 1,
"sequences": [],
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"message",
"Sequences saved successfully",
),
],
test_mock={
"save_campaign_sequences": lambda campaign_id, sequences, credentials: SaveSequencesResponse(
ok=True,
message="Sequences saved successfully",
)
},
)
@staticmethod
def save_campaign_sequences(
campaign_id: int, sequences: list[Sequence], credentials: SmartLeadCredentials
) -> SaveSequencesResponse:
client = SmartLeadClient(credentials.api_key.get_secret_value())
return client.save_campaign_sequences(
campaign_id=campaign_id, request=SaveSequencesRequest(sequences=sequences)
)
def run(
self,
input_data: Input,
*,
credentials: SmartLeadCredentials,
**kwargs,
) -> BlockOutput:
response = self.save_campaign_sequences(
input_data.campaign_id, input_data.sequences, credentials
)
if response.data:
yield "data", response.data
if response.message:
yield "message", response.message
if response.error:
yield "error", response.error
if not response.ok:
yield "error", "Failed to save sequences"

View File

@@ -1,147 +0,0 @@
from enum import Enum
from pydantic import BaseModel
from backend.data.model import SchemaField
class CreateCampaignResponse(BaseModel):
ok: bool
id: int
name: str
created_at: str
class CreateCampaignRequest(BaseModel):
name: str
client_id: str | None = None
class AddLeadsToCampaignResponse(BaseModel):
ok: bool
upload_count: int
total_leads: int
block_count: int
duplicate_count: int
invalid_email_count: int
invalid_emails: list[str]
already_added_to_campaign: int
unsubscribed_leads: list[str]
is_lead_limit_exhausted: bool
lead_import_stopped_count: int
bounce_count: int
error: str | None = None
class LeadCustomFields(BaseModel):
"""Custom fields for a lead (max 20 fields)"""
fields: dict[str, str] = SchemaField(
description="Custom fields for a lead (max 20 fields)",
max_length=20,
default={},
)
class LeadInput(BaseModel):
"""Single lead input data"""
first_name: str
last_name: str
email: str
phone_number: str | None = None # Changed from int to str for phone numbers
company_name: str | None = None
website: str | None = None
location: str | None = None
custom_fields: LeadCustomFields | None = None
linkedin_profile: str | None = None
company_url: str | None = None
class LeadUploadSettings(BaseModel):
"""Settings for lead upload"""
ignore_global_block_list: bool = SchemaField(
description="Ignore the global block list",
default=False,
)
ignore_unsubscribe_list: bool = SchemaField(
description="Ignore the unsubscribe list",
default=False,
)
ignore_community_bounce_list: bool = SchemaField(
description="Ignore the community bounce list",
default=False,
)
ignore_duplicate_leads_in_other_campaign: bool = SchemaField(
description="Ignore duplicate leads in other campaigns",
default=False,
)
class AddLeadsRequest(BaseModel):
"""Request body for adding leads to a campaign"""
lead_list: list[LeadInput] = SchemaField(
description="List of leads to add to the campaign",
max_length=100,
default=[],
)
settings: LeadUploadSettings
campaign_id: int
class VariantDistributionType(str, Enum):
MANUAL_EQUAL = "MANUAL_EQUAL"
MANUAL_PERCENTAGE = "MANUAL_PERCENTAGE"
AI_EQUAL = "AI_EQUAL"
class WinningMetricProperty(str, Enum):
OPEN_RATE = "OPEN_RATE"
CLICK_RATE = "CLICK_RATE"
REPLY_RATE = "REPLY_RATE"
POSITIVE_REPLY_RATE = "POSITIVE_REPLY_RATE"
class SequenceDelayDetails(BaseModel):
delay_in_days: int
class SequenceVariant(BaseModel):
subject: str
email_body: str
variant_label: str
id: int | None = None # Optional for creation, required for updates
variant_distribution_percentage: int | None = None
class Sequence(BaseModel):
seq_number: int = SchemaField(
description="The sequence number",
default=1,
)
seq_delay_details: SequenceDelayDetails
id: int | None = None
variant_distribution_type: VariantDistributionType | None = None
lead_distribution_percentage: int | None = SchemaField(
None, ge=20, le=100
) # >= 20% for fair calculation
winning_metric_property: WinningMetricProperty | None = None
seq_variants: list[SequenceVariant] | None = None
subject: str = "" # blank makes the follow up in the same thread
email_body: str | None = None
class SaveSequencesRequest(BaseModel):
sequences: list[Sequence]
class SaveSequencesResponse(BaseModel):
ok: bool
message: str = SchemaField(
description="Message from the API",
default="",
)
data: dict | str | None = None
error: str | None = None

View File

@@ -156,10 +156,6 @@ class CountdownTimerBlock(Block):
days: Union[int, str] = SchemaField(
advanced=False, description="Duration in days", default=0
)
repeat: int = SchemaField(
description="Number of times to repeat the timer",
default=1,
)
class Output(BlockSchema):
output_message: Any = SchemaField(
@@ -191,6 +187,5 @@ class CountdownTimerBlock(Block):
total_seconds = seconds + minutes * 60 + hours * 3600 + days * 86400
for _ in range(input_data.repeat):
time.sleep(total_seconds)
yield "output_message", input_data.input_message
time.sleep(total_seconds)
yield "output_message", input_data.input_message

View File

@@ -1,10 +0,0 @@
from zerobouncesdk import ZBValidateResponse, ZeroBounce
class ZeroBounceClient:
def __init__(self, api_key: str):
self.api_key = api_key
self.client = ZeroBounce(api_key)
def validate_email(self, email: str, ip_address: str) -> ZBValidateResponse:
return self.client.validate(email, ip_address)

View File

@@ -1,35 +0,0 @@
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
ZeroBounceCredentials = APIKeyCredentials
ZeroBounceCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.ZEROBOUNCE],
Literal["api_key"],
]
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="zerobounce",
api_key=SecretStr("mock-zerobounce-api-key"),
title="Mock ZeroBounce API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
def ZeroBounceCredentialsField() -> ZeroBounceCredentialsInput:
"""
Creates a ZeroBounce credentials input on a block.
"""
return CredentialsField(
description="The ZeroBounce integration can be used with an API Key.",
)

View File

@@ -1,175 +0,0 @@
from typing import Optional
from pydantic import BaseModel
from zerobouncesdk.zb_validate_response import (
ZBValidateResponse,
ZBValidateStatus,
ZBValidateSubStatus,
)
from backend.blocks.zerobounce._api import ZeroBounceClient
from backend.blocks.zerobounce._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
ZeroBounceCredentials,
ZeroBounceCredentialsInput,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class Response(BaseModel):
address: str = SchemaField(
description="The email address you are validating.", default="N/A"
)
status: ZBValidateStatus = SchemaField(
description="The status of the email address.", default=ZBValidateStatus.unknown
)
sub_status: ZBValidateSubStatus = SchemaField(
description="The sub-status of the email address.",
default=ZBValidateSubStatus.none,
)
account: Optional[str] = SchemaField(
description="The portion of the email address before the '@' symbol.",
default="N/A",
)
domain: Optional[str] = SchemaField(
description="The portion of the email address after the '@' symbol."
)
did_you_mean: Optional[str] = SchemaField(
description="Suggestive Fix for an email typo",
default=None,
)
domain_age_days: Optional[str] = SchemaField(
description="Age of the email domain in days or [null].",
default=None,
)
free_email: Optional[bool] = SchemaField(
description="Whether the email address is a free email provider.", default=False
)
mx_found: Optional[bool] = SchemaField(
description="Whether the MX record was found.", default=False
)
mx_record: Optional[str] = SchemaField(
description="The MX record of the email address.", default=None
)
smtp_provider: Optional[str] = SchemaField(
description="The SMTP provider of the email address.", default=None
)
firstname: Optional[str] = SchemaField(
description="The first name of the email address.", default=None
)
lastname: Optional[str] = SchemaField(
description="The last name of the email address.", default=None
)
gender: Optional[str] = SchemaField(
description="The gender of the email address.", default=None
)
city: Optional[str] = SchemaField(
description="The city of the email address.", default=None
)
region: Optional[str] = SchemaField(
description="The region of the email address.", default=None
)
zipcode: Optional[str] = SchemaField(
description="The zipcode of the email address.", default=None
)
country: Optional[str] = SchemaField(
description="The country of the email address.", default=None
)
class ValidateEmailsBlock(Block):
"""Search for people in Apollo"""
class Input(BlockSchema):
email: str = SchemaField(
description="Email to validate",
)
ip_address: str = SchemaField(
description="IP address to validate",
default="",
)
credentials: ZeroBounceCredentialsInput = SchemaField(
description="ZeroBounce credentials",
)
class Output(BlockSchema):
response: Response = SchemaField(
description="Response from ZeroBounce",
)
error: str = SchemaField(
description="Error message if the search failed",
default="",
)
def __init__(self):
super().__init__(
id="e3950439-fa0b-40e8-b19f-e0dca0bf5853",
description="Validate emails",
categories={BlockCategory.SEARCH},
input_schema=ValidateEmailsBlock.Input,
output_schema=ValidateEmailsBlock.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"email": "test@test.com",
},
test_output=[
(
"response",
Response(
address="test@test.com",
status=ZBValidateStatus.valid,
sub_status=ZBValidateSubStatus.allowed,
account="test",
domain="test.com",
did_you_mean=None,
domain_age_days=None,
free_email=False,
mx_found=False,
mx_record=None,
smtp_provider=None,
),
)
],
test_mock={
"validate_email": lambda email, ip_address, credentials: ZBValidateResponse(
data={
"address": email,
"status": ZBValidateStatus.valid,
"sub_status": ZBValidateSubStatus.allowed,
"account": "test",
"domain": "test.com",
"did_you_mean": None,
"domain_age_days": None,
"free_email": False,
"mx_found": False,
"mx_record": None,
"smtp_provider": None,
}
)
},
)
@staticmethod
def validate_email(
email: str, ip_address: str, credentials: ZeroBounceCredentials
) -> ZBValidateResponse:
client = ZeroBounceClient(credentials.api_key.get_secret_value())
return client.validate_email(email, ip_address)
def run(
self,
input_data: Input,
*,
credentials: ZeroBounceCredentials,
**kwargs,
) -> BlockOutput:
response: ZBValidateResponse = self.validate_email(
input_data.email, input_data.ip_address, credentials
)
response_model = Response(**response.__dict__)
yield "response", response_model

View File

@@ -220,8 +220,9 @@ def event():
@test.command()
@click.argument("server_address")
@click.argument("graph_exec_id")
def websocket(server_address: str, graph_exec_id: str):
@click.argument("graph_id")
@click.argument("graph_version")
def websocket(server_address: str, graph_id: str, graph_version: int):
"""
Tests the websocket connection.
"""
@@ -229,20 +230,16 @@ def websocket(server_address: str, graph_exec_id: str):
import websockets.asyncio.client
from backend.server.ws_api import (
WSMessage,
WSMethod,
WSSubscribeGraphExecutionRequest,
)
from backend.server.ws_api import ExecutionSubscription, Methods, WsMessage
async def send_message(server_address: str):
uri = f"ws://{server_address}"
async with websockets.asyncio.client.connect(uri) as websocket:
try:
msg = WSMessage(
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
data=WSSubscribeGraphExecutionRequest(
graph_exec_id=graph_exec_id,
msg = WsMessage(
method=Methods.SUBSCRIBE,
data=ExecutionSubscription(
graph_id=graph_id, graph_version=graph_version
).model_dump(),
).model_dump_json()
await websocket.send(msg)

View File

@@ -2,7 +2,6 @@ import inspect
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Generator,
@@ -19,8 +18,6 @@ import jsonschema
from prisma.models import AgentBlock
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.settings import Config
@@ -31,9 +28,6 @@ from .model import (
is_credentials_field_name,
)
if TYPE_CHECKING:
from .graph import Link
app_config = Config()
BlockData = tuple[str, Any] # Input & Output data should be a tuple of (name, data).
@@ -50,7 +44,6 @@ class BlockType(Enum):
WEBHOOK = "Webhook"
WEBHOOK_MANUAL = "Webhook (manual)"
AGENT = "Agent"
AI = "AI"
class BlockCategory(Enum):
@@ -116,30 +109,21 @@ class BlockSchema(BaseModel):
def validate_data(cls, data: BlockInput) -> str | None:
return json.validate_with_jsonschema(schema=cls.jsonschema(), data=data)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return cls.validate_data(data)
@classmethod
def get_field_schema(cls, field_name: str) -> dict[str, Any]:
model_schema = cls.jsonschema().get("properties", {})
if not model_schema:
raise ValueError(f"Invalid model schema {cls}")
property_schema = model_schema.get(field_name)
if not property_schema:
raise ValueError(f"Invalid property name {field_name}")
return property_schema
@classmethod
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:
"""
Validate the data against a specific property (one of the input/output name).
Returns the validation error message if the data does not match the schema.
"""
model_schema = cls.jsonschema().get("properties", {})
if not model_schema:
return f"Invalid model schema {cls}"
property_schema = model_schema.get(field_name)
if not property_schema:
return f"Invalid property name {field_name}"
try:
property_schema = cls.get_field_schema(field_name)
jsonschema.validate(json.to_dict(data), property_schema)
return None
except jsonschema.ValidationError as e:
@@ -202,19 +186,6 @@ class BlockSchema(BaseModel):
)
}
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data # Return as is, by default.
@classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
input_fields_from_nodes = {link.sink_name for link in links}
return input_fields_from_nodes - set(data)
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
return cls.get_required_fields() - set(data)
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema)
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchema)
@@ -231,7 +202,7 @@ class BlockManualWebhookConfig(BaseModel):
the user has to manually set up the webhook at the provider.
"""
provider: ProviderName
provider: str
"""The service provider that the webhook connects to"""
webhook_type: str
@@ -323,7 +294,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.static_output = static_output
self.block_type = block_type
self.webhook_config = webhook_config
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
self.execution_stats = {}
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
@@ -380,14 +351,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
Run the block with the given input data.
Args:
input_data: The input data with the structure of input_schema.
Kwargs: Currently 14/02/2025 these include
graph_id: The ID of the graph.
node_id: The ID of the node.
graph_exec_id: The ID of the graph execution.
node_exec_id: The ID of the node execution.
user_id: The ID of the user.
Returns:
A Generator that yields (output_name, output_data).
output_name: One of the output name defined in Block's output_schema.
@@ -401,29 +364,18 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
return data
raise ValueError(f"{self.name} did not produce any output for {output}")
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
stats_dict = stats.model_dump()
current_stats = self.execution_stats.model_dump()
for key, value in stats_dict.items():
if key not in current_stats:
# Field doesn't exist yet, just set it, but this will probably
# not happen, just in case though so we throw for invalid when
# converting back in
current_stats[key] = value
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
current_stats[key].update(value)
elif isinstance(value, (int, float)) and isinstance(
current_stats[key], (int, float)
):
current_stats[key] += value
elif isinstance(value, list) and isinstance(current_stats[key], list):
current_stats[key].extend(value)
def merge_stats(self, stats: dict[str, Any]) -> dict[str, Any]:
for key, value in stats.items():
if isinstance(value, dict):
self.execution_stats.setdefault(key, {}).update(value)
elif isinstance(value, (int, float)):
self.execution_stats.setdefault(key, 0)
self.execution_stats[key] += value
elif isinstance(value, list):
self.execution_stats.setdefault(key, [])
self.execution_stats[key].extend(value)
else:
current_stats[key] = value
self.execution_stats = NodeExecutionStats(**current_stats)
self.execution_stats[key] = value
return self.execution_stats
@property
@@ -446,6 +398,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
}
def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
# Merge the input data with the extra execution arguments, preferring the args for security
if error := self.input_schema.validate_data(input_data):
raise ValueError(
f"Unable to execute block with invalid input data: {error}"
@@ -467,9 +420,9 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
def get_blocks() -> dict[str, Type[Block]]:
from backend.blocks import load_all_blocks
from backend.blocks import AVAILABLE_BLOCKS # noqa: E402
return load_all_blocks()
return AVAILABLE_BLOCKS
async def initialize_blocks() -> None:

View File

@@ -15,7 +15,6 @@ from backend.blocks.llm import (
LlmModel,
)
from backend.blocks.replicate_flux_advanced import ReplicateFluxAdvancedModelBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
from backend.data.block import Block
@@ -266,5 +265,4 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
},
)
],
SmartDecisionMakerBlock: LLM_COST,
}

View File

@@ -1,43 +1,28 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timezone
import stripe
from autogpt_libs.utils.cache import thread_cached
from prisma import Json
from prisma.enums import (
CreditRefundRequestStatus,
CreditTransactionType,
NotificationType,
)
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditRefundRequest, CreditTransaction, User
from prisma.models import CreditTransaction, User
from prisma.types import CreditTransactionCreateInput, CreditTransactionWhereInput
from tenacity import retry, stop_after_attempt, wait_exponential
from pydantic import BaseModel
from backend.data import db
from backend.data.block import Block, BlockInput, get_block
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCost
from backend.data.model import (
AutoTopUpConfig,
RefundRequest,
TransactionHistory,
UserTransaction,
)
from backend.data.notifications import NotificationEventDTO, RefundRequestData
from backend.data.cost import BlockCost, BlockCostType
from backend.data.execution import NodeExecutionEntry
from backend.data.model import AutoTopUpConfig, TransactionHistory, UserTransaction
from backend.data.user import get_user_by_id
from backend.executor.utils import UsageTransactionMetadata
from backend.notifications import NotificationManager
from backend.util.exceptions import InsufficientBalanceError
from backend.util.service import get_service_client
from backend.util.settings import Settings
settings = Settings()
stripe.api_key = settings.secrets.stripe_api_key
logger = logging.getLogger(__name__)
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
class UserCreditBase(ABC):
@@ -55,61 +40,46 @@ class UserCreditBase(ABC):
async def get_transaction_history(
self,
user_id: str,
transaction_time: datetime,
transaction_count_limit: int,
transaction_time_ceiling: datetime | None = None,
transaction_type: str | None = None,
) -> TransactionHistory:
"""
Get the credit transactions for the user.
Args:
user_id (str): The user ID.
transaction_time (datetime): The upper bound of the transaction time.
transaction_count_limit (int): The transaction count limit.
transaction_time_ceiling (datetime): The upper bound of the transaction time.
transaction_type (str): The transaction type filter.
Returns:
TransactionHistory: The credit transactions for the user.
"""
pass
@abstractmethod
async def get_refund_requests(self, user_id: str) -> list[RefundRequest]:
"""
Get the refund requests for the user.
Args:
user_id (str): The user ID.
Returns:
list[RefundRequest]: The refund requests for the user.
"""
pass
@abstractmethod
async def spend_credits(
self,
user_id: str,
cost: int,
metadata: UsageTransactionMetadata,
entry: NodeExecutionEntry,
data_size: float,
run_time: float,
) -> int:
"""
Spend the credits for the user based on the cost.
Spend the credits for the user based on the block usage.
Args:
user_id (str): The user ID.
cost (int): The cost to spend.
metadata (UsageTransactionMetadata): The metadata of the transaction.
entry (NodeExecutionEntry): The node execution identifiers & data.
data_size (float): The size of the data being processed.
run_time (float): The time taken to run the block.
Returns:
int: The remaining balance.
int: amount of credit spent
"""
pass
@abstractmethod
async def top_up_credits(self, user_id: str, amount: int):
"""
Top up the credits for the user.
Top up the credits for the user immediately.
Args:
user_id (str): The user ID.
@@ -131,46 +101,6 @@ class UserCreditBase(ABC):
"""
pass
@abstractmethod
async def top_up_refund(
self, user_id: str, transaction_key: str, metadata: dict[str, str]
) -> int:
"""
Refund the top-up transaction for the user.
Args:
user_id (str): The user ID.
transaction_key (str): The top-up transaction key to refund.
metadata (dict[str, str]): The metadata of the refund.
Returns:
int: The amount refunded.
"""
pass
@abstractmethod
async def deduct_credits(
self,
request: stripe.Refund | stripe.Dispute,
):
"""
Deduct the credits for the user based on the dispute or refund of the top-up.
Args:
request (stripe.Refund | stripe.Dispute): The refund or dispute request.
"""
pass
@abstractmethod
async def handle_dispute(self, dispute: stripe.Dispute):
"""
Handle the dispute for the user based on the dispute request.
Args:
dispute (stripe.Dispute): The dispute request.
"""
pass
@abstractmethod
async def fulfill_checkout(
self, *, session_id: str | None = None, user_id: str | None = None
@@ -184,14 +114,6 @@ class UserCreditBase(ABC):
"""
pass
@staticmethod
async def create_billing_portal_session(user_id: str) -> str:
session = stripe.billing_portal.Session.create(
customer=await get_stripe_customer_id(user_id),
return_url=base_url + "/profile/credits",
)
return session.url
@staticmethod
def time_now() -> datetime:
return datetime.now(timezone.utc)
@@ -245,18 +167,10 @@ class UserCreditBase(ABC):
)
return transaction_balance, transaction_time
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=1, max=10),
reraise=True,
)
async def _enable_transaction(
self,
transaction_key: str,
user_id: str,
metadata: Json,
new_transaction_key: str | None = None,
self, transaction_key: str, user_id: str, metadata: Json
):
transaction = await CreditTransaction.prisma().find_first_or_raise(
where={"transactionKey": transaction_key, "userId": user_id}
)
@@ -274,7 +188,6 @@ class UserCreditBase(ABC):
}
},
data={
"transactionKey": new_transaction_key or transaction_key,
"isActive": True,
"runningBalance": user_balance + transaction.amount,
"createdAt": self.time_now(),
@@ -290,7 +203,6 @@ class UserCreditBase(ABC):
is_active: bool = True,
transaction_key: str | None = None,
ceiling_balance: int | None = None,
fail_insufficient_credits: bool = True,
metadata: Json = Json({}),
) -> tuple[int, str]:
"""
@@ -304,7 +216,6 @@ class UserCreditBase(ABC):
is_active (bool): Whether the transaction is active or needs to be manually activated through _enable_transaction.
transaction_key (str | None): The transaction key. Avoids adding transaction if the key already exists.
ceiling_balance (int | None): The ceiling balance. Avoids adding more credits if the balance is already above the ceiling.
fail_insufficient_credits (bool): Whether to fail if the user has insufficient credits.
metadata (Json): The metadata of the transaction.
Returns:
@@ -314,21 +225,15 @@ class UserCreditBase(ABC):
# Get latest balance snapshot
user_balance, _ = await self._get_credits(user_id)
if ceiling_balance and amount > 0 and user_balance >= ceiling_balance:
if ceiling_balance and user_balance >= ceiling_balance:
raise ValueError(
f"You already have enough balance of ${user_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
f"You already have enough balance for user {user_id}, balance: {user_balance}, ceiling: {ceiling_balance}"
)
if amount < 0 and user_balance + amount < 0:
if fail_insufficient_credits:
raise InsufficientBalanceError(
message=f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}",
user_id=user_id,
balance=user_balance,
amount=amount,
)
amount = min(-user_balance, 0)
if amount < 0 and user_balance < abs(amount):
raise ValueError(
f"Insufficient balance of ${user_balance/100} to run the block that costs ${abs(amount)/100}"
)
# Create the transaction
transaction_data: CreditTransactionCreateInput = {
@@ -346,41 +251,101 @@ class UserCreditBase(ABC):
return user_balance + amount, tx.transactionKey
class UserCredit(UserCreditBase):
@thread_cached
def notification_client(self) -> NotificationManager:
return get_service_client(NotificationManager)
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
node_exec_id: str | None = None
block_id: str | None = None
block: str | None = None
input: BlockInput | None = None
async def _send_refund_notification(
class UserCredit(UserCreditBase):
def _block_usage_cost(
self,
notification_request: RefundRequestData,
notification_type: NotificationType,
):
await asyncio.to_thread(
lambda: self.notification_client().queue_notification(
NotificationEventDTO(
user_id=notification_request.user_id,
type=notification_type,
data=notification_request.model_dump(),
block: Block,
input_data: BlockInput,
data_size: float,
run_time: float,
) -> tuple[int, BlockInput]:
block_costs = BLOCK_COSTS.get(type(block))
if not block_costs:
return 0, {}
for block_cost in block_costs:
if not self._is_cost_filter_match(block_cost.cost_filter, input_data):
continue
if block_cost.cost_type == BlockCostType.RUN:
return block_cost.cost_amount, block_cost.cost_filter
if block_cost.cost_type == BlockCostType.SECOND:
return (
int(run_time * block_cost.cost_amount),
block_cost.cost_filter,
)
)
if block_cost.cost_type == BlockCostType.BYTE:
return (
int(data_size * block_cost.cost_amount),
block_cost.cost_filter,
)
return 0, {}
def _is_cost_filter_match(
self, cost_filter: BlockInput, input_data: BlockInput
) -> bool:
"""
Filter rules:
- If cost_filter is an object, then check if cost_filter is the subset of input_data
- Otherwise, check if cost_filter is equal to input_data.
- Undefined, null, and empty string are considered as equal.
"""
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
return cost_filter == input_data
return all(
(not input_data.get(k) and not v)
or (input_data.get(k) and self._is_cost_filter_match(v, input_data[k]))
for k, v in cost_filter.items()
)
async def spend_credits(
self,
user_id: str,
cost: int,
metadata: UsageTransactionMetadata,
entry: NodeExecutionEntry,
data_size: float,
run_time: float,
) -> int:
block = get_block(entry.block_id)
if not block:
raise ValueError(f"Block not found: {entry.block_id}")
cost, matching_filter = self._block_usage_cost(
block=block, input_data=entry.data, data_size=data_size, run_time=run_time
)
if cost == 0:
return 0
balance, _ = await self._add_transaction(
user_id=user_id,
user_id=entry.user_id,
amount=-cost,
transaction_type=CreditTransactionType.USAGE,
metadata=Json(metadata.model_dump()),
metadata=Json(
UsageTransactionMetadata(
graph_exec_id=entry.graph_exec_id,
graph_id=entry.graph_id,
node_id=entry.node_id,
node_exec_id=entry.node_exec_id,
block_id=entry.block_id,
block=block.name,
input=matching_filter,
).model_dump()
),
)
user_id = entry.user_id
# Auto top-up if balance is below threshold.
auto_top_up = await get_auto_top_up(user_id)
@@ -390,7 +355,7 @@ class UserCredit(UserCreditBase):
user_id=user_id,
amount=auto_top_up.amount,
# Avoid multiple auto top-ups within the same graph execution.
key=f"AUTO-TOP-UP-{user_id}-{metadata.graph_exec_id}",
key=f"AUTO-TOP-UP-{user_id}-{entry.graph_exec_id}",
ceiling_balance=auto_top_up.threshold,
)
except Exception as e:
@@ -399,187 +364,11 @@ class UserCredit(UserCreditBase):
f"Auto top-up failed for user {user_id}, balance: {balance}, amount: {auto_top_up.amount}, error: {e}"
)
return balance
return cost
async def top_up_credits(self, user_id: str, amount: int):
await self._top_up_credits(user_id, amount)
async def top_up_refund(
self, user_id: str, transaction_key: str, metadata: dict[str, str]
) -> int:
transaction = await CreditTransaction.prisma().find_first_or_raise(
where={
"transactionKey": transaction_key,
"userId": user_id,
"isActive": True,
"type": CreditTransactionType.TOP_UP,
}
)
balance = await self.get_credits(user_id)
amount = transaction.amount
refund_key_format = settings.config.refund_request_time_key_format
refund_key = f"{transaction.createdAt.strftime(refund_key_format)}-{user_id}"
try:
refund_request = await CreditRefundRequest.prisma().create(
data={
"id": refund_key,
"transactionKey": transaction_key,
"userId": user_id,
"amount": amount,
"reason": metadata.get("reason", ""),
"status": CreditRefundRequestStatus.PENDING,
"result": "The refund request is under review.",
}
)
except UniqueViolationError:
raise ValueError(
"Unable to request a refund for this transaction, the request of the top-up transaction within the same week has already been made."
)
if amount - balance > settings.config.refund_credit_tolerance_threshold:
user_data = await get_user_by_id(user_id)
await self._send_refund_notification(
RefundRequestData(
user_id=user_id,
user_name=user_data.name or "AutoGPT Platform User",
user_email=user_data.email,
transaction_id=transaction_key,
refund_request_id=refund_request.id,
reason=refund_request.reason,
amount=amount,
balance=balance,
),
NotificationType.REFUND_REQUEST,
)
return 0 # Register the refund request for manual approval.
# Auto refund the top-up.
refund = stripe.Refund.create(payment_intent=transaction_key, metadata=metadata)
return refund.amount
async def deduct_credits(self, request: stripe.Refund | stripe.Dispute):
if isinstance(request, stripe.Refund) and request.status != "succeeded":
logger.warning(
f"Skip processing refund #{request.id} with status {request.status}"
)
return
if isinstance(request, stripe.Dispute) and request.status != "lost":
logger.warning(
f"Skip processing dispute #{request.id} with status {request.status}"
)
return
transaction = await CreditTransaction.prisma().find_first_or_raise(
where={
"transactionKey": str(request.payment_intent),
"isActive": True,
"type": CreditTransactionType.TOP_UP,
}
)
if request.amount <= 0 or request.amount > transaction.amount:
raise AssertionError(
f"Invalid amount to deduct ${request.amount/100} from ${transaction.amount/100} top-up"
)
balance, _ = await self._add_transaction(
user_id=transaction.userId,
amount=-request.amount,
transaction_type=CreditTransactionType.REFUND,
transaction_key=request.id,
metadata=Json(request),
fail_insufficient_credits=False,
)
# Update the result of the refund request if it exists.
await CreditRefundRequest.prisma().update_many(
where={
"userId": transaction.userId,
"transactionKey": transaction.transactionKey,
},
data={
"amount": request.amount,
"status": CreditRefundRequestStatus.APPROVED,
"result": "The refund request has been approved, the amount will be credited back to your account.",
},
)
user_data = await get_user_by_id(transaction.userId)
await self._send_refund_notification(
RefundRequestData(
user_id=user_data.id,
user_name=user_data.name or "AutoGPT Platform User",
user_email=user_data.email,
transaction_id=transaction.transactionKey,
refund_request_id=request.id,
reason=str(request.reason or "-"),
amount=transaction.amount,
balance=balance,
),
NotificationType.REFUND_PROCESSED,
)
async def handle_dispute(self, dispute: stripe.Dispute):
transaction = await CreditTransaction.prisma().find_first_or_raise(
where={
"transactionKey": str(dispute.payment_intent),
"isActive": True,
"type": CreditTransactionType.TOP_UP,
}
)
user_id = transaction.userId
amount = dispute.amount
balance = await self.get_credits(user_id)
# If the user has enough balance, just let them win the dispute.
if balance - amount >= settings.config.refund_credit_tolerance_threshold:
logger.warning(f"Accepting dispute from {user_id} for ${amount/100}")
dispute.close()
return
logger.warning(
f"Adding extra info for dispute from {user_id} for ${amount/100}"
)
# Retrieve recent transaction history to support our evidence.
# This provides a concise timeline that shows service usage and proper credit application.
transaction_history = await self.get_transaction_history(
user_id, transaction_count_limit=None
)
user = await get_user_by_id(user_id)
# Build a comprehensive explanation message that includes:
# - Confirmation that the top-up transaction was processed and credits were applied.
# - A summary of recent transaction history.
# - An explanation that the funds were used to render the agreed service.
evidence_text = (
f"The top-up transaction of ${transaction.amount / 100:.2f} was processed successfully, and the corresponding credits "
"were applied to the users account. Our records confirm that the funds were utilized for the intended services. "
"Below is a summary of recent transaction activity:\n"
)
for tx in transaction_history.transactions:
if tx.transaction_key == transaction.transactionKey:
additional_comment = (
" [This top-up transaction is the subject of the dispute]."
)
else:
additional_comment = ""
evidence_text += (
f"- {tx.description}: Amount ${tx.amount / 100:.2f} on {tx.transaction_time.isoformat()}, "
f"resulting balance ${tx.balance / 100:.2f} {additional_comment}\n"
)
evidence_text += (
"\nThis evidence demonstrates that the transaction was authorized and that the charged amount was used to render the service as agreed."
"\nAdditionally, we provide an automated refund functionality, so the user could have used it if they were not satisfied with the service. "
)
evidence: stripe.Dispute.ModifyParamsEvidence = {
"product_description": "AutoGPT Platform Credits",
"customer_email_address": user.email,
"uncategorized_text": evidence_text[:20000],
}
stripe.Dispute.modify(dispute.id, evidence=evidence)
async def _top_up_credits(
self,
user_id: str,
@@ -597,15 +386,10 @@ class UserCredit(UserCreditBase):
):
raise ValueError(f"Transaction key {key} already exists for user {user_id}")
if amount == 0:
transaction_type = CreditTransactionType.CARD_CHECK
else:
transaction_type = CreditTransactionType.TOP_UP
_, transaction_key = await self._add_transaction(
user_id=user_id,
amount=amount,
transaction_type=transaction_type,
transaction_type=CreditTransactionType.TOP_UP,
is_active=False,
transaction_key=key,
ceiling_balance=ceiling_balance,
@@ -617,10 +401,8 @@ class UserCredit(UserCreditBase):
if not payment_methods:
raise ValueError("No payment method found, please add it on the platform.")
successful_transaction = None
new_transaction_key = None
for payment_method in payment_methods:
if transaction_type == CreditTransactionType.CARD_CHECK:
if amount == 0:
setup_intent = stripe.SetupIntent.create(
customer=customer_id,
usage="off_session",
@@ -632,9 +414,8 @@ class UserCredit(UserCreditBase):
},
)
if setup_intent.status == "succeeded":
successful_transaction = Json({"setup_intent": setup_intent})
new_transaction_key = setup_intent.id
break
return
else:
payment_intent = stripe.PaymentIntent.create(
amount=amount,
@@ -650,20 +431,15 @@ class UserCredit(UserCreditBase):
},
)
if payment_intent.status == "succeeded":
successful_transaction = Json({"payment_intent": payment_intent})
new_transaction_key = payment_intent.id
break
await self._enable_transaction(
transaction_key=transaction_key,
user_id=user_id,
metadata=Json({"payment_intent": payment_intent}),
)
return
if not successful_transaction:
raise ValueError(
f"Out of {len(payment_methods)} payment methods tried, none is supported"
)
await self._enable_transaction(
transaction_key=transaction_key,
new_transaction_key=new_transaction_key,
user_id=user_id,
metadata=successful_transaction,
raise ValueError(
f"Out of {len(payment_methods)} payment methods tried, none is supported"
)
async def top_up_intent(self, user_id: str, amount: int) -> str:
@@ -694,8 +470,10 @@ class UserCredit(UserCreditBase):
ui_mode="hosted",
payment_intent_data={"setup_future_usage": "off_session"},
saved_payment_method_options={"payment_method_save": "enabled"},
success_url=base_url + "/profile/credits?topup=success",
cancel_url=base_url + "/profile/credits?topup=cancel",
success_url=settings.config.frontend_base_url
+ "/marketplace/credits?topup=success",
cancel_url=settings.config.frontend_base_url
+ "/marketplace/credits?topup=cancel",
allow_promotion_codes=True,
)
@@ -705,7 +483,7 @@ class UserCredit(UserCreditBase):
transaction_type=CreditTransactionType.TOP_UP,
transaction_key=checkout_session.id,
is_active=False,
metadata=Json(checkout_session),
metadata=Json({"checkout_session": checkout_session}),
)
return checkout_session.url or ""
@@ -721,7 +499,6 @@ class UserCredit(UserCreditBase):
find_filter: CreditTransactionWhereInput = {
"type": CreditTransactionType.TOP_UP,
"isActive": False,
"amount": {"gt": 0},
}
if session_id:
find_filter["transactionKey"] = session_id
@@ -738,25 +515,18 @@ class UserCredit(UserCreditBase):
if not credit_transaction:
return
# If the transaction is not a checkout session, then skip the fulfillment
if not credit_transaction.transactionKey.startswith("cs_"):
return
# Retrieve the Checkout Session from the API
checkout_session = stripe.checkout.Session.retrieve(
credit_transaction.transactionKey,
expand=["payment_intent"],
credit_transaction.transactionKey
)
# Check the Checkout Session's payment_status property
# to determine if fulfillment should be performed
if checkout_session.payment_status in ["paid", "no_payment_required"]:
assert isinstance(checkout_session.payment_intent, stripe.PaymentIntent)
await self._enable_transaction(
transaction_key=credit_transaction.transactionKey,
new_transaction_key=checkout_session.payment_intent.id,
user_id=credit_transaction.userId,
metadata=Json(checkout_session),
metadata=Json({"checkout_session": checkout_session}),
)
async def get_credits(self, user_id: str) -> int:
@@ -766,23 +536,15 @@ class UserCredit(UserCreditBase):
async def get_transaction_history(
self,
user_id: str,
transaction_count_limit: int | None = 100,
transaction_time_ceiling: datetime | None = None,
transaction_type: str | None = None,
transaction_time: datetime,
transaction_count_limit: int,
) -> TransactionHistory:
transactions_filter: CreditTransactionWhereInput = {
"userId": user_id,
"isActive": True,
}
if transaction_time_ceiling:
transaction_time_ceiling = transaction_time_ceiling.replace(
tzinfo=timezone.utc
)
transactions_filter["createdAt"] = {"lt": transaction_time_ceiling}
if transaction_type:
transactions_filter["type"] = CreditTransactionType[transaction_type]
transactions = await CreditTransaction.prisma().find_many(
where=transactions_filter,
where={
"userId": user_id,
"createdAt": {"lt": transaction_time},
"isActive": True,
},
order={"createdAt": "desc"},
take=transaction_count_limit,
)
@@ -797,7 +559,7 @@ class UserCredit(UserCreditBase):
if t.metadata
else UsageTransactionMetadata()
)
tx_time = t.createdAt.replace(tzinfo=timezone.utc)
tx_time = t.createdAt.replace(tzinfo=None)
if t.type == CreditTransactionType.USAGE and metadata.graph_exec_id:
gt = grouped_transactions[metadata.graph_exec_id]
@@ -811,7 +573,6 @@ class UserCredit(UserCreditBase):
else:
gt = grouped_transactions[t.transactionKey]
gt.description = f"{t.type} Transaction"
gt.transaction_key = t.transactionKey
gt.amount += t.amount
gt.transaction_type = t.type
@@ -827,25 +588,6 @@ class UserCredit(UserCreditBase):
),
)
async def get_refund_requests(self, user_id: str) -> list[RefundRequest]:
return [
RefundRequest(
id=r.id,
user_id=r.userId,
transaction_key=r.transactionKey,
amount=r.amount,
reason=r.reason,
result=r.result,
status=r.status,
created_at=r.createdAt,
updated_at=r.updatedAt,
)
for r in await CreditRefundRequest.prisma().find_many(
where={"userId": user_id},
order={"createdAt": "desc"},
)
]
class BetaUserCredit(UserCredit):
"""
@@ -866,7 +608,7 @@ class BetaUserCredit(UserCredit):
balance, _ = await self._add_transaction(
user_id=user_id,
amount=max(self.num_user_credits_refill - balance, 0),
transaction_type=CreditTransactionType.GRANT,
transaction_type=CreditTransactionType.TOP_UP,
transaction_key=f"MONTHLY-CREDIT-TOP-UP-{cur_time}",
)
return balance
@@ -882,9 +624,6 @@ class DisabledUserCredit(UserCreditBase):
async def get_transaction_history(self, *args, **kwargs) -> TransactionHistory:
return TransactionHistory(transactions=[], next_transaction_time=None)
async def get_refund_requests(self, *args, **kwargs) -> list[RefundRequest]:
return []
async def spend_credits(self, *args, **kwargs) -> int:
return 0
@@ -894,15 +633,6 @@ class DisabledUserCredit(UserCreditBase):
async def top_up_intent(self, *args, **kwargs) -> str:
return ""
async def top_up_refund(self, *args, **kwargs) -> int:
return 0
async def deduct_credits(self, *args, **kwargs):
pass
async def handle_dispute(self, *args, **kwargs):
pass
async def fulfill_checkout(self, *args, **kwargs):
pass
@@ -927,11 +657,7 @@ async def get_stripe_customer_id(user_id: str) -> str:
if user.stripeCustomerId:
return user.stripeCustomerId
customer = stripe.Customer.create(
name=user.name or "",
email=user.email,
metadata={"user_id": user_id},
)
customer = stripe.Customer.create(name=user.name or "", email=user.email)
await User.prisma().update(
where={"id": user_id}, data={"stripeCustomerId": customer.id}
)

View File

@@ -2,7 +2,6 @@ import logging
import os
import zlib
from contextlib import asynccontextmanager
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
from uuid import uuid4
from dotenv import load_dotenv
@@ -16,36 +15,7 @@ load_dotenv()
PRISMA_SCHEMA = os.getenv("PRISMA_SCHEMA", "schema.prisma")
os.environ["PRISMA_SCHEMA_PATH"] = PRISMA_SCHEMA
def add_param(url: str, key: str, value: str) -> str:
p = urlparse(url)
qs = dict(parse_qsl(p.query))
qs[key] = value
return urlunparse(p._replace(query=urlencode(qs)))
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
CONN_LIMIT = os.getenv("DB_CONNECTION_LIMIT")
if CONN_LIMIT:
DATABASE_URL = add_param(DATABASE_URL, "connection_limit", CONN_LIMIT)
CONN_TIMEOUT = os.getenv("DB_CONNECT_TIMEOUT")
if CONN_TIMEOUT:
DATABASE_URL = add_param(DATABASE_URL, "connect_timeout", CONN_TIMEOUT)
POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT")
if POOL_TIMEOUT:
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
prisma = Prisma(
auto_register=True,
http={"timeout": HTTP_TIMEOUT},
datasource={"url": DATABASE_URL},
)
prisma = Prisma(auto_register=True)
logger = logging.getLogger(__name__)

View File

@@ -1,199 +1,70 @@
import logging
from collections import defaultdict
from datetime import datetime, timezone
from enum import Enum
from multiprocessing import Manager
from typing import (
Annotated,
Any,
AsyncGenerator,
Generator,
Generic,
Literal,
Optional,
TypeVar,
overload,
)
from typing import Any, AsyncGenerator, Generator, Generic, Optional, Type, TypeVar
from prisma import Json
from prisma.enums import AgentExecutionStatus
from prisma.errors import PrismaError
from prisma.models import (
AgentGraphExecution,
AgentNodeExecution,
AgentNodeExecutionInputOutput,
)
from prisma.types import (
AgentGraphExecutionWhereInput,
AgentNodeExecutionUpdateInput,
AgentNodeExecutionWhereInput,
)
from pydantic import BaseModel
from pydantic.fields import Field
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
from backend.data.includes import EXECUTION_RESULT_INCLUDE, GRAPH_EXECUTION_INCLUDE
from backend.data.queue import AsyncRedisEventBus, RedisEventBus
from backend.server.v2.store.exceptions import DatabaseError
from backend.util import mock
from backend.util import type as type_utils
from backend.util import mock, type
from backend.util.settings import Config
from .block import BlockData, BlockInput, BlockType, CompletedBlockOutput, get_block
from .db import BaseDbModel
from .includes import (
EXECUTION_RESULT_INCLUDE,
GRAPH_EXECUTION_INCLUDE,
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
)
from .model import GraphExecutionStats, NodeExecutionStats
from .queue import AsyncRedisEventBus, RedisEventBus
T = TypeVar("T")
logger = logging.getLogger(__name__)
config = Config()
class GraphExecutionEntry(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
graph_version: int
start_node_execs: list["NodeExecutionEntry"]
# -------------------------- Models -------------------------- #
class NodeExecutionEntry(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
node_exec_id: str
node_id: str
block_id: str
data: BlockInput
ExecutionStatus = AgentExecutionStatus
class GraphExecutionMeta(BaseDbModel):
user_id: str
started_at: datetime
ended_at: datetime
cost: Optional[int] = Field(..., description="Execution cost in credits")
duration: float = Field(..., description="Seconds from start to end of run")
total_run_time: float = Field(..., description="Seconds of node runtime")
status: ExecutionStatus
graph_id: str
graph_version: int
preset_id: Optional[str] = None
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
now = datetime.now(timezone.utc)
start_time = _graph_exec.startedAt or _graph_exec.createdAt
end_time = _graph_exec.updatedAt or now
duration = (end_time - start_time).total_seconds()
total_run_time = duration
try:
stats = GraphExecutionStats.model_validate(_graph_exec.stats)
except ValueError as e:
if _graph_exec.stats is not None:
logger.warning(
"Failed to parse invalid graph execution stats "
f"{_graph_exec.stats}: {e}"
)
stats = None
duration = stats.walltime if stats else duration
total_run_time = stats.nodes_walltime if stats else total_run_time
return GraphExecutionMeta(
id=_graph_exec.id,
user_id=_graph_exec.userId,
started_at=start_time,
ended_at=end_time,
cost=stats.cost if stats else None,
duration=duration,
total_run_time=total_run_time,
status=ExecutionStatus(_graph_exec.executionStatus),
graph_id=_graph_exec.agentGraphId,
graph_version=_graph_exec.agentGraphVersion,
preset_id=_graph_exec.agentPresetId,
)
T = TypeVar("T")
class GraphExecution(GraphExecutionMeta):
inputs: BlockInput
outputs: CompletedBlockOutput
class ExecutionQueue(Generic[T]):
"""
Queue for managing the execution of agents.
This will be shared between different processes
"""
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
if _graph_exec.AgentNodeExecutions is None:
raise ValueError("Node executions must be included in query")
def __init__(self):
self.queue = Manager().Queue()
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
def add(self, execution: T) -> T:
self.queue.put(execution)
return execution
node_executions = sorted(
[
NodeExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.AgentNodeExecutions
],
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
)
def get(self) -> T:
return self.queue.get()
inputs = {
**{
# inputs from Agent Input Blocks
exec.input_data["name"]: exec.input_data.get("value")
for exec in node_executions
if (
(block := get_block(exec.block_id))
and block.block_type == BlockType.INPUT
)
},
**{
# input from webhook-triggered block
"payload": exec.input_data["payload"]
for exec in node_executions
if (
(block := get_block(exec.block_id))
and block.block_type
in [BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL]
)
},
}
outputs: CompletedBlockOutput = defaultdict(list)
for exec in node_executions:
if (
block := get_block(exec.block_id)
) and block.block_type == BlockType.OUTPUT:
outputs[exec.input_data["name"]].append(
exec.input_data.get("value", None)
)
return GraphExecution(
**{
field_name: getattr(graph_exec, field_name)
for field_name in graph_exec.model_fields
},
inputs=inputs,
outputs=outputs,
)
def empty(self) -> bool:
return self.queue.empty()
class GraphExecutionWithNodes(GraphExecution):
node_executions: list["NodeExecutionResult"]
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
if _graph_exec.AgentNodeExecutions is None:
raise ValueError("Node executions must be included in query")
graph_exec_with_io = GraphExecution.from_db(_graph_exec)
node_executions = sorted(
[
NodeExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.AgentNodeExecutions
],
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
)
return GraphExecutionWithNodes(
**{
field_name: getattr(graph_exec_with_io, field_name)
for field_name in graph_exec_with_io.model_fields
},
node_executions=node_executions,
)
class NodeExecutionResult(BaseModel):
user_id: str
class ExecutionResult(BaseModel):
graph_id: str
graph_version: int
graph_exec_id: str
@@ -209,30 +80,43 @@ class NodeExecutionResult(BaseModel):
end_time: datetime | None
@staticmethod
def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None):
def from_graph(graph: AgentGraphExecution):
return ExecutionResult(
graph_id=graph.agentGraphId,
graph_version=graph.agentGraphVersion,
graph_exec_id=graph.id,
node_exec_id="",
node_id="",
block_id="",
status=graph.executionStatus,
# TODO: Populate input_data & output_data from AgentNodeExecutions
# Input & Output comes AgentInputBlock & AgentOutputBlock.
input_data={},
output_data={},
add_time=graph.createdAt,
queue_time=graph.createdAt,
start_time=graph.startedAt,
end_time=graph.updatedAt,
)
@staticmethod
def from_db(execution: AgentNodeExecution):
if execution.executionData:
# Execution that has been queued for execution will persist its data.
input_data = type_utils.convert(execution.executionData, dict[str, Any])
input_data = type.convert(execution.executionData, dict[str, Any])
else:
# For incomplete execution, executionData will not be yet available.
input_data: BlockInput = defaultdict()
for data in execution.Input or []:
input_data[data.name] = type_utils.convert(data.data, type[Any])
input_data[data.name] = type.convert(data.data, Type[Any])
output_data: CompletedBlockOutput = defaultdict(list)
for data in execution.Output or []:
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
output_data[data.name].append(type.convert(data.data, Type[Any]))
graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution
if graph_execution:
user_id = graph_execution.userId
elif not user_id:
raise ValueError(
"AgentGraphExecution must be included or user_id passed in"
)
return NodeExecutionResult(
user_id=user_id,
return ExecutionResult(
graph_id=graph_execution.agentGraphId if graph_execution else "",
graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
graph_exec_id=execution.agentGraphExecutionId,
@@ -252,88 +136,13 @@ class NodeExecutionResult(BaseModel):
# --------------------- Model functions --------------------- #
async def get_graph_executions(
graph_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> list[GraphExecutionMeta]:
where_filter: AgentGraphExecutionWhereInput = {
"isDeleted": False,
}
if user_id:
where_filter["userId"] = user_id
if graph_id:
where_filter["agentGraphId"] = graph_id
executions = await AgentGraphExecution.prisma().find_many(
where=where_filter,
order={"createdAt": "desc"},
)
return [GraphExecutionMeta.from_db(execution) for execution in executions]
async def get_graph_execution_meta(
user_id: str, execution_id: str
) -> GraphExecutionMeta | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "isDeleted": False, "userId": user_id}
)
return GraphExecutionMeta.from_db(execution) if execution else None
@overload
async def get_graph_execution(
user_id: str,
execution_id: str,
include_node_executions: Literal[True],
) -> GraphExecutionWithNodes | None: ...
@overload
async def get_graph_execution(
user_id: str,
execution_id: str,
include_node_executions: Literal[False] = False,
) -> GraphExecution | None: ...
@overload
async def get_graph_execution(
user_id: str,
execution_id: str,
include_node_executions: bool = False,
) -> GraphExecution | GraphExecutionWithNodes | None: ...
async def get_graph_execution(
user_id: str,
execution_id: str,
include_node_executions: bool = False,
) -> GraphExecution | GraphExecutionWithNodes | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "isDeleted": False, "userId": user_id},
include=(
GRAPH_EXECUTION_INCLUDE_WITH_NODES
if include_node_executions
else GRAPH_EXECUTION_INCLUDE
),
)
if not execution:
return None
return (
GraphExecutionWithNodes.from_db(execution)
if include_node_executions
else GraphExecution.from_db(execution)
)
async def create_graph_execution(
graph_id: str,
graph_version: int,
nodes_input: list[tuple[str, BlockInput]],
user_id: str,
preset_id: str | None = None,
) -> GraphExecutionWithNodes:
) -> tuple[str, list[ExecutionResult]]:
"""
Create a new AgentGraphExecution record.
Returns:
@@ -348,8 +157,7 @@ async def create_graph_execution(
"create": [ # type: ignore
{
"agentNodeId": node_id,
"executionStatus": ExecutionStatus.QUEUED,
"queuedTime": datetime.now(tz=timezone.utc),
"executionStatus": ExecutionStatus.INCOMPLETE,
"Input": {
"create": [
{"name": name, "data": Json(data)}
@@ -363,10 +171,13 @@ async def create_graph_execution(
"userId": user_id,
"agentPresetId": preset_id,
},
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
include=GRAPH_EXECUTION_INCLUDE,
)
return GraphExecutionWithNodes.from_db(result)
return result.id, [
ExecutionResult.from_db(execution)
for execution in result.AgentNodeExecutions or []
]
async def upsert_execution_input(
@@ -388,20 +199,17 @@ async def upsert_execution_input(
node_exec_id: [Optional] The id of the AgentNodeExecution that has no `input_name` as input. If not provided, it will find the eligible incomplete AgentNodeExecution or create a new one.
Returns:
str: The id of the created or existing AgentNodeExecution.
dict[str, Any]: Node input data; key is the input name, value is the input data.
* The id of the created or existing AgentNodeExecution.
* Dict of node input data, key is the input name, value is the input data.
"""
existing_exec_query_filter: AgentNodeExecutionWhereInput = {
"agentNodeId": node_id,
"agentGraphExecutionId": graph_exec_id,
"executionStatus": ExecutionStatus.INCOMPLETE,
"Input": {"every": {"name": {"not": input_name}}},
}
if node_exec_id:
existing_exec_query_filter["id"] = node_exec_id
existing_execution = await AgentNodeExecution.prisma().find_first(
where=existing_exec_query_filter,
where={ # type: ignore
**({"id": node_exec_id} if node_exec_id else {}),
"agentNodeId": node_id,
"agentGraphExecutionId": graph_exec_id,
"executionStatus": ExecutionStatus.INCOMPLETE,
"Input": {"every": {"name": {"not": input_name}}},
},
order={"addedTime": "asc"},
include={"Input": True},
)
@@ -417,7 +225,7 @@ async def upsert_execution_input(
)
return existing_execution.id, {
**{
input_data.name: type_utils.convert(input_data.data, type[Any])
input_data.name: type.convert(input_data.data, Type[Any])
for input_data in existing_execution.Input or []
},
input_name: input_data,
@@ -457,289 +265,141 @@ async def upsert_execution_output(
)
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecution:
res = await AgentGraphExecution.prisma().update(
async def update_graph_execution_start_time(graph_exec_id: str):
await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
data={
"executionStatus": ExecutionStatus.RUNNING,
"startedAt": datetime.now(tz=timezone.utc),
},
include=GRAPH_EXECUTION_INCLUDE,
)
if not res:
raise ValueError(f"Graph execution #{graph_exec_id} not found")
return GraphExecution.from_db(res)
async def update_graph_execution_stats(
graph_exec_id: str,
status: ExecutionStatus,
stats: GraphExecutionStats | None = None,
) -> GraphExecution | None:
data = stats.model_dump() if stats else {}
if isinstance(data.get("error"), Exception):
data["error"] = str(data["error"])
stats: dict[str, Any],
) -> ExecutionResult:
res = await AgentGraphExecution.prisma().update(
where={
"id": graph_exec_id,
"OR": [
{"executionStatus": ExecutionStatus.RUNNING},
{"executionStatus": ExecutionStatus.QUEUED},
],
},
where={"id": graph_exec_id},
data={
"executionStatus": status,
"stats": Json(data),
"stats": Json(stats),
},
include=GRAPH_EXECUTION_INCLUDE,
)
if not res:
raise ValueError(f"Execution {graph_exec_id} not found.")
return GraphExecution.from_db(res) if res else None
return ExecutionResult.from_graph(res)
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
data = stats.model_dump()
if isinstance(data["error"], Exception):
data["error"] = str(data["error"])
async def update_node_execution_stats(node_exec_id: str, stats: dict[str, Any]):
await AgentNodeExecution.prisma().update(
where={"id": node_exec_id},
data={"stats": Json(data)},
data={"stats": Json(stats)},
)
async def update_node_execution_status_batch(
node_exec_ids: list[str],
status: ExecutionStatus,
stats: dict[str, Any] | None = None,
):
await AgentNodeExecution.prisma().update_many(
where={"id": {"in": node_exec_ids}},
data=_get_update_status_data(status, None, stats),
)
async def update_node_execution_status(
async def update_execution_status(
node_exec_id: str,
status: ExecutionStatus,
execution_data: BlockInput | None = None,
stats: dict[str, Any] | None = None,
) -> NodeExecutionResult:
) -> ExecutionResult:
if status == ExecutionStatus.QUEUED and execution_data is None:
raise ValueError("Execution data must be provided when queuing an execution.")
now = datetime.now(tz=timezone.utc)
data = {
**({"executionStatus": status}),
**({"queuedTime": now} if status == ExecutionStatus.QUEUED else {}),
**({"startedTime": now} if status == ExecutionStatus.RUNNING else {}),
**({"endedTime": now} if status == ExecutionStatus.FAILED else {}),
**({"endedTime": now} if status == ExecutionStatus.COMPLETED else {}),
**({"executionData": Json(execution_data)} if execution_data else {}),
**({"stats": Json(stats)} if stats else {}),
}
res = await AgentNodeExecution.prisma().update(
where={"id": node_exec_id},
data=_get_update_status_data(status, execution_data, stats),
data=data, # type: ignore
include=EXECUTION_RESULT_INCLUDE,
)
if not res:
raise ValueError(f"Execution {node_exec_id} not found.")
return NodeExecutionResult.from_db(res)
return ExecutionResult.from_db(res)
def _get_update_status_data(
status: ExecutionStatus,
execution_data: BlockInput | None = None,
stats: dict[str, Any] | None = None,
) -> AgentNodeExecutionUpdateInput:
now = datetime.now(tz=timezone.utc)
update_data: AgentNodeExecutionUpdateInput = {"executionStatus": status}
async def get_execution(
execution_id: str, user_id: str
) -> Optional[AgentNodeExecution]:
"""
Get an execution by ID. Returns None if not found.
if status == ExecutionStatus.QUEUED:
update_data["queuedTime"] = now
elif status == ExecutionStatus.RUNNING:
update_data["startedTime"] = now
elif status in (ExecutionStatus.FAILED, ExecutionStatus.COMPLETED):
update_data["endedTime"] = now
Args:
execution_id: The ID of the execution to retrieve
if execution_data:
update_data["executionData"] = Json(execution_data)
if stats:
update_data["stats"] = Json(stats)
return update_data
async def delete_graph_execution(
graph_exec_id: str, user_id: str, soft_delete: bool = True
) -> None:
if soft_delete:
deleted_count = await AgentGraphExecution.prisma().update_many(
where={"id": graph_exec_id, "userId": user_id}, data={"isDeleted": True}
)
else:
deleted_count = await AgentGraphExecution.prisma().delete_many(
where={"id": graph_exec_id, "userId": user_id}
)
if deleted_count < 1:
raise DatabaseError(
f"Could not delete graph execution #{graph_exec_id}: not found"
Returns:
The execution if found, None otherwise
"""
try:
execution = await AgentNodeExecution.prisma().find_unique(
where={
"id": execution_id,
"userId": user_id,
}
)
return execution
except PrismaError:
return None
async def get_node_execution_results(
graph_exec_id: str,
block_ids: list[str] | None = None,
statuses: list[ExecutionStatus] | None = None,
limit: int | None = None,
) -> list[NodeExecutionResult]:
where_clause: AgentNodeExecutionWhereInput = {
"agentGraphExecutionId": graph_exec_id,
}
if block_ids:
where_clause["AgentNode"] = {"is": {"agentBlockId": {"in": block_ids}}}
if statuses:
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
executions = await AgentNodeExecution.prisma().find_many(
where=where_clause,
where={"agentGraphExecutionId": graph_exec_id},
include=EXECUTION_RESULT_INCLUDE,
take=limit,
order=[
{"queuedTime": "asc"},
{"addedTime": "asc"}, # Fallback: Incomplete execs has no queuedTime.
],
)
res = [NodeExecutionResult.from_db(execution) for execution in executions]
res = [ExecutionResult.from_db(execution) for execution in executions]
return res
async def get_graph_executions_in_timerange(
async def get_executions_in_timerange(
user_id: str, start_time: str, end_time: str
) -> list[GraphExecution]:
) -> list[ExecutionResult]:
try:
executions = await AgentGraphExecution.prisma().find_many(
where={
"startedAt": {
"gte": datetime.fromisoformat(start_time),
"lte": datetime.fromisoformat(end_time),
},
"userId": user_id,
"isDeleted": False,
"AND": [
{
"startedAt": {
"gte": datetime.fromisoformat(start_time),
"lte": datetime.fromisoformat(end_time),
}
},
{"userId": user_id},
]
},
include=GRAPH_EXECUTION_INCLUDE,
)
return [GraphExecution.from_db(execution) for execution in executions]
return [ExecutionResult.from_graph(execution) for execution in executions]
except Exception as e:
raise DatabaseError(
f"Failed to get executions in timerange {start_time} to {end_time} for user {user_id}: {e}"
) from e
async def get_latest_node_execution(
node_id: str, graph_eid: str
) -> NodeExecutionResult | None:
execution = await AgentNodeExecution.prisma().find_first(
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
},
order=[
{"queuedTime": "desc"},
{"addedTime": "desc"},
],
include=EXECUTION_RESULT_INCLUDE,
)
if not execution:
return None
return NodeExecutionResult.from_db(execution)
async def get_incomplete_node_executions(
node_id: str, graph_eid: str
) -> list[NodeExecutionResult]:
executions = await AgentNodeExecution.prisma().find_many(
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": ExecutionStatus.INCOMPLETE,
},
include=EXECUTION_RESULT_INCLUDE,
)
return [NodeExecutionResult.from_db(execution) for execution in executions]
# ----------------- Execution Infrastructure ----------------- #
class GraphExecutionEntry(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
graph_version: int
start_node_execs: list["NodeExecutionEntry"]
class NodeExecutionEntry(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
node_exec_id: str
node_id: str
block_id: str
data: BlockInput
class ExecutionQueue(Generic[T]):
"""
Queue for managing the execution of agents.
This will be shared between different processes
"""
def __init__(self):
self.queue = Manager().Queue()
def add(self, execution: T) -> T:
self.queue.put(execution)
return execution
def get(self) -> T:
return self.queue.get()
def empty(self) -> bool:
return self.queue.empty()
# ------------------- Execution Utilities -------------------- #
LIST_SPLIT = "_$_"
DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
def parse_execution_output(output: BlockData, name: str) -> Any | None:
"""
Extracts partial output data by name from a given BlockData.
The function supports extracting data from lists, dictionaries, and objects
using specific naming conventions:
- For lists: <output_name>_$_<index>
- For dictionaries: <output_name>_#_<key>
- For objects: <output_name>_@_<attribute>
Args:
output (BlockData): A tuple containing the output name and data.
name (str): The name used to extract specific data from the output.
Returns:
Any | None: The extracted data if found, otherwise None.
Examples:
>>> output = ("result", [10, 20, 30])
>>> parse_execution_output(output, "result_$_1")
20
>>> output = ("config", {"key1": "value1", "key2": "value2"})
>>> parse_execution_output(output, "config_#_key1")
'value1'
>>> class Sample:
... attr1 = "value1"
... attr2 = "value2"
>>> output = ("object", Sample())
>>> parse_execution_output(output, "object_@_attr1")
'value1'
"""
# Allow extracting partial output data by name.
output_name, output_data = output
if name == output_name:
@@ -768,37 +428,11 @@ def parse_execution_output(output: BlockData, name: str) -> Any | None:
def merge_execution_input(data: BlockInput) -> BlockInput:
"""
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
This function processes input keys that follow specific patterns to merge them into a unified structure:
- `<input_name>_$_<index>` for list inputs.
- `<input_name>_#_<index>` for dictionary inputs.
- `<input_name>_@_<index>` for object inputs.
Args:
data (BlockInput): A dictionary containing input keys and their corresponding values.
Returns:
BlockInput: A dictionary with merged inputs.
Raises:
ValueError: If a list index is not an integer.
Examples:
>>> data = {
... "list_$_0": "a",
... "list_$_1": "b",
... "dict_#_key1": "value1",
... "dict_#_key2": "value2",
... "object_@_attr1": "value1",
... "object_@_attr2": "value2"
... }
>>> merge_execution_input(data)
{
"list": ["a", "b"],
"dict": {"key1": "value1", "key2": "value2"},
"object": <MockObject attr1="value1" attr2="value2">
}
Merge all dynamic input pins which described by the following pattern:
- <input_name>_$_<index> for list input.
- <input_name>_#_<index> for dict input.
- <input_name>_@_<index> for object input.
This function will construct pins with the same name into a single list/dict/object.
"""
# Merge all input with <input_name>_$_<index> into a single list.
@@ -837,84 +471,69 @@ def merge_execution_input(data: BlockInput) -> BlockInput:
return data
async def get_latest_execution(node_id: str, graph_eid: str) -> ExecutionResult | None:
execution = await AgentNodeExecution.prisma().find_first(
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
},
order={"queuedTime": "desc"},
include=EXECUTION_RESULT_INCLUDE,
)
if not execution:
return None
return ExecutionResult.from_db(execution)
async def get_incomplete_executions(
node_id: str, graph_eid: str
) -> list[ExecutionResult]:
executions = await AgentNodeExecution.prisma().find_many(
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": ExecutionStatus.INCOMPLETE,
},
include=EXECUTION_RESULT_INCLUDE,
)
return [ExecutionResult.from_db(execution) for execution in executions]
# --------------------- Event Bus --------------------- #
class ExecutionEventType(str, Enum):
GRAPH_EXEC_UPDATE = "graph_execution_update"
NODE_EXEC_UPDATE = "node_execution_update"
config = Config()
class GraphExecutionEvent(GraphExecution):
event_type: Literal[ExecutionEventType.GRAPH_EXEC_UPDATE] = (
ExecutionEventType.GRAPH_EXEC_UPDATE
)
class NodeExecutionEvent(NodeExecutionResult):
event_type: Literal[ExecutionEventType.NODE_EXEC_UPDATE] = (
ExecutionEventType.NODE_EXEC_UPDATE
)
ExecutionEvent = Annotated[
GraphExecutionEvent | NodeExecutionEvent, Field(discriminator="event_type")
]
class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
Model = ExecutionEvent # type: ignore
class RedisExecutionEventBus(RedisEventBus[ExecutionResult]):
Model = ExecutionResult
@property
def event_bus_name(self) -> str:
return config.execution_event_bus_name
def publish(self, res: GraphExecution | NodeExecutionResult):
if isinstance(res, GraphExecution):
self.publish_graph_exec_update(res)
else:
self.publish_node_exec_update(res)
def publish_node_exec_update(self, res: NodeExecutionResult):
event = NodeExecutionEvent.model_validate(res.model_dump())
self.publish_event(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
def publish_graph_exec_update(self, res: GraphExecution):
event = GraphExecutionEvent.model_validate(res.model_dump())
self.publish_event(event, f"{res.user_id}/{res.graph_id}/{res.id}")
def publish(self, res: ExecutionResult):
self.publish_event(res, f"{res.graph_id}/{res.graph_exec_id}")
def listen(
self, user_id: str, graph_id: str = "*", graph_exec_id: str = "*"
) -> Generator[ExecutionEvent, None, None]:
for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
yield event
self, graph_id: str = "*", graph_exec_id: str = "*"
) -> Generator[ExecutionResult, None, None]:
for execution_result in self.listen_events(f"{graph_id}/{graph_exec_id}"):
yield execution_result
class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
Model = ExecutionEvent # type: ignore
class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionResult]):
Model = ExecutionResult
@property
def event_bus_name(self) -> str:
return config.execution_event_bus_name
async def publish(self, res: GraphExecutionMeta | NodeExecutionResult):
if isinstance(res, GraphExecutionMeta):
await self.publish_graph_exec_update(res)
else:
await self.publish_node_exec_update(res)
async def publish_node_exec_update(self, res: NodeExecutionResult):
event = NodeExecutionEvent.model_validate(res.model_dump())
await self.publish_event(
event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}"
)
async def publish_graph_exec_update(self, res: GraphExecutionMeta):
event = GraphExecutionEvent.model_validate(res.model_dump())
await self.publish_event(event, f"{res.user_id}/{res.graph_id}/{res.id}")
async def publish(self, res: ExecutionResult):
await self.publish_event(res, f"{res.graph_id}/{res.graph_exec_id}")
async def listen(
self, user_id: str, graph_id: str = "*", graph_exec_id: str = "*"
) -> AsyncGenerator[ExecutionEvent, None]:
async for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
yield event
self, graph_id: str = "*", graph_exec_id: str = "*"
) -> AsyncGenerator[ExecutionResult, None]:
async for execution_result in self.listen_events(f"{graph_id}/{graph_exec_id}"):
yield execution_result

View File

@@ -1,23 +1,29 @@
import asyncio
import logging
import uuid
from collections import defaultdict
from datetime import datetime, timezone
from typing import Any, Literal, Optional, Type
import prisma
from prisma import Json
from prisma.enums import SubmissionStatus
from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVersion
from prisma.models import (
AgentGraph,
AgentGraphExecution,
AgentNode,
AgentNodeLink,
StoreListingVersion,
)
from prisma.types import AgentGraphWhereInput
from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.blocks.llm import LlmModel
from backend.data.db import prisma as db
from backend.util import type as type_utils
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock
from backend.util import type
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
from .block import BlockInput, BlockType, get_block, get_blocks
from .db import BaseDbModel, transaction
from .execution import ExecutionStatus
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
from .integrations import Webhook
@@ -62,20 +68,15 @@ class NodeModel(Node):
webhook: Optional[Webhook] = None
@property
def block(self) -> Block[BlockSchema, BlockSchema]:
block = get_block(self.block_id)
if not block:
raise ValueError(f"Block #{self.block_id} does not exist")
return block
@staticmethod
def from_db(node: AgentNode, for_export: bool = False) -> "NodeModel":
def from_db(node: AgentNode):
if not node.AgentBlock:
raise ValueError(f"Invalid node {node.id}, invalid AgentBlock.")
obj = NodeModel(
id=node.id,
block_id=node.agentBlockId,
input_default=type_utils.convert(node.constantInput, dict[str, Any]),
metadata=type_utils.convert(node.metadata, dict[str, Any]),
block_id=node.AgentBlock.id,
input_default=type.convert(node.constantInput, dict[str, Any]),
metadata=type.convert(node.metadata, dict[str, Any]),
graph_id=node.agentGraphId,
graph_version=node.agentGraphVersion,
webhook_id=node.webhookId,
@@ -83,8 +84,6 @@ class NodeModel(Node):
)
obj.input_links = [Link.from_db(link) for link in node.Input or []]
obj.output_links = [Link.from_db(link) for link in node.Output or []]
if for_export:
return obj.stripped_for_export()
return obj
def is_triggered_by_event_type(self, event_type: str) -> bool:
@@ -103,59 +102,54 @@ class NodeModel(Node):
if event_filter[k] is True
]
def stripped_for_export(self) -> "NodeModel":
"""
Returns a copy of the node model, stripped of any non-transferable properties
"""
stripped_node = self.model_copy(deep=True)
# Remove credentials from node input
if stripped_node.input_default:
stripped_node.input_default = NodeModel._filter_secrets_from_node_input(
stripped_node.input_default, self.block.input_schema.jsonschema()
)
if (
stripped_node.block.block_type == BlockType.INPUT
and "value" in stripped_node.input_default
):
stripped_node.input_default["value"] = ""
# Remove webhook info
stripped_node.webhook_id = None
stripped_node.webhook = None
return stripped_node
@staticmethod
def _filter_secrets_from_node_input(
input_data: dict[str, Any], schema: dict[str, Any] | None
) -> dict[str, Any]:
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
field_schemas = schema.get("properties", {}) if schema else {}
result = {}
for key, value in input_data.items():
field_schema: dict | None = field_schemas.get(key)
if (field_schema and field_schema.get("secret", False)) or any(
sensitive_key in key.lower() for sensitive_key in sensitive_keys
):
# This is a secret value -> filter this key-value pair out
continue
elif isinstance(value, dict):
result[key] = NodeModel._filter_secrets_from_node_input(
value, field_schema
)
else:
result[key] = value
return result
# Fix 2-way reference Node <-> Webhook
Webhook.model_rebuild()
class BaseGraph(BaseDbModel):
class GraphExecution(BaseDbModel):
execution_id: str
started_at: datetime
ended_at: datetime
duration: float
total_run_time: float
status: ExecutionStatus
graph_id: str
graph_version: int
@staticmethod
def from_db(execution: AgentGraphExecution):
now = datetime.now(timezone.utc)
start_time = execution.startedAt or execution.createdAt
end_time = execution.updatedAt or now
duration = (end_time - start_time).total_seconds()
total_run_time = duration
try:
stats = type.convert(execution.stats or {}, dict[str, Any])
except ValueError:
stats = {}
duration = stats.get("walltime", duration)
total_run_time = stats.get("nodes_walltime", total_run_time)
return GraphExecution(
id=execution.id,
execution_id=execution.id,
started_at=start_time,
ended_at=end_time,
duration=duration,
total_run_time=total_run_time,
status=ExecutionStatus(execution.executionStatus),
graph_id=execution.agentGraphId,
graph_version=execution.agentGraphVersion,
)
class Graph(BaseDbModel):
version: int = 1
is_active: bool = True
is_template: bool = False
name: str
description: str
nodes: list[Node] = []
@@ -165,48 +159,46 @@ class BaseGraph(BaseDbModel):
@property
def input_schema(self) -> dict[str, Any]:
return self._generate_schema(
*(
(b.input_schema, node.input_default)
AgentInputBlock.Input,
[
node.input_default
for node in self.nodes
if (b := get_block(node.block_id))
and b.block_type == BlockType.INPUT
and issubclass(b.input_schema, AgentInputBlock.Input)
)
and "name" in node.input_default
],
)
@computed_field
@property
def output_schema(self) -> dict[str, Any]:
return self._generate_schema(
*(
(b.input_schema, node.input_default)
AgentOutputBlock.Input,
[
node.input_default
for node in self.nodes
if (b := get_block(node.block_id))
and b.block_type == BlockType.OUTPUT
and issubclass(b.input_schema, AgentOutputBlock.Input)
)
and "name" in node.input_default
],
)
@staticmethod
def _generate_schema(
*props: tuple[Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input], dict],
type_class: Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input],
data: list[dict],
) -> dict[str, Any]:
schema = []
for type_class, input_default in props:
props = []
for p in data:
try:
schema.append(type_class(**input_default))
props.append(type_class(**p))
except Exception as e:
logger.warning(f"Invalid {type_class}: {input_default}, {e}")
logger.warning(f"Invalid {type_class}: {p}, {e}")
return {
"type": "object",
"properties": {
p.name: {
**{
k: v
for k, v in p.generate_schema().items()
if k not in ["description", "default"]
},
"secret": p.secret,
# Default value has to be set for advanced fields.
"advanced": p.advanced and p.value is not None,
@@ -214,16 +206,12 @@ class BaseGraph(BaseDbModel):
**({"description": p.description} if p.description else {}),
**({"default": p.value} if p.value is not None else {}),
}
for p in schema
for p in props
},
"required": [p.name for p in schema if p.value is None],
"required": [p.name for p in props if p.value is None],
}
class Graph(BaseGraph):
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs, only used in export
class GraphModel(Graph):
user_id: str
nodes: list[NodeModel] = [] # type: ignore
@@ -247,88 +235,42 @@ class GraphModel(Graph):
Reassigns all IDs in the graph to new UUIDs.
This method can be used before storing a new graph to the database.
"""
if reassign_graph_id:
graph_id_map = {
self.id: str(uuid.uuid4()),
**{sub_graph.id: str(uuid.uuid4()) for sub_graph in self.sub_graphs},
}
else:
graph_id_map = {}
self._reassign_ids(self, user_id, graph_id_map)
for sub_graph in self.sub_graphs:
self._reassign_ids(sub_graph, user_id, graph_id_map)
@staticmethod
def _reassign_ids(
graph: BaseGraph,
user_id: str,
graph_id_map: dict[str, str],
):
# Reassign Graph ID
if graph.id in graph_id_map:
graph.id = graph_id_map[graph.id]
id_map = {node.id: str(uuid.uuid4()) for node in self.nodes}
if reassign_graph_id:
self.id = str(uuid.uuid4())
# Reassign Node IDs
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
for node in graph.nodes:
for node in self.nodes:
node.id = id_map[node.id]
# Reassign Link IDs
for link in graph.links:
for link in self.links:
link.source_id = id_map[link.source_id]
link.sink_id = id_map[link.sink_id]
# Reassign User IDs for agent blocks
for node in graph.nodes:
for node in self.nodes:
if node.block_id != AgentExecutorBlock().id:
continue
node.input_default["user_id"] = user_id
node.input_default.setdefault("data", {})
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
node.input_default["graph_id"] = graph_id_map[graph_id]
self.validate_graph()
def validate_graph(self, for_run: bool = False):
self._validate_graph(self, for_run)
for sub_graph in self.sub_graphs:
self._validate_graph(sub_graph, for_run)
@staticmethod
def _validate_graph(graph: BaseGraph, for_run: bool = False):
def sanitize(name):
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
if sanitized_name.startswith("tools_^_"):
return sanitized_name.split("_^_")[0]
return sanitized_name
# Validate smart decision maker nodes
smart_decision_maker_nodes = set()
agent_nodes = set()
nodes_block = {
node.id: block
for node in graph.nodes
if (block := get_block(node.block_id)) is not None
}
for node in graph.nodes:
if (block := nodes_block.get(node.id)) is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
# Smart decision maker nodes
if block.block_type == BlockType.AI:
smart_decision_maker_nodes.add(node.id)
# Agent nodes
elif block.block_type == BlockType.AGENT:
agent_nodes.add(node.id)
return name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
input_links = defaultdict(list)
for link in graph.links:
for link in self.links:
input_links[link.sink_id].append(link)
# Nodes: required fields are filled or connected and dependencies are satisfied
for node in graph.nodes:
if (block := nodes_block.get(node.id)) is None:
for node in self.nodes:
block = get_block(node.block_id)
if block is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
provided_inputs = set(
@@ -345,12 +287,9 @@ class GraphModel(Graph):
)
and (
for_run # Skip input completion validation, unless when executing.
or block.block_type
in [
BlockType.INPUT,
BlockType.OUTPUT,
BlockType.AGENT,
]
or block.block_type == BlockType.INPUT
or block.block_type == BlockType.OUTPUT
or block.block_type == BlockType.AGENT
)
):
raise ValueError(
@@ -388,7 +327,7 @@ class GraphModel(Graph):
f"Node {block.name} #{node.id}: Field `{field_name}` requires [{', '.join(missing_deps)}] to be set"
)
node_map = {v.id: v for v in graph.nodes}
node_map = {v.id: v for v in self.nodes}
def is_static_output_block(nid: str) -> bool:
bid = node_map[nid].block_id
@@ -396,23 +335,23 @@ class GraphModel(Graph):
return b.static_output if b else False
# Links: links are connected and the connected pin data type are compatible.
for link in graph.links:
for link in self.links:
source = (link.source_id, link.source_name)
sink = (link.sink_id, link.sink_name)
prefix = f"Link {source} <-> {sink}"
suffix = f"Link {source} <-> {sink}"
for i, (node_id, name) in enumerate([source, sink]):
node = node_map.get(node_id)
if not node:
raise ValueError(
f"{prefix}, {node_id} is invalid node id, available nodes: {node_map.keys()}"
f"{suffix}, {node_id} is invalid node id, available nodes: {node_map.keys()}"
)
block = get_block(node.block_id)
if not block:
blocks = {v().id: v().name for v in get_blocks().values()}
raise ValueError(
f"{prefix}, {node.block_id} is invalid block id, available blocks: {blocks}"
f"{suffix}, {node.block_id} is invalid block id, available blocks: {blocks}"
)
sanitized_name = sanitize(name)
@@ -420,37 +359,35 @@ class GraphModel(Graph):
if i == 0:
fields = (
block.output_schema.get_fields()
if block.block_type not in [BlockType.AGENT]
if block.block_type != BlockType.AGENT
else vals.get("output_schema", {}).get("properties", {}).keys()
)
else:
fields = (
block.input_schema.get_fields()
if block.block_type not in [BlockType.AGENT]
if block.block_type != BlockType.AGENT
else vals.get("input_schema", {}).get("properties", {}).keys()
)
if sanitized_name not in fields and not name.startswith("tools_^_"):
if sanitized_name not in fields:
fields_msg = f"Allowed fields: {fields}"
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
raise ValueError(f"{suffix}, `{name}` invalid, {fields_msg}")
if is_static_output_block(link.source_id):
link.is_static = True # Each value block output should be static.
@staticmethod
def from_db(
graph: AgentGraph,
for_export: bool = False,
sub_graphs: list[AgentGraph] | None = None,
):
def from_db(graph: AgentGraph, for_export: bool = False):
return GraphModel(
id=graph.id,
user_id=graph.userId if not for_export else "",
user_id=graph.userId,
version=graph.version,
is_active=graph.isActive,
is_template=graph.isTemplate,
name=graph.name or "",
description=graph.description or "",
nodes=[
NodeModel.from_db(node, for_export) for node in graph.AgentNodes or []
NodeModel.from_db(GraphModel._process_node(node, for_export))
for node in graph.AgentNodes or []
],
links=list(
{
@@ -459,12 +396,59 @@ class GraphModel(Graph):
for link in (node.Input or []) + (node.Output or [])
}
),
sub_graphs=[
GraphModel.from_db(sub_graph, for_export)
for sub_graph in sub_graphs or []
],
)
@staticmethod
def _process_node(node: AgentNode, for_export: bool) -> AgentNode:
if for_export:
# Remove credentials from node input
if node.constantInput:
constant_input = type.convert(node.constantInput, dict[str, Any])
constant_input = GraphModel._hide_node_input_credentials(constant_input)
node.constantInput = Json(constant_input)
# Remove webhook info
node.webhookId = None
node.Webhook = None
return node
@staticmethod
def _hide_node_input_credentials(input_data: dict[str, Any]) -> dict[str, Any]:
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
result = {}
for key, value in input_data.items():
if isinstance(value, dict):
result[key] = GraphModel._hide_node_input_credentials(value)
elif isinstance(value, str) and any(
sensitive_key in key.lower() for sensitive_key in sensitive_keys
):
# Skip this key-value pair in the result
continue
else:
result[key] = value
return result
def clean_graph(self):
blocks = [block() for block in get_blocks().values()]
input_blocks = [
node
for node in self.nodes
if next(
(
b
for b in blocks
if b.id == node.block_id and b.block_type == BlockType.INPUT
),
None,
)
]
for node in self.nodes:
if any(input_block.id == node.id for input_block in input_blocks):
node.input_default["value"] = ""
# --------------------- CRUD functions --------------------- #
@@ -494,14 +478,14 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
async def get_graphs(
user_id: str,
filter_by: Literal["active"] | None = "active",
filter_by: Literal["active", "template"] | None = "active",
) -> list[GraphModel]:
"""
Retrieves graph metadata objects.
Default behaviour is to get all currently active graphs.
Args:
filter_by: An optional filter to either select graphs.
filter_by: An optional filter to either select templates or active graphs.
user_id: The ID of the user that owns the graph.
Returns:
@@ -511,6 +495,8 @@ async def get_graphs(
if filter_by == "active":
where_clause["isActive"] = True
elif filter_by == "template":
where_clause["isTemplate"] = True
graphs = await AgentGraph.prisma().find_many(
where=where_clause,
@@ -530,40 +516,32 @@ async def get_graphs(
return graph_models
async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph | None:
where_clause: AgentGraphWhereInput = {
"id": graph_id,
}
if version is not None:
where_clause["version"] = version
graph = await AgentGraph.prisma().find_first(
where=where_clause,
order={"version": "desc"},
async def get_executions(user_id: str) -> list[GraphExecution]:
executions = await AgentGraphExecution.prisma().find_many(
where={"userId": user_id},
order={"createdAt": "desc"},
)
return [GraphExecution.from_db(execution) for execution in executions]
if not graph:
return None
return Graph(
id=graph.id,
name=graph.name or "",
description=graph.description or "",
version=graph.version,
is_active=graph.isActive,
async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "userId": user_id}
)
return GraphExecution.from_db(execution) if execution else None
async def get_graph(
graph_id: str,
version: int | None = None,
template: bool = False, # note: currently not in use; TODO: remove from DB entirely
user_id: str | None = None,
for_export: bool = False,
) -> GraphModel | None:
"""
Retrieves a graph from the DB.
Defaults to the version with `is_active` if `version` is not passed.
Defaults to the version with `is_active` if `version` is not passed,
or the latest version with `is_template` if `template=True`.
Returns `None` if the record is not found.
"""
@@ -573,6 +551,8 @@ async def get_graph(
if version is not None:
where_clause["version"] = version
elif not template:
where_clause["isActive"] = True
graph = await AgentGraph.prisma().find_first(
where=where_clause,
@@ -589,81 +569,16 @@ async def get_graph(
"agentId": graph_id,
"agentVersion": version or graph.version,
"isDeleted": False,
"submissionStatus": SubmissionStatus.APPROVED,
"StoreListing": {"is": {"isApproved": True}},
}
)
)
):
return None
if for_export:
sub_graphs = await get_sub_graphs(graph)
return GraphModel.from_db(
graph=graph,
sub_graphs=sub_graphs,
for_export=for_export,
)
return GraphModel.from_db(graph, for_export)
async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
"""
Iteratively fetches all sub-graphs of a given graph, and flattens them into a list.
This call involves a DB fetch in batch, breadth-first, per-level of graph depth.
On each DB fetch we will only fetch the sub-graphs that are not already in the list.
"""
sub_graphs = {graph.id: graph}
search_graphs = [graph]
agent_block_id = AgentExecutorBlock().id
while search_graphs:
sub_graph_ids = [
(graph_id, graph_version)
for graph in search_graphs
for node in graph.AgentNodes or []
if (
node.AgentBlock
and node.AgentBlock.id == agent_block_id
and (graph_id := dict(node.constantInput).get("graph_id"))
and (graph_version := dict(node.constantInput).get("graph_version"))
)
]
if not sub_graph_ids:
break
graphs = await AgentGraph.prisma().find_many(
where={
"OR": [
{
"id": graph_id,
"version": graph_version,
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
}
for graph_id, graph_version in sub_graph_ids
] # type: ignore
},
include=AGENT_GRAPH_INCLUDE,
)
search_graphs = [graph for graph in graphs if graph.id not in sub_graphs]
sub_graphs.update({graph.id: graph for graph in search_graphs})
return [g for g in sub_graphs.values() if g.id != graph.id]
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
links = await AgentNodeLink.prisma().find_many(
where={"agentNodeSourceId": node_id},
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}}, # type: ignore
)
return [
(Link.from_db(link), NodeModel.from_db(link.AgentNodeSink))
for link in links
if link.AgentNodeSink
]
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
# Activate the requested version if it exists and is owned by the user.
updated_count = await AgentGraph.prisma().update_many(
@@ -715,56 +630,50 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
async with transaction() as tx:
await __create_graph(tx, graph, user_id)
if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
if created_graph := await get_graph(
graph.id, graph.version, graph.is_template, user_id=user_id
):
return created_graph
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
async def __create_graph(tx, graph: Graph, user_id: str):
graphs = [graph] + graph.sub_graphs
await AgentGraph.prisma(tx).create_many(
data=[
{
"id": graph.id,
"version": graph.version,
"name": graph.name,
"description": graph.description,
"isActive": graph.is_active,
"userId": user_id,
}
for graph in graphs
]
await AgentGraph.prisma(tx).create(
data={
"id": graph.id,
"version": graph.version,
"name": graph.name,
"description": graph.description,
"isTemplate": graph.is_template,
"isActive": graph.is_active,
"userId": user_id,
"AgentNodes": {
"create": [
{
"id": node.id,
"agentBlockId": node.block_id,
"constantInput": Json(node.input_default),
"metadata": Json(node.metadata),
}
for node in graph.nodes
]
},
}
)
await AgentNode.prisma(tx).create_many(
data=[
{
"id": node.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"agentBlockId": node.block_id,
"constantInput": Json(node.input_default),
"metadata": Json(node.metadata),
"webhookId": node.webhook_id,
}
for graph in graphs
for node in graph.nodes
]
)
await AgentNodeLink.prisma(tx).create_many(
data=[
{
"id": str(uuid.uuid4()),
"sourceName": link.source_name,
"sinkName": link.sink_name,
"agentNodeSourceId": link.source_id,
"agentNodeSinkId": link.sink_id,
"isStatic": link.is_static,
}
for graph in graphs
await asyncio.gather(
*[
AgentNodeLink.prisma(tx).create(
{
"id": str(uuid.uuid4()),
"sourceName": link.source_name,
"sinkName": link.sink_name,
"agentNodeSourceId": link.source_id,
"agentNodeSinkId": link.sink_id,
"isStatic": link.is_static,
}
)
for link in graph.links
]
)
@@ -807,11 +716,9 @@ async def fix_llm_provider_credentials():
store = IntegrationCredentialsStore()
broken_nodes = []
try:
broken_nodes = await prisma.get_client().query_raw(
"""
SELECT graph."userId" user_id,
broken_nodes = await prisma.get_client().query_raw(
"""
SELECT graph."userId" user_id,
node.id node_id,
node."constantInput" node_preset_input
FROM platform."AgentNode" node
@@ -820,10 +727,8 @@ async def fix_llm_provider_credentials():
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
ORDER BY graph."userId";
"""
)
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")
except Exception as e:
logger.error(f"Error fixing LLM credential inputs: {e}")
)
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")
user_id: str = ""
user_integrations = None
@@ -874,40 +779,3 @@ async def fix_llm_provider_credentials():
where={"id": node_id},
data={"constantInput": Json(node_preset_input)},
)
async def migrate_llm_models(migrate_to: LlmModel):
"""
Update all LLM models in all AI blocks that don't exist in the enum.
Note: Only updates top level LlmModel SchemaFields of blocks (won't update nested fields).
"""
logger.info("Migrating LLM models")
# Scan all blocks and search for LlmModel fields
llm_model_fields: dict[str, str] = {} # {block_id: field_name}
# Search for all LlmModel fields
for block_type in get_blocks().values():
block = block_type()
from pydantic.fields import FieldInfo
fields: dict[str, FieldInfo] = block.input_schema.model_fields
# Collect top-level LlmModel fields
for field_name, field in fields.items():
if field.annotation == LlmModel:
llm_model_fields[block.id] = field_name
# Update each block
for id, path in llm_model_fields.items():
# Convert enum values to a list of strings for the SQL query
enum_values = [v.value for v in LlmModel.__members__.values()]
query = f"""
UPDATE "AgentNode"
SET "constantInput" = jsonb_set("constantInput", '{{{path}}}', '"{migrate_to.value}"', true)
WHERE "agentBlockId" = '{id}'
AND "constantInput" ? '{path}'
AND "constantInput"->>'{path}' NOT IN ({','.join(f"'{value}'" for value in enum_values)})
"""
await db.execute_raw(query)

View File

@@ -1,7 +1,5 @@
import prisma
from backend.blocks.io import IO_BLOCK_IDs
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
"Input": True,
"Output": True,
@@ -20,49 +18,17 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
"AgentGraphExecution": True,
}
MAX_NODE_EXECUTIONS_FETCH = 1000
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
"AgentNodeExecutions": {
"include": {
"Input": True,
"Output": True,
"AgentNode": True,
"AgentGraphExecution": True,
},
"order_by": [
{"queuedTime": "desc"},
# Fallback: Incomplete execs has no queuedTime.
{"addedTime": "desc"},
],
"take": MAX_NODE_EXECUTIONS_FETCH, # Avoid loading excessive node executions.
}
}
}
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
"AgentNodeExecutions": {
**GRAPH_EXECUTION_INCLUDE_WITH_NODES["AgentNodeExecutions"], # type: ignore
"where": {
"AgentNode": {
"AgentBlock": {"id": {"in": IO_BLOCK_IDs}}, # type: ignore
},
},
}
}
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
}
def library_agent_include(user_id: str) -> prisma.types.LibraryAgentInclude:
return {
"Agent": {
"include": {
**AGENT_GRAPH_INCLUDE,
"AgentGraphExecution": {"where": {"userId": user_id}},
}
},
"Creator": True,
}

View File

@@ -9,7 +9,6 @@ from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
from backend.data.queue import AsyncRedisEventBus
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.utils import webhook_ingress_url
from backend.util.exceptions import NotFoundError
from .db import BaseDbModel
@@ -83,18 +82,11 @@ async def create_webhook(webhook: Webhook) -> Webhook:
async def get_webhook(webhook_id: str) -> Webhook:
"""
⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints.
Raises:
NotFoundError: if no record with the given ID exists
"""
webhook = await IntegrationWebhook.prisma().find_unique(
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
webhook = await IntegrationWebhook.prisma().find_unique_or_raise(
where={"id": webhook_id},
include=INTEGRATION_WEBHOOK_INCLUDE,
)
if not webhook:
raise NotFoundError(f"Webhook #{webhook_id} not found")
return Webhook.from_db(webhook)

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import base64
import logging
from datetime import datetime, timezone
from datetime import datetime
from typing import (
TYPE_CHECKING,
Annotated,
@@ -141,8 +141,9 @@ def SchemaField(
secret: bool = False,
exclude: bool = False,
hidden: Optional[bool] = None,
depends_on: Optional[list[str]] = None,
json_schema_extra: Optional[dict[str, Any]] = None,
depends_on: list[str] | None = None,
image_upload: Optional[bool] = None,
image_output: Optional[bool] = None,
**kwargs,
) -> T:
if default is PydanticUndefined and default_factory is None:
@@ -150,7 +151,7 @@ def SchemaField(
elif advanced is None:
advanced = True
json_schema_extra = {
json_extra = {
k: v
for k, v in {
"placeholder": placeholder,
@@ -158,7 +159,8 @@ def SchemaField(
"advanced": advanced,
"hidden": hidden,
"depends_on": depends_on,
**(json_schema_extra or {}),
"image_upload": image_upload,
"image_output": image_output,
}.items()
if v is not None
}
@@ -170,7 +172,7 @@ def SchemaField(
title=title,
description=description,
exclude=exclude,
json_schema_extra=json_schema_extra,
json_schema_extra=json_extra,
**kwargs,
) # type: ignore
@@ -373,8 +375,7 @@ class AutoTopUpConfig(BaseModel):
class UserTransaction(BaseModel):
transaction_key: str = ""
transaction_time: datetime = datetime.min.replace(tzinfo=timezone.utc)
transaction_time: datetime = datetime.min
transaction_type: CreditTransactionType = CreditTransactionType.USAGE
amount: int = 0
balance: int = 0
@@ -382,62 +383,9 @@ class UserTransaction(BaseModel):
usage_graph_id: str | None = None
usage_execution_id: str | None = None
usage_node_count: int = 0
usage_start_time: datetime = datetime.max.replace(tzinfo=timezone.utc)
usage_start_time: datetime = datetime.max
class TransactionHistory(BaseModel):
transactions: list[UserTransaction]
next_transaction_time: datetime | None
class RefundRequest(BaseModel):
id: str
user_id: str
transaction_key: str
amount: int
reason: str
result: str | None = None
status: str
created_at: datetime
updated_at: datetime
class NodeExecutionStats(BaseModel):
"""Execution statistics for a node execution."""
class Config:
arbitrary_types_allowed = True
extra = "allow"
error: Optional[Exception | str] = None
walltime: float = 0
cputime: float = 0
input_size: int = 0
output_size: int = 0
llm_call_count: int = 0
llm_retry_count: int = 0
input_token_count: int = 0
output_token_count: int = 0
class GraphExecutionStats(BaseModel):
"""Execution statistics for a graph execution."""
class Config:
arbitrary_types_allowed = True
extra = "allow"
error: Optional[Exception | str] = None
walltime: float = Field(
default=0, description="Time between start and end of run (seconds)"
)
cputime: float = 0
nodes_walltime: float = Field(
default=0, description="Total node execution time (seconds)"
)
nodes_cputime: float = 0
node_count: int = Field(default=0, description="Total number of node executions")
node_error_count: int = Field(
default=0, description="Total number of errors generated"
)
cost: int = Field(default=0, description="Total execution cost (cents)")

View File

@@ -1,7 +1,7 @@
import logging
from datetime import datetime, timedelta, timezone
from datetime import datetime, timedelta
from enum import Enum
from typing import Annotated, Any, Generic, Optional, TypeVar, Union
from typing import Annotated, Generic, Optional, TypeVar, Union
from prisma import Json
from prisma.enums import NotificationType
@@ -18,34 +18,27 @@ from .db import transaction
logger = logging.getLogger(__name__)
NotificationDataType_co = TypeVar(
"NotificationDataType_co", bound="BaseNotificationData", covariant=True
)
SummaryParamsType_co = TypeVar(
"SummaryParamsType_co", bound="BaseSummaryParams", covariant=True
)
T_co = TypeVar("T_co", bound="BaseNotificationData", covariant=True)
class QueueType(Enum):
class BatchingStrategy(Enum):
IMMEDIATE = "immediate" # Send right away (errors, critical notifications)
BATCH = "batch" # Batch for up to an hour (usage reports)
SUMMARY = "summary" # Daily digest (summary notifications)
HOURLY = "hourly" # Batch for up to an hour (usage reports)
DAILY = "daily" # Daily digest (summary notifications)
BACKOFF = "backoff" # Backoff strategy (exponential backoff)
ADMIN = "admin" # Admin notifications (errors, critical notifications)
class BaseNotificationData(BaseModel):
class Config:
extra = "allow"
pass
class AgentRunData(BaseNotificationData):
agent_name: str
credits_used: float
# remaining_balance: float
execution_time: float
node_count: int = Field(..., description="Number of nodes executed")
graph_id: str
outputs: list[dict[str, Any]] = Field(..., description="Outputs of the agent")
node_count: int = Field(..., description="Number of nodes executed")
class ZeroBalanceData(BaseNotificationData):
@@ -53,21 +46,12 @@ class ZeroBalanceData(BaseNotificationData):
last_transaction_time: datetime
top_up_link: str
@field_validator("last_transaction_time")
@classmethod
def validate_timezone(cls, value: datetime):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class LowBalanceData(BaseNotificationData):
agent_name: str = Field(..., description="Name of the agent")
current_balance: float = Field(
..., description="Current balance in credits (100 = $1)"
)
billing_page_link: str = Field(..., description="Link to billing page")
shortfall: float = Field(..., description="Amount of credits needed to continue")
current_balance: float
threshold_amount: float
top_up_link: str
recent_usage: float = Field(..., description="Usage in the last 24 hours")
class BlockExecutionFailedData(BaseNotificationData):
@@ -88,13 +72,6 @@ class ContinuousAgentErrorData(BaseNotificationData):
error_time: datetime
attempts: int = Field(..., description="Number of retry attempts made")
@field_validator("start_time", "error_time")
@classmethod
def validate_timezone(cls, value: datetime):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class BaseSummaryData(BaseNotificationData):
total_credits_used: float
@@ -107,66 +84,20 @@ class BaseSummaryData(BaseNotificationData):
cost_breakdown: dict[str, float]
class BaseSummaryParams(BaseModel):
pass
class DailySummaryParams(BaseSummaryParams):
date: datetime
@field_validator("date")
def validate_timezone(cls, value):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class WeeklySummaryParams(BaseSummaryParams):
start_date: datetime
end_date: datetime
@field_validator("start_date", "end_date")
def validate_timezone(cls, value):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class DailySummaryData(BaseSummaryData):
date: datetime
@field_validator("date")
def validate_timezone(cls, value):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class WeeklySummaryData(BaseSummaryData):
start_date: datetime
end_date: datetime
@field_validator("start_date", "end_date")
def validate_timezone(cls, value):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class MonthlySummaryData(BaseNotificationData):
month: int
week_number: int
year: int
class RefundRequestData(BaseNotificationData):
user_id: str
user_name: str
user_email: str
transaction_id: str
refund_request_id: str
reason: str
amount: float
balance: int
class MonthlySummaryData(BaseSummaryData):
month: int
year: int
NotificationData = Annotated[
@@ -177,10 +108,6 @@ NotificationData = Annotated[
BlockExecutionFailedData,
ContinuousAgentErrorData,
MonthlySummaryData,
WeeklySummaryData,
DailySummaryData,
RefundRequestData,
BaseSummaryData,
],
Field(discriminator="type"),
]
@@ -190,25 +117,17 @@ class NotificationEventDTO(BaseModel):
user_id: str
type: NotificationType
data: dict
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
retry_count: int = 0
created_at: datetime = Field(default_factory=datetime.now)
class SummaryParamsEventDTO(BaseModel):
class NotificationEventModel(BaseModel, Generic[T_co]):
user_id: str
type: NotificationType
data: dict
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
class NotificationEventModel(BaseModel, Generic[NotificationDataType_co]):
user_id: str
type: NotificationType
data: NotificationDataType_co
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
data: T_co
created_at: datetime = Field(default_factory=datetime.now)
@property
def strategy(self) -> QueueType:
def strategy(self) -> BatchingStrategy:
return NotificationTypeOverride(self.type).strategy
@field_validator("type", mode="before")
@@ -222,14 +141,7 @@ class NotificationEventModel(BaseModel, Generic[NotificationDataType_co]):
return NotificationTypeOverride(self.type).template
class SummaryParamsEventModel(BaseModel, Generic[SummaryParamsType_co]):
user_id: str
type: NotificationType
data: SummaryParamsType_co
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
def get_notif_data_type(
def get_data_type(
notification_type: NotificationType,
) -> type[BaseNotificationData]:
return {
@@ -241,25 +153,14 @@ def get_notif_data_type(
NotificationType.DAILY_SUMMARY: DailySummaryData,
NotificationType.WEEKLY_SUMMARY: WeeklySummaryData,
NotificationType.MONTHLY_SUMMARY: MonthlySummaryData,
NotificationType.REFUND_REQUEST: RefundRequestData,
NotificationType.REFUND_PROCESSED: RefundRequestData,
}[notification_type]
def get_summary_params_type(
notification_type: NotificationType,
) -> type[BaseSummaryParams]:
return {
NotificationType.DAILY_SUMMARY: DailySummaryParams,
NotificationType.WEEKLY_SUMMARY: WeeklySummaryParams,
}[notification_type]
class NotificationBatch(BaseModel):
user_id: str
events: list[NotificationEvent]
strategy: QueueType
last_update: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
strategy: BatchingStrategy
last_update: datetime = datetime.now()
class NotificationResult(BaseModel):
@@ -272,22 +173,21 @@ class NotificationTypeOverride:
self.notification_type = notification_type
@property
def strategy(self) -> QueueType:
def strategy(self) -> BatchingStrategy:
BATCHING_RULES = {
# These are batched by the notification service
NotificationType.AGENT_RUN: QueueType.BATCH,
NotificationType.AGENT_RUN: BatchingStrategy.IMMEDIATE,
# These are batched by the notification service, but with a backoff strategy
NotificationType.ZERO_BALANCE: QueueType.BACKOFF,
NotificationType.LOW_BALANCE: QueueType.IMMEDIATE,
NotificationType.BLOCK_EXECUTION_FAILED: QueueType.BACKOFF,
NotificationType.CONTINUOUS_AGENT_ERROR: QueueType.BACKOFF,
NotificationType.DAILY_SUMMARY: QueueType.SUMMARY,
NotificationType.WEEKLY_SUMMARY: QueueType.SUMMARY,
NotificationType.MONTHLY_SUMMARY: QueueType.SUMMARY,
NotificationType.REFUND_REQUEST: QueueType.ADMIN,
NotificationType.REFUND_PROCESSED: QueueType.ADMIN,
NotificationType.ZERO_BALANCE: BatchingStrategy.BACKOFF,
NotificationType.LOW_BALANCE: BatchingStrategy.BACKOFF,
NotificationType.BLOCK_EXECUTION_FAILED: BatchingStrategy.BACKOFF,
NotificationType.CONTINUOUS_AGENT_ERROR: BatchingStrategy.BACKOFF,
# These aren't batched by the notification service, so we send them right away
NotificationType.DAILY_SUMMARY: BatchingStrategy.IMMEDIATE,
NotificationType.WEEKLY_SUMMARY: BatchingStrategy.IMMEDIATE,
NotificationType.MONTHLY_SUMMARY: BatchingStrategy.IMMEDIATE,
}
return BATCHING_RULES.get(self.notification_type, QueueType.IMMEDIATE)
return BATCHING_RULES.get(self.notification_type, BatchingStrategy.HOURLY)
@property
def template(self) -> str:
@@ -301,33 +201,8 @@ class NotificationTypeOverride:
NotificationType.DAILY_SUMMARY: "daily_summary.html",
NotificationType.WEEKLY_SUMMARY: "weekly_summary.html",
NotificationType.MONTHLY_SUMMARY: "monthly_summary.html",
NotificationType.REFUND_REQUEST: "refund_request.html",
NotificationType.REFUND_PROCESSED: "refund_processed.html",
}[self.notification_type]
@property
def subject(self) -> str:
return {
NotificationType.AGENT_RUN: "Agent Run Report",
NotificationType.ZERO_BALANCE: "You're out of credits!",
NotificationType.LOW_BALANCE: "Low Balance Warning!",
NotificationType.BLOCK_EXECUTION_FAILED: "Uh oh! Block Execution Failed",
NotificationType.CONTINUOUS_AGENT_ERROR: "Shoot! Continuous Agent Error",
NotificationType.DAILY_SUMMARY: "Here's your daily summary!",
NotificationType.WEEKLY_SUMMARY: "Look at all the cool stuff you did last week!",
NotificationType.MONTHLY_SUMMARY: "We did a lot this month!",
NotificationType.REFUND_REQUEST: "[ACTION REQUIRED] You got a ${{data.amount / 100}} refund request from {{data.user_name}}",
NotificationType.REFUND_PROCESSED: "Refund for ${{data.amount / 100}} to {{data.user_name}} has been processed",
}[self.notification_type]
class NotificationPreferenceDTO(BaseModel):
email: EmailStr = Field(..., description="User's email address")
preferences: dict[NotificationType, bool] = Field(
..., description="Which notifications the user wants"
)
daily_limit: int = Field(..., description="Max emails per day")
class NotificationPreference(BaseModel):
user_id: str
@@ -337,51 +212,12 @@ class NotificationPreference(BaseModel):
)
daily_limit: int = 10 # Max emails per day
emails_sent_today: int = 0
last_reset_date: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc)
)
class UserNotificationEventDTO(BaseModel):
type: NotificationType
data: dict
created_at: datetime
updated_at: datetime
@staticmethod
def from_db(model: NotificationEvent) -> "UserNotificationEventDTO":
return UserNotificationEventDTO(
type=model.type,
data=dict(model.data),
created_at=model.createdAt,
updated_at=model.updatedAt,
)
class UserNotificationBatchDTO(BaseModel):
user_id: str
type: NotificationType
notifications: list[UserNotificationEventDTO]
created_at: datetime
updated_at: datetime
@staticmethod
def from_db(model: UserNotificationBatch) -> "UserNotificationBatchDTO":
return UserNotificationBatchDTO(
user_id=model.userId,
type=model.type,
notifications=[
UserNotificationEventDTO.from_db(notification)
for notification in model.Notifications or []
],
created_at=model.createdAt,
updated_at=model.updatedAt,
)
last_reset_date: datetime = Field(default_factory=datetime.now)
def get_batch_delay(notification_type: NotificationType) -> timedelta:
return {
NotificationType.AGENT_RUN: timedelta(minutes=60),
NotificationType.AGENT_RUN: timedelta(seconds=1),
NotificationType.ZERO_BALANCE: timedelta(minutes=60),
NotificationType.LOW_BALANCE: timedelta(minutes=60),
NotificationType.BLOCK_EXECUTION_FAILED: timedelta(minutes=60),
@@ -392,15 +228,19 @@ def get_batch_delay(notification_type: NotificationType) -> timedelta:
async def create_or_add_to_user_notification_batch(
user_id: str,
notification_type: NotificationType,
notification_data: NotificationEventModel,
) -> UserNotificationBatchDTO:
data: str, # type: 'NotificationEventModel'
) -> dict:
try:
logger.info(
f"Creating or adding to notification batch for {user_id} with type {notification_type} and data {notification_data}"
f"Creating or adding to notification batch for {user_id} with type {notification_type} and data {data}"
)
notification_data = NotificationEventModel[
get_data_type(notification_type)
].model_validate_json(data)
# Serialize the data
json_data: Json = Json(notification_data.data.model_dump())
json_data: Json = Json(notification_data.data.model_dump_json())
# First try to find existing batch
existing_batch = await UserNotificationBatch.prisma().find_unique(
@@ -410,7 +250,7 @@ async def create_or_add_to_user_notification_batch(
"type": notification_type,
}
},
include={"Notifications": True},
include={"notifications": True},
)
if not existing_batch:
@@ -427,11 +267,11 @@ async def create_or_add_to_user_notification_batch(
data={
"userId": user_id,
"type": notification_type,
"Notifications": {"connect": [{"id": notification_event.id}]},
"notifications": {"connect": [{"id": notification_event.id}]},
},
include={"Notifications": True},
include={"notifications": True},
)
return UserNotificationBatchDTO.from_db(resp)
return resp.model_dump()
else:
async with transaction() as tx:
notification_event = await tx.notificationevent.create(
@@ -445,41 +285,35 @@ async def create_or_add_to_user_notification_batch(
resp = await tx.usernotificationbatch.update(
where={"id": existing_batch.id},
data={
"Notifications": {"connect": [{"id": notification_event.id}]}
"notifications": {"connect": [{"id": notification_event.id}]}
},
include={"Notifications": True},
include={"notifications": True},
)
if not resp:
raise DatabaseError(
f"Failed to add notification event {notification_event.id} to existing batch {existing_batch.id}"
)
return UserNotificationBatchDTO.from_db(resp)
return resp.model_dump()
except Exception as e:
raise DatabaseError(
f"Failed to create or add to notification batch for user {user_id} and type {notification_type}: {e}"
) from e
async def get_user_notification_oldest_message_in_batch(
async def get_user_notification_last_message_in_batch(
user_id: str,
notification_type: NotificationType,
) -> UserNotificationEventDTO | None:
) -> NotificationEvent | None:
try:
batch = await UserNotificationBatch.prisma().find_first(
where={"userId": user_id, "type": notification_type},
include={"Notifications": True},
order={"createdAt": "desc"},
)
if not batch:
return None
if not batch.Notifications:
if not batch.notifications:
return None
sorted_notifications = sorted(batch.Notifications, key=lambda x: x.createdAt)
return (
UserNotificationEventDTO.from_db(sorted_notifications[0])
if sorted_notifications
else None
)
return batch.notifications[-1]
except Exception as e:
raise DatabaseError(
f"Failed to get user notification last message in batch for user {user_id} and type {notification_type}: {e}"
@@ -514,34 +348,13 @@ async def empty_user_notification_batch(
async def get_user_notification_batch(
user_id: str,
notification_type: NotificationType,
) -> UserNotificationBatchDTO | None:
) -> UserNotificationBatch | None:
try:
batch = await UserNotificationBatch.prisma().find_first(
return await UserNotificationBatch.prisma().find_first(
where={"userId": user_id, "type": notification_type},
include={"Notifications": True},
include={"notifications": True},
)
return UserNotificationBatchDTO.from_db(batch) if batch else None
except Exception as e:
raise DatabaseError(
f"Failed to get user notification batch for user {user_id} and type {notification_type}: {e}"
) from e
async def get_all_batches_by_type(
notification_type: NotificationType,
) -> list[UserNotificationBatchDTO]:
try:
batches = await UserNotificationBatch.prisma().find_many(
where={
"type": notification_type,
"Notifications": {
"some": {} # Only return batches with at least one notification
},
},
include={"Notifications": True},
)
return [UserNotificationBatchDTO.from_db(batch) for batch in batches]
except Exception as e:
raise DatabaseError(
f"Failed to get all batches by type {notification_type}: {e}"
) from e

View File

@@ -1,276 +0,0 @@
import re
from typing import Any, Optional
import prisma
import pydantic
from prisma import Json
from prisma.enums import OnboardingStep
from prisma.models import UserOnboarding
from prisma.types import UserOnboardingUpdateInput
from backend.data.block import get_blocks
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.server.v2.store.model import StoreAgentDetails
# Mapping from user reason id to categories to search for when choosing agent to show
REASON_MAPPING: dict[str, list[str]] = {
"content_marketing": ["writing", "marketing", "creative"],
"business_workflow_automation": ["business", "productivity"],
"data_research": ["data", "research"],
"ai_innovation": ["development", "research"],
"personal_productivity": ["personal", "productivity"],
}
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
class UserOnboardingUpdate(pydantic.BaseModel):
completedSteps: Optional[list[OnboardingStep]] = None
usageReason: Optional[str] = None
integrations: Optional[list[str]] = None
otherIntegrations: Optional[str] = None
selectedStoreListingVersionId: Optional[str] = None
agentInput: Optional[dict[str, Any]] = None
async def get_user_onboarding(user_id: str):
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id}, # type: ignore
"update": {},
},
)
async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update: UserOnboardingUpdateInput = {}
if data.completedSteps is not None:
update["completedSteps"] = list(set(data.completedSteps))
if data.usageReason is not None:
update["usageReason"] = data.usageReason
if data.integrations is not None:
update["integrations"] = data.integrations
if data.otherIntegrations is not None:
update["otherIntegrations"] = data.otherIntegrations
if data.selectedStoreListingVersionId is not None:
update["selectedStoreListingVersionId"] = data.selectedStoreListingVersionId
if data.agentInput is not None:
update["agentInput"] = Json(data.agentInput)
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, **update}, # type: ignore
"update": update,
},
)
def clean_and_split(text: str) -> list[str]:
"""
Removes all special characters from a string, truncates it to 100 characters,
and splits it by whitespace and commas.
Args:
text (str): The input string.
Returns:
list[str]: A list of cleaned words.
"""
# Remove all special characters (keep only alphanumeric and whitespace)
cleaned_text = re.sub(r"[^a-zA-Z0-9\s,]", "", text.strip()[:100])
# Split by whitespace and commas
words = re.split(r"[\s,]+", cleaned_text)
# Remove empty strings from the list
words = [word.lower() for word in words if word]
return words
def calculate_points(
agent, categories: list[str], custom: list[str], integrations: list[str]
) -> int:
"""
Calculates the total points for an agent based on the specified criteria.
Args:
agent: The agent object.
categories (list[str]): List of categories to match.
words (list[str]): List of words to match in the description.
Returns:
int: Total points for the agent.
"""
points = 0
# 1. Category Matches
matched_categories = sum(
1 for category in categories if category in agent.categories
)
points += matched_categories * 100
# 2. Description Word Matches
description_words = agent.description.split() # Split description into words
matched_words = sum(1 for word in custom if word in description_words)
points += matched_words * 100
matched_words = sum(1 for word in integrations if word in description_words)
points += matched_words * 50
# 3. Featured Bonus
if agent.featured:
points += 50
# 4. Rating Bonus
points += agent.rating * 10
# 5. Runs Bonus
runs_points = min(agent.runs / 1000 * 100, 100) # Cap at 100 points
points += runs_points
return int(points)
def get_credentials_blocks() -> dict[str, str]:
# Returns a dictionary of block id to credentials field name
creds: dict[str, str] = {}
blocks = get_blocks()
for id, block in blocks.items():
for field_name, field_info in block().input_schema.model_fields.items():
if field_info.annotation == CredentialsMetaInput:
creds[id] = field_name
return creds
CREDENTIALS_FIELDS: dict[str, str] = get_credentials_blocks()
async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
user_onboarding = await get_user_onboarding(user_id)
categories = REASON_MAPPING.get(user_onboarding.usageReason or "", [])
where_clause: dict[str, Any] = {}
custom = clean_and_split((user_onboarding.usageReason or "").lower())
if categories:
where_clause["OR"] = [
{"categories": {"has": category}} for category in categories
]
else:
where_clause["OR"] = [
{"description": {"contains": word, "mode": "insensitive"}}
for word in custom
]
where_clause["OR"] += [
{"description": {"contains": word, "mode": "insensitive"}}
for word in user_onboarding.integrations
]
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
where=prisma.types.StoreAgentWhereInput(**where_clause),
order=[
{"featured": "desc"},
{"runs": "desc"},
{"rating": "desc"},
],
take=100,
)
agentListings = await prisma.models.StoreListingVersion.prisma().find_many(
where={
"id": {"in": [agent.storeListingVersionId for agent in storeAgents]},
},
include={"Agent": True},
)
for listing in agentListings:
agent = listing.Agent
if agent is None:
continue
graph = GraphModel.from_db(agent)
# Remove agents with empty input schema
if not graph.input_schema:
storeAgents = [
a for a in storeAgents if a.storeListingVersionId != listing.id
]
continue
# Remove agents with empty credentials
# Get nodes from this agent that have credentials
nodes = await prisma.models.AgentNode.prisma().find_many(
where={
"agentGraphId": agent.id,
"agentBlockId": {"in": list(CREDENTIALS_FIELDS.keys())},
},
)
for node in nodes:
block_id = node.agentBlockId
field_name = CREDENTIALS_FIELDS[block_id]
# If there are no credentials or they are empty, remove the agent
# FIXME ignores default values
if (
field_name not in node.constantInput
or node.constantInput[field_name] is None
):
storeAgents = [
a for a in storeAgents if a.storeListingVersionId != listing.id
]
break
# If there are less than 2 agents, add more agents to the list
if len(storeAgents) < 2:
storeAgents += await prisma.models.StoreAgent.prisma().find_many(
where={
"listing_id": {"not_in": [agent.listing_id for agent in storeAgents]},
},
order=[
{"featured": "desc"},
{"runs": "desc"},
{"rating": "desc"},
],
take=2 - len(storeAgents),
)
# Calculate points for the first X agents and choose the top 2
agent_points = []
for agent in storeAgents[:POINTS_AGENT_COUNT]:
points = calculate_points(
agent, categories, custom, user_onboarding.integrations
)
agent_points.append((agent, points))
agent_points.sort(key=lambda x: x[1], reverse=True)
recommended_agents = [agent for agent, _ in agent_points[:2]]
return [
StoreAgentDetails(
store_listing_version_id=agent.storeListingVersionId,
slug=agent.slug,
agent_name=agent.agent_name,
agent_video=agent.agent_video or "",
agent_image=agent.agent_image,
creator=agent.creator_username,
creator_avatar=agent.creator_avatar,
sub_heading=agent.sub_heading,
description=agent.description,
categories=agent.categories,
runs=agent.runs,
rating=agent.rating,
versions=agent.versions,
last_updated=agent.updated_at,
)
for agent in recommended_agents
]
async def onboarding_enabled() -> bool:
count = await prisma.models.StoreAgent.prisma().count(take=MIN_AGENT_COUNT + 1)
# Onboading is enabled if there are at least 2 agents in the store
return count >= MIN_AGENT_COUNT

View File

@@ -1,6 +1,8 @@
import asyncio
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, AsyncGenerator, Generator, Generic, Optional, TypeVar
from pydantic import BaseModel
@@ -12,6 +14,13 @@ from backend.data import redis
logger = logging.getLogger(__name__)
class DateTimeEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, datetime):
return o.isoformat()
return super().default(o)
M = TypeVar("M", bound=BaseModel)
@@ -23,14 +32,10 @@ class BaseRedisEventBus(Generic[M], ABC):
def event_bus_name(self) -> str:
pass
@property
def Message(self) -> type["_EventPayloadWrapper[M]"]:
return _EventPayloadWrapper[self.Model]
def _serialize_message(self, item: M, channel_key: str) -> tuple[str, str]:
message = self.Message(payload=item).model_dump_json()
message = json.dumps(item.model_dump(), cls=DateTimeEncoder)
channel_name = f"{self.event_bus_name}/{channel_key}"
logger.debug(f"[{channel_name}] Publishing an event to Redis {message}")
logger.info(f"[{channel_name}] Publishing an event to Redis {message}")
return message, channel_name
def _deserialize_message(self, msg: Any, channel_key: str) -> M | None:
@@ -38,8 +43,9 @@ class BaseRedisEventBus(Generic[M], ABC):
if msg["type"] != message_type:
return None
try:
logger.debug(f"[{channel_key}] Consuming an event from Redis {msg['data']}")
return self.Message.model_validate_json(msg["data"]).payload
data = json.loads(msg["data"])
logger.info(f"Consuming an event from Redis {data}")
return self.Model(**data)
except Exception as e:
logger.error(f"Failed to parse event result from Redis {msg} {e}")
@@ -51,16 +57,9 @@ class BaseRedisEventBus(Generic[M], ABC):
return pubsub, full_channel_name
class _EventPayloadWrapper(BaseModel, Generic[M]):
"""
Wrapper model to allow `RedisEventBus.Model` to be a discriminated union
of multiple event types.
"""
payload: M
class RedisEventBus(BaseRedisEventBus[M], ABC):
Model: type[M]
@property
def connection(self) -> redis.Redis:
return redis.get_redis()
@@ -86,6 +85,8 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
Model: type[M]
@property
async def connection(self) -> redis.AsyncRedis:
return await redis.get_redis_async()

View File

@@ -1,24 +1,18 @@
import base64
import hashlib
import hmac
import logging
from datetime import datetime, timedelta
from typing import Optional, cast
from urllib.parse import quote_plus
from autogpt_libs.auth.models import DEFAULT_USER_ID
from fastapi import HTTPException
from prisma import Json
from prisma.enums import NotificationType
from prisma.models import User
from prisma.types import UserUpdateInput
from backend.data.db import prisma
from backend.data.model import UserIntegrations, UserMetadata, UserMetadataRaw
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.data.notifications import NotificationPreference
from backend.server.v2.store.exceptions import DatabaseError
from backend.util.encryption import JSONCryptor
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -55,31 +49,6 @@ async def get_user_by_id(user_id: str) -> User:
return User.model_validate(user)
async def get_user_email_by_id(user_id: str) -> Optional[str]:
try:
user = await prisma.user.find_unique(where={"id": user_id})
return user.email if user else None
except Exception as e:
raise DatabaseError(f"Failed to get user email for user {user_id}: {e}") from e
async def get_user_by_email(email: str) -> Optional[User]:
try:
user = await prisma.user.find_unique(where={"email": email})
return User.model_validate(user) if user else None
except Exception as e:
raise DatabaseError(f"Failed to get user by email {email}: {e}") from e
async def update_user_email(user_id: str, email: str):
try:
await prisma.user.update(where={"id": user_id}, data={"email": email})
except Exception as e:
raise DatabaseError(
f"Failed to update user email for user {user_id}: {e}"
) from e
async def create_default_user() -> Optional[User]:
user = await prisma.user.find_unique(where={"id": DEFAULT_USER_ID})
if not user:
@@ -209,16 +178,16 @@ async def get_user_notification_preference(user_id: str) -> NotificationPreferen
# enable notifications by default if user has no notification preference (shouldn't ever happen though)
preferences: dict[NotificationType, bool] = {
NotificationType.AGENT_RUN: user.notifyOnAgentRun or False,
NotificationType.ZERO_BALANCE: user.notifyOnZeroBalance or False,
NotificationType.LOW_BALANCE: user.notifyOnLowBalance or False,
NotificationType.AGENT_RUN: user.notifyOnAgentRun or True,
NotificationType.ZERO_BALANCE: user.notifyOnZeroBalance or True,
NotificationType.LOW_BALANCE: user.notifyOnLowBalance or True,
NotificationType.BLOCK_EXECUTION_FAILED: user.notifyOnBlockExecutionFailed
or False,
or True,
NotificationType.CONTINUOUS_AGENT_ERROR: user.notifyOnContinuousAgentError
or False,
NotificationType.DAILY_SUMMARY: user.notifyOnDailySummary or False,
NotificationType.WEEKLY_SUMMARY: user.notifyOnWeeklySummary or False,
NotificationType.MONTHLY_SUMMARY: user.notifyOnMonthlySummary or False,
or True,
NotificationType.DAILY_SUMMARY: user.notifyOnDailySummary or True,
NotificationType.WEEKLY_SUMMARY: user.notifyOnWeeklySummary or True,
NotificationType.MONTHLY_SUMMARY: user.notifyOnMonthlySummary or True,
}
daily_limit = user.maxEmailsPerDay or 3
notification_preference = NotificationPreference(
@@ -236,162 +205,3 @@ async def get_user_notification_preference(user_id: str) -> NotificationPreferen
raise DatabaseError(
f"Failed to upsert user notification preference for user {user_id}: {e}"
) from e
async def update_user_notification_preference(
user_id: str, data: NotificationPreferenceDTO
) -> NotificationPreference:
try:
update_data: UserUpdateInput = {}
if data.email:
update_data["email"] = data.email
if NotificationType.AGENT_RUN in data.preferences:
update_data["notifyOnAgentRun"] = data.preferences[
NotificationType.AGENT_RUN
]
if NotificationType.ZERO_BALANCE in data.preferences:
update_data["notifyOnZeroBalance"] = data.preferences[
NotificationType.ZERO_BALANCE
]
if NotificationType.LOW_BALANCE in data.preferences:
update_data["notifyOnLowBalance"] = data.preferences[
NotificationType.LOW_BALANCE
]
if NotificationType.BLOCK_EXECUTION_FAILED in data.preferences:
update_data["notifyOnBlockExecutionFailed"] = data.preferences[
NotificationType.BLOCK_EXECUTION_FAILED
]
if NotificationType.CONTINUOUS_AGENT_ERROR in data.preferences:
update_data["notifyOnContinuousAgentError"] = data.preferences[
NotificationType.CONTINUOUS_AGENT_ERROR
]
if NotificationType.DAILY_SUMMARY in data.preferences:
update_data["notifyOnDailySummary"] = data.preferences[
NotificationType.DAILY_SUMMARY
]
if NotificationType.WEEKLY_SUMMARY in data.preferences:
update_data["notifyOnWeeklySummary"] = data.preferences[
NotificationType.WEEKLY_SUMMARY
]
if NotificationType.MONTHLY_SUMMARY in data.preferences:
update_data["notifyOnMonthlySummary"] = data.preferences[
NotificationType.MONTHLY_SUMMARY
]
if data.daily_limit:
update_data["maxEmailsPerDay"] = data.daily_limit
user = await User.prisma().update(
where={"id": user_id},
data=update_data,
)
if not user:
raise ValueError(f"User not found with ID: {user_id}")
preferences: dict[NotificationType, bool] = {
NotificationType.AGENT_RUN: user.notifyOnAgentRun or True,
NotificationType.ZERO_BALANCE: user.notifyOnZeroBalance or True,
NotificationType.LOW_BALANCE: user.notifyOnLowBalance or True,
NotificationType.BLOCK_EXECUTION_FAILED: user.notifyOnBlockExecutionFailed
or True,
NotificationType.CONTINUOUS_AGENT_ERROR: user.notifyOnContinuousAgentError
or True,
NotificationType.DAILY_SUMMARY: user.notifyOnDailySummary or True,
NotificationType.WEEKLY_SUMMARY: user.notifyOnWeeklySummary or True,
NotificationType.MONTHLY_SUMMARY: user.notifyOnMonthlySummary or True,
}
notification_preference = NotificationPreference(
user_id=user.id,
email=user.email,
preferences=preferences,
daily_limit=user.maxEmailsPerDay or 3,
# TODO with other changes later, for now we just will email them
emails_sent_today=0,
last_reset_date=datetime.now(),
)
return NotificationPreference.model_validate(notification_preference)
except Exception as e:
raise DatabaseError(
f"Failed to update user notification preference for user {user_id}: {e}"
) from e
async def set_user_email_verification(user_id: str, verified: bool) -> None:
"""Set the email verification status for a user."""
try:
await User.prisma().update(
where={"id": user_id},
data={"emailVerified": verified},
)
except Exception as e:
raise DatabaseError(
f"Failed to set email verification status for user {user_id}: {e}"
) from e
async def get_user_email_verification(user_id: str) -> bool:
"""Get the email verification status for a user."""
try:
user = await User.prisma().find_unique_or_raise(
where={"id": user_id},
)
return user.emailVerified
except Exception as e:
raise DatabaseError(
f"Failed to get email verification status for user {user_id}: {e}"
) from e
def generate_unsubscribe_link(user_id: str) -> str:
"""Generate a link to unsubscribe from all notifications"""
# Create an HMAC using a secret key
secret_key = Settings().secrets.unsubscribe_secret_key
signature = hmac.new(
secret_key.encode("utf-8"), user_id.encode("utf-8"), hashlib.sha256
).digest()
# Create a token that combines the user_id and signature
token = base64.urlsafe_b64encode(
f"{user_id}:{signature.hex()}".encode("utf-8")
).decode("utf-8")
logger.info(f"Generating unsubscribe link for user {user_id}")
base_url = Settings().config.platform_base_url
return f"{base_url}/api/email/unsubscribe?token={quote_plus(token)}"
async def unsubscribe_user_by_token(token: str) -> None:
"""Unsubscribe a user from all notifications using the token"""
try:
# Decode the token
decoded = base64.urlsafe_b64decode(token).decode("utf-8")
user_id, received_signature_hex = decoded.split(":", 1)
# Verify the signature
secret_key = Settings().secrets.unsubscribe_secret_key
expected_signature = hmac.new(
secret_key.encode("utf-8"), user_id.encode("utf-8"), hashlib.sha256
).digest()
if not hmac.compare_digest(expected_signature.hex(), received_signature_hex):
raise ValueError("Invalid token signature")
user = await get_user_by_id(user_id)
await update_user_notification_preference(
user.id,
NotificationPreferenceDTO(
email=user.email,
daily_limit=0,
preferences={
NotificationType.AGENT_RUN: False,
NotificationType.ZERO_BALANCE: False,
NotificationType.LOW_BALANCE: False,
NotificationType.BLOCK_EXECUTION_FAILED: False,
NotificationType.CONTINUOUS_AGENT_ERROR: False,
NotificationType.DAILY_SUMMARY: False,
NotificationType.WEEKLY_SUMMARY: False,
NotificationType.MONTHLY_SUMMARY: False,
},
),
)
except Exception as e:
raise DatabaseError(f"Failed to unsubscribe user by token {token}: {e}") from e

View File

@@ -1,12 +1,15 @@
from backend.app import run_processes
from backend.executor import ExecutionManager
from backend.executor import DatabaseManager, ExecutionManager
def main():
"""
Run all the processes required for the AutoGPT-server REST API.
"""
run_processes(ExecutionManager())
run_processes(
DatabaseManager(),
ExecutionManager(),
)
if __name__ == "__main__":

View File

@@ -1,9 +1,9 @@
from .database import DatabaseManager
from .manager import ExecutionManager
from .scheduler import Scheduler
from .scheduler import ExecutionScheduler
__all__ = [
"DatabaseManager",
"ExecutionManager",
"Scheduler",
"ExecutionScheduler",
]

View File

@@ -1,55 +1,45 @@
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from functools import wraps
from typing import Any, Callable, Concatenate, Coroutine, ParamSpec, TypeVar, cast
from backend.data.credit import get_user_credit_model
from backend.data.execution import (
GraphExecution,
NodeExecutionResult,
ExecutionResult,
NodeExecutionEntry,
RedisExecutionEventBus,
create_graph_execution,
get_graph_execution,
get_incomplete_node_executions,
get_latest_node_execution,
get_node_execution_results,
update_graph_execution_start_time,
get_execution_results,
get_executions_in_timerange,
get_incomplete_executions,
get_latest_execution,
update_execution_status,
update_graph_execution_stats,
update_node_execution_stats,
update_node_execution_status,
update_node_execution_status_batch,
upsert_execution_input,
upsert_execution_output,
)
from backend.data.graph import (
get_connected_output_nodes,
get_graph,
get_graph_metadata,
get_node,
)
from backend.data.graph import get_graph, get_node
from backend.data.notifications import (
create_or_add_to_user_notification_batch,
empty_user_notification_batch,
get_all_batches_by_type,
get_user_notification_batch,
get_user_notification_oldest_message_in_batch,
get_user_notification_last_message_in_batch,
)
from backend.data.user import (
get_active_user_ids_in_timerange,
get_user_email_by_id,
get_user_email_verification,
get_active_users_ids,
get_user_by_id,
get_user_integrations,
get_user_metadata,
get_user_notification_preference,
update_user_integrations,
update_user_metadata,
)
from backend.util.service import AppService, expose, exposed_run_and_wait
from backend.util.service import AppService, expose, register_pydantic_serializers
from backend.util.settings import Config
P = ParamSpec("P")
R = TypeVar("R")
config = Config()
_user_credit_model = get_user_credit_model()
async def _spend_credits(
user_id: str, cost: int, metadata: UsageTransactionMetadata
) -> int:
return await _user_credit_model.spend_credits(user_id, cost, metadata)
class DatabaseManager(AppService):
@@ -57,70 +47,75 @@ class DatabaseManager(AppService):
super().__init__()
self.use_db = True
self.use_redis = True
self.execution_event_bus = RedisExecutionEventBus()
self.event_queue = RedisExecutionEventBus()
@classmethod
def get_port(cls) -> int:
return config.database_api_port
@expose
def send_execution_update(
self, execution_result: GraphExecution | NodeExecutionResult
):
self.execution_event_bus.publish(execution_result)
def send_execution_update(self, execution_result: ExecutionResult):
self.event_queue.publish(execution_result)
@staticmethod
def exposed_run_and_wait(
f: Callable[P, Coroutine[None, None, R]]
) -> Callable[Concatenate[object, P], R]:
@expose
@wraps(f)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
coroutine = f(*args, **kwargs)
res = self.run_and_wait(coroutine)
return res
# Register serializers for annotations on bare function
register_pydantic_serializers(f)
return wrapper
# Executions
get_graph_execution = exposed_run_and_wait(get_graph_execution)
create_graph_execution = exposed_run_and_wait(create_graph_execution)
get_node_execution_results = exposed_run_and_wait(get_node_execution_results)
get_incomplete_node_executions = exposed_run_and_wait(
get_incomplete_node_executions
)
get_latest_node_execution = exposed_run_and_wait(get_latest_node_execution)
update_node_execution_status = exposed_run_and_wait(update_node_execution_status)
update_node_execution_status_batch = exposed_run_and_wait(
update_node_execution_status_batch
)
update_graph_execution_start_time = exposed_run_and_wait(
update_graph_execution_start_time
)
get_execution_results = exposed_run_and_wait(get_execution_results)
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
get_latest_execution = exposed_run_and_wait(get_latest_execution)
update_execution_status = exposed_run_and_wait(update_execution_status)
update_graph_execution_stats = exposed_run_and_wait(update_graph_execution_stats)
update_node_execution_stats = exposed_run_and_wait(update_node_execution_stats)
upsert_execution_input = exposed_run_and_wait(upsert_execution_input)
upsert_execution_output = exposed_run_and_wait(upsert_execution_output)
get_executions_in_timerange = exposed_run_and_wait(get_executions_in_timerange)
# Graphs
get_node = exposed_run_and_wait(get_node)
get_graph = exposed_run_and_wait(get_graph)
get_connected_output_nodes = exposed_run_and_wait(get_connected_output_nodes)
get_graph_metadata = exposed_run_and_wait(get_graph_metadata)
# Credits
spend_credits = exposed_run_and_wait(_spend_credits)
user_credit_model = get_user_credit_model()
spend_credits = cast(
Callable[[Any, NodeExecutionEntry, float, float], int],
exposed_run_and_wait(user_credit_model.spend_credits),
)
# User + User Metadata + User Integrations
# User + User Metadata + User Integrations + User Notification Preferences
get_user_metadata = exposed_run_and_wait(get_user_metadata)
update_user_metadata = exposed_run_and_wait(update_user_metadata)
get_user_integrations = exposed_run_and_wait(get_user_integrations)
update_user_integrations = exposed_run_and_wait(update_user_integrations)
# User Comms - async
get_active_user_ids_in_timerange = exposed_run_and_wait(
get_active_user_ids_in_timerange
)
get_user_email_by_id = exposed_run_and_wait(get_user_email_by_id)
get_user_email_verification = exposed_run_and_wait(get_user_email_verification)
get_user_by_id = exposed_run_and_wait(get_user_by_id)
get_user_notification_preference = exposed_run_and_wait(
get_user_notification_preference
)
get_active_users_ids = exposed_run_and_wait(get_active_users_ids)
# Notifications - async
# Notifications
create_or_add_to_user_notification_batch = exposed_run_and_wait(
create_or_add_to_user_notification_batch
)
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
get_all_batches_by_type = exposed_run_and_wait(get_all_batches_by_type)
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)
get_user_notification_oldest_message_in_batch = exposed_run_and_wait(
get_user_notification_oldest_message_in_batch
get_user_notification_last_message_in_batch = exposed_run_and_wait(
get_user_notification_last_message_in_batch
)
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)

View File

@@ -12,19 +12,8 @@ from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
from redis.lock import Lock as RedisLock
from backend.blocks.io import AgentOutputBlock
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.notifications import (
AgentRunData,
LowBalanceData,
NotificationEventDTO,
NotificationType,
)
from backend.util.exceptions import InsufficientBalanceError
if TYPE_CHECKING:
from backend.executor import DatabaseManager
from backend.notifications.notifications import NotificationManager
from autogpt_libs.utils.cache import thread_cached
@@ -40,19 +29,14 @@ from backend.data.block import (
)
from backend.data.execution import (
ExecutionQueue,
ExecutionResult,
ExecutionStatus,
GraphExecutionEntry,
NodeExecutionEntry,
NodeExecutionResult,
merge_execution_input,
parse_execution_output,
)
from backend.data.graph import GraphModel, Link, Node
from backend.executor.utils import (
UsageTransactionMetadata,
block_usage_cost,
execution_usage_cost,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util import json
from backend.util.decorator import error_logged, time_measured
@@ -114,10 +98,7 @@ class LogMetadata:
logger.exception(msg, extra={"json_fields": {**self.metadata, **extra}})
def _wrap(self, msg: str, **extra):
extra_msg = str(extra or "")
if len(extra_msg) > 1000:
extra_msg = extra_msg[:1000] + "..."
return f"{self.prefix} {msg} {extra_msg}"
return f"{self.prefix} {msg} {extra}"
T = TypeVar("T")
@@ -128,7 +109,7 @@ def execute_node(
db_client: "DatabaseManager",
creds_manager: IntegrationCredentialsManager,
data: NodeExecutionEntry,
execution_stats: NodeExecutionStats | None = None,
execution_stats: dict[str, Any] | None = None,
) -> ExecutionStream:
"""
Execute a node in the graph. This will trigger a block execution on a node,
@@ -149,9 +130,8 @@ def execute_node(
node_exec_id = data.node_exec_id
node_id = data.node_id
def update_execution_status(status: ExecutionStatus) -> NodeExecutionResult:
"""Sets status and fetches+broadcasts the latest state of the node execution"""
exec_update = db_client.update_node_execution_status(node_exec_id, status)
def update_execution(status: ExecutionStatus) -> ExecutionResult:
exec_update = db_client.update_execution_status(node_exec_id, status)
db_client.send_execution_update(exec_update)
return exec_update
@@ -162,17 +142,6 @@ def execute_node(
logger.error(f"Block {node.block_id} not found.")
return
def push_output(output_name: str, output_data: Any) -> None:
_push_node_execution_output(
db_client=db_client,
user_id=user_id,
graph_exec_id=graph_exec_id,
node_exec_id=node_exec_id,
block_id=node_block.id,
output_name=output_name,
output_data=output_data,
)
log_metadata = LogMetadata(
user_id=user_id,
graph_eid=graph_exec_id,
@@ -186,8 +155,8 @@ def execute_node(
input_data, error = validate_exec(node, data.data, resolve_input=False)
if input_data is None:
log_metadata.error(f"Skip execution, input validation error: {error}")
push_output("error", error)
update_execution_status(ExecutionStatus.FAILED)
db_client.upsert_execution_output(node_exec_id, "error", error)
update_execution(ExecutionStatus.FAILED)
return
# Re-shape the input data for agent block.
@@ -200,7 +169,7 @@ def execute_node(
input_data_str = json.dumps(input_data)
input_size = len(input_data_str)
log_metadata.info("Executed node with input", input=input_data_str)
update_execution_status(ExecutionStatus.RUNNING)
update_execution(ExecutionStatus.RUNNING)
# Inject extra execution arguments for the blocks via kwargs
extra_exec_kwargs: dict = {
@@ -224,15 +193,18 @@ def execute_node(
output_size = 0
try:
outputs: dict[str, Any] = {}
# Charge the user for the execution before running the block.
# TODO: We assume the block is executed within 0 seconds.
# This is fine because for now, there is no block that is charged by time.
db_client.spend_credits(data, input_size + output_size, 0)
for output_name, output_data in node_block.execute(
input_data, **extra_exec_kwargs
):
output_data = json.convert_pydantic_to_json(output_data)
output_size += len(json.dumps(output_data))
log_metadata.info("Node produced output", **{output_name: output_data})
push_output(output_name, output_data)
outputs[output_name] = output_data
db_client.upsert_execution_output(node_exec_id, output_name, output_data)
for execution in _enqueue_next_nodes(
db_client=db_client,
node=node,
@@ -244,12 +216,12 @@ def execute_node(
):
yield execution
update_execution_status(ExecutionStatus.COMPLETED)
update_execution(ExecutionStatus.COMPLETED)
except Exception as e:
error_msg = str(e)
push_output("error", error_msg)
update_execution_status(ExecutionStatus.FAILED)
db_client.upsert_execution_output(node_exec_id, "error", error_msg)
update_execution(ExecutionStatus.FAILED)
for execution in _enqueue_next_nodes(
db_client=db_client,
@@ -265,7 +237,7 @@ def execute_node(
raise e
finally:
# Ensure credentials are released even if execution fails
if creds_lock and creds_lock.locked():
if creds_lock:
try:
creds_lock.release()
except Exception as e:
@@ -273,40 +245,9 @@ def execute_node(
# Update execution stats
if execution_stats is not None:
execution_stats = execution_stats.model_copy(
update=node_block.execution_stats.model_dump()
)
execution_stats.input_size = input_size
execution_stats.output_size = output_size
def _push_node_execution_output(
db_client: "DatabaseManager",
user_id: str,
graph_exec_id: str,
node_exec_id: str,
block_id: str,
output_name: str,
output_data: Any,
):
from backend.blocks.io import IO_BLOCK_IDs
db_client.upsert_execution_output(
node_exec_id=node_exec_id,
output_name=output_name,
output_data=output_data,
)
# Automatically push execution updates for all agent I/O
if block_id in IO_BLOCK_IDs:
graph_exec = db_client.get_graph_execution(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec:
raise ValueError(
f"Graph execution #{graph_exec_id} for user #{user_id} not found"
)
db_client.send_execution_update(graph_exec)
execution_stats.update(node_block.execution_stats)
execution_stats["input_size"] = input_size
execution_stats["output_size"] = output_size
def _enqueue_next_nodes(
@@ -321,7 +262,7 @@ def _enqueue_next_nodes(
def add_enqueued_execution(
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
) -> NodeExecutionEntry:
exec_update = db_client.update_node_execution_status(
exec_update = db_client.update_execution_status(
node_exec_id, ExecutionStatus.QUEUED, data
)
db_client.send_execution_update(exec_update)
@@ -366,7 +307,7 @@ def _enqueue_next_nodes(
if link.is_static and link.sink_name not in next_node_input
}
if static_link_names and (
latest_execution := db_client.get_latest_node_execution(
latest_execution := db_client.get_latest_execution(
next_node_id, graph_exec_id
)
):
@@ -399,7 +340,7 @@ def _enqueue_next_nodes(
# If link is static, there could be some incomplete executions waiting for it.
# Load and complete the input missing input data, and try to re-enqueue them.
for iexec in db_client.get_incomplete_node_executions(
for iexec in db_client.get_incomplete_executions(
next_node_id, graph_exec_id
):
idata = iexec.input_data
@@ -459,30 +400,46 @@ def validate_exec(
node_block: Block | None = get_block(node.block_id)
if not node_block:
return None, f"Block for {node.block_id} not found."
schema = node_block.input_schema
# Convert non-matching data types to the expected input schema.
for name, data_type in schema.__annotations__.items():
if (value := data.get(name)) and (type(value) is not data_type):
data[name] = convert(value, data_type)
if isinstance(node_block, AgentExecutorBlock):
# Validate the execution metadata for the agent executor block.
try:
exec_data = AgentExecutorBlock.Input(**node.input_default)
except Exception as e:
return None, f"Input data doesn't match {node_block.name}: {str(e)}"
# Validation input
input_schema = exec_data.input_schema
required_fields = set(input_schema["required"])
input_default = exec_data.data
else:
# Convert non-matching data types to the expected input schema.
for name, data_type in node_block.input_schema.__annotations__.items():
if (value := data.get(name)) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Validation input
input_schema = node_block.input_schema.jsonschema()
required_fields = node_block.input_schema.get_required_fields()
input_default = node.input_default
# Input data (without default values) should contain all required fields.
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
if missing_links := schema.get_missing_links(data, node.input_links):
return None, f"{error_prefix} unpopulated links {missing_links}"
input_fields_from_nodes = {link.sink_name for link in node.input_links}
if not input_fields_from_nodes.issubset(data):
return None, f"{error_prefix} {input_fields_from_nodes - set(data)}"
# Merge input data with default values and resolve dynamic dict/list/object pins.
input_default = schema.get_input_defaults(node.input_default)
data = {**input_default, **data}
if resolve_input:
data = merge_execution_input(data)
# Input data post-merge should contain all required fields from the schema.
if missing_input := schema.get_missing_input(data):
return None, f"{error_prefix} missing input {missing_input}"
if not required_fields.issubset(data):
return None, f"{error_prefix} {required_fields - set(data)}"
# Last validation: Validate the input values against the schema.
if error := schema.get_mismatch_error(data):
if error := json.validate_with_jsonschema(schema=input_schema, data=data):
error_message = f"{error_prefix} {error}"
logger.error(error_message)
return None, error_message
@@ -563,7 +520,7 @@ class Executor:
cls,
q: ExecutionQueue[NodeExecutionEntry],
node_exec: NodeExecutionEntry,
) -> NodeExecutionStats:
) -> dict[str, Any]:
log_metadata = LogMetadata(
user_id=node_exec.user_id,
graph_eid=node_exec.graph_exec_id,
@@ -573,15 +530,13 @@ class Executor:
block_name="-",
)
execution_stats = NodeExecutionStats()
execution_stats = {}
timing_info, _ = cls._on_node_execution(
q, node_exec, log_metadata, execution_stats
)
execution_stats.walltime = timing_info.wall_time
execution_stats.cputime = timing_info.cpu_time
execution_stats["walltime"] = timing_info.wall_time
execution_stats["cputime"] = timing_info.cpu_time
if isinstance(execution_stats.error, Exception):
execution_stats.error = str(execution_stats.error)
cls.db_client.update_node_execution_stats(
node_exec.node_exec_id, execution_stats
)
@@ -594,15 +549,12 @@ class Executor:
q: ExecutionQueue[NodeExecutionEntry],
node_exec: NodeExecutionEntry,
log_metadata: LogMetadata,
stats: NodeExecutionStats | None = None,
stats: dict[str, Any] | None = None,
):
try:
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
for execution in execute_node(
db_client=cls.db_client,
creds_manager=cls.creds_manager,
data=node_exec,
execution_stats=stats,
cls.db_client, cls.creds_manager, node_exec, stats
):
q.add(execution)
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
@@ -617,9 +569,6 @@ class Executor:
f"Failed node execution {node_exec.node_exec_id}: {e}"
)
if stats is not None:
stats.error = e
@classmethod
def on_graph_executor_start(cls):
configure_logging()
@@ -628,7 +577,6 @@ class Executor:
cls.db_client = get_db_client()
cls.pool_size = settings.config.num_node_workers
cls.pid = os.getpid()
cls.notification_service = get_notification_service()
cls._init_node_executor_pool()
logger.info(
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
@@ -666,72 +614,18 @@ class Executor:
node_eid="*",
block_name="-",
)
exec_meta = cls.db_client.update_graph_execution_start_time(
graph_exec.graph_exec_id
)
cls.db_client.send_execution_update(exec_meta)
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
graph_exec, cancel, log_metadata
)
exec_stats.walltime = timing_info.wall_time
exec_stats.cputime = timing_info.cpu_time
exec_stats.error = str(error)
if graph_exec_result := cls.db_client.update_graph_execution_stats(
exec_stats["walltime"] = timing_info.wall_time
exec_stats["cputime"] = timing_info.cpu_time
exec_stats["error"] = str(error) if error else None
result = cls.db_client.update_graph_execution_stats(
graph_exec_id=graph_exec.graph_exec_id,
status=status,
stats=exec_stats,
):
cls.db_client.send_execution_update(graph_exec_result)
cls._handle_agent_run_notif(graph_exec, exec_stats)
@classmethod
def _charge_usage(
cls,
node_exec: NodeExecutionEntry,
execution_count: int,
execution_stats: GraphExecutionStats,
) -> int:
block = get_block(node_exec.block_id)
if not block:
logger.error(f"Block {node_exec.block_id} not found.")
return execution_count
cost, matching_filter = block_usage_cost(block=block, input_data=node_exec.data)
if cost > 0:
cls.db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
block=block.name,
input=matching_filter,
),
)
execution_stats.cost += cost
cost, execution_count = execution_usage_cost(execution_count)
if cost > 0:
cls.db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
input={
"execution_count": execution_count,
"charge": "Execution Cost",
},
),
)
execution_stats.cost += cost
return execution_count
)
cls.db_client.send_execution_update(result)
@classmethod
@time_measured
@@ -740,7 +634,7 @@ class Executor:
graph_exec: GraphExecutionEntry,
cancel: threading.Event,
log_metadata: LogMetadata,
) -> tuple[GraphExecutionStats, ExecutionStatus, Exception | None]:
) -> tuple[dict[str, Any], ExecutionStatus, Exception | None]:
"""
Returns:
dict: The execution statistics of the graph execution.
@@ -748,7 +642,11 @@ class Executor:
Exception | None: The error that occurred during the execution, if any.
"""
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
exec_stats = GraphExecutionStats()
exec_stats = {
"nodes_walltime": 0,
"nodes_cputime": 0,
"node_count": 0,
}
error = None
finished = False
@@ -767,24 +665,24 @@ class Executor:
try:
queue = ExecutionQueue[NodeExecutionEntry]()
for node_exec in graph_exec.start_node_execs:
exec_update = cls.db_client.update_execution_status(
node_exec.node_exec_id, ExecutionStatus.QUEUED, node_exec.data
)
cls.db_client.send_execution_update(exec_update)
queue.add(node_exec)
exec_cost_counter = 0
running_executions: dict[str, AsyncResult] = {}
def make_exec_callback(exec_data: NodeExecutionEntry):
node_id = exec_data.node_id
def callback(result: object):
running_executions.pop(exec_data.node_id)
if not isinstance(result, NodeExecutionStats):
return
running_executions.pop(node_id)
nonlocal exec_stats
exec_stats.node_count += 1
exec_stats.nodes_cputime += result.cputime
exec_stats.nodes_walltime += result.walltime
if (err := result.error) and isinstance(err, Exception):
exec_stats.node_error_count += 1
if isinstance(result, dict):
exec_stats["node_count"] += 1
exec_stats["nodes_cputime"] += result.get("cputime", 0)
exec_stats["nodes_walltime"] += result.get("walltime", 0)
return callback
@@ -807,38 +705,6 @@ class Executor:
f"Dispatching node execution {exec_data.node_exec_id} "
f"for node {exec_data.node_id}",
)
try:
exec_cost_counter = cls._charge_usage(
node_exec=exec_data,
execution_count=exec_cost_counter + 1,
execution_stats=exec_stats,
)
except InsufficientBalanceError as error:
node_exec_id = exec_data.node_exec_id
_push_node_execution_output(
db_client=cls.db_client,
user_id=graph_exec.user_id,
graph_exec_id=graph_exec.graph_exec_id,
node_exec_id=node_exec_id,
block_id=exec_data.block_id,
output_name="error",
output_data=str(error),
)
exec_update = cls.db_client.update_node_execution_status(
node_exec_id, ExecutionStatus.FAILED
)
cls.db_client.send_execution_update(exec_update)
cls._handle_low_balance_notif(
graph_exec.user_id,
graph_exec.graph_id,
exec_stats,
error,
)
raise
running_executions[exec_data.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(queue, exec_data),
@@ -861,87 +727,22 @@ class Executor:
execution.wait(3)
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
except Exception as e:
log_metadata.exception(
f"Failed graph execution {graph_exec.graph_exec_id}: {e}"
)
error = e
finally:
if error:
log_metadata.error(
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
)
execution_status = ExecutionStatus.FAILED
else:
execution_status = ExecutionStatus.COMPLETED
if not cancel.is_set():
finished = True
cancel.set()
cancel_thread.join()
clean_exec_files(graph_exec.graph_exec_id)
return exec_stats, execution_status, error
@classmethod
def _handle_agent_run_notif(
cls,
graph_exec: GraphExecutionEntry,
exec_stats: GraphExecutionStats,
):
metadata = cls.db_client.get_graph_metadata(
graph_exec.graph_id, graph_exec.graph_version
)
outputs = cls.db_client.get_node_execution_results(
graph_exec.graph_exec_id,
block_ids=[AgentOutputBlock().id],
)
named_outputs = [
{
key: value[0] if key == "name" else value
for key, value in output.output_data.items()
}
for output in outputs
]
event = NotificationEventDTO(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=named_outputs,
agent_name=metadata.name if metadata else "Unknown Agent",
credits_used=exec_stats.cost,
execution_time=exec_stats.walltime,
graph_id=graph_exec.graph_id,
node_count=exec_stats.node_count,
).model_dump(),
)
cls.notification_service.queue_notification(event)
@classmethod
def _handle_low_balance_notif(
cls,
user_id: str,
graph_id: str,
exec_stats: GraphExecutionStats,
e: InsufficientBalanceError,
):
shortfall = e.balance - e.amount
metadata = cls.db_client.get_graph_metadata(graph_id)
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
cls.notification_service.queue_notification(
NotificationEventDTO(
user_id=user_id,
type=NotificationType.LOW_BALANCE,
data=LowBalanceData(
current_balance=exec_stats.cost,
billing_page_link=f"{base_url}/profile/credits",
shortfall=shortfall,
agent_name=metadata.name if metadata else "Unknown Agent",
).model_dump(),
)
return (
exec_stats,
ExecutionStatus.FAILED if error else ExecutionStatus.COMPLETED,
error,
)
@@ -1046,26 +847,17 @@ class ExecutionManager(AppService):
else:
nodes_input.append((node.id, input_data))
if not nodes_input:
raise ValueError(
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
)
graph_exec = self.db_client.create_graph_execution(
graph_exec_id, node_execs = self.db_client.create_graph_execution(
graph_id=graph_id,
graph_version=graph.version,
nodes_input=nodes_input,
user_id=user_id,
preset_id=preset_id,
)
self.db_client.send_execution_update(graph_exec)
graph_exec_entry = GraphExecutionEntry(
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version or 0,
graph_exec_id=graph_exec.id,
start_node_execs=[
starting_node_execs = []
for node_exec in node_execs:
starting_node_execs.append(
NodeExecutionEntry(
user_id=user_id,
graph_exec_id=node_exec.graph_exec_id,
@@ -1075,12 +867,18 @@ class ExecutionManager(AppService):
block_id=node_exec.block_id,
data=node_exec.input_data,
)
for node_exec in graph_exec.node_executions
],
)
self.queue.add(graph_exec_entry)
)
return graph_exec_entry
graph_exec = GraphExecutionEntry(
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version or 0,
graph_exec_id=graph_exec_id,
start_node_execs=starting_node_execs,
)
self.queue.add(graph_exec)
return graph_exec
@expose
def cancel_execution(self, graph_exec_id: str) -> None:
@@ -1092,36 +890,29 @@ class ExecutionManager(AppService):
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
if graph_exec_id not in self.active_graph_runs:
logger.warning(
raise Exception(
f"Graph execution #{graph_exec_id} not active/running: "
"possibly already completed/cancelled."
)
else:
future, cancel_event = self.active_graph_runs[graph_exec_id]
if not cancel_event.is_set():
cancel_event.set()
future.result()
# Update the status of the graph & node executions
self.db_client.update_graph_execution_stats(
graph_exec_id,
ExecutionStatus.TERMINATED,
)
node_execs = self.db_client.get_node_execution_results(
graph_exec_id=graph_exec_id,
statuses=[
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
ExecutionStatus.INCOMPLETE,
],
)
self.db_client.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in node_execs],
ExecutionStatus.TERMINATED,
)
future, cancel_event = self.active_graph_runs[graph_exec_id]
if cancel_event.is_set():
return
cancel_event.set()
future.result()
# Update the status of the unfinished node executions
node_execs = self.db_client.get_execution_results(graph_exec_id)
for node_exec in node_execs:
node_exec.status = ExecutionStatus.TERMINATED
self.db_client.send_execution_update(node_exec)
if node_exec.status not in (
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
):
exec_update = self.db_client.update_execution_status(
node_exec.node_exec_id, ExecutionStatus.TERMINATED
)
self.db_client.send_execution_update(exec_update)
def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
"""Checks all credentials for all nodes of the graph"""
@@ -1177,13 +968,6 @@ def get_db_client() -> "DatabaseManager":
return get_service_client(DatabaseManager)
@thread_cached
def get_notification_service() -> "NotificationManager":
from backend.notifications import NotificationManager
return get_service_client(NotificationManager)
@contextmanager
def synchronized(key: str, timeout: int = 60):
lock: RedisLock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)

View File

@@ -1,23 +1,19 @@
import logging
import os
from enum import Enum
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
from apscheduler.job import Job as JobObj
from apscheduler.jobstores.memory import MemoryJobStore
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
from apscheduler.schedulers.blocking import BlockingScheduler
from apscheduler.triggers.cron import CronTrigger
from autogpt_libs.utils.cache import thread_cached
from dotenv import load_dotenv
from prisma.enums import NotificationType
from pydantic import BaseModel
from sqlalchemy import MetaData, create_engine
from backend.data.block import BlockInput
from backend.executor.manager import ExecutionManager
from backend.notifications.notifications import NotificationManager
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Config
@@ -46,7 +42,7 @@ config = Config()
def log(msg, **kwargs):
logger.info("[Scheduler] " + msg, **kwargs)
logger.info("[ExecutionScheduler] " + msg, **kwargs)
def job_listener(event):
@@ -62,15 +58,8 @@ def get_execution_client() -> ExecutionManager:
return get_service_client(ExecutionManager)
@thread_cached
def get_notification_client():
from backend.notifications import NotificationManager
return get_service_client(NotificationManager)
def execute_graph(**kwargs):
args = ExecutionJobArgs(**kwargs)
args = JobArgs(**kwargs)
try:
log(f"Executing recurring job for graph #{args.graph_id}")
get_execution_client().add_execution(
@@ -83,32 +72,7 @@ def execute_graph(**kwargs):
logger.exception(f"Error executing graph {args.graph_id}: {e}")
def process_existing_batches(**kwargs):
args = NotificationJobArgs(**kwargs)
try:
log(
f"Processing existing batches for notification type {args.notification_types}"
)
get_notification_client().process_existing_batches(args.notification_types)
except Exception as e:
logger.exception(f"Error processing existing batches: {e}")
def process_weekly_summary(**kwargs):
try:
log("Processing weekly summary")
get_notification_client().queue_weekly_summary()
except Exception as e:
logger.exception(f"Error processing weekly summary: {e}")
class Jobstores(Enum):
EXECUTION = "execution"
BATCHED_NOTIFICATIONS = "batched_notifications"
WEEKLY_NOTIFICATIONS = "weekly_notifications"
class ExecutionJobArgs(BaseModel):
class JobArgs(BaseModel):
graph_id: str
input_data: BlockInput
user_id: str
@@ -116,14 +80,14 @@ class ExecutionJobArgs(BaseModel):
cron: str
class ExecutionJobInfo(ExecutionJobArgs):
class JobInfo(JobArgs):
id: str
name: str
next_run_time: str
@staticmethod
def from_db(job_args: ExecutionJobArgs, job_obj: JobObj) -> "ExecutionJobInfo":
return ExecutionJobInfo(
def from_db(job_args: JobArgs, job_obj: JobObj) -> "JobInfo":
return JobInfo(
id=job_obj.id,
name=job_obj.name,
next_run_time=job_obj.next_run_time.isoformat(),
@@ -131,29 +95,7 @@ class ExecutionJobInfo(ExecutionJobArgs):
)
class NotificationJobArgs(BaseModel):
notification_types: list[NotificationType]
cron: str
class NotificationJobInfo(NotificationJobArgs):
id: str
name: str
next_run_time: str
@staticmethod
def from_db(
job_args: NotificationJobArgs, job_obj: JobObj
) -> "NotificationJobInfo":
return NotificationJobInfo(
id=job_obj.id,
name=job_obj.name,
next_run_time=job_obj.next_run_time.isoformat(),
**job_args.model_dump(),
)
class Scheduler(AppService):
class ExecutionScheduler(AppService):
scheduler: BlockingScheduler
@classmethod
@@ -169,38 +111,19 @@ class Scheduler(AppService):
def execution_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager)
@property
@thread_cached
def notification_client(self) -> NotificationManager:
return get_service_client(NotificationManager)
def run_service(self):
load_dotenv()
db_schema, db_url = _extract_schema_from_url(os.getenv("DATABASE_URL"))
self.scheduler = BlockingScheduler(
jobstores={
Jobstores.EXECUTION.value: SQLAlchemyJobStore(
"default": SQLAlchemyJobStore(
engine=create_engine(
url=db_url,
pool_size=self.db_pool_size(),
max_overflow=0,
),
metadata=MetaData(schema=db_schema),
# this one is pre-existing so it keeps the
# default table name.
tablename="apscheduler_jobs",
),
Jobstores.BATCHED_NOTIFICATIONS.value: SQLAlchemyJobStore(
engine=create_engine(
url=db_url,
pool_size=self.db_pool_size(),
max_overflow=0,
),
metadata=MetaData(schema=db_schema),
tablename="apscheduler_jobs_batched_notifications",
),
# These don't really need persistence
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
)
}
)
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
@@ -214,8 +137,8 @@ class Scheduler(AppService):
cron: str,
input_data: BlockInput,
user_id: str,
) -> ExecutionJobInfo:
job_args = ExecutionJobArgs(
) -> JobInfo:
job_args = JobArgs(
graph_id=graph_id,
input_data=input_data,
user_id=user_id,
@@ -227,80 +150,37 @@ class Scheduler(AppService):
CronTrigger.from_crontab(cron),
kwargs=job_args.model_dump(),
replace_existing=True,
jobstore=Jobstores.EXECUTION.value,
)
log(f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}")
return ExecutionJobInfo.from_db(job_args, job)
return JobInfo.from_db(job_args, job)
@expose
def delete_schedule(self, schedule_id: str, user_id: str) -> ExecutionJobInfo:
job = self.scheduler.get_job(schedule_id, jobstore=Jobstores.EXECUTION.value)
def delete_schedule(self, schedule_id: str, user_id: str) -> JobInfo:
job = self.scheduler.get_job(schedule_id)
if not job:
log(f"Job {schedule_id} not found.")
raise ValueError(f"Job #{schedule_id} not found.")
job_args = ExecutionJobArgs(**job.kwargs)
job_args = JobArgs(**job.kwargs)
if job_args.user_id != user_id:
raise ValueError("User ID does not match the job's user ID.")
log(f"Deleting job {schedule_id}")
job.remove()
return ExecutionJobInfo.from_db(job_args, job)
return JobInfo.from_db(job_args, job)
@expose
def get_execution_schedules(
self, graph_id: str | None = None, user_id: str | None = None
) -> list[ExecutionJobInfo]:
) -> list[JobInfo]:
schedules = []
for job in self.scheduler.get_jobs(jobstore=Jobstores.EXECUTION.value):
logger.info(
f"Found job {job.id} with cron schedule {job.trigger} and args {job.kwargs}"
)
job_args = ExecutionJobArgs(**job.kwargs)
for job in self.scheduler.get_jobs():
job_args = JobArgs(**job.kwargs)
if (
job.next_run_time is not None
and (graph_id is None or job_args.graph_id == graph_id)
and (user_id is None or job_args.user_id == user_id)
):
schedules.append(ExecutionJobInfo.from_db(job_args, job))
schedules.append(JobInfo.from_db(job_args, job))
return schedules
@expose
def add_batched_notification_schedule(
self,
notification_types: list[NotificationType],
data: dict,
cron: str,
) -> NotificationJobInfo:
job_args = NotificationJobArgs(
notification_types=notification_types,
cron=cron,
)
job = self.scheduler.add_job(
process_existing_batches,
CronTrigger.from_crontab(cron),
kwargs=job_args.model_dump(),
replace_existing=True,
jobstore=Jobstores.BATCHED_NOTIFICATIONS.value,
)
log(f"Added job {job.id} with cron schedule '{cron}' input data: {data}")
return NotificationJobInfo.from_db(job_args, job)
@expose
def add_weekly_notification_schedule(self, cron: str) -> NotificationJobInfo:
job = self.scheduler.add_job(
process_weekly_summary,
CronTrigger.from_crontab(cron),
kwargs={},
replace_existing=True,
jobstore=Jobstores.WEEKLY_NOTIFICATIONS.value,
)
log(f"Added job {job.id} with cron schedule '{cron}'")
return NotificationJobInfo.from_db(
NotificationJobArgs(
cron=cron, notification_types=[NotificationType.WEEKLY_SUMMARY]
),
job,
)

View File

@@ -1,97 +0,0 @@
from pydantic import BaseModel
from backend.data.block import Block, BlockInput
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCostType
from backend.util.settings import Config
config = Config()
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
node_exec_id: str | None = None
block_id: str | None = None
block: str | None = None
input: BlockInput | None = None
def execution_usage_cost(execution_count: int) -> tuple[int, int]:
"""
Calculate the cost of executing a graph based on the number of executions.
Args:
execution_count: Number of executions
Returns:
Tuple of cost amount and remaining execution count
"""
return (
execution_count
// config.execution_cost_count_threshold
* config.execution_cost_per_threshold,
execution_count % config.execution_cost_count_threshold,
)
def block_usage_cost(
block: Block,
input_data: BlockInput,
data_size: float = 0,
run_time: float = 0,
) -> tuple[int, BlockInput]:
"""
Calculate the cost of using a block based on the input data and the block type.
Args:
block: Block object
input_data: Input data for the block
data_size: Size of the input data in bytes
run_time: Execution time of the block in seconds
Returns:
Tuple of cost amount and cost filter
"""
block_costs = BLOCK_COSTS.get(type(block))
if not block_costs:
return 0, {}
for block_cost in block_costs:
if not _is_cost_filter_match(block_cost.cost_filter, input_data):
continue
if block_cost.cost_type == BlockCostType.RUN:
return block_cost.cost_amount, block_cost.cost_filter
if block_cost.cost_type == BlockCostType.SECOND:
return (
int(run_time * block_cost.cost_amount),
block_cost.cost_filter,
)
if block_cost.cost_type == BlockCostType.BYTE:
return (
int(data_size * block_cost.cost_amount),
block_cost.cost_filter,
)
return 0, {}
def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bool:
"""
Filter rules:
- If cost_filter is an object, then check if cost_filter is the subset of input_data
- Otherwise, check if cost_filter is equal to input_data.
- Undefined, null, and empty string are considered as equal.
"""
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
return cost_filter == input_data
return all(
(not input_data.get(k) and not v)
or (input_data.get(k) and _is_cost_filter_match(v, input_data[k]))
for k, v in cost_filter.items()
)

View File

@@ -145,29 +145,6 @@ mem0_credentials = APIKeyCredentials(
expires_at=None,
)
apollo_credentials = APIKeyCredentials(
id="544c62b5-1d0f-4156-8fb4-9525f11656eb",
provider="apollo",
api_key=SecretStr(settings.secrets.apollo_api_key),
title="Use Credits for Apollo",
expires_at=None,
)
smartlead_credentials = APIKeyCredentials(
id="3bcdbda3-84a3-46af-8fdb-bfd2472298b8",
provider="smartlead",
api_key=SecretStr(settings.secrets.smartlead_api_key),
title="Use Credits for SmartLead",
expires_at=None,
)
zerobounce_credentials = APIKeyCredentials(
id="63a6e279-2dc2-448e-bf57-85776f7176dc",
provider="zerobounce",
api_key=SecretStr(settings.secrets.zerobounce_api_key),
title="Use Credits for ZeroBounce",
expires_at=None,
)
DEFAULT_CREDENTIALS = [
ollama_credentials,
@@ -187,9 +164,6 @@ DEFAULT_CREDENTIALS = [
mem0_credentials,
nvidia_credentials,
screenshotone_credentials,
apollo_credentials,
smartlead_credentials,
zerobounce_credentials,
]
@@ -257,12 +231,6 @@ class IntegrationCredentialsStore:
all_credentials.append(screenshotone_credentials)
if settings.secrets.mem0_api_key:
all_credentials.append(mem0_credentials)
if settings.secrets.apollo_api_key:
all_credentials.append(apollo_credentials)
if settings.secrets.smartlead_api_key:
all_credentials.append(smartlead_credentials)
if settings.secrets.zerobounce_api_key:
all_credentials.append(zerobounce_credentials)
return all_credentials
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:

View File

@@ -4,7 +4,6 @@ from enum import Enum
# --8<-- [start:ProviderName]
class ProviderName(str, Enum):
ANTHROPIC = "anthropic"
APOLLO = "apollo"
COMPASS = "compass"
DISCORD = "discord"
D_ID = "d_id"
@@ -33,10 +32,8 @@ class ProviderName(str, Enum):
REVID = "revid"
SCREENSHOTONE = "screenshotone"
SLANT3D = "slant3d"
SMARTLEAD = "smartlead"
SMTP = "smtp"
TWITTER = "twitter"
TODOIST = "todoist"
UNREAL_SPEECH = "unreal_speech"
ZEROBOUNCE = "zerobounce"
# --8<-- [end:ProviderName]

View File

@@ -1,43 +1,22 @@
from typing import TYPE_CHECKING
from .compass import CompassWebhookManager
from .github import GithubWebhooksManager
from .slant3d import Slant3DWebhooksManager
if TYPE_CHECKING:
from ..providers import ProviderName
from ._base import BaseWebhooksManager
_WEBHOOK_MANAGERS: dict["ProviderName", type["BaseWebhooksManager"]] = {}
# --8<-- [start:WEBHOOK_MANAGERS_BY_NAME]
WEBHOOK_MANAGERS_BY_NAME: dict["ProviderName", type["BaseWebhooksManager"]] = {
handler.PROVIDER_NAME: handler
for handler in [
CompassWebhookManager,
GithubWebhooksManager,
Slant3DWebhooksManager,
]
}
# --8<-- [end:WEBHOOK_MANAGERS_BY_NAME]
# --8<-- [start:load_webhook_managers]
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
if _WEBHOOK_MANAGERS:
return _WEBHOOK_MANAGERS
from .compass import CompassWebhookManager
from .github import GithubWebhooksManager
from .slant3d import Slant3DWebhooksManager
_WEBHOOK_MANAGERS.update(
{
handler.PROVIDER_NAME: handler
for handler in [
CompassWebhookManager,
GithubWebhooksManager,
Slant3DWebhooksManager,
]
}
)
return _WEBHOOK_MANAGERS
# --8<-- [end:load_webhook_managers]
def get_webhook_manager(provider_name: "ProviderName") -> "BaseWebhooksManager":
return load_webhook_managers()[provider_name]()
def supports_webhooks(provider_name: "ProviderName") -> bool:
return provider_name in load_webhook_managers()
__all__ = ["get_webhook_manager", "supports_webhooks"]
__all__ = ["WEBHOOK_MANAGERS_BY_NAME"]

View File

@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Callable, Optional, cast
from backend.data.block import BlockSchema, BlockWebhookConfig, get_block
from backend.data.graph import set_node_webhook
from backend.integrations.webhooks import get_webhook_manager, supports_webhooks
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
if TYPE_CHECKING:
from backend.data.graph import GraphModel, NodeModel
@@ -123,7 +123,7 @@ async def on_node_activate(
return node
provider = block.webhook_config.provider
if not supports_webhooks(provider):
if provider not in WEBHOOK_MANAGERS_BY_NAME:
raise ValueError(
f"Block #{block.id} has webhook_config for provider {provider} "
"which does not support webhooks"
@@ -133,7 +133,7 @@ async def on_node_activate(
f"Activating webhook node #{node.id} with config {block.webhook_config}"
)
webhooks_manager = get_webhook_manager(provider)
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
if auto_setup_webhook := isinstance(block.webhook_config, BlockWebhookConfig):
try:
@@ -234,13 +234,13 @@ async def on_node_deactivate(
return node
provider = block.webhook_config.provider
if not supports_webhooks(provider):
if provider not in WEBHOOK_MANAGERS_BY_NAME:
raise ValueError(
f"Block #{block.id} has webhook_config for provider {provider} "
"which does not support webhooks"
)
webhooks_manager = get_webhook_manager(provider)
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
if node.webhook_id:
logger.debug(f"Node #{node.id} has webhook_id {node.webhook_id}")

View File

@@ -1,5 +0,0 @@
from .notifications import NotificationManager
__all__ = [
"NotificationManager",
]

View File

@@ -1,136 +0,0 @@
import logging
import pathlib
from postmarker.core import PostmarkClient
from postmarker.models.emails import EmailManager
from prisma.enums import NotificationType
from pydantic import BaseModel
from backend.data.notifications import (
NotificationDataType_co,
NotificationEventModel,
NotificationTypeOverride,
)
from backend.util.settings import Settings
from backend.util.text import TextFormatter
logger = logging.getLogger(__name__)
settings = Settings()
# The following is a workaround to get the type checker to recognize the EmailManager type
# This is a temporary solution and should be removed once the Postmark library is updated
# to support type annotations.
class TypedPostmarkClient(PostmarkClient):
emails: EmailManager
class Template(BaseModel):
subject_template: str
body_template: str
base_template: str
class EmailSender:
def __init__(self):
if settings.secrets.postmark_server_api_token:
self.postmark = TypedPostmarkClient(
server_token=settings.secrets.postmark_server_api_token
)
else:
logger.warning(
"Postmark server API token not found, email sending disabled"
)
self.postmark = None
self.formatter = TextFormatter()
def send_templated(
self,
notification: NotificationType,
user_email: str,
data: (
NotificationEventModel[NotificationDataType_co]
| list[NotificationEventModel[NotificationDataType_co]]
),
user_unsub_link: str | None = None,
):
"""Send an email to a user using a template pulled from the notification type"""
if not self.postmark:
logger.warning("Postmark client not initialized, email not sent")
return
template = self._get_template(notification)
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
# Handle the case when data is a list
template_data = data
if isinstance(data, list):
# Create a dictionary with a 'notifications' key containing the list
template_data = {"notifications": data}
try:
subject, full_message = self.formatter.format_email(
base_template=template.base_template,
subject_template=template.subject_template,
content_template=template.body_template,
data=template_data,
unsubscribe_link=f"{base_url}/profile/settings",
)
except Exception as e:
logger.error(f"Error formatting full message: {e}")
raise e
self._send_email(
user_email=user_email,
user_unsubscribe_link=user_unsub_link,
subject=subject,
body=full_message,
)
def _get_template(self, notification: NotificationType):
# convert the notification type to a notification type override
notification_type_override = NotificationTypeOverride(notification)
# find the template in templates/name.html (the .template returns with the .html)
template_path = f"templates/{notification_type_override.template}.jinja2"
logger.debug(
f"Template full path: {pathlib.Path(__file__).parent / template_path}"
)
base_template_path = "templates/base.html.jinja2"
with open(pathlib.Path(__file__).parent / base_template_path, "r") as file:
base_template = file.read()
with open(pathlib.Path(__file__).parent / template_path, "r") as file:
template = file.read()
return Template(
subject_template=notification_type_override.subject,
body_template=template,
base_template=base_template,
)
def _send_email(
self,
user_email: str,
subject: str,
body: str,
user_unsubscribe_link: str | None = None,
):
if not self.postmark:
logger.warning("Email tried to send without postmark configured")
return
logger.debug(f"Sending email to {user_email} with subject {subject}")
self.postmark.emails.send(
From=settings.config.postmark_sender_email,
To=user_email,
Subject=subject,
HtmlBody=body,
# Headers default to None internally so this is fine
Headers=(
{
"List-Unsubscribe-Post": "List-Unsubscribe=One-Click",
"List-Unsubscribe": f"<{user_unsubscribe_link}>",
}
if user_unsubscribe_link
else None
),
)

View File

@@ -1,747 +0,0 @@
import logging
import time
from datetime import datetime, timedelta, timezone
from typing import Callable
import aio_pika
from aio_pika.exceptions import QueueEmpty
from autogpt_libs.utils.cache import thread_cached
from prisma.enums import NotificationType
from pydantic import BaseModel
from backend.data.notifications import (
BaseSummaryData,
BaseSummaryParams,
DailySummaryData,
DailySummaryParams,
NotificationEventDTO,
NotificationEventModel,
NotificationResult,
NotificationTypeOverride,
QueueType,
SummaryParamsEventDTO,
SummaryParamsEventModel,
WeeklySummaryData,
WeeklySummaryParams,
get_batch_delay,
get_notif_data_type,
get_summary_params_type,
)
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.user import generate_unsubscribe_link
from backend.notifications.email import EmailSender
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
class NotificationEvent(BaseModel):
event: NotificationEventDTO
model: NotificationEventModel
def create_notification_config() -> RabbitMQConfig:
"""Create RabbitMQ configuration for notifications"""
notification_exchange = Exchange(name="notifications", type=ExchangeType.TOPIC)
dead_letter_exchange = Exchange(name="dead_letter", type=ExchangeType.TOPIC)
queues = [
# Main notification queues
Queue(
name="immediate_notifications",
exchange=notification_exchange,
routing_key="notification.immediate.#",
arguments={
"x-dead-letter-exchange": dead_letter_exchange.name,
"x-dead-letter-routing-key": "failed.immediate",
},
),
Queue(
name="admin_notifications",
exchange=notification_exchange,
routing_key="notification.admin.#",
arguments={
"x-dead-letter-exchange": dead_letter_exchange.name,
"x-dead-letter-routing-key": "failed.admin",
},
),
# Summary notification queues
Queue(
name="summary_notifications",
exchange=notification_exchange,
routing_key="notification.summary.#",
arguments={
"x-dead-letter-exchange": dead_letter_exchange.name,
"x-dead-letter-routing-key": "failed.summary",
},
),
# Batch Queue
Queue(
name="batch_notifications",
exchange=notification_exchange,
routing_key="notification.batch.#",
arguments={
"x-dead-letter-exchange": dead_letter_exchange.name,
"x-dead-letter-routing-key": "failed.batch",
},
),
# Failed notifications queue
Queue(
name="failed_notifications",
exchange=dead_letter_exchange,
routing_key="failed.#",
),
]
return RabbitMQConfig(
exchanges=[
notification_exchange,
dead_letter_exchange,
],
queues=queues,
)
@thread_cached
def get_scheduler():
from backend.executor import Scheduler
return get_service_client(Scheduler)
@thread_cached
def get_db():
from backend.executor.database import DatabaseManager
return get_service_client(DatabaseManager)
class NotificationManager(AppService):
"""Service for handling notifications with batching support"""
def __init__(self):
super().__init__()
self.rabbitmq_config = create_notification_config()
self.running = True
self.email_sender = EmailSender()
@classmethod
def get_port(cls) -> int:
return settings.config.notification_service_port
def get_routing_key(self, event_type: NotificationType) -> str:
strategy = NotificationTypeOverride(event_type).strategy
"""Get the appropriate routing key for an event"""
if strategy == QueueType.IMMEDIATE:
return f"notification.immediate.{event_type.value}"
elif strategy == QueueType.BACKOFF:
return f"notification.backoff.{event_type.value}"
elif strategy == QueueType.ADMIN:
return f"notification.admin.{event_type.value}"
elif strategy == QueueType.BATCH:
return f"notification.batch.{event_type.value}"
elif strategy == QueueType.SUMMARY:
return f"notification.summary.{event_type.value}"
return f"notification.{event_type.value}"
@expose
def queue_weekly_summary(self):
"""Process weekly summary for specified notification types"""
try:
logger.info("Processing weekly summary queuing operation")
processed_count = 0
current_time = datetime.now(tz=timezone.utc)
start_time = current_time - timedelta(days=7)
users = get_db().get_active_user_ids_in_timerange(
end_time=current_time.isoformat(),
start_time=start_time.isoformat(),
)
for user in users:
self._queue_scheduled_notification(
SummaryParamsEventDTO(
user_id=user,
type=NotificationType.WEEKLY_SUMMARY,
data=WeeklySummaryParams(
start_date=start_time,
end_date=current_time,
).model_dump(),
),
)
processed_count += 1
logger.info(f"Processed {processed_count} weekly summaries into queue")
except Exception as e:
logger.exception(f"Error processing weekly summary: {e}")
@expose
def process_existing_batches(self, notification_types: list[NotificationType]):
"""Process existing batches for specified notification types"""
try:
processed_count = 0
current_time = datetime.now(tz=timezone.utc)
for notification_type in notification_types:
# Get all batches for this notification type
batches = get_db().get_all_batches_by_type(notification_type)
for batch in batches:
# Check if batch has aged out
oldest_message = (
get_db().get_user_notification_oldest_message_in_batch(
batch.user_id, notification_type
)
)
if not oldest_message:
# this should never happen
logger.error(
f"Batch for user {batch.user_id} and type {notification_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
)
continue
max_delay = get_batch_delay(notification_type)
# If batch has aged out, process it
if oldest_message.created_at + max_delay < current_time:
recipient_email = get_db().get_user_email_by_id(batch.user_id)
if not recipient_email:
logger.error(
f"User email not found for user {batch.user_id}"
)
continue
should_send = self._should_email_user_based_on_preference(
batch.user_id, notification_type
)
if not should_send:
logger.debug(
f"User {batch.user_id} does not want to receive {notification_type} notifications"
)
# Clear the batch
get_db().empty_user_notification_batch(
batch.user_id, notification_type
)
continue
batch_data = get_db().get_user_notification_batch(
batch.user_id, notification_type
)
if not batch_data or not batch_data.notifications:
logger.error(
f"Batch data not found for user {batch.user_id}"
)
# Clear the batch
get_db().empty_user_notification_batch(
batch.user_id, notification_type
)
continue
unsub_link = generate_unsubscribe_link(batch.user_id)
events = [
NotificationEventModel[
get_notif_data_type(db_event.type)
].model_validate(
{
"user_id": batch.user_id,
"type": db_event.type,
"data": db_event.data,
"created_at": db_event.created_at,
}
)
for db_event in batch_data.notifications
]
logger.info(f"{events=}")
self.email_sender.send_templated(
notification=notification_type,
user_email=recipient_email,
data=events,
user_unsub_link=unsub_link,
)
# Clear the batch
get_db().empty_user_notification_batch(
batch.user_id, notification_type
)
processed_count += 1
logger.info(f"Processed {processed_count} aged batches")
return {
"success": True,
"processed_count": processed_count,
"notification_types": [nt.value for nt in notification_types],
"timestamp": current_time.isoformat(),
}
except Exception as e:
logger.exception(f"Error processing batches: {e}")
return {
"success": False,
"error": str(e),
"notification_types": [nt.value for nt in notification_types],
"timestamp": datetime.now(tz=timezone.utc).isoformat(),
}
@expose
def queue_notification(self, event: NotificationEventDTO) -> NotificationResult:
"""Queue a notification - exposed method for other services to call"""
try:
logger.info(f"Received Request to queue {event=}")
# Workaround for not being able to serialize generics over the expose bus
parsed_event = NotificationEventModel[
get_notif_data_type(event.type)
].model_validate(event.model_dump())
routing_key = self.get_routing_key(parsed_event.type)
message = parsed_event.model_dump_json()
logger.info(f"Received Request to queue {message=}")
exchange = "notifications"
# Publish to RabbitMQ
self.run_and_wait(
self.rabbit.publish_message(
routing_key=routing_key,
message=message,
exchange=next(
ex for ex in self.rabbit_config.exchanges if ex.name == exchange
),
)
)
return NotificationResult(
success=True,
message=f"Notification queued with routing key: {routing_key}",
)
except Exception as e:
logger.exception(f"Error queueing notification: {e}")
return NotificationResult(success=False, message=str(e))
def _queue_scheduled_notification(self, event: SummaryParamsEventDTO):
"""Queue a scheduled notification - exposed method for other services to call"""
try:
logger.info(f"Received Request to queue scheduled notification {event=}")
parsed_event = SummaryParamsEventModel[
get_summary_params_type(event.type)
].model_validate(event.model_dump())
routing_key = self.get_routing_key(event.type)
message = parsed_event.model_dump_json()
logger.info(f"Received Request to queue {message=}")
exchange = "notifications"
# Publish to RabbitMQ
self.run_and_wait(
self.rabbit.publish_message(
routing_key=routing_key,
message=message,
exchange=next(
ex for ex in self.rabbit_config.exchanges if ex.name == exchange
),
)
)
except Exception as e:
logger.exception(f"Error queueing notification: {e}")
def _should_email_user_based_on_preference(
self, user_id: str, event_type: NotificationType
) -> bool:
"""Check if a user wants to receive a notification based on their preferences and email verification status"""
validated_email = get_db().get_user_email_verification(user_id)
preference = (
get_db()
.get_user_notification_preference(user_id)
.preferences.get(event_type, True)
)
# only if both are true, should we email this person
return validated_email and preference
def _gather_summary_data(
self, user_id: str, event_type: NotificationType, params: BaseSummaryParams
) -> BaseSummaryData:
"""Gathers the data to build a summary notification"""
logger.info(
f"Gathering summary data for {user_id} and {event_type} wiht {params=}"
)
# total_credits_used = self.run_and_wait(
# get_total_credits_used(user_id, start_time, end_time)
# )
# total_executions = self.run_and_wait(
# get_total_executions(user_id, start_time, end_time)
# )
# most_used_agent = self.run_and_wait(
# get_most_used_agent(user_id, start_time, end_time)
# )
# execution_times = self.run_and_wait(
# get_execution_time(user_id, start_time, end_time)
# )
# runs = self.run_and_wait(
# get_runs(user_id, start_time, end_time)
# )
total_credits_used = 3.0
total_executions = 2
most_used_agent = {"name": "Some"}
execution_times = [1, 2, 3]
runs = [{"status": "COMPLETED"}, {"status": "FAILED"}]
successful_runs = len([run for run in runs if run["status"] == "COMPLETED"])
failed_runs = len([run for run in runs if run["status"] != "COMPLETED"])
average_execution_time = (
sum(execution_times) / len(execution_times) if execution_times else 0
)
# cost_breakdown = self.run_and_wait(
# get_cost_breakdown(user_id, start_time, end_time)
# )
cost_breakdown = {
"agent1": 1.0,
"agent2": 2.0,
}
if event_type == NotificationType.DAILY_SUMMARY and isinstance(
params, DailySummaryParams
):
return DailySummaryData(
total_credits_used=total_credits_used,
total_executions=total_executions,
most_used_agent=most_used_agent["name"],
total_execution_time=sum(execution_times),
successful_runs=successful_runs,
failed_runs=failed_runs,
average_execution_time=average_execution_time,
cost_breakdown=cost_breakdown,
date=params.date,
)
elif event_type == NotificationType.WEEKLY_SUMMARY and isinstance(
params, WeeklySummaryParams
):
return WeeklySummaryData(
total_credits_used=total_credits_used,
total_executions=total_executions,
most_used_agent=most_used_agent["name"],
total_execution_time=sum(execution_times),
successful_runs=successful_runs,
failed_runs=failed_runs,
average_execution_time=average_execution_time,
cost_breakdown=cost_breakdown,
start_date=params.start_date,
end_date=params.end_date,
)
else:
raise ValueError("Invalid event type or params")
def _should_batch(
self, user_id: str, event_type: NotificationType, event: NotificationEventModel
) -> bool:
get_db().create_or_add_to_user_notification_batch(user_id, event_type, event)
oldest_message = get_db().get_user_notification_oldest_message_in_batch(
user_id, event_type
)
if not oldest_message:
logger.error(
f"Batch for user {user_id} and type {event_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
)
return False
oldest_age = oldest_message.created_at
max_delay = get_batch_delay(event_type)
if oldest_age + max_delay < datetime.now(tz=timezone.utc):
logger.info(f"Batch for user {user_id} and type {event_type} is old enough")
return True
logger.info(
f"Batch for user {user_id} and type {event_type} is not old enough: {oldest_age + max_delay} < {datetime.now(tz=timezone.utc)} max_delay={max_delay}"
)
return False
def _parse_message(self, message: str) -> NotificationEvent | None:
try:
event = NotificationEventDTO.model_validate_json(message)
model = NotificationEventModel[
get_notif_data_type(event.type)
].model_validate_json(message)
return NotificationEvent(event=event, model=model)
except Exception as e:
logger.error(f"Error parsing message due to non matching schema {e}")
return None
def _process_admin_message(self, message: str) -> bool:
"""Process a single notification, sending to an admin, returning whether to put into the failed queue"""
try:
parsed = self._parse_message(message)
if not parsed:
return False
event = parsed.event
model = parsed.model
logger.debug(f"Processing notification for admin: {model}")
recipient_email = settings.config.refund_notification_email
self.email_sender.send_templated(event.type, recipient_email, model)
return True
except Exception as e:
logger.exception(f"Error processing notification for admin queue: {e}")
return False
def _process_immediate(self, message: str) -> bool:
"""Process a single notification immediately, returning whether to put into the failed queue"""
try:
parsed = self._parse_message(message)
if not parsed:
return False
event = parsed.event
model = parsed.model
logger.debug(f"Processing immediate notification: {model}")
recipient_email = get_db().get_user_email_by_id(event.user_id)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
should_send = self._should_email_user_based_on_preference(
event.user_id, event.type
)
if not should_send:
logger.debug(
f"User {event.user_id} does not want to receive {event.type} notifications"
)
return True
unsub_link = generate_unsubscribe_link(event.user_id)
self.email_sender.send_templated(
notification=event.type,
user_email=recipient_email,
data=model,
user_unsub_link=unsub_link,
)
return True
except Exception as e:
logger.exception(f"Error processing notification for immediate queue: {e}")
return False
def _process_batch(self, message: str) -> bool:
"""Process a single notification with a batching strategy, returning whether to put into the failed queue"""
try:
parsed = self._parse_message(message)
if not parsed:
return False
event = parsed.event
model = parsed.model
logger.info(f"Processing batch notification: {model}")
recipient_email = get_db().get_user_email_by_id(event.user_id)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
should_send = self._should_email_user_based_on_preference(
event.user_id, event.type
)
if not should_send:
logger.info(
f"User {event.user_id} does not want to receive {event.type} notifications"
)
return True
should_send = self._should_batch(event.user_id, event.type, model)
if not should_send:
logger.info("Batch not old enough to send")
return False
batch = get_db().get_user_notification_batch(event.user_id, event.type)
if not batch or not batch.notifications:
logger.error(f"Batch not found for user {event.user_id}")
return False
unsub_link = generate_unsubscribe_link(event.user_id)
batch_messages = [
NotificationEventModel[
get_notif_data_type(db_event.type)
].model_validate(
{
"user_id": event.user_id,
"type": db_event.type,
"data": db_event.data,
"created_at": db_event.created_at,
}
)
for db_event in batch.notifications
]
self.email_sender.send_templated(
notification=event.type,
user_email=recipient_email,
data=batch_messages,
user_unsub_link=unsub_link,
)
# only empty the batch if we sent the email successfully
get_db().empty_user_notification_batch(event.user_id, event.type)
return True
except Exception as e:
logger.exception(f"Error processing notification for batch queue: {e}")
return False
def _process_summary(self, message: str) -> bool:
"""Process a single notification with a summary strategy, returning whether to put into the failed queue"""
try:
logger.info(f"Processing summary notification: {message}")
event = SummaryParamsEventDTO.model_validate_json(message)
model = SummaryParamsEventModel[
get_summary_params_type(event.type)
].model_validate_json(message)
logger.info(f"Processing summary notification: {model}")
recipient_email = get_db().get_user_email_by_id(event.user_id)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
should_send = self._should_email_user_based_on_preference(
event.user_id, event.type
)
if not should_send:
logger.info(
f"User {event.user_id} does not want to receive {event.type} notifications"
)
return True
summary_data = self._gather_summary_data(
event.user_id, event.type, model.data
)
unsub_link = generate_unsubscribe_link(event.user_id)
data = NotificationEventModel(
user_id=event.user_id,
type=event.type,
data=summary_data,
)
self.email_sender.send_templated(
notification=event.type,
user_email=recipient_email,
data=data,
user_unsub_link=unsub_link,
)
return True
except Exception as e:
logger.exception(f"Error processing notification for summary queue: {e}")
return False
def _run_queue(
self,
queue: aio_pika.abc.AbstractQueue,
process_func: Callable[[str], bool],
error_queue_name: str,
):
message: aio_pika.abc.AbstractMessage | None = None
try:
# This parameter "no_ack" is named like shit, think of it as "auto_ack"
message = self.run_and_wait(queue.get(timeout=1.0, no_ack=False))
result = process_func(message.body.decode())
if result:
self.run_and_wait(message.ack())
else:
self.run_and_wait(message.reject(requeue=False))
except QueueEmpty:
logger.debug(f"Queue {error_queue_name} empty")
except Exception as e:
if message:
logger.error(
f"Error in notification service loop, message rejected {e}"
)
self.run_and_wait(message.reject(requeue=False))
else:
logger.error(
f"Error in notification service loop, message unable to be rejected, and will have to be manually removed to free space in the queue: {e}"
)
def run_service(self):
logger.info(f"[{self.service_name}] Started notification service")
# Set up scheduler for batch processing of all notification types
# this can be changed later to spawn differnt cleanups on different schedules
try:
get_scheduler().add_batched_notification_schedule(
notification_types=list(NotificationType),
data={},
cron="0 * * * *",
)
# get_scheduler().add_weekly_notification_schedule(
# # weekly on Friday at 12pm
# cron="0 12 * * 5",
# )
logger.info("Scheduled notification cleanup")
except Exception as e:
logger.error(f"Error scheduling notification cleanup: {e}")
# Set up queue consumers
channel = self.run_and_wait(self.rabbit.get_channel())
immediate_queue = self.run_and_wait(
channel.get_queue("immediate_notifications")
)
batch_queue = self.run_and_wait(channel.get_queue("batch_notifications"))
admin_queue = self.run_and_wait(channel.get_queue("admin_notifications"))
summary_queue = self.run_and_wait(channel.get_queue("summary_notifications"))
while self.running:
try:
self._run_queue(
queue=immediate_queue,
process_func=self._process_immediate,
error_queue_name="immediate_notifications",
)
self._run_queue(
queue=admin_queue,
process_func=self._process_admin_message,
error_queue_name="admin_notifications",
)
self._run_queue(
queue=batch_queue,
process_func=self._process_batch,
error_queue_name="batch_notifications",
)
self._run_queue(
queue=summary_queue,
process_func=self._process_summary,
error_queue_name="summary_notifications",
)
time.sleep(0.1)
except QueueEmpty as e:
logger.debug(f"Queue empty: {e}")
except Exception as e:
logger.error(f"Error in notification service loop: {e}")
def cleanup(self):
"""Cleanup service resources"""
self.running = False
super().cleanup()

View File

@@ -1,142 +0,0 @@
{# Agent Run #}
{# Template variables:
notification.data: the stuff below but a list of them
data.agent_name: the name of the agent
data.credits_used: the number of credits used by the agent
data.node_count: the number of nodes the agent ran on
data.execution_time: the time it took to run the agent
data.graph_id: the id of the graph the agent ran on
data.outputs: the list of outputs of the agent
#}
{% if notifications is defined %}
{# BATCH MODE #}
<div style="font-family: 'Poppins', sans-serif; color: #070629;">
<h2 style="color: #5D23BB; margin-bottom: 15px;">Agent Run Summary</h2>
<p style="font-size: 16px; line-height: 165%; margin-top: 0; margin-bottom: 15px;">
<strong>{{ notifications|length }}</strong> agent runs have completed!
</p>
{# Calculate summary stats #}
{% set total_time = 0 %}
{% set total_nodes = 0 %}
{% set total_credits = 0 %}
{% set agent_names = [] %}
{% for notification in notifications %}
{% set total_time = total_time + notification.data.execution_time %}
{% set total_nodes = total_nodes + notification.data.node_count %}
{% set total_credits = total_credits + notification.data.credits_used %}
{% if notification.data.agent_name not in agent_names %}
{% set agent_names = agent_names + [notification.data.agent_name] %}
{% endif %}
{% endfor %}
<div style="background-color: #f8f7ff; border-radius: 8px; padding: 15px; margin-bottom: 25px;">
<h3 style="margin-top: 0; margin-bottom: 10px; color: #5D23BB;">Summary</h3>
<p style="margin: 5px 0;"><strong>Agents:</strong> {{ agent_names|join(", ") }}</p>
<p style="margin: 5px 0;"><strong>Total Time:</strong> {{ total_time | int }} seconds</p>
<p style="margin: 5px 0;"><strong>Total Nodes:</strong> {{ total_nodes }}</p>
<p style="margin: 5px 0;"><strong>Total Cost:</strong> ${{ "{:.2f}".format((total_credits|float)/100) }}</p>
</div>
<h3 style="margin-top: 25px; margin-bottom: 15px; color: #5D23BB;">Individual Runs</h3>
{% for notification in notifications %}
<div style="margin-bottom: 30px; border-left: 3px solid #5D23BB; padding-left: 15px;">
<p style="font-size: 16px; font-weight: 600; margin-top: 0; margin-bottom: 10px;">
Agent: <strong>{{ notification.data.agent_name }}</strong>
</p>
<div style="margin-left: 10px;">
<p style="margin: 5px 0;"><strong>Time:</strong> {{ notification.data.execution_time | int }} seconds</p>
<p style="margin: 5px 0;"><strong>Nodes:</strong> {{ notification.data.node_count }}</p>
<p style="margin: 5px 0;"><strong>Cost:</strong> ${{ "{:.2f}".format((notification.data.credits_used|float)/100) }}</p>
</div>
{% if notification.data.outputs and notification.data.outputs|length > 0 %}
<div style="margin-left: 10px; margin-top: 15px;">
<p style="font-weight: 600; margin-bottom: 10px;">Results:</p>
{% for output in notification.data.outputs %}
<div style="margin-left: 10px; margin-bottom: 12px;">
<p style="color: #5D23BB; font-weight: 500; margin-top: 0; margin-bottom: 5px;">
{{ output.name }}
</p>
{% for key, value in output.items() %}
{% if key != 'name' %}
<div style="margin-left: 10px; background-color: #f5f5ff; padding: 8px 12px; border-radius: 4px;
font-family: 'Roboto Mono', monospace; white-space: pre-wrap; word-break: break-word;
overflow-wrap: break-word; max-width: 100%; overflow-x: auto; margin-top: 3px;
margin-bottom: 8px; line-height: 1.4;">
{% if value is iterable and value is not string %}
{% if value|length == 1 %}
{{ value[0] }}
{% else %}
[{% for item in value %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}]
{% endif %}
{% else %}
{{ value }}
{% endif %}
</div>
{% endif %}
{% endfor %}
</div>
{% endfor %}
</div>
{% endif %}
</div>
{% endfor %}
</div>
{% else %}
{# SINGLE NOTIFICATION MODE - Original template #}
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%;
margin-top: 0; margin-bottom: 10px;">
Your agent, <strong>{{ data.agent_name }}</strong>, has completed its run!
</p>
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%;
margin-top: 0; margin-bottom: 20px; padding-left: 20px;">
<p style="margin-bottom: 10px;"><strong>Time Taken:</strong> {{ data.execution_time | int }} seconds</p>
<p style="margin-bottom: 10px;"><strong>Nodes Used:</strong> {{ data.node_count }}</p>
<p style="margin-bottom: 10px;"><strong>Cost:</strong> ${{ "{:.2f}".format((data.credits_used|float)/100) }}</p>
</p>
{% if data.outputs and data.outputs|length > 0 %}
<div style="margin-left: 15px; margin-bottom: 20px;">
<p style="font-family: 'Poppins', sans-serif; color: #070629; font-weight: 600;
font-size: 16px; margin-bottom: 10px;">
Results:
</p>
{% for output in data.outputs %}
<div style="margin-left: 15px; margin-bottom: 15px;">
<p style="font-family: 'Poppins', sans-serif; color: #5D23BB; font-weight: 500;
font-size: 16px; margin-top: 0; margin-bottom: 8px;">
{{ output.name }}
</p>
{% for key, value in output.items() %}
{% if key != 'name' %}
<div style="margin-left: 15px; background-color: #f5f5ff; padding: 8px 12px; border-radius: 4px;
font-family: 'Roboto Mono', monospace; white-space: pre-wrap; word-break: break-word;
overflow-wrap: break-word; max-width: 100%; overflow-x: auto; margin-top: 5px;
margin-bottom: 10px; line-height: 1.4;">
{% if value is iterable and value is not string %}
{% if value|length == 1 %}
{{ value[0] }}
{% else %}
[{% for item in value %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}]
{% endif %}
{% else %}
{{ value }}
{% endif %}
</div>
{% endif %}
{% endfor %}
</div>
{% endfor %}
</div>
{% endif %}
{% endif %}

View File

@@ -1,343 +0,0 @@
{# Base Template #}
{# Template variables:
data.message: the message to display in the email
data.title: the title of the email
data.unsubscribe_link: the link to unsubscribe from the email
#}
<!doctype html>
<html lang="ltr" xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=yes">
<meta name="format-detection" content="telephone=no, date=no, address=no, email=no, url=no">
<meta name="x-apple-disable-message-reformatting">
<!--[if !mso]>
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<![endif]-->
<!--[if mso]>
<style>
* { font-family: sans-serif !important; }
</style>
<noscript>
<xml>
<o:OfficeDocumentSettings>
<o:PixelsPerInch>96</o:PixelsPerInch>
</o:OfficeDocumentSettings>
</xml>
</noscript>
<![endif]-->
<style type="text/css">
/* RESET STYLES */
html,
body {
margin: 0 !important;
padding: 0 !important;
width: 100% !important;
height: 100% !important;
}
body {
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
text-rendering: optimizeLegibility;
}
.document {
margin: 0 !important;
padding: 0 !important;
width: 100% !important;
}
img {
border: 0;
outline: none;
text-decoration: none;
-ms-interpolation-mode: bicubic;
}
table {
border-collapse: collapse;
}
table,
td {
mso-table-lspace: 0pt;
mso-table-rspace: 0pt;
}
body,
table,
td,
a {
-webkit-text-size-adjust: 100%;
-ms-text-size-adjust: 100%;
}
h1,
h2,
h3,
h4,
h5,
p {
margin: 0;
word-break: break-word;
}
/* iOS BLUE LINKS */
a[x-apple-data-detectors] {
color: inherit !important;
text-decoration: none !important;
font-size: inherit !important;
font-family: inherit !important;
font-weight: inherit !important;
line-height: inherit !important;
}
/* ANDROID CENTER FIX */
div[style*="margin: 16px 0;"] {
margin: 0 !important;
}
/* MEDIA QUERIES */
@media all and (max-width:639px) {
.wrapper {
width: 100% !important;
}
.container {
width: 100% !important;
min-width: 100% !important;
padding: 0 !important;
}
.row {
padding-left: 20px !important;
padding-right: 20px !important;
}
.col-mobile {
width: 20px !important;
}
.col {
display: block !important;
width: 100% !important;
}
.mobile-center {
text-align: center !important;
float: none !important;
}
.mobile-mx-auto {
margin: 0 auto !important;
float: none !important;
}
.mobile-left {
text-align: center !important;
float: left !important;
}
.mobile-hide {
display: none !important;
}
.img {
width: 100% !important;
height: auto !important;
}
.ml-btn {
width: 100% !important;
max-width: 100% !important;
}
.ml-btn-container {
width: 100% !important;
max-width: 100% !important;
}
}
</style>
<style type="text/css">
@import url("https://assets.mlcdn.com/fonts-v2.css?version=1729862");
</style>
<style type="text/css">
@media screen {
body {
font-family: 'Poppins', sans-serif;
}
}
</style>
<title>{{data.title}}</title>
</head>
<body style="margin: 0 !important; padding: 0 !important; background-color:#070629;">
<div class="document" role="article" aria-roledescription="email" aria-label lang dir="ltr"
style="background-color:#070629; line-height: 100%; font-size:medium; font-size:max(16px, 1rem);">
<!-- Main Content -->
<table width="100%" align="center" cellspacing="0" cellpadding="0" border="0">
<tr>
<td class="background" bgcolor="#070629" align="center" valign="top" style="padding: 0 8px;">
<!-- Email Content -->
<table class="container" align="center" width="640" cellpadding="0" cellspacing="0" border="0"
style="max-width: 640px;">
<tr>
<td align="center">
<!-- Logo Section -->
<table class="container ml-4 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
cellspacing="0" cellpadding="0" style="width: 640px; min-width: 640px;">
<tr>
<td class="ml-default-border container" height="40" style="line-height: 40px; min-width: 640px;">
</td>
</tr>
<tr>
<td>
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
<tr>
<td class="row" align="center" style="padding: 0 50px;">
<img
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
border="0" alt="" width="120" class="logo"
style="max-width: 120px; display: inline-block;">
</td>
</tr>
</table>
</td>
</tr>
</table>
<!-- Main Content Section -->
<table class="container ml-6 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
cellspacing="0" cellpadding="0" style="color: #070629; width: 640px; min-width: 640px;">
<tr>
<td class="row" style="padding: 0 50px;">
{{data.message|safe}}
</td>
</tr>
</table>
<!-- Signature Section -->
<table class="container ml-8 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
cellspacing="0" cellpadding="0" style="color: #070629; width: 640px; min-width: 640px;">
<tr>
<td class="row mobile-center" align="left" style="padding: 0 50px;">
<table class="ml-8 wrapper" border="0" cellspacing="0" cellpadding="0"
style="color: #070629; text-align: left;">
<tr>
<td class="col center mobile-center" align>
<p
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-top: 0; margin-bottom: 0;">
Thank you for being a part of the AutoGPT community! Join the conversation on our Discord <a href="https://discord.gg/autogpt" style="color: #4285F4; text-decoration: underline;">here</a> and share your thoughts with us anytime.
</p>
</td>
</tr>
</table>
</td>
</tr>
</table>
<!-- Footer Section -->
<table class="container ml-10 ml-default-border" width="640" bgcolor="#ffffff" align="center" border="0"
cellspacing="0" cellpadding="0" style="width: 640px; min-width: 640px;">
<tr>
<td class="row" style="padding: 0 50px;">
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
<tr>
<td>
<!-- Footer Content -->
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
<tr>
<td class="col" align="left" valign="middle" width="120">
<img
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
border="0" alt="" width="120" class="logo"
style="max-width: 120px; display: inline-block;">
</td>
<td class="col" width="40" height="30" style="line-height: 30px;"></td>
<td class="col mobile-left" align="right" valign="middle" width="250">
<table role="presentation" cellpadding="0" cellspacing="0" border="0">
<tr>
<td align="center" valign="middle" width="18" style="padding: 0 5px 0 0;">
<a href="https://x.com/auto_gpt" target="blank" style="text-decoration: none;">
<img
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/x.png"
width="18" alt="x">
</a>
</td>
<td align="center" valign="middle" width="18" style="padding: 0 5px;">
<a href="https://discord.gg/autogpt" target="blank"
style="text-decoration: none;">
<img
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/discord.png"
width="18" alt="discord">
</a>
</td>
<td align="center" valign="middle" width="18" style="padding: 0 0 0 5px;">
<a href="https://agpt.co/" target="blank" style="text-decoration: none;">
<img
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/website.png"
width="18" alt="website">
</a>
</td>
</tr>
</table>
</td>
</tr>
</table>
</td>
</tr>
<tr>
<td align="center" style="text-align: left!important;">
<h5
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 15px; line-height: 125%; font-weight: bold; font-style: normal; text-decoration: none; margin-bottom: 6px;">
AutoGPT
</h5>
</td>
</tr>
<tr>
<td align="center" style="text-align: left!important;">
<p
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 14px; line-height: 150%; display: inline-block; margin-bottom: 0;">
3rd Floor 1 Ashley Road, Cheshire, United Kingdom, WA14 2DT, Altrincham<br>United Kingdom
</p>
</td>
</tr>
<tr>
<td height="8" style="line-height: 8px;"></td>
</tr>
<tr>
<td align="left" style="text-align: left!important;">
<p
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 14px; line-height: 150%; display: inline-block; margin-bottom: 0;">
You received this email because you signed up on our website.</p>
</td>
</tr>
<tr>
<td height="1" style="line-height: 12px;"></td>
</tr>
<tr>
<td align="left">
<p
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 14px; line-height: 150%; display: inline-block; margin-bottom: 0;">
<a href="{{data.unsubscribe_link}}"
style="color: #4285F4; font-weight: normal; font-style: normal; text-decoration: underline;">Unsubscribe</a>
</p>
</td>
</tr>
</table>
</td>
</tr>
</table>
</td>
</tr>
</table>
</td>
</tr>
</table>
</div>
</body>
</html>

View File

@@ -1,114 +0,0 @@
{# Low Balance Notification Email Template #}
{# Template variables:
data.agent_name: the name of the agent
data.current_balance: the current balance of the user
data.billing_page_link: the link to the billing page
data.shortfall: the shortfall amount
#}
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Low Balance Warning</strong>
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 20px;
">
Your agent "<strong>{{ data.agent_name }}</strong>" has been stopped due to low balance.
</p>
<div style="
margin-left: 15px;
margin-bottom: 20px;
padding: 15px;
border-left: 4px solid #5D23BB;
background-color: #f8f8ff;
">
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Current Balance:</strong> ${{ "{:.2f}".format((data.current_balance|float)/100) }}
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Shortfall:</strong> ${{ "{:.2f}".format((data.shortfall|float)/100) }}
</p>
</div>
<div style="
margin-left: 15px;
margin-bottom: 20px;
padding: 15px;
border-left: 4px solid #FF6B6B;
background-color: #FFF0F0;
">
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Low Balance:</strong>
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 5px;
">
Your agent "<strong>{{ data.agent_name }}</strong>" requires additional credits to continue running. The current operation has been canceled until your balance is replenished.
</p>
</div>
<div style="
text-align: center;
margin: 30px 0;
">
<a href="{{ data.billing_page_link }}" style="
font-family: 'Poppins', sans-serif;
background-color: #5D23BB;
color: white;
padding: 12px 24px;
text-decoration: none;
border-radius: 4px;
font-weight: 500;
display: inline-block;
">
Manage Billing
</a>
</div>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 150%;
margin-top: 30px;
margin-bottom: 10px;
font-style: italic;
">
This is an automated notification. Your agent is stopped and will need manually restarted unless set to trigger automatically.
</p>

View File

@@ -1,51 +0,0 @@
{# Refund Processed Notification Email Template #}
{#
Template variables:
data.user_id: the ID of the user
data.user_name: the user's name
data.user_email: the user's email address
data.transaction_id: the transaction ID for the refund request
data.refund_request_id: the refund request ID
data.reason: the reason for the refund request
data.amount: the refund amount in cents (divide by 100 for dollars)
data.balance: the user's latest balance in cents (after the refund deduction)
Subject: Refund for ${{ data.amount / 100 }} to {{ data.user_name }} has been processed
#}
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Refund Processed Notification</title>
</head>
<body style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 1.65; margin: 0; padding: 20px;">
<p style="margin-bottom: 10px;">Hello Administrator,</p>
<p style="margin-bottom: 10px;">
This is to notify you that the refund for <strong>${{ data.amount / 100 }}</strong> to <strong>{{ data.user_name }}</strong> has been processed successfully.
</p>
<h2 style="margin-bottom: 10px;">Refund Details</h2>
<ul style="margin-bottom: 10px;">
<li><strong>User ID:</strong> {{ data.user_id }}</li>
<li><strong>User Name:</strong> {{ data.user_name }}</li>
<li><strong>User Email:</strong> {{ data.user_email }}</li>
<li><strong>Transaction ID:</strong> {{ data.transaction_id }}</li>
<li><strong>Refund Request ID:</strong> {{ data.refund_request_id }}</li>
<li><strong>Refund Amount:</strong> ${{ data.amount / 100 }}</li>
<li><strong>Reason for Refund:</strong> {{ data.reason }}</li>
<li><strong>Latest User Balance:</strong> ${{ data.balance / 100 }}</li>
</ul>
<p style="margin-bottom: 10px;">
The user's balance has been updated accordingly after the deduction.
</p>
<p style="margin-bottom: 10px;">
Please contact the support team if you have any questions or need further assistance regarding this refund.
</p>
<p style="margin-bottom: 0;">Best regards,<br>Your Notification System</p>
</body>
</html>

View File

@@ -1,72 +0,0 @@
{# Refund Request Email Template #}
{#
Template variables:
data.user_id: the ID of the user
data.user_name: the user's name
data.user_email: the user's email address
data.transaction_id: the transaction ID for the refund request
data.refund_request_id: the refund request ID
data.reason: the reason for the refund request
data.amount: the refund amount in cents (divide by 100 for dollars)
data.balance: the user's balance in cents (divide by 100 for dollars)
Subject: [ACTION REQUIRED] You got a ${{ data.amount / 100 }} refund request from {{ data.user_name }}
#}
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Refund Request Approval Needed</title>
</head>
<body style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 1.65; margin: 0; padding: 20px;">
<p style="margin-bottom: 10px;">Hello Administrator,</p>
<p style="margin-bottom: 10px;">
A refund request has been submitted by a user and requires your approval.
</p>
<h2 style="margin-bottom: 10px;">Refund Request Details</h2>
<ul style="margin-bottom: 10px;">
<li><strong>User ID:</strong> {{ data.user_id }}</li>
<li><strong>User Name:</strong> {{ data.user_name }}</li>
<li><strong>User Email:</strong> {{ data.user_email }}</li>
<li><strong>Transaction ID:</strong> {{ data.transaction_id }}</li>
<li><strong>Refund Request ID:</strong> {{ data.refund_request_id }}</li>
<li><strong>Refund Amount:</strong> ${{ data.amount / 100 }}</li>
<li><strong>User Balance:</strong> ${{ data.balance / 100 }}</li>
<li><strong>Reason for Refund:</strong> {{ data.reason }}</li>
</ul>
<p style="margin-bottom: 10px;">
To approve this refund, please click on the following Stripe link:
https://dashboard.stripe.com/test/payments/{{data.transaction_id}}
<br/>
And then click on the "Refund" button.
</p>
<p style="margin-bottom: 10px;">
To reject this refund, please follow these steps:
</p>
<ol style="margin-bottom: 10px;">
<li>
Visit the Supabase Dashboard:
https://supabase.com/dashboard/project/bgwpwdsxblryihinutbx/editor
</li>
<li>
Navigate to the <strong>RefundRequest</strong> table.
</li>
<li>
Filter the <code>transactionKey</code> column with the Transaction ID: <strong>{{ data.transaction_id }}</strong>.
</li>
<li>
Update the <code>status</code> field to <strong>REJECTED</strong> and enter the rejection reason in the <code>result</code> column.
</li>
</ol>
<p style="margin-bottom: 10px;">
Please take the necessary action at your earliest convenience.
</p>
<p style="margin-bottom: 10px;">Thank you for your prompt attention.</p>
<p style="margin-bottom: 0;">Best regards,<br>Your Notification System</p>
</body>
</html>

View File

@@ -1,27 +0,0 @@
{# Weekly Summary #}
{# Template variables:
data: the stuff below
data.start_date: the start date of the summary
data.end_date: the end date of the summary
data.total_credits_used: the total credits used during the summary
data.total_executions: the total number of executions during the summary
data.most_used_agent: the most used agent's nameduring the summary
data.total_execution_time: the total execution time during the summary
data.successful_runs: the total number of successful runs during the summary
data.failed_runs: the total number of failed runs during the summary
data.average_execution_time: the average execution time during the summary
data.cost_breakdown: the cost breakdown during the summary
#}
<h1>Weekly Summary</h1>
<p>Start Date: {{ data.start_date }}</p>
<p>End Date: {{ data.end_date }}</p>
<p>Total Credits Used: {{ data.total_credits_used }}</p>
<p>Total Executions: {{ data.total_executions }}</p>
<p>Most Used Agent: {{ data.most_used_agent }}</p>
<p>Total Execution Time: {{ data.total_execution_time }}</p>
<p>Successful Runs: {{ data.successful_runs }}</p>
<p>Failed Runs: {{ data.failed_runs }}</p>
<p>Average Execution Time: {{ data.average_execution_time }}</p>
<p>Cost Breakdown: {{ data.cost_breakdown }}</p>

View File

@@ -1,6 +1,5 @@
from backend.app import run_processes
from backend.executor import DatabaseManager, Scheduler
from backend.notifications.notifications import NotificationManager
from backend.executor import DatabaseManager, ExecutionScheduler
from backend.server.rest_api import AgentServer
@@ -9,9 +8,8 @@ def main():
Run all the processes required for the AutoGPT-server REST API.
"""
run_processes(
NotificationManager(),
DatabaseManager(),
Scheduler(),
ExecutionScheduler(),
AgentServer(),
)

View File

@@ -2,17 +2,8 @@ from typing import Dict, Set
from fastapi import WebSocket
from backend.data.execution import (
ExecutionEventType,
GraphExecutionEvent,
NodeExecutionEvent,
)
from backend.server.model import WSMessage, WSMethod
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
ExecutionEventType.NODE_EXEC_UPDATE: WSMethod.NODE_EXECUTION_EVENT,
}
from backend.data import execution
from backend.server.model import Methods, WsMessage
class ConnectionManager:
@@ -20,96 +11,37 @@ class ConnectionManager:
self.active_connections: Set[WebSocket] = set()
self.subscriptions: Dict[str, Set[WebSocket]] = {}
async def connect_socket(self, websocket: WebSocket):
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.add(websocket)
def disconnect_socket(self, websocket: WebSocket):
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
for subscribers in self.subscriptions.values():
subscribers.discard(websocket)
async def subscribe_graph_exec(
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
) -> str:
return await self._subscribe(
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
)
async def subscribe(self, graph_id: str, graph_version: int, websocket: WebSocket):
key = f"{graph_id}_{graph_version}"
if key not in self.subscriptions:
self.subscriptions[key] = set()
self.subscriptions[key].add(websocket)
async def subscribe_graph_execs(
self, *, user_id: str, graph_id: str, websocket: WebSocket
) -> str:
return await self._subscribe(
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
)
async def unsubscribe(
self, graph_id: str, graph_version: int, websocket: WebSocket
):
key = f"{graph_id}_{graph_version}"
if key in self.subscriptions:
self.subscriptions[key].discard(websocket)
if not self.subscriptions[key]:
del self.subscriptions[key]
async def unsubscribe_graph_exec(
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
) -> str | None:
return await self._unsubscribe(
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
)
async def unsubscribe_graph_execs(
self, *, user_id: str, graph_id: str, websocket: WebSocket
) -> str | None:
return await self._unsubscribe(
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
)
async def send_execution_update(
self, exec_event: GraphExecutionEvent | NodeExecutionEvent
) -> int:
graph_exec_id = (
exec_event.id
if isinstance(exec_event, GraphExecutionEvent)
else exec_event.graph_exec_id
)
n_sent = 0
channels: set[str] = {
# Send update to listeners for this graph execution
_graph_exec_channel_key(exec_event.user_id, graph_exec_id=graph_exec_id)
}
if isinstance(exec_event, GraphExecutionEvent):
# Send update to listeners for all executions of this graph
channels.add(
_graph_execs_channel_key(
exec_event.user_id, graph_id=exec_event.graph_id
)
)
for channel in channels.intersection(self.subscriptions.keys()):
message = WSMessage(
method=_EVENT_TYPE_TO_METHOD_MAP[exec_event.event_type],
channel=channel,
data=exec_event.model_dump(),
async def send_execution_result(self, result: execution.ExecutionResult):
key = f"{result.graph_id}_{result.graph_version}"
if key in self.subscriptions:
message = WsMessage(
method=Methods.EXECUTION_EVENT,
channel=key,
data=result.model_dump(),
).model_dump_json()
for connection in self.subscriptions[channel]:
for connection in self.subscriptions[key]:
await connection.send_text(message)
n_sent += 1
return n_sent
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
if channel_key not in self.subscriptions:
self.subscriptions[channel_key] = set()
self.subscriptions[channel_key].add(websocket)
return channel_key
async def _unsubscribe(self, channel_key: str, websocket: WebSocket) -> str | None:
if channel_key in self.subscriptions:
self.subscriptions[channel_key].discard(websocket)
if not self.subscriptions[channel_key]:
del self.subscriptions[channel_key]
return channel_key
return None
def _graph_exec_channel_key(user_id: str, *, graph_exec_id: str) -> str:
return f"{user_id}|graph_exec#{graph_exec_id}"
def _graph_execs_channel_key(user_id: str, *, graph_id: str) -> str:
return f"{user_id}|graph#{graph_id}|executions"

View File

@@ -12,7 +12,7 @@ from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.api_key import APIKey
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.execution import NodeExecutionResult
from backend.data.execution import ExecutionResult
from backend.executor import ExecutionManager
from backend.server.external.middleware import require_permission
from backend.util.service import get_service_client
@@ -53,7 +53,7 @@ class GraphExecutionResult(TypedDict):
output: Optional[List[Dict[str, str]]]
def get_outputs_with_names(results: list[NodeExecutionResult]) -> list[dict[str, str]]:
def get_outputs_with_names(results: List[ExecutionResult]) -> List[Dict[str, str]]:
outputs = []
for result in results:
if "output" in result.output_data:
@@ -71,7 +71,7 @@ def get_outputs_with_names(results: list[NodeExecutionResult]) -> list[dict[str,
)
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.data.block.get_blocks().values()]
return [b.to_dict() for b in blocks if not b.disabled]
return [b.to_dict() for b in blocks]
@v1_router.post(
@@ -130,7 +130,7 @@ async def get_graph_execution_results(
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
results = await execution_db.get_node_execution_results(graph_exec_id)
results = await execution_db.get_execution_results(graph_exec_id)
last_result = results[-1] if results else None
execution_status = (
last_result.status if last_result else AgentExecutionStatus.INCOMPLETE

View File

@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Annotated, Literal
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from pydantic import BaseModel, Field
from starlette.status import HTTP_404_NOT_FOUND
from backend.data.graph import set_node_webhook
from backend.data.integrations import (
@@ -18,8 +17,8 @@ from backend.executor.manager import ExecutionManager
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
from backend.util.exceptions import NeedConfirmation, NotFoundError
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
from backend.util.exceptions import NeedConfirmation
from backend.util.service import get_service_client
from backend.util.settings import Settings
@@ -282,14 +281,8 @@ async def webhook_ingress_generic(
webhook_id: Annotated[str, Path(title="Our ID for the webhook")],
):
logger.debug(f"Received {provider.value} webhook ingress for ID {webhook_id}")
webhook_manager = get_webhook_manager(provider)
try:
webhook = await get_webhook(webhook_id)
except NotFoundError as e:
logger.warning(f"Webhook payload received for unknown webhook: {e}")
raise HTTPException(
status_code=HTTP_404_NOT_FOUND, detail=f"Webhook #{webhook_id} not found"
) from e
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
webhook = await get_webhook(webhook_id)
logger.debug(f"Webhook #{webhook_id}: {webhook}")
payload, event_type = await webhook_manager.validate_payload(webhook, request)
logger.debug(
@@ -330,7 +323,7 @@ async def webhook_ping(
user_id: Annotated[str, Depends(get_user_id)], # require auth
):
webhook = await get_webhook(webhook_id)
webhook_manager = get_webhook_manager(webhook.provider)
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[webhook.provider]()
credentials = (
creds_manager.get(user_id, webhook.credentials_id)
@@ -365,6 +358,14 @@ async def remove_all_webhooks_for_credentials(
NeedConfirmation: If any of the webhooks are still in use and `force` is `False`
"""
webhooks = await get_all_webhooks_by_creds(credentials.id)
if credentials.provider not in WEBHOOK_MANAGERS_BY_NAME:
if webhooks:
logger.error(
f"Credentials #{credentials.id} for provider {credentials.provider} "
f"are attached to {len(webhooks)} webhooks, "
f"but there is no available WebhooksHandler for {credentials.provider}"
)
return
if any(w.attached_nodes for w in webhooks) and not force:
raise NeedConfirmation(
"Some webhooks linked to these credentials are still in use by an agent"
@@ -375,7 +376,7 @@ async def remove_all_webhooks_for_credentials(
await set_node_webhook(node.id, None)
# Prune the webhook
webhook_manager = get_webhook_manager(ProviderName(credentials.provider))
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[credentials.provider]()
success = await webhook_manager.prune_webhook_if_dangling(
webhook.id, credentials
)

View File

@@ -1,36 +1,31 @@
import enum
from typing import Any, Optional
from typing import Any, List, Optional, Union
import pydantic
import backend.data.graph
from backend.data.api_key import APIKeyPermission, APIKeyWithoutHash
from backend.data.graph import Graph
class WSMethod(enum.Enum):
SUBSCRIBE_GRAPH_EXEC = "subscribe_graph_execution"
SUBSCRIBE_GRAPH_EXECS = "subscribe_graph_executions"
class Methods(enum.Enum):
SUBSCRIBE = "subscribe"
UNSUBSCRIBE = "unsubscribe"
GRAPH_EXECUTION_EVENT = "graph_execution_event"
NODE_EXECUTION_EVENT = "node_execution_event"
EXECUTION_EVENT = "execution_event"
ERROR = "error"
HEARTBEAT = "heartbeat"
class WSMessage(pydantic.BaseModel):
method: WSMethod
data: Optional[dict[str, Any] | list[Any] | str] = None
class WsMessage(pydantic.BaseModel):
method: Methods
data: Optional[Union[dict[str, Any], list[Any], str]] = None
success: bool | None = None
channel: str | None = None
error: str | None = None
class WSSubscribeGraphExecutionRequest(pydantic.BaseModel):
graph_exec_id: str
class WSSubscribeGraphExecutionsRequest(pydantic.BaseModel):
class ExecutionSubscription(pydantic.BaseModel):
graph_id: str
graph_version: int
class ExecuteGraphResponse(pydantic.BaseModel):
@@ -38,12 +33,12 @@ class ExecuteGraphResponse(pydantic.BaseModel):
class CreateGraph(pydantic.BaseModel):
graph: Graph
graph: backend.data.graph.Graph
class CreateAPIKeyRequest(pydantic.BaseModel):
name: str
permissions: list[APIKeyPermission]
permissions: List[APIKeyPermission]
description: Optional[str] = None
@@ -57,7 +52,7 @@ class SetGraphActiveVersion(pydantic.BaseModel):
class UpdatePermissionsRequest(pydantic.BaseModel):
permissions: list[APIKeyPermission]
permissions: List[APIKeyPermission]
class Pagination(pydantic.BaseModel):

View File

@@ -16,20 +16,14 @@ import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.user
import backend.server.integrations.router
import backend.server.routers.v1
import backend.server.v2.admin.store_admin_routes
import backend.server.v2.library.db
import backend.server.v2.library.model
import backend.server.v2.library.routes
import backend.server.v2.otto.routes
import backend.server.v2.postmark.postmark
import backend.server.v2.store.model
import backend.server.v2.store.routes
import backend.util.service
import backend.util.settings
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.server.external.api import external_app
settings = backend.util.settings.Settings()
@@ -56,8 +50,6 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
# FIXME ERROR: operator does not exist: text ? unknown
# await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
with launch_darkly_context():
yield
await backend.data.db.disconnect()
@@ -72,7 +64,8 @@ docs_url = (
app = fastapi.FastAPI(
title="AutoGPT Agent Server",
description=(
"This server is used to execute agents that are created by the AutoGPT system."
"This server is used to execute agents that are created by the "
"AutoGPT system."
),
summary="AutoGPT Agent Server",
version="0.1",
@@ -102,23 +95,9 @@ app.include_router(backend.server.routers.v1.v1_router, tags=["v1"], prefix="/ap
app.include_router(
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
)
app.include_router(
backend.server.v2.admin.store_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/store",
)
app.include_router(
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
)
app.include_router(
backend.server.v2.otto.routes.router, tags=["v2"], prefix="/api/otto"
)
app.include_router(
backend.server.v2.postmark.postmark.router,
tags=["v2", "email"],
prefix="/api/email",
)
app.mount("/external-api", external_app)
@@ -162,10 +141,9 @@ class AgentServer(backend.util.service.AppProcess):
graph_id: str,
graph_version: int,
user_id: str,
for_export: bool = False,
):
return await backend.server.routers.v1.get_graph(
graph_id, user_id, graph_version, for_export
graph_id, user_id, graph_version
)
@staticmethod
@@ -177,15 +155,21 @@ class AgentServer(backend.util.service.AppProcess):
@staticmethod
async def test_get_graph_run_status(graph_exec_id: str, user_id: str):
from backend.data.execution import get_graph_execution_meta
execution = await get_graph_execution_meta(
execution = await backend.data.graph.get_execution(
user_id=user_id, execution_id=graph_exec_id
)
if not execution:
raise ValueError(f"Execution {graph_exec_id} not found")
return execution.status
@staticmethod
async def test_get_graph_run_node_execution_results(
graph_id: str, graph_exec_id: str, user_id: str
):
return await backend.server.routers.v1.get_graph_run_node_execution_results(
graph_id, graph_exec_id, user_id
)
@staticmethod
async def test_delete_graph(graph_id: str, user_id: str):
await backend.server.v2.library.db.delete_library_agent_by_graph_id(
@@ -252,26 +236,12 @@ class AgentServer(backend.util.service.AppProcess):
):
return await backend.server.v2.store.routes.create_submission(request, user_id)
### ADMIN ###
@staticmethod
async def test_review_store_listing(
request: backend.server.v2.store.model.ReviewSubmissionRequest,
user: autogpt_libs.auth.models.User,
):
return await backend.server.v2.admin.store_admin_routes.review_submission(
request.store_listing_version_id, request, user
)
@staticmethod
def test_create_credentials(
user_id: str,
provider: ProviderName,
credentials: Credentials,
) -> Credentials:
return backend.server.integrations.router.create_credentials(
user_id=user_id, provider=provider, credentials=credentials
)
return await backend.server.v2.store.routes.review_submission(request, user)
def set_test_dependency_overrides(self, overrides: dict):
app.dependency_overrides.update(overrides)

Some files were not shown because too many files have changed in this diff Show More