Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into zamilmajdy/fix-static-output-resolve

This commit is contained in:
Zamil Majdy
2025-03-25 09:03:17 +07:00
136 changed files with 7143 additions and 2790 deletions

View File

@@ -129,30 +129,6 @@ updates:
- "minor"
- "patch"
# Submodules
- package-ecosystem: "gitsubmodule"
directory: "autogpt_platform/supabase"
schedule:
interval: "weekly"
open-pull-requests-limit: 1
target-branch: "dev"
commit-message:
prefix: "chore(platform/deps)"
prefix-development: "chore(platform/deps-dev)"
groups:
production-dependencies:
dependency-type: "production"
update-types:
- "minor"
- "patch"
development-dependencies:
dependency-type: "development"
update-types:
- "minor"
- "patch"
# Docs
- package-ecosystem: 'pip'
directory: "docs/"

View File

@@ -82,7 +82,7 @@ jobs:
- name: Copy default supabase .env
run: |
cp ../supabase/docker/.env.example ../.env
cp ../.env.example ../.env
- name: Copy backend .env
run: |

3
.gitmodules vendored
View File

@@ -1,6 +1,3 @@
[submodule "classic/forge/tests/vcr_cassettes"]
path = classic/forge/tests/vcr_cassettes
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes
[submodule "autogpt_platform/supabase"]
path = autogpt_platform/supabase
url = https://github.com/supabase/supabase.git

View File

@@ -0,0 +1,123 @@
############
# Secrets
# YOU MUST CHANGE THESE BEFORE GOING INTO PRODUCTION
############
POSTGRES_PASSWORD=your-super-secret-and-long-postgres-password
JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
DASHBOARD_USERNAME=supabase
DASHBOARD_PASSWORD=this_password_is_insecure_and_should_be_updated
SECRET_KEY_BASE=UpNVntn3cDxHJpq99YMc1T1AQgQpc8kfYTuRgBiYa15BLrx8etQoXz3gZv1/u2oq
VAULT_ENC_KEY=your-encryption-key-32-chars-min
############
# Database - You can change these to any PostgreSQL database that has logical replication enabled.
############
POSTGRES_HOST=db
POSTGRES_DB=postgres
POSTGRES_PORT=5432
# default user is postgres
############
# Supavisor -- Database pooler
############
POOLER_PROXY_PORT_TRANSACTION=6543
POOLER_DEFAULT_POOL_SIZE=20
POOLER_MAX_CLIENT_CONN=100
POOLER_TENANT_ID=your-tenant-id
############
# API Proxy - Configuration for the Kong Reverse proxy.
############
KONG_HTTP_PORT=8000
KONG_HTTPS_PORT=8443
############
# API - Configuration for PostgREST.
############
PGRST_DB_SCHEMAS=public,storage,graphql_public
############
# Auth - Configuration for the GoTrue authentication server.
############
## General
SITE_URL=http://localhost:3000
ADDITIONAL_REDIRECT_URLS=
JWT_EXPIRY=3600
DISABLE_SIGNUP=false
API_EXTERNAL_URL=http://localhost:8000
## Mailer Config
MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify"
MAILER_URLPATHS_INVITE="/auth/v1/verify"
MAILER_URLPATHS_RECOVERY="/auth/v1/verify"
MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify"
## Email auth
ENABLE_EMAIL_SIGNUP=true
ENABLE_EMAIL_AUTOCONFIRM=false
SMTP_ADMIN_EMAIL=admin@example.com
SMTP_HOST=supabase-mail
SMTP_PORT=2500
SMTP_USER=fake_mail_user
SMTP_PASS=fake_mail_password
SMTP_SENDER_NAME=fake_sender
ENABLE_ANONYMOUS_USERS=false
## Phone auth
ENABLE_PHONE_SIGNUP=true
ENABLE_PHONE_AUTOCONFIRM=true
############
# Studio - Configuration for the Dashboard
############
STUDIO_DEFAULT_ORGANIZATION=Default Organization
STUDIO_DEFAULT_PROJECT=Default Project
STUDIO_PORT=3000
# replace if you intend to use Studio outside of localhost
SUPABASE_PUBLIC_URL=http://localhost:8000
# Enable webp support
IMGPROXY_ENABLE_WEBP_DETECTION=true
# Add your OpenAI API key to enable SQL Editor Assistant
OPENAI_API_KEY=
############
# Functions - Configuration for Functions
############
# NOTE: VERIFY_JWT applies to all functions. Per-function VERIFY_JWT is not supported yet.
FUNCTIONS_VERIFY_JWT=false
############
# Logs - Configuration for Logflare
# Please refer to https://supabase.com/docs/reference/self-hosting-analytics/introduction
############
LOGFLARE_LOGGER_BACKEND_API_KEY=your-super-secret-and-long-logflare-key
# Change vector.toml sinks to reflect this change
LOGFLARE_API_KEY=your-super-secret-and-long-logflare-key
# Docker socket location - this value will differ depending on your OS
DOCKER_SOCKET_LOCATION=/var/run/docker.sock
# Google Cloud Project details
GOOGLE_PROJECT_ID=GOOGLE_PROJECT_ID
GOOGLE_PROJECT_NUMBER=GOOGLE_PROJECT_NUMBER

View File

@@ -22,35 +22,29 @@ To run the AutoGPT Platform, follow these steps:
2. Run the following command:
```
git submodule update --init --recursive --progress
cp .env.example .env
```
This command will initialize and update the submodules in the repository. The `supabase` folder will be cloned to the root directory.
This command will copy the `.env.example` file to `.env`. You can modify the `.env` file to add your own environment variables.
3. Run the following command:
```
cp supabase/docker/.env.example .env
```
This command will copy the `.env.example` file to `.env` in the `supabase/docker` directory. You can modify the `.env` file to add your own environment variables.
4. Run the following command:
```
docker compose up -d
```
This command will start all the necessary backend services defined in the `docker-compose.yml` file in detached mode.
5. Navigate to `frontend` within the `autogpt_platform` directory:
4. Navigate to `frontend` within the `autogpt_platform` directory:
```
cd frontend
```
You will need to run your frontend application separately on your local machine.
6. Run the following command:
5. Run the following command:
```
cp .env.example .env.local
```
This command will copy the `.env.example` file to `.env.local` in the `frontend` directory. You can modify the `.env.local` within this folder to add your own environment variables for the frontend application.
7. Run the following command:
6. Run the following command:
```
npm install
npm run dev
@@ -61,7 +55,7 @@ To run the AutoGPT Platform, follow these steps:
yarn install && yarn dev
```
8. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
7. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
### Docker Compose Commands

View File

