Merge branch 'dev' into aarushikansal/remove-title-input-blocks

This commit is contained in:
Bently
2024-12-18 16:00:39 -05:00
committed by GitHub
307 changed files with 26452 additions and 10073 deletions

View File

@@ -79,6 +79,17 @@ jobs:
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Check poetry.lock
run: |
poetry lock --no-update
if ! git diff --quiet poetry.lock; then
echo "Error: poetry.lock not up to date."
echo
git diff poetry.lock
exit 1
fi
- name: Install Python dependencies
run: poetry install

View File

@@ -23,6 +23,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
@@ -38,24 +39,12 @@ jobs:
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
browser: [chromium, webkit]
steps:
- name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@main
with:
# this might remove tools that are actually needed,
# if set to "true" but frees about 6 GB
tool-cache: false
# all of these default to true, but feel free to set to
# "false" if necessary for your workflow
android: false
dotnet: false
haskell: false
large-packages: true
docker-images: true
swap-storage: true
- name: Checkout repository
uses: actions/checkout@v4
with:
@@ -66,6 +55,12 @@ jobs:
with:
node-version: "21"
- name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@main
with:
large-packages: false # slow
docker-images: false # limited benefit
- name: Copy default supabase .env
run: |
cp ../supabase/docker/.env.example ../.env
@@ -86,16 +81,16 @@ jobs:
run: |
cp .env.example .env
- name: Install Playwright Browsers
run: yarn playwright install --with-deps
- name: Install Browser '${{ matrix.browser }}'
run: yarn playwright install --with-deps ${{ matrix.browser }}
- name: Run tests
run: |
yarn test
yarn test --project=${{ matrix.browser }}
- uses: actions/upload-artifact@v4
if: ${{ !cancelled() }}
with:
name: playwright-report
name: playwright-report-${{ matrix.browser }}
path: playwright-report/
retention-days: 30

3
.gitignore vendored
View File

@@ -173,3 +173,6 @@ LICENSE.rtf
autogpt_platform/backend/settings.py
/.auth
/autogpt_platform/frontend/.auth
*.ign.*
.test-contents

View File

@@ -35,3 +35,12 @@ def verify_user(payload: dict | None, admin_only: bool) -> User:
raise fastapi.HTTPException(status_code=403, detail="Admin access required")
return User.from_payload(payload)
def get_user_id(payload: dict = fastapi.Depends(auth_middleware)) -> str:
user_id = payload.get("sub")
if not user_id:
raise fastapi.HTTPException(
status_code=401, detail="User ID not found in token"
)
return user_id

