mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-20 04:28:09 -05:00
Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into zamilmajdy/fix-static-output-resolve
This commit is contained in:
24
.github/dependabot.yml
vendored
24
.github/dependabot.yml
vendored
@@ -129,30 +129,6 @@ updates:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# Submodules
|
||||
- package-ecosystem: "gitsubmodule"
|
||||
directory: "autogpt_platform/supabase"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 1
|
||||
target-branch: "dev"
|
||||
commit-message:
|
||||
prefix: "chore(platform/deps)"
|
||||
prefix-development: "chore(platform/deps-dev)"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# Docs
|
||||
- package-ecosystem: 'pip'
|
||||
directory: "docs/"
|
||||
|
||||
2
.github/workflows/platform-frontend-ci.yml
vendored
2
.github/workflows/platform-frontend-ci.yml
vendored
@@ -82,7 +82,7 @@ jobs:
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../supabase/docker/.env.example ../.env
|
||||
cp ../.env.example ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,6 +1,3 @@
|
||||
[submodule "classic/forge/tests/vcr_cassettes"]
|
||||
path = classic/forge/tests/vcr_cassettes
|
||||
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes
|
||||
[submodule "autogpt_platform/supabase"]
|
||||
path = autogpt_platform/supabase
|
||||
url = https://github.com/supabase/supabase.git
|
||||
|
||||
123
autogpt_platform/.env.example
Normal file
123
autogpt_platform/.env.example
Normal file
@@ -0,0 +1,123 @@
|
||||
############
|
||||
# Secrets
|
||||
# YOU MUST CHANGE THESE BEFORE GOING INTO PRODUCTION
|
||||
############
|
||||
|
||||
POSTGRES_PASSWORD=your-super-secret-and-long-postgres-password
|
||||
JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
|
||||
SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
DASHBOARD_USERNAME=supabase
|
||||
DASHBOARD_PASSWORD=this_password_is_insecure_and_should_be_updated
|
||||
SECRET_KEY_BASE=UpNVntn3cDxHJpq99YMc1T1AQgQpc8kfYTuRgBiYa15BLrx8etQoXz3gZv1/u2oq
|
||||
VAULT_ENC_KEY=your-encryption-key-32-chars-min
|
||||
|
||||
|
||||
############
|
||||
# Database - You can change these to any PostgreSQL database that has logical replication enabled.
|
||||
############
|
||||
|
||||
POSTGRES_HOST=db
|
||||
POSTGRES_DB=postgres
|
||||
POSTGRES_PORT=5432
|
||||
# default user is postgres
|
||||
|
||||
|
||||
############
|
||||
# Supavisor -- Database pooler
|
||||
############
|
||||
POOLER_PROXY_PORT_TRANSACTION=6543
|
||||
POOLER_DEFAULT_POOL_SIZE=20
|
||||
POOLER_MAX_CLIENT_CONN=100
|
||||
POOLER_TENANT_ID=your-tenant-id
|
||||
|
||||
|
||||
############
|
||||
# API Proxy - Configuration for the Kong Reverse proxy.
|
||||
############
|
||||
|
||||
KONG_HTTP_PORT=8000
|
||||
KONG_HTTPS_PORT=8443
|
||||
|
||||
|
||||
############
|
||||
# API - Configuration for PostgREST.
|
||||
############
|
||||
|
||||
PGRST_DB_SCHEMAS=public,storage,graphql_public
|
||||
|
||||
|
||||
############
|
||||
# Auth - Configuration for the GoTrue authentication server.
|
||||
############
|
||||
|
||||
## General
|
||||
SITE_URL=http://localhost:3000
|
||||
ADDITIONAL_REDIRECT_URLS=
|
||||
JWT_EXPIRY=3600
|
||||
DISABLE_SIGNUP=false
|
||||
API_EXTERNAL_URL=http://localhost:8000
|
||||
|
||||
## Mailer Config
|
||||
MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify"
|
||||
MAILER_URLPATHS_INVITE="/auth/v1/verify"
|
||||
MAILER_URLPATHS_RECOVERY="/auth/v1/verify"
|
||||
MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify"
|
||||
|
||||
## Email auth
|
||||
ENABLE_EMAIL_SIGNUP=true
|
||||
ENABLE_EMAIL_AUTOCONFIRM=false
|
||||
SMTP_ADMIN_EMAIL=admin@example.com
|
||||
SMTP_HOST=supabase-mail
|
||||
SMTP_PORT=2500
|
||||
SMTP_USER=fake_mail_user
|
||||
SMTP_PASS=fake_mail_password
|
||||
SMTP_SENDER_NAME=fake_sender
|
||||
ENABLE_ANONYMOUS_USERS=false
|
||||
|
||||
## Phone auth
|
||||
ENABLE_PHONE_SIGNUP=true
|
||||
ENABLE_PHONE_AUTOCONFIRM=true
|
||||
|
||||
|
||||
############
|
||||
# Studio - Configuration for the Dashboard
|
||||
############
|
||||
|
||||
STUDIO_DEFAULT_ORGANIZATION=Default Organization
|
||||
STUDIO_DEFAULT_PROJECT=Default Project
|
||||
|
||||
STUDIO_PORT=3000
|
||||
# replace if you intend to use Studio outside of localhost
|
||||
SUPABASE_PUBLIC_URL=http://localhost:8000
|
||||
|
||||
# Enable webp support
|
||||
IMGPROXY_ENABLE_WEBP_DETECTION=true
|
||||
|
||||
# Add your OpenAI API key to enable SQL Editor Assistant
|
||||
OPENAI_API_KEY=
|
||||
|
||||
|
||||
############
|
||||
# Functions - Configuration for Functions
|
||||
############
|
||||
# NOTE: VERIFY_JWT applies to all functions. Per-function VERIFY_JWT is not supported yet.
|
||||
FUNCTIONS_VERIFY_JWT=false
|
||||
|
||||
|
||||
############
|
||||
# Logs - Configuration for Logflare
|
||||
# Please refer to https://supabase.com/docs/reference/self-hosting-analytics/introduction
|
||||
############
|
||||
|
||||
LOGFLARE_LOGGER_BACKEND_API_KEY=your-super-secret-and-long-logflare-key
|
||||
|
||||
# Change vector.toml sinks to reflect this change
|
||||
LOGFLARE_API_KEY=your-super-secret-and-long-logflare-key
|
||||
|
||||
# Docker socket location - this value will differ depending on your OS
|
||||
DOCKER_SOCKET_LOCATION=/var/run/docker.sock
|
||||
|
||||
# Google Cloud Project details
|
||||
GOOGLE_PROJECT_ID=GOOGLE_PROJECT_ID
|
||||
GOOGLE_PROJECT_NUMBER=GOOGLE_PROJECT_NUMBER
|
||||
@@ -22,35 +22,29 @@ To run the AutoGPT Platform, follow these steps:
|
||||
|
||||
2. Run the following command:
|
||||
```
|
||||
git submodule update --init --recursive --progress
|
||||
cp .env.example .env
|
||||
```
|
||||
This command will initialize and update the submodules in the repository. The `supabase` folder will be cloned to the root directory.
|
||||
This command will copy the `.env.example` file to `.env`. You can modify the `.env` file to add your own environment variables.
|
||||
|
||||
3. Run the following command:
|
||||
```
|
||||
cp supabase/docker/.env.example .env
|
||||
```
|
||||
This command will copy the `.env.example` file to `.env` in the `supabase/docker` directory. You can modify the `.env` file to add your own environment variables.
|
||||
|
||||
4. Run the following command:
|
||||
```
|
||||
docker compose up -d
|
||||
```
|
||||
This command will start all the necessary backend services defined in the `docker-compose.yml` file in detached mode.
|
||||
|
||||
5. Navigate to `frontend` within the `autogpt_platform` directory:
|
||||
4. Navigate to `frontend` within the `autogpt_platform` directory:
|
||||
```
|
||||
cd frontend
|
||||
```
|
||||
You will need to run your frontend application separately on your local machine.
|
||||
|
||||
6. Run the following command:
|
||||
5. Run the following command:
|
||||
```
|
||||
cp .env.example .env.local
|
||||
```
|
||||
This command will copy the `.env.example` file to `.env.local` in the `frontend` directory. You can modify the `.env.local` within this folder to add your own environment variables for the frontend application.
|
||||
|
||||
7. Run the following command:
|
||||
6. Run the following command:
|
||||
```
|
||||
npm install
|
||||
npm run dev
|
||||
@@ -61,7 +55,7 @@ To run the AutoGPT Platform, follow these steps:
|
||||
yarn install && yarn dev
|
||||
```
|
||||
|
||||
8. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
7. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
|
||||
### Docker Compose Commands
|
||||
|
||||
|
||||
40
autogpt_platform/autogpt_libs/poetry.lock
generated
40
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1476,30 +1476,30 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.9.6"
|
||||
version = "0.9.10"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.9.6-py3-none-linux_armv6l.whl", hash = "sha256:2f218f356dd2d995839f1941322ff021c72a492c470f0b26a34f844c29cdf5ba"},
|
||||
{file = "ruff-0.9.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b908ff4df65dad7b251c9968a2e4560836d8f5487c2f0cc238321ed951ea0504"},
|
||||
{file = "ruff-0.9.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b109c0ad2ececf42e75fa99dc4043ff72a357436bb171900714a9ea581ddef83"},
|
||||
{file = "ruff-0.9.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1de4367cca3dac99bcbd15c161404e849bb0bfd543664db39232648dc00112dc"},
|
||||
{file = "ruff-0.9.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac3ee4d7c2c92ddfdaedf0bf31b2b176fa7aa8950efc454628d477394d35638b"},
|
||||
{file = "ruff-0.9.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5dc1edd1775270e6aa2386119aea692039781429f0be1e0949ea5884e011aa8e"},
|
||||
{file = "ruff-0.9.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4a091729086dffa4bd070aa5dab7e39cc6b9d62eb2bef8f3d91172d30d599666"},
|
||||
{file = "ruff-0.9.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1bbc6808bf7b15796cef0815e1dfb796fbd383e7dbd4334709642649625e7c5"},
|
||||
{file = "ruff-0.9.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:589d1d9f25b5754ff230dce914a174a7c951a85a4e9270613a2b74231fdac2f5"},
|
||||
{file = "ruff-0.9.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc61dd5131742e21103fbbdcad683a8813be0e3c204472d520d9a5021ca8b217"},
|
||||
{file = "ruff-0.9.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5e2d9126161d0357e5c8f30b0bd6168d2c3872372f14481136d13de9937f79b6"},
|
||||
{file = "ruff-0.9.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:68660eab1a8e65babb5229a1f97b46e3120923757a68b5413d8561f8a85d4897"},
|
||||
{file = "ruff-0.9.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c4cae6c4cc7b9b4017c71114115db0445b00a16de3bcde0946273e8392856f08"},
|
||||
{file = "ruff-0.9.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:19f505b643228b417c1111a2a536424ddde0db4ef9023b9e04a46ed8a1cb4656"},
|
||||
{file = "ruff-0.9.6-py3-none-win32.whl", hash = "sha256:194d8402bceef1b31164909540a597e0d913c0e4952015a5b40e28c146121b5d"},
|
||||
{file = "ruff-0.9.6-py3-none-win_amd64.whl", hash = "sha256:03482d5c09d90d4ee3f40d97578423698ad895c87314c4de39ed2af945633caa"},
|
||||
{file = "ruff-0.9.6-py3-none-win_arm64.whl", hash = "sha256:0e2bb706a2be7ddfea4a4af918562fdc1bcb16df255e5fa595bbd800ce322a5a"},
|
||||
{file = "ruff-0.9.6.tar.gz", hash = "sha256:81761592f72b620ec8fa1068a6fd00e98a5ebee342a3642efd84454f3031dca9"},
|
||||
{file = "ruff-0.9.10-py3-none-linux_armv6l.whl", hash = "sha256:eb4d25532cfd9fe461acc83498361ec2e2252795b4f40b17e80692814329e42d"},
|
||||
{file = "ruff-0.9.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:188a6638dab1aa9bb6228a7302387b2c9954e455fb25d6b4470cb0641d16759d"},
|
||||
{file = "ruff-0.9.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5284dcac6b9dbc2fcb71fdfc26a217b2ca4ede6ccd57476f52a587451ebe450d"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:47678f39fa2a3da62724851107f438c8229a3470f533894b5568a39b40029c0c"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:99713a6e2766b7a17147b309e8c915b32b07a25c9efd12ada79f217c9c778b3e"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:524ee184d92f7c7304aa568e2db20f50c32d1d0caa235d8ddf10497566ea1a12"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:df92aeac30af821f9acf819fc01b4afc3dfb829d2782884f8739fb52a8119a16"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de42e4edc296f520bb84954eb992a07a0ec5a02fecb834498415908469854a52"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d257f95b65806104b6b1ffca0ea53f4ef98454036df65b1eda3693534813ecd1"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b60dec7201c0b10d6d11be00e8f2dbb6f40ef1828ee75ed739923799513db24c"},
|
||||
{file = "ruff-0.9.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:d838b60007da7a39c046fcdd317293d10b845001f38bcb55ba766c3875b01e43"},
|
||||
{file = "ruff-0.9.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ccaf903108b899beb8e09a63ffae5869057ab649c1e9231c05ae354ebc62066c"},
|
||||
{file = "ruff-0.9.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f9567d135265d46e59d62dc60c0bfad10e9a6822e231f5b24032dba5a55be6b5"},
|
||||
{file = "ruff-0.9.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5f202f0d93738c28a89f8ed9eaba01b7be339e5d8d642c994347eaa81c6d75b8"},
|
||||
{file = "ruff-0.9.10-py3-none-win32.whl", hash = "sha256:bfb834e87c916521ce46b1788fbb8484966e5113c02df216680102e9eb960029"},
|
||||
{file = "ruff-0.9.10-py3-none-win_amd64.whl", hash = "sha256:f2160eeef3031bf4b17df74e307d4c5fb689a6f3a26a2de3f7ef4044e3c484f1"},
|
||||
{file = "ruff-0.9.10-py3-none-win_arm64.whl", hash = "sha256:5fd804c0327a5e5ea26615550e706942f348b197d5475ff34c19733aee4b2e69"},
|
||||
{file = "ruff-0.9.10.tar.gz", hash = "sha256:9bacb735d7bada9cfb0f2c227d3658fc443d90a727b47f206fb33f52f3c0eac7"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1929,4 +1929,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "f5cd0d1dafeb2b5c97d0ef27bef8a2235d4a1f54e3c60583d05ef582ac49c0e6"
|
||||
content-hash = "931772287f71c539575d601e6398423bf68e09ca87ae1a144057c7f5707cf978"
|
||||
|
||||
@@ -21,7 +21,7 @@ supabase = "^2.13.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
redis = "^5.2.1"
|
||||
ruff = "^0.9.6"
|
||||
ruff = "^0.9.10"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -2,88 +2,103 @@ import importlib
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Type, TypeVar
|
||||
|
||||
from backend.data.block import Block
|
||||
|
||||
# Dynamically load all modules under backend.blocks
|
||||
AVAILABLE_MODULES = []
|
||||
current_dir = Path(__file__).parent
|
||||
modules = [
|
||||
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
|
||||
for f in current_dir.rglob("*.py")
|
||||
if f.is_file() and f.name != "__init__.py"
|
||||
]
|
||||
for module in modules:
|
||||
if not re.match("^[a-z0-9_.]+$", module):
|
||||
raise ValueError(
|
||||
f"Block module {module} error: module name must be lowercase, "
|
||||
"and contain only alphanumeric characters and underscores."
|
||||
)
|
||||
|
||||
importlib.import_module(f".{module}", package=__name__)
|
||||
AVAILABLE_MODULES.append(module)
|
||||
|
||||
# Load all Block instances from the available modules
|
||||
AVAILABLE_BLOCKS: dict[str, Type[Block]] = {}
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import Block
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def all_subclasses(cls: Type[T]) -> list[Type[T]]:
|
||||
_AVAILABLE_BLOCKS: dict[str, type["Block"]] = {}
|
||||
|
||||
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
|
||||
if _AVAILABLE_BLOCKS:
|
||||
return _AVAILABLE_BLOCKS
|
||||
|
||||
# Dynamically load all modules under backend.blocks
|
||||
AVAILABLE_MODULES = []
|
||||
current_dir = Path(__file__).parent
|
||||
modules = [
|
||||
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
|
||||
for f in current_dir.rglob("*.py")
|
||||
if f.is_file() and f.name != "__init__.py"
|
||||
]
|
||||
for module in modules:
|
||||
if not re.match("^[a-z0-9_.]+$", module):
|
||||
raise ValueError(
|
||||
f"Block module {module} error: module name must be lowercase, "
|
||||
"and contain only alphanumeric characters and underscores."
|
||||
)
|
||||
|
||||
importlib.import_module(f".{module}", package=__name__)
|
||||
AVAILABLE_MODULES.append(module)
|
||||
|
||||
# Load all Block instances from the available modules
|
||||
for block_cls in all_subclasses(Block):
|
||||
class_name = block_cls.__name__
|
||||
|
||||
if class_name.endswith("Base"):
|
||||
continue
|
||||
|
||||
if not class_name.endswith("Block"):
|
||||
raise ValueError(
|
||||
f"Block class {class_name} does not end with 'Block'. "
|
||||
"If you are creating an abstract class, "
|
||||
"please name the class with 'Base' at the end"
|
||||
)
|
||||
|
||||
block = block_cls.create()
|
||||
|
||||
if not isinstance(block.id, str) or len(block.id) != 36:
|
||||
raise ValueError(
|
||||
f"Block ID {block.name} error: {block.id} is not a valid UUID"
|
||||
)
|
||||
|
||||
if block.id in _AVAILABLE_BLOCKS:
|
||||
raise ValueError(
|
||||
f"Block ID {block.name} error: {block.id} is already in use"
|
||||
)
|
||||
|
||||
input_schema = block.input_schema.model_fields
|
||||
output_schema = block.output_schema.model_fields
|
||||
|
||||
# Make sure `error` field is a string in the output schema
|
||||
if "error" in output_schema and output_schema["error"].annotation is not str:
|
||||
raise ValueError(
|
||||
f"{block.name} `error` field in output_schema must be a string"
|
||||
)
|
||||
|
||||
# Ensure all fields in input_schema and output_schema are annotated SchemaFields
|
||||
for field_name, field in [*input_schema.items(), *output_schema.items()]:
|
||||
if field.annotation is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} that is not annotated"
|
||||
)
|
||||
if field.json_schema_extra is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} not defined as SchemaField"
|
||||
)
|
||||
|
||||
for field in block.input_schema.model_fields.values():
|
||||
if field.annotation is bool and field.default not in (True, False):
|
||||
raise ValueError(
|
||||
f"{block.name} has a boolean field with no default value"
|
||||
)
|
||||
|
||||
_AVAILABLE_BLOCKS[block.id] = block_cls
|
||||
|
||||
return _AVAILABLE_BLOCKS
|
||||
|
||||
|
||||
__all__ = ["load_all_blocks"]
|
||||
|
||||
|
||||
def all_subclasses(cls: type[T]) -> list[type[T]]:
|
||||
subclasses = cls.__subclasses__()
|
||||
for subclass in subclasses:
|
||||
subclasses += all_subclasses(subclass)
|
||||
return subclasses
|
||||
|
||||
|
||||
for block_cls in all_subclasses(Block):
|
||||
name = block_cls.__name__
|
||||
|
||||
if block_cls.__name__.endswith("Base"):
|
||||
continue
|
||||
|
||||
if not block_cls.__name__.endswith("Block"):
|
||||
raise ValueError(
|
||||
f"Block class {block_cls.__name__} does not end with 'Block', If you are creating an abstract class, please name the class with 'Base' at the end"
|
||||
)
|
||||
|
||||
block = block_cls.create()
|
||||
|
||||
if not isinstance(block.id, str) or len(block.id) != 36:
|
||||
raise ValueError(f"Block ID {block.name} error: {block.id} is not a valid UUID")
|
||||
|
||||
if block.id in AVAILABLE_BLOCKS:
|
||||
raise ValueError(f"Block ID {block.name} error: {block.id} is already in use")
|
||||
|
||||
input_schema = block.input_schema.model_fields
|
||||
output_schema = block.output_schema.model_fields
|
||||
|
||||
# Make sure `error` field is a string in the output schema
|
||||
if "error" in output_schema and output_schema["error"].annotation is not str:
|
||||
raise ValueError(
|
||||
f"{block.name} `error` field in output_schema must be a string"
|
||||
)
|
||||
|
||||
# Make sure all fields in input_schema and output_schema are annotated and has a value
|
||||
for field_name, field in [*input_schema.items(), *output_schema.items()]:
|
||||
if field.annotation is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} that is not annotated"
|
||||
)
|
||||
if field.json_schema_extra is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} not defined as SchemaField"
|
||||
)
|
||||
|
||||
for field in block.input_schema.model_fields.values():
|
||||
if field.annotation is bool and field.default not in (True, False):
|
||||
raise ValueError(f"{block.name} has a boolean field with no default value")
|
||||
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
AVAILABLE_BLOCKS[block.id] = block_cls
|
||||
|
||||
__all__ = ["AVAILABLE_MODULES", "AVAILABLE_BLOCKS"]
|
||||
|
||||
@@ -1,26 +1,13 @@
|
||||
import enum
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
from typing import Any, List
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.file import MediaFile, store_media_file
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.text import TextFormatter
|
||||
from backend.util.type import convert
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import Link
|
||||
|
||||
formatter = TextFormatter()
|
||||
|
||||
|
||||
class FileStoreBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
@@ -101,29 +88,6 @@ class StoreValueBlock(Block):
|
||||
yield "output", input_data.data or input_data.input
|
||||
|
||||
|
||||
class PrintToConsoleBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(description="The text to print to the console.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="The status of the print operation.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
|
||||
description="Print the given text to the console, this is used for a debugging purpose.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=PrintToConsoleBlock.Input,
|
||||
output_schema=PrintToConsoleBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
test_output=("status", "printed"),
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
print(">>>>> Print: ", input_data.text)
|
||||
yield "status", "printed"
|
||||
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
@@ -184,188 +148,6 @@ class FindInDictionaryBlock(Block):
|
||||
yield "missing", input_data.input
|
||||
|
||||
|
||||
class AgentInputBlock(Block):
|
||||
"""
|
||||
This block is used to provide input to the graph.
|
||||
|
||||
It takes in a value, name, description, default values list and bool to limit selection to default values.
|
||||
|
||||
It Outputs the value passed as input.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
name: str = SchemaField(description="The name of the input.")
|
||||
value: Any = SchemaField(
|
||||
description="The value to be passed as input.",
|
||||
default=None,
|
||||
)
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the input.", default=None, advanced=True
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the input.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
placeholder_values: List[Any] = SchemaField(
|
||||
description="The placeholder values to be passed as input.",
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
limit_to_placeholder_values: bool = SchemaField(
|
||||
description="Whether to limit the selection to placeholder values.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to show the input in the advanced section, if the field is not required.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the input should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: Any = SchemaField(description="The value passed as input.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||
description="This block is used to provide input to the graph.",
|
||||
input_schema=AgentInputBlock.Input,
|
||||
output_schema=AgentInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "input_1",
|
||||
"description": "This is a test input.",
|
||||
"placeholder_values": [],
|
||||
"limit_to_placeholder_values": False,
|
||||
},
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "input_2",
|
||||
"description": "This is a test input.",
|
||||
"placeholder_values": ["Hello, World!"],
|
||||
"limit_to_placeholder_values": True,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", "Hello, World!"),
|
||||
("result", "Hello, World!"),
|
||||
],
|
||||
categories={BlockCategory.INPUT, BlockCategory.BASIC},
|
||||
block_type=BlockType.INPUT,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "result", input_data.value
|
||||
|
||||
|
||||
class AgentOutputBlock(Block):
|
||||
"""
|
||||
Records the output of the graph for users to see.
|
||||
|
||||
Behavior:
|
||||
If `format` is provided and the `value` is of a type that can be formatted,
|
||||
the block attempts to format the recorded_value using the `format`.
|
||||
If formatting fails or no `format` is provided, the raw `value` is output.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
value: Any = SchemaField(
|
||||
description="The value to be recorded as output.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(description="The name of the output.")
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the output.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the output.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="The format string to be used to format the recorded_value. Use Jinja2 syntax.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to treat the output as advanced.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the output should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The value recorded as output.")
|
||||
name: Any = SchemaField(description="The name of the value recorded as output.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||
description="Stores the output of the graph for users to see.",
|
||||
input_schema=AgentOutputBlock.Input,
|
||||
output_schema=AgentOutputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "output_1",
|
||||
"description": "This is a test output.",
|
||||
"format": "{{ output_1 }}!!",
|
||||
},
|
||||
{
|
||||
"value": "42",
|
||||
"name": "output_2",
|
||||
"description": "This is another test output.",
|
||||
"format": "{{ output_2 }}",
|
||||
},
|
||||
{
|
||||
"value": MockObject(value="!!", key="key"),
|
||||
"name": "output_3",
|
||||
"description": "This is a test output with a mock object.",
|
||||
"format": "{{ output_3 }}",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("output", "Hello, World!!!"),
|
||||
("output", "42"),
|
||||
("output", MockObject(value="!!", key="key")),
|
||||
],
|
||||
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
|
||||
block_type=BlockType.OUTPUT,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Attempts to format the recorded_value using the fmt_string if provided.
|
||||
If formatting fails or no fmt_string is given, returns the original recorded_value.
|
||||
"""
|
||||
if input_data.format:
|
||||
try:
|
||||
yield "output", formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
except Exception as e:
|
||||
yield "output", f"Error: {e}, {input_data.value}"
|
||||
else:
|
||||
yield "output", input_data.value
|
||||
yield "name", input_data.name
|
||||
|
||||
|
||||
class AddToDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
@@ -466,17 +248,6 @@ class AddToListBlock(Block):
|
||||
description="The position to insert the new entry. If not provided, the entry will be appended to the end of the list.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_missing_links(cls, data: BlockInput, links: List["Link"]) -> set[str]:
|
||||
return super().get_missing_links(
|
||||
data,
|
||||
[
|
||||
link
|
||||
for link in links
|
||||
if link.sink_name != "list" or link.sink_id != link.source_id
|
||||
],
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(
|
||||
description="The list with the new entry added."
|
||||
|
||||
@@ -8,6 +8,7 @@ from backend.data.block import (
|
||||
BlockSchema,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.compass import CompassWebhookType
|
||||
|
||||
|
||||
@@ -42,7 +43,7 @@ class CompassAITriggerBlock(Block):
|
||||
input_schema=CompassAITriggerBlock.Input,
|
||||
output_schema=CompassAITriggerBlock.Output,
|
||||
webhook_config=BlockManualWebhookConfig(
|
||||
provider="compass",
|
||||
provider=ProviderName.COMPASS,
|
||||
webhook_type=CompassWebhookType.TRANSCRIPTION,
|
||||
),
|
||||
test_input=[
|
||||
|
||||
@@ -12,6 +12,7 @@ from backend.data.block import (
|
||||
BlockWebhookConfig,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
@@ -123,7 +124,7 @@ class GithubPullRequestTriggerBlock(GitHubTriggerBase, Block):
|
||||
output_schema=GithubPullRequestTriggerBlock.Output,
|
||||
# --8<-- [start:example-webhook_config]
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider="github",
|
||||
provider=ProviderName.GITHUB,
|
||||
webhook_type=GithubWebhookType.REPO,
|
||||
resource_format="{repo}",
|
||||
event_filter_input="events",
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from requests.exceptions import HTTPError, RequestException
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import requests
|
||||
|
||||
logger = logging.getLogger(name=__name__)
|
||||
|
||||
|
||||
class HttpMethod(Enum):
|
||||
GET = "GET"
|
||||
@@ -43,8 +48,9 @@ class SendWebRequestBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: object = SchemaField(description="The response from the server")
|
||||
client_error: object = SchemaField(description="The error on 4xx status codes")
|
||||
server_error: object = SchemaField(description="The error on 5xx status codes")
|
||||
client_error: object = SchemaField(description="Errors on 4xx status codes")
|
||||
server_error: object = SchemaField(description="Errors on 5xx status codes")
|
||||
error: str = SchemaField(description="Errors for all other exceptions")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -68,20 +74,40 @@ class SendWebRequestBlock(Block):
|
||||
# we should send it as plain text instead
|
||||
input_data.json_format = False
|
||||
|
||||
response = requests.request(
|
||||
input_data.method.value,
|
||||
input_data.url,
|
||||
headers=input_data.headers,
|
||||
json=body if input_data.json_format else None,
|
||||
data=body if not input_data.json_format else None,
|
||||
)
|
||||
result = response.json() if input_data.json_format else response.text
|
||||
|
||||
if response.status_code // 100 == 2:
|
||||
try:
|
||||
response = requests.request(
|
||||
input_data.method.value,
|
||||
input_data.url,
|
||||
headers=input_data.headers,
|
||||
json=body if input_data.json_format else None,
|
||||
data=body if not input_data.json_format else None,
|
||||
)
|
||||
result = response.json() if input_data.json_format else response.text
|
||||
yield "response", result
|
||||
elif response.status_code // 100 == 4:
|
||||
yield "client_error", result
|
||||
elif response.status_code // 100 == 5:
|
||||
yield "server_error", result
|
||||
else:
|
||||
raise ValueError(f"Unexpected status code: {response.status_code}")
|
||||
|
||||
except HTTPError as e:
|
||||
# Handle error responses
|
||||
try:
|
||||
result = e.response.json() if input_data.json_format else str(e)
|
||||
except json.JSONDecodeError:
|
||||
result = str(e)
|
||||
|
||||
if 400 <= e.response.status_code < 500:
|
||||
yield "client_error", result
|
||||
elif 500 <= e.response.status_code < 600:
|
||||
yield "server_error", result
|
||||
else:
|
||||
error_msg = (
|
||||
"Unexpected status code "
|
||||
f"{e.response.status_code} '{e.response.reason}'"
|
||||
)
|
||||
logger.warning(error_msg)
|
||||
yield "error", error_msg
|
||||
|
||||
except RequestException as e:
|
||||
# Handle other request-related exceptions
|
||||
yield "error", str(e)
|
||||
|
||||
except Exception as e:
|
||||
# Catch any other unexpected exceptions
|
||||
yield "error", str(e)
|
||||
|
||||
552
autogpt_platform/backend/backend/blocks/io.py
Normal file
552
autogpt_platform/backend/backend/blocks/io.py
Normal file
@@ -0,0 +1,552 @@
|
||||
from datetime import date, time
|
||||
from typing import Any, Optional
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFile, store_media_file
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.settings import Config
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
formatter = TextFormatter()
|
||||
config = Config()
|
||||
|
||||
|
||||
class AgentInputBlock(Block):
|
||||
"""
|
||||
This block is used to provide input to the graph.
|
||||
|
||||
It takes in a value, name, description, default values list and bool to limit selection to default values.
|
||||
|
||||
It Outputs the value passed as input.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
name: str = SchemaField(description="The name of the input.")
|
||||
value: Any = SchemaField(
|
||||
description="The value to be passed as input.",
|
||||
default=None,
|
||||
)
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the input.", default=None, advanced=True
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the input.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="The placeholder values to be passed as input.",
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
limit_to_placeholder_values: bool = SchemaField(
|
||||
description="Whether to limit the selection to placeholder values.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to show the input in the advanced section, if the field is not required.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the input should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: Any = SchemaField(description="The value passed as input.")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
**{
|
||||
"id": "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||
"description": "Base block for user inputs.",
|
||||
"input_schema": AgentInputBlock.Input,
|
||||
"output_schema": AgentInputBlock.Output,
|
||||
"test_input": [
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "input_1",
|
||||
"description": "Example test input.",
|
||||
"placeholder_values": [],
|
||||
"limit_to_placeholder_values": False,
|
||||
},
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "input_2",
|
||||
"description": "Example test input with placeholders.",
|
||||
"placeholder_values": ["Hello, World!"],
|
||||
"limit_to_placeholder_values": True,
|
||||
},
|
||||
],
|
||||
"test_output": [
|
||||
("result", "Hello, World!"),
|
||||
("result", "Hello, World!"),
|
||||
],
|
||||
"categories": {BlockCategory.INPUT, BlockCategory.BASIC},
|
||||
"block_type": BlockType.INPUT,
|
||||
"static_output": True,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
|
||||
if input_data.value is not None:
|
||||
yield "result", input_data.value
|
||||
|
||||
|
||||
class AgentOutputBlock(Block):
|
||||
"""
|
||||
Records the output of the graph for users to see.
|
||||
|
||||
Behavior:
|
||||
If `format` is provided and the `value` is of a type that can be formatted,
|
||||
the block attempts to format the recorded_value using the `format`.
|
||||
If formatting fails or no `format` is provided, the raw `value` is output.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
value: Any = SchemaField(
|
||||
description="The value to be recorded as output.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(description="The name of the output.")
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the output.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the output.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="The format string to be used to format the recorded_value. Use Jinja2 syntax.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to treat the output as advanced.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the output should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The value recorded as output.")
|
||||
name: Any = SchemaField(description="The name of the value recorded as output.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||
description="Stores the output of the graph for users to see.",
|
||||
input_schema=AgentOutputBlock.Input,
|
||||
output_schema=AgentOutputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "output_1",
|
||||
"description": "This is a test output.",
|
||||
"format": "{{ output_1 }}!!",
|
||||
},
|
||||
{
|
||||
"value": "42",
|
||||
"name": "output_2",
|
||||
"description": "This is another test output.",
|
||||
"format": "{{ output_2 }}",
|
||||
},
|
||||
{
|
||||
"value": MockObject(value="!!", key="key"),
|
||||
"name": "output_3",
|
||||
"description": "This is a test output with a mock object.",
|
||||
"format": "{{ output_3 }}",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("output", "Hello, World!!!"),
|
||||
("output", "42"),
|
||||
("output", MockObject(value="!!", key="key")),
|
||||
],
|
||||
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
|
||||
block_type=BlockType.OUTPUT,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Attempts to format the recorded_value using the fmt_string if provided.
|
||||
If formatting fails or no fmt_string is given, returns the original recorded_value.
|
||||
"""
|
||||
if input_data.format:
|
||||
try:
|
||||
yield "output", formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
except Exception as e:
|
||||
yield "output", f"Error: {e}, {input_data.value}"
|
||||
else:
|
||||
yield "output", input_data.value
|
||||
yield "name", input_data.name
|
||||
|
||||
|
||||
class AgentShortTextInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[str] = SchemaField(
|
||||
description="Short text input.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
json_schema_extra={"format": "short-text"},
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="Short text result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7fcd3bcb-8e1b-4e69-903d-32d3d4a92158",
|
||||
description="Block for short text input (single-line).",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentShortTextInputBlock.Input,
|
||||
output_schema=AgentShortTextInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Hello",
|
||||
"name": "short_text_1",
|
||||
"description": "Short text example 1",
|
||||
"placeholder_values": [],
|
||||
"limit_to_placeholder_values": False,
|
||||
},
|
||||
{
|
||||
"value": "Quick test",
|
||||
"name": "short_text_2",
|
||||
"description": "Short text example 2",
|
||||
"placeholder_values": ["Quick test", "Another option"],
|
||||
"limit_to_placeholder_values": True,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", "Hello"),
|
||||
("result", "Quick test"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentLongTextInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[str] = SchemaField(
|
||||
description="Long text input (potentially multi-line).",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
json_schema_extra={"format": "long-text"},
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="Long text result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="90a56ffb-7024-4b2b-ab50-e26c5e5ab8ba",
|
||||
description="Block for long text input (multi-line).",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentLongTextInputBlock.Input,
|
||||
output_schema=AgentLongTextInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Lorem ipsum dolor sit amet...",
|
||||
"name": "long_text_1",
|
||||
"description": "Long text example 1",
|
||||
"placeholder_values": [],
|
||||
"limit_to_placeholder_values": False,
|
||||
},
|
||||
{
|
||||
"value": "Another multiline text input.",
|
||||
"name": "long_text_2",
|
||||
"description": "Long text example 2",
|
||||
"placeholder_values": ["Another multiline text input."],
|
||||
"limit_to_placeholder_values": True,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", "Lorem ipsum dolor sit amet..."),
|
||||
("result", "Another multiline text input."),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentNumberInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[int] = SchemaField(
|
||||
description="Number input.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: int = SchemaField(description="Number result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="96dae2bb-97a2-41c2-bd2f-13a3b5a8ea98",
|
||||
description="Block for number input.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentNumberInputBlock.Input,
|
||||
output_schema=AgentNumberInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": 42,
|
||||
"name": "number_input_1",
|
||||
"description": "Number example 1",
|
||||
"placeholder_values": [],
|
||||
"limit_to_placeholder_values": False,
|
||||
},
|
||||
{
|
||||
"value": 314,
|
||||
"name": "number_input_2",
|
||||
"description": "Number example 2",
|
||||
"placeholder_values": [314, 2718],
|
||||
"limit_to_placeholder_values": True,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", 42),
|
||||
("result", 314),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentDateInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[date] = SchemaField(
|
||||
description="Date input (YYYY-MM-DD).",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: date = SchemaField(description="Date result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7e198b09-4994-47db-8b4d-952d98241817",
|
||||
description="Block for date input.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentDateInputBlock.Input,
|
||||
output_schema=AgentDateInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
# If your system can parse JSON date strings to date objects
|
||||
"value": str(date(2025, 3, 19)),
|
||||
"name": "date_input_1",
|
||||
"description": "Example date input 1",
|
||||
},
|
||||
{
|
||||
"value": str(date(2023, 12, 31)),
|
||||
"name": "date_input_2",
|
||||
"description": "Example date input 2",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", date(2025, 3, 19)),
|
||||
("result", date(2023, 12, 31)),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentTimeInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[time] = SchemaField(
|
||||
description="Time input (HH:MM:SS).",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: time = SchemaField(description="Time result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2a1c757e-86cf-4c7e-aacf-060dc382e434",
|
||||
description="Block for time input.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentTimeInputBlock.Input,
|
||||
output_schema=AgentTimeInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": str(time(9, 30, 0)),
|
||||
"name": "time_input_1",
|
||||
"description": "Time example 1",
|
||||
},
|
||||
{
|
||||
"value": str(time(23, 59, 59)),
|
||||
"name": "time_input_2",
|
||||
"description": "Time example 2",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", time(9, 30, 0)),
|
||||
("result", time(23, 59, 59)),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentFileInputBlock(AgentInputBlock):
|
||||
"""
|
||||
A simplified file-upload block. In real usage, you might have a custom
|
||||
file type or handle binary data. Here, we'll store a string path as the example.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[MediaFile] = SchemaField(
|
||||
description="Path or reference to an uploaded file.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="File reference/path result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="95ead23f-8283-4654-aef3-10c053b74a31",
|
||||
description="Block for file upload input (string path for example).",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentFileInputBlock.Input,
|
||||
output_schema=AgentFileInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "data:image/png;base64,MQ==",
|
||||
"name": "file_upload_1",
|
||||
"description": "Example file upload 1",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", str),
|
||||
],
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
if not input_data.value:
|
||||
return
|
||||
|
||||
file_path = store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.value,
|
||||
return_content=False,
|
||||
)
|
||||
yield "result", file_path
|
||||
|
||||
|
||||
class AgentDropdownInputBlock(AgentInputBlock):
|
||||
"""
|
||||
A specialized text input block that relies on placeholder_values +
|
||||
limit_to_placeholder_values to present a dropdown.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[str] = SchemaField(
|
||||
description="Text selected from a dropdown.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="Possible values for the dropdown.",
|
||||
default=[],
|
||||
advanced=False,
|
||||
title="Dropdown Options",
|
||||
)
|
||||
limit_to_placeholder_values: bool = SchemaField(
|
||||
description="Whether the selection is limited to placeholder values.",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="Selected dropdown value.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="655d6fdf-a334-421c-b733-520549c07cd1",
|
||||
description="Block for dropdown text selection.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentDropdownInputBlock.Input,
|
||||
output_schema=AgentDropdownInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Option A",
|
||||
"name": "dropdown_1",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"limit_to_placeholder_values": True,
|
||||
"description": "Dropdown example 1",
|
||||
},
|
||||
{
|
||||
"value": "Option C",
|
||||
"name": "dropdown_2",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"limit_to_placeholder_values": True,
|
||||
"description": "Dropdown example 2",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", "Option A"),
|
||||
("result", "Option C"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentToggleInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: bool = SchemaField(
|
||||
description="Boolean toggle input.",
|
||||
default=False,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: bool = SchemaField(description="Boolean toggle result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="cbf36ab5-df4a-43b6-8a7f-f7ed8652116e",
|
||||
description="Block for boolean toggle input.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentToggleInputBlock.Input,
|
||||
output_schema=AgentToggleInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": True,
|
||||
"name": "toggle_1",
|
||||
"description": "Toggle example 1",
|
||||
},
|
||||
{
|
||||
"value": False,
|
||||
"name": "toggle_2",
|
||||
"description": "Toggle example 2",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", True),
|
||||
("result", False),
|
||||
],
|
||||
)
|
||||
@@ -142,7 +142,9 @@ class ScreenshotWebPageBlock(Block):
|
||||
return {
|
||||
"image": store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=f"data:image/{format.value};base64,{b64encode(response.content).decode('utf-8')}",
|
||||
file=MediaFile(
|
||||
f"data:image/{format.value};base64,{b64encode(response.content).decode('utf-8')}"
|
||||
),
|
||||
return_content=True,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ from backend.data.block import (
|
||||
BlockWebhookConfig,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import settings
|
||||
from backend.util.settings import AppEnvironment, BehaveAs
|
||||
|
||||
@@ -82,7 +83,7 @@ class Slant3DOrderWebhookBlock(Slant3DTriggerBase, Block):
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider="slant3d",
|
||||
provider=ProviderName.SLANT3D,
|
||||
webhook_type="orders", # Only one type for now
|
||||
resource_format="", # No resource format needed
|
||||
event_filter_input="events",
|
||||
|
||||
@@ -20,6 +20,7 @@ from prisma.models import AgentBlock
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.settings import Config
|
||||
|
||||
@@ -225,7 +226,7 @@ class BlockManualWebhookConfig(BaseModel):
|
||||
the user has to manually set up the webhook at the provider.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
provider: ProviderName
|
||||
"""The service provider that the webhook connects to"""
|
||||
|
||||
webhook_type: str
|
||||
@@ -461,9 +462,9 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
|
||||
|
||||
def get_blocks() -> dict[str, Type[Block]]:
|
||||
from backend.blocks import AVAILABLE_BLOCKS # noqa: E402
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
return AVAILABLE_BLOCKS
|
||||
return load_all_blocks()
|
||||
|
||||
|
||||
async def initialize_blocks() -> None:
|
||||
|
||||
@@ -15,14 +15,11 @@ from prisma.enums import (
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User
|
||||
from prisma.types import CreditTransactionCreateInput, CreditTransactionWhereInput
|
||||
from pydantic import BaseModel
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block import Block, BlockInput, get_block
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.data.execution import NodeExecutionEntry
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -31,6 +28,7 @@ from backend.data.model import (
|
||||
)
|
||||
from backend.data.notifications import NotificationEventDTO, RefundRequestData
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor.utils import UsageTransactionMetadata
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.service import get_service_client
|
||||
@@ -39,6 +37,7 @@ from backend.util.settings import Settings
|
||||
settings = Settings()
|
||||
stripe.api_key = settings.secrets.stripe_api_key
|
||||
logger = logging.getLogger(__name__)
|
||||
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
|
||||
|
||||
class UserCreditBase(ABC):
|
||||
@@ -90,20 +89,20 @@ class UserCreditBase(ABC):
|
||||
@abstractmethod
|
||||
async def spend_credits(
|
||||
self,
|
||||
entry: NodeExecutionEntry,
|
||||
data_size: float,
|
||||
run_time: float,
|
||||
user_id: str,
|
||||
cost: int,
|
||||
metadata: UsageTransactionMetadata,
|
||||
) -> int:
|
||||
"""
|
||||
Spend the credits for the user based on the block usage.
|
||||
Spend the credits for the user based on the cost.
|
||||
|
||||
Args:
|
||||
entry (NodeExecutionEntry): The node execution identifiers & data.
|
||||
data_size (float): The size of the data being processed.
|
||||
run_time (float): The time taken to run the block.
|
||||
user_id (str): The user ID.
|
||||
cost (int): The cost to spend.
|
||||
metadata (UsageTransactionMetadata): The metadata of the transaction.
|
||||
|
||||
Returns:
|
||||
int: amount of credit spent
|
||||
int: The remaining balance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -185,6 +184,14 @@ class UserCreditBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def create_billing_portal_session(user_id: str) -> str:
|
||||
session = stripe.billing_portal.Session.create(
|
||||
customer=await get_stripe_customer_id(user_id),
|
||||
return_url=base_url + "/profile/credits",
|
||||
)
|
||||
return session.url
|
||||
|
||||
@staticmethod
|
||||
def time_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
@@ -339,16 +346,6 @@ class UserCreditBase(ABC):
|
||||
return user_balance + amount, tx.transactionKey
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
graph_exec_id: str | None = None
|
||||
graph_id: str | None = None
|
||||
node_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
block_id: str | None = None
|
||||
block: str | None = None
|
||||
input: BlockInput | None = None
|
||||
|
||||
|
||||
class UserCredit(UserCreditBase):
|
||||
@thread_cached
|
||||
def notification_client(self) -> NotificationManager:
|
||||
@@ -369,89 +366,21 @@ class UserCredit(UserCreditBase):
|
||||
)
|
||||
)
|
||||
|
||||
def _block_usage_cost(
|
||||
self,
|
||||
block: Block,
|
||||
input_data: BlockInput,
|
||||
data_size: float,
|
||||
run_time: float,
|
||||
) -> tuple[int, BlockInput]:
|
||||
block_costs = BLOCK_COSTS.get(type(block))
|
||||
if not block_costs:
|
||||
return 0, {}
|
||||
|
||||
for block_cost in block_costs:
|
||||
if not self._is_cost_filter_match(block_cost.cost_filter, input_data):
|
||||
continue
|
||||
|
||||
if block_cost.cost_type == BlockCostType.RUN:
|
||||
return block_cost.cost_amount, block_cost.cost_filter
|
||||
|
||||
if block_cost.cost_type == BlockCostType.SECOND:
|
||||
return (
|
||||
int(run_time * block_cost.cost_amount),
|
||||
block_cost.cost_filter,
|
||||
)
|
||||
|
||||
if block_cost.cost_type == BlockCostType.BYTE:
|
||||
return (
|
||||
int(data_size * block_cost.cost_amount),
|
||||
block_cost.cost_filter,
|
||||
)
|
||||
|
||||
return 0, {}
|
||||
|
||||
def _is_cost_filter_match(
|
||||
self, cost_filter: BlockInput, input_data: BlockInput
|
||||
) -> bool:
|
||||
"""
|
||||
Filter rules:
|
||||
- If cost_filter is an object, then check if cost_filter is the subset of input_data
|
||||
- Otherwise, check if cost_filter is equal to input_data.
|
||||
- Undefined, null, and empty string are considered as equal.
|
||||
"""
|
||||
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
|
||||
return cost_filter == input_data
|
||||
|
||||
return all(
|
||||
(not input_data.get(k) and not v)
|
||||
or (input_data.get(k) and self._is_cost_filter_match(v, input_data[k]))
|
||||
for k, v in cost_filter.items()
|
||||
)
|
||||
|
||||
async def spend_credits(
|
||||
self,
|
||||
entry: NodeExecutionEntry,
|
||||
data_size: float,
|
||||
run_time: float,
|
||||
user_id: str,
|
||||
cost: int,
|
||||
metadata: UsageTransactionMetadata,
|
||||
) -> int:
|
||||
block = get_block(entry.block_id)
|
||||
if not block:
|
||||
raise ValueError(f"Block not found: {entry.block_id}")
|
||||
|
||||
cost, matching_filter = self._block_usage_cost(
|
||||
block=block, input_data=entry.data, data_size=data_size, run_time=run_time
|
||||
)
|
||||
if cost == 0:
|
||||
return 0
|
||||
|
||||
balance, _ = await self._add_transaction(
|
||||
user_id=entry.user_id,
|
||||
user_id=user_id,
|
||||
amount=-cost,
|
||||
transaction_type=CreditTransactionType.USAGE,
|
||||
metadata=Json(
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=entry.graph_exec_id,
|
||||
graph_id=entry.graph_id,
|
||||
node_id=entry.node_id,
|
||||
node_exec_id=entry.node_exec_id,
|
||||
block_id=entry.block_id,
|
||||
block=block.name,
|
||||
input=matching_filter,
|
||||
).model_dump()
|
||||
),
|
||||
metadata=Json(metadata.model_dump()),
|
||||
)
|
||||
user_id = entry.user_id
|
||||
|
||||
# Auto top-up if balance is below threshold.
|
||||
auto_top_up = await get_auto_top_up(user_id)
|
||||
@@ -461,7 +390,7 @@ class UserCredit(UserCreditBase):
|
||||
user_id=user_id,
|
||||
amount=auto_top_up.amount,
|
||||
# Avoid multiple auto top-ups within the same graph execution.
|
||||
key=f"AUTO-TOP-UP-{user_id}-{entry.graph_exec_id}",
|
||||
key=f"AUTO-TOP-UP-{user_id}-{metadata.graph_exec_id}",
|
||||
ceiling_balance=auto_top_up.threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -470,7 +399,7 @@ class UserCredit(UserCreditBase):
|
||||
f"Auto top-up failed for user {user_id}, balance: {balance}, amount: {auto_top_up.amount}, error: {e}"
|
||||
)
|
||||
|
||||
return cost
|
||||
return balance
|
||||
|
||||
async def top_up_credits(self, user_id: str, amount: int):
|
||||
await self._top_up_credits(user_id, amount)
|
||||
@@ -765,10 +694,8 @@ class UserCredit(UserCreditBase):
|
||||
ui_mode="hosted",
|
||||
payment_intent_data={"setup_future_usage": "off_session"},
|
||||
saved_payment_method_options={"payment_method_save": "enabled"},
|
||||
success_url=settings.config.frontend_base_url
|
||||
+ "/profile/credits?topup=success",
|
||||
cancel_url=settings.config.frontend_base_url
|
||||
+ "/profile/credits?topup=cancel",
|
||||
success_url=base_url + "/profile/credits?topup=success",
|
||||
cancel_url=base_url + "/profile/credits?topup=cancel",
|
||||
allow_promotion_codes=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,16 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from multiprocessing import Manager
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, Generic, Type, TypeVar
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Generator,
|
||||
Generic,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from prisma import Json
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
@@ -10,6 +19,7 @@ from prisma.models import (
|
||||
AgentNodeExecution,
|
||||
AgentNodeExecutionInputOutput,
|
||||
)
|
||||
from prisma.types import AgentNodeExecutionUpdateInput, AgentNodeExecutionWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
|
||||
@@ -68,6 +78,7 @@ class ExecutionQueue(Generic[T]):
|
||||
|
||||
|
||||
class ExecutionResult(BaseModel):
|
||||
user_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
graph_exec_id: str
|
||||
@@ -83,27 +94,28 @@ class ExecutionResult(BaseModel):
|
||||
end_time: datetime | None
|
||||
|
||||
@staticmethod
|
||||
def from_graph(graph: AgentGraphExecution):
|
||||
def from_graph(graph_exec: AgentGraphExecution):
|
||||
return ExecutionResult(
|
||||
graph_id=graph.agentGraphId,
|
||||
graph_version=graph.agentGraphVersion,
|
||||
graph_exec_id=graph.id,
|
||||
user_id=graph_exec.userId,
|
||||
graph_id=graph_exec.agentGraphId,
|
||||
graph_version=graph_exec.agentGraphVersion,
|
||||
graph_exec_id=graph_exec.id,
|
||||
node_exec_id="",
|
||||
node_id="",
|
||||
block_id="",
|
||||
status=graph.executionStatus,
|
||||
status=graph_exec.executionStatus,
|
||||
# TODO: Populate input_data & output_data from AgentNodeExecutions
|
||||
# Input & Output comes AgentInputBlock & AgentOutputBlock.
|
||||
input_data={},
|
||||
output_data={},
|
||||
add_time=graph.createdAt,
|
||||
queue_time=graph.createdAt,
|
||||
start_time=graph.startedAt,
|
||||
end_time=graph.updatedAt,
|
||||
add_time=graph_exec.createdAt,
|
||||
queue_time=graph_exec.createdAt,
|
||||
start_time=graph_exec.startedAt,
|
||||
end_time=graph_exec.updatedAt,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_db(execution: AgentNodeExecution):
|
||||
def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None):
|
||||
if execution.executionData:
|
||||
# Execution that has been queued for execution will persist its data.
|
||||
input_data = type.convert(execution.executionData, dict[str, Any])
|
||||
@@ -118,8 +130,15 @@ class ExecutionResult(BaseModel):
|
||||
output_data[data.name].append(type.convert(data.data, Type[Any]))
|
||||
|
||||
graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution
|
||||
if graph_execution:
|
||||
user_id = graph_execution.userId
|
||||
elif not user_id:
|
||||
raise ValueError(
|
||||
"AgentGraphExecution must be included or user_id passed in"
|
||||
)
|
||||
|
||||
return ExecutionResult(
|
||||
user_id=user_id,
|
||||
graph_id=graph_execution.agentGraphId if graph_execution else "",
|
||||
graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
|
||||
graph_exec_id=execution.agentGraphExecutionId,
|
||||
@@ -160,7 +179,8 @@ async def create_graph_execution(
|
||||
"create": [ # type: ignore
|
||||
{
|
||||
"agentNodeId": node_id,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
"queuedTime": datetime.now(tz=timezone.utc),
|
||||
"Input": {
|
||||
"create": [
|
||||
{"name": name, "data": Json(data)}
|
||||
@@ -178,7 +198,7 @@ async def create_graph_execution(
|
||||
)
|
||||
|
||||
return result.id, [
|
||||
ExecutionResult.from_db(execution)
|
||||
ExecutionResult.from_db(execution, result.userId)
|
||||
for execution in result.AgentNodeExecutions or []
|
||||
]
|
||||
|
||||
@@ -285,13 +305,19 @@ async def update_graph_execution_start_time(graph_exec_id: str) -> ExecutionResu
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
stats: GraphExecutionStats,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> ExecutionResult:
|
||||
data = stats.model_dump()
|
||||
if isinstance(data["error"], Exception):
|
||||
data = stats.model_dump() if stats else {}
|
||||
if isinstance(data.get("error"), Exception):
|
||||
data["error"] = str(data["error"])
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"OR": [
|
||||
{"executionStatus": ExecutionStatus.RUNNING},
|
||||
{"executionStatus": ExecutionStatus.QUEUED},
|
||||
],
|
||||
},
|
||||
data={
|
||||
"executionStatus": status,
|
||||
"stats": Json(data),
|
||||
@@ -313,6 +339,17 @@ async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionSta
|
||||
)
|
||||
|
||||
|
||||
async def update_execution_status_batch(
|
||||
node_exec_ids: list[str],
|
||||
status: ExecutionStatus,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
await AgentNodeExecution.prisma().update_many(
|
||||
where={"id": {"in": node_exec_ids}},
|
||||
data=_get_update_status_data(status, None, stats),
|
||||
)
|
||||
|
||||
|
||||
async def update_execution_status(
|
||||
node_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
@@ -322,20 +359,9 @@ async def update_execution_status(
|
||||
if status == ExecutionStatus.QUEUED and execution_data is None:
|
||||
raise ValueError("Execution data must be provided when queuing an execution.")
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
data = {
|
||||
**({"executionStatus": status}),
|
||||
**({"queuedTime": now} if status == ExecutionStatus.QUEUED else {}),
|
||||
**({"startedTime": now} if status == ExecutionStatus.RUNNING else {}),
|
||||
**({"endedTime": now} if status == ExecutionStatus.FAILED else {}),
|
||||
**({"endedTime": now} if status == ExecutionStatus.COMPLETED else {}),
|
||||
**({"executionData": Json(execution_data)} if execution_data else {}),
|
||||
**({"stats": Json(stats)} if stats else {}),
|
||||
}
|
||||
|
||||
res = await AgentNodeExecution.prisma().update(
|
||||
where={"id": node_exec_id},
|
||||
data=data, # type: ignore
|
||||
data=_get_update_status_data(status, execution_data, stats),
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
if not res:
|
||||
@@ -344,6 +370,29 @@ async def update_execution_status(
|
||||
return ExecutionResult.from_db(res)
|
||||
|
||||
|
||||
def _get_update_status_data(
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> AgentNodeExecutionUpdateInput:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
update_data: AgentNodeExecutionUpdateInput = {"executionStatus": status}
|
||||
|
||||
if status == ExecutionStatus.QUEUED:
|
||||
update_data["queuedTime"] = now
|
||||
elif status == ExecutionStatus.RUNNING:
|
||||
update_data["startedTime"] = now
|
||||
elif status in (ExecutionStatus.FAILED, ExecutionStatus.COMPLETED):
|
||||
update_data["endedTime"] = now
|
||||
|
||||
if execution_data:
|
||||
update_data["executionData"] = Json(execution_data)
|
||||
if stats:
|
||||
update_data["stats"] = Json(stats)
|
||||
|
||||
return update_data
|
||||
|
||||
|
||||
async def delete_execution(
|
||||
graph_exec_id: str, user_id: str, soft_delete: bool = True
|
||||
) -> None:
|
||||
@@ -361,41 +410,29 @@ async def delete_execution(
|
||||
)
|
||||
|
||||
|
||||
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
|
||||
async def get_execution_results(
|
||||
graph_exec_id: str,
|
||||
block_ids: list[str] | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[ExecutionResult]:
|
||||
where_clause: AgentNodeExecutionWhereInput = {
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
}
|
||||
if block_ids:
|
||||
where_clause["AgentNode"] = {"is": {"agentBlockId": {"in": block_ids}}}
|
||||
if statuses:
|
||||
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where={"agentGraphExecutionId": graph_exec_id},
|
||||
where=where_clause,
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
order=[
|
||||
{"queuedTime": "asc"},
|
||||
{"addedTime": "asc"}, # Fallback: Incomplete execs has no queuedTime.
|
||||
],
|
||||
take=limit,
|
||||
)
|
||||
res = [ExecutionResult.from_db(execution) for execution in executions]
|
||||
return res
|
||||
|
||||
|
||||
async def get_executions_in_timerange(
|
||||
user_id: str, start_time: str, end_time: str
|
||||
) -> list[ExecutionResult]:
|
||||
try:
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where={
|
||||
"startedAt": {
|
||||
"gte": datetime.fromisoformat(start_time),
|
||||
"lte": datetime.fromisoformat(end_time),
|
||||
},
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return [ExecutionResult.from_graph(execution) for execution in executions]
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to get executions in timerange {start_time} to {end_time} for user {user_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
@@ -550,7 +587,10 @@ async def get_output_from_links(
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
|
||||
},
|
||||
order={"queuedTime": "asc"},
|
||||
order=[
|
||||
{"queuedTime": "asc"},
|
||||
{"addedTime": "desc"},
|
||||
],
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
@@ -7,6 +6,7 @@ from typing import Any, Literal, Optional, Type
|
||||
|
||||
import prisma
|
||||
from prisma import Json
|
||||
from prisma.enums import SubmissionStatus
|
||||
from prisma.models import (
|
||||
AgentGraph,
|
||||
AgentGraphExecution,
|
||||
@@ -14,17 +14,17 @@ from prisma.models import (
|
||||
AgentNodeLink,
|
||||
StoreListingVersion,
|
||||
)
|
||||
from prisma.types import AgentGraphWhereInput
|
||||
from prisma.types import AgentGraphExecutionWhereInput, AgentGraphWhereInput
|
||||
from pydantic.fields import Field, computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock
|
||||
from backend.util import type
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.util import type as type_utils
|
||||
|
||||
from .block import BlockInput, BlockType, get_block, get_blocks
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
from .db import BaseDbModel, transaction
|
||||
from .execution import ExecutionResult, ExecutionStatus
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE, GRAPH_EXECUTION_INCLUDE
|
||||
from .integrations import Webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -71,13 +71,20 @@ class NodeModel(Node):
|
||||
|
||||
webhook: Optional[Webhook] = None
|
||||
|
||||
@property
|
||||
def block(self) -> Block[BlockSchema, BlockSchema]:
|
||||
block = get_block(self.block_id)
|
||||
if not block:
|
||||
raise ValueError(f"Block #{self.block_id} does not exist")
|
||||
return block
|
||||
|
||||
@staticmethod
|
||||
def from_db(node: AgentNode) -> "NodeModel":
|
||||
def from_db(node: AgentNode, for_export: bool = False) -> "NodeModel":
|
||||
obj = NodeModel(
|
||||
id=node.id,
|
||||
block_id=node.agentBlockId,
|
||||
input_default=type.convert(node.constantInput, dict[str, Any]),
|
||||
metadata=type.convert(node.metadata, dict[str, Any]),
|
||||
input_default=type_utils.convert(node.constantInput, dict[str, Any]),
|
||||
metadata=type_utils.convert(node.metadata, dict[str, Any]),
|
||||
graph_id=node.agentGraphId,
|
||||
graph_version=node.agentGraphVersion,
|
||||
webhook_id=node.webhookId,
|
||||
@@ -85,6 +92,8 @@ class NodeModel(Node):
|
||||
)
|
||||
obj.input_links = [Link.from_db(link) for link in node.Input or []]
|
||||
obj.output_links = [Link.from_db(link) for link in node.Output or []]
|
||||
if for_export:
|
||||
return obj.stripped_for_export()
|
||||
return obj
|
||||
|
||||
def is_triggered_by_event_type(self, event_type: str) -> bool:
|
||||
@@ -103,6 +112,51 @@ class NodeModel(Node):
|
||||
if event_filter[k] is True
|
||||
]
|
||||
|
||||
def stripped_for_export(self) -> "NodeModel":
|
||||
"""
|
||||
Returns a copy of the node model, stripped of any non-transferable properties
|
||||
"""
|
||||
stripped_node = self.model_copy(deep=True)
|
||||
# Remove credentials from node input
|
||||
if stripped_node.input_default:
|
||||
stripped_node.input_default = NodeModel._filter_secrets_from_node_input(
|
||||
stripped_node.input_default, self.block.input_schema.jsonschema()
|
||||
)
|
||||
|
||||
if (
|
||||
stripped_node.block.block_type == BlockType.INPUT
|
||||
and "value" in stripped_node.input_default
|
||||
):
|
||||
stripped_node.input_default["value"] = ""
|
||||
|
||||
# Remove webhook info
|
||||
stripped_node.webhook_id = None
|
||||
stripped_node.webhook = None
|
||||
|
||||
return stripped_node
|
||||
|
||||
@staticmethod
|
||||
def _filter_secrets_from_node_input(
|
||||
input_data: dict[str, Any], schema: dict[str, Any] | None
|
||||
) -> dict[str, Any]:
|
||||
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
|
||||
field_schemas = schema.get("properties", {}) if schema else {}
|
||||
result = {}
|
||||
for key, value in input_data.items():
|
||||
field_schema: dict | None = field_schemas.get(key)
|
||||
if (field_schema and field_schema.get("secret", False)) or any(
|
||||
sensitive_key in key.lower() for sensitive_key in sensitive_keys
|
||||
):
|
||||
# This is a secret value -> filter this key-value pair out
|
||||
continue
|
||||
elif isinstance(value, dict):
|
||||
result[key] = NodeModel._filter_secrets_from_node_input(
|
||||
value, field_schema
|
||||
)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
# Fix 2-way reference Node <-> Webhook
|
||||
Webhook.model_rebuild()
|
||||
@@ -129,7 +183,7 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
total_run_time = duration
|
||||
|
||||
try:
|
||||
stats = type.convert(_graph_exec.stats or {}, dict[str, Any])
|
||||
stats = type_utils.convert(_graph_exec.stats or {}, dict[str, Any])
|
||||
except ValueError:
|
||||
stats = {}
|
||||
|
||||
@@ -163,29 +217,41 @@ class GraphExecution(GraphExecutionMeta):
|
||||
|
||||
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
|
||||
|
||||
node_executions = [
|
||||
ExecutionResult.from_db(ne) for ne in _graph_exec.AgentNodeExecutions
|
||||
]
|
||||
node_executions = sorted(
|
||||
[
|
||||
ExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
],
|
||||
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
|
||||
)
|
||||
|
||||
inputs = {
|
||||
**{
|
||||
# inputs from Agent Input Blocks
|
||||
exec.input_data["name"]: exec.input_data["value"]
|
||||
exec.input_data["name"]: exec.input_data.get("value")
|
||||
for exec in node_executions
|
||||
if exec.block_id == _INPUT_BLOCK_ID
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type == BlockType.INPUT
|
||||
)
|
||||
},
|
||||
**{
|
||||
# input from webhook-triggered block
|
||||
"payload": exec.input_data["payload"]
|
||||
for exec in node_executions
|
||||
if (block := get_block(exec.block_id))
|
||||
and block.block_type in [BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL]
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type
|
||||
in [BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL]
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
outputs: dict[str, list] = defaultdict(list)
|
||||
for exec in node_executions:
|
||||
if exec.block_id == _OUTPUT_BLOCK_ID:
|
||||
if (
|
||||
block := get_block(exec.block_id)
|
||||
) and block.block_type == BlockType.OUTPUT:
|
||||
outputs[exec.input_data["name"]].append(
|
||||
exec.input_data.get("value", None)
|
||||
)
|
||||
@@ -201,10 +267,9 @@ class GraphExecution(GraphExecutionMeta):
|
||||
)
|
||||
|
||||
|
||||
class Graph(BaseDbModel):
|
||||
class BaseGraph(BaseDbModel):
|
||||
version: int = 1
|
||||
is_active: bool = True
|
||||
is_template: bool = False
|
||||
name: str
|
||||
description: str
|
||||
nodes: list[Node] = []
|
||||
@@ -267,6 +332,10 @@ class Graph(BaseDbModel):
|
||||
}
|
||||
|
||||
|
||||
class Graph(BaseGraph):
|
||||
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs, only used in export
|
||||
|
||||
|
||||
class GraphModel(Graph):
|
||||
user_id: str
|
||||
nodes: list[NodeModel] = [] # type: ignore
|
||||
@@ -290,31 +359,54 @@ class GraphModel(Graph):
|
||||
Reassigns all IDs in the graph to new UUIDs.
|
||||
This method can be used before storing a new graph to the database.
|
||||
"""
|
||||
|
||||
# Reassign Graph ID
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in self.nodes}
|
||||
if reassign_graph_id:
|
||||
self.id = str(uuid.uuid4())
|
||||
graph_id_map = {
|
||||
self.id: str(uuid.uuid4()),
|
||||
**{sub_graph.id: str(uuid.uuid4()) for sub_graph in self.sub_graphs},
|
||||
}
|
||||
else:
|
||||
graph_id_map = {}
|
||||
|
||||
self._reassign_ids(self, user_id, graph_id_map)
|
||||
for sub_graph in self.sub_graphs:
|
||||
self._reassign_ids(sub_graph, user_id, graph_id_map)
|
||||
|
||||
@staticmethod
|
||||
def _reassign_ids(
|
||||
graph: BaseGraph,
|
||||
user_id: str,
|
||||
graph_id_map: dict[str, str],
|
||||
):
|
||||
# Reassign Graph ID
|
||||
if graph.id in graph_id_map:
|
||||
graph.id = graph_id_map[graph.id]
|
||||
|
||||
# Reassign Node IDs
|
||||
for node in self.nodes:
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||
for node in graph.nodes:
|
||||
node.id = id_map[node.id]
|
||||
|
||||
# Reassign Link IDs
|
||||
for link in self.links:
|
||||
for link in graph.links:
|
||||
link.source_id = id_map[link.source_id]
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
# Reassign User IDs for agent blocks
|
||||
for node in self.nodes:
|
||||
for node in graph.nodes:
|
||||
if node.block_id != AgentExecutorBlock().id:
|
||||
continue
|
||||
node.input_default["user_id"] = user_id
|
||||
node.input_default.setdefault("data", {})
|
||||
|
||||
self.validate_graph()
|
||||
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
|
||||
node.input_default["graph_id"] = graph_id_map[graph_id]
|
||||
|
||||
def validate_graph(self, for_run: bool = False):
|
||||
self._validate_graph(self, for_run)
|
||||
for sub_graph in self.sub_graphs:
|
||||
self._validate_graph(sub_graph, for_run)
|
||||
|
||||
@staticmethod
|
||||
def _validate_graph(graph: BaseGraph, for_run: bool = False):
|
||||
def sanitize(name):
|
||||
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
if sanitized_name.startswith("tools_^_"):
|
||||
@@ -326,11 +418,11 @@ class GraphModel(Graph):
|
||||
agent_nodes = set()
|
||||
nodes_block = {
|
||||
node.id: block
|
||||
for node in self.nodes
|
||||
for node in graph.nodes
|
||||
if (block := get_block(node.block_id)) is not None
|
||||
}
|
||||
|
||||
for node in self.nodes:
|
||||
for node in graph.nodes:
|
||||
if (block := nodes_block.get(node.id)) is None:
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
@@ -343,11 +435,11 @@ class GraphModel(Graph):
|
||||
|
||||
input_links = defaultdict(list)
|
||||
|
||||
for link in self.links:
|
||||
for link in graph.links:
|
||||
input_links[link.sink_id].append(link)
|
||||
|
||||
# Nodes: required fields are filled or connected and dependencies are satisfied
|
||||
for node in self.nodes:
|
||||
for node in graph.nodes:
|
||||
if (block := nodes_block.get(node.id)) is None:
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
@@ -408,7 +500,7 @@ class GraphModel(Graph):
|
||||
f"Node {block.name} #{node.id}: Field `{field_name}` requires [{', '.join(missing_deps)}] to be set"
|
||||
)
|
||||
|
||||
node_map = {v.id: v for v in self.nodes}
|
||||
node_map = {v.id: v for v in graph.nodes}
|
||||
|
||||
def is_static_output_block(nid: str) -> bool:
|
||||
bid = node_map[nid].block_id
|
||||
@@ -416,7 +508,7 @@ class GraphModel(Graph):
|
||||
return b.static_output if b else False
|
||||
|
||||
# Links: links are connected and the connected pin data type are compatible.
|
||||
for link in self.links:
|
||||
for link in graph.links:
|
||||
source = (link.source_id, link.source_name)
|
||||
sink = (link.sink_id, link.sink_name)
|
||||
prefix = f"Link {source} <-> {sink}"
|
||||
@@ -457,18 +549,20 @@ class GraphModel(Graph):
|
||||
link.is_static = True # Each value block output should be static.
|
||||
|
||||
@staticmethod
|
||||
def from_db(graph: AgentGraph, for_export: bool = False):
|
||||
def from_db(
|
||||
graph: AgentGraph,
|
||||
for_export: bool = False,
|
||||
sub_graphs: list[AgentGraph] | None = None,
|
||||
):
|
||||
return GraphModel(
|
||||
id=graph.id,
|
||||
user_id=graph.userId,
|
||||
user_id=graph.userId if not for_export else "",
|
||||
version=graph.version,
|
||||
is_active=graph.isActive,
|
||||
is_template=graph.isTemplate,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
nodes=[
|
||||
NodeModel.from_db(GraphModel._process_node(node, for_export))
|
||||
for node in graph.AgentNodes or []
|
||||
NodeModel.from_db(node, for_export) for node in graph.AgentNodes or []
|
||||
],
|
||||
links=list(
|
||||
{
|
||||
@@ -477,59 +571,12 @@ class GraphModel(Graph):
|
||||
for link in (node.Input or []) + (node.Output or [])
|
||||
}
|
||||
),
|
||||
sub_graphs=[
|
||||
GraphModel.from_db(sub_graph, for_export)
|
||||
for sub_graph in sub_graphs or []
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_node(node: AgentNode, for_export: bool) -> AgentNode:
|
||||
if for_export:
|
||||
# Remove credentials from node input
|
||||
if node.constantInput:
|
||||
constant_input = type.convert(node.constantInput, dict[str, Any])
|
||||
constant_input = GraphModel._hide_node_input_credentials(constant_input)
|
||||
node.constantInput = Json(constant_input)
|
||||
|
||||
# Remove webhook info
|
||||
node.webhookId = None
|
||||
node.Webhook = None
|
||||
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def _hide_node_input_credentials(input_data: dict[str, Any]) -> dict[str, Any]:
|
||||
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
|
||||
result = {}
|
||||
for key, value in input_data.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = GraphModel._hide_node_input_credentials(value)
|
||||
elif isinstance(value, str) and any(
|
||||
sensitive_key in key.lower() for sensitive_key in sensitive_keys
|
||||
):
|
||||
# Skip this key-value pair in the result
|
||||
continue
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
def clean_graph(self):
|
||||
blocks = [block() for block in get_blocks().values()]
|
||||
|
||||
input_blocks = [
|
||||
node
|
||||
for node in self.nodes
|
||||
if next(
|
||||
(
|
||||
b
|
||||
for b in blocks
|
||||
if b.id == node.block_id and b.block_type == BlockType.INPUT
|
||||
),
|
||||
None,
|
||||
)
|
||||
]
|
||||
|
||||
for node in self.nodes:
|
||||
if any(input_block.id == node.id for input_block in input_blocks):
|
||||
node.input_default["value"] = ""
|
||||
|
||||
|
||||
# --------------------- CRUD functions --------------------- #
|
||||
|
||||
@@ -559,14 +606,14 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
|
||||
|
||||
async def get_graphs(
|
||||
user_id: str,
|
||||
filter_by: Literal["active", "template"] | None = "active",
|
||||
filter_by: Literal["active"] | None = "active",
|
||||
) -> list[GraphModel]:
|
||||
"""
|
||||
Retrieves graph metadata objects.
|
||||
Default behaviour is to get all currently active graphs.
|
||||
|
||||
Args:
|
||||
filter_by: An optional filter to either select templates or active graphs.
|
||||
filter_by: An optional filter to either select graphs.
|
||||
user_id: The ID of the user that owns the graph.
|
||||
|
||||
Returns:
|
||||
@@ -576,8 +623,6 @@ async def get_graphs(
|
||||
|
||||
if filter_by == "active":
|
||||
where_clause["isActive"] = True
|
||||
elif filter_by == "template":
|
||||
where_clause["isTemplate"] = True
|
||||
|
||||
graphs = await AgentGraph.prisma().find_many(
|
||||
where=where_clause,
|
||||
@@ -597,18 +642,20 @@ async def get_graphs(
|
||||
return graph_models
|
||||
|
||||
|
||||
# TODO: move execution stuff to .execution
|
||||
async def get_graphs_executions(user_id: str) -> list[GraphExecutionMeta]:
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where={"isDeleted": False, "userId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [GraphExecutionMeta.from_db(execution) for execution in executions]
|
||||
async def get_graph_executions(
|
||||
graph_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> list[GraphExecutionMeta]:
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
"isDeleted": False,
|
||||
}
|
||||
if user_id:
|
||||
where_filter["userId"] = user_id
|
||||
if graph_id:
|
||||
where_filter["agentGraphId"] = graph_id
|
||||
|
||||
|
||||
async def get_graph_executions(graph_id: str, user_id: str) -> list[GraphExecutionMeta]:
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where={"agentGraphId": graph_id, "isDeleted": False, "userId": user_id},
|
||||
where=where_filter,
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [GraphExecutionMeta.from_db(execution) for execution in executions]
|
||||
@@ -623,20 +670,13 @@ async def get_execution_meta(
|
||||
return GraphExecutionMeta.from_db(execution) if execution else None
|
||||
|
||||
|
||||
async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
|
||||
async def get_execution(
|
||||
user_id: str,
|
||||
execution_id: str,
|
||||
) -> GraphExecution | None:
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": execution_id, "isDeleted": False, "userId": user_id},
|
||||
include={
|
||||
"AgentNodeExecutions": {
|
||||
"include": {"AgentNode": True, "Input": True, "Output": True},
|
||||
"order_by": [
|
||||
{"queuedTime": "asc"},
|
||||
{ # Fallback: Incomplete execs has no queuedTime.
|
||||
"addedTime": "asc"
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return GraphExecution.from_db(execution) if execution else None
|
||||
|
||||
@@ -664,21 +704,18 @@ async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph
|
||||
description=graph.description or "",
|
||||
version=graph.version,
|
||||
is_active=graph.isActive,
|
||||
is_template=graph.isTemplate,
|
||||
)
|
||||
|
||||
|
||||
async def get_graph(
|
||||
graph_id: str,
|
||||
version: int | None = None,
|
||||
template: bool = False, # note: currently not in use; TODO: remove from DB entirely
|
||||
user_id: str | None = None,
|
||||
for_export: bool = False,
|
||||
) -> GraphModel | None:
|
||||
"""
|
||||
Retrieves a graph from the DB.
|
||||
Defaults to the version with `is_active` if `version` is not passed,
|
||||
or the latest version with `is_template` if `template=True`.
|
||||
Defaults to the version with `is_active` if `version` is not passed.
|
||||
|
||||
Returns `None` if the record is not found.
|
||||
"""
|
||||
@@ -688,8 +725,6 @@ async def get_graph(
|
||||
|
||||
if version is not None:
|
||||
where_clause["version"] = version
|
||||
elif not template:
|
||||
where_clause["isActive"] = True
|
||||
|
||||
graph = await AgentGraph.prisma().find_first(
|
||||
where=where_clause,
|
||||
@@ -706,16 +741,69 @@ async def get_graph(
|
||||
"agentId": graph_id,
|
||||
"agentVersion": version or graph.version,
|
||||
"isDeleted": False,
|
||||
"StoreListing": {"is": {"isApproved": True}},
|
||||
"submissionStatus": SubmissionStatus.APPROVED,
|
||||
}
|
||||
)
|
||||
)
|
||||
):
|
||||
return None
|
||||
|
||||
if for_export:
|
||||
sub_graphs = await get_sub_graphs(graph)
|
||||
return GraphModel.from_db(
|
||||
graph=graph,
|
||||
sub_graphs=sub_graphs,
|
||||
for_export=for_export,
|
||||
)
|
||||
|
||||
return GraphModel.from_db(graph, for_export)
|
||||
|
||||
|
||||
async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
"""
|
||||
Iteratively fetches all sub-graphs of a given graph, and flattens them into a list.
|
||||
This call involves a DB fetch in batch, breadth-first, per-level of graph depth.
|
||||
On each DB fetch we will only fetch the sub-graphs that are not already in the list.
|
||||
"""
|
||||
sub_graphs = {graph.id: graph}
|
||||
search_graphs = [graph]
|
||||
agent_block_id = AgentExecutorBlock().id
|
||||
|
||||
while search_graphs:
|
||||
sub_graph_ids = [
|
||||
(graph_id, graph_version)
|
||||
for graph in search_graphs
|
||||
for node in graph.AgentNodes or []
|
||||
if (
|
||||
node.AgentBlock
|
||||
and node.AgentBlock.id == agent_block_id
|
||||
and (graph_id := dict(node.constantInput).get("graph_id"))
|
||||
and (graph_version := dict(node.constantInput).get("graph_version"))
|
||||
)
|
||||
]
|
||||
if not sub_graph_ids:
|
||||
break
|
||||
|
||||
graphs = await AgentGraph.prisma().find_many(
|
||||
where={
|
||||
"OR": [
|
||||
{
|
||||
"id": graph_id,
|
||||
"version": graph_version,
|
||||
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
|
||||
}
|
||||
for graph_id, graph_version in sub_graph_ids
|
||||
] # type: ignore
|
||||
},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
|
||||
search_graphs = [graph for graph in graphs if graph.id not in sub_graphs]
|
||||
sub_graphs.update({graph.id: graph for graph in search_graphs})
|
||||
|
||||
return [g for g in sub_graphs.values() if g.id != graph.id]
|
||||
|
||||
|
||||
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
|
||||
links = await AgentNodeLink.prisma().find_many(
|
||||
where={"agentNodeSourceId": node_id},
|
||||
@@ -779,50 +867,56 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
|
||||
async with transaction() as tx:
|
||||
await __create_graph(tx, graph, user_id)
|
||||
|
||||
if created_graph := await get_graph(
|
||||
graph.id, graph.version, template=graph.is_template, user_id=user_id
|
||||
):
|
||||
if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
|
||||
return created_graph
|
||||
|
||||
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
|
||||
|
||||
|
||||
async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
await AgentGraph.prisma(tx).create(
|
||||
data={
|
||||
"id": graph.id,
|
||||
"version": graph.version,
|
||||
"name": graph.name,
|
||||
"description": graph.description,
|
||||
"isTemplate": graph.is_template,
|
||||
"isActive": graph.is_active,
|
||||
"userId": user_id,
|
||||
"AgentNodes": {
|
||||
"create": [
|
||||
{
|
||||
"id": node.id,
|
||||
"agentBlockId": node.block_id,
|
||||
"constantInput": Json(node.input_default),
|
||||
"metadata": Json(node.metadata),
|
||||
}
|
||||
for node in graph.nodes
|
||||
]
|
||||
},
|
||||
}
|
||||
graphs = [graph] + graph.sub_graphs
|
||||
|
||||
await AgentGraph.prisma(tx).create_many(
|
||||
data=[
|
||||
{
|
||||
"id": graph.id,
|
||||
"version": graph.version,
|
||||
"name": graph.name,
|
||||
"description": graph.description,
|
||||
"isActive": graph.is_active,
|
||||
"userId": user_id,
|
||||
}
|
||||
for graph in graphs
|
||||
]
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
AgentNodeLink.prisma(tx).create(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"sourceName": link.source_name,
|
||||
"sinkName": link.sink_name,
|
||||
"agentNodeSourceId": link.source_id,
|
||||
"agentNodeSinkId": link.sink_id,
|
||||
"isStatic": link.is_static,
|
||||
}
|
||||
)
|
||||
await AgentNode.prisma(tx).create_many(
|
||||
data=[
|
||||
{
|
||||
"id": node.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"agentBlockId": node.block_id,
|
||||
"constantInput": Json(node.input_default),
|
||||
"metadata": Json(node.metadata),
|
||||
"webhookId": node.webhook_id,
|
||||
}
|
||||
for graph in graphs
|
||||
for node in graph.nodes
|
||||
]
|
||||
)
|
||||
|
||||
await AgentNodeLink.prisma(tx).create_many(
|
||||
data=[
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"sourceName": link.source_name,
|
||||
"sinkName": link.sink_name,
|
||||
"agentNodeSourceId": link.source_id,
|
||||
"agentNodeSinkId": link.sink_id,
|
||||
"isStatic": link.is_static,
|
||||
}
|
||||
for graph in graphs
|
||||
for link in graph.links
|
||||
]
|
||||
)
|
||||
|
||||
@@ -18,6 +18,8 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
|
||||
MAX_NODE_EXECUTIONS_FETCH = 1000
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
"AgentNodeExecutions": {
|
||||
"include": {
|
||||
@@ -25,10 +27,17 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
},
|
||||
"order_by": [
|
||||
{"queuedTime": "desc"},
|
||||
# Fallback: Incomplete execs has no queuedTime.
|
||||
{"addedTime": "desc"},
|
||||
],
|
||||
"take": MAX_NODE_EXECUTIONS_FETCH, # Avoid loading excessive node executions.
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
|
||||
}
|
||||
|
||||
@@ -144,6 +144,7 @@ def SchemaField(
|
||||
depends_on: list[str] | None = None,
|
||||
image_upload: Optional[bool] = None,
|
||||
image_output: Optional[bool] = None,
|
||||
json_schema_extra: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
if default is PydanticUndefined and default_factory is None:
|
||||
@@ -151,7 +152,7 @@ def SchemaField(
|
||||
elif advanced is None:
|
||||
advanced = True
|
||||
|
||||
json_extra = {
|
||||
json_schema_extra = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"placeholder": placeholder,
|
||||
@@ -161,6 +162,7 @@ def SchemaField(
|
||||
"depends_on": depends_on,
|
||||
"image_upload": image_upload,
|
||||
"image_output": image_output,
|
||||
**(json_schema_extra or {}),
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
@@ -172,7 +174,7 @@ def SchemaField(
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
json_schema_extra=json_extra,
|
||||
json_schema_extra=json_schema_extra,
|
||||
**kwargs,
|
||||
) # type: ignore
|
||||
|
||||
@@ -413,7 +415,6 @@ class NodeExecutionStats(BaseModel):
|
||||
error: Optional[Exception | str] = None
|
||||
walltime: float = 0
|
||||
cputime: float = 0
|
||||
cost: float = 0
|
||||
input_size: int = 0
|
||||
output_size: int = 0
|
||||
llm_call_count: int = 0
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Generic, Optional, TypeVar, Union
|
||||
|
||||
@@ -18,7 +18,12 @@ from .db import transaction
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
T_co = TypeVar("T_co", bound="BaseNotificationData", covariant=True)
|
||||
NotificationDataType_co = TypeVar(
|
||||
"NotificationDataType_co", bound="BaseNotificationData", covariant=True
|
||||
)
|
||||
SummaryParamsType_co = TypeVar(
|
||||
"SummaryParamsType_co", bound="BaseSummaryParams", covariant=True
|
||||
)
|
||||
|
||||
|
||||
class QueueType(Enum):
|
||||
@@ -30,7 +35,8 @@ class QueueType(Enum):
|
||||
|
||||
|
||||
class BaseNotificationData(BaseModel):
|
||||
pass
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class AgentRunData(BaseNotificationData):
|
||||
@@ -47,6 +53,13 @@ class ZeroBalanceData(BaseNotificationData):
|
||||
last_transaction_time: datetime
|
||||
top_up_link: str
|
||||
|
||||
@field_validator("last_transaction_time")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: datetime):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class LowBalanceData(BaseNotificationData):
|
||||
agent_name: str = Field(..., description="Name of the agent")
|
||||
@@ -75,6 +88,13 @@ class ContinuousAgentErrorData(BaseNotificationData):
|
||||
error_time: datetime
|
||||
attempts: int = Field(..., description="Number of retry attempts made")
|
||||
|
||||
@field_validator("start_time", "error_time")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: datetime):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class BaseSummaryData(BaseNotificationData):
|
||||
total_credits_used: float
|
||||
@@ -87,18 +107,53 @@ class BaseSummaryData(BaseNotificationData):
|
||||
cost_breakdown: dict[str, float]
|
||||
|
||||
|
||||
class BaseSummaryParams(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class DailySummaryParams(BaseSummaryParams):
|
||||
date: datetime
|
||||
|
||||
@field_validator("date")
|
||||
def validate_timezone(cls, value):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class WeeklySummaryParams(BaseSummaryParams):
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
|
||||
@field_validator("start_date", "end_date")
|
||||
def validate_timezone(cls, value):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class DailySummaryData(BaseSummaryData):
|
||||
date: datetime
|
||||
|
||||
@field_validator("date")
|
||||
def validate_timezone(cls, value):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class WeeklySummaryData(BaseSummaryData):
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
week_number: int
|
||||
year: int
|
||||
|
||||
@field_validator("start_date", "end_date")
|
||||
def validate_timezone(cls, value):
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("datetime must have timezone information")
|
||||
return value
|
||||
|
||||
|
||||
class MonthlySummaryData(BaseSummaryData):
|
||||
class MonthlySummaryData(BaseNotificationData):
|
||||
month: int
|
||||
year: int
|
||||
|
||||
@@ -125,6 +180,7 @@ NotificationData = Annotated[
|
||||
WeeklySummaryData,
|
||||
DailySummaryData,
|
||||
RefundRequestData,
|
||||
BaseSummaryData,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
@@ -134,15 +190,22 @@ class NotificationEventDTO(BaseModel):
|
||||
user_id: str
|
||||
type: NotificationType
|
||||
data: dict
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
||||
retry_count: int = 0
|
||||
|
||||
|
||||
class NotificationEventModel(BaseModel, Generic[T_co]):
|
||||
class SummaryParamsEventDTO(BaseModel):
|
||||
user_id: str
|
||||
type: NotificationType
|
||||
data: T_co
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
data: dict
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
||||
|
||||
|
||||
class NotificationEventModel(BaseModel, Generic[NotificationDataType_co]):
|
||||
user_id: str
|
||||
type: NotificationType
|
||||
data: NotificationDataType_co
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
||||
|
||||
@property
|
||||
def strategy(self) -> QueueType:
|
||||
@@ -159,7 +222,14 @@ class NotificationEventModel(BaseModel, Generic[T_co]):
|
||||
return NotificationTypeOverride(self.type).template
|
||||
|
||||
|
||||
def get_data_type(
|
||||
class SummaryParamsEventModel(BaseModel, Generic[SummaryParamsType_co]):
|
||||
user_id: str
|
||||
type: NotificationType
|
||||
data: SummaryParamsType_co
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
||||
|
||||
|
||||
def get_notif_data_type(
|
||||
notification_type: NotificationType,
|
||||
) -> type[BaseNotificationData]:
|
||||
return {
|
||||
@@ -176,11 +246,20 @@ def get_data_type(
|
||||
}[notification_type]
|
||||
|
||||
|
||||
def get_summary_params_type(
|
||||
notification_type: NotificationType,
|
||||
) -> type[BaseSummaryParams]:
|
||||
return {
|
||||
NotificationType.DAILY_SUMMARY: DailySummaryParams,
|
||||
NotificationType.WEEKLY_SUMMARY: WeeklySummaryParams,
|
||||
}[notification_type]
|
||||
|
||||
|
||||
class NotificationBatch(BaseModel):
|
||||
user_id: str
|
||||
events: list[NotificationEvent]
|
||||
strategy: QueueType
|
||||
last_update: datetime = datetime.now()
|
||||
last_update: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
||||
|
||||
|
||||
class NotificationResult(BaseModel):
|
||||
@@ -258,12 +337,51 @@ class NotificationPreference(BaseModel):
|
||||
)
|
||||
daily_limit: int = 10 # Max emails per day
|
||||
emails_sent_today: int = 0
|
||||
last_reset_date: datetime = Field(default_factory=datetime.now)
|
||||
last_reset_date: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
class UserNotificationEventDTO(BaseModel):
|
||||
type: NotificationType
|
||||
data: dict
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@staticmethod
|
||||
def from_db(model: NotificationEvent) -> "UserNotificationEventDTO":
|
||||
return UserNotificationEventDTO(
|
||||
type=model.type,
|
||||
data=dict(model.data),
|
||||
created_at=model.createdAt,
|
||||
updated_at=model.updatedAt,
|
||||
)
|
||||
|
||||
|
||||
class UserNotificationBatchDTO(BaseModel):
|
||||
user_id: str
|
||||
type: NotificationType
|
||||
notifications: list[UserNotificationEventDTO]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@staticmethod
|
||||
def from_db(model: UserNotificationBatch) -> "UserNotificationBatchDTO":
|
||||
return UserNotificationBatchDTO(
|
||||
user_id=model.userId,
|
||||
type=model.type,
|
||||
notifications=[
|
||||
UserNotificationEventDTO.from_db(notification)
|
||||
for notification in model.Notifications or []
|
||||
],
|
||||
created_at=model.createdAt,
|
||||
updated_at=model.updatedAt,
|
||||
)
|
||||
|
||||
|
||||
def get_batch_delay(notification_type: NotificationType) -> timedelta:
|
||||
return {
|
||||
NotificationType.AGENT_RUN: timedelta(minutes=1),
|
||||
NotificationType.AGENT_RUN: timedelta(minutes=60),
|
||||
NotificationType.ZERO_BALANCE: timedelta(minutes=60),
|
||||
NotificationType.LOW_BALANCE: timedelta(minutes=60),
|
||||
NotificationType.BLOCK_EXECUTION_FAILED: timedelta(minutes=60),
|
||||
@@ -275,7 +393,7 @@ async def create_or_add_to_user_notification_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType,
|
||||
notification_data: NotificationEventModel,
|
||||
) -> UserNotificationBatch:
|
||||
) -> UserNotificationBatchDTO:
|
||||
try:
|
||||
logger.info(
|
||||
f"Creating or adding to notification batch for {user_id} with type {notification_type} and data {notification_data}"
|
||||
@@ -292,7 +410,7 @@ async def create_or_add_to_user_notification_batch(
|
||||
"type": notification_type,
|
||||
}
|
||||
},
|
||||
include={"notifications": True},
|
||||
include={"Notifications": True},
|
||||
)
|
||||
|
||||
if not existing_batch:
|
||||
@@ -309,11 +427,11 @@ async def create_or_add_to_user_notification_batch(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"type": notification_type,
|
||||
"notifications": {"connect": [{"id": notification_event.id}]},
|
||||
"Notifications": {"connect": [{"id": notification_event.id}]},
|
||||
},
|
||||
include={"notifications": True},
|
||||
include={"Notifications": True},
|
||||
)
|
||||
return resp
|
||||
return UserNotificationBatchDTO.from_db(resp)
|
||||
else:
|
||||
async with transaction() as tx:
|
||||
notification_event = await tx.notificationevent.create(
|
||||
@@ -327,15 +445,15 @@ async def create_or_add_to_user_notification_batch(
|
||||
resp = await tx.usernotificationbatch.update(
|
||||
where={"id": existing_batch.id},
|
||||
data={
|
||||
"notifications": {"connect": [{"id": notification_event.id}]}
|
||||
"Notifications": {"connect": [{"id": notification_event.id}]}
|
||||
},
|
||||
include={"notifications": True},
|
||||
include={"Notifications": True},
|
||||
)
|
||||
if not resp:
|
||||
raise DatabaseError(
|
||||
f"Failed to add notification event {notification_event.id} to existing batch {existing_batch.id}"
|
||||
)
|
||||
return resp
|
||||
return UserNotificationBatchDTO.from_db(resp)
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to create or add to notification batch for user {user_id} and type {notification_type}: {e}"
|
||||
@@ -345,18 +463,23 @@ async def create_or_add_to_user_notification_batch(
|
||||
async def get_user_notification_oldest_message_in_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType,
|
||||
) -> NotificationEvent | None:
|
||||
) -> UserNotificationEventDTO | None:
|
||||
try:
|
||||
batch = await UserNotificationBatch.prisma().find_first(
|
||||
where={"userId": user_id, "type": notification_type},
|
||||
include={"notifications": True},
|
||||
include={"Notifications": True},
|
||||
)
|
||||
if not batch:
|
||||
return None
|
||||
if not batch.notifications:
|
||||
if not batch.Notifications:
|
||||
return None
|
||||
sorted_notifications = sorted(batch.notifications, key=lambda x: x.createdAt)
|
||||
return sorted_notifications[0]
|
||||
sorted_notifications = sorted(batch.Notifications, key=lambda x: x.createdAt)
|
||||
|
||||
return (
|
||||
UserNotificationEventDTO.from_db(sorted_notifications[0])
|
||||
if sorted_notifications
|
||||
else None
|
||||
)
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to get user notification last message in batch for user {user_id} and type {notification_type}: {e}"
|
||||
@@ -391,12 +514,13 @@ async def empty_user_notification_batch(
|
||||
async def get_user_notification_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType,
|
||||
) -> UserNotificationBatch | None:
|
||||
) -> UserNotificationBatchDTO | None:
|
||||
try:
|
||||
return await UserNotificationBatch.prisma().find_first(
|
||||
batch = await UserNotificationBatch.prisma().find_first(
|
||||
where={"userId": user_id, "type": notification_type},
|
||||
include={"notifications": True},
|
||||
include={"Notifications": True},
|
||||
)
|
||||
return UserNotificationBatchDTO.from_db(batch) if batch else None
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to get user notification batch for user {user_id} and type {notification_type}: {e}"
|
||||
@@ -405,17 +529,18 @@ async def get_user_notification_batch(
|
||||
|
||||
async def get_all_batches_by_type(
|
||||
notification_type: NotificationType,
|
||||
) -> list[UserNotificationBatch]:
|
||||
) -> list[UserNotificationBatchDTO]:
|
||||
try:
|
||||
return await UserNotificationBatch.prisma().find_many(
|
||||
batches = await UserNotificationBatch.prisma().find_many(
|
||||
where={
|
||||
"type": notification_type,
|
||||
"notifications": {
|
||||
"Notifications": {
|
||||
"some": {} # Only return batches with at least one notification
|
||||
},
|
||||
},
|
||||
include={"notifications": True},
|
||||
include={"Notifications": True},
|
||||
)
|
||||
return [UserNotificationBatchDTO.from_db(batch) for batch in batches]
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to get all batches by type {notification_type}: {e}"
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
from backend.app import run_processes
|
||||
from backend.executor import DatabaseManager, ExecutionManager
|
||||
from backend.executor import ExecutionManager
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Run all the processes required for the AutoGPT-server REST API.
|
||||
"""
|
||||
run_processes(
|
||||
DatabaseManager(),
|
||||
ExecutionManager(),
|
||||
)
|
||||
run_processes(ExecutionManager())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
ExecutionResult,
|
||||
NodeExecutionEntry,
|
||||
RedisExecutionEventBus,
|
||||
create_graph_execution,
|
||||
get_execution_results,
|
||||
get_incomplete_executions,
|
||||
get_output_from_links,
|
||||
update_execution_status,
|
||||
update_execution_status_batch,
|
||||
update_graph_execution_start_time,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_stats,
|
||||
@@ -20,9 +20,20 @@ from backend.data.graph import (
|
||||
get_graph_metadata,
|
||||
get_node,
|
||||
)
|
||||
from backend.data.notifications import (
|
||||
create_or_add_to_user_notification_batch,
|
||||
empty_user_notification_batch,
|
||||
get_all_batches_by_type,
|
||||
get_user_notification_batch,
|
||||
get_user_notification_oldest_message_in_batch,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_active_user_ids_in_timerange,
|
||||
get_user_email_by_id,
|
||||
get_user_email_verification,
|
||||
get_user_integrations,
|
||||
get_user_metadata,
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
update_user_metadata,
|
||||
)
|
||||
@@ -33,8 +44,10 @@ config = Config()
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
|
||||
async def _spend_credits(entry: NodeExecutionEntry) -> int:
|
||||
return await _user_credit_model.spend_credits(entry, 0, 0)
|
||||
async def _spend_credits(
|
||||
user_id: str, cost: int, metadata: UsageTransactionMetadata
|
||||
) -> int:
|
||||
return await _user_credit_model.spend_credits(user_id, cost, metadata)
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
@@ -58,6 +71,7 @@ class DatabaseManager(AppService):
|
||||
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
|
||||
get_output_from_links = exposed_run_and_wait(get_output_from_links)
|
||||
update_execution_status = exposed_run_and_wait(update_execution_status)
|
||||
update_execution_status_batch = exposed_run_and_wait(update_execution_status_batch)
|
||||
update_graph_execution_start_time = exposed_run_and_wait(
|
||||
update_graph_execution_start_time
|
||||
)
|
||||
@@ -80,3 +94,24 @@ class DatabaseManager(AppService):
|
||||
update_user_metadata = exposed_run_and_wait(update_user_metadata)
|
||||
get_user_integrations = exposed_run_and_wait(get_user_integrations)
|
||||
update_user_integrations = exposed_run_and_wait(update_user_integrations)
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = exposed_run_and_wait(
|
||||
get_active_user_ids_in_timerange
|
||||
)
|
||||
get_user_email_by_id = exposed_run_and_wait(get_user_email_by_id)
|
||||
get_user_email_verification = exposed_run_and_wait(get_user_email_verification)
|
||||
get_user_notification_preference = exposed_run_and_wait(
|
||||
get_user_notification_preference
|
||||
)
|
||||
|
||||
# Notifications - async
|
||||
create_or_add_to_user_notification_batch = exposed_run_and_wait(
|
||||
create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
|
||||
get_all_batches_by_type = exposed_run_and_wait(get_all_batches_by_type)
|
||||
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = exposed_run_and_wait(
|
||||
get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
|
||||
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.basic import AgentOutputBlock
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
@@ -48,6 +48,11 @@ from backend.data.execution import (
|
||||
parse_execution_output,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Link, Node
|
||||
from backend.executor.utils import (
|
||||
UsageTransactionMetadata,
|
||||
block_usage_cost,
|
||||
execution_usage_cost,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util import json
|
||||
from backend.util.decorator import error_logged, time_measured
|
||||
@@ -206,11 +211,7 @@ def execute_node(
|
||||
extra_exec_kwargs[field_name] = credentials
|
||||
|
||||
output_size = 0
|
||||
cost = 0
|
||||
try:
|
||||
# Charge the user for the execution before running the block.
|
||||
cost = db_client.spend_credits(data)
|
||||
|
||||
outputs: dict[str, Any] = {}
|
||||
for output_name, output_data in node_block.execute(
|
||||
input_data, **extra_exec_kwargs
|
||||
@@ -266,7 +267,6 @@ def execute_node(
|
||||
)
|
||||
execution_stats.input_size = input_size
|
||||
execution_stats.output_size = output_size
|
||||
execution_stats.cost = cost
|
||||
|
||||
|
||||
def _enqueue_next_nodes(
|
||||
@@ -657,6 +657,53 @@ class Executor:
|
||||
|
||||
cls._handle_agent_run_notif(graph_exec, exec_stats)
|
||||
|
||||
@classmethod
|
||||
def _charge_usage(
|
||||
cls,
|
||||
node_exec: NodeExecutionEntry,
|
||||
execution_count: int,
|
||||
execution_stats: GraphExecutionStats,
|
||||
) -> int:
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return execution_count
|
||||
|
||||
cost, matching_filter = block_usage_cost(block=block, input_data=node_exec.data)
|
||||
if cost > 0:
|
||||
cls.db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
block=block.name,
|
||||
input=matching_filter,
|
||||
),
|
||||
)
|
||||
execution_stats.cost += cost
|
||||
|
||||
cost, execution_count = execution_usage_cost(execution_count)
|
||||
if cost > 0:
|
||||
cls.db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
input={
|
||||
"execution_count": execution_count,
|
||||
"charge": "Execution Cost",
|
||||
},
|
||||
),
|
||||
)
|
||||
execution_stats.cost += cost
|
||||
|
||||
return execution_count
|
||||
|
||||
@classmethod
|
||||
@time_measured
|
||||
def _on_graph_execution(
|
||||
@@ -691,14 +738,10 @@ class Executor:
|
||||
try:
|
||||
queue = ExecutionQueue[NodeExecutionEntry]()
|
||||
for node_exec in graph_exec.start_node_execs:
|
||||
exec_update = cls.db_client.update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.QUEUED, node_exec.data
|
||||
)
|
||||
cls.db_client.send_execution_update(exec_update)
|
||||
queue.add(node_exec)
|
||||
|
||||
exec_cost_counter = 0
|
||||
running_executions: dict[str, AsyncResult] = {}
|
||||
low_balance_error: Optional[InsufficientBalanceError] = None
|
||||
|
||||
def make_exec_callback(exec_data: NodeExecutionEntry):
|
||||
|
||||
@@ -708,17 +751,13 @@ class Executor:
|
||||
if not isinstance(result, NodeExecutionStats):
|
||||
return
|
||||
|
||||
nonlocal exec_stats, low_balance_error
|
||||
nonlocal exec_stats
|
||||
exec_stats.node_count += 1
|
||||
exec_stats.nodes_cputime += result.cputime
|
||||
exec_stats.nodes_walltime += result.walltime
|
||||
exec_stats.cost += result.cost
|
||||
if (err := result.error) and isinstance(err, Exception):
|
||||
exec_stats.node_error_count += 1
|
||||
|
||||
if isinstance(err, InsufficientBalanceError):
|
||||
low_balance_error = err
|
||||
|
||||
return callback
|
||||
|
||||
while not queue.empty():
|
||||
@@ -740,6 +779,30 @@ class Executor:
|
||||
f"Dispatching node execution {exec_data.node_exec_id} "
|
||||
f"for node {exec_data.node_id}",
|
||||
)
|
||||
|
||||
try:
|
||||
exec_cost_counter = cls._charge_usage(
|
||||
node_exec=exec_data,
|
||||
execution_count=exec_cost_counter + 1,
|
||||
execution_stats=exec_stats,
|
||||
)
|
||||
except InsufficientBalanceError as error:
|
||||
exec_id = exec_data.node_exec_id
|
||||
cls.db_client.upsert_execution_output(exec_id, "error", str(error))
|
||||
|
||||
exec_update = cls.db_client.update_execution_status(
|
||||
exec_id, ExecutionStatus.FAILED
|
||||
)
|
||||
cls.db_client.send_execution_update(exec_update)
|
||||
|
||||
cls._handle_low_balance_notif(
|
||||
graph_exec.user_id,
|
||||
graph_exec.graph_id,
|
||||
exec_stats,
|
||||
error,
|
||||
)
|
||||
raise
|
||||
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(queue, exec_data),
|
||||
@@ -763,32 +826,24 @@ class Executor:
|
||||
|
||||
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
|
||||
|
||||
if isinstance(low_balance_error, InsufficientBalanceError):
|
||||
cls._handle_low_balance_notif(
|
||||
graph_exec.user_id,
|
||||
graph_exec.graph_id,
|
||||
exec_stats,
|
||||
low_balance_error,
|
||||
)
|
||||
raise low_balance_error
|
||||
|
||||
except Exception as e:
|
||||
log_metadata.exception(
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {e}"
|
||||
)
|
||||
error = e
|
||||
finally:
|
||||
if error:
|
||||
log_metadata.error(
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
|
||||
)
|
||||
execution_status = ExecutionStatus.FAILED
|
||||
else:
|
||||
execution_status = ExecutionStatus.COMPLETED
|
||||
|
||||
if not cancel.is_set():
|
||||
finished = True
|
||||
cancel.set()
|
||||
cancel_thread.join()
|
||||
clean_exec_files(graph_exec.graph_exec_id)
|
||||
|
||||
return (
|
||||
exec_stats,
|
||||
ExecutionStatus.FAILED if error else ExecutionStatus.COMPLETED,
|
||||
error,
|
||||
)
|
||||
return exec_stats, execution_status, error
|
||||
|
||||
@classmethod
|
||||
def _handle_agent_run_notif(
|
||||
@@ -799,7 +854,10 @@ class Executor:
|
||||
metadata = cls.db_client.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
outputs = cls.db_client.get_execution_results(graph_exec.graph_exec_id)
|
||||
outputs = cls.db_client.get_execution_results(
|
||||
graph_exec.graph_exec_id,
|
||||
block_ids=[AgentOutputBlock().id],
|
||||
)
|
||||
|
||||
named_outputs = [
|
||||
{
|
||||
@@ -807,7 +865,6 @@ class Executor:
|
||||
for key, value in output.output_data.items()
|
||||
}
|
||||
for output in outputs
|
||||
if output.block_id == AgentOutputBlock().id
|
||||
]
|
||||
|
||||
event = NotificationEventDTO(
|
||||
@@ -1001,29 +1058,36 @@ class ExecutionManager(AppService):
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
raise Exception(
|
||||
logger.warning(
|
||||
f"Graph execution #{graph_exec_id} not active/running: "
|
||||
"possibly already completed/cancelled."
|
||||
)
|
||||
else:
|
||||
future, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
future.result()
|
||||
|
||||
future, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
if cancel_event.is_set():
|
||||
return
|
||||
|
||||
cancel_event.set()
|
||||
future.result()
|
||||
|
||||
# Update the status of the unfinished node executions
|
||||
node_execs = self.db_client.get_execution_results(graph_exec_id)
|
||||
# Update the status of the graph & node executions
|
||||
self.db_client.update_graph_execution_stats(
|
||||
graph_exec_id,
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
node_execs = self.db_client.get_execution_results(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
self.db_client.update_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
for node_exec in node_execs:
|
||||
if node_exec.status not in (
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.FAILED,
|
||||
):
|
||||
exec_update = self.db_client.update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.TERMINATED
|
||||
)
|
||||
self.db_client.send_execution_update(exec_update)
|
||||
node_exec.status = ExecutionStatus.TERMINATED
|
||||
self.db_client.send_execution_update(node_exec)
|
||||
|
||||
def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
@@ -5,6 +5,7 @@ from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
|
||||
from apscheduler.job import Job as JobObj
|
||||
from apscheduler.jobstores.memory import MemoryJobStore
|
||||
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
|
||||
from apscheduler.schedulers.blocking import BlockingScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
@@ -93,9 +94,18 @@ def process_existing_batches(**kwargs):
|
||||
logger.exception(f"Error processing existing batches: {e}")
|
||||
|
||||
|
||||
def process_weekly_summary(**kwargs):
|
||||
try:
|
||||
log("Processing weekly summary")
|
||||
get_notification_client().queue_weekly_summary()
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing weekly summary: {e}")
|
||||
|
||||
|
||||
class Jobstores(Enum):
|
||||
EXECUTION = "execution"
|
||||
BATCHED_NOTIFICATIONS = "batched_notifications"
|
||||
WEEKLY_NOTIFICATIONS = "weekly_notifications"
|
||||
|
||||
|
||||
class ExecutionJobArgs(BaseModel):
|
||||
@@ -189,6 +199,8 @@ class Scheduler(AppService):
|
||||
metadata=MetaData(schema=db_schema),
|
||||
tablename="apscheduler_jobs_batched_notifications",
|
||||
),
|
||||
# These don't really need persistence
|
||||
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
|
||||
}
|
||||
)
|
||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||
@@ -242,6 +254,9 @@ class Scheduler(AppService):
|
||||
) -> list[ExecutionJobInfo]:
|
||||
schedules = []
|
||||
for job in self.scheduler.get_jobs(jobstore=Jobstores.EXECUTION.value):
|
||||
logger.info(
|
||||
f"Found job {job.id} with cron schedule {job.trigger} and args {job.kwargs}"
|
||||
)
|
||||
job_args = ExecutionJobArgs(**job.kwargs)
|
||||
if (
|
||||
job.next_run_time is not None
|
||||
@@ -271,3 +286,21 @@ class Scheduler(AppService):
|
||||
)
|
||||
log(f"Added job {job.id} with cron schedule '{cron}' input data: {data}")
|
||||
return NotificationJobInfo.from_db(job_args, job)
|
||||
|
||||
@expose
|
||||
def add_weekly_notification_schedule(self, cron: str) -> NotificationJobInfo:
|
||||
|
||||
job = self.scheduler.add_job(
|
||||
process_weekly_summary,
|
||||
CronTrigger.from_crontab(cron),
|
||||
kwargs={},
|
||||
replace_existing=True,
|
||||
jobstore=Jobstores.WEEKLY_NOTIFICATIONS.value,
|
||||
)
|
||||
log(f"Added job {job.id} with cron schedule '{cron}'")
|
||||
return NotificationJobInfo.from_db(
|
||||
NotificationJobArgs(
|
||||
cron=cron, notification_types=[NotificationType.WEEKLY_SUMMARY]
|
||||
),
|
||||
job,
|
||||
)
|
||||
|
||||
97
autogpt_platform/backend/backend/executor/utils.py
Normal file
97
autogpt_platform/backend/backend/executor/utils.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockInput
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCostType
|
||||
from backend.util.settings import Config
|
||||
|
||||
config = Config()
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
graph_exec_id: str | None = None
|
||||
graph_id: str | None = None
|
||||
node_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
block_id: str | None = None
|
||||
block: str | None = None
|
||||
input: BlockInput | None = None
|
||||
|
||||
|
||||
def execution_usage_cost(execution_count: int) -> tuple[int, int]:
|
||||
"""
|
||||
Calculate the cost of executing a graph based on the number of executions.
|
||||
|
||||
Args:
|
||||
execution_count: Number of executions
|
||||
|
||||
Returns:
|
||||
Tuple of cost amount and remaining execution count
|
||||
"""
|
||||
return (
|
||||
execution_count
|
||||
// config.execution_cost_count_threshold
|
||||
* config.execution_cost_per_threshold,
|
||||
execution_count % config.execution_cost_count_threshold,
|
||||
)
|
||||
|
||||
|
||||
def block_usage_cost(
|
||||
block: Block,
|
||||
input_data: BlockInput,
|
||||
data_size: float = 0,
|
||||
run_time: float = 0,
|
||||
) -> tuple[int, BlockInput]:
|
||||
"""
|
||||
Calculate the cost of using a block based on the input data and the block type.
|
||||
|
||||
Args:
|
||||
block: Block object
|
||||
input_data: Input data for the block
|
||||
data_size: Size of the input data in bytes
|
||||
run_time: Execution time of the block in seconds
|
||||
|
||||
Returns:
|
||||
Tuple of cost amount and cost filter
|
||||
"""
|
||||
block_costs = BLOCK_COSTS.get(type(block))
|
||||
if not block_costs:
|
||||
return 0, {}
|
||||
|
||||
for block_cost in block_costs:
|
||||
if not _is_cost_filter_match(block_cost.cost_filter, input_data):
|
||||
continue
|
||||
|
||||
if block_cost.cost_type == BlockCostType.RUN:
|
||||
return block_cost.cost_amount, block_cost.cost_filter
|
||||
|
||||
if block_cost.cost_type == BlockCostType.SECOND:
|
||||
return (
|
||||
int(run_time * block_cost.cost_amount),
|
||||
block_cost.cost_filter,
|
||||
)
|
||||
|
||||
if block_cost.cost_type == BlockCostType.BYTE:
|
||||
return (
|
||||
int(data_size * block_cost.cost_amount),
|
||||
block_cost.cost_filter,
|
||||
)
|
||||
|
||||
return 0, {}
|
||||
|
||||
|
||||
def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bool:
|
||||
"""
|
||||
Filter rules:
|
||||
- If cost_filter is an object, then check if cost_filter is the subset of input_data
|
||||
- Otherwise, check if cost_filter is equal to input_data.
|
||||
- Undefined, null, and empty string are considered as equal.
|
||||
"""
|
||||
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
|
||||
return cost_filter == input_data
|
||||
|
||||
return all(
|
||||
(not input_data.get(k) and not v)
|
||||
or (input_data.get(k) and _is_cost_filter_match(v, input_data[k]))
|
||||
for k, v in cost_filter.items()
|
||||
)
|
||||
@@ -1,22 +1,43 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .compass import CompassWebhookManager
|
||||
from .github import GithubWebhooksManager
|
||||
from .slant3d import Slant3DWebhooksManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
from ._base import BaseWebhooksManager
|
||||
|
||||
# --8<-- [start:WEBHOOK_MANAGERS_BY_NAME]
|
||||
WEBHOOK_MANAGERS_BY_NAME: dict["ProviderName", type["BaseWebhooksManager"]] = {
|
||||
handler.PROVIDER_NAME: handler
|
||||
for handler in [
|
||||
CompassWebhookManager,
|
||||
GithubWebhooksManager,
|
||||
Slant3DWebhooksManager,
|
||||
]
|
||||
}
|
||||
# --8<-- [end:WEBHOOK_MANAGERS_BY_NAME]
|
||||
_WEBHOOK_MANAGERS: dict["ProviderName", type["BaseWebhooksManager"]] = {}
|
||||
|
||||
__all__ = ["WEBHOOK_MANAGERS_BY_NAME"]
|
||||
|
||||
# --8<-- [start:load_webhook_managers]
|
||||
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
|
||||
if _WEBHOOK_MANAGERS:
|
||||
return _WEBHOOK_MANAGERS
|
||||
|
||||
from .compass import CompassWebhookManager
|
||||
from .github import GithubWebhooksManager
|
||||
from .slant3d import Slant3DWebhooksManager
|
||||
|
||||
_WEBHOOK_MANAGERS.update(
|
||||
{
|
||||
handler.PROVIDER_NAME: handler
|
||||
for handler in [
|
||||
CompassWebhookManager,
|
||||
GithubWebhooksManager,
|
||||
Slant3DWebhooksManager,
|
||||
]
|
||||
}
|
||||
)
|
||||
return _WEBHOOK_MANAGERS
|
||||
|
||||
|
||||
# --8<-- [end:load_webhook_managers]
|
||||
|
||||
|
||||
def get_webhook_manager(provider_name: "ProviderName") -> "BaseWebhooksManager":
|
||||
return load_webhook_managers()[provider_name]()
|
||||
|
||||
|
||||
def supports_webhooks(provider_name: "ProviderName") -> bool:
|
||||
return provider_name in load_webhook_managers()
|
||||
|
||||
|
||||
__all__ = ["get_webhook_manager", "supports_webhooks"]
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Callable, Optional, cast
|
||||
|
||||
from backend.data.block import BlockSchema, BlockWebhookConfig, get_block
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
|
||||
from backend.integrations.webhooks import get_webhook_manager, supports_webhooks
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import GraphModel, NodeModel
|
||||
@@ -123,7 +123,7 @@ async def on_node_activate(
|
||||
return node
|
||||
|
||||
provider = block.webhook_config.provider
|
||||
if provider not in WEBHOOK_MANAGERS_BY_NAME:
|
||||
if not supports_webhooks(provider):
|
||||
raise ValueError(
|
||||
f"Block #{block.id} has webhook_config for provider {provider} "
|
||||
"which does not support webhooks"
|
||||
@@ -133,7 +133,7 @@ async def on_node_activate(
|
||||
f"Activating webhook node #{node.id} with config {block.webhook_config}"
|
||||
)
|
||||
|
||||
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
|
||||
webhooks_manager = get_webhook_manager(provider)
|
||||
|
||||
if auto_setup_webhook := isinstance(block.webhook_config, BlockWebhookConfig):
|
||||
try:
|
||||
@@ -234,13 +234,13 @@ async def on_node_deactivate(
|
||||
return node
|
||||
|
||||
provider = block.webhook_config.provider
|
||||
if provider not in WEBHOOK_MANAGERS_BY_NAME:
|
||||
if not supports_webhooks(provider):
|
||||
raise ValueError(
|
||||
f"Block #{block.id} has webhook_config for provider {provider} "
|
||||
"which does not support webhooks"
|
||||
)
|
||||
|
||||
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
|
||||
webhooks_manager = get_webhook_manager(provider)
|
||||
|
||||
if node.webhook_id:
|
||||
logger.debug(f"Node #{node.id} has webhook_id {node.webhook_id}")
|
||||
|
||||
@@ -7,9 +7,9 @@ from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.notifications import (
|
||||
NotificationDataType_co,
|
||||
NotificationEventModel,
|
||||
NotificationTypeOverride,
|
||||
T_co,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.text import TextFormatter
|
||||
@@ -48,7 +48,10 @@ class EmailSender:
|
||||
self,
|
||||
notification: NotificationType,
|
||||
user_email: str,
|
||||
data: NotificationEventModel[T_co] | list[NotificationEventModel[T_co]],
|
||||
data: (
|
||||
NotificationEventModel[NotificationDataType_co]
|
||||
| list[NotificationEventModel[NotificationDataType_co]]
|
||||
),
|
||||
user_unsub_link: str | None = None,
|
||||
):
|
||||
"""Send an email to a user using a template pulled from the notification type"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable
|
||||
|
||||
import aio_pika
|
||||
@@ -10,25 +10,25 @@ from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.notifications import (
|
||||
BaseSummaryData,
|
||||
BaseSummaryParams,
|
||||
DailySummaryData,
|
||||
DailySummaryParams,
|
||||
NotificationEventDTO,
|
||||
NotificationEventModel,
|
||||
NotificationResult,
|
||||
NotificationTypeOverride,
|
||||
QueueType,
|
||||
create_or_add_to_user_notification_batch,
|
||||
empty_user_notification_batch,
|
||||
get_all_batches_by_type,
|
||||
SummaryParamsEventDTO,
|
||||
SummaryParamsEventModel,
|
||||
WeeklySummaryData,
|
||||
WeeklySummaryParams,
|
||||
get_batch_delay,
|
||||
get_data_type,
|
||||
get_user_notification_batch,
|
||||
get_user_notification_oldest_message_in_batch,
|
||||
get_notif_data_type,
|
||||
get_summary_params_type,
|
||||
)
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.user import (
|
||||
generate_unsubscribe_link,
|
||||
get_user_email_by_id,
|
||||
get_user_email_verification,
|
||||
get_user_notification_preference,
|
||||
)
|
||||
from backend.data.user import generate_unsubscribe_link
|
||||
from backend.notifications.email import EmailSender
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import Settings
|
||||
@@ -68,6 +68,16 @@ def create_notification_config() -> RabbitMQConfig:
|
||||
"x-dead-letter-routing-key": "failed.admin",
|
||||
},
|
||||
),
|
||||
# Summary notification queues
|
||||
Queue(
|
||||
name="summary_notifications",
|
||||
exchange=notification_exchange,
|
||||
routing_key="notification.summary.#",
|
||||
arguments={
|
||||
"x-dead-letter-exchange": dead_letter_exchange.name,
|
||||
"x-dead-letter-routing-key": "failed.summary",
|
||||
},
|
||||
),
|
||||
# Batch Queue
|
||||
Queue(
|
||||
name="batch_notifications",
|
||||
@@ -102,12 +112,18 @@ def get_scheduler():
|
||||
return get_service_client(Scheduler)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db():
|
||||
from backend.executor.database import DatabaseManager
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
class NotificationManager(AppService):
|
||||
"""Service for handling notifications with batching support"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_db = True
|
||||
self.rabbitmq_config = create_notification_config()
|
||||
self.running = True
|
||||
self.email_sender = EmailSender()
|
||||
@@ -116,19 +132,51 @@ class NotificationManager(AppService):
|
||||
def get_port(cls) -> int:
|
||||
return settings.config.notification_service_port
|
||||
|
||||
def get_routing_key(self, event: NotificationEventModel) -> str:
|
||||
def get_routing_key(self, event_type: NotificationType) -> str:
|
||||
strategy = NotificationTypeOverride(event_type).strategy
|
||||
"""Get the appropriate routing key for an event"""
|
||||
if event.strategy == QueueType.IMMEDIATE:
|
||||
return f"notification.immediate.{event.type.value}"
|
||||
elif event.strategy == QueueType.BACKOFF:
|
||||
return f"notification.backoff.{event.type.value}"
|
||||
elif event.strategy == QueueType.ADMIN:
|
||||
return f"notification.admin.{event.type.value}"
|
||||
elif event.strategy == QueueType.BATCH:
|
||||
return f"notification.batch.{event.type.value}"
|
||||
elif event.strategy == QueueType.SUMMARY:
|
||||
return f"notification.summary.{event.type.value}"
|
||||
return f"notification.{event.type.value}"
|
||||
if strategy == QueueType.IMMEDIATE:
|
||||
return f"notification.immediate.{event_type.value}"
|
||||
elif strategy == QueueType.BACKOFF:
|
||||
return f"notification.backoff.{event_type.value}"
|
||||
elif strategy == QueueType.ADMIN:
|
||||
return f"notification.admin.{event_type.value}"
|
||||
elif strategy == QueueType.BATCH:
|
||||
return f"notification.batch.{event_type.value}"
|
||||
elif strategy == QueueType.SUMMARY:
|
||||
return f"notification.summary.{event_type.value}"
|
||||
return f"notification.{event_type.value}"
|
||||
|
||||
@expose
|
||||
def queue_weekly_summary(self):
|
||||
"""Process weekly summary for specified notification types"""
|
||||
try:
|
||||
logger.info("Processing weekly summary queuing operation")
|
||||
processed_count = 0
|
||||
current_time = datetime.now(tz=timezone.utc)
|
||||
start_time = current_time - timedelta(days=7)
|
||||
users = get_db().get_active_user_ids_in_timerange(
|
||||
end_time=current_time.isoformat(),
|
||||
start_time=start_time.isoformat(),
|
||||
)
|
||||
for user in users:
|
||||
|
||||
self._queue_scheduled_notification(
|
||||
SummaryParamsEventDTO(
|
||||
user_id=user,
|
||||
type=NotificationType.WEEKLY_SUMMARY,
|
||||
data=WeeklySummaryParams(
|
||||
start_date=start_time,
|
||||
end_date=current_time,
|
||||
).model_dump(),
|
||||
),
|
||||
)
|
||||
processed_count += 1
|
||||
|
||||
logger.info(f"Processed {processed_count} weekly summaries into queue")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing weekly summary: {e}")
|
||||
|
||||
@expose
|
||||
def process_existing_batches(self, notification_types: list[NotificationType]):
|
||||
@@ -139,80 +187,74 @@ class NotificationManager(AppService):
|
||||
|
||||
for notification_type in notification_types:
|
||||
# Get all batches for this notification type
|
||||
batches = self.run_and_wait(get_all_batches_by_type(notification_type))
|
||||
batches = get_db().get_all_batches_by_type(notification_type)
|
||||
|
||||
for batch in batches:
|
||||
# Check if batch has aged out
|
||||
oldest_message = self.run_and_wait(
|
||||
get_user_notification_oldest_message_in_batch(
|
||||
batch.userId, notification_type
|
||||
oldest_message = (
|
||||
get_db().get_user_notification_oldest_message_in_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
)
|
||||
|
||||
if not oldest_message:
|
||||
# this should never happen
|
||||
logger.error(
|
||||
f"Batch for user {batch.userId} and type {notification_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
|
||||
f"Batch for user {batch.user_id} and type {notification_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
|
||||
)
|
||||
continue
|
||||
|
||||
max_delay = get_batch_delay(notification_type)
|
||||
|
||||
# If batch has aged out, process it
|
||||
if oldest_message.createdAt + max_delay < current_time:
|
||||
recipient_email = self.run_and_wait(
|
||||
get_user_email_by_id(batch.userId)
|
||||
)
|
||||
if oldest_message.created_at + max_delay < current_time:
|
||||
recipient_email = get_db().get_user_email_by_id(batch.user_id)
|
||||
|
||||
if not recipient_email:
|
||||
logger.error(
|
||||
f"User email not found for user {batch.userId}"
|
||||
f"User email not found for user {batch.user_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
should_send = self._should_email_user_based_on_preference(
|
||||
batch.userId, notification_type
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
if not should_send:
|
||||
logger.debug(
|
||||
f"User {batch.userId} does not want to receive {notification_type} notifications"
|
||||
f"User {batch.user_id} does not want to receive {notification_type} notifications"
|
||||
)
|
||||
# Clear the batch
|
||||
self.run_and_wait(
|
||||
empty_user_notification_batch(
|
||||
batch.userId, notification_type
|
||||
)
|
||||
get_db().empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
continue
|
||||
|
||||
batch_data = self.run_and_wait(
|
||||
get_user_notification_batch(batch.userId, notification_type)
|
||||
batch_data = get_db().get_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
if not batch_data or not batch_data.notifications:
|
||||
logger.error(
|
||||
f"Batch data not found for user {batch.userId}"
|
||||
f"Batch data not found for user {batch.user_id}"
|
||||
)
|
||||
# Clear the batch
|
||||
self.run_and_wait(
|
||||
empty_user_notification_batch(
|
||||
batch.userId, notification_type
|
||||
)
|
||||
get_db().empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
continue
|
||||
|
||||
unsub_link = generate_unsubscribe_link(batch.userId)
|
||||
unsub_link = generate_unsubscribe_link(batch.user_id)
|
||||
|
||||
events = [
|
||||
NotificationEventModel[
|
||||
get_data_type(db_event.type)
|
||||
get_notif_data_type(db_event.type)
|
||||
].model_validate(
|
||||
{
|
||||
"user_id": batch.userId,
|
||||
"user_id": batch.user_id,
|
||||
"type": db_event.type,
|
||||
"data": db_event.data,
|
||||
"created_at": db_event.createdAt,
|
||||
"created_at": db_event.created_at,
|
||||
}
|
||||
)
|
||||
for db_event in batch_data.notifications
|
||||
@@ -227,10 +269,8 @@ class NotificationManager(AppService):
|
||||
)
|
||||
|
||||
# Clear the batch
|
||||
self.run_and_wait(
|
||||
empty_user_notification_batch(
|
||||
batch.userId, notification_type
|
||||
)
|
||||
get_db().empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
processed_count += 1
|
||||
@@ -259,9 +299,9 @@ class NotificationManager(AppService):
|
||||
logger.info(f"Received Request to queue {event=}")
|
||||
# Workaround for not being able to serialize generics over the expose bus
|
||||
parsed_event = NotificationEventModel[
|
||||
get_data_type(event.type)
|
||||
get_notif_data_type(event.type)
|
||||
].model_validate(event.model_dump())
|
||||
routing_key = self.get_routing_key(parsed_event)
|
||||
routing_key = self.get_routing_key(parsed_event.type)
|
||||
message = parsed_event.model_dump_json()
|
||||
|
||||
logger.info(f"Received Request to queue {message=}")
|
||||
@@ -288,24 +328,136 @@ class NotificationManager(AppService):
|
||||
logger.exception(f"Error queueing notification: {e}")
|
||||
return NotificationResult(success=False, message=str(e))
|
||||
|
||||
def _queue_scheduled_notification(self, event: SummaryParamsEventDTO):
|
||||
"""Queue a scheduled notification - exposed method for other services to call"""
|
||||
try:
|
||||
logger.info(f"Received Request to queue scheduled notification {event=}")
|
||||
|
||||
parsed_event = SummaryParamsEventModel[
|
||||
get_summary_params_type(event.type)
|
||||
].model_validate(event.model_dump())
|
||||
|
||||
routing_key = self.get_routing_key(event.type)
|
||||
message = parsed_event.model_dump_json()
|
||||
|
||||
logger.info(f"Received Request to queue {message=}")
|
||||
|
||||
exchange = "notifications"
|
||||
|
||||
# Publish to RabbitMQ
|
||||
self.run_and_wait(
|
||||
self.rabbit.publish_message(
|
||||
routing_key=routing_key,
|
||||
message=message,
|
||||
exchange=next(
|
||||
ex for ex in self.rabbit_config.exchanges if ex.name == exchange
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error queueing notification: {e}")
|
||||
|
||||
def _should_email_user_based_on_preference(
|
||||
self, user_id: str, event_type: NotificationType
|
||||
) -> bool:
|
||||
"""Check if a user wants to receive a notification based on their preferences and email verification status"""
|
||||
validated_email = self.run_and_wait(get_user_email_verification(user_id))
|
||||
preference = self.run_and_wait(
|
||||
get_user_notification_preference(user_id)
|
||||
).preferences.get(event_type, True)
|
||||
validated_email = get_db().get_user_email_verification(user_id)
|
||||
preference = (
|
||||
get_db()
|
||||
.get_user_notification_preference(user_id)
|
||||
.preferences.get(event_type, True)
|
||||
)
|
||||
# only if both are true, should we email this person
|
||||
return validated_email and preference
|
||||
|
||||
async def _should_batch(
|
||||
def _gather_summary_data(
|
||||
self, user_id: str, event_type: NotificationType, params: BaseSummaryParams
|
||||
) -> BaseSummaryData:
|
||||
"""Gathers the data to build a summary notification"""
|
||||
|
||||
logger.info(
|
||||
f"Gathering summary data for {user_id} and {event_type} wiht {params=}"
|
||||
)
|
||||
|
||||
# total_credits_used = self.run_and_wait(
|
||||
# get_total_credits_used(user_id, start_time, end_time)
|
||||
# )
|
||||
|
||||
# total_executions = self.run_and_wait(
|
||||
# get_total_executions(user_id, start_time, end_time)
|
||||
# )
|
||||
|
||||
# most_used_agent = self.run_and_wait(
|
||||
# get_most_used_agent(user_id, start_time, end_time)
|
||||
# )
|
||||
|
||||
# execution_times = self.run_and_wait(
|
||||
# get_execution_time(user_id, start_time, end_time)
|
||||
# )
|
||||
|
||||
# runs = self.run_and_wait(
|
||||
# get_runs(user_id, start_time, end_time)
|
||||
# )
|
||||
total_credits_used = 3.0
|
||||
total_executions = 2
|
||||
most_used_agent = {"name": "Some"}
|
||||
execution_times = [1, 2, 3]
|
||||
runs = [{"status": "COMPLETED"}, {"status": "FAILED"}]
|
||||
|
||||
successful_runs = len([run for run in runs if run["status"] == "COMPLETED"])
|
||||
failed_runs = len([run for run in runs if run["status"] != "COMPLETED"])
|
||||
average_execution_time = (
|
||||
sum(execution_times) / len(execution_times) if execution_times else 0
|
||||
)
|
||||
# cost_breakdown = self.run_and_wait(
|
||||
# get_cost_breakdown(user_id, start_time, end_time)
|
||||
# )
|
||||
|
||||
cost_breakdown = {
|
||||
"agent1": 1.0,
|
||||
"agent2": 2.0,
|
||||
}
|
||||
|
||||
if event_type == NotificationType.DAILY_SUMMARY and isinstance(
|
||||
params, DailySummaryParams
|
||||
):
|
||||
return DailySummaryData(
|
||||
total_credits_used=total_credits_used,
|
||||
total_executions=total_executions,
|
||||
most_used_agent=most_used_agent["name"],
|
||||
total_execution_time=sum(execution_times),
|
||||
successful_runs=successful_runs,
|
||||
failed_runs=failed_runs,
|
||||
average_execution_time=average_execution_time,
|
||||
cost_breakdown=cost_breakdown,
|
||||
date=params.date,
|
||||
)
|
||||
elif event_type == NotificationType.WEEKLY_SUMMARY and isinstance(
|
||||
params, WeeklySummaryParams
|
||||
):
|
||||
return WeeklySummaryData(
|
||||
total_credits_used=total_credits_used,
|
||||
total_executions=total_executions,
|
||||
most_used_agent=most_used_agent["name"],
|
||||
total_execution_time=sum(execution_times),
|
||||
successful_runs=successful_runs,
|
||||
failed_runs=failed_runs,
|
||||
average_execution_time=average_execution_time,
|
||||
cost_breakdown=cost_breakdown,
|
||||
start_date=params.start_date,
|
||||
end_date=params.end_date,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid event type or params")
|
||||
|
||||
def _should_batch(
|
||||
self, user_id: str, event_type: NotificationType, event: NotificationEventModel
|
||||
) -> bool:
|
||||
|
||||
await create_or_add_to_user_notification_batch(user_id, event_type, event)
|
||||
get_db().create_or_add_to_user_notification_batch(user_id, event_type, event)
|
||||
|
||||
oldest_message = await get_user_notification_oldest_message_in_batch(
|
||||
oldest_message = get_db().get_user_notification_oldest_message_in_batch(
|
||||
user_id, event_type
|
||||
)
|
||||
if not oldest_message:
|
||||
@@ -313,7 +465,7 @@ class NotificationManager(AppService):
|
||||
f"Batch for user {user_id} and type {event_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
|
||||
)
|
||||
return False
|
||||
oldest_age = oldest_message.createdAt
|
||||
oldest_age = oldest_message.created_at
|
||||
|
||||
max_delay = get_batch_delay(event_type)
|
||||
|
||||
@@ -329,7 +481,7 @@ class NotificationManager(AppService):
|
||||
try:
|
||||
event = NotificationEventDTO.model_validate_json(message)
|
||||
model = NotificationEventModel[
|
||||
get_data_type(event.type)
|
||||
get_notif_data_type(event.type)
|
||||
].model_validate_json(message)
|
||||
return NotificationEvent(event=event, model=model)
|
||||
except Exception as e:
|
||||
@@ -362,7 +514,7 @@ class NotificationManager(AppService):
|
||||
model = parsed.model
|
||||
logger.debug(f"Processing immediate notification: {model}")
|
||||
|
||||
recipient_email = self.run_and_wait(get_user_email_by_id(event.user_id))
|
||||
recipient_email = get_db().get_user_email_by_id(event.user_id)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
@@ -399,7 +551,7 @@ class NotificationManager(AppService):
|
||||
model = parsed.model
|
||||
logger.info(f"Processing batch notification: {model}")
|
||||
|
||||
recipient_email = self.run_and_wait(get_user_email_by_id(event.user_id))
|
||||
recipient_email = get_db().get_user_email_by_id(event.user_id)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
@@ -413,28 +565,26 @@ class NotificationManager(AppService):
|
||||
)
|
||||
return True
|
||||
|
||||
should_send = self.run_and_wait(
|
||||
self._should_batch(event.user_id, event.type, model)
|
||||
)
|
||||
should_send = self._should_batch(event.user_id, event.type, model)
|
||||
|
||||
if not should_send:
|
||||
logger.info("Batch not old enough to send")
|
||||
return False
|
||||
batch = self.run_and_wait(
|
||||
get_user_notification_batch(event.user_id, event.type)
|
||||
)
|
||||
batch = get_db().get_user_notification_batch(event.user_id, event.type)
|
||||
if not batch or not batch.notifications:
|
||||
logger.error(f"Batch not found for user {event.user_id}")
|
||||
return False
|
||||
unsub_link = generate_unsubscribe_link(event.user_id)
|
||||
|
||||
batch_messages = [
|
||||
NotificationEventModel[get_data_type(db_event.type)].model_validate(
|
||||
NotificationEventModel[
|
||||
get_notif_data_type(db_event.type)
|
||||
].model_validate(
|
||||
{
|
||||
"user_id": event.user_id,
|
||||
"type": db_event.type,
|
||||
"data": db_event.data,
|
||||
"created_at": db_event.createdAt,
|
||||
"created_at": db_event.created_at,
|
||||
}
|
||||
)
|
||||
for db_event in batch.notifications
|
||||
@@ -447,12 +597,59 @@ class NotificationManager(AppService):
|
||||
user_unsub_link=unsub_link,
|
||||
)
|
||||
# only empty the batch if we sent the email successfully
|
||||
self.run_and_wait(empty_user_notification_batch(event.user_id, event.type))
|
||||
get_db().empty_user_notification_batch(event.user_id, event.type)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing notification for batch queue: {e}")
|
||||
return False
|
||||
|
||||
def _process_summary(self, message: str) -> bool:
|
||||
"""Process a single notification with a summary strategy, returning whether to put into the failed queue"""
|
||||
try:
|
||||
logger.info(f"Processing summary notification: {message}")
|
||||
event = SummaryParamsEventDTO.model_validate_json(message)
|
||||
model = SummaryParamsEventModel[
|
||||
get_summary_params_type(event.type)
|
||||
].model_validate_json(message)
|
||||
|
||||
logger.info(f"Processing summary notification: {model}")
|
||||
|
||||
recipient_email = get_db().get_user_email_by_id(event.user_id)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
should_send = self._should_email_user_based_on_preference(
|
||||
event.user_id, event.type
|
||||
)
|
||||
if not should_send:
|
||||
logger.info(
|
||||
f"User {event.user_id} does not want to receive {event.type} notifications"
|
||||
)
|
||||
return True
|
||||
|
||||
summary_data = self._gather_summary_data(
|
||||
event.user_id, event.type, model.data
|
||||
)
|
||||
|
||||
unsub_link = generate_unsubscribe_link(event.user_id)
|
||||
|
||||
data = NotificationEventModel(
|
||||
user_id=event.user_id,
|
||||
type=event.type,
|
||||
data=summary_data,
|
||||
)
|
||||
|
||||
self.email_sender.send_templated(
|
||||
notification=event.type,
|
||||
user_email=recipient_email,
|
||||
data=data,
|
||||
user_unsub_link=unsub_link,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing notification for summary queue: {e}")
|
||||
return False
|
||||
|
||||
def _run_queue(
|
||||
self,
|
||||
queue: aio_pika.abc.AbstractQueue,
|
||||
@@ -493,6 +690,10 @@ class NotificationManager(AppService):
|
||||
data={},
|
||||
cron="0 * * * *",
|
||||
)
|
||||
# get_scheduler().add_weekly_notification_schedule(
|
||||
# # weekly on Friday at 12pm
|
||||
# cron="0 12 * * 5",
|
||||
# )
|
||||
logger.info("Scheduled notification cleanup")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scheduling notification cleanup: {e}")
|
||||
@@ -507,6 +708,8 @@ class NotificationManager(AppService):
|
||||
|
||||
admin_queue = self.run_and_wait(channel.get_queue("admin_notifications"))
|
||||
|
||||
summary_queue = self.run_and_wait(channel.get_queue("summary_notifications"))
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
self._run_queue(
|
||||
@@ -525,6 +728,12 @@ class NotificationManager(AppService):
|
||||
error_queue_name="batch_notifications",
|
||||
)
|
||||
|
||||
self._run_queue(
|
||||
queue=summary_queue,
|
||||
process_func=self._process_summary,
|
||||
error_queue_name="summary_notifications",
|
||||
)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
except QueueEmpty as e:
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
{# Weekly Summary #}
|
||||
{# Template variables:
|
||||
data: the stuff below
|
||||
data.start_date: the start date of the summary
|
||||
data.end_date: the end date of the summary
|
||||
data.total_credits_used: the total credits used during the summary
|
||||
data.total_executions: the total number of executions during the summary
|
||||
data.most_used_agent: the most used agent's nameduring the summary
|
||||
data.total_execution_time: the total execution time during the summary
|
||||
data.successful_runs: the total number of successful runs during the summary
|
||||
data.failed_runs: the total number of failed runs during the summary
|
||||
data.average_execution_time: the average execution time during the summary
|
||||
data.cost_breakdown: the cost breakdown during the summary
|
||||
#}
|
||||
|
||||
<h1>Weekly Summary</h1>
|
||||
|
||||
<p>Start Date: {{ data.start_date }}</p>
|
||||
<p>End Date: {{ data.end_date }}</p>
|
||||
<p>Total Credits Used: {{ data.total_credits_used }}</p>
|
||||
<p>Total Executions: {{ data.total_executions }}</p>
|
||||
<p>Most Used Agent: {{ data.most_used_agent }}</p>
|
||||
<p>Total Execution Time: {{ data.total_execution_time }}</p>
|
||||
<p>Successful Runs: {{ data.successful_runs }}</p>
|
||||
<p>Failed Runs: {{ data.failed_runs }}</p>
|
||||
<p>Average Execution Time: {{ data.average_execution_time }}</p>
|
||||
<p>Cost Breakdown: {{ data.cost_breakdown }}</p>
|
||||
@@ -20,23 +20,25 @@ class ConnectionManager:
|
||||
for subscribers in self.subscriptions.values():
|
||||
subscribers.discard(websocket)
|
||||
|
||||
async def subscribe(self, graph_id: str, graph_version: int, websocket: WebSocket):
|
||||
key = f"{graph_id}_{graph_version}"
|
||||
async def subscribe(
|
||||
self, *, user_id: str, graph_id: str, graph_version: int, websocket: WebSocket
|
||||
):
|
||||
key = f"{user_id}_{graph_id}_{graph_version}"
|
||||
if key not in self.subscriptions:
|
||||
self.subscriptions[key] = set()
|
||||
self.subscriptions[key].add(websocket)
|
||||
|
||||
async def unsubscribe(
|
||||
self, graph_id: str, graph_version: int, websocket: WebSocket
|
||||
self, *, user_id: str, graph_id: str, graph_version: int, websocket: WebSocket
|
||||
):
|
||||
key = f"{graph_id}_{graph_version}"
|
||||
key = f"{user_id}_{graph_id}_{graph_version}"
|
||||
if key in self.subscriptions:
|
||||
self.subscriptions[key].discard(websocket)
|
||||
if not self.subscriptions[key]:
|
||||
del self.subscriptions[key]
|
||||
|
||||
async def send_execution_result(self, result: execution.ExecutionResult):
|
||||
key = f"{result.graph_id}_{result.graph_version}"
|
||||
key = f"{result.user_id}_{result.graph_id}_{result.graph_version}"
|
||||
if key in self.subscriptions:
|
||||
message = WsMessage(
|
||||
method=Methods.EXECUTION_EVENT,
|
||||
|
||||
@@ -71,7 +71,7 @@ def get_outputs_with_names(results: List[ExecutionResult]) -> List[Dict[str, str
|
||||
)
|
||||
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||
return [b.to_dict() for b in blocks]
|
||||
return [b.to_dict() for b in blocks if not b.disabled]
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.executor.manager import ExecutionManager
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.util.exceptions import NeedConfirmation
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
@@ -281,7 +281,7 @@ async def webhook_ingress_generic(
|
||||
webhook_id: Annotated[str, Path(title="Our ID for the webhook")],
|
||||
):
|
||||
logger.debug(f"Received {provider.value} webhook ingress for ID {webhook_id}")
|
||||
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
|
||||
webhook_manager = get_webhook_manager(provider)
|
||||
webhook = await get_webhook(webhook_id)
|
||||
logger.debug(f"Webhook #{webhook_id}: {webhook}")
|
||||
payload, event_type = await webhook_manager.validate_payload(webhook, request)
|
||||
@@ -323,7 +323,7 @@ async def webhook_ping(
|
||||
user_id: Annotated[str, Depends(get_user_id)], # require auth
|
||||
):
|
||||
webhook = await get_webhook(webhook_id)
|
||||
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[webhook.provider]()
|
||||
webhook_manager = get_webhook_manager(webhook.provider)
|
||||
|
||||
credentials = (
|
||||
creds_manager.get(user_id, webhook.credentials_id)
|
||||
@@ -358,14 +358,6 @@ async def remove_all_webhooks_for_credentials(
|
||||
NeedConfirmation: If any of the webhooks are still in use and `force` is `False`
|
||||
"""
|
||||
webhooks = await get_all_webhooks_by_creds(credentials.id)
|
||||
if credentials.provider not in WEBHOOK_MANAGERS_BY_NAME:
|
||||
if webhooks:
|
||||
logger.error(
|
||||
f"Credentials #{credentials.id} for provider {credentials.provider} "
|
||||
f"are attached to {len(webhooks)} webhooks, "
|
||||
f"but there is no available WebhooksHandler for {credentials.provider}"
|
||||
)
|
||||
return
|
||||
if any(w.attached_nodes for w in webhooks) and not force:
|
||||
raise NeedConfirmation(
|
||||
"Some webhooks linked to these credentials are still in use by an agent"
|
||||
@@ -376,7 +368,7 @@ async def remove_all_webhooks_for_credentials(
|
||||
await set_node_webhook(node.id, None)
|
||||
|
||||
# Prune the webhook
|
||||
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[credentials.provider]()
|
||||
webhook_manager = get_webhook_manager(ProviderName(credentials.provider))
|
||||
success = await webhook_manager.prune_webhook_if_dangling(
|
||||
webhook.id, credentials
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.store_admin_routes
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.library.routes
|
||||
@@ -99,6 +100,11 @@ app.include_router(backend.server.routers.v1.v1_router, tags=["v1"], prefix="/ap
|
||||
app.include_router(
|
||||
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.admin.store_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/store",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||
)
|
||||
@@ -154,9 +160,10 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
user_id: str,
|
||||
for_export: bool = False,
|
||||
):
|
||||
return await backend.server.routers.v1.get_graph(
|
||||
graph_id, user_id, graph_version
|
||||
graph_id, user_id, graph_version, for_export
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -249,12 +256,16 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
):
|
||||
return await backend.server.v2.store.routes.create_submission(request, user_id)
|
||||
|
||||
### ADMIN ###
|
||||
|
||||
@staticmethod
|
||||
async def test_review_store_listing(
|
||||
request: backend.server.v2.store.model.ReviewSubmissionRequest,
|
||||
user: autogpt_libs.auth.models.User,
|
||||
):
|
||||
return await backend.server.v2.store.routes.review_submission(request, user)
|
||||
return await backend.server.v2.admin.store_admin_routes.review_submission(
|
||||
request.store_listing_version_id, request, user
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def test_create_credentials(
|
||||
|
||||
@@ -38,7 +38,6 @@ from backend.data.credit import (
|
||||
TransactionHistory,
|
||||
get_auto_top_up,
|
||||
get_block_costs,
|
||||
get_stripe_customer_id,
|
||||
get_user_credit_model,
|
||||
set_auto_top_up,
|
||||
)
|
||||
@@ -199,7 +198,9 @@ async def get_onboarding_agents(
|
||||
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||
costs = get_block_costs()
|
||||
return [{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks]
|
||||
return [
|
||||
{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks if not b.disabled
|
||||
]
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -341,15 +342,7 @@ async def stripe_webhook(request: Request):
|
||||
async def manage_payment_method(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[str, str]:
|
||||
session = stripe.billing_portal.Session.create(
|
||||
customer=await get_stripe_customer_id(user_id),
|
||||
return_url=settings.config.frontend_base_url + "/profile/credits",
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Failed to create billing portal session"
|
||||
)
|
||||
return {"url": session.url}
|
||||
return {"url": await _user_credit_model.create_billing_portal_session(user_id)}
|
||||
|
||||
|
||||
@v1_router.get(path="/credits/transactions", dependencies=[Depends(auth_middleware)])
|
||||
@@ -405,10 +398,10 @@ async def get_graph(
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
version: int | None = None,
|
||||
hide_credentials: bool = False,
|
||||
for_export: bool = False,
|
||||
) -> graph_db.GraphModel:
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id, version, user_id=user_id, for_export=hide_credentials
|
||||
graph_id, version, user_id=user_id, for_export=for_export
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
@@ -438,6 +431,7 @@ async def create_new_graph(
|
||||
) -> graph_db.GraphModel:
|
||||
graph = graph_db.make_graph_model(create_graph.graph, user_id)
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||
graph.validate_graph(for_run=False)
|
||||
|
||||
graph = await graph_db.create_graph(graph, user_id=user_id)
|
||||
|
||||
@@ -489,17 +483,10 @@ async def update_graph(
|
||||
latest_version_number = max(g.version for g in existing_versions)
|
||||
graph.version = latest_version_number + 1
|
||||
|
||||
latest_version_graph = next(
|
||||
v for v in existing_versions if v.version == latest_version_number
|
||||
)
|
||||
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
||||
if latest_version_graph.is_template != graph.is_template:
|
||||
raise HTTPException(
|
||||
400, detail="Changing is_template on an existing graph is forbidden"
|
||||
)
|
||||
graph.is_active = not graph.is_template
|
||||
graph = graph_db.make_graph_model(graph, user_id)
|
||||
graph.reassign_ids(user_id=user_id)
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||
graph.validate_graph(for_run=False)
|
||||
|
||||
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
||||
|
||||
@@ -630,7 +617,7 @@ async def stop_graph_run(
|
||||
async def get_graphs_executions(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[graph_db.GraphExecutionMeta]:
|
||||
return await graph_db.get_graphs_executions(user_id=user_id)
|
||||
return await graph_db.get_graph_executions(user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -655,12 +642,8 @@ async def get_graph_execution(
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> graph_db.GraphExecution:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
result = await graph_db.get_execution(execution_id=graph_exec_id, user_id=user_id)
|
||||
if not result:
|
||||
if not result or result.graph_id != graph_id:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
|
||||
)
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter(prefix="/admin", tags=["store", "admin"])
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings",
|
||||
response_model=backend.server.v2.store.model.StoreListingsWithVersionsResponse,
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
|
||||
)
|
||||
async def get_admin_listings_with_versions(
|
||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||
search: typing.Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
Get store listings with their version history for admins.
|
||||
|
||||
This provides a consolidated view of listings with their versions,
|
||||
allowing for an expandable UI in the admin dashboard.
|
||||
|
||||
Args:
|
||||
status: Filter by submission status (PENDING, APPROVED, REJECTED)
|
||||
search: Search by name, description, or user email
|
||||
page: Page number for pagination
|
||||
page_size: Number of items per page
|
||||
|
||||
Returns:
|
||||
StoreListingsWithVersionsResponse with listings and their versions
|
||||
"""
|
||||
try:
|
||||
listings = await backend.server.v2.store.db.get_admin_listings_with_versions(
|
||||
status=status,
|
||||
search_query=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
except Exception as e:
|
||||
logger.exception("Error getting admin listings with versions: %s", e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while retrieving listings with versions"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/review",
|
||||
response_model=backend.server.v2.store.model.StoreSubmission,
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
|
||||
)
|
||||
async def review_submission(
|
||||
store_listing_version_id: str,
|
||||
request: backend.server.v2.store.model.ReviewSubmissionRequest,
|
||||
user: typing.Annotated[
|
||||
autogpt_libs.auth.models.User,
|
||||
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
|
||||
],
|
||||
):
|
||||
"""
|
||||
Review a store listing submission.
|
||||
|
||||
Args:
|
||||
store_listing_version_id: ID of the submission to review
|
||||
request: Review details including approval status and comments
|
||||
user: Authenticated admin user performing the review
|
||||
|
||||
Returns:
|
||||
StoreSubmission with updated review information
|
||||
"""
|
||||
try:
|
||||
submission = await backend.server.v2.store.db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user.user_id,
|
||||
)
|
||||
return submission
|
||||
except Exception as e:
|
||||
logger.exception("Error reviewing submission: %s", e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while reviewing the submission"},
|
||||
)
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@@ -186,12 +185,7 @@ async def add_generated_agent_image(
|
||||
try:
|
||||
if not (image_url := await store_media.check_media_exists(user_id, filename)):
|
||||
# Generate agent image as JPEG
|
||||
if config.use_agent_image_generation_v2:
|
||||
image = await asyncio.to_thread(
|
||||
store_image_gen.generate_agent_image_v2, graph=graph
|
||||
)
|
||||
else:
|
||||
image = await store_image_gen.generate_agent_image(agent=graph)
|
||||
image = await store_image_gen.generate_agent_image(graph)
|
||||
|
||||
# Create UploadFile with the correct filename and content_type
|
||||
image_file = fastapi.UploadFile(file=image, filename=filename)
|
||||
|
||||
@@ -1,22 +1,14 @@
|
||||
from datetime import datetime
|
||||
|
||||
import prisma.enums
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import pytest
|
||||
from prisma import Prisma
|
||||
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.store.exceptions
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_prisma():
|
||||
# Don't register client if already registered
|
||||
try:
|
||||
Prisma()
|
||||
except prisma.errors.ClientAlreadyRegisteredError:
|
||||
pass
|
||||
yield
|
||||
from backend.data.db import connect
|
||||
from backend.data.includes import library_agent_include
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -31,7 +23,6 @@ async def test_get_library_agents(mocker):
|
||||
userId="test-user",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
isTemplate=False,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -56,7 +47,6 @@ async def test_get_library_agents(mocker):
|
||||
userId="other-user",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
isTemplate=False,
|
||||
),
|
||||
)
|
||||
]
|
||||
@@ -91,17 +81,17 @@ async def test_get_library_agents(mocker):
|
||||
assert result.pagination.page_size == 50
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_add_agent_to_library(mocker):
|
||||
await connect()
|
||||
# Mock data
|
||||
mock_store_listing = prisma.models.StoreListingVersion(
|
||||
mock_store_listing_data = prisma.models.StoreListingVersion(
|
||||
id="version123",
|
||||
version=1,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
agentId="agent1",
|
||||
agentVersion=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
subHeading="Test Agent Subheading",
|
||||
imageUrls=["https://example.com/image.jpg"],
|
||||
@@ -110,7 +100,8 @@ async def test_add_agent_to_library(mocker):
|
||||
isFeatured=False,
|
||||
isDeleted=False,
|
||||
isAvailable=True,
|
||||
isApproved=True,
|
||||
storeListingId="listing123",
|
||||
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
|
||||
Agent=prisma.models.AgentGraph(
|
||||
id="agent1",
|
||||
version=1,
|
||||
@@ -119,21 +110,37 @@ async def test_add_agent_to_library(mocker):
|
||||
userId="creator",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
isTemplate=False,
|
||||
),
|
||||
)
|
||||
|
||||
mock_library_agent_data = prisma.models.LibraryAgent(
|
||||
id="ua1",
|
||||
userId="test-user",
|
||||
agentId=mock_store_listing_data.agentId,
|
||||
agentVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
Agent=mock_store_listing_data.Agent,
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
mock_store_listing_version = mocker.patch(
|
||||
"prisma.models.StoreListingVersion.prisma"
|
||||
)
|
||||
mock_store_listing_version.return_value.find_unique = mocker.AsyncMock(
|
||||
return_value=mock_store_listing
|
||||
return_value=mock_store_listing_data
|
||||
)
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.create = mocker.AsyncMock()
|
||||
mock_library_agent.return_value.create = mocker.AsyncMock(
|
||||
return_value=mock_library_agent_data
|
||||
)
|
||||
|
||||
# Call function
|
||||
await db.add_store_agent_to_library("version123", "test-user")
|
||||
@@ -147,17 +154,20 @@ async def test_add_agent_to_library(mocker):
|
||||
"userId": "test-user",
|
||||
"agentId": "agent1",
|
||||
"agentVersion": 1,
|
||||
}
|
||||
},
|
||||
include=library_agent_include("test-user"),
|
||||
)
|
||||
mock_library_agent.return_value.create.assert_called_once_with(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
|
||||
)
|
||||
),
|
||||
include=library_agent_include("test-user"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_add_agent_to_library_not_found(mocker):
|
||||
await connect()
|
||||
# Mock prisma calls
|
||||
mock_store_listing_version = mocker.patch(
|
||||
"prisma.models.StoreListingVersion.prisma"
|
||||
|
||||
@@ -2,11 +2,14 @@ import datetime
|
||||
|
||||
import prisma.fields
|
||||
import prisma.models
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
from backend.util import json
|
||||
|
||||
|
||||
def test_agent_preset_from_db():
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_preset_from_db():
|
||||
# Create mock DB agent
|
||||
db_agent = prisma.models.AgentPreset(
|
||||
id="test-agent-123",
|
||||
@@ -24,7 +27,7 @@ def test_agent_preset_from_db():
|
||||
id="input-123",
|
||||
time=datetime.datetime.now(),
|
||||
name="input1",
|
||||
data=prisma.fields.Json({"type": "string", "value": "test value"}),
|
||||
data=json.dumps({"type": "string", "value": "test value"}), # type: ignore
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import datetime
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
@@ -30,49 +29,48 @@ app.dependency_overrides[autogpt_auth_lib.auth_middleware] = override_auth_middl
|
||||
app.dependency_overrides[autogpt_auth_lib.depends.get_user_id] = override_get_user_id
|
||||
|
||||
|
||||
def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = [
|
||||
library_model.LibraryAgentResponse(
|
||||
agents=[
|
||||
library_model.LibraryAgent(
|
||||
id="test-agent-1",
|
||||
agent_id="test-agent-1",
|
||||
agent_version=1,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
is_latest_version=True,
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
),
|
||||
library_model.LibraryAgent(
|
||||
id="test-agent-2",
|
||||
agent_id="test-agent-2",
|
||||
agent_version=1,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
new_output=False,
|
||||
can_access_graph=False,
|
||||
is_latest_version=True,
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
),
|
||||
],
|
||||
pagination=server_model.Pagination(
|
||||
total_items=2, total_pages=1, current_page=1, page_size=50
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = library_model.LibraryAgentResponse(
|
||||
agents=[
|
||||
library_model.LibraryAgent(
|
||||
id="test-agent-1",
|
||||
agent_id="test-agent-1",
|
||||
agent_version=1,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
is_latest_version=True,
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
),
|
||||
library_model.LibraryAgent(
|
||||
id="test-agent-2",
|
||||
agent_id="test-agent-2",
|
||||
agent_version=1,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
new_output=False,
|
||||
can_access_graph=False,
|
||||
is_latest_version=True,
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
),
|
||||
],
|
||||
pagination=server_model.Pagination(
|
||||
total_items=2, total_pages=1, current_page=1, page_size=50
|
||||
),
|
||||
]
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.get_library_agents")
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.list_library_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/agents?search_term=test")
|
||||
@@ -94,7 +92,7 @@ def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||
|
||||
|
||||
def test_get_library_agents_error(mocker: pytest_mock.MockFixture):
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.get_library_agents")
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.list_library_agents")
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.get("/agents?search_term=test")
|
||||
|
||||
@@ -4,15 +4,11 @@ from autogpt_libs.auth.middleware import auth_middleware
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from backend.server.utils import get_user_id
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .models import ApiResponse, ChatRequest
|
||||
from .service import OttoService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
OTTO_API_URL = settings.config.otto_api_url
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@@ -67,6 +68,13 @@ class OttoService:
|
||||
"""
|
||||
Send request to Otto API and handle the response.
|
||||
"""
|
||||
# Check if Otto API URL is configured
|
||||
if not OTTO_API_URL:
|
||||
logger.error("Otto API URL is not configured")
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Otto service is not configured"
|
||||
)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {
|
||||
@@ -94,7 +102,10 @@ class OttoService:
|
||||
logger.debug(f"Request payload: {payload}")
|
||||
|
||||
async with session.post(
|
||||
OTTO_API_URL, json=payload, headers=headers
|
||||
OTTO_API_URL,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=60),
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
@@ -115,6 +126,11 @@ class OttoService:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Failed to connect to Otto service"
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Timeout error connecting to Otto API after 60 seconds")
|
||||
raise HTTPException(
|
||||
status_code=504, detail="Request to Otto service timed out"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in Otto API proxy: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import fastapi
|
||||
import prisma.enums
|
||||
@@ -11,7 +10,8 @@ import prisma.types
|
||||
import backend.data.graph
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.graph import GraphModel, get_sub_graphs
|
||||
from backend.data.includes import AGENT_GRAPH_INCLUDE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,6 +44,9 @@ async def get_store_agents(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.StoreAgentsResponse:
|
||||
"""
|
||||
Get PUBLIC store agents from the StoreAgent view
|
||||
"""
|
||||
logger.debug(
|
||||
f"Getting store agents. featured={featured}, creator={creator}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||
)
|
||||
@@ -129,6 +132,7 @@ async def get_store_agents(
|
||||
async def get_store_agent_details(
|
||||
username: str, agent_name: str
|
||||
) -> backend.server.v2.store.model.StoreAgentDetails:
|
||||
"""Get PUBLIC store agent details from the StoreAgent view"""
|
||||
logger.debug(f"Getting store agent details for {username}/{agent_name}")
|
||||
|
||||
try:
|
||||
@@ -142,6 +146,20 @@ async def get_store_agent_details(
|
||||
f"Agent {username}/{agent_name} not found"
|
||||
)
|
||||
|
||||
# Retrieve StoreListing to get active_version_id and has_approved_version
|
||||
store_listing = await prisma.models.StoreListing.prisma().find_first(
|
||||
where=prisma.types.StoreListingWhereInput(
|
||||
slug=agent_name,
|
||||
owningUserId=username, # Direct equality check instead of 'has'
|
||||
),
|
||||
include={"ActiveVersion": True},
|
||||
)
|
||||
|
||||
active_version_id = store_listing.activeVersionId if store_listing else None
|
||||
has_approved_version = (
|
||||
store_listing.hasApprovedVersion if store_listing else False
|
||||
)
|
||||
|
||||
logger.debug(f"Found agent details for {username}/{agent_name}")
|
||||
return backend.server.v2.store.model.StoreAgentDetails(
|
||||
store_listing_version_id=agent.storeListingVersionId,
|
||||
@@ -158,6 +176,8 @@ async def get_store_agent_details(
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
last_updated=agent.updated_at,
|
||||
active_version_id=active_version_id,
|
||||
has_approved_version=has_approved_version,
|
||||
)
|
||||
except backend.server.v2.store.exceptions.AgentNotFoundError:
|
||||
raise
|
||||
@@ -175,6 +195,7 @@ async def get_store_creators(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.CreatorsResponse:
|
||||
"""Get PUBLIC store creators from the Creator view"""
|
||||
logger.debug(
|
||||
f"Getting store creators. featured={featured}, search={search_query}, sorted_by={sorted_by}, page={page}"
|
||||
)
|
||||
@@ -322,6 +343,7 @@ async def get_store_creator_details(
|
||||
async def get_store_submissions(
|
||||
user_id: str, page: int = 1, page_size: int = 20
|
||||
) -> backend.server.v2.store.model.StoreSubmissionsResponse:
|
||||
"""Get store submissions for the authenticated user -- not an admin"""
|
||||
logger.debug(f"Getting store submissions for user {user_id}, page={page}")
|
||||
|
||||
try:
|
||||
@@ -343,8 +365,9 @@ async def get_store_submissions(
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Convert to response models
|
||||
submission_models = [
|
||||
backend.server.v2.store.model.StoreSubmission(
|
||||
submission_models = []
|
||||
for sub in submissions:
|
||||
submission_model = backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id=sub.agent_id,
|
||||
agent_version=sub.agent_version,
|
||||
name=sub.name,
|
||||
@@ -352,13 +375,18 @@ async def get_store_submissions(
|
||||
slug=sub.slug,
|
||||
description=sub.description,
|
||||
image_urls=sub.image_urls or [],
|
||||
date_submitted=sub.date_submitted or datetime.now(),
|
||||
date_submitted=sub.date_submitted or datetime.now(tz=timezone.utc),
|
||||
status=sub.status,
|
||||
runs=sub.runs or 0,
|
||||
rating=sub.rating or 0.0,
|
||||
store_listing_version_id=sub.store_listing_version_id,
|
||||
reviewer_id=sub.reviewer_id,
|
||||
review_comments=sub.review_comments,
|
||||
# internal_comments omitted for regular users
|
||||
reviewed_at=sub.reviewed_at,
|
||||
changes_summary=sub.changes_summary,
|
||||
)
|
||||
for sub in submissions
|
||||
]
|
||||
submission_models.append(submission_model)
|
||||
|
||||
logger.debug(f"Found {len(submission_models)} submissions")
|
||||
return backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
@@ -390,7 +418,7 @@ async def delete_store_submission(
|
||||
submission_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a store listing submission.
|
||||
Delete a store listing submission as the submitting user.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user
|
||||
@@ -437,9 +465,10 @@ async def create_store_submission(
|
||||
description: str = "",
|
||||
sub_heading: str = "",
|
||||
categories: list[str] = [],
|
||||
changes_summary: str = "Initial Submission",
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""
|
||||
Create a new store listing submission.
|
||||
Create the first (and only) store listing and thus submission as a normal user
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user submitting the listing
|
||||
@@ -450,7 +479,9 @@ async def create_store_submission(
|
||||
video_url: Optional URL to video demo
|
||||
image_urls: List of image URLs for the listing
|
||||
description: Description of the agent
|
||||
sub_heading: Optional sub-heading for the agent
|
||||
categories: List of categories for the agent
|
||||
changes_summary: Summary of changes made in this submission
|
||||
|
||||
Returns:
|
||||
StoreSubmission: The created store submission
|
||||
@@ -480,45 +511,66 @@ async def create_store_submission(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
|
||||
listing = await prisma.models.StoreListing.prisma().find_first(
|
||||
# Check if listing already exists for this agent
|
||||
existing_listing = await prisma.models.StoreListing.prisma().find_first(
|
||||
where=prisma.types.StoreListingWhereInput(
|
||||
agentId=agent_id, owningUserId=user_id
|
||||
)
|
||||
)
|
||||
if listing is not None:
|
||||
logger.warning(f"Listing already exists for agent {agent_id}")
|
||||
raise backend.server.v2.store.exceptions.ListingExistsError(
|
||||
"Listing already exists for this agent"
|
||||
|
||||
if existing_listing is not None:
|
||||
logger.info(
|
||||
f"Listing already exists for agent {agent_id}, creating new version instead"
|
||||
)
|
||||
|
||||
# Create the store listing
|
||||
listing = await prisma.models.StoreListing.prisma().create(
|
||||
data={
|
||||
"agentId": agent_id,
|
||||
"agentVersion": agent_version,
|
||||
"owningUserId": user_id,
|
||||
"createdAt": datetime.now(),
|
||||
"StoreListingVersions": {
|
||||
"create": {
|
||||
"agentId": agent_id,
|
||||
"agentVersion": agent_version,
|
||||
"slug": slug,
|
||||
"name": name,
|
||||
"videoUrl": video_url,
|
||||
"imageUrls": image_urls,
|
||||
"description": description,
|
||||
"categories": categories,
|
||||
"subHeading": sub_heading,
|
||||
}
|
||||
},
|
||||
# Delegate to create_store_version which already handles this case correctly
|
||||
return await create_store_version(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
store_listing_id=existing_listing.id,
|
||||
name=name,
|
||||
video_url=video_url,
|
||||
image_urls=image_urls,
|
||||
description=description,
|
||||
sub_heading=sub_heading,
|
||||
categories=categories,
|
||||
changes_summary=changes_summary,
|
||||
)
|
||||
|
||||
# If no existing listing, create a new one
|
||||
data = prisma.types.StoreListingCreateInput(
|
||||
slug=slug,
|
||||
agentId=agent_id,
|
||||
agentVersion=agent_version,
|
||||
owningUserId=user_id,
|
||||
createdAt=datetime.now(tz=timezone.utc),
|
||||
Versions={
|
||||
"create": [
|
||||
prisma.types.StoreListingVersionCreateInput(
|
||||
agentId=agent_id,
|
||||
agentVersion=agent_version,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
submittedAt=datetime.now(tz=timezone.utc),
|
||||
changesSummary=changes_summary,
|
||||
)
|
||||
]
|
||||
},
|
||||
include={"StoreListingVersions": True},
|
||||
)
|
||||
listing = await prisma.models.StoreListing.prisma().create(
|
||||
data=data,
|
||||
include=prisma.types.StoreListingInclude(Versions=True),
|
||||
)
|
||||
|
||||
store_listing_version_id = (
|
||||
listing.StoreListingVersions[0].id
|
||||
if listing.StoreListingVersions is not None
|
||||
and len(listing.StoreListingVersions) > 0
|
||||
listing.Versions[0].id
|
||||
if listing.Versions is not None and len(listing.Versions) > 0
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -537,6 +589,7 @@ async def create_store_submission(
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
changes_summary=changes_summary,
|
||||
)
|
||||
|
||||
except (
|
||||
@@ -551,13 +604,137 @@ async def create_store_submission(
|
||||
) from e
|
||||
|
||||
|
||||
async def create_store_version(
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
agent_version: int,
|
||||
store_listing_id: str,
|
||||
name: str,
|
||||
video_url: str | None = None,
|
||||
image_urls: list[str] = [],
|
||||
description: str = "",
|
||||
sub_heading: str = "",
|
||||
categories: list[str] = [],
|
||||
changes_summary: str = "Update Submission",
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""
|
||||
Create a new version for an existing store listing
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user submitting the version
|
||||
agent_id: ID of the agent being submitted
|
||||
agent_version: Version of the agent being submitted
|
||||
store_listing_id: ID of the existing store listing
|
||||
name: Name of the agent
|
||||
video_url: Optional URL to video demo
|
||||
image_urls: List of image URLs for the listing
|
||||
description: Description of the agent
|
||||
categories: List of categories for the agent
|
||||
changes_summary: Summary of changes from the previous version
|
||||
|
||||
Returns:
|
||||
StoreSubmission: The created store submission
|
||||
"""
|
||||
logger.debug(
|
||||
f"Creating new version for store listing {store_listing_id} for user {user_id}, agent {agent_id} v{agent_version}"
|
||||
)
|
||||
|
||||
try:
|
||||
# First verify the listing belongs to this user
|
||||
listing = await prisma.models.StoreListing.prisma().find_first(
|
||||
where=prisma.types.StoreListingWhereInput(
|
||||
id=store_listing_id, owningUserId=user_id
|
||||
),
|
||||
include={"Versions": {"order_by": {"version": "desc"}, "take": 1}},
|
||||
)
|
||||
|
||||
if not listing:
|
||||
raise backend.server.v2.store.exceptions.ListingNotFoundError(
|
||||
f"Store listing not found. User ID: {user_id}, Listing ID: {store_listing_id}"
|
||||
)
|
||||
|
||||
# Verify the agent belongs to this user
|
||||
agent = await prisma.models.AgentGraph.prisma().find_first(
|
||||
where=prisma.types.AgentGraphWhereInput(
|
||||
id=agent_id, version=agent_version, userId=user_id
|
||||
)
|
||||
)
|
||||
|
||||
if not agent:
|
||||
raise backend.server.v2.store.exceptions.AgentNotFoundError(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
|
||||
# Get the latest version number
|
||||
latest_version = listing.Versions[0] if listing.Versions else None
|
||||
|
||||
next_version = (latest_version.version + 1) if latest_version else 1
|
||||
|
||||
# Create a new version for the existing listing
|
||||
new_version = await prisma.models.StoreListingVersion.prisma().create(
|
||||
data=prisma.types.StoreListingVersionCreateInput(
|
||||
version=next_version,
|
||||
agentId=agent_id,
|
||||
agentVersion=agent_version,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
submittedAt=datetime.now(),
|
||||
changesSummary=changes_summary,
|
||||
storeListingId=store_listing_id,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Created new version for listing {store_listing_id} of agent {agent_id}"
|
||||
)
|
||||
# Return submission details
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
slug=listing.slug,
|
||||
sub_heading=sub_heading,
|
||||
description=description,
|
||||
image_urls=image_urls,
|
||||
date_submitted=datetime.now(),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
store_listing_version_id=new_version.id,
|
||||
changes_summary=changes_summary,
|
||||
version=next_version,
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create new store version"
|
||||
) from e
|
||||
|
||||
|
||||
async def create_store_review(
|
||||
user_id: str,
|
||||
store_listing_version_id: str,
|
||||
score: int,
|
||||
comments: str | None = None,
|
||||
) -> backend.server.v2.store.model.StoreReview:
|
||||
"""Create a review for a store listing as a user to detail their experience"""
|
||||
try:
|
||||
data = prisma.types.StoreListingReviewUpsertInput(
|
||||
update=prisma.types.StoreListingReviewUpdateInput(
|
||||
score=score,
|
||||
comments=comments,
|
||||
),
|
||||
create=prisma.types.StoreListingReviewCreateInput(
|
||||
reviewByUserId=user_id,
|
||||
storeListingVersionId=store_listing_version_id,
|
||||
score=score,
|
||||
comments=comments,
|
||||
),
|
||||
)
|
||||
review = await prisma.models.StoreListingReview.prisma().upsert(
|
||||
where={
|
||||
"storeListingVersionId_reviewByUserId": {
|
||||
@@ -565,18 +742,7 @@ async def create_store_review(
|
||||
"reviewByUserId": user_id,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"create": {
|
||||
"reviewByUserId": user_id,
|
||||
"storeListingVersionId": store_listing_version_id,
|
||||
"score": score,
|
||||
"comments": comments,
|
||||
},
|
||||
"update": {
|
||||
"score": score,
|
||||
"comments": comments,
|
||||
},
|
||||
},
|
||||
data=data,
|
||||
)
|
||||
|
||||
return backend.server.v2.store.model.StoreReview(
|
||||
@@ -598,7 +764,7 @@ async def get_user_profile(
|
||||
|
||||
try:
|
||||
profile = await prisma.models.Profile.prisma().find_first(
|
||||
where={"userId": user_id} # type: ignore
|
||||
where={"userId": user_id}
|
||||
)
|
||||
|
||||
if not profile:
|
||||
@@ -703,48 +869,39 @@ async def get_my_agents(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.MyAgentsResponse:
|
||||
"""Get the agents for the authenticated user"""
|
||||
logger.debug(f"Getting my agents for user {user_id}, page={page}")
|
||||
|
||||
try:
|
||||
agents_with_max_version = await prisma.models.AgentGraph.prisma().find_many(
|
||||
where=prisma.types.AgentGraphWhereInput(
|
||||
userId=user_id, StoreListing={"none": {"isDeleted": False}}
|
||||
),
|
||||
order=[{"version": "desc"}],
|
||||
distinct=["id"],
|
||||
search_filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
"Agent": {"is": {"StoreListing": {"none": {"isDeleted": False}}}},
|
||||
"isArchived": False,
|
||||
"isDeleted": False,
|
||||
}
|
||||
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=search_filter,
|
||||
order=[{"agentVersion": "desc"}],
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
include={"Agent": True},
|
||||
)
|
||||
|
||||
# store_listings = await prisma.models.StoreListing.prisma().find_many(
|
||||
# where=prisma.types.StoreListingWhereInput(
|
||||
# isDeleted=False,
|
||||
# ),
|
||||
# )
|
||||
|
||||
total = len(
|
||||
await prisma.models.AgentGraph.prisma().find_many(
|
||||
where=prisma.types.AgentGraphWhereInput(
|
||||
userId=user_id, StoreListing={"none": {"isDeleted": False}}
|
||||
),
|
||||
order=[{"version": "desc"}],
|
||||
distinct=["id"],
|
||||
)
|
||||
)
|
||||
|
||||
total = await prisma.models.LibraryAgent.prisma().count(where=search_filter)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
agents = agents_with_max_version
|
||||
|
||||
my_agents = [
|
||||
backend.server.v2.store.model.MyAgent(
|
||||
agent_id=agent.id,
|
||||
agent_version=agent.version,
|
||||
agent_name=agent.name or "",
|
||||
last_edited=agent.updatedAt or agent.createdAt,
|
||||
description=agent.description or "",
|
||||
agent_id=graph.id,
|
||||
agent_version=graph.version,
|
||||
agent_name=graph.name or "",
|
||||
last_edited=graph.updatedAt or graph.createdAt,
|
||||
description=graph.description or "",
|
||||
agent_image=library_agent.imageUrl,
|
||||
)
|
||||
for agent in agents
|
||||
for library_agent in library_agents
|
||||
if (graph := library_agent.Agent)
|
||||
]
|
||||
|
||||
return backend.server.v2.store.model.MyAgentsResponse(
|
||||
@@ -764,58 +921,87 @@ async def get_my_agents(
|
||||
|
||||
|
||||
async def get_agent(
|
||||
store_listing_version_id: str, version_id: Optional[int]
|
||||
user_id: str,
|
||||
store_listing_version_id: str,
|
||||
) -> GraphModel:
|
||||
"""Get agent using the version ID and store listing version ID."""
|
||||
try:
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"Agent": True}
|
||||
)
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}
|
||||
)
|
||||
)
|
||||
|
||||
if not store_listing_version:
|
||||
raise ValueError(f"Store listing version {store_listing_version_id} not found")
|
||||
|
||||
graph = await backend.data.graph.get_graph(
|
||||
user_id=user_id,
|
||||
graph_id=store_listing_version.agentId,
|
||||
version=store_listing_version.agentVersion,
|
||||
for_export=True,
|
||||
)
|
||||
if not graph:
|
||||
raise ValueError(
|
||||
f"Agent {store_listing_version.agentId} v{store_listing_version.agentVersion} not found"
|
||||
)
|
||||
|
||||
if not store_listing_version or not store_listing_version.Agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Store listing version {store_listing_version_id} not found",
|
||||
)
|
||||
return graph
|
||||
|
||||
graph_id = store_listing_version.agentId
|
||||
graph_version = store_listing_version.agentVersion
|
||||
graph = await backend.data.graph.get_graph(graph_id, graph_version)
|
||||
|
||||
if not graph:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404,
|
||||
detail=(
|
||||
f"Agent #{graph_id} not found "
|
||||
f"for store listing version #{store_listing_version_id}"
|
||||
),
|
||||
)
|
||||
#####################################################
|
||||
################## ADMIN FUNCTIONS ##################
|
||||
#####################################################
|
||||
|
||||
graph.version = 1
|
||||
graph.is_template = False
|
||||
graph.is_active = True
|
||||
delattr(graph, "user_id")
|
||||
|
||||
return graph
|
||||
async def _get_missing_sub_store_listing(
|
||||
graph: prisma.models.AgentGraph,
|
||||
) -> list[prisma.models.AgentGraph]:
|
||||
"""
|
||||
Agent graph can have sub-graphs, and those sub-graphs also need to be store listed.
|
||||
This method fetches the sub-graphs, and returns the ones not listed in the store.
|
||||
"""
|
||||
sub_graphs = await get_sub_graphs(graph)
|
||||
if not sub_graphs:
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch agent"
|
||||
) from e
|
||||
# Fetch all the sub-graphs that are listed, and return the ones missing.
|
||||
store_listed_sub_graphs = {
|
||||
(listing.agentId, listing.agentVersion)
|
||||
for listing in await prisma.models.StoreListingVersion.prisma().find_many(
|
||||
where={
|
||||
"OR": [
|
||||
{"agentId": sub_graph.id, "agentVersion": sub_graph.version}
|
||||
for sub_graph in sub_graphs
|
||||
],
|
||||
"submissionStatus": prisma.enums.SubmissionStatus.APPROVED,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
return [
|
||||
sub_graph
|
||||
for sub_graph in sub_graphs
|
||||
if (sub_graph.id, sub_graph.version) not in store_listed_sub_graphs
|
||||
]
|
||||
|
||||
|
||||
async def review_store_submission(
|
||||
store_listing_version_id: str, is_approved: bool, comments: str, reviewer_id: str
|
||||
) -> prisma.models.StoreListingSubmission:
|
||||
"""Review a store listing submission."""
|
||||
store_listing_version_id: str,
|
||||
is_approved: bool,
|
||||
external_comments: str,
|
||||
internal_comments: str,
|
||||
reviewer_id: str,
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""Review a store listing submission as an admin."""
|
||||
try:
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id},
|
||||
include={"StoreListing": True},
|
||||
include={
|
||||
"StoreListing": True,
|
||||
"Agent": {"include": AGENT_GRAPH_INCLUDE}, # type: ignore
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -825,10 +1011,34 @@ async def review_store_submission(
|
||||
detail=f"Store listing version {store_listing_version_id} not found",
|
||||
)
|
||||
|
||||
if is_approved:
|
||||
# If approving, update the listing to indicate it has an approved version
|
||||
if is_approved and store_listing_version.Agent:
|
||||
heading = f"Sub-graph of {store_listing_version.name}v{store_listing_version.agentVersion}"
|
||||
|
||||
sub_store_listing_versions = [
|
||||
prisma.types.StoreListingVersionCreateWithoutRelationsInput(
|
||||
agentId=sub_graph.id,
|
||||
agentVersion=sub_graph.version,
|
||||
name=sub_graph.name or heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
|
||||
subHeading=heading,
|
||||
description=f"{heading}: {sub_graph.description}",
|
||||
changesSummary=f"This listing is added as a {heading} / #{store_listing_version.agentId}.",
|
||||
isAvailable=False, # Hide sub-graphs from the store by default.
|
||||
submittedAt=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
for sub_graph in await _get_missing_sub_store_listing(
|
||||
store_listing_version.Agent
|
||||
)
|
||||
]
|
||||
|
||||
await prisma.models.StoreListing.prisma().update(
|
||||
where={"id": store_listing_version.StoreListing.id},
|
||||
data={"isApproved": True},
|
||||
data={
|
||||
"hasApprovedVersion": True,
|
||||
"ActiveVersion": {"connect": {"id": store_listing_version_id}},
|
||||
"Versions": {"create": sub_store_listing_versions},
|
||||
},
|
||||
)
|
||||
|
||||
submission_status = (
|
||||
@@ -837,36 +1047,230 @@ async def review_store_submission(
|
||||
else prisma.enums.SubmissionStatus.REJECTED
|
||||
)
|
||||
|
||||
update_data: prisma.types.StoreListingSubmissionUpdateInput = {
|
||||
"Status": submission_status,
|
||||
"reviewComments": comments,
|
||||
# Update the version with review information
|
||||
update_data: prisma.types.StoreListingVersionUpdateInput = {
|
||||
"submissionStatus": submission_status,
|
||||
"reviewComments": external_comments,
|
||||
"internalComments": internal_comments,
|
||||
"Reviewer": {"connect": {"id": reviewer_id}},
|
||||
"StoreListing": {"connect": {"id": store_listing_version.StoreListing.id}},
|
||||
"reviewedAt": datetime.now(tz=timezone.utc),
|
||||
}
|
||||
|
||||
create_data: prisma.types.StoreListingSubmissionCreateInput = {
|
||||
**update_data,
|
||||
"StoreListingVersion": {"connect": {"id": store_listing_version_id}},
|
||||
}
|
||||
|
||||
submission = await prisma.models.StoreListingSubmission.prisma().upsert(
|
||||
where={"storeListingVersionId": store_listing_version_id},
|
||||
data={
|
||||
"create": create_data,
|
||||
"update": update_data,
|
||||
},
|
||||
# Update the version
|
||||
submission = await prisma.models.StoreListingVersion.prisma().update(
|
||||
where={"id": store_listing_version_id},
|
||||
data=update_data,
|
||||
include={"StoreListing": True},
|
||||
)
|
||||
|
||||
if not submission:
|
||||
raise fastapi.HTTPException( # FIXME: don't return HTTP exceptions here
|
||||
status_code=404,
|
||||
detail=f"Store listing submission {store_listing_version_id} not found",
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
f"Failed to update store listing version {store_listing_version_id}"
|
||||
)
|
||||
|
||||
return submission
|
||||
# Convert to Pydantic model for consistency
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id=submission.agentId,
|
||||
agent_version=submission.agentVersion,
|
||||
name=submission.name,
|
||||
sub_heading=submission.subHeading,
|
||||
slug=(
|
||||
submission.StoreListing.slug
|
||||
if hasattr(submission, "storeListing") and submission.StoreListing
|
||||
else ""
|
||||
),
|
||||
description=submission.description,
|
||||
image_urls=submission.imageUrls or [],
|
||||
date_submitted=submission.submittedAt or submission.createdAt,
|
||||
status=submission.submissionStatus,
|
||||
runs=0, # Default values since we don't have this data here
|
||||
rating=0.0,
|
||||
store_listing_version_id=submission.id,
|
||||
reviewer_id=submission.reviewerId,
|
||||
review_comments=submission.reviewComments,
|
||||
internal_comments=submission.internalComments,
|
||||
reviewed_at=submission.reviewedAt,
|
||||
changes_summary=submission.changesSummary,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Could not create store submission review: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create store submission review"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_admin_listings_with_versions(
|
||||
status: prisma.enums.SubmissionStatus | None = None,
|
||||
search_query: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.StoreListingsWithVersionsResponse:
|
||||
"""
|
||||
Get store listings for admins with all their versions.
|
||||
|
||||
Args:
|
||||
status: Filter by submission status (PENDING, APPROVED, REJECTED)
|
||||
search_query: Search by name, description, or user email
|
||||
page: Page number for pagination
|
||||
page_size: Number of items per page
|
||||
|
||||
Returns:
|
||||
StoreListingsWithVersionsResponse with listings and their versions
|
||||
"""
|
||||
logger.debug(
|
||||
f"Getting admin store listings with status={status}, search={search_query}, page={page}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Build the where clause for StoreListing
|
||||
where_dict: prisma.types.StoreListingWhereInput = {
|
||||
"isDeleted": False,
|
||||
}
|
||||
if status:
|
||||
where_dict["Versions"] = {"some": {"submissionStatus": status}}
|
||||
|
||||
sanitized_query = sanitize_query(search_query)
|
||||
if sanitized_query:
|
||||
# Find users with matching email
|
||||
matching_users = await prisma.models.User.prisma().find_many(
|
||||
where={"email": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
)
|
||||
|
||||
user_ids = [user.id for user in matching_users]
|
||||
|
||||
# Set up OR conditions
|
||||
where_dict["OR"] = [
|
||||
{"slug": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
{
|
||||
"Versions": {
|
||||
"some": {
|
||||
"name": {"contains": sanitized_query, "mode": "insensitive"}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"Versions": {
|
||||
"some": {
|
||||
"description": {
|
||||
"contains": sanitized_query,
|
||||
"mode": "insensitive",
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"Versions": {
|
||||
"some": {
|
||||
"subHeading": {
|
||||
"contains": sanitized_query,
|
||||
"mode": "insensitive",
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
# Add user_id condition if any users matched
|
||||
if user_ids:
|
||||
where_dict["OR"].append({"owningUserId": {"in": user_ids}})
|
||||
|
||||
# Calculate pagination
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
# Create proper Prisma types for the query
|
||||
where = prisma.types.StoreListingWhereInput(**where_dict)
|
||||
include = prisma.types.StoreListingInclude(
|
||||
Versions=prisma.types.FindManyStoreListingVersionArgsFromStoreListing(
|
||||
order_by=prisma.types._StoreListingVersion_version_OrderByInput(
|
||||
version="desc"
|
||||
)
|
||||
),
|
||||
OwningUser=True,
|
||||
)
|
||||
|
||||
# Query listings with their versions
|
||||
listings = await prisma.models.StoreListing.prisma().find_many(
|
||||
where=where,
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
include=include,
|
||||
order=[{"createdAt": "desc"}],
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
total = await prisma.models.StoreListing.prisma().count(where=where)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Convert to response models
|
||||
listings_with_versions = []
|
||||
for listing in listings:
|
||||
versions: list[backend.server.v2.store.model.StoreSubmission] = []
|
||||
# If we have versions, turn them into StoreSubmission models
|
||||
for version in listing.Versions or []:
|
||||
version_model = backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id=version.agentId,
|
||||
agent_version=version.agentVersion,
|
||||
name=version.name,
|
||||
sub_heading=version.subHeading,
|
||||
slug=listing.slug,
|
||||
description=version.description,
|
||||
image_urls=version.imageUrls or [],
|
||||
date_submitted=version.submittedAt or version.createdAt,
|
||||
status=version.submissionStatus,
|
||||
runs=0, # Default values since we don't have this data here
|
||||
rating=0.0, # Default values since we don't have this data here
|
||||
store_listing_version_id=version.id,
|
||||
reviewer_id=version.reviewerId,
|
||||
review_comments=version.reviewComments,
|
||||
internal_comments=version.internalComments,
|
||||
reviewed_at=version.reviewedAt,
|
||||
changes_summary=version.changesSummary,
|
||||
version=version.version,
|
||||
)
|
||||
versions.append(version_model)
|
||||
|
||||
# Get the latest version (first in the sorted list)
|
||||
latest_version = versions[0] if versions else None
|
||||
|
||||
creator_email = listing.OwningUser.email if listing.OwningUser else None
|
||||
|
||||
listing_with_versions = (
|
||||
backend.server.v2.store.model.StoreListingWithVersions(
|
||||
listing_id=listing.id,
|
||||
slug=listing.slug,
|
||||
agent_id=listing.agentId,
|
||||
agent_version=listing.agentVersion,
|
||||
active_version_id=listing.activeVersionId,
|
||||
has_approved_version=listing.hasApprovedVersion,
|
||||
creator_email=creator_email,
|
||||
latest_version=latest_version,
|
||||
versions=versions,
|
||||
)
|
||||
)
|
||||
|
||||
listings_with_versions.append(listing_with_versions)
|
||||
|
||||
logger.debug(f"Found {len(listings_with_versions)} listings for admin")
|
||||
return backend.server.v2.store.model.StoreListingsWithVersionsResponse(
|
||||
listings=listings_with_versions,
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=page,
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching admin store listings: {e}")
|
||||
# Return empty response rather than exposing internal errors
|
||||
return backend.server.v2.store.model.StoreListingsWithVersionsResponse(
|
||||
listings=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=page,
|
||||
total_items=0,
|
||||
total_pages=0,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime
|
||||
|
||||
import prisma.enums
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import pytest
|
||||
@@ -83,21 +84,35 @@ async def test_get_store_agent_details(mocker):
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
# Mock prisma call
|
||||
# Create a mock StoreListing result
|
||||
mock_store_listing = mocker.MagicMock()
|
||||
mock_store_listing.activeVersionId = "active-version-id"
|
||||
mock_store_listing.hasApprovedVersion = True
|
||||
|
||||
# Mock StoreAgent prisma call
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Mock StoreListing prisma call - this is what was missing
|
||||
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_store_listing
|
||||
)
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_agent_details("creator", "test-agent")
|
||||
|
||||
# Verify results
|
||||
assert result.slug == "test-agent"
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.active_version_id == "active-version-id"
|
||||
assert result.has_approved_version is True
|
||||
|
||||
# Verify mock called correctly
|
||||
# Verify mocks called correctly
|
||||
mock_store_agent.return_value.find_first.assert_called_once_with(
|
||||
where={"creator_username": "creator", "slug": "test-agent"}
|
||||
)
|
||||
mock_store_listing_db.return_value.find_first.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -146,7 +161,6 @@ async def test_create_store_submission(mocker):
|
||||
userId="user-id",
|
||||
createdAt=datetime.now(),
|
||||
isActive=True,
|
||||
isTemplate=False,
|
||||
)
|
||||
|
||||
mock_listing = prisma.models.StoreListing(
|
||||
@@ -154,16 +168,16 @@ async def test_create_store_submission(mocker):
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isDeleted=False,
|
||||
isApproved=False,
|
||||
hasApprovedVersion=False,
|
||||
slug="test-agent",
|
||||
agentId="agent-id",
|
||||
agentVersion=1,
|
||||
owningUserId="user-id",
|
||||
StoreListingVersions=[
|
||||
Versions=[
|
||||
prisma.models.StoreListingVersion(
|
||||
id="version-id",
|
||||
agentId="agent-id",
|
||||
agentVersion=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
createdAt=datetime.now(),
|
||||
@@ -174,8 +188,9 @@ async def test_create_store_submission(mocker):
|
||||
isFeatured=False,
|
||||
isDeleted=False,
|
||||
version=1,
|
||||
storeListingId="listing-id",
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
isAvailable=True,
|
||||
isApproved=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -70,6 +70,12 @@ class ProfileNotFoundError(StoreError):
|
||||
pass
|
||||
|
||||
|
||||
class ListingNotFoundError(StoreError):
|
||||
"""Raised when a store listing is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SubmissionNotFoundError(StoreError):
|
||||
"""Raised when a submission is not found"""
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
from enum import Enum
|
||||
@@ -34,6 +35,13 @@ class ImageStyle(str, Enum):
|
||||
DIGITAL_ART = "digital art"
|
||||
|
||||
|
||||
async def generate_agent_image(agent: Graph | AgentGraph) -> io.BytesIO:
|
||||
if settings.config.use_agent_image_generation_v2:
|
||||
return await asyncio.to_thread(generate_agent_image_v2, graph=agent)
|
||||
else:
|
||||
return await generate_agent_image_v1(agent=agent)
|
||||
|
||||
|
||||
def generate_agent_image_v2(graph: Graph | AgentGraph) -> io.BytesIO:
|
||||
"""
|
||||
Generate an image for an agent using Ideogram model.
|
||||
@@ -91,7 +99,7 @@ def generate_agent_image_v2(graph: Graph | AgentGraph) -> io.BytesIO:
|
||||
return io.BytesIO(requests.get(url).content)
|
||||
|
||||
|
||||
async def generate_agent_image(agent: Graph | AgentGraph) -> io.BytesIO:
|
||||
async def generate_agent_image_v1(agent: Graph | AgentGraph) -> io.BytesIO:
|
||||
"""
|
||||
Generate an image for an agent using Flux model via Replicate API.
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ class MyAgent(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
agent_name: str
|
||||
agent_image: str | None = None
|
||||
description: str
|
||||
last_edited: datetime.datetime
|
||||
|
||||
@@ -66,6 +67,9 @@ class StoreAgentDetails(pydantic.BaseModel):
|
||||
versions: list[str]
|
||||
last_updated: datetime.datetime
|
||||
|
||||
active_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
|
||||
|
||||
class Creator(pydantic.BaseModel):
|
||||
name: str
|
||||
@@ -116,6 +120,19 @@ class StoreSubmission(pydantic.BaseModel):
|
||||
runs: int
|
||||
rating: float
|
||||
store_listing_version_id: str | None = None
|
||||
version: int | None = None # Actual version number from the database
|
||||
|
||||
reviewer_id: str | None = None
|
||||
review_comments: str | None = None # External comments visible to creator
|
||||
internal_comments: str | None = None # Private notes for admin use only
|
||||
reviewed_at: datetime.datetime | None = None
|
||||
changes_summary: str | None = None
|
||||
|
||||
reviewer_id: str | None = None
|
||||
review_comments: str | None = None # External comments visible to creator
|
||||
internal_comments: str | None = None # Private notes for admin use only
|
||||
reviewed_at: datetime.datetime | None = None
|
||||
changes_summary: str | None = None
|
||||
|
||||
|
||||
class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||
@@ -123,6 +140,27 @@ class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreListingWithVersions(pydantic.BaseModel):
|
||||
"""A store listing with its version history"""
|
||||
|
||||
listing_id: str
|
||||
slug: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
active_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
creator_email: str | None = None
|
||||
latest_version: StoreSubmission | None = None
|
||||
versions: list[StoreSubmission] = []
|
||||
|
||||
|
||||
class StoreListingsWithVersionsResponse(pydantic.BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
|
||||
listings: list[StoreListingWithVersions]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreSubmissionRequest(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
@@ -133,6 +171,7 @@ class StoreSubmissionRequest(pydantic.BaseModel):
|
||||
image_urls: list[str] = []
|
||||
description: str = ""
|
||||
categories: list[str] = []
|
||||
changes_summary: str | None = None
|
||||
|
||||
|
||||
class ProfileDetails(pydantic.BaseModel):
|
||||
@@ -157,4 +196,5 @@ class StoreReviewCreate(pydantic.BaseModel):
|
||||
class ReviewSubmissionRequest(pydantic.BaseModel):
|
||||
store_listing_version_id: str
|
||||
is_approved: bool
|
||||
comments: str
|
||||
comments: str # External comments visible to creator
|
||||
internal_comments: str | None = None # Private admin notes
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
import typing
|
||||
@@ -8,7 +7,6 @@ import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
import backend.data.block
|
||||
import backend.data.graph
|
||||
@@ -16,6 +14,7 @@ import backend.server.v2.store.db
|
||||
import backend.server.v2.store.image_gen
|
||||
import backend.server.v2.store.media
|
||||
import backend.server.v2.store.model
|
||||
import backend.util.json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,7 +34,7 @@ router = fastapi.APIRouter()
|
||||
async def get_profile(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
]
|
||||
],
|
||||
):
|
||||
"""
|
||||
Get the profile details for the authenticated user.
|
||||
@@ -339,7 +338,7 @@ async def get_creator(
|
||||
async def get_my_agents(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
]
|
||||
],
|
||||
):
|
||||
try:
|
||||
agents = await backend.server.v2.store.db.get_my_agents(user_id)
|
||||
@@ -467,7 +466,7 @@ async def create_submission(
|
||||
HTTPException: If there is an error creating the submission
|
||||
"""
|
||||
try:
|
||||
submission = await backend.server.v2.store.db.create_store_submission(
|
||||
return await backend.server.v2.store.db.create_store_submission(
|
||||
user_id=user_id,
|
||||
agent_id=submission_request.agent_id,
|
||||
agent_version=submission_request.agent_version,
|
||||
@@ -478,8 +477,8 @@ async def create_submission(
|
||||
description=submission_request.description,
|
||||
sub_heading=submission_request.sub_heading,
|
||||
categories=submission_request.categories,
|
||||
changes_summary=submission_request.changes_summary or "Initial Submission",
|
||||
)
|
||||
return submission
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst creating store submission")
|
||||
return fastapi.responses.JSONResponse(
|
||||
@@ -591,19 +590,18 @@ async def generate_image(
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def download_agent_file(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
store_listing_version_id: str = fastapi.Path(
|
||||
..., description="The ID of the agent to download"
|
||||
),
|
||||
version: typing.Optional[int] = fastapi.Query(
|
||||
None, description="Specific version of the agent"
|
||||
),
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""
|
||||
Download the agent file by streaming its content.
|
||||
|
||||
Args:
|
||||
agent_id (str): The ID of the agent to download.
|
||||
version (Optional[int]): Specific version of the agent to download.
|
||||
store_listing_version_id (str): The ID of the agent to download
|
||||
|
||||
Returns:
|
||||
StreamingResponse: A streaming response containing the agent's graph data.
|
||||
@@ -613,65 +611,18 @@ async def download_agent_file(
|
||||
"""
|
||||
|
||||
graph_data = await backend.server.v2.store.db.get_agent(
|
||||
store_listing_version_id=store_listing_version_id, version_id=version
|
||||
user_id=user_id,
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
|
||||
graph_data.clean_graph()
|
||||
graph_date_dict = jsonable_encoder(graph_data)
|
||||
|
||||
def remove_credentials(obj):
|
||||
if obj and isinstance(obj, dict):
|
||||
if "credentials" in obj:
|
||||
del obj["credentials"]
|
||||
if "creds" in obj:
|
||||
del obj["creds"]
|
||||
|
||||
for value in obj.values():
|
||||
remove_credentials(value)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
remove_credentials(item)
|
||||
return obj
|
||||
|
||||
graph_date_dict = remove_credentials(graph_date_dict)
|
||||
|
||||
file_name = f"agent_{store_listing_version_id}_v{version or 'latest'}.json"
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(json.dumps(graph_date_dict))
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/review/{store_listing_version_id}",
|
||||
tags=["store", "private"],
|
||||
)
|
||||
async def review_submission(
|
||||
request: backend.server.v2.store.model.ReviewSubmissionRequest,
|
||||
user: typing.Annotated[
|
||||
autogpt_libs.auth.models.User,
|
||||
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
|
||||
],
|
||||
):
|
||||
# Proceed with the review submission logic
|
||||
try:
|
||||
submission = await backend.server.v2.store.db.review_store_submission(
|
||||
store_listing_version_id=request.store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
comments=request.comments,
|
||||
reviewer_id=user.user_id,
|
||||
)
|
||||
return submission
|
||||
except Exception as e:
|
||||
logger.error(f"Could not create store submission review: {e}")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="An error occurred while creating the store submission review",
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
|
||||
|
||||
import uvicorn
|
||||
from autogpt_libs.auth import parse_jwt_token
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
@@ -12,7 +13,7 @@ from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import ExecutionSubscription, Methods, WsMessage
|
||||
from backend.util.service import AppProcess
|
||||
from backend.util.service import AppProcess, get_service_client
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -39,6 +40,13 @@ def get_connection_manager():
|
||||
return _connection_manager
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client():
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
try:
|
||||
redis.connect()
|
||||
@@ -74,7 +82,10 @@ async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
|
||||
|
||||
async def handle_subscribe(
|
||||
websocket: WebSocket, manager: ConnectionManager, message: WsMessage
|
||||
connection_manager: ConnectionManager,
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
message: WsMessage,
|
||||
):
|
||||
if not message.data:
|
||||
await websocket.send_text(
|
||||
@@ -85,20 +96,47 @@ async def handle_subscribe(
|
||||
).model_dump_json()
|
||||
)
|
||||
else:
|
||||
ex_sub = ExecutionSubscription.model_validate(message.data)
|
||||
await manager.subscribe(ex_sub.graph_id, ex_sub.graph_version, websocket)
|
||||
logger.debug(f"New execution subscription for graph {ex_sub.graph_id}")
|
||||
sub_req = ExecutionSubscription.model_validate(message.data)
|
||||
|
||||
# Verify that user has read access to graph
|
||||
# if not get_db_client().get_graph(
|
||||
# graph_id=sub_req.graph_id,
|
||||
# version=sub_req.graph_version,
|
||||
# user_id=user_id,
|
||||
# ):
|
||||
# await websocket.send_text(
|
||||
# WsMessage(
|
||||
# method=Methods.ERROR,
|
||||
# success=False,
|
||||
# error="Access denied",
|
||||
# ).model_dump_json()
|
||||
# )
|
||||
# return
|
||||
|
||||
await connection_manager.subscribe(
|
||||
user_id=user_id,
|
||||
graph_id=sub_req.graph_id,
|
||||
graph_version=sub_req.graph_version,
|
||||
websocket=websocket,
|
||||
)
|
||||
logger.debug(
|
||||
f"New execution subscription for user #{user_id} "
|
||||
f"graph #{sub_req.graph_id}v{sub_req.graph_version}"
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.SUBSCRIBE,
|
||||
success=True,
|
||||
channel=f"{ex_sub.graph_id}_{ex_sub.graph_version}",
|
||||
channel=f"{user_id}_{sub_req.graph_id}_{sub_req.graph_version}",
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
|
||||
async def handle_unsubscribe(
|
||||
websocket: WebSocket, manager: ConnectionManager, message: WsMessage
|
||||
connection_manager: ConnectionManager,
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
message: WsMessage,
|
||||
):
|
||||
if not message.data:
|
||||
await websocket.send_text(
|
||||
@@ -109,14 +147,22 @@ async def handle_unsubscribe(
|
||||
).model_dump_json()
|
||||
)
|
||||
else:
|
||||
ex_sub = ExecutionSubscription.model_validate(message.data)
|
||||
await manager.unsubscribe(ex_sub.graph_id, ex_sub.graph_version, websocket)
|
||||
logger.debug(f"Removed execution subscription for graph {ex_sub.graph_id}")
|
||||
unsub_req = ExecutionSubscription.model_validate(message.data)
|
||||
await connection_manager.unsubscribe(
|
||||
user_id=user_id,
|
||||
graph_id=unsub_req.graph_id,
|
||||
graph_version=unsub_req.graph_version,
|
||||
websocket=websocket,
|
||||
)
|
||||
logger.debug(
|
||||
f"Removed execution subscription for user #{user_id} "
|
||||
f"graph #{unsub_req.graph_id}v{unsub_req.graph_version}"
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.UNSUBSCRIBE,
|
||||
success=True,
|
||||
channel=f"{ex_sub.graph_id}_{ex_sub.graph_version}",
|
||||
channel=f"{unsub_req.graph_id}_{unsub_req.graph_version}",
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
@@ -145,13 +191,32 @@ async def websocket_router(
|
||||
)
|
||||
continue
|
||||
|
||||
if message.method == Methods.SUBSCRIBE:
|
||||
await handle_subscribe(websocket, manager, message)
|
||||
try:
|
||||
if message.method == Methods.SUBSCRIBE:
|
||||
await handle_subscribe(
|
||||
connection_manager=manager,
|
||||
websocket=websocket,
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
|
||||
elif message.method == Methods.UNSUBSCRIBE:
|
||||
await handle_unsubscribe(websocket, manager, message)
|
||||
elif message.method == Methods.UNSUBSCRIBE:
|
||||
await handle_unsubscribe(
|
||||
connection_manager=manager,
|
||||
websocket=websocket,
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error while handling '{message.method}' message "
|
||||
f"for user #{user_id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
elif message.method == Methods.ERROR:
|
||||
if message.method == Methods.ERROR:
|
||||
logger.error(f"WebSocket Error message received: {message.data}")
|
||||
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from prisma.models import User
|
||||
|
||||
from backend.blocks.basic import AgentInputBlock, PrintToConsoleBlock
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
from backend.blocks.text import FillTextTemplateBlock
|
||||
from backend.data import graph
|
||||
from backend.data.graph import create_graph
|
||||
@@ -29,7 +30,7 @@ def create_test_graph() -> graph.Graph:
|
||||
"""
|
||||
InputBlock
|
||||
\
|
||||
---- FillTextTemplateBlock ---- PrintToConsoleBlock
|
||||
---- FillTextTemplateBlock ---- StoreValueBlock
|
||||
/
|
||||
InputBlock
|
||||
"""
|
||||
@@ -52,7 +53,7 @@ def create_test_graph() -> graph.Graph:
|
||||
"values_#_c": "!!!",
|
||||
},
|
||||
),
|
||||
graph.Node(block_id=PrintToConsoleBlock().id),
|
||||
graph.Node(block_id=StoreValueBlock().id),
|
||||
]
|
||||
links = [
|
||||
graph.Link(
|
||||
@@ -71,7 +72,7 @@ def create_test_graph() -> graph.Graph:
|
||||
source_id=nodes[2].id,
|
||||
sink_id=nodes[3].id,
|
||||
source_name="output",
|
||||
sink_name="text",
|
||||
sink_name="input",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -93,11 +94,7 @@ async def sample_agent():
|
||||
user_id=test_user.id,
|
||||
node_input=input_data,
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(
|
||||
test_user.id, test_graph.id, response.graph_exec_id, 10
|
||||
)
|
||||
print(result)
|
||||
await wait_execution(test_user.id, test_graph.id, response.graph_exec_id, 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -29,15 +29,25 @@ def clean_exec_files(graph_exec_id: str, file: str = "") -> None:
|
||||
shutil.rmtree(exec_path)
|
||||
|
||||
|
||||
"""
|
||||
MediaFile is a string that represents a file. It can be one of the following:
|
||||
- Data URI: base64 encoded media file. See https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data/
|
||||
- URL: Media file hosted on the internet, it starts with http:// or https://.
|
||||
- Local path (anything else): A temporary file path living within graph execution time.
|
||||
|
||||
Note: Replace this type alias into a proper class, when more information is needed.
|
||||
"""
|
||||
MediaFile = str
|
||||
class MediaFile(str):
|
||||
"""
|
||||
MediaFile is a string that represents a file. It can be one of the following:
|
||||
- Data URI: base64 encoded media file. See https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data/
|
||||
- URL: Media file hosted on the internet, it starts with http:// or https://.
|
||||
- Local path (anything else): A temporary file path living within graph execution time.
|
||||
|
||||
Note: Replace this type alias into a proper class, when more information is needed.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type, handler):
|
||||
return handler(str)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema, handler):
|
||||
json_schema = handler(core_schema)
|
||||
json_schema["format"] = "file"
|
||||
return json_schema
|
||||
|
||||
|
||||
def store_media_file(
|
||||
|
||||
@@ -44,7 +44,7 @@ from Pyro5 import config as pyro_config
|
||||
from backend.data import db, rabbitmq, redis
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.json import to_dict
|
||||
from backend.util.process import AppProcess
|
||||
from backend.util.process import AppProcess, get_service_name
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Config, Secrets
|
||||
|
||||
@@ -190,7 +190,17 @@ class BaseAppService(AppProcess, ABC):
|
||||
|
||||
@classmethod
|
||||
def get_host(cls) -> str:
|
||||
return os.environ.get(f"{cls.service_name.upper()}_HOST", api_host)
|
||||
source_host = os.environ.get(f"{get_service_name().upper()}_HOST", api_host)
|
||||
target_host = os.environ.get(f"{cls.service_name.upper()}_HOST", api_host)
|
||||
|
||||
if source_host == target_host and source_host != api_host:
|
||||
logger.warning(
|
||||
f"Service {cls.service_name} is the same host as the source service."
|
||||
f"Use the localhost of {api_host} instead."
|
||||
)
|
||||
return api_host
|
||||
|
||||
return target_host
|
||||
|
||||
@property
|
||||
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
|
||||
@@ -455,7 +465,7 @@ def fastapi_get_service_client(
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error in {method_name}: {e.response.text}")
|
||||
error = RemoteCallError.model_validate(e.response.json(), strict=False)
|
||||
error = RemoteCallError.model_validate(e.response.json())
|
||||
# DEBUG HELP: if you made a custom exception, make sure you override self.args to be how to make your exception
|
||||
raise EXCEPTION_MAPPING.get(error.type, Exception)(
|
||||
*(error.args or [str(e)])
|
||||
|
||||
@@ -113,6 +113,14 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default="%Y-%W", # This will allow for weekly refunds per user.
|
||||
description="Time key format for refund requests.",
|
||||
)
|
||||
execution_cost_count_threshold: int = Field(
|
||||
default=100,
|
||||
description="Number of executions after which the cost is calculated.",
|
||||
)
|
||||
execution_cost_per_threshold: int = Field(
|
||||
default=1,
|
||||
description="Cost per execution in cents after each threshold.",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
@@ -219,6 +227,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=True,
|
||||
description="Whether to use the new agent image generation service",
|
||||
)
|
||||
enable_agent_input_subtype_blocks: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable the agent input subtype blocks",
|
||||
)
|
||||
|
||||
@field_validator("platform_base_url", "frontend_base_url")
|
||||
@classmethod
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the column `isTemplate` on the `AgentGraph` table. All the data in the column will be lost.
|
||||
|
||||
*/
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraph" DROP COLUMN "isTemplate";
|
||||
@@ -0,0 +1,372 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- The enum type "SubmissionStatus" will be replaced. The 'DAFT' value is removed, so any data using 'DAFT' will be updated to 'DRAFT'. If there are rows still expecting 'DAFT' after this change, it will fail.
|
||||
- You are about to drop the column "isApproved" on the "StoreListing" table. All the data in that column will be lost.
|
||||
- You are about to drop the column "slug" on the "StoreListingVersion" table. All the data in that column will be lost.
|
||||
- You are about to drop the "StoreListingSubmission" table. Data in that table (beyond what is copied over) will be permanently lost.
|
||||
- A unique constraint covering the column "activeVersionId" on the "StoreListing" table will be added. If duplicates already exist, this will fail.
|
||||
- A unique constraint covering the columns ("storeListingId","version") on "StoreListingVersion" will be added. If duplicates already exist, this will fail.
|
||||
- The "storeListingId" column on "StoreListingVersion" is set to NOT NULL. If any rows currently have a NULL value, this step will fail.
|
||||
- The views "StoreSubmission", "StoreAgent", and "Creator" are dropped and recreated. Any usage or references to them will be momentarily disrupted until the views are recreated.
|
||||
*/
|
||||
|
||||
BEGIN;
|
||||
|
||||
-- First, drop all views that depend on the columns and types we're modifying
|
||||
DROP VIEW IF EXISTS "StoreSubmission";
|
||||
DROP VIEW IF EXISTS "StoreAgent";
|
||||
DROP VIEW IF EXISTS "Creator";
|
||||
|
||||
-- Create the new enum type
|
||||
CREATE TYPE "SubmissionStatus_new" AS ENUM ('DRAFT', 'PENDING', 'APPROVED', 'REJECTED');
|
||||
|
||||
-- Modify the column with the correct casing (Status with capital S)
|
||||
ALTER TABLE "StoreListingSubmission" ALTER COLUMN "Status" DROP DEFAULT;
|
||||
ALTER TABLE "StoreListingSubmission"
|
||||
ALTER COLUMN "Status" TYPE "SubmissionStatus_new"
|
||||
USING (
|
||||
CASE WHEN "Status"::text = 'DAFT' THEN 'DRAFT'::text
|
||||
ELSE "Status"::text
|
||||
END
|
||||
)::"SubmissionStatus_new";
|
||||
|
||||
-- Rename the enum types
|
||||
ALTER TYPE "SubmissionStatus" RENAME TO "SubmissionStatus_old";
|
||||
ALTER TYPE "SubmissionStatus_new" RENAME TO "SubmissionStatus";
|
||||
DROP TYPE "SubmissionStatus_old";
|
||||
|
||||
-- Set default back
|
||||
ALTER TABLE "StoreListingSubmission" ALTER COLUMN "Status" SET DEFAULT 'PENDING';
|
||||
|
||||
-- Drop constraints
|
||||
ALTER TABLE "StoreListingSubmission" DROP CONSTRAINT IF EXISTS "StoreListingSubmission_reviewerId_fkey";
|
||||
|
||||
-- Drop indexes
|
||||
DROP INDEX IF EXISTS "StoreListing_isDeleted_isApproved_idx";
|
||||
DROP INDEX IF EXISTS "StoreListingSubmission_storeListingVersionId_key";
|
||||
|
||||
-- Modify StoreListing
|
||||
ALTER TABLE "StoreListing"
|
||||
DROP COLUMN IF EXISTS "isApproved",
|
||||
ADD COLUMN IF NOT EXISTS "activeVersionId" TEXT,
|
||||
ADD COLUMN IF NOT EXISTS "hasApprovedVersion" BOOLEAN NOT NULL DEFAULT false,
|
||||
ADD COLUMN IF NOT EXISTS "slug" TEXT;
|
||||
|
||||
-- First add ALL columns to StoreListingVersion (including the submissionStatus column)
|
||||
ALTER TABLE "StoreListingVersion"
|
||||
ADD COLUMN IF NOT EXISTS "reviewerId" TEXT,
|
||||
ADD COLUMN IF NOT EXISTS "reviewComments" TEXT,
|
||||
ADD COLUMN IF NOT EXISTS "internalComments" TEXT,
|
||||
ADD COLUMN IF NOT EXISTS "reviewedAt" TIMESTAMP(3),
|
||||
ADD COLUMN IF NOT EXISTS "changesSummary" TEXT,
|
||||
ADD COLUMN IF NOT EXISTS "submissionStatus" "SubmissionStatus" NOT NULL DEFAULT 'DRAFT',
|
||||
ADD COLUMN IF NOT EXISTS "submittedAt" TIMESTAMP(3),
|
||||
ALTER COLUMN "storeListingId" SET NOT NULL;
|
||||
|
||||
-- NOW copy data from StoreListingSubmission to StoreListingVersion
|
||||
DO $$
|
||||
BEGIN
|
||||
-- First, check what columns actually exist in the StoreListingSubmission table
|
||||
DECLARE
|
||||
has_reviewerId BOOLEAN := (
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.columns
|
||||
WHERE table_name = 'StoreListingSubmission'
|
||||
AND column_name = 'reviewerId'
|
||||
)
|
||||
);
|
||||
|
||||
has_reviewComments BOOLEAN := (
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.columns
|
||||
WHERE table_name = 'StoreListingSubmission'
|
||||
AND column_name = 'reviewComments'
|
||||
)
|
||||
);
|
||||
|
||||
has_changesSummary BOOLEAN := (
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.columns
|
||||
WHERE table_name = 'StoreListingSubmission'
|
||||
AND column_name = 'changesSummary'
|
||||
)
|
||||
);
|
||||
BEGIN
|
||||
-- Only copy fields that we know exist
|
||||
IF has_reviewerId THEN
|
||||
UPDATE "StoreListingVersion" AS v
|
||||
SET "reviewerId" = s."reviewerId"
|
||||
FROM "StoreListingSubmission" AS s
|
||||
WHERE v."id" = s."storeListingVersionId";
|
||||
END IF;
|
||||
|
||||
IF has_reviewComments THEN
|
||||
UPDATE "StoreListingVersion" AS v
|
||||
SET "reviewComments" = s."reviewComments"
|
||||
FROM "StoreListingSubmission" AS s
|
||||
WHERE v."id" = s."storeListingVersionId";
|
||||
END IF;
|
||||
|
||||
IF has_changesSummary THEN
|
||||
UPDATE "StoreListingVersion" AS v
|
||||
SET "changesSummary" = s."changesSummary"
|
||||
FROM "StoreListingSubmission" AS s
|
||||
WHERE v."id" = s."storeListingVersionId";
|
||||
END IF;
|
||||
END;
|
||||
|
||||
-- Update submission status based on StoreListingSubmission status
|
||||
UPDATE "StoreListingVersion" AS v
|
||||
SET "submissionStatus" = s."Status"
|
||||
FROM "StoreListingSubmission" AS s
|
||||
WHERE v."id" = s."storeListingVersionId";
|
||||
|
||||
-- Update reviewedAt timestamps for versions with APPROVED or REJECTED status
|
||||
UPDATE "StoreListingVersion" AS v
|
||||
SET "reviewedAt" = s."updatedAt"
|
||||
FROM "StoreListingSubmission" AS s
|
||||
WHERE v."id" = s."storeListingVersionId"
|
||||
AND s."Status" IN ('APPROVED', 'REJECTED');
|
||||
END;
|
||||
$$;
|
||||
|
||||
-- Drop the StoreListingSubmission table
|
||||
DROP TABLE IF EXISTS "StoreListingSubmission";
|
||||
|
||||
-- Copy slugs from StoreListingVersion to StoreListing
|
||||
WITH latest_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
"slug",
|
||||
ROW_NUMBER() OVER (PARTITION BY "storeListingId" ORDER BY "version" DESC) as rn
|
||||
FROM "StoreListingVersion"
|
||||
)
|
||||
UPDATE "StoreListing" sl
|
||||
SET "slug" = lv."slug"
|
||||
FROM latest_versions lv
|
||||
WHERE sl."id" = lv."storeListingId"
|
||||
AND lv.rn = 1;
|
||||
|
||||
-- Make StoreListing.slug required and unique
|
||||
ALTER TABLE "StoreListing" ALTER COLUMN "slug" SET NOT NULL;
|
||||
CREATE UNIQUE INDEX "StoreListing_owningUserId_slug_key" ON "StoreListing"("owningUserId", "slug");
|
||||
DROP INDEX "StoreListing_owningUserId_idx";
|
||||
|
||||
-- Drop the slug column from StoreListingVersion since it's now on StoreListing
|
||||
ALTER TABLE "StoreListingVersion" DROP COLUMN "slug";
|
||||
|
||||
-- Update both sides of the relation from one-to-one to one-to-many
|
||||
-- The AgentGraph->StoreListingVersion relationship is now one-to-many
|
||||
|
||||
-- Drop the unique constraint but add a non-unique index for query performance
|
||||
ALTER TABLE "StoreListingVersion" DROP CONSTRAINT IF EXISTS "StoreListingVersion_agentId_agentVersion_key";
|
||||
CREATE INDEX IF NOT EXISTS "StoreListingVersion_agentId_agentVersion_idx"
|
||||
ON "StoreListingVersion"("agentId", "agentVersion");
|
||||
|
||||
-- Set isApproved based on submissionStatus before removing it
|
||||
UPDATE "StoreListingVersion"
|
||||
SET "submissionStatus" = 'APPROVED'
|
||||
WHERE "isApproved" = true;
|
||||
|
||||
-- Drop the isApproved column from StoreListingVersion since it's redundant with submissionStatus
|
||||
ALTER TABLE "StoreListingVersion" DROP COLUMN "isApproved";
|
||||
|
||||
-- Initialize hasApprovedVersion for existing StoreListing rows ***
|
||||
-- This sets "hasApprovedVersion" = TRUE for any StoreListing
|
||||
-- that has at least one corresponding version with "APPROVED" status.
|
||||
UPDATE "StoreListing" sl
|
||||
SET "hasApprovedVersion" = (
|
||||
SELECT COUNT(*) > 0
|
||||
FROM "StoreListingVersion" slv
|
||||
WHERE slv."storeListingId" = sl.id
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
AND sl."agentId" = slv."agentId"
|
||||
AND sl."agentVersion" = slv."agentVersion"
|
||||
);
|
||||
|
||||
-- Create new indexes
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS "StoreListing_activeVersionId_key"
|
||||
ON "StoreListing"("activeVersionId");
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "StoreListing_isDeleted_hasApprovedVersion_idx"
|
||||
ON "StoreListing"("isDeleted", "hasApprovedVersion");
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "StoreListingVersion_storeListingId_submissionStatus_isAvailable_idx"
|
||||
ON "StoreListingVersion"("storeListingId", "submissionStatus", "isAvailable");
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "StoreListingVersion_submissionStatus_idx"
|
||||
ON "StoreListingVersion"("submissionStatus");
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS "StoreListingVersion_storeListingId_version_key"
|
||||
ON "StoreListingVersion"("storeListingId", "version");
|
||||
|
||||
-- Add foreign keys
|
||||
ALTER TABLE "StoreListing"
|
||||
ADD CONSTRAINT "StoreListing_activeVersionId_fkey"
|
||||
FOREIGN KEY ("activeVersionId") REFERENCES "StoreListingVersion"("id")
|
||||
ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- Add reviewer foreign key
|
||||
ALTER TABLE "StoreListingVersion"
|
||||
ADD CONSTRAINT "StoreListingVersion_reviewerId_fkey"
|
||||
FOREIGN KEY ("reviewerId") REFERENCES "User"("id")
|
||||
ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- Add index for reviewer
|
||||
CREATE INDEX IF NOT EXISTS "StoreListingVersion_reviewerId_idx"
|
||||
ON "StoreListingVersion"("reviewerId");
|
||||
|
||||
-- DropIndex
|
||||
DROP INDEX "StoreListingVersion_agentId_agentVersion_key";
|
||||
|
||||
-- RenameIndex
|
||||
ALTER INDEX "StoreListingVersion_storeListingId_submissionStatus_isAvailable_idx"
|
||||
RENAME TO "StoreListingVersion_storeListingId_submissionStatus_isAvail_idx";
|
||||
|
||||
-- Recreate the views with updated column references
|
||||
|
||||
-- 1. Recreate StoreSubmission view
|
||||
CREATE VIEW "StoreSubmission" AS
|
||||
SELECT
|
||||
sl.id AS listing_id,
|
||||
sl."owningUserId" AS user_id,
|
||||
slv."agentId" AS agent_id,
|
||||
slv.version AS agent_version,
|
||||
sl.slug,
|
||||
COALESCE(slv.name, '') AS name,
|
||||
slv."subHeading" AS sub_heading,
|
||||
slv.description,
|
||||
slv."imageUrls" AS image_urls,
|
||||
slv."submittedAt" AS date_submitted,
|
||||
slv."submissionStatus" AS status,
|
||||
COALESCE(ar.run_count, 0::bigint) AS runs,
|
||||
COALESCE(avg(sr.score::numeric), 0.0)::double precision AS rating,
|
||||
-- Add the additional fields needed by the Pydantic model
|
||||
slv.id AS store_listing_version_id,
|
||||
slv."reviewerId" AS reviewer_id,
|
||||
slv."reviewComments" AS review_comments,
|
||||
slv."internalComments" AS internal_comments,
|
||||
slv."reviewedAt" AS reviewed_at,
|
||||
slv."changesSummary" AS changes_summary
|
||||
FROM "StoreListing" sl
|
||||
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
|
||||
LEFT JOIN "StoreListingReview" sr ON sr."storeListingVersionId" = slv.id
|
||||
LEFT JOIN (
|
||||
SELECT "AgentGraphExecution"."agentGraphId", count(*) AS run_count
|
||||
FROM "AgentGraphExecution"
|
||||
GROUP BY "AgentGraphExecution"."agentGraphId"
|
||||
) ar ON ar."agentGraphId" = slv."agentId"
|
||||
WHERE sl."isDeleted" = false
|
||||
GROUP BY sl.id, sl."owningUserId", slv.id, slv."agentId", slv.version, sl.slug, slv.name,
|
||||
slv."subHeading", slv.description, slv."imageUrls", slv."submittedAt",
|
||||
slv."submissionStatus", slv."reviewerId", slv."reviewComments", slv."internalComments",
|
||||
slv."reviewedAt", slv."changesSummary", ar.run_count;
|
||||
|
||||
-- 2. Recreate StoreAgent view
|
||||
CREATE VIEW "StoreAgent" AS
|
||||
WITH reviewstats AS (
|
||||
SELECT sl_1.id AS "storeListingId",
|
||||
count(sr.id) AS review_count,
|
||||
avg(sr.score::numeric) AS avg_rating
|
||||
FROM "StoreListing" sl_1
|
||||
JOIN "StoreListingVersion" slv_1
|
||||
ON slv_1."storeListingId" = sl_1.id
|
||||
JOIN "StoreListingReview" sr
|
||||
ON sr."storeListingVersionId" = slv_1.id
|
||||
WHERE sl_1."isDeleted" = false
|
||||
GROUP BY sl_1.id
|
||||
), agentruns AS (
|
||||
SELECT "AgentGraphExecution"."agentGraphId",
|
||||
count(*) AS run_count
|
||||
FROM "AgentGraphExecution"
|
||||
GROUP BY "AgentGraphExecution"."agentGraphId"
|
||||
)
|
||||
SELECT sl.id AS listing_id,
|
||||
slv.id AS "storeListingVersionId",
|
||||
slv."createdAt" AS updated_at,
|
||||
sl.slug,
|
||||
COALESCE(slv.name, '') AS agent_name,
|
||||
slv."videoUrl" AS agent_video,
|
||||
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
|
||||
slv."isFeatured" AS featured,
|
||||
p.username AS creator_username,
|
||||
p."avatarUrl" AS creator_avatar,
|
||||
slv."subHeading" AS sub_heading,
|
||||
slv.description,
|
||||
slv.categories,
|
||||
COALESCE(ar.run_count, 0::bigint) AS runs,
|
||||
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
|
||||
array_agg(DISTINCT slv.version::text) AS versions
|
||||
FROM "StoreListing" sl
|
||||
JOIN "AgentGraph" a
|
||||
ON sl."agentId" = a.id
|
||||
AND sl."agentVersion" = a.version
|
||||
LEFT JOIN "Profile" p
|
||||
ON sl."owningUserId" = p."userId"
|
||||
LEFT JOIN "StoreListingVersion" slv
|
||||
ON slv."storeListingId" = sl.id
|
||||
LEFT JOIN reviewstats rs
|
||||
ON sl.id = rs."storeListingId"
|
||||
LEFT JOIN agentruns ar
|
||||
ON a.id = ar."agentGraphId"
|
||||
WHERE sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
GROUP BY sl.id, slv.id, sl.slug, slv."createdAt", slv.name, slv."videoUrl",
|
||||
slv."imageUrls", slv."isFeatured", p.username, p."avatarUrl",
|
||||
slv."subHeading", slv.description, slv.categories, ar.run_count,
|
||||
rs.avg_rating;
|
||||
|
||||
-- 3. Recreate Creator view
|
||||
CREATE VIEW "Creator" AS
|
||||
WITH agentstats AS (
|
||||
SELECT p_1.username,
|
||||
count(DISTINCT sl.id) AS num_agents,
|
||||
avg(COALESCE(sr.score, 0)::numeric) AS agent_rating,
|
||||
sum(COALESCE(age.run_count, 0::bigint)) AS agent_runs
|
||||
FROM "Profile" p_1
|
||||
LEFT JOIN "StoreListing" sl
|
||||
ON sl."owningUserId" = p_1."userId"
|
||||
LEFT JOIN "StoreListingVersion" slv
|
||||
ON slv."storeListingId" = sl.id
|
||||
LEFT JOIN "StoreListingReview" sr
|
||||
ON sr."storeListingVersionId" = slv.id
|
||||
LEFT JOIN (
|
||||
SELECT "AgentGraphExecution"."agentGraphId",
|
||||
count(*) AS run_count
|
||||
FROM "AgentGraphExecution"
|
||||
GROUP BY "AgentGraphExecution"."agentGraphId"
|
||||
) age ON age."agentGraphId" = sl."agentId"
|
||||
WHERE sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
GROUP BY p_1.username
|
||||
)
|
||||
SELECT p.username,
|
||||
p.name,
|
||||
p."avatarUrl" AS avatar_url,
|
||||
p.description,
|
||||
array_agg(DISTINCT cats.c) FILTER (WHERE cats.c IS NOT NULL) AS top_categories,
|
||||
p.links,
|
||||
p."isFeatured" AS is_featured,
|
||||
COALESCE(ast.num_agents, 0::bigint) AS num_agents,
|
||||
COALESCE(ast.agent_rating, 0.0) AS agent_rating,
|
||||
COALESCE(ast.agent_runs, 0::numeric) AS agent_runs
|
||||
FROM "Profile" p
|
||||
LEFT JOIN agentstats ast
|
||||
ON ast.username = p.username
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT unnest(slv.categories) AS c
|
||||
FROM "StoreListing" sl
|
||||
JOIN "StoreListingVersion" slv
|
||||
ON slv."storeListingId" = sl.id
|
||||
WHERE sl."owningUserId" = p."userId"
|
||||
AND sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
) cats ON true
|
||||
GROUP BY p.username, p.name, p."avatarUrl", p.description, p.links,
|
||||
p."isFeatured", ast.num_agents, ast.agent_rating, ast.agent_runs;
|
||||
|
||||
COMMIT;
|
||||
@@ -44,14 +44,14 @@ model User {
|
||||
AgentPreset AgentPreset[]
|
||||
LibraryAgent LibraryAgent[]
|
||||
|
||||
Profile Profile[]
|
||||
UserOnboarding UserOnboarding?
|
||||
StoreListing StoreListing[]
|
||||
StoreListingReview StoreListingReview[]
|
||||
StoreListingSubmission StoreListingSubmission[]
|
||||
APIKeys APIKey[]
|
||||
IntegrationWebhooks IntegrationWebhook[]
|
||||
UserNotificationBatch UserNotificationBatch[]
|
||||
Profile Profile[]
|
||||
UserOnboarding UserOnboarding?
|
||||
StoreListing StoreListing[]
|
||||
StoreListingReview StoreListingReview[]
|
||||
StoreVersionsReviewed StoreListingVersion[]
|
||||
APIKeys APIKey[]
|
||||
IntegrationWebhooks IntegrationWebhook[]
|
||||
UserNotificationBatch UserNotificationBatch[]
|
||||
|
||||
@@index([id])
|
||||
@@index([email])
|
||||
@@ -71,7 +71,7 @@ model UserOnboarding {
|
||||
isCompleted Boolean @default(false)
|
||||
|
||||
userId String @unique
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
@@ -86,14 +86,13 @@ model AgentGraph {
|
||||
name String?
|
||||
description String?
|
||||
|
||||
isActive Boolean @default(true)
|
||||
isTemplate Boolean @default(false)
|
||||
isActive Boolean @default(true)
|
||||
|
||||
// Link to User model
|
||||
userId String
|
||||
// FIX: Do not cascade delete the agent when the user is deleted
|
||||
// This allows us to delete user data with deleting the agent which maybe in use by other users
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
AgentNodes AgentNode[]
|
||||
AgentGraphExecution AgentGraphExecution[]
|
||||
@@ -101,7 +100,7 @@ model AgentGraph {
|
||||
AgentPreset AgentPreset[]
|
||||
LibraryAgent LibraryAgent[]
|
||||
StoreListing StoreListing[]
|
||||
StoreListingVersion StoreListingVersion?
|
||||
StoreListingVersion StoreListingVersion[]
|
||||
|
||||
@@id(name: "graphVersionId", [id, version])
|
||||
@@index([userId, isActive])
|
||||
@@ -176,11 +175,11 @@ model UserNotificationBatch {
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
type NotificationType
|
||||
|
||||
notifications NotificationEvent[]
|
||||
Notifications NotificationEvent[]
|
||||
|
||||
// Each user can only have one batch of a notification type at a time
|
||||
@@unique([userId, type])
|
||||
@@ -196,7 +195,7 @@ model LibraryAgent {
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
imageUrl String?
|
||||
imageUrl String?
|
||||
|
||||
agentId String
|
||||
agentVersion Int
|
||||
@@ -320,7 +319,7 @@ model AgentGraphExecution {
|
||||
|
||||
// Link to User model -- Executed by this user
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
stats Json?
|
||||
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
|
||||
@@ -385,7 +384,7 @@ model IntegrationWebhook {
|
||||
updatedAt DateTime? @updatedAt
|
||||
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Restrict) // Webhooks must be deregistered before deleting
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Restrict) // Webhooks must be deregistered before deleting
|
||||
|
||||
provider String // e.g. 'github'
|
||||
credentialsId String // relation to the credentials that the webhook was created with
|
||||
@@ -412,7 +411,7 @@ model AnalyticsDetails {
|
||||
|
||||
// Link to User model
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
// Analytics Categorical data used for filtering (indexable w and w/o userId)
|
||||
type String
|
||||
@@ -447,7 +446,7 @@ model AnalyticsMetrics {
|
||||
|
||||
// Link to User model
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
@@ -471,7 +470,7 @@ model CreditTransaction {
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
amount Int
|
||||
type CreditTransactionType
|
||||
@@ -580,19 +579,25 @@ view StoreAgent {
|
||||
}
|
||||
|
||||
view StoreSubmission {
|
||||
listing_id String @id
|
||||
user_id String
|
||||
slug String
|
||||
name String
|
||||
sub_heading String
|
||||
description String
|
||||
image_urls String[]
|
||||
date_submitted DateTime
|
||||
status SubmissionStatus
|
||||
runs Int
|
||||
rating Float
|
||||
agent_id String
|
||||
agent_version Int
|
||||
listing_id String @id
|
||||
user_id String
|
||||
slug String
|
||||
name String
|
||||
sub_heading String
|
||||
description String
|
||||
image_urls String[]
|
||||
date_submitted DateTime
|
||||
status SubmissionStatus
|
||||
runs Int
|
||||
rating Float
|
||||
agent_id String
|
||||
agent_version Int
|
||||
store_listing_version_id String
|
||||
reviewer_id String?
|
||||
review_comments String?
|
||||
internal_comments String?
|
||||
reviewed_at DateTime?
|
||||
changes_summary String?
|
||||
|
||||
// Index or unique are not applied to views
|
||||
}
|
||||
@@ -602,11 +607,18 @@ model StoreListing {
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
isDeleted Boolean @default(false)
|
||||
// Not needed but makes lookups faster
|
||||
isApproved Boolean @default(false)
|
||||
isDeleted Boolean @default(false)
|
||||
// Whether any version has been approved and is available for display
|
||||
hasApprovedVersion Boolean @default(false)
|
||||
|
||||
// The agent link here is only so we can do lookup on agentId, for the listing the StoreListingVersion is used.
|
||||
// URL-friendly identifier for this agent (moved from StoreListingVersion)
|
||||
slug String
|
||||
|
||||
// The currently active version that should be shown to users
|
||||
activeVersionId String? @unique
|
||||
ActiveVersion StoreListingVersion? @relation("ActiveVersion", fields: [activeVersionId], references: [id])
|
||||
|
||||
// The agent link here is only so we can do lookup on agentId
|
||||
agentId String
|
||||
agentVersion Int
|
||||
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
|
||||
@@ -614,14 +626,14 @@ model StoreListing {
|
||||
owningUserId String
|
||||
OwningUser User @relation(fields: [owningUserId], references: [id])
|
||||
|
||||
StoreListingVersions StoreListingVersion[]
|
||||
StoreListingSubmission StoreListingSubmission[]
|
||||
// Relations
|
||||
Versions StoreListingVersion[] @relation("ListingVersions")
|
||||
|
||||
// Unique index on agentId to ensure only one listing per agent, regardless of number of versions the agent has.
|
||||
@@unique([agentId])
|
||||
@@index([owningUserId])
|
||||
@@unique([owningUserId, slug])
|
||||
// Used in the view query
|
||||
@@index([isDeleted, isApproved])
|
||||
@@index([isDeleted, hasApprovedVersion])
|
||||
}
|
||||
|
||||
model StoreListingVersion {
|
||||
@@ -635,10 +647,7 @@ model StoreListingVersion {
|
||||
agentVersion Int
|
||||
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version])
|
||||
|
||||
// The details for this version of the agent, this allows the author to update the details of the agent,
|
||||
// But still allow using old versions of the agent with there original details.
|
||||
// TODO: Create a database view that shows only the latest version of each store listing.
|
||||
slug String
|
||||
// Content fields
|
||||
name String
|
||||
subHeading String
|
||||
videoUrl String?
|
||||
@@ -648,20 +657,39 @@ model StoreListingVersion {
|
||||
|
||||
isFeatured Boolean @default(false)
|
||||
|
||||
isDeleted Boolean @default(false)
|
||||
isDeleted Boolean @default(false)
|
||||
// Old versions can be made unavailable by the author if desired
|
||||
isAvailable Boolean @default(true)
|
||||
// Not needed but makes lookups faster
|
||||
isApproved Boolean @default(false)
|
||||
StoreListing StoreListing? @relation(fields: [storeListingId], references: [id], onDelete: Cascade)
|
||||
storeListingId String?
|
||||
StoreListingSubmission StoreListingSubmission[]
|
||||
isAvailable Boolean @default(true)
|
||||
|
||||
// Reviews are on a specific version, but then aggregated up to the listing.
|
||||
// This allows us to provide a review filter to current version of the agent.
|
||||
StoreListingReview StoreListingReview[]
|
||||
// Version workflow state
|
||||
submissionStatus SubmissionStatus @default(DRAFT)
|
||||
submittedAt DateTime?
|
||||
|
||||
@@unique([agentId, agentVersion])
|
||||
// Relations
|
||||
storeListingId String
|
||||
StoreListing StoreListing @relation("ListingVersions", fields: [storeListingId], references: [id], onDelete: Cascade)
|
||||
|
||||
// This version might be the active version for a listing
|
||||
ActiveFor StoreListing? @relation("ActiveVersion")
|
||||
|
||||
// Submission history
|
||||
changesSummary String?
|
||||
|
||||
// Review information
|
||||
reviewerId String?
|
||||
Reviewer User? @relation(fields: [reviewerId], references: [id])
|
||||
internalComments String? // Private notes for admin use only
|
||||
reviewComments String? // Comments visible to creator
|
||||
reviewedAt DateTime?
|
||||
|
||||
// Reviews for this specific version
|
||||
Reviews StoreListingReview[]
|
||||
|
||||
@@unique([storeListingId, version])
|
||||
@@index([storeListingId, submissionStatus, isAvailable])
|
||||
@@index([submissionStatus])
|
||||
@@index([reviewerId])
|
||||
@@index([agentId, agentVersion]) // Non-unique index for efficient lookups
|
||||
}
|
||||
|
||||
model StoreListingReview {
|
||||
@@ -682,31 +710,10 @@ model StoreListingReview {
|
||||
}
|
||||
|
||||
enum SubmissionStatus {
|
||||
DAFT
|
||||
PENDING
|
||||
APPROVED
|
||||
REJECTED
|
||||
}
|
||||
|
||||
model StoreListingSubmission {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
storeListingId String
|
||||
StoreListing StoreListing @relation(fields: [storeListingId], references: [id], onDelete: Cascade)
|
||||
|
||||
storeListingVersionId String
|
||||
StoreListingVersion StoreListingVersion @relation(fields: [storeListingVersionId], references: [id], onDelete: Cascade)
|
||||
|
||||
reviewerId String
|
||||
Reviewer User @relation(fields: [reviewerId], references: [id])
|
||||
|
||||
Status SubmissionStatus @default(PENDING)
|
||||
reviewComments String?
|
||||
|
||||
@@unique([storeListingVersionId])
|
||||
@@index([storeListingId])
|
||||
DRAFT // Being prepared, not yet submitted
|
||||
PENDING // Submitted, awaiting review
|
||||
APPROVED // Reviewed and approved
|
||||
REJECTED // Reviewed and rejected
|
||||
}
|
||||
|
||||
enum APIKeyPermission {
|
||||
@@ -733,7 +740,7 @@ model APIKey {
|
||||
|
||||
// Relation to user
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@index([key])
|
||||
@@index([prefix])
|
||||
|
||||
@@ -5,9 +5,11 @@ from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction
|
||||
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.block import get_block
|
||||
from backend.data.credit import BetaUserCredit
|
||||
from backend.data.execution import NodeExecutionEntry
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.executor.utils import UsageTransactionMetadata, block_usage_cost
|
||||
from backend.integrations.credentials_store import openai_credentials
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
@@ -27,13 +29,36 @@ async def top_up(amount: int):
|
||||
)
|
||||
|
||||
|
||||
async def spend_credits(entry: NodeExecutionEntry) -> int:
|
||||
block = get_block(entry.block_id)
|
||||
if not block:
|
||||
raise RuntimeError(f"Block {entry.block_id} not found")
|
||||
|
||||
cost, matching_filter = block_usage_cost(block=block, input_data=entry.data)
|
||||
await user_credit.spend_credits(
|
||||
entry.user_id,
|
||||
cost,
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=entry.graph_exec_id,
|
||||
graph_id=entry.graph_id,
|
||||
node_id=entry.node_id,
|
||||
node_exec_id=entry.node_exec_id,
|
||||
block_id=entry.block_id,
|
||||
block=entry.block_id,
|
||||
input=matching_filter,
|
||||
),
|
||||
)
|
||||
|
||||
return cost
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_block_credit_usage(server: SpinTestServer):
|
||||
await disable_test_user_transactions()
|
||||
await top_up(100)
|
||||
current_credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
|
||||
spending_amount_1 = await user_credit.spend_credits(
|
||||
spending_amount_1 = await spend_credits(
|
||||
NodeExecutionEntry(
|
||||
user_id=DEFAULT_USER_ID,
|
||||
graph_id="test_graph",
|
||||
@@ -50,12 +75,10 @@ async def test_block_credit_usage(server: SpinTestServer):
|
||||
},
|
||||
},
|
||||
),
|
||||
0.0,
|
||||
0.0,
|
||||
)
|
||||
assert spending_amount_1 > 0
|
||||
|
||||
spending_amount_2 = await user_credit.spend_credits(
|
||||
spending_amount_2 = await spend_credits(
|
||||
NodeExecutionEntry(
|
||||
user_id=DEFAULT_USER_ID,
|
||||
graph_id="test_graph",
|
||||
@@ -65,8 +88,6 @@ async def test_block_credit_usage(server: SpinTestServer):
|
||||
block_id=AITextGeneratorBlock().id,
|
||||
data={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
|
||||
),
|
||||
0.0,
|
||||
0.0,
|
||||
)
|
||||
assert spending_amount_2 == 0
|
||||
|
||||
|
||||
@@ -6,7 +6,8 @@ import fastapi.exceptions
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.store.model as store
|
||||
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock, StoreValueBlock
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.data.model import SchemaField
|
||||
@@ -199,7 +200,9 @@ async def test_clean_graph(server: SpinTestServer):
|
||||
)
|
||||
|
||||
# Clean the graph
|
||||
created_graph.clean_graph()
|
||||
created_graph = await server.agent_server.test_get_graph(
|
||||
created_graph.id, created_graph.version, DEFAULT_USER_ID, for_export=True
|
||||
)
|
||||
|
||||
# # Verify input block value is cleared
|
||||
input_node = next(
|
||||
@@ -240,7 +243,7 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
store_submission_request = store.StoreSubmissionRequest(
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
slug="test-slug",
|
||||
slug=created_graph.id,
|
||||
name="Test name",
|
||||
sub_heading="Test sub heading",
|
||||
video_url=None,
|
||||
|
||||
@@ -7,7 +7,8 @@ from prisma.models import User
|
||||
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.store.model
|
||||
from backend.blocks.basic import AgentInputBlock, FindInDictionaryBlock, StoreValueBlock
|
||||
from backend.blocks.basic import FindInDictionaryBlock, StoreValueBlock
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
from backend.blocks.maths import CalculatorBlock, Operation
|
||||
from backend.data import execution, graph
|
||||
from backend.server.model import CreateGraph
|
||||
@@ -123,8 +124,8 @@ async def assert_sample_graph_executions(
|
||||
logger.info(f"Checking PrintToConsoleBlock execution: {exec}")
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert exec.output_data == {"status": ["printed"]}
|
||||
assert exec.input_data == {"text": "Hello, World!!!"}
|
||||
assert exec.output_data == {"output": ["Hello, World!!!"]}
|
||||
assert exec.input_data == {"input": "Hello, World!!!"}
|
||||
assert exec.node_id == test_graph.nodes[3].id
|
||||
|
||||
|
||||
@@ -494,7 +495,7 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
|
||||
agent_id=test_graph.id,
|
||||
agent_version=test_graph.version,
|
||||
slug="test-slug",
|
||||
slug=test_graph.id,
|
||||
name="Test name",
|
||||
sub_heading="Test sub heading",
|
||||
video_url=None,
|
||||
|
||||
@@ -46,17 +46,27 @@ def test_disconnect(
|
||||
async def test_subscribe(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
await connection_manager.subscribe("test_graph", 1, mock_websocket)
|
||||
assert mock_websocket in connection_manager.subscriptions["test_graph_1"]
|
||||
await connection_manager.subscribe(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
assert mock_websocket in connection_manager.subscriptions["user-1_test_graph_1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
connection_manager.subscriptions["test_graph_1"] = {mock_websocket}
|
||||
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
|
||||
|
||||
await connection_manager.unsubscribe("test_graph", 1, mock_websocket)
|
||||
await connection_manager.unsubscribe(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
|
||||
assert "test_graph" not in connection_manager.subscriptions
|
||||
|
||||
@@ -65,8 +75,9 @@ async def test_unsubscribe(
|
||||
async def test_send_execution_result(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
connection_manager.subscriptions["test_graph_1"] = {mock_websocket}
|
||||
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
|
||||
result: ExecutionResult = ExecutionResult(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test_exec_id",
|
||||
@@ -87,17 +98,45 @@ async def test_send_execution_result(
|
||||
mock_websocket.send_text.assert_called_once_with(
|
||||
WsMessage(
|
||||
method=Methods.EXECUTION_EVENT,
|
||||
channel="test_graph_1",
|
||||
channel="user-1_test_graph_1",
|
||||
data=result.model_dump(),
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_execution_result_user_mismatch(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
|
||||
result: ExecutionResult = ExecutionResult(
|
||||
user_id="user-2",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test_exec_id",
|
||||
node_exec_id="test_node_exec_id",
|
||||
node_id="test_node_id",
|
||||
block_id="test_block_id",
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
input_data={"input1": "value1"},
|
||||
output_data={"output1": ["result1"]},
|
||||
add_time=datetime.now(tz=timezone.utc),
|
||||
queue_time=None,
|
||||
start_time=datetime.now(tz=timezone.utc),
|
||||
end_time=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
await connection_manager.send_execution_result(result)
|
||||
|
||||
mock_websocket.send_text.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_execution_result_no_subscribers(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
result: ExecutionResult = ExecutionResult(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test_exec_id",
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.ws_api import (
|
||||
Methods,
|
||||
@@ -41,7 +42,12 @@ async def test_websocket_router_subscribe(
|
||||
)
|
||||
|
||||
mock_manager.connect.assert_called_once_with(mock_websocket)
|
||||
mock_manager.subscribe.assert_called_once_with("test_graph", 1, mock_websocket)
|
||||
mock_manager.subscribe.assert_called_once_with(
|
||||
user_id=DEFAULT_USER_ID,
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
||||
@@ -65,7 +71,12 @@ async def test_websocket_router_unsubscribe(
|
||||
)
|
||||
|
||||
mock_manager.connect.assert_called_once_with(mock_websocket)
|
||||
mock_manager.unsubscribe.assert_called_once_with("test_graph", 1, mock_websocket)
|
||||
mock_manager.unsubscribe.assert_called_once_with(
|
||||
user_id=DEFAULT_USER_ID,
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
||||
@@ -101,10 +112,18 @@ async def test_handle_subscribe_success(
|
||||
)
|
||||
|
||||
await handle_subscribe(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
websocket=cast(WebSocket, mock_websocket),
|
||||
user_id="user-1",
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager.subscribe.assert_called_once_with("test_graph", 1, mock_websocket)
|
||||
mock_manager.subscribe.assert_called_once_with(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
||||
@@ -117,7 +136,10 @@ async def test_handle_subscribe_missing_data(
|
||||
message = WsMessage(method=Methods.SUBSCRIBE)
|
||||
|
||||
await handle_subscribe(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
websocket=cast(WebSocket, mock_websocket),
|
||||
user_id="user-1",
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager.subscribe.assert_not_called()
|
||||
@@ -135,10 +157,18 @@ async def test_handle_unsubscribe_success(
|
||||
)
|
||||
|
||||
await handle_unsubscribe(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
websocket=cast(WebSocket, mock_websocket),
|
||||
user_id="user-1",
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager.unsubscribe.assert_called_once_with("test_graph", 1, mock_websocket)
|
||||
mock_manager.unsubscribe.assert_called_once_with(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
||||
@@ -151,7 +181,10 @@ async def test_handle_unsubscribe_missing_data(
|
||||
message = WsMessage(method=Methods.UNSUBSCRIBE)
|
||||
|
||||
await handle_unsubscribe(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
websocket=cast(WebSocket, mock_websocket),
|
||||
user_id="user-1",
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager.unsubscribe.assert_not_called()
|
||||
|
||||
@@ -91,7 +91,6 @@ async def main():
|
||||
"description": faker.text(max_nb_chars=200),
|
||||
"userId": user.id,
|
||||
"isActive": True,
|
||||
"isTemplate": False,
|
||||
}
|
||||
)
|
||||
agent_graphs.append(graph)
|
||||
@@ -329,12 +328,14 @@ async def main():
|
||||
print(f"Inserting {NUM_USERS} store listings")
|
||||
for graph in agent_graphs:
|
||||
user = random.choice(users)
|
||||
slug = faker.slug()
|
||||
listing = await db.storelisting.create(
|
||||
data={
|
||||
"agentId": graph.id,
|
||||
"agentVersion": graph.version,
|
||||
"owningUserId": user.id,
|
||||
"isApproved": random.choice([True, False]),
|
||||
"hasApprovedVersion": random.choice([True, False]),
|
||||
"slug": slug,
|
||||
}
|
||||
)
|
||||
store_listings.append(listing)
|
||||
@@ -348,7 +349,6 @@ async def main():
|
||||
data={
|
||||
"agentId": graph.id,
|
||||
"agentVersion": graph.version,
|
||||
"slug": faker.slug(),
|
||||
"name": graph.name or faker.sentence(nb_words=3),
|
||||
"subHeading": faker.sentence(),
|
||||
"videoUrl": faker.url(),
|
||||
@@ -357,8 +357,14 @@ async def main():
|
||||
"categories": [faker.word() for _ in range(3)],
|
||||
"isFeatured": random.choice([True, False]),
|
||||
"isAvailable": True,
|
||||
"isApproved": random.choice([True, False]),
|
||||
"storeListingId": listing.id,
|
||||
"submissionStatus": random.choice(
|
||||
[
|
||||
prisma.enums.SubmissionStatus.PENDING,
|
||||
prisma.enums.SubmissionStatus.APPROVED,
|
||||
prisma.enums.SubmissionStatus.REJECTED,
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
store_listing_versions.append(version)
|
||||
@@ -387,10 +393,9 @@ async def main():
|
||||
}
|
||||
)
|
||||
|
||||
# Insert StoreListingSubmissions
|
||||
print(f"Inserting {NUM_USERS} store listing submissions")
|
||||
for listing in store_listings:
|
||||
version = random.choice(store_listing_versions)
|
||||
# Update StoreListingVersions with submission status (StoreListingSubmissions table no longer exists)
|
||||
print(f"Updating {NUM_USERS} store listing versions with submission status")
|
||||
for version in store_listing_versions:
|
||||
reviewer = random.choice(users)
|
||||
status: prisma.enums.SubmissionStatus = random.choice(
|
||||
[
|
||||
@@ -399,14 +404,14 @@ async def main():
|
||||
prisma.enums.SubmissionStatus.REJECTED,
|
||||
]
|
||||
)
|
||||
await db.storelistingsubmission.create(
|
||||
await db.storelistingversion.update(
|
||||
where={"id": version.id},
|
||||
data={
|
||||
"storeListingId": listing.id,
|
||||
"storeListingVersionId": version.id,
|
||||
"reviewerId": reviewer.id,
|
||||
"Status": status,
|
||||
"submissionStatus": status,
|
||||
"Reviewer": {"connect": {"id": reviewer.id}},
|
||||
"reviewComments": faker.text(),
|
||||
}
|
||||
"reviewedAt": datetime.now(),
|
||||
},
|
||||
)
|
||||
|
||||
# Insert APIKeys
|
||||
|
||||
123
autogpt_platform/db/docker/.env.example
Normal file
123
autogpt_platform/db/docker/.env.example
Normal file
@@ -0,0 +1,123 @@
|
||||
############
|
||||
# Secrets
|
||||
# YOU MUST CHANGE THESE BEFORE GOING INTO PRODUCTION
|
||||
############
|
||||
|
||||
POSTGRES_PASSWORD=your-super-secret-and-long-postgres-password
|
||||
JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
|
||||
SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
DASHBOARD_USERNAME=supabase
|
||||
DASHBOARD_PASSWORD=this_password_is_insecure_and_should_be_updated
|
||||
SECRET_KEY_BASE=UpNVntn3cDxHJpq99YMc1T1AQgQpc8kfYTuRgBiYa15BLrx8etQoXz3gZv1/u2oq
|
||||
VAULT_ENC_KEY=your-encryption-key-32-chars-min
|
||||
|
||||
|
||||
############
|
||||
# Database - You can change these to any PostgreSQL database that has logical replication enabled.
|
||||
############
|
||||
|
||||
POSTGRES_HOST=db
|
||||
POSTGRES_DB=postgres
|
||||
POSTGRES_PORT=5432
|
||||
# default user is postgres
|
||||
|
||||
|
||||
############
|
||||
# Supavisor -- Database pooler
|
||||
############
|
||||
POOLER_PROXY_PORT_TRANSACTION=6543
|
||||
POOLER_DEFAULT_POOL_SIZE=20
|
||||
POOLER_MAX_CLIENT_CONN=100
|
||||
POOLER_TENANT_ID=your-tenant-id
|
||||
|
||||
|
||||
############
|
||||
# API Proxy - Configuration for the Kong Reverse proxy.
|
||||
############
|
||||
|
||||
KONG_HTTP_PORT=8000
|
||||
KONG_HTTPS_PORT=8443
|
||||
|
||||
|
||||
############
|
||||
# API - Configuration for PostgREST.
|
||||
############
|
||||
|
||||
PGRST_DB_SCHEMAS=public,storage,graphql_public
|
||||
|
||||
|
||||
############
|
||||
# Auth - Configuration for the GoTrue authentication server.
|
||||
############
|
||||
|
||||
## General
|
||||
SITE_URL=http://localhost:3000
|
||||
ADDITIONAL_REDIRECT_URLS=
|
||||
JWT_EXPIRY=3600
|
||||
DISABLE_SIGNUP=false
|
||||
API_EXTERNAL_URL=http://localhost:8000
|
||||
|
||||
## Mailer Config
|
||||
MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify"
|
||||
MAILER_URLPATHS_INVITE="/auth/v1/verify"
|
||||
MAILER_URLPATHS_RECOVERY="/auth/v1/verify"
|
||||
MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify"
|
||||
|
||||
## Email auth
|
||||
ENABLE_EMAIL_SIGNUP=true
|
||||
ENABLE_EMAIL_AUTOCONFIRM=false
|
||||
SMTP_ADMIN_EMAIL=admin@example.com
|
||||
SMTP_HOST=supabase-mail
|
||||
SMTP_PORT=2500
|
||||
SMTP_USER=fake_mail_user
|
||||
SMTP_PASS=fake_mail_password
|
||||
SMTP_SENDER_NAME=fake_sender
|
||||
ENABLE_ANONYMOUS_USERS=false
|
||||
|
||||
## Phone auth
|
||||
ENABLE_PHONE_SIGNUP=true
|
||||
ENABLE_PHONE_AUTOCONFIRM=true
|
||||
|
||||
|
||||
############
|
||||
# Studio - Configuration for the Dashboard
|
||||
############
|
||||
|
||||
STUDIO_DEFAULT_ORGANIZATION=Default Organization
|
||||
STUDIO_DEFAULT_PROJECT=Default Project
|
||||
|
||||
STUDIO_PORT=3000
|
||||
# replace if you intend to use Studio outside of localhost
|
||||
SUPABASE_PUBLIC_URL=http://localhost:8000
|
||||
|
||||
# Enable webp support
|
||||
IMGPROXY_ENABLE_WEBP_DETECTION=true
|
||||
|
||||
# Add your OpenAI API key to enable SQL Editor Assistant
|
||||
OPENAI_API_KEY=
|
||||
|
||||
|
||||
############
|
||||
# Functions - Configuration for Functions
|
||||
############
|
||||
# NOTE: VERIFY_JWT applies to all functions. Per-function VERIFY_JWT is not supported yet.
|
||||
FUNCTIONS_VERIFY_JWT=false
|
||||
|
||||
|
||||
############
|
||||
# Logs - Configuration for Logflare
|
||||
# Please refer to https://supabase.com/docs/reference/self-hosting-analytics/introduction
|
||||
############
|
||||
|
||||
LOGFLARE_LOGGER_BACKEND_API_KEY=your-super-secret-and-long-logflare-key
|
||||
|
||||
# Change vector.toml sinks to reflect this change
|
||||
LOGFLARE_API_KEY=your-super-secret-and-long-logflare-key
|
||||
|
||||
# Docker socket location - this value will differ depending on your OS
|
||||
DOCKER_SOCKET_LOCATION=/var/run/docker.sock
|
||||
|
||||
# Google Cloud Project details
|
||||
GOOGLE_PROJECT_ID=GOOGLE_PROJECT_ID
|
||||
GOOGLE_PROJECT_NUMBER=GOOGLE_PROJECT_NUMBER
|
||||
5
autogpt_platform/db/docker/.gitignore
vendored
Normal file
5
autogpt_platform/db/docker/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
volumes/db/data
|
||||
volumes/storage
|
||||
.env
|
||||
test.http
|
||||
docker-compose.override.yml
|
||||
3
autogpt_platform/db/docker/README.md
Normal file
3
autogpt_platform/db/docker/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Supabase Docker
|
||||
|
||||
This is a minimal Docker Compose setup for self-hosting Supabase. Follow the steps [here](https://supabase.com/docs/guides/hosting/docker) to get started.
|
||||
48
autogpt_platform/db/docker/dev/data.sql
Normal file
48
autogpt_platform/db/docker/dev/data.sql
Normal file
@@ -0,0 +1,48 @@
|
||||
create table profiles (
|
||||
id uuid references auth.users not null,
|
||||
updated_at timestamp with time zone,
|
||||
username text unique,
|
||||
avatar_url text,
|
||||
website text,
|
||||
|
||||
primary key (id),
|
||||
unique(username),
|
||||
constraint username_length check (char_length(username) >= 3)
|
||||
);
|
||||
|
||||
alter table profiles enable row level security;
|
||||
|
||||
create policy "Public profiles are viewable by the owner."
|
||||
on profiles for select
|
||||
using ( auth.uid() = id );
|
||||
|
||||
create policy "Users can insert their own profile."
|
||||
on profiles for insert
|
||||
with check ( auth.uid() = id );
|
||||
|
||||
create policy "Users can update own profile."
|
||||
on profiles for update
|
||||
using ( auth.uid() = id );
|
||||
|
||||
-- Set up Realtime
|
||||
begin;
|
||||
drop publication if exists supabase_realtime;
|
||||
create publication supabase_realtime;
|
||||
commit;
|
||||
alter publication supabase_realtime add table profiles;
|
||||
|
||||
-- Set up Storage
|
||||
insert into storage.buckets (id, name)
|
||||
values ('avatars', 'avatars');
|
||||
|
||||
create policy "Avatar images are publicly accessible."
|
||||
on storage.objects for select
|
||||
using ( bucket_id = 'avatars' );
|
||||
|
||||
create policy "Anyone can upload an avatar."
|
||||
on storage.objects for insert
|
||||
with check ( bucket_id = 'avatars' );
|
||||
|
||||
create policy "Anyone can update an avatar."
|
||||
on storage.objects for update
|
||||
with check ( bucket_id = 'avatars' );
|
||||
34
autogpt_platform/db/docker/dev/docker-compose.dev.yml
Normal file
34
autogpt_platform/db/docker/dev/docker-compose.dev.yml
Normal file
@@ -0,0 +1,34 @@
|
||||
version: "3.8"
|
||||
|
||||
services:
|
||||
studio:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: studio/Dockerfile
|
||||
target: dev
|
||||
ports:
|
||||
- 8082:8082
|
||||
mail:
|
||||
container_name: supabase-mail
|
||||
image: inbucket/inbucket:3.0.3
|
||||
ports:
|
||||
- '2500:2500' # SMTP
|
||||
- '9000:9000' # web interface
|
||||
- '1100:1100' # POP3
|
||||
auth:
|
||||
environment:
|
||||
- GOTRUE_SMTP_USER=
|
||||
- GOTRUE_SMTP_PASS=
|
||||
meta:
|
||||
ports:
|
||||
- 5555:8080
|
||||
db:
|
||||
restart: 'no'
|
||||
volumes:
|
||||
# Always use a fresh database when developing
|
||||
- /var/lib/postgresql/data
|
||||
# Seed data should be inserted last (alphabetical order)
|
||||
- ./dev/data.sql:/docker-entrypoint-initdb.d/seed.sql
|
||||
storage:
|
||||
volumes:
|
||||
- /var/lib/storage
|
||||
94
autogpt_platform/db/docker/docker-compose.s3.yml
Normal file
94
autogpt_platform/db/docker/docker-compose.s3.yml
Normal file
@@ -0,0 +1,94 @@
|
||||
services:
|
||||
|
||||
minio:
|
||||
image: minio/minio
|
||||
ports:
|
||||
- '9000:9000'
|
||||
- '9001:9001'
|
||||
environment:
|
||||
MINIO_ROOT_USER: supa-storage
|
||||
MINIO_ROOT_PASSWORD: secret1234
|
||||
command: server --console-address ":9001" /data
|
||||
healthcheck:
|
||||
test: [ "CMD", "curl", "-f", "http://minio:9000/minio/health/live" ]
|
||||
interval: 2s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
volumes:
|
||||
- ./volumes/storage:/data:z
|
||||
|
||||
minio-createbucket:
|
||||
image: minio/mc
|
||||
depends_on:
|
||||
minio:
|
||||
condition: service_healthy
|
||||
entrypoint: >
|
||||
/bin/sh -c "
|
||||
/usr/bin/mc alias set supa-minio http://minio:9000 supa-storage secret1234;
|
||||
/usr/bin/mc mb supa-minio/stub;
|
||||
exit 0;
|
||||
"
|
||||
|
||||
storage:
|
||||
container_name: supabase-storage
|
||||
image: supabase/storage-api:v1.11.13
|
||||
depends_on:
|
||||
db:
|
||||
# Disable this if you are using an external Postgres database
|
||||
condition: service_healthy
|
||||
rest:
|
||||
condition: service_started
|
||||
imgproxy:
|
||||
condition: service_started
|
||||
minio:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"wget",
|
||||
"--no-verbose",
|
||||
"--tries=1",
|
||||
"--spider",
|
||||
"http://localhost:5000/status"
|
||||
]
|
||||
timeout: 5s
|
||||
interval: 5s
|
||||
retries: 3
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
ANON_KEY: ${ANON_KEY}
|
||||
SERVICE_KEY: ${SERVICE_ROLE_KEY}
|
||||
POSTGREST_URL: http://rest:3000
|
||||
PGRST_JWT_SECRET: ${JWT_SECRET}
|
||||
DATABASE_URL: postgres://supabase_storage_admin:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
|
||||
FILE_SIZE_LIMIT: 52428800
|
||||
STORAGE_BACKEND: s3
|
||||
GLOBAL_S3_BUCKET: stub
|
||||
GLOBAL_S3_ENDPOINT: http://minio:9000
|
||||
GLOBAL_S3_PROTOCOL: http
|
||||
GLOBAL_S3_FORCE_PATH_STYLE: true
|
||||
AWS_ACCESS_KEY_ID: supa-storage
|
||||
AWS_SECRET_ACCESS_KEY: secret1234
|
||||
AWS_DEFAULT_REGION: stub
|
||||
FILE_STORAGE_BACKEND_PATH: /var/lib/storage
|
||||
TENANT_ID: stub
|
||||
# TODO: https://github.com/supabase/storage-api/issues/55
|
||||
REGION: stub
|
||||
ENABLE_IMAGE_TRANSFORMATION: "true"
|
||||
IMGPROXY_URL: http://imgproxy:5001
|
||||
volumes:
|
||||
- ./volumes/storage:/var/lib/storage:z
|
||||
|
||||
imgproxy:
|
||||
container_name: supabase-imgproxy
|
||||
image: darthsim/imgproxy:v3.8.0
|
||||
healthcheck:
|
||||
test: [ "CMD", "imgproxy", "health" ]
|
||||
timeout: 5s
|
||||
interval: 5s
|
||||
retries: 3
|
||||
environment:
|
||||
IMGPROXY_BIND: ":5001"
|
||||
IMGPROXY_USE_ETAG: "true"
|
||||
IMGPROXY_ENABLE_WEBP_DETECTION: ${IMGPROXY_ENABLE_WEBP_DETECTION}
|
||||
526
autogpt_platform/db/docker/docker-compose.yml
Normal file
526
autogpt_platform/db/docker/docker-compose.yml
Normal file
@@ -0,0 +1,526 @@
|
||||
# Usage
|
||||
# Start: docker compose up
|
||||
# With helpers: docker compose -f docker-compose.yml -f ./dev/docker-compose.dev.yml up
|
||||
# Stop: docker compose down
|
||||
# Destroy: docker compose -f docker-compose.yml -f ./dev/docker-compose.dev.yml down -v --remove-orphans
|
||||
# Reset everything: ./reset.sh
|
||||
|
||||
name: supabase
|
||||
|
||||
services:
|
||||
|
||||
studio:
|
||||
container_name: supabase-studio
|
||||
image: supabase/studio:20250224-d10db0f
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"node",
|
||||
"-e",
|
||||
"fetch('http://studio:3000/api/platform/profile').then((r) => {if (r.status !== 200) throw new Error(r.status)})"
|
||||
]
|
||||
timeout: 10s
|
||||
interval: 5s
|
||||
retries: 3
|
||||
depends_on:
|
||||
analytics:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
STUDIO_PG_META_URL: http://meta:8080
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
|
||||
DEFAULT_ORGANIZATION_NAME: ${STUDIO_DEFAULT_ORGANIZATION}
|
||||
DEFAULT_PROJECT_NAME: ${STUDIO_DEFAULT_PROJECT}
|
||||
OPENAI_API_KEY: ${OPENAI_API_KEY:-}
|
||||
|
||||
SUPABASE_URL: http://kong:8000
|
||||
SUPABASE_PUBLIC_URL: ${SUPABASE_PUBLIC_URL}
|
||||
SUPABASE_ANON_KEY: ${ANON_KEY}
|
||||
SUPABASE_SERVICE_KEY: ${SERVICE_ROLE_KEY}
|
||||
AUTH_JWT_SECRET: ${JWT_SECRET}
|
||||
|
||||
LOGFLARE_API_KEY: ${LOGFLARE_API_KEY}
|
||||
LOGFLARE_URL: http://analytics:4000
|
||||
NEXT_PUBLIC_ENABLE_LOGS: true
|
||||
# Comment to use Big Query backend for analytics
|
||||
NEXT_ANALYTICS_BACKEND_PROVIDER: postgres
|
||||
# Uncomment to use Big Query backend for analytics
|
||||
# NEXT_ANALYTICS_BACKEND_PROVIDER: bigquery
|
||||
|
||||
kong:
|
||||
container_name: supabase-kong
|
||||
image: kong:2.8.1
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- ${KONG_HTTP_PORT}:8000/tcp
|
||||
- ${KONG_HTTPS_PORT}:8443/tcp
|
||||
volumes:
|
||||
# https://github.com/supabase/supabase/issues/12661
|
||||
- ./volumes/api/kong.yml:/home/kong/temp.yml:ro
|
||||
depends_on:
|
||||
analytics:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
KONG_DATABASE: "off"
|
||||
KONG_DECLARATIVE_CONFIG: /home/kong/kong.yml
|
||||
# https://github.com/supabase/cli/issues/14
|
||||
KONG_DNS_ORDER: LAST,A,CNAME
|
||||
KONG_PLUGINS: request-transformer,cors,key-auth,acl,basic-auth
|
||||
KONG_NGINX_PROXY_PROXY_BUFFER_SIZE: 160k
|
||||
KONG_NGINX_PROXY_PROXY_BUFFERS: 64 160k
|
||||
SUPABASE_ANON_KEY: ${ANON_KEY}
|
||||
SUPABASE_SERVICE_KEY: ${SERVICE_ROLE_KEY}
|
||||
DASHBOARD_USERNAME: ${DASHBOARD_USERNAME}
|
||||
DASHBOARD_PASSWORD: ${DASHBOARD_PASSWORD}
|
||||
# https://unix.stackexchange.com/a/294837
|
||||
entrypoint: bash -c 'eval "echo \"$$(cat ~/temp.yml)\"" > ~/kong.yml && /docker-entrypoint.sh kong docker-start'
|
||||
|
||||
auth:
|
||||
container_name: supabase-auth
|
||||
image: supabase/gotrue:v2.170.0
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"wget",
|
||||
"--no-verbose",
|
||||
"--tries=1",
|
||||
"--spider",
|
||||
"http://localhost:9999/health"
|
||||
]
|
||||
timeout: 5s
|
||||
interval: 5s
|
||||
retries: 3
|
||||
depends_on:
|
||||
db:
|
||||
# Disable this if you are using an external Postgres database
|
||||
condition: service_healthy
|
||||
analytics:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
GOTRUE_API_HOST: 0.0.0.0
|
||||
GOTRUE_API_PORT: 9999
|
||||
API_EXTERNAL_URL: ${API_EXTERNAL_URL}
|
||||
|
||||
GOTRUE_DB_DRIVER: postgres
|
||||
GOTRUE_DB_DATABASE_URL: postgres://supabase_auth_admin:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
|
||||
|
||||
GOTRUE_SITE_URL: ${SITE_URL}
|
||||
GOTRUE_URI_ALLOW_LIST: ${ADDITIONAL_REDIRECT_URLS}
|
||||
GOTRUE_DISABLE_SIGNUP: ${DISABLE_SIGNUP}
|
||||
|
||||
GOTRUE_JWT_ADMIN_ROLES: service_role
|
||||
GOTRUE_JWT_AUD: authenticated
|
||||
GOTRUE_JWT_DEFAULT_GROUP_NAME: authenticated
|
||||
GOTRUE_JWT_EXP: ${JWT_EXPIRY}
|
||||
GOTRUE_JWT_SECRET: ${JWT_SECRET}
|
||||
|
||||
GOTRUE_EXTERNAL_EMAIL_ENABLED: ${ENABLE_EMAIL_SIGNUP}
|
||||
GOTRUE_EXTERNAL_ANONYMOUS_USERS_ENABLED: ${ENABLE_ANONYMOUS_USERS}
|
||||
GOTRUE_MAILER_AUTOCONFIRM: ${ENABLE_EMAIL_AUTOCONFIRM}
|
||||
|
||||
# Uncomment to bypass nonce check in ID Token flow. Commonly set to true when using Google Sign In on mobile.
|
||||
# GOTRUE_EXTERNAL_SKIP_NONCE_CHECK: true
|
||||
|
||||
# GOTRUE_MAILER_SECURE_EMAIL_CHANGE_ENABLED: true
|
||||
# GOTRUE_SMTP_MAX_FREQUENCY: 1s
|
||||
GOTRUE_SMTP_ADMIN_EMAIL: ${SMTP_ADMIN_EMAIL}
|
||||
GOTRUE_SMTP_HOST: ${SMTP_HOST}
|
||||
GOTRUE_SMTP_PORT: ${SMTP_PORT}
|
||||
GOTRUE_SMTP_USER: ${SMTP_USER}
|
||||
GOTRUE_SMTP_PASS: ${SMTP_PASS}
|
||||
GOTRUE_SMTP_SENDER_NAME: ${SMTP_SENDER_NAME}
|
||||
GOTRUE_MAILER_URLPATHS_INVITE: ${MAILER_URLPATHS_INVITE}
|
||||
GOTRUE_MAILER_URLPATHS_CONFIRMATION: ${MAILER_URLPATHS_CONFIRMATION}
|
||||
GOTRUE_MAILER_URLPATHS_RECOVERY: ${MAILER_URLPATHS_RECOVERY}
|
||||
GOTRUE_MAILER_URLPATHS_EMAIL_CHANGE: ${MAILER_URLPATHS_EMAIL_CHANGE}
|
||||
|
||||
GOTRUE_EXTERNAL_PHONE_ENABLED: ${ENABLE_PHONE_SIGNUP}
|
||||
GOTRUE_SMS_AUTOCONFIRM: ${ENABLE_PHONE_AUTOCONFIRM}
|
||||
# Uncomment to enable custom access token hook. Please see: https://supabase.com/docs/guides/auth/auth-hooks for full list of hooks and additional details about custom_access_token_hook
|
||||
|
||||
# GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_ENABLED: "true"
|
||||
# GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_URI: "pg-functions://postgres/public/custom_access_token_hook"
|
||||
# GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_SECRETS: "<standard-base64-secret>"
|
||||
|
||||
# GOTRUE_HOOK_MFA_VERIFICATION_ATTEMPT_ENABLED: "true"
|
||||
# GOTRUE_HOOK_MFA_VERIFICATION_ATTEMPT_URI: "pg-functions://postgres/public/mfa_verification_attempt"
|
||||
|
||||
# GOTRUE_HOOK_PASSWORD_VERIFICATION_ATTEMPT_ENABLED: "true"
|
||||
# GOTRUE_HOOK_PASSWORD_VERIFICATION_ATTEMPT_URI: "pg-functions://postgres/public/password_verification_attempt"
|
||||
|
||||
# GOTRUE_HOOK_SEND_SMS_ENABLED: "false"
|
||||
# GOTRUE_HOOK_SEND_SMS_URI: "pg-functions://postgres/public/custom_access_token_hook"
|
||||
# GOTRUE_HOOK_SEND_SMS_SECRETS: "v1,whsec_VGhpcyBpcyBhbiBleGFtcGxlIG9mIGEgc2hvcnRlciBCYXNlNjQgc3RyaW5n"
|
||||
|
||||
# GOTRUE_HOOK_SEND_EMAIL_ENABLED: "false"
|
||||
# GOTRUE_HOOK_SEND_EMAIL_URI: "http://host.docker.internal:54321/functions/v1/email_sender"
|
||||
# GOTRUE_HOOK_SEND_EMAIL_SECRETS: "v1,whsec_VGhpcyBpcyBhbiBleGFtcGxlIG9mIGEgc2hvcnRlciBCYXNlNjQgc3RyaW5n"
|
||||
|
||||
rest:
|
||||
container_name: supabase-rest
|
||||
image: postgrest/postgrest:v12.2.8
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
db:
|
||||
# Disable this if you are using an external Postgres database
|
||||
condition: service_healthy
|
||||
analytics:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
PGRST_DB_URI: postgres://authenticator:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
|
||||
PGRST_DB_SCHEMAS: ${PGRST_DB_SCHEMAS}
|
||||
PGRST_DB_ANON_ROLE: anon
|
||||
PGRST_JWT_SECRET: ${JWT_SECRET}
|
||||
PGRST_DB_USE_LEGACY_GUCS: "false"
|
||||
PGRST_APP_SETTINGS_JWT_SECRET: ${JWT_SECRET}
|
||||
PGRST_APP_SETTINGS_JWT_EXP: ${JWT_EXPIRY}
|
||||
command:
|
||||
[
|
||||
"postgrest"
|
||||
]
|
||||
|
||||
realtime:
|
||||
# This container name looks inconsistent but is correct because realtime constructs tenant id by parsing the subdomain
|
||||
container_name: realtime-dev.supabase-realtime
|
||||
image: supabase/realtime:v2.34.40
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
db:
|
||||
# Disable this if you are using an external Postgres database
|
||||
condition: service_healthy
|
||||
analytics:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"curl",
|
||||
"-sSfL",
|
||||
"--head",
|
||||
"-o",
|
||||
"/dev/null",
|
||||
"-H",
|
||||
"Authorization: Bearer ${ANON_KEY}",
|
||||
"http://localhost:4000/api/tenants/realtime-dev/health"
|
||||
]
|
||||
timeout: 5s
|
||||
interval: 5s
|
||||
retries: 3
|
||||
environment:
|
||||
PORT: 4000
|
||||
DB_HOST: ${POSTGRES_HOST}
|
||||
DB_PORT: ${POSTGRES_PORT}
|
||||
DB_USER: supabase_admin
|
||||
DB_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
DB_NAME: ${POSTGRES_DB}
|
||||
DB_AFTER_CONNECT_QUERY: 'SET search_path TO _realtime'
|
||||
DB_ENC_KEY: supabaserealtime
|
||||
API_JWT_SECRET: ${JWT_SECRET}
|
||||
SECRET_KEY_BASE: ${SECRET_KEY_BASE}
|
||||
ERL_AFLAGS: -proto_dist inet_tcp
|
||||
DNS_NODES: "''"
|
||||
RLIMIT_NOFILE: "10000"
|
||||
APP_NAME: realtime
|
||||
SEED_SELF_HOST: true
|
||||
RUN_JANITOR: true
|
||||
|
||||
# To use S3 backed storage: docker compose -f docker-compose.yml -f docker-compose.s3.yml up
|
||||
storage:
|
||||
container_name: supabase-storage
|
||||
image: supabase/storage-api:v1.19.3
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./volumes/storage:/var/lib/storage:z
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"wget",
|
||||
"--no-verbose",
|
||||
"--tries=1",
|
||||
"--spider",
|
||||
"http://storage:5000/status"
|
||||
]
|
||||
timeout: 5s
|
||||
interval: 5s
|
||||
retries: 3
|
||||
depends_on:
|
||||
db:
|
||||
# Disable this if you are using an external Postgres database
|
||||
condition: service_healthy
|
||||
rest:
|
||||
condition: service_started
|
||||
imgproxy:
|
||||
condition: service_started
|
||||
environment:
|
||||
ANON_KEY: ${ANON_KEY}
|
||||
SERVICE_KEY: ${SERVICE_ROLE_KEY}
|
||||
POSTGREST_URL: http://rest:3000
|
||||
PGRST_JWT_SECRET: ${JWT_SECRET}
|
||||
DATABASE_URL: postgres://supabase_storage_admin:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
|
||||
FILE_SIZE_LIMIT: 52428800
|
||||
STORAGE_BACKEND: file
|
||||
FILE_STORAGE_BACKEND_PATH: /var/lib/storage
|
||||
TENANT_ID: stub
|
||||
# TODO: https://github.com/supabase/storage-api/issues/55
|
||||
REGION: stub
|
||||
GLOBAL_S3_BUCKET: stub
|
||||
ENABLE_IMAGE_TRANSFORMATION: "true"
|
||||
IMGPROXY_URL: http://imgproxy:5001
|
||||
|
||||
imgproxy:
|
||||
container_name: supabase-imgproxy
|
||||
image: darthsim/imgproxy:v3.8.0
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./volumes/storage:/var/lib/storage:z
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"imgproxy",
|
||||
"health"
|
||||
]
|
||||
timeout: 5s
|
||||
interval: 5s
|
||||
retries: 3
|
||||
environment:
|
||||
IMGPROXY_BIND: ":5001"
|
||||
IMGPROXY_LOCAL_FILESYSTEM_ROOT: /
|
||||
IMGPROXY_USE_ETAG: "true"
|
||||
IMGPROXY_ENABLE_WEBP_DETECTION: ${IMGPROXY_ENABLE_WEBP_DETECTION}
|
||||
|
||||
meta:
|
||||
container_name: supabase-meta
|
||||
image: supabase/postgres-meta:v0.86.1
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
db:
|
||||
# Disable this if you are using an external Postgres database
|
||||
condition: service_healthy
|
||||
analytics:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
PG_META_PORT: 8080
|
||||
PG_META_DB_HOST: ${POSTGRES_HOST}
|
||||
PG_META_DB_PORT: ${POSTGRES_PORT}
|
||||
PG_META_DB_NAME: ${POSTGRES_DB}
|
||||
PG_META_DB_USER: supabase_admin
|
||||
PG_META_DB_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
|
||||
functions:
|
||||
container_name: supabase-edge-functions
|
||||
image: supabase/edge-runtime:v1.67.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./volumes/functions:/home/deno/functions:Z
|
||||
depends_on:
|
||||
analytics:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
JWT_SECRET: ${JWT_SECRET}
|
||||
SUPABASE_URL: http://kong:8000
|
||||
SUPABASE_ANON_KEY: ${ANON_KEY}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${SERVICE_ROLE_KEY}
|
||||
SUPABASE_DB_URL: postgresql://postgres:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
|
||||
# TODO: Allow configuring VERIFY_JWT per function. This PR might help: https://github.com/supabase/cli/pull/786
|
||||
VERIFY_JWT: "${FUNCTIONS_VERIFY_JWT}"
|
||||
command:
|
||||
[
|
||||
"start",
|
||||
"--main-service",
|
||||
"/home/deno/functions/main"
|
||||
]
|
||||
|
||||
analytics:
|
||||
container_name: supabase-analytics
|
||||
image: supabase/logflare:1.12.5
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- 4000:4000
|
||||
# Uncomment to use Big Query backend for analytics
|
||||
# volumes:
|
||||
# - type: bind
|
||||
# source: ${PWD}/gcloud.json
|
||||
# target: /opt/app/rel/logflare/bin/gcloud.json
|
||||
# read_only: true
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"curl",
|
||||
"http://localhost:4000/health"
|
||||
]
|
||||
timeout: 5s
|
||||
interval: 5s
|
||||
retries: 10
|
||||
depends_on:
|
||||
db:
|
||||
# Disable this if you are using an external Postgres database
|
||||
condition: service_healthy
|
||||
environment:
|
||||
LOGFLARE_NODE_HOST: 127.0.0.1
|
||||
DB_USERNAME: supabase_admin
|
||||
DB_DATABASE: _supabase
|
||||
DB_HOSTNAME: ${POSTGRES_HOST}
|
||||
DB_PORT: ${POSTGRES_PORT}
|
||||
DB_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
DB_SCHEMA: _analytics
|
||||
LOGFLARE_API_KEY: ${LOGFLARE_API_KEY}
|
||||
LOGFLARE_SINGLE_TENANT: true
|
||||
LOGFLARE_SUPABASE_MODE: true
|
||||
LOGFLARE_MIN_CLUSTER_SIZE: 1
|
||||
|
||||
# Comment variables to use Big Query backend for analytics
|
||||
POSTGRES_BACKEND_URL: postgresql://supabase_admin:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/_supabase
|
||||
POSTGRES_BACKEND_SCHEMA: _analytics
|
||||
LOGFLARE_FEATURE_FLAG_OVERRIDE: multibackend=true
|
||||
# Uncomment to use Big Query backend for analytics
|
||||
# GOOGLE_PROJECT_ID: ${GOOGLE_PROJECT_ID}
|
||||
# GOOGLE_PROJECT_NUMBER: ${GOOGLE_PROJECT_NUMBER}
|
||||
|
||||
# Comment out everything below this point if you are using an external Postgres database
|
||||
db:
|
||||
container_name: supabase-db
|
||||
image: supabase/postgres:15.8.1.049
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./volumes/db/realtime.sql:/docker-entrypoint-initdb.d/migrations/99-realtime.sql:Z
|
||||
# Must be superuser to create event trigger
|
||||
- ./volumes/db/webhooks.sql:/docker-entrypoint-initdb.d/init-scripts/98-webhooks.sql:Z
|
||||
# Must be superuser to alter reserved role
|
||||
- ./volumes/db/roles.sql:/docker-entrypoint-initdb.d/init-scripts/99-roles.sql:Z
|
||||
# Initialize the database settings with JWT_SECRET and JWT_EXP
|
||||
- ./volumes/db/jwt.sql:/docker-entrypoint-initdb.d/init-scripts/99-jwt.sql:Z
|
||||
# PGDATA directory is persisted between restarts
|
||||
- ./volumes/db/data:/var/lib/postgresql/data:Z
|
||||
# Changes required for internal supabase data such as _analytics
|
||||
- ./volumes/db/_supabase.sql:/docker-entrypoint-initdb.d/migrations/97-_supabase.sql:Z
|
||||
# Changes required for Analytics support
|
||||
- ./volumes/db/logs.sql:/docker-entrypoint-initdb.d/migrations/99-logs.sql:Z
|
||||
# Changes required for Pooler support
|
||||
- ./volumes/db/pooler.sql:/docker-entrypoint-initdb.d/migrations/99-pooler.sql:Z
|
||||
# Use named volume to persist pgsodium decryption key between restarts
|
||||
- supabase-config:/etc/postgresql-custom
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"pg_isready",
|
||||
"-U",
|
||||
"postgres",
|
||||
"-h",
|
||||
"localhost"
|
||||
]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
depends_on:
|
||||
vector:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
POSTGRES_HOST: /var/run/postgresql
|
||||
PGPORT: ${POSTGRES_PORT}
|
||||
POSTGRES_PORT: ${POSTGRES_PORT}
|
||||
PGPASSWORD: ${POSTGRES_PASSWORD}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
PGDATABASE: ${POSTGRES_DB}
|
||||
POSTGRES_DB: ${POSTGRES_DB}
|
||||
JWT_SECRET: ${JWT_SECRET}
|
||||
JWT_EXP: ${JWT_EXPIRY}
|
||||
command:
|
||||
[
|
||||
"postgres",
|
||||
"-c",
|
||||
"config_file=/etc/postgresql/postgresql.conf",
|
||||
"-c",
|
||||
"log_min_messages=fatal" # prevents Realtime polling queries from appearing in logs
|
||||
]
|
||||
|
||||
vector:
|
||||
container_name: supabase-vector
|
||||
image: timberio/vector:0.28.1-alpine
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./volumes/logs/vector.yml:/etc/vector/vector.yml:ro
|
||||
- ${DOCKER_SOCKET_LOCATION}:/var/run/docker.sock:ro
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"wget",
|
||||
"--no-verbose",
|
||||
"--tries=1",
|
||||
"--spider",
|
||||
"http://vector:9001/health"
|
||||
]
|
||||
timeout: 5s
|
||||
interval: 5s
|
||||
retries: 3
|
||||
environment:
|
||||
LOGFLARE_API_KEY: ${LOGFLARE_API_KEY}
|
||||
command:
|
||||
[
|
||||
"--config",
|
||||
"/etc/vector/vector.yml"
|
||||
]
|
||||
|
||||
# Update the DATABASE_URL if you are using an external Postgres database
|
||||
supavisor:
|
||||
container_name: supabase-pooler
|
||||
image: supabase/supavisor:2.4.12
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- ${POSTGRES_PORT}:5432
|
||||
- ${POOLER_PROXY_PORT_TRANSACTION}:6543
|
||||
volumes:
|
||||
- ./volumes/pooler/pooler.exs:/etc/pooler/pooler.exs:ro
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"curl",
|
||||
"-sSfL",
|
||||
"--head",
|
||||
"-o",
|
||||
"/dev/null",
|
||||
"http://127.0.0.1:4000/api/health"
|
||||
]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
analytics:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
PORT: 4000
|
||||
POSTGRES_PORT: ${POSTGRES_PORT}
|
||||
POSTGRES_DB: ${POSTGRES_DB}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
DATABASE_URL: ecto://supabase_admin:${POSTGRES_PASSWORD}@db:${POSTGRES_PORT}/_supabase
|
||||
CLUSTER_POSTGRES: true
|
||||
SECRET_KEY_BASE: ${SECRET_KEY_BASE}
|
||||
VAULT_ENC_KEY: ${VAULT_ENC_KEY}
|
||||
API_JWT_SECRET: ${JWT_SECRET}
|
||||
METRICS_JWT_SECRET: ${JWT_SECRET}
|
||||
REGION: local
|
||||
ERL_AFLAGS: -proto_dist inet_tcp
|
||||
POOLER_TENANT_ID: ${POOLER_TENANT_ID}
|
||||
POOLER_DEFAULT_POOL_SIZE: ${POOLER_DEFAULT_POOL_SIZE}
|
||||
POOLER_MAX_CLIENT_CONN: ${POOLER_MAX_CLIENT_CONN}
|
||||
POOLER_POOL_MODE: transaction
|
||||
command:
|
||||
[
|
||||
"/bin/sh",
|
||||
"-c",
|
||||
"/app/bin/migrate && /app/bin/supavisor eval \"$$(cat /etc/pooler/pooler.exs)\" && /app/bin/server"
|
||||
]
|
||||
|
||||
volumes:
|
||||
supabase-config:
|
||||
44
autogpt_platform/db/docker/reset.sh
Executable file
44
autogpt_platform/db/docker/reset.sh
Executable file
@@ -0,0 +1,44 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo "WARNING: This will remove all containers and container data, and will reset the .env file. This action cannot be undone!"
|
||||
read -p "Are you sure you want to proceed? (y/N) " -n 1 -r
|
||||
echo # Move to a new line
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]
|
||||
then
|
||||
echo "Operation cancelled."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Stopping and removing all containers..."
|
||||
docker compose -f docker-compose.yml -f ./dev/docker-compose.dev.yml down -v --remove-orphans
|
||||
|
||||
echo "Cleaning up bind-mounted directories..."
|
||||
BIND_MOUNTS=(
|
||||
"./volumes/db/data"
|
||||
)
|
||||
|
||||
for DIR in "${BIND_MOUNTS[@]}"; do
|
||||
if [ -d "$DIR" ]; then
|
||||
echo "Deleting $DIR..."
|
||||
rm -rf "$DIR"
|
||||
else
|
||||
echo "Directory $DIR does not exist. Skipping bind mount deletion step..."
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Resetting .env file..."
|
||||
if [ -f ".env" ]; then
|
||||
echo "Removing existing .env file..."
|
||||
rm -f .env
|
||||
else
|
||||
echo "No .env file found. Skipping .env removal step..."
|
||||
fi
|
||||
|
||||
if [ -f ".env.example" ]; then
|
||||
echo "Copying .env.example to .env..."
|
||||
cp .env.example .env
|
||||
else
|
||||
echo ".env.example file not found. Skipping .env reset step..."
|
||||
fi
|
||||
|
||||
echo "Cleanup complete!"
|
||||
241
autogpt_platform/db/docker/volumes/api/kong.yml
Normal file
241
autogpt_platform/db/docker/volumes/api/kong.yml
Normal file
@@ -0,0 +1,241 @@
|
||||
_format_version: '2.1'
|
||||
_transform: true
|
||||
|
||||
###
|
||||
### Consumers / Users
|
||||
###
|
||||
consumers:
|
||||
- username: DASHBOARD
|
||||
- username: anon
|
||||
keyauth_credentials:
|
||||
- key: $SUPABASE_ANON_KEY
|
||||
- username: service_role
|
||||
keyauth_credentials:
|
||||
- key: $SUPABASE_SERVICE_KEY
|
||||
|
||||
###
|
||||
### Access Control List
|
||||
###
|
||||
acls:
|
||||
- consumer: anon
|
||||
group: anon
|
||||
- consumer: service_role
|
||||
group: admin
|
||||
|
||||
###
|
||||
### Dashboard credentials
|
||||
###
|
||||
basicauth_credentials:
|
||||
- consumer: DASHBOARD
|
||||
username: $DASHBOARD_USERNAME
|
||||
password: $DASHBOARD_PASSWORD
|
||||
|
||||
###
|
||||
### API Routes
|
||||
###
|
||||
services:
|
||||
## Open Auth routes
|
||||
- name: auth-v1-open
|
||||
url: http://auth:9999/verify
|
||||
routes:
|
||||
- name: auth-v1-open
|
||||
strip_path: true
|
||||
paths:
|
||||
- /auth/v1/verify
|
||||
plugins:
|
||||
- name: cors
|
||||
- name: auth-v1-open-callback
|
||||
url: http://auth:9999/callback
|
||||
routes:
|
||||
- name: auth-v1-open-callback
|
||||
strip_path: true
|
||||
paths:
|
||||
- /auth/v1/callback
|
||||
plugins:
|
||||
- name: cors
|
||||
- name: auth-v1-open-authorize
|
||||
url: http://auth:9999/authorize
|
||||
routes:
|
||||
- name: auth-v1-open-authorize
|
||||
strip_path: true
|
||||
paths:
|
||||
- /auth/v1/authorize
|
||||
plugins:
|
||||
- name: cors
|
||||
|
||||
## Secure Auth routes
|
||||
- name: auth-v1
|
||||
_comment: 'GoTrue: /auth/v1/* -> http://auth:9999/*'
|
||||
url: http://auth:9999/
|
||||
routes:
|
||||
- name: auth-v1-all
|
||||
strip_path: true
|
||||
paths:
|
||||
- /auth/v1/
|
||||
plugins:
|
||||
- name: cors
|
||||
- name: key-auth
|
||||
config:
|
||||
hide_credentials: false
|
||||
- name: acl
|
||||
config:
|
||||
hide_groups_header: true
|
||||
allow:
|
||||
- admin
|
||||
- anon
|
||||
|
||||
## Secure REST routes
|
||||
- name: rest-v1
|
||||
_comment: 'PostgREST: /rest/v1/* -> http://rest:3000/*'
|
||||
url: http://rest:3000/
|
||||
routes:
|
||||
- name: rest-v1-all
|
||||
strip_path: true
|
||||
paths:
|
||||
- /rest/v1/
|
||||
plugins:
|
||||
- name: cors
|
||||
- name: key-auth
|
||||
config:
|
||||
hide_credentials: true
|
||||
- name: acl
|
||||
config:
|
||||
hide_groups_header: true
|
||||
allow:
|
||||
- admin
|
||||
- anon
|
||||
|
||||
## Secure GraphQL routes
|
||||
- name: graphql-v1
|
||||
_comment: 'PostgREST: /graphql/v1/* -> http://rest:3000/rpc/graphql'
|
||||
url: http://rest:3000/rpc/graphql
|
||||
routes:
|
||||
- name: graphql-v1-all
|
||||
strip_path: true
|
||||
paths:
|
||||
- /graphql/v1
|
||||
plugins:
|
||||
- name: cors
|
||||
- name: key-auth
|
||||
config:
|
||||
hide_credentials: true
|
||||
- name: request-transformer
|
||||
config:
|
||||
add:
|
||||
headers:
|
||||
- Content-Profile:graphql_public
|
||||
- name: acl
|
||||
config:
|
||||
hide_groups_header: true
|
||||
allow:
|
||||
- admin
|
||||
- anon
|
||||
|
||||
## Secure Realtime routes
|
||||
- name: realtime-v1-ws
|
||||
_comment: 'Realtime: /realtime/v1/* -> ws://realtime:4000/socket/*'
|
||||
url: http://realtime-dev.supabase-realtime:4000/socket
|
||||
protocol: ws
|
||||
routes:
|
||||
- name: realtime-v1-ws
|
||||
strip_path: true
|
||||
paths:
|
||||
- /realtime/v1/
|
||||
plugins:
|
||||
- name: cors
|
||||
- name: key-auth
|
||||
config:
|
||||
hide_credentials: false
|
||||
- name: acl
|
||||
config:
|
||||
hide_groups_header: true
|
||||
allow:
|
||||
- admin
|
||||
- anon
|
||||
- name: realtime-v1-rest
|
||||
_comment: 'Realtime: /realtime/v1/* -> ws://realtime:4000/socket/*'
|
||||
url: http://realtime-dev.supabase-realtime:4000/api
|
||||
protocol: http
|
||||
routes:
|
||||
- name: realtime-v1-rest
|
||||
strip_path: true
|
||||
paths:
|
||||
- /realtime/v1/api
|
||||
plugins:
|
||||
- name: cors
|
||||
- name: key-auth
|
||||
config:
|
||||
hide_credentials: false
|
||||
- name: acl
|
||||
config:
|
||||
hide_groups_header: true
|
||||
allow:
|
||||
- admin
|
||||
- anon
|
||||
## Storage routes: the storage server manages its own auth
|
||||
- name: storage-v1
|
||||
_comment: 'Storage: /storage/v1/* -> http://storage:5000/*'
|
||||
url: http://storage:5000/
|
||||
routes:
|
||||
- name: storage-v1-all
|
||||
strip_path: true
|
||||
paths:
|
||||
- /storage/v1/
|
||||
plugins:
|
||||
- name: cors
|
||||
|
||||
## Edge Functions routes
|
||||
- name: functions-v1
|
||||
_comment: 'Edge Functions: /functions/v1/* -> http://functions:9000/*'
|
||||
url: http://functions:9000/
|
||||
routes:
|
||||
- name: functions-v1-all
|
||||
strip_path: true
|
||||
paths:
|
||||
- /functions/v1/
|
||||
plugins:
|
||||
- name: cors
|
||||
|
||||
## Analytics routes
|
||||
- name: analytics-v1
|
||||
_comment: 'Analytics: /analytics/v1/* -> http://logflare:4000/*'
|
||||
url: http://analytics:4000/
|
||||
routes:
|
||||
- name: analytics-v1-all
|
||||
strip_path: true
|
||||
paths:
|
||||
- /analytics/v1/
|
||||
|
||||
## Secure Database routes
|
||||
- name: meta
|
||||
_comment: 'pg-meta: /pg/* -> http://pg-meta:8080/*'
|
||||
url: http://meta:8080/
|
||||
routes:
|
||||
- name: meta-all
|
||||
strip_path: true
|
||||
paths:
|
||||
- /pg/
|
||||
plugins:
|
||||
- name: key-auth
|
||||
config:
|
||||
hide_credentials: false
|
||||
- name: acl
|
||||
config:
|
||||
hide_groups_header: true
|
||||
allow:
|
||||
- admin
|
||||
|
||||
## Protected Dashboard - catch all remaining routes
|
||||
- name: dashboard
|
||||
_comment: 'Studio: /* -> http://studio:3000/*'
|
||||
url: http://studio:3000/
|
||||
routes:
|
||||
- name: dashboard-all
|
||||
strip_path: true
|
||||
paths:
|
||||
- /
|
||||
plugins:
|
||||
- name: cors
|
||||
- name: basic-auth
|
||||
config:
|
||||
hide_credentials: true
|
||||
3
autogpt_platform/db/docker/volumes/db/_supabase.sql
Normal file
3
autogpt_platform/db/docker/volumes/db/_supabase.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
\set pguser `echo "$POSTGRES_USER"`
|
||||
|
||||
CREATE DATABASE _supabase WITH OWNER :pguser;
|
||||
0
autogpt_platform/db/docker/volumes/db/init/data.sql
Executable file
0
autogpt_platform/db/docker/volumes/db/init/data.sql
Executable file
5
autogpt_platform/db/docker/volumes/db/jwt.sql
Normal file
5
autogpt_platform/db/docker/volumes/db/jwt.sql
Normal file
@@ -0,0 +1,5 @@
|
||||
\set jwt_secret `echo "$JWT_SECRET"`
|
||||
\set jwt_exp `echo "$JWT_EXP"`
|
||||
|
||||
ALTER DATABASE postgres SET "app.settings.jwt_secret" TO :'jwt_secret';
|
||||
ALTER DATABASE postgres SET "app.settings.jwt_exp" TO :'jwt_exp';
|
||||
6
autogpt_platform/db/docker/volumes/db/logs.sql
Normal file
6
autogpt_platform/db/docker/volumes/db/logs.sql
Normal file
@@ -0,0 +1,6 @@
|
||||
\set pguser `echo "$POSTGRES_USER"`
|
||||
|
||||
\c _supabase
|
||||
create schema if not exists _analytics;
|
||||
alter schema _analytics owner to :pguser;
|
||||
\c postgres
|
||||
6
autogpt_platform/db/docker/volumes/db/pooler.sql
Normal file
6
autogpt_platform/db/docker/volumes/db/pooler.sql
Normal file
@@ -0,0 +1,6 @@
|
||||
\set pguser `echo "$POSTGRES_USER"`
|
||||
|
||||
\c _supabase
|
||||
create schema if not exists _supavisor;
|
||||
alter schema _supavisor owner to :pguser;
|
||||
\c postgres
|
||||
4
autogpt_platform/db/docker/volumes/db/realtime.sql
Normal file
4
autogpt_platform/db/docker/volumes/db/realtime.sql
Normal file
@@ -0,0 +1,4 @@
|
||||
\set pguser `echo "$POSTGRES_USER"`
|
||||
|
||||
create schema if not exists _realtime;
|
||||
alter schema _realtime owner to :pguser;
|
||||
8
autogpt_platform/db/docker/volumes/db/roles.sql
Normal file
8
autogpt_platform/db/docker/volumes/db/roles.sql
Normal file
@@ -0,0 +1,8 @@
|
||||
-- NOTE: change to your own passwords for production environments
|
||||
\set pgpass `echo "$POSTGRES_PASSWORD"`
|
||||
|
||||
ALTER USER authenticator WITH PASSWORD :'pgpass';
|
||||
ALTER USER pgbouncer WITH PASSWORD :'pgpass';
|
||||
ALTER USER supabase_auth_admin WITH PASSWORD :'pgpass';
|
||||
ALTER USER supabase_functions_admin WITH PASSWORD :'pgpass';
|
||||
ALTER USER supabase_storage_admin WITH PASSWORD :'pgpass';
|
||||
208
autogpt_platform/db/docker/volumes/db/webhooks.sql
Normal file
208
autogpt_platform/db/docker/volumes/db/webhooks.sql
Normal file
@@ -0,0 +1,208 @@
|
||||
BEGIN;
|
||||
-- Create pg_net extension
|
||||
CREATE EXTENSION IF NOT EXISTS pg_net SCHEMA extensions;
|
||||
-- Create supabase_functions schema
|
||||
CREATE SCHEMA supabase_functions AUTHORIZATION supabase_admin;
|
||||
GRANT USAGE ON SCHEMA supabase_functions TO postgres, anon, authenticated, service_role;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA supabase_functions GRANT ALL ON TABLES TO postgres, anon, authenticated, service_role;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA supabase_functions GRANT ALL ON FUNCTIONS TO postgres, anon, authenticated, service_role;
|
||||
ALTER DEFAULT PRIVILEGES IN SCHEMA supabase_functions GRANT ALL ON SEQUENCES TO postgres, anon, authenticated, service_role;
|
||||
-- supabase_functions.migrations definition
|
||||
CREATE TABLE supabase_functions.migrations (
|
||||
version text PRIMARY KEY,
|
||||
inserted_at timestamptz NOT NULL DEFAULT NOW()
|
||||
);
|
||||
-- Initial supabase_functions migration
|
||||
INSERT INTO supabase_functions.migrations (version) VALUES ('initial');
|
||||
-- supabase_functions.hooks definition
|
||||
CREATE TABLE supabase_functions.hooks (
|
||||
id bigserial PRIMARY KEY,
|
||||
hook_table_id integer NOT NULL,
|
||||
hook_name text NOT NULL,
|
||||
created_at timestamptz NOT NULL DEFAULT NOW(),
|
||||
request_id bigint
|
||||
);
|
||||
CREATE INDEX supabase_functions_hooks_request_id_idx ON supabase_functions.hooks USING btree (request_id);
|
||||
CREATE INDEX supabase_functions_hooks_h_table_id_h_name_idx ON supabase_functions.hooks USING btree (hook_table_id, hook_name);
|
||||
COMMENT ON TABLE supabase_functions.hooks IS 'Supabase Functions Hooks: Audit trail for triggered hooks.';
|
||||
CREATE FUNCTION supabase_functions.http_request()
|
||||
RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $function$
|
||||
DECLARE
|
||||
request_id bigint;
|
||||
payload jsonb;
|
||||
url text := TG_ARGV[0]::text;
|
||||
method text := TG_ARGV[1]::text;
|
||||
headers jsonb DEFAULT '{}'::jsonb;
|
||||
params jsonb DEFAULT '{}'::jsonb;
|
||||
timeout_ms integer DEFAULT 1000;
|
||||
BEGIN
|
||||
IF url IS NULL OR url = 'null' THEN
|
||||
RAISE EXCEPTION 'url argument is missing';
|
||||
END IF;
|
||||
|
||||
IF method IS NULL OR method = 'null' THEN
|
||||
RAISE EXCEPTION 'method argument is missing';
|
||||
END IF;
|
||||
|
||||
IF TG_ARGV[2] IS NULL OR TG_ARGV[2] = 'null' THEN
|
||||
headers = '{"Content-Type": "application/json"}'::jsonb;
|
||||
ELSE
|
||||
headers = TG_ARGV[2]::jsonb;
|
||||
END IF;
|
||||
|
||||
IF TG_ARGV[3] IS NULL OR TG_ARGV[3] = 'null' THEN
|
||||
params = '{}'::jsonb;
|
||||
ELSE
|
||||
params = TG_ARGV[3]::jsonb;
|
||||
END IF;
|
||||
|
||||
IF TG_ARGV[4] IS NULL OR TG_ARGV[4] = 'null' THEN
|
||||
timeout_ms = 1000;
|
||||
ELSE
|
||||
timeout_ms = TG_ARGV[4]::integer;
|
||||
END IF;
|
||||
|
||||
CASE
|
||||
WHEN method = 'GET' THEN
|
||||
SELECT http_get INTO request_id FROM net.http_get(
|
||||
url,
|
||||
params,
|
||||
headers,
|
||||
timeout_ms
|
||||
);
|
||||
WHEN method = 'POST' THEN
|
||||
payload = jsonb_build_object(
|
||||
'old_record', OLD,
|
||||
'record', NEW,
|
||||
'type', TG_OP,
|
||||
'table', TG_TABLE_NAME,
|
||||
'schema', TG_TABLE_SCHEMA
|
||||
);
|
||||
|
||||
SELECT http_post INTO request_id FROM net.http_post(
|
||||
url,
|
||||
payload,
|
||||
params,
|
||||
headers,
|
||||
timeout_ms
|
||||
);
|
||||
ELSE
|
||||
RAISE EXCEPTION 'method argument % is invalid', method;
|
||||
END CASE;
|
||||
|
||||
INSERT INTO supabase_functions.hooks
|
||||
(hook_table_id, hook_name, request_id)
|
||||
VALUES
|
||||
(TG_RELID, TG_NAME, request_id);
|
||||
|
||||
RETURN NEW;
|
||||
END
|
||||
$function$;
|
||||
-- Supabase super admin
|
||||
DO
|
||||
$$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_roles
|
||||
WHERE rolname = 'supabase_functions_admin'
|
||||
)
|
||||
THEN
|
||||
CREATE USER supabase_functions_admin NOINHERIT CREATEROLE LOGIN NOREPLICATION;
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
GRANT ALL PRIVILEGES ON SCHEMA supabase_functions TO supabase_functions_admin;
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA supabase_functions TO supabase_functions_admin;
|
||||
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA supabase_functions TO supabase_functions_admin;
|
||||
ALTER USER supabase_functions_admin SET search_path = "supabase_functions";
|
||||
ALTER table "supabase_functions".migrations OWNER TO supabase_functions_admin;
|
||||
ALTER table "supabase_functions".hooks OWNER TO supabase_functions_admin;
|
||||
ALTER function "supabase_functions".http_request() OWNER TO supabase_functions_admin;
|
||||
GRANT supabase_functions_admin TO postgres;
|
||||
-- Remove unused supabase_pg_net_admin role
|
||||
DO
|
||||
$$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_roles
|
||||
WHERE rolname = 'supabase_pg_net_admin'
|
||||
)
|
||||
THEN
|
||||
REASSIGN OWNED BY supabase_pg_net_admin TO supabase_admin;
|
||||
DROP OWNED BY supabase_pg_net_admin;
|
||||
DROP ROLE supabase_pg_net_admin;
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
-- pg_net grants when extension is already enabled
|
||||
DO
|
||||
$$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_extension
|
||||
WHERE extname = 'pg_net'
|
||||
)
|
||||
THEN
|
||||
GRANT USAGE ON SCHEMA net TO supabase_functions_admin, postgres, anon, authenticated, service_role;
|
||||
ALTER function net.http_get(url text, params jsonb, headers jsonb, timeout_milliseconds integer) SECURITY DEFINER;
|
||||
ALTER function net.http_post(url text, body jsonb, params jsonb, headers jsonb, timeout_milliseconds integer) SECURITY DEFINER;
|
||||
ALTER function net.http_get(url text, params jsonb, headers jsonb, timeout_milliseconds integer) SET search_path = net;
|
||||
ALTER function net.http_post(url text, body jsonb, params jsonb, headers jsonb, timeout_milliseconds integer) SET search_path = net;
|
||||
REVOKE ALL ON FUNCTION net.http_get(url text, params jsonb, headers jsonb, timeout_milliseconds integer) FROM PUBLIC;
|
||||
REVOKE ALL ON FUNCTION net.http_post(url text, body jsonb, params jsonb, headers jsonb, timeout_milliseconds integer) FROM PUBLIC;
|
||||
GRANT EXECUTE ON FUNCTION net.http_get(url text, params jsonb, headers jsonb, timeout_milliseconds integer) TO supabase_functions_admin, postgres, anon, authenticated, service_role;
|
||||
GRANT EXECUTE ON FUNCTION net.http_post(url text, body jsonb, params jsonb, headers jsonb, timeout_milliseconds integer) TO supabase_functions_admin, postgres, anon, authenticated, service_role;
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
-- Event trigger for pg_net
|
||||
CREATE OR REPLACE FUNCTION extensions.grant_pg_net_access()
|
||||
RETURNS event_trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_event_trigger_ddl_commands() AS ev
|
||||
JOIN pg_extension AS ext
|
||||
ON ev.objid = ext.oid
|
||||
WHERE ext.extname = 'pg_net'
|
||||
)
|
||||
THEN
|
||||
GRANT USAGE ON SCHEMA net TO supabase_functions_admin, postgres, anon, authenticated, service_role;
|
||||
ALTER function net.http_get(url text, params jsonb, headers jsonb, timeout_milliseconds integer) SECURITY DEFINER;
|
||||
ALTER function net.http_post(url text, body jsonb, params jsonb, headers jsonb, timeout_milliseconds integer) SECURITY DEFINER;
|
||||
ALTER function net.http_get(url text, params jsonb, headers jsonb, timeout_milliseconds integer) SET search_path = net;
|
||||
ALTER function net.http_post(url text, body jsonb, params jsonb, headers jsonb, timeout_milliseconds integer) SET search_path = net;
|
||||
REVOKE ALL ON FUNCTION net.http_get(url text, params jsonb, headers jsonb, timeout_milliseconds integer) FROM PUBLIC;
|
||||
REVOKE ALL ON FUNCTION net.http_post(url text, body jsonb, params jsonb, headers jsonb, timeout_milliseconds integer) FROM PUBLIC;
|
||||
GRANT EXECUTE ON FUNCTION net.http_get(url text, params jsonb, headers jsonb, timeout_milliseconds integer) TO supabase_functions_admin, postgres, anon, authenticated, service_role;
|
||||
GRANT EXECUTE ON FUNCTION net.http_post(url text, body jsonb, params jsonb, headers jsonb, timeout_milliseconds integer) TO supabase_functions_admin, postgres, anon, authenticated, service_role;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
COMMENT ON FUNCTION extensions.grant_pg_net_access IS 'Grants access to pg_net';
|
||||
DO
|
||||
$$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_event_trigger
|
||||
WHERE evtname = 'issue_pg_net_access'
|
||||
) THEN
|
||||
CREATE EVENT TRIGGER issue_pg_net_access ON ddl_command_end WHEN TAG IN ('CREATE EXTENSION')
|
||||
EXECUTE PROCEDURE extensions.grant_pg_net_access();
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
INSERT INTO supabase_functions.migrations (version) VALUES ('20210809183423_update_grants');
|
||||
ALTER function supabase_functions.http_request() SECURITY DEFINER;
|
||||
ALTER function supabase_functions.http_request() SET search_path = supabase_functions;
|
||||
REVOKE ALL ON FUNCTION supabase_functions.http_request() FROM PUBLIC;
|
||||
GRANT EXECUTE ON FUNCTION supabase_functions.http_request() TO postgres, anon, authenticated, service_role;
|
||||
COMMIT;
|
||||
16
autogpt_platform/db/docker/volumes/functions/hello/index.ts
Normal file
16
autogpt_platform/db/docker/volumes/functions/hello/index.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
// Follow this setup guide to integrate the Deno language server with your editor:
|
||||
// https://deno.land/manual/getting_started/setup_your_environment
|
||||
// This enables autocomplete, go to definition, etc.
|
||||
|
||||
import { serve } from "https://deno.land/std@0.177.1/http/server.ts"
|
||||
|
||||
serve(async () => {
|
||||
return new Response(
|
||||
`"Hello from Edge Functions!"`,
|
||||
{ headers: { "Content-Type": "application/json" } },
|
||||
)
|
||||
})
|
||||
|
||||
// To invoke:
|
||||
// curl 'http://localhost:<KONG_HTTP_PORT>/functions/v1/hello' \
|
||||
// --header 'Authorization: Bearer <anon/service_role API key>'
|
||||
94
autogpt_platform/db/docker/volumes/functions/main/index.ts
Normal file
94
autogpt_platform/db/docker/volumes/functions/main/index.ts
Normal file
@@ -0,0 +1,94 @@
|
||||
import { serve } from 'https://deno.land/std@0.131.0/http/server.ts'
|
||||
import * as jose from 'https://deno.land/x/jose@v4.14.4/index.ts'
|
||||
|
||||
console.log('main function started')
|
||||
|
||||
const JWT_SECRET = Deno.env.get('JWT_SECRET')
|
||||
const VERIFY_JWT = Deno.env.get('VERIFY_JWT') === 'true'
|
||||
|
||||
function getAuthToken(req: Request) {
|
||||
const authHeader = req.headers.get('authorization')
|
||||
if (!authHeader) {
|
||||
throw new Error('Missing authorization header')
|
||||
}
|
||||
const [bearer, token] = authHeader.split(' ')
|
||||
if (bearer !== 'Bearer') {
|
||||
throw new Error(`Auth header is not 'Bearer {token}'`)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
async function verifyJWT(jwt: string): Promise<boolean> {
|
||||
const encoder = new TextEncoder()
|
||||
const secretKey = encoder.encode(JWT_SECRET)
|
||||
try {
|
||||
await jose.jwtVerify(jwt, secretKey)
|
||||
} catch (err) {
|
||||
console.error(err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
serve(async (req: Request) => {
|
||||
if (req.method !== 'OPTIONS' && VERIFY_JWT) {
|
||||
try {
|
||||
const token = getAuthToken(req)
|
||||
const isValidJWT = await verifyJWT(token)
|
||||
|
||||
if (!isValidJWT) {
|
||||
return new Response(JSON.stringify({ msg: 'Invalid JWT' }), {
|
||||
status: 401,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
})
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e)
|
||||
return new Response(JSON.stringify({ msg: e.toString() }), {
|
||||
status: 401,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const url = new URL(req.url)
|
||||
const { pathname } = url
|
||||
const path_parts = pathname.split('/')
|
||||
const service_name = path_parts[1]
|
||||
|
||||
if (!service_name || service_name === '') {
|
||||
const error = { msg: 'missing function name in request' }
|
||||
return new Response(JSON.stringify(error), {
|
||||
status: 400,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
})
|
||||
}
|
||||
|
||||
const servicePath = `/home/deno/functions/${service_name}`
|
||||
console.error(`serving the request with ${servicePath}`)
|
||||
|
||||
const memoryLimitMb = 150
|
||||
const workerTimeoutMs = 1 * 60 * 1000
|
||||
const noModuleCache = false
|
||||
const importMapPath = null
|
||||
const envVarsObj = Deno.env.toObject()
|
||||
const envVars = Object.keys(envVarsObj).map((k) => [k, envVarsObj[k]])
|
||||
|
||||
try {
|
||||
const worker = await EdgeRuntime.userWorkers.create({
|
||||
servicePath,
|
||||
memoryLimitMb,
|
||||
workerTimeoutMs,
|
||||
noModuleCache,
|
||||
importMapPath,
|
||||
envVars,
|
||||
})
|
||||
return await worker.fetch(req)
|
||||
} catch (e) {
|
||||
const error = { msg: e.toString() }
|
||||
return new Response(JSON.stringify(error), {
|
||||
status: 500,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
})
|
||||
}
|
||||
})
|
||||
232
autogpt_platform/db/docker/volumes/logs/vector.yml
Normal file
232
autogpt_platform/db/docker/volumes/logs/vector.yml
Normal file
@@ -0,0 +1,232 @@
|
||||
api:
|
||||
enabled: true
|
||||
address: 0.0.0.0:9001
|
||||
|
||||
sources:
|
||||
docker_host:
|
||||
type: docker_logs
|
||||
exclude_containers:
|
||||
- supabase-vector
|
||||
|
||||
transforms:
|
||||
project_logs:
|
||||
type: remap
|
||||
inputs:
|
||||
- docker_host
|
||||
source: |-
|
||||
.project = "default"
|
||||
.event_message = del(.message)
|
||||
.appname = del(.container_name)
|
||||
del(.container_created_at)
|
||||
del(.container_id)
|
||||
del(.source_type)
|
||||
del(.stream)
|
||||
del(.label)
|
||||
del(.image)
|
||||
del(.host)
|
||||
del(.stream)
|
||||
router:
|
||||
type: route
|
||||
inputs:
|
||||
- project_logs
|
||||
route:
|
||||
kong: '.appname == "supabase-kong"'
|
||||
auth: '.appname == "supabase-auth"'
|
||||
rest: '.appname == "supabase-rest"'
|
||||
realtime: '.appname == "supabase-realtime"'
|
||||
storage: '.appname == "supabase-storage"'
|
||||
functions: '.appname == "supabase-functions"'
|
||||
db: '.appname == "supabase-db"'
|
||||
# Ignores non nginx errors since they are related with kong booting up
|
||||
kong_logs:
|
||||
type: remap
|
||||
inputs:
|
||||
- router.kong
|
||||
source: |-
|
||||
req, err = parse_nginx_log(.event_message, "combined")
|
||||
if err == null {
|
||||
.timestamp = req.timestamp
|
||||
.metadata.request.headers.referer = req.referer
|
||||
.metadata.request.headers.user_agent = req.agent
|
||||
.metadata.request.headers.cf_connecting_ip = req.client
|
||||
.metadata.request.method = req.method
|
||||
.metadata.request.path = req.path
|
||||
.metadata.request.protocol = req.protocol
|
||||
.metadata.response.status_code = req.status
|
||||
}
|
||||
if err != null {
|
||||
abort
|
||||
}
|
||||
# Ignores non nginx errors since they are related with kong booting up
|
||||
kong_err:
|
||||
type: remap
|
||||
inputs:
|
||||
- router.kong
|
||||
source: |-
|
||||
.metadata.request.method = "GET"
|
||||
.metadata.response.status_code = 200
|
||||
parsed, err = parse_nginx_log(.event_message, "error")
|
||||
if err == null {
|
||||
.timestamp = parsed.timestamp
|
||||
.severity = parsed.severity
|
||||
.metadata.request.host = parsed.host
|
||||
.metadata.request.headers.cf_connecting_ip = parsed.client
|
||||
url, err = split(parsed.request, " ")
|
||||
if err == null {
|
||||
.metadata.request.method = url[0]
|
||||
.metadata.request.path = url[1]
|
||||
.metadata.request.protocol = url[2]
|
||||
}
|
||||
}
|
||||
if err != null {
|
||||
abort
|
||||
}
|
||||
# Gotrue logs are structured json strings which frontend parses directly. But we keep metadata for consistency.
|
||||
auth_logs:
|
||||
type: remap
|
||||
inputs:
|
||||
- router.auth
|
||||
source: |-
|
||||
parsed, err = parse_json(.event_message)
|
||||
if err == null {
|
||||
.metadata.timestamp = parsed.time
|
||||
.metadata = merge!(.metadata, parsed)
|
||||
}
|
||||
# PostgREST logs are structured so we separate timestamp from message using regex
|
||||
rest_logs:
|
||||
type: remap
|
||||
inputs:
|
||||
- router.rest
|
||||
source: |-
|
||||
parsed, err = parse_regex(.event_message, r'^(?P<time>.*): (?P<msg>.*)$')
|
||||
if err == null {
|
||||
.event_message = parsed.msg
|
||||
.timestamp = to_timestamp!(parsed.time)
|
||||
.metadata.host = .project
|
||||
}
|
||||
# Realtime logs are structured so we parse the severity level using regex (ignore time because it has no date)
|
||||
realtime_logs:
|
||||
type: remap
|
||||
inputs:
|
||||
- router.realtime
|
||||
source: |-
|
||||
.metadata.project = del(.project)
|
||||
.metadata.external_id = .metadata.project
|
||||
parsed, err = parse_regex(.event_message, r'^(?P<time>\d+:\d+:\d+\.\d+) \[(?P<level>\w+)\] (?P<msg>.*)$')
|
||||
if err == null {
|
||||
.event_message = parsed.msg
|
||||
.metadata.level = parsed.level
|
||||
}
|
||||
# Storage logs may contain json objects so we parse them for completeness
|
||||
storage_logs:
|
||||
type: remap
|
||||
inputs:
|
||||
- router.storage
|
||||
source: |-
|
||||
.metadata.project = del(.project)
|
||||
.metadata.tenantId = .metadata.project
|
||||
parsed, err = parse_json(.event_message)
|
||||
if err == null {
|
||||
.event_message = parsed.msg
|
||||
.metadata.level = parsed.level
|
||||
.metadata.timestamp = parsed.time
|
||||
.metadata.context[0].host = parsed.hostname
|
||||
.metadata.context[0].pid = parsed.pid
|
||||
}
|
||||
# Postgres logs some messages to stderr which we map to warning severity level
|
||||
db_logs:
|
||||
type: remap
|
||||
inputs:
|
||||
- router.db
|
||||
source: |-
|
||||
.metadata.host = "db-default"
|
||||
.metadata.parsed.timestamp = .timestamp
|
||||
|
||||
parsed, err = parse_regex(.event_message, r'.*(?P<level>INFO|NOTICE|WARNING|ERROR|LOG|FATAL|PANIC?):.*', numeric_groups: true)
|
||||
|
||||
if err != null || parsed == null {
|
||||
.metadata.parsed.error_severity = "info"
|
||||
}
|
||||
if parsed != null {
|
||||
.metadata.parsed.error_severity = parsed.level
|
||||
}
|
||||
if .metadata.parsed.error_severity == "info" {
|
||||
.metadata.parsed.error_severity = "log"
|
||||
}
|
||||
.metadata.parsed.error_severity = upcase!(.metadata.parsed.error_severity)
|
||||
|
||||
sinks:
|
||||
logflare_auth:
|
||||
type: 'http'
|
||||
inputs:
|
||||
- auth_logs
|
||||
encoding:
|
||||
codec: 'json'
|
||||
method: 'post'
|
||||
request:
|
||||
retry_max_duration_secs: 10
|
||||
uri: 'http://analytics:4000/api/logs?source_name=gotrue.logs.prod&api_key=${LOGFLARE_API_KEY?LOGFLARE_API_KEY is required}'
|
||||
logflare_realtime:
|
||||
type: 'http'
|
||||
inputs:
|
||||
- realtime_logs
|
||||
encoding:
|
||||
codec: 'json'
|
||||
method: 'post'
|
||||
request:
|
||||
retry_max_duration_secs: 10
|
||||
uri: 'http://analytics:4000/api/logs?source_name=realtime.logs.prod&api_key=${LOGFLARE_API_KEY?LOGFLARE_API_KEY is required}'
|
||||
logflare_rest:
|
||||
type: 'http'
|
||||
inputs:
|
||||
- rest_logs
|
||||
encoding:
|
||||
codec: 'json'
|
||||
method: 'post'
|
||||
request:
|
||||
retry_max_duration_secs: 10
|
||||
uri: 'http://analytics:4000/api/logs?source_name=postgREST.logs.prod&api_key=${LOGFLARE_API_KEY?LOGFLARE_API_KEY is required}'
|
||||
logflare_db:
|
||||
type: 'http'
|
||||
inputs:
|
||||
- db_logs
|
||||
encoding:
|
||||
codec: 'json'
|
||||
method: 'post'
|
||||
request:
|
||||
retry_max_duration_secs: 10
|
||||
# We must route the sink through kong because ingesting logs before logflare is fully initialised will
|
||||
# lead to broken queries from studio. This works by the assumption that containers are started in the
|
||||
# following order: vector > db > logflare > kong
|
||||
uri: 'http://kong:8000/analytics/v1/api/logs?source_name=postgres.logs&api_key=${LOGFLARE_API_KEY?LOGFLARE_API_KEY is required}'
|
||||
logflare_functions:
|
||||
type: 'http'
|
||||
inputs:
|
||||
- router.functions
|
||||
encoding:
|
||||
codec: 'json'
|
||||
method: 'post'
|
||||
request:
|
||||
retry_max_duration_secs: 10
|
||||
uri: 'http://analytics:4000/api/logs?source_name=deno-relay-logs&api_key=${LOGFLARE_API_KEY?LOGFLARE_API_KEY is required}'
|
||||
logflare_storage:
|
||||
type: 'http'
|
||||
inputs:
|
||||
- storage_logs
|
||||
encoding:
|
||||
codec: 'json'
|
||||
method: 'post'
|
||||
request:
|
||||
retry_max_duration_secs: 10
|
||||
uri: 'http://analytics:4000/api/logs?source_name=storage.logs.prod.2&api_key=${LOGFLARE_API_KEY?LOGFLARE_API_KEY is required}'
|
||||
logflare_kong:
|
||||
type: 'http'
|
||||
inputs:
|
||||
- kong_logs
|
||||
- kong_err
|
||||
encoding:
|
||||
codec: 'json'
|
||||
method: 'post'
|
||||
request:
|
||||
retry_max_duration_secs: 10
|
||||
uri: 'http://analytics:4000/api/logs?source_name=cloudflare.logs.prod&api_key=${LOGFLARE_API_KEY?LOGFLARE_API_KEY is required}'
|
||||
30
autogpt_platform/db/docker/volumes/pooler/pooler.exs
Normal file
30
autogpt_platform/db/docker/volumes/pooler/pooler.exs
Normal file
@@ -0,0 +1,30 @@
|
||||
{:ok, _} = Application.ensure_all_started(:supavisor)
|
||||
|
||||
{:ok, version} =
|
||||
case Supavisor.Repo.query!("select version()") do
|
||||
%{rows: [[ver]]} -> Supavisor.Helpers.parse_pg_version(ver)
|
||||
_ -> nil
|
||||
end
|
||||
|
||||
params = %{
|
||||
"external_id" => System.get_env("POOLER_TENANT_ID"),
|
||||
"db_host" => "db",
|
||||
"db_port" => System.get_env("POSTGRES_PORT"),
|
||||
"db_database" => System.get_env("POSTGRES_DB"),
|
||||
"require_user" => false,
|
||||
"auth_query" => "SELECT * FROM pgbouncer.get_auth($1)",
|
||||
"default_max_clients" => System.get_env("POOLER_MAX_CLIENT_CONN"),
|
||||
"default_pool_size" => System.get_env("POOLER_DEFAULT_POOL_SIZE"),
|
||||
"default_parameter_status" => %{"server_version" => version},
|
||||
"users" => [%{
|
||||
"db_user" => "pgbouncer",
|
||||
"db_password" => System.get_env("POSTGRES_PASSWORD"),
|
||||
"mode_type" => System.get_env("POOLER_POOL_MODE"),
|
||||
"pool_size" => System.get_env("POOLER_DEFAULT_POOL_SIZE"),
|
||||
"is_manager" => true
|
||||
}]
|
||||
}
|
||||
|
||||
if !Supavisor.Tenants.get_tenant_by_external_id(params["external_id"]) do
|
||||
{:ok, _} = Supavisor.Tenants.create_tenant(params)
|
||||
end
|
||||
@@ -121,6 +121,7 @@ services:
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
environment:
|
||||
- DATABASEMANAGER_HOST=rest_server
|
||||
- SUPABASE_URL=http://kong:8000
|
||||
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
- SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
@@ -163,6 +164,7 @@ services:
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
environment:
|
||||
- DATABASEMANAGER_HOST=rest_server
|
||||
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- REDIS_HOST=redis
|
||||
|
||||
@@ -5,7 +5,7 @@ networks:
|
||||
name: shared-network
|
||||
|
||||
volumes:
|
||||
db-config:
|
||||
supabase-config:
|
||||
|
||||
x-agpt-services:
|
||||
&agpt-services
|
||||
@@ -67,19 +67,19 @@ services:
|
||||
studio:
|
||||
<<: *supabase-services
|
||||
extends:
|
||||
file: ./supabase/docker/docker-compose.yml
|
||||
file: ./db/docker/docker-compose.yml
|
||||
service: studio
|
||||
|
||||
kong:
|
||||
<<: *supabase-services
|
||||
extends:
|
||||
file: ./supabase/docker/docker-compose.yml
|
||||
file: ./db/docker/docker-compose.yml
|
||||
service: kong
|
||||
|
||||
auth:
|
||||
<<: *supabase-services
|
||||
extends:
|
||||
file: ./supabase/docker/docker-compose.yml
|
||||
file: ./db/docker/docker-compose.yml
|
||||
service: auth
|
||||
environment:
|
||||
GOTRUE_MAILER_AUTOCONFIRM: true
|
||||
@@ -87,54 +87,57 @@ services:
|
||||
rest:
|
||||
<<: *supabase-services
|
||||
extends:
|
||||
file: ./supabase/docker/docker-compose.yml
|
||||
file: ./db/docker/docker-compose.yml
|
||||
service: rest
|
||||
|
||||
realtime:
|
||||
<<: *supabase-services
|
||||
extends:
|
||||
file: ./supabase/docker/docker-compose.yml
|
||||
file: ./db/docker/docker-compose.yml
|
||||
service: realtime
|
||||
|
||||
storage:
|
||||
<<: *supabase-services
|
||||
extends:
|
||||
file: ./supabase/docker/docker-compose.yml
|
||||
file: ./db/docker/docker-compose.yml
|
||||
service: storage
|
||||
|
||||
imgproxy:
|
||||
<<: *supabase-services
|
||||
extends:
|
||||
file: ./supabase/docker/docker-compose.yml
|
||||
file: ./db/docker/docker-compose.yml
|
||||
service: imgproxy
|
||||
|
||||
meta:
|
||||
<<: *supabase-services
|
||||
extends:
|
||||
file: ./supabase/docker/docker-compose.yml
|
||||
file: ./db/docker/docker-compose.yml
|
||||
service: meta
|
||||
|
||||
functions:
|
||||
<<: *supabase-services
|
||||
extends:
|
||||
file: ./supabase/docker/docker-compose.yml
|
||||
file: ./db/docker/docker-compose.yml
|
||||
service: functions
|
||||
|
||||
analytics:
|
||||
<<: *supabase-services
|
||||
extends:
|
||||
file: ./supabase/docker/docker-compose.yml
|
||||
file: ./db/docker/docker-compose.yml
|
||||
service: analytics
|
||||
|
||||
db:
|
||||
<<: *supabase-services
|
||||
extends:
|
||||
file: ./supabase/docker/docker-compose.yml
|
||||
file: ./db/docker/docker-compose.yml
|
||||
service: db
|
||||
ports:
|
||||
- ${POSTGRES_PORT}:5432 # We don't use Supavisor locally, so we expose the db directly.
|
||||
|
||||
vector:
|
||||
<<: *supabase-services
|
||||
extends:
|
||||
file: ./supabase/docker/docker-compose.yml
|
||||
file: ./db/docker/docker-compose.yml
|
||||
service: vector
|
||||
|
||||
deps:
|
||||
|
||||
@@ -23,9 +23,9 @@
|
||||
"defaults"
|
||||
],
|
||||
"dependencies": {
|
||||
"@faker-js/faker": "^9.4.0",
|
||||
"@faker-js/faker": "^9.6.0",
|
||||
"@hookform/resolvers": "^3.10.0",
|
||||
"@next/third-parties": "^15.1.6",
|
||||
"@next/third-parties": "^15.2.1",
|
||||
"@radix-ui/react-alert-dialog": "^1.1.5",
|
||||
"@radix-ui/react-avatar": "^1.1.1",
|
||||
"@radix-ui/react-checkbox": "^1.1.2",
|
||||
@@ -46,9 +46,9 @@
|
||||
"@radix-ui/react-tooltip": "^1.1.7",
|
||||
"@sentry/nextjs": "^8",
|
||||
"@supabase/ssr": "^0.5.2",
|
||||
"@supabase/supabase-js": "^2.48.1",
|
||||
"@tanstack/react-table": "^8.20.6",
|
||||
"@xyflow/react": "^12.4.2",
|
||||
"@supabase/supabase-js": "^2.49.1",
|
||||
"@tanstack/react-table": "^8.21.2",
|
||||
"@xyflow/react": "12.4.2",
|
||||
"ajv": "^8.17.1",
|
||||
"boring-avatars": "^1.11.2",
|
||||
"canvas-confetti": "^1.9.3",
|
||||
@@ -60,28 +60,28 @@
|
||||
"dotenv": "^16.4.7",
|
||||
"elliptic": "6.6.1",
|
||||
"embla-carousel-react": "^8.5.2",
|
||||
"framer-motion": "^12.0.11",
|
||||
"framer-motion": "^12.4.11",
|
||||
"geist": "^1.3.1",
|
||||
"launchdarkly-react-client-sdk": "^3.6.1",
|
||||
"lodash.debounce": "^4.0.8",
|
||||
"lucide-react": "^0.474.0",
|
||||
"lucide-react": "^0.479.0",
|
||||
"moment": "^2.30.1",
|
||||
"next": "^14.2.21",
|
||||
"next-themes": "^0.4.4",
|
||||
"next": "^14.2.25",
|
||||
"next-themes": "^0.4.5",
|
||||
"react": "^18",
|
||||
"react-day-picker": "^9.5.1",
|
||||
"react-day-picker": "^9.6.1",
|
||||
"react-dom": "^18",
|
||||
"react-drag-drop-files": "^2.4.0",
|
||||
"react-hook-form": "^7.54.0",
|
||||
"react-icons": "^5.4.0",
|
||||
"react-icons": "^5.5.0",
|
||||
"react-markdown": "^9.0.3",
|
||||
"react-modal": "^3.16.3",
|
||||
"react-shepherd": "^6.1.7",
|
||||
"react-shepherd": "^6.1.8",
|
||||
"recharts": "^2.15.1",
|
||||
"tailwind-merge": "^2.6.0",
|
||||
"tailwindcss-animate": "^1.0.7",
|
||||
"uuid": "^11.0.5",
|
||||
"zod": "^3.23.8"
|
||||
"uuid": "^11.1.0",
|
||||
"zod": "^3.24.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/storybook": "^3.2.4",
|
||||
|
||||
@@ -1,15 +1,29 @@
|
||||
"use client";
|
||||
import { ShoppingBag } from "lucide-react";
|
||||
import { Sidebar } from "@/components/agptui/Sidebar";
|
||||
import { Users, DollarSign, LogOut } from "lucide-react";
|
||||
|
||||
import { useState } from "react";
|
||||
import Link from "next/link";
|
||||
import { BinaryIcon, XIcon } from "lucide-react";
|
||||
import { usePathname } from "next/navigation"; // Add this import
|
||||
import { IconSliders } from "@/components/ui/icons";
|
||||
|
||||
const tabs = [
|
||||
{ name: "Dashboard", href: "/admin/dashboard" },
|
||||
{ name: "Marketplace", href: "/admin/marketplace" },
|
||||
{ name: "Users", href: "/admin/users" },
|
||||
{ name: "Settings", href: "/admin/settings" },
|
||||
const sidebarLinkGroups = [
|
||||
{
|
||||
links: [
|
||||
{
|
||||
text: "Marketplace Management",
|
||||
href: "/admin/marketplace",
|
||||
icon: <Users className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "User Spending",
|
||||
href: "/admin/spending",
|
||||
icon: <DollarSign className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Admin User Management",
|
||||
href: "/admin/settings",
|
||||
icon: <IconSliders className="h-6 w-6" />,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
export default function AdminLayout({
|
||||
@@ -17,84 +31,10 @@ export default function AdminLayout({
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
const pathname = usePathname(); // Get the current pathname
|
||||
const [activeTab, setActiveTab] = useState(() => {
|
||||
// Set active tab based on the current route
|
||||
return tabs.find((tab) => tab.href === pathname)?.name || tabs[0].name;
|
||||
});
|
||||
const [mobileMenuOpen, setMobileMenuOpen] = useState(false);
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-gray-100">
|
||||
<nav className="bg-white shadow-sm">
|
||||
<div className="max-w-10xl mx-auto px-4 sm:px-6 lg:px-8">
|
||||
<div className="flex h-16 items-center justify-between">
|
||||
<div className="flex items-center">
|
||||
<div className="flex-shrink-0">
|
||||
<h1 className="text-xl font-bold">Admin Panel</h1>
|
||||
</div>
|
||||
<div className="hidden sm:ml-6 sm:flex sm:space-x-8">
|
||||
{tabs.map((tab) => (
|
||||
<Link
|
||||
key={tab.name}
|
||||
href={tab.href}
|
||||
className={`${
|
||||
activeTab === tab.name
|
||||
? "border-indigo-500 text-indigo-600"
|
||||
: "border-transparent text-gray-500 hover:border-gray-300 hover:text-gray-700"
|
||||
} inline-flex items-center border-b-2 px-1 pt-1 text-sm font-medium`}
|
||||
onClick={() => setActiveTab(tab.name)}
|
||||
>
|
||||
{tab.name}
|
||||
</Link>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
<div className="sm:hidden">
|
||||
<button
|
||||
type="button"
|
||||
className="inline-flex items-center justify-center rounded-md p-2 text-gray-400 hover:bg-gray-100 hover:text-gray-500 focus:outline-none focus:ring-2 focus:ring-inset focus:ring-indigo-500"
|
||||
onClick={() => setMobileMenuOpen(!mobileMenuOpen)}
|
||||
>
|
||||
<span className="sr-only">Open main menu</span>
|
||||
{mobileMenuOpen ? (
|
||||
<XIcon className="block h-6 w-6" aria-hidden="true" />
|
||||
) : (
|
||||
<BinaryIcon className="block h-6 w-6" aria-hidden="true" />
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{mobileMenuOpen && (
|
||||
<div className="sm:hidden">
|
||||
<div className="space-y-1 pb-3 pt-2">
|
||||
{tabs.map((tab) => (
|
||||
<Link
|
||||
key={tab.name}
|
||||
href={tab.href}
|
||||
className={`${
|
||||
activeTab === tab.name
|
||||
? "border-indigo-500 bg-indigo-50 text-indigo-700"
|
||||
: "border-transparent text-gray-600 hover:border-gray-300 hover:bg-gray-50 hover:text-gray-800"
|
||||
} block border-l-4 py-2 pl-3 pr-4 text-base font-medium`}
|
||||
onClick={() => {
|
||||
setActiveTab(tab.name);
|
||||
setMobileMenuOpen(false);
|
||||
}}
|
||||
>
|
||||
{tab.name}
|
||||
</Link>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</nav>
|
||||
|
||||
<main className="py-10">
|
||||
<div className="mx-auto max-w-7xl px-4 sm:px-6 lg:px-8">{children}</div>
|
||||
</main>
|
||||
<div className="flex min-h-screen w-screen flex-col lg:flex-row">
|
||||
<Sidebar linkGroups={sidebarLinkGroups} />
|
||||
<div className="flex-1 pl-4">{children}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
"use server";
|
||||
|
||||
import { revalidatePath } from "next/cache";
|
||||
import BackendApi from "@/lib/autogpt-server-api";
|
||||
import {
|
||||
NotificationPreferenceDTO,
|
||||
StoreListingsWithVersionsResponse,
|
||||
StoreSubmissionsResponse,
|
||||
SubmissionStatus,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
|
||||
export async function approveAgent(formData: FormData) {
|
||||
const data = {
|
||||
store_listing_version_id: formData.get("id") as string,
|
||||
is_approved: true,
|
||||
comments: formData.get("comments") as string,
|
||||
};
|
||||
const api = new BackendApi();
|
||||
await api.reviewSubmissionAdmin(data.store_listing_version_id, data);
|
||||
|
||||
revalidatePath("/admin/marketplace");
|
||||
}
|
||||
|
||||
export async function rejectAgent(formData: FormData) {
|
||||
const data = {
|
||||
store_listing_version_id: formData.get("id") as string,
|
||||
is_approved: false,
|
||||
comments: formData.get("comments") as string,
|
||||
internal_comments: formData.get("internal_comments") as string,
|
||||
};
|
||||
const api = new BackendApi();
|
||||
await api.reviewSubmissionAdmin(data.store_listing_version_id, data);
|
||||
|
||||
revalidatePath("/admin/marketplace");
|
||||
}
|
||||
|
||||
export async function getAdminListingsWithVersions(
|
||||
status?: SubmissionStatus,
|
||||
search?: string,
|
||||
page: number = 1,
|
||||
pageSize: number = 20,
|
||||
): Promise<StoreListingsWithVersionsResponse> {
|
||||
const data: Record<string, any> = {
|
||||
page,
|
||||
page_size: pageSize,
|
||||
};
|
||||
|
||||
if (status) {
|
||||
data.status = status;
|
||||
}
|
||||
|
||||
if (search) {
|
||||
data.search = search;
|
||||
}
|
||||
const api = new BackendApi();
|
||||
const response = await api.getAdminListingsWithVersions(data);
|
||||
return response;
|
||||
}
|
||||
@@ -1,25 +1,62 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { Suspense } from "react";
|
||||
import type { SubmissionStatus } from "@/lib/autogpt-server-api/types";
|
||||
import { AdminAgentsDataTable } from "@/components/admin/marketplace/admin-agents-data-table";
|
||||
|
||||
import React from "react";
|
||||
// import { getReviewableAgents } from "@/components/admin/marketplace/actions";
|
||||
// import AdminMarketplaceAgentList from "@/components/admin/marketplace/AdminMarketplaceAgentList";
|
||||
// import AdminFeaturedAgentsControl from "@/components/admin/marketplace/AdminFeaturedAgentsControl";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
async function AdminMarketplace() {
|
||||
// const reviewableAgents = await getReviewableAgents();
|
||||
async function AdminMarketplaceDashboard({
|
||||
searchParams,
|
||||
}: {
|
||||
searchParams: {
|
||||
page?: string;
|
||||
status?: string;
|
||||
search?: string;
|
||||
};
|
||||
}) {
|
||||
const page = searchParams.page ? Number.parseInt(searchParams.page) : 1;
|
||||
const status = searchParams.status as SubmissionStatus | undefined;
|
||||
const search = searchParams.search;
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* <AdminMarketplaceAgentList agents={reviewableAgents.items} />
|
||||
<Separator className="my-4" />
|
||||
<AdminFeaturedAgentsControl className="mt-4" /> */}
|
||||
</>
|
||||
<div className="mx-auto p-6">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<h1 className="text-3xl font-bold">Marketplace Management</h1>
|
||||
<p className="text-gray-500">
|
||||
Unified view for marketplace management and approval history
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Suspense
|
||||
fallback={
|
||||
<div className="py-10 text-center">Loading submissions...</div>
|
||||
}
|
||||
>
|
||||
<AdminAgentsDataTable
|
||||
initialPage={page}
|
||||
initialStatus={status}
|
||||
initialSearch={search}
|
||||
/>
|
||||
</Suspense>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default async function AdminDashboardPage() {
|
||||
export default async function AdminMarketplacePage({
|
||||
searchParams,
|
||||
}: {
|
||||
searchParams: {
|
||||
page?: string;
|
||||
status?: string;
|
||||
search?: string;
|
||||
};
|
||||
}) {
|
||||
"use server";
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedAdminMarketplace = await withAdminAccess(AdminMarketplace);
|
||||
return <ProtectedAdminMarketplace />;
|
||||
const ProtectedAdminMarketplace = await withAdminAccess(
|
||||
AdminMarketplaceDashboard,
|
||||
);
|
||||
return <ProtectedAdminMarketplace searchParams={searchParams} />;
|
||||
}
|
||||
|
||||
@@ -24,10 +24,14 @@ export async function askOtto(
|
||||
|
||||
try {
|
||||
const response = await api.askOtto(ottoQuery);
|
||||
revalidatePath("/build");
|
||||
return response;
|
||||
} catch (error) {
|
||||
console.error("Error in askOtto server action:", error);
|
||||
throw error;
|
||||
return {
|
||||
answer: error instanceof Error ? error.message : "Unknown error occurred",
|
||||
documents: [],
|
||||
success: false,
|
||||
error: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import React, { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { useParams, useRouter } from "next/navigation";
|
||||
|
||||
import { exportAsJSONFile } from "@/lib/utils";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import {
|
||||
GraphExecution,
|
||||
@@ -191,19 +192,40 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
[schedules, api],
|
||||
);
|
||||
|
||||
const downloadGraph = useCallback(
|
||||
async () =>
|
||||
agent &&
|
||||
// Export sanitized graph from backend
|
||||
api
|
||||
.getGraph(agent.agent_id, agent.agent_version, true)
|
||||
.then((graph) =>
|
||||
exportAsJSONFile(graph, `${graph.name}_v${graph.version}.json`),
|
||||
),
|
||||
[api, agent],
|
||||
);
|
||||
|
||||
const agentActions: ButtonAction[] = useMemo(
|
||||
() => [
|
||||
{
|
||||
label: "Open in builder",
|
||||
callback: () => agent && router.push(`/build?flowID=${agent.agent_id}`),
|
||||
},
|
||||
...(agent?.can_access_graph
|
||||
? [
|
||||
{
|
||||
label: "Open in builder",
|
||||
callback: () =>
|
||||
agent &&
|
||||
router.push(
|
||||
`/build?flowID=${agent.agent_id}&flowVersion=${agent.agent_version}`,
|
||||
),
|
||||
},
|
||||
{ label: "Export agent to file", callback: downloadGraph },
|
||||
]
|
||||
: []),
|
||||
{
|
||||
label: "Delete agent",
|
||||
variant: "destructive",
|
||||
callback: () => setAgentDeleteDialogOpen(true),
|
||||
},
|
||||
],
|
||||
[agent, router],
|
||||
[agent, router, downloadGraph],
|
||||
);
|
||||
|
||||
if (!agent || !graph) {
|
||||
|
||||
@@ -29,7 +29,6 @@ export default function CreditsPage() {
|
||||
formatCredits,
|
||||
refundTopUp,
|
||||
refundRequests,
|
||||
fetchRefundRequests,
|
||||
} = useCredits({
|
||||
fetchInitialAutoTopUpConfig: true,
|
||||
fetchInitialRefundRequests: true,
|
||||
|
||||
@@ -1,17 +1,48 @@
|
||||
import * as React from "react";
|
||||
import { Sidebar } from "@/components/agptui/Sidebar";
|
||||
import {
|
||||
IconDashboardLayout,
|
||||
IconIntegrations,
|
||||
IconProfile,
|
||||
IconSliders,
|
||||
IconCoin,
|
||||
} from "@/components/ui/icons";
|
||||
import { KeyIcon } from "lucide-react";
|
||||
|
||||
export default function Layout({ children }: { children: React.ReactNode }) {
|
||||
const sidebarLinkGroups = [
|
||||
{
|
||||
links: [
|
||||
{ text: "Creator Dashboard", href: "/profile/dashboard" },
|
||||
{ text: "Agent dashboard", href: "/profile/agent-dashboard" },
|
||||
{ text: "Billing", href: "/profile/credits" },
|
||||
{ text: "Integrations", href: "/profile/integrations" },
|
||||
{ text: "API Keys", href: "/profile/api_keys" },
|
||||
{ text: "Profile", href: "/profile" },
|
||||
{ text: "Settings", href: "/profile/settings" },
|
||||
{
|
||||
text: "Creator Dashboard",
|
||||
href: "/profile/dashboard",
|
||||
icon: <IconDashboardLayout className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Billing",
|
||||
href: "/profile/credits",
|
||||
icon: <IconCoin className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Integrations",
|
||||
href: "/profile/integrations",
|
||||
icon: <IconIntegrations className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "API Keys",
|
||||
href: "/profile/api_keys",
|
||||
icon: <KeyIcon className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Profile",
|
||||
href: "/profile",
|
||||
icon: <IconProfile className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Settings",
|
||||
href: "/profile/settings",
|
||||
icon: <IconSliders className="h-6 w-6" />,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
@@ -29,7 +29,6 @@ export async function sendResetEmail(email: string) {
|
||||
return error.message;
|
||||
}
|
||||
|
||||
console.log("Reset email sent");
|
||||
redirect("/reset_password");
|
||||
},
|
||||
);
|
||||
|
||||
@@ -461,6 +461,37 @@ const FlowEditor: React.FC<{
|
||||
});
|
||||
}, [nodes, setViewport, x, y]);
|
||||
|
||||
const fillDefaults = useCallback((obj: any, schema: any) => {
|
||||
// Iterate over the schema properties
|
||||
for (const key in schema.properties) {
|
||||
if (schema.properties.hasOwnProperty(key)) {
|
||||
const propertySchema = schema.properties[key];
|
||||
|
||||
// If the property is not in the object, initialize it with the default value
|
||||
if (!obj.hasOwnProperty(key)) {
|
||||
if (propertySchema.default !== undefined) {
|
||||
obj[key] = propertySchema.default;
|
||||
} else if (propertySchema.type === "object") {
|
||||
// Recursively fill defaults for nested objects
|
||||
obj[key] = fillDefaults({}, propertySchema);
|
||||
} else if (propertySchema.type === "array") {
|
||||
// Recursively fill defaults for arrays
|
||||
obj[key] = fillDefaults([], propertySchema);
|
||||
}
|
||||
} else {
|
||||
// If the property exists, recursively fill defaults for nested objects/arrays
|
||||
if (propertySchema.type === "object") {
|
||||
obj[key] = fillDefaults(obj[key], propertySchema);
|
||||
} else if (propertySchema.type === "array") {
|
||||
obj[key] = fillDefaults(obj[key], propertySchema);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return obj;
|
||||
}, []);
|
||||
|
||||
const addNode = useCallback(
|
||||
(blockId: string, nodeType: string, hardcodedValues: any = {}) => {
|
||||
const nodeSchema = availableNodes.find((node) => node.id === blockId);
|
||||
@@ -507,7 +538,10 @@ const FlowEditor: React.FC<{
|
||||
categories: nodeSchema.categories,
|
||||
inputSchema: nodeSchema.inputSchema,
|
||||
outputSchema: nodeSchema.outputSchema,
|
||||
hardcodedValues: hardcodedValues,
|
||||
hardcodedValues: {
|
||||
...fillDefaults({}, nodeSchema.inputSchema),
|
||||
...hardcodedValues,
|
||||
},
|
||||
connections: [],
|
||||
isOutputOpen: false,
|
||||
block_id: blockId,
|
||||
|
||||
@@ -56,29 +56,30 @@ const OttoChatWidget = () => {
|
||||
// Add user message to chat
|
||||
setMessages((prev) => [...prev, { type: "user", content: userMessage }]);
|
||||
|
||||
// Add temporary processing message
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{ type: "assistant", content: "Processing your question..." },
|
||||
]);
|
||||
|
||||
const conversationHistory = messages.reduce<
|
||||
{ query: string; response: string }[]
|
||||
>((acc, msg, i, arr) => {
|
||||
if (
|
||||
msg.type === "user" &&
|
||||
i + 1 < arr.length &&
|
||||
arr[i + 1].type === "assistant" &&
|
||||
arr[i + 1].content !== "Processing your question..."
|
||||
) {
|
||||
acc.push({
|
||||
query: msg.content,
|
||||
response: arr[i + 1].content,
|
||||
});
|
||||
}
|
||||
return acc;
|
||||
}, []);
|
||||
|
||||
try {
|
||||
// Add temporary processing message
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{ type: "assistant", content: "Processing your question..." },
|
||||
]);
|
||||
|
||||
const conversationHistory = messages.reduce<
|
||||
{ query: string; response: string }[]
|
||||
>((acc, msg, i, arr) => {
|
||||
if (
|
||||
msg.type === "user" &&
|
||||
i + 1 < arr.length &&
|
||||
arr[i + 1].type === "assistant"
|
||||
) {
|
||||
acc.push({
|
||||
query: msg.content,
|
||||
response: arr[i + 1].content,
|
||||
});
|
||||
}
|
||||
return acc;
|
||||
}, []);
|
||||
|
||||
const data = await askOtto(
|
||||
userMessage,
|
||||
conversationHistory,
|
||||
@@ -86,34 +87,43 @@ const OttoChatWidget = () => {
|
||||
flowID || undefined,
|
||||
);
|
||||
|
||||
// Remove processing message and add actual response
|
||||
setMessages((prev) => [
|
||||
...prev.slice(0, -1),
|
||||
{ type: "assistant", content: data.answer },
|
||||
]);
|
||||
} catch (error) {
|
||||
console.error("Error calling API:", error);
|
||||
// Remove processing message and add error message
|
||||
const errorMessage =
|
||||
error instanceof Error && error.message === "Authentication required"
|
||||
? "Please sign in to use the chat feature."
|
||||
: "Sorry, there was an error processing your message. Please try again.";
|
||||
// Check if the response contains an error
|
||||
if ("error" in data && data.error === true) {
|
||||
// Handle different error types
|
||||
let errorMessage =
|
||||
"Sorry, there was an error processing your message. Please try again.";
|
||||
|
||||
setMessages((prev) => [
|
||||
...prev.slice(0, -1),
|
||||
{ type: "assistant", content: errorMessage },
|
||||
]);
|
||||
if (data.answer === "Authentication required") {
|
||||
errorMessage = "Please sign in to use the chat feature.";
|
||||
} else if (data.answer === "Failed to connect to Otto service") {
|
||||
errorMessage =
|
||||
"Otto service is currently unavailable. Please try again later.";
|
||||
} else if (data.answer.includes("timed out")) {
|
||||
errorMessage = "Request timed out. Please try again later.";
|
||||
}
|
||||
|
||||
if (
|
||||
error instanceof Error &&
|
||||
error.message === "Authentication required"
|
||||
) {
|
||||
toast({
|
||||
title: "Authentication Error",
|
||||
description: "Please sign in to use the chat feature.",
|
||||
variant: "destructive",
|
||||
});
|
||||
// Remove processing message and add error message
|
||||
setMessages((prev) => [
|
||||
...prev.slice(0, -1),
|
||||
{ type: "assistant", content: errorMessage },
|
||||
]);
|
||||
} else {
|
||||
// Remove processing message and add actual response
|
||||
setMessages((prev) => [
|
||||
...prev.slice(0, -1),
|
||||
{ type: "assistant", content: data.answer },
|
||||
]);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Unexpected error in chat widget:", error);
|
||||
setMessages((prev) => [
|
||||
...prev.slice(0, -1),
|
||||
{
|
||||
type: "assistant",
|
||||
content:
|
||||
"An unexpected error occurred. Please refresh the page and try again.",
|
||||
},
|
||||
]);
|
||||
} finally {
|
||||
setIsProcessing(false);
|
||||
setIncludeGraphData(false);
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
// "use client";
|
||||
|
||||
// import {
|
||||
// Dialog,
|
||||
// DialogContent,
|
||||
// DialogClose,
|
||||
// DialogFooter,
|
||||
// DialogHeader,
|
||||
// DialogTitle,
|
||||
// DialogTrigger,
|
||||
// } from "@/components/ui/dialog";
|
||||
// import { Button } from "@/components/ui/button";
|
||||
// import {
|
||||
// MultiSelector,
|
||||
// MultiSelectorContent,
|
||||
// MultiSelectorInput,
|
||||
// MultiSelectorItem,
|
||||
// MultiSelectorList,
|
||||
// MultiSelectorTrigger,
|
||||
// } from "@/components/ui/multiselect";
|
||||
// import { Controller, useForm } from "react-hook-form";
|
||||
// import {
|
||||
// Select,
|
||||
// SelectContent,
|
||||
// SelectItem,
|
||||
// SelectTrigger,
|
||||
// SelectValue,
|
||||
// } from "@/components/ui/select";
|
||||
// import { useState } from "react";
|
||||
// import { addFeaturedAgent } from "./actions";
|
||||
// import { Agent } from "@/lib/marketplace-api/types";
|
||||
|
||||
// type FormData = {
|
||||
// agent: string;
|
||||
// categories: string[];
|
||||
// };
|
||||
|
||||
// export const AdminAddFeaturedAgentDialog = ({
|
||||
// categories,
|
||||
// agents,
|
||||
// }: {
|
||||
// categories: string[];
|
||||
// agents: Agent[];
|
||||
// }) => {
|
||||
// const [selectedAgent, setSelectedAgent] = useState<string>("");
|
||||
// const [selectedCategories, setSelectedCategories] = useState<string[]>([]);
|
||||
|
||||
// const {
|
||||
// control,
|
||||
// handleSubmit,
|
||||
// watch,
|
||||
// setValue,
|
||||
// formState: { errors },
|
||||
// } = useForm<FormData>({
|
||||
// defaultValues: {
|
||||
// agent: "",
|
||||
// categories: [],
|
||||
// },
|
||||
// });
|
||||
|
||||
// return (
|
||||
// <Dialog>
|
||||
// <DialogTrigger asChild>
|
||||
// <Button variant="outline" size="sm">
|
||||
// Add Featured Agent
|
||||
// </Button>
|
||||
// </DialogTrigger>
|
||||
// <DialogContent>
|
||||
// <DialogHeader>
|
||||
// <DialogTitle>Add Featured Agent</DialogTitle>
|
||||
// </DialogHeader>
|
||||
// <div className="flex flex-col gap-4">
|
||||
// <Controller
|
||||
// name="agent"
|
||||
// control={control}
|
||||
// rules={{ required: true }}
|
||||
// render={({ field }) => (
|
||||
// <div>
|
||||
// <label htmlFor={field.name}>Agent</label>
|
||||
// <Select
|
||||
// onValueChange={(value) => {
|
||||
// field.onChange(value);
|
||||
// setSelectedAgent(value);
|
||||
// }}
|
||||
// value={field.value || ""}
|
||||
// >
|
||||
// <SelectTrigger>
|
||||
// <SelectValue placeholder="Select an agent" />
|
||||
// </SelectTrigger>
|
||||
// <SelectContent>
|
||||
// {/* Populate with agents */}
|
||||
// {agents.map((agent) => (
|
||||
// <SelectItem key={agent.id} value={agent.id}>
|
||||
// {agent.name}
|
||||
// </SelectItem>
|
||||
// ))}
|
||||
// </SelectContent>
|
||||
// </Select>
|
||||
// </div>
|
||||
// )}
|
||||
// />
|
||||
// <Controller
|
||||
// name="categories"
|
||||
// control={control}
|
||||
// render={({ field }) => (
|
||||
// <MultiSelector
|
||||
// values={field.value || []}
|
||||
// onValuesChange={(values) => {
|
||||
// field.onChange(values);
|
||||
// setSelectedCategories(values);
|
||||
// }}
|
||||
// >
|
||||
// <MultiSelectorTrigger>
|
||||
// <MultiSelectorInput placeholder="Select categories" />
|
||||
// </MultiSelectorTrigger>
|
||||
// <MultiSelectorContent>
|
||||
// <MultiSelectorList>
|
||||
// {categories.map((category) => (
|
||||
// <MultiSelectorItem key={category} value={category}>
|
||||
// {category}
|
||||
// </MultiSelectorItem>
|
||||
// ))}
|
||||
// </MultiSelectorList>
|
||||
// </MultiSelectorContent>
|
||||
// </MultiSelector>
|
||||
// )}
|
||||
// />
|
||||
// </div>
|
||||
// <DialogFooter>
|
||||
// <DialogClose asChild>
|
||||
// <Button variant="outline">Cancel</Button>
|
||||
// </DialogClose>
|
||||
// <DialogClose asChild>
|
||||
// <Button
|
||||
// type="submit"
|
||||
// onClick={async () => {
|
||||
// // Handle adding the featured agent
|
||||
// await addFeaturedAgent(selectedAgent, selectedCategories);
|
||||
// // close the dialog
|
||||
// }}
|
||||
// >
|
||||
// Add
|
||||
// </Button>
|
||||
// </DialogClose>
|
||||
// </DialogFooter>
|
||||
// </DialogContent>
|
||||
// </Dialog>
|
||||
// );
|
||||
// };
|
||||
@@ -1,74 +0,0 @@
|
||||
// import { Button } from "@/components/ui/button";
|
||||
// import {
|
||||
// getFeaturedAgents,
|
||||
// removeFeaturedAgent,
|
||||
// getCategories,
|
||||
// getNotFeaturedAgents,
|
||||
// } from "./actions";
|
||||
|
||||
// import FeaturedAgentsTable from "./FeaturedAgentsTable";
|
||||
// import { AdminAddFeaturedAgentDialog } from "./AdminAddFeaturedAgentDialog";
|
||||
// import { revalidatePath } from "next/cache";
|
||||
// import * as Sentry from "@sentry/nextjs";
|
||||
|
||||
// export default async function AdminFeaturedAgentsControl({
|
||||
// className,
|
||||
// }: {
|
||||
// className?: string;
|
||||
// }) {
|
||||
// // add featured agent button
|
||||
// // modal to select agent?
|
||||
// // modal to select categories?
|
||||
// // table of featured agents
|
||||
// // in table
|
||||
// // remove featured agent button
|
||||
// // edit featured agent categories button
|
||||
// // table footer
|
||||
// // Next page button
|
||||
// // Previous page button
|
||||
// // Page number input
|
||||
// // Page size input
|
||||
// // Total pages input
|
||||
// // Go to page button
|
||||
|
||||
// const page = 1;
|
||||
// const pageSize = 10;
|
||||
|
||||
// const agents = await getFeaturedAgents(page, pageSize);
|
||||
|
||||
// const categories = await getCategories();
|
||||
|
||||
// const notFeaturedAgents = await getNotFeaturedAgents();
|
||||
|
||||
// return (
|
||||
// <div className={`flex flex-col gap-4 ${className}`}>
|
||||
// <div className="mb-4 flex justify-between">
|
||||
// <h3 className="text-lg font-semibold">Featured Agent Controls</h3>
|
||||
// <AdminAddFeaturedAgentDialog
|
||||
// categories={categories.unique_categories}
|
||||
// agents={notFeaturedAgents.items}
|
||||
// />
|
||||
// </div>
|
||||
// <FeaturedAgentsTable
|
||||
// agents={agents.items}
|
||||
// globalActions={[
|
||||
// {
|
||||
// component: <Button>Remove</Button>,
|
||||
// action: async (rows) => {
|
||||
// "use server";
|
||||
// return await Sentry.withServerActionInstrumentation(
|
||||
// "removeFeaturedAgent",
|
||||
// {},
|
||||
// async () => {
|
||||
// const all = rows.map((row) => removeFeaturedAgent(row.id));
|
||||
// await Promise.all(all);
|
||||
// revalidatePath("/marketplace");
|
||||
// },
|
||||
// );
|
||||
// },
|
||||
// },
|
||||
// ]}
|
||||
// />
|
||||
// </div>
|
||||
// );
|
||||
// }
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user