@@ -1476,30 +1476,30 @@ pyasn1 = ">=0.1.3"
[[package]]
name = "ruff"
version = "0.9.6"
version = "0.9.10"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
groups = ["dev"]
files = [
{file = "ruff-0.9.6-py3-none-linux_armv6l.whl", hash = "sha256:2f218f356dd2d995839f1941322ff021c72a492c470f0b26a34f844c29cdf5ba"},
{file = "ruff-0.9.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b908ff4df65dad7b251c9968a2e4560836d8f5487c2f0cc238321ed951ea0504"},
{file = "ruff-0.9.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b109c0ad2ececf42e75fa99dc4043ff72a357436bb171900714a9ea581ddef83"},
{file = "ruff-0.9.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1de4367cca3dac99bcbd15c161404e849bb0bfd543664db39232648dc00112dc"},
{file = "ruff-0.9.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac3ee4d7c2c92ddfdaedf0bf31b2b176fa7aa8950efc454628d477394d35638b"},
{file = "ruff-0.9.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5dc1edd1775270e6aa2386119aea692039781429f0be1e0949ea5884e011aa8e"},
{file = "ruff-0.9.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4a091729086dffa4bd070aa5dab7e39cc6b9d62eb2bef8f3d91172d30d599666"},
{file = "ruff-0.9.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1bbc6808bf7b15796cef0815e1dfb796fbd383e7dbd4334709642649625e7c5"},
{file = "ruff-0.9.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:589d1d9f25b5754ff230dce914a174a7c951a85a4e9270613a2b74231fdac2f5"},
{file = "ruff-0.9.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc61dd5131742e21103fbbdcad683a8813be0e3c204472d520d9a5021ca8b217"},
{file = "ruff-0.9.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5e2d9126161d0357e5c8f30b0bd6168d2c3872372f14481136d13de9937f79b6"},
{file = "ruff-0.9.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:68660eab1a8e65babb5229a1f97b46e3120923757a68b5413d8561f8a85d4897"},
{file = "ruff-0.9.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c4cae6c4cc7b9b4017c71114115db0445b00a16de3bcde0946273e8392856f08"},
{file = "ruff-0.9.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:19f505b643228b417c1111a2a536424ddde0db4ef9023b9e04a46ed8a1cb4656"},
{file = "ruff-0.9.6-py3-none-win32.whl", hash = "sha256:194d8402bceef1b31164909540a597e0d913c0e4952015a5b40e28c146121b5d"},
{file = "ruff-0.9.6-py3-none-win_amd64.whl", hash = "sha256:03482d5c09d90d4ee3f40d97578423698ad895c87314c4de39ed2af945633caa"},
{file = "ruff-0.9.6-py3-none-win_arm64.whl", hash = "sha256:0e2bb706a2be7ddfea4a4af918562fdc1bcb16df255e5fa595bbd800ce322a5a"},
{file = "ruff-0.9.6.tar.gz", hash = "sha256:81761592f72b620ec8fa1068a6fd00e98a5ebee342a3642efd84454f3031dca9"},
{file = "ruff-0.9.10-py3-none-linux_armv6l.whl", hash = "sha256:eb4d25532cfd9fe461acc83498361ec2e2252795b4f40b17e80692814329e42d"},
{file = "ruff-0.9.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:188a6638dab1aa9bb6228a7302387b2c9954e455fb25d6b4470cb0641d16759d"},
{file = "ruff-0.9.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5284dcac6b9dbc2fcb71fdfc26a217b2ca4ede6ccd57476f52a587451ebe450d"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:47678f39fa2a3da62724851107f438c8229a3470f533894b5568a39b40029c0c"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:99713a6e2766b7a17147b309e8c915b32b07a25c9efd12ada79f217c9c778b3e"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:524ee184d92f7c7304aa568e2db20f50c32d1d0caa235d8ddf10497566ea1a12"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:df92aeac30af821f9acf819fc01b4afc3dfb829d2782884f8739fb52a8119a16"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de42e4edc296f520bb84954eb992a07a0ec5a02fecb834498415908469854a52"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d257f95b65806104b6b1ffca0ea53f4ef98454036df65b1eda3693534813ecd1"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b60dec7201c0b10d6d11be00e8f2dbb6f40ef1828ee75ed739923799513db24c"},
{file = "ruff-0.9.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:d838b60007da7a39c046fcdd317293d10b845001f38bcb55ba766c3875b01e43"},
{file = "ruff-0.9.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ccaf903108b899beb8e09a63ffae5869057ab649c1e9231c05ae354ebc62066c"},
{file = "ruff-0.9.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f9567d135265d46e59d62dc60c0bfad10e9a6822e231f5b24032dba5a55be6b5"},
{file = "ruff-0.9.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5f202f0d93738c28a89f8ed9eaba01b7be339e5d8d642c994347eaa81c6d75b8"},
{file = "ruff-0.9.10-py3-none-win32.whl", hash = "sha256:bfb834e87c916521ce46b1788fbb8484966e5113c02df216680102e9eb960029"},
{file = "ruff-0.9.10-py3-none-win_amd64.whl", hash = "sha256:f2160eeef3031bf4b17df74e307d4c5fb689a6f3a26a2de3f7ef4044e3c484f1"},
{file = "ruff-0.9.10-py3-none-win_arm64.whl", hash = "sha256:5fd804c0327a5e5ea26615550e706942f348b197d5475ff34c19733aee4b2e69"},
{file = "ruff-0.9.10.tar.gz", hash = "sha256:9bacb735d7bada9cfb0f2c227d3658fc443d90a727b47f206fb33f52f3c0eac7"},
]
[[package]]
@@ -1929,4 +1929,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<4.0"
content-hash = "f5cd0d1dafeb2b5c97d0ef27bef8a2235d4a1f54e3c60583d05ef582ac49c0e6"
content-hash = "931772287f71c539575d601e6398423bf68e09ca87ae1a144057c7f5707cf978"

View File

@@ -21,7 +21,7 @@ supabase = "^2.13.0"
[tool.poetry.group.dev.dependencies]
redis = "^5.2.1"
ruff = "^0.9.6"
ruff = "^0.9.10"
[build-system]
requires = ["poetry-core"]

View File

@@ -2,88 +2,103 @@ import importlib
import os
import re
from pathlib import Path
from typing import Type, TypeVar
from backend.data.block import Block
# Dynamically load all modules under backend.blocks
AVAILABLE_MODULES = []
current_dir = Path(__file__).parent
modules = [
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
for f in current_dir.rglob("*.py")
if f.is_file() and f.name != "__init__.py"
]
for module in modules:
if not re.match("^[a-z0-9_.]+$", module):
raise ValueError(
f"Block module {module} error: module name must be lowercase, "
"and contain only alphanumeric characters and underscores."
)
importlib.import_module(f".{module}", package=__name__)
AVAILABLE_MODULES.append(module)
# Load all Block instances from the available modules
AVAILABLE_BLOCKS: dict[str, Type[Block]] = {}
from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
from backend.data.block import Block
T = TypeVar("T")
def all_subclasses(cls: Type[T]) -> list[Type[T]]:
_AVAILABLE_BLOCKS: dict[str, type["Block"]] = {}
def load_all_blocks() -> dict[str, type["Block"]]:
from backend.data.block import Block
if _AVAILABLE_BLOCKS:
return _AVAILABLE_BLOCKS
# Dynamically load all modules under backend.blocks
AVAILABLE_MODULES = []
current_dir = Path(__file__).parent
modules = [
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
for f in current_dir.rglob("*.py")
if f.is_file() and f.name != "__init__.py"
]
for module in modules:
if not re.match("^[a-z0-9_.]+$", module):
raise ValueError(
f"Block module {module} error: module name must be lowercase, "
"and contain only alphanumeric characters and underscores."
)
importlib.import_module(f".{module}", package=__name__)
AVAILABLE_MODULES.append(module)
# Load all Block instances from the available modules
for block_cls in all_subclasses(Block):
class_name = block_cls.__name__
if class_name.endswith("Base"):
continue
if not class_name.endswith("Block"):
raise ValueError(
f"Block class {class_name} does not end with 'Block'. "
"If you are creating an abstract class, "
"please name the class with 'Base' at the end"
)
block = block_cls.create()
if not isinstance(block.id, str) or len(block.id) != 36:
raise ValueError(
f"Block ID {block.name} error: {block.id} is not a valid UUID"
)
if block.id in _AVAILABLE_BLOCKS:
raise ValueError(
f"Block ID {block.name} error: {block.id} is already in use"
)
input_schema = block.input_schema.model_fields
output_schema = block.output_schema.model_fields
# Make sure `error` field is a string in the output schema
if "error" in output_schema and output_schema["error"].annotation is not str:
raise ValueError(
f"{block.name} `error` field in output_schema must be a string"
)
# Ensure all fields in input_schema and output_schema are annotated SchemaFields
for field_name, field in [*input_schema.items(), *output_schema.items()]:
if field.annotation is None:
raise ValueError(
f"{block.name} has a field {field_name} that is not annotated"
)
if field.json_schema_extra is None:
raise ValueError(
f"{block.name} has a field {field_name} not defined as SchemaField"
)
for field in block.input_schema.model_fields.values():
if field.annotation is bool and field.default not in (True, False):
raise ValueError(
f"{block.name} has a boolean field with no default value"
)
_AVAILABLE_BLOCKS[block.id] = block_cls
return _AVAILABLE_BLOCKS
__all__ = ["load_all_blocks"]
def all_subclasses(cls: type[T]) -> list[type[T]]:
subclasses = cls.__subclasses__()
for subclass in subclasses:
subclasses += all_subclasses(subclass)
return subclasses
for block_cls in all_subclasses(Block):
name = block_cls.__name__
if block_cls.__name__.endswith("Base"):
continue
if not block_cls.__name__.endswith("Block"):
raise ValueError(
f"Block class {block_cls.__name__} does not end with 'Block', If you are creating an abstract class, please name the class with 'Base' at the end"
)
block = block_cls.create()
if not isinstance(block.id, str) or len(block.id) != 36:
raise ValueError(f"Block ID {block.name} error: {block.id} is not a valid UUID")
if block.id in AVAILABLE_BLOCKS:
raise ValueError(f"Block ID {block.name} error: {block.id} is already in use")
input_schema = block.input_schema.model_fields
output_schema = block.output_schema.model_fields
# Make sure `error` field is a string in the output schema
if "error" in output_schema and output_schema["error"].annotation is not str:
raise ValueError(
f"{block.name} `error` field in output_schema must be a string"
)
# Make sure all fields in input_schema and output_schema are annotated and has a value
for field_name, field in [*input_schema.items(), *output_schema.items()]:
if field.annotation is None:
raise ValueError(
f"{block.name} has a field {field_name} that is not annotated"
)
if field.json_schema_extra is None:
raise ValueError(
f"{block.name} has a field {field_name} not defined as SchemaField"
)
for field in block.input_schema.model_fields.values():
if field.annotation is bool and field.default not in (True, False):
raise ValueError(f"{block.name} has a boolean field with no default value")
if block.disabled:
continue
AVAILABLE_BLOCKS[block.id] = block_cls
__all__ = ["AVAILABLE_MODULES", "AVAILABLE_BLOCKS"]

View File

@@ -1,26 +1,13 @@
import enum
from typing import TYPE_CHECKING, Any, List
from typing import Any, List
from backend.data.block import (
Block,
BlockCategory,
BlockInput,
BlockOutput,
BlockSchema,
BlockType,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
from backend.data.model import SchemaField
from backend.util import json
from backend.util.file import MediaFile, store_media_file
from backend.util.mock import MockObject
from backend.util.text import TextFormatter
from backend.util.type import convert
if TYPE_CHECKING:
from backend.data.graph import Link
formatter = TextFormatter()
class FileStoreBlock(Block):
class Input(BlockSchema):
@@ -101,29 +88,6 @@ class StoreValueBlock(Block):
yield "output", input_data.data or input_data.input
class PrintToConsoleBlock(Block):
class Input(BlockSchema):
text: str = SchemaField(description="The text to print to the console.")
class Output(BlockSchema):
status: str = SchemaField(description="The status of the print operation.")
def __init__(self):
super().__init__(
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
description="Print the given text to the console, this is used for a debugging purpose.",
categories={BlockCategory.BASIC},
input_schema=PrintToConsoleBlock.Input,
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
test_output=("status", "printed"),
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
print(">>>>> Print: ", input_data.text)
yield "status", "printed"
class FindInDictionaryBlock(Block):
class Input(BlockSchema):
input: Any = SchemaField(description="Dictionary to lookup from")
@@ -184,188 +148,6 @@ class FindInDictionaryBlock(Block):
yield "missing", input_data.input
class AgentInputBlock(Block):
"""
This block is used to provide input to the graph.
It takes in a value, name, description, default values list and bool to limit selection to default values.
It Outputs the value passed as input.
"""
class Input(BlockSchema):
name: str = SchemaField(description="The name of the input.")
value: Any = SchemaField(
description="The value to be passed as input.",
default=None,
)
title: str | None = SchemaField(
description="The title of the input.", default=None, advanced=True
)
description: str | None = SchemaField(
description="The description of the input.",
default=None,
advanced=True,
)
placeholder_values: List[Any] = SchemaField(
description="The placeholder values to be passed as input.",
default=[],
advanced=True,
)
limit_to_placeholder_values: bool = SchemaField(
description="Whether to limit the selection to placeholder values.",
default=False,
advanced=True,
)
advanced: bool = SchemaField(
description="Whether to show the input in the advanced section, if the field is not required.",
default=False,
advanced=True,
)
secret: bool = SchemaField(
description="Whether the input should be treated as a secret.",
default=False,
advanced=True,
)
class Output(BlockSchema):
result: Any = SchemaField(description="The value passed as input.")
def __init__(self):
super().__init__(
id="c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
description="This block is used to provide input to the graph.",
input_schema=AgentInputBlock.Input,
output_schema=AgentInputBlock.Output,
test_input=[
{
"value": "Hello, World!",
"name": "input_1",
"description": "This is a test input.",
"placeholder_values": [],
"limit_to_placeholder_values": False,
},
{
"value": "Hello, World!",
"name": "input_2",
"description": "This is a test input.",
"placeholder_values": ["Hello, World!"],
"limit_to_placeholder_values": True,
},
],
test_output=[
("result", "Hello, World!"),
("result", "Hello, World!"),
],
categories={BlockCategory.INPUT, BlockCategory.BASIC},
block_type=BlockType.INPUT,
static_output=True,
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "result", input_data.value
class AgentOutputBlock(Block):
"""
Records the output of the graph for users to see.
Behavior:
If `format` is provided and the `value` is of a type that can be formatted,
the block attempts to format the recorded_value using the `format`.
If formatting fails or no `format` is provided, the raw `value` is output.
"""
class Input(BlockSchema):
value: Any = SchemaField(
description="The value to be recorded as output.",
default=None,
advanced=False,
)
name: str = SchemaField(description="The name of the output.")
title: str | None = SchemaField(
description="The title of the output.",
default=None,
advanced=True,
)
description: str | None = SchemaField(
description="The description of the output.",
default=None,
advanced=True,
)
format: str = SchemaField(
description="The format string to be used to format the recorded_value. Use Jinja2 syntax.",
default="",
advanced=True,
)
advanced: bool = SchemaField(
description="Whether to treat the output as advanced.",
default=False,
advanced=True,
)
secret: bool = SchemaField(
description="Whether the output should be treated as a secret.",
default=False,
advanced=True,
)
class Output(BlockSchema):
output: Any = SchemaField(description="The value recorded as output.")
name: Any = SchemaField(description="The name of the value recorded as output.")
def __init__(self):
super().__init__(
id="363ae599-353e-4804-937e-b2ee3cef3da4",
description="Stores the output of the graph for users to see.",
input_schema=AgentOutputBlock.Input,
output_schema=AgentOutputBlock.Output,
test_input=[
{
"value": "Hello, World!",
"name": "output_1",
"description": "This is a test output.",
"format": "{{ output_1 }}!!",
},
{
"value": "42",
"name": "output_2",
"description": "This is another test output.",
"format": "{{ output_2 }}",
},
{
"value": MockObject(value="!!", key="key"),
"name": "output_3",
"description": "This is a test output with a mock object.",
"format": "{{ output_3 }}",
},
],
test_output=[
("output", "Hello, World!!!"),
("output", "42"),
("output", MockObject(value="!!", key="key")),
],
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
block_type=BlockType.OUTPUT,
static_output=True,
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
"""
Attempts to format the recorded_value using the fmt_string if provided.
If formatting fails or no fmt_string is given, returns the original recorded_value.
"""
if input_data.format:
try:
yield "output", formatter.format_string(
input_data.format, {input_data.name: input_data.value}
)
except Exception as e:
yield "output", f"Error: {e}, {input_data.value}"
else:
yield "output", input_data.value
yield "name", input_data.name
class AddToDictionaryBlock(Block):
class Input(BlockSchema):
dictionary: dict[Any, Any] = SchemaField(
@@ -466,17 +248,6 @@ class AddToListBlock(Block):
description="The position to insert the new entry. If not provided, the entry will be appended to the end of the list.",
)
@classmethod
def get_missing_links(cls, data: BlockInput, links: List["Link"]) -> set[str]:
return super().get_missing_links(
data,
[
link
for link in links
if link.sink_name != "list" or link.sink_id != link.source_id
],
)
class Output(BlockSchema):
updated_list: List[Any] = SchemaField(
description="The list with the new entry added."

View File

@@ -8,6 +8,7 @@ from backend.data.block import (
BlockSchema,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.compass import CompassWebhookType
@@ -42,7 +43,7 @@ class CompassAITriggerBlock(Block):
input_schema=CompassAITriggerBlock.Input,
output_schema=CompassAITriggerBlock.Output,
webhook_config=BlockManualWebhookConfig(
provider="compass",
provider=ProviderName.COMPASS,
webhook_type=CompassWebhookType.TRANSCRIPTION,
),
test_input=[

View File

@@ -12,6 +12,7 @@ from backend.data.block import (
BlockWebhookConfig,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from ._auth import (
TEST_CREDENTIALS,
@@ -123,7 +124,7 @@ class GithubPullRequestTriggerBlock(GitHubTriggerBase, Block):
output_schema=GithubPullRequestTriggerBlock.Output,
# --8<-- [start:example-webhook_config]
webhook_config=BlockWebhookConfig(
provider="github",
provider=ProviderName.GITHUB,
webhook_type=GithubWebhookType.REPO,
resource_format="{repo}",
event_filter_input="events",

View File

@@ -1,11 +1,16 @@
import json
import logging
from enum import Enum
from typing import Any
from requests.exceptions import HTTPError, RequestException
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import requests
logger = logging.getLogger(name=__name__)
class HttpMethod(Enum):
GET = "GET"
@@ -43,8 +48,9 @@ class SendWebRequestBlock(Block):
class Output(BlockSchema):
response: object = SchemaField(description="The response from the server")
client_error: object = SchemaField(description="The error on 4xx status codes")
server_error: object = SchemaField(description="The error on 5xx status codes")
client_error: object = SchemaField(description="Errors on 4xx status codes")
server_error: object = SchemaField(description="Errors on 5xx status codes")
error: str = SchemaField(description="Errors for all other exceptions")
def __init__(self):
super().__init__(
@@ -68,20 +74,40 @@ class SendWebRequestBlock(Block):
# we should send it as plain text instead
input_data.json_format = False
response = requests.request(
input_data.method.value,
input_data.url,
headers=input_data.headers,
json=body if input_data.json_format else None,
data=body if not input_data.json_format else None,
)
result = response.json() if input_data.json_format else response.text
if response.status_code // 100 == 2:
try:
response = requests.request(
input_data.method.value,
input_data.url,
headers=input_data.headers,
json=body if input_data.json_format else None,
data=body if not input_data.json_format else None,
)
result = response.json() if input_data.json_format else response.text
yield "response", result
elif response.status_code // 100 == 4:
yield "client_error", result
elif response.status_code // 100 == 5:
yield "server_error", result
else:
raise ValueError(f"Unexpected status code: {response.status_code}")
except HTTPError as e:
# Handle error responses
try:
result = e.response.json() if input_data.json_format else str(e)
except json.JSONDecodeError:
result = str(e)
if 400 <= e.response.status_code < 500:
yield "client_error", result
elif 500 <= e.response.status_code < 600:
yield "server_error", result
else:
error_msg = (
"Unexpected status code "
f"{e.response.status_code} '{e.response.reason}'"
)
logger.warning(error_msg)
yield "error", error_msg
except RequestException as e:
# Handle other request-related exceptions
yield "error", str(e)
except Exception as e:
# Catch any other unexpected exceptions
yield "error", str(e)

View File

@@ -0,0 +1,552 @@
from datetime import date, time
from typing import Any, Optional
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
from backend.data.model import SchemaField
from backend.util.file import MediaFile, store_media_file
from backend.util.mock import MockObject
from backend.util.settings import Config
from backend.util.text import TextFormatter
formatter = TextFormatter()
config = Config()
class AgentInputBlock(Block):
"""
This block is used to provide input to the graph.
It takes in a value, name, description, default values list and bool to limit selection to default values.
It Outputs the value passed as input.
"""
class Input(BlockSchema):
name: str = SchemaField(description="The name of the input.")
value: Any = SchemaField(
description="The value to be passed as input.",
default=None,
)
title: str | None = SchemaField(
description="The title of the input.", default=None, advanced=True
)
description: str | None = SchemaField(
description="The description of the input.",
default=None,
advanced=True,
)
placeholder_values: list = SchemaField(
description="The placeholder values to be passed as input.",
default=[],
advanced=True,
)
limit_to_placeholder_values: bool = SchemaField(
description="Whether to limit the selection to placeholder values.",
default=False,
advanced=True,
)
advanced: bool = SchemaField(
description="Whether to show the input in the advanced section, if the field is not required.",
default=False,
advanced=True,
)
secret: bool = SchemaField(
description="Whether the input should be treated as a secret.",
default=False,
advanced=True,
)
class Output(BlockSchema):
result: Any = SchemaField(description="The value passed as input.")
def __init__(self, **kwargs):
super().__init__(
**{
"id": "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
"description": "Base block for user inputs.",
"input_schema": AgentInputBlock.Input,
"output_schema": AgentInputBlock.Output,
"test_input": [
{
"value": "Hello, World!",
"name": "input_1",
"description": "Example test input.",
"placeholder_values": [],
"limit_to_placeholder_values": False,
},
{
"value": "Hello, World!",
"name": "input_2",
"description": "Example test input with placeholders.",
"placeholder_values": ["Hello, World!"],
"limit_to_placeholder_values": True,
},
],
"test_output": [
("result", "Hello, World!"),
("result", "Hello, World!"),
],
"categories": {BlockCategory.INPUT, BlockCategory.BASIC},
"block_type": BlockType.INPUT,
"static_output": True,
**kwargs,
}
)
def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
if input_data.value is not None:
yield "result", input_data.value
class AgentOutputBlock(Block):
"""
Records the output of the graph for users to see.
Behavior:
If `format` is provided and the `value` is of a type that can be formatted,
the block attempts to format the recorded_value using the `format`.
If formatting fails or no `format` is provided, the raw `value` is output.
"""
class Input(BlockSchema):
value: Any = SchemaField(
description="The value to be recorded as output.",
default=None,
advanced=False,
)
name: str = SchemaField(description="The name of the output.")
title: str | None = SchemaField(
description="The title of the output.",
default=None,
advanced=True,
)
description: str | None = SchemaField(
description="The description of the output.",
default=None,
advanced=True,
)
format: str = SchemaField(
description="The format string to be used to format the recorded_value. Use Jinja2 syntax.",
default="",
advanced=True,
)
advanced: bool = SchemaField(
description="Whether to treat the output as advanced.",
default=False,
advanced=True,
)
secret: bool = SchemaField(
description="Whether the output should be treated as a secret.",
default=False,
advanced=True,
)
class Output(BlockSchema):
output: Any = SchemaField(description="The value recorded as output.")
name: Any = SchemaField(description="The name of the value recorded as output.")
def __init__(self):
super().__init__(
id="363ae599-353e-4804-937e-b2ee3cef3da4",
description="Stores the output of the graph for users to see.",
input_schema=AgentOutputBlock.Input,
output_schema=AgentOutputBlock.Output,
test_input=[
{
"value": "Hello, World!",
"name": "output_1",
"description": "This is a test output.",
"format": "{{ output_1 }}!!",
},
{
"value": "42",
"name": "output_2",
"description": "This is another test output.",
"format": "{{ output_2 }}",
},
{
"value": MockObject(value="!!", key="key"),
"name": "output_3",
"description": "This is a test output with a mock object.",
"format": "{{ output_3 }}",
},
],
test_output=[
("output", "Hello, World!!!"),
("output", "42"),
("output", MockObject(value="!!", key="key")),
],
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
block_type=BlockType.OUTPUT,
static_output=True,
)
def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
"""
Attempts to format the recorded_value using the fmt_string if provided.
If formatting fails or no fmt_string is given, returns the original recorded_value.
"""
if input_data.format:
try:
yield "output", formatter.format_string(
input_data.format, {input_data.name: input_data.value}
)
except Exception as e:
yield "output", f"Error: {e}, {input_data.value}"
else:
yield "output", input_data.value
yield "name", input_data.name
class AgentShortTextInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: Optional[str] = SchemaField(
description="Short text input.",
default=None,
advanced=False,
title="Default Value",
json_schema_extra={"format": "short-text"},
)
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="Short text result.")
def __init__(self):
super().__init__(
id="7fcd3bcb-8e1b-4e69-903d-32d3d4a92158",
description="Block for short text input (single-line).",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentShortTextInputBlock.Input,
output_schema=AgentShortTextInputBlock.Output,
test_input=[
{
"value": "Hello",
"name": "short_text_1",
"description": "Short text example 1",
"placeholder_values": [],
"limit_to_placeholder_values": False,
},
{
"value": "Quick test",
"name": "short_text_2",
"description": "Short text example 2",
"placeholder_values": ["Quick test", "Another option"],
"limit_to_placeholder_values": True,
},
],
test_output=[
("result", "Hello"),
("result", "Quick test"),
],
)
class AgentLongTextInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: Optional[str] = SchemaField(
description="Long text input (potentially multi-line).",
default=None,
advanced=False,
title="Default Value",
json_schema_extra={"format": "long-text"},
)
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="Long text result.")
def __init__(self):
super().__init__(
id="90a56ffb-7024-4b2b-ab50-e26c5e5ab8ba",
description="Block for long text input (multi-line).",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentLongTextInputBlock.Input,
output_schema=AgentLongTextInputBlock.Output,
test_input=[
{
"value": "Lorem ipsum dolor sit amet...",
"name": "long_text_1",
"description": "Long text example 1",
"placeholder_values": [],
"limit_to_placeholder_values": False,
},
{
"value": "Another multiline text input.",
"name": "long_text_2",
"description": "Long text example 2",
"placeholder_values": ["Another multiline text input."],
"limit_to_placeholder_values": True,
},
],
test_output=[
("result", "Lorem ipsum dolor sit amet..."),
("result", "Another multiline text input."),
],
)
class AgentNumberInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: Optional[int] = SchemaField(
description="Number input.",
default=None,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: int = SchemaField(description="Number result.")
def __init__(self):
super().__init__(
id="96dae2bb-97a2-41c2-bd2f-13a3b5a8ea98",
description="Block for number input.",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentNumberInputBlock.Input,
output_schema=AgentNumberInputBlock.Output,
test_input=[
{
"value": 42,
"name": "number_input_1",
"description": "Number example 1",
"placeholder_values": [],
"limit_to_placeholder_values": False,
},
{
"value": 314,
"name": "number_input_2",
"description": "Number example 2",
"placeholder_values": [314, 2718],
"limit_to_placeholder_values": True,
},
],
test_output=[
("result", 42),
("result", 314),
],
)
class AgentDateInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: Optional[date] = SchemaField(
description="Date input (YYYY-MM-DD).",
default=None,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: date = SchemaField(description="Date result.")
def __init__(self):
super().__init__(
id="7e198b09-4994-47db-8b4d-952d98241817",
description="Block for date input.",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentDateInputBlock.Input,
output_schema=AgentDateInputBlock.Output,
test_input=[
{
# If your system can parse JSON date strings to date objects
"value": str(date(2025, 3, 19)),
"name": "date_input_1",
"description": "Example date input 1",
},
{
"value": str(date(2023, 12, 31)),
"name": "date_input_2",
"description": "Example date input 2",
},
],
test_output=[
("result", date(2025, 3, 19)),
("result", date(2023, 12, 31)),
],
)
class AgentTimeInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: Optional[time] = SchemaField(
description="Time input (HH:MM:SS).",
default=None,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: time = SchemaField(description="Time result.")
def __init__(self):
super().__init__(
id="2a1c757e-86cf-4c7e-aacf-060dc382e434",
description="Block for time input.",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentTimeInputBlock.Input,
output_schema=AgentTimeInputBlock.Output,
test_input=[
{
"value": str(time(9, 30, 0)),
"name": "time_input_1",
"description": "Time example 1",
},
{
"value": str(time(23, 59, 59)),
"name": "time_input_2",
"description": "Time example 2",
},
],
test_output=[
("result", time(9, 30, 0)),
("result", time(23, 59, 59)),
],
)
class AgentFileInputBlock(AgentInputBlock):
"""
A simplified file-upload block. In real usage, you might have a custom
file type or handle binary data. Here, we'll store a string path as the example.
"""
class Input(AgentInputBlock.Input):
value: Optional[MediaFile] = SchemaField(
description="Path or reference to an uploaded file.",
default=None,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="File reference/path result.")
def __init__(self):
super().__init__(
id="95ead23f-8283-4654-aef3-10c053b74a31",
description="Block for file upload input (string path for example).",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentFileInputBlock.Input,
output_schema=AgentFileInputBlock.Output,
test_input=[
{
"value": "data:image/png;base64,MQ==",
"name": "file_upload_1",
"description": "Example file upload 1",
},
],
test_output=[
("result", str),
],
)
def run(
self,
input_data: Input,
*,
graph_exec_id: str,
**kwargs,
) -> BlockOutput:
if not input_data.value:
return
file_path = store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.value,
return_content=False,
)
yield "result", file_path
class AgentDropdownInputBlock(AgentInputBlock):
"""
A specialized text input block that relies on placeholder_values +
limit_to_placeholder_values to present a dropdown.
"""
class Input(AgentInputBlock.Input):
value: Optional[str] = SchemaField(
description="Text selected from a dropdown.",
default=None,
advanced=False,
title="Default Value",
)
placeholder_values: list = SchemaField(
description="Possible values for the dropdown.",
default=[],
advanced=False,
title="Dropdown Options",
)
limit_to_placeholder_values: bool = SchemaField(
description="Whether the selection is limited to placeholder values.",
default=True,
)
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="Selected dropdown value.")
def __init__(self):
super().__init__(
id="655d6fdf-a334-421c-b733-520549c07cd1",
description="Block for dropdown text selection.",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentDropdownInputBlock.Input,
output_schema=AgentDropdownInputBlock.Output,
test_input=[
{
"value": "Option A",
"name": "dropdown_1",
"placeholder_values": ["Option A", "Option B", "Option C"],
"limit_to_placeholder_values": True,
"description": "Dropdown example 1",
},
{
"value": "Option C",
"name": "dropdown_2",
"placeholder_values": ["Option A", "Option B", "Option C"],
"limit_to_placeholder_values": True,
"description": "Dropdown example 2",
},
],
test_output=[
("result", "Option A"),
("result", "Option C"),
],
)
class AgentToggleInputBlock(AgentInputBlock):
class Input(AgentInputBlock.Input):
value: bool = SchemaField(
description="Boolean toggle input.",
default=False,
advanced=False,
title="Default Value",
)
class Output(AgentInputBlock.Output):
result: bool = SchemaField(description="Boolean toggle result.")
def __init__(self):
super().__init__(
id="cbf36ab5-df4a-43b6-8a7f-f7ed8652116e",
description="Block for boolean toggle input.",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentToggleInputBlock.Input,
output_schema=AgentToggleInputBlock.Output,
test_input=[
{
"value": True,
"name": "toggle_1",
"description": "Toggle example 1",
},
{
"value": False,
"name": "toggle_2",
"description": "Toggle example 2",
},
],
test_output=[
("result", True),
("result", False),
],
)

View File

@@ -142,7 +142,9 @@ class ScreenshotWebPageBlock(Block):
return {
"image": store_media_file(
graph_exec_id=graph_exec_id,
file=f"data:image/{format.value};base64,{b64encode(response.content).decode('utf-8')}",
file=MediaFile(
f"data:image/{format.value};base64,{b64encode(response.content).decode('utf-8')}"
),
return_content=True,
)
}

View File

@@ -8,6 +8,7 @@ from backend.data.block import (
BlockWebhookConfig,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.util import settings
from backend.util.settings import AppEnvironment, BehaveAs
@@ -82,7 +83,7 @@ class Slant3DOrderWebhookBlock(Slant3DTriggerBase, Block):
input_schema=self.Input,
output_schema=self.Output,
webhook_config=BlockWebhookConfig(
provider="slant3d",
provider=ProviderName.SLANT3D,
webhook_type="orders", # Only one type for now
resource_format="", # No resource format needed
event_filter_input="events",

View File

@@ -20,6 +20,7 @@ from prisma.models import AgentBlock
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.settings import Config
@@ -225,7 +226,7 @@ class BlockManualWebhookConfig(BaseModel):
the user has to manually set up the webhook at the provider.
"""
provider: str
provider: ProviderName
"""The service provider that the webhook connects to"""
webhook_type: str
@@ -461,9 +462,9 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
def get_blocks() -> dict[str, Type[Block]]:
from backend.blocks import AVAILABLE_BLOCKS # noqa: E402
from backend.blocks import load_all_blocks
return AVAILABLE_BLOCKS
return load_all_blocks()
async def initialize_blocks() -> None:

View File

@@ -15,14 +15,11 @@ from prisma.enums import (
from prisma.errors import UniqueViolationError
from prisma.models import CreditRefundRequest, CreditTransaction, User
from prisma.types import CreditTransactionCreateInput, CreditTransactionWhereInput
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_exponential
from backend.data import db
from backend.data.block import Block, BlockInput, get_block
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCost, BlockCostType
from backend.data.execution import NodeExecutionEntry
from backend.data.cost import BlockCost
from backend.data.model import (
AutoTopUpConfig,
RefundRequest,
@@ -31,6 +28,7 @@ from backend.data.model import (
)
from backend.data.notifications import NotificationEventDTO, RefundRequestData
from backend.data.user import get_user_by_id
from backend.executor.utils import UsageTransactionMetadata
from backend.notifications import NotificationManager
from backend.util.exceptions import InsufficientBalanceError
from backend.util.service import get_service_client
@@ -39,6 +37,7 @@ from backend.util.settings import Settings
settings = Settings()
stripe.api_key = settings.secrets.stripe_api_key
logger = logging.getLogger(__name__)
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
class UserCreditBase(ABC):
@@ -90,20 +89,20 @@ class UserCreditBase(ABC):
@abstractmethod
async def spend_credits(
self,
entry: NodeExecutionEntry,
data_size: float,
run_time: float,
user_id: str,
cost: int,
metadata: UsageTransactionMetadata,
) -> int:
"""
Spend the credits for the user based on the block usage.
Spend the credits for the user based on the cost.
Args:
entry (NodeExecutionEntry): The node execution identifiers & data.
data_size (float): The size of the data being processed.
run_time (float): The time taken to run the block.
user_id (str): The user ID.
cost (int): The cost to spend.
metadata (UsageTransactionMetadata): The metadata of the transaction.
Returns:
int: amount of credit spent
int: The remaining balance.
"""
pass
@@ -185,6 +184,14 @@ class UserCreditBase(ABC):
"""
pass
@staticmethod
async def create_billing_portal_session(user_id: str) -> str:
session = stripe.billing_portal.Session.create(
customer=await get_stripe_customer_id(user_id),
return_url=base_url + "/profile/credits",
)
return session.url
@staticmethod
def time_now() -> datetime:
return datetime.now(timezone.utc)
@@ -339,16 +346,6 @@ class UserCreditBase(ABC):
return user_balance + amount, tx.transactionKey
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
node_exec_id: str | None = None
block_id: str | None = None
block: str | None = None
input: BlockInput | None = None
class UserCredit(UserCreditBase):
@thread_cached
def notification_client(self) -> NotificationManager:
@@ -369,89 +366,21 @@ class UserCredit(UserCreditBase):
)
)
def _block_usage_cost(
self,
block: Block,
input_data: BlockInput,
data_size: float,
run_time: float,
) -> tuple[int, BlockInput]:
block_costs = BLOCK_COSTS.get(type(block))
if not block_costs:
return 0, {}
for block_cost in block_costs:
if not self._is_cost_filter_match(block_cost.cost_filter, input_data):
continue
if block_cost.cost_type == BlockCostType.RUN:
return block_cost.cost_amount, block_cost.cost_filter
if block_cost.cost_type == BlockCostType.SECOND:
return (
int(run_time * block_cost.cost_amount),
block_cost.cost_filter,
)
if block_cost.cost_type == BlockCostType.BYTE:
return (
int(data_size * block_cost.cost_amount),
block_cost.cost_filter,
)
return 0, {}
def _is_cost_filter_match(
self, cost_filter: BlockInput, input_data: BlockInput
) -> bool:
"""
Filter rules:
- If cost_filter is an object, then check if cost_filter is the subset of input_data
- Otherwise, check if cost_filter is equal to input_data.
- Undefined, null, and empty string are considered as equal.
"""
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
return cost_filter == input_data
return all(
(not input_data.get(k) and not v)
or (input_data.get(k) and self._is_cost_filter_match(v, input_data[k]))
for k, v in cost_filter.items()
)
async def spend_credits(
self,
entry: NodeExecutionEntry,
data_size: float,
run_time: float,
user_id: str,
cost: int,
metadata: UsageTransactionMetadata,
) -> int:
block = get_block(entry.block_id)
if not block:
raise ValueError(f"Block not found: {entry.block_id}")
cost, matching_filter = self._block_usage_cost(
block=block, input_data=entry.data, data_size=data_size, run_time=run_time
)
if cost == 0:
return 0
balance, _ = await self._add_transaction(
user_id=entry.user_id,
user_id=user_id,
amount=-cost,
transaction_type=CreditTransactionType.USAGE,
metadata=Json(
UsageTransactionMetadata(
graph_exec_id=entry.graph_exec_id,
graph_id=entry.graph_id,
node_id=entry.node_id,
node_exec_id=entry.node_exec_id,
block_id=entry.block_id,
block=block.name,
input=matching_filter,
).model_dump()
),
metadata=Json(metadata.model_dump()),
)
user_id = entry.user_id
# Auto top-up if balance is below threshold.
auto_top_up = await get_auto_top_up(user_id)
@@ -461,7 +390,7 @@ class UserCredit(UserCreditBase):
user_id=user_id,
amount=auto_top_up.amount,
# Avoid multiple auto top-ups within the same graph execution.
key=f"AUTO-TOP-UP-{user_id}-{entry.graph_exec_id}",
key=f"AUTO-TOP-UP-{user_id}-{metadata.graph_exec_id}",
ceiling_balance=auto_top_up.threshold,
)
except Exception as e:
@@ -470,7 +399,7 @@ class UserCredit(UserCreditBase):
f"Auto top-up failed for user {user_id}, balance: {balance}, amount: {auto_top_up.amount}, error: {e}"
)
return cost
return balance
async def top_up_credits(self, user_id: str, amount: int):
await self._top_up_credits(user_id, amount)
@@ -765,10 +694,8 @@ class UserCredit(UserCreditBase):
ui_mode="hosted",
payment_intent_data={"setup_future_usage": "off_session"},
saved_payment_method_options={"payment_method_save": "enabled"},
success_url=settings.config.frontend_base_url
+ "/profile/credits?topup=success",
cancel_url=settings.config.frontend_base_url
+ "/profile/credits?topup=cancel",
success_url=base_url + "/profile/credits?topup=success",
cancel_url=base_url + "/profile/credits?topup=cancel",
allow_promotion_codes=True,
)

View File

@@ -1,7 +1,16 @@
from collections import defaultdict
from datetime import datetime, timezone
from multiprocessing import Manager
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, Generic, Type, TypeVar
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Generator,
Generic,
Optional,
Type,
TypeVar,
)
from prisma import Json
from prisma.enums import AgentExecutionStatus
@@ -10,6 +19,7 @@ from prisma.models import (
AgentNodeExecution,
AgentNodeExecutionInputOutput,
)
from prisma.types import AgentNodeExecutionUpdateInput, AgentNodeExecutionWhereInput
from pydantic import BaseModel
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
@@ -68,6 +78,7 @@ class ExecutionQueue(Generic[T]):
class ExecutionResult(BaseModel):
user_id: str
graph_id: str
graph_version: int
graph_exec_id: str
@@ -83,27 +94,28 @@ class ExecutionResult(BaseModel):
end_time: datetime | None
@staticmethod
def from_graph(graph: AgentGraphExecution):
def from_graph(graph_exec: AgentGraphExecution):
return ExecutionResult(
graph_id=graph.agentGraphId,
graph_version=graph.agentGraphVersion,
graph_exec_id=graph.id,
user_id=graph_exec.userId,
graph_id=graph_exec.agentGraphId,
graph_version=graph_exec.agentGraphVersion,
graph_exec_id=graph_exec.id,
node_exec_id="",
node_id="",
block_id="",
status=graph.executionStatus,
status=graph_exec.executionStatus,
# TODO: Populate input_data & output_data from AgentNodeExecutions
# Input & Output comes AgentInputBlock & AgentOutputBlock.
input_data={},
output_data={},
add_time=graph.createdAt,
queue_time=graph.createdAt,
start_time=graph.startedAt,
end_time=graph.updatedAt,
add_time=graph_exec.createdAt,
queue_time=graph_exec.createdAt,
start_time=graph_exec.startedAt,
end_time=graph_exec.updatedAt,
)
@staticmethod
def from_db(execution: AgentNodeExecution):
def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None):
if execution.executionData:
# Execution that has been queued for execution will persist its data.
input_data = type.convert(execution.executionData, dict[str, Any])
@@ -118,8 +130,15 @@ class ExecutionResult(BaseModel):
output_data[data.name].append(type.convert(data.data, Type[Any]))
graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution
if graph_execution:
user_id = graph_execution.userId
elif not user_id:
raise ValueError(
"AgentGraphExecution must be included or user_id passed in"
)
return ExecutionResult(
user_id=user_id,
graph_id=graph_execution.agentGraphId if graph_execution else "",
graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
graph_exec_id=execution.agentGraphExecutionId,
@@ -160,7 +179,8 @@ async def create_graph_execution(
"create": [ # type: ignore
{
"agentNodeId": node_id,
"executionStatus": ExecutionStatus.INCOMPLETE,
"executionStatus": ExecutionStatus.QUEUED,
"queuedTime": datetime.now(tz=timezone.utc),
"Input": {
"create": [
{"name": name, "data": Json(data)}
@@ -178,7 +198,7 @@ async def create_graph_execution(
)
return result.id, [
ExecutionResult.from_db(execution)
ExecutionResult.from_db(execution, result.userId)
for execution in result.AgentNodeExecutions or []
]
@@ -285,13 +305,19 @@ async def update_graph_execution_start_time(graph_exec_id: str) -> ExecutionResu
async def update_graph_execution_stats(
graph_exec_id: str,
status: ExecutionStatus,
stats: GraphExecutionStats,
stats: GraphExecutionStats | None = None,
) -> ExecutionResult:
data = stats.model_dump()
if isinstance(data["error"], Exception):
data = stats.model_dump() if stats else {}
if isinstance(data.get("error"), Exception):
data["error"] = str(data["error"])
res = await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
where={
"id": graph_exec_id,
"OR": [
{"executionStatus": ExecutionStatus.RUNNING},
{"executionStatus": ExecutionStatus.QUEUED},
],
},
data={
"executionStatus": status,
"stats": Json(data),
@@ -313,6 +339,17 @@ async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionSta
)
async def update_execution_status_batch(
node_exec_ids: list[str],
status: ExecutionStatus,
stats: dict[str, Any] | None = None,
):
await AgentNodeExecution.prisma().update_many(
where={"id": {"in": node_exec_ids}},
data=_get_update_status_data(status, None, stats),
)
async def update_execution_status(
node_exec_id: str,
status: ExecutionStatus,
@@ -322,20 +359,9 @@ async def update_execution_status(
if status == ExecutionStatus.QUEUED and execution_data is None:
raise ValueError("Execution data must be provided when queuing an execution.")
now = datetime.now(tz=timezone.utc)
data = {
**({"executionStatus": status}),
**({"queuedTime": now} if status == ExecutionStatus.QUEUED else {}),
**({"startedTime": now} if status == ExecutionStatus.RUNNING else {}),
**({"endedTime": now} if status == ExecutionStatus.FAILED else {}),
**({"endedTime": now} if status == ExecutionStatus.COMPLETED else {}),
**({"executionData": Json(execution_data)} if execution_data else {}),
**({"stats": Json(stats)} if stats else {}),
}
res = await AgentNodeExecution.prisma().update(
where={"id": node_exec_id},
data=data, # type: ignore
data=_get_update_status_data(status, execution_data, stats),
include=EXECUTION_RESULT_INCLUDE,
)
if not res:
@@ -344,6 +370,29 @@ async def update_execution_status(
return ExecutionResult.from_db(res)
def _get_update_status_data(
status: ExecutionStatus,
execution_data: BlockInput | None = None,
stats: dict[str, Any] | None = None,
) -> AgentNodeExecutionUpdateInput:
now = datetime.now(tz=timezone.utc)
update_data: AgentNodeExecutionUpdateInput = {"executionStatus": status}
if status == ExecutionStatus.QUEUED:
update_data["queuedTime"] = now
elif status == ExecutionStatus.RUNNING:
update_data["startedTime"] = now
elif status in (ExecutionStatus.FAILED, ExecutionStatus.COMPLETED):
update_data["endedTime"] = now
if execution_data:
update_data["executionData"] = Json(execution_data)
if stats:
update_data["stats"] = Json(stats)
return update_data
async def delete_execution(
graph_exec_id: str, user_id: str, soft_delete: bool = True
) -> None:
@@ -361,41 +410,29 @@ async def delete_execution(
)
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
async def get_execution_results(
graph_exec_id: str,
block_ids: list[str] | None = None,
statuses: list[ExecutionStatus] | None = None,
limit: int | None = None,
) -> list[ExecutionResult]:
where_clause: AgentNodeExecutionWhereInput = {
"agentGraphExecutionId": graph_exec_id,
}
if block_ids:
where_clause["AgentNode"] = {"is": {"agentBlockId": {"in": block_ids}}}
if statuses:
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
executions = await AgentNodeExecution.prisma().find_many(
where={"agentGraphExecutionId": graph_exec_id},
where=where_clause,
include=EXECUTION_RESULT_INCLUDE,
order=[
{"queuedTime": "asc"},
{"addedTime": "asc"}, # Fallback: Incomplete execs has no queuedTime.
],
take=limit,
)
res = [ExecutionResult.from_db(execution) for execution in executions]
return res
async def get_executions_in_timerange(
user_id: str, start_time: str, end_time: str
) -> list[ExecutionResult]:
try:
executions = await AgentGraphExecution.prisma().find_many(
where={
"startedAt": {
"gte": datetime.fromisoformat(start_time),
"lte": datetime.fromisoformat(end_time),
},
"userId": user_id,
"isDeleted": False,
},
include=GRAPH_EXECUTION_INCLUDE,
)
return [ExecutionResult.from_graph(execution) for execution in executions]
except Exception as e:
raise DatabaseError(
f"Failed to get executions in timerange {start_time} to {end_time} for user {user_id}: {e}"
) from e
LIST_SPLIT = "_$_"
DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
@@ -550,7 +587,10 @@ async def get_output_from_links(
"agentGraphExecutionId": graph_eid,
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
},
order={"queuedTime": "asc"},
order=[
{"queuedTime": "asc"},
{"addedTime": "desc"},
],
include=EXECUTION_RESULT_INCLUDE,
)

View File

@@ -1,4 +1,3 @@
import asyncio
import logging
import uuid
from collections import defaultdict
@@ -7,6 +6,7 @@ from typing import Any, Literal, Optional, Type
import prisma
from prisma import Json
from prisma.enums import SubmissionStatus
from prisma.models import (
AgentGraph,
AgentGraphExecution,
@@ -14,17 +14,17 @@ from prisma.models import (
AgentNodeLink,
StoreListingVersion,
)
from prisma.types import AgentGraphWhereInput
from prisma.types import AgentGraphExecutionWhereInput, AgentGraphWhereInput
from pydantic.fields import Field, computed_field
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock
from backend.util import type
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.util import type as type_utils
from .block import BlockInput, BlockType, get_block, get_blocks
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
from .db import BaseDbModel, transaction
from .execution import ExecutionResult, ExecutionStatus
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE, GRAPH_EXECUTION_INCLUDE
from .integrations import Webhook
logger = logging.getLogger(__name__)
@@ -71,13 +71,20 @@ class NodeModel(Node):
webhook: Optional[Webhook] = None
@property
def block(self) -> Block[BlockSchema, BlockSchema]:
block = get_block(self.block_id)
if not block:
raise ValueError(f"Block #{self.block_id} does not exist")
return block
@staticmethod
def from_db(node: AgentNode) -> "NodeModel":
def from_db(node: AgentNode, for_export: bool = False) -> "NodeModel":
obj = NodeModel(
id=node.id,
block_id=node.agentBlockId,
input_default=type.convert(node.constantInput, dict[str, Any]),
metadata=type.convert(node.metadata, dict[str, Any]),
input_default=type_utils.convert(node.constantInput, dict[str, Any]),
metadata=type_utils.convert(node.metadata, dict[str, Any]),
graph_id=node.agentGraphId,
graph_version=node.agentGraphVersion,
webhook_id=node.webhookId,
@@ -85,6 +92,8 @@ class NodeModel(Node):
)
obj.input_links = [Link.from_db(link) for link in node.Input or []]
obj.output_links = [Link.from_db(link) for link in node.Output or []]
if for_export:
return obj.stripped_for_export()
return obj
def is_triggered_by_event_type(self, event_type: str) -> bool:
@@ -103,6 +112,51 @@ class NodeModel(Node):
if event_filter[k] is True
]
def stripped_for_export(self) -> "NodeModel":
"""
Returns a copy of the node model, stripped of any non-transferable properties
"""
stripped_node = self.model_copy(deep=True)
# Remove credentials from node input
if stripped_node.input_default:
stripped_node.input_default = NodeModel._filter_secrets_from_node_input(
stripped_node.input_default, self.block.input_schema.jsonschema()
)
if (
stripped_node.block.block_type == BlockType.INPUT
and "value" in stripped_node.input_default
):
stripped_node.input_default["value"] = ""
# Remove webhook info
stripped_node.webhook_id = None
stripped_node.webhook = None
return stripped_node
@staticmethod
def _filter_secrets_from_node_input(
input_data: dict[str, Any], schema: dict[str, Any] | None
) -> dict[str, Any]:
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
field_schemas = schema.get("properties", {}) if schema else {}
result = {}
for key, value in input_data.items():
field_schema: dict | None = field_schemas.get(key)
if (field_schema and field_schema.get("secret", False)) or any(
sensitive_key in key.lower() for sensitive_key in sensitive_keys
):
# This is a secret value -> filter this key-value pair out
continue
elif isinstance(value, dict):
result[key] = NodeModel._filter_secrets_from_node_input(
value, field_schema
)
else:
result[key] = value
return result
# Fix 2-way reference Node <-> Webhook
Webhook.model_rebuild()
@@ -129,7 +183,7 @@ class GraphExecutionMeta(BaseDbModel):
total_run_time = duration
try:
stats = type.convert(_graph_exec.stats or {}, dict[str, Any])
stats = type_utils.convert(_graph_exec.stats or {}, dict[str, Any])
except ValueError:
stats = {}
@@ -163,29 +217,41 @@ class GraphExecution(GraphExecutionMeta):
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
node_executions = [
ExecutionResult.from_db(ne) for ne in _graph_exec.AgentNodeExecutions
]
node_executions = sorted(
[
ExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.AgentNodeExecutions
],
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
)
inputs = {
**{
# inputs from Agent Input Blocks
exec.input_data["name"]: exec.input_data["value"]
exec.input_data["name"]: exec.input_data.get("value")
for exec in node_executions
if exec.block_id == _INPUT_BLOCK_ID
if (
(block := get_block(exec.block_id))
and block.block_type == BlockType.INPUT
)
},
**{
# input from webhook-triggered block
"payload": exec.input_data["payload"]
for exec in node_executions
if (block := get_block(exec.block_id))
and block.block_type in [BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL]
if (
(block := get_block(exec.block_id))
and block.block_type
in [BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL]
)
},
}
outputs: dict[str, list] = defaultdict(list)
for exec in node_executions:
if exec.block_id == _OUTPUT_BLOCK_ID:
if (
block := get_block(exec.block_id)
) and block.block_type == BlockType.OUTPUT:
outputs[exec.input_data["name"]].append(
exec.input_data.get("value", None)
)
@@ -201,10 +267,9 @@ class GraphExecution(GraphExecutionMeta):
)
class Graph(BaseDbModel):
class BaseGraph(BaseDbModel):
version: int = 1
is_active: bool = True
is_template: bool = False
name: str
description: str
nodes: list[Node] = []
@@ -267,6 +332,10 @@ class Graph(BaseDbModel):
}
class Graph(BaseGraph):
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs, only used in export
class GraphModel(Graph):
user_id: str
nodes: list[NodeModel] = [] # type: ignore
@@ -290,31 +359,54 @@ class GraphModel(Graph):
Reassigns all IDs in the graph to new UUIDs.
This method can be used before storing a new graph to the database.
"""
# Reassign Graph ID
id_map = {node.id: str(uuid.uuid4()) for node in self.nodes}
if reassign_graph_id:
self.id = str(uuid.uuid4())
graph_id_map = {
self.id: str(uuid.uuid4()),
**{sub_graph.id: str(uuid.uuid4()) for sub_graph in self.sub_graphs},
}
else:
graph_id_map = {}
self._reassign_ids(self, user_id, graph_id_map)
for sub_graph in self.sub_graphs:
self._reassign_ids(sub_graph, user_id, graph_id_map)
@staticmethod
def _reassign_ids(
graph: BaseGraph,
user_id: str,
graph_id_map: dict[str, str],
):
# Reassign Graph ID
if graph.id in graph_id_map:
graph.id = graph_id_map[graph.id]
# Reassign Node IDs
for node in self.nodes:
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
for node in graph.nodes:
node.id = id_map[node.id]
# Reassign Link IDs
for link in self.links:
for link in graph.links:
link.source_id = id_map[link.source_id]
link.sink_id = id_map[link.sink_id]
# Reassign User IDs for agent blocks
for node in self.nodes:
for node in graph.nodes:
if node.block_id != AgentExecutorBlock().id:
continue
node.input_default["user_id"] = user_id
node.input_default.setdefault("data", {})
self.validate_graph()
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
node.input_default["graph_id"] = graph_id_map[graph_id]
def validate_graph(self, for_run: bool = False):
self._validate_graph(self, for_run)
for sub_graph in self.sub_graphs:
self._validate_graph(sub_graph, for_run)
@staticmethod
def _validate_graph(graph: BaseGraph, for_run: bool = False):
def sanitize(name):
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
if sanitized_name.startswith("tools_^_"):
@@ -326,11 +418,11 @@ class GraphModel(Graph):
agent_nodes = set()
nodes_block = {
node.id: block
for node in self.nodes
for node in graph.nodes
if (block := get_block(node.block_id)) is not None
}
for node in self.nodes:
for node in graph.nodes:
if (block := nodes_block.get(node.id)) is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
@@ -343,11 +435,11 @@ class GraphModel(Graph):
input_links = defaultdict(list)
for link in self.links:
for link in graph.links:
input_links[link.sink_id].append(link)
# Nodes: required fields are filled or connected and dependencies are satisfied
for node in self.nodes:
for node in graph.nodes:
if (block := nodes_block.get(node.id)) is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
@@ -408,7 +500,7 @@ class GraphModel(Graph):
f"Node {block.name} #{node.id}: Field `{field_name}` requires [{', '.join(missing_deps)}] to be set"
)
node_map = {v.id: v for v in self.nodes}
node_map = {v.id: v for v in graph.nodes}
def is_static_output_block(nid: str) -> bool:
bid = node_map[nid].block_id
@@ -416,7 +508,7 @@ class GraphModel(Graph):
return b.static_output if b else False
# Links: links are connected and the connected pin data type are compatible.
for link in self.links:
for link in graph.links:
source = (link.source_id, link.source_name)
sink = (link.sink_id, link.sink_name)
prefix = f"Link {source} <-> {sink}"
@@ -457,18 +549,20 @@ class GraphModel(Graph):
link.is_static = True # Each value block output should be static.
@staticmethod
def from_db(graph: AgentGraph, for_export: bool = False):
def from_db(
graph: AgentGraph,
for_export: bool = False,
sub_graphs: list[AgentGraph] | None = None,
):
return GraphModel(
id=graph.id,
user_id=graph.userId,
user_id=graph.userId if not for_export else "",
version=graph.version,
is_active=graph.isActive,
is_template=graph.isTemplate,
name=graph.name or "",
description=graph.description or "",
nodes=[
NodeModel.from_db(GraphModel._process_node(node, for_export))
for node in graph.AgentNodes or []
NodeModel.from_db(node, for_export) for node in graph.AgentNodes or []
],
links=list(
{
@@ -477,59 +571,12 @@ class GraphModel(Graph):
for link in (node.Input or []) + (node.Output or [])
}
),
sub_graphs=[
GraphModel.from_db(sub_graph, for_export)
for sub_graph in sub_graphs or []
],
)
@staticmethod
def _process_node(node: AgentNode, for_export: bool) -> AgentNode:
if for_export:
# Remove credentials from node input
if node.constantInput:
constant_input = type.convert(node.constantInput, dict[str, Any])
constant_input = GraphModel._hide_node_input_credentials(constant_input)
node.constantInput = Json(constant_input)
# Remove webhook info
node.webhookId = None
node.Webhook = None
return node
@staticmethod
def _hide_node_input_credentials(input_data: dict[str, Any]) -> dict[str, Any]:
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
result = {}
for key, value in input_data.items():
if isinstance(value, dict):
result[key] = GraphModel._hide_node_input_credentials(value)
elif isinstance(value, str) and any(
sensitive_key in key.lower() for sensitive_key in sensitive_keys
):
# Skip this key-value pair in the result
continue
else:
result[key] = value
return result
def clean_graph(self):
blocks = [block() for block in get_blocks().values()]
input_blocks = [
node
for node in self.nodes
if next(
(
b
for b in blocks
if b.id == node.block_id and b.block_type == BlockType.INPUT
),
None,
)
]
for node in self.nodes:
if any(input_block.id == node.id for input_block in input_blocks):
node.input_default["value"] = ""
# --------------------- CRUD functions --------------------- #
@@ -559,14 +606,14 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
async def get_graphs(
user_id: str,
filter_by: Literal["active", "template"] | None = "active",
filter_by: Literal["active"] | None = "active",
) -> list[GraphModel]:
"""
Retrieves graph metadata objects.
Default behaviour is to get all currently active graphs.
Args:
filter_by: An optional filter to either select templates or active graphs.
filter_by: An optional filter to either select graphs.
user_id: The ID of the user that owns the graph.
Returns:
@@ -576,8 +623,6 @@ async def get_graphs(
if filter_by == "active":
where_clause["isActive"] = True
elif filter_by == "template":
where_clause["isTemplate"] = True
graphs = await AgentGraph.prisma().find_many(
where=where_clause,
@@ -597,18 +642,20 @@ async def get_graphs(
return graph_models
# TODO: move execution stuff to .execution
async def get_graphs_executions(user_id: str) -> list[GraphExecutionMeta]:
executions = await AgentGraphExecution.prisma().find_many(
where={"isDeleted": False, "userId": user_id},
order={"createdAt": "desc"},
)
return [GraphExecutionMeta.from_db(execution) for execution in executions]
async def get_graph_executions(
graph_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> list[GraphExecutionMeta]:
where_filter: AgentGraphExecutionWhereInput = {
"isDeleted": False,
}
if user_id:
where_filter["userId"] = user_id
if graph_id:
where_filter["agentGraphId"] = graph_id
async def get_graph_executions(graph_id: str, user_id: str) -> list[GraphExecutionMeta]:
executions = await AgentGraphExecution.prisma().find_many(
where={"agentGraphId": graph_id, "isDeleted": False, "userId": user_id},
where=where_filter,
order={"createdAt": "desc"},
)
return [GraphExecutionMeta.from_db(execution) for execution in executions]
@@ -623,20 +670,13 @@ async def get_execution_meta(
return GraphExecutionMeta.from_db(execution) if execution else None
async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
async def get_execution(
user_id: str,
execution_id: str,
) -> GraphExecution | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "isDeleted": False, "userId": user_id},
include={
"AgentNodeExecutions": {
"include": {"AgentNode": True, "Input": True, "Output": True},
"order_by": [
{"queuedTime": "asc"},
{ # Fallback: Incomplete execs has no queuedTime.
"addedTime": "asc"
},
],
},
},
include=GRAPH_EXECUTION_INCLUDE,
)
return GraphExecution.from_db(execution) if execution else None
@@ -664,21 +704,18 @@ async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph
description=graph.description or "",
version=graph.version,
is_active=graph.isActive,
is_template=graph.isTemplate,
)
async def get_graph(
graph_id: str,
version: int | None = None,
template: bool = False, # note: currently not in use; TODO: remove from DB entirely
user_id: str | None = None,
for_export: bool = False,
) -> GraphModel | None:
"""
Retrieves a graph from the DB.
Defaults to the version with `is_active` if `version` is not passed,
or the latest version with `is_template` if `template=True`.
Defaults to the version with `is_active` if `version` is not passed.
Returns `None` if the record is not found.
"""
@@ -688,8 +725,6 @@ async def get_graph(
if version is not None:
where_clause["version"] = version
elif not template:
where_clause["isActive"] = True
graph = await AgentGraph.prisma().find_first(
where=where_clause,
@@ -706,16 +741,69 @@ async def get_graph(
"agentId": graph_id,
"agentVersion": version or graph.version,
"isDeleted": False,
"StoreListing": {"is": {"isApproved": True}},
"submissionStatus": SubmissionStatus.APPROVED,
}
)
)
):
return None
if for_export:
sub_graphs = await get_sub_graphs(graph)
return GraphModel.from_db(
graph=graph,
sub_graphs=sub_graphs,
for_export=for_export,
)
return GraphModel.from_db(graph, for_export)
async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
"""
Iteratively fetches all sub-graphs of a given graph, and flattens them into a list.
This call involves a DB fetch in batch, breadth-first, per-level of graph depth.
On each DB fetch we will only fetch the sub-graphs that are not already in the list.
"""
sub_graphs = {graph.id: graph}
search_graphs = [graph]
agent_block_id = AgentExecutorBlock().id
while search_graphs:
sub_graph_ids = [
(graph_id, graph_version)
for graph in search_graphs
for node in graph.AgentNodes or []
if (
node.AgentBlock
and node.AgentBlock.id == agent_block_id
and (graph_id := dict(node.constantInput).get("graph_id"))
and (graph_version := dict(node.constantInput).get("graph_version"))
)
]
if not sub_graph_ids:
break
graphs = await AgentGraph.prisma().find_many(
where={
"OR": [
{
"id": graph_id,
"version": graph_version,
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
}
for graph_id, graph_version in sub_graph_ids
] # type: ignore
},
include=AGENT_GRAPH_INCLUDE,
)
search_graphs = [graph for graph in graphs if graph.id not in sub_graphs]
sub_graphs.update({graph.id: graph for graph in search_graphs})
return [g for g in sub_graphs.values() if g.id != graph.id]
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
links = await AgentNodeLink.prisma().find_many(
where={"agentNodeSourceId": node_id},
@@ -779,50 +867,56 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
async with transaction() as tx:
await __create_graph(tx, graph, user_id)
if created_graph := await get_graph(
graph.id, graph.version, template=graph.is_template, user_id=user_id
):
if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
return created_graph
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
async def __create_graph(tx, graph: Graph, user_id: str):
await AgentGraph.prisma(tx).create(
data={
"id": graph.id,
"version": graph.version,
"name": graph.name,
"description": graph.description,
"isTemplate": graph.is_template,
"isActive": graph.is_active,
"userId": user_id,
"AgentNodes": {
"create": [
{
"id": node.id,
"agentBlockId": node.block_id,
"constantInput": Json(node.input_default),
"metadata": Json(node.metadata),
}
for node in graph.nodes
]
},
}
graphs = [graph] + graph.sub_graphs
await AgentGraph.prisma(tx).create_many(
data=[
{
"id": graph.id,
"version": graph.version,
"name": graph.name,
"description": graph.description,
"isActive": graph.is_active,
"userId": user_id,
}
for graph in graphs
]
)
await asyncio.gather(
*[
AgentNodeLink.prisma(tx).create(
{
"id": str(uuid.uuid4()),
"sourceName": link.source_name,
"sinkName": link.sink_name,
"agentNodeSourceId": link.source_id,
"agentNodeSinkId": link.sink_id,
"isStatic": link.is_static,
}
)
await AgentNode.prisma(tx).create_many(
data=[
{
"id": node.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"agentBlockId": node.block_id,
"constantInput": Json(node.input_default),
"metadata": Json(node.metadata),
"webhookId": node.webhook_id,
}
for graph in graphs
for node in graph.nodes
]
)
await AgentNodeLink.prisma(tx).create_many(
data=[
{
"id": str(uuid.uuid4()),
"sourceName": link.source_name,
"sinkName": link.sink_name,
"agentNodeSourceId": link.source_id,
"agentNodeSinkId": link.sink_id,
"isStatic": link.is_static,
}
for graph in graphs
for link in graph.links
]
)

View File

@@ -18,6 +18,8 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
"AgentGraphExecution": True,
}
MAX_NODE_EXECUTIONS_FETCH = 1000
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
"AgentNodeExecutions": {
"include": {
@@ -25,10 +27,17 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
"Output": True,
"AgentNode": True,
"AgentGraphExecution": True,
}
},
"order_by": [
{"queuedTime": "desc"},
# Fallback: Incomplete execs has no queuedTime.
{"addedTime": "desc"},
],
"take": MAX_NODE_EXECUTIONS_FETCH, # Avoid loading excessive node executions.
}
}
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
}

View File

@@ -144,6 +144,7 @@ def SchemaField(
depends_on: list[str] | None = None,
image_upload: Optional[bool] = None,
image_output: Optional[bool] = None,
json_schema_extra: dict[str, Any] | None = None,
**kwargs,
) -> T:
if default is PydanticUndefined and default_factory is None:
@@ -151,7 +152,7 @@ def SchemaField(
elif advanced is None:
advanced = True
json_extra = {
json_schema_extra = {
k: v
for k, v in {
"placeholder": placeholder,
@@ -161,6 +162,7 @@ def SchemaField(
"depends_on": depends_on,
"image_upload": image_upload,
"image_output": image_output,
**(json_schema_extra or {}),
}.items()
if v is not None
}
@@ -172,7 +174,7 @@ def SchemaField(
title=title,
description=description,
exclude=exclude,
json_schema_extra=json_extra,
json_schema_extra=json_schema_extra,
**kwargs,
) # type: ignore
@@ -413,7 +415,6 @@ class NodeExecutionStats(BaseModel):
error: Optional[Exception | str] = None
walltime: float = 0
cputime: float = 0
cost: float = 0
input_size: int = 0
output_size: int = 0
llm_call_count: int = 0

View File

@@ -1,5 +1,5 @@
import logging
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Annotated, Any, Generic, Optional, TypeVar, Union
@@ -18,7 +18,12 @@ from .db import transaction
logger = logging.getLogger(__name__)
T_co = TypeVar("T_co", bound="BaseNotificationData", covariant=True)
NotificationDataType_co = TypeVar(
"NotificationDataType_co", bound="BaseNotificationData", covariant=True
)
SummaryParamsType_co = TypeVar(
"SummaryParamsType_co", bound="BaseSummaryParams", covariant=True
)
class QueueType(Enum):
@@ -30,7 +35,8 @@ class QueueType(Enum):
class BaseNotificationData(BaseModel):
pass
class Config:
extra = "allow"
class AgentRunData(BaseNotificationData):
@@ -47,6 +53,13 @@ class ZeroBalanceData(BaseNotificationData):
last_transaction_time: datetime
top_up_link: str
@field_validator("last_transaction_time")
@classmethod
def validate_timezone(cls, value: datetime):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class LowBalanceData(BaseNotificationData):
agent_name: str = Field(..., description="Name of the agent")
@@ -75,6 +88,13 @@ class ContinuousAgentErrorData(BaseNotificationData):
error_time: datetime
attempts: int = Field(..., description="Number of retry attempts made")
@field_validator("start_time", "error_time")
@classmethod
def validate_timezone(cls, value: datetime):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class BaseSummaryData(BaseNotificationData):
total_credits_used: float
@@ -87,18 +107,53 @@ class BaseSummaryData(BaseNotificationData):
cost_breakdown: dict[str, float]
class BaseSummaryParams(BaseModel):
pass
class DailySummaryParams(BaseSummaryParams):
date: datetime
@field_validator("date")
def validate_timezone(cls, value):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class WeeklySummaryParams(BaseSummaryParams):
start_date: datetime
end_date: datetime
@field_validator("start_date", "end_date")
def validate_timezone(cls, value):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class DailySummaryData(BaseSummaryData):
date: datetime
@field_validator("date")
def validate_timezone(cls, value):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class WeeklySummaryData(BaseSummaryData):
start_date: datetime
end_date: datetime
week_number: int
year: int
@field_validator("start_date", "end_date")
def validate_timezone(cls, value):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
class MonthlySummaryData(BaseSummaryData):
class MonthlySummaryData(BaseNotificationData):
month: int
year: int
@@ -125,6 +180,7 @@ NotificationData = Annotated[
WeeklySummaryData,
DailySummaryData,
RefundRequestData,
BaseSummaryData,
],
Field(discriminator="type"),
]
@@ -134,15 +190,22 @@ class NotificationEventDTO(BaseModel):
user_id: str
type: NotificationType
data: dict
created_at: datetime = Field(default_factory=datetime.now)
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
retry_count: int = 0
class NotificationEventModel(BaseModel, Generic[T_co]):
class SummaryParamsEventDTO(BaseModel):
user_id: str
type: NotificationType
data: T_co
created_at: datetime = Field(default_factory=datetime.now)
data: dict
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
class NotificationEventModel(BaseModel, Generic[NotificationDataType_co]):
user_id: str
type: NotificationType
data: NotificationDataType_co
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
@property
def strategy(self) -> QueueType:
@@ -159,7 +222,14 @@ class NotificationEventModel(BaseModel, Generic[T_co]):
return NotificationTypeOverride(self.type).template
def get_data_type(
class SummaryParamsEventModel(BaseModel, Generic[SummaryParamsType_co]):
user_id: str
type: NotificationType
data: SummaryParamsType_co
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
def get_notif_data_type(
notification_type: NotificationType,
) -> type[BaseNotificationData]:
return {
@@ -176,11 +246,20 @@ def get_data_type(
}[notification_type]
def get_summary_params_type(
notification_type: NotificationType,
) -> type[BaseSummaryParams]:
return {
NotificationType.DAILY_SUMMARY: DailySummaryParams,
NotificationType.WEEKLY_SUMMARY: WeeklySummaryParams,
}[notification_type]
class NotificationBatch(BaseModel):
user_id: str
events: list[NotificationEvent]
strategy: QueueType
last_update: datetime = datetime.now()
last_update: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
class NotificationResult(BaseModel):
@@ -258,12 +337,51 @@ class NotificationPreference(BaseModel):
)
daily_limit: int = 10 # Max emails per day
emails_sent_today: int = 0
last_reset_date: datetime = Field(default_factory=datetime.now)
last_reset_date: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc)
)
class UserNotificationEventDTO(BaseModel):
type: NotificationType
data: dict
created_at: datetime
updated_at: datetime
@staticmethod
def from_db(model: NotificationEvent) -> "UserNotificationEventDTO":
return UserNotificationEventDTO(
type=model.type,
data=dict(model.data),
created_at=model.createdAt,
updated_at=model.updatedAt,
)
class UserNotificationBatchDTO(BaseModel):
user_id: str
type: NotificationType
notifications: list[UserNotificationEventDTO]
created_at: datetime
updated_at: datetime
@staticmethod
def from_db(model: UserNotificationBatch) -> "UserNotificationBatchDTO":
return UserNotificationBatchDTO(
user_id=model.userId,
type=model.type,
notifications=[
UserNotificationEventDTO.from_db(notification)
for notification in model.Notifications or []
],
created_at=model.createdAt,
updated_at=model.updatedAt,
)
def get_batch_delay(notification_type: NotificationType) -> timedelta:
return {
NotificationType.AGENT_RUN: timedelta(minutes=1),
NotificationType.AGENT_RUN: timedelta(minutes=60),
NotificationType.ZERO_BALANCE: timedelta(minutes=60),
NotificationType.LOW_BALANCE: timedelta(minutes=60),
NotificationType.BLOCK_EXECUTION_FAILED: timedelta(minutes=60),
@@ -275,7 +393,7 @@ async def create_or_add_to_user_notification_batch(
user_id: str,
notification_type: NotificationType,
notification_data: NotificationEventModel,
) -> UserNotificationBatch:
) -> UserNotificationBatchDTO:
try:
logger.info(
f"Creating or adding to notification batch for {user_id} with type {notification_type} and data {notification_data}"
@@ -292,7 +410,7 @@ async def create_or_add_to_user_notification_batch(
"type": notification_type,
}
},
include={"notifications": True},
include={"Notifications": True},
)
if not existing_batch:
@@ -309,11 +427,11 @@ async def create_or_add_to_user_notification_batch(
data={
"userId": user_id,
"type": notification_type,
"notifications": {"connect": [{"id": notification_event.id}]},
"Notifications": {"connect": [{"id": notification_event.id}]},
},
include={"notifications": True},
include={"Notifications": True},
)
return resp
return UserNotificationBatchDTO.from_db(resp)
else:
async with transaction() as tx:
notification_event = await tx.notificationevent.create(
@@ -327,15 +445,15 @@ async def create_or_add_to_user_notification_batch(
resp = await tx.usernotificationbatch.update(
where={"id": existing_batch.id},
data={
"notifications": {"connect": [{"id": notification_event.id}]}
"Notifications": {"connect": [{"id": notification_event.id}]}
},
include={"notifications": True},
include={"Notifications": True},
)
if not resp:
raise DatabaseError(
f"Failed to add notification event {notification_event.id} to existing batch {existing_batch.id}"
)
return resp
return UserNotificationBatchDTO.from_db(resp)
except Exception as e:
raise DatabaseError(
f"Failed to create or add to notification batch for user {user_id} and type {notification_type}: {e}"
@@ -345,18 +463,23 @@ async def create_or_add_to_user_notification_batch(
async def get_user_notification_oldest_message_in_batch(
user_id: str,
notification_type: NotificationType,
) -> NotificationEvent | None:
) -> UserNotificationEventDTO | None:
try:
batch = await UserNotificationBatch.prisma().find_first(
where={"userId": user_id, "type": notification_type},
include={"notifications": True},
include={"Notifications": True},
)
if not batch:
return None
if not batch.notifications:
if not batch.Notifications:
return None
sorted_notifications = sorted(batch.notifications, key=lambda x: x.createdAt)
return sorted_notifications[0]
sorted_notifications = sorted(batch.Notifications, key=lambda x: x.createdAt)
return (
UserNotificationEventDTO.from_db(sorted_notifications[0])
if sorted_notifications
else None
)
except Exception as e:
raise DatabaseError(
f"Failed to get user notification last message in batch for user {user_id} and type {notification_type}: {e}"
@@ -391,12 +514,13 @@ async def empty_user_notification_batch(
async def get_user_notification_batch(
user_id: str,
notification_type: NotificationType,
) -> UserNotificationBatch | None:
) -> UserNotificationBatchDTO | None:
try:
return await UserNotificationBatch.prisma().find_first(
batch = await UserNotificationBatch.prisma().find_first(
where={"userId": user_id, "type": notification_type},
include={"notifications": True},
include={"Notifications": True},
)
return UserNotificationBatchDTO.from_db(batch) if batch else None
except Exception as e:
raise DatabaseError(
f"Failed to get user notification batch for user {user_id} and type {notification_type}: {e}"
@@ -405,17 +529,18 @@ async def get_user_notification_batch(
async def get_all_batches_by_type(
notification_type: NotificationType,
) -> list[UserNotificationBatch]:
) -> list[UserNotificationBatchDTO]:
try:
return await UserNotificationBatch.prisma().find_many(
batches = await UserNotificationBatch.prisma().find_many(
where={
"type": notification_type,
"notifications": {
"Notifications": {
"some": {} # Only return batches with at least one notification
},
},
include={"notifications": True},
include={"Notifications": True},
)
return [UserNotificationBatchDTO.from_db(batch) for batch in batches]
except Exception as e:
raise DatabaseError(
f"Failed to get all batches by type {notification_type}: {e}"

View File

@@ -1,15 +1,12 @@
from backend.app import run_processes
from backend.executor import DatabaseManager, ExecutionManager
from backend.executor import ExecutionManager
def main():
"""
Run all the processes required for the AutoGPT-server REST API.
"""
run_processes(
DatabaseManager(),
ExecutionManager(),
)
run_processes(ExecutionManager())
if __name__ == "__main__":

View File

@@ -1,13 +1,13 @@
from backend.data.credit import get_user_credit_model
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
ExecutionResult,
NodeExecutionEntry,
RedisExecutionEventBus,
create_graph_execution,
get_execution_results,
get_incomplete_executions,
get_output_from_links,
update_execution_status,
update_execution_status_batch,
update_graph_execution_start_time,
update_graph_execution_stats,
update_node_execution_stats,
@@ -20,9 +20,20 @@ from backend.data.graph import (
get_graph_metadata,
get_node,
)
from backend.data.notifications import (
create_or_add_to_user_notification_batch,
empty_user_notification_batch,
get_all_batches_by_type,
get_user_notification_batch,
get_user_notification_oldest_message_in_batch,
)
from backend.data.user import (
get_active_user_ids_in_timerange,
get_user_email_by_id,
get_user_email_verification,
get_user_integrations,
get_user_metadata,
get_user_notification_preference,
update_user_integrations,
update_user_metadata,
)
@@ -33,8 +44,10 @@ config = Config()
_user_credit_model = get_user_credit_model()
async def _spend_credits(entry: NodeExecutionEntry) -> int:
return await _user_credit_model.spend_credits(entry, 0, 0)
async def _spend_credits(
user_id: str, cost: int, metadata: UsageTransactionMetadata
) -> int:
return await _user_credit_model.spend_credits(user_id, cost, metadata)
class DatabaseManager(AppService):
@@ -58,6 +71,7 @@ class DatabaseManager(AppService):
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
get_output_from_links = exposed_run_and_wait(get_output_from_links)
update_execution_status = exposed_run_and_wait(update_execution_status)
update_execution_status_batch = exposed_run_and_wait(update_execution_status_batch)
update_graph_execution_start_time = exposed_run_and_wait(
update_graph_execution_start_time
)
@@ -80,3 +94,24 @@ class DatabaseManager(AppService):
update_user_metadata = exposed_run_and_wait(update_user_metadata)
get_user_integrations = exposed_run_and_wait(get_user_integrations)
update_user_integrations = exposed_run_and_wait(update_user_integrations)
# User Comms - async
get_active_user_ids_in_timerange = exposed_run_and_wait(
get_active_user_ids_in_timerange
)
get_user_email_by_id = exposed_run_and_wait(get_user_email_by_id)
get_user_email_verification = exposed_run_and_wait(get_user_email_verification)
get_user_notification_preference = exposed_run_and_wait(
get_user_notification_preference
)
# Notifications - async
create_or_add_to_user_notification_batch = exposed_run_and_wait(
create_or_add_to_user_notification_batch
)
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
get_all_batches_by_type = exposed_run_and_wait(get_all_batches_by_type)
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)
get_user_notification_oldest_message_in_batch = exposed_run_and_wait(
get_user_notification_oldest_message_in_batch
)

View File

@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
from redis.lock import Lock as RedisLock
from backend.blocks.basic import AgentOutputBlock
from backend.blocks.io import AgentOutputBlock
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.notifications import (
AgentRunData,
@@ -48,6 +48,11 @@ from backend.data.execution import (
parse_execution_output,
)
from backend.data.graph import GraphModel, Link, Node
from backend.executor.utils import (
UsageTransactionMetadata,
block_usage_cost,
execution_usage_cost,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util import json
from backend.util.decorator import error_logged, time_measured
@@ -206,11 +211,7 @@ def execute_node(
extra_exec_kwargs[field_name] = credentials
output_size = 0
cost = 0
try:
# Charge the user for the execution before running the block.
cost = db_client.spend_credits(data)
outputs: dict[str, Any] = {}
for output_name, output_data in node_block.execute(
input_data, **extra_exec_kwargs
@@ -266,7 +267,6 @@ def execute_node(
)
execution_stats.input_size = input_size
execution_stats.output_size = output_size
execution_stats.cost = cost
def _enqueue_next_nodes(
@@ -657,6 +657,53 @@ class Executor:
cls._handle_agent_run_notif(graph_exec, exec_stats)
@classmethod
def _charge_usage(
cls,
node_exec: NodeExecutionEntry,
execution_count: int,
execution_stats: GraphExecutionStats,
) -> int:
block = get_block(node_exec.block_id)
if not block:
logger.error(f"Block {node_exec.block_id} not found.")
return execution_count
cost, matching_filter = block_usage_cost(block=block, input_data=node_exec.data)
if cost > 0:
cls.db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
block=block.name,
input=matching_filter,
),
)
execution_stats.cost += cost
cost, execution_count = execution_usage_cost(execution_count)
if cost > 0:
cls.db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
input={
"execution_count": execution_count,
"charge": "Execution Cost",
},
),
)
execution_stats.cost += cost
return execution_count
@classmethod
@time_measured
def _on_graph_execution(
@@ -691,14 +738,10 @@ class Executor:
try:
queue = ExecutionQueue[NodeExecutionEntry]()
for node_exec in graph_exec.start_node_execs:
exec_update = cls.db_client.update_execution_status(
node_exec.node_exec_id, ExecutionStatus.QUEUED, node_exec.data
)
cls.db_client.send_execution_update(exec_update)
queue.add(node_exec)
exec_cost_counter = 0
running_executions: dict[str, AsyncResult] = {}
low_balance_error: Optional[InsufficientBalanceError] = None
def make_exec_callback(exec_data: NodeExecutionEntry):
@@ -708,17 +751,13 @@ class Executor:
if not isinstance(result, NodeExecutionStats):
return
nonlocal exec_stats, low_balance_error
nonlocal exec_stats
exec_stats.node_count += 1
exec_stats.nodes_cputime += result.cputime
exec_stats.nodes_walltime += result.walltime
exec_stats.cost += result.cost
if (err := result.error) and isinstance(err, Exception):
exec_stats.node_error_count += 1
if isinstance(err, InsufficientBalanceError):
low_balance_error = err
return callback
while not queue.empty():
@@ -740,6 +779,30 @@ class Executor:
f"Dispatching node execution {exec_data.node_exec_id} "
f"for node {exec_data.node_id}",
)
try:
exec_cost_counter = cls._charge_usage(
node_exec=exec_data,
execution_count=exec_cost_counter + 1,
execution_stats=exec_stats,
)
except InsufficientBalanceError as error:
exec_id = exec_data.node_exec_id
cls.db_client.upsert_execution_output(exec_id, "error", str(error))
exec_update = cls.db_client.update_execution_status(
exec_id, ExecutionStatus.FAILED
)
cls.db_client.send_execution_update(exec_update)
cls._handle_low_balance_notif(
graph_exec.user_id,
graph_exec.graph_id,
exec_stats,
error,
)
raise
running_executions[exec_data.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(queue, exec_data),
@@ -763,32 +826,24 @@ class Executor:
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
if isinstance(low_balance_error, InsufficientBalanceError):
cls._handle_low_balance_notif(
graph_exec.user_id,
graph_exec.graph_id,
exec_stats,
low_balance_error,
)
raise low_balance_error
except Exception as e:
log_metadata.exception(
f"Failed graph execution {graph_exec.graph_exec_id}: {e}"
)
error = e
finally:
if error:
log_metadata.error(
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
)
execution_status = ExecutionStatus.FAILED
else:
execution_status = ExecutionStatus.COMPLETED
if not cancel.is_set():
finished = True
cancel.set()
cancel_thread.join()
clean_exec_files(graph_exec.graph_exec_id)
return (
exec_stats,
ExecutionStatus.FAILED if error else ExecutionStatus.COMPLETED,
error,
)
return exec_stats, execution_status, error
@classmethod
def _handle_agent_run_notif(
@@ -799,7 +854,10 @@ class Executor:
metadata = cls.db_client.get_graph_metadata(
graph_exec.graph_id, graph_exec.graph_version
)
outputs = cls.db_client.get_execution_results(graph_exec.graph_exec_id)
outputs = cls.db_client.get_execution_results(
graph_exec.graph_exec_id,
block_ids=[AgentOutputBlock().id],
)
named_outputs = [
{
@@ -807,7 +865,6 @@ class Executor:
for key, value in output.output_data.items()
}
for output in outputs
if output.block_id == AgentOutputBlock().id
]
event = NotificationEventDTO(
@@ -1001,29 +1058,36 @@ class ExecutionManager(AppService):
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
if graph_exec_id not in self.active_graph_runs:
raise Exception(
logger.warning(
f"Graph execution #{graph_exec_id} not active/running: "
"possibly already completed/cancelled."
)
else:
future, cancel_event = self.active_graph_runs[graph_exec_id]
if not cancel_event.is_set():
cancel_event.set()
future.result()
future, cancel_event = self.active_graph_runs[graph_exec_id]
if cancel_event.is_set():
return
cancel_event.set()
future.result()
# Update the status of the unfinished node executions
node_execs = self.db_client.get_execution_results(graph_exec_id)
# Update the status of the graph & node executions
self.db_client.update_graph_execution_stats(
graph_exec_id,
ExecutionStatus.TERMINATED,
)
node_execs = self.db_client.get_execution_results(
graph_exec_id=graph_exec_id,
statuses=[
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
ExecutionStatus.INCOMPLETE,
],
)
self.db_client.update_execution_status_batch(
[node_exec.node_exec_id for node_exec in node_execs],
ExecutionStatus.TERMINATED,
)
for node_exec in node_execs:
if node_exec.status not in (
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
):
exec_update = self.db_client.update_execution_status(
node_exec.node_exec_id, ExecutionStatus.TERMINATED
)
self.db_client.send_execution_update(exec_update)
node_exec.status = ExecutionStatus.TERMINATED
self.db_client.send_execution_update(node_exec)
def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
"""Checks all credentials for all nodes of the graph"""

View File

@@ -5,6 +5,7 @@ from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
from apscheduler.job import Job as JobObj
from apscheduler.jobstores.memory import MemoryJobStore
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
from apscheduler.schedulers.blocking import BlockingScheduler
from apscheduler.triggers.cron import CronTrigger
@@ -93,9 +94,18 @@ def process_existing_batches(**kwargs):
logger.exception(f"Error processing existing batches: {e}")
def process_weekly_summary(**kwargs):
try:
log("Processing weekly summary")
get_notification_client().queue_weekly_summary()
except Exception as e:
logger.exception(f"Error processing weekly summary: {e}")
class Jobstores(Enum):
EXECUTION = "execution"
BATCHED_NOTIFICATIONS = "batched_notifications"
WEEKLY_NOTIFICATIONS = "weekly_notifications"
class ExecutionJobArgs(BaseModel):
@@ -189,6 +199,8 @@ class Scheduler(AppService):
metadata=MetaData(schema=db_schema),
tablename="apscheduler_jobs_batched_notifications",
),
# These don't really need persistence
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
}
)
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
@@ -242,6 +254,9 @@ class Scheduler(AppService):
) -> list[ExecutionJobInfo]:
schedules = []
for job in self.scheduler.get_jobs(jobstore=Jobstores.EXECUTION.value):
logger.info(
f"Found job {job.id} with cron schedule {job.trigger} and args {job.kwargs}"
)
job_args = ExecutionJobArgs(**job.kwargs)
if (
job.next_run_time is not None
@@ -271,3 +286,21 @@ class Scheduler(AppService):
)
log(f"Added job {job.id} with cron schedule '{cron}' input data: {data}")
return NotificationJobInfo.from_db(job_args, job)
@expose
def add_weekly_notification_schedule(self, cron: str) -> NotificationJobInfo:
job = self.scheduler.add_job(
process_weekly_summary,
CronTrigger.from_crontab(cron),
kwargs={},
replace_existing=True,
jobstore=Jobstores.WEEKLY_NOTIFICATIONS.value,
)
log(f"Added job {job.id} with cron schedule '{cron}'")
return NotificationJobInfo.from_db(
NotificationJobArgs(
cron=cron, notification_types=[NotificationType.WEEKLY_SUMMARY]
),
job,
)

View File

@@ -0,0 +1,97 @@
from pydantic import BaseModel
from backend.data.block import Block, BlockInput
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCostType
from backend.util.settings import Config
config = Config()
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
node_exec_id: str | None = None
block_id: str | None = None
block: str | None = None
input: BlockInput | None = None
def execution_usage_cost(execution_count: int) -> tuple[int, int]:
"""
Calculate the cost of executing a graph based on the number of executions.
Args:
execution_count: Number of executions
Returns:
Tuple of cost amount and remaining execution count
"""
return (
execution_count
// config.execution_cost_count_threshold
* config.execution_cost_per_threshold,
execution_count % config.execution_cost_count_threshold,
)
def block_usage_cost(
block: Block,
input_data: BlockInput,
data_size: float = 0,
run_time: float = 0,
) -> tuple[int, BlockInput]:
"""
Calculate the cost of using a block based on the input data and the block type.
Args:
block: Block object
input_data: Input data for the block
data_size: Size of the input data in bytes
run_time: Execution time of the block in seconds
Returns:
Tuple of cost amount and cost filter
"""
block_costs = BLOCK_COSTS.get(type(block))
if not block_costs:
return 0, {}
for block_cost in block_costs:
if not _is_cost_filter_match(block_cost.cost_filter, input_data):
continue
if block_cost.cost_type == BlockCostType.RUN:
return block_cost.cost_amount, block_cost.cost_filter
if block_cost.cost_type == BlockCostType.SECOND:
return (
int(run_time * block_cost.cost_amount),
block_cost.cost_filter,
)
if block_cost.cost_type == BlockCostType.BYTE:
return (
int(data_size * block_cost.cost_amount),
block_cost.cost_filter,
)
return 0, {}
def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bool:
"""
Filter rules:
- If cost_filter is an object, then check if cost_filter is the subset of input_data
- Otherwise, check if cost_filter is equal to input_data.
- Undefined, null, and empty string are considered as equal.
"""
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
return cost_filter == input_data
return all(
(not input_data.get(k) and not v)
or (input_data.get(k) and _is_cost_filter_match(v, input_data[k]))
for k, v in cost_filter.items()
)

View File

@@ -1,22 +1,43 @@
from typing import TYPE_CHECKING
from .compass import CompassWebhookManager
from .github import GithubWebhooksManager
from .slant3d import Slant3DWebhooksManager
if TYPE_CHECKING:
from ..providers import ProviderName
from ._base import BaseWebhooksManager
# --8<-- [start:WEBHOOK_MANAGERS_BY_NAME]
WEBHOOK_MANAGERS_BY_NAME: dict["ProviderName", type["BaseWebhooksManager"]] = {
handler.PROVIDER_NAME: handler
for handler in [
CompassWebhookManager,
GithubWebhooksManager,
Slant3DWebhooksManager,
]
}
# --8<-- [end:WEBHOOK_MANAGERS_BY_NAME]
_WEBHOOK_MANAGERS: dict["ProviderName", type["BaseWebhooksManager"]] = {}
__all__ = ["WEBHOOK_MANAGERS_BY_NAME"]
# --8<-- [start:load_webhook_managers]
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
if _WEBHOOK_MANAGERS:
return _WEBHOOK_MANAGERS
from .compass import CompassWebhookManager
from .github import GithubWebhooksManager
from .slant3d import Slant3DWebhooksManager
_WEBHOOK_MANAGERS.update(
{
handler.PROVIDER_NAME: handler
for handler in [
CompassWebhookManager,
GithubWebhooksManager,
Slant3DWebhooksManager,
]
}
)
return _WEBHOOK_MANAGERS
# --8<-- [end:load_webhook_managers]
def get_webhook_manager(provider_name: "ProviderName") -> "BaseWebhooksManager":
return load_webhook_managers()[provider_name]()
def supports_webhooks(provider_name: "ProviderName") -> bool:
return provider_name in load_webhook_managers()
__all__ = ["get_webhook_manager", "supports_webhooks"]

View File

@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Callable, Optional, cast
from backend.data.block import BlockSchema, BlockWebhookConfig, get_block
from backend.data.graph import set_node_webhook
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
from backend.integrations.webhooks import get_webhook_manager, supports_webhooks
if TYPE_CHECKING:
from backend.data.graph import GraphModel, NodeModel
@@ -123,7 +123,7 @@ async def on_node_activate(
return node
provider = block.webhook_config.provider
if provider not in WEBHOOK_MANAGERS_BY_NAME:
if not supports_webhooks(provider):
raise ValueError(
f"Block #{block.id} has webhook_config for provider {provider} "
"which does not support webhooks"
@@ -133,7 +133,7 @@ async def on_node_activate(
f"Activating webhook node #{node.id} with config {block.webhook_config}"
)
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
webhooks_manager = get_webhook_manager(provider)
if auto_setup_webhook := isinstance(block.webhook_config, BlockWebhookConfig):
try:
@@ -234,13 +234,13 @@ async def on_node_deactivate(
return node
provider = block.webhook_config.provider
if provider not in WEBHOOK_MANAGERS_BY_NAME:
if not supports_webhooks(provider):
raise ValueError(
f"Block #{block.id} has webhook_config for provider {provider} "
"which does not support webhooks"
)
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
webhooks_manager = get_webhook_manager(provider)
if node.webhook_id:
logger.debug(f"Node #{node.id} has webhook_id {node.webhook_id}")

View File

@@ -7,9 +7,9 @@ from prisma.enums import NotificationType
from pydantic import BaseModel
from backend.data.notifications import (
NotificationDataType_co,
NotificationEventModel,
NotificationTypeOverride,
T_co,
)
from backend.util.settings import Settings
from backend.util.text import TextFormatter
@@ -48,7 +48,10 @@ class EmailSender:
self,
notification: NotificationType,
user_email: str,
data: NotificationEventModel[T_co] | list[NotificationEventModel[T_co]],
data: (
NotificationEventModel[NotificationDataType_co]
| list[NotificationEventModel[NotificationDataType_co]]
),
user_unsub_link: str | None = None,
):
"""Send an email to a user using a template pulled from the notification type"""

View File

@@ -1,6 +1,6 @@
import logging
import time
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from typing import Callable
import aio_pika
@@ -10,25 +10,25 @@ from prisma.enums import NotificationType
from pydantic import BaseModel
from backend.data.notifications import (
BaseSummaryData,
BaseSummaryParams,
DailySummaryData,
DailySummaryParams,
NotificationEventDTO,
NotificationEventModel,
NotificationResult,
NotificationTypeOverride,
QueueType,
create_or_add_to_user_notification_batch,
empty_user_notification_batch,
get_all_batches_by_type,
SummaryParamsEventDTO,
SummaryParamsEventModel,
WeeklySummaryData,
WeeklySummaryParams,
get_batch_delay,
get_data_type,
get_user_notification_batch,
get_user_notification_oldest_message_in_batch,
get_notif_data_type,
get_summary_params_type,
)
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.user import (
generate_unsubscribe_link,
get_user_email_by_id,
get_user_email_verification,
get_user_notification_preference,
)
from backend.data.user import generate_unsubscribe_link
from backend.notifications.email import EmailSender
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Settings
@@ -68,6 +68,16 @@ def create_notification_config() -> RabbitMQConfig:
"x-dead-letter-routing-key": "failed.admin",
},
),
# Summary notification queues
Queue(
name="summary_notifications",
exchange=notification_exchange,
routing_key="notification.summary.#",
arguments={
"x-dead-letter-exchange": dead_letter_exchange.name,
"x-dead-letter-routing-key": "failed.summary",
},
),
# Batch Queue
Queue(
name="batch_notifications",
@@ -102,12 +112,18 @@ def get_scheduler():
return get_service_client(Scheduler)
@thread_cached
def get_db():
from backend.executor.database import DatabaseManager
return get_service_client(DatabaseManager)
class NotificationManager(AppService):
"""Service for handling notifications with batching support"""
def __init__(self):
super().__init__()
self.use_db = True
self.rabbitmq_config = create_notification_config()
self.running = True
self.email_sender = EmailSender()
@@ -116,19 +132,51 @@ class NotificationManager(AppService):
def get_port(cls) -> int:
return settings.config.notification_service_port
def get_routing_key(self, event: NotificationEventModel) -> str:
def get_routing_key(self, event_type: NotificationType) -> str:
strategy = NotificationTypeOverride(event_type).strategy
"""Get the appropriate routing key for an event"""
if event.strategy == QueueType.IMMEDIATE:
return f"notification.immediate.{event.type.value}"
elif event.strategy == QueueType.BACKOFF:
return f"notification.backoff.{event.type.value}"
elif event.strategy == QueueType.ADMIN:
return f"notification.admin.{event.type.value}"
elif event.strategy == QueueType.BATCH:
return f"notification.batch.{event.type.value}"
elif event.strategy == QueueType.SUMMARY:
return f"notification.summary.{event.type.value}"
return f"notification.{event.type.value}"
if strategy == QueueType.IMMEDIATE:
return f"notification.immediate.{event_type.value}"
elif strategy == QueueType.BACKOFF:
return f"notification.backoff.{event_type.value}"
elif strategy == QueueType.ADMIN:
return f"notification.admin.{event_type.value}"
elif strategy == QueueType.BATCH:
return f"notification.batch.{event_type.value}"
elif strategy == QueueType.SUMMARY:
return f"notification.summary.{event_type.value}"
return f"notification.{event_type.value}"
@expose
def queue_weekly_summary(self):
"""Process weekly summary for specified notification types"""
try:
logger.info("Processing weekly summary queuing operation")
processed_count = 0
current_time = datetime.now(tz=timezone.utc)
start_time = current_time - timedelta(days=7)
users = get_db().get_active_user_ids_in_timerange(
end_time=current_time.isoformat(),
start_time=start_time.isoformat(),
)
for user in users:
self._queue_scheduled_notification(
SummaryParamsEventDTO(
user_id=user,
type=NotificationType.WEEKLY_SUMMARY,
data=WeeklySummaryParams(
start_date=start_time,
end_date=current_time,
).model_dump(),
),
)
processed_count += 1
logger.info(f"Processed {processed_count} weekly summaries into queue")
except Exception as e:
logger.exception(f"Error processing weekly summary: {e}")
@expose
def process_existing_batches(self, notification_types: list[NotificationType]):
@@ -139,80 +187,74 @@ class NotificationManager(AppService):
for notification_type in notification_types:
# Get all batches for this notification type
batches = self.run_and_wait(get_all_batches_by_type(notification_type))
batches = get_db().get_all_batches_by_type(notification_type)
for batch in batches:
# Check if batch has aged out
oldest_message = self.run_and_wait(
get_user_notification_oldest_message_in_batch(
batch.userId, notification_type
oldest_message = (
get_db().get_user_notification_oldest_message_in_batch(
batch.user_id, notification_type
)
)
if not oldest_message:
# this should never happen
logger.error(
f"Batch for user {batch.userId} and type {notification_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
f"Batch for user {batch.user_id} and type {notification_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
)
continue
max_delay = get_batch_delay(notification_type)
# If batch has aged out, process it
if oldest_message.createdAt + max_delay < current_time:
recipient_email = self.run_and_wait(
get_user_email_by_id(batch.userId)
)
if oldest_message.created_at + max_delay < current_time:
recipient_email = get_db().get_user_email_by_id(batch.user_id)
if not recipient_email:
logger.error(
f"User email not found for user {batch.userId}"
f"User email not found for user {batch.user_id}"
)
continue
should_send = self._should_email_user_based_on_preference(
batch.userId, notification_type
batch.user_id, notification_type
)
if not should_send:
logger.debug(
f"User {batch.userId} does not want to receive {notification_type} notifications"
f"User {batch.user_id} does not want to receive {notification_type} notifications"
)
# Clear the batch
self.run_and_wait(
empty_user_notification_batch(
batch.userId, notification_type
)
get_db().empty_user_notification_batch(
batch.user_id, notification_type
)
continue
batch_data = self.run_and_wait(
get_user_notification_batch(batch.userId, notification_type)
batch_data = get_db().get_user_notification_batch(
batch.user_id, notification_type
)
if not batch_data or not batch_data.notifications:
logger.error(
f"Batch data not found for user {batch.userId}"
f"Batch data not found for user {batch.user_id}"
)
# Clear the batch
self.run_and_wait(
empty_user_notification_batch(
batch.userId, notification_type
)
get_db().empty_user_notification_batch(
batch.user_id, notification_type
)
continue
unsub_link = generate_unsubscribe_link(batch.userId)
unsub_link = generate_unsubscribe_link(batch.user_id)
events = [
NotificationEventModel[
get_data_type(db_event.type)
get_notif_data_type(db_event.type)
].model_validate(
{
"user_id": batch.userId,
"user_id": batch.user_id,
"type": db_event.type,
"data": db_event.data,
"created_at": db_event.createdAt,
"created_at": db_event.created_at,
}
)
for db_event in batch_data.notifications
@@ -227,10 +269,8 @@ class NotificationManager(AppService):
)
# Clear the batch
self.run_and_wait(
empty_user_notification_batch(
batch.userId, notification_type
)
get_db().empty_user_notification_batch(
batch.user_id, notification_type
)
processed_count += 1
@@ -259,9 +299,9 @@ class NotificationManager(AppService):
logger.info(f"Received Request to queue {event=}")
# Workaround for not being able to serialize generics over the expose bus
parsed_event = NotificationEventModel[
get_data_type(event.type)
get_notif_data_type(event.type)
].model_validate(event.model_dump())
routing_key = self.get_routing_key(parsed_event)
routing_key = self.get_routing_key(parsed_event.type)
message = parsed_event.model_dump_json()
logger.info(f"Received Request to queue {message=}")
@@ -288,24 +328,136 @@ class NotificationManager(AppService):
logger.exception(f"Error queueing notification: {e}")
return NotificationResult(success=False, message=str(e))
def _queue_scheduled_notification(self, event: SummaryParamsEventDTO):
"""Queue a scheduled notification - exposed method for other services to call"""
try:
logger.info(f"Received Request to queue scheduled notification {event=}")
parsed_event = SummaryParamsEventModel[
get_summary_params_type(event.type)
].model_validate(event.model_dump())
routing_key = self.get_routing_key(event.type)
message = parsed_event.model_dump_json()
logger.info(f"Received Request to queue {message=}")
exchange = "notifications"
# Publish to RabbitMQ
self.run_and_wait(
self.rabbit.publish_message(
routing_key=routing_key,
message=message,
exchange=next(
ex for ex in self.rabbit_config.exchanges if ex.name == exchange
),
)
)
except Exception as e:
logger.exception(f"Error queueing notification: {e}")
def _should_email_user_based_on_preference(
self, user_id: str, event_type: NotificationType
) -> bool:
"""Check if a user wants to receive a notification based on their preferences and email verification status"""
validated_email = self.run_and_wait(get_user_email_verification(user_id))
preference = self.run_and_wait(
get_user_notification_preference(user_id)
).preferences.get(event_type, True)
validated_email = get_db().get_user_email_verification(user_id)
preference = (
get_db()
.get_user_notification_preference(user_id)
.preferences.get(event_type, True)
)
# only if both are true, should we email this person
return validated_email and preference
async def _should_batch(
def _gather_summary_data(
self, user_id: str, event_type: NotificationType, params: BaseSummaryParams
) -> BaseSummaryData:
"""Gathers the data to build a summary notification"""
logger.info(
f"Gathering summary data for {user_id} and {event_type} wiht {params=}"
)
# total_credits_used = self.run_and_wait(
# get_total_credits_used(user_id, start_time, end_time)
# )
# total_executions = self.run_and_wait(
# get_total_executions(user_id, start_time, end_time)
# )
# most_used_agent = self.run_and_wait(
# get_most_used_agent(user_id, start_time, end_time)
# )
# execution_times = self.run_and_wait(
# get_execution_time(user_id, start_time, end_time)
# )
# runs = self.run_and_wait(
# get_runs(user_id, start_time, end_time)
# )
total_credits_used = 3.0
total_executions = 2
most_used_agent = {"name": "Some"}
execution_times = [1, 2, 3]
runs = [{"status": "COMPLETED"}, {"status": "FAILED"}]
successful_runs = len([run for run in runs if run["status"] == "COMPLETED"])
failed_runs = len([run for run in runs if run["status"] != "COMPLETED"])
average_execution_time = (
sum(execution_times) / len(execution_times) if execution_times else 0
)
# cost_breakdown = self.run_and_wait(
# get_cost_breakdown(user_id, start_time, end_time)
# )
cost_breakdown = {
"agent1": 1.0,
"agent2": 2.0,
}
if event_type == NotificationType.DAILY_SUMMARY and isinstance(
params, DailySummaryParams
):
return DailySummaryData(
total_credits_used=total_credits_used,
total_executions=total_executions,
most_used_agent=most_used_agent["name"],
total_execution_time=sum(execution_times),
successful_runs=successful_runs,
failed_runs=failed_runs,
average_execution_time=average_execution_time,
cost_breakdown=cost_breakdown,
date=params.date,
)
elif event_type == NotificationType.WEEKLY_SUMMARY and isinstance(
params, WeeklySummaryParams
):
return WeeklySummaryData(
total_credits_used=total_credits_used,
total_executions=total_executions,
most_used_agent=most_used_agent["name"],
total_execution_time=sum(execution_times),
successful_runs=successful_runs,
failed_runs=failed_runs,
average_execution_time=average_execution_time,
cost_breakdown=cost_breakdown,
start_date=params.start_date,
end_date=params.end_date,
)
else:
raise ValueError("Invalid event type or params")
def _should_batch(
self, user_id: str, event_type: NotificationType, event: NotificationEventModel
) -> bool:
await create_or_add_to_user_notification_batch(user_id, event_type, event)
get_db().create_or_add_to_user_notification_batch(user_id, event_type, event)
oldest_message = await get_user_notification_oldest_message_in_batch(
oldest_message = get_db().get_user_notification_oldest_message_in_batch(
user_id, event_type
)
if not oldest_message:
@@ -313,7 +465,7 @@ class NotificationManager(AppService):
f"Batch for user {user_id} and type {event_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
)
return False
oldest_age = oldest_message.createdAt
oldest_age = oldest_message.created_at
max_delay = get_batch_delay(event_type)
@@ -329,7 +481,7 @@ class NotificationManager(AppService):
try:
event = NotificationEventDTO.model_validate_json(message)
model = NotificationEventModel[
get_data_type(event.type)
get_notif_data_type(event.type)
].model_validate_json(message)
return NotificationEvent(event=event, model=model)
except Exception as e:
@@ -362,7 +514,7 @@ class NotificationManager(AppService):
model = parsed.model
logger.debug(f"Processing immediate notification: {model}")
recipient_email = self.run_and_wait(get_user_email_by_id(event.user_id))
recipient_email = get_db().get_user_email_by_id(event.user_id)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
@@ -399,7 +551,7 @@ class NotificationManager(AppService):
model = parsed.model
logger.info(f"Processing batch notification: {model}")
recipient_email = self.run_and_wait(get_user_email_by_id(event.user_id))
recipient_email = get_db().get_user_email_by_id(event.user_id)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
@@ -413,28 +565,26 @@ class NotificationManager(AppService):
)
return True
should_send = self.run_and_wait(
self._should_batch(event.user_id, event.type, model)
)
should_send = self._should_batch(event.user_id, event.type, model)
if not should_send:
logger.info("Batch not old enough to send")
return False
batch = self.run_and_wait(
get_user_notification_batch(event.user_id, event.type)
)
batch = get_db().get_user_notification_batch(event.user_id, event.type)
if not batch or not batch.notifications:
logger.error(f"Batch not found for user {event.user_id}")
return False
unsub_link = generate_unsubscribe_link(event.user_id)
batch_messages = [
NotificationEventModel[get_data_type(db_event.type)].model_validate(
NotificationEventModel[
get_notif_data_type(db_event.type)
].model_validate(
{
"user_id": event.user_id,
"type": db_event.type,
"data": db_event.data,
"created_at": db_event.createdAt,
"created_at": db_event.created_at,
}
)
for db_event in batch.notifications
@@ -447,12 +597,59 @@ class NotificationManager(AppService):
user_unsub_link=unsub_link,
)
# only empty the batch if we sent the email successfully
self.run_and_wait(empty_user_notification_batch(event.user_id, event.type))
get_db().empty_user_notification_batch(event.user_id, event.type)
return True
except Exception as e:
logger.exception(f"Error processing notification for batch queue: {e}")
return False
def _process_summary(self, message: str) -> bool:
"""Process a single notification with a summary strategy, returning whether to put into the failed queue"""
try:
logger.info(f"Processing summary notification: {message}")
event = SummaryParamsEventDTO.model_validate_json(message)
model = SummaryParamsEventModel[
get_summary_params_type(event.type)
].model_validate_json(message)
logger.info(f"Processing summary notification: {model}")
recipient_email = get_db().get_user_email_by_id(event.user_id)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
should_send = self._should_email_user_based_on_preference(
event.user_id, event.type
)
if not should_send:
logger.info(
f"User {event.user_id} does not want to receive {event.type} notifications"
)
return True
summary_data = self._gather_summary_data(
event.user_id, event.type, model.data
)
unsub_link = generate_unsubscribe_link(event.user_id)
data = NotificationEventModel(
user_id=event.user_id,
type=event.type,
data=summary_data,
)
self.email_sender.send_templated(
notification=event.type,
user_email=recipient_email,
data=data,
user_unsub_link=unsub_link,
)
return True
except Exception as e:
logger.exception(f"Error processing notification for summary queue: {e}")
return False
def _run_queue(
self,
queue: aio_pika.abc.AbstractQueue,
@@ -493,6 +690,10 @@ class NotificationManager(AppService):
data={},
cron="0 * * * *",
)
# get_scheduler().add_weekly_notification_schedule(
# # weekly on Friday at 12pm
# cron="0 12 * * 5",
# )
logger.info("Scheduled notification cleanup")
except Exception as e:
logger.error(f"Error scheduling notification cleanup: {e}")
@@ -507,6 +708,8 @@ class NotificationManager(AppService):
admin_queue = self.run_and_wait(channel.get_queue("admin_notifications"))
summary_queue = self.run_and_wait(channel.get_queue("summary_notifications"))
while self.running:
try:
self._run_queue(
@@ -525,6 +728,12 @@ class NotificationManager(AppService):
error_queue_name="batch_notifications",
)
self._run_queue(
queue=summary_queue,
process_func=self._process_summary,
error_queue_name="summary_notifications",
)
time.sleep(0.1)
except QueueEmpty as e:

View File

@@ -0,0 +1,27 @@
{# Weekly Summary #}
{# Template variables:
data: the stuff below
data.start_date: the start date of the summary
data.end_date: the end date of the summary
data.total_credits_used: the total credits used during the summary
data.total_executions: the total number of executions during the summary
data.most_used_agent: the most used agent's nameduring the summary
data.total_execution_time: the total execution time during the summary
data.successful_runs: the total number of successful runs during the summary
data.failed_runs: the total number of failed runs during the summary
data.average_execution_time: the average execution time during the summary
data.cost_breakdown: the cost breakdown during the summary
#}
<h1>Weekly Summary</h1>
<p>Start Date: {{ data.start_date }}</p>
<p>End Date: {{ data.end_date }}</p>
<p>Total Credits Used: {{ data.total_credits_used }}</p>
<p>Total Executions: {{ data.total_executions }}</p>
<p>Most Used Agent: {{ data.most_used_agent }}</p>
<p>Total Execution Time: {{ data.total_execution_time }}</p>
<p>Successful Runs: {{ data.successful_runs }}</p>
<p>Failed Runs: {{ data.failed_runs }}</p>
<p>Average Execution Time: {{ data.average_execution_time }}</p>
<p>Cost Breakdown: {{ data.cost_breakdown }}</p>

View File

@@ -20,23 +20,25 @@ class ConnectionManager:
for subscribers in self.subscriptions.values():
subscribers.discard(websocket)
async def subscribe(self, graph_id: str, graph_version: int, websocket: WebSocket):
key = f"{graph_id}_{graph_version}"
async def subscribe(
self, *, user_id: str, graph_id: str, graph_version: int, websocket: WebSocket
):
key = f"{user_id}_{graph_id}_{graph_version}"
if key not in self.subscriptions:
self.subscriptions[key] = set()
self.subscriptions[key].add(websocket)
async def unsubscribe(
self, graph_id: str, graph_version: int, websocket: WebSocket
self, *, user_id: str, graph_id: str, graph_version: int, websocket: WebSocket
):
key = f"{graph_id}_{graph_version}"
key = f"{user_id}_{graph_id}_{graph_version}"
if key in self.subscriptions:
self.subscriptions[key].discard(websocket)
if not self.subscriptions[key]:
del self.subscriptions[key]
async def send_execution_result(self, result: execution.ExecutionResult):
key = f"{result.graph_id}_{result.graph_version}"
key = f"{result.user_id}_{result.graph_id}_{result.graph_version}"
if key in self.subscriptions:
message = WsMessage(
method=Methods.EXECUTION_EVENT,

View File

@@ -71,7 +71,7 @@ def get_outputs_with_names(results: List[ExecutionResult]) -> List[Dict[str, str
)
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.data.block.get_blocks().values()]
return [b.to_dict() for b in blocks]
return [b.to_dict() for b in blocks if not b.disabled]
@v1_router.post(

View File

@@ -17,7 +17,7 @@ from backend.executor.manager import ExecutionManager
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
from backend.integrations.webhooks import get_webhook_manager
from backend.util.exceptions import NeedConfirmation
from backend.util.service import get_service_client
from backend.util.settings import Settings
@@ -281,7 +281,7 @@ async def webhook_ingress_generic(
webhook_id: Annotated[str, Path(title="Our ID for the webhook")],
):
logger.debug(f"Received {provider.value} webhook ingress for ID {webhook_id}")
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
webhook_manager = get_webhook_manager(provider)
webhook = await get_webhook(webhook_id)
logger.debug(f"Webhook #{webhook_id}: {webhook}")
payload, event_type = await webhook_manager.validate_payload(webhook, request)
@@ -323,7 +323,7 @@ async def webhook_ping(
user_id: Annotated[str, Depends(get_user_id)], # require auth
):
webhook = await get_webhook(webhook_id)
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[webhook.provider]()
webhook_manager = get_webhook_manager(webhook.provider)
credentials = (
creds_manager.get(user_id, webhook.credentials_id)
@@ -358,14 +358,6 @@ async def remove_all_webhooks_for_credentials(
NeedConfirmation: If any of the webhooks are still in use and `force` is `False`
"""
webhooks = await get_all_webhooks_by_creds(credentials.id)
if credentials.provider not in WEBHOOK_MANAGERS_BY_NAME:
if webhooks:
logger.error(
f"Credentials #{credentials.id} for provider {credentials.provider} "
f"are attached to {len(webhooks)} webhooks, "
f"but there is no available WebhooksHandler for {credentials.provider}"
)
return
if any(w.attached_nodes for w in webhooks) and not force:
raise NeedConfirmation(
"Some webhooks linked to these credentials are still in use by an agent"
@@ -376,7 +368,7 @@ async def remove_all_webhooks_for_credentials(
await set_node_webhook(node.id, None)
# Prune the webhook
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[credentials.provider]()
webhook_manager = get_webhook_manager(ProviderName(credentials.provider))
success = await webhook_manager.prune_webhook_if_dangling(
webhook.id, credentials
)

View File

@@ -18,6 +18,7 @@ import backend.data.graph
import backend.data.user
import backend.server.integrations.router
import backend.server.routers.v1
import backend.server.v2.admin.store_admin_routes
import backend.server.v2.library.db
import backend.server.v2.library.model
import backend.server.v2.library.routes
@@ -99,6 +100,11 @@ app.include_router(backend.server.routers.v1.v1_router, tags=["v1"], prefix="/ap
app.include_router(
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
)
app.include_router(
backend.server.v2.admin.store_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/store",
)
app.include_router(
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
)
@@ -154,9 +160,10 @@ class AgentServer(backend.util.service.AppProcess):
graph_id: str,
graph_version: int,
user_id: str,
for_export: bool = False,
):
return await backend.server.routers.v1.get_graph(
graph_id, user_id, graph_version
graph_id, user_id, graph_version, for_export
)
@staticmethod
@@ -249,12 +256,16 @@ class AgentServer(backend.util.service.AppProcess):
):
return await backend.server.v2.store.routes.create_submission(request, user_id)
### ADMIN ###
@staticmethod
async def test_review_store_listing(
request: backend.server.v2.store.model.ReviewSubmissionRequest,
user: autogpt_libs.auth.models.User,
):
return await backend.server.v2.store.routes.review_submission(request, user)
return await backend.server.v2.admin.store_admin_routes.review_submission(
request.store_listing_version_id, request, user
)
@staticmethod
def test_create_credentials(

View File

@@ -38,7 +38,6 @@ from backend.data.credit import (
TransactionHistory,
get_auto_top_up,
get_block_costs,
get_stripe_customer_id,
get_user_credit_model,
set_auto_top_up,
)
@@ -199,7 +198,9 @@ async def get_onboarding_agents(
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.data.block.get_blocks().values()]
costs = get_block_costs()
return [{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks]
return [
{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks if not b.disabled
]
@v1_router.post(
@@ -341,15 +342,7 @@ async def stripe_webhook(request: Request):
async def manage_payment_method(
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[str, str]:
session = stripe.billing_portal.Session.create(
customer=await get_stripe_customer_id(user_id),
return_url=settings.config.frontend_base_url + "/profile/credits",
)
if not session:
raise HTTPException(
status_code=400, detail="Failed to create billing portal session"
)
return {"url": session.url}
return {"url": await _user_credit_model.create_billing_portal_session(user_id)}
@v1_router.get(path="/credits/transactions", dependencies=[Depends(auth_middleware)])
@@ -405,10 +398,10 @@ async def get_graph(
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
version: int | None = None,
hide_credentials: bool = False,
for_export: bool = False,
) -> graph_db.GraphModel:
graph = await graph_db.get_graph(
graph_id, version, user_id=user_id, for_export=hide_credentials
graph_id, version, user_id=user_id, for_export=for_export
)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
@@ -438,6 +431,7 @@ async def create_new_graph(
) -> graph_db.GraphModel:
graph = graph_db.make_graph_model(create_graph.graph, user_id)
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
graph.validate_graph(for_run=False)
graph = await graph_db.create_graph(graph, user_id=user_id)
@@ -489,17 +483,10 @@ async def update_graph(
latest_version_number = max(g.version for g in existing_versions)
graph.version = latest_version_number + 1
latest_version_graph = next(
v for v in existing_versions if v.version == latest_version_number
)
current_active_version = next((v for v in existing_versions if v.is_active), None)
if latest_version_graph.is_template != graph.is_template:
raise HTTPException(
400, detail="Changing is_template on an existing graph is forbidden"
)
graph.is_active = not graph.is_template
graph = graph_db.make_graph_model(graph, user_id)
graph.reassign_ids(user_id=user_id)
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
graph.validate_graph(for_run=False)
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
@@ -630,7 +617,7 @@ async def stop_graph_run(
async def get_graphs_executions(
user_id: Annotated[str, Depends(get_user_id)],
) -> list[graph_db.GraphExecutionMeta]:
return await graph_db.get_graphs_executions(user_id=user_id)
return await graph_db.get_graph_executions(user_id=user_id)
@v1_router.get(
@@ -655,12 +642,8 @@ async def get_graph_execution(
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> graph_db.GraphExecution:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
result = await graph_db.get_execution(execution_id=graph_exec_id, user_id=user_id)
if not result:
if not result or result.graph_id != graph_id:
raise HTTPException(
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
)

View File

@@ -0,0 +1,100 @@
import logging
import typing
import autogpt_libs.auth.depends
import fastapi
import fastapi.responses
import prisma.enums
import backend.server.v2.store.db
import backend.server.v2.store.exceptions
import backend.server.v2.store.model
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(prefix="/admin", tags=["store", "admin"])
@router.get(
"/listings",
response_model=backend.server.v2.store.model.StoreListingsWithVersionsResponse,
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
)
async def get_admin_listings_with_versions(
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
search: typing.Optional[str] = None,
page: int = 1,
page_size: int = 20,
):
"""
Get store listings with their version history for admins.
This provides a consolidated view of listings with their versions,
allowing for an expandable UI in the admin dashboard.
Args:
status: Filter by submission status (PENDING, APPROVED, REJECTED)
search: Search by name, description, or user email
page: Page number for pagination
page_size: Number of items per page
Returns:
StoreListingsWithVersionsResponse with listings and their versions
"""
try:
listings = await backend.server.v2.store.db.get_admin_listings_with_versions(
status=status,
search_query=search,
page=page,
page_size=page_size,
)
return listings
except Exception as e:
logger.exception("Error getting admin listings with versions: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={
"detail": "An error occurred while retrieving listings with versions"
},
)
@router.post(
"/submissions/{store_listing_version_id}/review",
response_model=backend.server.v2.store.model.StoreSubmission,
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
)
async def review_submission(
store_listing_version_id: str,
request: backend.server.v2.store.model.ReviewSubmissionRequest,
user: typing.Annotated[
autogpt_libs.auth.models.User,
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
],
):
"""
Review a store listing submission.
Args:
store_listing_version_id: ID of the submission to review
request: Review details including approval status and comments
user: Authenticated admin user performing the review
Returns:
StoreSubmission with updated review information
"""
try:
submission = await backend.server.v2.store.db.review_store_submission(
store_listing_version_id=store_listing_version_id,
is_approved=request.is_approved,
external_comments=request.comments,
internal_comments=request.internal_comments or "",
reviewer_id=user.user_id,
)
return submission
except Exception as e:
logger.exception("Error reviewing submission: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while reviewing the submission"},
)

View File

@@ -1,4 +1,3 @@
import asyncio
import logging
from typing import Optional
@@ -186,12 +185,7 @@ async def add_generated_agent_image(
try:
if not (image_url := await store_media.check_media_exists(user_id, filename)):
# Generate agent image as JPEG
if config.use_agent_image_generation_v2:
image = await asyncio.to_thread(
store_image_gen.generate_agent_image_v2, graph=graph
)
else:
image = await store_image_gen.generate_agent_image(agent=graph)
image = await store_image_gen.generate_agent_image(graph)
# Create UploadFile with the correct filename and content_type
image_file = fastapi.UploadFile(file=image, filename=filename)

View File

@@ -1,22 +1,14 @@
from datetime import datetime
import prisma.enums
import prisma.errors
import prisma.models
import pytest
from prisma import Prisma
import backend.server.v2.library.db as db
import backend.server.v2.store.exceptions
@pytest.fixture(autouse=True)
async def setup_prisma():
# Don't register client if already registered
try:
Prisma()
except prisma.errors.ClientAlreadyRegisteredError:
pass
yield
from backend.data.db import connect
from backend.data.includes import library_agent_include
@pytest.mark.asyncio
@@ -31,7 +23,6 @@ async def test_get_library_agents(mocker):
userId="test-user",
isActive=True,
createdAt=datetime.now(),
isTemplate=False,
)
]
@@ -56,7 +47,6 @@ async def test_get_library_agents(mocker):
userId="other-user",
isActive=True,
createdAt=datetime.now(),
isTemplate=False,
),
)
]
@@ -91,17 +81,17 @@ async def test_get_library_agents(mocker):
assert result.pagination.page_size == 50
@pytest.mark.asyncio
@pytest.mark.asyncio(scope="session")
async def test_add_agent_to_library(mocker):
await connect()
# Mock data
mock_store_listing = prisma.models.StoreListingVersion(
mock_store_listing_data = prisma.models.StoreListingVersion(
id="version123",
version=1,
createdAt=datetime.now(),
updatedAt=datetime.now(),
agentId="agent1",
agentVersion=1,
slug="test-agent",
name="Test Agent",
subHeading="Test Agent Subheading",
imageUrls=["https://example.com/image.jpg"],
@@ -110,7 +100,8 @@ async def test_add_agent_to_library(mocker):
isFeatured=False,
isDeleted=False,
isAvailable=True,
isApproved=True,
storeListingId="listing123",
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
Agent=prisma.models.AgentGraph(
id="agent1",
version=1,
@@ -119,21 +110,37 @@ async def test_add_agent_to_library(mocker):
userId="creator",
isActive=True,
createdAt=datetime.now(),
isTemplate=False,
),
)
mock_library_agent_data = prisma.models.LibraryAgent(
id="ua1",
userId="test-user",
agentId=mock_store_listing_data.agentId,
agentVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
createdAt=datetime.now(),
updatedAt=datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
Agent=mock_store_listing_data.Agent,
)
# Mock prisma calls
mock_store_listing_version = mocker.patch(
"prisma.models.StoreListingVersion.prisma"
)
mock_store_listing_version.return_value.find_unique = mocker.AsyncMock(
return_value=mock_store_listing
return_value=mock_store_listing_data
)
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock()
mock_library_agent.return_value.create = mocker.AsyncMock(
return_value=mock_library_agent_data
)
# Call function
await db.add_store_agent_to_library("version123", "test-user")
@@ -147,17 +154,20 @@ async def test_add_agent_to_library(mocker):
"userId": "test-user",
"agentId": "agent1",
"agentVersion": 1,
}
},
include=library_agent_include("test-user"),
)
mock_library_agent.return_value.create.assert_called_once_with(
data=prisma.types.LibraryAgentCreateInput(
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
)
),
include=library_agent_include("test-user"),
)
@pytest.mark.asyncio
@pytest.mark.asyncio(scope="session")
async def test_add_agent_to_library_not_found(mocker):
await connect()
# Mock prisma calls
mock_store_listing_version = mocker.patch(
"prisma.models.StoreListingVersion.prisma"

View File

@@ -2,11 +2,14 @@ import datetime
import prisma.fields
import prisma.models
import pytest
import backend.server.v2.library.model as library_model
from backend.util import json
def test_agent_preset_from_db():
@pytest.mark.asyncio
async def test_agent_preset_from_db():
# Create mock DB agent
db_agent = prisma.models.AgentPreset(
id="test-agent-123",
@@ -24,7 +27,7 @@ def test_agent_preset_from_db():
id="input-123",
time=datetime.datetime.now(),
name="input1",
data=prisma.fields.Json({"type": "string", "value": "test value"}),
data=json.dumps({"type": "string", "value": "test value"}), # type: ignore
)
],
)

View File

@@ -1,7 +1,6 @@
import datetime
import autogpt_libs.auth as autogpt_auth_lib
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
@@ -30,49 +29,48 @@ app.dependency_overrides[autogpt_auth_lib.auth_middleware] = override_auth_middl
app.dependency_overrides[autogpt_auth_lib.depends.get_user_id] = override_get_user_id
def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
mocked_value = [
library_model.LibraryAgentResponse(
agents=[
library_model.LibraryAgent(
id="test-agent-1",
agent_id="test-agent-1",
agent_version=1,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=True,
is_latest_version=True,
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
),
library_model.LibraryAgent(
id="test-agent-2",
agent_id="test-agent-2",
agent_version=1,
name="Test Agent 2",
description="Test Description 2",
image_url=None,
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=False,
is_latest_version=True,
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
),
],
pagination=server_model.Pagination(
total_items=2, total_pages=1, current_page=1, page_size=50
@pytest.mark.asyncio
async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
mocked_value = library_model.LibraryAgentResponse(
agents=[
library_model.LibraryAgent(
id="test-agent-1",
agent_id="test-agent-1",
agent_version=1,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=True,
is_latest_version=True,
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
),
library_model.LibraryAgent(
id="test-agent-2",
agent_id="test-agent-2",
agent_version=1,
name="Test Agent 2",
description="Test Description 2",
image_url=None,
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=False,
is_latest_version=True,
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
),
],
pagination=server_model.Pagination(
total_items=2, total_pages=1, current_page=1, page_size=50
),
]
mock_db_call = mocker.patch("backend.server.v2.library.db.get_library_agents")
)
mock_db_call = mocker.patch("backend.server.v2.library.db.list_library_agents")
mock_db_call.return_value = mocked_value
response = client.get("/agents?search_term=test")
@@ -94,7 +92,7 @@ def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
def test_get_library_agents_error(mocker: pytest_mock.MockFixture):
mock_db_call = mocker.patch("backend.server.v2.library.db.get_library_agents")
mock_db_call = mocker.patch("backend.server.v2.library.db.list_library_agents")
mock_db_call.side_effect = Exception("Test error")
response = client.get("/agents?search_term=test")

View File

@@ -4,15 +4,11 @@ from autogpt_libs.auth.middleware import auth_middleware
from fastapi import APIRouter, Depends
from backend.server.utils import get_user_id
from backend.util.settings import Settings
from .models import ApiResponse, ChatRequest
from .service import OttoService
logger = logging.getLogger(__name__)
settings = Settings()
OTTO_API_URL = settings.config.otto_api_url
router = APIRouter()

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import Optional
@@ -67,6 +68,13 @@ class OttoService:
"""
Send request to Otto API and handle the response.
"""
# Check if Otto API URL is configured
if not OTTO_API_URL:
logger.error("Otto API URL is not configured")
raise HTTPException(
status_code=503, detail="Otto service is not configured"
)
try:
async with aiohttp.ClientSession() as session:
headers = {
@@ -94,7 +102,10 @@ class OttoService:
logger.debug(f"Request payload: {payload}")
async with session.post(
OTTO_API_URL, json=payload, headers=headers
OTTO_API_URL,
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=60),
) as response:
if response.status != 200:
error_text = await response.text()
@@ -115,6 +126,11 @@ class OttoService:
raise HTTPException(
status_code=503, detail="Failed to connect to Otto service"
)
except asyncio.TimeoutError:
logger.error("Timeout error connecting to Otto API after 60 seconds")
raise HTTPException(
status_code=504, detail="Request to Otto service timed out"
)
except Exception as e:
logger.error(f"Unexpected error in Otto API proxy: {str(e)}")
raise HTTPException(

View File

@@ -1,6 +1,5 @@
import logging
from datetime import datetime
from typing import Optional
from datetime import datetime, timezone
import fastapi
import prisma.enums
@@ -11,7 +10,8 @@ import prisma.types
import backend.data.graph
import backend.server.v2.store.exceptions
import backend.server.v2.store.model
from backend.data.graph import GraphModel
from backend.data.graph import GraphModel, get_sub_graphs
from backend.data.includes import AGENT_GRAPH_INCLUDE
logger = logging.getLogger(__name__)
@@ -44,6 +44,9 @@ async def get_store_agents(
page: int = 1,
page_size: int = 20,
) -> backend.server.v2.store.model.StoreAgentsResponse:
"""
Get PUBLIC store agents from the StoreAgent view
"""
logger.debug(
f"Getting store agents. featured={featured}, creator={creator}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
)
@@ -129,6 +132,7 @@ async def get_store_agents(
async def get_store_agent_details(
username: str, agent_name: str
) -> backend.server.v2.store.model.StoreAgentDetails:
"""Get PUBLIC store agent details from the StoreAgent view"""
logger.debug(f"Getting store agent details for {username}/{agent_name}")
try:
@@ -142,6 +146,20 @@ async def get_store_agent_details(
f"Agent {username}/{agent_name} not found"
)
# Retrieve StoreListing to get active_version_id and has_approved_version
store_listing = await prisma.models.StoreListing.prisma().find_first(
where=prisma.types.StoreListingWhereInput(
slug=agent_name,
owningUserId=username, # Direct equality check instead of 'has'
),
include={"ActiveVersion": True},
)
active_version_id = store_listing.activeVersionId if store_listing else None
has_approved_version = (
store_listing.hasApprovedVersion if store_listing else False
)
logger.debug(f"Found agent details for {username}/{agent_name}")
return backend.server.v2.store.model.StoreAgentDetails(
store_listing_version_id=agent.storeListingVersionId,
@@ -158,6 +176,8 @@ async def get_store_agent_details(
rating=agent.rating,
versions=agent.versions,
last_updated=agent.updated_at,
active_version_id=active_version_id,
has_approved_version=has_approved_version,
)
except backend.server.v2.store.exceptions.AgentNotFoundError:
raise
@@ -175,6 +195,7 @@ async def get_store_creators(
page: int = 1,
page_size: int = 20,
) -> backend.server.v2.store.model.CreatorsResponse:
"""Get PUBLIC store creators from the Creator view"""
logger.debug(
f"Getting store creators. featured={featured}, search={search_query}, sorted_by={sorted_by}, page={page}"
)
@@ -322,6 +343,7 @@ async def get_store_creator_details(
async def get_store_submissions(
user_id: str, page: int = 1, page_size: int = 20
) -> backend.server.v2.store.model.StoreSubmissionsResponse:
"""Get store submissions for the authenticated user -- not an admin"""
logger.debug(f"Getting store submissions for user {user_id}, page={page}")
try:
@@ -343,8 +365,9 @@ async def get_store_submissions(
total_pages = (total + page_size - 1) // page_size
# Convert to response models
submission_models = [
backend.server.v2.store.model.StoreSubmission(
submission_models = []
for sub in submissions:
submission_model = backend.server.v2.store.model.StoreSubmission(
agent_id=sub.agent_id,
agent_version=sub.agent_version,
name=sub.name,
@@ -352,13 +375,18 @@ async def get_store_submissions(
slug=sub.slug,
description=sub.description,
image_urls=sub.image_urls or [],
date_submitted=sub.date_submitted or datetime.now(),
date_submitted=sub.date_submitted or datetime.now(tz=timezone.utc),
status=sub.status,
runs=sub.runs or 0,
rating=sub.rating or 0.0,
store_listing_version_id=sub.store_listing_version_id,
reviewer_id=sub.reviewer_id,
review_comments=sub.review_comments,
# internal_comments omitted for regular users
reviewed_at=sub.reviewed_at,
changes_summary=sub.changes_summary,
)
for sub in submissions
]
submission_models.append(submission_model)
logger.debug(f"Found {len(submission_models)} submissions")
return backend.server.v2.store.model.StoreSubmissionsResponse(
@@ -390,7 +418,7 @@ async def delete_store_submission(
submission_id: str,
) -> bool:
"""
Delete a store listing submission.
Delete a store listing submission as the submitting user.
Args:
user_id: ID of the authenticated user
@@ -437,9 +465,10 @@ async def create_store_submission(
description: str = "",
sub_heading: str = "",
categories: list[str] = [],
changes_summary: str = "Initial Submission",
) -> backend.server.v2.store.model.StoreSubmission:
"""
Create a new store listing submission.
Create the first (and only) store listing and thus submission as a normal user
Args:
user_id: ID of the authenticated user submitting the listing
@@ -450,7 +479,9 @@ async def create_store_submission(
video_url: Optional URL to video demo
image_urls: List of image URLs for the listing
description: Description of the agent
sub_heading: Optional sub-heading for the agent
categories: List of categories for the agent
changes_summary: Summary of changes made in this submission
Returns:
StoreSubmission: The created store submission
@@ -480,45 +511,66 @@ async def create_store_submission(
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
)
listing = await prisma.models.StoreListing.prisma().find_first(
# Check if listing already exists for this agent
existing_listing = await prisma.models.StoreListing.prisma().find_first(
where=prisma.types.StoreListingWhereInput(
agentId=agent_id, owningUserId=user_id
)
)
if listing is not None:
logger.warning(f"Listing already exists for agent {agent_id}")
raise backend.server.v2.store.exceptions.ListingExistsError(
"Listing already exists for this agent"
if existing_listing is not None:
logger.info(
f"Listing already exists for agent {agent_id}, creating new version instead"
)
# Create the store listing
listing = await prisma.models.StoreListing.prisma().create(
data={
"agentId": agent_id,
"agentVersion": agent_version,
"owningUserId": user_id,
"createdAt": datetime.now(),
"StoreListingVersions": {
"create": {
"agentId": agent_id,
"agentVersion": agent_version,
"slug": slug,
"name": name,
"videoUrl": video_url,
"imageUrls": image_urls,
"description": description,
"categories": categories,
"subHeading": sub_heading,
}
},
# Delegate to create_store_version which already handles this case correctly
return await create_store_version(
user_id=user_id,
agent_id=agent_id,
agent_version=agent_version,
store_listing_id=existing_listing.id,
name=name,
video_url=video_url,
image_urls=image_urls,
description=description,
sub_heading=sub_heading,
categories=categories,
changes_summary=changes_summary,
)
# If no existing listing, create a new one
data = prisma.types.StoreListingCreateInput(
slug=slug,
agentId=agent_id,
agentVersion=agent_version,
owningUserId=user_id,
createdAt=datetime.now(tz=timezone.utc),
Versions={
"create": [
prisma.types.StoreListingVersionCreateInput(
agentId=agent_id,
agentVersion=agent_version,
name=name,
videoUrl=video_url,
imageUrls=image_urls,
description=description,
categories=categories,
subHeading=sub_heading,
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
submittedAt=datetime.now(tz=timezone.utc),
changesSummary=changes_summary,
)
]
},
include={"StoreListingVersions": True},
)
listing = await prisma.models.StoreListing.prisma().create(
data=data,
include=prisma.types.StoreListingInclude(Versions=True),
)
store_listing_version_id = (
listing.StoreListingVersions[0].id
if listing.StoreListingVersions is not None
and len(listing.StoreListingVersions) > 0
listing.Versions[0].id
if listing.Versions is not None and len(listing.Versions) > 0
else None
)
@@ -537,6 +589,7 @@ async def create_store_submission(
runs=0,
rating=0.0,
store_listing_version_id=store_listing_version_id,
changes_summary=changes_summary,
)
except (
@@ -551,13 +604,137 @@ async def create_store_submission(
) from e
async def create_store_version(
user_id: str,
agent_id: str,
agent_version: int,
store_listing_id: str,
name: str,
video_url: str | None = None,
image_urls: list[str] = [],
description: str = "",
sub_heading: str = "",
categories: list[str] = [],
changes_summary: str = "Update Submission",
) -> backend.server.v2.store.model.StoreSubmission:
"""
Create a new version for an existing store listing
Args:
user_id: ID of the authenticated user submitting the version
agent_id: ID of the agent being submitted
agent_version: Version of the agent being submitted
store_listing_id: ID of the existing store listing
name: Name of the agent
video_url: Optional URL to video demo
image_urls: List of image URLs for the listing
description: Description of the agent
categories: List of categories for the agent
changes_summary: Summary of changes from the previous version
Returns:
StoreSubmission: The created store submission
"""
logger.debug(
f"Creating new version for store listing {store_listing_id} for user {user_id}, agent {agent_id} v{agent_version}"
)
try:
# First verify the listing belongs to this user
listing = await prisma.models.StoreListing.prisma().find_first(
where=prisma.types.StoreListingWhereInput(
id=store_listing_id, owningUserId=user_id
),
include={"Versions": {"order_by": {"version": "desc"}, "take": 1}},
)
if not listing:
raise backend.server.v2.store.exceptions.ListingNotFoundError(
f"Store listing not found. User ID: {user_id}, Listing ID: {store_listing_id}"
)
# Verify the agent belongs to this user
agent = await prisma.models.AgentGraph.prisma().find_first(
where=prisma.types.AgentGraphWhereInput(
id=agent_id, version=agent_version, userId=user_id
)
)
if not agent:
raise backend.server.v2.store.exceptions.AgentNotFoundError(
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
)
# Get the latest version number
latest_version = listing.Versions[0] if listing.Versions else None
next_version = (latest_version.version + 1) if latest_version else 1
# Create a new version for the existing listing
new_version = await prisma.models.StoreListingVersion.prisma().create(
data=prisma.types.StoreListingVersionCreateInput(
version=next_version,
agentId=agent_id,
agentVersion=agent_version,
name=name,
videoUrl=video_url,
imageUrls=image_urls,
description=description,
categories=categories,
subHeading=sub_heading,
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
submittedAt=datetime.now(),
changesSummary=changes_summary,
storeListingId=store_listing_id,
)
)
logger.debug(
f"Created new version for listing {store_listing_id} of agent {agent_id}"
)
# Return submission details
return backend.server.v2.store.model.StoreSubmission(
agent_id=agent_id,
agent_version=agent_version,
name=name,
slug=listing.slug,
sub_heading=sub_heading,
description=description,
image_urls=image_urls,
date_submitted=datetime.now(),
status=prisma.enums.SubmissionStatus.PENDING,
runs=0,
rating=0.0,
store_listing_version_id=new_version.id,
changes_summary=changes_summary,
version=next_version,
)
except prisma.errors.PrismaError as e:
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to create new store version"
) from e
async def create_store_review(
user_id: str,
store_listing_version_id: str,
score: int,
comments: str | None = None,
) -> backend.server.v2.store.model.StoreReview:
"""Create a review for a store listing as a user to detail their experience"""
try:
data = prisma.types.StoreListingReviewUpsertInput(
update=prisma.types.StoreListingReviewUpdateInput(
score=score,
comments=comments,
),
create=prisma.types.StoreListingReviewCreateInput(
reviewByUserId=user_id,
storeListingVersionId=store_listing_version_id,
score=score,
comments=comments,
),
)
review = await prisma.models.StoreListingReview.prisma().upsert(
where={
"storeListingVersionId_reviewByUserId": {
@@ -565,18 +742,7 @@ async def create_store_review(
"reviewByUserId": user_id,
}
},
data={
"create": {
"reviewByUserId": user_id,
"storeListingVersionId": store_listing_version_id,
"score": score,
"comments": comments,
},
"update": {
"score": score,
"comments": comments,
},
},
data=data,
)
return backend.server.v2.store.model.StoreReview(
@@ -598,7 +764,7 @@ async def get_user_profile(
try:
profile = await prisma.models.Profile.prisma().find_first(
where={"userId": user_id} # type: ignore
where={"userId": user_id}
)
if not profile:
@@ -703,48 +869,39 @@ async def get_my_agents(
page: int = 1,
page_size: int = 20,
) -> backend.server.v2.store.model.MyAgentsResponse:
"""Get the agents for the authenticated user"""
logger.debug(f"Getting my agents for user {user_id}, page={page}")
try:
agents_with_max_version = await prisma.models.AgentGraph.prisma().find_many(
where=prisma.types.AgentGraphWhereInput(
userId=user_id, StoreListing={"none": {"isDeleted": False}}
),
order=[{"version": "desc"}],
distinct=["id"],
search_filter: prisma.types.LibraryAgentWhereInput = {
"userId": user_id,
"Agent": {"is": {"StoreListing": {"none": {"isDeleted": False}}}},
"isArchived": False,
"isDeleted": False,
}
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
where=search_filter,
order=[{"agentVersion": "desc"}],
skip=(page - 1) * page_size,
take=page_size,
include={"Agent": True},
)
# store_listings = await prisma.models.StoreListing.prisma().find_many(
# where=prisma.types.StoreListingWhereInput(
# isDeleted=False,
# ),
# )
total = len(
await prisma.models.AgentGraph.prisma().find_many(
where=prisma.types.AgentGraphWhereInput(
userId=user_id, StoreListing={"none": {"isDeleted": False}}
),
order=[{"version": "desc"}],
distinct=["id"],
)
)
total = await prisma.models.LibraryAgent.prisma().count(where=search_filter)
total_pages = (total + page_size - 1) // page_size
agents = agents_with_max_version
my_agents = [
backend.server.v2.store.model.MyAgent(
agent_id=agent.id,
agent_version=agent.version,
agent_name=agent.name or "",
last_edited=agent.updatedAt or agent.createdAt,
description=agent.description or "",
agent_id=graph.id,
agent_version=graph.version,
agent_name=graph.name or "",
last_edited=graph.updatedAt or graph.createdAt,
description=graph.description or "",
agent_image=library_agent.imageUrl,
)
for agent in agents
for library_agent in library_agents
if (graph := library_agent.Agent)
]
return backend.server.v2.store.model.MyAgentsResponse(
@@ -764,58 +921,87 @@ async def get_my_agents(
async def get_agent(
store_listing_version_id: str, version_id: Optional[int]
user_id: str,
store_listing_version_id: str,
) -> GraphModel:
"""Get agent using the version ID and store listing version ID."""
try:
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}, include={"Agent": True}
)
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}
)
)
if not store_listing_version:
raise ValueError(f"Store listing version {store_listing_version_id} not found")
graph = await backend.data.graph.get_graph(
user_id=user_id,
graph_id=store_listing_version.agentId,
version=store_listing_version.agentVersion,
for_export=True,
)
if not graph:
raise ValueError(
f"Agent {store_listing_version.agentId} v{store_listing_version.agentVersion} not found"
)
if not store_listing_version or not store_listing_version.Agent:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing version {store_listing_version_id} not found",
)
return graph
graph_id = store_listing_version.agentId
graph_version = store_listing_version.agentVersion
graph = await backend.data.graph.get_graph(graph_id, graph_version)
if not graph:
raise fastapi.HTTPException(
status_code=404,
detail=(
f"Agent #{graph_id} not found "
f"for store listing version #{store_listing_version_id}"
),
)
#####################################################
################## ADMIN FUNCTIONS ##################
#####################################################
graph.version = 1
graph.is_template = False
graph.is_active = True
delattr(graph, "user_id")
return graph
async def _get_missing_sub_store_listing(
graph: prisma.models.AgentGraph,
) -> list[prisma.models.AgentGraph]:
"""
Agent graph can have sub-graphs, and those sub-graphs also need to be store listed.
This method fetches the sub-graphs, and returns the ones not listed in the store.
"""
sub_graphs = await get_sub_graphs(graph)
if not sub_graphs:
return []
except Exception as e:
logger.error(f"Error getting agent: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch agent"
) from e
# Fetch all the sub-graphs that are listed, and return the ones missing.
store_listed_sub_graphs = {
(listing.agentId, listing.agentVersion)
for listing in await prisma.models.StoreListingVersion.prisma().find_many(
where={
"OR": [
{"agentId": sub_graph.id, "agentVersion": sub_graph.version}
for sub_graph in sub_graphs
],
"submissionStatus": prisma.enums.SubmissionStatus.APPROVED,
"isDeleted": False,
}
)
}
return [
sub_graph
for sub_graph in sub_graphs
if (sub_graph.id, sub_graph.version) not in store_listed_sub_graphs
]
async def review_store_submission(
store_listing_version_id: str, is_approved: bool, comments: str, reviewer_id: str
) -> prisma.models.StoreListingSubmission:
"""Review a store listing submission."""
store_listing_version_id: str,
is_approved: bool,
external_comments: str,
internal_comments: str,
reviewer_id: str,
) -> backend.server.v2.store.model.StoreSubmission:
"""Review a store listing submission as an admin."""
try:
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id},
include={"StoreListing": True},
include={
"StoreListing": True,
"Agent": {"include": AGENT_GRAPH_INCLUDE}, # type: ignore
},
)
)
@@ -825,10 +1011,34 @@ async def review_store_submission(
detail=f"Store listing version {store_listing_version_id} not found",
)
if is_approved:
# If approving, update the listing to indicate it has an approved version
if is_approved and store_listing_version.Agent:
heading = f"Sub-graph of {store_listing_version.name}v{store_listing_version.agentVersion}"
sub_store_listing_versions = [
prisma.types.StoreListingVersionCreateWithoutRelationsInput(
agentId=sub_graph.id,
agentVersion=sub_graph.version,
name=sub_graph.name or heading,
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
subHeading=heading,
description=f"{heading}: {sub_graph.description}",
changesSummary=f"This listing is added as a {heading} / #{store_listing_version.agentId}.",
isAvailable=False, # Hide sub-graphs from the store by default.
submittedAt=datetime.now(tz=timezone.utc),
)
for sub_graph in await _get_missing_sub_store_listing(
store_listing_version.Agent
)
]
await prisma.models.StoreListing.prisma().update(
where={"id": store_listing_version.StoreListing.id},
data={"isApproved": True},
data={
"hasApprovedVersion": True,
"ActiveVersion": {"connect": {"id": store_listing_version_id}},
"Versions": {"create": sub_store_listing_versions},
},
)
submission_status = (
@@ -837,36 +1047,230 @@ async def review_store_submission(
else prisma.enums.SubmissionStatus.REJECTED
)
update_data: prisma.types.StoreListingSubmissionUpdateInput = {
"Status": submission_status,
"reviewComments": comments,
# Update the version with review information
update_data: prisma.types.StoreListingVersionUpdateInput = {
"submissionStatus": submission_status,
"reviewComments": external_comments,
"internalComments": internal_comments,
"Reviewer": {"connect": {"id": reviewer_id}},
"StoreListing": {"connect": {"id": store_listing_version.StoreListing.id}},
"reviewedAt": datetime.now(tz=timezone.utc),
}
create_data: prisma.types.StoreListingSubmissionCreateInput = {
**update_data,
"StoreListingVersion": {"connect": {"id": store_listing_version_id}},
}
submission = await prisma.models.StoreListingSubmission.prisma().upsert(
where={"storeListingVersionId": store_listing_version_id},
data={
"create": create_data,
"update": update_data,
},
# Update the version
submission = await prisma.models.StoreListingVersion.prisma().update(
where={"id": store_listing_version_id},
data=update_data,
include={"StoreListing": True},
)
if not submission:
raise fastapi.HTTPException( # FIXME: don't return HTTP exceptions here
status_code=404,
detail=f"Store listing submission {store_listing_version_id} not found",
raise backend.server.v2.store.exceptions.DatabaseError(
f"Failed to update store listing version {store_listing_version_id}"
)
return submission
# Convert to Pydantic model for consistency
return backend.server.v2.store.model.StoreSubmission(
agent_id=submission.agentId,
agent_version=submission.agentVersion,
name=submission.name,
sub_heading=submission.subHeading,
slug=(
submission.StoreListing.slug
if hasattr(submission, "storeListing") and submission.StoreListing
else ""
),
description=submission.description,
image_urls=submission.imageUrls or [],
date_submitted=submission.submittedAt or submission.createdAt,
status=submission.submissionStatus,
runs=0, # Default values since we don't have this data here
rating=0.0,
store_listing_version_id=submission.id,
reviewer_id=submission.reviewerId,
review_comments=submission.reviewComments,
internal_comments=submission.internalComments,
reviewed_at=submission.reviewedAt,
changes_summary=submission.changesSummary,
)
except Exception as e:
logger.error(f"Could not create store submission review: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to create store submission review"
) from e
async def get_admin_listings_with_versions(
status: prisma.enums.SubmissionStatus | None = None,
search_query: str | None = None,
page: int = 1,
page_size: int = 20,
) -> backend.server.v2.store.model.StoreListingsWithVersionsResponse:
"""
Get store listings for admins with all their versions.
Args:
status: Filter by submission status (PENDING, APPROVED, REJECTED)
search_query: Search by name, description, or user email
page: Page number for pagination
page_size: Number of items per page
Returns:
StoreListingsWithVersionsResponse with listings and their versions
"""
logger.debug(
f"Getting admin store listings with status={status}, search={search_query}, page={page}"
)
try:
# Build the where clause for StoreListing
where_dict: prisma.types.StoreListingWhereInput = {
"isDeleted": False,
}
if status:
where_dict["Versions"] = {"some": {"submissionStatus": status}}
sanitized_query = sanitize_query(search_query)
if sanitized_query:
# Find users with matching email
matching_users = await prisma.models.User.prisma().find_many(
where={"email": {"contains": sanitized_query, "mode": "insensitive"}},
)
user_ids = [user.id for user in matching_users]
# Set up OR conditions
where_dict["OR"] = [
{"slug": {"contains": sanitized_query, "mode": "insensitive"}},
{
"Versions": {
"some": {
"name": {"contains": sanitized_query, "mode": "insensitive"}
}
}
},
{
"Versions": {
"some": {
"description": {
"contains": sanitized_query,
"mode": "insensitive",
}
}
}
},
{
"Versions": {
"some": {
"subHeading": {
"contains": sanitized_query,
"mode": "insensitive",
}
}
}
},
]
# Add user_id condition if any users matched
if user_ids:
where_dict["OR"].append({"owningUserId": {"in": user_ids}})
# Calculate pagination
skip = (page - 1) * page_size
# Create proper Prisma types for the query
where = prisma.types.StoreListingWhereInput(**where_dict)
include = prisma.types.StoreListingInclude(
Versions=prisma.types.FindManyStoreListingVersionArgsFromStoreListing(
order_by=prisma.types._StoreListingVersion_version_OrderByInput(
version="desc"
)
),
OwningUser=True,
)
# Query listings with their versions
listings = await prisma.models.StoreListing.prisma().find_many(
where=where,
skip=skip,
take=page_size,
include=include,
order=[{"createdAt": "desc"}],
)
# Get total count for pagination
total = await prisma.models.StoreListing.prisma().count(where=where)
total_pages = (total + page_size - 1) // page_size
# Convert to response models
listings_with_versions = []
for listing in listings:
versions: list[backend.server.v2.store.model.StoreSubmission] = []
# If we have versions, turn them into StoreSubmission models
for version in listing.Versions or []:
version_model = backend.server.v2.store.model.StoreSubmission(
agent_id=version.agentId,
agent_version=version.agentVersion,
name=version.name,
sub_heading=version.subHeading,
slug=listing.slug,
description=version.description,
image_urls=version.imageUrls or [],
date_submitted=version.submittedAt or version.createdAt,
status=version.submissionStatus,
runs=0, # Default values since we don't have this data here
rating=0.0, # Default values since we don't have this data here
store_listing_version_id=version.id,
reviewer_id=version.reviewerId,
review_comments=version.reviewComments,
internal_comments=version.internalComments,
reviewed_at=version.reviewedAt,
changes_summary=version.changesSummary,
version=version.version,
)
versions.append(version_model)
# Get the latest version (first in the sorted list)
latest_version = versions[0] if versions else None
creator_email = listing.OwningUser.email if listing.OwningUser else None
listing_with_versions = (
backend.server.v2.store.model.StoreListingWithVersions(
listing_id=listing.id,
slug=listing.slug,
agent_id=listing.agentId,
agent_version=listing.agentVersion,
active_version_id=listing.activeVersionId,
has_approved_version=listing.hasApprovedVersion,
creator_email=creator_email,
latest_version=latest_version,
versions=versions,
)
)
listings_with_versions.append(listing_with_versions)
logger.debug(f"Found {len(listings_with_versions)} listings for admin")
return backend.server.v2.store.model.StoreListingsWithVersionsResponse(
listings=listings_with_versions,
pagination=backend.server.v2.store.model.Pagination(
current_page=page,
total_items=total,
total_pages=total_pages,
page_size=page_size,
),
)
except Exception as e:
logger.error(f"Error fetching admin store listings: {e}")
# Return empty response rather than exposing internal errors
return backend.server.v2.store.model.StoreListingsWithVersionsResponse(
listings=[],
pagination=backend.server.v2.store.model.Pagination(
current_page=page,
total_items=0,
total_pages=0,
page_size=page_size,
),
)

View File

@@ -1,5 +1,6 @@
from datetime import datetime
import prisma.enums
import prisma.errors
import prisma.models
import pytest
@@ -83,21 +84,35 @@ async def test_get_store_agent_details(mocker):
updated_at=datetime.now(),
)
# Mock prisma call
# Create a mock StoreListing result
mock_store_listing = mocker.MagicMock()
mock_store_listing.activeVersionId = "active-version-id"
mock_store_listing.hasApprovedVersion = True
# Mock StoreAgent prisma call
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
# Mock StoreListing prisma call - this is what was missing
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_store_listing
)
# Call function
result = await db.get_store_agent_details("creator", "test-agent")
# Verify results
assert result.slug == "test-agent"
assert result.agent_name == "Test Agent"
assert result.active_version_id == "active-version-id"
assert result.has_approved_version is True
# Verify mock called correctly
# Verify mocks called correctly
mock_store_agent.return_value.find_first.assert_called_once_with(
where={"creator_username": "creator", "slug": "test-agent"}
)
mock_store_listing_db.return_value.find_first.assert_called_once()
@pytest.mark.asyncio
@@ -146,7 +161,6 @@ async def test_create_store_submission(mocker):
userId="user-id",
createdAt=datetime.now(),
isActive=True,
isTemplate=False,
)
mock_listing = prisma.models.StoreListing(
@@ -154,16 +168,16 @@ async def test_create_store_submission(mocker):
createdAt=datetime.now(),
updatedAt=datetime.now(),
isDeleted=False,
isApproved=False,
hasApprovedVersion=False,
slug="test-agent",
agentId="agent-id",
agentVersion=1,
owningUserId="user-id",
StoreListingVersions=[
Versions=[
prisma.models.StoreListingVersion(
id="version-id",
agentId="agent-id",
agentVersion=1,
slug="test-agent",
name="Test Agent",
description="Test description",
createdAt=datetime.now(),
@@ -174,8 +188,9 @@ async def test_create_store_submission(mocker):
isFeatured=False,
isDeleted=False,
version=1,
storeListingId="listing-id",
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
isAvailable=True,
isApproved=False,
)
],
)

View File

@@ -70,6 +70,12 @@ class ProfileNotFoundError(StoreError):
pass
class ListingNotFoundError(StoreError):
"""Raised when a store listing is not found"""
pass
class SubmissionNotFoundError(StoreError):
"""Raised when a submission is not found"""

View File

@@ -1,3 +1,4 @@
import asyncio
import io
import logging
from enum import Enum
@@ -34,6 +35,13 @@ class ImageStyle(str, Enum):
DIGITAL_ART = "digital art"
async def generate_agent_image(agent: Graph | AgentGraph) -> io.BytesIO:
if settings.config.use_agent_image_generation_v2:
return await asyncio.to_thread(generate_agent_image_v2, graph=agent)
else:
return await generate_agent_image_v1(agent=agent)
def generate_agent_image_v2(graph: Graph | AgentGraph) -> io.BytesIO:
"""
Generate an image for an agent using Ideogram model.
@@ -91,7 +99,7 @@ def generate_agent_image_v2(graph: Graph | AgentGraph) -> io.BytesIO:
return io.BytesIO(requests.get(url).content)
async def generate_agent_image(agent: Graph | AgentGraph) -> io.BytesIO:
async def generate_agent_image_v1(agent: Graph | AgentGraph) -> io.BytesIO:
"""
Generate an image for an agent using Flux model via Replicate API.

View File

@@ -24,6 +24,7 @@ class MyAgent(pydantic.BaseModel):
agent_id: str
agent_version: int
agent_name: str
agent_image: str | None = None
description: str
last_edited: datetime.datetime
@@ -66,6 +67,9 @@ class StoreAgentDetails(pydantic.BaseModel):
versions: list[str]
last_updated: datetime.datetime
active_version_id: str | None = None
has_approved_version: bool = False
class Creator(pydantic.BaseModel):
name: str
@@ -116,6 +120,19 @@ class StoreSubmission(pydantic.BaseModel):
runs: int
rating: float
store_listing_version_id: str | None = None
version: int | None = None # Actual version number from the database
reviewer_id: str | None = None
review_comments: str | None = None # External comments visible to creator
internal_comments: str | None = None # Private notes for admin use only
reviewed_at: datetime.datetime | None = None
changes_summary: str | None = None
reviewer_id: str | None = None
review_comments: str | None = None # External comments visible to creator
internal_comments: str | None = None # Private notes for admin use only
reviewed_at: datetime.datetime | None = None
changes_summary: str | None = None
class StoreSubmissionsResponse(pydantic.BaseModel):
@@ -123,6 +140,27 @@ class StoreSubmissionsResponse(pydantic.BaseModel):
pagination: Pagination
class StoreListingWithVersions(pydantic.BaseModel):
"""A store listing with its version history"""
listing_id: str
slug: str
agent_id: str
agent_version: int
active_version_id: str | None = None
has_approved_version: bool = False
creator_email: str | None = None
latest_version: StoreSubmission | None = None
versions: list[StoreSubmission] = []
class StoreListingsWithVersionsResponse(pydantic.BaseModel):
"""Response model for listings with version history"""
listings: list[StoreListingWithVersions]
pagination: Pagination
class StoreSubmissionRequest(pydantic.BaseModel):
agent_id: str
agent_version: int
@@ -133,6 +171,7 @@ class StoreSubmissionRequest(pydantic.BaseModel):
image_urls: list[str] = []
description: str = ""
categories: list[str] = []
changes_summary: str | None = None
class ProfileDetails(pydantic.BaseModel):
@@ -157,4 +196,5 @@ class StoreReviewCreate(pydantic.BaseModel):
class ReviewSubmissionRequest(pydantic.BaseModel):
store_listing_version_id: str
is_approved: bool
comments: str
comments: str # External comments visible to creator
internal_comments: str | None = None # Private admin notes

View File

@@ -1,4 +1,3 @@
import json
import logging
import tempfile
import typing
@@ -8,7 +7,6 @@ import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware
import fastapi
import fastapi.responses
from fastapi.encoders import jsonable_encoder
import backend.data.block
import backend.data.graph
@@ -16,6 +14,7 @@ import backend.server.v2.store.db
import backend.server.v2.store.image_gen
import backend.server.v2.store.media
import backend.server.v2.store.model
import backend.util.json
logger = logging.getLogger(__name__)
@@ -35,7 +34,7 @@ router = fastapi.APIRouter()
async def get_profile(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
]
],
):
"""
Get the profile details for the authenticated user.
@@ -339,7 +338,7 @@ async def get_creator(
async def get_my_agents(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
]
],
):
try:
agents = await backend.server.v2.store.db.get_my_agents(user_id)
@@ -467,7 +466,7 @@ async def create_submission(
HTTPException: If there is an error creating the submission
"""
try:
submission = await backend.server.v2.store.db.create_store_submission(
return await backend.server.v2.store.db.create_store_submission(
user_id=user_id,
agent_id=submission_request.agent_id,
agent_version=submission_request.agent_version,
@@ -478,8 +477,8 @@ async def create_submission(
description=submission_request.description,
sub_heading=submission_request.sub_heading,
categories=submission_request.categories,
changes_summary=submission_request.changes_summary or "Initial Submission",
)
return submission
except Exception:
logger.exception("Exception occurred whilst creating store submission")
return fastapi.responses.JSONResponse(
@@ -591,19 +590,18 @@ async def generate_image(
tags=["store", "public"],
)
async def download_agent_file(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
store_listing_version_id: str = fastapi.Path(
..., description="The ID of the agent to download"
),
version: typing.Optional[int] = fastapi.Query(
None, description="Specific version of the agent"
),
) -> fastapi.responses.FileResponse:
"""
Download the agent file by streaming its content.
Args:
agent_id (str): The ID of the agent to download.
version (Optional[int]): Specific version of the agent to download.
store_listing_version_id (str): The ID of the agent to download
Returns:
StreamingResponse: A streaming response containing the agent's graph data.
@@ -613,65 +611,18 @@ async def download_agent_file(
"""
graph_data = await backend.server.v2.store.db.get_agent(
store_listing_version_id=store_listing_version_id, version_id=version
user_id=user_id,
store_listing_version_id=store_listing_version_id,
)
graph_data.clean_graph()
graph_date_dict = jsonable_encoder(graph_data)
def remove_credentials(obj):
if obj and isinstance(obj, dict):
if "credentials" in obj:
del obj["credentials"]
if "creds" in obj:
del obj["creds"]
for value in obj.values():
remove_credentials(value)
elif isinstance(obj, list):
for item in obj:
remove_credentials(item)
return obj
graph_date_dict = remove_credentials(graph_date_dict)
file_name = f"agent_{store_listing_version_id}_v{version or 'latest'}.json"
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
# Sending graph as a stream (similar to marketplace v1)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
tmp_file.write(json.dumps(graph_date_dict))
tmp_file.write(backend.util.json.dumps(graph_data))
tmp_file.flush()
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)
@router.post(
"/submissions/review/{store_listing_version_id}",
tags=["store", "private"],
)
async def review_submission(
request: backend.server.v2.store.model.ReviewSubmissionRequest,
user: typing.Annotated[
autogpt_libs.auth.models.User,
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
],
):
# Proceed with the review submission logic
try:
submission = await backend.server.v2.store.db.review_store_submission(
store_listing_version_id=request.store_listing_version_id,
is_approved=request.is_approved,
comments=request.comments,
reviewer_id=user.user_id,
)
return submission
except Exception as e:
logger.error(f"Could not create store submission review: {e}")
raise fastapi.HTTPException(
status_code=500,
detail="An error occurred while creating the store submission review",
)

View File

@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
import uvicorn
from autogpt_libs.auth import parse_jwt_token
from autogpt_libs.utils.cache import thread_cached
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
from starlette.middleware.cors import CORSMiddleware
@@ -12,7 +13,7 @@ from backend.data.execution import AsyncRedisExecutionEventBus
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
from backend.server.model import ExecutionSubscription, Methods, WsMessage
from backend.util.service import AppProcess
from backend.util.service import AppProcess, get_service_client
from backend.util.settings import AppEnvironment, Config, Settings
logger = logging.getLogger(__name__)
@@ -39,6 +40,13 @@ def get_connection_manager():
return _connection_manager
@thread_cached
def get_db_client():
from backend.executor import DatabaseManager
return get_service_client(DatabaseManager)
async def event_broadcaster(manager: ConnectionManager):
try:
redis.connect()
@@ -74,7 +82,10 @@ async def authenticate_websocket(websocket: WebSocket) -> str:
async def handle_subscribe(
websocket: WebSocket, manager: ConnectionManager, message: WsMessage
connection_manager: ConnectionManager,
websocket: WebSocket,
user_id: str,
message: WsMessage,
):
if not message.data:
await websocket.send_text(
@@ -85,20 +96,47 @@ async def handle_subscribe(
).model_dump_json()
)
else:
ex_sub = ExecutionSubscription.model_validate(message.data)
await manager.subscribe(ex_sub.graph_id, ex_sub.graph_version, websocket)
logger.debug(f"New execution subscription for graph {ex_sub.graph_id}")
sub_req = ExecutionSubscription.model_validate(message.data)
# Verify that user has read access to graph
# if not get_db_client().get_graph(
# graph_id=sub_req.graph_id,
# version=sub_req.graph_version,
# user_id=user_id,
# ):
# await websocket.send_text(
# WsMessage(
# method=Methods.ERROR,
# success=False,
# error="Access denied",
# ).model_dump_json()
# )
# return
await connection_manager.subscribe(
user_id=user_id,
graph_id=sub_req.graph_id,
graph_version=sub_req.graph_version,
websocket=websocket,
)
logger.debug(
f"New execution subscription for user #{user_id} "
f"graph #{sub_req.graph_id}v{sub_req.graph_version}"
)
await websocket.send_text(
WsMessage(
method=Methods.SUBSCRIBE,
success=True,
channel=f"{ex_sub.graph_id}_{ex_sub.graph_version}",
channel=f"{user_id}_{sub_req.graph_id}_{sub_req.graph_version}",
).model_dump_json()
)
async def handle_unsubscribe(
websocket: WebSocket, manager: ConnectionManager, message: WsMessage
connection_manager: ConnectionManager,
websocket: WebSocket,
user_id: str,
message: WsMessage,
):
if not message.data:
await websocket.send_text(
@@ -109,14 +147,22 @@ async def handle_unsubscribe(
).model_dump_json()
)
else:
ex_sub = ExecutionSubscription.model_validate(message.data)
await manager.unsubscribe(ex_sub.graph_id, ex_sub.graph_version, websocket)
logger.debug(f"Removed execution subscription for graph {ex_sub.graph_id}")
unsub_req = ExecutionSubscription.model_validate(message.data)
await connection_manager.unsubscribe(
user_id=user_id,
graph_id=unsub_req.graph_id,
graph_version=unsub_req.graph_version,
websocket=websocket,
)
logger.debug(
f"Removed execution subscription for user #{user_id} "
f"graph #{unsub_req.graph_id}v{unsub_req.graph_version}"
)
await websocket.send_text(
WsMessage(
method=Methods.UNSUBSCRIBE,
success=True,
channel=f"{ex_sub.graph_id}_{ex_sub.graph_version}",
channel=f"{unsub_req.graph_id}_{unsub_req.graph_version}",
).model_dump_json()
)
@@ -145,13 +191,32 @@ async def websocket_router(
)
continue
if message.method == Methods.SUBSCRIBE:
await handle_subscribe(websocket, manager, message)
try:
if message.method == Methods.SUBSCRIBE:
await handle_subscribe(
connection_manager=manager,
websocket=websocket,
user_id=user_id,
message=message,
)
continue
elif message.method == Methods.UNSUBSCRIBE:
await handle_unsubscribe(websocket, manager, message)
elif message.method == Methods.UNSUBSCRIBE:
await handle_unsubscribe(
connection_manager=manager,
websocket=websocket,
user_id=user_id,
message=message,
)
continue
except Exception as e:
logger.error(
f"Error while handling '{message.method}' message "
f"for user #{user_id}: {e}"
)
continue
elif message.method == Methods.ERROR:
if message.method == Methods.ERROR:
logger.error(f"WebSocket Error message received: {message.data}")
else:

View File

@@ -1,6 +1,7 @@
from prisma.models import User
from backend.blocks.basic import AgentInputBlock, PrintToConsoleBlock
from backend.blocks.basic import StoreValueBlock
from backend.blocks.io import AgentInputBlock
from backend.blocks.text import FillTextTemplateBlock
from backend.data import graph
from backend.data.graph import create_graph
@@ -29,7 +30,7 @@ def create_test_graph() -> graph.Graph:
"""
InputBlock
\
---- FillTextTemplateBlock ---- PrintToConsoleBlock
---- FillTextTemplateBlock ---- StoreValueBlock
/
InputBlock
"""
@@ -52,7 +53,7 @@ def create_test_graph() -> graph.Graph:
"values_#_c": "!!!",
},
),
graph.Node(block_id=PrintToConsoleBlock().id),
graph.Node(block_id=StoreValueBlock().id),
]
links = [
graph.Link(
@@ -71,7 +72,7 @@ def create_test_graph() -> graph.Graph:
source_id=nodes[2].id,
sink_id=nodes[3].id,
source_name="output",
sink_name="text",
sink_name="input",
),
]
@@ -93,11 +94,7 @@ async def sample_agent():
user_id=test_user.id,
node_input=input_data,
)
print(response)
result = await wait_execution(
test_user.id, test_graph.id, response.graph_exec_id, 10
)
print(result)
await wait_execution(test_user.id, test_graph.id, response.graph_exec_id, 10)
if __name__ == "__main__":

View File

@@ -29,15 +29,25 @@ def clean_exec_files(graph_exec_id: str, file: str = "") -> None:
shutil.rmtree(exec_path)
"""
MediaFile is a string that represents a file. It can be one of the following:
- Data URI: base64 encoded media file. See https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data/
- URL: Media file hosted on the internet, it starts with http:// or https://.
- Local path (anything else): A temporary file path living within graph execution time.
Note: Replace this type alias into a proper class, when more information is needed.
"""
MediaFile = str
class MediaFile(str):
"""
MediaFile is a string that represents a file. It can be one of the following:
- Data URI: base64 encoded media file. See https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data/
- URL: Media file hosted on the internet, it starts with http:// or https://.
- Local path (anything else): A temporary file path living within graph execution time.
Note: Replace this type alias into a proper class, when more information is needed.
"""
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
return handler(str)
@classmethod
def __get_pydantic_json_schema__(cls, core_schema, handler):
json_schema = handler(core_schema)
json_schema["format"] = "file"
return json_schema
def store_media_file(

View File

@@ -44,7 +44,7 @@ from Pyro5 import config as pyro_config
from backend.data import db, rabbitmq, redis
from backend.util.exceptions import InsufficientBalanceError
from backend.util.json import to_dict
from backend.util.process import AppProcess
from backend.util.process import AppProcess, get_service_name
from backend.util.retry import conn_retry
from backend.util.settings import Config, Secrets
@@ -190,7 +190,17 @@ class BaseAppService(AppProcess, ABC):
@classmethod
def get_host(cls) -> str:
return os.environ.get(f"{cls.service_name.upper()}_HOST", api_host)
source_host = os.environ.get(f"{get_service_name().upper()}_HOST", api_host)
target_host = os.environ.get(f"{cls.service_name.upper()}_HOST", api_host)
if source_host == target_host and source_host != api_host:
logger.warning(
f"Service {cls.service_name} is the same host as the source service."
f"Use the localhost of {api_host} instead."
)
return api_host
return target_host
@property
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
@@ -455,7 +465,7 @@ def fastapi_get_service_client(
return response.json()
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error in {method_name}: {e.response.text}")
error = RemoteCallError.model_validate(e.response.json(), strict=False)
error = RemoteCallError.model_validate(e.response.json())
# DEBUG HELP: if you made a custom exception, make sure you override self.args to be how to make your exception
raise EXCEPTION_MAPPING.get(error.type, Exception)(
*(error.args or [str(e)])

View File

@@ -113,6 +113,14 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default="%Y-%W", # This will allow for weekly refunds per user.
description="Time key format for refund requests.",
)
execution_cost_count_threshold: int = Field(
default=100,
description="Number of executions after which the cost is calculated.",
)
execution_cost_per_threshold: int = Field(
default=1,
description="Cost per execution in cents after each threshold.",
)
model_config = SettingsConfigDict(
env_file=".env",
@@ -219,6 +227,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default=True,
description="Whether to use the new agent image generation service",
)
enable_agent_input_subtype_blocks: bool = Field(
default=False,
description="Whether to enable the agent input subtype blocks",
)
@field_validator("platform_base_url", "frontend_base_url")
@classmethod

View File

@@ -0,0 +1,8 @@
/*
Warnings:
- You are about to drop the column `isTemplate` on the `AgentGraph` table. All the data in the column will be lost.
*/
-- AlterTable
ALTER TABLE "AgentGraph" DROP COLUMN "isTemplate";

View File

@@ -0,0 +1,372 @@
/*
Warnings:
- The enum type "SubmissionStatus" will be replaced. The 'DAFT' value is removed, so any data using 'DAFT' will be updated to 'DRAFT'. If there are rows still expecting 'DAFT' after this change, it will fail.
- You are about to drop the column "isApproved" on the "StoreListing" table. All the data in that column will be lost.
- You are about to drop the column "slug" on the "StoreListingVersion" table. All the data in that column will be lost.
- You are about to drop the "StoreListingSubmission" table. Data in that table (beyond what is copied over) will be permanently lost.
- A unique constraint covering the column "activeVersionId" on the "StoreListing" table will be added. If duplicates already exist, this will fail.
- A unique constraint covering the columns ("storeListingId","version") on "StoreListingVersion" will be added. If duplicates already exist, this will fail.
- The "storeListingId" column on "StoreListingVersion" is set to NOT NULL. If any rows currently have a NULL value, this step will fail.
- The views "StoreSubmission", "StoreAgent", and "Creator" are dropped and recreated. Any usage or references to them will be momentarily disrupted until the views are recreated.
*/
BEGIN;
-- First, drop all views that depend on the columns and types we're modifying
DROP VIEW IF EXISTS "StoreSubmission";
DROP VIEW IF EXISTS "StoreAgent";
DROP VIEW IF EXISTS "Creator";
-- Create the new enum type
CREATE TYPE "SubmissionStatus_new" AS ENUM ('DRAFT', 'PENDING', 'APPROVED', 'REJECTED');
-- Modify the column with the correct casing (Status with capital S)
ALTER TABLE "StoreListingSubmission" ALTER COLUMN "Status" DROP DEFAULT;
ALTER TABLE "StoreListingSubmission"
ALTER COLUMN "Status" TYPE "SubmissionStatus_new"
USING (
CASE WHEN "Status"::text = 'DAFT' THEN 'DRAFT'::text
ELSE "Status"::text
END
)::"SubmissionStatus_new";
-- Rename the enum types
ALTER TYPE "SubmissionStatus" RENAME TO "SubmissionStatus_old";
ALTER TYPE "SubmissionStatus_new" RENAME TO "SubmissionStatus";
DROP TYPE "SubmissionStatus_old";
-- Set default back
ALTER TABLE "StoreListingSubmission" ALTER COLUMN "Status" SET DEFAULT 'PENDING';
-- Drop constraints
ALTER TABLE "StoreListingSubmission" DROP CONSTRAINT IF EXISTS "StoreListingSubmission_reviewerId_fkey";
-- Drop indexes
DROP INDEX IF EXISTS "StoreListing_isDeleted_isApproved_idx";
DROP INDEX IF EXISTS "StoreListingSubmission_storeListingVersionId_key";
-- Modify StoreListing
ALTER TABLE "StoreListing"
DROP COLUMN IF EXISTS "isApproved",
ADD COLUMN IF NOT EXISTS "activeVersionId" TEXT,
ADD COLUMN IF NOT EXISTS "hasApprovedVersion" BOOLEAN NOT NULL DEFAULT false,
ADD COLUMN IF NOT EXISTS "slug" TEXT;
-- First add ALL columns to StoreListingVersion (including the submissionStatus column)
ALTER TABLE "StoreListingVersion"
ADD COLUMN IF NOT EXISTS "reviewerId" TEXT,
ADD COLUMN IF NOT EXISTS "reviewComments" TEXT,
ADD COLUMN IF NOT EXISTS "internalComments" TEXT,
ADD COLUMN IF NOT EXISTS "reviewedAt" TIMESTAMP(3),
ADD COLUMN IF NOT EXISTS "changesSummary" TEXT,
ADD COLUMN IF NOT EXISTS "submissionStatus" "SubmissionStatus" NOT NULL DEFAULT 'DRAFT',
ADD COLUMN IF NOT EXISTS "submittedAt" TIMESTAMP(3),
ALTER COLUMN "storeListingId" SET NOT NULL;
-- NOW copy data from StoreListingSubmission to StoreListingVersion
DO $$
BEGIN
-- First, check what columns actually exist in the StoreListingSubmission table
DECLARE
has_reviewerId BOOLEAN := (
SELECT EXISTS (
SELECT FROM information_schema.columns
WHERE table_name = 'StoreListingSubmission'
AND column_name = 'reviewerId'
)
);
has_reviewComments BOOLEAN := (
SELECT EXISTS (
SELECT FROM information_schema.columns
WHERE table_name = 'StoreListingSubmission'
AND column_name = 'reviewComments'
)
);
has_changesSummary BOOLEAN := (
SELECT EXISTS (
SELECT FROM information_schema.columns
WHERE table_name = 'StoreListingSubmission'
AND column_name = 'changesSummary'
)
);
BEGIN
-- Only copy fields that we know exist
IF has_reviewerId THEN
UPDATE "StoreListingVersion" AS v
SET "reviewerId" = s."reviewerId"
FROM "StoreListingSubmission" AS s
WHERE v."id" = s."storeListingVersionId";
END IF;
IF has_reviewComments THEN
UPDATE "StoreListingVersion" AS v
SET "reviewComments" = s."reviewComments"
FROM "StoreListingSubmission" AS s
WHERE v."id" = s."storeListingVersionId";
END IF;
IF has_changesSummary THEN
UPDATE "StoreListingVersion" AS v
SET "changesSummary" = s."changesSummary"
FROM "StoreListingSubmission" AS s
WHERE v."id" = s."storeListingVersionId";
END IF;
END;
-- Update submission status based on StoreListingSubmission status
UPDATE "StoreListingVersion" AS v
SET "submissionStatus" = s."Status"
FROM "StoreListingSubmission" AS s
WHERE v."id" = s."storeListingVersionId";
-- Update reviewedAt timestamps for versions with APPROVED or REJECTED status
UPDATE "StoreListingVersion" AS v
SET "reviewedAt" = s."updatedAt"
FROM "StoreListingSubmission" AS s
WHERE v."id" = s."storeListingVersionId"
AND s."Status" IN ('APPROVED', 'REJECTED');
END;
$$;
-- Drop the StoreListingSubmission table
DROP TABLE IF EXISTS "StoreListingSubmission";
-- Copy slugs from StoreListingVersion to StoreListing
WITH latest_versions AS (
SELECT
"storeListingId",
"slug",
ROW_NUMBER() OVER (PARTITION BY "storeListingId" ORDER BY "version" DESC) as rn
FROM "StoreListingVersion"
)
UPDATE "StoreListing" sl
SET "slug" = lv."slug"
FROM latest_versions lv
WHERE sl."id" = lv."storeListingId"
AND lv.rn = 1;
-- Make StoreListing.slug required and unique
ALTER TABLE "StoreListing" ALTER COLUMN "slug" SET NOT NULL;
CREATE UNIQUE INDEX "StoreListing_owningUserId_slug_key" ON "StoreListing"("owningUserId", "slug");
DROP INDEX "StoreListing_owningUserId_idx";
-- Drop the slug column from StoreListingVersion since it's now on StoreListing
ALTER TABLE "StoreListingVersion" DROP COLUMN "slug";
-- Update both sides of the relation from one-to-one to one-to-many
-- The AgentGraph->StoreListingVersion relationship is now one-to-many
-- Drop the unique constraint but add a non-unique index for query performance
ALTER TABLE "StoreListingVersion" DROP CONSTRAINT IF EXISTS "StoreListingVersion_agentId_agentVersion_key";
CREATE INDEX IF NOT EXISTS "StoreListingVersion_agentId_agentVersion_idx"
ON "StoreListingVersion"("agentId", "agentVersion");
-- Set isApproved based on submissionStatus before removing it
UPDATE "StoreListingVersion"
SET "submissionStatus" = 'APPROVED'
WHERE "isApproved" = true;
-- Drop the isApproved column from StoreListingVersion since it's redundant with submissionStatus
ALTER TABLE "StoreListingVersion" DROP COLUMN "isApproved";
-- Initialize hasApprovedVersion for existing StoreListing rows ***
-- This sets "hasApprovedVersion" = TRUE for any StoreListing
-- that has at least one corresponding version with "APPROVED" status.
UPDATE "StoreListing" sl
SET "hasApprovedVersion" = (
SELECT COUNT(*) > 0
FROM "StoreListingVersion" slv
WHERE slv."storeListingId" = sl.id
AND slv."submissionStatus" = 'APPROVED'
AND sl."agentId" = slv."agentId"
AND sl."agentVersion" = slv."agentVersion"
);
-- Create new indexes
CREATE UNIQUE INDEX IF NOT EXISTS "StoreListing_activeVersionId_key"
ON "StoreListing"("activeVersionId");
CREATE INDEX IF NOT EXISTS "StoreListing_isDeleted_hasApprovedVersion_idx"
ON "StoreListing"("isDeleted", "hasApprovedVersion");
CREATE INDEX IF NOT EXISTS "StoreListingVersion_storeListingId_submissionStatus_isAvailable_idx"
ON "StoreListingVersion"("storeListingId", "submissionStatus", "isAvailable");
CREATE INDEX IF NOT EXISTS "StoreListingVersion_submissionStatus_idx"
ON "StoreListingVersion"("submissionStatus");
CREATE UNIQUE INDEX IF NOT EXISTS "StoreListingVersion_storeListingId_version_key"
ON "StoreListingVersion"("storeListingId", "version");
-- Add foreign keys
ALTER TABLE "StoreListing"
ADD CONSTRAINT "StoreListing_activeVersionId_fkey"
FOREIGN KEY ("activeVersionId") REFERENCES "StoreListingVersion"("id")
ON DELETE SET NULL ON UPDATE CASCADE;
-- Add reviewer foreign key
ALTER TABLE "StoreListingVersion"
ADD CONSTRAINT "StoreListingVersion_reviewerId_fkey"
FOREIGN KEY ("reviewerId") REFERENCES "User"("id")
ON DELETE SET NULL ON UPDATE CASCADE;
-- Add index for reviewer
CREATE INDEX IF NOT EXISTS "StoreListingVersion_reviewerId_idx"
ON "StoreListingVersion"("reviewerId");
-- DropIndex
DROP INDEX "StoreListingVersion_agentId_agentVersion_key";
-- RenameIndex
ALTER INDEX "StoreListingVersion_storeListingId_submissionStatus_isAvailable_idx"
RENAME TO "StoreListingVersion_storeListingId_submissionStatus_isAvail_idx";
-- Recreate the views with updated column references
-- 1. Recreate StoreSubmission view
CREATE VIEW "StoreSubmission" AS
SELECT
sl.id AS listing_id,
sl."owningUserId" AS user_id,
slv."agentId" AS agent_id,
slv.version AS agent_version,
sl.slug,
COALESCE(slv.name, '') AS name,
slv."subHeading" AS sub_heading,
slv.description,
slv."imageUrls" AS image_urls,
slv."submittedAt" AS date_submitted,
slv."submissionStatus" AS status,
COALESCE(ar.run_count, 0::bigint) AS runs,
COALESCE(avg(sr.score::numeric), 0.0)::double precision AS rating,
-- Add the additional fields needed by the Pydantic model
slv.id AS store_listing_version_id,
slv."reviewerId" AS reviewer_id,
slv."reviewComments" AS review_comments,
slv."internalComments" AS internal_comments,
slv."reviewedAt" AS reviewed_at,
slv."changesSummary" AS changes_summary
FROM "StoreListing" sl
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
LEFT JOIN "StoreListingReview" sr ON sr."storeListingVersionId" = slv.id
LEFT JOIN (
SELECT "AgentGraphExecution"."agentGraphId", count(*) AS run_count
FROM "AgentGraphExecution"
GROUP BY "AgentGraphExecution"."agentGraphId"
) ar ON ar."agentGraphId" = slv."agentId"
WHERE sl."isDeleted" = false
GROUP BY sl.id, sl."owningUserId", slv.id, slv."agentId", slv.version, sl.slug, slv.name,
slv."subHeading", slv.description, slv."imageUrls", slv."submittedAt",
slv."submissionStatus", slv."reviewerId", slv."reviewComments", slv."internalComments",
slv."reviewedAt", slv."changesSummary", ar.run_count;
-- 2. Recreate StoreAgent view
CREATE VIEW "StoreAgent" AS
WITH reviewstats AS (
SELECT sl_1.id AS "storeListingId",
count(sr.id) AS review_count,
avg(sr.score::numeric) AS avg_rating
FROM "StoreListing" sl_1
JOIN "StoreListingVersion" slv_1
ON slv_1."storeListingId" = sl_1.id
JOIN "StoreListingReview" sr
ON sr."storeListingVersionId" = slv_1.id
WHERE sl_1."isDeleted" = false
GROUP BY sl_1.id
), agentruns AS (
SELECT "AgentGraphExecution"."agentGraphId",
count(*) AS run_count
FROM "AgentGraphExecution"
GROUP BY "AgentGraphExecution"."agentGraphId"
)
SELECT sl.id AS listing_id,
slv.id AS "storeListingVersionId",
slv."createdAt" AS updated_at,
sl.slug,
COALESCE(slv.name, '') AS agent_name,
slv."videoUrl" AS agent_video,
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
slv."isFeatured" AS featured,
p.username AS creator_username,
p."avatarUrl" AS creator_avatar,
slv."subHeading" AS sub_heading,
slv.description,
slv.categories,
COALESCE(ar.run_count, 0::bigint) AS runs,
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
array_agg(DISTINCT slv.version::text) AS versions
FROM "StoreListing" sl
JOIN "AgentGraph" a
ON sl."agentId" = a.id
AND sl."agentVersion" = a.version
LEFT JOIN "Profile" p
ON sl."owningUserId" = p."userId"
LEFT JOIN "StoreListingVersion" slv
ON slv."storeListingId" = sl.id
LEFT JOIN reviewstats rs
ON sl.id = rs."storeListingId"
LEFT JOIN agentruns ar
ON a.id = ar."agentGraphId"
WHERE sl."isDeleted" = false
AND sl."hasApprovedVersion" = true
AND slv."submissionStatus" = 'APPROVED'
GROUP BY sl.id, slv.id, sl.slug, slv."createdAt", slv.name, slv."videoUrl",
slv."imageUrls", slv."isFeatured", p.username, p."avatarUrl",
slv."subHeading", slv.description, slv.categories, ar.run_count,
rs.avg_rating;
-- 3. Recreate Creator view
CREATE VIEW "Creator" AS
WITH agentstats AS (
SELECT p_1.username,
count(DISTINCT sl.id) AS num_agents,
avg(COALESCE(sr.score, 0)::numeric) AS agent_rating,
sum(COALESCE(age.run_count, 0::bigint)) AS agent_runs
FROM "Profile" p_1
LEFT JOIN "StoreListing" sl
ON sl."owningUserId" = p_1."userId"
LEFT JOIN "StoreListingVersion" slv
ON slv."storeListingId" = sl.id
LEFT JOIN "StoreListingReview" sr
ON sr."storeListingVersionId" = slv.id
LEFT JOIN (
SELECT "AgentGraphExecution"."agentGraphId",
count(*) AS run_count
FROM "AgentGraphExecution"
GROUP BY "AgentGraphExecution"."agentGraphId"
) age ON age."agentGraphId" = sl."agentId"
WHERE sl."isDeleted" = false
AND sl."hasApprovedVersion" = true
AND slv."submissionStatus" = 'APPROVED'
GROUP BY p_1.username
)
SELECT p.username,
p.name,
p."avatarUrl" AS avatar_url,
p.description,
array_agg(DISTINCT cats.c) FILTER (WHERE cats.c IS NOT NULL) AS top_categories,
p.links,
p."isFeatured" AS is_featured,
COALESCE(ast.num_agents, 0::bigint) AS num_agents,
COALESCE(ast.agent_rating, 0.0) AS agent_rating,
COALESCE(ast.agent_runs, 0::numeric) AS agent_runs
FROM "Profile" p
LEFT JOIN agentstats ast
ON ast.username = p.username
LEFT JOIN LATERAL (
SELECT unnest(slv.categories) AS c
FROM "StoreListing" sl
JOIN "StoreListingVersion" slv
ON slv."storeListingId" = sl.id
WHERE sl."owningUserId" = p."userId"
AND sl."isDeleted" = false
AND sl."hasApprovedVersion" = true
AND slv."submissionStatus" = 'APPROVED'
) cats ON true
GROUP BY p.username, p.name, p."avatarUrl", p.description, p.links,
p."isFeatured", ast.num_agents, ast.agent_rating, ast.agent_runs;
COMMIT;

View File

@@ -44,14 +44,14 @@ model User {
AgentPreset AgentPreset[]
LibraryAgent LibraryAgent[]
Profile Profile[]
UserOnboarding UserOnboarding?
StoreListing StoreListing[]
StoreListingReview StoreListingReview[]
StoreListingSubmission StoreListingSubmission[]
APIKeys APIKey[]
IntegrationWebhooks IntegrationWebhook[]
UserNotificationBatch UserNotificationBatch[]
Profile Profile[]
UserOnboarding UserOnboarding?
StoreListing StoreListing[]
StoreListingReview StoreListingReview[]
StoreVersionsReviewed StoreListingVersion[]
APIKeys APIKey[]
IntegrationWebhooks IntegrationWebhook[]
UserNotificationBatch UserNotificationBatch[]
@@index([id])
@@index([email])
@@ -71,7 +71,7 @@ model UserOnboarding {
isCompleted Boolean @default(false)
userId String @unique
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@index([userId])
}
@@ -86,14 +86,13 @@ model AgentGraph {
name String?
description String?
isActive Boolean @default(true)
isTemplate Boolean @default(false)
isActive Boolean @default(true)
// Link to User model
userId String
// FIX: Do not cascade delete the agent when the user is deleted
// This allows us to delete user data with deleting the agent which maybe in use by other users
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
AgentNodes AgentNode[]
AgentGraphExecution AgentGraphExecution[]
@@ -101,7 +100,7 @@ model AgentGraph {
AgentPreset AgentPreset[]
LibraryAgent LibraryAgent[]
StoreListing StoreListing[]
StoreListingVersion StoreListingVersion?
StoreListingVersion StoreListingVersion[]
@@id(name: "graphVersionId", [id, version])
@@index([userId, isActive])
@@ -176,11 +175,11 @@ model UserNotificationBatch {
updatedAt DateTime @default(now()) @updatedAt
userId String
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
type NotificationType
notifications NotificationEvent[]
Notifications NotificationEvent[]
// Each user can only have one batch of a notification type at a time
@@unique([userId, type])
@@ -196,7 +195,7 @@ model LibraryAgent {
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
imageUrl String?
imageUrl String?
agentId String
agentVersion Int
@@ -320,7 +319,7 @@ model AgentGraphExecution {
// Link to User model -- Executed by this user
userId String
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
stats Json?
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
@@ -385,7 +384,7 @@ model IntegrationWebhook {
updatedAt DateTime? @updatedAt
userId String
user User @relation(fields: [userId], references: [id], onDelete: Restrict) // Webhooks must be deregistered before deleting
User User @relation(fields: [userId], references: [id], onDelete: Restrict) // Webhooks must be deregistered before deleting
provider String // e.g. 'github'
credentialsId String // relation to the credentials that the webhook was created with
@@ -412,7 +411,7 @@ model AnalyticsDetails {
// Link to User model
userId String
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
// Analytics Categorical data used for filtering (indexable w and w/o userId)
type String
@@ -447,7 +446,7 @@ model AnalyticsMetrics {
// Link to User model
userId String
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@index([userId])
}
@@ -471,7 +470,7 @@ model CreditTransaction {
createdAt DateTime @default(now())
userId String
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
amount Int
type CreditTransactionType
@@ -580,19 +579,25 @@ view StoreAgent {
}
view StoreSubmission {
listing_id String @id
user_id String
slug String
name String
sub_heading String
description String
image_urls String[]
date_submitted DateTime
status SubmissionStatus
runs Int
rating Float
agent_id String
agent_version Int
listing_id String @id
user_id String
slug String
name String
sub_heading String
description String
image_urls String[]
date_submitted DateTime
status SubmissionStatus
runs Int
rating Float
agent_id String
agent_version Int
store_listing_version_id String
reviewer_id String?
review_comments String?
internal_comments String?
reviewed_at DateTime?
changes_summary String?
// Index or unique are not applied to views
}
@@ -602,11 +607,18 @@ model StoreListing {
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
isDeleted Boolean @default(false)
// Not needed but makes lookups faster
isApproved Boolean @default(false)
isDeleted Boolean @default(false)
// Whether any version has been approved and is available for display
hasApprovedVersion Boolean @default(false)
// The agent link here is only so we can do lookup on agentId, for the listing the StoreListingVersion is used.
// URL-friendly identifier for this agent (moved from StoreListingVersion)
slug String
// The currently active version that should be shown to users
activeVersionId String? @unique
ActiveVersion StoreListingVersion? @relation("ActiveVersion", fields: [activeVersionId], references: [id])
// The agent link here is only so we can do lookup on agentId
agentId String
agentVersion Int
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
@@ -614,14 +626,14 @@ model StoreListing {
owningUserId String
OwningUser User @relation(fields: [owningUserId], references: [id])
StoreListingVersions StoreListingVersion[]
StoreListingSubmission StoreListingSubmission[]
// Relations
Versions StoreListingVersion[] @relation("ListingVersions")
// Unique index on agentId to ensure only one listing per agent, regardless of number of versions the agent has.
@@unique([agentId])
@@index([owningUserId])
@@unique([owningUserId, slug])
// Used in the view query
@@index([isDeleted, isApproved])
@@index([isDeleted, hasApprovedVersion])
}
model StoreListingVersion {
@@ -635,10 +647,7 @@ model StoreListingVersion {
agentVersion Int
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version])
// The details for this version of the agent, this allows the author to update the details of the agent,
// But still allow using old versions of the agent with there original details.
// TODO: Create a database view that shows only the latest version of each store listing.
slug String
// Content fields
name String
subHeading String
videoUrl String?
@@ -648,20 +657,39 @@ model StoreListingVersion {
isFeatured Boolean @default(false)
isDeleted Boolean @default(false)
isDeleted Boolean @default(false)
// Old versions can be made unavailable by the author if desired
isAvailable Boolean @default(true)
// Not needed but makes lookups faster
isApproved Boolean @default(false)
StoreListing StoreListing? @relation(fields: [storeListingId], references: [id], onDelete: Cascade)
storeListingId String?
StoreListingSubmission StoreListingSubmission[]
isAvailable Boolean @default(true)
// Reviews are on a specific version, but then aggregated up to the listing.
// This allows us to provide a review filter to current version of the agent.
StoreListingReview StoreListingReview[]
// Version workflow state
submissionStatus SubmissionStatus @default(DRAFT)
submittedAt DateTime?
@@unique([agentId, agentVersion])
// Relations
storeListingId String
StoreListing StoreListing @relation("ListingVersions", fields: [storeListingId], references: [id], onDelete: Cascade)
// This version might be the active version for a listing
ActiveFor StoreListing? @relation("ActiveVersion")
// Submission history
changesSummary String?
// Review information
reviewerId String?
Reviewer User? @relation(fields: [reviewerId], references: [id])
internalComments String? // Private notes for admin use only
reviewComments String? // Comments visible to creator
reviewedAt DateTime?
// Reviews for this specific version
Reviews StoreListingReview[]
@@unique([storeListingId, version])
@@index([storeListingId, submissionStatus, isAvailable])
@@index([submissionStatus])
@@index([reviewerId])
@@index([agentId, agentVersion]) // Non-unique index for efficient lookups
}
model StoreListingReview {
@@ -682,31 +710,10 @@ model StoreListingReview {
}
enum SubmissionStatus {
DAFT
PENDING
APPROVED
REJECTED
}
model StoreListingSubmission {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
storeListingId String
StoreListing StoreListing @relation(fields: [storeListingId], references: [id], onDelete: Cascade)
storeListingVersionId String
StoreListingVersion StoreListingVersion @relation(fields: [storeListingVersionId], references: [id], onDelete: Cascade)
reviewerId String
Reviewer User @relation(fields: [reviewerId], references: [id])
Status SubmissionStatus @default(PENDING)
reviewComments String?
@@unique([storeListingVersionId])
@@index([storeListingId])
DRAFT // Being prepared, not yet submitted
PENDING // Submitted, awaiting review
APPROVED // Reviewed and approved
REJECTED // Reviewed and rejected
}
enum APIKeyPermission {
@@ -733,7 +740,7 @@ model APIKey {
// Relation to user
userId String
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@index([key])
@@index([prefix])

View File

@@ -5,9 +5,11 @@ from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction
from backend.blocks.llm import AITextGeneratorBlock
from backend.data.block import get_block
from backend.data.credit import BetaUserCredit
from backend.data.execution import NodeExecutionEntry
from backend.data.user import DEFAULT_USER_ID
from backend.executor.utils import UsageTransactionMetadata, block_usage_cost
from backend.integrations.credentials_store import openai_credentials
from backend.util.test import SpinTestServer
@@ -27,13 +29,36 @@ async def top_up(amount: int):
)
async def spend_credits(entry: NodeExecutionEntry) -> int:
block = get_block(entry.block_id)
if not block:
raise RuntimeError(f"Block {entry.block_id} not found")
cost, matching_filter = block_usage_cost(block=block, input_data=entry.data)
await user_credit.spend_credits(
entry.user_id,
cost,
UsageTransactionMetadata(
graph_exec_id=entry.graph_exec_id,
graph_id=entry.graph_id,
node_id=entry.node_id,
node_exec_id=entry.node_exec_id,
block_id=entry.block_id,
block=entry.block_id,
input=matching_filter,
),
)
return cost
@pytest.mark.asyncio(scope="session")
async def test_block_credit_usage(server: SpinTestServer):
await disable_test_user_transactions()
await top_up(100)
current_credit = await user_credit.get_credits(DEFAULT_USER_ID)
spending_amount_1 = await user_credit.spend_credits(
spending_amount_1 = await spend_credits(
NodeExecutionEntry(
user_id=DEFAULT_USER_ID,
graph_id="test_graph",
@@ -50,12 +75,10 @@ async def test_block_credit_usage(server: SpinTestServer):
},
},
),
0.0,
0.0,
)
assert spending_amount_1 > 0
spending_amount_2 = await user_credit.spend_credits(
spending_amount_2 = await spend_credits(
NodeExecutionEntry(
user_id=DEFAULT_USER_ID,
graph_id="test_graph",
@@ -65,8 +88,6 @@ async def test_block_credit_usage(server: SpinTestServer):
block_id=AITextGeneratorBlock().id,
data={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
),
0.0,
0.0,
)
assert spending_amount_2 == 0

View File

@@ -6,7 +6,8 @@ import fastapi.exceptions
import pytest
import backend.server.v2.store.model as store
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock, StoreValueBlock
from backend.blocks.basic import StoreValueBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.data.block import BlockSchema
from backend.data.graph import Graph, Link, Node
from backend.data.model import SchemaField
@@ -199,7 +200,9 @@ async def test_clean_graph(server: SpinTestServer):
)
# Clean the graph
created_graph.clean_graph()
created_graph = await server.agent_server.test_get_graph(
created_graph.id, created_graph.version, DEFAULT_USER_ID, for_export=True
)
# # Verify input block value is cleared
input_node = next(
@@ -240,7 +243,7 @@ async def test_access_store_listing_graph(server: SpinTestServer):
store_submission_request = store.StoreSubmissionRequest(
agent_id=created_graph.id,
agent_version=created_graph.version,
slug="test-slug",
slug=created_graph.id,
name="Test name",
sub_heading="Test sub heading",
video_url=None,

View File

@@ -7,7 +7,8 @@ from prisma.models import User
import backend.server.v2.library.model
import backend.server.v2.store.model
from backend.blocks.basic import AgentInputBlock, FindInDictionaryBlock, StoreValueBlock
from backend.blocks.basic import FindInDictionaryBlock, StoreValueBlock
from backend.blocks.io import AgentInputBlock
from backend.blocks.maths import CalculatorBlock, Operation
from backend.data import execution, graph
from backend.server.model import CreateGraph
@@ -123,8 +124,8 @@ async def assert_sample_graph_executions(
logger.info(f"Checking PrintToConsoleBlock execution: {exec}")
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"status": ["printed"]}
assert exec.input_data == {"text": "Hello, World!!!"}
assert exec.output_data == {"output": ["Hello, World!!!"]}
assert exec.input_data == {"input": "Hello, World!!!"}
assert exec.node_id == test_graph.nodes[3].id
@@ -494,7 +495,7 @@ async def test_store_listing_graph(server: SpinTestServer):
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
agent_id=test_graph.id,
agent_version=test_graph.version,
slug="test-slug",
slug=test_graph.id,
name="Test name",
sub_heading="Test sub heading",
video_url=None,

View File

@@ -46,17 +46,27 @@ def test_disconnect(
async def test_subscribe(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
await connection_manager.subscribe("test_graph", 1, mock_websocket)
assert mock_websocket in connection_manager.subscriptions["test_graph_1"]
await connection_manager.subscribe(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
assert mock_websocket in connection_manager.subscriptions["user-1_test_graph_1"]
@pytest.mark.asyncio
async def test_unsubscribe(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
connection_manager.subscriptions["test_graph_1"] = {mock_websocket}
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
await connection_manager.unsubscribe("test_graph", 1, mock_websocket)
await connection_manager.unsubscribe(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
assert "test_graph" not in connection_manager.subscriptions
@@ -65,8 +75,9 @@ async def test_unsubscribe(
async def test_send_execution_result(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
connection_manager.subscriptions["test_graph_1"] = {mock_websocket}
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
result: ExecutionResult = ExecutionResult(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec_id",
@@ -87,17 +98,45 @@ async def test_send_execution_result(
mock_websocket.send_text.assert_called_once_with(
WsMessage(
method=Methods.EXECUTION_EVENT,
channel="test_graph_1",
channel="user-1_test_graph_1",
data=result.model_dump(),
).model_dump_json()
)
@pytest.mark.asyncio
async def test_send_execution_result_user_mismatch(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
result: ExecutionResult = ExecutionResult(
user_id="user-2",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec_id",
node_exec_id="test_node_exec_id",
node_id="test_node_id",
block_id="test_block_id",
status=ExecutionStatus.COMPLETED,
input_data={"input1": "value1"},
output_data={"output1": ["result1"]},
add_time=datetime.now(tz=timezone.utc),
queue_time=None,
start_time=datetime.now(tz=timezone.utc),
end_time=datetime.now(tz=timezone.utc),
)
await connection_manager.send_execution_result(result)
mock_websocket.send_text.assert_not_called()
@pytest.mark.asyncio
async def test_send_execution_result_no_subscribers(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
result: ExecutionResult = ExecutionResult(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec_id",

View File

@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock
import pytest
from fastapi import WebSocket, WebSocketDisconnect
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
from backend.server.ws_api import (
Methods,
@@ -41,7 +42,12 @@ async def test_websocket_router_subscribe(
)
mock_manager.connect.assert_called_once_with(mock_websocket)
mock_manager.subscribe.assert_called_once_with("test_graph", 1, mock_websocket)
mock_manager.subscribe.assert_called_once_with(
user_id=DEFAULT_USER_ID,
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@@ -65,7 +71,12 @@ async def test_websocket_router_unsubscribe(
)
mock_manager.connect.assert_called_once_with(mock_websocket)
mock_manager.unsubscribe.assert_called_once_with("test_graph", 1, mock_websocket)
mock_manager.unsubscribe.assert_called_once_with(
user_id=DEFAULT_USER_ID,
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@@ -101,10 +112,18 @@ async def test_handle_subscribe_success(
)
await handle_subscribe(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.subscribe.assert_called_once_with("test_graph", 1, mock_websocket)
mock_manager.subscribe.assert_called_once_with(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@@ -117,7 +136,10 @@ async def test_handle_subscribe_missing_data(
message = WsMessage(method=Methods.SUBSCRIBE)
await handle_subscribe(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.subscribe.assert_not_called()
@@ -135,10 +157,18 @@ async def test_handle_unsubscribe_success(
)
await handle_unsubscribe(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.unsubscribe.assert_called_once_with("test_graph", 1, mock_websocket)
mock_manager.unsubscribe.assert_called_once_with(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@@ -151,7 +181,10 @@ async def test_handle_unsubscribe_missing_data(
message = WsMessage(method=Methods.UNSUBSCRIBE)
await handle_unsubscribe(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.unsubscribe.assert_not_called()

View File

@@ -91,7 +91,6 @@ async def main():
"description": faker.text(max_nb_chars=200),
"userId": user.id,
"isActive": True,
"isTemplate": False,
}
)
agent_graphs.append(graph)
@@ -329,12 +328,14 @@ async def main():
print(f"Inserting {NUM_USERS} store listings")
for graph in agent_graphs:
user = random.choice(users)
slug = faker.slug()
listing = await db.storelisting.create(
data={
"agentId": graph.id,
"agentVersion": graph.version,
"owningUserId": user.id,
"isApproved": random.choice([True, False]),
"hasApprovedVersion": random.choice([True, False]),
"slug": slug,
}
)
store_listings.append(listing)
@@ -348,7 +349,6 @@ async def main():
data={
"agentId": graph.id,
"agentVersion": graph.version,
"slug": faker.slug(),
"name": graph.name or faker.sentence(nb_words=3),
"subHeading": faker.sentence(),
"videoUrl": faker.url(),
@@ -357,8 +357,14 @@ async def main():
"categories": [faker.word() for _ in range(3)],
"isFeatured": random.choice([True, False]),
"isAvailable": True,
"isApproved": random.choice([True, False]),
"storeListingId": listing.id,
"submissionStatus": random.choice(
[
prisma.enums.SubmissionStatus.PENDING,
prisma.enums.SubmissionStatus.APPROVED,
prisma.enums.SubmissionStatus.REJECTED,
]
),
}
)
store_listing_versions.append(version)
@@ -387,10 +393,9 @@ async def main():
}
)
# Insert StoreListingSubmissions
print(f"Inserting {NUM_USERS} store listing submissions")
for listing in store_listings:
version = random.choice(store_listing_versions)
# Update StoreListingVersions with submission status (StoreListingSubmissions table no longer exists)
print(f"Updating {NUM_USERS} store listing versions with submission status")
for version in store_listing_versions:
reviewer = random.choice(users)
status: prisma.enums.SubmissionStatus = random.choice(
[
@@ -399,14 +404,14 @@ async def main():
prisma.enums.SubmissionStatus.REJECTED,
]
)
await db.storelistingsubmission.create(
await db.storelistingversion.update(
where={"id": version.id},
data={
"storeListingId": listing.id,
"storeListingVersionId": version.id,
"reviewerId": reviewer.id,
"Status": status,
"submissionStatus": status,
"Reviewer": {"connect": {"id": reviewer.id}},
"reviewComments": faker.text(),
}
"reviewedAt": datetime.now(),
},
)
# Insert APIKeys

View File

@@ -0,0 +1,123 @@
############
# Secrets
# YOU MUST CHANGE THESE BEFORE GOING INTO PRODUCTION
############
POSTGRES_PASSWORD=your-super-secret-and-long-postgres-password
JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
DASHBOARD_USERNAME=supabase
DASHBOARD_PASSWORD=this_password_is_insecure_and_should_be_updated
SECRET_KEY_BASE=UpNVntn3cDxHJpq99YMc1T1AQgQpc8kfYTuRgBiYa15BLrx8etQoXz3gZv1/u2oq
VAULT_ENC_KEY=your-encryption-key-32-chars-min
############
# Database - You can change these to any PostgreSQL database that has logical replication enabled.
############
POSTGRES_HOST=db
POSTGRES_DB=postgres
POSTGRES_PORT=5432
# default user is postgres
############
# Supavisor -- Database pooler
############
POOLER_PROXY_PORT_TRANSACTION=6543
POOLER_DEFAULT_POOL_SIZE=20
POOLER_MAX_CLIENT_CONN=100
POOLER_TENANT_ID=your-tenant-id
############
# API Proxy - Configuration for the Kong Reverse proxy.
############
KONG_HTTP_PORT=8000
KONG_HTTPS_PORT=8443
############
# API - Configuration for PostgREST.
############
PGRST_DB_SCHEMAS=public,storage,graphql_public
############
# Auth - Configuration for the GoTrue authentication server.
############
## General
SITE_URL=http://localhost:3000
ADDITIONAL_REDIRECT_URLS=
JWT_EXPIRY=3600
DISABLE_SIGNUP=false
API_EXTERNAL_URL=http://localhost:8000
## Mailer Config
MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify"
MAILER_URLPATHS_INVITE="/auth/v1/verify"
MAILER_URLPATHS_RECOVERY="/auth/v1/verify"
MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify"
## Email auth
ENABLE_EMAIL_SIGNUP=true
ENABLE_EMAIL_AUTOCONFIRM=false
SMTP_ADMIN_EMAIL=admin@example.com
SMTP_HOST=supabase-mail
SMTP_PORT=2500
SMTP_USER=fake_mail_user
SMTP_PASS=fake_mail_password
SMTP_SENDER_NAME=fake_sender
ENABLE_ANONYMOUS_USERS=false
## Phone auth
ENABLE_PHONE_SIGNUP=true
ENABLE_PHONE_AUTOCONFIRM=true
############
# Studio - Configuration for the Dashboard
############
STUDIO_DEFAULT_ORGANIZATION=Default Organization
STUDIO_DEFAULT_PROJECT=Default Project
STUDIO_PORT=3000
# replace if you intend to use Studio outside of localhost
SUPABASE_PUBLIC_URL=http://localhost:8000
# Enable webp support
IMGPROXY_ENABLE_WEBP_DETECTION=true
# Add your OpenAI API key to enable SQL Editor Assistant
OPENAI_API_KEY=
############
# Functions - Configuration for Functions
############
# NOTE: VERIFY_JWT applies to all functions. Per-function VERIFY_JWT is not supported yet.
FUNCTIONS_VERIFY_JWT=false
############
# Logs - Configuration for Logflare
# Please refer to https://supabase.com/docs/reference/self-hosting-analytics/introduction
############
LOGFLARE_LOGGER_BACKEND_API_KEY=your-super-secret-and-long-logflare-key
# Change vector.toml sinks to reflect this change
LOGFLARE_API_KEY=your-super-secret-and-long-logflare-key
# Docker socket location - this value will differ depending on your OS
DOCKER_SOCKET_LOCATION=/var/run/docker.sock
# Google Cloud Project details
GOOGLE_PROJECT_ID=GOOGLE_PROJECT_ID
GOOGLE_PROJECT_NUMBER=GOOGLE_PROJECT_NUMBER

5
autogpt_platform/db/docker/.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
volumes/db/data
volumes/storage
.env
test.http
docker-compose.override.yml

View File

@@ -0,0 +1,3 @@
# Supabase Docker
This is a minimal Docker Compose setup for self-hosting Supabase. Follow the steps [here](https://supabase.com/docs/guides/hosting/docker) to get started.

View File

@@ -0,0 +1,48 @@
create table profiles (
id uuid references auth.users not null,
updated_at timestamp with time zone,
username text unique,
avatar_url text,
website text,
primary key (id),
unique(username),
constraint username_length check (char_length(username) >= 3)
);
alter table profiles enable row level security;
create policy "Public profiles are viewable by the owner."
on profiles for select
using ( auth.uid() = id );
create policy "Users can insert their own profile."
on profiles for insert
with check ( auth.uid() = id );
create policy "Users can update own profile."
on profiles for update
using ( auth.uid() = id );
-- Set up Realtime
begin;
drop publication if exists supabase_realtime;
create publication supabase_realtime;
commit;
alter publication supabase_realtime add table profiles;
-- Set up Storage
insert into storage.buckets (id, name)
values ('avatars', 'avatars');
create policy "Avatar images are publicly accessible."
on storage.objects for select
using ( bucket_id = 'avatars' );
create policy "Anyone can upload an avatar."
on storage.objects for insert
with check ( bucket_id = 'avatars' );
create policy "Anyone can update an avatar."
on storage.objects for update
with check ( bucket_id = 'avatars' );

View File

@@ -0,0 +1,34 @@
version: "3.8"
services:
studio:
build:
context: ..
dockerfile: studio/Dockerfile
target: dev
ports:
- 8082:8082
mail:
container_name: supabase-mail
image: inbucket/inbucket:3.0.3
ports:
- '2500:2500' # SMTP
- '9000:9000' # web interface
- '1100:1100' # POP3
auth:
environment:
- GOTRUE_SMTP_USER=
- GOTRUE_SMTP_PASS=
meta:
ports:
- 5555:8080
db:
restart: 'no'
volumes:
# Always use a fresh database when developing
- /var/lib/postgresql/data
# Seed data should be inserted last (alphabetical order)
- ./dev/data.sql:/docker-entrypoint-initdb.d/seed.sql
storage:
volumes:
- /var/lib/storage

View File

@@ -0,0 +1,94 @@
services:
minio:
image: minio/minio
ports:
- '9000:9000'
- '9001:9001'
environment:
MINIO_ROOT_USER: supa-storage
MINIO_ROOT_PASSWORD: secret1234
command: server --console-address ":9001" /data
healthcheck:
test: [ "CMD", "curl", "-f", "http://minio:9000/minio/health/live" ]
interval: 2s
timeout: 10s
retries: 5
volumes:
- ./volumes/storage:/data:z
minio-createbucket:
image: minio/mc
depends_on:
minio:
condition: service_healthy
entrypoint: >
/bin/sh -c "
/usr/bin/mc alias set supa-minio http://minio:9000 supa-storage secret1234;
/usr/bin/mc mb supa-minio/stub;
exit 0;
"
storage:
container_name: supabase-storage
image: supabase/storage-api:v1.11.13
depends_on:
db:
# Disable this if you are using an external Postgres database
condition: service_healthy
rest:
condition: service_started
imgproxy:
condition: service_started
minio:
condition: service_healthy
healthcheck:
test:
[
"CMD",
"wget",
"--no-verbose",
"--tries=1",
"--spider",
"http://localhost:5000/status"
]
timeout: 5s
interval: 5s
retries: 3
restart: unless-stopped
environment:
ANON_KEY: ${ANON_KEY}
SERVICE_KEY: ${SERVICE_ROLE_KEY}
POSTGREST_URL: http://rest:3000
PGRST_JWT_SECRET: ${JWT_SECRET}
DATABASE_URL: postgres://supabase_storage_admin:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
FILE_SIZE_LIMIT: 52428800
STORAGE_BACKEND: s3
GLOBAL_S3_BUCKET: stub
GLOBAL_S3_ENDPOINT: http://minio:9000
GLOBAL_S3_PROTOCOL: http
GLOBAL_S3_FORCE_PATH_STYLE: true
AWS_ACCESS_KEY_ID: supa-storage
AWS_SECRET_ACCESS_KEY: secret1234
AWS_DEFAULT_REGION: stub
FILE_STORAGE_BACKEND_PATH: /var/lib/storage
TENANT_ID: stub
# TODO: https://github.com/supabase/storage-api/issues/55
REGION: stub
ENABLE_IMAGE_TRANSFORMATION: "true"
IMGPROXY_URL: http://imgproxy:5001
volumes:
- ./volumes/storage:/var/lib/storage:z
imgproxy:
container_name: supabase-imgproxy
image: darthsim/imgproxy:v3.8.0
healthcheck:
test: [ "CMD", "imgproxy", "health" ]
timeout: 5s
interval: 5s
retries: 3
environment:
IMGPROXY_BIND: ":5001"
IMGPROXY_USE_ETAG: "true"
IMGPROXY_ENABLE_WEBP_DETECTION: ${IMGPROXY_ENABLE_WEBP_DETECTION}

View File

@@ -0,0 +1,526 @@
# Usage
# Start: docker compose up
# With helpers: docker compose -f docker-compose.yml -f ./dev/docker-compose.dev.yml up
# Stop: docker compose down
# Destroy: docker compose -f docker-compose.yml -f ./dev/docker-compose.dev.yml down -v --remove-orphans
# Reset everything: ./reset.sh
name: supabase
services:
studio:
container_name: supabase-studio
image: supabase/studio:20250224-d10db0f
restart: unless-stopped
healthcheck:
test:
[
"CMD",
"node",
"-e",
"fetch('http://studio:3000/api/platform/profile').then((r) => {if (r.status !== 200) throw new Error(r.status)})"
]
timeout: 10s
interval: 5s
retries: 3
depends_on:
analytics:
condition: service_healthy
environment:
STUDIO_PG_META_URL: http://meta:8080
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
DEFAULT_ORGANIZATION_NAME: ${STUDIO_DEFAULT_ORGANIZATION}
DEFAULT_PROJECT_NAME: ${STUDIO_DEFAULT_PROJECT}
OPENAI_API_KEY: ${OPENAI_API_KEY:-}
SUPABASE_URL: http://kong:8000
SUPABASE_PUBLIC_URL: ${SUPABASE_PUBLIC_URL}
SUPABASE_ANON_KEY: ${ANON_KEY}
SUPABASE_SERVICE_KEY: ${SERVICE_ROLE_KEY}
AUTH_JWT_SECRET: ${JWT_SECRET}
LOGFLARE_API_KEY: ${LOGFLARE_API_KEY}
LOGFLARE_URL: http://analytics:4000
NEXT_PUBLIC_ENABLE_LOGS: true
# Comment to use Big Query backend for analytics
NEXT_ANALYTICS_BACKEND_PROVIDER: postgres
# Uncomment to use Big Query backend for analytics
# NEXT_ANALYTICS_BACKEND_PROVIDER: bigquery
kong:
container_name: supabase-kong
image: kong:2.8.1
restart: unless-stopped
ports:
- ${KONG_HTTP_PORT}:8000/tcp
- ${KONG_HTTPS_PORT}:8443/tcp
volumes:
# https://github.com/supabase/supabase/issues/12661
- ./volumes/api/kong.yml:/home/kong/temp.yml:ro
depends_on:
analytics:
condition: service_healthy
environment:
KONG_DATABASE: "off"
KONG_DECLARATIVE_CONFIG: /home/kong/kong.yml
# https://github.com/supabase/cli/issues/14
KONG_DNS_ORDER: LAST,A,CNAME
KONG_PLUGINS: request-transformer,cors,key-auth,acl,basic-auth
KONG_NGINX_PROXY_PROXY_BUFFER_SIZE: 160k
KONG_NGINX_PROXY_PROXY_BUFFERS: 64 160k
SUPABASE_ANON_KEY: ${ANON_KEY}
SUPABASE_SERVICE_KEY: ${SERVICE_ROLE_KEY}
DASHBOARD_USERNAME: ${DASHBOARD_USERNAME}
DASHBOARD_PASSWORD: ${DASHBOARD_PASSWORD}
# https://unix.stackexchange.com/a/294837
entrypoint: bash -c 'eval "echo \"$$(cat ~/temp.yml)\"" > ~/kong.yml && /docker-entrypoint.sh kong docker-start'
auth:
container_name: supabase-auth
image: supabase/gotrue:v2.170.0
restart: unless-stopped
healthcheck:
test:
[
"CMD",
"wget",
"--no-verbose",
"--tries=1",
"--spider",
"http://localhost:9999/health"
]
timeout: 5s
interval: 5s
retries: 3
depends_on:
db:
# Disable this if you are using an external Postgres database
condition: service_healthy
analytics:
condition: service_healthy
environment:
GOTRUE_API_HOST: 0.0.0.0
GOTRUE_API_PORT: 9999
API_EXTERNAL_URL: ${API_EXTERNAL_URL}
GOTRUE_DB_DRIVER: postgres
GOTRUE_DB_DATABASE_URL: postgres://supabase_auth_admin:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
GOTRUE_SITE_URL: ${SITE_URL}
GOTRUE_URI_ALLOW_LIST: ${ADDITIONAL_REDIRECT_URLS}
GOTRUE_DISABLE_SIGNUP: ${DISABLE_SIGNUP}
GOTRUE_JWT_ADMIN_ROLES: service_role
GOTRUE_JWT_AUD: authenticated
GOTRUE_JWT_DEFAULT_GROUP_NAME: authenticated
GOTRUE_JWT_EXP: ${JWT_EXPIRY}
GOTRUE_JWT_SECRET: ${JWT_SECRET}
GOTRUE_EXTERNAL_EMAIL_ENABLED: ${ENABLE_EMAIL_SIGNUP}
GOTRUE_EXTERNAL_ANONYMOUS_USERS_ENABLED: ${ENABLE_ANONYMOUS_USERS}
GOTRUE_MAILER_AUTOCONFIRM: ${ENABLE_EMAIL_AUTOCONFIRM}
# Uncomment to bypass nonce check in ID Token flow. Commonly set to true when using Google Sign In on mobile.
# GOTRUE_EXTERNAL_SKIP_NONCE_CHECK: true
# GOTRUE_MAILER_SECURE_EMAIL_CHANGE_ENABLED: true
# GOTRUE_SMTP_MAX_FREQUENCY: 1s
GOTRUE_SMTP_ADMIN_EMAIL: ${SMTP_ADMIN_EMAIL}
GOTRUE_SMTP_HOST: ${SMTP_HOST}
GOTRUE_SMTP_PORT: ${SMTP_PORT}
GOTRUE_SMTP_USER: ${SMTP_USER}
GOTRUE_SMTP_PASS: ${SMTP_PASS}
GOTRUE_SMTP_SENDER_NAME: ${SMTP_SENDER_NAME}
GOTRUE_MAILER_URLPATHS_INVITE: ${MAILER_URLPATHS_INVITE}
GOTRUE_MAILER_URLPATHS_CONFIRMATION: ${MAILER_URLPATHS_CONFIRMATION}
GOTRUE_MAILER_URLPATHS_RECOVERY: ${MAILER_URLPATHS_RECOVERY}
GOTRUE_MAILER_URLPATHS_EMAIL_CHANGE: ${MAILER_URLPATHS_EMAIL_CHANGE}
GOTRUE_EXTERNAL_PHONE_ENABLED: ${ENABLE_PHONE_SIGNUP}
GOTRUE_SMS_AUTOCONFIRM: ${ENABLE_PHONE_AUTOCONFIRM}
# Uncomment to enable custom access token hook. Please see: https://supabase.com/docs/guides/auth/auth-hooks for full list of hooks and additional details about custom_access_token_hook
# GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_ENABLED: "true"
# GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_URI: "pg-functions://postgres/public/custom_access_token_hook"
# GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_SECRETS: "<standard-base64-secret>"
# GOTRUE_HOOK_MFA_VERIFICATION_ATTEMPT_ENABLED: "true"
# GOTRUE_HOOK_MFA_VERIFICATION_ATTEMPT_URI: "pg-functions://postgres/public/mfa_verification_attempt"
# GOTRUE_HOOK_PASSWORD_VERIFICATION_ATTEMPT_ENABLED: "true"
# GOTRUE_HOOK_PASSWORD_VERIFICATION_ATTEMPT_URI: "pg-functions://postgres/public/password_verification_attempt"
# GOTRUE_HOOK_SEND_SMS_ENABLED: "false"
# GOTRUE_HOOK_SEND_SMS_URI: "pg-functions://postgres/public/custom_access_token_hook"
# GOTRUE_HOOK_SEND_SMS_SECRETS: "v1,whsec_VGhpcyBpcyBhbiBleGFtcGxlIG9mIGEgc2hvcnRlciBCYXNlNjQgc3RyaW5n"
# GOTRUE_HOOK_SEND_EMAIL_ENABLED: "false"
# GOTRUE_HOOK_SEND_EMAIL_URI: "http://host.docker.internal:54321/functions/v1/email_sender"
# GOTRUE_HOOK_SEND_EMAIL_SECRETS: "v1,whsec_VGhpcyBpcyBhbiBleGFtcGxlIG9mIGEgc2hvcnRlciBCYXNlNjQgc3RyaW5n"
rest:
container_name: supabase-rest
image: postgrest/postgrest:v12.2.8
restart: unless-stopped
depends_on:
db:
# Disable this if you are using an external Postgres database
condition: service_healthy
analytics:
condition: service_healthy
environment:
PGRST_DB_URI: postgres://authenticator:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
PGRST_DB_SCHEMAS: ${PGRST_DB_SCHEMAS}
PGRST_DB_ANON_ROLE: anon
PGRST_JWT_SECRET: ${JWT_SECRET}
PGRST_DB_USE_LEGACY_GUCS: "false"
PGRST_APP_SETTINGS_JWT_SECRET: ${JWT_SECRET}
PGRST_APP_SETTINGS_JWT_EXP: ${JWT_EXPIRY}
command:
[
"postgrest"
]
realtime:
# This container name looks inconsistent but is correct because realtime constructs tenant id by parsing the subdomain
container_name: realtime-dev.supabase-realtime
image: supabase/realtime:v2.34.40
restart: unless-stopped
depends_on:
db:
# Disable this if you are using an external Postgres database
condition: service_healthy
analytics:
condition: service_healthy
healthcheck:
test:
[
"CMD",
"curl",
"-sSfL",
"--head",
"-o",
"/dev/null",
"-H",
"Authorization: Bearer ${ANON_KEY}",
"http://localhost:4000/api/tenants/realtime-dev/health"
]
timeout: 5s
interval: 5s
retries: 3
environment:
PORT: 4000
DB_HOST: ${POSTGRES_HOST}
DB_PORT: ${POSTGRES_PORT}
DB_USER: supabase_admin
DB_PASSWORD: ${POSTGRES_PASSWORD}
DB_NAME: ${POSTGRES_DB}
DB_AFTER_CONNECT_QUERY: 'SET search_path TO _realtime'
DB_ENC_KEY: supabaserealtime
API_JWT_SECRET: ${JWT_SECRET}
SECRET_KEY_BASE: ${SECRET_KEY_BASE}
ERL_AFLAGS: -proto_dist inet_tcp
DNS_NODES: "''"
RLIMIT_NOFILE: "10000"
APP_NAME: realtime
SEED_SELF_HOST: true
RUN_JANITOR: true
# To use S3 backed storage: docker compose -f docker-compose.yml -f docker-compose.s3.yml up
storage:
container_name: supabase-storage
image: supabase/storage-api:v1.19.3
restart: unless-stopped
volumes:
- ./volumes/storage:/var/lib/storage:z
healthcheck:
test:
[
"CMD",
"wget",
"--no-verbose",
"--tries=1",
"--spider",
"http://storage:5000/status"
]
timeout: 5s
interval: 5s
retries: 3
depends_on:
db:
# Disable this if you are using an external Postgres database
condition: service_healthy
rest:
condition: service_started
imgproxy:
condition: service_started
environment:
ANON_KEY: ${ANON_KEY}
SERVICE_KEY: ${SERVICE_ROLE_KEY}
POSTGREST_URL: http://rest:3000
PGRST_JWT_SECRET: ${JWT_SECRET}
DATABASE_URL: postgres://supabase_storage_admin:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
FILE_SIZE_LIMIT: 52428800
STORAGE_BACKEND: file
FILE_STORAGE_BACKEND_PATH: /var/lib/storage
TENANT_ID: stub
# TODO: https://github.com/supabase/storage-api/issues/55
REGION: stub
GLOBAL_S3_BUCKET: stub
ENABLE_IMAGE_TRANSFORMATION: "true"
IMGPROXY_URL: http://imgproxy:5001
imgproxy:
container_name: supabase-imgproxy
image: darthsim/imgproxy:v3.8.0
restart: unless-stopped
volumes:
- ./volumes/storage:/var/lib/storage:z
healthcheck:
test:
[
"CMD",
"imgproxy",
"health"
]
timeout: 5s
interval: 5s
retries: 3
environment:
IMGPROXY_BIND: ":5001"
IMGPROXY_LOCAL_FILESYSTEM_ROOT: /
IMGPROXY_USE_ETAG: "true"
IMGPROXY_ENABLE_WEBP_DETECTION: ${IMGPROXY_ENABLE_WEBP_DETECTION}
meta:
container_name: supabase-meta
image: supabase/postgres-meta:v0.86.1
restart: unless-stopped
depends_on:
db:
# Disable this if you are using an external Postgres database
condition: service_healthy
analytics:
condition: service_healthy
environment:
PG_META_PORT: 8080
PG_META_DB_HOST: ${POSTGRES_HOST}
PG_META_DB_PORT: ${POSTGRES_PORT}
PG_META_DB_NAME: ${POSTGRES_DB}
PG_META_DB_USER: supabase_admin
PG_META_DB_PASSWORD: ${POSTGRES_PASSWORD}
functions:
container_name: supabase-edge-functions
image: supabase/edge-runtime:v1.67.2
restart: unless-stopped
volumes:
- ./volumes/functions:/home/deno/functions:Z
depends_on:
analytics:
condition: service_healthy
environment:
JWT_SECRET: ${JWT_SECRET}
SUPABASE_URL: http://kong:8000
SUPABASE_ANON_KEY: ${ANON_KEY}
SUPABASE_SERVICE_ROLE_KEY: ${SERVICE_ROLE_KEY}
SUPABASE_DB_URL: postgresql://postgres:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
# TODO: Allow configuring VERIFY_JWT per function. This PR might help: https://github.com/supabase/cli/pull/786
VERIFY_JWT: "${FUNCTIONS_VERIFY_JWT}"
command:
[
"start",
"--main-service",
"/home/deno/functions/main"
]
analytics:
container_name: supabase-analytics
image: supabase/logflare:1.12.5
restart: unless-stopped
ports:
- 4000:4000
# Uncomment to use Big Query backend for analytics
# volumes:
# - type: bind
# source: ${PWD}/gcloud.json
# target: /opt/app/rel/logflare/bin/gcloud.json
# read_only: true
healthcheck:
test:
[
"CMD",
"curl",
"http://localhost:4000/health"
]
timeout: 5s
interval: 5s
retries: 10
depends_on:
db:
# Disable this if you are using an external Postgres database
condition: service_healthy
environment:
LOGFLARE_NODE_HOST: 127.0.0.1
DB_USERNAME: supabase_admin
DB_DATABASE: _supabase
DB_HOSTNAME: ${POSTGRES_HOST}
DB_PORT: ${POSTGRES_PORT}
DB_PASSWORD: ${POSTGRES_PASSWORD}
DB_SCHEMA: _analytics
LOGFLARE_API_KEY: ${LOGFLARE_API_KEY}
LOGFLARE_SINGLE_TENANT: true
LOGFLARE_SUPABASE_MODE: true
LOGFLARE_MIN_CLUSTER_SIZE: 1
# Comment variables to use Big Query backend for analytics
POSTGRES_BACKEND_URL: postgresql://supabase_admin:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/_supabase
POSTGRES_BACKEND_SCHEMA: _analytics
LOGFLARE_FEATURE_FLAG_OVERRIDE: multibackend=true
# Uncomment to use Big Query backend for analytics
# GOOGLE_PROJECT_ID: ${GOOGLE_PROJECT_ID}
# GOOGLE_PROJECT_NUMBER: ${GOOGLE_PROJECT_NUMBER}
# Comment out everything below this point if you are using an external Postgres database
db:
container_name: supabase-db
image: supabase/postgres:15.8.1.049
restart: unless-stopped
volumes:
- ./volumes/db/realtime.sql:/docker-entrypoint-initdb.d/migrations/99-realtime.sql:Z
# Must be superuser to create event trigger
- ./volumes/db/webhooks.sql:/docker-entrypoint-initdb.d/init-scripts/98-webhooks.sql:Z
# Must be superuser to alter reserved role
- ./volumes/db/roles.sql:/docker-entrypoint-initdb.d/init-scripts/99-roles.sql:Z
# Initialize the database settings with JWT_SECRET and JWT_EXP
- ./volumes/db/jwt.sql:/docker-entrypoint-initdb.d/init-scripts/99-jwt.sql:Z
# PGDATA directory is persisted between restarts
- ./volumes/db/data:/var/lib/postgresql/data:Z
# Changes required for internal supabase data such as _analytics
- ./volumes/db/_supabase.sql:/docker-entrypoint-initdb.d/migrations/97-_supabase.sql:Z
# Changes required for Analytics support
- ./volumes/db/logs.sql:/docker-entrypoint-initdb.d/migrations/99-logs.sql:Z
# Changes required for Pooler support
- ./volumes/db/pooler.sql:/docker-entrypoint-initdb.d/migrations/99-pooler.sql:Z
# Use named volume to persist pgsodium decryption key between restarts
- supabase-config:/etc/postgresql-custom
healthcheck:
test:
[
"CMD",
"pg_isready",
"-U",
"postgres",
"-h",
"localhost"
]
interval: 5s
timeout: 5s
retries: 10
depends_on:
vector:
condition: service_healthy
environment:
POSTGRES_HOST: /var/run/postgresql
PGPORT: ${POSTGRES_PORT}
POSTGRES_PORT: ${POSTGRES_PORT}
PGPASSWORD: ${POSTGRES_PASSWORD}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
PGDATABASE: ${POSTGRES_DB}
POSTGRES_DB: ${POSTGRES_DB}
JWT_SECRET: ${JWT_SECRET}
JWT_EXP: ${JWT_EXPIRY}
command:
[
"postgres",
"-c",
"config_file=/etc/postgresql/postgresql.conf",
"-c",
"log_min_messages=fatal" # prevents Realtime polling queries from appearing in logs
]
vector:
container_name: supabase-vector
image: timberio/vector:0.28.1-alpine
restart: unless-stopped
volumes:
- ./volumes/logs/vector.yml:/etc/vector/vector.yml:ro
- ${DOCKER_SOCKET_LOCATION}:/var/run/docker.sock:ro
healthcheck:
test:
[
"CMD",
"wget",
"--no-verbose",
"--tries=1",
"--spider",
"http://vector:9001/health"
]
timeout: 5s
interval: 5s
retries: 3
environment:
LOGFLARE_API_KEY: ${LOGFLARE_API_KEY}
command:
[
"--config",
"/etc/vector/vector.yml"
]
# Update the DATABASE_URL if you are using an external Postgres database
supavisor:
container_name: supabase-pooler
image: supabase/supavisor:2.4.12
restart: unless-stopped
ports:
- ${POSTGRES_PORT}:5432
- ${POOLER_PROXY_PORT_TRANSACTION}:6543
volumes:
- ./volumes/pooler/pooler.exs:/etc/pooler/pooler.exs:ro
healthcheck:
test:
[
"CMD",
"curl",
"-sSfL",
"--head",
"-o",
"/dev/null",
"http://127.0.0.1:4000/api/health"
]
interval: 10s
timeout: 5s
retries: 5
depends_on:
db:
condition: service_healthy
analytics:
condition: service_healthy
environment:
PORT: 4000
POSTGRES_PORT: ${POSTGRES_PORT}
POSTGRES_DB: ${POSTGRES_DB}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
DATABASE_URL: ecto://supabase_admin:${POSTGRES_PASSWORD}@db:${POSTGRES_PORT}/_supabase
CLUSTER_POSTGRES: true
SECRET_KEY_BASE: ${SECRET_KEY_BASE}
VAULT_ENC_KEY: ${VAULT_ENC_KEY}
API_JWT_SECRET: ${JWT_SECRET}
METRICS_JWT_SECRET: ${JWT_SECRET}
REGION: local
ERL_AFLAGS: -proto_dist inet_tcp
POOLER_TENANT_ID: ${POOLER_TENANT_ID}
POOLER_DEFAULT_POOL_SIZE: ${POOLER_DEFAULT_POOL_SIZE}
POOLER_MAX_CLIENT_CONN: ${POOLER_MAX_CLIENT_CONN}
POOLER_POOL_MODE: transaction
command:
[
"/bin/sh",
"-c",
"/app/bin/migrate && /app/bin/supavisor eval \"$$(cat /etc/pooler/pooler.exs)\" && /app/bin/server"
]
volumes:
supabase-config:

View File

@@ -0,0 +1,44 @@
#!/bin/bash
echo "WARNING: This will remove all containers and container data, and will reset the .env file. This action cannot be undone!"
read -p "Are you sure you want to proceed? (y/N) " -n 1 -r
echo # Move to a new line
if [[ ! $REPLY =~ ^[Yy]$ ]]
then
echo "Operation cancelled."
exit 1
fi
echo "Stopping and removing all containers..."
docker compose -f docker-compose.yml -f ./dev/docker-compose.dev.yml down -v --remove-orphans
echo "Cleaning up bind-mounted directories..."
BIND_MOUNTS=(
"./volumes/db/data"
)
for DIR in "${BIND_MOUNTS[@]}"; do
if [ -d "$DIR" ]; then
echo "Deleting $DIR..."
rm -rf "$DIR"
else
echo "Directory $DIR does not exist. Skipping bind mount deletion step..."
fi
done
echo "Resetting .env file..."
if [ -f ".env" ]; then
echo "Removing existing .env file..."
rm -f .env
else
echo "No .env file found. Skipping .env removal step..."
fi
if [ -f ".env.example" ]; then
echo "Copying .env.example to .env..."
cp .env.example .env
else
echo ".env.example file not found. Skipping .env reset step..."
fi
echo "Cleanup complete!"

View File

@@ -0,0 +1,241 @@
_format_version: '2.1'
_transform: true
###
### Consumers / Users
###
consumers:
- username: DASHBOARD
- username: anon
keyauth_credentials:
- key: $SUPABASE_ANON_KEY
- username: service_role
keyauth_credentials:
- key: $SUPABASE_SERVICE_KEY
###
### Access Control List
###
acls:
- consumer: anon
group: anon
- consumer: service_role
group: admin
###
### Dashboard credentials
###
basicauth_credentials:
- consumer: DASHBOARD
username: $DASHBOARD_USERNAME
password: $DASHBOARD_PASSWORD
###
### API Routes
###
services:
## Open Auth routes
- name: auth-v1-open
url: http://auth:9999/verify
routes:
- name: auth-v1-open
strip_path: true
paths:
- /auth/v1/verify
plugins:
- name: cors
- name: auth-v1-open-callback
url: http://auth:9999/callback
routes:
- name: auth-v1-open-callback
strip_path: true
paths:
- /auth/v1/callback
plugins:
- name: cors
- name: auth-v1-open-authorize
url: http://auth:9999/authorize
routes:
- name: auth-v1-open-authorize
strip_path: true
paths:
- /auth/v1/authorize
plugins:
- name: cors
## Secure Auth routes
- name: auth-v1
_comment: 'GoTrue: /auth/v1/* -> http://auth:9999/*'
url: http://auth:9999/
routes:
- name: auth-v1-all
strip_path: true
paths:
- /auth/v1/
plugins:
- name: cors
- name: key-auth
config:
hide_credentials: false
- name: acl
config:
hide_groups_header: true
allow:
- admin
- anon
## Secure REST routes
- name: rest-v1
_comment: 'PostgREST: /rest/v1/* -> http://rest:3000/*'
url: http://rest:3000/
routes:
- name: rest-v1-all
strip_path: true
paths:
- /rest/v1/
plugins:
- name: cors
- name: key-auth
config:
hide_credentials: true
- name: acl
config:
hide_groups_header: true
allow:
- admin
- anon
## Secure GraphQL routes
- name: graphql-v1
_comment: 'PostgREST: /graphql/v1/* -> http://rest:3000/rpc/graphql'
url: http://rest:3000/rpc/graphql
routes:
- name: graphql-v1-all
strip_path: true
paths:
- /graphql/v1
plugins:
- name: cors
- name: key-auth
config:
hide_credentials: true
- name: request-transformer
config:
add:
headers:
- Content-Profile:graphql_public
- name: acl
config:
hide_groups_header: true
allow:
- admin
- anon
## Secure Realtime routes
- name: realtime-v1-ws
_comment: 'Realtime: /realtime/v1/* -> ws://realtime:4000/socket/*'
url: http://realtime-dev.supabase-realtime:4000/socket
protocol: ws
routes:
- name: realtime-v1-ws
strip_path: true
paths:
- /realtime/v1/
plugins:
- name: cors
- name: key-auth
config:
hide_credentials: false
- name: acl
config:
hide_groups_header: true
allow:
- admin
- anon
- name: realtime-v1-rest
_comment: 'Realtime: /realtime/v1/* -> ws://realtime:4000/socket/*'
url: http://realtime-dev.supabase-realtime:4000/api
protocol: http
routes:
- name: realtime-v1-rest
strip_path: true
paths:
- /realtime/v1/api
plugins:
- name: cors
- name: key-auth
config:
hide_credentials: false
- name: acl
config:
hide_groups_header: true
allow:
- admin
- anon
## Storage routes: the storage server manages its own auth
- name: storage-v1
_comment: 'Storage: /storage/v1/* -> http://storage:5000/*'
url: http://storage:5000/
routes:
- name: storage-v1-all
strip_path: true
paths:
- /storage/v1/
plugins:
- name: cors
## Edge Functions routes
- name: functions-v1
_comment: 'Edge Functions: /functions/v1/* -> http://functions:9000/*'
url: http://functions:9000/
routes:
- name: functions-v1-all
strip_path: true
paths:
- /functions/v1/
plugins:
- name: cors
## Analytics routes
- name: analytics-v1
_comment: 'Analytics: /analytics/v1/* -> http://logflare:4000/*'
url: http://analytics:4000/
routes:
- name: analytics-v1-all
strip_path: true
paths:
- /analytics/v1/
## Secure Database routes
- name: meta
_comment: 'pg-meta: /pg/* -> http://pg-meta:8080/*'
url: http://meta:8080/
routes:
- name: meta-all
strip_path: true
paths:
- /pg/
plugins:
- name: key-auth
config:
hide_credentials: false
- name: acl
config:
hide_groups_header: true
allow:
- admin
## Protected Dashboard - catch all remaining routes
- name: dashboard
_comment: 'Studio: /* -> http://studio:3000/*'
url: http://studio:3000/
routes:
- name: dashboard-all
strip_path: true
paths:
- /
plugins:
- name: cors
- name: basic-auth
config:
hide_credentials: true

View File

@@ -0,0 +1,3 @@
\set pguser `echo "$POSTGRES_USER"`
CREATE DATABASE _supabase WITH OWNER :pguser;

View File

@@ -0,0 +1,5 @@
\set jwt_secret `echo "$JWT_SECRET"`
\set jwt_exp `echo "$JWT_EXP"`
ALTER DATABASE postgres SET "app.settings.jwt_secret" TO :'jwt_secret';
ALTER DATABASE postgres SET "app.settings.jwt_exp" TO :'jwt_exp';

View File

@@ -0,0 +1,6 @@
\set pguser `echo "$POSTGRES_USER"`
\c _supabase
create schema if not exists _analytics;
alter schema _analytics owner to :pguser;
\c postgres

View File

@@ -0,0 +1,6 @@
\set pguser `echo "$POSTGRES_USER"`
\c _supabase
create schema if not exists _supavisor;
alter schema _supavisor owner to :pguser;
\c postgres

View File

@@ -0,0 +1,4 @@
\set pguser `echo "$POSTGRES_USER"`
create schema if not exists _realtime;
alter schema _realtime owner to :pguser;

View File

@@ -0,0 +1,8 @@
-- NOTE: change to your own passwords for production environments
\set pgpass `echo "$POSTGRES_PASSWORD"`
ALTER USER authenticator WITH PASSWORD :'pgpass';
ALTER USER pgbouncer WITH PASSWORD :'pgpass';
ALTER USER supabase_auth_admin WITH PASSWORD :'pgpass';
ALTER USER supabase_functions_admin WITH PASSWORD :'pgpass';
ALTER USER supabase_storage_admin WITH PASSWORD :'pgpass';

View File

@@ -0,0 +1,208 @@
BEGIN;
-- Create pg_net extension
CREATE EXTENSION IF NOT EXISTS pg_net SCHEMA extensions;
-- Create supabase_functions schema
CREATE SCHEMA supabase_functions AUTHORIZATION supabase_admin;
GRANT USAGE ON SCHEMA supabase_functions TO postgres, anon, authenticated, service_role;
ALTER DEFAULT PRIVILEGES IN SCHEMA supabase_functions GRANT ALL ON TABLES TO postgres, anon, authenticated, service_role;
ALTER DEFAULT PRIVILEGES IN SCHEMA supabase_functions GRANT ALL ON FUNCTIONS TO postgres, anon, authenticated, service_role;
ALTER DEFAULT PRIVILEGES IN SCHEMA supabase_functions GRANT ALL ON SEQUENCES TO postgres, anon, authenticated, service_role;
-- supabase_functions.migrations definition
CREATE TABLE supabase_functions.migrations (
version text PRIMARY KEY,
inserted_at timestamptz NOT NULL DEFAULT NOW()
);
-- Initial supabase_functions migration
INSERT INTO supabase_functions.migrations (version) VALUES ('initial');
-- supabase_functions.hooks definition
CREATE TABLE supabase_functions.hooks (
id bigserial PRIMARY KEY,
hook_table_id integer NOT NULL,
hook_name text NOT NULL,
created_at timestamptz NOT NULL DEFAULT NOW(),
request_id bigint
);
CREATE INDEX supabase_functions_hooks_request_id_idx ON supabase_functions.hooks USING btree (request_id);
CREATE INDEX supabase_functions_hooks_h_table_id_h_name_idx ON supabase_functions.hooks USING btree (hook_table_id, hook_name);
COMMENT ON TABLE supabase_functions.hooks IS 'Supabase Functions Hooks: Audit trail for triggered hooks.';
CREATE FUNCTION supabase_functions.http_request()
RETURNS trigger
LANGUAGE plpgsql
AS $function$
DECLARE
request_id bigint;
payload jsonb;
url text := TG_ARGV[0]::text;
method text := TG_ARGV[1]::text;
headers jsonb DEFAULT '{}'::jsonb;
params jsonb DEFAULT '{}'::jsonb;
timeout_ms integer DEFAULT 1000;
BEGIN
IF url IS NULL OR url = 'null' THEN
RAISE EXCEPTION 'url argument is missing';
END IF;
IF method IS NULL OR method = 'null' THEN
RAISE EXCEPTION 'method argument is missing';
END IF;
IF TG_ARGV[2] IS NULL OR TG_ARGV[2] = 'null' THEN
headers = '{"Content-Type": "application/json"}'::jsonb;
ELSE
headers = TG_ARGV[2]::jsonb;
END IF;
IF TG_ARGV[3] IS NULL OR TG_ARGV[3] = 'null' THEN
params = '{}'::jsonb;
ELSE
params = TG_ARGV[3]::jsonb;
END IF;
IF TG_ARGV[4] IS NULL OR TG_ARGV[4] = 'null' THEN
timeout_ms = 1000;
ELSE
timeout_ms = TG_ARGV[4]::integer;
END IF;
CASE
WHEN method = 'GET' THEN
SELECT http_get INTO request_id FROM net.http_get(
url,
params,
headers,
timeout_ms
);
WHEN method = 'POST' THEN
payload = jsonb_build_object(
'old_record', OLD,
'record', NEW,
'type', TG_OP,
'table', TG_TABLE_NAME,
'schema', TG_TABLE_SCHEMA
);
SELECT http_post INTO request_id FROM net.http_post(
url,
payload,
params,
headers,
timeout_ms
);
ELSE
RAISE EXCEPTION 'method argument % is invalid', method;
END CASE;
INSERT INTO supabase_functions.hooks
(hook_table_id, hook_name, request_id)
VALUES
(TG_RELID, TG_NAME, request_id);
RETURN NEW;
END
$function$;
-- Supabase super admin
DO
$$
BEGIN
IF NOT EXISTS (
SELECT 1
FROM pg_roles
WHERE rolname = 'supabase_functions_admin'
)
THEN
CREATE USER supabase_functions_admin NOINHERIT CREATEROLE LOGIN NOREPLICATION;
END IF;
END
$$;
GRANT ALL PRIVILEGES ON SCHEMA supabase_functions TO supabase_functions_admin;
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA supabase_functions TO supabase_functions_admin;
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA supabase_functions TO supabase_functions_admin;
ALTER USER supabase_functions_admin SET search_path = "supabase_functions";
ALTER table "supabase_functions".migrations OWNER TO supabase_functions_admin;
ALTER table "supabase_functions".hooks OWNER TO supabase_functions_admin;
ALTER function "supabase_functions".http_request() OWNER TO supabase_functions_admin;
GRANT supabase_functions_admin TO postgres;
-- Remove unused supabase_pg_net_admin role
DO
$$
BEGIN
IF EXISTS (
SELECT 1
FROM pg_roles
WHERE rolname = 'supabase_pg_net_admin'
)
THEN
REASSIGN OWNED BY supabase_pg_net_admin TO supabase_admin;
DROP OWNED BY supabase_pg_net_admin;
DROP ROLE supabase_pg_net_admin;
END IF;
END
$$;
-- pg_net grants when extension is already enabled
DO
$$
BEGIN
IF EXISTS (
SELECT 1
FROM pg_extension
WHERE extname = 'pg_net'
)
THEN
GRANT USAGE ON SCHEMA net TO supabase_functions_admin, postgres, anon, authenticated, service_role;
ALTER function net.http_get(url text, params jsonb, headers jsonb, timeout_milliseconds integer) SECURITY DEFINER;
ALTER function net.http_post(url text, body jsonb, params jsonb, headers jsonb, timeout_milliseconds integer) SECURITY DEFINER;
ALTER function net.http_get(url text, params jsonb, headers jsonb, timeout_milliseconds integer) SET search_path = net;
ALTER function net.http_post(url text, body jsonb, params jsonb, headers jsonb, timeout_milliseconds integer) SET search_path = net;
REVOKE ALL ON FUNCTION net.http_get(url text, params jsonb, headers jsonb, timeout_milliseconds integer) FROM PUBLIC;
REVOKE ALL ON FUNCTION net.http_post(url text, body jsonb, params jsonb, headers jsonb, timeout_milliseconds integer) FROM PUBLIC;
GRANT EXECUTE ON FUNCTION net.http_get(url text, params jsonb, headers jsonb, timeout_milliseconds integer) TO supabase_functions_admin, postgres, anon, authenticated, service_role;
GRANT EXECUTE ON FUNCTION net.http_post(url text, body jsonb, params jsonb, headers jsonb, timeout_milliseconds integer) TO supabase_functions_admin, postgres, anon, authenticated, service_role;
END IF;
END
$$;
-- Event trigger for pg_net
CREATE OR REPLACE FUNCTION extensions.grant_pg_net_access()
RETURNS event_trigger
LANGUAGE plpgsql
AS $$
BEGIN
IF EXISTS (
SELECT 1
FROM pg_event_trigger_ddl_commands() AS ev
JOIN pg_extension AS ext
ON ev.objid = ext.oid
WHERE ext.extname = 'pg_net'
)
THEN
GRANT USAGE ON SCHEMA net TO supabase_functions_admin, postgres, anon, authenticated, service_role;
ALTER function net.http_get(url text, params jsonb, headers jsonb, timeout_milliseconds integer) SECURITY DEFINER;
ALTER function net.http_post(url text, body jsonb, params jsonb, headers jsonb, timeout_milliseconds integer) SECURITY DEFINER;
ALTER function net.http_get(url text, params jsonb, headers jsonb, timeout_milliseconds integer) SET search_path = net;
ALTER function net.http_post(url text, body jsonb, params jsonb, headers jsonb, timeout_milliseconds integer) SET search_path = net;
REVOKE ALL ON FUNCTION net.http_get(url text, params jsonb, headers jsonb, timeout_milliseconds integer) FROM PUBLIC;
REVOKE ALL ON FUNCTION net.http_post(url text, body jsonb, params jsonb, headers jsonb, timeout_milliseconds integer) FROM PUBLIC;
GRANT EXECUTE ON FUNCTION net.http_get(url text, params jsonb, headers jsonb, timeout_milliseconds integer) TO supabase_functions_admin, postgres, anon, authenticated, service_role;
GRANT EXECUTE ON FUNCTION net.http_post(url text, body jsonb, params jsonb, headers jsonb, timeout_milliseconds integer) TO supabase_functions_admin, postgres, anon, authenticated, service_role;
END IF;
END;
$$;
COMMENT ON FUNCTION extensions.grant_pg_net_access IS 'Grants access to pg_net';
DO
$$
BEGIN
IF NOT EXISTS (
SELECT 1
FROM pg_event_trigger
WHERE evtname = 'issue_pg_net_access'
) THEN
CREATE EVENT TRIGGER issue_pg_net_access ON ddl_command_end WHEN TAG IN ('CREATE EXTENSION')
EXECUTE PROCEDURE extensions.grant_pg_net_access();
END IF;
END
$$;
INSERT INTO supabase_functions.migrations (version) VALUES ('20210809183423_update_grants');
ALTER function supabase_functions.http_request() SECURITY DEFINER;
ALTER function supabase_functions.http_request() SET search_path = supabase_functions;
REVOKE ALL ON FUNCTION supabase_functions.http_request() FROM PUBLIC;
GRANT EXECUTE ON FUNCTION supabase_functions.http_request() TO postgres, anon, authenticated, service_role;
COMMIT;

View File

@@ -0,0 +1,16 @@
// Follow this setup guide to integrate the Deno language server with your editor:
// https://deno.land/manual/getting_started/setup_your_environment
// This enables autocomplete, go to definition, etc.
import { serve } from "https://deno.land/std@0.177.1/http/server.ts"
serve(async () => {
return new Response(
`"Hello from Edge Functions!"`,
{ headers: { "Content-Type": "application/json" } },
)
})
// To invoke:
// curl 'http://localhost:<KONG_HTTP_PORT>/functions/v1/hello' \
// --header 'Authorization: Bearer <anon/service_role API key>'

View File

@@ -0,0 +1,94 @@
import { serve } from 'https://deno.land/std@0.131.0/http/server.ts'
import * as jose from 'https://deno.land/x/jose@v4.14.4/index.ts'
console.log('main function started')
const JWT_SECRET = Deno.env.get('JWT_SECRET')
const VERIFY_JWT = Deno.env.get('VERIFY_JWT') === 'true'
function getAuthToken(req: Request) {
const authHeader = req.headers.get('authorization')
if (!authHeader) {
throw new Error('Missing authorization header')
}
const [bearer, token] = authHeader.split(' ')
if (bearer !== 'Bearer') {
throw new Error(`Auth header is not 'Bearer {token}'`)
}
return token
}
async function verifyJWT(jwt: string): Promise<boolean> {
const encoder = new TextEncoder()
const secretKey = encoder.encode(JWT_SECRET)
try {
await jose.jwtVerify(jwt, secretKey)
} catch (err) {
console.error(err)
return false
}
return true
}
serve(async (req: Request) => {
if (req.method !== 'OPTIONS' && VERIFY_JWT) {
try {
const token = getAuthToken(req)
const isValidJWT = await verifyJWT(token)
if (!isValidJWT) {
return new Response(JSON.stringify({ msg: 'Invalid JWT' }), {
status: 401,
headers: { 'Content-Type': 'application/json' },
})
}
} catch (e) {
console.error(e)
return new Response(JSON.stringify({ msg: e.toString() }), {
status: 401,
headers: { 'Content-Type': 'application/json' },
})
}
}
const url = new URL(req.url)
const { pathname } = url
const path_parts = pathname.split('/')
const service_name = path_parts[1]
if (!service_name || service_name === '') {
const error = { msg: 'missing function name in request' }
return new Response(JSON.stringify(error), {
status: 400,
headers: { 'Content-Type': 'application/json' },
})
}
const servicePath = `/home/deno/functions/${service_name}`
console.error(`serving the request with ${servicePath}`)
const memoryLimitMb = 150
const workerTimeoutMs = 1 * 60 * 1000
const noModuleCache = false
const importMapPath = null
const envVarsObj = Deno.env.toObject()
const envVars = Object.keys(envVarsObj).map((k) => [k, envVarsObj[k]])
try {
const worker = await EdgeRuntime.userWorkers.create({
servicePath,
memoryLimitMb,
workerTimeoutMs,
noModuleCache,
importMapPath,
envVars,
})
return await worker.fetch(req)
} catch (e) {
const error = { msg: e.toString() }
return new Response(JSON.stringify(error), {
status: 500,
headers: { 'Content-Type': 'application/json' },
})
}
})

View File

@@ -0,0 +1,232 @@
api:
enabled: true
address: 0.0.0.0:9001
sources:
docker_host:
type: docker_logs
exclude_containers:
- supabase-vector
transforms:
project_logs:
type: remap
inputs:
- docker_host
source: |-
.project = "default"
.event_message = del(.message)
.appname = del(.container_name)
del(.container_created_at)
del(.container_id)
del(.source_type)
del(.stream)
del(.label)
del(.image)
del(.host)
del(.stream)
router:
type: route
inputs:
- project_logs
route:
kong: '.appname == "supabase-kong"'
auth: '.appname == "supabase-auth"'
rest: '.appname == "supabase-rest"'
realtime: '.appname == "supabase-realtime"'
storage: '.appname == "supabase-storage"'
functions: '.appname == "supabase-functions"'
db: '.appname == "supabase-db"'
# Ignores non nginx errors since they are related with kong booting up
kong_logs:
type: remap
inputs:
- router.kong
source: |-
req, err = parse_nginx_log(.event_message, "combined")
if err == null {
.timestamp = req.timestamp
.metadata.request.headers.referer = req.referer
.metadata.request.headers.user_agent = req.agent
.metadata.request.headers.cf_connecting_ip = req.client
.metadata.request.method = req.method
.metadata.request.path = req.path
.metadata.request.protocol = req.protocol
.metadata.response.status_code = req.status
}
if err != null {
abort
}
# Ignores non nginx errors since they are related with kong booting up
kong_err:
type: remap
inputs:
- router.kong
source: |-
.metadata.request.method = "GET"
.metadata.response.status_code = 200
parsed, err = parse_nginx_log(.event_message, "error")
if err == null {
.timestamp = parsed.timestamp
.severity = parsed.severity
.metadata.request.host = parsed.host
.metadata.request.headers.cf_connecting_ip = parsed.client
url, err = split(parsed.request, " ")
if err == null {
.metadata.request.method = url[0]
.metadata.request.path = url[1]
.metadata.request.protocol = url[2]
}
}
if err != null {
abort
}
# Gotrue logs are structured json strings which frontend parses directly. But we keep metadata for consistency.
auth_logs:
type: remap
inputs:
- router.auth
source: |-
parsed, err = parse_json(.event_message)
if err == null {
.metadata.timestamp = parsed.time
.metadata = merge!(.metadata, parsed)
}
# PostgREST logs are structured so we separate timestamp from message using regex
rest_logs:
type: remap
inputs:
- router.rest
source: |-
parsed, err = parse_regex(.event_message, r'^(?P<time>.*): (?P<msg>.*)$')
if err == null {
.event_message = parsed.msg
.timestamp = to_timestamp!(parsed.time)
.metadata.host = .project
}
# Realtime logs are structured so we parse the severity level using regex (ignore time because it has no date)
realtime_logs:
type: remap
inputs:
- router.realtime
source: |-
.metadata.project = del(.project)
.metadata.external_id = .metadata.project
parsed, err = parse_regex(.event_message, r'^(?P<time>\d+:\d+:\d+\.\d+) \[(?P<level>\w+)\] (?P<msg>.*)$')
if err == null {
.event_message = parsed.msg
.metadata.level = parsed.level
}
# Storage logs may contain json objects so we parse them for completeness
storage_logs:
type: remap
inputs:
- router.storage
source: |-
.metadata.project = del(.project)
.metadata.tenantId = .metadata.project
parsed, err = parse_json(.event_message)
if err == null {
.event_message = parsed.msg
.metadata.level = parsed.level
.metadata.timestamp = parsed.time
.metadata.context[0].host = parsed.hostname
.metadata.context[0].pid = parsed.pid
}
# Postgres logs some messages to stderr which we map to warning severity level
db_logs:
type: remap
inputs:
- router.db
source: |-
.metadata.host = "db-default"
.metadata.parsed.timestamp = .timestamp
parsed, err = parse_regex(.event_message, r'.*(?P<level>INFO|NOTICE|WARNING|ERROR|LOG|FATAL|PANIC?):.*', numeric_groups: true)
if err != null || parsed == null {
.metadata.parsed.error_severity = "info"
}
if parsed != null {
.metadata.parsed.error_severity = parsed.level
}
if .metadata.parsed.error_severity == "info" {
.metadata.parsed.error_severity = "log"
}
.metadata.parsed.error_severity = upcase!(.metadata.parsed.error_severity)
sinks:
logflare_auth:
type: 'http'
inputs:
- auth_logs
encoding:
codec: 'json'
method: 'post'
request:
retry_max_duration_secs: 10
uri: 'http://analytics:4000/api/logs?source_name=gotrue.logs.prod&api_key=${LOGFLARE_API_KEY?LOGFLARE_API_KEY is required}'
logflare_realtime:
type: 'http'
inputs:
- realtime_logs
encoding:
codec: 'json'
method: 'post'
request:
retry_max_duration_secs: 10
uri: 'http://analytics:4000/api/logs?source_name=realtime.logs.prod&api_key=${LOGFLARE_API_KEY?LOGFLARE_API_KEY is required}'
logflare_rest:
type: 'http'
inputs:
- rest_logs
encoding:
codec: 'json'
method: 'post'
request:
retry_max_duration_secs: 10
uri: 'http://analytics:4000/api/logs?source_name=postgREST.logs.prod&api_key=${LOGFLARE_API_KEY?LOGFLARE_API_KEY is required}'
logflare_db:
type: 'http'
inputs:
- db_logs
encoding:
codec: 'json'
method: 'post'
request:
retry_max_duration_secs: 10
# We must route the sink through kong because ingesting logs before logflare is fully initialised will
# lead to broken queries from studio. This works by the assumption that containers are started in the
# following order: vector > db > logflare > kong
uri: 'http://kong:8000/analytics/v1/api/logs?source_name=postgres.logs&api_key=${LOGFLARE_API_KEY?LOGFLARE_API_KEY is required}'
logflare_functions:
type: 'http'
inputs:
- router.functions
encoding:
codec: 'json'
method: 'post'
request:
retry_max_duration_secs: 10
uri: 'http://analytics:4000/api/logs?source_name=deno-relay-logs&api_key=${LOGFLARE_API_KEY?LOGFLARE_API_KEY is required}'
logflare_storage:
type: 'http'
inputs:
- storage_logs
encoding:
codec: 'json'
method: 'post'
request:
retry_max_duration_secs: 10
uri: 'http://analytics:4000/api/logs?source_name=storage.logs.prod.2&api_key=${LOGFLARE_API_KEY?LOGFLARE_API_KEY is required}'
logflare_kong:
type: 'http'
inputs:
- kong_logs
- kong_err
encoding:
codec: 'json'
method: 'post'
request:
retry_max_duration_secs: 10
uri: 'http://analytics:4000/api/logs?source_name=cloudflare.logs.prod&api_key=${LOGFLARE_API_KEY?LOGFLARE_API_KEY is required}'

View File

@@ -0,0 +1,30 @@
{:ok, _} = Application.ensure_all_started(:supavisor)
{:ok, version} =
case Supavisor.Repo.query!("select version()") do
%{rows: [[ver]]} -> Supavisor.Helpers.parse_pg_version(ver)
_ -> nil
end
params = %{
"external_id" => System.get_env("POOLER_TENANT_ID"),
"db_host" => "db",
"db_port" => System.get_env("POSTGRES_PORT"),
"db_database" => System.get_env("POSTGRES_DB"),
"require_user" => false,
"auth_query" => "SELECT * FROM pgbouncer.get_auth($1)",
"default_max_clients" => System.get_env("POOLER_MAX_CLIENT_CONN"),
"default_pool_size" => System.get_env("POOLER_DEFAULT_POOL_SIZE"),
"default_parameter_status" => %{"server_version" => version},
"users" => [%{
"db_user" => "pgbouncer",
"db_password" => System.get_env("POSTGRES_PASSWORD"),
"mode_type" => System.get_env("POOLER_POOL_MODE"),
"pool_size" => System.get_env("POOLER_DEFAULT_POOL_SIZE"),
"is_manager" => true
}]
}
if !Supavisor.Tenants.get_tenant_by_external_id(params["external_id"]) do
{:ok, _} = Supavisor.Tenants.create_tenant(params)
end

View File

@@ -121,6 +121,7 @@ services:
migrate:
condition: service_completed_successfully
environment:
- DATABASEMANAGER_HOST=rest_server
- SUPABASE_URL=http://kong:8000
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
- SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
@@ -163,6 +164,7 @@ services:
migrate:
condition: service_completed_successfully
environment:
- DATABASEMANAGER_HOST=rest_server
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- REDIS_HOST=redis

View File

@@ -5,7 +5,7 @@ networks:
name: shared-network
volumes:
db-config:
supabase-config:
x-agpt-services:
&agpt-services
@@ -67,19 +67,19 @@ services:
studio:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
file: ./db/docker/docker-compose.yml
service: studio
kong:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
file: ./db/docker/docker-compose.yml
service: kong
auth:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
file: ./db/docker/docker-compose.yml
service: auth
environment:
GOTRUE_MAILER_AUTOCONFIRM: true
@@ -87,54 +87,57 @@ services:
rest:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
file: ./db/docker/docker-compose.yml
service: rest
realtime:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
file: ./db/docker/docker-compose.yml
service: realtime
storage:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
file: ./db/docker/docker-compose.yml
service: storage
imgproxy:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
file: ./db/docker/docker-compose.yml
service: imgproxy
meta:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
file: ./db/docker/docker-compose.yml
service: meta
functions:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
file: ./db/docker/docker-compose.yml
service: functions
analytics:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
file: ./db/docker/docker-compose.yml
service: analytics
db:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
file: ./db/docker/docker-compose.yml
service: db
ports:
- ${POSTGRES_PORT}:5432 # We don't use Supavisor locally, so we expose the db directly.
vector:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
file: ./db/docker/docker-compose.yml
service: vector
deps:

View File

@@ -23,9 +23,9 @@
"defaults"
],
"dependencies": {
"@faker-js/faker": "^9.4.0",
"@faker-js/faker": "^9.6.0",
"@hookform/resolvers": "^3.10.0",
"@next/third-parties": "^15.1.6",
"@next/third-parties": "^15.2.1",
"@radix-ui/react-alert-dialog": "^1.1.5",
"@radix-ui/react-avatar": "^1.1.1",
"@radix-ui/react-checkbox": "^1.1.2",
@@ -46,9 +46,9 @@
"@radix-ui/react-tooltip": "^1.1.7",
"@sentry/nextjs": "^8",
"@supabase/ssr": "^0.5.2",
"@supabase/supabase-js": "^2.48.1",
"@tanstack/react-table": "^8.20.6",
"@xyflow/react": "^12.4.2",
"@supabase/supabase-js": "^2.49.1",
"@tanstack/react-table": "^8.21.2",
"@xyflow/react": "12.4.2",
"ajv": "^8.17.1",
"boring-avatars": "^1.11.2",
"canvas-confetti": "^1.9.3",
@@ -60,28 +60,28 @@
"dotenv": "^16.4.7",
"elliptic": "6.6.1",
"embla-carousel-react": "^8.5.2",
"framer-motion": "^12.0.11",
"framer-motion": "^12.4.11",
"geist": "^1.3.1",
"launchdarkly-react-client-sdk": "^3.6.1",
"lodash.debounce": "^4.0.8",
"lucide-react": "^0.474.0",
"lucide-react": "^0.479.0",
"moment": "^2.30.1",
"next": "^14.2.21",
"next-themes": "^0.4.4",
"next": "^14.2.25",
"next-themes": "^0.4.5",
"react": "^18",
"react-day-picker": "^9.5.1",
"react-day-picker": "^9.6.1",
"react-dom": "^18",
"react-drag-drop-files": "^2.4.0",
"react-hook-form": "^7.54.0",
"react-icons": "^5.4.0",
"react-icons": "^5.5.0",
"react-markdown": "^9.0.3",
"react-modal": "^3.16.3",
"react-shepherd": "^6.1.7",
"react-shepherd": "^6.1.8",
"recharts": "^2.15.1",
"tailwind-merge": "^2.6.0",
"tailwindcss-animate": "^1.0.7",
"uuid": "^11.0.5",
"zod": "^3.23.8"
"uuid": "^11.1.0",
"zod": "^3.24.2"
},
"devDependencies": {
"@chromatic-com/storybook": "^3.2.4",

View File

@@ -1,15 +1,29 @@
"use client";
import { ShoppingBag } from "lucide-react";
import { Sidebar } from "@/components/agptui/Sidebar";
import { Users, DollarSign, LogOut } from "lucide-react";
import { useState } from "react";
import Link from "next/link";
import { BinaryIcon, XIcon } from "lucide-react";
import { usePathname } from "next/navigation"; // Add this import
import { IconSliders } from "@/components/ui/icons";
const tabs = [
{ name: "Dashboard", href: "/admin/dashboard" },
{ name: "Marketplace", href: "/admin/marketplace" },
{ name: "Users", href: "/admin/users" },
{ name: "Settings", href: "/admin/settings" },
const sidebarLinkGroups = [
{
links: [
{
text: "Marketplace Management",
href: "/admin/marketplace",
icon: <Users className="h-6 w-6" />,
},
{
text: "User Spending",
href: "/admin/spending",
icon: <DollarSign className="h-6 w-6" />,
},
{
text: "Admin User Management",
href: "/admin/settings",
icon: <IconSliders className="h-6 w-6" />,
},
],
},
];
export default function AdminLayout({
@@ -17,84 +31,10 @@ export default function AdminLayout({
}: {
children: React.ReactNode;
}) {
const pathname = usePathname(); // Get the current pathname
const [activeTab, setActiveTab] = useState(() => {
// Set active tab based on the current route
return tabs.find((tab) => tab.href === pathname)?.name || tabs[0].name;
});
const [mobileMenuOpen, setMobileMenuOpen] = useState(false);
return (
<div className="min-h-screen bg-gray-100">
<nav className="bg-white shadow-sm">
<div className="max-w-10xl mx-auto px-4 sm:px-6 lg:px-8">
<div className="flex h-16 items-center justify-between">
<div className="flex items-center">
<div className="flex-shrink-0">
<h1 className="text-xl font-bold">Admin Panel</h1>
</div>
<div className="hidden sm:ml-6 sm:flex sm:space-x-8">
{tabs.map((tab) => (
<Link
key={tab.name}
href={tab.href}
className={`${
activeTab === tab.name
? "border-indigo-500 text-indigo-600"
: "border-transparent text-gray-500 hover:border-gray-300 hover:text-gray-700"
} inline-flex items-center border-b-2 px-1 pt-1 text-sm font-medium`}
onClick={() => setActiveTab(tab.name)}
>
{tab.name}
</Link>
))}
</div>
</div>
<div className="sm:hidden">
<button
type="button"
className="inline-flex items-center justify-center rounded-md p-2 text-gray-400 hover:bg-gray-100 hover:text-gray-500 focus:outline-none focus:ring-2 focus:ring-inset focus:ring-indigo-500"
onClick={() => setMobileMenuOpen(!mobileMenuOpen)}
>
<span className="sr-only">Open main menu</span>
{mobileMenuOpen ? (
<XIcon className="block h-6 w-6" aria-hidden="true" />
) : (
<BinaryIcon className="block h-6 w-6" aria-hidden="true" />
)}
</button>
</div>
</div>
</div>
{mobileMenuOpen && (
<div className="sm:hidden">
<div className="space-y-1 pb-3 pt-2">
{tabs.map((tab) => (
<Link
key={tab.name}
href={tab.href}
className={`${
activeTab === tab.name
? "border-indigo-500 bg-indigo-50 text-indigo-700"
: "border-transparent text-gray-600 hover:border-gray-300 hover:bg-gray-50 hover:text-gray-800"
} block border-l-4 py-2 pl-3 pr-4 text-base font-medium`}
onClick={() => {
setActiveTab(tab.name);
setMobileMenuOpen(false);
}}
>
{tab.name}
</Link>
))}
</div>
</div>
)}
</nav>
<main className="py-10">
<div className="mx-auto max-w-7xl px-4 sm:px-6 lg:px-8">{children}</div>
</main>
<div className="flex min-h-screen w-screen flex-col lg:flex-row">
<Sidebar linkGroups={sidebarLinkGroups} />
<div className="flex-1 pl-4">{children}</div>
</div>
);
}

View File

@@ -0,0 +1,58 @@
"use server";
import { revalidatePath } from "next/cache";
import BackendApi from "@/lib/autogpt-server-api";
import {
NotificationPreferenceDTO,
StoreListingsWithVersionsResponse,
StoreSubmissionsResponse,
SubmissionStatus,
} from "@/lib/autogpt-server-api/types";
export async function approveAgent(formData: FormData) {
const data = {
store_listing_version_id: formData.get("id") as string,
is_approved: true,
comments: formData.get("comments") as string,
};
const api = new BackendApi();
await api.reviewSubmissionAdmin(data.store_listing_version_id, data);
revalidatePath("/admin/marketplace");
}
export async function rejectAgent(formData: FormData) {
const data = {
store_listing_version_id: formData.get("id") as string,
is_approved: false,
comments: formData.get("comments") as string,
internal_comments: formData.get("internal_comments") as string,
};
const api = new BackendApi();
await api.reviewSubmissionAdmin(data.store_listing_version_id, data);
revalidatePath("/admin/marketplace");
}
export async function getAdminListingsWithVersions(
status?: SubmissionStatus,
search?: string,
page: number = 1,
pageSize: number = 20,
): Promise<StoreListingsWithVersionsResponse> {
const data: Record<string, any> = {
page,
page_size: pageSize,
};
if (status) {
data.status = status;
}
if (search) {
data.search = search;
}
const api = new BackendApi();
const response = await api.getAdminListingsWithVersions(data);
return response;
}

View File

@@ -1,25 +1,62 @@
import { withRoleAccess } from "@/lib/withRoleAccess";
import { Suspense } from "react";
import type { SubmissionStatus } from "@/lib/autogpt-server-api/types";
import { AdminAgentsDataTable } from "@/components/admin/marketplace/admin-agents-data-table";
import React from "react";
// import { getReviewableAgents } from "@/components/admin/marketplace/actions";
// import AdminMarketplaceAgentList from "@/components/admin/marketplace/AdminMarketplaceAgentList";
// import AdminFeaturedAgentsControl from "@/components/admin/marketplace/AdminFeaturedAgentsControl";
import { Separator } from "@/components/ui/separator";
async function AdminMarketplace() {
// const reviewableAgents = await getReviewableAgents();
async function AdminMarketplaceDashboard({
searchParams,
}: {
searchParams: {
page?: string;
status?: string;
search?: string;
};
}) {
const page = searchParams.page ? Number.parseInt(searchParams.page) : 1;
const status = searchParams.status as SubmissionStatus | undefined;
const search = searchParams.search;
return (
<>
{/* <AdminMarketplaceAgentList agents={reviewableAgents.items} />
<Separator className="my-4" />
<AdminFeaturedAgentsControl className="mt-4" /> */}
</>
<div className="mx-auto p-6">
<div className="flex flex-col gap-4">
<div className="flex items-center justify-between">
<div>
<h1 className="text-3xl font-bold">Marketplace Management</h1>
<p className="text-gray-500">
Unified view for marketplace management and approval history
</p>
</div>
</div>
<Suspense
fallback={
<div className="py-10 text-center">Loading submissions...</div>
}
>
<AdminAgentsDataTable
initialPage={page}
initialStatus={status}
initialSearch={search}
/>
</Suspense>
</div>
</div>
);
}
export default async function AdminDashboardPage() {
export default async function AdminMarketplacePage({
searchParams,
}: {
searchParams: {
page?: string;
status?: string;
search?: string;
};
}) {
"use server";
const withAdminAccess = await withRoleAccess(["admin"]);
const ProtectedAdminMarketplace = await withAdminAccess(AdminMarketplace);
return <ProtectedAdminMarketplace />;
const ProtectedAdminMarketplace = await withAdminAccess(
AdminMarketplaceDashboard,
);
return <ProtectedAdminMarketplace searchParams={searchParams} />;
}

View File

@@ -24,10 +24,14 @@ export async function askOtto(
try {
const response = await api.askOtto(ottoQuery);
revalidatePath("/build");
return response;
} catch (error) {
console.error("Error in askOtto server action:", error);
throw error;
return {
answer: error instanceof Error ? error.message : "Unknown error occurred",
documents: [],
success: false,
error: true,
};
}
}

View File

@@ -2,6 +2,7 @@
import React, { useCallback, useEffect, useMemo, useState } from "react";
import { useParams, useRouter } from "next/navigation";
import { exportAsJSONFile } from "@/lib/utils";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import {
GraphExecution,
@@ -191,19 +192,40 @@ export default function AgentRunsPage(): React.ReactElement {
[schedules, api],
);
const downloadGraph = useCallback(
async () =>
agent &&
// Export sanitized graph from backend
api
.getGraph(agent.agent_id, agent.agent_version, true)
.then((graph) =>
exportAsJSONFile(graph, `${graph.name}_v${graph.version}.json`),
),
[api, agent],
);
const agentActions: ButtonAction[] = useMemo(
() => [
{
label: "Open in builder",
callback: () => agent && router.push(`/build?flowID=${agent.agent_id}`),
},
...(agent?.can_access_graph
? [
{
label: "Open in builder",
callback: () =>
agent &&
router.push(
`/build?flowID=${agent.agent_id}&flowVersion=${agent.agent_version}`,
),
},
{ label: "Export agent to file", callback: downloadGraph },
]
: []),
{
label: "Delete agent",
variant: "destructive",
callback: () => setAgentDeleteDialogOpen(true),
},
],
[agent, router],
[agent, router, downloadGraph],
);
if (!agent || !graph) {

View File

@@ -29,7 +29,6 @@ export default function CreditsPage() {
formatCredits,
refundTopUp,
refundRequests,
fetchRefundRequests,
} = useCredits({
fetchInitialAutoTopUpConfig: true,
fetchInitialRefundRequests: true,

View File

@@ -1,17 +1,48 @@
import * as React from "react";
import { Sidebar } from "@/components/agptui/Sidebar";
import {
IconDashboardLayout,
IconIntegrations,
IconProfile,
IconSliders,
IconCoin,
} from "@/components/ui/icons";
import { KeyIcon } from "lucide-react";
export default function Layout({ children }: { children: React.ReactNode }) {
const sidebarLinkGroups = [
{
links: [
{ text: "Creator Dashboard", href: "/profile/dashboard" },
{ text: "Agent dashboard", href: "/profile/agent-dashboard" },
{ text: "Billing", href: "/profile/credits" },
{ text: "Integrations", href: "/profile/integrations" },
{ text: "API Keys", href: "/profile/api_keys" },
{ text: "Profile", href: "/profile" },
{ text: "Settings", href: "/profile/settings" },
{
text: "Creator Dashboard",
href: "/profile/dashboard",
icon: <IconDashboardLayout className="h-6 w-6" />,
},
{
text: "Billing",
href: "/profile/credits",
icon: <IconCoin className="h-6 w-6" />,
},
{
text: "Integrations",
href: "/profile/integrations",
icon: <IconIntegrations className="h-6 w-6" />,
},
{
text: "API Keys",
href: "/profile/api_keys",
icon: <KeyIcon className="h-6 w-6" />,
},
{
text: "Profile",
href: "/profile",
icon: <IconProfile className="h-6 w-6" />,
},
{
text: "Settings",
href: "/profile/settings",
icon: <IconSliders className="h-6 w-6" />,
},
],
},
];

View File

@@ -29,7 +29,6 @@ export async function sendResetEmail(email: string) {
return error.message;
}
console.log("Reset email sent");
redirect("/reset_password");
},
);

View File

@@ -461,6 +461,37 @@ const FlowEditor: React.FC<{
});
}, [nodes, setViewport, x, y]);
const fillDefaults = useCallback((obj: any, schema: any) => {
// Iterate over the schema properties
for (const key in schema.properties) {
if (schema.properties.hasOwnProperty(key)) {
const propertySchema = schema.properties[key];
// If the property is not in the object, initialize it with the default value
if (!obj.hasOwnProperty(key)) {
if (propertySchema.default !== undefined) {
obj[key] = propertySchema.default;
} else if (propertySchema.type === "object") {
// Recursively fill defaults for nested objects
obj[key] = fillDefaults({}, propertySchema);
} else if (propertySchema.type === "array") {
// Recursively fill defaults for arrays
obj[key] = fillDefaults([], propertySchema);
}
} else {
// If the property exists, recursively fill defaults for nested objects/arrays
if (propertySchema.type === "object") {
obj[key] = fillDefaults(obj[key], propertySchema);
} else if (propertySchema.type === "array") {
obj[key] = fillDefaults(obj[key], propertySchema);
}
}
}
}
return obj;
}, []);
const addNode = useCallback(
(blockId: string, nodeType: string, hardcodedValues: any = {}) => {
const nodeSchema = availableNodes.find((node) => node.id === blockId);
@@ -507,7 +538,10 @@ const FlowEditor: React.FC<{
categories: nodeSchema.categories,
inputSchema: nodeSchema.inputSchema,
outputSchema: nodeSchema.outputSchema,
hardcodedValues: hardcodedValues,
hardcodedValues: {
...fillDefaults({}, nodeSchema.inputSchema),
...hardcodedValues,
},
connections: [],
isOutputOpen: false,
block_id: blockId,

View File

@@ -56,29 +56,30 @@ const OttoChatWidget = () => {
// Add user message to chat
setMessages((prev) => [...prev, { type: "user", content: userMessage }]);
// Add temporary processing message
setMessages((prev) => [
...prev,
{ type: "assistant", content: "Processing your question..." },
]);
const conversationHistory = messages.reduce<
{ query: string; response: string }[]
>((acc, msg, i, arr) => {
if (
msg.type === "user" &&
i + 1 < arr.length &&
arr[i + 1].type === "assistant" &&
arr[i + 1].content !== "Processing your question..."
) {
acc.push({
query: msg.content,
response: arr[i + 1].content,
});
}
return acc;
}, []);
try {
// Add temporary processing message
setMessages((prev) => [
...prev,
{ type: "assistant", content: "Processing your question..." },
]);
const conversationHistory = messages.reduce<
{ query: string; response: string }[]
>((acc, msg, i, arr) => {
if (
msg.type === "user" &&
i + 1 < arr.length &&
arr[i + 1].type === "assistant"
) {
acc.push({
query: msg.content,
response: arr[i + 1].content,
});
}
return acc;
}, []);
const data = await askOtto(
userMessage,
conversationHistory,
@@ -86,34 +87,43 @@ const OttoChatWidget = () => {
flowID || undefined,
);
// Remove processing message and add actual response
setMessages((prev) => [
...prev.slice(0, -1),
{ type: "assistant", content: data.answer },
]);
} catch (error) {
console.error("Error calling API:", error);
// Remove processing message and add error message
const errorMessage =
error instanceof Error && error.message === "Authentication required"
? "Please sign in to use the chat feature."
: "Sorry, there was an error processing your message. Please try again.";
// Check if the response contains an error
if ("error" in data && data.error === true) {
// Handle different error types
let errorMessage =
"Sorry, there was an error processing your message. Please try again.";
setMessages((prev) => [
...prev.slice(0, -1),
{ type: "assistant", content: errorMessage },
]);
if (data.answer === "Authentication required") {
errorMessage = "Please sign in to use the chat feature.";
} else if (data.answer === "Failed to connect to Otto service") {
errorMessage =
"Otto service is currently unavailable. Please try again later.";
} else if (data.answer.includes("timed out")) {
errorMessage = "Request timed out. Please try again later.";
}
if (
error instanceof Error &&
error.message === "Authentication required"
) {
toast({
title: "Authentication Error",
description: "Please sign in to use the chat feature.",
variant: "destructive",
});
// Remove processing message and add error message
setMessages((prev) => [
...prev.slice(0, -1),
{ type: "assistant", content: errorMessage },
]);
} else {
// Remove processing message and add actual response
setMessages((prev) => [
...prev.slice(0, -1),
{ type: "assistant", content: data.answer },
]);
}
} catch (error) {
console.error("Unexpected error in chat widget:", error);
setMessages((prev) => [
...prev.slice(0, -1),
{
type: "assistant",
content:
"An unexpected error occurred. Please refresh the page and try again.",
},
]);
} finally {
setIsProcessing(false);
setIncludeGraphData(false);

View File

@@ -1,149 +0,0 @@
// "use client";
// import {
// Dialog,
// DialogContent,
// DialogClose,
// DialogFooter,
// DialogHeader,
// DialogTitle,
// DialogTrigger,
// } from "@/components/ui/dialog";
// import { Button } from "@/components/ui/button";
// import {
// MultiSelector,
// MultiSelectorContent,
// MultiSelectorInput,
// MultiSelectorItem,
// MultiSelectorList,
// MultiSelectorTrigger,
// } from "@/components/ui/multiselect";
// import { Controller, useForm } from "react-hook-form";
// import {
// Select,
// SelectContent,
// SelectItem,
// SelectTrigger,
// SelectValue,
// } from "@/components/ui/select";
// import { useState } from "react";
// import { addFeaturedAgent } from "./actions";
// import { Agent } from "@/lib/marketplace-api/types";
// type FormData = {
// agent: string;
// categories: string[];
// };
// export const AdminAddFeaturedAgentDialog = ({
// categories,
// agents,
// }: {
// categories: string[];
// agents: Agent[];
// }) => {
// const [selectedAgent, setSelectedAgent] = useState<string>("");
// const [selectedCategories, setSelectedCategories] = useState<string[]>([]);
// const {
// control,
// handleSubmit,
// watch,
// setValue,
// formState: { errors },
// } = useForm<FormData>({
// defaultValues: {
// agent: "",
// categories: [],
// },
// });
// return (
// <Dialog>
// <DialogTrigger asChild>
// <Button variant="outline" size="sm">
// Add Featured Agent
// </Button>
// </DialogTrigger>
// <DialogContent>
// <DialogHeader>
// <DialogTitle>Add Featured Agent</DialogTitle>
// </DialogHeader>
// <div className="flex flex-col gap-4">
// <Controller
// name="agent"
// control={control}
// rules={{ required: true }}
// render={({ field }) => (
// <div>
// <label htmlFor={field.name}>Agent</label>
// <Select
// onValueChange={(value) => {
// field.onChange(value);
// setSelectedAgent(value);
// }}
// value={field.value || ""}
// >
// <SelectTrigger>
// <SelectValue placeholder="Select an agent" />
// </SelectTrigger>
// <SelectContent>
// {/* Populate with agents */}
// {agents.map((agent) => (
// <SelectItem key={agent.id} value={agent.id}>
// {agent.name}
// </SelectItem>
// ))}
// </SelectContent>
// </Select>
// </div>
// )}
// />
// <Controller
// name="categories"
// control={control}
// render={({ field }) => (
// <MultiSelector
// values={field.value || []}
// onValuesChange={(values) => {
// field.onChange(values);
// setSelectedCategories(values);
// }}
// >
// <MultiSelectorTrigger>
// <MultiSelectorInput placeholder="Select categories" />
// </MultiSelectorTrigger>
// <MultiSelectorContent>
// <MultiSelectorList>
// {categories.map((category) => (
// <MultiSelectorItem key={category} value={category}>
// {category}
// </MultiSelectorItem>
// ))}
// </MultiSelectorList>
// </MultiSelectorContent>
// </MultiSelector>
// )}
// />
// </div>
// <DialogFooter>
// <DialogClose asChild>
// <Button variant="outline">Cancel</Button>
// </DialogClose>
// <DialogClose asChild>
// <Button
// type="submit"
// onClick={async () => {
// // Handle adding the featured agent
// await addFeaturedAgent(selectedAgent, selectedCategories);
// // close the dialog
// }}
// >
// Add
// </Button>
// </DialogClose>
// </DialogFooter>
// </DialogContent>
// </Dialog>
// );
// };

View File

@@ -1,74 +0,0 @@
// import { Button } from "@/components/ui/button";
// import {
// getFeaturedAgents,
// removeFeaturedAgent,
// getCategories,
// getNotFeaturedAgents,
// } from "./actions";
// import FeaturedAgentsTable from "./FeaturedAgentsTable";
// import { AdminAddFeaturedAgentDialog } from "./AdminAddFeaturedAgentDialog";
// import { revalidatePath } from "next/cache";
// import * as Sentry from "@sentry/nextjs";
// export default async function AdminFeaturedAgentsControl({
// className,
// }: {
// className?: string;
// }) {
// // add featured agent button
// // modal to select agent?
// // modal to select categories?
// // table of featured agents
// // in table
// // remove featured agent button
// // edit featured agent categories button
// // table footer
// // Next page button
// // Previous page button
// // Page number input
// // Page size input
// // Total pages input
// // Go to page button
// const page = 1;
// const pageSize = 10;
// const agents = await getFeaturedAgents(page, pageSize);
// const categories = await getCategories();
// const notFeaturedAgents = await getNotFeaturedAgents();
// return (
// <div className={`flex flex-col gap-4 ${className}`}>
// <div className="mb-4 flex justify-between">
// <h3 className="text-lg font-semibold">Featured Agent Controls</h3>
// <AdminAddFeaturedAgentDialog
// categories={categories.unique_categories}
// agents={notFeaturedAgents.items}
// />
// </div>
// <FeaturedAgentsTable
// agents={agents.items}
// globalActions={[
// {
// component: <Button>Remove</Button>,
// action: async (rows) => {
// "use server";
// return await Sentry.withServerActionInstrumentation(
// "removeFeaturedAgent",
// {},
// async () => {
// const all = rows.map((row) => removeFeaturedAgent(row.id));
// await Promise.all(all);
// revalidatePath("/marketplace");
// },
// );
// },
// },
// ]}
// />
// </div>
// );
// }

Some files were not shown because too many files have changed in this diff Show More