View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@@ -1091,13 +1091,13 @@ pyasn1 = ">=0.4.6,<0.7.0"
[[package]]
name = "pydantic"
version = "2.10.2"
version = "2.10.3"
description = "Data validation using Python type hints"
optional = false
python-versions = ">=3.8"
files = [
{file = "pydantic-2.10.2-py3-none-any.whl", hash = "sha256:cfb96e45951117c3024e6b67b25cdc33a3cb7b2fa62e239f7af1378358a1d99e"},
{file = "pydantic-2.10.2.tar.gz", hash = "sha256:2bc2d7f17232e0841cbba4641e65ba1eb6fafb3a08de3a091ff3ce14a197c4fa"},
{file = "pydantic-2.10.3-py3-none-any.whl", hash = "sha256:be04d85bbc7b65651c5f8e6b9976ed9c6f41782a55524cef079a34a0bb82144d"},
{file = "pydantic-2.10.3.tar.gz", hash = "sha256:cb5ac360ce894ceacd69c403187900a02c4b20b693a9dd1d643e1effab9eadf9"},
]
[package.dependencies]
@@ -1223,13 +1223,13 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
[[package]]
name = "pydantic-settings"
version = "2.6.1"
version = "2.7.0"
description = "Settings management using Pydantic"
optional = false
python-versions = ">=3.8"
files = [
{file = "pydantic_settings-2.6.1-py3-none-any.whl", hash = "sha256:7fb0637c786a558d3103436278a7c4f1cfd29ba8973238a50c5bb9a55387da87"},
{file = "pydantic_settings-2.6.1.tar.gz", hash = "sha256:e0f92546d8a9923cb8941689abf85d6601a8c19a23e97a34b2964a2e3f813ca0"},
{file = "pydantic_settings-2.7.0-py3-none-any.whl", hash = "sha256:e00c05d5fa6cbbb227c84bd7487c5c1065084119b750df7c8c1a554aed236eb5"},
{file = "pydantic_settings-2.7.0.tar.gz", hash = "sha256:ac4bfd4a36831a48dbf8b2d9325425b549a0a6f18cea118436d728eb4f1c4d66"},
]
[package.dependencies]
@@ -1243,13 +1243,13 @@ yaml = ["pyyaml (>=6.0.1)"]
[[package]]
name = "pyjwt"
version = "2.10.0"
version = "2.10.1"
description = "JSON Web Token implementation in Python"
optional = false
python-versions = ">=3.9"
files = [
{file = "PyJWT-2.10.0-py3-none-any.whl", hash = "sha256:543b77207db656de204372350926bed5a86201c4cbff159f623f79c7bb487a15"},
{file = "pyjwt-2.10.0.tar.gz", hash = "sha256:7628a7eb7938959ac1b26e819a1df0fd3259505627b575e4bad6d08f76db695c"},
{file = "PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb"},
{file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"},
]
[package.extras]
@@ -1282,20 +1282,20 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments
[[package]]
name = "pytest-asyncio"
version = "0.24.0"
version = "0.25.0"
description = "Pytest support for asyncio"
optional = false
python-versions = ">=3.8"
python-versions = ">=3.9"
files = [
{file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"},
{file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"},
{file = "pytest_asyncio-0.25.0-py3-none-any.whl", hash = "sha256:db5432d18eac6b7e28b46dcd9b69921b55c3b1086e85febfe04e70b18d9e81b3"},
{file = "pytest_asyncio-0.25.0.tar.gz", hash = "sha256:8c0610303c9e0442a5db8604505fc0f545456ba1528824842b37b4a626cbf609"},
]
[package.dependencies]
pytest = ">=8.2,<9"
[package.extras]
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
[[package]]
@@ -1362,13 +1362,13 @@ websockets = ">=11,<13"
[[package]]
name = "redis"
version = "5.2.0"
version = "5.2.1"
description = "Python client for Redis database and key-value store"
optional = false
python-versions = ">=3.8"
files = [
{file = "redis-5.2.0-py3-none-any.whl", hash = "sha256:ae174f2bb3b1bf2b09d54bf3e51fbc1469cf6c10aa03e21141f51969801a7897"},
{file = "redis-5.2.0.tar.gz", hash = "sha256:0b1087665a771b1ff2e003aa5bdd354f15a70c9e25d5a7dbf9c722c16528a7b0"},
{file = "redis-5.2.1-py3-none-any.whl", hash = "sha256:ee7e1056b9aea0f04c6c2ed59452947f34c4940ee025f5dd83e6a6418b6989e4"},
{file = "redis-5.2.1.tar.gz", hash = "sha256:16f2e22dff21d5125e8481515e386711a34cbec50f0e44413dd7d9c060a54e0f"},
]
[package.dependencies]
@@ -1415,29 +1415,29 @@ pyasn1 = ">=0.1.3"
[[package]]
name = "ruff"
version = "0.8.1"
version = "0.8.3"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
files = [
{file = "ruff-0.8.1-py3-none-linux_armv6l.whl", hash = "sha256:fae0805bd514066f20309f6742f6ee7904a773eb9e6c17c45d6b1600ca65c9b5"},
{file = "ruff-0.8.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b8a4f7385c2285c30f34b200ca5511fcc865f17578383db154e098150ce0a087"},
{file = "ruff-0.8.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:cd054486da0c53e41e0086e1730eb77d1f698154f910e0cd9e0d64274979a209"},
{file = "ruff-0.8.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2029b8c22da147c50ae577e621a5bfbc5d1fed75d86af53643d7a7aee1d23871"},
{file = "ruff-0.8.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2666520828dee7dfc7e47ee4ea0d928f40de72056d929a7c5292d95071d881d1"},
{file = "ruff-0.8.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:333c57013ef8c97a53892aa56042831c372e0bb1785ab7026187b7abd0135ad5"},
{file = "ruff-0.8.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:288326162804f34088ac007139488dcb43de590a5ccfec3166396530b58fb89d"},
{file = "ruff-0.8.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b12c39b9448632284561cbf4191aa1b005882acbc81900ffa9f9f471c8ff7e26"},
{file = "ruff-0.8.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:364e6674450cbac8e998f7b30639040c99d81dfb5bbc6dfad69bc7a8f916b3d1"},
{file = "ruff-0.8.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b22346f845fec132aa39cd29acb94451d030c10874408dbf776af3aaeb53284c"},
{file = "ruff-0.8.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b2f2f7a7e7648a2bfe6ead4e0a16745db956da0e3a231ad443d2a66a105c04fa"},
{file = "ruff-0.8.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:adf314fc458374c25c5c4a4a9270c3e8a6a807b1bec018cfa2813d6546215540"},
{file = "ruff-0.8.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a885d68342a231b5ba4d30b8c6e1b1ee3a65cf37e3d29b3c74069cdf1ee1e3c9"},
{file = "ruff-0.8.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d2c16e3508c8cc73e96aa5127d0df8913d2290098f776416a4b157657bee44c5"},
{file = "ruff-0.8.1-py3-none-win32.whl", hash = "sha256:93335cd7c0eaedb44882d75a7acb7df4b77cd7cd0d2255c93b28791716e81790"},
{file = "ruff-0.8.1-py3-none-win_amd64.whl", hash = "sha256:2954cdbe8dfd8ab359d4a30cd971b589d335a44d444b6ca2cb3d1da21b75e4b6"},
{file = "ruff-0.8.1-py3-none-win_arm64.whl", hash = "sha256:55873cc1a473e5ac129d15eccb3c008c096b94809d693fc7053f588b67822737"},
{file = "ruff-0.8.1.tar.gz", hash = "sha256:3583db9a6450364ed5ca3f3b4225958b24f78178908d5c4bc0f46251ccca898f"},
{file = "ruff-0.8.3-py3-none-linux_armv6l.whl", hash = "sha256:8d5d273ffffff0acd3db5bf626d4b131aa5a5ada1276126231c4174543ce20d6"},
{file = "ruff-0.8.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e4d66a21de39f15c9757d00c50c8cdd20ac84f55684ca56def7891a025d7e939"},
{file = "ruff-0.8.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c356e770811858bd20832af696ff6c7e884701115094f427b64b25093d6d932d"},
{file = "ruff-0.8.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c0a60a825e3e177116c84009d5ebaa90cf40dfab56e1358d1df4e29a9a14b13"},
{file = "ruff-0.8.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:75fb782f4db39501210ac093c79c3de581d306624575eddd7e4e13747e61ba18"},
{file = "ruff-0.8.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f26bc76a133ecb09a38b7868737eded6941b70a6d34ef53a4027e83913b6502"},
{file = "ruff-0.8.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:01b14b2f72a37390c1b13477c1c02d53184f728be2f3ffc3ace5b44e9e87b90d"},
{file = "ruff-0.8.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:53babd6e63e31f4e96ec95ea0d962298f9f0d9cc5990a1bbb023a6baf2503a82"},
{file = "ruff-0.8.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1ae441ce4cf925b7f363d33cd6570c51435972d697e3e58928973994e56e1452"},
{file = "ruff-0.8.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7c65bc0cadce32255e93c57d57ecc2cca23149edd52714c0c5d6fa11ec328cd"},
{file = "ruff-0.8.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5be450bb18f23f0edc5a4e5585c17a56ba88920d598f04a06bd9fd76d324cb20"},
{file = "ruff-0.8.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8faeae3827eaa77f5721f09b9472a18c749139c891dbc17f45e72d8f2ca1f8fc"},
{file = "ruff-0.8.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:db503486e1cf074b9808403991663e4277f5c664d3fe237ee0d994d1305bb060"},
{file = "ruff-0.8.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:6567be9fb62fbd7a099209257fef4ad2c3153b60579818b31a23c886ed4147ea"},
{file = "ruff-0.8.3-py3-none-win32.whl", hash = "sha256:19048f2f878f3ee4583fc6cb23fb636e48c2635e30fb2022b3a1cd293402f964"},
{file = "ruff-0.8.3-py3-none-win_amd64.whl", hash = "sha256:f7df94f57d7418fa7c3ffb650757e0c2b96cf2501a0b192c18e4fb5571dfada9"},
{file = "ruff-0.8.3-py3-none-win_arm64.whl", hash = "sha256:fe2756edf68ea79707c8d68b78ca9a58ed9af22e430430491ee03e718b5e4936"},
{file = "ruff-0.8.3.tar.gz", hash = "sha256:5e7558304353b84279042fc584a4f4cb8a07ae79b2bf3da1a7551d960b5626d3"},
]
[[package]]
@@ -1852,4 +1852,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<4.0"
content-hash = "0aef772b321a1a00163abdc6a88efb21874b84bf25c68e84992c3ae3af026a17"
content-hash = "13a36d3be675cab4a3eb2e6a62a1b08df779bded4c7b9164d8be300dc08748d0"

View File

@@ -10,18 +10,18 @@ packages = [{ include = "autogpt_libs" }]
colorama = "^0.4.6"
expiringdict = "^1.2.2"
google-cloud-logging = "^3.11.3"
pydantic = "^2.10.2"
pydantic-settings = "^2.6.1"
pyjwt = "^2.10.0"
pytest-asyncio = "^0.24.0"
pydantic = "^2.10.3"
pydantic-settings = "^2.7.0"
pyjwt = "^2.10.1"
pytest-asyncio = "^0.25.0"
pytest-mock = "^3.14.0"
python = ">=3.10,<4.0"
python-dotenv = "^1.0.1"
supabase = "^2.10.0"
[tool.poetry.group.dev.dependencies]
redis = "^5.2.0"
ruff = "^0.8.1"
redis = "^5.2.1"
ruff = "^0.8.3"
[build-system]
requires = ["poetry-core"]

View File

@@ -6,18 +6,23 @@ ENV PYTHONUNBUFFERED 1
WORKDIR /app
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
RUN apt-get update --allow-releaseinfo-change --fix-missing
# Install build dependencies
RUN apt-get update \
&& apt-get install -y build-essential curl ffmpeg wget libcurl4-gnutls-dev libexpat1-dev libpq5 gettext libz-dev libssl-dev postgresql-client git \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
ENV POETRY_VERSION=1.8.3 \
POETRY_HOME="/opt/poetry" \
POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_CREATE=false \
PATH="$POETRY_HOME/bin:$PATH"
RUN apt-get install -y build-essential
RUN apt-get install -y libpq5
RUN apt-get install -y libz-dev
RUN apt-get install -y libssl-dev
RUN apt-get install -y postgresql-client
ENV POETRY_VERSION=1.8.3
ENV POETRY_HOME=/opt/poetry
ENV POETRY_NO_INTERACTION=1
ENV POETRY_VIRTUALENVS_CREATE=false
ENV PATH=/opt/poetry/bin:$PATH
# Upgrade pip and setuptools to fix security vulnerabilities
RUN pip3 install --upgrade pip setuptools
@@ -39,11 +44,11 @@ FROM python:3.11.10-slim-bookworm AS server_dependencies
WORKDIR /app
ENV POETRY_VERSION=1.8.3 \
POETRY_HOME="/opt/poetry" \
POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_CREATE=false \
PATH="$POETRY_HOME/bin:$PATH"
ENV POETRY_VERSION=1.8.3
ENV POETRY_HOME=/opt/poetry
ENV POETRY_NO_INTERACTION=1
ENV POETRY_VIRTUALENVS_CREATE=false
ENV PATH=/opt/poetry/bin:$PATH
# Upgrade pip and setuptools to fix security vulnerabilities

View File

@@ -200,4 +200,4 @@ To add a new agent block, you need to create a new class that inherits from `Blo
* `run` method: the main logic of the block.
* `test_input` & `test_output`: the sample input and output data for the block, which will be used to auto-test the block.
* You can mock the functions declared in the block using the `test_mock` field for your unit tests.
* Once you finish creating the block, you can test it by running `pytest -s test/block/test_block.py`.
* Once you finish creating the block, you can test it by running `poetry run pytest -s test/block/test_block.py`.

View File

@@ -12,6 +12,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
class ImageSize(str, Enum):
@@ -101,12 +102,10 @@ class ImageGenModel(str, Enum):
class AIImageGeneratorBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = (
CredentialsField(
provider="replicate",
supported_credential_types={"api_key"},
description="Enter your Replicate API key to access the image generation API. You can obtain an API key from https://replicate.com/account/api-tokens.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.REPLICATE], Literal["api_key"]
] = CredentialsField(
description="Enter your Replicate API key to access the image generation API. You can obtain an API key from https://replicate.com/account/api-tokens.",
)
prompt: str = SchemaField(
description="Text prompt for image generation",

View File

@@ -13,6 +13,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
logger = logging.getLogger(__name__)
@@ -54,13 +55,11 @@ class NormalizationStrategy(str, Enum):
class AIMusicGeneratorBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = (
CredentialsField(
provider="replicate",
supported_credential_types={"api_key"},
description="The Replicate integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.REPLICATE], Literal["api_key"]
] = CredentialsField(
description="The Replicate integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
prompt: str = SchemaField(
description="A description of the music you want to generate",

View File

@@ -12,6 +12,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
@@ -140,13 +141,11 @@ logger = logging.getLogger(__name__)
class AIShortformVideoCreatorBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["revid"], Literal["api_key"]] = (
CredentialsField(
provider="revid",
supported_credential_types={"api_key"},
description="The revid.ai integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.REVID], Literal["api_key"]
] = CredentialsField(
description="The revid.ai integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
script: str = SchemaField(
description="""1. Use short and punctuated sentences\n\n2. Use linebreaks to create a new clip\n\n3. Text outside of brackets is spoken by the AI, and [text between brackets] will be used to guide the visual generation. For example, [close-up of a cat] will show a close-up of a cat.""",

View File

@@ -1,13 +1,11 @@
import re
from typing import Any, List
from jinja2 import BaseLoader, Environment
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
from backend.data.model import SchemaField
from backend.util.mock import MockObject
from backend.util.text import TextFormatter
jinja = Environment(loader=BaseLoader())
formatter = TextFormatter()
class StoreValueBlock(Block):
@@ -296,9 +294,9 @@ class AgentOutputBlock(Block):
"""
if input_data.format:
try:
fmt = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", input_data.format)
template = jinja.from_string(fmt)
yield "output", template.render({input_data.name: input_data.value})
yield "output", formatter.format_string(
input_data.format, {input_data.name: input_data.value}
)
except Exception as e:
yield "output", f"Error: {e}, {input_data.value}"
else:
@@ -486,3 +484,101 @@ class NoteBlock(Block):
def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "output", input_data.text
class CreateDictionaryBlock(Block):
class Input(BlockSchema):
values: dict[str, Any] = SchemaField(
description="Key-value pairs to create the dictionary with",
placeholder="e.g., {'name': 'Alice', 'age': 25}",
)
class Output(BlockSchema):
dictionary: dict[str, Any] = SchemaField(
description="The created dictionary containing the specified key-value pairs"
)
error: str = SchemaField(
description="Error message if dictionary creation failed"
)
def __init__(self):
super().__init__(
id="b924ddf4-de4f-4b56-9a85-358930dcbc91",
description="Creates a dictionary with the specified key-value pairs. Use this when you know all the values you want to add upfront.",
categories={BlockCategory.DATA},
input_schema=CreateDictionaryBlock.Input,
output_schema=CreateDictionaryBlock.Output,
test_input=[
{
"values": {"name": "Alice", "age": 25, "city": "New York"},
},
{
"values": {"numbers": [1, 2, 3], "active": True, "score": 95.5},
},
],
test_output=[
(
"dictionary",
{"name": "Alice", "age": 25, "city": "New York"},
),
(
"dictionary",
{"numbers": [1, 2, 3], "active": True, "score": 95.5},
),
],
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
# The values are already validated by Pydantic schema
yield "dictionary", input_data.values
except Exception as e:
yield "error", f"Failed to create dictionary: {str(e)}"
class CreateListBlock(Block):
class Input(BlockSchema):
values: List[Any] = SchemaField(
description="A list of values to be combined into a new list.",
placeholder="e.g., ['Alice', 25, True]",
)
class Output(BlockSchema):
list: List[Any] = SchemaField(
description="The created list containing the specified values."
)
error: str = SchemaField(description="Error message if list creation failed.")
def __init__(self):
super().__init__(
id="a912d5c7-6e00-4542-b2a9-8034136930e4",
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront.",
categories={BlockCategory.DATA},
input_schema=CreateListBlock.Input,
output_schema=CreateListBlock.Output,
test_input=[
{
"values": ["Alice", 25, True],
},
{
"values": [1, 2, 3, "four", {"key": "value"}],
},
],
test_output=[
(
"list",
["Alice", 25, True],
),
(
"list",
[1, 2, 3, "four", {"key": "value"}],
),
],
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
# The values are already validated by Pydantic schema
yield "list", input_data.values
except Exception as e:
yield "error", f"Failed to create list: {str(e)}"

View File

@@ -0,0 +1,190 @@
from enum import Enum
from typing import Literal
from e2b_code_interpreter import Sandbox
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="e2b",
api_key=SecretStr("mock-e2b-api-key"),
title="Mock E2B API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.type,
}
class ProgrammingLanguage(Enum):
PYTHON = "python"
JAVASCRIPT = "js"
BASH = "bash"
R = "r"
JAVA = "java"
class CodeExecutionBlock(Block):
# TODO : Add support to upload and download files
# Currently, You can customized the CPU and Memory, only by creating a pre customized sandbox template
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.E2B], Literal["api_key"]
] = CredentialsField(
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
)
# Todo : Option to run commond in background
setup_commands: list[str] = SchemaField(
description=(
"Shell commands to set up the sandbox before running the code. "
"You can use `curl` or `git` to install your desired Debian based "
"package manager. `pip` and `npm` are pre-installed.\n\n"
"These commands are executed with `sh`, in the foreground."
),
placeholder="pip install cowsay",
default=[],
advanced=False,
)
code: str = SchemaField(
description="Code to execute in the sandbox",
placeholder="print('Hello, World!')",
default="",
advanced=False,
)
language: ProgrammingLanguage = SchemaField(
description="Programming language to execute",
default=ProgrammingLanguage.PYTHON,
advanced=False,
)
timeout: int = SchemaField(
description="Execution timeout in seconds", default=300
)
template_id: str = SchemaField(
description=(
"You can use an E2B sandbox template by entering its ID here. "
"Check out the E2B docs for more details: "
"[E2B - Sandbox template](https://e2b.dev/docs/sandbox-template)"
),
default="",
advanced=True,
)
class Output(BlockSchema):
response: str = SchemaField(description="Response from code execution")
stdout_logs: str = SchemaField(
description="Standard output logs from execution"
)
stderr_logs: str = SchemaField(description="Standard error logs from execution")
error: str = SchemaField(description="Error message if execution failed")
def __init__(self):
super().__init__(
id="0b02b072-abe7-11ef-8372-fb5d162dd712",
description="Executes code in an isolated sandbox environment with internet access.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=CodeExecutionBlock.Input,
output_schema=CodeExecutionBlock.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"code": "print('Hello World')",
"language": ProgrammingLanguage.PYTHON.value,
"setup_commands": [],
"timeout": 300,
"template_id": "",
},
test_output=[
("response", "Hello World"),
("stdout_logs", "Hello World\n"),
],
test_mock={
"execute_code": lambda code, language, setup_commands, timeout, api_key, template_id: (
"Hello World",
"Hello World\n",
"",
),
},
)
def execute_code(
self,
code: str,
language: ProgrammingLanguage,
setup_commands: list[str],
timeout: int,
api_key: str,
template_id: str,
):
try:
sandbox = None
if template_id:
sandbox = Sandbox(
template=template_id, api_key=api_key, timeout=timeout
)
else:
sandbox = Sandbox(api_key=api_key, timeout=timeout)
if not sandbox:
raise Exception("Sandbox not created")
# Running setup commands
for cmd in setup_commands:
sandbox.commands.run(cmd)
# Executing the code
execution = sandbox.run_code(
code,
language=language.value,
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
)
if execution.error:
raise Exception(execution.error)
response = execution.text
stdout_logs = "".join(execution.logs.stdout)
stderr_logs = "".join(execution.logs.stderr)
return response, stdout_logs, stderr_logs
except Exception as e:
raise e
def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
response, stdout_logs, stderr_logs = self.execute_code(
input_data.code,
input_data.language,
input_data.setup_commands,
input_data.timeout,
credentials.api_key.get_secret_value(),
input_data.template_id,
)
if response:
yield "response", response
if stdout_logs:
yield "stdout_logs", stdout_logs
if stderr_logs:
yield "stderr_logs", stderr_logs
except Exception as e:
yield "error", str(e)

View File

@@ -0,0 +1,110 @@
import re
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class CodeExtractionBlock(Block):
class Input(BlockSchema):
text: str = SchemaField(
description="Text containing code blocks to extract (e.g., AI response)",
placeholder="Enter text containing code blocks",
)
class Output(BlockSchema):
html: str = SchemaField(description="Extracted HTML code")
css: str = SchemaField(description="Extracted CSS code")
javascript: str = SchemaField(description="Extracted JavaScript code")
python: str = SchemaField(description="Extracted Python code")
sql: str = SchemaField(description="Extracted SQL code")
java: str = SchemaField(description="Extracted Java code")
cpp: str = SchemaField(description="Extracted C++ code")
csharp: str = SchemaField(description="Extracted C# code")
json_code: str = SchemaField(description="Extracted JSON code")
bash: str = SchemaField(description="Extracted Bash code")
php: str = SchemaField(description="Extracted PHP code")
ruby: str = SchemaField(description="Extracted Ruby code")
yaml: str = SchemaField(description="Extracted YAML code")
markdown: str = SchemaField(description="Extracted Markdown code")
typescript: str = SchemaField(description="Extracted TypeScript code")
xml: str = SchemaField(description="Extracted XML code")
remaining_text: str = SchemaField(
description="Remaining text after code extraction"
)
def __init__(self):
super().__init__(
id="d3a7d896-3b78-4f44-8b4b-48fbf4f0bcd8",
description="Extracts code blocks from text and identifies their programming languages",
categories={BlockCategory.TEXT},
input_schema=CodeExtractionBlock.Input,
output_schema=CodeExtractionBlock.Output,
test_input={
"text": "Here's a Python example:\n```python\nprint('Hello World')\n```\nAnd some HTML:\n```html\n<h1>Title</h1>\n```"
},
test_output=[
("html", "<h1>Title</h1>"),
("python", "print('Hello World')"),
("remaining_text", "Here's a Python example:\nAnd some HTML:"),
],
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
# List of supported programming languages with mapped aliases
language_aliases = {
"html": ["html", "htm"],
"css": ["css"],
"javascript": ["javascript", "js"],
"python": ["python", "py"],
"sql": ["sql"],
"java": ["java"],
"cpp": ["cpp", "c++"],
"csharp": ["csharp", "c#", "cs"],
"json_code": ["json"],
"bash": ["bash", "shell", "sh"],
"php": ["php"],
"ruby": ["ruby", "rb"],
"yaml": ["yaml", "yml"],
"markdown": ["markdown", "md"],
"typescript": ["typescript", "ts"],
"xml": ["xml"],
}
# Extract code for each language
for canonical_name, aliases in language_aliases.items():
code = ""
# Try each alias for the language
for alias in aliases:
code_for_alias = self.extract_code(input_data.text, alias)
if code_for_alias:
code = code + "\n\n" + code_for_alias if code else code_for_alias
if code: # Only yield if there's actual code content
yield canonical_name, code
# Remove all code blocks from the text to get remaining text
pattern = (
r"```(?:"
+ "|".join(
re.escape(alias)
for aliases in language_aliases.values()
for alias in aliases
)
+ r")\s+[\s\S]*?```"
)
remaining_text = re.sub(pattern, "", input_data.text).strip()
remaining_text = re.sub(r"\n\s*\n", "\n", remaining_text)
if remaining_text: # Only yield if there's remaining text
yield "remaining_text", remaining_text
def extract_code(self, text: str, language: str) -> str:
# Escape special regex characters in the language string
language = re.escape(language)
# Extract all code blocks enclosed in ```language``` blocks
pattern = re.compile(rf"```{language}\s+(.*?)```", re.DOTALL | re.IGNORECASE)
matches = pattern.finditer(text)
# Combine all code blocks for this language with newlines between them
code_blocks = [match.group(1).strip() for match in matches]
return "\n\n".join(code_blocks) if code_blocks else ""

View File

@@ -0,0 +1,59 @@
from pydantic import BaseModel
from backend.data.block import (
Block,
BlockCategory,
BlockManualWebhookConfig,
BlockOutput,
BlockSchema,
)
from backend.data.model import SchemaField
from backend.integrations.webhooks.compass import CompassWebhookType
class Transcription(BaseModel):
text: str
speaker: str
end: float
start: float
duration: float
class TranscriptionDataModel(BaseModel):
date: str
transcription: str
transcriptions: list[Transcription]
class CompassAITriggerBlock(Block):
class Input(BlockSchema):
payload: TranscriptionDataModel = SchemaField(hidden=True)
class Output(BlockSchema):
transcription: str = SchemaField(
description="The contents of the compass transcription."
)
def __init__(self):
super().__init__(
id="9464a020-ed1d-49e1-990f-7f2ac924a2b7",
description="This block will output the contents of the compass transcription.",
categories={BlockCategory.HARDWARE},
input_schema=CompassAITriggerBlock.Input,
output_schema=CompassAITriggerBlock.Output,
webhook_config=BlockManualWebhookConfig(
provider="compass",
webhook_type=CompassWebhookType.TRANSCRIPTION,
),
test_input=[
{"input": "Hello, World!"},
{"input": "Hello, World!", "data": "Existing Data"},
],
# test_output=[
# ("output", "Hello, World!"), # No data provided, so trigger is returned
# ("output", "Existing Data"), # Data is provided, so data is returned.
# ],
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "transcription", input_data.payload.transcription

View File

@@ -12,16 +12,15 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
DiscordCredentials = CredentialsMetaInput[Literal["discord"], Literal["api_key"]]
DiscordCredentials = CredentialsMetaInput[
Literal[ProviderName.DISCORD], Literal["api_key"]
]
def DiscordCredentialsField() -> DiscordCredentials:
return CredentialsField(
description="Discord bot token",
provider="discord",
supported_credential_types={"api_key"},
)
return CredentialsField(description="Discord bot token")
TEST_CREDENTIALS = APIKeyCredentials(

View File

@@ -0,0 +1,32 @@
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
ExaCredentials = APIKeyCredentials
ExaCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.EXA],
Literal["api_key"],
]
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="exa",
api_key=SecretStr("mock-exa-api-key"),
title="Mock Exa API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
def ExaCredentialsField() -> ExaCredentialsInput:
"""Creates an Exa credentials input on a block."""
return CredentialsField(description="The Exa integration requires an API Key.")

View File

@@ -0,0 +1,157 @@
from datetime import datetime
from typing import List
from pydantic import BaseModel
from backend.blocks.exa._auth import (
ExaCredentials,
ExaCredentialsField,
ExaCredentialsInput,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import requests
class ContentSettings(BaseModel):
text: dict = SchemaField(
description="Text content settings",
default={"maxCharacters": 1000, "includeHtmlTags": False},
)
highlights: dict = SchemaField(
description="Highlight settings",
default={"numSentences": 3, "highlightsPerUrl": 3},
)
summary: dict = SchemaField(
description="Summary settings",
default={"query": ""},
)
class ExaSearchBlock(Block):
class Input(BlockSchema):
credentials: ExaCredentialsInput = ExaCredentialsField()
query: str = SchemaField(description="The search query")
useAutoprompt: bool = SchemaField(
description="Whether to use autoprompt",
default=True,
)
type: str = SchemaField(
description="Type of search",
default="",
)
category: str = SchemaField(
description="Category to search within",
default="",
)
numResults: int = SchemaField(
description="Number of results to return",
default=10,
)
includeDomains: List[str] = SchemaField(
description="Domains to include in search",
default=[],
)
excludeDomains: List[str] = SchemaField(
description="Domains to exclude from search",
default=[],
)
startCrawlDate: datetime = SchemaField(
description="Start date for crawled content",
)
endCrawlDate: datetime = SchemaField(
description="End date for crawled content",
)
startPublishedDate: datetime = SchemaField(
description="Start date for published content",
)
endPublishedDate: datetime = SchemaField(
description="End date for published content",
)
includeText: List[str] = SchemaField(
description="Text patterns to include",
default=[],
)
excludeText: List[str] = SchemaField(
description="Text patterns to exclude",
default=[],
)
contents: ContentSettings = SchemaField(
description="Content retrieval settings",
default=ContentSettings(),
)
class Output(BlockSchema):
results: list = SchemaField(
description="List of search results",
default=[],
)
def __init__(self):
super().__init__(
id="996cec64-ac40-4dde-982f-b0dc60a5824d",
description="Searches the web using Exa's advanced search API",
categories={BlockCategory.SEARCH},
input_schema=ExaSearchBlock.Input,
output_schema=ExaSearchBlock.Output,
)
def run(
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/search"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
payload = {
"query": input_data.query,
"useAutoprompt": input_data.useAutoprompt,
"numResults": input_data.numResults,
"contents": {
"text": {"maxCharacters": 1000, "includeHtmlTags": False},
"highlights": {
"numSentences": 3,
"highlightsPerUrl": 3,
},
"summary": {"query": ""},
},
}
# Add dates if they exist
date_fields = [
"startCrawlDate",
"endCrawlDate",
"startPublishedDate",
"endPublishedDate",
]
for field in date_fields:
value = getattr(input_data, field, None)
if value:
payload[field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
# Add other fields
optional_fields = [
"type",
"category",
"includeDomains",
"excludeDomains",
"includeText",
"excludeText",
]
for field in optional_fields:
value = getattr(input_data, field)
if value: # Only add non-empty values
payload[field] = value
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
# Extract just the results array from the response
yield "results", data.get("results", [])
except Exception as e:
yield "error", str(e)
yield "results", []

View File

@@ -3,10 +3,11 @@ from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
FalCredentials = APIKeyCredentials
FalCredentialsInput = CredentialsMetaInput[
Literal["fal"],
Literal[ProviderName.FAL],
Literal["api_key"],
]
@@ -30,7 +31,5 @@ def FalCredentialsField() -> FalCredentialsInput:
Creates a FAL credentials input on a block.
"""
return CredentialsField(
provider="fal",
supported_credential_types={"api_key"},
description="The FAL integration can be used with an API Key.",
)

View File

@@ -8,6 +8,7 @@ from backend.data.model import (
CredentialsMetaInput,
OAuth2Credentials,
)
from backend.integrations.providers import ProviderName
from backend.util.settings import Secrets
secrets = Secrets()
@@ -17,7 +18,7 @@ GITHUB_OAUTH_IS_CONFIGURED = bool(
GithubCredentials = APIKeyCredentials | OAuth2Credentials
GithubCredentialsInput = CredentialsMetaInput[
Literal["github"],
Literal[ProviderName.GITHUB],
Literal["api_key", "oauth2"] if GITHUB_OAUTH_IS_CONFIGURED else Literal["api_key"],
]
@@ -30,10 +31,6 @@ def GithubCredentialsField(scope: str) -> GithubCredentialsInput:
scope: The authorization scope needed for the block to work. ([list of available scopes](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/scopes-for-oauth-apps#available-scopes))
""" # noqa
return CredentialsField(
provider="github",
supported_credential_types=(
{"api_key", "oauth2"} if GITHUB_OAUTH_IS_CONFIGURED else {"api_key"}
),
required_scopes={scope},
description="The GitHub integration can be used with OAuth, "
"or any API key with sufficient permissions for the blocks it is used on.",

View File

@@ -1,3 +1,5 @@
import re
from typing_extensions import TypedDict
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
@@ -253,7 +255,7 @@ class GithubReadPullRequestBlock(Block):
@staticmethod
def read_pr_changes(credentials: GithubCredentials, pr_url: str) -> str:
api = get_api(credentials)
files_url = pr_url + "/files"
files_url = prepare_pr_api_url(pr_url=pr_url, path="files")
response = api.get(files_url)
files = response.json()
changes = []
@@ -331,7 +333,7 @@ class GithubAssignPRReviewerBlock(Block):
credentials: GithubCredentials, pr_url: str, reviewer: str
) -> str:
api = get_api(credentials)
reviewers_url = pr_url + "/requested_reviewers"
reviewers_url = prepare_pr_api_url(pr_url=pr_url, path="requested_reviewers")
data = {"reviewers": [reviewer]}
api.post(reviewers_url, json=data)
return "Reviewer assigned successfully"
@@ -398,7 +400,7 @@ class GithubUnassignPRReviewerBlock(Block):
credentials: GithubCredentials, pr_url: str, reviewer: str
) -> str:
api = get_api(credentials)
reviewers_url = pr_url + "/requested_reviewers"
reviewers_url = prepare_pr_api_url(pr_url=pr_url, path="requested_reviewers")
data = {"reviewers": [reviewer]}
api.delete(reviewers_url, json=data)
return "Reviewer unassigned successfully"
@@ -478,7 +480,7 @@ class GithubListPRReviewersBlock(Block):
credentials: GithubCredentials, pr_url: str
) -> list[Output.ReviewerItem]:
api = get_api(credentials)
reviewers_url = pr_url + "/requested_reviewers"
reviewers_url = prepare_pr_api_url(pr_url=pr_url, path="requested_reviewers")
response = api.get(reviewers_url)
data = response.json()
reviewers: list[GithubListPRReviewersBlock.Output.ReviewerItem] = [
@@ -499,3 +501,14 @@ class GithubListPRReviewersBlock(Block):
input_data.pr_url,
)
yield from (("reviewer", reviewer) for reviewer in reviewers)
def prepare_pr_api_url(pr_url: str, path: str) -> str:
# Pattern to capture the base repository URL and the pull request number
pattern = r"^(?:https?://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
match = re.match(pattern, pr_url)
if not match:
return pr_url
base_url, pr_number = match.groups()
return f"{base_url}/pulls/{pr_number}/{path}"

View File

@@ -3,6 +3,7 @@ from typing import Literal
from pydantic import SecretStr
from backend.data.model import CredentialsField, CredentialsMetaInput, OAuth2Credentials
from backend.integrations.providers import ProviderName
from backend.util.settings import Secrets
# --8<-- [start:GoogleOAuthIsConfigured]
@@ -12,7 +13,9 @@ GOOGLE_OAUTH_IS_CONFIGURED = bool(
)
# --8<-- [end:GoogleOAuthIsConfigured]
GoogleCredentials = OAuth2Credentials
GoogleCredentialsInput = CredentialsMetaInput[Literal["google"], Literal["oauth2"]]
GoogleCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.GOOGLE], Literal["oauth2"]
]
def GoogleCredentialsField(scopes: list[str]) -> GoogleCredentialsInput:
@@ -23,8 +26,6 @@ def GoogleCredentialsField(scopes: list[str]) -> GoogleCredentialsInput:
scopes: The authorization scopes needed for the block to work.
"""
return CredentialsField(
provider="google",
supported_credential_types={"oauth2"},
required_scopes=set(scopes),
description="The Google integration requires OAuth2 authentication.",
)

View File

@@ -10,6 +10,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -38,12 +39,8 @@ class Place(BaseModel):
class GoogleMapsSearchBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal["google_maps"], Literal["api_key"]
] = CredentialsField(
provider="google_maps",
supported_credential_types={"api_key"},
description="Google Maps API Key",
)
Literal[ProviderName.GOOGLE_MAPS], Literal["api_key"]
] = CredentialsField(description="Google Maps API Key")
query: str = SchemaField(
description="Search query for local businesses",
placeholder="e.g., 'restaurants in New York'",

View File

@@ -3,10 +3,11 @@ from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
HubSpotCredentials = APIKeyCredentials
HubSpotCredentialsInput = CredentialsMetaInput[
Literal["hubspot"],
Literal[ProviderName.HUBSPOT],
Literal["api_key"],
]
@@ -14,8 +15,6 @@ HubSpotCredentialsInput = CredentialsMetaInput[
def HubSpotCredentialsField() -> HubSpotCredentialsInput:
"""Creates a HubSpot credentials input on a block."""
return CredentialsField(
provider="hubspot",
supported_credential_types={"api_key"},
description="The HubSpot integration requires an API Key.",
)

View File

@@ -11,6 +11,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
@@ -83,13 +84,10 @@ class UpscaleOption(str, Enum):
class IdeogramModelBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["ideogram"], Literal["api_key"]] = (
CredentialsField(
provider="ideogram",
supported_credential_types={"api_key"},
description="The Ideogram integration can be used with any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.IDEOGRAM], Literal["api_key"]
] = CredentialsField(
description="The Ideogram integration can be used with any API key with sufficient permissions for the blocks it is used on.",
)
prompt: str = SchemaField(
description="Text prompt for image generation",

View File

@@ -3,27 +3,14 @@ from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
JinaCredentials = APIKeyCredentials
JinaCredentialsInput = CredentialsMetaInput[
Literal["jina"],
Literal[ProviderName.JINA],
Literal["api_key"],
]
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="jina",
api_key=SecretStr("mock-jina-api-key"),
title="Mock Jina API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.type,
}
def JinaCredentialsField() -> JinaCredentialsInput:
"""
@@ -31,8 +18,6 @@ def JinaCredentialsField() -> JinaCredentialsInput:
"""
return CredentialsField(
provider="jina",
supported_credential_types={"api_key"},
description="The Jina integration can be used with an API Key.",
)

View File

@@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple
from pydantic import SecretStr
from backend.integrations.providers import ProviderName
if TYPE_CHECKING:
from enum import _EnumMemberT
@@ -27,7 +29,13 @@ from backend.util.settings import BehaveAs, Settings
logger = logging.getLogger(__name__)
LLMProviderName = Literal["anthropic", "groq", "openai", "ollama", "open_router"]
LLMProviderName = Literal[
ProviderName.ANTHROPIC,
ProviderName.GROQ,
ProviderName.OLLAMA,
ProviderName.OPENAI,
ProviderName.OPEN_ROUTER,
]
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
TEST_CREDENTIALS = APIKeyCredentials(
@@ -48,8 +56,6 @@ TEST_CREDENTIALS_INPUT = {
def AICredentialsField() -> AICredentials:
return CredentialsField(
description="API key for the LLM provider.",
provider=["anthropic", "groq", "openai", "ollama", "open_router"],
supported_credential_types={"api_key"},
discriminator="model",
discriminator_mapping={
model.value: model.metadata.provider for model in LlmModel
@@ -105,9 +111,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
# Ollama models
OLLAMA_LLAMA3_8B = "llama3"
OLLAMA_LLAMA3_405B = "llama3.1:405b"
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
# OpenRouter models
GEMINI_FLASH_1_5_8B = "google/gemini-flash-1.5"
GEMINI_FLASH_1_5_EXP = "google/gemini-flash-1.5-exp"
GROK_BETA = "x-ai/grok-beta"
MISTRAL_NEMO = "mistralai/mistral-nemo"
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
@@ -117,6 +123,14 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE = (
"perplexity/llama-3.1-sonar-large-128k-online"
)
QWEN_QWQ_32B_PREVIEW = "qwen/qwq-32b-preview"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
AMAZON_NOVA_LITE_V1 = "amazon/nova-lite-v1"
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
@property
def metadata(self) -> ModelMetadata:
@@ -151,8 +165,8 @@ MODEL_METADATA = {
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072),
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192),
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768),
LlmModel.GEMINI_FLASH_1_5_8B: ModelMetadata("open_router", 8192),
LlmModel.GEMINI_FLASH_1_5_EXP: ModelMetadata("open_router", 8192),
LlmModel.GROK_BETA: ModelMetadata("open_router", 8192),
LlmModel.MISTRAL_NEMO: ModelMetadata("open_router", 4000),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata("open_router", 4000),
@@ -162,6 +176,14 @@ MODEL_METADATA = {
LlmModel.PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE: ModelMetadata(
"open_router", 8192
),
LlmModel.QWEN_QWQ_32B_PREVIEW: ModelMetadata("open_router", 4000),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata("open_router", 4000),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata("open_router", 4000),
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata("open_router", 4000),
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata("open_router", 4000),
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata("open_router", 4000),
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata("open_router", 4000),
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata("open_router", 4000),
}
for model in LlmModel:
@@ -220,6 +242,12 @@ class AIStructuredResponseGeneratorBlock(Block):
description="The maximum number of tokens to generate in the chat completion.",
)
ollama_host: str = SchemaField(
advanced=True,
default="localhost:11434",
description="Ollama host for local models",
)
class Output(BlockSchema):
response: dict[str, Any] = SchemaField(
description="The response object generated by the language model."
@@ -265,6 +293,7 @@ class AIStructuredResponseGeneratorBlock(Block):
prompt: list[dict],
json_format: bool,
max_tokens: int | None = None,
ollama_host: str = "localhost:11434",
) -> tuple[str, int, int]:
"""
Args:
@@ -273,6 +302,7 @@ class AIStructuredResponseGeneratorBlock(Block):
prompt: The prompt to send to the LLM.
json_format: Whether the response should be in JSON format.
max_tokens: The maximum number of tokens to generate in the chat completion.
ollama_host: The host for ollama to use
Returns:
The response from the LLM.
@@ -362,9 +392,10 @@ class AIStructuredResponseGeneratorBlock(Block):
response.usage.completion_tokens if response.usage else 0,
)
elif provider == "ollama":
client = ollama.Client(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
response = ollama.generate(
response = client.generate(
model=llm_model.value,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
@@ -464,6 +495,7 @@ class AIStructuredResponseGeneratorBlock(Block):
llm_model=llm_model,
prompt=prompt,
json_format=bool(input_data.expected_format),
ollama_host=input_data.ollama_host,
max_tokens=input_data.max_tokens,
)
self.merge_stats(
@@ -546,6 +578,11 @@ class AITextGeneratorBlock(Block):
prompt_values: dict[str, str] = SchemaField(
advanced=False, default={}, description="Values used to fill in the prompt."
)
ollama_host: str = SchemaField(
advanced=True,
default="localhost:11434",
description="Ollama host for local models",
)
max_tokens: int | None = SchemaField(
advanced=True,
default=None,
@@ -636,6 +673,11 @@ class AITextSummarizerBlock(Block):
description="The number of overlapping tokens between chunks to maintain context.",
ge=0,
)
ollama_host: str = SchemaField(
advanced=True,
default="localhost:11434",
description="Ollama host for local models",
)
class Output(BlockSchema):
summary: str = SchemaField(description="The final summary of the text.")
@@ -774,6 +816,11 @@ class AIConversationBlock(Block):
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
ollama_host: str = SchemaField(
advanced=True,
default="localhost:11434",
description="Ollama host for local models",
)
class Output(BlockSchema):
response: str = SchemaField(
@@ -871,6 +918,11 @@ class AIListGeneratorBlock(Block):
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
ollama_host: str = SchemaField(
advanced=True,
default="localhost:11434",
description="Ollama host for local models",
)
class Output(BlockSchema):
generated_list: List[str] = SchemaField(description="The generated list.")
@@ -1022,6 +1074,7 @@ class AIListGeneratorBlock(Block):
credentials=input_data.credentials,
model=input_data.model,
expected_format={}, # Do not use structured response
ollama_host=input_data.ollama_host,
),
credentials=credentials,
)

View File

@@ -12,6 +12,7 @@ from backend.data.model import (
SchemaField,
SecretField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
@@ -77,12 +78,10 @@ class PublishToMediumBlock(Block):
description="Whether to notify followers that the user has published",
placeholder="False",
)
credentials: CredentialsMetaInput[Literal["medium"], Literal["api_key"]] = (
CredentialsField(
provider="medium",
supported_credential_types={"api_key"},
description="The Medium integration can be used with any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.MEDIUM], Literal["api_key"]
] = CredentialsField(
description="The Medium integration can be used with any API key with sufficient permissions for the blocks it is used on.",
)
class Output(BlockSchema):

View File

@@ -10,22 +10,18 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
PineconeCredentials = APIKeyCredentials
PineconeCredentialsInput = CredentialsMetaInput[
Literal["pinecone"],
Literal[ProviderName.PINECONE],
Literal["api_key"],
]
def PineconeCredentialsField() -> PineconeCredentialsInput:
"""
Creates a Pinecone credentials input on a block.
"""
"""Creates a Pinecone credentials input on a block."""
return CredentialsField(
provider="pinecone",
supported_credential_types={"api_key"},
description="The Pinecone integration can be used with an API Key.",
)
@@ -147,7 +143,7 @@ class PineconeQueryBlock(Block):
top_k=input_data.top_k,
include_values=input_data.include_values,
include_metadata=input_data.include_metadata,
).to_dict()
).to_dict() # type: ignore
combined_text = ""
if results["matches"]:
texts = [

View File

@@ -13,6 +13,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -54,13 +55,11 @@ class ImageType(str, Enum):
class ReplicateFluxAdvancedModelBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = (
CredentialsField(
provider="replicate",
supported_credential_types={"api_key"},
description="The Replicate integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.REPLICATE], Literal["api_key"]
] = CredentialsField(
description="The Replicate integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
prompt: str = SchemaField(
description="Text prompt for image generation",

View File

@@ -11,6 +11,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
class GetWikipediaSummaryBlock(Block, GetRequest):
@@ -65,10 +66,8 @@ class GetWeatherInformationBlock(Block, GetRequest):
description="Location to get weather information for"
)
credentials: CredentialsMetaInput[
Literal["openweathermap"], Literal["api_key"]
Literal[ProviderName.OPENWEATHERMAP], Literal["api_key"]
] = CredentialsField(
provider="openweathermap",
supported_credential_types={"api_key"},
description="The OpenWeatherMap integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)

View File

@@ -4,16 +4,15 @@ from typing import Literal
from pydantic import BaseModel, SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
Slant3DCredentialsInput = CredentialsMetaInput[Literal["slant3d"], Literal["api_key"]]
Slant3DCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.SLANT3D], Literal["api_key"]
]
def Slant3DCredentialsField() -> Slant3DCredentialsInput:
return CredentialsField(
provider="slant3d",
supported_credential_types={"api_key"},
description="Slant3D API key for authentication",
)
return CredentialsField(description="Slant3D API key for authentication")
TEST_CREDENTIALS = APIKeyCredentials(

View File

@@ -10,6 +10,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
@@ -29,13 +30,11 @@ TEST_CREDENTIALS_INPUT = {
class CreateTalkingAvatarVideoBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["d_id"], Literal["api_key"]] = (
CredentialsField(
provider="d_id",
supported_credential_types={"api_key"},
description="The D-ID integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.D_ID], Literal["api_key"]
] = CredentialsField(
description="The D-ID integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
script_input: str = SchemaField(
description="The text input for the script",

View File

@@ -1,13 +1,11 @@
import re
from typing import Any
from jinja2 import BaseLoader, Environment
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util import json
from backend.util import json, text
jinja = Environment(loader=BaseLoader())
formatter = text.TextFormatter()
class MatchTextPatternBlock(Block):
@@ -73,6 +71,7 @@ class ExtractTextInformationBlock(Block):
description="Case sensitive match", default=True
)
dot_all: bool = SchemaField(description="Dot matches all", default=True)
find_all: bool = SchemaField(description="Find all matches", default=False)
class Output(BlockSchema):
positive: str = SchemaField(description="Extracted text")
@@ -90,12 +89,27 @@ class ExtractTextInformationBlock(Block):
{"text": "Hello, World!", "pattern": "Hello, (.+)", "group": 0},
{"text": "Hello, World!", "pattern": "Hello, (.+)", "group": 2},
{"text": "Hello, World!", "pattern": "hello,", "case_sensitive": False},
{
"text": "Hello, World!! Hello, Earth!!",
"pattern": "Hello, (\\S+)",
"group": 1,
"find_all": False,
},
{
"text": "Hello, World!! Hello, Earth!!",
"pattern": "Hello, (\\S+)",
"group": 1,
"find_all": True,
},
],
test_output=[
("positive", "World!"),
("positive", "Hello, World!"),
("negative", "Hello, World!"),
("positive", "Hello,"),
("positive", "World!!"),
("positive", "World!!"),
("positive", "Earth!!"),
],
)
@@ -107,15 +121,21 @@ class ExtractTextInformationBlock(Block):
flags = flags | re.DOTALL
if isinstance(input_data.text, str):
text = input_data.text
txt = input_data.text
else:
text = json.dumps(input_data.text)
txt = json.dumps(input_data.text)
match = re.search(input_data.pattern, text, flags)
if match and input_data.group <= len(match.groups()):
yield "positive", match.group(input_data.group)
else:
yield "negative", text
matches = [
match.group(input_data.group)
for match in re.finditer(input_data.pattern, txt, flags)
if input_data.group <= len(match.groups())
]
for match in matches:
yield "positive", match
if not input_data.find_all:
return
if not matches:
yield "negative", input_data.text
class FillTextTemplateBlock(Block):
@@ -146,19 +166,20 @@ class FillTextTemplateBlock(Block):
"values": {"list": ["Hello", " World!"]},
"format": "{% for item in list %}{{ item }}{% endfor %}",
},
{
"values": {},
"format": "{% set name = 'Alice' %}Hello, World! {{ name }}",
},
],
test_output=[
("output", "Hello, World! Alice"),
("output", "Hello World!"),
("output", "Hello, World! Alice"),
],
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
# For python.format compatibility: replace all {...} with {{..}}.
# But avoid replacing {{...}} to {{{...}}}.
fmt = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", input_data.format)
template = jinja.from_string(fmt)
yield "output", template.render(**input_data.values)
yield "output", formatter.format_string(input_data.format, input_data.values)
class CombineTextsBlock(Block):

View File

@@ -9,6 +9,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
@@ -38,10 +39,8 @@ class UnrealTextToSpeechBlock(Block):
default="Scarlett",
)
credentials: CredentialsMetaInput[
Literal["unreal_speech"], Literal["api_key"]
Literal[ProviderName.UNREAL_SPEECH], Literal["api_key"]
] = CredentialsField(
provider="unreal_speech",
supported_credential_types={"api_key"},
description="The Unreal Speech integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)

View File

@@ -42,6 +42,7 @@ class BlockType(Enum):
OUTPUT = "Output"
NOTE = "Note"
WEBHOOK = "Webhook"
WEBHOOK_MANUAL = "Webhook (manual)"
AGENT = "Agent"
@@ -57,6 +58,7 @@ class BlockCategory(Enum):
COMMUNICATION = "Block that interacts with communication platforms."
DEVELOPER_TOOLS = "Developer tools such as GitHub blocks."
DATA = "Block that interacts with structured data."
HARDWARE = "Block that interacts with hardware."
AGENT = "Block that interacts with other agents."
CRM = "Block that interacts with CRM services."
@@ -65,7 +67,7 @@ class BlockCategory(Enum):
class BlockSchema(BaseModel):
cached_jsonschema: ClassVar[dict[str, Any]] = {}
cached_jsonschema: ClassVar[dict[str, Any]]
@classmethod
def jsonschema(cls) -> dict[str, Any]:
@@ -90,6 +92,7 @@ class BlockSchema(BaseModel):
}
elif isinstance(obj, list):
return [ref_to_dict(item) for item in obj]
return obj
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
@@ -145,6 +148,10 @@ class BlockSchema(BaseModel):
- A field that is called `credentials` MUST be a `CredentialsMetaInput`.
"""
super().__pydantic_init_subclass__(**kwargs)
# Reset cached JSON schema to prevent inheriting it from parent class
cls.cached_jsonschema = {}
credentials_fields = [
field_name
for field_name, info in cls.model_fields.items()
@@ -176,6 +183,11 @@ class BlockSchema(BaseModel):
f"Field 'credentials' on {cls.__qualname__} "
f"must be of type {CredentialsMetaInput.__name__}"
)
if credentials_field := cls.model_fields.get(CREDENTIALS_FIELD_NAME):
credentials_input_type = cast(
CredentialsMetaInput, credentials_field.annotation
)
credentials_input_type.validate_credentials_field_schema(cls)
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema)
@@ -187,7 +199,12 @@ class EmptySchema(BlockSchema):
# --8<-- [start:BlockWebhookConfig]
class BlockWebhookConfig(BaseModel):
class BlockManualWebhookConfig(BaseModel):
"""
Configuration model for webhook-triggered blocks on which
the user has to manually set up the webhook at the provider.
"""
provider: str
"""The service provider that the webhook connects to"""
@@ -198,6 +215,27 @@ class BlockWebhookConfig(BaseModel):
Only for use in the corresponding `WebhooksManager`.
"""
event_filter_input: str = ""
"""
Name of the block's event filter input.
Leave empty if the corresponding webhook doesn't have distinct event/payload types.
"""
event_format: str = "{event}"
"""
Template string for the event(s) that a block instance subscribes to.
Applied individually to each event selected in the event filter input.
Example: `"pull_request.{event}"` -> `"pull_request.opened"`
"""
class BlockWebhookConfig(BlockManualWebhookConfig):
"""
Configuration model for webhook-triggered blocks for which
the webhook can be automatically set up through the provider's API.
"""
resource_format: str
"""
Template string for the resource that a block instance subscribes to.
@@ -207,17 +245,6 @@ class BlockWebhookConfig(BaseModel):
Only for use in the corresponding `WebhooksManager`.
"""
event_filter_input: str
"""Name of the block's event filter input."""
event_format: str = "{event}"
"""
Template string for the event(s) that a block instance subscribes to.
Applied individually to each event selected in the event filter input.
Example: `"pull_request.{event}"` -> `"pull_request.opened"`
"""
# --8<-- [end:BlockWebhookConfig]
@@ -237,7 +264,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
disabled: bool = False,
static_output: bool = False,
block_type: BlockType = BlockType.STANDARD,
webhook_config: Optional[BlockWebhookConfig] = None,
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
):
"""
Initialize the block with the given schema.
@@ -268,27 +295,38 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.contributors = contributors or set()
self.disabled = disabled
self.static_output = static_output
self.block_type = block_type if not webhook_config else BlockType.WEBHOOK
self.block_type = block_type
self.webhook_config = webhook_config
self.execution_stats = {}
if self.webhook_config:
# Enforce shape of webhook event filter
event_filter_field = self.input_schema.model_fields[
self.webhook_config.event_filter_input
]
if not (
isinstance(event_filter_field.annotation, type)
and issubclass(event_filter_field.annotation, BaseModel)
and all(
field.annotation is bool
for field in event_filter_field.annotation.model_fields.values()
)
):
raise NotImplementedError(
f"{self.name} has an invalid webhook event selector: "
"field must be a BaseModel and all its fields must be boolean"
)
if isinstance(self.webhook_config, BlockWebhookConfig):
# Enforce presence of credentials field on auto-setup webhook blocks
if CREDENTIALS_FIELD_NAME not in self.input_schema.model_fields:
raise TypeError(
"credentials field is required on auto-setup webhook blocks"
)
self.block_type = BlockType.WEBHOOK
else:
self.block_type = BlockType.WEBHOOK_MANUAL
# Enforce shape of webhook event filter, if present
if self.webhook_config.event_filter_input:
event_filter_field = self.input_schema.model_fields[
self.webhook_config.event_filter_input
]
if not (
isinstance(event_filter_field.annotation, type)
and issubclass(event_filter_field.annotation, BaseModel)
and all(
field.annotation is bool
for field in event_filter_field.annotation.model_fields.values()
)
):
raise NotImplementedError(
f"{self.name} has an invalid webhook event selector: "
"field must be a BaseModel and all its fields must be boolean"
)
# Enforce presence of 'payload' input
if "payload" not in self.input_schema.model_fields:

View File

@@ -53,8 +53,8 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.LLAMA3_1_8B: 1,
LlmModel.OLLAMA_LLAMA3_8B: 1,
LlmModel.OLLAMA_LLAMA3_405B: 1,
LlmModel.OLLAMA_DOLPHIN: 1,
LlmModel.GEMINI_FLASH_1_5_8B: 1,
LlmModel.GEMINI_FLASH_1_5_EXP: 1,
LlmModel.GROK_BETA: 5,
LlmModel.MISTRAL_NEMO: 1,
LlmModel.COHERE_COMMAND_R_08_2024: 1,
@@ -62,6 +62,14 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.EVA_QWEN_2_5_32B: 1,
LlmModel.DEEPSEEK_CHAT: 2,
LlmModel.PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE: 1,
LlmModel.QWEN_QWQ_32B_PREVIEW: 2,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
LlmModel.AMAZON_NOVA_LITE_V1: 1,
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
LlmModel.AMAZON_NOVA_PRO_V1: 1,
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
}
for model in LlmModel:

View File

@@ -2,9 +2,9 @@ from abc import ABC, abstractmethod
from datetime import datetime, timezone
from prisma import Json
from prisma.enums import UserBlockCreditType
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import UserBlockCredit
from prisma.models import CreditTransaction
from backend.data.block import Block, BlockInput, get_block
from backend.data.block_cost_config import BLOCK_COSTS
@@ -76,7 +76,7 @@ class UserCredit(UserCreditBase):
else cur_month.replace(year=cur_month.year + 1, month=1)
)
user_credit = await UserBlockCredit.prisma().group_by(
user_credit = await CreditTransaction.prisma().group_by(
by=["userId"],
sum={"amount": True},
where={
@@ -93,10 +93,10 @@ class UserCredit(UserCreditBase):
key = f"MONTHLY-CREDIT-TOP-UP-{cur_month}"
try:
await UserBlockCredit.prisma().create(
await CreditTransaction.prisma().create(
data={
"amount": self.num_user_credits_refill,
"type": UserBlockCreditType.TOP_UP,
"type": CreditTransactionType.TOP_UP,
"userId": user_id,
"transactionKey": key,
"createdAt": self.time_now(),
@@ -184,11 +184,11 @@ class UserCredit(UserCreditBase):
if validate_balance and user_credit < cost:
raise ValueError(f"Insufficient credit: {user_credit} < {cost}")
await UserBlockCredit.prisma().create(
await CreditTransaction.prisma().create(
data={
"userId": user_id,
"amount": -cost,
"type": UserBlockCreditType.USAGE,
"type": CreditTransactionType.USAGE,
"blockId": block.id,
"metadata": Json(
{
@@ -202,11 +202,11 @@ class UserCredit(UserCreditBase):
return cost
async def top_up_credits(self, user_id: str, amount: int):
await UserBlockCredit.prisma().create(
await CreditTransaction.prisma().create(
data={
"userId": user_id,
"amount": amount,
"type": UserBlockCreditType.TOP_UP,
"type": CreditTransactionType.TOP_UP,
"createdAt": self.time_now(),
}
)

View File

@@ -29,6 +29,13 @@ async def connect():
if not prisma.is_connected():
raise ConnectionError("Failed to connect to Prisma.")
# Connection acquired from a pool like Supabase somehow still possibly allows
# the db client obtains a connection but still reject query connection afterward.
try:
await prisma.execute_raw("SELECT 1")
except Exception as e:
raise ConnectionError("Failed to connect to Prisma.") from e
@conn_retry("Prisma", "Releasing connection")
async def disconnect():

View File

@@ -9,7 +9,6 @@ from prisma.models import (
AgentNodeExecution,
AgentNodeExecutionInputOutput,
)
from prisma.types import AgentGraphExecutionWhereInput
from pydantic import BaseModel
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
@@ -19,14 +18,14 @@ from backend.util import json, mock
from backend.util.settings import Config
class GraphExecution(BaseModel):
class GraphExecutionEntry(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
start_node_execs: list["NodeExecution"]
start_node_execs: list["NodeExecutionEntry"]
class NodeExecution(BaseModel):
class NodeExecutionEntry(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
@@ -325,34 +324,6 @@ async def update_execution_status(
return ExecutionResult.from_db(res)
async def get_graph_execution(
graph_exec_id: str, user_id: str
) -> AgentGraphExecution | None:
"""
Retrieve a specific graph execution by its ID.
Args:
graph_exec_id (str): The ID of the graph execution to retrieve.
user_id (str): The ID of the user to whom the graph (execution) belongs.
Returns:
AgentGraphExecution | None: The graph execution if found, None otherwise.
"""
execution = await AgentGraphExecution.prisma().find_first(
where={"id": graph_exec_id, "userId": user_id},
include=GRAPH_EXECUTION_INCLUDE,
)
return execution
async def list_executions(graph_id: str, graph_version: int | None = None) -> list[str]:
where: AgentGraphExecutionWhereInput = {"agentGraphId": graph_id}
if graph_version is not None:
where["agentGraphVersion"] = graph_version
executions = await AgentGraphExecution.prisma().find_many(where=where)
return [execution.id for execution in executions]
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
executions = await AgentNodeExecution.prisma().find_many(
where={"agentGraphExecutionId": graph_exec_id},

View File

@@ -84,6 +84,8 @@ class NodeModel(Node):
raise ValueError(f"Block #{self.block_id} not found for node #{self.id}")
if not block.webhook_config:
raise TypeError("This method can't be used on non-webhook blocks")
if not block.webhook_config.event_filter_input:
return True
event_filter = self.input_default.get(block.webhook_config.event_filter_input)
if not event_filter:
raise ValueError(f"Event filter is not configured on node #{self.id}")
@@ -105,6 +107,8 @@ class GraphExecution(BaseDbModel):
duration: float
total_run_time: float
status: ExecutionStatus
graph_id: str
graph_version: int
@staticmethod
def from_db(execution: AgentGraphExecution):
@@ -130,6 +134,8 @@ class GraphExecution(BaseDbModel):
duration=duration,
total_run_time=total_run_time,
status=ExecutionStatus(execution.executionStatus),
graph_id=execution.agentGraphId,
graph_version=execution.agentGraphVersion,
)
@@ -139,7 +145,6 @@ class Graph(BaseDbModel):
is_template: bool = False
name: str
description: str
executions: list[GraphExecution] = []
nodes: list[Node] = []
links: list[Link] = []
@@ -253,7 +258,7 @@ class GraphModel(Graph):
for link in self.links:
input_links[link.sink_id].append(link)
# Nodes: required fields are filled or connected
# Nodes: required fields are filled or connected and dependencies are satisfied
for node in self.nodes:
block = get_block(node.block_id)
if block is None:
@@ -264,16 +269,55 @@ class GraphModel(Graph):
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
)
for name in block.input_schema.get_required_fields():
if name not in provided_inputs and (
for_run # Skip input completion validation, unless when executing.
or block.block_type == BlockType.INPUT
or block.block_type == BlockType.OUTPUT
or block.block_type == BlockType.AGENT
if (
name not in provided_inputs
and not (
name == "payload"
and block.block_type
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
)
and (
for_run # Skip input completion validation, unless when executing.
or block.block_type == BlockType.INPUT
or block.block_type == BlockType.OUTPUT
or block.block_type == BlockType.AGENT
)
):
raise ValueError(
f"Node {block.name} #{node.id} required input missing: `{name}`"
)
# Get input schema properties and check dependencies
input_schema = block.input_schema.model_fields
required_fields = block.input_schema.get_required_fields()
def has_value(name):
return (
node is not None
and name in node.input_default
and node.input_default[name] is not None
and str(node.input_default[name]).strip() != ""
) or (name in input_schema and input_schema[name].default is not None)
# Validate dependencies between fields
for field_name, field_info in input_schema.items():
# Apply input dependency validation only on run & field with depends_on
json_schema_extra = field_info.json_schema_extra or {}
dependencies = json_schema_extra.get("depends_on", [])
if not for_run or not dependencies:
continue
# Check if dependent field has value in input_default
field_has_value = has_value(field_name)
field_is_required = field_name in required_fields
# Check for missing dependencies when dependent field is present
missing_deps = [dep for dep in dependencies if not has_value(dep)]
if missing_deps and (field_has_value or field_is_required):
raise ValueError(
f"Node {block.name} #{node.id}: Field `{field_name}` requires [{', '.join(missing_deps)}] to be set"
)
node_map = {v.id: v for v in self.nodes}
def is_static_output_block(nid: str) -> bool:
@@ -323,12 +367,7 @@ class GraphModel(Graph):
link.is_static = True # Each value block output should be static.
@staticmethod
def from_db(graph: AgentGraph, hide_credentials: bool = False):
executions = [
GraphExecution.from_db(execution)
for execution in graph.AgentGraphExecution or []
]
def from_db(graph: AgentGraph, for_export: bool = False):
return GraphModel(
id=graph.id,
user_id=graph.userId,
@@ -337,9 +376,8 @@ class GraphModel(Graph):
is_template=graph.isTemplate,
name=graph.name or "",
description=graph.description or "",
executions=executions,
nodes=[
GraphModel._process_node(node, hide_credentials)
NodeModel.from_db(GraphModel._process_node(node, for_export))
for node in graph.AgentNodes or []
],
links=list(
@@ -352,23 +390,29 @@ class GraphModel(Graph):
)
@staticmethod
def _process_node(node: AgentNode, hide_credentials: bool) -> NodeModel:
node_dict = {field: getattr(node, field) for field in node.model_fields}
if hide_credentials and "constantInput" in node_dict:
constant_input = json.loads(
node_dict["constantInput"], target_type=dict[str, Any]
)
constant_input = GraphModel._hide_credentials_in_input(constant_input)
node_dict["constantInput"] = json.dumps(constant_input)
return NodeModel.from_db(AgentNode(**node_dict))
def _process_node(node: AgentNode, for_export: bool) -> AgentNode:
if for_export:
# Remove credentials from node input
if node.constantInput:
constant_input = json.loads(
node.constantInput, target_type=dict[str, Any]
)
constant_input = GraphModel._hide_node_input_credentials(constant_input)
node.constantInput = json.dumps(constant_input)
# Remove webhook info
node.webhookId = None
node.Webhook = None
return node
@staticmethod
def _hide_credentials_in_input(input_data: dict[str, Any]) -> dict[str, Any]:
def _hide_node_input_credentials(input_data: dict[str, Any]) -> dict[str, Any]:
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
result = {}
for key, value in input_data.items():
if isinstance(value, dict):
result[key] = GraphModel._hide_credentials_in_input(value)
result[key] = GraphModel._hide_node_input_credentials(value)
elif isinstance(value, str) and any(
sensitive_key in key.lower() for sensitive_key in sensitive_keys
):
@@ -407,7 +451,6 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
async def get_graphs(
user_id: str,
include_executions: bool = False,
filter_by: Literal["active", "template"] | None = "active",
) -> list[GraphModel]:
"""
@@ -415,33 +458,50 @@ async def get_graphs(
Default behaviour is to get all currently active graphs.
Args:
include_executions: Whether to include executions in the graph metadata.
filter_by: An optional filter to either select templates or active graphs.
user_id: The ID of the user that owns the graph.
Returns:
list[GraphModel]: A list of objects representing the retrieved graphs.
"""
where_clause: AgentGraphWhereInput = {}
where_clause: AgentGraphWhereInput = {"userId": user_id}
if filter_by == "active":
where_clause["isActive"] = True
elif filter_by == "template":
where_clause["isTemplate"] = True
where_clause["userId"] = user_id
graph_include = AGENT_GRAPH_INCLUDE
graph_include["AgentGraphExecution"] = include_executions
graphs = await AgentGraph.prisma().find_many(
where=where_clause,
distinct=["id"],
order={"version": "desc"},
include=graph_include,
include=AGENT_GRAPH_INCLUDE,
)
return [GraphModel.from_db(graph) for graph in graphs]
graph_models = []
for graph in graphs:
try:
graph_models.append(GraphModel.from_db(graph))
except Exception as e:
logger.error(f"Error processing graph {graph.id}: {e}")
continue
return graph_models
async def get_executions(user_id: str) -> list[GraphExecution]:
executions = await AgentGraphExecution.prisma().find_many(
where={"userId": user_id},
order={"createdAt": "desc"},
)
return [GraphExecution.from_db(execution) for execution in executions]
async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "userId": user_id}
)
return GraphExecution.from_db(execution) if execution else None
async def get_graph(
@@ -449,7 +509,7 @@ async def get_graph(
version: int | None = None,
template: bool = False,
user_id: str | None = None,
hide_credentials: bool = False,
for_export: bool = False,
) -> GraphModel | None:
"""
Retrieves a graph from the DB.
@@ -475,7 +535,7 @@ async def get_graph(
include=AGENT_GRAPH_INCLUDE,
order={"version": "desc"},
)
return GraphModel.from_db(graph, hide_credentials) if graph else None
return GraphModel.from_db(graph, for_export) if graph else None
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:

View File

@@ -3,10 +3,12 @@ from typing import TYPE_CHECKING, AsyncGenerator, Optional
from prisma import Json
from prisma.models import IntegrationWebhook
from pydantic import Field
from pydantic import Field, computed_field
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
from backend.data.queue import AsyncRedisEventBus
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.utils import webhook_ingress_url
from .db import BaseDbModel
@@ -18,7 +20,7 @@ logger = logging.getLogger(__name__)
class Webhook(BaseDbModel):
user_id: str
provider: str
provider: ProviderName
credentials_id: str
webhook_type: str
resource: str
@@ -30,6 +32,11 @@ class Webhook(BaseDbModel):
attached_nodes: Optional[list["NodeModel"]] = None
@computed_field
@property
def url(self) -> str:
return webhook_ingress_url(self.provider.value, self.id)
@staticmethod
def from_db(webhook: IntegrationWebhook):
from .graph import NodeModel
@@ -37,7 +44,7 @@ class Webhook(BaseDbModel):
return Webhook(
id=webhook.id,
user_id=webhook.userId,
provider=webhook.provider,
provider=ProviderName(webhook.provider),
credentials_id=webhook.credentialsId,
webhook_type=webhook.webhookType,
resource=webhook.resource,
@@ -61,7 +68,7 @@ async def create_webhook(webhook: Webhook) -> Webhook:
data={
"id": webhook.id,
"userId": webhook.user_id,
"provider": webhook.provider,
"provider": webhook.provider.value,
"credentialsId": webhook.credentials_id,
"webhookType": webhook.webhook_type,
"resource": webhook.resource,
@@ -83,8 +90,10 @@ async def get_webhook(webhook_id: str) -> Webhook:
return Webhook.from_db(webhook)
async def get_all_webhooks(credentials_id: str) -> list[Webhook]:
async def get_all_webhooks_by_creds(credentials_id: str) -> list[Webhook]:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
if not credentials_id:
raise ValueError("credentials_id must not be empty")
webhooks = await IntegrationWebhook.prisma().find_many(
where={"credentialsId": credentials_id},
include=INTEGRATION_WEBHOOK_INCLUDE,
@@ -92,7 +101,7 @@ async def get_all_webhooks(credentials_id: str) -> list[Webhook]:
return [Webhook.from_db(webhook) for webhook in webhooks]
async def find_webhook(
async def find_webhook_by_credentials_and_props(
credentials_id: str, webhook_type: str, resource: str, events: list[str]
) -> Webhook | None:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
@@ -108,6 +117,22 @@ async def find_webhook(
return Webhook.from_db(webhook) if webhook else None
async def find_webhook_by_graph_and_props(
graph_id: str, provider: str, webhook_type: str, events: list[str]
) -> Webhook | None:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
webhook = await IntegrationWebhook.prisma().find_first(
where={
"provider": provider,
"webhookType": webhook_type,
"events": {"has_every": events},
"AgentNodes": {"some": {"agentGraphId": graph_id}},
},
include=INTEGRATION_WEBHOOK_INCLUDE,
)
return Webhook.from_db(webhook) if webhook else None
async def update_webhook_config(webhook_id: str, updated_config: dict) -> Webhook:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
_updated_webhook = await IntegrationWebhook.prisma().update(
@@ -144,25 +169,28 @@ class WebhookEventBus(AsyncRedisEventBus[WebhookEvent]):
def event_bus_name(self) -> str:
return "webhooks"
async def publish(self, event: WebhookEvent):
await self.publish_event(event, f"{event.webhook_id}/{event.event_type}")
async def listen(
self, webhook_id: str, event_type: Optional[str] = None
) -> AsyncGenerator[WebhookEvent, None]:
async for event in self.listen_events(f"{webhook_id}/{event_type or '*'}"):
yield event
event_bus = WebhookEventBus()
_webhook_event_bus = WebhookEventBus()
async def publish_webhook_event(event: WebhookEvent):
await event_bus.publish(event)
await _webhook_event_bus.publish_event(
event, f"{event.webhook_id}/{event.event_type}"
)
async def listen_for_webhook_event(
async def listen_for_webhook_events(
webhook_id: str, event_type: Optional[str] = None
) -> AsyncGenerator[WebhookEvent, None]:
async for event in _webhook_event_bus.listen_events(
f"{webhook_id}/{event_type or '*'}"
):
yield event
async def wait_for_webhook_event(
webhook_id: str, event_type: Optional[str] = None, timeout: Optional[float] = None
) -> WebhookEvent | None:
async for event in event_bus.listen(webhook_id, event_type):
return event # Only one event is expected
return await _webhook_event_bus.wait_for_event(
f"{webhook_id}/{event_type or '*'}", timeout
)

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import logging
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
@@ -11,19 +12,32 @@ from typing import (
Optional,
TypedDict,
TypeVar,
get_args,
)
from uuid import uuid4
from pydantic import BaseModel, Field, GetCoreSchemaHandler, SecretStr, field_serializer
from pydantic import (
BaseModel,
ConfigDict,
Field,
GetCoreSchemaHandler,
SecretStr,
field_serializer,
)
from pydantic_core import (
CoreSchema,
PydanticUndefined,
PydanticUndefinedType,
ValidationError,
core_schema,
)
from backend.integrations.providers import ProviderName
from backend.util.settings import Secrets
if TYPE_CHECKING:
from backend.data.block import BlockSchema
T = TypeVar("T")
logger = logging.getLogger(__name__)
@@ -124,6 +138,7 @@ def SchemaField(
secret: bool = False,
exclude: bool = False,
hidden: Optional[bool] = None,
depends_on: list[str] | None = None,
**kwargs,
) -> T:
json_extra = {
@@ -133,6 +148,7 @@ def SchemaField(
"secret": secret,
"advanced": advanced,
"hidden": hidden,
"depends_on": depends_on,
}.items()
if v is not None
}
@@ -146,7 +162,7 @@ def SchemaField(
exclude=exclude,
json_schema_extra=json_extra,
**kwargs,
)
) # type: ignore
class _BaseCredentials(BaseModel):
@@ -220,7 +236,7 @@ class UserIntegrations(BaseModel):
oauth_states: list[OAuthState] = Field(default_factory=list)
CP = TypeVar("CP", bound=str)
CP = TypeVar("CP", bound=ProviderName)
CT = TypeVar("CT", bound=CredentialsType)
@@ -233,19 +249,51 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
provider: CP
type: CT
@staticmethod
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
schema["credentials_provider"] = get_args(
cls.model_fields["provider"].annotation
)
schema["credentials_types"] = get_args(cls.model_fields["type"].annotation)
class CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
model_config = ConfigDict(
json_schema_extra=_add_json_schema_extra, # type: ignore
)
@classmethod
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
"""Validates the schema of a `credentials` field"""
field_schema = model.jsonschema()["properties"][CREDENTIALS_FIELD_NAME]
try:
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
field_schema
)
except ValidationError as e:
if "Field required [type=missing" not in str(e):
raise
raise TypeError(
"Field 'credentials' JSON schema lacks required extra items: "
f"{field_schema}"
) from e
if (
len(schema_extra.credentials_provider) > 1
and not schema_extra.discriminator
):
raise TypeError("Multi-provider CredentialsField requires discriminator!")
class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
credentials_provider: list[CP]
credentials_scopes: Optional[list[str]]
credentials_scopes: Optional[list[str]] = None
credentials_types: list[CT]
discriminator: Optional[str] = None
discriminator_mapping: Optional[dict[str, CP]] = None
def CredentialsField(
provider: CP | list[CP],
supported_credential_types: set[CT],
required_scopes: set[str] = set(),
*,
discriminator: Optional[str] = None,
@@ -253,26 +301,26 @@ def CredentialsField(
title: Optional[str] = None,
description: Optional[str] = None,
**kwargs,
) -> CredentialsMetaInput[CP, CT]:
) -> CredentialsMetaInput:
"""
`CredentialsField` must and can only be used on fields named `credentials`.
This is enforced by the `BlockSchema` base class.
"""
if not isinstance(provider, str) and len(provider) > 1 and not discriminator:
raise TypeError("Multi-provider CredentialsField requires discriminator!")
field_schema_extra = CredentialsFieldSchemaExtra[CP, CT](
credentials_provider=[provider] if isinstance(provider, str) else provider,
credentials_scopes=list(required_scopes) or None, # omit if empty
credentials_types=list(supported_credential_types),
discriminator=discriminator,
discriminator_mapping=discriminator_mapping,
)
field_schema_extra = {
k: v
for k, v in {
"credentials_scopes": list(required_scopes) or None,
"discriminator": discriminator,
"discriminator_mapping": discriminator_mapping,
}.items()
if v is not None
}
return Field(
title=title,
description=description,
json_schema_extra=field_schema_extra.model_dump(exclude_none=True),
json_schema_extra=field_schema_extra, # validated on BlockSchema init
**kwargs,
)

View File

@@ -1,8 +1,9 @@
import asyncio
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, AsyncGenerator, Generator, Generic, TypeVar
from typing import Any, AsyncGenerator, Generator, Generic, Optional, TypeVar
from pydantic import BaseModel
from redis.asyncio.client import PubSub as AsyncPubSub
@@ -48,12 +49,12 @@ class BaseRedisEventBus(Generic[M], ABC):
except Exception as e:
logger.error(f"Failed to parse event result from Redis {msg} {e}")
def _subscribe(
def _get_pubsub_channel(
self, connection: redis.Redis | redis.AsyncRedis, channel_key: str
) -> tuple[PubSub | AsyncPubSub, str]:
channel_name = f"{self.event_bus_name}/{channel_key}"
full_channel_name = f"{self.event_bus_name}/{channel_key}"
pubsub = connection.pubsub()
return pubsub, channel_name
return pubsub, full_channel_name
class RedisEventBus(BaseRedisEventBus[M], ABC):
@@ -64,17 +65,19 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
return redis.get_redis()
def publish_event(self, event: M, channel_key: str):
message, channel_name = self._serialize_message(event, channel_key)
self.connection.publish(channel_name, message)
message, full_channel_name = self._serialize_message(event, channel_key)
self.connection.publish(full_channel_name, message)
def listen_events(self, channel_key: str) -> Generator[M, None, None]:
pubsub, channel_name = self._subscribe(self.connection, channel_key)
pubsub, full_channel_name = self._get_pubsub_channel(
self.connection, channel_key
)
assert isinstance(pubsub, PubSub)
if "*" in channel_key:
pubsub.psubscribe(channel_name)
pubsub.psubscribe(full_channel_name)
else:
pubsub.subscribe(channel_name)
pubsub.subscribe(full_channel_name)
for message in pubsub.listen():
if event := self._deserialize_message(message, channel_key):
@@ -89,19 +92,31 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
return await redis.get_redis_async()
async def publish_event(self, event: M, channel_key: str):
message, channel_name = self._serialize_message(event, channel_key)
message, full_channel_name = self._serialize_message(event, channel_key)
connection = await self.connection
await connection.publish(channel_name, message)
await connection.publish(full_channel_name, message)
async def listen_events(self, channel_key: str) -> AsyncGenerator[M, None]:
pubsub, channel_name = self._subscribe(await self.connection, channel_key)
pubsub, full_channel_name = self._get_pubsub_channel(
await self.connection, channel_key
)
assert isinstance(pubsub, AsyncPubSub)
if "*" in channel_key:
await pubsub.psubscribe(channel_name)
await pubsub.psubscribe(full_channel_name)
else:
await pubsub.subscribe(channel_name)
await pubsub.subscribe(full_channel_name)
async for message in pubsub.listen():
if event := self._deserialize_message(message, channel_key):
yield event
async def wait_for_event(
self, channel_key: str, timeout: Optional[float] = None
) -> M | None:
try:
return await asyncio.wait_for(
anext(aiter(self.listen_events(channel_key))), timeout
)
except TimeoutError:
return None

View File

@@ -25,8 +25,8 @@ from backend.data.execution import (
ExecutionQueue,
ExecutionResult,
ExecutionStatus,
GraphExecution,
NodeExecution,
GraphExecutionEntry,
NodeExecutionEntry,
merge_execution_input,
parse_execution_output,
)
@@ -96,13 +96,13 @@ class LogMetadata:
T = TypeVar("T")
ExecutionStream = Generator[NodeExecution, None, None]
ExecutionStream = Generator[NodeExecutionEntry, None, None]
def execute_node(
db_client: "DatabaseManager",
creds_manager: IntegrationCredentialsManager,
data: NodeExecution,
data: NodeExecutionEntry,
execution_stats: dict[str, Any] | None = None,
) -> ExecutionStream:
"""
@@ -252,15 +252,15 @@ def _enqueue_next_nodes(
graph_exec_id: str,
graph_id: str,
log_metadata: LogMetadata,
) -> list[NodeExecution]:
) -> list[NodeExecutionEntry]:
def add_enqueued_execution(
node_exec_id: str, node_id: str, data: BlockInput
) -> NodeExecution:
) -> NodeExecutionEntry:
exec_update = db_client.update_execution_status(
node_exec_id, ExecutionStatus.QUEUED, data
)
db_client.send_execution_update(exec_update)
return NodeExecution(
return NodeExecutionEntry(
user_id=user_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
@@ -269,7 +269,7 @@ def _enqueue_next_nodes(
data=data,
)
def register_next_executions(node_link: Link) -> list[NodeExecution]:
def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
enqueued_executions = []
next_output_name = node_link.source_name
next_input_name = node_link.sink_name
@@ -501,8 +501,8 @@ class Executor:
@error_logged
def on_node_execution(
cls,
q: ExecutionQueue[NodeExecution],
node_exec: NodeExecution,
q: ExecutionQueue[NodeExecutionEntry],
node_exec: NodeExecutionEntry,
) -> dict[str, Any]:
log_metadata = LogMetadata(
user_id=node_exec.user_id,
@@ -529,8 +529,8 @@ class Executor:
@time_measured
def _on_node_execution(
cls,
q: ExecutionQueue[NodeExecution],
node_exec: NodeExecution,
q: ExecutionQueue[NodeExecutionEntry],
node_exec: NodeExecutionEntry,
log_metadata: LogMetadata,
stats: dict[str, Any] | None = None,
):
@@ -580,7 +580,9 @@ class Executor:
@classmethod
@error_logged
def on_graph_execution(cls, graph_exec: GraphExecution, cancel: threading.Event):
def on_graph_execution(
cls, graph_exec: GraphExecutionEntry, cancel: threading.Event
):
log_metadata = LogMetadata(
user_id=graph_exec.user_id,
graph_eid=graph_exec.graph_exec_id,
@@ -605,7 +607,7 @@ class Executor:
@time_measured
def _on_graph_execution(
cls,
graph_exec: GraphExecution,
graph_exec: GraphExecutionEntry,
cancel: threading.Event,
log_metadata: LogMetadata,
) -> tuple[dict[str, Any], Exception | None]:
@@ -636,13 +638,13 @@ class Executor:
cancel_thread.start()
try:
queue = ExecutionQueue[NodeExecution]()
queue = ExecutionQueue[NodeExecutionEntry]()
for node_exec in graph_exec.start_node_execs:
queue.add(node_exec)
running_executions: dict[str, AsyncResult] = {}
def make_exec_callback(exec_data: NodeExecution):
def make_exec_callback(exec_data: NodeExecutionEntry):
node_id = exec_data.node_id
def callback(result: object):
@@ -717,7 +719,7 @@ class ExecutionManager(AppService):
self.use_redis = True
self.use_supabase = True
self.pool_size = settings.config.num_graph_workers
self.queue = ExecutionQueue[GraphExecution]()
self.queue = ExecutionQueue[GraphExecutionEntry]()
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
@classmethod
@@ -768,7 +770,7 @@ class ExecutionManager(AppService):
data: BlockInput,
user_id: str,
graph_version: int | None = None,
) -> GraphExecution:
) -> GraphExecutionEntry:
graph: GraphModel | None = self.db_client.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
@@ -796,10 +798,13 @@ class ExecutionManager(AppService):
# Extract webhook payload, and assign it to the input pin
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
if (
block.block_type == BlockType.WEBHOOK
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
and node.webhook_id
and webhook_payload_key in data
):
if webhook_payload_key not in data:
raise ValueError(
f"Node {block.name} #{node.id} webhook payload is missing"
)
input_data = {"payload": data[webhook_payload_key]}
input_data, error = validate_exec(node, input_data)
@@ -818,7 +823,7 @@ class ExecutionManager(AppService):
starting_node_execs = []
for node_exec in node_execs:
starting_node_execs.append(
NodeExecution(
NodeExecutionEntry(
user_id=user_id,
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
@@ -832,7 +837,7 @@ class ExecutionManager(AppService):
)
self.db_client.send_execution_update(exec_update)
graph_exec = GraphExecution(
graph_exec = GraphExecutionEntry(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec_id,

View File

@@ -1,6 +1,7 @@
import logging
from contextlib import contextmanager
from datetime import datetime
from typing import TYPE_CHECKING
from autogpt_libs.utils.synchronize import RedisKeyedMutex
from redis.lock import Lock as RedisLock
@@ -8,10 +9,13 @@ from redis.lock import Lock as RedisLock
from backend.data import redis
from backend.data.model import Credentials
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.util.exceptions import MissingConfigError
from backend.util.settings import Settings
if TYPE_CHECKING:
from backend.integrations.oauth import BaseOAuthHandler
logger = logging.getLogger(__name__)
settings = Settings()
@@ -148,7 +152,7 @@ class IntegrationCredentialsManager:
self.store.locks.release_all_locks()
def _get_provider_oauth_handler(provider_name: str) -> BaseOAuthHandler:
def _get_provider_oauth_handler(provider_name: str) -> "BaseOAuthHandler":
if provider_name not in HANDLERS_BY_NAME:
raise KeyError(f"Unknown provider '{provider_name}'")

View File

@@ -1,10 +1,15 @@
from .base import BaseOAuthHandler
from typing import TYPE_CHECKING
from .github import GitHubOAuthHandler
from .google import GoogleOAuthHandler
from .notion import NotionOAuthHandler
if TYPE_CHECKING:
from ..providers import ProviderName
from .base import BaseOAuthHandler
# --8<-- [start:HANDLERS_BY_NAMEExample]
HANDLERS_BY_NAME: dict[str, type[BaseOAuthHandler]] = {
HANDLERS_BY_NAME: dict["ProviderName", type["BaseOAuthHandler"]] = {
handler.PROVIDER_NAME: handler
for handler in [
GitHubOAuthHandler,

View File

@@ -4,13 +4,14 @@ from abc import ABC, abstractmethod
from typing import ClassVar
from backend.data.model import OAuth2Credentials
from backend.integrations.providers import ProviderName
logger = logging.getLogger(__name__)
class BaseOAuthHandler(ABC):
# --8<-- [start:BaseOAuthHandler1]
PROVIDER_NAME: ClassVar[str]
PROVIDER_NAME: ClassVar[ProviderName]
DEFAULT_SCOPES: ClassVar[list[str]] = []
# --8<-- [end:BaseOAuthHandler1]
@@ -76,6 +77,8 @@ class BaseOAuthHandler(ABC):
"""Handles the default scopes for the provider"""
# If scopes are empty, use the default scopes for the provider
if not scopes:
logger.debug(f"Using default scopes for provider {self.PROVIDER_NAME}")
logger.debug(
f"Using default scopes for provider {self.PROVIDER_NAME.value}"
)
scopes = self.DEFAULT_SCOPES
return scopes

View File

@@ -3,6 +3,7 @@ from typing import Optional
from urllib.parse import urlencode
from backend.data.model import OAuth2Credentials
from backend.integrations.providers import ProviderName
from backend.util.request import requests
from .base import BaseOAuthHandler
@@ -23,7 +24,7 @@ class GitHubOAuthHandler(BaseOAuthHandler):
access token *with no refresh token*.
""" # noqa
PROVIDER_NAME = "github"
PROVIDER_NAME = ProviderName.GITHUB
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id

View File

@@ -9,6 +9,7 @@ from google_auth_oauthlib.flow import Flow
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
from backend.integrations.providers import ProviderName
from .base import BaseOAuthHandler
@@ -21,7 +22,7 @@ class GoogleOAuthHandler(BaseOAuthHandler):
Based on the documentation at https://developers.google.com/identity/protocols/oauth2/web-server
""" # noqa
PROVIDER_NAME = "google"
PROVIDER_NAME = ProviderName.GOOGLE
EMAIL_ENDPOINT = "https://www.googleapis.com/oauth2/v2/userinfo"
DEFAULT_SCOPES = [
"https://www.googleapis.com/auth/userinfo.email",

View File

@@ -2,6 +2,7 @@ from base64 import b64encode
from urllib.parse import urlencode
from backend.data.model import OAuth2Credentials
from backend.integrations.providers import ProviderName
from backend.util.request import requests
from .base import BaseOAuthHandler
@@ -16,7 +17,7 @@ class NotionOAuthHandler(BaseOAuthHandler):
- Notion doesn't use scopes
"""
PROVIDER_NAME = "notion"
PROVIDER_NAME = ProviderName.NOTION
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id

View File

@@ -1,7 +1,31 @@
from enum import Enum
# --8<-- [start:ProviderName]
class ProviderName(str, Enum):
ANTHROPIC = "anthropic"
COMPASS = "compass"
DISCORD = "discord"
D_ID = "d_id"
E2B = "e2b"
EXA = "exa"
FAL = "fal"
GITHUB = "github"
GOOGLE = "google"
GOOGLE_MAPS = "google_maps"
GROQ = "groq"
HUBSPOT = "hubspot"
IDEOGRAM = "ideogram"
JINA = "jina"
MEDIUM = "medium"
NOTION = "notion"
OLLAMA = "ollama"
OPENAI = "openai"
OPENWEATHERMAP = "openweathermap"
OPEN_ROUTER = "open_router"
PINECONE = "pinecone"
REPLICATE = "replicate"
REVID = "revid"
SLANT3D = "slant3d"
UNREAL_SPEECH = "unreal_speech"
# --8<-- [end:ProviderName]

View File

@@ -1,15 +1,18 @@
from typing import TYPE_CHECKING
from .compass import CompassWebhookManager
from .github import GithubWebhooksManager
from .slant3d import Slant3DWebhooksManager
if TYPE_CHECKING:
from .base import BaseWebhooksManager
from ..providers import ProviderName
from ._base import BaseWebhooksManager
# --8<-- [start:WEBHOOK_MANAGERS_BY_NAME]
WEBHOOK_MANAGERS_BY_NAME: dict[str, type["BaseWebhooksManager"]] = {
WEBHOOK_MANAGERS_BY_NAME: dict["ProviderName", type["BaseWebhooksManager"]] = {
handler.PROVIDER_NAME: handler
for handler in [
CompassWebhookManager,
GithubWebhooksManager,
Slant3DWebhooksManager,
]

View File

@@ -1,7 +1,7 @@
import logging
import secrets
from abc import ABC, abstractmethod
from typing import ClassVar, Generic, TypeVar
from typing import ClassVar, Generic, Optional, TypeVar
from uuid import uuid4
from fastapi import Request
@@ -9,6 +9,8 @@ from strenum import StrEnum
from backend.data import integrations
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.utils import webhook_ingress_url
from backend.util.exceptions import MissingConfigError
from backend.util.settings import Config
@@ -20,12 +22,12 @@ WT = TypeVar("WT", bound=StrEnum)
class BaseWebhooksManager(ABC, Generic[WT]):
# --8<-- [start:BaseWebhooksManager1]
PROVIDER_NAME: ClassVar[str]
PROVIDER_NAME: ClassVar[ProviderName]
# --8<-- [end:BaseWebhooksManager1]
WebhookType: WT
async def get_suitable_webhook(
async def get_suitable_auto_webhook(
self,
user_id: str,
credentials: Credentials,
@@ -38,16 +40,34 @@ class BaseWebhooksManager(ABC, Generic[WT]):
"PLATFORM_BASE_URL must be set to use Webhook functionality"
)
if webhook := await integrations.find_webhook(
if webhook := await integrations.find_webhook_by_credentials_and_props(
credentials.id, webhook_type, resource, events
):
return webhook
return await self._create_webhook(
user_id, credentials, webhook_type, resource, events
user_id, webhook_type, events, resource, credentials
)
async def get_manual_webhook(
self,
user_id: str,
graph_id: str,
webhook_type: WT,
events: list[str],
):
if current_webhook := await integrations.find_webhook_by_graph_and_props(
graph_id, self.PROVIDER_NAME, webhook_type, events
):
return current_webhook
return await self._create_webhook(
user_id,
webhook_type,
events,
register=False,
)
async def prune_webhook_if_dangling(
self, webhook_id: str, credentials: Credentials
self, webhook_id: str, credentials: Optional[Credentials]
) -> bool:
webhook = await integrations.get_webhook(webhook_id)
if webhook.attached_nodes is None:
@@ -56,7 +76,8 @@ class BaseWebhooksManager(ABC, Generic[WT]):
# Don't prune webhook if in use
return False
await self._deregister_webhook(webhook, credentials)
if credentials:
await self._deregister_webhook(webhook, credentials)
await integrations.delete_webhook(webhook.id)
return True
@@ -81,7 +102,9 @@ class BaseWebhooksManager(ABC, Generic[WT]):
# --8<-- [end:BaseWebhooksManager3]
# --8<-- [start:BaseWebhooksManager5]
async def trigger_ping(self, webhook: integrations.Webhook) -> None:
async def trigger_ping(
self, webhook: integrations.Webhook, credentials: Credentials | None
) -> None:
"""
Triggers a ping to the given webhook.
@@ -132,27 +155,36 @@ class BaseWebhooksManager(ABC, Generic[WT]):
async def _create_webhook(
self,
user_id: str,
credentials: Credentials,
webhook_type: WT,
resource: str,
events: list[str],
resource: str = "",
credentials: Optional[Credentials] = None,
register: bool = True,
) -> integrations.Webhook:
if not app_config.platform_base_url:
raise MissingConfigError(
"PLATFORM_BASE_URL must be set to use Webhook functionality"
)
id = str(uuid4())
secret = secrets.token_hex(32)
provider_name = self.PROVIDER_NAME
ingress_url = (
f"{app_config.platform_base_url}/api/integrations/{provider_name}"
f"/webhooks/{id}/ingress"
)
provider_webhook_id, config = await self._register_webhook(
credentials, webhook_type, resource, events, ingress_url, secret
)
ingress_url = webhook_ingress_url(provider_name=provider_name, webhook_id=id)
if register:
if not credentials:
raise TypeError("credentials are required if register = True")
provider_webhook_id, config = await self._register_webhook(
credentials, webhook_type, resource, events, ingress_url, secret
)
else:
provider_webhook_id, config = "", {}
return await integrations.create_webhook(
integrations.Webhook(
id=id,
user_id=user_id,
provider=provider_name,
credentials_id=credentials.id,
credentials_id=credentials.id if credentials else "",
webhook_type=webhook_type,
resource=resource,
events=events,

View File

@@ -0,0 +1,30 @@
import logging
from backend.data import integrations
from backend.data.model import APIKeyCredentials, Credentials, OAuth2Credentials
from ._base import WT, BaseWebhooksManager
logger = logging.getLogger(__name__)
class ManualWebhookManagerBase(BaseWebhooksManager[WT]):
async def _register_webhook(
self,
credentials: Credentials,
webhook_type: WT,
resource: str,
events: list[str],
ingress_url: str,
secret: str,
) -> tuple[str, dict]:
print(ingress_url) # FIXME: pass URL to user in front end
return "", {}
async def _deregister_webhook(
self,
webhook: integrations.Webhook,
credentials: OAuth2Credentials | APIKeyCredentials,
) -> None:
pass

View File

@@ -0,0 +1,30 @@
import logging
from fastapi import Request
from strenum import StrEnum
from backend.data import integrations
from backend.integrations.providers import ProviderName
from ._manual_base import ManualWebhookManagerBase
logger = logging.getLogger(__name__)
class CompassWebhookType(StrEnum):
TRANSCRIPTION = "transcription"
TASK = "task"
class CompassWebhookManager(ManualWebhookManagerBase):
PROVIDER_NAME = ProviderName.COMPASS
WebhookType = CompassWebhookType
@classmethod
async def validate_payload(
cls, webhook: integrations.Webhook, request: Request
) -> tuple[dict, str]:
payload = await request.json()
event_type = CompassWebhookType.TRANSCRIPTION # currently the only type
return payload, event_type

View File

@@ -8,8 +8,9 @@ from strenum import StrEnum
from backend.data import integrations
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from .base import BaseWebhooksManager
from ._base import BaseWebhooksManager
logger = logging.getLogger(__name__)
@@ -20,7 +21,7 @@ class GithubWebhookType(StrEnum):
class GithubWebhooksManager(BaseWebhooksManager):
PROVIDER_NAME = "github"
PROVIDER_NAME = ProviderName.GITHUB
WebhookType = GithubWebhookType
@@ -58,10 +59,15 @@ class GithubWebhooksManager(BaseWebhooksManager):
return payload, event_type
async def trigger_ping(self, webhook: integrations.Webhook) -> None:
async def trigger_ping(
self, webhook: integrations.Webhook, credentials: Credentials | None
) -> None:
if not credentials:
raise ValueError("Credentials are required but were not passed")
headers = {
**self.GITHUB_API_DEFAULT_HEADERS,
"Authorization": f"Bearer {webhook.config.get('access_token')}",
"Authorization": credentials.bearer(),
}
repo, github_hook_id = webhook.resource, webhook.provider_webhook_id

View File

@@ -1,7 +1,7 @@
import logging
from typing import TYPE_CHECKING, Callable, Optional, cast
from backend.data.block import get_block
from backend.data.block import BlockWebhookConfig, get_block
from backend.data.graph import set_node_webhook
from backend.data.model import CREDENTIALS_FIELD_NAME
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
@@ -10,7 +10,7 @@ if TYPE_CHECKING:
from backend.data.graph import GraphModel, NodeModel
from backend.data.model import Credentials
from .base import BaseWebhooksManager
from ._base import BaseWebhooksManager
logger = logging.getLogger(__name__)
@@ -95,56 +95,92 @@ async def on_node_activate(
if not block.webhook_config:
return node
provider = block.webhook_config.provider
if provider not in WEBHOOK_MANAGERS_BY_NAME:
raise ValueError(
f"Block #{block.id} has webhook_config for provider {provider} "
"which does not support webhooks"
)
logger.debug(
f"Activating webhook node #{node.id} with config {block.webhook_config}"
)
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[block.webhook_config.provider]()
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
try:
resource = block.webhook_config.resource_format.format(**node.input_default)
except KeyError:
resource = None
logger.debug(
f"Constructed resource string {resource} from input {node.input_default}"
if auto_setup_webhook := isinstance(block.webhook_config, BlockWebhookConfig):
try:
resource = block.webhook_config.resource_format.format(**node.input_default)
except KeyError:
resource = None
logger.debug(
f"Constructed resource string {resource} from input {node.input_default}"
)
else:
resource = "" # not relevant for manual webhooks
needs_credentials = CREDENTIALS_FIELD_NAME in block.input_schema.model_fields
credentials_meta = (
node.input_default.get(CREDENTIALS_FIELD_NAME) if needs_credentials else None
)
event_filter_input_name = block.webhook_config.event_filter_input
has_everything_for_webhook = (
resource is not None
and CREDENTIALS_FIELD_NAME in node.input_default
and event_filter_input_name in node.input_default
and any(is_on for is_on in node.input_default[event_filter_input_name].values())
and (credentials_meta or not needs_credentials)
and (
not event_filter_input_name
or (
event_filter_input_name in node.input_default
and any(
is_on
for is_on in node.input_default[event_filter_input_name].values()
)
)
)
)
if has_everything_for_webhook and resource:
if has_everything_for_webhook and resource is not None:
logger.debug(f"Node #{node} has everything for a webhook!")
if not credentials:
credentials_meta = node.input_default[CREDENTIALS_FIELD_NAME]
if credentials_meta and not credentials:
raise ValueError(
f"Cannot set up webhook for node #{node.id}: "
f"credentials #{credentials_meta['id']} not available"
)
# Shape of the event filter is enforced in Block.__init__
event_filter = cast(dict, node.input_default[event_filter_input_name])
events = [
block.webhook_config.event_format.format(event=event)
for event, enabled in event_filter.items()
if enabled is True
]
logger.debug(f"Webhook events to subscribe to: {', '.join(events)}")
if event_filter_input_name:
# Shape of the event filter is enforced in Block.__init__
event_filter = cast(dict, node.input_default[event_filter_input_name])
events = [
block.webhook_config.event_format.format(event=event)
for event, enabled in event_filter.items()
if enabled is True
]
logger.debug(f"Webhook events to subscribe to: {', '.join(events)}")
else:
events = []
# Find/make and attach a suitable webhook to the node
new_webhook = await webhooks_manager.get_suitable_webhook(
user_id,
credentials,
block.webhook_config.webhook_type,
resource,
events,
)
if auto_setup_webhook:
assert credentials is not None
new_webhook = await webhooks_manager.get_suitable_auto_webhook(
user_id,
credentials,
block.webhook_config.webhook_type,
resource,
events,
)
else:
# Manual webhook -> no credentials -> don't register but do create
new_webhook = await webhooks_manager.get_manual_webhook(
user_id,
node.graph_id,
block.webhook_config.webhook_type,
events,
)
logger.debug(f"Acquired webhook: {new_webhook}")
return await set_node_webhook(node.id, new_webhook.id)
else:
logger.debug(f"Node #{node.id} does not have everything for a webhook")
return node
@@ -167,7 +203,14 @@ async def on_node_deactivate(
if not block.webhook_config:
return node
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[block.webhook_config.provider]()
provider = block.webhook_config.provider
if provider not in WEBHOOK_MANAGERS_BY_NAME:
raise ValueError(
f"Block #{block.id} has webhook_config for provider {provider} "
"which does not support webhooks"
)
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
if node.webhook_id:
logger.debug(f"Node #{node.id} has webhook_id {node.webhook_id}")
@@ -180,16 +223,20 @@ async def on_node_deactivate(
updated_node = await set_node_webhook(node.id, None)
# Prune and deregister the webhook if it is no longer used anywhere
logger.debug("Pruning and deregistering webhook if dangling")
webhook = node.webhook
if credentials:
logger.debug(f"Pruning webhook #{webhook.id} with credentials")
await webhooks_manager.prune_webhook_if_dangling(webhook.id, credentials)
else:
logger.debug(
f"Pruning{' and deregistering' if credentials else ''} "
f"webhook #{webhook.id}"
)
await webhooks_manager.prune_webhook_if_dangling(webhook.id, credentials)
if (
CREDENTIALS_FIELD_NAME in block.input_schema.model_fields
and not credentials
):
logger.warning(
f"Cannot deregister webhook #{webhook.id}: credentials "
f"#{webhook.credentials_id} not available "
f"({webhook.provider} webhook ID: {webhook.provider_webhook_id})"
f"({webhook.provider.value} webhook ID: {webhook.provider_webhook_id})"
)
return updated_node

View File

@@ -1,12 +1,12 @@
import logging
from typing import ClassVar
import requests
from fastapi import Request
from backend.data import integrations
from backend.data.model import APIKeyCredentials, Credentials
from backend.integrations.webhooks.base import BaseWebhooksManager
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks._base import BaseWebhooksManager
logger = logging.getLogger(__name__)
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class Slant3DWebhooksManager(BaseWebhooksManager):
"""Manager for Slant3D webhooks"""
PROVIDER_NAME: ClassVar[str] = "slant3d"
PROVIDER_NAME = ProviderName.SLANT3D
BASE_URL = "https://www.slant3dapi.com/api"
async def _register_webhook(

View File

@@ -0,0 +1,11 @@
from backend.util.settings import Config
app_config = Config()
# TODO: add test to assert this matches the actual API route
def webhook_ingress_url(provider_name: str, webhook_id: str) -> str:
return (
f"{app_config.platform_base_url}/api/integrations/{provider_name}"
f"/webhooks/{webhook_id}/ingress"
)

View File

@@ -1,5 +1,5 @@
import logging
from typing import Annotated, Literal
from typing import TYPE_CHECKING, Annotated, Literal
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from pydantic import BaseModel, Field, SecretStr
@@ -7,10 +7,10 @@ from pydantic import BaseModel, Field, SecretStr
from backend.data.graph import set_node_webhook
from backend.data.integrations import (
WebhookEvent,
get_all_webhooks,
get_all_webhooks_by_creds,
get_webhook,
listen_for_webhook_event,
publish_webhook_event,
wait_for_webhook_event,
)
from backend.data.model import (
APIKeyCredentials,
@@ -20,12 +20,16 @@ from backend.data.model import (
)
from backend.executor.manager import ExecutionManager
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
from backend.util.exceptions import NeedConfirmation
from backend.util.service import get_service_client
from backend.util.settings import Settings
if TYPE_CHECKING:
from backend.integrations.oauth import BaseOAuthHandler
from ..utils import get_user_id
logger = logging.getLogger(__name__)
@@ -42,7 +46,9 @@ class LoginResponse(BaseModel):
@router.get("/{provider}/login")
def login(
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
provider: Annotated[
ProviderName, Path(title="The provider to initiate an OAuth flow for")
],
user_id: Annotated[str, Depends(get_user_id)],
request: Request,
scopes: Annotated[
@@ -74,7 +80,9 @@ class CredentialsMetaResponse(BaseModel):
@router.post("/{provider}/callback")
def callback(
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
provider: Annotated[
ProviderName, Path(title="The target provider for this OAuth exchange")
],
code: Annotated[str, Body(title="Authorization code acquired by user login")],
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
user_id: Annotated[str, Depends(get_user_id)],
@@ -103,11 +111,12 @@ def callback(
if not set(scopes).issubset(set(credentials.scopes)):
# For now, we'll just log the warning and continue
logger.warning(
f"Granted scopes {credentials.scopes} for {provider}do not include all requested scopes {scopes}"
f"Granted scopes {credentials.scopes} for provider {provider.value} "
f"do not include all requested scopes {scopes}"
)
except Exception as e:
logger.error(f"Code->Token exchange failed for provider {provider}: {e}")
logger.error(f"Code->Token exchange failed for provider {provider.value}: {e}")
raise HTTPException(
status_code=400, detail=f"Failed to exchange code for tokens: {str(e)}"
)
@@ -116,7 +125,8 @@ def callback(
creds_manager.create(user_id, credentials)
logger.debug(
f"Successfully processed OAuth callback for user {user_id} and provider {provider}"
f"Successfully processed OAuth callback for user {user_id} "
f"and provider {provider.value}"
)
return CredentialsMetaResponse(
id=credentials.id,
@@ -148,7 +158,9 @@ def list_credentials(
@router.get("/{provider}/credentials")
def list_credentials_by_provider(
provider: Annotated[str, Path(title="The provider to list credentials for")],
provider: Annotated[
ProviderName, Path(title="The provider to list credentials for")
],
user_id: Annotated[str, Depends(get_user_id)],
) -> list[CredentialsMetaResponse]:
credentials = creds_manager.store.get_creds_by_provider(user_id, provider)
@@ -167,7 +179,9 @@ def list_credentials_by_provider(
@router.get("/{provider}/credentials/{cred_id}")
def get_credential(
provider: Annotated[str, Path(title="The provider to retrieve credentials for")],
provider: Annotated[
ProviderName, Path(title="The provider to retrieve credentials for")
],
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
user_id: Annotated[str, Depends(get_user_id)],
) -> Credentials:
@@ -184,7 +198,9 @@ def get_credential(
@router.post("/{provider}/credentials", status_code=201)
def create_api_key_credentials(
user_id: Annotated[str, Depends(get_user_id)],
provider: Annotated[str, Path(title="The provider to create credentials for")],
provider: Annotated[
ProviderName, Path(title="The provider to create credentials for")
],
api_key: Annotated[str, Body(title="The API key to store")],
title: Annotated[str, Body(title="Optional title for the credentials")],
expires_at: Annotated[
@@ -225,7 +241,9 @@ class CredentialsDeletionNeedsConfirmationResponse(BaseModel):
@router.delete("/{provider}/credentials/{cred_id}")
async def delete_credentials(
request: Request,
provider: Annotated[str, Path(title="The provider to delete credentials for")],
provider: Annotated[
ProviderName, Path(title="The provider to delete credentials for")
],
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],
user_id: Annotated[str, Depends(get_user_id)],
force: Annotated[
@@ -264,15 +282,20 @@ async def delete_credentials(
@router.post("/{provider}/webhooks/{webhook_id}/ingress")
async def webhook_ingress_generic(
request: Request,
provider: Annotated[str, Path(title="Provider where the webhook was registered")],
provider: Annotated[
ProviderName, Path(title="Provider where the webhook was registered")
],
webhook_id: Annotated[str, Path(title="Our ID for the webhook")],
):
logger.debug(f"Received {provider} webhook ingress for ID {webhook_id}")
logger.debug(f"Received {provider.value} webhook ingress for ID {webhook_id}")
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
webhook = await get_webhook(webhook_id)
logger.debug(f"Webhook #{webhook_id}: {webhook}")
payload, event_type = await webhook_manager.validate_payload(webhook, request)
logger.debug(f"Validated {provider} {event_type} event with payload {payload}")
logger.debug(
f"Validated {provider.value} {webhook.webhook_type} {event_type} event "
f"with payload {payload}"
)
webhook_event = WebhookEvent(
provider=provider,
@@ -300,18 +323,28 @@ async def webhook_ingress_generic(
)
@router.post("/{provider}/webhooks/{webhook_id}/ping")
@router.post("/webhooks/{webhook_id}/ping")
async def webhook_ping(
provider: Annotated[str, Path(title="Provider where the webhook was registered")],
webhook_id: Annotated[str, Path(title="Our ID for the webhook")],
user_id: Annotated[str, Depends(get_user_id)], # require auth
):
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
webhook = await get_webhook(webhook_id)
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[webhook.provider]()
await webhook_manager.trigger_ping(webhook)
if not await listen_for_webhook_event(webhook_id, event_type="ping"):
raise HTTPException(status_code=500, detail="Webhook ping event not received")
credentials = (
creds_manager.get(user_id, webhook.credentials_id)
if webhook.credentials_id
else None
)
try:
await webhook_manager.trigger_ping(webhook, credentials)
except NotImplementedError:
return False
if not await wait_for_webhook_event(webhook_id, event_type="ping", timeout=10):
raise HTTPException(status_code=504, detail="Webhook ping timed out")
return True
# --------------------------- UTILITIES ---------------------------- #
@@ -330,7 +363,15 @@ async def remove_all_webhooks_for_credentials(
Raises:
NeedConfirmation: If any of the webhooks are still in use and `force` is `False`
"""
webhooks = await get_all_webhooks(credentials.id)
webhooks = await get_all_webhooks_by_creds(credentials.id)
if credentials.provider not in WEBHOOK_MANAGERS_BY_NAME:
if webhooks:
logger.error(
f"Credentials #{credentials.id} for provider {credentials.provider} "
f"are attached to {len(webhooks)} webhooks, "
f"but there is no available WebhooksHandler for {credentials.provider}"
)
return
if any(w.attached_nodes for w in webhooks) and not force:
raise NeedConfirmation(
"Some webhooks linked to these credentials are still in use by an agent"
@@ -349,18 +390,23 @@ async def remove_all_webhooks_for_credentials(
logger.warning(f"Webhook #{webhook.id} failed to prune")
def _get_provider_oauth_handler(req: Request, provider_name: str) -> BaseOAuthHandler:
def _get_provider_oauth_handler(
req: Request, provider_name: ProviderName
) -> "BaseOAuthHandler":
if provider_name not in HANDLERS_BY_NAME:
raise HTTPException(
status_code=404, detail=f"Unknown provider '{provider_name}'"
status_code=404,
detail=f"Provider '{provider_name.value}' does not support OAuth",
)
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id")
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
if not (client_id and client_secret):
raise HTTPException(
status_code=501,
detail=f"Integration with provider '{provider_name}' is not configured",
detail=(
f"Integration with provider '{provider_name.value}' is not configured"
),
)
handler_class = HANDLERS_BY_NAME[provider_name]

View File

@@ -16,6 +16,8 @@ import backend.data.db
import backend.data.graph
import backend.data.user
import backend.server.routers.v1
import backend.server.v2.library.routes
import backend.server.v2.store.routes
import backend.util.service
import backend.util.settings
@@ -25,15 +27,26 @@ logger = logging.getLogger(__name__)
logging.getLogger("autogpt_libs").setLevel(logging.INFO)
@contextlib.contextmanager
def launch_darkly_context():
if settings.config.app_env != backend.util.settings.AppEnvironment.LOCAL:
initialize_launchdarkly()
try:
yield
finally:
shutdown_launchdarkly()
else:
yield
@contextlib.asynccontextmanager
async def lifespan_context(app: fastapi.FastAPI):
await backend.data.db.connect()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
initialize_launchdarkly()
yield
shutdown_launchdarkly()
with launch_darkly_context():
yield
await backend.data.db.disconnect()
@@ -73,7 +86,13 @@ def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
app.add_exception_handler(ValueError, handle_internal_http_error(400))
app.add_exception_handler(Exception, handle_internal_http_error(500))
app.include_router(backend.server.routers.v1.v1_router, tags=["v1"])
app.include_router(backend.server.routers.v1.v1_router, tags=["v1"], prefix="/api")
app.include_router(
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
)
app.include_router(
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
)
@app.get(path="/health", tags=["health"], dependencies=[])
@@ -106,17 +125,17 @@ class AgentServer(backend.util.service.AppProcess):
async def test_create_graph(
create_graph: backend.server.routers.v1.CreateGraph,
user_id: str,
is_template=False,
):
return await backend.server.routers.v1.create_new_graph(create_graph, user_id)
@staticmethod
async def test_get_graph_run_status(
graph_id: str, graph_exec_id: str, user_id: str
):
return await backend.server.routers.v1.get_graph_run_status(
graph_id, graph_exec_id, user_id
async def test_get_graph_run_status(graph_exec_id: str, user_id: str):
execution = await backend.data.graph.get_execution(
user_id=user_id, execution_id=graph_exec_id
)
if not execution:
raise ValueError(f"Execution {graph_exec_id} not found")
return execution.status
@staticmethod
async def test_get_graph_run_node_execution_results(

View File

@@ -69,8 +69,7 @@ integration_creds_manager = IntegrationCredentialsManager()
_user_credit_model = get_user_credit_model()
# Define the API routes
v1_router = APIRouter(prefix="/api")
v1_router = APIRouter()
v1_router.include_router(
backend.server.integrations.router.router,
@@ -132,7 +131,7 @@ def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput
@v1_router.get(path="/credits", dependencies=[Depends(auth_middleware)])
async def get_user_credits(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[str, int]:
# Credits can go negative, so ensure it's at least 0 for user to see.
return {"credits": max(await _user_credit_model.get_or_refill_credit(user_id), 0)}
@@ -149,12 +148,9 @@ class DeleteGraphResponse(TypedDict):
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
async def get_graphs(
user_id: Annotated[str, Depends(get_user_id)],
with_runs: bool = False,
) -> Sequence[graph_db.Graph]:
return await graph_db.get_graphs(
include_executions=with_runs, filter_by="active", user_id=user_id
)
user_id: Annotated[str, Depends(get_user_id)]
) -> Sequence[graph_db.GraphModel]:
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
@v1_router.get(
@@ -170,9 +166,9 @@ async def get_graph(
user_id: Annotated[str, Depends(get_user_id)],
version: int | None = None,
hide_credentials: bool = False,
) -> graph_db.Graph:
) -> graph_db.GraphModel:
graph = await graph_db.get_graph(
graph_id, version, user_id=user_id, hide_credentials=hide_credentials
graph_id, version, user_id=user_id, for_export=hide_credentials
)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
@@ -191,7 +187,7 @@ async def get_graph(
)
async def get_graph_all_versions(
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> Sequence[graph_db.Graph]:
) -> Sequence[graph_db.GraphModel]:
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
if not graphs:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
@@ -203,7 +199,7 @@ async def get_graph_all_versions(
)
async def create_new_graph(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.Graph:
) -> graph_db.GraphModel:
return await do_create_graph(create_graph, is_template=False, user_id=user_id)
@@ -213,7 +209,7 @@ async def do_create_graph(
# user_id doesn't have to be annotated like on other endpoints,
# because create_graph isn't used directly as an endpoint
user_id: str,
) -> graph_db.Graph:
) -> graph_db.GraphModel:
if create_graph.graph:
graph = graph_db.make_graph_model(create_graph.graph, user_id)
elif create_graph.template_id:
@@ -252,6 +248,13 @@ async def do_create_graph(
async def delete_graph(
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> DeleteGraphResponse:
if active_version := await graph_db.get_graph(graph_id, user_id=user_id):
def get_credentials(credentials_id: str) -> "Credentials | None":
return integration_creds_manager.get(user_id, credentials_id)
await on_graph_deactivate(active_version, get_credentials)
return {"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)}
@@ -267,7 +270,7 @@ async def update_graph(
graph_id: str,
graph: graph_db.Graph,
user_id: Annotated[str, Depends(get_user_id)],
) -> graph_db.Graph:
) -> graph_db.GraphModel:
# Sanity check
if graph.id and graph.id != graph_id:
raise HTTPException(400, detail="Graph ID does not match ID in URI")
@@ -386,7 +389,7 @@ def execute_graph(
async def stop_graph_run(
graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> Sequence[execution_db.ExecutionResult]:
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
if not await graph_db.get_execution(user_id=user_id, execution_id=graph_exec_id):
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
await asyncio.to_thread(
@@ -398,23 +401,14 @@ async def stop_graph_run(
@v1_router.get(
path="/graphs/{graph_id}/executions",
path="/executions",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def list_graph_runs(
graph_id: str,
async def get_executions(
user_id: Annotated[str, Depends(get_user_id)],
graph_version: int | None = None,
) -> Sequence[str]:
graph = await graph_db.get_graph(graph_id, graph_version, user_id=user_id)
if not graph:
rev = "" if graph_version is None else f" v{graph_version}"
raise HTTPException(
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
)
return await execution_db.list_executions(graph_id, graph_version)
) -> list[graph_db.GraphExecution]:
return await graph_db.get_executions(user_id=user_id)
@v1_router.get(
@@ -434,25 +428,6 @@ async def get_graph_run_node_execution_results(
return await execution_db.get_execution_results(graph_exec_id)
# NOTE: This is used for testing
async def get_graph_run_status(
graph_id: str,
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> execution_db.ExecutionStatus:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
execution = await execution_db.get_graph_execution(graph_exec_id, user_id)
if not execution:
raise HTTPException(
status_code=404, detail=f"Execution #{graph_exec_id} not found."
)
return execution.executionStatus
########################################################
##################### Templates ########################
########################################################
@@ -465,7 +440,7 @@ async def get_graph_run_status(
)
async def get_templates(
user_id: Annotated[str, Depends(get_user_id)]
) -> Sequence[graph_db.Graph]:
) -> Sequence[graph_db.GraphModel]:
return await graph_db.get_graphs(filter_by="template", user_id=user_id)
@@ -474,7 +449,9 @@ async def get_templates(
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_template(graph_id: str, version: int | None = None) -> graph_db.Graph:
async def get_template(
graph_id: str, version: int | None = None
) -> graph_db.GraphModel:
graph = await graph_db.get_graph(graph_id, version, template=True)
if not graph:
raise HTTPException(status_code=404, detail=f"Template #{graph_id} not found.")
@@ -488,7 +465,7 @@ async def get_template(graph_id: str, version: int | None = None) -> graph_db.Gr
)
async def create_new_template(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.Graph:
) -> graph_db.GraphModel:
return await do_create_graph(create_graph, is_template=True, user_id=user_id)

View File

@@ -0,0 +1,165 @@
import logging
from typing import List
import prisma.errors
import prisma.models
import prisma.types
import backend.data.graph
import backend.data.includes
import backend.server.v2.library.model
import backend.server.v2.store.exceptions
logger = logging.getLogger(__name__)
async def get_library_agents(
user_id: str,
) -> List[backend.server.v2.library.model.LibraryAgent]:
"""
Returns all agents (AgentGraph) that belong to the user and all agents in their library (UserAgent table)
"""
logger.debug(f"Getting library agents for user {user_id}")
try:
# Get agents created by user with nodes and links
user_created = await prisma.models.AgentGraph.prisma().find_many(
where=prisma.types.AgentGraphWhereInput(userId=user_id, isActive=True),
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
)
# Get agents in user's library with nodes and links
library_agents = await prisma.models.UserAgent.prisma().find_many(
where=prisma.types.UserAgentWhereInput(
userId=user_id, isDeleted=False, isArchived=False
),
include={
"Agent": {
"include": {
"AgentNodes": {
"include": {
"Input": True,
"Output": True,
"Webhook": True,
"AgentBlock": True,
}
}
}
}
},
)
# Convert to Graph models first
graphs = []
# Add user created agents
for agent in user_created:
try:
graphs.append(backend.data.graph.GraphModel.from_db(agent))
except Exception as e:
logger.error(f"Error processing user created agent {agent.id}: {e}")
continue
# Add library agents
for agent in library_agents:
if agent.Agent:
try:
graphs.append(backend.data.graph.GraphModel.from_db(agent.Agent))
except Exception as e:
logger.error(f"Error processing library agent {agent.agentId}: {e}")
continue
# Convert Graph models to LibraryAgent models
result = []
for graph in graphs:
result.append(
backend.server.v2.library.model.LibraryAgent(
id=graph.id,
version=graph.version,
is_active=graph.is_active,
name=graph.name,
description=graph.description,
isCreatedByUser=any(a.id == graph.id for a in user_created),
input_schema=graph.input_schema,
output_schema=graph.output_schema,
)
)
logger.debug(f"Found {len(result)} library agents")
return result
except prisma.errors.PrismaError as e:
logger.error(f"Database error getting library agents: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch library agents"
) from e
async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> None:
"""
Finds the agent from the store listing version and adds it to the user's library (UserAgent table)
if they don't already have it
"""
logger.debug(
f"Adding agent from store listing version {store_listing_version_id} to library for user {user_id}"
)
try:
# Get store listing version to find agent
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}, include={"Agent": True}
)
)
if not store_listing_version or not store_listing_version.Agent:
logger.warning(
f"Store listing version not found: {store_listing_version_id}"
)
raise backend.server.v2.store.exceptions.AgentNotFoundError(
f"Store listing version {store_listing_version_id} not found"
)
agent = store_listing_version.Agent
if agent.userId == user_id:
logger.warning(
f"User {user_id} cannot add their own agent to their library"
)
raise backend.server.v2.store.exceptions.DatabaseError(
"Cannot add own agent to library"
)
# Check if user already has this agent
existing_user_agent = await prisma.models.UserAgent.prisma().find_first(
where={
"userId": user_id,
"agentId": agent.id,
"agentVersion": agent.version,
}
)
if existing_user_agent:
logger.debug(
f"User {user_id} already has agent {agent.id} in their library"
)
return
# Create UserAgent entry
await prisma.models.UserAgent.prisma().create(
data=prisma.types.UserAgentCreateInput(
userId=user_id,
agentId=agent.id,
agentVersion=agent.version,
isCreatedByUser=False,
)
)
logger.debug(f"Added agent {agent.id} to library for user {user_id}")
except backend.server.v2.store.exceptions.AgentNotFoundError:
raise
except prisma.errors.PrismaError as e:
logger.error(f"Database error adding agent to library: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to add agent to library"
) from e

View File

@@ -0,0 +1,197 @@
from datetime import datetime
import prisma.errors
import prisma.models
import pytest
from prisma import Prisma
import backend.data.includes
import backend.server.v2.library.db as db
import backend.server.v2.store.exceptions
@pytest.fixture(autouse=True)
async def setup_prisma():
# Don't register client if already registered
try:
Prisma()
except prisma.errors.ClientAlreadyRegisteredError:
pass
yield
@pytest.mark.asyncio
async def test_get_library_agents(mocker):
# Mock data
mock_user_created = [
prisma.models.AgentGraph(
id="agent1",
version=1,
name="Test Agent 1",
description="Test Description 1",
userId="test-user",
isActive=True,
createdAt=datetime.now(),
isTemplate=False,
)
]
mock_library_agents = [
prisma.models.UserAgent(
id="ua1",
userId="test-user",
agentId="agent2",
agentVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
createdAt=datetime.now(),
updatedAt=datetime.now(),
isFavorite=False,
Agent=prisma.models.AgentGraph(
id="agent2",
version=1,
name="Test Agent 2",
description="Test Description 2",
userId="other-user",
isActive=True,
createdAt=datetime.now(),
isTemplate=False,
),
)
]
# Mock prisma calls
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
mock_agent_graph.return_value.find_many = mocker.AsyncMock(
return_value=mock_user_created
)
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma")
mock_user_agent.return_value.find_many = mocker.AsyncMock(
return_value=mock_library_agents
)
# Call function
result = await db.get_library_agents("test-user")
# Verify results
assert len(result) == 2
assert result[0].id == "agent1"
assert result[0].name == "Test Agent 1"
assert result[0].description == "Test Description 1"
assert result[0].isCreatedByUser is True
assert result[1].id == "agent2"
assert result[1].name == "Test Agent 2"
assert result[1].description == "Test Description 2"
assert result[1].isCreatedByUser is False
# Verify mocks called correctly
mock_agent_graph.return_value.find_many.assert_called_once_with(
where=prisma.types.AgentGraphWhereInput(userId="test-user", isActive=True),
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
)
mock_user_agent.return_value.find_many.assert_called_once_with(
where=prisma.types.UserAgentWhereInput(
userId="test-user", isDeleted=False, isArchived=False
),
include={
"Agent": {
"include": {
"AgentNodes": {
"include": {
"Input": True,
"Output": True,
"Webhook": True,
"AgentBlock": True,
}
}
}
}
},
)
@pytest.mark.asyncio
async def test_add_agent_to_library(mocker):
# Mock data
mock_store_listing = prisma.models.StoreListingVersion(
id="version123",
version=1,
createdAt=datetime.now(),
updatedAt=datetime.now(),
agentId="agent1",
agentVersion=1,
slug="test-agent",
name="Test Agent",
subHeading="Test Agent Subheading",
imageUrls=["https://example.com/image.jpg"],
description="Test Description",
categories=["test"],
isFeatured=False,
isDeleted=False,
isAvailable=True,
isApproved=True,
Agent=prisma.models.AgentGraph(
id="agent1",
version=1,
name="Test Agent",
description="Test Description",
userId="creator",
isActive=True,
createdAt=datetime.now(),
isTemplate=False,
),
)
# Mock prisma calls
mock_store_listing_version = mocker.patch(
"prisma.models.StoreListingVersion.prisma"
)
mock_store_listing_version.return_value.find_unique = mocker.AsyncMock(
return_value=mock_store_listing
)
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma")
mock_user_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_user_agent.return_value.create = mocker.AsyncMock()
# Call function
await db.add_agent_to_library("version123", "test-user")
# Verify mocks called correctly
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"Agent": True}
)
mock_user_agent.return_value.find_first.assert_called_once_with(
where={
"userId": "test-user",
"agentId": "agent1",
"agentVersion": 1,
}
)
mock_user_agent.return_value.create.assert_called_once_with(
data=prisma.types.UserAgentCreateInput(
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
)
)
@pytest.mark.asyncio
async def test_add_agent_to_library_not_found(mocker):
# Mock prisma calls
mock_store_listing_version = mocker.patch(
"prisma.models.StoreListingVersion.prisma"
)
mock_store_listing_version.return_value.find_unique = mocker.AsyncMock(
return_value=None
)
# Call function and verify exception
with pytest.raises(backend.server.v2.store.exceptions.AgentNotFoundError):
await db.add_agent_to_library("version123", "test-user")
# Verify mock called correctly
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"Agent": True}
)

View File

@@ -0,0 +1,16 @@
import typing
import pydantic
class LibraryAgent(pydantic.BaseModel):
id: str # Changed from agent_id to match GraphMeta
version: int # Changed from agent_version to match GraphMeta
is_active: bool # Added to match GraphMeta
name: str
description: str
isCreatedByUser: bool
# Made input_schema and output_schema match GraphMeta's type
input_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
output_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend

View File

@@ -0,0 +1,43 @@
import backend.server.v2.library.model
def test_library_agent():
agent = backend.server.v2.library.model.LibraryAgent(
id="test-agent-123",
version=1,
is_active=True,
name="Test Agent",
description="Test description",
isCreatedByUser=False,
input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}},
)
assert agent.id == "test-agent-123"
assert agent.version == 1
assert agent.is_active is True
assert agent.name == "Test Agent"
assert agent.description == "Test description"
assert agent.isCreatedByUser is False
assert agent.input_schema == {"type": "object", "properties": {}}
assert agent.output_schema == {"type": "object", "properties": {}}
def test_library_agent_with_user_created():
agent = backend.server.v2.library.model.LibraryAgent(
id="user-agent-456",
version=2,
is_active=True,
name="User Created Agent",
description="An agent created by the user",
isCreatedByUser=True,
input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}},
)
assert agent.id == "user-agent-456"
assert agent.version == 2
assert agent.is_active is True
assert agent.name == "User Created Agent"
assert agent.description == "An agent created by the user"
assert agent.isCreatedByUser is True
assert agent.input_schema == {"type": "object", "properties": {}}
assert agent.output_schema == {"type": "object", "properties": {}}

View File

@@ -0,0 +1,74 @@
import logging
import typing
import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware
import fastapi
import backend.data.graph
import backend.server.v2.library.db
import backend.server.v2.library.model
logger = logging.getLogger(__name__)
router = fastapi.APIRouter()
@router.get(
"/agents",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
async def get_library_agents(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
]
) -> typing.Sequence[backend.server.v2.library.model.LibraryAgent]:
"""
Get all agents in the user's library, including both created and saved agents.
"""
try:
agents = await backend.server.v2.library.db.get_library_agents(user_id)
return agents
except Exception:
logger.exception("Exception occurred whilst getting library agents")
raise fastapi.HTTPException(
status_code=500, detail="Failed to get library agents"
)
@router.post(
"/agents/{store_listing_version_id}",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
status_code=201,
)
async def add_agent_to_library(
store_listing_version_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> fastapi.Response:
"""
Add an agent from the store to the user's library.
Args:
store_listing_version_id (str): ID of the store listing version to add
user_id (str): ID of the authenticated user
Returns:
fastapi.Response: 201 status code on success
Raises:
HTTPException: If there is an error adding the agent to the library
"""
try:
await backend.server.v2.library.db.add_agent_to_library(
store_listing_version_id=store_listing_version_id, user_id=user_id
)
return fastapi.Response(status_code=201)
except Exception:
logger.exception("Exception occurred whilst adding agent to library")
raise fastapi.HTTPException(
status_code=500, detail="Failed to add agent to library"
)

View File

@@ -0,0 +1,103 @@
import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware
import fastapi
import fastapi.testclient
import pytest_mock
import backend.server.v2.library.db
import backend.server.v2.library.model
import backend.server.v2.library.routes
app = fastapi.FastAPI()
app.include_router(backend.server.v2.library.routes.router)
client = fastapi.testclient.TestClient(app)
def override_auth_middleware():
"""Override auth middleware for testing"""
return {"sub": "test-user-id"}
def override_get_user_id():
"""Override get_user_id for testing"""
return "test-user-id"
app.dependency_overrides[autogpt_libs.auth.middleware.auth_middleware] = (
override_auth_middleware
)
app.dependency_overrides[autogpt_libs.auth.depends.get_user_id] = override_get_user_id
def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
mocked_value = [
backend.server.v2.library.model.LibraryAgent(
id="test-agent-1",
version=1,
is_active=True,
name="Test Agent 1",
description="Test Description 1",
isCreatedByUser=True,
input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}},
),
backend.server.v2.library.model.LibraryAgent(
id="test-agent-2",
version=1,
is_active=True,
name="Test Agent 2",
description="Test Description 2",
isCreatedByUser=False,
input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}},
),
]
mock_db_call = mocker.patch("backend.server.v2.library.db.get_library_agents")
mock_db_call.return_value = mocked_value
response = client.get("/agents")
assert response.status_code == 200
data = [
backend.server.v2.library.model.LibraryAgent.model_validate(agent)
for agent in response.json()
]
assert len(data) == 2
assert data[0].id == "test-agent-1"
assert data[0].isCreatedByUser is True
assert data[1].id == "test-agent-2"
assert data[1].isCreatedByUser is False
mock_db_call.assert_called_once_with("test-user-id")
def test_get_library_agents_error(mocker: pytest_mock.MockFixture):
mock_db_call = mocker.patch("backend.server.v2.library.db.get_library_agents")
mock_db_call.side_effect = Exception("Test error")
response = client.get("/agents")
assert response.status_code == 500
mock_db_call.assert_called_once_with("test-user-id")
def test_add_agent_to_library_success(mocker: pytest_mock.MockFixture):
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
mock_db_call.return_value = None
response = client.post("/agents/test-version-id")
assert response.status_code == 201
mock_db_call.assert_called_once_with(
store_listing_version_id="test-version-id", user_id="test-user-id"
)
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture):
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
mock_db_call.side_effect = Exception("Test error")
response = client.post("/agents/test-version-id")
assert response.status_code == 500
assert response.json()["detail"] == "Failed to add agent to library"
mock_db_call.assert_called_once_with(
store_listing_version_id="test-version-id", user_id="test-user-id"
)

View File

@@ -0,0 +1,53 @@
# Store Module
This module implements the backend API for the AutoGPT Store, handling agents, creators, profiles, submissions and media uploads.
## Files
### routes.py
Contains the FastAPI route handlers for the store API endpoints:
- Profile endpoints for managing user profiles
- Agent endpoints for browsing and retrieving store agents
- Creator endpoints for browsing and retrieving creator details
- Store submission endpoints for submitting agents to the store
- Media upload endpoints for submission images/videos
### model.py
Contains Pydantic models for request/response validation and serialization:
- Pagination model for paginated responses
- Models for agents, creators, profiles, submissions
- Request/response models for all API endpoints
### db.py
Contains database access functions using Prisma ORM:
- Functions to query and manipulate store data
- Handles database operations for all API endpoints
- Implements business logic and data validation
### media.py
Handles media file uploads to Google Cloud Storage:
- Validates file types and sizes
- Processes image and video uploads
- Stores files in GCS buckets
- Returns public URLs for uploaded media
## Key Features
- Paginated listings of store agents and creators
- Search and filtering of agents and creators
- Agent submission workflow
- Media file upload handling
- Profile management
- Reviews and ratings
## Authentication
Most endpoints require authentication via the AutoGPT auth middleware. Public endpoints are marked with the "public" tag.
## Error Handling
All database and storage operations include proper error handling and logging. Errors are mapped to appropriate HTTP status codes.

View File

@@ -0,0 +1,765 @@
import logging
import random
from datetime import datetime
import prisma.enums
import prisma.errors
import prisma.models
import prisma.types
import backend.server.v2.store.exceptions
import backend.server.v2.store.model
logger = logging.getLogger(__name__)
async def get_store_agents(
featured: bool = False,
creator: str | None = None,
sorted_by: str | None = None,
search_query: str | None = None,
category: str | None = None,
page: int = 1,
page_size: int = 20,
) -> backend.server.v2.store.model.StoreAgentsResponse:
logger.debug(
f"Getting store agents. featured={featured}, creator={creator}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
)
sanitized_query = None
# Sanitize and validate search query by escaping special characters
if search_query is not None:
sanitized_query = search_query.strip()
if not sanitized_query or len(sanitized_query) > 100: # Reasonable length limit
raise backend.server.v2.store.exceptions.DatabaseError(
"Invalid search query"
)
# Escape special SQL characters
sanitized_query = (
sanitized_query.replace("\\", "\\\\")
.replace("%", "\\%")
.replace("_", "\\_")
.replace("[", "\\[")
.replace("]", "\\]")
.replace("'", "\\'")
.replace('"', '\\"')
.replace(";", "\\;")
.replace("--", "\\--")
.replace("/*", "\\/*")
.replace("*/", "\\*/")
)
where_clause = {}
if featured:
where_clause["featured"] = featured
if creator:
where_clause["creator_username"] = creator
if category:
where_clause["categories"] = {"has": category}
if sanitized_query:
where_clause["OR"] = [
{"agent_name": {"contains": sanitized_query, "mode": "insensitive"}},
{"description": {"contains": sanitized_query, "mode": "insensitive"}},
]
order_by = []
if sorted_by == "rating":
order_by.append({"rating": "desc"})
elif sorted_by == "runs":
order_by.append({"runs": "desc"})
elif sorted_by == "name":
order_by.append({"agent_name": "asc"})
try:
agents = await prisma.models.StoreAgent.prisma().find_many(
where=prisma.types.StoreAgentWhereInput(**where_clause),
order=order_by,
skip=(page - 1) * page_size,
take=page_size,
)
total = await prisma.models.StoreAgent.prisma().count(
where=prisma.types.StoreAgentWhereInput(**where_clause)
)
total_pages = (total + page_size - 1) // page_size
store_agents = [
backend.server.v2.store.model.StoreAgent(
slug=agent.slug,
agent_name=agent.agent_name,
agent_image=agent.agent_image[0] if agent.agent_image else "",
creator=agent.creator_username,
creator_avatar=agent.creator_avatar,
sub_heading=agent.sub_heading,
description=agent.description,
runs=agent.runs,
rating=agent.rating,
)
for agent in agents
]
logger.debug(f"Found {len(store_agents)} agents")
return backend.server.v2.store.model.StoreAgentsResponse(
agents=store_agents,
pagination=backend.server.v2.store.model.Pagination(
current_page=page,
total_items=total,
total_pages=total_pages,
page_size=page_size,
),
)
except Exception as e:
logger.error(f"Error getting store agents: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch store agents"
) from e
async def get_store_agent_details(
username: str, agent_name: str
) -> backend.server.v2.store.model.StoreAgentDetails:
logger.debug(f"Getting store agent details for {username}/{agent_name}")
try:
agent = await prisma.models.StoreAgent.prisma().find_first(
where={"creator_username": username, "slug": agent_name}
)
if not agent:
logger.warning(f"Agent not found: {username}/{agent_name}")
raise backend.server.v2.store.exceptions.AgentNotFoundError(
f"Agent {username}/{agent_name} not found"
)
logger.debug(f"Found agent details for {username}/{agent_name}")
return backend.server.v2.store.model.StoreAgentDetails(
store_listing_version_id=agent.storeListingVersionId,
slug=agent.slug,
agent_name=agent.agent_name,
agent_video=agent.agent_video or "",
agent_image=agent.agent_image,
creator=agent.creator_username,
creator_avatar=agent.creator_avatar,
sub_heading=agent.sub_heading,
description=agent.description,
categories=agent.categories,
runs=agent.runs,
rating=agent.rating,
versions=agent.versions,
last_updated=agent.updated_at,
)
except backend.server.v2.store.exceptions.AgentNotFoundError:
raise
except Exception as e:
logger.error(f"Error getting store agent details: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch agent details"
) from e
async def get_store_creators(
featured: bool = False,
search_query: str | None = None,
sorted_by: str | None = None,
page: int = 1,
page_size: int = 20,
) -> backend.server.v2.store.model.CreatorsResponse:
logger.debug(
f"Getting store creators. featured={featured}, search={search_query}, sorted_by={sorted_by}, page={page}"
)
# Build where clause with sanitized inputs
where = {}
if featured:
where["is_featured"] = featured
# Add search filter if provided, using parameterized queries
if search_query:
# Sanitize and validate search query by escaping special characters
sanitized_query = search_query.strip()
if not sanitized_query or len(sanitized_query) > 100: # Reasonable length limit
raise backend.server.v2.store.exceptions.DatabaseError(
"Invalid search query"
)
# Escape special SQL characters
sanitized_query = (
sanitized_query.replace("\\", "\\\\")
.replace("%", "\\%")
.replace("_", "\\_")
.replace("[", "\\[")
.replace("]", "\\]")
.replace("'", "\\'")
.replace('"', '\\"')
.replace(";", "\\;")
.replace("--", "\\--")
.replace("/*", "\\/*")
.replace("*/", "\\*/")
)
where["OR"] = [
{"username": {"contains": sanitized_query, "mode": "insensitive"}},
{"name": {"contains": sanitized_query, "mode": "insensitive"}},
{"description": {"contains": sanitized_query, "mode": "insensitive"}},
]
try:
# Validate pagination parameters
if not isinstance(page, int) or page < 1:
raise backend.server.v2.store.exceptions.DatabaseError(
"Invalid page number"
)
if not isinstance(page_size, int) or page_size < 1 or page_size > 100:
raise backend.server.v2.store.exceptions.DatabaseError("Invalid page size")
# Get total count for pagination using sanitized where clause
total = await prisma.models.Creator.prisma().count(
where=prisma.types.CreatorWhereInput(**where)
)
total_pages = (total + page_size - 1) // page_size
# Add pagination with validated parameters
skip = (page - 1) * page_size
take = page_size
# Add sorting with validated sort parameter
order = []
valid_sort_fields = {"agent_rating", "agent_runs", "num_agents"}
if sorted_by in valid_sort_fields:
order.append({sorted_by: "desc"})
else:
order.append({"username": "asc"})
# Execute query with sanitized parameters
creators = await prisma.models.Creator.prisma().find_many(
where=prisma.types.CreatorWhereInput(**where),
skip=skip,
take=take,
order=order,
)
# Convert to response model
creator_models = [
backend.server.v2.store.model.Creator(
username=creator.username,
name=creator.name,
description=creator.description,
avatar_url=creator.avatar_url,
num_agents=creator.num_agents,
agent_rating=creator.agent_rating,
agent_runs=creator.agent_runs,
is_featured=creator.is_featured,
)
for creator in creators
]
logger.debug(f"Found {len(creator_models)} creators")
return backend.server.v2.store.model.CreatorsResponse(
creators=creator_models,
pagination=backend.server.v2.store.model.Pagination(
current_page=page,
total_items=total,
total_pages=total_pages,
page_size=page_size,
),
)
except Exception as e:
logger.error(f"Error getting store creators: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch store creators"
) from e
async def get_store_creator_details(
username: str,
) -> backend.server.v2.store.model.CreatorDetails:
logger.debug(f"Getting store creator details for {username}")
try:
# Query creator details from database
creator = await prisma.models.Creator.prisma().find_unique(
where={"username": username}
)
if not creator:
logger.warning(f"Creator not found: {username}")
raise backend.server.v2.store.exceptions.CreatorNotFoundError(
f"Creator {username} not found"
)
logger.debug(f"Found creator details for {username}")
return backend.server.v2.store.model.CreatorDetails(
name=creator.name,
username=creator.username,
description=creator.description,
links=creator.links,
avatar_url=creator.avatar_url,
agent_rating=creator.agent_rating,
agent_runs=creator.agent_runs,
top_categories=creator.top_categories,
)
except backend.server.v2.store.exceptions.CreatorNotFoundError:
raise
except Exception as e:
logger.error(f"Error getting store creator details: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch creator details"
) from e
async def get_store_submissions(
user_id: str, page: int = 1, page_size: int = 20
) -> backend.server.v2.store.model.StoreSubmissionsResponse:
logger.debug(f"Getting store submissions for user {user_id}, page={page}")
try:
# Calculate pagination values
skip = (page - 1) * page_size
where = prisma.types.StoreSubmissionWhereInput(user_id=user_id)
# Query submissions from database
submissions = await prisma.models.StoreSubmission.prisma().find_many(
where=where, skip=skip, take=page_size, order=[{"date_submitted": "desc"}]
)
# Get total count for pagination
total = await prisma.models.StoreSubmission.prisma().count(where=where)
total_pages = (total + page_size - 1) // page_size
# Convert to response models
submission_models = [
backend.server.v2.store.model.StoreSubmission(
agent_id=sub.agent_id,
agent_version=sub.agent_version,
name=sub.name,
sub_heading=sub.sub_heading,
slug=sub.slug,
description=sub.description,
image_urls=sub.image_urls or [],
date_submitted=sub.date_submitted or datetime.now(),
status=sub.status,
runs=sub.runs or 0,
rating=sub.rating or 0.0,
)
for sub in submissions
]
logger.debug(f"Found {len(submission_models)} submissions")
return backend.server.v2.store.model.StoreSubmissionsResponse(
submissions=submission_models,
pagination=backend.server.v2.store.model.Pagination(
current_page=page,
total_items=total,
total_pages=total_pages,
page_size=page_size,
),
)
except Exception as e:
logger.error(f"Error fetching store submissions: {str(e)}")
# Return empty response rather than exposing internal errors
return backend.server.v2.store.model.StoreSubmissionsResponse(
submissions=[],
pagination=backend.server.v2.store.model.Pagination(
current_page=page,
total_items=0,
total_pages=0,
page_size=page_size,
),
)
async def delete_store_submission(
user_id: str,
submission_id: str,
) -> bool:
"""
Delete a store listing submission.
Args:
user_id: ID of the authenticated user
submission_id: ID of the submission to be deleted
Returns:
bool: True if the submission was successfully deleted, False otherwise
"""
logger.debug(f"Deleting store submission {submission_id} for user {user_id}")
try:
# Verify the submission belongs to this user
submission = await prisma.models.StoreListing.prisma().find_first(
where={"agentId": submission_id, "owningUserId": user_id}
)
if not submission:
logger.warning(f"Submission not found for user {user_id}: {submission_id}")
raise backend.server.v2.store.exceptions.SubmissionNotFoundError(
f"Submission not found for this user. User ID: {user_id}, Submission ID: {submission_id}"
)
# Delete the submission
await prisma.models.StoreListing.prisma().delete(
where=prisma.types.StoreListingWhereUniqueInput(id=submission.id)
)
logger.debug(
f"Successfully deleted submission {submission_id} for user {user_id}"
)
return True
except Exception as e:
logger.error(f"Error deleting store submission: {str(e)}")
return False
async def create_store_submission(
user_id: str,
agent_id: str,
agent_version: int,
slug: str,
name: str,
video_url: str | None = None,
image_urls: list[str] = [],
description: str = "",
sub_heading: str = "",
categories: list[str] = [],
) -> backend.server.v2.store.model.StoreSubmission:
"""
Create a new store listing submission.
Args:
user_id: ID of the authenticated user submitting the listing
agent_id: ID of the agent being submitted
agent_version: Version of the agent being submitted
slug: URL slug for the listing
name: Name of the agent
video_url: Optional URL to video demo
image_urls: List of image URLs for the listing
description: Description of the agent
categories: List of categories for the agent
Returns:
StoreSubmission: The created store submission
"""
logger.debug(
f"Creating store submission for user {user_id}, agent {agent_id} v{agent_version}"
)
try:
# First verify the agent belongs to this user
agent = await prisma.models.AgentGraph.prisma().find_first(
where=prisma.types.AgentGraphWhereInput(
id=agent_id, version=agent_version, userId=user_id
)
)
if not agent:
logger.warning(
f"Agent not found for user {user_id}: {agent_id} v{agent_version}"
)
raise backend.server.v2.store.exceptions.AgentNotFoundError(
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
)
listing = await prisma.models.StoreListing.prisma().find_first(
where=prisma.types.StoreListingWhereInput(
agentId=agent_id, owningUserId=user_id
)
)
if listing is not None:
logger.warning(f"Listing already exists for agent {agent_id}")
raise backend.server.v2.store.exceptions.ListingExistsError(
"Listing already exists for this agent"
)
# Create the store listing
listing = await prisma.models.StoreListing.prisma().create(
data={
"agentId": agent_id,
"agentVersion": agent_version,
"owningUserId": user_id,
"createdAt": datetime.now(),
"StoreListingVersions": {
"create": {
"agentId": agent_id,
"agentVersion": agent_version,
"slug": slug,
"name": name,
"videoUrl": video_url,
"imageUrls": image_urls,
"description": description,
"categories": categories,
"subHeading": sub_heading,
}
},
}
)
logger.debug(f"Created store listing for agent {agent_id}")
# Return submission details
return backend.server.v2.store.model.StoreSubmission(
agent_id=agent_id,
agent_version=agent_version,
name=name,
slug=slug,
sub_heading=sub_heading,
description=description,
image_urls=image_urls,
date_submitted=listing.createdAt,
status=prisma.enums.SubmissionStatus.PENDING,
runs=0,
rating=0.0,
)
except (
backend.server.v2.store.exceptions.AgentNotFoundError,
backend.server.v2.store.exceptions.ListingExistsError,
):
raise
except prisma.errors.PrismaError as e:
logger.error(f"Database error creating store submission: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to create store submission"
) from e
async def create_store_review(
user_id: str,
store_listing_version_id: str,
score: int,
comments: str | None = None,
) -> backend.server.v2.store.model.StoreReview:
try:
review = await prisma.models.StoreListingReview.prisma().upsert(
where={
"storeListingVersionId_reviewByUserId": {
"storeListingVersionId": store_listing_version_id,
"reviewByUserId": user_id,
}
},
data={
"create": {
"reviewByUserId": user_id,
"storeListingVersionId": store_listing_version_id,
"score": score,
"comments": comments,
},
"update": {
"score": score,
"comments": comments,
},
},
)
return backend.server.v2.store.model.StoreReview(
score=review.score,
comments=review.comments,
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error creating store review: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to create store review"
) from e
async def get_user_profile(
user_id: str,
) -> backend.server.v2.store.model.ProfileDetails:
logger.debug(f"Getting user profile for {user_id}")
try:
profile = await prisma.models.Profile.prisma().find_first(
where={"userId": user_id} # type: ignore
)
if not profile:
logger.warning(f"Profile not found for user {user_id}")
await prisma.models.Profile.prisma().create(
data=prisma.types.ProfileCreateInput(
userId=user_id,
name="No Profile Data",
username=f"{random.choice(['happy', 'clever', 'swift', 'bright', 'wise'])}-{random.choice(['fox', 'wolf', 'bear', 'eagle', 'owl'])}_{random.randint(1000,9999)}",
description="No Profile Data",
links=[],
avatarUrl="",
)
)
return backend.server.v2.store.model.ProfileDetails(
name="No Profile Data",
username="No Profile Data",
description="No Profile Data",
links=[],
avatar_url="",
)
return backend.server.v2.store.model.ProfileDetails(
name=profile.name,
username=profile.username,
description=profile.description,
links=profile.links,
avatar_url=profile.avatarUrl,
)
except Exception as e:
logger.error(f"Error getting user profile: {str(e)}")
return backend.server.v2.store.model.ProfileDetails(
name="No Profile Data",
username="No Profile Data",
description="No Profile Data",
links=[],
avatar_url="",
)
async def update_or_create_profile(
user_id: str, profile: backend.server.v2.store.model.Profile
) -> backend.server.v2.store.model.CreatorDetails:
"""
Update the store profile for a user. Creates a new profile if one doesn't exist.
Only allows updating if the user_id matches the owning user.
Args:
user_id: ID of the authenticated user
profile: Updated profile details
Returns:
CreatorDetails: The updated profile
Raises:
HTTPException: If user is not authorized to update this profile
"""
logger.debug(f"Updating profile for user {user_id}")
try:
# Check if profile exists for user
existing_profile = await prisma.models.Profile.prisma().find_first(
where={"userId": user_id}
)
# If no profile exists, create a new one
if not existing_profile:
logger.debug(f"Creating new profile for user {user_id}")
# Create new profile since one doesn't exist
new_profile = await prisma.models.Profile.prisma().create(
data={
"userId": user_id,
"name": profile.name,
"username": profile.username,
"description": profile.description,
"links": profile.links,
"avatarUrl": profile.avatar_url,
}
)
return backend.server.v2.store.model.CreatorDetails(
name=new_profile.name,
username=new_profile.username,
description=new_profile.description,
links=new_profile.links,
avatar_url=new_profile.avatarUrl or "",
agent_rating=0.0,
agent_runs=0,
top_categories=[],
)
else:
logger.debug(f"Updating existing profile for user {user_id}")
# Update the existing profile
updated_profile = await prisma.models.Profile.prisma().update(
where={"id": existing_profile.id},
data=prisma.types.ProfileUpdateInput(
name=profile.name,
username=profile.username,
description=profile.description,
links=profile.links,
avatarUrl=profile.avatar_url,
),
)
if updated_profile is None:
logger.error(f"Failed to update profile for user {user_id}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to update profile"
)
return backend.server.v2.store.model.CreatorDetails(
name=updated_profile.name,
username=updated_profile.username,
description=updated_profile.description,
links=updated_profile.links,
avatar_url=updated_profile.avatarUrl or "",
agent_rating=0.0,
agent_runs=0,
top_categories=[],
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating profile: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to update profile"
) from e
async def get_my_agents(
user_id: str,
page: int = 1,
page_size: int = 20,
) -> backend.server.v2.store.model.MyAgentsResponse:
logger.debug(f"Getting my agents for user {user_id}, page={page}")
try:
agents_with_max_version = await prisma.models.AgentGraph.prisma().find_many(
where=prisma.types.AgentGraphWhereInput(
userId=user_id, StoreListing={"none": {"isDeleted": False}}
),
order=[{"version": "desc"}],
distinct=["id"],
skip=(page - 1) * page_size,
take=page_size,
)
# store_listings = await prisma.models.StoreListing.prisma().find_many(
# where=prisma.types.StoreListingWhereInput(
# isDeleted=False,
# ),
# )
total = len(
await prisma.models.AgentGraph.prisma().find_many(
where=prisma.types.AgentGraphWhereInput(
userId=user_id, StoreListing={"none": {"isDeleted": False}}
),
order=[{"version": "desc"}],
distinct=["id"],
)
)
total_pages = (total + page_size - 1) // page_size
agents = agents_with_max_version
my_agents = [
backend.server.v2.store.model.MyAgent(
agent_id=agent.id,
agent_version=agent.version,
agent_name=agent.name or "",
last_edited=agent.updatedAt or agent.createdAt,
)
for agent in agents
]
return backend.server.v2.store.model.MyAgentsResponse(
agents=my_agents,
pagination=backend.server.v2.store.model.Pagination(
current_page=page,
total_items=total,
total_pages=total_pages,
page_size=page_size,
),
)
except Exception as e:
logger.error(f"Error getting my agents: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch my agents"
) from e

View File

@@ -0,0 +1,264 @@
from datetime import datetime
import prisma.errors
import prisma.models
import pytest
from prisma import Prisma
import backend.server.v2.store.db as db
from backend.server.v2.store.model import Profile
@pytest.fixture(autouse=True)
async def setup_prisma():
# Don't register client if already registered
try:
Prisma()
except prisma.errors.ClientAlreadyRegisteredError:
pass
yield
@pytest.mark.asyncio
async def test_get_store_agents(mocker):
# Mock data
mock_agents = [
prisma.models.StoreAgent(
listing_id="test-id",
storeListingVersionId="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video=None,
agent_image=["image.jpg"],
featured=False,
creator_username="creator",
creator_avatar="avatar.jpg",
sub_heading="Test heading",
description="Test description",
categories=[],
runs=10,
rating=4.5,
versions=["1.0"],
updated_at=datetime.now(),
)
]
# Mock prisma calls
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
mock_store_agent.return_value.find_many = mocker.AsyncMock(return_value=mock_agents)
mock_store_agent.return_value.count = mocker.AsyncMock(return_value=1)
# Call function
result = await db.get_store_agents()
# Verify results
assert len(result.agents) == 1
assert result.agents[0].slug == "test-agent"
assert result.pagination.total_items == 1
# Verify mocks called correctly
mock_store_agent.return_value.find_many.assert_called_once()
mock_store_agent.return_value.count.assert_called_once()
@pytest.mark.asyncio
async def test_get_store_agent_details(mocker):
# Mock data
mock_agent = prisma.models.StoreAgent(
listing_id="test-id",
storeListingVersionId="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video="video.mp4",
agent_image=["image.jpg"],
featured=False,
creator_username="creator",
creator_avatar="avatar.jpg",
sub_heading="Test heading",
description="Test description",
categories=["test"],
runs=10,
rating=4.5,
versions=["1.0"],
updated_at=datetime.now(),
)
# Mock prisma call
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
# Call function
result = await db.get_store_agent_details("creator", "test-agent")
# Verify results
assert result.slug == "test-agent"
assert result.agent_name == "Test Agent"
# Verify mock called correctly
mock_store_agent.return_value.find_first.assert_called_once_with(
where={"creator_username": "creator", "slug": "test-agent"}
)
@pytest.mark.asyncio
async def test_get_store_creator_details(mocker):
# Mock data
mock_creator_data = prisma.models.Creator(
name="Test Creator",
username="creator",
description="Test description",
links=["link1"],
avatar_url="avatar.jpg",
num_agents=1,
agent_rating=4.5,
agent_runs=10,
top_categories=["test"],
is_featured=False,
)
# Mock prisma call
mock_creator = mocker.patch("prisma.models.Creator.prisma")
mock_creator.return_value.find_unique = mocker.AsyncMock()
# Configure the mock to return values that will pass validation
mock_creator.return_value.find_unique.return_value = mock_creator_data
# Call function
result = await db.get_store_creator_details("creator")
# Verify results
assert result.username == "creator"
assert result.name == "Test Creator"
assert result.description == "Test description"
assert result.avatar_url == "avatar.jpg"
# Verify mock called correctly
mock_creator.return_value.find_unique.assert_called_once_with(
where={"username": "creator"}
)
@pytest.mark.asyncio
async def test_create_store_submission(mocker):
# Mock data
mock_agent = prisma.models.AgentGraph(
id="agent-id",
version=1,
userId="user-id",
createdAt=datetime.now(),
isActive=True,
isTemplate=False,
)
mock_listing = prisma.models.StoreListing(
id="listing-id",
createdAt=datetime.now(),
updatedAt=datetime.now(),
isDeleted=False,
isApproved=False,
agentId="agent-id",
agentVersion=1,
owningUserId="user-id",
)
# Mock prisma calls
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
mock_store_listing = mocker.patch("prisma.models.StoreListing.prisma")
mock_store_listing.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_store_listing.return_value.create = mocker.AsyncMock(return_value=mock_listing)
# Call function
result = await db.create_store_submission(
user_id="user-id",
agent_id="agent-id",
agent_version=1,
slug="test-agent",
name="Test Agent",
description="Test description",
)
# Verify results
assert result.name == "Test Agent"
assert result.description == "Test description"
# Verify mocks called correctly
mock_agent_graph.return_value.find_first.assert_called_once()
mock_store_listing.return_value.find_first.assert_called_once()
mock_store_listing.return_value.create.assert_called_once()
@pytest.mark.asyncio
async def test_update_profile(mocker):
# Mock data
mock_profile = prisma.models.Profile(
id="profile-id",
name="Test Creator",
username="creator",
description="Test description",
links=["link1"],
avatarUrl="avatar.jpg",
isFeatured=False,
createdAt=datetime.now(),
updatedAt=datetime.now(),
)
# Mock prisma calls
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
mock_profile_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_profile
)
mock_profile_db.return_value.update = mocker.AsyncMock(return_value=mock_profile)
# Test data
profile = Profile(
name="Test Creator",
username="creator",
description="Test description",
links=["link1"],
avatar_url="avatar.jpg",
is_featured=False,
)
# Call function
result = await db.update_or_create_profile("user-id", profile)
# Verify results
assert result.username == "creator"
assert result.name == "Test Creator"
# Verify mocks called correctly
mock_profile_db.return_value.find_first.assert_called_once()
mock_profile_db.return_value.update.assert_called_once()
@pytest.mark.asyncio
async def test_get_user_profile(mocker):
# Mock data
mock_profile = prisma.models.Profile(
id="profile-id",
name="No Profile Data",
username="testuser",
description="Test description",
links=["link1", "link2"],
avatarUrl="avatar.jpg",
isFeatured=False,
createdAt=datetime.now(),
updatedAt=datetime.now(),
)
# Mock prisma calls
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
mock_profile_db.return_value.find_unique = mocker.AsyncMock(
return_value=mock_profile
)
# Call function
result = await db.get_user_profile("user-id")
# Verify results
assert result.name == "No Profile Data"
assert result.username == "No Profile Data"
assert result.description == "No Profile Data"
assert result.links == []
assert result.avatar_url == ""

View File

@@ -0,0 +1,76 @@
class MediaUploadError(Exception):
"""Base exception for media upload errors"""
pass
class InvalidFileTypeError(MediaUploadError):
"""Raised when file type is not supported"""
pass
class FileSizeTooLargeError(MediaUploadError):
"""Raised when file size exceeds maximum limit"""
pass
class FileReadError(MediaUploadError):
"""Raised when there's an error reading the file"""
pass
class StorageConfigError(MediaUploadError):
"""Raised when storage configuration is invalid"""
pass
class StorageUploadError(MediaUploadError):
"""Raised when upload to storage fails"""
pass
class StoreError(Exception):
"""Base exception for store-related errors"""
pass
class AgentNotFoundError(StoreError):
"""Raised when an agent is not found"""
pass
class CreatorNotFoundError(StoreError):
"""Raised when a creator is not found"""
pass
class ListingExistsError(StoreError):
"""Raised when trying to create a listing that already exists"""
pass
class DatabaseError(StoreError):
"""Raised when there is an error interacting with the database"""
pass
class ProfileNotFoundError(StoreError):
"""Raised when a profile is not found"""
pass
class SubmissionNotFoundError(StoreError):
"""Raised when a submission is not found"""
pass

View File

@@ -0,0 +1,154 @@
import logging
import os
import uuid
import fastapi
from google.cloud import storage
import backend.server.v2.store.exceptions
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
ALLOWED_IMAGE_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp"}
ALLOWED_VIDEO_TYPES = {"video/mp4", "video/webm"}
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
async def upload_media(user_id: str, file: fastapi.UploadFile) -> str:
# Get file content for deeper validation
try:
content = await file.read(1024) # Read first 1KB for validation
await file.seek(0) # Reset file pointer
except Exception as e:
logger.error(f"Error reading file content: {str(e)}")
raise backend.server.v2.store.exceptions.FileReadError(
"Failed to read file content"
) from e
# Validate file signature/magic bytes
if file.content_type in ALLOWED_IMAGE_TYPES:
# Check image file signatures
if content.startswith(b"\xFF\xD8\xFF"): # JPEG
if file.content_type != "image/jpeg":
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
"File signature does not match content type"
)
elif content.startswith(b"\x89PNG\r\n\x1a\n"): # PNG
if file.content_type != "image/png":
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
"File signature does not match content type"
)
elif content.startswith(b"GIF87a") or content.startswith(b"GIF89a"): # GIF
if file.content_type != "image/gif":
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
"File signature does not match content type"
)
elif content.startswith(b"RIFF") and content[8:12] == b"WEBP": # WebP
if file.content_type != "image/webp":
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
"File signature does not match content type"
)
else:
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
"Invalid image file signature"
)
elif file.content_type in ALLOWED_VIDEO_TYPES:
# Check video file signatures
if content.startswith(b"\x00\x00\x00") and (content[4:8] == b"ftyp"): # MP4
if file.content_type != "video/mp4":
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
"File signature does not match content type"
)
elif content.startswith(b"\x1a\x45\xdf\xa3"): # WebM
if file.content_type != "video/webm":
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
"File signature does not match content type"
)
else:
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
"Invalid video file signature"
)
settings = Settings()
# Check required settings first before doing any file processing
if not settings.config.media_gcs_bucket_name:
logger.error("Missing GCS bucket name setting")
raise backend.server.v2.store.exceptions.StorageConfigError(
"Missing storage bucket configuration"
)
try:
# Validate file type
content_type = file.content_type
if (
content_type not in ALLOWED_IMAGE_TYPES
and content_type not in ALLOWED_VIDEO_TYPES
):
logger.warning(f"Invalid file type attempted: {content_type}")
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
f"File type not supported. Must be jpeg, png, gif, webp, mp4 or webm. Content type: {content_type}"
)
# Validate file size
file_size = 0
chunk_size = 8192 # 8KB chunks
try:
while chunk := await file.read(chunk_size):
file_size += len(chunk)
if file_size > MAX_FILE_SIZE:
logger.warning(f"File size too large: {file_size} bytes")
raise backend.server.v2.store.exceptions.FileSizeTooLargeError(
"File too large. Maximum size is 50MB"
)
except backend.server.v2.store.exceptions.FileSizeTooLargeError:
raise
except Exception as e:
logger.error(f"Error reading file chunks: {str(e)}")
raise backend.server.v2.store.exceptions.FileReadError(
"Failed to read uploaded file"
) from e
# Reset file pointer
await file.seek(0)
# Generate unique filename
filename = file.filename or ""
file_ext = os.path.splitext(filename)[1].lower()
unique_filename = f"{uuid.uuid4()}{file_ext}"
# Construct storage path
media_type = "images" if content_type in ALLOWED_IMAGE_TYPES else "videos"
storage_path = f"users/{user_id}/{media_type}/{unique_filename}"
try:
storage_client = storage.Client()
bucket = storage_client.bucket(settings.config.media_gcs_bucket_name)
blob = bucket.blob(storage_path)
blob.content_type = content_type
file_bytes = await file.read()
blob.upload_from_string(file_bytes, content_type=content_type)
public_url = blob.public_url
logger.info(f"Successfully uploaded file to: {storage_path}")
return public_url
except Exception as e:
logger.error(f"GCS storage error: {str(e)}")
raise backend.server.v2.store.exceptions.StorageUploadError(
"Failed to upload file to storage"
) from e
except backend.server.v2.store.exceptions.MediaUploadError:
raise
except Exception as e:
logger.exception("Unexpected error in upload_media")
raise backend.server.v2.store.exceptions.MediaUploadError(
"Unexpected error during media upload"
) from e

