mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 16:48:06 -05:00
Compare commits
4 Commits
ci-chromat
...
builder-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6fca2352bb | ||
|
|
84759053a7 | ||
|
|
4afe724628 | ||
|
|
c178a537b7 |
24
.github/dependabot.yml
vendored
24
.github/dependabot.yml
vendored
@@ -129,6 +129,30 @@ updates:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# Submodules
|
||||
- package-ecosystem: "gitsubmodule"
|
||||
directory: "autogpt_platform/supabase"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 1
|
||||
target-branch: "dev"
|
||||
commit-message:
|
||||
prefix: "chore(platform/deps)"
|
||||
prefix-development: "chore(platform/deps-dev)"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# Docs
|
||||
- package-ecosystem: 'pip'
|
||||
directory: "docs/"
|
||||
|
||||
9
.github/workflows/classic-autogpt-ci.yml
vendored
9
.github/workflows/classic-autogpt-ci.yml
vendored
@@ -115,7 +115,6 @@ jobs:
|
||||
poetry run pytest -vv \
|
||||
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--numprocesses=logical --durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests/unit tests/integration
|
||||
env:
|
||||
CI: true
|
||||
@@ -125,14 +124,8 @@ jobs:
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
- name: Upload test results to Codecov
|
||||
if: ${{ !cancelled() }} # Run even if tests fail
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: autogpt-agent,${{ runner.os }}
|
||||
|
||||
9
.github/workflows/classic-benchmark-ci.yml
vendored
9
.github/workflows/classic-benchmark-ci.yml
vendored
@@ -87,20 +87,13 @@ jobs:
|
||||
poetry run pytest -vv \
|
||||
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests
|
||||
env:
|
||||
CI: true
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Upload test results to Codecov
|
||||
if: ${{ !cancelled() }} # Run even if tests fail
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: agbenchmark,${{ runner.os }}
|
||||
|
||||
9
.github/workflows/classic-forge-ci.yml
vendored
9
.github/workflows/classic-forge-ci.yml
vendored
@@ -139,7 +139,6 @@ jobs:
|
||||
poetry run pytest -vv \
|
||||
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
forge
|
||||
env:
|
||||
CI: true
|
||||
@@ -149,14 +148,8 @@ jobs:
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
- name: Upload test results to Codecov
|
||||
if: ${{ !cancelled() }} # Run even if tests fail
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: forge,${{ runner.os }}
|
||||
|
||||
@@ -34,7 +34,6 @@ jobs:
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
|
||||
trigger:
|
||||
|
||||
@@ -36,7 +36,6 @@ jobs:
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
trigger:
|
||||
needs: migrate
|
||||
|
||||
35
.github/workflows/platform-backend-ci.yml
vendored
35
.github/workflows/platform-backend-ci.yml
vendored
@@ -66,7 +66,7 @@ jobs:
|
||||
- name: Setup Supabase
|
||||
uses: supabase/setup-cli@v1
|
||||
with:
|
||||
version: 1.178.1
|
||||
version: latest
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
@@ -80,35 +80,18 @@ jobs:
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock
|
||||
HEAD_POETRY_VERSION=$(head -n 1 poetry.lock | grep -oP '(?<=Poetry )[0-9]+\.[0-9]+\.[0-9]+')
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
if [ -n "$BASE_REF" ]; then
|
||||
BASE_BRANCH=${BASE_REF/refs\/heads\//}
|
||||
BASE_POETRY_VERSION=$((git show "origin/$BASE_BRANCH":./poetry.lock; true) | head -n 1 | grep -oP '(?<=Poetry )[0-9]+\.[0-9]+\.[0-9]+')
|
||||
echo "Found Poetry version ${BASE_POETRY_VERSION} in backend/poetry.lock on ${BASE_REF}"
|
||||
POETRY_VERSION=$(printf '%s\n' "$HEAD_POETRY_VERSION" "$BASE_POETRY_VERSION" | sort -V | tail -n1)
|
||||
else
|
||||
POETRY_VERSION=$HEAD_POETRY_VERSION
|
||||
fi
|
||||
echo "Using Poetry version ${POETRY_VERSION}"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$POETRY_VERSION python3 -
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
env:
|
||||
BASE_REF: ${{ github.base_ref || github.event.merge_group.base_ref }}
|
||||
|
||||
- name: Check poetry.lock
|
||||
run: |
|
||||
poetry lock
|
||||
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
if ! git diff --quiet poetry.lock; then
|
||||
echo "Error: poetry.lock not up to date."
|
||||
echo
|
||||
git diff poetry.lock
|
||||
@@ -135,7 +118,6 @@ jobs:
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
@@ -152,13 +134,12 @@ jobs:
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
REDIS_HOST: 'localhost'
|
||||
REDIS_PORT: '6379'
|
||||
REDIS_PASSWORD: 'testpassword'
|
||||
|
||||
env:
|
||||
CI: true
|
||||
@@ -171,8 +152,8 @@ jobs:
|
||||
# If you want to replace this, you can do so by making our entire system generate
|
||||
# new credentials for each local user and update the environment variables in
|
||||
# the backend service, docker composes, and examples
|
||||
RABBITMQ_DEFAULT_USER: "rabbitmq_user_default"
|
||||
RABBITMQ_DEFAULT_PASS: "k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7"
|
||||
RABBITMQ_DEFAULT_USER: 'rabbitmq_user_default'
|
||||
RABBITMQ_DEFAULT_PASS: 'k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7'
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
|
||||
26
.github/workflows/platform-frontend-ci.yml
vendored
26
.github/workflows/platform-frontend-ci.yml
vendored
@@ -56,30 +56,6 @@ jobs:
|
||||
run: |
|
||||
yarn type-check
|
||||
|
||||
design:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
yarn install --frozen-lockfile
|
||||
|
||||
- name: Run Chromatic
|
||||
uses: chromaui/action@latest
|
||||
with:
|
||||
# ⚠️ Make sure to configure a `CHROMATIC_PROJECT_TOKEN` repository secret
|
||||
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
|
||||
workingDir: autogpt_platform/frontend
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
@@ -106,7 +82,7 @@ jobs:
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.example ../.env
|
||||
cp ../supabase/docker/.env.example ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +1,6 @@
|
||||
[submodule "classic/forge/tests/vcr_cassettes"]
|
||||
path = classic/forge/tests/vcr_cassettes
|
||||
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes
|
||||
[submodule "autogpt_platform/supabase"]
|
||||
path = autogpt_platform/supabase
|
||||
url = https://github.com/supabase/supabase.git
|
||||
|
||||
@@ -140,7 +140,7 @@ repos:
|
||||
language: system
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.10.0
|
||||
rev: 23.12.1
|
||||
# Black has sensible defaults, doesn't need package context, and ignores
|
||||
# everything in .gitignore, so it works fine without any config or arguments.
|
||||
hooks:
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
If you are reading this, you are probably looking for the full **[contribution guide]**,
|
||||
which is part of our [wiki].
|
||||
|
||||
Also check out our [🚀 Roadmap][roadmap] for information about our priorities and associated tasks.
|
||||
<!-- You can find our immediate priorities and their progress on our public [kanban board]. -->
|
||||
|
||||
[contribution guide]: https://github.com/Significant-Gravitas/AutoGPT/wiki/Contributing
|
||||
[wiki]: https://github.com/Significant-Gravitas/AutoGPT/wiki
|
||||
[roadmap]: https://github.com/Significant-Gravitas/AutoGPT/discussions/6971
|
||||
|
||||
@@ -15,11 +15,7 @@
|
||||
> Setting up and hosting the AutoGPT Platform yourself is a technical process.
|
||||
> If you'd rather something that just works, we recommend [joining the waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta.
|
||||
|
||||
### Updated Setup Instructions:
|
||||
We’ve moved to a fully maintained and regularly updated documentation site.
|
||||
|
||||
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
|
||||
|
||||
https://github.com/user-attachments/assets/d04273a5-b36a-4a37-818e-f631ce72d603
|
||||
|
||||
This tutorial assumes you have Docker, VSCode, git and npm installed.
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ Instead, please report them via:
|
||||
- Please provide detailed reports with reproducible steps
|
||||
- Include the version/commit hash where you discovered the vulnerability
|
||||
- Allow us a 90-day security fix window before any public disclosure
|
||||
- After patch is released, allow 30 days for users to update before public disclosure (for a total of 120 days max between update time and fix time)
|
||||
- Share any potential mitigations or workarounds if known
|
||||
|
||||
## Supported Versions
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
############
|
||||
# Secrets
|
||||
# YOU MUST CHANGE THESE BEFORE GOING INTO PRODUCTION
|
||||
############
|
||||
|
||||
POSTGRES_PASSWORD=your-super-secret-and-long-postgres-password
|
||||
JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
|
||||
SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
DASHBOARD_USERNAME=supabase
|
||||
DASHBOARD_PASSWORD=this_password_is_insecure_and_should_be_updated
|
||||
SECRET_KEY_BASE=UpNVntn3cDxHJpq99YMc1T1AQgQpc8kfYTuRgBiYa15BLrx8etQoXz3gZv1/u2oq
|
||||
VAULT_ENC_KEY=your-encryption-key-32-chars-min
|
||||
|
||||
|
||||
############
|
||||
# Database - You can change these to any PostgreSQL database that has logical replication enabled.
|
||||
############
|
||||
|
||||
POSTGRES_HOST=db
|
||||
POSTGRES_DB=postgres
|
||||
POSTGRES_PORT=5432
|
||||
# default user is postgres
|
||||
|
||||
|
||||
############
|
||||
# Supavisor -- Database pooler
|
||||
############
|
||||
POOLER_PROXY_PORT_TRANSACTION=6543
|
||||
POOLER_DEFAULT_POOL_SIZE=20
|
||||
POOLER_MAX_CLIENT_CONN=100
|
||||
POOLER_TENANT_ID=your-tenant-id
|
||||
|
||||
|
||||
############
|
||||
# API Proxy - Configuration for the Kong Reverse proxy.
|
||||
############
|
||||
|
||||
KONG_HTTP_PORT=8000
|
||||
KONG_HTTPS_PORT=8443
|
||||
|
||||
|
||||
############
|
||||
# API - Configuration for PostgREST.
|
||||
############
|
||||
|
||||
PGRST_DB_SCHEMAS=public,storage,graphql_public
|
||||
|
||||
|
||||
############
|
||||
# Auth - Configuration for the GoTrue authentication server.
|
||||
############
|
||||
|
||||
## General
|
||||
SITE_URL=http://localhost:3000
|
||||
ADDITIONAL_REDIRECT_URLS=
|
||||
JWT_EXPIRY=3600
|
||||
DISABLE_SIGNUP=false
|
||||
API_EXTERNAL_URL=http://localhost:8000
|
||||
|
||||
## Mailer Config
|
||||
MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify"
|
||||
MAILER_URLPATHS_INVITE="/auth/v1/verify"
|
||||
MAILER_URLPATHS_RECOVERY="/auth/v1/verify"
|
||||
MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify"
|
||||
|
||||
## Email auth
|
||||
ENABLE_EMAIL_SIGNUP=true
|
||||
ENABLE_EMAIL_AUTOCONFIRM=false
|
||||
SMTP_ADMIN_EMAIL=admin@example.com
|
||||
SMTP_HOST=supabase-mail
|
||||
SMTP_PORT=2500
|
||||
SMTP_USER=fake_mail_user
|
||||
SMTP_PASS=fake_mail_password
|
||||
SMTP_SENDER_NAME=fake_sender
|
||||
ENABLE_ANONYMOUS_USERS=false
|
||||
|
||||
## Phone auth
|
||||
ENABLE_PHONE_SIGNUP=true
|
||||
ENABLE_PHONE_AUTOCONFIRM=true
|
||||
|
||||
|
||||
############
|
||||
# Studio - Configuration for the Dashboard
|
||||
############
|
||||
|
||||
STUDIO_DEFAULT_ORGANIZATION=Default Organization
|
||||
STUDIO_DEFAULT_PROJECT=Default Project
|
||||
|
||||
STUDIO_PORT=3000
|
||||
# replace if you intend to use Studio outside of localhost
|
||||
SUPABASE_PUBLIC_URL=http://localhost:8000
|
||||
|
||||
# Enable webp support
|
||||
IMGPROXY_ENABLE_WEBP_DETECTION=true
|
||||
|
||||
# Add your OpenAI API key to enable SQL Editor Assistant
|
||||
OPENAI_API_KEY=
|
||||
|
||||
|
||||
############
|
||||
# Functions - Configuration for Functions
|
||||
############
|
||||
# NOTE: VERIFY_JWT applies to all functions. Per-function VERIFY_JWT is not supported yet.
|
||||
FUNCTIONS_VERIFY_JWT=false
|
||||
|
||||
|
||||
############
|
||||
# Logs - Configuration for Logflare
|
||||
# Please refer to https://supabase.com/docs/reference/self-hosting-analytics/introduction
|
||||
############
|
||||
|
||||
LOGFLARE_LOGGER_BACKEND_API_KEY=your-super-secret-and-long-logflare-key
|
||||
|
||||
# Change vector.toml sinks to reflect this change
|
||||
LOGFLARE_API_KEY=your-super-secret-and-long-logflare-key
|
||||
|
||||
# Docker socket location - this value will differ depending on your OS
|
||||
DOCKER_SOCKET_LOCATION=/var/run/docker.sock
|
||||
|
||||
# Google Cloud Project details
|
||||
GOOGLE_PROJECT_ID=GOOGLE_PROJECT_ID
|
||||
GOOGLE_PROJECT_NUMBER=GOOGLE_PROJECT_NUMBER
|
||||
@@ -22,29 +22,35 @@ To run the AutoGPT Platform, follow these steps:
|
||||
|
||||
2. Run the following command:
|
||||
```
|
||||
cp .env.example .env
|
||||
git submodule update --init --recursive --progress
|
||||
```
|
||||
This command will copy the `.env.example` file to `.env`. You can modify the `.env` file to add your own environment variables.
|
||||
This command will initialize and update the submodules in the repository. The `supabase` folder will be cloned to the root directory.
|
||||
|
||||
3. Run the following command:
|
||||
```
|
||||
cp supabase/docker/.env.example .env
|
||||
```
|
||||
This command will copy the `.env.example` file to `.env` in the `supabase/docker` directory. You can modify the `.env` file to add your own environment variables.
|
||||
|
||||
4. Run the following command:
|
||||
```
|
||||
docker compose up -d
|
||||
```
|
||||
This command will start all the necessary backend services defined in the `docker-compose.yml` file in detached mode.
|
||||
|
||||
4. Navigate to `frontend` within the `autogpt_platform` directory:
|
||||
5. Navigate to `frontend` within the `autogpt_platform` directory:
|
||||
```
|
||||
cd frontend
|
||||
```
|
||||
You will need to run your frontend application separately on your local machine.
|
||||
|
||||
5. Run the following command:
|
||||
6. Run the following command:
|
||||
```
|
||||
cp .env.example .env.local
|
||||
```
|
||||
This command will copy the `.env.example` file to `.env.local` in the `frontend` directory. You can modify the `.env.local` within this folder to add your own environment variables for the frontend application.
|
||||
|
||||
6. Run the following command:
|
||||
7. Run the following command:
|
||||
```
|
||||
npm install
|
||||
npm run dev
|
||||
@@ -55,7 +61,7 @@ To run the AutoGPT Platform, follow these steps:
|
||||
yarn install && yarn dev
|
||||
```
|
||||
|
||||
7. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
8. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
|
||||
### Docker Compose Commands
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from .config import Settings
|
||||
from .depends import requires_admin_user, requires_user
|
||||
from .jwt_utils import parse_jwt_token
|
||||
from .middleware import APIKeyValidator, auth_middleware
|
||||
from .middleware import auth_middleware
|
||||
from .models import User
|
||||
|
||||
__all__ = [
|
||||
"Settings",
|
||||
"parse_jwt_token",
|
||||
"requires_user",
|
||||
"requires_admin_user",
|
||||
"APIKeyValidator",
|
||||
"auth_middleware",
|
||||
"User",
|
||||
]
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self):
|
||||
self.JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
|
||||
self.JWT_ALGORITHM: str = "HS256"
|
||||
JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import fastapi
|
||||
|
||||
from .config import settings
|
||||
from .config import Settings
|
||||
from .middleware import auth_middleware
|
||||
from .models import DEFAULT_USER_ID, User
|
||||
|
||||
@@ -17,7 +17,7 @@ def requires_admin_user(
|
||||
|
||||
def verify_user(payload: dict | None, admin_only: bool) -> User:
|
||||
if not payload:
|
||||
if settings.ENABLE_AUTH:
|
||||
if Settings.ENABLE_AUTH:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="Authorization header is missing"
|
||||
)
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, Security
|
||||
from fastapi.security import APIKeyHeader, HTTPBearer
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.security import HTTPBearer
|
||||
|
||||
from .config import settings
|
||||
from .jwt_utils import parse_jwt_token
|
||||
@@ -32,104 +29,3 @@ async def auth_middleware(request: Request):
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
return payload
|
||||
|
||||
|
||||
class APIKeyValidator:
|
||||
"""
|
||||
Configurable API key validator that supports custom validation functions
|
||||
for FastAPI applications.
|
||||
|
||||
This class provides a flexible way to implement API key authentication with optional
|
||||
custom validation logic. It can be used for simple token matching
|
||||
or more complex validation scenarios like database lookups.
|
||||
|
||||
Examples:
|
||||
Simple token validation:
|
||||
```python
|
||||
validator = APIKeyValidator(
|
||||
header_name="X-API-Key",
|
||||
expected_token="your-secret-token"
|
||||
)
|
||||
|
||||
@app.get("/protected", dependencies=[Depends(validator.get_dependency())])
|
||||
def protected_endpoint():
|
||||
return {"message": "Access granted"}
|
||||
```
|
||||
|
||||
Custom validation with database lookup:
|
||||
```python
|
||||
async def validate_with_db(api_key: str):
|
||||
api_key_obj = await db.get_api_key(api_key)
|
||||
return api_key_obj if api_key_obj and api_key_obj.is_active else None
|
||||
|
||||
validator = APIKeyValidator(
|
||||
header_name="X-API-Key",
|
||||
validate_fn=validate_with_db
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
header_name (str): The name of the header containing the API key
|
||||
expected_token (Optional[str]): The expected API key value for simple token matching
|
||||
validate_fn (Optional[Callable]): Custom validation function that takes an API key
|
||||
string and returns a boolean or object. Can be async.
|
||||
error_status (int): HTTP status code to use for validation errors
|
||||
error_message (str): Error message to return when validation fails
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
header_name: str,
|
||||
expected_token: Optional[str] = None,
|
||||
validate_fn: Optional[Callable[[str], bool]] = None,
|
||||
error_status: int = HTTP_401_UNAUTHORIZED,
|
||||
error_message: str = "Invalid API key",
|
||||
):
|
||||
# Create the APIKeyHeader as a class property
|
||||
self.security_scheme = APIKeyHeader(name=header_name)
|
||||
self.expected_token = expected_token
|
||||
self.custom_validate_fn = validate_fn
|
||||
self.error_status = error_status
|
||||
self.error_message = error_message
|
||||
|
||||
async def default_validator(self, api_key: str) -> bool:
|
||||
return api_key == self.expected_token
|
||||
|
||||
async def __call__(
|
||||
self, request: Request, api_key: str = Security(APIKeyHeader)
|
||||
) -> Any:
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=self.error_status, detail="Missing API key")
|
||||
|
||||
# Use custom validation if provided, otherwise use default equality check
|
||||
validator = self.custom_validate_fn or self.default_validator
|
||||
result = (
|
||||
await validator(api_key)
|
||||
if inspect.iscoroutinefunction(validator)
|
||||
else validator(api_key)
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=self.error_status, detail=self.error_message
|
||||
)
|
||||
|
||||
# Store validation result in request state if it's not just a boolean
|
||||
if result is not True:
|
||||
request.state.api_key = result
|
||||
|
||||
return result
|
||||
|
||||
def get_dependency(self):
|
||||
"""
|
||||
Returns a callable dependency that FastAPI will recognize as a security scheme
|
||||
"""
|
||||
|
||||
async def validate_api_key(
|
||||
request: Request, api_key: str = Security(self.security_scheme)
|
||||
) -> Any:
|
||||
return await self(request, api_key)
|
||||
|
||||
# This helps FastAPI recognize it as a security dependency
|
||||
validate_api_key.__name__ = f"validate_{self.security_scheme.model.name}"
|
||||
return validate_api_key
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from .filters import BelowLevelFilter
|
||||
from .formatters import AGPTFormatter
|
||||
from .formatters import AGPTFormatter, StructuredLoggingFormatter
|
||||
|
||||
LOG_DIR = Path(__file__).parent.parent.parent.parent / "logs"
|
||||
LOG_FILE = "activity.log"
|
||||
@@ -81,26 +81,9 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
"""
|
||||
|
||||
config = LoggingConfig()
|
||||
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
# Console output handlers
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
log_handlers += [stdout, stderr]
|
||||
|
||||
# Cloud logging setup
|
||||
if config.enable_cloud_logging or force_cloud_logging:
|
||||
import google.cloud.logging
|
||||
@@ -114,7 +97,28 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
transport=SyncTransport,
|
||||
)
|
||||
cloud_handler.setLevel(config.level)
|
||||
cloud_handler.setFormatter(StructuredLoggingFormatter())
|
||||
log_handlers.append(cloud_handler)
|
||||
print("Cloud logging enabled")
|
||||
else:
|
||||
# Console output handlers
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
log_handlers += [stdout, stderr]
|
||||
print("Console logging enabled")
|
||||
|
||||
# File logging setup
|
||||
if config.enable_file_logging:
|
||||
@@ -152,6 +156,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
error_log_handler.setLevel(logging.ERROR)
|
||||
error_log_handler.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT, no_color=True))
|
||||
log_handlers.append(error_log_handler)
|
||||
print("File logging enabled")
|
||||
|
||||
# Configure the root logger
|
||||
logging.basicConfig(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from colorama import Fore, Style
|
||||
from google.cloud.logging_v2.handlers import CloudLoggingFilter, StructuredLogHandler
|
||||
|
||||
from .utils import remove_color_codes
|
||||
|
||||
@@ -79,3 +80,16 @@ class AGPTFormatter(FancyConsoleFormatter):
|
||||
return remove_color_codes(super().format(record))
|
||||
else:
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class StructuredLoggingFormatter(StructuredLogHandler, logging.Formatter):
|
||||
def __init__(self):
|
||||
# Set up CloudLoggingFilter to add diagnostic info to the log records
|
||||
self.cloud_logging_filter = CloudLoggingFilter()
|
||||
|
||||
# Init StructuredLogHandler
|
||||
super().__init__()
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
self.cloud_logging_filter.filter(record)
|
||||
return super().format(record)
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import uvicorn.config
|
||||
from colorama import Fore
|
||||
|
||||
|
||||
@@ -26,14 +25,3 @@ def print_attribute(
|
||||
"color": value_color,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def generate_uvicorn_config():
|
||||
"""
|
||||
Generates a uvicorn logging config that silences uvicorn's default logging and tells it to use the native logging module.
|
||||
"""
|
||||
log_config = dict(uvicorn.config.LOGGING_CONFIG)
|
||||
log_config["loggers"]["uvicorn"] = {"handlers": []}
|
||||
log_config["loggers"]["uvicorn.error"] = {"handlers": []}
|
||||
log_config["loggers"]["uvicorn.access"] = {"handlers": []}
|
||||
return log_config
|
||||
|
||||
@@ -1,59 +1,20 @@
|
||||
import inspect
|
||||
import threading
|
||||
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
def thread_cached(
|
||||
func: Callable[P, R] | Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
thread_local = threading.local()
|
||||
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
return wrapper
|
||||
|
||||
809
autogpt_platform/autogpt_libs/poetry.lock
generated
809
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -10,17 +10,18 @@ packages = [{ include = "autogpt_libs" }]
|
||||
colorama = "^0.4.6"
|
||||
expiringdict = "^1.2.2"
|
||||
google-cloud-logging = "^3.11.4"
|
||||
pydantic = "^2.11.1"
|
||||
pydantic-settings = "^2.8.1"
|
||||
pydantic = "^2.10.6"
|
||||
pydantic-settings = "^2.7.1"
|
||||
pyjwt = "^2.10.1"
|
||||
pytest-asyncio = "^0.26.0"
|
||||
pytest-asyncio = "^0.25.3"
|
||||
pytest-mock = "^3.14.0"
|
||||
python = ">=3.10,<4.0"
|
||||
supabase = "^2.15.0"
|
||||
python-dotenv = "^1.0.1"
|
||||
supabase = "^2.13.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
redis = "^5.2.1"
|
||||
ruff = "^0.11.0"
|
||||
ruff = "^0.9.3"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -2,24 +2,13 @@ DB_USER=postgres
|
||||
DB_PASS=your-super-secret-and-long-postgres-password
|
||||
DB_NAME=postgres
|
||||
DB_PORT=5432
|
||||
DB_HOST=localhost
|
||||
DB_CONNECTION_LIMIT=12
|
||||
DB_CONNECT_TIMEOUT=60
|
||||
DB_POOL_TIMEOUT=300
|
||||
DB_SCHEMA=platform
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@localhost:${DB_PORT}/${DB_NAME}?connect_timeout=60&schema=platform"
|
||||
PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
|
||||
# EXECUTOR
|
||||
NUM_GRAPH_WORKERS=10
|
||||
NUM_NODE_WORKERS=3
|
||||
|
||||
BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
|
||||
|
||||
# generate using `from cryptography.fernet import Fernet;Fernet.generate_key().decode()`
|
||||
ENCRYPTION_KEY='dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw='
|
||||
UNSUBSCRIBE_SECRET_KEY = 'HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio='
|
||||
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
@@ -39,7 +28,6 @@ SENTRY_DSN=
|
||||
# Email For Postmark so we can send emails
|
||||
POSTMARK_SERVER_API_TOKEN=
|
||||
POSTMARK_SENDER_EMAIL=invalid@invalid.com
|
||||
POSTMARK_WEBHOOK_TOKEN=
|
||||
|
||||
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
|
||||
ENABLE_AUTH=true
|
||||
@@ -53,9 +41,6 @@ RABBITMQ_PORT=5672
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
|
||||
## GCS bucket is required for marketplace and library functionality
|
||||
MEDIA_GCS_BUCKET_NAME=
|
||||
|
||||
## For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow
|
||||
## for integrations to work. Defaults to the value of PLATFORM_BASE_URL if not set.
|
||||
# FRONTEND_BASE_URL=http://localhost:3000
|
||||
@@ -189,8 +174,6 @@ SMARTLEAD_API_KEY=
|
||||
# ZeroBounce
|
||||
ZEROBOUNCE_API_KEY=
|
||||
|
||||
## ===== OPTIONAL API KEYS END ===== ##
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
ENABLE_CLOUD_LOGGING=false
|
||||
|
||||
@@ -73,6 +73,7 @@ FROM server_dependencies AS server
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
RUN poetry install --no-ansi --only-root
|
||||
|
||||
ENV DATABASE_URL=""
|
||||
ENV PORT=8000
|
||||
|
||||
CMD ["poetry", "run", "rest"]
|
||||
|
||||
@@ -1 +1,75 @@
|
||||
[Advanced Setup (Dev Branch)](https://dev-docs.agpt.co/platform/advanced_setup/#autogpt_agent_server_advanced_set_up)
|
||||
# AutoGPT Agent Server Advanced set up
|
||||
|
||||
This guide walks you through a dockerized set up, with an external DB (postgres)
|
||||
|
||||
## Setup
|
||||
|
||||
We use the Poetry to manage the dependencies. To set up the project, follow these steps inside this directory:
|
||||
|
||||
0. Install Poetry
|
||||
```sh
|
||||
pip install poetry
|
||||
```
|
||||
|
||||
1. Configure Poetry to use .venv in your project directory
|
||||
```sh
|
||||
poetry config virtualenvs.in-project true
|
||||
```
|
||||
|
||||
2. Enter the poetry shell
|
||||
|
||||
```sh
|
||||
poetry shell
|
||||
```
|
||||
|
||||
3. Install dependencies
|
||||
|
||||
```sh
|
||||
poetry install
|
||||
```
|
||||
|
||||
4. Copy .env.example to .env
|
||||
|
||||
```sh
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
5. Generate the Prisma client
|
||||
|
||||
```sh
|
||||
poetry run prisma generate
|
||||
```
|
||||
|
||||
|
||||
> In case Prisma generates the client for the global Python installation instead of the virtual environment, the current mitigation is to just uninstall the global Prisma package:
|
||||
>
|
||||
> ```sh
|
||||
> pip uninstall prisma
|
||||
> ```
|
||||
>
|
||||
> Then run the generation again. The path *should* look something like this:
|
||||
> `<some path>/pypoetry/virtualenvs/backend-TQIRSwR6-py3.12/bin/prisma`
|
||||
|
||||
6. Run the postgres database from the /rnd folder
|
||||
|
||||
```sh
|
||||
cd autogpt_platform/
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
7. Run the migrations (from the backend folder)
|
||||
|
||||
```sh
|
||||
cd ../backend
|
||||
prisma migrate deploy
|
||||
```
|
||||
|
||||
## Running The Server
|
||||
|
||||
### Starting the server directly
|
||||
|
||||
Run the following command:
|
||||
|
||||
```sh
|
||||
poetry run app
|
||||
```
|
||||
|
||||
@@ -1 +1,210 @@
|
||||
[Getting Started (Released)](https://docs.agpt.co/platform/getting-started/#autogpt_agent_server)
|
||||
# AutoGPT Agent Server
|
||||
|
||||
This is an initial project for creating the next generation of agent execution, which is an AutoGPT agent server.
|
||||
The agent server will enable the creation of composite multi-agent systems that utilize AutoGPT agents and other non-agent components as its primitives.
|
||||
|
||||
## Docs
|
||||
|
||||
You can access the docs for the [AutoGPT Agent Server here](https://docs.agpt.co/server/setup).
|
||||
|
||||
## Setup
|
||||
|
||||
We use the Poetry to manage the dependencies. To set up the project, follow these steps inside this directory:
|
||||
|
||||
0. Install Poetry
|
||||
```sh
|
||||
pip install poetry
|
||||
```
|
||||
|
||||
1. Configure Poetry to use .venv in your project directory
|
||||
```sh
|
||||
poetry config virtualenvs.in-project true
|
||||
```
|
||||
|
||||
2. Enter the poetry shell
|
||||
|
||||
```sh
|
||||
poetry shell
|
||||
```
|
||||
|
||||
3. Install dependencies
|
||||
|
||||
```sh
|
||||
poetry install
|
||||
```
|
||||
|
||||
4. Copy .env.example to .env
|
||||
|
||||
```sh
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
5. Generate the Prisma client
|
||||
|
||||
```sh
|
||||
poetry run prisma generate
|
||||
```
|
||||
|
||||
|
||||
> In case Prisma generates the client for the global Python installation instead of the virtual environment, the current mitigation is to just uninstall the global Prisma package:
|
||||
>
|
||||
> ```sh
|
||||
> pip uninstall prisma
|
||||
> ```
|
||||
>
|
||||
> Then run the generation again. The path *should* look something like this:
|
||||
> `<some path>/pypoetry/virtualenvs/backend-TQIRSwR6-py3.12/bin/prisma`
|
||||
|
||||
6. Migrate the database. Be careful because this deletes current data in the database.
|
||||
|
||||
```sh
|
||||
docker compose up db -d
|
||||
poetry run prisma migrate deploy
|
||||
```
|
||||
|
||||
## Running The Server
|
||||
|
||||
### Starting the server without Docker
|
||||
|
||||
To run the server locally, start in the autogpt_platform folder:
|
||||
|
||||
```sh
|
||||
cd ..
|
||||
```
|
||||
|
||||
Run the following command to run database in docker but the application locally:
|
||||
|
||||
```sh
|
||||
docker compose --profile local up deps --build --detach
|
||||
cd backend
|
||||
poetry run app
|
||||
```
|
||||
|
||||
### Starting the server with Docker
|
||||
|
||||
Run the following command to build the dockerfiles:
|
||||
|
||||
```sh
|
||||
docker compose build
|
||||
```
|
||||
|
||||
Run the following command to run the app:
|
||||
|
||||
```sh
|
||||
docker compose up
|
||||
```
|
||||
|
||||
Run the following to automatically rebuild when code changes, in another terminal:
|
||||
|
||||
```sh
|
||||
docker compose watch
|
||||
```
|
||||
|
||||
Run the following command to shut down:
|
||||
|
||||
```sh
|
||||
docker compose down
|
||||
```
|
||||
|
||||
If you run into issues with dangling orphans, try:
|
||||
|
||||
```sh
|
||||
docker compose down --volumes --remove-orphans && docker-compose up --force-recreate --renew-anon-volumes --remove-orphans
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
To run the tests:
|
||||
|
||||
```sh
|
||||
poetry run test
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Formatting & Linting
|
||||
Auto formatter and linter are set up in the project. To run them:
|
||||
|
||||
Install:
|
||||
```sh
|
||||
poetry install --with dev
|
||||
```
|
||||
|
||||
Format the code:
|
||||
```sh
|
||||
poetry run format
|
||||
```
|
||||
|
||||
Lint the code:
|
||||
```sh
|
||||
poetry run lint
|
||||
```
|
||||
|
||||
## Project Outline
|
||||
|
||||
The current project has the following main modules:
|
||||
|
||||
### **blocks**
|
||||
|
||||
This module stores all the Agent Blocks, which are reusable components to build a graph that represents the agent's behavior.
|
||||
|
||||
### **data**
|
||||
|
||||
This module stores the logical model that is persisted in the database.
|
||||
It abstracts the database operations into functions that can be called by the service layer.
|
||||
Any code that interacts with Prisma objects or the database should reside in this module.
|
||||
The main models are:
|
||||
* `block`: anything related to the block used in the graph
|
||||
* `execution`: anything related to the execution graph execution
|
||||
* `graph`: anything related to the graph, node, and its relations
|
||||
|
||||
### **execution**
|
||||
|
||||
This module stores the business logic of executing the graph.
|
||||
It currently has the following main modules:
|
||||
* `manager`: A service that consumes the queue of the graph execution and executes the graph. It contains both pieces of logic.
|
||||
* `scheduler`: A service that triggers scheduled graph execution based on a cron expression. It pushes an execution request to the manager.
|
||||
|
||||
### **server**
|
||||
|
||||
This module stores the logic for the server API.
|
||||
It contains all the logic used for the API that allows the client to create, execute, and monitor the graph and its execution.
|
||||
This API service interacts with other services like those defined in `manager` and `scheduler`.
|
||||
|
||||
### **utils**
|
||||
|
||||
This module stores utility functions that are used across the project.
|
||||
Currently, it has two main modules:
|
||||
* `process`: A module that contains the logic to spawn a new process.
|
||||
* `service`: A module that serves as a parent class for all the services in the project.
|
||||
|
||||
## Service Communication
|
||||
|
||||
Currently, there are only 3 active services:
|
||||
|
||||
- AgentServer (the API, defined in `server.py`)
|
||||
- ExecutionManager (the executor, defined in `manager.py`)
|
||||
- ExecutionScheduler (the scheduler, defined in `scheduler.py`)
|
||||
|
||||
The services run in independent Python processes and communicate through an IPC.
|
||||
A communication layer (`service.py`) is created to decouple the communication library from the implementation.
|
||||
|
||||
Currently, the IPC is done using Pyro5 and abstracted in a way that allows a function decorated with `@expose` to be called from a different process.
|
||||
|
||||
|
||||
By default the daemons run on the following ports:
|
||||
|
||||
Execution Manager Daemon: 8002
|
||||
Execution Scheduler Daemon: 8003
|
||||
Rest Server Daemon: 8004
|
||||
|
||||
## Adding a New Agent Block
|
||||
|
||||
To add a new agent block, you need to create a new class that inherits from `Block` and provides the following information:
|
||||
* All the block code should live in the `blocks` (`backend.blocks`) module.
|
||||
* `input_schema`: the schema of the input data, represented by a Pydantic object.
|
||||
* `output_schema`: the schema of the output data, represented by a Pydantic object.
|
||||
* `run` method: the main logic of the block.
|
||||
* `test_input` & `test_output`: the sample input and output data for the block, which will be used to auto-test the block.
|
||||
* You can mock the functions declared in the block using the `test_mock` field for your unit tests.
|
||||
* Once you finish creating the block, you can test it by running `poetry run pytest -s test/block/test_block.py`.
|
||||
|
||||
@@ -1,30 +1,22 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.util.process import AppProcess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_processes(*processes: "AppProcess", **kwargs):
|
||||
"""
|
||||
Execute all processes in the app. The last process is run in the foreground.
|
||||
Includes enhanced error handling and process lifecycle management.
|
||||
"""
|
||||
try:
|
||||
# Run all processes except the last one in the background.
|
||||
for process in processes[:-1]:
|
||||
process.start(background=True, **kwargs)
|
||||
|
||||
# Run the last process in the foreground.
|
||||
# Run the last process in the foreground
|
||||
processes[-1].start(background=False, **kwargs)
|
||||
finally:
|
||||
for process in processes:
|
||||
try:
|
||||
process.stop()
|
||||
except Exception as e:
|
||||
logger.exception(f"[{process.service_name}] unable to stop: {e}")
|
||||
process.stop()
|
||||
|
||||
|
||||
def main(**kwargs):
|
||||
@@ -32,7 +24,7 @@ def main(**kwargs):
|
||||
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
|
||||
"""
|
||||
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.executor import DatabaseManager, ExecutionManager, ExecutionScheduler
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.server.ws_api import WebsocketServer
|
||||
@@ -40,7 +32,7 @@ def main(**kwargs):
|
||||
run_processes(
|
||||
DatabaseManager(),
|
||||
ExecutionManager(),
|
||||
Scheduler(),
|
||||
ExecutionScheduler(),
|
||||
NotificationManager(),
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
|
||||
@@ -2,103 +2,88 @@ import importlib
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
from typing import Type, TypeVar
|
||||
|
||||
from backend.data.block import Block
|
||||
|
||||
# Dynamically load all modules under backend.blocks
|
||||
AVAILABLE_MODULES = []
|
||||
current_dir = Path(__file__).parent
|
||||
modules = [
|
||||
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
|
||||
for f in current_dir.rglob("*.py")
|
||||
if f.is_file() and f.name != "__init__.py"
|
||||
]
|
||||
for module in modules:
|
||||
if not re.match("^[a-z0-9_.]+$", module):
|
||||
raise ValueError(
|
||||
f"Block module {module} error: module name must be lowercase, "
|
||||
"and contain only alphanumeric characters and underscores."
|
||||
)
|
||||
|
||||
importlib.import_module(f".{module}", package=__name__)
|
||||
AVAILABLE_MODULES.append(module)
|
||||
|
||||
# Load all Block instances from the available modules
|
||||
AVAILABLE_BLOCKS: dict[str, Type[Block]] = {}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import Block
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
_AVAILABLE_BLOCKS: dict[str, type["Block"]] = {}
|
||||
|
||||
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
|
||||
if _AVAILABLE_BLOCKS:
|
||||
return _AVAILABLE_BLOCKS
|
||||
|
||||
# Dynamically load all modules under backend.blocks
|
||||
AVAILABLE_MODULES = []
|
||||
current_dir = Path(__file__).parent
|
||||
modules = [
|
||||
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
|
||||
for f in current_dir.rglob("*.py")
|
||||
if f.is_file() and f.name != "__init__.py"
|
||||
]
|
||||
for module in modules:
|
||||
if not re.match("^[a-z0-9_.]+$", module):
|
||||
raise ValueError(
|
||||
f"Block module {module} error: module name must be lowercase, "
|
||||
"and contain only alphanumeric characters and underscores."
|
||||
)
|
||||
|
||||
importlib.import_module(f".{module}", package=__name__)
|
||||
AVAILABLE_MODULES.append(module)
|
||||
|
||||
# Load all Block instances from the available modules
|
||||
for block_cls in all_subclasses(Block):
|
||||
class_name = block_cls.__name__
|
||||
|
||||
if class_name.endswith("Base"):
|
||||
continue
|
||||
|
||||
if not class_name.endswith("Block"):
|
||||
raise ValueError(
|
||||
f"Block class {class_name} does not end with 'Block'. "
|
||||
"If you are creating an abstract class, "
|
||||
"please name the class with 'Base' at the end"
|
||||
)
|
||||
|
||||
block = block_cls.create()
|
||||
|
||||
if not isinstance(block.id, str) or len(block.id) != 36:
|
||||
raise ValueError(
|
||||
f"Block ID {block.name} error: {block.id} is not a valid UUID"
|
||||
)
|
||||
|
||||
if block.id in _AVAILABLE_BLOCKS:
|
||||
raise ValueError(
|
||||
f"Block ID {block.name} error: {block.id} is already in use"
|
||||
)
|
||||
|
||||
input_schema = block.input_schema.model_fields
|
||||
output_schema = block.output_schema.model_fields
|
||||
|
||||
# Make sure `error` field is a string in the output schema
|
||||
if "error" in output_schema and output_schema["error"].annotation is not str:
|
||||
raise ValueError(
|
||||
f"{block.name} `error` field in output_schema must be a string"
|
||||
)
|
||||
|
||||
# Ensure all fields in input_schema and output_schema are annotated SchemaFields
|
||||
for field_name, field in [*input_schema.items(), *output_schema.items()]:
|
||||
if field.annotation is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} that is not annotated"
|
||||
)
|
||||
if field.json_schema_extra is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} not defined as SchemaField"
|
||||
)
|
||||
|
||||
for field in block.input_schema.model_fields.values():
|
||||
if field.annotation is bool and field.default not in (True, False):
|
||||
raise ValueError(
|
||||
f"{block.name} has a boolean field with no default value"
|
||||
)
|
||||
|
||||
_AVAILABLE_BLOCKS[block.id] = block_cls
|
||||
|
||||
return _AVAILABLE_BLOCKS
|
||||
|
||||
|
||||
__all__ = ["load_all_blocks"]
|
||||
|
||||
|
||||
def all_subclasses(cls: type[T]) -> list[type[T]]:
|
||||
def all_subclasses(cls: Type[T]) -> list[Type[T]]:
|
||||
subclasses = cls.__subclasses__()
|
||||
for subclass in subclasses:
|
||||
subclasses += all_subclasses(subclass)
|
||||
return subclasses
|
||||
|
||||
|
||||
for block_cls in all_subclasses(Block):
|
||||
name = block_cls.__name__
|
||||
|
||||
if block_cls.__name__.endswith("Base"):
|
||||
continue
|
||||
|
||||
if not block_cls.__name__.endswith("Block"):
|
||||
raise ValueError(
|
||||
f"Block class {block_cls.__name__} does not end with 'Block', If you are creating an abstract class, please name the class with 'Base' at the end"
|
||||
)
|
||||
|
||||
block = block_cls.create()
|
||||
|
||||
if not isinstance(block.id, str) or len(block.id) != 36:
|
||||
raise ValueError(f"Block ID {block.name} error: {block.id} is not a valid UUID")
|
||||
|
||||
if block.id in AVAILABLE_BLOCKS:
|
||||
raise ValueError(f"Block ID {block.name} error: {block.id} is already in use")
|
||||
|
||||
input_schema = block.input_schema.model_fields
|
||||
output_schema = block.output_schema.model_fields
|
||||
|
||||
# Make sure `error` field is a string in the output schema
|
||||
if "error" in output_schema and output_schema["error"].annotation is not str:
|
||||
raise ValueError(
|
||||
f"{block.name} `error` field in output_schema must be a string"
|
||||
)
|
||||
|
||||
# Make sure all fields in input_schema and output_schema are annotated and has a value
|
||||
for field_name, field in [*input_schema.items(), *output_schema.items()]:
|
||||
if field.annotation is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} that is not annotated"
|
||||
)
|
||||
if field.json_schema_extra is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} not defined as SchemaField"
|
||||
)
|
||||
|
||||
for field in block.input_schema.model_fields.values():
|
||||
if field.annotation is bool and field.default not in (True, False):
|
||||
raise ValueError(f"{block.name} has a boolean field with no default value")
|
||||
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
AVAILABLE_BLOCKS[block.id] = block_cls
|
||||
|
||||
__all__ = ["AVAILABLE_MODULES", "AVAILABLE_BLOCKS"]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
@@ -12,11 +13,25 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_executor_manager_client():
|
||||
from backend.executor import ExecutionManager
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_event_bus():
|
||||
from backend.data.execution import RedisExecutionEventBus
|
||||
|
||||
return RedisExecutionEventBus()
|
||||
|
||||
|
||||
class AgentExecutorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
user_id: str = SchemaField(description="User ID")
|
||||
@@ -27,23 +42,6 @@ class AgentExecutorBlock(Block):
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
output_schema: dict = SchemaField(description="Output schema for the graph")
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
||||
return data.get("input_schema", {})
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
return data.get("data", {})
|
||||
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
required_fields = cls.get_input_schema(data).get("required", [])
|
||||
return set(required_fields) - set(data)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
|
||||
class Output(BlockSchema):
|
||||
pass
|
||||
|
||||
@@ -58,26 +56,26 @@ class AgentExecutorBlock(Block):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
from backend.data.execution import ExecutionEventType
|
||||
from backend.executor import utils as execution_utils
|
||||
executor_manager = get_executor_manager_client()
|
||||
event_bus = get_event_bus()
|
||||
|
||||
event_bus = execution_utils.get_execution_event_bus()
|
||||
|
||||
graph_exec = execution_utils.add_graph_execution(
|
||||
graph_exec = executor_manager.add_execution(
|
||||
graph_id=input_data.graph_id,
|
||||
graph_version=input_data.graph_version,
|
||||
user_id=input_data.user_id,
|
||||
inputs=input_data.data,
|
||||
data=input_data.data,
|
||||
)
|
||||
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.id}"
|
||||
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.graph_exec_id}"
|
||||
logger.info(f"Starting execution of {log_id}")
|
||||
|
||||
for event in event_bus.listen(
|
||||
user_id=graph_exec.user_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
graph_exec_id=graph_exec.id,
|
||||
graph_id=graph_exec.graph_id, graph_exec_id=graph_exec.graph_exec_id
|
||||
):
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
logger.info(
|
||||
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
|
||||
)
|
||||
|
||||
if not event.node_id:
|
||||
if event.status in [
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
@@ -88,10 +86,6 @@ class AgentExecutorBlock(Block):
|
||||
else:
|
||||
continue
|
||||
|
||||
logger.debug(
|
||||
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
|
||||
)
|
||||
|
||||
if not event.block_id:
|
||||
logger.warning(f"{log_id} received event without block_id {event}")
|
||||
continue
|
||||
@@ -106,7 +100,5 @@ class AgentExecutorBlock(Block):
|
||||
continue
|
||||
|
||||
for output_data in event.output_data.get("output", []):
|
||||
logger.debug(
|
||||
f"Execution {log_id} produced {output_name}: {output_data}"
|
||||
)
|
||||
logger.info(f"Execution {log_id} produced {output_name}: {output_data}")
|
||||
yield output_name, output_data
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
@@ -143,12 +143,11 @@ class ContactEmail(BaseModel):
|
||||
class EmploymentHistory(BaseModel):
|
||||
"""An employment history in Apollo"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
|
||||
_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
@@ -189,12 +188,11 @@ class TypedCustomField(BaseModel):
|
||||
class Pagination(BaseModel):
|
||||
"""Pagination in Apollo"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
class Config:
|
||||
extra = "allow" # Allow extra fields
|
||||
arbitrary_types_allowed = True # Allow any type
|
||||
from_attributes = True # Allow from_orm
|
||||
populate_by_name = True # Allow field aliases to work both ways
|
||||
|
||||
page: int = 0
|
||||
per_page: int = 0
|
||||
@@ -232,12 +230,11 @@ class PhoneNumber(BaseModel):
|
||||
class Organization(BaseModel):
|
||||
"""An organization in Apollo"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
|
||||
id: Optional[str] = "N/A"
|
||||
name: Optional[str] = "N/A"
|
||||
@@ -271,12 +268,11 @@ class Organization(BaseModel):
|
||||
class Contact(BaseModel):
|
||||
"""A contact in Apollo"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
|
||||
contact_roles: list[Any] = []
|
||||
id: Optional[str] = None
|
||||
@@ -373,14 +369,14 @@ If a company has several office locations, results are still based on the headqu
|
||||
|
||||
To exclude companies based on location, use the organization_not_locations parameter.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organizations_not_locations: list[str] = SchemaField(
|
||||
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
|
||||
|
||||
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
q_organization_keyword_tags: list[str] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
|
||||
@@ -394,7 +390,7 @@ If the value you enter for this parameter does not match with a company's name,
|
||||
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, identify the values for organization_id when you call this endpoint.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
@@ -447,14 +443,14 @@ Results also include job titles with the same terms, even if they are not exact
|
||||
|
||||
Use this parameter in combination with the person_seniorities[] parameter to find people based on specific job functions and seniority levels.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
placeholder="marketing manager",
|
||||
)
|
||||
person_locations: list[str] = SchemaField(
|
||||
description="""The location where people live. You can search across cities, US states, and countries.
|
||||
|
||||
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
person_seniorities: list[SenorityLevels] = SchemaField(
|
||||
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
|
||||
@@ -464,7 +460,7 @@ For a person to be included in search results, they only need to match 1 of the
|
||||
Searches only return results based on their current job title, so searching for Director-level employees only returns people that currently hold a Director-level title. If someone was previously a Director, but is currently a VP, they would not be included in your search results.
|
||||
|
||||
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organization_locations: list[str] = SchemaField(
|
||||
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
|
||||
@@ -472,7 +468,7 @@ Use this parameter in combination with the person_titles[] parameter to find peo
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
|
||||
|
||||
To find people based on their personal location, use the person_locations parameter.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
q_organization_domains: list[str] = SchemaField(
|
||||
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
|
||||
@@ -480,23 +476,23 @@ To find people based on their personal location, use the person_locations parame
|
||||
You can add multiple domains to search across companies.
|
||||
|
||||
Examples: apollo.io and microsoft.com""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
|
||||
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
q_keywords: str = SchemaField(
|
||||
description="""A string of words over which we want to filter the results""",
|
||||
@@ -526,12 +522,11 @@ Use the page parameter to search the different pages of data.""",
|
||||
class SearchPeopleResponse(BaseModel):
|
||||
"""Response from Apollo's search people API"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
class Config:
|
||||
extra = "allow" # Allow extra fields
|
||||
arbitrary_types_allowed = True # Allow any type
|
||||
from_attributes = True # Allow from_orm
|
||||
populate_by_name = True # Allow field aliases to work both ways
|
||||
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
|
||||
@@ -32,18 +32,18 @@ If a company has several office locations, results are still based on the headqu
|
||||
|
||||
To exclude companies based on location, use the organization_not_locations parameter.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organizations_not_locations: list[str] = SchemaField(
|
||||
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
|
||||
|
||||
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
q_organization_keyword_tags: list[str] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
q_organization_name: str = SchemaField(
|
||||
description="""Filter search results to include a specific company name.
|
||||
@@ -56,7 +56,7 @@ If the value you enter for this parameter does not match with a company's name,
|
||||
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, identify the values for organization_id when you call this endpoint.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
@@ -72,7 +72,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
class Output(BlockSchema):
|
||||
organizations: list[Organization] = SchemaField(
|
||||
description="List of organizations found",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organization: Organization = SchemaField(
|
||||
description="Each found organization, one at a time",
|
||||
|
||||
@@ -26,14 +26,14 @@ class SearchPeopleBlock(Block):
|
||||
|
||||
Use this parameter in combination with the person_seniorities[] parameter to find people based on specific job functions and seniority levels.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
person_locations: list[str] = SchemaField(
|
||||
description="""The location where people live. You can search across cities, US states, and countries.
|
||||
|
||||
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
person_seniorities: list[SenorityLevels] = SchemaField(
|
||||
@@ -44,7 +44,7 @@ class SearchPeopleBlock(Block):
|
||||
Searches only return results based on their current job title, so searching for Director-level employees only returns people that currently hold a Director-level title. If someone was previously a Director, but is currently a VP, they would not be included in your search results.
|
||||
|
||||
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
organization_locations: list[str] = SchemaField(
|
||||
@@ -53,7 +53,7 @@ class SearchPeopleBlock(Block):
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
|
||||
|
||||
To find people based on their personal location, use the person_locations parameter.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
q_organization_domains: list[str] = SchemaField(
|
||||
@@ -62,26 +62,26 @@ class SearchPeopleBlock(Block):
|
||||
You can add multiple domains to search across companies.
|
||||
|
||||
Examples: apollo.io and microsoft.com""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
|
||||
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
q_keywords: str = SchemaField(
|
||||
@@ -104,7 +104,7 @@ class SearchPeopleBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
people: list[Contact] = SchemaField(
|
||||
description="List of people found",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
person: Contact = SchemaField(
|
||||
description="Each found person, one at a time",
|
||||
|
||||
@@ -3,20 +3,22 @@ from typing import Any, List
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.file import MediaFile, store_media_file
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.type import MediaFileType, convert
|
||||
from backend.util.text import TextFormatter
|
||||
from backend.util.type import convert
|
||||
|
||||
formatter = TextFormatter()
|
||||
|
||||
|
||||
class FileStoreBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
file_in: MediaFileType = SchemaField(
|
||||
file_in: MediaFile = SchemaField(
|
||||
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
file_out: MediaFileType = SchemaField(
|
||||
file_out: MediaFile = SchemaField(
|
||||
description="The relative path to the stored file in the temporary directory."
|
||||
)
|
||||
|
||||
@@ -88,6 +90,29 @@ class StoreValueBlock(Block):
|
||||
yield "output", input_data.data or input_data.input
|
||||
|
||||
|
||||
class PrintToConsoleBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(description="The text to print to the console.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="The status of the print operation.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
|
||||
description="Print the given text to the console, this is used for a debugging purpose.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=PrintToConsoleBlock.Input,
|
||||
output_schema=PrintToConsoleBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
test_output=("status", "printed"),
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
print(">>>>> Print: ", input_data.text)
|
||||
yield "status", "printed"
|
||||
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
@@ -128,9 +153,6 @@ class FindInDictionaryBlock(Block):
|
||||
obj = input_data.input
|
||||
key = input_data.key
|
||||
|
||||
if isinstance(obj, str):
|
||||
obj = json.loads(obj)
|
||||
|
||||
if isinstance(obj, dict) and key in obj:
|
||||
yield "output", obj[key]
|
||||
elif isinstance(obj, list) and isinstance(key, int) and 0 <= key < len(obj):
|
||||
@@ -148,10 +170,192 @@ class FindInDictionaryBlock(Block):
|
||||
yield "missing", input_data.input
|
||||
|
||||
|
||||
class AgentInputBlock(Block):
|
||||
"""
|
||||
This block is used to provide input to the graph.
|
||||
|
||||
It takes in a value, name, description, default values list and bool to limit selection to default values.
|
||||
|
||||
It Outputs the value passed as input.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
name: str = SchemaField(description="The name of the input.")
|
||||
value: Any = SchemaField(
|
||||
description="The value to be passed as input.",
|
||||
default=None,
|
||||
)
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the input.", default=None, advanced=True
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the input.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
placeholder_values: List[Any] = SchemaField(
|
||||
description="The placeholder values to be passed as input.",
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
limit_to_placeholder_values: bool = SchemaField(
|
||||
description="Whether to limit the selection to placeholder values.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to show the input in the advanced section, if the field is not required.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the input should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: Any = SchemaField(description="The value passed as input.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||
description="This block is used to provide input to the graph.",
|
||||
input_schema=AgentInputBlock.Input,
|
||||
output_schema=AgentInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "input_1",
|
||||
"description": "This is a test input.",
|
||||
"placeholder_values": [],
|
||||
"limit_to_placeholder_values": False,
|
||||
},
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "input_2",
|
||||
"description": "This is a test input.",
|
||||
"placeholder_values": ["Hello, World!"],
|
||||
"limit_to_placeholder_values": True,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", "Hello, World!"),
|
||||
("result", "Hello, World!"),
|
||||
],
|
||||
categories={BlockCategory.INPUT, BlockCategory.BASIC},
|
||||
block_type=BlockType.INPUT,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "result", input_data.value
|
||||
|
||||
|
||||
class AgentOutputBlock(Block):
|
||||
"""
|
||||
Records the output of the graph for users to see.
|
||||
|
||||
Behavior:
|
||||
If `format` is provided and the `value` is of a type that can be formatted,
|
||||
the block attempts to format the recorded_value using the `format`.
|
||||
If formatting fails or no `format` is provided, the raw `value` is output.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
value: Any = SchemaField(
|
||||
description="The value to be recorded as output.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(description="The name of the output.")
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the output.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the output.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="The format string to be used to format the recorded_value. Use Jinja2 syntax.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to treat the output as advanced.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the output should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The value recorded as output.")
|
||||
name: Any = SchemaField(description="The name of the value recorded as output.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||
description="Stores the output of the graph for users to see.",
|
||||
input_schema=AgentOutputBlock.Input,
|
||||
output_schema=AgentOutputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "output_1",
|
||||
"description": "This is a test output.",
|
||||
"format": "{{ output_1 }}!!",
|
||||
},
|
||||
{
|
||||
"value": "42",
|
||||
"name": "output_2",
|
||||
"description": "This is another test output.",
|
||||
"format": "{{ output_2 }}",
|
||||
},
|
||||
{
|
||||
"value": MockObject(value="!!", key="key"),
|
||||
"name": "output_3",
|
||||
"description": "This is a test output with a mock object.",
|
||||
"format": "{{ output_3 }}",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("output", "Hello, World!!!"),
|
||||
("output", "42"),
|
||||
("output", MockObject(value="!!", key="key")),
|
||||
],
|
||||
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
|
||||
block_type=BlockType.OUTPUT,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Attempts to format the recorded_value using the fmt_string if provided.
|
||||
If formatting fails or no fmt_string is given, returns the original recorded_value.
|
||||
"""
|
||||
if input_data.format:
|
||||
try:
|
||||
yield "output", formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
except Exception as e:
|
||||
yield "output", f"Error: {e}, {input_data.value}"
|
||||
else:
|
||||
yield "output", input_data.value
|
||||
yield "name", input_data.name
|
||||
|
||||
|
||||
class AddToDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
default={},
|
||||
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
|
||||
)
|
||||
key: str = SchemaField(
|
||||
@@ -167,7 +371,7 @@ class AddToDictionaryBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
entries: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
default={},
|
||||
description="The entries to add to the dictionary. This is the batch version of the `key` and `value` fields.",
|
||||
advanced=True,
|
||||
)
|
||||
@@ -229,7 +433,7 @@ class AddToDictionaryBlock(Block):
|
||||
class AddToListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
description="The list to add the entry to. If not provided, a new list will be created.",
|
||||
)
|
||||
@@ -239,7 +443,7 @@ class AddToListBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
entries: List[Any] = SchemaField(
|
||||
default_factory=lambda: list(),
|
||||
default=[],
|
||||
description="The entries to add to the list. This is the batch version of the `entry` field.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
@@ -55,7 +55,7 @@ class CodeExecutionBlock(Block):
|
||||
"These commands are executed with `sh`, in the foreground."
|
||||
),
|
||||
placeholder="pip install cowsay",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -207,7 +207,7 @@ class InstantiationBlock(Block):
|
||||
"These commands are executed with `sh`, in the foreground."
|
||||
),
|
||||
placeholder="pip install cowsay",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from backend.data.block import (
|
||||
BlockSchema,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.compass import CompassWebhookType
|
||||
|
||||
|
||||
@@ -43,7 +42,7 @@ class CompassAITriggerBlock(Block):
|
||||
input_schema=CompassAITriggerBlock.Input,
|
||||
output_schema=CompassAITriggerBlock.Output,
|
||||
webhook_config=BlockManualWebhookConfig(
|
||||
provider=ProviderName.COMPASS,
|
||||
provider="compass",
|
||||
webhook_type=CompassWebhookType.TRANSCRIPTION,
|
||||
),
|
||||
test_input=[
|
||||
|
||||
@@ -34,7 +34,7 @@ class ReadCsvBlock(Block):
|
||||
)
|
||||
skip_columns: list[str] = SchemaField(
|
||||
description="The columns to skip from the start of the row",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -49,9 +49,8 @@ class ExaContentsBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of document contents",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -38,11 +38,11 @@ class ExaSearchBlock(Block):
|
||||
)
|
||||
include_domains: List[str] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
exclude_domains: List[str] = SchemaField(
|
||||
description="Domains to exclude from search",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
@@ -59,12 +59,12 @@ class ExaSearchBlock(Block):
|
||||
)
|
||||
include_text: List[str] = SchemaField(
|
||||
description="Text patterns to include",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
exclude_text: List[str] = SchemaField(
|
||||
description="Text patterns to exclude",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
@@ -76,7 +76,7 @@ class ExaSearchBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of search results",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -26,12 +26,12 @@ class ExaFindSimilarBlock(Block):
|
||||
)
|
||||
include_domains: List[str] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
exclude_domains: List[str] = SchemaField(
|
||||
description="Domains to exclude from search",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
@@ -48,12 +48,12 @@ class ExaFindSimilarBlock(Block):
|
||||
)
|
||||
include_text: List[str] = SchemaField(
|
||||
description="Text patterns to include (max 1 string, up to 5 words)",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
exclude_text: List[str] = SchemaField(
|
||||
description="Text patterns to exclude (max 1 string, up to 5 words)",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
@@ -65,7 +65,7 @@ class ExaFindSimilarBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
results: List[Any] = SchemaField(
|
||||
description="List of similar documents with title, URL, published date, author, and score",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -42,7 +42,7 @@ class AIVideoGeneratorBlock(Block):
|
||||
description="Error message if video generation failed."
|
||||
)
|
||||
logs: list[str] = SchemaField(
|
||||
description="Generation progress logs.",
|
||||
description="Generation progress logs.", optional=True
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockManualWebhookConfig,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.generic import GenericWebhookType
|
||||
|
||||
|
||||
class GenericWebhookTriggerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
constants: dict = SchemaField(
|
||||
description="The constants to be set when the block is put on the graph",
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
payload: dict = SchemaField(
|
||||
description="The complete webhook payload that was received from the generic webhook."
|
||||
)
|
||||
constants: dict = SchemaField(
|
||||
description="The constants to be set when the block is put on the graph"
|
||||
)
|
||||
|
||||
example_payload = {"message": "Hello, World!"}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8fa8c167-2002-47ce-aba8-97572fc5d387",
|
||||
description="This block will output the contents of the generic input for the webhook.",
|
||||
categories={BlockCategory.INPUT},
|
||||
input_schema=GenericWebhookTriggerBlock.Input,
|
||||
output_schema=GenericWebhookTriggerBlock.Output,
|
||||
webhook_config=BlockManualWebhookConfig(
|
||||
provider=ProviderName.GENERIC_WEBHOOK,
|
||||
webhook_type=GenericWebhookType.PLAIN,
|
||||
),
|
||||
test_input={"constants": {"key": "value"}, "payload": self.example_payload},
|
||||
test_output=[
|
||||
("constants", {"key": "value"}),
|
||||
("payload", self.example_payload),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "constants", input_data.constants
|
||||
yield "payload", input_data.payload
|
||||
@@ -38,59 +38,6 @@ def _get_headers(credentials: GithubCredentials) -> dict[str, str]:
|
||||
}
|
||||
|
||||
|
||||
def convert_comment_url_to_api_endpoint(comment_url: str) -> str:
|
||||
"""
|
||||
Converts a GitHub comment URL (web interface) to the appropriate API endpoint URL.
|
||||
|
||||
Handles:
|
||||
1. Issue/PR comments: #issuecomment-{id}
|
||||
2. PR review comments: #discussion_r{id}
|
||||
|
||||
Returns the appropriate API endpoint path for the comment.
|
||||
"""
|
||||
# First, check if this is already an API URL
|
||||
parsed_url = urlparse(comment_url)
|
||||
if parsed_url.hostname == "api.github.com":
|
||||
return comment_url
|
||||
|
||||
# Replace pull with issues for comment endpoints
|
||||
if "/pull/" in comment_url:
|
||||
comment_url = comment_url.replace("/pull/", "/issues/")
|
||||
|
||||
# Handle issue/PR comments (#issuecomment-xxx)
|
||||
if "#issuecomment-" in comment_url:
|
||||
base_url, comment_part = comment_url.split("#issuecomment-")
|
||||
comment_id = comment_part
|
||||
|
||||
# Extract repo information from base URL
|
||||
parsed_url = urlparse(base_url)
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
owner, repo = path_parts[0], path_parts[1]
|
||||
|
||||
# Construct API URL for issue comments
|
||||
return (
|
||||
f"https://api.github.com/repos/{owner}/{repo}/issues/comments/{comment_id}"
|
||||
)
|
||||
|
||||
# Handle PR review comments (#discussion_r)
|
||||
elif "#discussion_r" in comment_url:
|
||||
base_url, comment_part = comment_url.split("#discussion_r")
|
||||
comment_id = comment_part
|
||||
|
||||
# Extract repo information from base URL
|
||||
parsed_url = urlparse(base_url)
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
owner, repo = path_parts[0], path_parts[1]
|
||||
|
||||
# Construct API URL for PR review comments
|
||||
return (
|
||||
f"https://api.github.com/repos/{owner}/{repo}/pulls/comments/{comment_id}"
|
||||
)
|
||||
|
||||
# If no specific comment identifiers are found, use the general URL conversion
|
||||
return _convert_to_api_url(comment_url)
|
||||
|
||||
|
||||
def get_api(
|
||||
credentials: GithubCredentials | GithubFineGrainedAPICredentials,
|
||||
convert_urls: bool = True,
|
||||
|
||||
@@ -172,9 +172,7 @@ class GithubCreateCheckRunBlock(Block):
|
||||
data.output = output_data
|
||||
|
||||
check_runs_url = f"{repo_url}/check-runs"
|
||||
response = api.post(
|
||||
check_runs_url, data=data.model_dump_json(exclude_none=True)
|
||||
)
|
||||
response = api.post(check_runs_url)
|
||||
result = response.json()
|
||||
|
||||
return {
|
||||
@@ -325,9 +323,7 @@ class GithubUpdateCheckRunBlock(Block):
|
||||
data.output = output_data
|
||||
|
||||
check_run_url = f"{repo_url}/check-runs/{check_run_id}"
|
||||
response = api.patch(
|
||||
check_run_url, data=data.model_dump_json(exclude_none=True)
|
||||
)
|
||||
response = api.patch(check_run_url)
|
||||
result = response.json()
|
||||
|
||||
return {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
@@ -6,7 +5,7 @@ from typing_extensions import TypedDict
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import convert_comment_url_to_api_endpoint, get_api
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
@@ -15,8 +14,6 @@ from ._auth import (
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_github_url(url: str) -> bool:
|
||||
return urlparse(url).netloc == "github.com"
|
||||
@@ -111,228 +108,6 @@ class GithubCommentBlock(Block):
|
||||
# --8<-- [end:GithubCommentBlockExample]
|
||||
|
||||
|
||||
class GithubUpdateCommentBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
comment_url: str = SchemaField(
|
||||
description="URL of the GitHub comment",
|
||||
placeholder="https://github.com/owner/repo/issues/1#issuecomment-123456789",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue or pull request",
|
||||
placeholder="https://github.com/owner/repo/issues/1",
|
||||
default="",
|
||||
)
|
||||
comment_id: str = SchemaField(
|
||||
description="ID of the GitHub comment",
|
||||
placeholder="123456789",
|
||||
default="",
|
||||
)
|
||||
comment: str = SchemaField(
|
||||
description="Comment to update",
|
||||
placeholder="Enter your comment",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
id: int = SchemaField(description="ID of the updated comment")
|
||||
url: str = SchemaField(description="URL to the comment on GitHub")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the comment update failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b3f4d747-10e3-4e69-8c51-f2be1d99c9a7",
|
||||
description="This block updates a comment on a specified GitHub issue or pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUpdateCommentBlock.Input,
|
||||
output_schema=GithubUpdateCommentBlock.Output,
|
||||
test_input={
|
||||
"comment_url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
|
||||
"comment": "This is an updated comment.",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", 123456789),
|
||||
(
|
||||
"url",
|
||||
"https://github.com/owner/repo/issues/1#issuecomment-123456789",
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"update_comment": lambda *args, **kwargs: (
|
||||
123456789,
|
||||
"https://github.com/owner/repo/issues/1#issuecomment-123456789",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_comment(
|
||||
credentials: GithubCredentials, comment_url: str, body_text: str
|
||||
) -> tuple[int, str]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
data = {"body": body_text}
|
||||
url = convert_comment_url_to_api_endpoint(comment_url)
|
||||
|
||||
logger.info(url)
|
||||
response = api.patch(url, json=data)
|
||||
comment = response.json()
|
||||
return comment["id"], comment["html_url"]
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
if (
|
||||
not input_data.comment_url
|
||||
and input_data.comment_id
|
||||
and input_data.issue_url
|
||||
):
|
||||
parsed_url = urlparse(input_data.issue_url)
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
owner, repo = path_parts[0], path_parts[1]
|
||||
|
||||
input_data.comment_url = f"https://api.github.com/repos/{owner}/{repo}/issues/comments/{input_data.comment_id}"
|
||||
|
||||
elif (
|
||||
not input_data.comment_url
|
||||
and not input_data.comment_id
|
||||
and input_data.issue_url
|
||||
):
|
||||
raise ValueError(
|
||||
"Must provide either comment_url or comment_id and issue_url"
|
||||
)
|
||||
id, url = self.update_comment(
|
||||
credentials,
|
||||
input_data.comment_url,
|
||||
input_data.comment,
|
||||
)
|
||||
yield "id", id
|
||||
yield "url", url
|
||||
|
||||
|
||||
class GithubListCommentsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue or pull request",
|
||||
placeholder="https://github.com/owner/repo/issues/1",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class CommentItem(TypedDict):
|
||||
id: int
|
||||
body: str
|
||||
user: str
|
||||
url: str
|
||||
|
||||
comment: CommentItem = SchemaField(
|
||||
title="Comment", description="Comments with their ID, body, user, and URL"
|
||||
)
|
||||
comments: list[CommentItem] = SchemaField(
|
||||
description="List of comments with their ID, body, user, and URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing comments failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c4b5fb63-0005-4a11-b35a-0c2467bd6b59",
|
||||
description="This block lists all comments for a specified GitHub issue or pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListCommentsBlock.Input,
|
||||
output_schema=GithubListCommentsBlock.Output,
|
||||
test_input={
|
||||
"issue_url": "https://github.com/owner/repo/issues/1",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"comment",
|
||||
{
|
||||
"id": 123456789,
|
||||
"body": "This is a test comment.",
|
||||
"user": "test_user",
|
||||
"url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
|
||||
},
|
||||
),
|
||||
(
|
||||
"comments",
|
||||
[
|
||||
{
|
||||
"id": 123456789,
|
||||
"body": "This is a test comment.",
|
||||
"user": "test_user",
|
||||
"url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_comments": lambda *args, **kwargs: [
|
||||
{
|
||||
"id": 123456789,
|
||||
"body": "This is a test comment.",
|
||||
"user": "test_user",
|
||||
"url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def list_comments(
|
||||
credentials: GithubCredentials, issue_url: str
|
||||
) -> list[Output.CommentItem]:
|
||||
parsed_url = urlparse(issue_url)
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
|
||||
owner = path_parts[0]
|
||||
repo = path_parts[1]
|
||||
|
||||
# GitHub API uses 'issues' for both issues and pull requests when it comes to comments
|
||||
issue_number = path_parts[3] # Whether 'issues/123' or 'pull/123'
|
||||
|
||||
# Construct the proper API URL directly
|
||||
api_url = f"https://api.github.com/repos/{owner}/{repo}/issues/{issue_number}/comments"
|
||||
|
||||
# Set convert_urls=False since we're already providing an API URL
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
response = api.get(api_url)
|
||||
comments = response.json()
|
||||
parsed_comments: list[GithubListCommentsBlock.Output.CommentItem] = [
|
||||
{
|
||||
"id": comment["id"],
|
||||
"body": comment["body"],
|
||||
"user": comment["user"]["login"],
|
||||
"url": comment["html_url"],
|
||||
}
|
||||
for comment in comments
|
||||
]
|
||||
return parsed_comments
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
comments = self.list_comments(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
)
|
||||
yield from (("comment", comment) for comment in comments)
|
||||
yield "comments", comments
|
||||
|
||||
|
||||
class GithubMakeIssueBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
|
||||
@@ -144,7 +144,7 @@ class GithubCreateStatusBlock(Block):
|
||||
data.description = description
|
||||
|
||||
status_url = f"{repo_url}/statuses/{sha}"
|
||||
response = api.post(status_url, data=data.model_dump_json(exclude_none=True))
|
||||
response = api.post(status_url, json=data)
|
||||
result = response.json()
|
||||
|
||||
return {
|
||||
|
||||
@@ -12,7 +12,6 @@ from backend.data.block import (
|
||||
BlockWebhookConfig,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
@@ -37,7 +36,7 @@ class GitHubTriggerBase:
|
||||
placeholder="{owner}/{repo}",
|
||||
)
|
||||
# --8<-- [start:example-payload-field]
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
payload: dict = SchemaField(hidden=True, default={})
|
||||
# --8<-- [end:example-payload-field]
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -124,7 +123,7 @@ class GithubPullRequestTriggerBlock(GitHubTriggerBase, Block):
|
||||
output_schema=GithubPullRequestTriggerBlock.Output,
|
||||
# --8<-- [start:example-webhook_config]
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName.GITHUB,
|
||||
provider="github",
|
||||
webhook_type=GithubWebhookType.REPO,
|
||||
resource_format="{repo}",
|
||||
event_filter_input="events",
|
||||
|
||||
@@ -8,7 +8,6 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ._auth import (
|
||||
GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
@@ -151,8 +150,8 @@ class GmailReadBlock(Block):
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=Settings().secrets.google_client_id,
|
||||
client_secret=Settings().secrets.google_client_secret,
|
||||
client_id=kwargs.get("client_id"),
|
||||
client_secret=kwargs.get("client_secret"),
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
return build("gmail", "v1", credentials=creds)
|
||||
|
||||
@@ -3,7 +3,6 @@ from googleapiclient.discovery import build
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ._auth import (
|
||||
GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
@@ -87,8 +86,8 @@ class GoogleSheetsReadBlock(Block):
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=Settings().secrets.google_client_id,
|
||||
client_secret=Settings().secrets.google_client_secret,
|
||||
client_id=kwargs.get("client_id"),
|
||||
client_secret=kwargs.get("client_secret"),
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
return build("sheets", "v4", credentials=creds)
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from requests.exceptions import HTTPError, RequestException
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import requests
|
||||
|
||||
logger = logging.getLogger(name=__name__)
|
||||
|
||||
|
||||
class HttpMethod(Enum):
|
||||
GET = "GET"
|
||||
@@ -34,7 +29,7 @@ class SendWebRequestBlock(Block):
|
||||
)
|
||||
headers: dict[str, str] = SchemaField(
|
||||
description="The headers to include in the request",
|
||||
default_factory=dict,
|
||||
default={},
|
||||
)
|
||||
json_format: bool = SchemaField(
|
||||
title="JSON format",
|
||||
@@ -48,9 +43,8 @@ class SendWebRequestBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: object = SchemaField(description="The response from the server")
|
||||
client_error: object = SchemaField(description="Errors on 4xx status codes")
|
||||
server_error: object = SchemaField(description="Errors on 5xx status codes")
|
||||
error: str = SchemaField(description="Errors for all other exceptions")
|
||||
client_error: object = SchemaField(description="The error on 4xx status codes")
|
||||
server_error: object = SchemaField(description="The error on 5xx status codes")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -74,40 +68,20 @@ class SendWebRequestBlock(Block):
|
||||
# we should send it as plain text instead
|
||||
input_data.json_format = False
|
||||
|
||||
try:
|
||||
response = requests.request(
|
||||
input_data.method.value,
|
||||
input_data.url,
|
||||
headers=input_data.headers,
|
||||
json=body if input_data.json_format else None,
|
||||
data=body if not input_data.json_format else None,
|
||||
)
|
||||
result = response.json() if input_data.json_format else response.text
|
||||
response = requests.request(
|
||||
input_data.method.value,
|
||||
input_data.url,
|
||||
headers=input_data.headers,
|
||||
json=body if input_data.json_format else None,
|
||||
data=body if not input_data.json_format else None,
|
||||
)
|
||||
result = response.json() if input_data.json_format else response.text
|
||||
|
||||
if response.status_code // 100 == 2:
|
||||
yield "response", result
|
||||
|
||||
except HTTPError as e:
|
||||
# Handle error responses
|
||||
try:
|
||||
result = e.response.json() if input_data.json_format else str(e)
|
||||
except json.JSONDecodeError:
|
||||
result = str(e)
|
||||
|
||||
if 400 <= e.response.status_code < 500:
|
||||
yield "client_error", result
|
||||
elif 500 <= e.response.status_code < 600:
|
||||
yield "server_error", result
|
||||
else:
|
||||
error_msg = (
|
||||
"Unexpected status code "
|
||||
f"{e.response.status_code} '{e.response.reason}'"
|
||||
)
|
||||
logger.warning(error_msg)
|
||||
yield "error", error_msg
|
||||
|
||||
except RequestException as e:
|
||||
# Handle other request-related exceptions
|
||||
yield "error", str(e)
|
||||
|
||||
except Exception as e:
|
||||
# Catch any other unexpected exceptions
|
||||
yield "error", str(e)
|
||||
elif response.status_code // 100 == 4:
|
||||
yield "client_error", result
|
||||
elif response.status_code // 100 == 5:
|
||||
yield "server_error", result
|
||||
else:
|
||||
raise ValueError(f"Unexpected status code: {response.status_code}")
|
||||
|
||||
@@ -15,8 +15,7 @@ class HubSpotCompanyBlock(Block):
|
||||
description="Operation to perform (create, update, get)", default="get"
|
||||
)
|
||||
company_data: dict = SchemaField(
|
||||
description="Company data for create/update operations",
|
||||
default_factory=dict,
|
||||
description="Company data for create/update operations", default={}
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Company domain for get/update operations", default=""
|
||||
|
||||
@@ -15,8 +15,7 @@ class HubSpotContactBlock(Block):
|
||||
description="Operation to perform (create, update, get)", default="get"
|
||||
)
|
||||
contact_data: dict = SchemaField(
|
||||
description="Contact data for create/update operations",
|
||||
default_factory=dict,
|
||||
description="Contact data for create/update operations", default={}
|
||||
)
|
||||
email: str = SchemaField(
|
||||
description="Email address for get/update operations", default=""
|
||||
|
||||
@@ -19,7 +19,7 @@ class HubSpotEngagementBlock(Block):
|
||||
)
|
||||
email_data: dict = SchemaField(
|
||||
description="Email data including recipient, subject, content",
|
||||
default_factory=dict,
|
||||
default={},
|
||||
)
|
||||
contact_id: str = SchemaField(
|
||||
description="Contact ID for engagement tracking", default=""
|
||||
@@ -27,6 +27,7 @@ class HubSpotEngagementBlock(Block):
|
||||
timeframe_days: int = SchemaField(
|
||||
description="Number of days to look back for engagement",
|
||||
default=30,
|
||||
optional=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -142,16 +142,6 @@ class IdeogramModelBlock(Block):
|
||||
title="Color Palette Preset",
|
||||
advanced=True,
|
||||
)
|
||||
custom_color_palette: Optional[list[str]] = SchemaField(
|
||||
description=(
|
||||
"Only available for model version V_2 or V_2_TURBO. Provide one or more color hex codes "
|
||||
"(e.g., ['#000030', '#1C0C47', '#9900FF', '#4285F4', '#FFFFFF']) to define a custom color "
|
||||
"palette. Only used if 'color_palette_name' is 'NONE'."
|
||||
),
|
||||
default=None,
|
||||
title="Custom Color Palette",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: str = SchemaField(description="Generated image URL")
|
||||
@@ -174,13 +164,6 @@ class IdeogramModelBlock(Block):
|
||||
"style_type": StyleType.AUTO,
|
||||
"negative_prompt": None,
|
||||
"color_palette_name": ColorPalettePreset.NONE,
|
||||
"custom_color_palette": [
|
||||
"#000030",
|
||||
"#1C0C47",
|
||||
"#9900FF",
|
||||
"#4285F4",
|
||||
"#FFFFFF",
|
||||
],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
@@ -190,7 +173,7 @@ class IdeogramModelBlock(Block):
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name, custom_colors: "https://ideogram.ai/api/images/test-generated-image-url.png",
|
||||
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name: "https://ideogram.ai/api/images/test-generated-image-url.png",
|
||||
"upscale_image": lambda api_key, image_url: "https://ideogram.ai/api/images/test-upscaled-image-url.png",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -212,7 +195,6 @@ class IdeogramModelBlock(Block):
|
||||
style_type=input_data.style_type.value,
|
||||
negative_prompt=input_data.negative_prompt,
|
||||
color_palette_name=input_data.color_palette_name.value,
|
||||
custom_colors=input_data.custom_color_palette,
|
||||
)
|
||||
|
||||
# Step 2: Upscale the image if requested
|
||||
@@ -235,7 +217,6 @@ class IdeogramModelBlock(Block):
|
||||
style_type: str,
|
||||
negative_prompt: Optional[str],
|
||||
color_palette_name: str,
|
||||
custom_colors: Optional[list[str]],
|
||||
):
|
||||
url = "https://api.ideogram.ai/generate"
|
||||
headers = {
|
||||
@@ -260,11 +241,7 @@ class IdeogramModelBlock(Block):
|
||||
data["image_request"]["negative_prompt"] = negative_prompt
|
||||
|
||||
if color_palette_name != "NONE":
|
||||
data["color_palette"] = {"name": color_palette_name}
|
||||
elif custom_colors:
|
||||
data["color_palette"] = {
|
||||
"members": [{"color_hex": color} for color in custom_colors]
|
||||
}
|
||||
data["image_request"]["color_palette"] = {"name": color_palette_name}
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=data, headers=headers)
|
||||
@@ -290,7 +267,9 @@ class IdeogramModelBlock(Block):
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=headers,
|
||||
data={"image_request": "{}"},
|
||||
data={
|
||||
"image_request": "{}", # Empty JSON object
|
||||
},
|
||||
files=files,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,556 +0,0 @@
|
||||
import copy
|
||||
from datetime import date, time
|
||||
from typing import Any, Optional
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.settings import Config
|
||||
from backend.util.text import TextFormatter
|
||||
from backend.util.type import LongTextType, MediaFileType, ShortTextType
|
||||
|
||||
formatter = TextFormatter()
|
||||
config = Config()
|
||||
|
||||
|
||||
class AgentInputBlock(Block):
|
||||
"""
|
||||
This block is used to provide input to the graph.
|
||||
|
||||
It takes in a value, name, description, default values list and bool to limit selection to default values.
|
||||
|
||||
It Outputs the value passed as input.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
name: str = SchemaField(description="The name of the input.")
|
||||
value: Any = SchemaField(
|
||||
description="The value to be passed as input.",
|
||||
default=None,
|
||||
)
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the input.", default=None, advanced=True
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the input.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="The placeholder values to be passed as input.",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
hidden=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to show the input in the advanced section, if the field is not required.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the input should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = copy.deepcopy(self.get_field_schema("value"))
|
||||
if possible_values := self.placeholder_values:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: Any = SchemaField(description="The value passed as input.")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
**{
|
||||
"id": "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||
"description": "Base block for user inputs.",
|
||||
"input_schema": AgentInputBlock.Input,
|
||||
"output_schema": AgentInputBlock.Output,
|
||||
"test_input": [
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "input_1",
|
||||
"description": "Example test input.",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "input_2",
|
||||
"description": "Example test input with placeholders.",
|
||||
"placeholder_values": ["Hello, World!"],
|
||||
},
|
||||
],
|
||||
"test_output": [
|
||||
("result", "Hello, World!"),
|
||||
("result", "Hello, World!"),
|
||||
],
|
||||
"categories": {BlockCategory.INPUT, BlockCategory.BASIC},
|
||||
"block_type": BlockType.INPUT,
|
||||
"static_output": True,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
|
||||
if input_data.value is not None:
|
||||
yield "result", input_data.value
|
||||
|
||||
|
||||
class AgentOutputBlock(Block):
|
||||
"""
|
||||
Records the output of the graph for users to see.
|
||||
|
||||
Behavior:
|
||||
If `format` is provided and the `value` is of a type that can be formatted,
|
||||
the block attempts to format the recorded_value using the `format`.
|
||||
If formatting fails or no `format` is provided, the raw `value` is output.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
value: Any = SchemaField(
|
||||
description="The value to be recorded as output.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(description="The name of the output.")
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the output.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the output.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="The format string to be used to format the recorded_value. Use Jinja2 syntax.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to treat the output as advanced.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the output should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
return self.get_field_schema("value")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The value recorded as output.")
|
||||
name: Any = SchemaField(description="The name of the value recorded as output.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||
description="Stores the output of the graph for users to see.",
|
||||
input_schema=AgentOutputBlock.Input,
|
||||
output_schema=AgentOutputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "output_1",
|
||||
"description": "This is a test output.",
|
||||
"format": "{{ output_1 }}!!",
|
||||
},
|
||||
{
|
||||
"value": "42",
|
||||
"name": "output_2",
|
||||
"description": "This is another test output.",
|
||||
"format": "{{ output_2 }}",
|
||||
},
|
||||
{
|
||||
"value": MockObject(value="!!", key="key"),
|
||||
"name": "output_3",
|
||||
"description": "This is a test output with a mock object.",
|
||||
"format": "{{ output_3 }}",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("output", "Hello, World!!!"),
|
||||
("output", "42"),
|
||||
("output", MockObject(value="!!", key="key")),
|
||||
],
|
||||
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
|
||||
block_type=BlockType.OUTPUT,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Attempts to format the recorded_value using the fmt_string if provided.
|
||||
If formatting fails or no fmt_string is given, returns the original recorded_value.
|
||||
"""
|
||||
if input_data.format:
|
||||
try:
|
||||
yield "output", formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
except Exception as e:
|
||||
yield "output", f"Error: {e}, {input_data.value}"
|
||||
else:
|
||||
yield "output", input_data.value
|
||||
yield "name", input_data.name
|
||||
|
||||
|
||||
class AgentShortTextInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[ShortTextType] = SchemaField(
|
||||
description="Short text input.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="Short text result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7fcd3bcb-8e1b-4e69-903d-32d3d4a92158",
|
||||
description="Block for short text input (single-line).",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentShortTextInputBlock.Input,
|
||||
output_schema=AgentShortTextInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Hello",
|
||||
"name": "short_text_1",
|
||||
"description": "Short text example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Quick test",
|
||||
"name": "short_text_2",
|
||||
"description": "Short text example 2",
|
||||
"placeholder_values": ["Quick test", "Another option"],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", "Hello"),
|
||||
("result", "Quick test"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentLongTextInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[LongTextType] = SchemaField(
|
||||
description="Long text input (potentially multi-line).",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="Long text result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="90a56ffb-7024-4b2b-ab50-e26c5e5ab8ba",
|
||||
description="Block for long text input (multi-line).",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentLongTextInputBlock.Input,
|
||||
output_schema=AgentLongTextInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Lorem ipsum dolor sit amet...",
|
||||
"name": "long_text_1",
|
||||
"description": "Long text example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Another multiline text input.",
|
||||
"name": "long_text_2",
|
||||
"description": "Long text example 2",
|
||||
"placeholder_values": ["Another multiline text input."],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", "Lorem ipsum dolor sit amet..."),
|
||||
("result", "Another multiline text input."),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentNumberInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[int] = SchemaField(
|
||||
description="Number input.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: int = SchemaField(description="Number result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="96dae2bb-97a2-41c2-bd2f-13a3b5a8ea98",
|
||||
description="Block for number input.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentNumberInputBlock.Input,
|
||||
output_schema=AgentNumberInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": 42,
|
||||
"name": "number_input_1",
|
||||
"description": "Number example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": 314,
|
||||
"name": "number_input_2",
|
||||
"description": "Number example 2",
|
||||
"placeholder_values": [314, 2718],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", 42),
|
||||
("result", 314),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentDateInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[date] = SchemaField(
|
||||
description="Date input (YYYY-MM-DD).",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: date = SchemaField(description="Date result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7e198b09-4994-47db-8b4d-952d98241817",
|
||||
description="Block for date input.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentDateInputBlock.Input,
|
||||
output_schema=AgentDateInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
# If your system can parse JSON date strings to date objects
|
||||
"value": str(date(2025, 3, 19)),
|
||||
"name": "date_input_1",
|
||||
"description": "Example date input 1",
|
||||
},
|
||||
{
|
||||
"value": str(date(2023, 12, 31)),
|
||||
"name": "date_input_2",
|
||||
"description": "Example date input 2",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", date(2025, 3, 19)),
|
||||
("result", date(2023, 12, 31)),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentTimeInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[time] = SchemaField(
|
||||
description="Time input (HH:MM:SS).",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: time = SchemaField(description="Time result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2a1c757e-86cf-4c7e-aacf-060dc382e434",
|
||||
description="Block for time input.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentTimeInputBlock.Input,
|
||||
output_schema=AgentTimeInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": str(time(9, 30, 0)),
|
||||
"name": "time_input_1",
|
||||
"description": "Time example 1",
|
||||
},
|
||||
{
|
||||
"value": str(time(23, 59, 59)),
|
||||
"name": "time_input_2",
|
||||
"description": "Time example 2",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", time(9, 30, 0)),
|
||||
("result", time(23, 59, 59)),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentFileInputBlock(AgentInputBlock):
|
||||
"""
|
||||
A simplified file-upload block. In real usage, you might have a custom
|
||||
file type or handle binary data. Here, we'll store a string path as the example.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[MediaFileType] = SchemaField(
|
||||
description="Path or reference to an uploaded file.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="File reference/path result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="95ead23f-8283-4654-aef3-10c053b74a31",
|
||||
description="Block for file upload input (string path for example).",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentFileInputBlock.Input,
|
||||
output_schema=AgentFileInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "data:image/png;base64,MQ==",
|
||||
"name": "file_upload_1",
|
||||
"description": "Example file upload 1",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", str),
|
||||
],
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
if not input_data.value:
|
||||
return
|
||||
|
||||
file_path = store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.value,
|
||||
return_content=False,
|
||||
)
|
||||
yield "result", file_path
|
||||
|
||||
|
||||
class AgentDropdownInputBlock(AgentInputBlock):
|
||||
"""
|
||||
A specialized text input block that relies on placeholder_values to present a dropdown.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[str] = SchemaField(
|
||||
description="Text selected from a dropdown.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="Possible values for the dropdown.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
title="Dropdown Options",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="Selected dropdown value.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="655d6fdf-a334-421c-b733-520549c07cd1",
|
||||
description="Block for dropdown text selection.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentDropdownInputBlock.Input,
|
||||
output_schema=AgentDropdownInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Option A",
|
||||
"name": "dropdown_1",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 1",
|
||||
},
|
||||
{
|
||||
"value": "Option C",
|
||||
"name": "dropdown_2",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 2",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", "Option A"),
|
||||
("result", "Option C"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentToggleInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: bool = SchemaField(
|
||||
description="Boolean toggle input.",
|
||||
default=False,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: bool = SchemaField(description="Boolean toggle result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="cbf36ab5-df4a-43b6-8a7f-f7ed8652116e",
|
||||
description="Block for boolean toggle input.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentToggleInputBlock.Input,
|
||||
output_schema=AgentToggleInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": True,
|
||||
"name": "toggle_1",
|
||||
"description": "Toggle example 1",
|
||||
},
|
||||
{
|
||||
"value": False,
|
||||
"name": "toggle_2",
|
||||
"description": "Toggle example 2",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", True),
|
||||
("result", False),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
IO_BLOCK_IDs = [
|
||||
AgentInputBlock().id,
|
||||
AgentOutputBlock().id,
|
||||
AgentShortTextInputBlock().id,
|
||||
AgentLongTextInputBlock().id,
|
||||
AgentNumberInputBlock().id,
|
||||
AgentDateInputBlock().id,
|
||||
AgentTimeInputBlock().id,
|
||||
AgentFileInputBlock().id,
|
||||
AgentDropdownInputBlock().id,
|
||||
AgentToggleInputBlock().id,
|
||||
]
|
||||
@@ -11,13 +11,13 @@ class StepThroughItemsBlock(Block):
|
||||
advanced=False,
|
||||
description="The list or dictionary of items to iterate over",
|
||||
placeholder="[1, 2, 3, 4, 5] or {'key1': 'value1', 'key2': 'value2'}",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
items_object: dict = SchemaField(
|
||||
advanced=False,
|
||||
description="The list or dictionary of items to iterate over",
|
||||
placeholder="[1, 2, 3, 4, 5] or {'key1': 'value1', 'key2': 'value2'}",
|
||||
default_factory=dict,
|
||||
default={},
|
||||
)
|
||||
items_str: str = SchemaField(
|
||||
advanced=False,
|
||||
|
||||
@@ -23,7 +23,7 @@ class JinaChunkingBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
chunks: list = SchemaField(description="List of chunked texts")
|
||||
tokens: list = SchemaField(
|
||||
description="List of token information for each chunk",
|
||||
description="List of token information for each chunk", optional=True
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from urllib.parse import quote
|
||||
from groq._utils._utils import quote
|
||||
|
||||
from backend.blocks.jina._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
|
||||
@@ -28,8 +28,8 @@ class LinearCreateIssueBlock(Block):
|
||||
priority: int | None = SchemaField(
|
||||
description="Priority of the issue",
|
||||
default=None,
|
||||
ge=0,
|
||||
le=4,
|
||||
minimum=0,
|
||||
maximum=4,
|
||||
)
|
||||
project_name: str | None = SchemaField(
|
||||
description="Name of the project to create the issue on",
|
||||
|
||||
@@ -4,24 +4,27 @@ from abc import ABC
|
||||
from enum import Enum, EnumMeta
|
||||
from json import JSONDecodeError
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import _EnumMemberT
|
||||
|
||||
import anthropic
|
||||
import ollama
|
||||
import openai
|
||||
from anthropic.types import ToolParam
|
||||
from groq import Groq
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.settings import BehaveAs, Settings
|
||||
from backend.util.text import TextFormatter
|
||||
@@ -71,10 +74,12 @@ class ModelMetadata(NamedTuple):
|
||||
|
||||
class LlmModelMeta(EnumMeta):
|
||||
@property
|
||||
def __members__(self) -> MappingProxyType:
|
||||
def __members__(
|
||||
self: type["_EnumMemberT"],
|
||||
) -> MappingProxyType[str, "_EnumMemberT"]:
|
||||
if Settings().config.behave_as == BehaveAs.LOCAL:
|
||||
members = super().__members__
|
||||
return MappingProxyType(members)
|
||||
return members
|
||||
else:
|
||||
removed_providers = ["ollama"]
|
||||
existing_members = super().__members__
|
||||
@@ -89,17 +94,14 @@ class LlmModelMeta(EnumMeta):
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenAI models
|
||||
O3_MINI = "o3-mini"
|
||||
O3 = "o3-2025-04-16"
|
||||
O1 = "o1"
|
||||
O1_PREVIEW = "o1-preview"
|
||||
O1_MINI = "o1-mini"
|
||||
GPT41 = "gpt-4.1-2025-04-14"
|
||||
GPT4O_MINI = "gpt-4o-mini"
|
||||
GPT4O = "gpt-4o"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
|
||||
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
@@ -120,7 +122,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
|
||||
# OpenRouter models
|
||||
GEMINI_FLASH_1_5 = "google/gemini-flash-1.5"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||
GROK_BETA = "x-ai/grok-beta"
|
||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||
@@ -138,8 +139,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
||||
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
@@ -160,14 +159,12 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
|
||||
MODEL_METADATA = {
|
||||
# https://platform.openai.com/docs/models
|
||||
LlmModel.O3: ModelMetadata("openai", 200000, 100000),
|
||||
LlmModel.O3_MINI: ModelMetadata("openai", 200000, 100000), # o3-mini-2025-01-31
|
||||
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
|
||||
LlmModel.O1_PREVIEW: ModelMetadata(
|
||||
"openai", 128000, 32768
|
||||
), # o1-preview-2024-09-12
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
|
||||
LlmModel.GPT41: ModelMetadata("openai", 1047576, 32768),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata(
|
||||
"openai", 128000, 16384
|
||||
), # gpt-4o-mini-2024-07-18
|
||||
@@ -177,9 +174,6 @@ MODEL_METADATA = {
|
||||
), # gpt-4-turbo-2024-04-09
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-3-7-sonnet-20250219
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-3-5-sonnet-20241022
|
||||
@@ -205,7 +199,6 @@ MODEL_METADATA = {
|
||||
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768, None),
|
||||
# https://openrouter.ai/models
|
||||
LlmModel.GEMINI_FLASH_1_5: ModelMetadata("open_router", 1000000, 8192),
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata("open_router", 1050000, 8192),
|
||||
LlmModel.GROK_BETA: ModelMetadata("open_router", 131072, 131072),
|
||||
LlmModel.MISTRAL_NEMO: ModelMetadata("open_router", 128000, 4096),
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata("open_router", 128000, 4096),
|
||||
@@ -227,8 +220,6 @@ MODEL_METADATA = {
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata("open_router", 300000, 5120),
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata("open_router", 65536, 4096),
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata("open_router", 4096, 4096),
|
||||
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata("open_router", 131072, 131072),
|
||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata("open_router", 1048576, 1000000),
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
@@ -236,312 +227,21 @@ for model in LlmModel:
|
||||
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
class MessageRole(str, Enum):
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
|
||||
class ToolContentBlock(BaseModel):
|
||||
id: str
|
||||
type: str
|
||||
function: ToolCall
|
||||
|
||||
|
||||
class LLMResponse(BaseModel):
|
||||
raw_response: Any
|
||||
prompt: List[Any]
|
||||
response: str
|
||||
tool_calls: Optional[List[ToolContentBlock]] | None
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
openai_tools: list[dict] | None = None,
|
||||
) -> Iterable[ToolParam] | anthropic.NotGiven:
|
||||
"""
|
||||
Convert OpenAI tool format to Anthropic tool format.
|
||||
"""
|
||||
if not openai_tools or len(openai_tools) == 0:
|
||||
return anthropic.NOT_GIVEN
|
||||
|
||||
anthropic_tools = []
|
||||
for tool in openai_tools:
|
||||
if "function" in tool:
|
||||
# Handle case where tool is already in OpenAI format with "type" and "function"
|
||||
function_data = tool["function"]
|
||||
else:
|
||||
# Handle case where tool is just the function definition
|
||||
function_data = tool
|
||||
|
||||
anthropic_tool: anthropic.types.ToolParam = {
|
||||
"name": function_data["name"],
|
||||
"description": function_data.get("description", ""),
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": function_data.get("parameters", {}).get("properties", {}),
|
||||
"required": function_data.get("parameters", {}).get("required", []),
|
||||
},
|
||||
}
|
||||
anthropic_tools.append(anthropic_tool)
|
||||
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def llm_call(
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
max_tokens: int | None,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
parallel_tool_calls: bool | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Make a call to a language model.
|
||||
|
||||
Args:
|
||||
credentials: The API key credentials to use.
|
||||
llm_model: The LLM model to use.
|
||||
prompt: The prompt to send to the LLM.
|
||||
json_format: Whether the response should be in JSON format.
|
||||
max_tokens: The maximum number of tokens to generate in the chat completion.
|
||||
tools: The tools to use in the chat completion.
|
||||
ollama_host: The host for ollama to use.
|
||||
|
||||
Returns:
|
||||
LLMResponse object containing:
|
||||
- prompt: The prompt sent to the LLM.
|
||||
- response: The text response from the LLM.
|
||||
- tool_calls: Any tool calls the model made, if applicable.
|
||||
- prompt_tokens: The number of tokens used in the prompt.
|
||||
- completion_tokens: The number of tokens used in the completion.
|
||||
"""
|
||||
provider = llm_model.metadata.provider
|
||||
max_tokens = max_tokens or llm_model.max_output_tokens or 4096
|
||||
|
||||
if provider == "openai":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
oai_client = openai.OpenAI(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = None
|
||||
|
||||
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
prompt = [
|
||||
{"role": "user", "content": "\n".join(sys_messages)},
|
||||
{"role": "user", "content": "\n".join(usr_messages)},
|
||||
]
|
||||
elif json_format:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = oai_client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=(
|
||||
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
|
||||
),
|
||||
)
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
tool_calls = [
|
||||
ToolContentBlock(
|
||||
id=tool.id,
|
||||
type=tool.type,
|
||||
function=ToolCall(
|
||||
name=tool.function.name,
|
||||
arguments=tool.function.arguments,
|
||||
),
|
||||
)
|
||||
for tool in response.choices[0].message.tool_calls
|
||||
]
|
||||
else:
|
||||
tool_calls = None
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
prompt=prompt,
|
||||
response=response.choices[0].message.content or "",
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
|
||||
an_tools = convert_openai_tool_fmt_to_anthropic(tools)
|
||||
|
||||
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
sysprompt = " ".join(system_messages)
|
||||
|
||||
messages = []
|
||||
last_role = None
|
||||
for p in prompt:
|
||||
if p["role"] in ["user", "assistant"]:
|
||||
if (
|
||||
p["role"] == last_role
|
||||
and isinstance(messages[-1]["content"], str)
|
||||
and isinstance(p["content"], str)
|
||||
):
|
||||
# If the role is the same as the last one, combine the content
|
||||
messages[-1]["content"] += p["content"]
|
||||
else:
|
||||
messages.append({"role": p["role"], "content": p["content"]})
|
||||
last_role = p["role"]
|
||||
|
||||
client = anthropic.Anthropic(api_key=credentials.api_key.get_secret_value())
|
||||
try:
|
||||
resp = client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
)
|
||||
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
|
||||
tool_calls = None
|
||||
for content_block in resp.content:
|
||||
# Antropic is different to openai, need to iterate through
|
||||
# the content blocks to find the tool calls
|
||||
if content_block.type == "tool_use":
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(
|
||||
ToolContentBlock(
|
||||
id=content_block.id,
|
||||
type=content_block.type,
|
||||
function=ToolCall(
|
||||
name=content_block.name,
|
||||
arguments=json.dumps(content_block.input),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls and resp.stop_reason == "tool_use":
|
||||
logger.warning(
|
||||
"Tool use stop reason but no tool calls found in content. %s", resp
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=resp,
|
||||
prompt=prompt,
|
||||
response=(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else getattr(resp.content[0], "text", "")
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
)
|
||||
except anthropic.APIError as e:
|
||||
error_message = f"Anthropic API error: {str(e)}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
elif provider == "groq":
|
||||
if tools:
|
||||
raise ValueError("Groq does not support tools.")
|
||||
|
||||
client = Groq(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = {"type": "json_object"} if json_format else None
|
||||
response = client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
prompt=prompt,
|
||||
response=response.choices[0].message.content or "",
|
||||
tool_calls=None,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
)
|
||||
elif provider == "ollama":
|
||||
if tools:
|
||||
raise ValueError("Ollama does not support tools.")
|
||||
|
||||
client = ollama.Client(host=ollama_host)
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
response = client.generate(
|
||||
model=llm_model.value,
|
||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||
stream=False,
|
||||
)
|
||||
return LLMResponse(
|
||||
raw_response=response.get("response") or "",
|
||||
prompt=prompt,
|
||||
response=response.get("response") or "",
|
||||
tool_calls=None,
|
||||
prompt_tokens=response.get("prompt_eval_count") or 0,
|
||||
completion_tokens=response.get("eval_count") or 0,
|
||||
)
|
||||
elif provider == "open_router":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
client = openai.OpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
extra_headers={
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=(
|
||||
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
|
||||
),
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"OpenRouter error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from OpenRouter.")
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
tool_calls = [
|
||||
ToolContentBlock(
|
||||
id=tool.id,
|
||||
type=tool.type,
|
||||
function=ToolCall(
|
||||
name=tool.function.name, arguments=tool.function.arguments
|
||||
),
|
||||
)
|
||||
for tool in response.choices[0].message.tool_calls
|
||||
]
|
||||
else:
|
||||
tool_calls = None
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
prompt=prompt,
|
||||
response=response.choices[0].message.content or "",
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
class Message(BlockSchema):
|
||||
role: MessageRole
|
||||
content: str
|
||||
|
||||
|
||||
class AIBlockBase(Block, ABC):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.prompt = []
|
||||
self.prompt = ""
|
||||
|
||||
def merge_llm_stats(self, block: "AIBlockBase"):
|
||||
self.merge_stats(block.execution_stats)
|
||||
@@ -560,7 +260,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
default=LlmModel.GPT4_TURBO,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -570,8 +270,8 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
default="",
|
||||
description="The system prompt to provide additional context to the model.",
|
||||
)
|
||||
conversation_history: list[dict] = SchemaField(
|
||||
default_factory=list,
|
||||
conversation_history: list[Message] = SchemaField(
|
||||
default=[],
|
||||
description="The conversation history to provide context for the prompt.",
|
||||
)
|
||||
retry: int = SchemaField(
|
||||
@@ -581,7 +281,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False,
|
||||
default_factory=dict,
|
||||
default={},
|
||||
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
@@ -600,7 +300,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
response: dict[str, Any] = SchemaField(
|
||||
description="The response object generated by the language model."
|
||||
)
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
prompt: str = SchemaField(description="The prompt sent to the language model.")
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -611,7 +311,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
input_schema=AIStructuredResponseGeneratorBlock.Input,
|
||||
output_schema=AIStructuredResponseGeneratorBlock.Output,
|
||||
test_input={
|
||||
"model": LlmModel.GPT4O,
|
||||
"model": LlmModel.GPT4_TURBO,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expected_format": {
|
||||
"key1": "value1",
|
||||
@@ -622,24 +322,22 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("response", {"key1": "key1Value", "key2": "key2Value"}),
|
||||
("prompt", list),
|
||||
("prompt", str),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[""],
|
||||
response=json.dumps(
|
||||
"llm_call": lambda *args, **kwargs: (
|
||||
json.dumps(
|
||||
{
|
||||
"key1": "key1Value",
|
||||
"key2": "key2Value",
|
||||
}
|
||||
),
|
||||
tool_calls=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
},
|
||||
)
|
||||
self.prompt = ""
|
||||
|
||||
def llm_call(
|
||||
self,
|
||||
@@ -648,29 +346,160 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
max_tokens: int | None,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
) -> LLMResponse:
|
||||
) -> tuple[str, int, int]:
|
||||
"""
|
||||
Test mocks work only on class functions, this wraps the llm_call function
|
||||
so that it can be mocked withing the block testing framework.
|
||||
Args:
|
||||
credentials: The API key credentials to use.
|
||||
llm_model: The LLM model to use.
|
||||
prompt: The prompt to send to the LLM.
|
||||
json_format: Whether the response should be in JSON format.
|
||||
max_tokens: The maximum number of tokens to generate in the chat completion.
|
||||
ollama_host: The host for ollama to use
|
||||
|
||||
Returns:
|
||||
The response from the LLM.
|
||||
The number of tokens used in the prompt.
|
||||
The number of tokens used in the completion.
|
||||
"""
|
||||
self.prompt = prompt
|
||||
return llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
json_format=json_format,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
ollama_host=ollama_host,
|
||||
)
|
||||
provider = llm_model.metadata.provider
|
||||
max_tokens = max_tokens or llm_model.max_output_tokens or 4096
|
||||
|
||||
if provider == "openai":
|
||||
oai_client = openai.OpenAI(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = None
|
||||
|
||||
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
prompt = [
|
||||
{"role": "user", "content": "\n".join(sys_messages)},
|
||||
{"role": "user", "content": "\n".join(usr_messages)},
|
||||
]
|
||||
elif json_format:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = oai_client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
)
|
||||
self.prompt = json.dumps(prompt)
|
||||
|
||||
return (
|
||||
response.choices[0].message.content or "",
|
||||
response.usage.prompt_tokens if response.usage else 0,
|
||||
response.usage.completion_tokens if response.usage else 0,
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
sysprompt = " ".join(system_messages)
|
||||
|
||||
messages = []
|
||||
last_role = None
|
||||
for p in prompt:
|
||||
if p["role"] in ["user", "assistant"]:
|
||||
if p["role"] != last_role:
|
||||
messages.append({"role": p["role"], "content": p["content"]})
|
||||
last_role = p["role"]
|
||||
else:
|
||||
# If the role is the same as the last one, combine the content
|
||||
messages[-1]["content"] += "\n" + p["content"]
|
||||
|
||||
client = anthropic.Anthropic(api_key=credentials.api_key.get_secret_value())
|
||||
try:
|
||||
resp = client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
self.prompt = json.dumps(prompt)
|
||||
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
|
||||
return (
|
||||
(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else resp.content[0].text
|
||||
),
|
||||
resp.usage.input_tokens,
|
||||
resp.usage.output_tokens,
|
||||
)
|
||||
except anthropic.APIError as e:
|
||||
error_message = f"Anthropic API error: {str(e)}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
elif provider == "groq":
|
||||
client = Groq(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = {"type": "json_object"} if json_format else None
|
||||
response = client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
self.prompt = json.dumps(prompt)
|
||||
return (
|
||||
response.choices[0].message.content or "",
|
||||
response.usage.prompt_tokens if response.usage else 0,
|
||||
response.usage.completion_tokens if response.usage else 0,
|
||||
)
|
||||
elif provider == "ollama":
|
||||
client = ollama.Client(host=ollama_host)
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
response = client.generate(
|
||||
model=llm_model.value,
|
||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||
stream=False,
|
||||
)
|
||||
self.prompt = json.dumps(prompt)
|
||||
return (
|
||||
response.get("response") or "",
|
||||
response.get("prompt_eval_count") or 0,
|
||||
response.get("eval_count") or 0,
|
||||
)
|
||||
elif provider == "open_router":
|
||||
client = openai.OpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
extra_headers={
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
self.prompt = json.dumps(prompt)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"OpenRouter error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from OpenRouter.")
|
||||
|
||||
return (
|
||||
response.choices[0].message.content or "",
|
||||
response.usage.prompt_tokens if response.usage else 0,
|
||||
response.usage.completion_tokens if response.usage else 0,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
logger.debug(f"Calling LLM with input data: {input_data}")
|
||||
prompt = [json.to_dict(p) for p in input_data.conversation_history]
|
||||
prompt = [p.model_dump() for p in input_data.conversation_history]
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
lines = s.strip().split("\n")
|
||||
@@ -720,7 +549,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
try:
|
||||
llm_response = self.llm_call(
|
||||
response_text, input_token, output_token = self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
@@ -728,12 +557,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
ollama_host=input_data.ollama_host,
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
response_text = llm_response.response
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=llm_response.prompt_tokens,
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
{
|
||||
"input_token_count": input_token,
|
||||
"output_token_count": output_token,
|
||||
}
|
||||
)
|
||||
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
@@ -776,10 +604,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
retry_prompt = f"Error calling LLM: {e}"
|
||||
finally:
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
llm_retry_count=retry_count,
|
||||
)
|
||||
{
|
||||
"llm_call_count": retry_count + 1,
|
||||
"llm_retry_count": retry_count,
|
||||
}
|
||||
)
|
||||
|
||||
raise RuntimeError(retry_prompt)
|
||||
@@ -793,7 +621,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
default=LlmModel.GPT4_TURBO,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -810,7 +638,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False,
|
||||
default_factory=dict,
|
||||
default={},
|
||||
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
@@ -828,7 +656,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
response: str = SchemaField(
|
||||
description="The response generated by the language model."
|
||||
)
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
prompt: str = SchemaField(description="The prompt sent to the language model.")
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -845,7 +673,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("response", "Response text"),
|
||||
("prompt", list),
|
||||
("prompt", str),
|
||||
],
|
||||
test_mock={"llm_call": lambda *args, **kwargs: "Response text"},
|
||||
)
|
||||
@@ -864,10 +692,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
object_input_data = AIStructuredResponseGeneratorBlock.Input(
|
||||
**{
|
||||
attr: getattr(input_data, attr)
|
||||
for attr in AITextGeneratorBlock.Input.model_fields
|
||||
},
|
||||
**{attr: getattr(input_data, attr) for attr in input_data.model_fields},
|
||||
expected_format={},
|
||||
)
|
||||
yield "response", self.llm_call(object_input_data, credentials)
|
||||
@@ -889,7 +714,7 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
default=LlmModel.GPT4_TURBO,
|
||||
description="The language model to use for summarizing the text.",
|
||||
)
|
||||
focus: str = SchemaField(
|
||||
@@ -924,7 +749,7 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
|
||||
class Output(BlockSchema):
|
||||
summary: str = SchemaField(description="The final summary of the text.")
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
prompt: str = SchemaField(description="The prompt sent to the language model.")
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -941,7 +766,7 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("summary", "Final summary of a long text"),
|
||||
("prompt", list),
|
||||
("prompt", str),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda input_data, credentials: (
|
||||
@@ -1050,18 +875,12 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
|
||||
class AIConversationBlock(AIBlockBase):
|
||||
class Input(BlockSchema):
|
||||
prompt: str = SchemaField(
|
||||
description="The prompt to send to the language model.",
|
||||
placeholder="Enter your prompt here...",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
messages: List[Any] = SchemaField(
|
||||
description="List of messages in the conversation.",
|
||||
messages: List[Message] = SchemaField(
|
||||
description="List of messages in the conversation.", min_length=1
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
default=LlmModel.GPT4_TURBO,
|
||||
description="The language model to use for the conversation.",
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
@@ -1080,7 +899,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
response: str = SchemaField(
|
||||
description="The model's response to the conversation."
|
||||
)
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
prompt: str = SchemaField(description="The prompt sent to the language model.")
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -1100,7 +919,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
},
|
||||
{"role": "user", "content": "Where was it played?"},
|
||||
],
|
||||
"model": LlmModel.GPT4O,
|
||||
"model": LlmModel.GPT4_TURBO,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -1109,7 +928,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
"response",
|
||||
"The 2020 World Series was played at Globe Life Field in Arlington, Texas.",
|
||||
),
|
||||
("prompt", list),
|
||||
("prompt", str),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: "The 2020 World Series was played at Globe Life Field in Arlington, Texas."
|
||||
@@ -1131,7 +950,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
) -> BlockOutput:
|
||||
response = self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt=input_data.prompt,
|
||||
prompt="",
|
||||
credentials=input_data.credentials,
|
||||
model=input_data.model,
|
||||
conversation_history=input_data.messages,
|
||||
@@ -1162,7 +981,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
default=LlmModel.GPT4_TURBO,
|
||||
description="The language model to use for generating the list.",
|
||||
advanced=True,
|
||||
)
|
||||
@@ -1189,7 +1008,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
list_item: str = SchemaField(
|
||||
description="Each individual item in the list.",
|
||||
)
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
prompt: str = SchemaField(description="The prompt sent to the language model.")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the list generation failed."
|
||||
)
|
||||
@@ -1211,7 +1030,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
||||
"fictional worlds."
|
||||
),
|
||||
"model": LlmModel.GPT4O,
|
||||
"model": LlmModel.GPT4_TURBO,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_retries": 3,
|
||||
},
|
||||
@@ -1221,7 +1040,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"generated_list",
|
||||
["Zylora Prime", "Kharon-9", "Vortexia", "Oceara", "Draknos"],
|
||||
),
|
||||
("prompt", list),
|
||||
("prompt", str),
|
||||
("list_item", "Zylora Prime"),
|
||||
("list_item", "Kharon-9"),
|
||||
("list_item", "Vortexia"),
|
||||
|
||||
@@ -8,13 +8,13 @@ from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
from backend.util.file import MediaFile, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class MediaDurationBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
media_in: MediaFileType = SchemaField(
|
||||
media_in: MediaFile = SchemaField(
|
||||
description="Media input (URL, data URI, or local path)."
|
||||
)
|
||||
is_video: bool = SchemaField(
|
||||
@@ -69,7 +69,7 @@ class LoopVideoBlock(Block):
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
video_in: MediaFile = SchemaField(
|
||||
description="The input video (can be a URL, data URI, or local path)."
|
||||
)
|
||||
# Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`.
|
||||
@@ -137,7 +137,7 @@ class LoopVideoBlock(Block):
|
||||
assert isinstance(looped_clip, VideoFileClip)
|
||||
|
||||
# 4) Save the looped output
|
||||
output_filename = MediaFileType(
|
||||
output_filename = MediaFile(
|
||||
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||
@@ -162,10 +162,10 @@ class AddAudioToVideoBlock(Block):
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
video_in: MediaFile = SchemaField(
|
||||
description="Video input (URL, data URI, or local path)."
|
||||
)
|
||||
audio_in: MediaFileType = SchemaField(
|
||||
audio_in: MediaFile = SchemaField(
|
||||
description="Audio input (URL, data URI, or local path)."
|
||||
)
|
||||
volume: float = SchemaField(
|
||||
@@ -178,7 +178,7 @@ class AddAudioToVideoBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
video_out: MediaFile = SchemaField(
|
||||
description="Final video (with attached audio), as a path or data URI."
|
||||
)
|
||||
error: str = SchemaField(
|
||||
@@ -229,7 +229,7 @@ class AddAudioToVideoBlock(Block):
|
||||
final_clip = video_clip.with_audio(audio_clip)
|
||||
|
||||
# 4) Write to output file
|
||||
output_filename = MediaFileType(
|
||||
output_filename = MediaFile(
|
||||
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
||||
|
||||
@@ -65,7 +65,7 @@ class AddMemoryBlock(Block, Mem0Base):
|
||||
default=Content(discriminator="content", content="I'm a vegetarian"),
|
||||
)
|
||||
metadata: dict[str, Any] = SchemaField(
|
||||
description="Optional metadata for the memory", default_factory=dict
|
||||
description="Optional metadata for the memory", default={}
|
||||
)
|
||||
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
@@ -173,7 +173,7 @@ class SearchMemoryBlock(Block, Mem0Base):
|
||||
)
|
||||
categories_filter: list[str] = SchemaField(
|
||||
description="Categories to filter by",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
|
||||
@@ -6,14 +6,13 @@ from backend.blocks.nvidia._auth import (
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import requests
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
class NvidiaDeepfakeDetectBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: NvidiaCredentialsInput = NvidiaCredentialsField()
|
||||
image_base64: MediaFileType = SchemaField(
|
||||
description="Image to analyze for deepfakes",
|
||||
image_base64: str = SchemaField(
|
||||
description="Image to analyze for deepfakes", image_upload=True
|
||||
)
|
||||
return_image: bool = SchemaField(
|
||||
description="Whether to return the processed image with markings",
|
||||
@@ -23,12 +22,16 @@ class NvidiaDeepfakeDetectBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(
|
||||
description="Detection status (SUCCESS, ERROR, CONTENT_FILTERED)",
|
||||
default="",
|
||||
)
|
||||
image: MediaFileType = SchemaField(
|
||||
image: str = SchemaField(
|
||||
description="Processed image with detection markings (if return_image=True)",
|
||||
default="",
|
||||
image_output=True,
|
||||
)
|
||||
is_deepfake: float = SchemaField(
|
||||
description="Probability that the image is a deepfake (0-1)",
|
||||
default=0.0,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -177,8 +177,7 @@ class PineconeInsertBlock(Block):
|
||||
description="Namespace to use in Pinecone", default=""
|
||||
)
|
||||
metadata: dict = SchemaField(
|
||||
description="Additional metadata to store with each vector",
|
||||
default_factory=dict,
|
||||
description="Additional metadata to store with each vector", default={}
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -12,7 +12,7 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import MediaFileType, store_media_file
|
||||
from backend.util.file import MediaFile, store_media_file
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ class ScreenshotWebPageBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
image: MediaFileType = SchemaField(description="The screenshot image data")
|
||||
image: MediaFile = SchemaField(description="The screenshot image data")
|
||||
error: str = SchemaField(description="Error message if the screenshot failed")
|
||||
|
||||
def __init__(self):
|
||||
@@ -142,9 +142,7 @@ class ScreenshotWebPageBlock(Block):
|
||||
return {
|
||||
"image": store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=MediaFileType(
|
||||
f"data:image/{format.value};base64,{b64encode(response.content).decode('utf-8')}"
|
||||
),
|
||||
file=f"data:image/{format.value};base64,{b64encode(response.content).decode('utf-8')}",
|
||||
return_content=True,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ from backend.data.block import (
|
||||
BlockWebhookConfig,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import settings
|
||||
from backend.util.settings import AppEnvironment, BehaveAs
|
||||
|
||||
@@ -26,7 +25,7 @@ class Slant3DTriggerBase:
|
||||
class Input(BlockSchema):
|
||||
credentials: Slant3DCredentialsInput = Slant3DCredentialsField()
|
||||
# Webhook URL is handled by the webhook system
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
payload: dict = SchemaField(hidden=True, default={})
|
||||
|
||||
class Output(BlockSchema):
|
||||
payload: dict = SchemaField(
|
||||
@@ -83,7 +82,7 @@ class Slant3DOrderWebhookBlock(Slant3DTriggerBase, Block):
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName.SLANT3D,
|
||||
provider="slant3d",
|
||||
webhook_type="orders", # Only one type for now
|
||||
resource_format="", # No resource format needed
|
||||
event_filter_input="events",
|
||||
|
||||
@@ -1,509 +0,0 @@
|
||||
import logging
|
||||
import re
|
||||
from collections import Counter
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Return a list of tool_call_ids if the entry is a tool request.
|
||||
Supports both OpenAI and Anthropics formats.
|
||||
"""
|
||||
tool_call_ids = []
|
||||
if entry.get("role") != "assistant":
|
||||
return tool_call_ids
|
||||
|
||||
# OpenAI: check for tool_calls in the entry.
|
||||
calls = entry.get("tool_calls")
|
||||
if isinstance(calls, list):
|
||||
for call in calls:
|
||||
if tool_id := call.get("id"):
|
||||
tool_call_ids.append(tool_id)
|
||||
|
||||
# Anthropics: check content items for tool_use type.
|
||||
content = entry.get("content")
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") != "tool_use":
|
||||
continue
|
||||
if tool_id := item.get("id"):
|
||||
tool_call_ids.append(tool_id)
|
||||
|
||||
return tool_call_ids
|
||||
|
||||
|
||||
def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Return a list of tool_call_ids if the entry is a tool response.
|
||||
Supports both OpenAI and Anthropics formats.
|
||||
"""
|
||||
tool_call_ids: list[str] = []
|
||||
|
||||
# OpenAI: a tool response message with role "tool" and key "tool_call_id".
|
||||
if entry.get("role") == "tool":
|
||||
if tool_call_id := entry.get("tool_call_id"):
|
||||
tool_call_ids.append(str(tool_call_id))
|
||||
|
||||
# Anthropics: check content items for tool_result type.
|
||||
if entry.get("role") == "user":
|
||||
content = entry.get("content")
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") != "tool_result":
|
||||
continue
|
||||
if tool_call_id := item.get("tool_use_id"):
|
||||
tool_call_ids.append(tool_call_id)
|
||||
|
||||
return tool_call_ids
|
||||
|
||||
|
||||
def _create_tool_response(call_id: str, output: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Create a tool response message for either OpenAI or Anthropics,
|
||||
based on the tool_id format.
|
||||
"""
|
||||
content = output if isinstance(output, str) else json.dumps(output)
|
||||
|
||||
# Anthropics format: tool IDs typically start with "toolu_"
|
||||
if call_id.startswith("toolu_"):
|
||||
return {
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{"tool_use_id": call_id, "type": "tool_result", "content": content}
|
||||
],
|
||||
}
|
||||
|
||||
# OpenAI format: tool IDs typically start with "call_".
|
||||
# Or default fallback (if the tool_id doesn't match any known prefix)
|
||||
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
|
||||
|
||||
def get_pending_tool_calls(conversation_history: list[Any]) -> dict[str, int]:
|
||||
"""
|
||||
All the tool calls entry in the conversation history requires a response.
|
||||
This function returns the pending tool calls that has not generated an output yet.
|
||||
|
||||
Return: dict[str, int] - A dictionary of pending tool call IDs with their count.
|
||||
"""
|
||||
pending_calls = Counter()
|
||||
for history in conversation_history:
|
||||
for call_id in _get_tool_requests(history):
|
||||
pending_calls[call_id] += 1
|
||||
|
||||
for call_id in _get_tool_responses(history):
|
||||
pending_calls[call_id] -= 1
|
||||
|
||||
return {call_id: count for call_id, count in pending_calls.items() if count > 0}
|
||||
|
||||
|
||||
class SmartDecisionMakerBlock(Block):
|
||||
"""
|
||||
A block that uses a language model to make smart decisions based on a given prompt.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
prompt: str = SchemaField(
|
||||
description="The prompt to send to the language model.",
|
||||
placeholder="Enter your prompt here...",
|
||||
)
|
||||
model: llm.LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=llm.LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
credentials: llm.AICredentials = llm.AICredentialsField()
|
||||
sys_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="Thinking carefully step by step decide which function to call. "
|
||||
"Always choose a function call from the list of function signatures, "
|
||||
"and always provide the complete argument provided with the type "
|
||||
"matching the required jsonschema signature, no missing argument is allowed. "
|
||||
"If you have already completed the task objective, you can end the task "
|
||||
"by providing the end result of your work as a finish message. "
|
||||
"Only provide EXACTLY one function call, multiple tool calls is strictly prohibited.",
|
||||
description="The system prompt to provide additional context to the model.",
|
||||
)
|
||||
conversation_history: list[dict] = SchemaField(
|
||||
default_factory=list,
|
||||
description="The conversation history to provide context for the prompt.",
|
||||
)
|
||||
last_tool_output: Any = SchemaField(
|
||||
default=None,
|
||||
description="The output of the last tool that was called.",
|
||||
)
|
||||
retry: int = SchemaField(
|
||||
title="Retry Count",
|
||||
default=3,
|
||||
description="Number of times to retry the LLM call if the response does not match the expected format.",
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False,
|
||||
default_factory=dict,
|
||||
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
default="localhost:11434",
|
||||
description="Ollama host for local models",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
|
||||
# conversation_history & last_tool_output validation is handled differently
|
||||
missing_links = super().get_missing_links(
|
||||
data,
|
||||
[
|
||||
link
|
||||
for link in links
|
||||
if link.sink_name
|
||||
not in ["conversation_history", "last_tool_output"]
|
||||
],
|
||||
)
|
||||
|
||||
# Avoid executing the block if the last_tool_output is connected to a static
|
||||
# link, like StoreValueBlock or AgentInputBlock.
|
||||
if any(link.sink_name == "conversation_history" for link in links) and any(
|
||||
link.sink_name == "last_tool_output" and link.is_static
|
||||
for link in links
|
||||
):
|
||||
raise ValueError(
|
||||
"Last Tool Output can't be connected to a static (dashed line) "
|
||||
"link like the output of `StoreValue` or `AgentInput` block"
|
||||
)
|
||||
|
||||
return missing_links
|
||||
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
if missing_input := super().get_missing_input(data):
|
||||
return missing_input
|
||||
|
||||
conversation_history = data.get("conversation_history", [])
|
||||
pending_tool_calls = get_pending_tool_calls(conversation_history)
|
||||
last_tool_output = data.get("last_tool_output")
|
||||
if not last_tool_output and pending_tool_calls:
|
||||
return {"last_tool_output"}
|
||||
return set()
|
||||
|
||||
class Output(BlockSchema):
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
tools: Any = SchemaField(description="The tools that are available to use.")
|
||||
finished: str = SchemaField(
|
||||
description="The finished message to display to the user."
|
||||
)
|
||||
conversations: list[Any] = SchemaField(
|
||||
description="The conversation history to provide context for the prompt."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||
description="Uses AI to intelligently decide what tool to use.",
|
||||
categories={BlockCategory.AI},
|
||||
block_type=BlockType.AI,
|
||||
input_schema=SmartDecisionMakerBlock.Input,
|
||||
output_schema=SmartDecisionMakerBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[],
|
||||
test_credentials=llm.TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_block_function_signature(
|
||||
sink_node: "Node", links: list["Link"]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Creates a function signature for a block node.
|
||||
|
||||
Args:
|
||||
sink_node: The node for which to create a function signature.
|
||||
links: The list of links connected to the sink node.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the function signature in the format expected by LLM tools.
|
||||
|
||||
Raises:
|
||||
ValueError: If the block specified by sink_node.block_id is not found.
|
||||
"""
|
||||
block = sink_node.block
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", block.name).lower(),
|
||||
"description": block.description,
|
||||
}
|
||||
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for link in links:
|
||||
sink_block_input_schema = block.input_schema
|
||||
description = (
|
||||
sink_block_input_schema.model_fields[link.sink_name].description
|
||||
if link.sink_name in sink_block_input_schema.model_fields
|
||||
and sink_block_input_schema.model_fields[link.sink_name].description
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[link.sink_name.lower()] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
}
|
||||
|
||||
tool_function["parameters"] = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
"additionalProperties": False,
|
||||
"strict": True,
|
||||
}
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
def _create_agent_function_signature(
|
||||
sink_node: "Node", links: list["Link"]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Creates a function signature for an agent node.
|
||||
|
||||
Args:
|
||||
sink_node: The agent node for which to create a function signature.
|
||||
links: The list of links connected to the sink node.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the function signature in the format expected by LLM tools.
|
||||
|
||||
Raises:
|
||||
ValueError: If the graph metadata for the specified graph_id and graph_version is not found.
|
||||
"""
|
||||
graph_id = sink_node.input_default.get("graph_id")
|
||||
graph_version = sink_node.input_default.get("graph_version")
|
||||
if not graph_id or not graph_version:
|
||||
raise ValueError("Graph ID or Graph Version not found in sink node.")
|
||||
|
||||
db_client = get_database_manager_client()
|
||||
sink_graph_meta = db_client.get_graph_metadata(graph_id, graph_version)
|
||||
if not sink_graph_meta:
|
||||
raise ValueError(
|
||||
f"Sink graph metadata not found: {graph_id} {graph_version}"
|
||||
)
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", sink_graph_meta.name).lower(),
|
||||
"description": sink_graph_meta.description,
|
||||
}
|
||||
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for link in links:
|
||||
sink_block_input_schema = sink_node.input_default["input_schema"]
|
||||
description = (
|
||||
sink_block_input_schema["properties"][link.sink_name]["description"]
|
||||
if "description"
|
||||
in sink_block_input_schema["properties"][link.sink_name]
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[link.sink_name.lower()] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
}
|
||||
|
||||
tool_function["parameters"] = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
"additionalProperties": False,
|
||||
"strict": True,
|
||||
}
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Creates function signatures for tools linked to a specified node within a graph.
|
||||
|
||||
This method filters the graph links to identify those that are tools and are
|
||||
connected to the given node_id. It then constructs function signatures for each
|
||||
tool based on the metadata and input schema of the linked nodes.
|
||||
|
||||
Args:
|
||||
node_id: The node_id for which to create function signatures.
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: A list of dictionaries, each representing a function signature
|
||||
for a tool, including its name, description, and parameters.
|
||||
|
||||
Raises:
|
||||
ValueError: If no tool links are found for the specified node_id, or if a sink node
|
||||
or its metadata cannot be found.
|
||||
"""
|
||||
db_client = get_database_manager_client()
|
||||
tools = [
|
||||
(link, node)
|
||||
for link, node in db_client.get_connected_output_nodes(node_id)
|
||||
if link.source_name.startswith("tools_^_") and link.source_id == node_id
|
||||
]
|
||||
if not tools:
|
||||
raise ValueError("There is no next node to execute.")
|
||||
|
||||
return_tool_functions = []
|
||||
|
||||
grouped_tool_links: dict[str, tuple["Node", list["Link"]]] = {}
|
||||
for link, node in tools:
|
||||
if link.sink_id not in grouped_tool_links:
|
||||
grouped_tool_links[link.sink_id] = (node, [link])
|
||||
else:
|
||||
grouped_tool_links[link.sink_id][1].append(link)
|
||||
|
||||
for sink_node, links in grouped_tool_links.values():
|
||||
if not sink_node:
|
||||
raise ValueError(f"Sink node not found: {links[0].sink_id}")
|
||||
|
||||
if sink_node.block_id == AgentExecutorBlock().id:
|
||||
return_tool_functions.append(
|
||||
SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
else:
|
||||
return_tool_functions.append(
|
||||
SmartDecisionMakerBlock._create_block_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
|
||||
return return_tool_functions
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: llm.APIKeyCredentials,
|
||||
graph_id: str,
|
||||
node_id: str,
|
||||
graph_exec_id: str,
|
||||
node_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
tool_functions = self._create_function_signature(node_id)
|
||||
|
||||
input_data.conversation_history = input_data.conversation_history or []
|
||||
prompt = [json.to_dict(p) for p in input_data.conversation_history if p]
|
||||
|
||||
pending_tool_calls = get_pending_tool_calls(input_data.conversation_history)
|
||||
if pending_tool_calls and not input_data.last_tool_output:
|
||||
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
|
||||
|
||||
# Prefill all missing tool calls with the last tool output/
|
||||
# TODO: we need a better way to handle this.
|
||||
tool_output = [
|
||||
_create_tool_response(pending_call_id, input_data.last_tool_output)
|
||||
for pending_call_id, count in pending_tool_calls.items()
|
||||
for _ in range(count)
|
||||
]
|
||||
|
||||
# If the SDM block only calls 1 tool at a time, this should not happen.
|
||||
if len(tool_output) > 1:
|
||||
logger.warning(
|
||||
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
|
||||
f"Multiple pending tool calls are prefilled using a single output. "
|
||||
f"Execution may not be accurate."
|
||||
)
|
||||
|
||||
# Fallback on adding tool output in the conversation history as user prompt.
|
||||
if len(tool_output) == 0 and input_data.last_tool_output:
|
||||
logger.warning(
|
||||
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
|
||||
f"No pending tool calls found. This may indicate an issue with the "
|
||||
f"conversation history, or an LLM calling two tools at the same time."
|
||||
)
|
||||
tool_output.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Last tool output: {json.dumps(input_data.last_tool_output)}",
|
||||
}
|
||||
)
|
||||
|
||||
prompt.extend(tool_output)
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
|
||||
|
||||
prefix = "[Main Objective Prompt]: "
|
||||
|
||||
if input_data.sys_prompt and not any(
|
||||
p["role"] == "system" and p["content"].startswith(prefix) for p in prompt
|
||||
):
|
||||
prompt.append({"role": "system", "content": prefix + input_data.sys_prompt})
|
||||
|
||||
if input_data.prompt and not any(
|
||||
p["role"] == "user" and p["content"].startswith(prefix) for p in prompt
|
||||
):
|
||||
prompt.append({"role": "user", "content": prefix + input_data.prompt})
|
||||
|
||||
response = llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
json_format=False,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=False,
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
yield "finished", response.response
|
||||
return
|
||||
|
||||
for tool_call in response.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
for arg_name, arg_value in tool_args.items():
|
||||
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
|
||||
|
||||
response.prompt.append(response.raw_response)
|
||||
yield "conversations", response.prompt
|
||||
@@ -112,7 +112,7 @@ class AddLeadToCampaignBlock(Block):
|
||||
lead_list: list[LeadInput] = SchemaField(
|
||||
description="An array of JSON objects, each representing a lead's details. Can hold max 100 leads.",
|
||||
max_length=100,
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
settings: LeadUploadSettings = SchemaField(
|
||||
@@ -248,7 +248,7 @@ class SaveCampaignSequencesBlock(Block):
|
||||
)
|
||||
sequences: list[Sequence] = SchemaField(
|
||||
description="The sequences to save",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
|
||||
@@ -39,7 +39,7 @@ class LeadCustomFields(BaseModel):
|
||||
fields: dict[str, str] = SchemaField(
|
||||
description="Custom fields for a lead (max 20 fields)",
|
||||
max_length=20,
|
||||
default_factory=dict,
|
||||
default={},
|
||||
)
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ class AddLeadsRequest(BaseModel):
|
||||
lead_list: list[LeadInput] = SchemaField(
|
||||
description="List of leads to add to the campaign",
|
||||
max_length=100,
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
settings: LeadUploadSettings
|
||||
campaign_id: int
|
||||
|
||||
@@ -156,10 +156,6 @@ class CountdownTimerBlock(Block):
|
||||
days: Union[int, str] = SchemaField(
|
||||
advanced=False, description="Duration in days", default=0
|
||||
)
|
||||
repeat: int = SchemaField(
|
||||
description="Number of times to repeat the timer",
|
||||
default=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output_message: Any = SchemaField(
|
||||
@@ -191,6 +187,5 @@ class CountdownTimerBlock(Block):
|
||||
|
||||
total_seconds = seconds + minutes * 60 + hours * 3600 + days * 86400
|
||||
|
||||
for _ in range(input_data.repeat):
|
||||
time.sleep(total_seconds)
|
||||
yield "output_message", input_data.input_message
|
||||
time.sleep(total_seconds)
|
||||
yield "output_message", input_data.input_message
|
||||
|
||||
@@ -156,7 +156,7 @@
|
||||
# participant_ids: list[str] = SchemaField(
|
||||
# description="Array of User IDs to create conversation with (max 50)",
|
||||
# placeholder="Enter participant user IDs",
|
||||
# default_factory=list,
|
||||
# default=[],
|
||||
# advanced=False
|
||||
# )
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ class TwitterGetListBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to lookup",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -183,6 +184,7 @@ class TwitterGetOwnedListsBlock(Block):
|
||||
user_id: str = SchemaField(
|
||||
description="The user ID whose owned Lists to retrieve",
|
||||
placeholder="Enter user ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
|
||||
@@ -45,11 +45,13 @@ class TwitterRemoveListMemberBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to remove the member from",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="The ID of the user to remove from the List",
|
||||
placeholder="Enter user ID to remove",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -118,11 +120,13 @@ class TwitterAddListMemberBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to add the member to",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="The ID of the user to add to the List",
|
||||
placeholder="Enter user ID to add",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -191,6 +195,7 @@ class TwitterGetListMembersBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to get members from",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
@@ -371,6 +376,7 @@ class TwitterGetListMembershipsBlock(Block):
|
||||
user_id: str = SchemaField(
|
||||
description="The ID of the user whose List memberships to retrieve",
|
||||
placeholder="Enter user ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
|
||||
@@ -42,6 +42,7 @@ class TwitterGetListTweetsBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List whose Tweets you would like to retrieve",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
|
||||
@@ -28,6 +28,7 @@ class TwitterDeleteListBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to be deleted",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -39,6 +39,7 @@ class TwitterUnpinListBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to unpin",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -102,6 +103,7 @@ class TwitterPinListBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to pin",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -44,7 +44,7 @@ class SpaceList(BaseModel):
|
||||
space_ids: list[str] = SchemaField(
|
||||
description="List of Space IDs to lookup (up to 100)",
|
||||
placeholder="Enter Space IDs",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -54,7 +54,7 @@ class UserList(BaseModel):
|
||||
user_ids: list[str] = SchemaField(
|
||||
description="List of user IDs to lookup their Spaces (up to 100)",
|
||||
placeholder="Enter user IDs",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -227,6 +227,7 @@ class TwitterGetSpaceByIdBlock(Block):
|
||||
space_id: str = SchemaField(
|
||||
description="Space ID to lookup",
|
||||
placeholder="Enter Space ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -388,6 +389,7 @@ class TwitterGetSpaceBuyersBlock(Block):
|
||||
space_id: str = SchemaField(
|
||||
description="Space ID to lookup buyers for",
|
||||
placeholder="Enter Space ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -515,6 +517,7 @@ class TwitterGetSpaceTweetsBlock(Block):
|
||||
space_id: str = SchemaField(
|
||||
description="Space ID to lookup tweets for",
|
||||
placeholder="Enter Space ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -200,7 +200,7 @@ class UserIdList(BaseModel):
|
||||
user_ids: list[str] = SchemaField(
|
||||
description="List of user IDs to lookup (max 100)",
|
||||
placeholder="Enter user IDs",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -210,7 +210,7 @@ class UsernameList(BaseModel):
|
||||
usernames: list[str] = SchemaField(
|
||||
description="List of Twitter usernames/handles to lookup (max 100)",
|
||||
placeholder="Enter usernames",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import pathlib
|
||||
import click
|
||||
import psutil
|
||||
|
||||
from backend import app
|
||||
from backend.util.process import AppProcess
|
||||
|
||||
|
||||
@@ -41,13 +42,8 @@ def write_pid(pid: int):
|
||||
|
||||
class MainApp(AppProcess):
|
||||
def run(self):
|
||||
from backend import app
|
||||
|
||||
app.main(silent=True)
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
|
||||
@click.group()
|
||||
def main():
|
||||
@@ -224,8 +220,9 @@ def event():
|
||||
|
||||
@test.command()
|
||||
@click.argument("server_address")
|
||||
@click.argument("graph_exec_id")
|
||||
def websocket(server_address: str, graph_exec_id: str):
|
||||
@click.argument("graph_id")
|
||||
@click.argument("graph_version")
|
||||
def websocket(server_address: str, graph_id: str, graph_version: int):
|
||||
"""
|
||||
Tests the websocket connection.
|
||||
"""
|
||||
@@ -233,20 +230,16 @@ def websocket(server_address: str, graph_exec_id: str):
|
||||
|
||||
import websockets.asyncio.client
|
||||
|
||||
from backend.server.ws_api import (
|
||||
WSMessage,
|
||||
WSMethod,
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
)
|
||||
from backend.server.ws_api import ExecutionSubscription, Methods, WsMessage
|
||||
|
||||
async def send_message(server_address: str):
|
||||
uri = f"ws://{server_address}"
|
||||
async with websockets.asyncio.client.connect(uri) as websocket:
|
||||
try:
|
||||
msg = WSMessage(
|
||||
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
|
||||
data=WSSubscribeGraphExecutionRequest(
|
||||
graph_exec_id=graph_exec_id,
|
||||
msg = WsMessage(
|
||||
method=Methods.SUBSCRIBE,
|
||||
data=ExecutionSubscription(
|
||||
graph_id=graph_id, graph_version=graph_version
|
||||
).model_dump(),
|
||||
).model_dump_json()
|
||||
await websocket.send(msg)
|
||||
|
||||
@@ -12,12 +12,12 @@ async def log_raw_analytics(
|
||||
data_index: str,
|
||||
):
|
||||
details = await prisma.models.AnalyticsDetails.prisma().create(
|
||||
data=prisma.types.AnalyticsDetailsCreateInput(
|
||||
userId=user_id,
|
||||
type=type,
|
||||
data=prisma.Json(data),
|
||||
dataIndex=data_index,
|
||||
)
|
||||
data={
|
||||
"userId": user_id,
|
||||
"type": type,
|
||||
"data": prisma.Json(data),
|
||||
"dataIndex": data_index,
|
||||
}
|
||||
)
|
||||
return details
|
||||
|
||||
@@ -32,12 +32,12 @@ async def log_raw_metric(
|
||||
raise ValueError("metric_value must be non-negative")
|
||||
|
||||
result = await prisma.models.AnalyticsMetrics.prisma().create(
|
||||
data=prisma.types.AnalyticsMetricsCreateInput(
|
||||
value=metric_value,
|
||||
analyticMetric=metric_name,
|
||||
userId=user_id,
|
||||
dataString=data_string,
|
||||
)
|
||||
data={
|
||||
"value": metric_value,
|
||||
"analyticMetric": metric_name,
|
||||
"userId": user_id,
|
||||
"dataString": data_string,
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -2,7 +2,6 @@ import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Generator,
|
||||
@@ -17,25 +16,18 @@ from typing import (
|
||||
import jsonref
|
||||
import jsonschema
|
||||
from prisma.models import AgentBlock
|
||||
from prisma.types import AgentBlockCreateInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .model import (
|
||||
ContributorDetails,
|
||||
Credentials,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import Link
|
||||
|
||||
app_config = Config()
|
||||
|
||||
BlockData = tuple[str, Any] # Input & Output data should be a tuple of (name, data).
|
||||
@@ -52,7 +44,6 @@ class BlockType(Enum):
|
||||
WEBHOOK = "Webhook"
|
||||
WEBHOOK_MANUAL = "Webhook (manual)"
|
||||
AGENT = "Agent"
|
||||
AI = "AI"
|
||||
|
||||
|
||||
class BlockCategory(Enum):
|
||||
@@ -118,30 +109,21 @@ class BlockSchema(BaseModel):
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
return json.validate_with_jsonschema(schema=cls.jsonschema(), data=data)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return cls.validate_data(data)
|
||||
|
||||
@classmethod
|
||||
def get_field_schema(cls, field_name: str) -> dict[str, Any]:
|
||||
model_schema = cls.jsonschema().get("properties", {})
|
||||
if not model_schema:
|
||||
raise ValueError(f"Invalid model schema {cls}")
|
||||
|
||||
property_schema = model_schema.get(field_name)
|
||||
if not property_schema:
|
||||
raise ValueError(f"Invalid property name {field_name}")
|
||||
|
||||
return property_schema
|
||||
|
||||
@classmethod
|
||||
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:
|
||||
"""
|
||||
Validate the data against a specific property (one of the input/output name).
|
||||
Returns the validation error message if the data does not match the schema.
|
||||
"""
|
||||
model_schema = cls.jsonschema().get("properties", {})
|
||||
if not model_schema:
|
||||
return f"Invalid model schema {cls}"
|
||||
|
||||
property_schema = model_schema.get(field_name)
|
||||
if not property_schema:
|
||||
return f"Invalid property name {field_name}"
|
||||
|
||||
try:
|
||||
property_schema = cls.get_field_schema(field_name)
|
||||
jsonschema.validate(json.to_dict(data), property_schema)
|
||||
return None
|
||||
except jsonschema.ValidationError as e:
|
||||
@@ -204,28 +186,6 @@ class BlockSchema(BaseModel):
|
||||
)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
|
||||
return {
|
||||
field_name: CredentialsFieldInfo.model_validate(
|
||||
cls.get_field_schema(field_name), by_alias=True
|
||||
)
|
||||
for field_name in cls.get_credentials_fields().keys()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
return data # Return as is, by default.
|
||||
|
||||
@classmethod
|
||||
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
|
||||
input_fields_from_nodes = {link.sink_name for link in links}
|
||||
return input_fields_from_nodes - set(data)
|
||||
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
return cls.get_required_fields() - set(data)
|
||||
|
||||
|
||||
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema)
|
||||
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchema)
|
||||
@@ -242,7 +202,7 @@ class BlockManualWebhookConfig(BaseModel):
|
||||
the user has to manually set up the webhook at the provider.
|
||||
"""
|
||||
|
||||
provider: ProviderName
|
||||
provider: str
|
||||
"""The service provider that the webhook connects to"""
|
||||
|
||||
webhook_type: str
|
||||
@@ -334,7 +294,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.static_output = static_output
|
||||
self.block_type = block_type
|
||||
self.webhook_config = webhook_config
|
||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||
self.execution_stats = {}
|
||||
|
||||
if self.webhook_config:
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
@@ -391,14 +351,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
Run the block with the given input data.
|
||||
Args:
|
||||
input_data: The input data with the structure of input_schema.
|
||||
|
||||
Kwargs: Currently 14/02/2025 these include
|
||||
graph_id: The ID of the graph.
|
||||
node_id: The ID of the node.
|
||||
graph_exec_id: The ID of the graph execution.
|
||||
node_exec_id: The ID of the node execution.
|
||||
user_id: The ID of the user.
|
||||
|
||||
Returns:
|
||||
A Generator that yields (output_name, output_data).
|
||||
output_name: One of the output name defined in Block's output_schema.
|
||||
@@ -412,29 +364,18 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
return data
|
||||
raise ValueError(f"{self.name} did not produce any output for {output}")
|
||||
|
||||
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
|
||||
stats_dict = stats.model_dump()
|
||||
current_stats = self.execution_stats.model_dump()
|
||||
|
||||
for key, value in stats_dict.items():
|
||||
if key not in current_stats:
|
||||
# Field doesn't exist yet, just set it, but this will probably
|
||||
# not happen, just in case though so we throw for invalid when
|
||||
# converting back in
|
||||
current_stats[key] = value
|
||||
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
|
||||
current_stats[key].update(value)
|
||||
elif isinstance(value, (int, float)) and isinstance(
|
||||
current_stats[key], (int, float)
|
||||
):
|
||||
current_stats[key] += value
|
||||
elif isinstance(value, list) and isinstance(current_stats[key], list):
|
||||
current_stats[key].extend(value)
|
||||
def merge_stats(self, stats: dict[str, Any]) -> dict[str, Any]:
|
||||
for key, value in stats.items():
|
||||
if isinstance(value, dict):
|
||||
self.execution_stats.setdefault(key, {}).update(value)
|
||||
elif isinstance(value, (int, float)):
|
||||
self.execution_stats.setdefault(key, 0)
|
||||
self.execution_stats[key] += value
|
||||
elif isinstance(value, list):
|
||||
self.execution_stats.setdefault(key, [])
|
||||
self.execution_stats[key].extend(value)
|
||||
else:
|
||||
current_stats[key] = value
|
||||
|
||||
self.execution_stats = NodeExecutionStats(**current_stats)
|
||||
|
||||
self.execution_stats[key] = value
|
||||
return self.execution_stats
|
||||
|
||||
@property
|
||||
@@ -457,6 +398,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
}
|
||||
|
||||
def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||
# Merge the input data with the extra execution arguments, preferring the args for security
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise ValueError(
|
||||
f"Unable to execute block with invalid input data: {error}"
|
||||
@@ -478,9 +420,9 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
|
||||
|
||||
def get_blocks() -> dict[str, Type[Block]]:
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks import AVAILABLE_BLOCKS # noqa: E402
|
||||
|
||||
return load_all_blocks()
|
||||
return AVAILABLE_BLOCKS
|
||||
|
||||
|
||||
async def initialize_blocks() -> None:
|
||||
@@ -491,12 +433,12 @@ async def initialize_blocks() -> None:
|
||||
)
|
||||
if not existing_block:
|
||||
await AgentBlock.prisma().create(
|
||||
data=AgentBlockCreateInput(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
inputSchema=json.dumps(block.input_schema.jsonschema()),
|
||||
outputSchema=json.dumps(block.output_schema.jsonschema()),
|
||||
)
|
||||
data={
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"inputSchema": json.dumps(block.input_schema.jsonschema()),
|
||||
"outputSchema": json.dumps(block.output_schema.jsonschema()),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -519,7 +461,6 @@ async def initialize_blocks() -> None:
|
||||
)
|
||||
|
||||
|
||||
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||
def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
|
||||
def get_block(block_id: str) -> Block | None:
|
||||
cls = get_blocks().get(block_id)
|
||||
return cls() if cls else None
|
||||
|
||||
@@ -15,7 +15,6 @@ from backend.blocks.llm import (
|
||||
LlmModel,
|
||||
)
|
||||
from backend.blocks.replicate_flux_advanced import ReplicateFluxAdvancedModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.data.block import Block
|
||||
@@ -36,17 +35,14 @@ from backend.integrations.credentials_store import (
|
||||
# =============== Configure the cost for each LLM Model call =============== #
|
||||
|
||||
MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3: 7,
|
||||
LlmModel.O3_MINI: 2, # $1.10 / $4.40
|
||||
LlmModel.O1: 16, # $15 / $60
|
||||
LlmModel.O1_PREVIEW: 16,
|
||||
LlmModel.O1_MINI: 4,
|
||||
LlmModel.GPT41: 2,
|
||||
LlmModel.GPT4O_MINI: 1,
|
||||
LlmModel.GPT4O: 3,
|
||||
LlmModel.GPT4_TURBO: 10,
|
||||
LlmModel.GPT3_5_TURBO: 1,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_5_SONNET: 4,
|
||||
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
|
||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||
@@ -63,7 +59,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.DEEPSEEK_LLAMA_70B: 1, # ? / ?
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
LlmModel.GEMINI_FLASH_1_5: 1,
|
||||
LlmModel.GEMINI_2_5_PRO: 4,
|
||||
LlmModel.GROK_BETA: 5,
|
||||
LlmModel.MISTRAL_NEMO: 1,
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: 1,
|
||||
@@ -79,8 +74,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: 1,
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
|
||||
LlmModel.META_LLAMA_4_SCOUT: 1,
|
||||
LlmModel.META_LLAMA_4_MAVERICK: 1,
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
@@ -272,5 +265,4 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
},
|
||||
)
|
||||
],
|
||||
SmartDecisionMakerBlock: LLM_COST,
|
||||
}
|
||||
|
||||
@@ -11,20 +11,18 @@ from prisma.enums import (
|
||||
CreditRefundRequestStatus,
|
||||
CreditTransactionType,
|
||||
NotificationType,
|
||||
OnboardingStep,
|
||||
)
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User
|
||||
from prisma.types import (
|
||||
CreditRefundRequestCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
CreditTransactionWhereInput,
|
||||
)
|
||||
from prisma.types import CreditTransactionCreateInput, CreditTransactionWhereInput
|
||||
from pydantic import BaseModel
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block import Block, BlockInput, get_block
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.data.execution import NodeExecutionEntry
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -33,16 +31,13 @@ from backend.data.model import (
|
||||
)
|
||||
from backend.data.notifications import NotificationEventDTO, RefundRequestData
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor.utils import UsageTransactionMetadata
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
stripe.api_key = settings.secrets.stripe_api_key
|
||||
logger = logging.getLogger(__name__)
|
||||
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
|
||||
|
||||
class UserCreditBase(ABC):
|
||||
@@ -94,20 +89,20 @@ class UserCreditBase(ABC):
|
||||
@abstractmethod
|
||||
async def spend_credits(
|
||||
self,
|
||||
user_id: str,
|
||||
cost: int,
|
||||
metadata: UsageTransactionMetadata,
|
||||
entry: NodeExecutionEntry,
|
||||
data_size: float,
|
||||
run_time: float,
|
||||
) -> int:
|
||||
"""
|
||||
Spend the credits for the user based on the cost.
|
||||
Spend the credits for the user based on the block usage.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
cost (int): The cost to spend.
|
||||
metadata (UsageTransactionMetadata): The metadata of the transaction.
|
||||
entry (NodeExecutionEntry): The node execution identifiers & data.
|
||||
data_size (float): The size of the data being processed.
|
||||
run_time (float): The time taken to run the block.
|
||||
|
||||
Returns:
|
||||
int: The remaining balance.
|
||||
int: amount of credit spent
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -122,18 +117,6 @@ class UserCreditBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
|
||||
"""
|
||||
Reward the user with credits for completing an onboarding step.
|
||||
Won't reward if the user has already received credits for the step.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
step (OnboardingStep): The onboarding step.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def top_up_intent(self, user_id: str, amount: int) -> str:
|
||||
"""
|
||||
@@ -201,14 +184,6 @@ class UserCreditBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def create_billing_portal_session(user_id: str) -> str:
|
||||
session = stripe.billing_portal.Session.create(
|
||||
customer=await get_stripe_customer_id(user_id),
|
||||
return_url=base_url + "/profile/credits",
|
||||
)
|
||||
return session.url
|
||||
|
||||
@staticmethod
|
||||
def time_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
@@ -226,7 +201,7 @@ class UserCreditBase(ABC):
|
||||
"userId": user_id,
|
||||
"createdAt": {"lte": top_time},
|
||||
"isActive": True,
|
||||
"NOT": [{"runningBalance": None}],
|
||||
"runningBalance": {"not": None}, # type: ignore
|
||||
},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
@@ -274,6 +249,7 @@ class UserCreditBase(ABC):
|
||||
metadata: Json,
|
||||
new_transaction_key: str | None = None,
|
||||
):
|
||||
|
||||
transaction = await CreditTransaction.prisma().find_first_or_raise(
|
||||
where={"transactionKey": transaction_key, "userId": user_id}
|
||||
)
|
||||
@@ -338,32 +314,39 @@ class UserCreditBase(ABC):
|
||||
|
||||
if amount < 0 and user_balance + amount < 0:
|
||||
if fail_insufficient_credits:
|
||||
raise InsufficientBalanceError(
|
||||
message=f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}",
|
||||
user_id=user_id,
|
||||
balance=user_balance,
|
||||
amount=amount,
|
||||
raise ValueError(
|
||||
f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}"
|
||||
)
|
||||
|
||||
amount = min(-user_balance, 0)
|
||||
|
||||
# Create the transaction
|
||||
transaction_data = CreditTransactionCreateInput(
|
||||
userId=user_id,
|
||||
amount=amount,
|
||||
runningBalance=user_balance + amount,
|
||||
type=transaction_type,
|
||||
metadata=metadata,
|
||||
isActive=is_active,
|
||||
createdAt=self.time_now(),
|
||||
)
|
||||
transaction_data: CreditTransactionCreateInput = {
|
||||
"userId": user_id,
|
||||
"amount": amount,
|
||||
"runningBalance": user_balance + amount,
|
||||
"type": transaction_type,
|
||||
"metadata": metadata,
|
||||
"isActive": is_active,
|
||||
"createdAt": self.time_now(),
|
||||
}
|
||||
if transaction_key:
|
||||
transaction_data["transactionKey"] = transaction_key
|
||||
tx = await CreditTransaction.prisma().create(data=transaction_data)
|
||||
return user_balance + amount, tx.transactionKey
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
graph_exec_id: str | None = None
|
||||
graph_id: str | None = None
|
||||
node_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
block_id: str | None = None
|
||||
block: str | None = None
|
||||
input: BlockInput | None = None
|
||||
|
||||
|
||||
class UserCredit(UserCreditBase):
|
||||
|
||||
@thread_cached
|
||||
def notification_client(self) -> NotificationManager:
|
||||
return get_service_client(NotificationManager)
|
||||
@@ -376,6 +359,7 @@ class UserCredit(UserCreditBase):
|
||||
await asyncio.to_thread(
|
||||
lambda: self.notification_client().queue_notification(
|
||||
NotificationEventDTO(
|
||||
recipient_email=settings.config.refund_notification_email,
|
||||
user_id=notification_request.user_id,
|
||||
type=notification_type,
|
||||
data=notification_request.model_dump(),
|
||||
@@ -383,21 +367,89 @@ class UserCredit(UserCreditBase):
|
||||
)
|
||||
)
|
||||
|
||||
def _block_usage_cost(
|
||||
self,
|
||||
block: Block,
|
||||
input_data: BlockInput,
|
||||
data_size: float,
|
||||
run_time: float,
|
||||
) -> tuple[int, BlockInput]:
|
||||
block_costs = BLOCK_COSTS.get(type(block))
|
||||
if not block_costs:
|
||||
return 0, {}
|
||||
|
||||
for block_cost in block_costs:
|
||||
if not self._is_cost_filter_match(block_cost.cost_filter, input_data):
|
||||
continue
|
||||
|
||||
if block_cost.cost_type == BlockCostType.RUN:
|
||||
return block_cost.cost_amount, block_cost.cost_filter
|
||||
|
||||
if block_cost.cost_type == BlockCostType.SECOND:
|
||||
return (
|
||||
int(run_time * block_cost.cost_amount),
|
||||
block_cost.cost_filter,
|
||||
)
|
||||
|
||||
if block_cost.cost_type == BlockCostType.BYTE:
|
||||
return (
|
||||
int(data_size * block_cost.cost_amount),
|
||||
block_cost.cost_filter,
|
||||
)
|
||||
|
||||
return 0, {}
|
||||
|
||||
def _is_cost_filter_match(
|
||||
self, cost_filter: BlockInput, input_data: BlockInput
|
||||
) -> bool:
|
||||
"""
|
||||
Filter rules:
|
||||
- If cost_filter is an object, then check if cost_filter is the subset of input_data
|
||||
- Otherwise, check if cost_filter is equal to input_data.
|
||||
- Undefined, null, and empty string are considered as equal.
|
||||
"""
|
||||
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
|
||||
return cost_filter == input_data
|
||||
|
||||
return all(
|
||||
(not input_data.get(k) and not v)
|
||||
or (input_data.get(k) and self._is_cost_filter_match(v, input_data[k]))
|
||||
for k, v in cost_filter.items()
|
||||
)
|
||||
|
||||
async def spend_credits(
|
||||
self,
|
||||
user_id: str,
|
||||
cost: int,
|
||||
metadata: UsageTransactionMetadata,
|
||||
entry: NodeExecutionEntry,
|
||||
data_size: float,
|
||||
run_time: float,
|
||||
) -> int:
|
||||
block = get_block(entry.block_id)
|
||||
if not block:
|
||||
raise ValueError(f"Block not found: {entry.block_id}")
|
||||
|
||||
cost, matching_filter = self._block_usage_cost(
|
||||
block=block, input_data=entry.data, data_size=data_size, run_time=run_time
|
||||
)
|
||||
if cost == 0:
|
||||
return 0
|
||||
|
||||
balance, _ = await self._add_transaction(
|
||||
user_id=user_id,
|
||||
user_id=entry.user_id,
|
||||
amount=-cost,
|
||||
transaction_type=CreditTransactionType.USAGE,
|
||||
metadata=Json(metadata.model_dump()),
|
||||
metadata=Json(
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=entry.graph_exec_id,
|
||||
graph_id=entry.graph_id,
|
||||
node_id=entry.node_id,
|
||||
node_exec_id=entry.node_exec_id,
|
||||
block_id=entry.block_id,
|
||||
block=block.name,
|
||||
input=matching_filter,
|
||||
).model_dump()
|
||||
),
|
||||
)
|
||||
user_id = entry.user_id
|
||||
|
||||
# Auto top-up if balance is below threshold.
|
||||
auto_top_up = await get_auto_top_up(user_id)
|
||||
@@ -407,7 +459,7 @@ class UserCredit(UserCreditBase):
|
||||
user_id=user_id,
|
||||
amount=auto_top_up.amount,
|
||||
# Avoid multiple auto top-ups within the same graph execution.
|
||||
key=f"AUTO-TOP-UP-{user_id}-{metadata.graph_exec_id}",
|
||||
key=f"AUTO-TOP-UP-{user_id}-{entry.graph_exec_id}",
|
||||
ceiling_balance=auto_top_up.threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -416,29 +468,11 @@ class UserCredit(UserCreditBase):
|
||||
f"Auto top-up failed for user {user_id}, balance: {balance}, amount: {auto_top_up.amount}, error: {e}"
|
||||
)
|
||||
|
||||
return balance
|
||||
return cost
|
||||
|
||||
async def top_up_credits(self, user_id: str, amount: int):
|
||||
await self._top_up_credits(user_id, amount)
|
||||
|
||||
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
|
||||
key = f"REWARD-{user_id}-{step.value}"
|
||||
if not await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"transactionKey": key,
|
||||
}
|
||||
):
|
||||
await self._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=credits,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
transaction_key=key,
|
||||
metadata=Json(
|
||||
{"reason": f"Reward for completing {step.value} onboarding step."}
|
||||
),
|
||||
)
|
||||
|
||||
async def top_up_refund(
|
||||
self, user_id: str, transaction_key: str, metadata: dict[str, str]
|
||||
) -> int:
|
||||
@@ -457,15 +491,15 @@ class UserCredit(UserCreditBase):
|
||||
|
||||
try:
|
||||
refund_request = await CreditRefundRequest.prisma().create(
|
||||
data=CreditRefundRequestCreateInput(
|
||||
id=refund_key,
|
||||
transactionKey=transaction_key,
|
||||
userId=user_id,
|
||||
amount=amount,
|
||||
reason=metadata.get("reason", ""),
|
||||
status=CreditRefundRequestStatus.PENDING,
|
||||
result="The refund request is under review.",
|
||||
)
|
||||
data={
|
||||
"id": refund_key,
|
||||
"transactionKey": transaction_key,
|
||||
"userId": user_id,
|
||||
"amount": amount,
|
||||
"reason": metadata.get("reason", ""),
|
||||
"status": CreditRefundRequestStatus.PENDING,
|
||||
"result": "The refund request is under review.",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
raise ValueError(
|
||||
@@ -729,8 +763,10 @@ class UserCredit(UserCreditBase):
|
||||
ui_mode="hosted",
|
||||
payment_intent_data={"setup_future_usage": "off_session"},
|
||||
saved_payment_method_options={"payment_method_save": "enabled"},
|
||||
success_url=base_url + "/profile/credits?topup=success",
|
||||
cancel_url=base_url + "/profile/credits?topup=cancel",
|
||||
success_url=settings.config.frontend_base_url
|
||||
+ "/profile/credits?topup=success",
|
||||
cancel_url=settings.config.frontend_base_url
|
||||
+ "/profile/credits?topup=cancel",
|
||||
allow_promotion_codes=True,
|
||||
)
|
||||
|
||||
@@ -805,6 +841,7 @@ class UserCredit(UserCreditBase):
|
||||
transaction_time_ceiling: datetime | None = None,
|
||||
transaction_type: str | None = None,
|
||||
) -> TransactionHistory:
|
||||
|
||||
transactions_filter: CreditTransactionWhereInput = {
|
||||
"userId": user_id,
|
||||
"isActive": True,
|
||||
@@ -926,9 +963,6 @@ class DisabledUserCredit(UserCreditBase):
|
||||
async def top_up_credits(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def onboarding_reward(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def top_up_intent(self, *args, **kwargs) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
import os
|
||||
import zlib
|
||||
from contextlib import asynccontextmanager
|
||||
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
|
||||
from uuid import uuid4
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -16,36 +15,7 @@ load_dotenv()
|
||||
PRISMA_SCHEMA = os.getenv("PRISMA_SCHEMA", "schema.prisma")
|
||||
os.environ["PRISMA_SCHEMA_PATH"] = PRISMA_SCHEMA
|
||||
|
||||
|
||||
def add_param(url: str, key: str, value: str) -> str:
|
||||
p = urlparse(url)
|
||||
qs = dict(parse_qsl(p.query))
|
||||
qs[key] = value
|
||||
return urlunparse(p._replace(query=urlencode(qs)))
|
||||
|
||||
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||
|
||||
CONN_LIMIT = os.getenv("DB_CONNECTION_LIMIT")
|
||||
if CONN_LIMIT:
|
||||
DATABASE_URL = add_param(DATABASE_URL, "connection_limit", CONN_LIMIT)
|
||||
|
||||
CONN_TIMEOUT = os.getenv("DB_CONNECT_TIMEOUT")
|
||||
if CONN_TIMEOUT:
|
||||
DATABASE_URL = add_param(DATABASE_URL, "connect_timeout", CONN_TIMEOUT)
|
||||
|
||||
POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT")
|
||||
if POOL_TIMEOUT:
|
||||
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
|
||||
|
||||
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
|
||||
|
||||
prisma = Prisma(
|
||||
auto_register=True,
|
||||
http={"timeout": HTTP_TIMEOUT},
|
||||
datasource={"url": DATABASE_URL},
|
||||
)
|
||||
|
||||
prisma = Prisma(auto_register=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -62,10 +32,10 @@ async def connect():
|
||||
|
||||
# Connection acquired from a pool like Supabase somehow still possibly allows
|
||||
# the db client obtains a connection but still reject query connection afterward.
|
||||
# try:
|
||||
# await prisma.execute_raw("SELECT 1")
|
||||
# except Exception as e:
|
||||
# raise ConnectionError("Failed to connect to Prisma.") from e
|
||||
try:
|
||||
await prisma.execute_raw("SELECT 1")
|
||||
except Exception as e:
|
||||
raise ConnectionError("Failed to connect to Prisma.") from e
|
||||
|
||||
|
||||
@conn_retry("Prisma", "Releasing connection")
|
||||
@@ -89,7 +59,7 @@ async def transaction():
|
||||
async def locked_transaction(key: str):
|
||||
lock_key = zlib.crc32(key.encode("utf-8"))
|
||||
async with transaction() as tx:
|
||||
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
|
||||
await tx.execute_raw(f"SELECT pg_advisory_xact_lock({lock_key})")
|
||||
yield tx
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,4 @@
|
||||
from typing import cast
|
||||
|
||||
import prisma.enums
|
||||
import prisma.types
|
||||
|
||||
from backend.blocks.io import IO_BLOCK_IDs
|
||||
import prisma
|
||||
|
||||
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
"Input": True,
|
||||
@@ -13,61 +8,27 @@ AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
}
|
||||
|
||||
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
|
||||
"Nodes": {"include": AGENT_NODE_INCLUDE}
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
|
||||
}
|
||||
|
||||
EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"Node": True,
|
||||
"GraphExecution": True,
|
||||
}
|
||||
|
||||
MAX_NODE_EXECUTIONS_FETCH = 1000
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
|
||||
"NodeExecutions": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"Node": True,
|
||||
"GraphExecution": True,
|
||||
},
|
||||
"order_by": [
|
||||
{"queuedTime": "desc"},
|
||||
# Fallback: Incomplete execs has no queuedTime.
|
||||
{"addedTime": "desc"},
|
||||
],
|
||||
"take": MAX_NODE_EXECUTIONS_FETCH, # Avoid loading excessive node executions.
|
||||
}
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
"NodeExecutions": {
|
||||
**cast(
|
||||
prisma.types.FindManyAgentNodeExecutionArgsFromAgentGraphExecution,
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES["NodeExecutions"],
|
||||
),
|
||||
"where": {
|
||||
"Node": {"is": {"AgentBlock": {"is": {"id": {"in": IO_BLOCK_IDs}}}}},
|
||||
"NOT": [{"executionStatus": prisma.enums.AgentExecutionStatus.INCOMPLETE}],
|
||||
},
|
||||
"AgentNodeExecutions": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE}
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
|
||||
}
|
||||
|
||||
|
||||
def library_agent_include(user_id: str) -> prisma.types.LibraryAgentInclude:
|
||||
return {
|
||||
"AgentGraph": {
|
||||
"include": {
|
||||
**AGENT_GRAPH_INCLUDE,
|
||||
"Executions": {"where": {"userId": user_id}},
|
||||
}
|
||||
},
|
||||
"Creator": True,
|
||||
}
|
||||
|
||||
@@ -3,14 +3,12 @@ from typing import TYPE_CHECKING, AsyncGenerator, Optional
|
||||
|
||||
from prisma import Json
|
||||
from prisma.models import IntegrationWebhook
|
||||
from prisma.types import IntegrationWebhookCreateInput
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
from backend.data.queue import AsyncRedisEventBus
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .db import BaseDbModel
|
||||
|
||||
@@ -67,35 +65,28 @@ class Webhook(BaseDbModel):
|
||||
|
||||
async def create_webhook(webhook: Webhook) -> Webhook:
|
||||
created_webhook = await IntegrationWebhook.prisma().create(
|
||||
data=IntegrationWebhookCreateInput(
|
||||
id=webhook.id,
|
||||
userId=webhook.user_id,
|
||||
provider=webhook.provider.value,
|
||||
credentialsId=webhook.credentials_id,
|
||||
webhookType=webhook.webhook_type,
|
||||
resource=webhook.resource,
|
||||
events=webhook.events,
|
||||
config=Json(webhook.config),
|
||||
secret=webhook.secret,
|
||||
providerWebhookId=webhook.provider_webhook_id,
|
||||
)
|
||||
data={
|
||||
"id": webhook.id,
|
||||
"userId": webhook.user_id,
|
||||
"provider": webhook.provider.value,
|
||||
"credentialsId": webhook.credentials_id,
|
||||
"webhookType": webhook.webhook_type,
|
||||
"resource": webhook.resource,
|
||||
"events": webhook.events,
|
||||
"config": Json(webhook.config),
|
||||
"secret": webhook.secret,
|
||||
"providerWebhookId": webhook.provider_webhook_id,
|
||||
}
|
||||
)
|
||||
return Webhook.from_db(created_webhook)
|
||||
|
||||
|
||||
async def get_webhook(webhook_id: str) -> Webhook:
|
||||
"""
|
||||
⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints.
|
||||
|
||||
Raises:
|
||||
NotFoundError: if no record with the given ID exists
|
||||
"""
|
||||
webhook = await IntegrationWebhook.prisma().find_unique(
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
webhook = await IntegrationWebhook.prisma().find_unique_or_raise(
|
||||
where={"id": webhook_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
)
|
||||
if not webhook:
|
||||
raise NotFoundError(f"Webhook #{webhook_id} not found")
|
||||
return Webhook.from_db(webhook)
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -13,7 +12,6 @@ from typing import (
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
get_args,
|
||||
@@ -143,20 +141,17 @@ def SchemaField(
|
||||
secret: bool = False,
|
||||
exclude: bool = False,
|
||||
hidden: Optional[bool] = None,
|
||||
depends_on: Optional[list[str]] = None,
|
||||
ge: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
json_schema_extra: Optional[dict[str, Any]] = None,
|
||||
depends_on: list[str] | None = None,
|
||||
image_upload: Optional[bool] = None,
|
||||
image_output: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
if default is PydanticUndefined and default_factory is None:
|
||||
advanced = False
|
||||
elif advanced is None:
|
||||
advanced = True
|
||||
|
||||
json_schema_extra = {
|
||||
json_extra = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"placeholder": placeholder,
|
||||
@@ -164,7 +159,8 @@ def SchemaField(
|
||||
"advanced": advanced,
|
||||
"hidden": hidden,
|
||||
"depends_on": depends_on,
|
||||
**(json_schema_extra or {}),
|
||||
"image_upload": image_upload,
|
||||
"image_output": image_output,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
@@ -176,12 +172,8 @@ def SchemaField(
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
ge=ge,
|
||||
le=le,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
discriminator=discriminator,
|
||||
json_schema_extra=json_schema_extra,
|
||||
json_schema_extra=json_extra,
|
||||
**kwargs,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
@@ -302,7 +294,9 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
field_schema = model.jsonschema()["properties"][field_name]
|
||||
try:
|
||||
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
||||
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
|
||||
field_schema
|
||||
)
|
||||
except ValidationError as e:
|
||||
if "Field required [type=missing" not in str(e):
|
||||
raise
|
||||
@@ -328,90 +322,14 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
|
||||
|
||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
|
||||
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
|
||||
provider: frozenset[CP] = Field(..., alias="credentials_provider")
|
||||
supported_types: frozenset[CT] = Field(..., alias="credentials_types")
|
||||
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
|
||||
credentials_provider: list[CP]
|
||||
credentials_scopes: Optional[list[str]] = None
|
||||
credentials_types: list[CT]
|
||||
discriminator: Optional[str] = None
|
||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||
|
||||
@classmethod
|
||||
def combine(
|
||||
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
|
||||
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
|
||||
"""
|
||||
Combines multiple CredentialsFieldInfo objects into as few as possible.
|
||||
|
||||
Rules:
|
||||
- Items can only be combined if they have the same supported credentials types
|
||||
and the same supported providers.
|
||||
- When combining items, the `required_scopes` of the result is a join
|
||||
of the `required_scopes` of the original items.
|
||||
|
||||
Params:
|
||||
*fields: (CredentialsFieldInfo, key) objects to group and combine
|
||||
|
||||
Returns:
|
||||
A sequence of tuples containing combined CredentialsFieldInfo objects and
|
||||
the set of keys of the respective original items that were grouped together.
|
||||
"""
|
||||
if not fields:
|
||||
return []
|
||||
|
||||
# Group fields by their provider and supported_types
|
||||
grouped_fields: defaultdict[
|
||||
tuple[frozenset[CP], frozenset[CT]],
|
||||
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
|
||||
] = defaultdict(list)
|
||||
|
||||
for field, key in fields:
|
||||
group_key = (frozenset(field.provider), frozenset(field.supported_types))
|
||||
grouped_fields[group_key].append((key, field))
|
||||
|
||||
# Combine fields within each group
|
||||
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
|
||||
|
||||
for group in grouped_fields.values():
|
||||
# Start with the first field in the group
|
||||
_, combined = group[0]
|
||||
|
||||
# Track the keys that were combined
|
||||
combined_keys = {key for key, _ in group}
|
||||
|
||||
# Combine required_scopes from all fields in the group
|
||||
all_scopes = set()
|
||||
for _, field in group:
|
||||
if field.required_scopes:
|
||||
all_scopes.update(field.required_scopes)
|
||||
|
||||
# Create a new combined field
|
||||
result.append(
|
||||
(
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
credentials_types=combined.supported_types,
|
||||
credentials_scopes=frozenset(all_scopes) or None,
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
),
|
||||
combined_keys,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def discriminate(self, discriminator_value: Any) -> CredentialsFieldInfo:
|
||||
if not (self.discriminator and self.discriminator_mapping):
|
||||
return self
|
||||
|
||||
discriminator_value = self.discriminator_mapping[discriminator_value]
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([discriminator_value]),
|
||||
credentials_types=self.supported_types,
|
||||
credentials_scopes=self.required_scopes,
|
||||
)
|
||||
|
||||
|
||||
def CredentialsField(
|
||||
required_scopes: set[str] = set(),
|
||||
@@ -484,46 +402,3 @@ class RefundRequest(BaseModel):
|
||||
status: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class NodeExecutionStats(BaseModel):
|
||||
"""Execution statistics for a node execution."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
error: Optional[Exception | str] = None
|
||||
walltime: float = 0
|
||||
cputime: float = 0
|
||||
input_size: int = 0
|
||||
output_size: int = 0
|
||||
llm_call_count: int = 0
|
||||
llm_retry_count: int = 0
|
||||
input_token_count: int = 0
|
||||
output_token_count: int = 0
|
||||
|
||||
|
||||
class GraphExecutionStats(BaseModel):
|
||||
"""Execution statistics for a graph execution."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
error: Optional[Exception | str] = None
|
||||
walltime: float = Field(
|
||||
default=0, description="Time between start and end of run (seconds)"
|
||||
)
|
||||
cputime: float = 0
|
||||
nodes_walltime: float = Field(
|
||||
default=0, description="Total node execution time (seconds)"
|
||||
)
|
||||
nodes_cputime: float = 0
|
||||
node_count: int = Field(default=0, description="Total number of node executions")
|
||||
node_error_count: int = Field(
|
||||
default=0, description="Total number of errors generated"
|
||||
)
|
||||
cost: int = Field(default=0, description="Total execution cost (cents)")
|
||||
|
||||
@@ -1,19 +1,15 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Generic, Optional, TypeVar, Union
|
||||
|
||||
from prisma import Json
|
||||
from prisma.enums import NotificationType
|
||||
from prisma.models import NotificationEvent, UserNotificationBatch
|
||||
from prisma.types import (
|
||||
NotificationEventCreateInput,
|
||||
UserNotificationBatchCreateInput,
|
||||
UserNotificationBatchWhereInput,
|
||||
)
|
||||
from prisma.types import UserNotificationBatchWhereInput
|
||||
|
||||
# from backend.notifications.models import NotificationEvent
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
|
||||
from pydantic import BaseModel, EmailStr, Field, field_validator
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
|
||||
@@ -22,24 +18,18 @@ from .db import transaction
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
NotificationDataType_co = TypeVar(
|
||||
"NotificationDataType_co", bound="BaseNotificationData", covariant=True
|
||||
)
|
||||
SummaryParamsType_co = TypeVar(
|
||||
"SummaryParamsType_co", bound="BaseSummaryParams", covariant=True
|
||||
)
|
||||
T_co = TypeVar("T_co", bound="BaseNotificationData", covariant=True)
|
||||
|
||||
|
||||
class QueueType(Enum):
|
||||
class BatchingStrategy(Enum):
|
||||
IMMEDIATE = "immediate" # Send right away (errors, critical notifications)
|
||||
BATCH = "batch" # Batch for up to an hour (usage reports)
|
||||
SUMMARY = "summary" # Daily digest (summary notifications)
|
||||
HOURLY = "hourly" # Batch for up to an hour (usage reports)
|
||||
DAILY = "daily" # Daily digest (summary notifications)
|
||||
BACKOFF = "backoff" # Backoff strategy (exponential backoff)
|
||||
ADMIN = "admin" # Admin notifications (errors, critical notifications)
|
||||
|
||||
|
||||
class BaseNotificationData(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
pass
|
||||
|
||||
|
||||
class AgentRunData(BaseNotificationData):
|
||||
@@ -48,7 +38,7 @@ class AgentRunData(BaseNotificationData):
|
||||
execution_time: float
|
||||
node_count: int = Field(..., description="Number of nodes executed")
|
||||
graph_id: str
|
||||
outputs: list[dict[str, Any]] = Field(..., description="Outputs of the agent")
|
||||
outputs: dict[str, Any] = Field(..., description="Outputs of the agent")
|
||||
|
||||
|
||||
class ZeroBalanceData(BaseNotificationData):
|
||||
@@ -56,21 +46,12 @@ class ZeroBalanceData(BaseNotificationData):
|
||||
last_transaction_time: datetime
|
||||
top_up_link: str
|
||||
|
||||
@field_validator("last_transaction_time")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: datetime):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class LowBalanceData(BaseNotificationData):
|
||||
agent_name: str = Field(..., description="Name of the agent")
|
||||
current_balance: float = Field(
|
||||
..., description="Current balance in credits (100 = $1)"
|
||||
)
|
||||
billing_page_link: str = Field(..., description="Link to billing page")
|
||||
shortfall: float = Field(..., description="Amount of credits needed to continue")
|
||||
current_balance: float
|
||||
threshold_amount: float
|
||||
top_up_link: str
|
||||
recent_usage: float = Field(..., description="Usage in the last 24 hours")
|
||||
|
||||
|
||||
class BlockExecutionFailedData(BaseNotificationData):
|
||||
@@ -91,13 +72,6 @@ class ContinuousAgentErrorData(BaseNotificationData):
|
||||
error_time: datetime
|
||||
attempts: int = Field(..., description="Number of retry attempts made")
|
||||
|
||||
@field_validator("start_time", "error_time")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: datetime):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class BaseSummaryData(BaseNotificationData):
|
||||
total_credits_used: float
|
||||
@@ -110,53 +84,18 @@ class BaseSummaryData(BaseNotificationData):
|
||||
cost_breakdown: dict[str, float]
|
||||
|
||||
|
||||
class BaseSummaryParams(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class DailySummaryParams(BaseSummaryParams):
|
||||
date: datetime
|
||||
|
||||
@field_validator("date")
|
||||
def validate_timezone(cls, value):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class WeeklySummaryParams(BaseSummaryParams):
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
|
||||
@field_validator("start_date", "end_date")
|
||||
def validate_timezone(cls, value):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class DailySummaryData(BaseSummaryData):
|
||||
date: datetime
|
||||
|
||||
@field_validator("date")
|
||||
def validate_timezone(cls, value):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class WeeklySummaryData(BaseSummaryData):
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
|
||||
@field_validator("start_date", "end_date")
|
||||
def validate_timezone(cls, value):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
week_number: int
|
||||
year: int
|
||||
|
||||
|
||||
class MonthlySummaryData(BaseNotificationData):
|
||||
class MonthlySummaryData(BaseSummaryData):
|
||||
month: int
|
||||
year: int
|
||||
|
||||
@@ -180,10 +119,6 @@ NotificationData = Annotated[
|
||||
BlockExecutionFailedData,
|
||||
ContinuousAgentErrorData,
|
||||
MonthlySummaryData,
|
||||
WeeklySummaryData,
|
||||
DailySummaryData,
|
||||
RefundRequestData,
|
||||
BaseSummaryData,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
@@ -193,25 +128,19 @@ class NotificationEventDTO(BaseModel):
|
||||
user_id: str
|
||||
type: NotificationType
|
||||
data: dict
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
recipient_email: Optional[str] = None
|
||||
retry_count: int = 0
|
||||
|
||||
|
||||
class SummaryParamsEventDTO(BaseModel):
|
||||
class NotificationEventModel(BaseModel, Generic[T_co]):
|
||||
user_id: str
|
||||
type: NotificationType
|
||||
data: dict
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
||||
|
||||
|
||||
class NotificationEventModel(BaseModel, Generic[NotificationDataType_co]):
|
||||
user_id: str
|
||||
type: NotificationType
|
||||
data: NotificationDataType_co
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
||||
data: T_co
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
@property
|
||||
def strategy(self) -> QueueType:
|
||||
def strategy(self) -> BatchingStrategy:
|
||||
return NotificationTypeOverride(self.type).strategy
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@@ -225,14 +154,7 @@ class NotificationEventModel(BaseModel, Generic[NotificationDataType_co]):
|
||||
return NotificationTypeOverride(self.type).template
|
||||
|
||||
|
||||
class SummaryParamsEventModel(BaseModel, Generic[SummaryParamsType_co]):
|
||||
user_id: str
|
||||
type: NotificationType
|
||||
data: SummaryParamsType_co
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
||||
|
||||
|
||||
def get_notif_data_type(
|
||||
def get_data_type(
|
||||
notification_type: NotificationType,
|
||||
) -> type[BaseNotificationData]:
|
||||
return {
|
||||
@@ -249,20 +171,11 @@ def get_notif_data_type(
|
||||
}[notification_type]
|
||||
|
||||
|
||||
def get_summary_params_type(
|
||||
notification_type: NotificationType,
|
||||
) -> type[BaseSummaryParams]:
|
||||
return {
|
||||
NotificationType.DAILY_SUMMARY: DailySummaryParams,
|
||||
NotificationType.WEEKLY_SUMMARY: WeeklySummaryParams,
|
||||
}[notification_type]
|
||||
|
||||
|
||||
class NotificationBatch(BaseModel):
|
||||
user_id: str
|
||||
events: list[NotificationEvent]
|
||||
strategy: QueueType
|
||||
last_update: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
||||
strategy: BatchingStrategy
|
||||
last_update: datetime = datetime.now()
|
||||
|
||||
|
||||
class NotificationResult(BaseModel):
|
||||
@@ -275,22 +188,23 @@ class NotificationTypeOverride:
|
||||
self.notification_type = notification_type
|
||||
|
||||
@property
|
||||
def strategy(self) -> QueueType:
|
||||
def strategy(self) -> BatchingStrategy:
|
||||
BATCHING_RULES = {
|
||||
# These are batched by the notification service
|
||||
NotificationType.AGENT_RUN: QueueType.BATCH,
|
||||
NotificationType.AGENT_RUN: BatchingStrategy.IMMEDIATE,
|
||||
# These are batched by the notification service, but with a backoff strategy
|
||||
NotificationType.ZERO_BALANCE: QueueType.BACKOFF,
|
||||
NotificationType.LOW_BALANCE: QueueType.IMMEDIATE,
|
||||
NotificationType.BLOCK_EXECUTION_FAILED: QueueType.BACKOFF,
|
||||
NotificationType.CONTINUOUS_AGENT_ERROR: QueueType.BACKOFF,
|
||||
NotificationType.DAILY_SUMMARY: QueueType.SUMMARY,
|
||||
NotificationType.WEEKLY_SUMMARY: QueueType.SUMMARY,
|
||||
NotificationType.MONTHLY_SUMMARY: QueueType.SUMMARY,
|
||||
NotificationType.REFUND_REQUEST: QueueType.ADMIN,
|
||||
NotificationType.REFUND_PROCESSED: QueueType.ADMIN,
|
||||
NotificationType.ZERO_BALANCE: BatchingStrategy.BACKOFF,
|
||||
NotificationType.LOW_BALANCE: BatchingStrategy.BACKOFF,
|
||||
NotificationType.BLOCK_EXECUTION_FAILED: BatchingStrategy.BACKOFF,
|
||||
NotificationType.CONTINUOUS_AGENT_ERROR: BatchingStrategy.BACKOFF,
|
||||
# These aren't batched by the notification service, so we send them right away
|
||||
NotificationType.DAILY_SUMMARY: BatchingStrategy.IMMEDIATE,
|
||||
NotificationType.WEEKLY_SUMMARY: BatchingStrategy.IMMEDIATE,
|
||||
NotificationType.MONTHLY_SUMMARY: BatchingStrategy.IMMEDIATE,
|
||||
NotificationType.REFUND_REQUEST: BatchingStrategy.IMMEDIATE,
|
||||
NotificationType.REFUND_PROCESSED: BatchingStrategy.IMMEDIATE,
|
||||
}
|
||||
return BATCHING_RULES.get(self.notification_type, QueueType.IMMEDIATE)
|
||||
return BATCHING_RULES.get(self.notification_type, BatchingStrategy.HOURLY)
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
@@ -340,51 +254,12 @@ class NotificationPreference(BaseModel):
|
||||
)
|
||||
daily_limit: int = 10 # Max emails per day
|
||||
emails_sent_today: int = 0
|
||||
last_reset_date: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
class UserNotificationEventDTO(BaseModel):
|
||||
type: NotificationType
|
||||
data: dict
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@staticmethod
|
||||
def from_db(model: NotificationEvent) -> "UserNotificationEventDTO":
|
||||
return UserNotificationEventDTO(
|
||||
type=model.type,
|
||||
data=dict(model.data),
|
||||
created_at=model.createdAt,
|
||||
updated_at=model.updatedAt,
|
||||
)
|
||||
|
||||
|
||||
class UserNotificationBatchDTO(BaseModel):
|
||||
user_id: str
|
||||
type: NotificationType
|
||||
notifications: list[UserNotificationEventDTO]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@staticmethod
|
||||
def from_db(model: UserNotificationBatch) -> "UserNotificationBatchDTO":
|
||||
return UserNotificationBatchDTO(
|
||||
user_id=model.userId,
|
||||
type=model.type,
|
||||
notifications=[
|
||||
UserNotificationEventDTO.from_db(notification)
|
||||
for notification in model.Notifications or []
|
||||
],
|
||||
created_at=model.createdAt,
|
||||
updated_at=model.updatedAt,
|
||||
)
|
||||
last_reset_date: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
|
||||
def get_batch_delay(notification_type: NotificationType) -> timedelta:
|
||||
return {
|
||||
NotificationType.AGENT_RUN: timedelta(minutes=60),
|
||||
NotificationType.AGENT_RUN: timedelta(seconds=1),
|
||||
NotificationType.ZERO_BALANCE: timedelta(minutes=60),
|
||||
NotificationType.LOW_BALANCE: timedelta(minutes=60),
|
||||
NotificationType.BLOCK_EXECUTION_FAILED: timedelta(minutes=60),
|
||||
@@ -395,17 +270,19 @@ def get_batch_delay(notification_type: NotificationType) -> timedelta:
|
||||
async def create_or_add_to_user_notification_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType,
|
||||
notification_data: NotificationEventModel,
|
||||
) -> UserNotificationBatchDTO:
|
||||
data: str, # type: 'NotificationEventModel'
|
||||
) -> dict:
|
||||
try:
|
||||
logger.info(
|
||||
f"Creating or adding to notification batch for {user_id} with type {notification_type} and data {notification_data}"
|
||||
f"Creating or adding to notification batch for {user_id} with type {notification_type} and data {data}"
|
||||
)
|
||||
if not notification_data.data:
|
||||
raise ValueError("Notification data must be provided")
|
||||
|
||||
notification_data = NotificationEventModel[
|
||||
get_data_type(notification_type)
|
||||
].model_validate_json(data)
|
||||
|
||||
# Serialize the data
|
||||
json_data: Json = Json(notification_data.data.model_dump())
|
||||
json_data: Json = Json(notification_data.data.model_dump_json())
|
||||
|
||||
# First try to find existing batch
|
||||
existing_batch = await UserNotificationBatch.prisma().find_unique(
|
||||
@@ -415,76 +292,70 @@ async def create_or_add_to_user_notification_batch(
|
||||
"type": notification_type,
|
||||
}
|
||||
},
|
||||
include={"Notifications": True},
|
||||
include={"notifications": True},
|
||||
)
|
||||
|
||||
if not existing_batch:
|
||||
async with transaction() as tx:
|
||||
notification_event = await tx.notificationevent.create(
|
||||
data=NotificationEventCreateInput(
|
||||
type=notification_type,
|
||||
data=json_data,
|
||||
)
|
||||
data={
|
||||
"type": notification_type,
|
||||
"data": json_data,
|
||||
}
|
||||
)
|
||||
|
||||
# Create new batch
|
||||
resp = await tx.usernotificationbatch.create(
|
||||
data=UserNotificationBatchCreateInput(
|
||||
userId=user_id,
|
||||
type=notification_type,
|
||||
Notifications={"connect": [{"id": notification_event.id}]},
|
||||
),
|
||||
include={"Notifications": True},
|
||||
data={
|
||||
"userId": user_id,
|
||||
"type": notification_type,
|
||||
"notifications": {"connect": [{"id": notification_event.id}]},
|
||||
},
|
||||
include={"notifications": True},
|
||||
)
|
||||
return UserNotificationBatchDTO.from_db(resp)
|
||||
return resp.model_dump()
|
||||
else:
|
||||
async with transaction() as tx:
|
||||
notification_event = await tx.notificationevent.create(
|
||||
data=NotificationEventCreateInput(
|
||||
type=notification_type,
|
||||
data=json_data,
|
||||
UserNotificationBatch={"connect": {"id": existing_batch.id}},
|
||||
)
|
||||
data={
|
||||
"type": notification_type,
|
||||
"data": json_data,
|
||||
"UserNotificationBatch": {"connect": {"id": existing_batch.id}},
|
||||
}
|
||||
)
|
||||
# Add to existing batch
|
||||
resp = await tx.usernotificationbatch.update(
|
||||
where={"id": existing_batch.id},
|
||||
data={
|
||||
"Notifications": {"connect": [{"id": notification_event.id}]}
|
||||
"notifications": {"connect": [{"id": notification_event.id}]}
|
||||
},
|
||||
include={"Notifications": True},
|
||||
include={"notifications": True},
|
||||
)
|
||||
if not resp:
|
||||
raise DatabaseError(
|
||||
f"Failed to add notification event {notification_event.id} to existing batch {existing_batch.id}"
|
||||
)
|
||||
return UserNotificationBatchDTO.from_db(resp)
|
||||
return resp.model_dump()
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to create or add to notification batch for user {user_id} and type {notification_type}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_user_notification_oldest_message_in_batch(
|
||||
async def get_user_notification_last_message_in_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType,
|
||||
) -> UserNotificationEventDTO | None:
|
||||
) -> NotificationEvent | None:
|
||||
try:
|
||||
batch = await UserNotificationBatch.prisma().find_first(
|
||||
where={"userId": user_id, "type": notification_type},
|
||||
include={"Notifications": True},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
if not batch:
|
||||
return None
|
||||
if not batch.Notifications:
|
||||
if not batch.notifications:
|
||||
return None
|
||||
sorted_notifications = sorted(batch.Notifications, key=lambda x: x.createdAt)
|
||||
|
||||
return (
|
||||
UserNotificationEventDTO.from_db(sorted_notifications[0])
|
||||
if sorted_notifications
|
||||
else None
|
||||
)
|
||||
return batch.notifications[-1]
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to get user notification last message in batch for user {user_id} and type {notification_type}: {e}"
|
||||
@@ -519,34 +390,13 @@ async def empty_user_notification_batch(
|
||||
async def get_user_notification_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType,
|
||||
) -> UserNotificationBatchDTO | None:
|
||||
) -> UserNotificationBatch | None:
|
||||
try:
|
||||
batch = await UserNotificationBatch.prisma().find_first(
|
||||
return await UserNotificationBatch.prisma().find_first(
|
||||
where={"userId": user_id, "type": notification_type},
|
||||
include={"Notifications": True},
|
||||
include={"notifications": True},
|
||||
)
|
||||
return UserNotificationBatchDTO.from_db(batch) if batch else None
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to get user notification batch for user {user_id} and type {notification_type}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_all_batches_by_type(
|
||||
notification_type: NotificationType,
|
||||
) -> list[UserNotificationBatchDTO]:
|
||||
try:
|
||||
batches = await UserNotificationBatch.prisma().find_many(
|
||||
where={
|
||||
"type": notification_type,
|
||||
"Notifications": {
|
||||
"some": {} # Only return batches with at least one notification
|
||||
},
|
||||
},
|
||||
include={"Notifications": True},
|
||||
)
|
||||
return [UserNotificationBatchDTO.from_db(batch) for batch in batches]
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to get all batches by type {notification_type}: {e}"
|
||||
) from e
|
||||
|
||||
@@ -1,338 +0,0 @@
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
|
||||
import prisma
|
||||
import pydantic
|
||||
from prisma import Json
|
||||
from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block import get_blocks
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
|
||||
# Mapping from user reason id to categories to search for when choosing agent to show
|
||||
REASON_MAPPING: dict[str, list[str]] = {
|
||||
"content_marketing": ["writing", "marketing", "creative"],
|
||||
"business_workflow_automation": ["business", "productivity"],
|
||||
"data_research": ["data", "research"],
|
||||
"ai_innovation": ["development", "research"],
|
||||
"personal_productivity": ["personal", "productivity"],
|
||||
}
|
||||
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
|
||||
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
|
||||
|
||||
user_credit = get_user_credit_model()
|
||||
|
||||
|
||||
class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
completedSteps: Optional[list[OnboardingStep]] = None
|
||||
notificationDot: Optional[bool] = None
|
||||
notified: Optional[list[OnboardingStep]] = None
|
||||
usageReason: Optional[str] = None
|
||||
integrations: Optional[list[str]] = None
|
||||
otherIntegrations: Optional[str] = None
|
||||
selectedStoreListingVersionId: Optional[str] = None
|
||||
agentInput: Optional[dict[str, Any]] = None
|
||||
onboardingAgentExecutionId: Optional[str] = None
|
||||
|
||||
|
||||
async def get_user_onboarding(user_id: str):
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": UserOnboardingCreateInput(userId=user_id),
|
||||
"update": {},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update: UserOnboardingUpdateInput = {}
|
||||
if data.completedSteps is not None:
|
||||
update["completedSteps"] = list(set(data.completedSteps))
|
||||
for step in (
|
||||
OnboardingStep.AGENT_NEW_RUN,
|
||||
OnboardingStep.GET_RESULTS,
|
||||
OnboardingStep.MARKETPLACE_ADD_AGENT,
|
||||
OnboardingStep.MARKETPLACE_RUN_AGENT,
|
||||
OnboardingStep.BUILDER_SAVE_AGENT,
|
||||
OnboardingStep.BUILDER_RUN_AGENT,
|
||||
):
|
||||
if step in data.completedSteps:
|
||||
await reward_user(user_id, step)
|
||||
if data.notificationDot is not None:
|
||||
update["notificationDot"] = data.notificationDot
|
||||
if data.notified is not None:
|
||||
update["notified"] = list(set(data.notified))
|
||||
if data.usageReason is not None:
|
||||
update["usageReason"] = data.usageReason
|
||||
if data.integrations is not None:
|
||||
update["integrations"] = data.integrations
|
||||
if data.otherIntegrations is not None:
|
||||
update["otherIntegrations"] = data.otherIntegrations
|
||||
if data.selectedStoreListingVersionId is not None:
|
||||
update["selectedStoreListingVersionId"] = data.selectedStoreListingVersionId
|
||||
if data.agentInput is not None:
|
||||
update["agentInput"] = Json(data.agentInput)
|
||||
if data.onboardingAgentExecutionId is not None:
|
||||
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
|
||||
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, **update},
|
||||
"update": update,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def reward_user(user_id: str, step: OnboardingStep):
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}-reward"):
|
||||
reward = 0
|
||||
match step:
|
||||
# Reward user when they clicked New Run during onboarding
|
||||
# This is because they need credits before scheduling a run (next step)
|
||||
case OnboardingStep.AGENT_NEW_RUN:
|
||||
reward = 300
|
||||
case OnboardingStep.GET_RESULTS:
|
||||
reward = 300
|
||||
case OnboardingStep.MARKETPLACE_ADD_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.MARKETPLACE_RUN_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_SAVE_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_RUN_AGENT:
|
||||
reward = 100
|
||||
|
||||
if reward == 0:
|
||||
return
|
||||
|
||||
onboarding = await get_user_onboarding(user_id)
|
||||
|
||||
# Skip if already rewarded
|
||||
if step in onboarding.rewardedFor:
|
||||
return
|
||||
|
||||
onboarding.rewardedFor.append(step)
|
||||
await user_credit.onboarding_reward(user_id, reward, step)
|
||||
await UserOnboarding.prisma().update(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"completedSteps": list(set(onboarding.completedSteps + [step])),
|
||||
"rewardedFor": onboarding.rewardedFor,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def clean_and_split(text: str) -> list[str]:
|
||||
"""
|
||||
Removes all special characters from a string, truncates it to 100 characters,
|
||||
and splits it by whitespace and commas.
|
||||
|
||||
Args:
|
||||
text (str): The input string.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of cleaned words.
|
||||
"""
|
||||
# Remove all special characters (keep only alphanumeric and whitespace)
|
||||
cleaned_text = re.sub(r"[^a-zA-Z0-9\s,]", "", text.strip()[:100])
|
||||
|
||||
# Split by whitespace and commas
|
||||
words = re.split(r"[\s,]+", cleaned_text)
|
||||
|
||||
# Remove empty strings from the list
|
||||
words = [word.lower() for word in words if word]
|
||||
|
||||
return words
|
||||
|
||||
|
||||
def calculate_points(
|
||||
agent, categories: list[str], custom: list[str], integrations: list[str]
|
||||
) -> int:
|
||||
"""
|
||||
Calculates the total points for an agent based on the specified criteria.
|
||||
|
||||
Args:
|
||||
agent: The agent object.
|
||||
categories (list[str]): List of categories to match.
|
||||
words (list[str]): List of words to match in the description.
|
||||
|
||||
Returns:
|
||||
int: Total points for the agent.
|
||||
"""
|
||||
points = 0
|
||||
|
||||
# 1. Category Matches
|
||||
matched_categories = sum(
|
||||
1 for category in categories if category in agent.categories
|
||||
)
|
||||
points += matched_categories * 100
|
||||
|
||||
# 2. Description Word Matches
|
||||
description_words = agent.description.split() # Split description into words
|
||||
matched_words = sum(1 for word in custom if word in description_words)
|
||||
points += matched_words * 100
|
||||
|
||||
matched_words = sum(1 for word in integrations if word in description_words)
|
||||
points += matched_words * 50
|
||||
|
||||
# 3. Featured Bonus
|
||||
if agent.featured:
|
||||
points += 50
|
||||
|
||||
# 4. Rating Bonus
|
||||
points += agent.rating * 10
|
||||
|
||||
# 5. Runs Bonus
|
||||
runs_points = min(agent.runs / 1000 * 100, 100) # Cap at 100 points
|
||||
points += runs_points
|
||||
|
||||
return int(points)
|
||||
|
||||
|
||||
def get_credentials_blocks() -> dict[str, str]:
|
||||
# Returns a dictionary of block id to credentials field name
|
||||
creds: dict[str, str] = {}
|
||||
blocks = get_blocks()
|
||||
for id, block in blocks.items():
|
||||
for field_name, field_info in block().input_schema.model_fields.items():
|
||||
if field_info.annotation == CredentialsMetaInput:
|
||||
creds[id] = field_name
|
||||
return creds
|
||||
|
||||
|
||||
CREDENTIALS_FIELDS: dict[str, str] = get_credentials_blocks()
|
||||
|
||||
|
||||
async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
user_onboarding = await get_user_onboarding(user_id)
|
||||
categories = REASON_MAPPING.get(user_onboarding.usageReason or "", [])
|
||||
|
||||
where_clause: dict[str, Any] = {}
|
||||
|
||||
custom = clean_and_split((user_onboarding.usageReason or "").lower())
|
||||
|
||||
if categories:
|
||||
where_clause["OR"] = [
|
||||
{"categories": {"has": category}} for category in categories
|
||||
]
|
||||
else:
|
||||
where_clause["OR"] = [
|
||||
{"description": {"contains": word, "mode": "insensitive"}}
|
||||
for word in custom
|
||||
]
|
||||
|
||||
where_clause["OR"] += [
|
||||
{"description": {"contains": word, "mode": "insensitive"}}
|
||||
for word in user_onboarding.integrations
|
||||
]
|
||||
|
||||
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause),
|
||||
order=[
|
||||
{"featured": "desc"},
|
||||
{"runs": "desc"},
|
||||
{"rating": "desc"},
|
||||
],
|
||||
take=100,
|
||||
)
|
||||
|
||||
agentListings = await prisma.models.StoreListingVersion.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": [agent.storeListingVersionId for agent in storeAgents]},
|
||||
},
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
|
||||
for listing in agentListings:
|
||||
agent = listing.AgentGraph
|
||||
if agent is None:
|
||||
continue
|
||||
graph = GraphModel.from_db(agent)
|
||||
# Remove agents with empty input schema
|
||||
if not graph.input_schema:
|
||||
storeAgents = [
|
||||
a for a in storeAgents if a.storeListingVersionId != listing.id
|
||||
]
|
||||
continue
|
||||
|
||||
# Remove agents with empty credentials
|
||||
# Get nodes from this agent that have credentials
|
||||
nodes = await prisma.models.AgentNode.prisma().find_many(
|
||||
where={
|
||||
"agentGraphId": agent.id,
|
||||
"agentBlockId": {"in": list(CREDENTIALS_FIELDS.keys())},
|
||||
},
|
||||
)
|
||||
for node in nodes:
|
||||
block_id = node.agentBlockId
|
||||
field_name = CREDENTIALS_FIELDS[block_id]
|
||||
# If there are no credentials or they are empty, remove the agent
|
||||
# FIXME ignores default values
|
||||
if (
|
||||
field_name not in node.constantInput
|
||||
or node.constantInput[field_name] is None
|
||||
):
|
||||
storeAgents = [
|
||||
a for a in storeAgents if a.storeListingVersionId != listing.id
|
||||
]
|
||||
break
|
||||
|
||||
# If there are less than 2 agents, add more agents to the list
|
||||
if len(storeAgents) < 2:
|
||||
storeAgents += await prisma.models.StoreAgent.prisma().find_many(
|
||||
where={
|
||||
"listing_id": {"not_in": [agent.listing_id for agent in storeAgents]},
|
||||
},
|
||||
order=[
|
||||
{"featured": "desc"},
|
||||
{"runs": "desc"},
|
||||
{"rating": "desc"},
|
||||
],
|
||||
take=2 - len(storeAgents),
|
||||
)
|
||||
|
||||
# Calculate points for the first X agents and choose the top 2
|
||||
agent_points = []
|
||||
for agent in storeAgents[:POINTS_AGENT_COUNT]:
|
||||
points = calculate_points(
|
||||
agent, categories, custom, user_onboarding.integrations
|
||||
)
|
||||
agent_points.append((agent, points))
|
||||
|
||||
agent_points.sort(key=lambda x: x[1], reverse=True)
|
||||
recommended_agents = [agent for agent, _ in agent_points[:2]]
|
||||
|
||||
return [
|
||||
StoreAgentDetails(
|
||||
store_listing_version_id=agent.storeListingVersionId,
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_video=agent.agent_video or "",
|
||||
agent_image=agent.agent_image,
|
||||
creator=agent.creator_username,
|
||||
creator_avatar=agent.creator_avatar,
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
categories=agent.categories,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
last_updated=agent.updated_at,
|
||||
)
|
||||
for agent in recommended_agents
|
||||
]
|
||||
|
||||
|
||||
async def onboarding_enabled() -> bool:
|
||||
count = await prisma.models.StoreAgent.prisma().count(take=MIN_AGENT_COUNT + 1)
|
||||
|
||||
# Onboading is enabled if there are at least 2 agents in the store
|
||||
return count >= MIN_AGENT_COUNT
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Generator, Generic, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -12,6 +14,13 @@ from backend.data import redis
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, datetime):
|
||||
return o.isoformat()
|
||||
return super().default(o)
|
||||
|
||||
|
||||
M = TypeVar("M", bound=BaseModel)
|
||||
|
||||
|
||||
@@ -23,14 +32,10 @@ class BaseRedisEventBus(Generic[M], ABC):
|
||||
def event_bus_name(self) -> str:
|
||||
pass
|
||||
|
||||
@property
|
||||
def Message(self) -> type["_EventPayloadWrapper[M]"]:
|
||||
return _EventPayloadWrapper[self.Model]
|
||||
|
||||
def _serialize_message(self, item: M, channel_key: str) -> tuple[str, str]:
|
||||
message = self.Message(payload=item).model_dump_json()
|
||||
message = json.dumps(item.model_dump(), cls=DateTimeEncoder)
|
||||
channel_name = f"{self.event_bus_name}/{channel_key}"
|
||||
logger.debug(f"[{channel_name}] Publishing an event to Redis {message}")
|
||||
logger.info(f"[{channel_name}] Publishing an event to Redis {message}")
|
||||
return message, channel_name
|
||||
|
||||
def _deserialize_message(self, msg: Any, channel_key: str) -> M | None:
|
||||
@@ -38,8 +43,9 @@ class BaseRedisEventBus(Generic[M], ABC):
|
||||
if msg["type"] != message_type:
|
||||
return None
|
||||
try:
|
||||
logger.debug(f"[{channel_key}] Consuming an event from Redis {msg['data']}")
|
||||
return self.Message.model_validate_json(msg["data"]).payload
|
||||
data = json.loads(msg["data"])
|
||||
logger.info(f"Consuming an event from Redis {data}")
|
||||
return self.Model(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse event result from Redis {msg} {e}")
|
||||
|
||||
@@ -51,16 +57,9 @@ class BaseRedisEventBus(Generic[M], ABC):
|
||||
return pubsub, full_channel_name
|
||||
|
||||
|
||||
class _EventPayloadWrapper(BaseModel, Generic[M]):
|
||||
"""
|
||||
Wrapper model to allow `RedisEventBus.Model` to be a discriminated union
|
||||
of multiple event types.
|
||||
"""
|
||||
|
||||
payload: M
|
||||
|
||||
|
||||
class RedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
Model: type[M]
|
||||
|
||||
@property
|
||||
def connection(self) -> redis.Redis:
|
||||
return redis.get_redis()
|
||||
@@ -86,6 +85,8 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
|
||||
|
||||
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
Model: type[M]
|
||||
|
||||
@property
|
||||
async def connection(self) -> redis.AsyncRedis:
|
||||
return await redis.get_redis_async()
|
||||
|
||||
@@ -4,18 +4,10 @@ from enum import Enum
|
||||
from typing import Awaitable, Optional
|
||||
|
||||
import aio_pika
|
||||
import aio_pika.exceptions as aio_ex
|
||||
import pika
|
||||
import pika.adapters.blocking_connection
|
||||
from pika.exceptions import AMQPError
|
||||
from pika.spec import BasicProperties
|
||||
from pydantic import BaseModel
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Settings
|
||||
@@ -169,12 +161,6 @@ class SyncRabbitMQ(RabbitMQBase):
|
||||
routing_key=queue.routing_key or queue.name,
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type((AMQPError, ConnectionError)),
|
||||
wait=wait_random_exponential(multiplier=1, max=5),
|
||||
stop=stop_after_attempt(5),
|
||||
reraise=True,
|
||||
)
|
||||
def publish_message(
|
||||
self,
|
||||
routing_key: str,
|
||||
@@ -272,12 +258,6 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
exchange, routing_key=queue.routing_key or queue.name
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type((aio_ex.AMQPError, ConnectionError)),
|
||||
wait=wait_random_exponential(multiplier=1, max=5),
|
||||
stop=stop_after_attempt(5),
|
||||
reraise=True,
|
||||
)
|
||||
async def publish_message(
|
||||
self,
|
||||
routing_key: str,
|
||||
|
||||
@@ -1,24 +1,19 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, cast
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from autogpt_libs.auth.models import DEFAULT_USER_ID
|
||||
from fastapi import HTTPException
|
||||
from prisma import Json
|
||||
from prisma.enums import NotificationType
|
||||
from prisma.models import User
|
||||
from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
|
||||
from prisma.types import UserUpdateInput
|
||||
|
||||
from backend.data.db import prisma
|
||||
from backend.data.model import UserIntegrations, UserMetadata, UserMetadataRaw
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util.encryption import JSONCryptor
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,11 +31,11 @@ async def get_or_create_user(user_data: dict) -> User:
|
||||
user = await prisma.user.find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
user = await prisma.user.create(
|
||||
data=UserCreateInput(
|
||||
id=user_id,
|
||||
email=user_email,
|
||||
name=user_data.get("user_metadata", {}).get("name"),
|
||||
)
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": user_email,
|
||||
"name": user_data.get("user_metadata", {}).get("name"),
|
||||
}
|
||||
)
|
||||
|
||||
return User.model_validate(user)
|
||||
@@ -63,14 +58,6 @@ async def get_user_email_by_id(user_id: str) -> Optional[str]:
|
||||
raise DatabaseError(f"Failed to get user email for user {user_id}: {e}") from e
|
||||
|
||||
|
||||
async def get_user_by_email(email: str) -> Optional[User]:
|
||||
try:
|
||||
user = await prisma.user.find_unique(where={"email": email})
|
||||
return User.model_validate(user) if user else None
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to get user by email {email}: {e}") from e
|
||||
|
||||
|
||||
async def update_user_email(user_id: str, email: str):
|
||||
try:
|
||||
await prisma.user.update(where={"id": user_id}, data={"email": email})
|
||||
@@ -84,11 +71,11 @@ async def create_default_user() -> Optional[User]:
|
||||
user = await prisma.user.find_unique(where={"id": DEFAULT_USER_ID})
|
||||
if not user:
|
||||
user = await prisma.user.create(
|
||||
data=UserCreateInput(
|
||||
id=DEFAULT_USER_ID,
|
||||
email="default@example.com",
|
||||
name="Default User",
|
||||
)
|
||||
data={
|
||||
"id": DEFAULT_USER_ID,
|
||||
"email": "default@example.com",
|
||||
"name": "Default User",
|
||||
}
|
||||
)
|
||||
return User.model_validate(user)
|
||||
|
||||
@@ -135,21 +122,16 @@ async def migrate_and_encrypt_user_integrations():
|
||||
"""Migrate integration credentials and OAuth states from metadata to integrations column."""
|
||||
users = await User.prisma().find_many(
|
||||
where={
|
||||
"metadata": cast(
|
||||
JsonFilter,
|
||||
{
|
||||
"path": ["integration_credentials"],
|
||||
"not": Json(
|
||||
{"a": "yolo"}
|
||||
), # bogus value works to check if key exists
|
||||
},
|
||||
)
|
||||
"metadata": {
|
||||
"path": ["integration_credentials"],
|
||||
"not": Json({"a": "yolo"}), # bogus value works to check if key exists
|
||||
} # type: ignore
|
||||
}
|
||||
)
|
||||
logger.info(f"Migrating integration credentials for {len(users)} users")
|
||||
|
||||
for user in users:
|
||||
raw_metadata = cast(dict, user.metadata)
|
||||
raw_metadata = cast(UserMetadataRaw, user.metadata)
|
||||
metadata = UserMetadata.model_validate(raw_metadata)
|
||||
|
||||
# Get existing integrations data
|
||||
@@ -165,6 +147,7 @@ async def migrate_and_encrypt_user_integrations():
|
||||
await update_user_integrations(user_id=user.id, data=integrations)
|
||||
|
||||
# Remove from metadata
|
||||
raw_metadata = dict(raw_metadata)
|
||||
raw_metadata.pop("integration_credentials", None)
|
||||
raw_metadata.pop("integration_oauth_states", None)
|
||||
|
||||
@@ -317,85 +300,3 @@ async def update_user_notification_preference(
|
||||
raise DatabaseError(
|
||||
f"Failed to update user notification preference for user {user_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def set_user_email_verification(user_id: str, verified: bool) -> None:
|
||||
"""Set the email verification status for a user."""
|
||||
try:
|
||||
await User.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"emailVerified": verified},
|
||||
)
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to set email verification status for user {user_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_user_email_verification(user_id: str) -> bool:
|
||||
"""Get the email verification status for a user."""
|
||||
try:
|
||||
user = await User.prisma().find_unique_or_raise(
|
||||
where={"id": user_id},
|
||||
)
|
||||
return user.emailVerified
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to get email verification status for user {user_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def generate_unsubscribe_link(user_id: str) -> str:
|
||||
"""Generate a link to unsubscribe from all notifications"""
|
||||
# Create an HMAC using a secret key
|
||||
secret_key = Settings().secrets.unsubscribe_secret_key
|
||||
signature = hmac.new(
|
||||
secret_key.encode("utf-8"), user_id.encode("utf-8"), hashlib.sha256
|
||||
).digest()
|
||||
|
||||
# Create a token that combines the user_id and signature
|
||||
token = base64.urlsafe_b64encode(
|
||||
f"{user_id}:{signature.hex()}".encode("utf-8")
|
||||
).decode("utf-8")
|
||||
logger.info(f"Generating unsubscribe link for user {user_id}")
|
||||
|
||||
base_url = Settings().config.platform_base_url
|
||||
return f"{base_url}/api/email/unsubscribe?token={quote_plus(token)}"
|
||||
|
||||
|
||||
async def unsubscribe_user_by_token(token: str) -> None:
|
||||
"""Unsubscribe a user from all notifications using the token"""
|
||||
try:
|
||||
# Decode the token
|
||||
decoded = base64.urlsafe_b64decode(token).decode("utf-8")
|
||||
user_id, received_signature_hex = decoded.split(":", 1)
|
||||
|
||||
# Verify the signature
|
||||
secret_key = Settings().secrets.unsubscribe_secret_key
|
||||
expected_signature = hmac.new(
|
||||
secret_key.encode("utf-8"), user_id.encode("utf-8"), hashlib.sha256
|
||||
).digest()
|
||||
|
||||
if not hmac.compare_digest(expected_signature.hex(), received_signature_hex):
|
||||
raise ValueError("Invalid token signature")
|
||||
|
||||
user = await get_user_by_id(user_id)
|
||||
await update_user_notification_preference(
|
||||
user.id,
|
||||
NotificationPreferenceDTO(
|
||||
email=user.email,
|
||||
daily_limit=0,
|
||||
preferences={
|
||||
NotificationType.AGENT_RUN: False,
|
||||
NotificationType.ZERO_BALANCE: False,
|
||||
NotificationType.LOW_BALANCE: False,
|
||||
NotificationType.BLOCK_EXECUTION_FAILED: False,
|
||||
NotificationType.CONTINUOUS_AGENT_ERROR: False,
|
||||
NotificationType.DAILY_SUMMARY: False,
|
||||
NotificationType.WEEKLY_SUMMARY: False,
|
||||
NotificationType.MONTHLY_SUMMARY: False,
|
||||
},
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to unsubscribe user by token {token}: {e}") from e
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from backend.app import run_processes
|
||||
from backend.executor import ExecutionManager
|
||||
from backend.executor import DatabaseManager, ExecutionManager
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Run all the processes required for the AutoGPT-server REST API.
|
||||
"""
|
||||
run_processes(ExecutionManager())
|
||||
run_processes(
|
||||
DatabaseManager(),
|
||||
ExecutionManager(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from .database import DatabaseManager
|
||||
from .manager import ExecutionManager
|
||||
from .scheduler import Scheduler
|
||||
from .scheduler import ExecutionScheduler
|
||||
|
||||
__all__ = [
|
||||
"DatabaseManager",
|
||||
"ExecutionManager",
|
||||
"Scheduler",
|
||||
"ExecutionScheduler",
|
||||
]
|
||||
|
||||
@@ -1,89 +1,73 @@
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Concatenate, Coroutine, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
ExecutionResult,
|
||||
NodeExecutionEntry,
|
||||
RedisExecutionEventBus,
|
||||
create_graph_execution,
|
||||
get_graph_execution,
|
||||
get_incomplete_node_executions,
|
||||
get_latest_node_execution,
|
||||
get_node_execution_results,
|
||||
update_graph_execution_start_time,
|
||||
get_execution_results,
|
||||
get_incomplete_executions,
|
||||
get_latest_execution,
|
||||
update_execution_status,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_stats,
|
||||
update_node_execution_status,
|
||||
update_node_execution_status_batch,
|
||||
upsert_execution_input,
|
||||
upsert_execution_output,
|
||||
)
|
||||
from backend.data.graph import (
|
||||
get_connected_output_nodes,
|
||||
get_graph,
|
||||
get_graph_metadata,
|
||||
get_node,
|
||||
)
|
||||
from backend.data.notifications import (
|
||||
create_or_add_to_user_notification_batch,
|
||||
empty_user_notification_batch,
|
||||
get_all_batches_by_type,
|
||||
get_user_notification_batch,
|
||||
get_user_notification_oldest_message_in_batch,
|
||||
)
|
||||
from backend.data.graph import get_graph, get_node
|
||||
from backend.data.user import (
|
||||
get_active_user_ids_in_timerange,
|
||||
get_user_email_by_id,
|
||||
get_user_email_verification,
|
||||
get_user_integrations,
|
||||
get_user_metadata,
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
update_user_metadata,
|
||||
)
|
||||
from backend.util.service import AppService, exposed_run_and_wait
|
||||
from backend.util.service import AppService, expose, register_pydantic_serializers
|
||||
from backend.util.settings import Config
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
config = Config()
|
||||
_user_credit_model = get_user_credit_model()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _spend_credits(
|
||||
user_id: str, cost: int, metadata: UsageTransactionMetadata
|
||||
) -> int:
|
||||
return await _user_credit_model.spend_credits(user_id, cost, metadata)
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
|
||||
def run_service(self) -> None:
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
self.run_and_wait(db.connect())
|
||||
super().run_service()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_db = True
|
||||
self.use_redis = True
|
||||
self.event_queue = RedisExecutionEventBus()
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return config.database_api_port
|
||||
|
||||
@expose
|
||||
def send_execution_update(self, execution_result: ExecutionResult):
|
||||
self.event_queue.publish(execution_result)
|
||||
|
||||
@staticmethod
|
||||
def exposed_run_and_wait(
|
||||
f: Callable[P, Coroutine[None, None, R]]
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
@expose
|
||||
@wraps(f)
|
||||
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
coroutine = f(*args, **kwargs)
|
||||
res = self.run_and_wait(coroutine)
|
||||
return res
|
||||
|
||||
# Register serializers for annotations on bare function
|
||||
register_pydantic_serializers(f)
|
||||
|
||||
return wrapper
|
||||
|
||||
# Executions
|
||||
get_graph_execution = exposed_run_and_wait(get_graph_execution)
|
||||
create_graph_execution = exposed_run_and_wait(create_graph_execution)
|
||||
get_node_execution_results = exposed_run_and_wait(get_node_execution_results)
|
||||
get_incomplete_node_executions = exposed_run_and_wait(
|
||||
get_incomplete_node_executions
|
||||
)
|
||||
get_latest_node_execution = exposed_run_and_wait(get_latest_node_execution)
|
||||
update_node_execution_status = exposed_run_and_wait(update_node_execution_status)
|
||||
update_node_execution_status_batch = exposed_run_and_wait(
|
||||
update_node_execution_status_batch
|
||||
)
|
||||
update_graph_execution_start_time = exposed_run_and_wait(
|
||||
update_graph_execution_start_time
|
||||
)
|
||||
get_execution_results = exposed_run_and_wait(get_execution_results)
|
||||
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
|
||||
get_latest_execution = exposed_run_and_wait(get_latest_execution)
|
||||
update_execution_status = exposed_run_and_wait(update_execution_status)
|
||||
update_graph_execution_stats = exposed_run_and_wait(update_graph_execution_stats)
|
||||
update_node_execution_stats = exposed_run_and_wait(update_node_execution_stats)
|
||||
upsert_execution_input = exposed_run_and_wait(upsert_execution_input)
|
||||
@@ -92,35 +76,16 @@ class DatabaseManager(AppService):
|
||||
# Graphs
|
||||
get_node = exposed_run_and_wait(get_node)
|
||||
get_graph = exposed_run_and_wait(get_graph)
|
||||
get_connected_output_nodes = exposed_run_and_wait(get_connected_output_nodes)
|
||||
get_graph_metadata = exposed_run_and_wait(get_graph_metadata)
|
||||
|
||||
# Credits
|
||||
spend_credits = exposed_run_and_wait(_spend_credits)
|
||||
user_credit_model = get_user_credit_model()
|
||||
spend_credits = cast(
|
||||
Callable[[Any, NodeExecutionEntry, float, float], int],
|
||||
exposed_run_and_wait(user_credit_model.spend_credits),
|
||||
)
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
get_user_metadata = exposed_run_and_wait(get_user_metadata)
|
||||
update_user_metadata = exposed_run_and_wait(update_user_metadata)
|
||||
get_user_integrations = exposed_run_and_wait(get_user_integrations)
|
||||
update_user_integrations = exposed_run_and_wait(update_user_integrations)
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = exposed_run_and_wait(
|
||||
get_active_user_ids_in_timerange
|
||||
)
|
||||
get_user_email_by_id = exposed_run_and_wait(get_user_email_by_id)
|
||||
get_user_email_verification = exposed_run_and_wait(get_user_email_verification)
|
||||
get_user_notification_preference = exposed_run_and_wait(
|
||||
get_user_notification_preference
|
||||
)
|
||||
|
||||
# Notifications - async
|
||||
create_or_add_to_user_notification_batch = exposed_run_and_wait(
|
||||
create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
|
||||
get_all_batches_by_type = exposed_run_and_wait(get_all_batches_by_type)
|
||||
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = exposed_run_and_wait(
|
||||
get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user