View File

@@ -0,0 +1,190 @@
import io
import unittest.mock
import fastapi
import pytest
import starlette.datastructures
import backend.server.v2.store.exceptions
import backend.server.v2.store.media
from backend.util.settings import Settings
@pytest.fixture
def mock_settings(monkeypatch):
settings = Settings()
settings.config.media_gcs_bucket_name = "test-bucket"
settings.config.google_application_credentials = "test-credentials"
monkeypatch.setattr("backend.server.v2.store.media.Settings", lambda: settings)
return settings
@pytest.fixture
def mock_storage_client(mocker):
mock_client = unittest.mock.MagicMock()
mock_bucket = unittest.mock.MagicMock()
mock_blob = unittest.mock.MagicMock()
mock_client.bucket.return_value = mock_bucket
mock_bucket.blob.return_value = mock_blob
mock_blob.public_url = "http://test-url/media/laptop.jpeg"
mocker.patch("google.cloud.storage.Client", return_value=mock_client)
return mock_client
async def test_upload_media_success(mock_settings, mock_storage_client):
# Create test JPEG data with valid signature
test_data = b"\xFF\xD8\xFF" + b"test data"
test_file = fastapi.UploadFile(
filename="laptop.jpeg",
file=io.BytesIO(test_data),
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
)
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
assert result == "http://test-url/media/laptop.jpeg"
mock_bucket = mock_storage_client.bucket.return_value
mock_blob = mock_bucket.blob.return_value
mock_blob.upload_from_string.assert_called_once()
async def test_upload_media_invalid_type(mock_settings, mock_storage_client):
test_file = fastapi.UploadFile(
filename="test.txt",
file=io.BytesIO(b"test data"),
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
)
with pytest.raises(backend.server.v2.store.exceptions.InvalidFileTypeError):
await backend.server.v2.store.media.upload_media("test-user", test_file)
mock_bucket = mock_storage_client.bucket.return_value
mock_blob = mock_bucket.blob.return_value
mock_blob.upload_from_string.assert_not_called()
async def test_upload_media_missing_credentials(monkeypatch):
settings = Settings()
settings.config.media_gcs_bucket_name = ""
settings.config.google_application_credentials = ""
monkeypatch.setattr("backend.server.v2.store.media.Settings", lambda: settings)
test_file = fastapi.UploadFile(
filename="laptop.jpeg",
file=io.BytesIO(b"\xFF\xD8\xFF" + b"test data"), # Valid JPEG signature
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
)
with pytest.raises(backend.server.v2.store.exceptions.StorageConfigError):
await backend.server.v2.store.media.upload_media("test-user", test_file)
async def test_upload_media_video_type(mock_settings, mock_storage_client):
test_file = fastapi.UploadFile(
filename="test.mp4",
file=io.BytesIO(b"\x00\x00\x00\x18ftypmp42"), # Valid MP4 signature
headers=starlette.datastructures.Headers({"content-type": "video/mp4"}),
)
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
assert result == "http://test-url/media/laptop.jpeg"
mock_bucket = mock_storage_client.bucket.return_value
mock_blob = mock_bucket.blob.return_value
mock_blob.upload_from_string.assert_called_once()
async def test_upload_media_file_too_large(mock_settings, mock_storage_client):
large_data = b"\xFF\xD8\xFF" + b"x" * (
50 * 1024 * 1024 + 1
) # 50MB + 1 byte with valid JPEG signature
test_file = fastapi.UploadFile(
filename="laptop.jpeg",
file=io.BytesIO(large_data),
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
)
with pytest.raises(backend.server.v2.store.exceptions.FileSizeTooLargeError):
await backend.server.v2.store.media.upload_media("test-user", test_file)
async def test_upload_media_file_read_error(mock_settings, mock_storage_client):
test_file = fastapi.UploadFile(
filename="laptop.jpeg",
file=io.BytesIO(b""), # Empty file that will raise error on read
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
)
test_file.read = unittest.mock.AsyncMock(side_effect=Exception("Read error"))
with pytest.raises(backend.server.v2.store.exceptions.FileReadError):
await backend.server.v2.store.media.upload_media("test-user", test_file)
async def test_upload_media_png_success(mock_settings, mock_storage_client):
test_file = fastapi.UploadFile(
filename="test.png",
file=io.BytesIO(b"\x89PNG\r\n\x1a\n"), # Valid PNG signature
headers=starlette.datastructures.Headers({"content-type": "image/png"}),
)
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
assert result == "http://test-url/media/laptop.jpeg"
async def test_upload_media_gif_success(mock_settings, mock_storage_client):
test_file = fastapi.UploadFile(
filename="test.gif",
file=io.BytesIO(b"GIF89a"), # Valid GIF signature
headers=starlette.datastructures.Headers({"content-type": "image/gif"}),
)
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
assert result == "http://test-url/media/laptop.jpeg"
async def test_upload_media_webp_success(mock_settings, mock_storage_client):
test_file = fastapi.UploadFile(
filename="test.webp",
file=io.BytesIO(b"RIFF\x00\x00\x00\x00WEBP"), # Valid WebP signature
headers=starlette.datastructures.Headers({"content-type": "image/webp"}),
)
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
assert result == "http://test-url/media/laptop.jpeg"
async def test_upload_media_webm_success(mock_settings, mock_storage_client):
test_file = fastapi.UploadFile(
filename="test.webm",
file=io.BytesIO(b"\x1a\x45\xdf\xa3"), # Valid WebM signature
headers=starlette.datastructures.Headers({"content-type": "video/webm"}),
)
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
assert result == "http://test-url/media/laptop.jpeg"
async def test_upload_media_mismatched_signature(mock_settings, mock_storage_client):
test_file = fastapi.UploadFile(
filename="test.jpeg",
file=io.BytesIO(b"\x89PNG\r\n\x1a\n"), # PNG signature with JPEG content type
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
)
with pytest.raises(backend.server.v2.store.exceptions.InvalidFileTypeError):
await backend.server.v2.store.media.upload_media("test-user", test_file)
async def test_upload_media_invalid_signature(mock_settings, mock_storage_client):
test_file = fastapi.UploadFile(
filename="test.jpeg",
file=io.BytesIO(b"invalid signature"),
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
)
with pytest.raises(backend.server.v2.store.exceptions.InvalidFileTypeError):
await backend.server.v2.store.media.upload_media("test-user", test_file)

View File

@@ -0,0 +1,152 @@
import datetime
from typing import List
import prisma.enums
import pydantic
class Pagination(pydantic.BaseModel):
total_items: int = pydantic.Field(
description="Total number of items.", examples=[42]
)
total_pages: int = pydantic.Field(
description="Total number of pages.", examples=[97]
)
current_page: int = pydantic.Field(
description="Current_page page number.", examples=[1]
)
page_size: int = pydantic.Field(
description="Number of items per page.", examples=[25]
)
class MyAgent(pydantic.BaseModel):
agent_id: str
agent_version: int
agent_name: str
last_edited: datetime.datetime
class MyAgentsResponse(pydantic.BaseModel):
agents: list[MyAgent]
pagination: Pagination
class StoreAgent(pydantic.BaseModel):
slug: str
agent_name: str
agent_image: str
creator: str
creator_avatar: str
sub_heading: str
description: str
runs: int
rating: float
class StoreAgentsResponse(pydantic.BaseModel):
agents: list[StoreAgent]
pagination: Pagination
class StoreAgentDetails(pydantic.BaseModel):
store_listing_version_id: str
slug: str
agent_name: str
agent_video: str
agent_image: list[str]
creator: str
creator_avatar: str
sub_heading: str
description: str
categories: list[str]
runs: int
rating: float
versions: list[str]
last_updated: datetime.datetime
class Creator(pydantic.BaseModel):
name: str
username: str
description: str
avatar_url: str
num_agents: int
agent_rating: float
agent_runs: int
is_featured: bool
class CreatorsResponse(pydantic.BaseModel):
creators: List[Creator]
pagination: Pagination
class CreatorDetails(pydantic.BaseModel):
name: str
username: str
description: str
links: list[str]
avatar_url: str
agent_rating: float
agent_runs: int
top_categories: list[str]
class Profile(pydantic.BaseModel):
name: str
username: str
description: str
links: list[str]
avatar_url: str
is_featured: bool = False
class StoreSubmission(pydantic.BaseModel):
agent_id: str
agent_version: int
name: str
sub_heading: str
slug: str
description: str
image_urls: list[str]
date_submitted: datetime.datetime
status: prisma.enums.SubmissionStatus
runs: int
rating: float
class StoreSubmissionsResponse(pydantic.BaseModel):
submissions: list[StoreSubmission]
pagination: Pagination
class StoreSubmissionRequest(pydantic.BaseModel):
agent_id: str
agent_version: int
slug: str
name: str
sub_heading: str
video_url: str | None = None
image_urls: list[str] = []
description: str = ""
categories: list[str] = []
class ProfileDetails(pydantic.BaseModel):
name: str
username: str
description: str
links: list[str]
avatar_url: str | None = None
class StoreReview(pydantic.BaseModel):
score: int
comments: str | None = None
class StoreReviewCreate(pydantic.BaseModel):
store_listing_version_id: str
score: int
comments: str | None = None

View File

@@ -0,0 +1,195 @@
import datetime
import prisma.enums
import backend.server.v2.store.model
def test_pagination():
pagination = backend.server.v2.store.model.Pagination(
total_items=100, total_pages=5, current_page=2, page_size=20
)
assert pagination.total_items == 100
assert pagination.total_pages == 5
assert pagination.current_page == 2
assert pagination.page_size == 20
def test_store_agent():
agent = backend.server.v2.store.model.StoreAgent(
slug="test-agent",
agent_name="Test Agent",
agent_image="test.jpg",
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
runs=50,
rating=4.5,
)
assert agent.slug == "test-agent"
assert agent.agent_name == "Test Agent"
assert agent.runs == 50
assert agent.rating == 4.5
def test_store_agents_response():
response = backend.server.v2.store.model.StoreAgentsResponse(
agents=[
backend.server.v2.store.model.StoreAgent(
slug="test-agent",
agent_name="Test Agent",
agent_image="test.jpg",
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
runs=50,
rating=4.5,
)
],
pagination=backend.server.v2.store.model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.agents) == 1
assert response.pagination.total_items == 1
def test_store_agent_details():
details = backend.server.v2.store.model.StoreAgentDetails(
store_listing_version_id="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video="video.mp4",
agent_image=["image1.jpg", "image2.jpg"],
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
categories=["cat1", "cat2"],
runs=50,
rating=4.5,
versions=["1.0", "2.0"],
last_updated=datetime.datetime.now(),
)
assert details.slug == "test-agent"
assert len(details.agent_image) == 2
assert len(details.categories) == 2
assert len(details.versions) == 2
def test_creator():
creator = backend.server.v2.store.model.Creator(
agent_rating=4.8,
agent_runs=1000,
name="Test Creator",
username="creator1",
description="Test description",
avatar_url="avatar.jpg",
num_agents=5,
is_featured=False,
)
assert creator.name == "Test Creator"
assert creator.num_agents == 5
def test_creators_response():
response = backend.server.v2.store.model.CreatorsResponse(
creators=[
backend.server.v2.store.model.Creator(
agent_rating=4.8,
agent_runs=1000,
name="Test Creator",
username="creator1",
description="Test description",
avatar_url="avatar.jpg",
num_agents=5,
is_featured=False,
)
],
pagination=backend.server.v2.store.model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.creators) == 1
assert response.pagination.total_items == 1
def test_creator_details():
details = backend.server.v2.store.model.CreatorDetails(
name="Test Creator",
username="creator1",
description="Test description",
links=["link1.com", "link2.com"],
avatar_url="avatar.jpg",
agent_rating=4.8,
agent_runs=1000,
top_categories=["cat1", "cat2"],
)
assert details.name == "Test Creator"
assert len(details.links) == 2
assert details.agent_rating == 4.8
assert len(details.top_categories) == 2
def test_store_submission():
submission = backend.server.v2.store.model.StoreSubmission(
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",
name="Test Agent",
slug="test-agent",
description="Test description",
image_urls=["image1.jpg", "image2.jpg"],
date_submitted=datetime.datetime(2023, 1, 1),
status=prisma.enums.SubmissionStatus.PENDING,
runs=50,
rating=4.5,
)
assert submission.name == "Test Agent"
assert len(submission.image_urls) == 2
assert submission.status == prisma.enums.SubmissionStatus.PENDING
def test_store_submissions_response():
response = backend.server.v2.store.model.StoreSubmissionsResponse(
submissions=[
backend.server.v2.store.model.StoreSubmission(
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",
name="Test Agent",
slug="test-agent",
description="Test description",
image_urls=["image1.jpg"],
date_submitted=datetime.datetime(2023, 1, 1),
status=prisma.enums.SubmissionStatus.PENDING,
runs=50,
rating=4.5,
)
],
pagination=backend.server.v2.store.model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.submissions) == 1
assert response.pagination.total_items == 1
def test_store_submission_request():
request = backend.server.v2.store.model.StoreSubmissionRequest(
agent_id="agent123",
agent_version=1,
slug="test-agent",
name="Test Agent",
sub_heading="Test subheading",
video_url="video.mp4",
image_urls=["image1.jpg", "image2.jpg"],
description="Test description",
categories=["cat1", "cat2"],
)
assert request.agent_id == "agent123"
assert request.agent_version == 1
assert len(request.image_urls) == 2
assert len(request.categories) == 2

View File

@@ -0,0 +1,441 @@
import logging
import typing
import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware
import fastapi
import fastapi.responses
import backend.server.v2.store.db
import backend.server.v2.store.media
import backend.server.v2.store.model
logger = logging.getLogger(__name__)
router = fastapi.APIRouter()
##############################################
############### Profile Endpoints ############
##############################################
@router.get("/profile", tags=["store", "private"])
async def get_profile(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
]
) -> backend.server.v2.store.model.ProfileDetails:
"""
Get the profile details for the authenticated user.
"""
try:
profile = await backend.server.v2.store.db.get_user_profile(user_id)
return profile
except Exception:
logger.exception("Exception occurred whilst getting user profile")
raise
@router.post(
"/profile",
tags=["store", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
async def update_or_create_profile(
profile: backend.server.v2.store.model.Profile,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> backend.server.v2.store.model.CreatorDetails:
"""
Update the store profile for the authenticated user.
Args:
profile (Profile): The updated profile details
user_id (str): ID of the authenticated user
Returns:
CreatorDetails: The updated profile
Raises:
HTTPException: If there is an error updating the profile
"""
try:
updated_profile = await backend.server.v2.store.db.update_or_create_profile(
user_id=user_id, profile=profile
)
return updated_profile
except Exception:
logger.exception("Exception occurred whilst updating profile")
raise
##############################################
############### Agent Endpoints ##############
##############################################
@router.get("/agents", tags=["store", "public"])
async def get_agents(
featured: bool = False,
creator: str | None = None,
sorted_by: str | None = None,
search_query: str | None = None,
category: str | None = None,
page: int = 1,
page_size: int = 20,
) -> backend.server.v2.store.model.StoreAgentsResponse:
"""
Get a paginated list of agents from the store with optional filtering and sorting.
Args:
featured (bool, optional): Filter to only show featured agents. Defaults to False.
creator (str | None, optional): Filter agents by creator username. Defaults to None.
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
category (str | None, optional): Filter agents by category. Defaults to None.
page (int, optional): Page number for pagination. Defaults to 1.
page_size (int, optional): Number of agents per page. Defaults to 20.
Returns:
StoreAgentsResponse: Paginated list of agents matching the filters
Raises:
HTTPException: If page or page_size are less than 1
Used for:
- Home Page Featured Agents
- Home Page Top Agents
- Search Results
- Agent Details - Other Agents By Creator
- Agent Details - Similar Agents
- Creator Details - Agents By Creator
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
try:
agents = await backend.server.v2.store.db.get_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
return agents
except Exception:
logger.exception("Exception occured whilst getting store agents")
raise
@router.get("/agents/{username}/{agent_name}", tags=["store", "public"])
async def get_agent(
username: str, agent_name: str
) -> backend.server.v2.store.model.StoreAgentDetails:
"""
This is only used on the AgentDetails Page
It returns the store listing agents details.
"""
try:
agent = await backend.server.v2.store.db.get_store_agent_details(
username=username, agent_name=agent_name
)
return agent
except Exception:
logger.exception("Exception occurred whilst getting store agent details")
raise
@router.post(
"/agents/{username}/{agent_name}/review",
tags=["store"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
async def create_review(
username: str,
agent_name: str,
review: backend.server.v2.store.model.StoreReviewCreate,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> backend.server.v2.store.model.StoreReview:
"""
Create a review for a store agent.
Args:
username: Creator's username
agent_name: Name/slug of the agent
review: Review details including score and optional comments
user_id: ID of authenticated user creating the review
Returns:
The created review
"""
try:
# Create the review
created_review = await backend.server.v2.store.db.create_store_review(
user_id=user_id,
store_listing_version_id=review.store_listing_version_id,
score=review.score,
comments=review.comments,
)
return created_review
except Exception:
logger.exception("Exception occurred whilst creating store review")
raise
##############################################
############# Creator Endpoints #############
##############################################
@router.get("/creators", tags=["store", "public"])
async def get_creators(
featured: bool = False,
search_query: str | None = None,
sorted_by: str | None = None,
page: int = 1,
page_size: int = 20,
) -> backend.server.v2.store.model.CreatorsResponse:
"""
This is needed for:
- Home Page Featured Creators
- Search Results Page
---
To support this functionality we need:
- featured: bool - to limit the list to just featured agents
- search_query: str - vector search based on the creators profile description.
- sorted_by: [agent_rating, agent_runs] -
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
try:
creators = await backend.server.v2.store.db.get_store_creators(
featured=featured,
search_query=search_query,
sorted_by=sorted_by,
page=page,
page_size=page_size,
)
return creators
except Exception:
logger.exception("Exception occurred whilst getting store creators")
raise
@router.get("/creator/{username}", tags=["store", "public"])
async def get_creator(username: str) -> backend.server.v2.store.model.CreatorDetails:
"""
Get the details of a creator
- Creator Details Page
"""
try:
creator = await backend.server.v2.store.db.get_store_creator_details(
username=username
)
return creator
except Exception:
logger.exception("Exception occurred whilst getting creator details")
raise
############################################
############# Store Submissions ###############
############################################
@router.get(
"/myagents",
tags=["store", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
async def get_my_agents(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
]
) -> backend.server.v2.store.model.MyAgentsResponse:
try:
agents = await backend.server.v2.store.db.get_my_agents(user_id)
return agents
except Exception:
logger.exception("Exception occurred whilst getting my agents")
raise
@router.delete(
"/submissions/{submission_id}",
tags=["store", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
async def delete_submission(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
submission_id: str,
) -> bool:
"""
Delete a store listing submission.
Args:
user_id (str): ID of the authenticated user
submission_id (str): ID of the submission to be deleted
Returns:
bool: True if the submission was successfully deleted, False otherwise
"""
try:
result = await backend.server.v2.store.db.delete_store_submission(
user_id=user_id,
submission_id=submission_id,
)
return result
except Exception:
logger.exception("Exception occurred whilst deleting store submission")
raise
@router.get(
"/submissions",
tags=["store", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
async def get_submissions(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
page: int = 1,
page_size: int = 20,
) -> backend.server.v2.store.model.StoreSubmissionsResponse:
"""
Get a paginated list of store submissions for the authenticated user.
Args:
user_id (str): ID of the authenticated user
page (int, optional): Page number for pagination. Defaults to 1.
page_size (int, optional): Number of submissions per page. Defaults to 20.
Returns:
StoreListingsResponse: Paginated list of store submissions
Raises:
HTTPException: If page or page_size are less than 1
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
try:
listings = await backend.server.v2.store.db.get_store_submissions(
user_id=user_id,
page=page,
page_size=page_size,
)
return listings
except Exception:
logger.exception("Exception occurred whilst getting store submissions")
raise
@router.post(
"/submissions",
tags=["store", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
async def create_submission(
submission_request: backend.server.v2.store.model.StoreSubmissionRequest,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> backend.server.v2.store.model.StoreSubmission:
"""
Create a new store listing submission.
Args:
submission_request (StoreSubmissionRequest): The submission details
user_id (str): ID of the authenticated user submitting the listing
Returns:
StoreSubmission: The created store submission
Raises:
HTTPException: If there is an error creating the submission
"""
try:
submission = await backend.server.v2.store.db.create_store_submission(
user_id=user_id,
agent_id=submission_request.agent_id,
agent_version=submission_request.agent_version,
slug=submission_request.slug,
name=submission_request.name,
video_url=submission_request.video_url,
image_urls=submission_request.image_urls,
description=submission_request.description,
sub_heading=submission_request.sub_heading,
categories=submission_request.categories,
)
return submission
except Exception:
logger.exception("Exception occurred whilst creating store submission")
raise
@router.post(
"/submissions/media",
tags=["store", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
async def upload_submission_media(
file: fastapi.UploadFile,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> str:
"""
Upload media (images/videos) for a store listing submission.
Args:
file (UploadFile): The media file to upload
user_id (str): ID of the authenticated user uploading the media
Returns:
str: URL of the uploaded media file
Raises:
HTTPException: If there is an error uploading the media
"""
try:
media_url = await backend.server.v2.store.media.upload_media(
user_id=user_id, file=file
)
return media_url
except Exception as e:
logger.exception("Exception occurred whilst uploading submission media")
raise fastapi.HTTPException(
status_code=500, detail=f"Failed to upload media file: {str(e)}"
)

View File

@@ -0,0 +1,552 @@
import datetime
import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware
import fastapi
import fastapi.testclient
import prisma.enums
import pytest_mock
import backend.server.v2.store.model
import backend.server.v2.store.routes
app = fastapi.FastAPI()
app.include_router(backend.server.v2.store.routes.router)
client = fastapi.testclient.TestClient(app)
def override_auth_middleware():
"""Override auth middleware for testing"""
return {"sub": "test-user-id"}
def override_get_user_id():
"""Override get_user_id for testing"""
return "test-user-id"
app.dependency_overrides[autogpt_libs.auth.middleware.auth_middleware] = (
override_auth_middleware
)
app.dependency_overrides[autogpt_libs.auth.depends.get_user_id] = override_get_user_id
def test_get_agents_defaults(mocker: pytest_mock.MockFixture):
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
agents=[],
pagination=backend.server.v2.store.model.Pagination(
current_page=0,
total_items=0,
total_pages=0,
page_size=10,
),
)
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
mock_db_call.return_value = mocked_value
response = client.get("/agents")
assert response.status_code == 200
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
response.json()
)
assert data.pagination.total_pages == 0
assert data.agents == []
mock_db_call.assert_called_once_with(
featured=False,
creator=None,
sorted_by=None,
search_query=None,
category=None,
page=1,
page_size=20,
)
def test_get_agents_featured(mocker: pytest_mock.MockFixture):
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
agents=[
backend.server.v2.store.model.StoreAgent(
slug="featured-agent",
agent_name="Featured Agent",
agent_image="featured.jpg",
creator="creator1",
creator_avatar="avatar1.jpg",
sub_heading="Featured agent subheading",
description="Featured agent description",
runs=100,
rating=4.5,
)
],
pagination=backend.server.v2.store.model.Pagination(
current_page=1,
total_items=1,
total_pages=1,
page_size=20,
),
)
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
mock_db_call.return_value = mocked_value
response = client.get("/agents?featured=true")
assert response.status_code == 200
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
response.json()
)
assert len(data.agents) == 1
assert data.agents[0].slug == "featured-agent"
mock_db_call.assert_called_once_with(
featured=True,
creator=None,
sorted_by=None,
search_query=None,
category=None,
page=1,
page_size=20,
)
def test_get_agents_by_creator(mocker: pytest_mock.MockFixture):
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
agents=[
backend.server.v2.store.model.StoreAgent(
slug="creator-agent",
agent_name="Creator Agent",
agent_image="agent.jpg",
creator="specific-creator",
creator_avatar="avatar.jpg",
sub_heading="Creator agent subheading",
description="Creator agent description",
runs=50,
rating=4.0,
)
],
pagination=backend.server.v2.store.model.Pagination(
current_page=1,
total_items=1,
total_pages=1,
page_size=20,
),
)
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
mock_db_call.return_value = mocked_value
response = client.get("/agents?creator=specific-creator")
assert response.status_code == 200
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
response.json()
)
assert len(data.agents) == 1
assert data.agents[0].creator == "specific-creator"
mock_db_call.assert_called_once_with(
featured=False,
creator="specific-creator",
sorted_by=None,
search_query=None,
category=None,
page=1,
page_size=20,
)
def test_get_agents_sorted(mocker: pytest_mock.MockFixture):
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
agents=[
backend.server.v2.store.model.StoreAgent(
slug="top-agent",
agent_name="Top Agent",
agent_image="top.jpg",
creator="creator1",
creator_avatar="avatar1.jpg",
sub_heading="Top agent subheading",
description="Top agent description",
runs=1000,
rating=5.0,
)
],
pagination=backend.server.v2.store.model.Pagination(
current_page=1,
total_items=1,
total_pages=1,
page_size=20,
),
)
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
mock_db_call.return_value = mocked_value
response = client.get("/agents?sorted_by=runs")
assert response.status_code == 200
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
response.json()
)
assert len(data.agents) == 1
assert data.agents[0].runs == 1000
mock_db_call.assert_called_once_with(
featured=False,
creator=None,
sorted_by="runs",
search_query=None,
category=None,
page=1,
page_size=20,
)
def test_get_agents_search(mocker: pytest_mock.MockFixture):
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
agents=[
backend.server.v2.store.model.StoreAgent(
slug="search-agent",
agent_name="Search Agent",
agent_image="search.jpg",
creator="creator1",
creator_avatar="avatar1.jpg",
sub_heading="Search agent subheading",
description="Specific search term description",
runs=75,
rating=4.2,
)
],
pagination=backend.server.v2.store.model.Pagination(
current_page=1,
total_items=1,
total_pages=1,
page_size=20,
),
)
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
mock_db_call.return_value = mocked_value
response = client.get("/agents?search_query=specific")
assert response.status_code == 200
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
response.json()
)
assert len(data.agents) == 1
assert "specific" in data.agents[0].description.lower()
mock_db_call.assert_called_once_with(
featured=False,
creator=None,
sorted_by=None,
search_query="specific",
category=None,
page=1,
page_size=20,
)
def test_get_agents_category(mocker: pytest_mock.MockFixture):
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
agents=[
backend.server.v2.store.model.StoreAgent(
slug="category-agent",
agent_name="Category Agent",
agent_image="category.jpg",
creator="creator1",
creator_avatar="avatar1.jpg",
sub_heading="Category agent subheading",
description="Category agent description",
runs=60,
rating=4.1,
)
],
pagination=backend.server.v2.store.model.Pagination(
current_page=1,
total_items=1,
total_pages=1,
page_size=20,
),
)
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
mock_db_call.return_value = mocked_value
response = client.get("/agents?category=test-category")
assert response.status_code == 200
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
response.json()
)
assert len(data.agents) == 1
mock_db_call.assert_called_once_with(
featured=False,
creator=None,
sorted_by=None,
search_query=None,
category="test-category",
page=1,
page_size=20,
)
def test_get_agents_pagination(mocker: pytest_mock.MockFixture):
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
agents=[
backend.server.v2.store.model.StoreAgent(
slug=f"agent-{i}",
agent_name=f"Agent {i}",
agent_image=f"agent{i}.jpg",
creator="creator1",
creator_avatar="avatar1.jpg",
sub_heading=f"Agent {i} subheading",
description=f"Agent {i} description",
runs=i * 10,
rating=4.0,
)
for i in range(5)
],
pagination=backend.server.v2.store.model.Pagination(
current_page=2,
total_items=15,
total_pages=3,
page_size=5,
),
)
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
mock_db_call.return_value = mocked_value
response = client.get("/agents?page=2&page_size=5")
assert response.status_code == 200
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
response.json()
)
assert len(data.agents) == 5
assert data.pagination.current_page == 2
assert data.pagination.page_size == 5
mock_db_call.assert_called_once_with(
featured=False,
creator=None,
sorted_by=None,
search_query=None,
category=None,
page=2,
page_size=5,
)
def test_get_agents_malformed_request(mocker: pytest_mock.MockFixture):
# Test with invalid page number
response = client.get("/agents?page=-1")
assert response.status_code == 422
# Test with invalid page size
response = client.get("/agents?page_size=0")
assert response.status_code == 422
# Test with non-numeric values
response = client.get("/agents?page=abc&page_size=def")
assert response.status_code == 422
# Verify no DB calls were made
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
mock_db_call.assert_not_called()
def test_get_agent_details(mocker: pytest_mock.MockFixture):
mocked_value = backend.server.v2.store.model.StoreAgentDetails(
store_listing_version_id="test-version-id",
slug="test-agent",
agent_name="Test Agent",
agent_video="video.mp4",
agent_image=["image1.jpg", "image2.jpg"],
creator="creator1",
creator_avatar="avatar1.jpg",
sub_heading="Test agent subheading",
description="Test agent description",
categories=["category1", "category2"],
runs=100,
rating=4.5,
versions=["1.0.0", "1.1.0"],
last_updated=datetime.datetime.now(),
)
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agent_details")
mock_db_call.return_value = mocked_value
response = client.get("/agents/creator1/test-agent")
assert response.status_code == 200
data = backend.server.v2.store.model.StoreAgentDetails.model_validate(
response.json()
)
assert data.agent_name == "Test Agent"
assert data.creator == "creator1"
mock_db_call.assert_called_once_with(username="creator1", agent_name="test-agent")
def test_get_creators_defaults(mocker: pytest_mock.MockFixture):
mocked_value = backend.server.v2.store.model.CreatorsResponse(
creators=[],
pagination=backend.server.v2.store.model.Pagination(
current_page=0,
total_items=0,
total_pages=0,
page_size=10,
),
)
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_creators")
mock_db_call.return_value = mocked_value
response = client.get("/creators")
assert response.status_code == 200
data = backend.server.v2.store.model.CreatorsResponse.model_validate(
response.json()
)
assert data.pagination.total_pages == 0
assert data.creators == []
mock_db_call.assert_called_once_with(
featured=False, search_query=None, sorted_by=None, page=1, page_size=20
)
def test_get_creators_pagination(mocker: pytest_mock.MockFixture):
mocked_value = backend.server.v2.store.model.CreatorsResponse(
creators=[
backend.server.v2.store.model.Creator(
name=f"Creator {i}",
username=f"creator{i}",
description=f"Creator {i} description",
avatar_url=f"avatar{i}.jpg",
num_agents=1,
agent_rating=4.5,
agent_runs=100,
is_featured=False,
)
for i in range(5)
],
pagination=backend.server.v2.store.model.Pagination(
current_page=2,
total_items=15,
total_pages=3,
page_size=5,
),
)
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_creators")
mock_db_call.return_value = mocked_value
response = client.get("/creators?page=2&page_size=5")
assert response.status_code == 200
data = backend.server.v2.store.model.CreatorsResponse.model_validate(
response.json()
)
assert len(data.creators) == 5
assert data.pagination.current_page == 2
assert data.pagination.page_size == 5
mock_db_call.assert_called_once_with(
featured=False, search_query=None, sorted_by=None, page=2, page_size=5
)
def test_get_creators_malformed_request(mocker: pytest_mock.MockFixture):
# Test with invalid page number
response = client.get("/creators?page=-1")
assert response.status_code == 422
# Test with invalid page size
response = client.get("/creators?page_size=0")
assert response.status_code == 422
# Test with non-numeric values
response = client.get("/creators?page=abc&page_size=def")
assert response.status_code == 422
# Verify no DB calls were made
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_creators")
mock_db_call.assert_not_called()
def test_get_creator_details(mocker: pytest_mock.MockFixture):
mocked_value = backend.server.v2.store.model.CreatorDetails(
name="Test User",
username="creator1",
description="Test creator description",
links=["link1.com", "link2.com"],
avatar_url="avatar.jpg",
agent_rating=4.8,
agent_runs=1000,
top_categories=["category1", "category2"],
)
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_creator_details")
mock_db_call.return_value = mocked_value
response = client.get("/creator/creator1")
assert response.status_code == 200
data = backend.server.v2.store.model.CreatorDetails.model_validate(response.json())
assert data.username == "creator1"
assert data.name == "Test User"
mock_db_call.assert_called_once_with(username="creator1")
def test_get_submissions_success(mocker: pytest_mock.MockFixture):
mocked_value = backend.server.v2.store.model.StoreSubmissionsResponse(
submissions=[
backend.server.v2.store.model.StoreSubmission(
name="Test Agent",
description="Test agent description",
image_urls=["test.jpg"],
date_submitted=datetime.datetime.now(),
status=prisma.enums.SubmissionStatus.APPROVED,
runs=50,
rating=4.2,
agent_id="test-agent-id",
agent_version=1,
sub_heading="Test agent subheading",
slug="test-agent",
)
],
pagination=backend.server.v2.store.model.Pagination(
current_page=1,
total_items=1,
total_pages=1,
page_size=20,
),
)
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_submissions")
mock_db_call.return_value = mocked_value
response = client.get("/submissions")
assert response.status_code == 200
data = backend.server.v2.store.model.StoreSubmissionsResponse.model_validate(
response.json()
)
assert len(data.submissions) == 1
assert data.submissions[0].name == "Test Agent"
assert data.pagination.current_page == 1
mock_db_call.assert_called_once_with(user_id="test-user-id", page=1, page_size=20)
def test_get_submissions_pagination(mocker: pytest_mock.MockFixture):
mocked_value = backend.server.v2.store.model.StoreSubmissionsResponse(
submissions=[],
pagination=backend.server.v2.store.model.Pagination(
current_page=2,
total_items=10,
total_pages=2,
page_size=5,
),
)
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_submissions")
mock_db_call.return_value = mocked_value
response = client.get("/submissions?page=2&page_size=5")
assert response.status_code == 200
data = backend.server.v2.store.model.StoreSubmissionsResponse.model_validate(
response.json()
)
assert data.pagination.current_page == 2
assert data.pagination.page_size == 5
mock_db_call.assert_called_once_with(user_id="test-user-id", page=2, page_size=5)
def test_get_submissions_malformed_request(mocker: pytest_mock.MockFixture):
# Test with invalid page number
response = client.get("/submissions?page=-1")
assert response.status_code == 422
# Test with invalid page size
response = client.get("/submissions?page_size=0")
assert response.status_code == 422
# Test with non-numeric values
response = client.get("/submissions?page=abc&page_size=def")
assert response.status_code == 422
# Verify no DB calls were made
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_submissions")
mock_db_call.assert_not_called()

View File

@@ -1,8 +1,10 @@
import ipaddress
import re
import socket
from typing import Callable
from urllib.parse import urlparse
from urllib.parse import urlparse, urlunparse
import idna
import requests as req
from backend.util.settings import Config
@@ -21,8 +23,23 @@ BLOCKED_IP_NETWORKS = [
# --8<-- [end:BLOCKED_IP_NETWORKS]
]
ALLOWED_SCHEMES = ["http", "https"]
HOSTNAME_REGEX = re.compile(r"^[A-Za-z0-9.-]+$") # Basic DNS-safe hostname pattern
def is_ip_blocked(ip: str) -> bool:
def _canonicalize_url(url: str) -> str:
# Strip spaces and trailing slashes
url = url.strip().strip("/")
# Ensure the URL starts with http:// or https://
if not url.startswith(("http://", "https://")):
url = "http://" + url
# Replace backslashes with forward slashes to avoid parsing ambiguities
url = url.replace("\\", "/")
return url
def _is_ip_blocked(ip: str) -> bool:
"""
Checks if the IP address is in a blocked network.
"""
@@ -35,29 +52,51 @@ def validate_url(url: str, trusted_origins: list[str]) -> str:
Validates the URL to prevent SSRF attacks by ensuring it does not point to a private
or untrusted IP address, unless whitelisted.
"""
url = url.strip().strip("/")
if not url.startswith(("http://", "https://")):
url = "http://" + url
url = _canonicalize_url(url)
parsed = urlparse(url)
parsed_url = urlparse(url)
hostname = parsed_url.hostname
# Check scheme
if parsed.scheme not in ALLOWED_SCHEMES:
raise ValueError(
f"Scheme '{parsed.scheme}' is not allowed. Only HTTP/HTTPS are supported."
)
if not hostname:
raise ValueError(f"Invalid URL: Unable to determine hostname from {url}")
# Validate and IDNA encode the hostname
if not parsed.hostname:
raise ValueError("Invalid URL: No hostname found.")
if any(hostname == origin for origin in trusted_origins):
# IDNA encode to prevent Unicode domain attacks
try:
ascii_hostname = idna.encode(parsed.hostname).decode("ascii")
except idna.IDNAError:
raise ValueError("Invalid hostname with unsupported characters.")
# Check hostname characters
if not HOSTNAME_REGEX.match(ascii_hostname):
raise ValueError("Hostname contains invalid characters.")
# Rebuild the URL with the normalized, IDNA-encoded hostname
parsed = parsed._replace(netloc=ascii_hostname)
url = str(urlunparse(parsed))
# Check if hostname is a trusted origin (exact match)
if ascii_hostname in trusted_origins:
return url
# Resolve all IP addresses for the hostname
ip_addresses = {result[4][0] for result in socket.getaddrinfo(hostname, None)}
if not ip_addresses:
raise ValueError(f"Unable to resolve IP address for {hostname}")
try:
ip_addresses = {res[4][0] for res in socket.getaddrinfo(ascii_hostname, None)}
except socket.gaierror:
raise ValueError(f"Unable to resolve IP address for hostname {ascii_hostname}")
# Check if all IP addresses are global
if not ip_addresses:
raise ValueError(f"No IP addresses found for {ascii_hostname}")
# Check if any resolved IP address falls into blocked ranges
for ip in ip_addresses:
if is_ip_blocked(ip):
if _is_ip_blocked(ip):
raise ValueError(
f"Access to private IP address at {hostname}: {ip} is not allowed."
f"Access to private IP address {ip} for hostname {ascii_hostname} is not allowed."
)
return url

View File

@@ -148,6 +148,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
"This value is then used to generate redirect URLs for OAuth flows.",
)
media_gcs_bucket_name: str = Field(
default="",
description="The name of the Google Cloud Storage bucket for media files",
)
@field_validator("platform_base_url", "frontend_base_url")
@classmethod
def validate_platform_base_url(cls, v: str, info: ValidationInfo) -> str:

View File

@@ -60,9 +60,7 @@ async def wait_execution(
timeout: int = 20,
) -> Sequence[ExecutionResult]:
async def is_execution_completed():
status = await AgentServer().test_get_graph_run_status(
graph_id, graph_exec_id, user_id
)
status = await AgentServer().test_get_graph_run_status(graph_exec_id, user_id)
log.info(f"Execution status: {status}")
if status == ExecutionStatus.FAILED:
log.info("Execution failed")

View File

@@ -0,0 +1,22 @@
import re
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
class TextFormatter:
def __init__(self):
# Create a sandboxed environment
self.env = SandboxedEnvironment(loader=BaseLoader(), autoescape=True)
# Clear any registered filters, tests, and globals to minimize attack surface
self.env.filters.clear()
self.env.tests.clear()
self.env.globals.clear()
def format_string(self, template_str: str, values=None, **kwargs) -> str:
# For python.format compatibility: replace all {...} with {{..}}.
# But avoid replacing {{...}} to {{{...}}}.
template_str = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", template_str)
template = self.env.from_string(template_str)
return template.render(values or {}, **kwargs)

View File

@@ -1,20 +1,31 @@
version: "3"
services:
postgres-test:
image: ankane/pgvector:latest
environment:
- POSTGRES_USER=agpt_user
- POSTGRES_PASSWORD=pass123
- POSTGRES_DB=agpt_local
- POSTGRES_USER=${DB_USER}
- POSTGRES_PASSWORD=${DB_PASS}
- POSTGRES_DB=${DB_NAME}
healthcheck:
test: pg_isready -U $$POSTGRES_USER -d $$POSTGRES_DB
interval: 10s
timeout: 5s
retries: 5
ports:
- "5433:5432"
- "${DB_PORT}:5432"
networks:
- app-network-test
redis-test:
image: redis:latest
command: redis-server --requirepass password
ports:
- "6379:6379"
networks:
- app-network-test
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
networks:
app-network-test:

View File

@@ -0,0 +1,41 @@
-- CreateIndex
CREATE INDEX "AgentGraph_userId_isActive_idx" ON "AgentGraph"("userId", "isActive");
-- CreateIndex
CREATE INDEX "AgentGraphExecution_agentGraphId_agentGraphVersion_idx" ON "AgentGraphExecution"("agentGraphId", "agentGraphVersion");
-- CreateIndex
CREATE INDEX "AgentGraphExecution_userId_idx" ON "AgentGraphExecution"("userId");
-- CreateIndex
CREATE INDEX "AgentNode_agentGraphId_agentGraphVersion_idx" ON "AgentNode"("agentGraphId", "agentGraphVersion");
-- CreateIndex
CREATE INDEX "AgentNode_agentBlockId_idx" ON "AgentNode"("agentBlockId");
-- CreateIndex
CREATE INDEX "AgentNode_webhookId_idx" ON "AgentNode"("webhookId");
-- CreateIndex
CREATE INDEX "AgentNodeExecution_agentGraphExecutionId_idx" ON "AgentNodeExecution"("agentGraphExecutionId");
-- CreateIndex
CREATE INDEX "AgentNodeExecution_agentNodeId_idx" ON "AgentNodeExecution"("agentNodeId");
-- CreateIndex
CREATE INDEX "AgentNodeExecutionInputOutput_referencedByOutputExecId_idx" ON "AgentNodeExecutionInputOutput"("referencedByOutputExecId");
-- CreateIndex
CREATE INDEX "AgentNodeLink_agentNodeSourceId_idx" ON "AgentNodeLink"("agentNodeSourceId");
-- CreateIndex
CREATE INDEX "AgentNodeLink_agentNodeSinkId_idx" ON "AgentNodeLink"("agentNodeSinkId");
-- CreateIndex
CREATE INDEX "AnalyticsMetrics_userId_idx" ON "AnalyticsMetrics"("userId");
-- CreateIndex
CREATE INDEX "IntegrationWebhook_userId_idx" ON "IntegrationWebhook"("userId");
-- CreateIndex
CREATE INDEX "UserBlockCredit_userId_createdAt_idx" ON "UserBlockCredit"("userId", "createdAt");

View File

@@ -0,0 +1,8 @@
-- AlterTable
ALTER TABLE "User" ADD COLUMN "stripeCustomerId" TEXT;
-- AlterEnum
ALTER TYPE "UserBlockCreditType" RENAME TO "CreditTransactionType";
-- AlterTable
ALTER TABLE "UserBlockCredit" RENAME TO "CreditTransaction";

View File

@@ -0,0 +1,228 @@
-- CreateEnum
CREATE TYPE "SubmissionStatus" AS ENUM ('DAFT', 'PENDING', 'APPROVED', 'REJECTED');
-- AlterTable
ALTER TABLE "AgentGraphExecution" ADD COLUMN "agentPresetId" TEXT;
-- AlterTable
ALTER TABLE "AgentNodeExecutionInputOutput" ADD COLUMN "agentPresetId" TEXT;
-- AlterTable
ALTER TABLE "AnalyticsMetrics" ALTER COLUMN "id" DROP DEFAULT;
-- AlterTable
ALTER TABLE "CreditTransaction" RENAME CONSTRAINT "UserBlockCredit_pkey" TO "CreditTransaction_pkey";
-- CreateTable
CREATE TABLE "AgentPreset" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"name" TEXT NOT NULL,
"description" TEXT NOT NULL,
"isActive" BOOLEAN NOT NULL DEFAULT true,
"userId" TEXT NOT NULL,
"agentId" TEXT NOT NULL,
"agentVersion" INTEGER NOT NULL,
CONSTRAINT "AgentPreset_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "UserAgent" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"userId" TEXT NOT NULL,
"agentId" TEXT NOT NULL,
"agentVersion" INTEGER NOT NULL,
"agentPresetId" TEXT,
"isFavorite" BOOLEAN NOT NULL DEFAULT false,
"isCreatedByUser" BOOLEAN NOT NULL DEFAULT false,
"isArchived" BOOLEAN NOT NULL DEFAULT false,
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
CONSTRAINT "UserAgent_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "Profile" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"userId" TEXT,
"name" TEXT NOT NULL,
"username" TEXT NOT NULL,
"description" TEXT NOT NULL,
"links" TEXT[],
"avatarUrl" TEXT,
CONSTRAINT "Profile_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "StoreListing" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
"isApproved" BOOLEAN NOT NULL DEFAULT false,
"agentId" TEXT NOT NULL,
"agentVersion" INTEGER NOT NULL,
"owningUserId" TEXT NOT NULL,
CONSTRAINT "StoreListing_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "StoreListingVersion" (
"id" TEXT NOT NULL,
"version" INTEGER NOT NULL DEFAULT 1,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"agentId" TEXT NOT NULL,
"agentVersion" INTEGER NOT NULL,
"slug" TEXT NOT NULL,
"name" TEXT NOT NULL,
"subHeading" TEXT NOT NULL,
"videoUrl" TEXT,
"imageUrls" TEXT[],
"description" TEXT NOT NULL,
"categories" TEXT[],
"isFeatured" BOOLEAN NOT NULL DEFAULT false,
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
"isAvailable" BOOLEAN NOT NULL DEFAULT true,
"isApproved" BOOLEAN NOT NULL DEFAULT false,
"storeListingId" TEXT,
CONSTRAINT "StoreListingVersion_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "StoreListingReview" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"storeListingVersionId" TEXT NOT NULL,
"reviewByUserId" TEXT NOT NULL,
"score" INTEGER NOT NULL,
"comments" TEXT,
CONSTRAINT "StoreListingReview_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "StoreListingSubmission" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"storeListingId" TEXT NOT NULL,
"storeListingVersionId" TEXT NOT NULL,
"reviewerId" TEXT NOT NULL,
"Status" "SubmissionStatus" NOT NULL DEFAULT 'PENDING',
"reviewComments" TEXT,
CONSTRAINT "StoreListingSubmission_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "AgentPreset_userId_idx" ON "AgentPreset"("userId");
-- CreateIndex
CREATE INDEX "UserAgent_userId_idx" ON "UserAgent"("userId");
-- CreateIndex
CREATE UNIQUE INDEX "Profile_username_key" ON "Profile"("username");
-- CreateIndex
CREATE INDEX "Profile_username_idx" ON "Profile"("username");
-- CreateIndex
CREATE INDEX "Profile_userId_idx" ON "Profile"("userId");
-- CreateIndex
CREATE INDEX "StoreListing_isApproved_idx" ON "StoreListing"("isApproved");
-- CreateIndex
CREATE INDEX "StoreListing_agentId_idx" ON "StoreListing"("agentId");
-- CreateIndex
CREATE INDEX "StoreListing_owningUserId_idx" ON "StoreListing"("owningUserId");
-- CreateIndex
CREATE INDEX "StoreListingVersion_agentId_agentVersion_isApproved_idx" ON "StoreListingVersion"("agentId", "agentVersion", "isApproved");
-- CreateIndex
CREATE UNIQUE INDEX "StoreListingVersion_agentId_agentVersion_key" ON "StoreListingVersion"("agentId", "agentVersion");
-- CreateIndex
CREATE INDEX "StoreListingReview_storeListingVersionId_idx" ON "StoreListingReview"("storeListingVersionId");
-- CreateIndex
CREATE UNIQUE INDEX "StoreListingReview_storeListingVersionId_reviewByUserId_key" ON "StoreListingReview"("storeListingVersionId", "reviewByUserId");
-- CreateIndex
CREATE INDEX "StoreListingSubmission_storeListingId_idx" ON "StoreListingSubmission"("storeListingId");
-- CreateIndex
CREATE INDEX "StoreListingSubmission_Status_idx" ON "StoreListingSubmission"("Status");
-- RenameForeignKey
ALTER TABLE "CreditTransaction" RENAME CONSTRAINT "UserBlockCredit_blockId_fkey" TO "CreditTransaction_blockId_fkey";
-- RenameForeignKey
ALTER TABLE "CreditTransaction" RENAME CONSTRAINT "UserBlockCredit_userId_fkey" TO "CreditTransaction_userId_fkey";
-- AddForeignKey
ALTER TABLE "AgentPreset" ADD CONSTRAINT "AgentPreset_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "AgentPreset" ADD CONSTRAINT "AgentPreset_agentId_agentVersion_fkey" FOREIGN KEY ("agentId", "agentVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "UserAgent" ADD CONSTRAINT "UserAgent_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "UserAgent" ADD CONSTRAINT "UserAgent_agentId_agentVersion_fkey" FOREIGN KEY ("agentId", "agentVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "UserAgent" ADD CONSTRAINT "UserAgent_agentPresetId_fkey" FOREIGN KEY ("agentPresetId") REFERENCES "AgentPreset"("id") ON DELETE SET NULL ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "AgentGraphExecution" ADD CONSTRAINT "AgentGraphExecution_agentPresetId_fkey" FOREIGN KEY ("agentPresetId") REFERENCES "AgentPreset"("id") ON DELETE SET NULL ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "AgentNodeExecutionInputOutput" ADD CONSTRAINT "AgentNodeExecutionInputOutput_agentPresetId_fkey" FOREIGN KEY ("agentPresetId") REFERENCES "AgentPreset"("id") ON DELETE SET NULL ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "Profile" ADD CONSTRAINT "Profile_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "StoreListing" ADD CONSTRAINT "StoreListing_agentId_agentVersion_fkey" FOREIGN KEY ("agentId", "agentVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "StoreListing" ADD CONSTRAINT "StoreListing_owningUserId_fkey" FOREIGN KEY ("owningUserId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "StoreListingVersion" ADD CONSTRAINT "StoreListingVersion_agentId_agentVersion_fkey" FOREIGN KEY ("agentId", "agentVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "StoreListingVersion" ADD CONSTRAINT "StoreListingVersion_storeListingId_fkey" FOREIGN KEY ("storeListingId") REFERENCES "StoreListing"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "StoreListingReview" ADD CONSTRAINT "StoreListingReview_storeListingVersionId_fkey" FOREIGN KEY ("storeListingVersionId") REFERENCES "StoreListingVersion"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "StoreListingReview" ADD CONSTRAINT "StoreListingReview_reviewByUserId_fkey" FOREIGN KEY ("reviewByUserId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "StoreListingSubmission" ADD CONSTRAINT "StoreListingSubmission_storeListingId_fkey" FOREIGN KEY ("storeListingId") REFERENCES "StoreListing"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "StoreListingSubmission" ADD CONSTRAINT "StoreListingSubmission_storeListingVersionId_fkey" FOREIGN KEY ("storeListingVersionId") REFERENCES "StoreListingVersion"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "StoreListingSubmission" ADD CONSTRAINT "StoreListingSubmission_reviewerId_fkey" FOREIGN KEY ("reviewerId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- RenameIndex
ALTER INDEX "UserBlockCredit_userId_createdAt_idx" RENAME TO "CreditTransaction_userId_createdAt_idx";

View File

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "Profile" ADD COLUMN "isFeatured" BOOLEAN NOT NULL DEFAULT false;

View File

@@ -0,0 +1,119 @@
BEGIN;
CREATE VIEW "StoreAgent" AS
WITH ReviewStats AS (
SELECT sl."id" AS "storeListingId",
COUNT(sr.id) AS review_count,
AVG(CAST(sr.score AS DECIMAL)) AS avg_rating
FROM "StoreListing" sl
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl."id"
JOIN "StoreListingReview" sr ON sr."storeListingVersionId" = slv.id
WHERE sl."isDeleted" = FALSE
GROUP BY sl."id"
),
AgentRuns AS (
SELECT "agentGraphId", COUNT(*) AS run_count
FROM "AgentGraphExecution"
GROUP BY "agentGraphId"
)
SELECT
sl.id AS listing_id,
slv.id AS "storeListingVersionId",
slv."createdAt" AS updated_at,
slv.slug,
a.name AS agent_name,
slv."videoUrl" AS agent_video,
COALESCE(slv."imageUrls", ARRAY[]::TEXT[]) AS agent_image,
slv."isFeatured" AS featured,
p.username AS creator_username,
p."avatarUrl" AS creator_avatar,
slv."subHeading" AS sub_heading,
slv.description,
slv.categories,
COALESCE(ar.run_count, 0) AS runs,
CAST(COALESCE(rs.avg_rating, 0.0) AS DOUBLE PRECISION) AS rating,
ARRAY_AGG(DISTINCT CAST(slv.version AS TEXT)) AS versions
FROM "StoreListing" sl
JOIN "AgentGraph" a ON sl."agentId" = a.id AND sl."agentVersion" = a."version"
LEFT JOIN "Profile" p ON sl."owningUserId" = p."userId"
LEFT JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
LEFT JOIN ReviewStats rs ON sl.id = rs."storeListingId"
LEFT JOIN AgentRuns ar ON a.id = ar."agentGraphId"
WHERE sl."isDeleted" = FALSE
AND sl."isApproved" = TRUE
GROUP BY sl.id, slv.id, slv.slug, slv."createdAt", a.name, slv."videoUrl", slv."imageUrls", slv."isFeatured",
p.username, p."avatarUrl", slv."subHeading", slv.description, slv.categories,
ar.run_count, rs.avg_rating;
CREATE VIEW "Creator" AS
WITH AgentStats AS (
SELECT
p.username,
COUNT(DISTINCT sl.id) as num_agents,
AVG(CAST(COALESCE(sr.score, 0) AS DECIMAL)) as agent_rating,
SUM(COALESCE(age.run_count, 0)) as agent_runs
FROM "Profile" p
LEFT JOIN "StoreListing" sl ON sl."owningUserId" = p."userId"
LEFT JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
LEFT JOIN "StoreListingReview" sr ON sr."storeListingVersionId" = slv.id
LEFT JOIN (
SELECT "agentGraphId", COUNT(*) as run_count
FROM "AgentGraphExecution"
GROUP BY "agentGraphId"
) age ON age."agentGraphId" = sl."agentId"
WHERE sl."isDeleted" = FALSE AND sl."isApproved" = TRUE
GROUP BY p.username
)
SELECT
p.username,
p.name,
p."avatarUrl" as avatar_url,
p.description,
ARRAY_AGG(DISTINCT c) FILTER (WHERE c IS NOT NULL) as top_categories,
p.links,
p."isFeatured" as is_featured,
COALESCE(ast.num_agents, 0) as num_agents,
COALESCE(ast.agent_rating, 0.0) as agent_rating,
COALESCE(ast.agent_runs, 0) as agent_runs
FROM "Profile" p
LEFT JOIN AgentStats ast ON ast.username = p.username
LEFT JOIN LATERAL (
SELECT UNNEST(slv.categories) as c
FROM "StoreListing" sl
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
WHERE sl."owningUserId" = p."userId"
AND sl."isDeleted" = FALSE
AND sl."isApproved" = TRUE
) cats ON true
GROUP BY p.username, p.name, p."avatarUrl", p.description, p.links, p."isFeatured",
ast.num_agents, ast.agent_rating, ast.agent_runs;
CREATE VIEW "StoreSubmission" AS
SELECT
sl.id as listing_id,
sl."owningUserId" as user_id,
slv."agentId" as agent_id,
slv."version" as agent_version,
slv.slug,
slv.name,
slv."subHeading" as sub_heading,
slv.description,
slv."imageUrls" as image_urls,
slv."createdAt" as date_submitted,
COALESCE(sls."Status", 'PENDING') as status,
COALESCE(ar.run_count, 0) as runs,
CAST(COALESCE(AVG(CAST(sr.score AS DECIMAL)), 0.0) AS DOUBLE PRECISION) as rating
FROM "StoreListing" sl
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
LEFT JOIN "StoreListingSubmission" sls ON sls."storeListingId" = sl.id
LEFT JOIN "StoreListingReview" sr ON sr."storeListingVersionId" = slv.id
LEFT JOIN (
SELECT "agentGraphId", COUNT(*) as run_count
FROM "AgentGraphExecution"
GROUP BY "agentGraphId"
) ar ON ar."agentGraphId" = slv."agentId"
WHERE sl."isDeleted" = FALSE
GROUP BY sl.id, sl."owningUserId", slv."agentId", slv."version", slv.slug, slv.name, slv."subHeading",
slv.description, slv."imageUrls", slv."createdAt", sls."Status", ar.run_count;
COMMIT;

File diff suppressed because it is too large Load Diff

View File

@@ -8,25 +8,26 @@ packages = [{ include = "backend" }]
[tool.poetry.dependencies]
python = "^3.10"
aio-pika = "^9.5.0"
anthropic = "^0.39.0"
python = ">=3.10,<3.13"
aio-pika = "^9.5.4"
anthropic = "^0.40.0"
apscheduler = "^3.11.0"
autogpt-libs = { path = "../autogpt_libs", develop = true }
click = "^8.1.7"
croniter = "^5.0.1"
discord-py = "^2.4.0"
e2b-code-interpreter = "^1.0.1"
fastapi = "^0.115.5"
feedparser = "^6.0.11"
flake8 = "^7.0.0"
google-api-python-client = "^2.154.0"
google-auth-oauthlib = "^1.2.1"
groq = "^0.12.0"
groq = "^0.13.1"
jinja2 = "^3.1.4"
jsonref = "^1.1.0"
jsonschema = "^4.22.0"
ollama = "^0.4.1"
openai = "^1.55.1"
openai = "^1.57.4"
praw = "~7.8.1"
prisma = "^0.15.0"
psutil = "^6.1.0"
@@ -34,23 +35,26 @@ pydantic = "^2.9.2"
pydantic-settings = "^2.3.4"
pyro5 = "^5.15"
pytest = "^8.2.1"
pytest-asyncio = "^0.24.0"
pytest-asyncio = "^0.25.0"
python-dotenv = "^1.0.1"
redis = "^5.2.0"
sentry-sdk = "2.19.0"
sentry-sdk = "2.19.2"
strenum = "^0.4.9"
supabase = "^2.10.0"
tenacity = "^9.0.0"
uvicorn = { extras = ["standard"], version = "^0.32.1" }
uvicorn = { extras = ["standard"], version = "^0.34.0" }
websockets = "^13.1"
youtube-transcript-api = "^0.6.2"
googlemaps = "^4.10.0"
replicate = "^1.0.4"
pinecone = "^5.3.1"
cryptography = "^43.0.3"
cryptography = "^43.0"
python-multipart = "^0.0.20"
sqlalchemy = "^2.0.36"
psycopg2-binary = "^2.9.10"
google-cloud-storage = "^2.18.2"
launchdarkly-server-sdk = "^9.8.0"
[tool.poetry.group.dev.dependencies]
poethepoet = "^0.31.0"
httpx = "^0.27.0"
@@ -61,6 +65,8 @@ pyright = "^1.1.389"
isort = "^5.13.2"
black = "^24.10.0"
aiohappyeyeballs = "^2.4.3"
pytest-mock = "^3.14.0"
faker = "^33.1.0"
[build-system]
requires = ["poetry-core"]
@@ -90,3 +96,6 @@ ignore_patterns = []
[tool.pytest.ini_options]
asyncio_mode = "auto"
[tool.ruff]
target-version = "py310"

View File

@@ -16,9 +16,9 @@ def wait_for_postgres(max_retries=5, delay=5):
"postgres-test",
"pg_isready",
"-U",
"agpt_user",
"postgres",
"-d",
"agpt_local",
"postgres",
],
check=True,
capture_output=True,

View File

@@ -8,26 +8,36 @@ generator client {
provider = "prisma-client-py"
recursive_type_depth = 5
interface = "asyncio"
previewFeatures = ["views"]
}
// User model to mirror Auth provider users
model User {
id String @id // This should match the Supabase user ID
email String @unique
name String?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
metadata Json @default("{}")
integrations String @default("")
id String @id // This should match the Supabase user ID
email String @unique
name String?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
metadata Json @default("{}")
integrations String @default("")
stripeCustomerId String?
// Relations
AgentGraphs AgentGraph[]
AgentGraphExecutions AgentGraphExecution[]
IntegrationWebhooks IntegrationWebhook[]
AnalyticsDetails AnalyticsDetails[]
AnalyticsMetrics AnalyticsMetrics[]
UserBlockCredit UserBlockCredit[]
APIKeys APIKey[]
AgentGraphs AgentGraph[]
AgentGraphExecutions AgentGraphExecution[]
AnalyticsDetails AnalyticsDetails[]
AnalyticsMetrics AnalyticsMetrics[]
CreditTransaction CreditTransaction[]
AgentPreset AgentPreset[]
UserAgent UserAgent[]
Profile Profile[]
StoreListing StoreListing[]
StoreListingReview StoreListingReview[]
StoreListingSubmission StoreListingSubmission[]
APIKeys APIKey[]
IntegrationWebhooks IntegrationWebhook[]
@@index([id])
@@index([email])
@@ -47,14 +57,89 @@ model AgentGraph {
// Link to User model
userId String
// FIX: Do not cascade delete the agent when the user is deleted
// This allows us to delete user data with deleting the agent which maybe in use by other users
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
AgentNodes AgentNode[]
AgentGraphExecution AgentGraphExecution[]
AgentNodes AgentNode[]
AgentGraphExecution AgentGraphExecution[]
AgentPreset AgentPreset[]
UserAgent UserAgent[]
StoreListing StoreListing[]
StoreListingVersion StoreListingVersion?
@@id(name: "graphVersionId", [id, version])
@@index([userId, isActive])
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
//////////////// USER SPECIFIC DATA ////////////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
// An AgentPrest is an Agent + User Configuration of that agent.
// For example, if someone has created a weather agent and they want to set it up to
// Inform them of extreme weather warnings in Texas, the agent with the configuration to set it to
// monitor texas, along with the cron setup or webhook tiggers, is an AgentPreset
model AgentPreset {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
name String
description String
// For agents that can be triggered by webhooks or cronjob
// This bool allows us to disable a configured agent without deleting it
isActive Boolean @default(true)
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
agentId String
agentVersion Int
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
UserAgents UserAgent[]
AgentExecution AgentGraphExecution[]
@@index([userId])
}
// For the library page
// It is a user controlled list of agents, that they will see in there library
model UserAgent {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
agentId String
agentVersion Int
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version])
agentPresetId String?
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
isFavorite Boolean @default(false)
isCreatedByUser Boolean @default(false)
isArchived Boolean @default(false)
isDeleted Boolean @default(false)
@@index([userId])
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
//////// AGENT DEFINITION AND EXECUTION TABLES ////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
// This model describes a single node in the Agent Graph/Flow (Multi Agent System).
model AgentNode {
id String @id @default(uuid())
@@ -83,6 +168,10 @@ model AgentNode {
metadata String @default("{}")
ExecutionHistory AgentNodeExecution[]
@@index([agentGraphId, agentGraphVersion])
@@index([agentBlockId])
@@index([webhookId])
}
// This model describes the link between two AgentNodes.
@@ -101,6 +190,9 @@ model AgentNodeLink {
// Default: the data coming from the source can only be consumed by the sink once, Static: input data will be reused.
isStatic Boolean @default(false)
@@index([agentNodeSourceId])
@@index([agentNodeSinkId])
}
// This model describes a component that will be executed by the AgentNode.
@@ -115,7 +207,7 @@ model AgentBlock {
// Prisma requires explicit back-references.
ReferencedByAgentNode AgentNode[]
UserBlockCredit UserBlockCredit[]
CreditTransaction CreditTransaction[]
}
// This model describes the status of an AgentGraphExecution or AgentNodeExecution.
@@ -146,7 +238,12 @@ model AgentGraphExecution {
userId String
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
stats String? // JSON serialized object
stats String? // JSON serialized object
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
agentPresetId String?
@@index([agentGraphId, agentGraphVersion])
@@index([userId])
}
// This model describes the execution of an AgentNode.
@@ -171,6 +268,9 @@ model AgentNodeExecution {
endedTime DateTime?
stats String? // JSON serialized object
@@index([agentGraphExecutionId])
@@index([agentNodeId])
}
// This model describes the output of an AgentNodeExecution.
@@ -187,8 +287,12 @@ model AgentNodeExecutionInputOutput {
referencedByOutputExecId String?
ReferencedByOutputExec AgentNodeExecution? @relation("AgentNodeExecutionOutput", fields: [referencedByOutputExecId], references: [id], onDelete: Cascade)
agentPresetId String?
AgentPreset AgentPreset? @relation("AgentPresetsInputData", fields: [agentPresetId], references: [id])
// Input and Output pin names are unique for each AgentNodeExecution.
@@unique([referencedByInputExecId, referencedByOutputExecId, name])
@@index([referencedByOutputExecId])
}
// Webhook that is registered with a provider and propagates to one or more nodes
@@ -211,6 +315,8 @@ model IntegrationWebhook {
providerWebhookId String // Webhook ID assigned by the provider
AgentNodes AgentNode[]
@@index([userId])
}
model AnalyticsDetails {
@@ -238,8 +344,13 @@ model AnalyticsDetails {
@@index([type])
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
////////////// METRICS TRACKING TABLES ////////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
model AnalyticsMetrics {
id String @id @default(dbgenerated("gen_random_uuid()"))
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
@@ -254,14 +365,21 @@ model AnalyticsMetrics {
// Link to User model
userId String
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@index([userId])
}
enum UserBlockCreditType {
enum CreditTransactionType {
TOP_UP
USAGE
}
model UserBlockCredit {
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
//////// ACCOUNTING AND CREDIT SYSTEM TABLES //////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
model CreditTransaction {
transactionKey String @default(uuid())
createdAt DateTime @default(now())
@@ -272,12 +390,215 @@ model UserBlockCredit {
block AgentBlock? @relation(fields: [blockId], references: [id])
amount Int
type UserBlockCreditType
type CreditTransactionType
isActive Boolean @default(true)
metadata Json?
@@id(name: "creditTransactionIdentifier", [transactionKey, userId])
@@index([userId, createdAt])
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
////////////// Store TABLES ///////////////////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
model Profile {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
// Only 1 of user or group can be set.
// The user this profile belongs to, if any.
userId String?
User User? @relation(fields: [userId], references: [id], onDelete: Cascade)
name String
username String @unique
description String
links String[]
avatarUrl String?
isFeatured Boolean @default(false)
@@index([username])
@@index([userId])
}
view Creator {
username String @unique
name String
avatar_url String
description String
top_categories String[]
links String[]
num_agents Int
agent_rating Float
agent_runs Int
is_featured Boolean
}
view StoreAgent {
listing_id String @id
storeListingVersionId String
updated_at DateTime
slug String
agent_name String
agent_video String?
agent_image String[]
featured Boolean @default(false)
creator_username String
creator_avatar String
sub_heading String
description String
categories String[]
runs Int
rating Float
versions String[]
@@unique([creator_username, slug])
@@index([creator_username])
@@index([featured])
@@index([categories])
@@index([storeListingVersionId])
}
view StoreSubmission {
listing_id String @id
user_id String
slug String
name String
sub_heading String
description String
image_urls String[]
date_submitted DateTime
status SubmissionStatus
runs Int
rating Float
agent_id String
agent_version Int
@@index([user_id])
}
model StoreListing {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
isDeleted Boolean @default(false)
// Not needed but makes lookups faster
isApproved Boolean @default(false)
// The agent link here is only so we can do lookup on agentId, for the listing the StoreListingVersion is used.
agentId String
agentVersion Int
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
owningUserId String
OwningUser User @relation(fields: [owningUserId], references: [id])
StoreListingVersions StoreListingVersion[]
StoreListingSubmission StoreListingSubmission[]
@@index([isApproved])
@@index([agentId])
@@index([owningUserId])
}
model StoreListingVersion {
id String @id @default(uuid())
version Int @default(1)
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
// The agent and version to be listed on the store
agentId String
agentVersion Int
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version])
// The detials for this version of the agent, this allows the author to update the details of the agent,
// But still allow using old versions of the agent with there original details.
// TODO: Create a database view that shows only the latest version of each store listing.
slug String
name String
subHeading String
videoUrl String?
imageUrls String[]
description String
categories String[]
isFeatured Boolean @default(false)
isDeleted Boolean @default(false)
// Old versions can be made unavailable by the author if desired
isAvailable Boolean @default(true)
// Not needed but makes lookups faster
isApproved Boolean @default(false)
StoreListing StoreListing? @relation(fields: [storeListingId], references: [id], onDelete: Cascade)
storeListingId String?
StoreListingSubmission StoreListingSubmission[]
// Reviews are on a specific version, but then aggregated up to the listing.
// This allows us to provide a review filter to current version of the agent.
StoreListingReview StoreListingReview[]
@@unique([agentId, agentVersion])
@@index([agentId, agentVersion, isApproved])
}
model StoreListingReview {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
storeListingVersionId String
StoreListingVersion StoreListingVersion @relation(fields: [storeListingVersionId], references: [id], onDelete: Cascade)
reviewByUserId String
ReviewByUser User @relation(fields: [reviewByUserId], references: [id])
score Int
comments String?
@@unique([storeListingVersionId, reviewByUserId])
@@index([storeListingVersionId])
}
enum SubmissionStatus {
DAFT
PENDING
APPROVED
REJECTED
}
model StoreListingSubmission {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
storeListingId String
StoreListing StoreListing @relation(fields: [storeListingId], references: [id], onDelete: Cascade)
storeListingVersionId String
StoreListingVersion StoreListingVersion @relation(fields: [storeListingVersionId], references: [id], onDelete: Cascade)
reviewerId String
Reviewer User @relation(fields: [reviewerId], references: [id])
Status SubmissionStatus @default(PENDING)
reviewComments String?
@@index([storeListingId])
@@index([Status])
}
enum APIKeyPermission {
@@ -290,7 +611,7 @@ enum APIKeyPermission {
model APIKey {
id String @id @default(uuid())
name String
prefix String // First 8 chars for identification
prefix String // First 8 chars for identification
postfix String
key String @unique // Hashed key
status APIKeyStatus @default(ACTIVE)
@@ -317,4 +638,4 @@ enum APIKeyStatus {
ACTIVE
REVOKED
SUSPENDED
}
}

View File

@@ -1,628 +0,0 @@
// We need to migrate our database schema to support the domain as we understand it now
// To do so requires adding a bunch of new tables, but also modiftying old ones and how
// they relate to each other. This is a large change, so instead of doing in in one go,
// We have created the target schema, and will migrate to it incrementally.
datasource db {
provider = "postgresql"
url = env("DATABASE_URL")
}
generator client {
provider = "prisma-client-py"
recursive_type_depth = 5
interface = "asyncio"
}
// User model to mirror Auth provider users
model User {
id String @id @db.Uuid // This should match the Supabase user ID
email String @unique
name String?
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
metadata String @default("")
// Relations
Agents Agent[]
AgentExecutions AgentExecution[]
AgentExecutionSchedules AgentExecutionSchedule[]
AnalyticsDetails AnalyticsDetails[]
AnalyticsMetrics AnalyticsMetrics[]
UserBlockCredit UserBlockCredit[]
AgentPresets AgentPreset[]
UserAgents UserAgent[]
// User Group relations
UserGroupMemberships UserGroupMembership[]
Profile Profile[]
StoreListing StoreListing[]
StoreListingSubmission StoreListingSubmission[]
StoreListingReview StoreListingReview[]
}
model UserGroup {
id String @id @default(uuid()) @db.Uuid
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
name String
description String
groupIconUrl String?
UserGroupMemberships UserGroupMembership[]
Agents Agent[]
Profile Profile[]
StoreListing StoreListing[]
@@index([name])
}
enum UserGroupRole {
MEMBER
OWNER
}
model UserGroupMembership {
id String @id @default(uuid()) @db.Uuid
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
userId String @db.Uuid
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
userGroupId String @db.Uuid
UserGroup UserGroup @relation(fields: [userGroupId], references: [id], onDelete: Cascade)
Role UserGroupRole @default(MEMBER)
@@unique([userId, userGroupId])
@@index([userId])
@@index([userGroupId])
}
// This model describes the Agent Graph/Flow (Multi Agent System).
model Agent {
id String @default(uuid()) @db.Uuid
version Int @default(1)
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
name String?
description String?
// Link to User model
createdByUserId String? @db.Uuid
// Do not cascade delete the agent when the user is deleted
// This allows us to delete user data with deleting the agent which maybe in use by other users
CreatedByUser User? @relation(fields: [createdByUserId], references: [id], onDelete: SetNull)
groupId String? @db.Uuid
// Do not cascade delete the agent when the group is deleted
// This allows us to delete user group data with deleting the agent which maybe in use by other users
Group UserGroup? @relation(fields: [groupId], references: [id], onDelete: SetNull)
AgentNodes AgentNode[]
AgentExecution AgentExecution[]
// All sub-graphs are defined within this 1-level depth list (even if it's a nested graph).
SubAgents Agent[] @relation("SubAgents")
agentParentId String? @db.Uuid
agentParentVersion Int?
AgentParent Agent? @relation("SubAgents", fields: [agentParentId, agentParentVersion], references: [id, version])
AgentPresets AgentPreset[]
WebhookTrigger WebhookTrigger[]
AgentExecutionSchedule AgentExecutionSchedule[]
UserAgents UserAgent[]
UserBlockCredit UserBlockCredit[]
StoreListing StoreListing[]
StoreListingVersion StoreListingVersion[]
@@id(name: "agentVersionId", [id, version])
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
//////////////// USER SPECIFIC DATA ////////////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
// An AgentPrest is an Agent + User Configuration of that agent.
// For example, if someone has created a weather agent and they want to set it up to
// Inform them of extreme weather warnings in Texas, the agent with the configuration to set it to
// monitor texas, along with the cron setup or webhook tiggers, is an AgentPreset
model AgentPreset {
id String @id @default(uuid()) @db.Uuid
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
name String
description String
// For agents that can be triggered by webhooks or cronjob
// This bool allows us to disable a configured agent without deleting it
isActive Boolean @default(true)
userId String @db.Uuid
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
agentId String @db.Uuid
agentVersion Int
Agent Agent @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
UserAgents UserAgent[]
WebhookTrigger WebhookTrigger[]
AgentExecutionSchedule AgentExecutionSchedule[]
AgentExecution AgentExecution[]
@@index([userId])
}
// For the library page
// It is a user controlled list of agents, that they will see in there library
model UserAgent {
id String @id @default(uuid()) @db.Uuid
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
userId String @db.Uuid
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
agentId String @db.Uuid
agentVersion Int
Agent Agent @relation(fields: [agentId, agentVersion], references: [id, version])
agentPresetId String? @db.Uuid
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
isFavorite Boolean @default(false)
isCreatedByUser Boolean @default(false)
isArchived Boolean @default(false)
isDeleted Boolean @default(false)
@@index([userId])
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
//////// AGENT DEFINITION AND EXECUTION TABLES ////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
// This model describes a single node in the Agent Graph/Flow (Multi Agent System).
model AgentNode {
id String @id @default(uuid()) @db.Uuid
agentBlockId String @db.Uuid
AgentBlock AgentBlock @relation(fields: [agentBlockId], references: [id], onUpdate: Cascade)
agentId String @db.Uuid
agentVersion Int @default(1)
Agent Agent @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
// List of consumed input, that the parent node should provide.
Input AgentNodeLink[] @relation("AgentNodeSink")
// List of produced output, that the child node should be executed.
Output AgentNodeLink[] @relation("AgentNodeSource")
// JSON serialized dict[str, str] containing predefined input values.
constantInput Json @default("{}")
// JSON serialized dict[str, str] containing the node metadata.
metadata Json @default("{}")
ExecutionHistory AgentNodeExecution[]
}
// This model describes the link between two AgentNodes.
model AgentNodeLink {
id String @id @default(uuid()) @db.Uuid
// Output of a node is connected to the source of the link.
agentNodeSourceId String @db.Uuid
AgentNodeSource AgentNode @relation("AgentNodeSource", fields: [agentNodeSourceId], references: [id], onDelete: Cascade)
sourceName String
// Input of a node is connected to the sink of the link.
agentNodeSinkId String @db.Uuid
AgentNodeSink AgentNode @relation("AgentNodeSink", fields: [agentNodeSinkId], references: [id], onDelete: Cascade)
sinkName String
// Default: the data coming from the source can only be consumed by the sink once, Static: input data will be reused.
isStatic Boolean @default(false)
}
// This model describes a component that will be executed by the AgentNode.
model AgentBlock {
id String @id @default(uuid()) @db.Uuid
name String @unique
// We allow a block to have multiple types of input & output.
// Serialized object-typed `jsonschema` with top-level properties as input/output name.
inputSchema Json @default("{}")
outputSchema Json @default("{}")
// Prisma requires explicit back-references.
ReferencedByAgentNode AgentNode[]
UserBlockCredit UserBlockCredit[]
}
// This model describes the status of an AgentExecution or AgentNodeExecution.
enum AgentExecutionStatus {
INCOMPLETE
QUEUED
RUNNING
COMPLETED
FAILED
}
// Enum for execution trigger types
enum ExecutionTriggerType {
MANUAL
SCHEDULE
WEBHOOK
}
// This model describes the execution of an AgentGraph.
model AgentExecution {
id String @id @default(uuid()) @db.Uuid
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
startedAt DateTime?
executionTriggerType ExecutionTriggerType @default(MANUAL)
executionStatus AgentExecutionStatus @default(COMPLETED)
agentId String @db.Uuid
agentVersion Int @default(1)
Agent Agent @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
// we need to be able to associate an agent execution with an agent preset
agentPresetId String? @db.Uuid
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
AgentNodeExecutions AgentNodeExecution[]
// This is so we can track which user executed the agent.
executedByUserId String @db.Uuid
ExecutedByUser User @relation(fields: [executedByUserId], references: [id], onDelete: Cascade)
stats Json @default("{}") // JSON serialized object
}
// This model describes the execution of an AgentNode.
model AgentNodeExecution {
id String @id @default(uuid()) @db.Uuid
agentExecutionId String @db.Uuid
AgentExecution AgentExecution @relation(fields: [agentExecutionId], references: [id], onDelete: Cascade)
agentNodeId String @db.Uuid
AgentNode AgentNode @relation(fields: [agentNodeId], references: [id], onDelete: Cascade)
Input AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionInput")
Output AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionOutput")
executionStatus AgentExecutionStatus @default(COMPLETED)
// Final JSON serialized input data for the node execution.
executionData String?
addedTime DateTime @default(now())
queuedTime DateTime?
startedTime DateTime?
endedTime DateTime?
stats Json @default("{}") // JSON serialized object
UserBlockCredit UserBlockCredit[]
}
// This model describes the output of an AgentNodeExecution.
model AgentNodeExecutionInputOutput {
id String @id @default(uuid()) @db.Uuid
name String
data String
time DateTime @default(now())
// Prisma requires explicit back-references.
referencedByInputExecId String? @db.Uuid
ReferencedByInputExec AgentNodeExecution? @relation("AgentNodeExecutionInput", fields: [referencedByInputExecId], references: [id], onDelete: Cascade)
referencedByOutputExecId String? @db.Uuid
ReferencedByOutputExec AgentNodeExecution? @relation("AgentNodeExecutionOutput", fields: [referencedByOutputExecId], references: [id], onDelete: Cascade)
agentPresetId String? @db.Uuid
AgentPreset AgentPreset? @relation("AgentPresetsInputData", fields: [agentPresetId], references: [id])
// Input and Output pin names are unique for each AgentNodeExecution.
@@unique([referencedByInputExecId, referencedByOutputExecId, name])
}
// This model describes the recurring execution schedule of an Agent.
model AgentExecutionSchedule {
id String @id @default(uuid()) @db.Uuid
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
agentPresetId String @db.Uuid
AgentPreset AgentPreset @relation(fields: [agentPresetId], references: [id], onDelete: Cascade)
schedule String // cron expression
isEnabled Boolean @default(true)
// Allows triggers to be routed down different execution paths in an agent graph
triggerIdentifier String
// default and set the value on each update, lastUpdated field has no time zone.
lastUpdated DateTime @default(now()) @updatedAt
// Link to User model
userId String @db.Uuid
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
Agent Agent? @relation(fields: [agentId, agentVersion], references: [id, version])
agentId String? @db.Uuid
agentVersion Int?
@@index([isEnabled])
}
enum HttpMethod {
GET
POST
PUT
DELETE
PATCH
}
model WebhookTrigger {
id String @id @default(uuid()) @db.Uuid
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
agentPresetId String @db.Uuid
AgentPreset AgentPreset @relation(fields: [agentPresetId], references: [id])
method HttpMethod
urlSlug String
// Allows triggers to be routed down different execution paths in an agent graph
triggerIdentifier String
isActive Boolean @default(true)
lastReceivedDataAt DateTime?
isDeleted Boolean @default(false)
Agent Agent? @relation(fields: [agentId, agentVersion], references: [id, version])
agentId String? @db.Uuid
agentVersion Int?
@@index([agentPresetId])
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
////////////// METRICS TRACKING TABLES ////////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
model AnalyticsDetails {
// PK uses gen_random_uuid() to allow the db inserts to happen outside of prisma
// typical uuid() inserts are handled by prisma
id String @id @default(dbgenerated("gen_random_uuid()")) @db.Uuid
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
// Link to User model
userId String @db.Uuid
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
// Analytics Categorical data used for filtering (indexable w and w/o userId)
type String
// Analytic Specific Data. We should use a union type here, but prisma doesn't support it.
data Json @default("{}")
// Indexable field for any count based analytical measures like page order clicking, tutorial step completion, etc.
dataIndex String?
@@index([userId, type], name: "analyticsDetails")
@@index([type])
}
model AnalyticsMetrics {
id String @id @default(dbgenerated("gen_random_uuid()")) @db.Uuid
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
// Analytics Categorical data used for filtering (indexable w and w/o userId)
analyticMetric String
// Any numeric data that should be counted upon, summed, or otherwise aggregated.
value Float
// Any string data that should be used to identify the metric as distinct.
// ex: '/build' vs '/market'
dataString String?
// Link to User model
userId String @db.Uuid
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
//////// ACCOUNTING AND CREDIT SYSTEM TABLES //////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
enum UserBlockCreditType {
TOP_UP
USAGE
}
model UserBlockCredit {
transactionKey String @default(uuid())
createdAt DateTime @default(now())
userId String @db.Uuid
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
blockId String? @db.Uuid
Block AgentBlock? @relation(fields: [blockId], references: [id])
// We need to be able to associate a credit transaction with an agent
executedAgentId String? @db.Uuid
executedAgentVersion Int?
ExecutedAgent Agent? @relation(fields: [executedAgentId, executedAgentVersion], references: [id, version])
// We need to be able to associate a cost with a specific agent execution
agentNodeExecutionId String? @db.Uuid
AgentNodeExecution AgentNodeExecution? @relation(fields: [agentNodeExecutionId], references: [id])
amount Int
type UserBlockCreditType
isActive Boolean @default(true)
metadata Json @default("{}")
@@id(name: "creditTransactionIdentifier", [transactionKey, userId])
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
////////////// Store TABLES ///////////////////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
model Profile {
id String @id @default(uuid()) @db.Uuid
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
// Only 1 of user or group can be set.
// The user this profile belongs to, if any.
userId String? @db.Uuid
User User? @relation(fields: [userId], references: [id], onDelete: Cascade)
// The group this profile belongs to, if any.
groupId String? @db.Uuid
Group UserGroup? @relation(fields: [groupId], references: [id])
username String @unique
description String
links String[]
avatarUrl String?
@@index([username])
}
model StoreListing {
id String @id @default(uuid()) @db.Uuid
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
isDeleted Boolean @default(false)
// Not needed but makes lookups faster
isApproved Boolean @default(false)
// The agent link here is only so we can do lookup on agentId, for the listing the StoreListingVersion is used.
agentId String @db.Uuid
agentVersion Int
Agent Agent @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
owningUserId String @db.Uuid
OwningUser User @relation(fields: [owningUserId], references: [id])
isGroupListing Boolean @default(false)
owningGroupId String? @db.Uuid
OwningGroup UserGroup? @relation(fields: [owningGroupId], references: [id])
StoreListingVersions StoreListingVersion[]
StoreListingSubmission StoreListingSubmission[]
@@index([isApproved])
@@index([agentId])
@@index([owningUserId])
@@index([owningGroupId])
}
model StoreListingVersion {
id String @id @default(uuid()) @db.Uuid
version Int @default(1)
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
// The agent and version to be listed on the store
agentId String @db.Uuid
agentVersion Int
Agent Agent @relation(fields: [agentId, agentVersion], references: [id, version])
// The detials for this version of the agent, this allows the author to update the details of the agent,
// But still allow using old versions of the agent with there original details.
// TODO: Create a database view that shows only the latest version of each store listing.
slug String
name String
videoUrl String?
imageUrls String[]
description String
categories String[]
isFeatured Boolean @default(false)
isDeleted Boolean @default(false)
// Old versions can be made unavailable by the author if desired
isAvailable Boolean @default(true)
// Not needed but makes lookups faster
isApproved Boolean @default(false)
StoreListing StoreListing? @relation(fields: [storeListingId], references: [id], onDelete: Cascade)
storeListingId String? @db.Uuid
StoreListingSubmission StoreListingSubmission[]
// Reviews are on a specific version, but then aggregated up to the listing.
// This allows us to provide a review filter to current version of the agent.
StoreListingReview StoreListingReview[]
@@unique([agentId, agentVersion])
@@index([agentId, agentVersion, isApproved])
}
model StoreListingReview {
id String @id @default(dbgenerated("gen_random_uuid()")) @db.Uuid
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
storeListingVersionId String @db.Uuid
StoreListingVersion StoreListingVersion @relation(fields: [storeListingVersionId], references: [id], onDelete: Cascade)
reviewByUserId String @db.Uuid
ReviewByUser User @relation(fields: [reviewByUserId], references: [id])
score Int
comments String?
}
enum SubmissionStatus {
DAFT
PENDING
APPROVED
REJECTED
}
model StoreListingSubmission {
id String @id @default(uuid()) @db.Uuid
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
storeListingId String @db.Uuid
StoreListing StoreListing @relation(fields: [storeListingId], references: [id], onDelete: Cascade)
storeListingVersionId String @db.Uuid
StoreListingVersion StoreListingVersion @relation(fields: [storeListingVersionId], references: [id], onDelete: Cascade)
reviewerId String @db.Uuid
Reviewer User @relation(fields: [reviewerId], references: [id])
Status SubmissionStatus @default(PENDING)
reviewComments String?
@@index([storeListingId])
@@index([Status])
}

View File

@@ -1,7 +1,7 @@
from datetime import datetime
import pytest
from prisma.models import UserBlockCredit
from prisma.models import CreditTransaction
from backend.blocks.llm import AITextGeneratorBlock
from backend.data.credit import UserCredit
@@ -82,7 +82,7 @@ async def test_block_credit_reset(server: SpinTestServer):
@pytest.mark.asyncio(scope="session")
async def test_credit_refill(server: SpinTestServer):
# Clear all transactions within the month
await UserBlockCredit.prisma().update_many(
await CreditTransaction.prisma().update_many(
where={
"userId": DEFAULT_USER_ID,
"createdAt": {

View File

@@ -14,7 +14,6 @@ async def test_agent_schedule(server: SpinTestServer):
test_user = await create_test_user()
test_graph = await server.agent_server.test_create_graph(
create_graph=CreateGraph(graph=create_test_graph()),
is_template=False,
user_id=test_user.id,
)

View File

@@ -0,0 +1,436 @@
import asyncio
import random
from datetime import datetime
import prisma.enums
from faker import Faker
from prisma import Prisma
faker = Faker()
# Constants for data generation limits
# Base entities
NUM_USERS = 100 # Creates 100 user records
NUM_AGENT_BLOCKS = 100 # Creates 100 agent block templates
# Per-user entities
MIN_GRAPHS_PER_USER = 1 # Each user will have between 1-5 graphs
MAX_GRAPHS_PER_USER = 5 # Total graphs: 500-2500 (NUM_USERS * MIN/MAX_GRAPHS)
# Per-graph entities
MIN_NODES_PER_GRAPH = 2 # Each graph will have between 2-5 nodes
MAX_NODES_PER_GRAPH = (
5 # Total nodes: 1000-2500 (GRAPHS_PER_USER * NUM_USERS * MIN/MAX_NODES)
)
# Additional per-user entities
MIN_PRESETS_PER_USER = 1 # Each user will have between 1-2 presets
MAX_PRESETS_PER_USER = 5 # Total presets: 500-2500 (NUM_USERS * MIN/MAX_PRESETS)
MIN_AGENTS_PER_USER = 1 # Each user will have between 1-2 agents
MAX_AGENTS_PER_USER = 10 # Total agents: 500-5000 (NUM_USERS * MIN/MAX_AGENTS)
# Execution and review records
MIN_EXECUTIONS_PER_GRAPH = 1 # Each graph will have between 1-5 execution records
MAX_EXECUTIONS_PER_GRAPH = (
20 # Total executions: 1000-5000 (TOTAL_GRAPHS * MIN/MAX_EXECUTIONS)
)
MIN_REVIEWS_PER_VERSION = 1 # Each version will have between 1-3 reviews
MAX_REVIEWS_PER_VERSION = 5 # Total reviews depends on number of versions created
def get_image():
url = faker.image_url()
while "placekitten.com" in url:
url = faker.image_url()
return url
async def main():
db = Prisma()
await db.connect()
# Insert Users
print(f"Inserting {NUM_USERS} users")
users = []
for _ in range(NUM_USERS):
user = await db.user.create(
data={
"id": str(faker.uuid4()),
"email": faker.unique.email(),
"name": faker.name(),
"metadata": prisma.Json({}),
"integrations": "",
}
)
users.append(user)
# Insert AgentBlocks
agent_blocks = []
print(f"Inserting {NUM_AGENT_BLOCKS} agent blocks")
for _ in range(NUM_AGENT_BLOCKS):
block = await db.agentblock.create(
data={
"name": f"{faker.word()}_{str(faker.uuid4())[:8]}",
"inputSchema": "{}",
"outputSchema": "{}",
}
)
agent_blocks.append(block)
# Insert AgentGraphs
agent_graphs = []
print(f"Inserting {NUM_USERS * MAX_GRAPHS_PER_USER} agent graphs")
for user in users:
for _ in range(
random.randint(MIN_GRAPHS_PER_USER, MAX_GRAPHS_PER_USER)
): # Adjust the range to create more graphs per user if desired
graph = await db.agentgraph.create(
data={
"name": faker.sentence(nb_words=3),
"description": faker.text(max_nb_chars=200),
"userId": user.id,
"isActive": True,
"isTemplate": False,
}
)
agent_graphs.append(graph)
# Insert AgentNodes
agent_nodes = []
print(
f"Inserting {NUM_USERS * MAX_GRAPHS_PER_USER * MAX_NODES_PER_GRAPH} agent nodes"
)
for graph in agent_graphs:
num_nodes = random.randint(MIN_NODES_PER_GRAPH, MAX_NODES_PER_GRAPH)
for _ in range(num_nodes): # Create 5 AgentNodes per graph
block = random.choice(agent_blocks)
node = await db.agentnode.create(
data={
"agentBlockId": block.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"constantInput": "{}",
"metadata": "{}",
}
)
agent_nodes.append(node)
# Insert AgentPresets
agent_presets = []
print(f"Inserting {NUM_USERS * MAX_PRESETS_PER_USER} agent presets")
for user in users:
num_presets = random.randint(MIN_PRESETS_PER_USER, MAX_PRESETS_PER_USER)
for _ in range(num_presets): # Create 1 AgentPreset per user
graph = random.choice(agent_graphs)
preset = await db.agentpreset.create(
data={
"name": faker.sentence(nb_words=3),
"description": faker.text(max_nb_chars=200),
"userId": user.id,
"agentId": graph.id,
"agentVersion": graph.version,
"isActive": True,
}
)
agent_presets.append(preset)
# Insert UserAgents
user_agents = []
print(f"Inserting {NUM_USERS * MAX_AGENTS_PER_USER} user agents")
for user in users:
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
for _ in range(num_agents): # Create 1 UserAgent per user
graph = random.choice(agent_graphs)
preset = random.choice(agent_presets)
user_agent = await db.useragent.create(
data={
"userId": user.id,
"agentId": graph.id,
"agentVersion": graph.version,
"agentPresetId": preset.id,
"isFavorite": random.choice([True, False]),
"isCreatedByUser": random.choice([True, False]),
"isArchived": random.choice([True, False]),
"isDeleted": random.choice([True, False]),
}
)
user_agents.append(user_agent)
# Insert AgentGraphExecutions
# Insert AgentGraphExecutions
agent_graph_executions = []
print(
f"Inserting {NUM_USERS * MAX_GRAPHS_PER_USER * MAX_EXECUTIONS_PER_GRAPH} agent graph executions"
)
graph_execution_data = []
for graph in agent_graphs:
user = random.choice(users)
num_executions = random.randint(
MIN_EXECUTIONS_PER_GRAPH, MAX_EXECUTIONS_PER_GRAPH
)
for _ in range(num_executions):
matching_presets = [p for p in agent_presets if p.agentId == graph.id]
preset = (
random.choice(matching_presets)
if matching_presets and random.random() < 0.5
else None
)
graph_execution_data.append(
{
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"userId": user.id,
"executionStatus": prisma.enums.AgentExecutionStatus.COMPLETED,
"startedAt": faker.date_time_this_year(),
"agentPresetId": preset.id if preset else None,
}
)
agent_graph_executions = await db.agentgraphexecution.create_many(
data=graph_execution_data
)
# Need to fetch the created records since create_many doesn't return them
agent_graph_executions = await db.agentgraphexecution.find_many()
# Insert AgentNodeExecutions
print(
f"Inserting {NUM_USERS * MAX_GRAPHS_PER_USER * MAX_EXECUTIONS_PER_GRAPH} agent node executions"
)
node_execution_data = []
for execution in agent_graph_executions:
nodes = [
node for node in agent_nodes if node.agentGraphId == execution.agentGraphId
]
for node in nodes:
node_execution_data.append(
{
"agentGraphExecutionId": execution.id,
"agentNodeId": node.id,
"executionStatus": prisma.enums.AgentExecutionStatus.COMPLETED,
"addedTime": datetime.now(),
}
)
agent_node_executions = await db.agentnodeexecution.create_many(
data=node_execution_data
)
# Need to fetch the created records since create_many doesn't return them
agent_node_executions = await db.agentnodeexecution.find_many()
# Insert AgentNodeExecutionInputOutput
print(
f"Inserting {NUM_USERS * MAX_GRAPHS_PER_USER * MAX_EXECUTIONS_PER_GRAPH} agent node execution input/outputs"
)
input_output_data = []
for node_execution in agent_node_executions:
# Input data
input_output_data.append(
{
"name": "input1",
"data": "{}",
"time": datetime.now(),
"referencedByInputExecId": node_execution.id,
}
)
# Output data
input_output_data.append(
{
"name": "output1",
"data": "{}",
"time": datetime.now(),
"referencedByOutputExecId": node_execution.id,
}
)
await db.agentnodeexecutioninputoutput.create_many(data=input_output_data)
# Insert AgentNodeLinks
print(f"Inserting {NUM_USERS * MAX_GRAPHS_PER_USER} agent node links")
for graph in agent_graphs:
nodes = [node for node in agent_nodes if node.agentGraphId == graph.id]
if len(nodes) >= 2:
source_node = nodes[0]
sink_node = nodes[1]
await db.agentnodelink.create(
data={
"agentNodeSourceId": source_node.id,
"sourceName": "output1",
"agentNodeSinkId": sink_node.id,
"sinkName": "input1",
"isStatic": False,
}
)
# Insert AnalyticsDetails
print(f"Inserting {NUM_USERS} analytics details")
for user in users:
for _ in range(1):
await db.analyticsdetails.create(
data={
"userId": user.id,
"type": faker.word(),
"data": prisma.Json({}),
"dataIndex": faker.word(),
}
)
# Insert AnalyticsMetrics
print(f"Inserting {NUM_USERS} analytics metrics")
for user in users:
for _ in range(1):
await db.analyticsmetrics.create(
data={
"userId": user.id,
"analyticMetric": faker.word(),
"value": random.uniform(0, 100),
"dataString": faker.word(),
}
)
# Insert CreditTransaction (formerly UserBlockCredit)
print(f"Inserting {NUM_USERS} credit transactions")
for user in users:
for _ in range(1):
block = random.choice(agent_blocks)
await db.credittransaction.create(
data={
"transactionKey": str(faker.uuid4()),
"userId": user.id,
"blockId": block.id,
"amount": random.randint(1, 100),
"type": (
prisma.enums.CreditTransactionType.TOP_UP
if random.random() < 0.5
else prisma.enums.CreditTransactionType.USAGE
),
"metadata": prisma.Json({}),
}
)
# Insert Profiles
profiles = []
print(f"Inserting {NUM_USERS} profiles")
for user in users:
profile = await db.profile.create(
data={
"userId": user.id,
"name": user.name or faker.name(),
"username": faker.unique.user_name(),
"description": faker.text(),
"links": [faker.url() for _ in range(3)],
"avatarUrl": get_image(),
}
)
profiles.append(profile)
# Insert StoreListings
store_listings = []
print(f"Inserting {NUM_USERS} store listings")
for graph in agent_graphs:
user = random.choice(users)
listing = await db.storelisting.create(
data={
"agentId": graph.id,
"agentVersion": graph.version,
"owningUserId": user.id,
"isApproved": random.choice([True, False]),
}
)
store_listings.append(listing)
# Insert StoreListingVersions
store_listing_versions = []
print(f"Inserting {NUM_USERS} store listing versions")
for listing in store_listings:
graph = [g for g in agent_graphs if g.id == listing.agentId][0]
version = await db.storelistingversion.create(
data={
"agentId": graph.id,
"agentVersion": graph.version,
"slug": faker.slug(),
"name": graph.name or faker.sentence(nb_words=3),
"subHeading": faker.sentence(),
"videoUrl": faker.url(),
"imageUrls": [get_image() for _ in range(3)],
"description": faker.text(),
"categories": [faker.word() for _ in range(3)],
"isFeatured": random.choice([True, False]),
"isAvailable": True,
"isApproved": random.choice([True, False]),
"storeListingId": listing.id,
}
)
store_listing_versions.append(version)
# Insert StoreListingReviews
print(f"Inserting {NUM_USERS * MAX_REVIEWS_PER_VERSION} store listing reviews")
for version in store_listing_versions:
# Create a copy of users list and shuffle it to avoid duplicates
available_reviewers = users.copy()
random.shuffle(available_reviewers)
# Limit number of reviews to available unique reviewers
num_reviews = min(
random.randint(MIN_REVIEWS_PER_VERSION, MAX_REVIEWS_PER_VERSION),
len(available_reviewers),
)
# Take only the first num_reviews reviewers
for reviewer in available_reviewers[:num_reviews]:
await db.storelistingreview.create(
data={
"storeListingVersionId": version.id,
"reviewByUserId": reviewer.id,
"score": random.randint(1, 5),
"comments": faker.text(),
}
)
# Insert StoreListingSubmissions
print(f"Inserting {NUM_USERS} store listing submissions")
for listing in store_listings:
version = random.choice(store_listing_versions)
reviewer = random.choice(users)
status: prisma.enums.SubmissionStatus = random.choice(
[
prisma.enums.SubmissionStatus.PENDING,
prisma.enums.SubmissionStatus.APPROVED,
prisma.enums.SubmissionStatus.REJECTED,
]
)
await db.storelistingsubmission.create(
data={
"storeListingId": listing.id,
"storeListingVersionId": version.id,
"reviewerId": reviewer.id,
"Status": status,
"reviewComments": faker.text(),
}
)
# Insert APIKeys
print(f"Inserting {NUM_USERS} api keys")
for user in users:
await db.apikey.create(
data={
"name": faker.word(),
"prefix": str(faker.uuid4())[:8],
"postfix": str(faker.uuid4())[-8:],
"key": str(faker.sha256()),
"status": prisma.enums.APIKeyStatus.ACTIVE,
"permissions": [
prisma.enums.APIKeyPermission.EXECUTE_GRAPH,
prisma.enums.APIKeyPermission.READ_GRAPH,
],
"description": faker.text(),
"userId": user.id,
}
)
await db.disconnect()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -4,6 +4,7 @@ from backend.util.request import validate_url
def test_validate_url():
# Rejected IP ranges
with pytest.raises(ValueError):
validate_url("localhost", [])
@@ -16,6 +17,63 @@ def test_validate_url():
with pytest.raises(ValueError):
validate_url("0.0.0.0", [])
validate_url("google.com", [])
validate_url("github.com", [])
validate_url("http://github.com", [])
# Normal URLs
assert validate_url("google.com/a?b=c", []) == "http://google.com/a?b=c"
assert validate_url("github.com?key=!@!@", []) == "http://github.com?key=!@!@"
# Scheme Enforcement
with pytest.raises(ValueError):
validate_url("ftp://example.com", [])
with pytest.raises(ValueError):
validate_url("file://example.com", [])
# International domain that converts to punycode - should be allowed if public
assert validate_url("http://xn--exmple-cua.com", []) == "http://xn--exmple-cua.com"
# If the domain fails IDNA encoding or is invalid, it should raise an error
with pytest.raises(ValueError):
validate_url("http://exa◌mple.com", [])
# IPv6 Addresses
with pytest.raises(ValueError):
validate_url("::1", []) # IPv6 loopback should be blocked
with pytest.raises(ValueError):
validate_url("http://[::1]", []) # IPv6 loopback in URL form
# Suspicious Characters in Hostname
with pytest.raises(ValueError):
validate_url("http://example_underscore.com", [])
with pytest.raises(ValueError):
validate_url("http://exa mple.com", []) # Space in hostname
# Malformed URLs
with pytest.raises(ValueError):
validate_url("http://", []) # No hostname
with pytest.raises(ValueError):
validate_url("://missing-scheme", []) # Missing proper scheme
# Trusted Origins
trusted = ["internal-api.company.com", "10.0.0.5"]
assert (
validate_url("internal-api.company.com", trusted)
== "http://internal-api.company.com"
)
assert validate_url("10.0.0.5", ["10.0.0.5"]) == "http://10.0.0.5"
# Special Characters in Path or Query
assert (
validate_url("example.com/path%20with%20spaces", [])
== "http://example.com/path%20with%20spaces"
)
# Backslashes should be replaced with forward slashes
assert (
validate_url("http://example.com\\backslash", [])
== "http://example.com/backslash"
)
# Check defaulting scheme behavior for valid domains
assert validate_url("example.com", []) == "http://example.com"
assert validate_url("https://secure.com", []) == "https://secure.com"
# Non-ASCII Characters in Query/Fragment
assert validate_url("example.com?param=äöü", []) == "http://example.com?param=äöü"

Some files were not shown because too many files have changed in this diff Show More