mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge branch 'dev' into aarushikansal/remove-title-input-blocks
This commit is contained in:
11
.github/workflows/platform-backend-ci.yml
vendored
11
.github/workflows/platform-backend-ci.yml
vendored
@@ -79,6 +79,17 @@ jobs:
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Check poetry.lock
|
||||
run: |
|
||||
poetry lock --no-update
|
||||
|
||||
if ! git diff --quiet poetry.lock; then
|
||||
echo "Error: poetry.lock not up to date."
|
||||
echo
|
||||
git diff poetry.lock
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
|
||||
35
.github/workflows/platform-frontend-ci.yml
vendored
35
.github/workflows/platform-frontend-ci.yml
vendored
@@ -23,6 +23,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
@@ -38,24 +39,12 @@ jobs:
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
browser: [chromium, webkit]
|
||||
|
||||
steps:
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
# this might remove tools that are actually needed,
|
||||
# if set to "true" but frees about 6 GB
|
||||
tool-cache: false
|
||||
|
||||
# all of these default to true, but feel free to set to
|
||||
# "false" if necessary for your workflow
|
||||
android: false
|
||||
dotnet: false
|
||||
haskell: false
|
||||
large-packages: true
|
||||
docker-images: true
|
||||
swap-storage: true
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -66,6 +55,12 @@ jobs:
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
large-packages: false # slow
|
||||
docker-images: false # limited benefit
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../supabase/docker/.env.example ../.env
|
||||
@@ -86,16 +81,16 @@ jobs:
|
||||
run: |
|
||||
cp .env.example .env
|
||||
|
||||
- name: Install Playwright Browsers
|
||||
run: yarn playwright install --with-deps
|
||||
- name: Install Browser '${{ matrix.browser }}'
|
||||
run: yarn playwright install --with-deps ${{ matrix.browser }}
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
yarn test
|
||||
yarn test --project=${{ matrix.browser }}
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: ${{ !cancelled() }}
|
||||
with:
|
||||
name: playwright-report
|
||||
name: playwright-report-${{ matrix.browser }}
|
||||
path: playwright-report/
|
||||
retention-days: 30
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -173,3 +173,6 @@ LICENSE.rtf
|
||||
autogpt_platform/backend/settings.py
|
||||
/.auth
|
||||
/autogpt_platform/frontend/.auth
|
||||
|
||||
*.ign.*
|
||||
.test-contents
|
||||
|
||||
@@ -35,3 +35,12 @@ def verify_user(payload: dict | None, admin_only: bool) -> User:
|
||||
raise fastapi.HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
return User.from_payload(payload)
|
||||
|
||||
|
||||
def get_user_id(payload: dict = fastapi.Depends(auth_middleware)) -> str:
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
return user_id
|
||||
|
||||
76
autogpt_platform/autogpt_libs/poetry.lock
generated
76
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
@@ -1091,13 +1091,13 @@ pyasn1 = ">=0.4.6,<0.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.10.2"
|
||||
version = "2.10.3"
|
||||
description = "Data validation using Python type hints"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic-2.10.2-py3-none-any.whl", hash = "sha256:cfb96e45951117c3024e6b67b25cdc33a3cb7b2fa62e239f7af1378358a1d99e"},
|
||||
{file = "pydantic-2.10.2.tar.gz", hash = "sha256:2bc2d7f17232e0841cbba4641e65ba1eb6fafb3a08de3a091ff3ce14a197c4fa"},
|
||||
{file = "pydantic-2.10.3-py3-none-any.whl", hash = "sha256:be04d85bbc7b65651c5f8e6b9976ed9c6f41782a55524cef079a34a0bb82144d"},
|
||||
{file = "pydantic-2.10.3.tar.gz", hash = "sha256:cb5ac360ce894ceacd69c403187900a02c4b20b693a9dd1d643e1effab9eadf9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1223,13 +1223,13 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "pydantic-settings"
|
||||
version = "2.6.1"
|
||||
version = "2.7.0"
|
||||
description = "Settings management using Pydantic"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic_settings-2.6.1-py3-none-any.whl", hash = "sha256:7fb0637c786a558d3103436278a7c4f1cfd29ba8973238a50c5bb9a55387da87"},
|
||||
{file = "pydantic_settings-2.6.1.tar.gz", hash = "sha256:e0f92546d8a9923cb8941689abf85d6601a8c19a23e97a34b2964a2e3f813ca0"},
|
||||
{file = "pydantic_settings-2.7.0-py3-none-any.whl", hash = "sha256:e00c05d5fa6cbbb227c84bd7487c5c1065084119b750df7c8c1a554aed236eb5"},
|
||||
{file = "pydantic_settings-2.7.0.tar.gz", hash = "sha256:ac4bfd4a36831a48dbf8b2d9325425b549a0a6f18cea118436d728eb4f1c4d66"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1243,13 +1243,13 @@ yaml = ["pyyaml (>=6.0.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pyjwt"
|
||||
version = "2.10.0"
|
||||
version = "2.10.1"
|
||||
description = "JSON Web Token implementation in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "PyJWT-2.10.0-py3-none-any.whl", hash = "sha256:543b77207db656de204372350926bed5a86201c4cbff159f623f79c7bb487a15"},
|
||||
{file = "pyjwt-2.10.0.tar.gz", hash = "sha256:7628a7eb7938959ac1b26e819a1df0fd3259505627b575e4bad6d08f76db695c"},
|
||||
{file = "PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb"},
|
||||
{file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@@ -1282,20 +1282,20 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "0.24.0"
|
||||
version = "0.25.0"
|
||||
description = "Pytest support for asyncio"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"},
|
||||
{file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"},
|
||||
{file = "pytest_asyncio-0.25.0-py3-none-any.whl", hash = "sha256:db5432d18eac6b7e28b46dcd9b69921b55c3b1086e85febfe04e70b18d9e81b3"},
|
||||
{file = "pytest_asyncio-0.25.0.tar.gz", hash = "sha256:8c0610303c9e0442a5db8604505fc0f545456ba1528824842b37b4a626cbf609"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=8.2,<9"
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
|
||||
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
@@ -1362,13 +1362,13 @@ websockets = ">=11,<13"
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "5.2.0"
|
||||
version = "5.2.1"
|
||||
description = "Python client for Redis database and key-value store"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "redis-5.2.0-py3-none-any.whl", hash = "sha256:ae174f2bb3b1bf2b09d54bf3e51fbc1469cf6c10aa03e21141f51969801a7897"},
|
||||
{file = "redis-5.2.0.tar.gz", hash = "sha256:0b1087665a771b1ff2e003aa5bdd354f15a70c9e25d5a7dbf9c722c16528a7b0"},
|
||||
{file = "redis-5.2.1-py3-none-any.whl", hash = "sha256:ee7e1056b9aea0f04c6c2ed59452947f34c4940ee025f5dd83e6a6418b6989e4"},
|
||||
{file = "redis-5.2.1.tar.gz", hash = "sha256:16f2e22dff21d5125e8481515e386711a34cbec50f0e44413dd7d9c060a54e0f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1415,29 +1415,29 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.8.1"
|
||||
version = "0.8.3"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "ruff-0.8.1-py3-none-linux_armv6l.whl", hash = "sha256:fae0805bd514066f20309f6742f6ee7904a773eb9e6c17c45d6b1600ca65c9b5"},
|
||||
{file = "ruff-0.8.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b8a4f7385c2285c30f34b200ca5511fcc865f17578383db154e098150ce0a087"},
|
||||
{file = "ruff-0.8.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:cd054486da0c53e41e0086e1730eb77d1f698154f910e0cd9e0d64274979a209"},
|
||||
{file = "ruff-0.8.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2029b8c22da147c50ae577e621a5bfbc5d1fed75d86af53643d7a7aee1d23871"},
|
||||
{file = "ruff-0.8.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2666520828dee7dfc7e47ee4ea0d928f40de72056d929a7c5292d95071d881d1"},
|
||||
{file = "ruff-0.8.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:333c57013ef8c97a53892aa56042831c372e0bb1785ab7026187b7abd0135ad5"},
|
||||
{file = "ruff-0.8.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:288326162804f34088ac007139488dcb43de590a5ccfec3166396530b58fb89d"},
|
||||
{file = "ruff-0.8.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b12c39b9448632284561cbf4191aa1b005882acbc81900ffa9f9f471c8ff7e26"},
|
||||
{file = "ruff-0.8.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:364e6674450cbac8e998f7b30639040c99d81dfb5bbc6dfad69bc7a8f916b3d1"},
|
||||
{file = "ruff-0.8.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b22346f845fec132aa39cd29acb94451d030c10874408dbf776af3aaeb53284c"},
|
||||
{file = "ruff-0.8.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b2f2f7a7e7648a2bfe6ead4e0a16745db956da0e3a231ad443d2a66a105c04fa"},
|
||||
{file = "ruff-0.8.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:adf314fc458374c25c5c4a4a9270c3e8a6a807b1bec018cfa2813d6546215540"},
|
||||
{file = "ruff-0.8.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a885d68342a231b5ba4d30b8c6e1b1ee3a65cf37e3d29b3c74069cdf1ee1e3c9"},
|
||||
{file = "ruff-0.8.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d2c16e3508c8cc73e96aa5127d0df8913d2290098f776416a4b157657bee44c5"},
|
||||
{file = "ruff-0.8.1-py3-none-win32.whl", hash = "sha256:93335cd7c0eaedb44882d75a7acb7df4b77cd7cd0d2255c93b28791716e81790"},
|
||||
{file = "ruff-0.8.1-py3-none-win_amd64.whl", hash = "sha256:2954cdbe8dfd8ab359d4a30cd971b589d335a44d444b6ca2cb3d1da21b75e4b6"},
|
||||
{file = "ruff-0.8.1-py3-none-win_arm64.whl", hash = "sha256:55873cc1a473e5ac129d15eccb3c008c096b94809d693fc7053f588b67822737"},
|
||||
{file = "ruff-0.8.1.tar.gz", hash = "sha256:3583db9a6450364ed5ca3f3b4225958b24f78178908d5c4bc0f46251ccca898f"},
|
||||
{file = "ruff-0.8.3-py3-none-linux_armv6l.whl", hash = "sha256:8d5d273ffffff0acd3db5bf626d4b131aa5a5ada1276126231c4174543ce20d6"},
|
||||
{file = "ruff-0.8.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e4d66a21de39f15c9757d00c50c8cdd20ac84f55684ca56def7891a025d7e939"},
|
||||
{file = "ruff-0.8.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c356e770811858bd20832af696ff6c7e884701115094f427b64b25093d6d932d"},
|
||||
{file = "ruff-0.8.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c0a60a825e3e177116c84009d5ebaa90cf40dfab56e1358d1df4e29a9a14b13"},
|
||||
{file = "ruff-0.8.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:75fb782f4db39501210ac093c79c3de581d306624575eddd7e4e13747e61ba18"},
|
||||
{file = "ruff-0.8.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f26bc76a133ecb09a38b7868737eded6941b70a6d34ef53a4027e83913b6502"},
|
||||
{file = "ruff-0.8.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:01b14b2f72a37390c1b13477c1c02d53184f728be2f3ffc3ace5b44e9e87b90d"},
|
||||
{file = "ruff-0.8.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:53babd6e63e31f4e96ec95ea0d962298f9f0d9cc5990a1bbb023a6baf2503a82"},
|
||||
{file = "ruff-0.8.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1ae441ce4cf925b7f363d33cd6570c51435972d697e3e58928973994e56e1452"},
|
||||
{file = "ruff-0.8.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7c65bc0cadce32255e93c57d57ecc2cca23149edd52714c0c5d6fa11ec328cd"},
|
||||
{file = "ruff-0.8.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5be450bb18f23f0edc5a4e5585c17a56ba88920d598f04a06bd9fd76d324cb20"},
|
||||
{file = "ruff-0.8.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8faeae3827eaa77f5721f09b9472a18c749139c891dbc17f45e72d8f2ca1f8fc"},
|
||||
{file = "ruff-0.8.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:db503486e1cf074b9808403991663e4277f5c664d3fe237ee0d994d1305bb060"},
|
||||
{file = "ruff-0.8.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:6567be9fb62fbd7a099209257fef4ad2c3153b60579818b31a23c886ed4147ea"},
|
||||
{file = "ruff-0.8.3-py3-none-win32.whl", hash = "sha256:19048f2f878f3ee4583fc6cb23fb636e48c2635e30fb2022b3a1cd293402f964"},
|
||||
{file = "ruff-0.8.3-py3-none-win_amd64.whl", hash = "sha256:f7df94f57d7418fa7c3ffb650757e0c2b96cf2501a0b192c18e4fb5571dfada9"},
|
||||
{file = "ruff-0.8.3-py3-none-win_arm64.whl", hash = "sha256:fe2756edf68ea79707c8d68b78ca9a58ed9af22e430430491ee03e718b5e4936"},
|
||||
{file = "ruff-0.8.3.tar.gz", hash = "sha256:5e7558304353b84279042fc584a4f4cb8a07ae79b2bf3da1a7551d960b5626d3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1852,4 +1852,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "0aef772b321a1a00163abdc6a88efb21874b84bf25c68e84992c3ae3af026a17"
|
||||
content-hash = "13a36d3be675cab4a3eb2e6a62a1b08df779bded4c7b9164d8be300dc08748d0"
|
||||
|
||||
@@ -10,18 +10,18 @@ packages = [{ include = "autogpt_libs" }]
|
||||
colorama = "^0.4.6"
|
||||
expiringdict = "^1.2.2"
|
||||
google-cloud-logging = "^3.11.3"
|
||||
pydantic = "^2.10.2"
|
||||
pydantic-settings = "^2.6.1"
|
||||
pyjwt = "^2.10.0"
|
||||
pytest-asyncio = "^0.24.0"
|
||||
pydantic = "^2.10.3"
|
||||
pydantic-settings = "^2.7.0"
|
||||
pyjwt = "^2.10.1"
|
||||
pytest-asyncio = "^0.25.0"
|
||||
pytest-mock = "^3.14.0"
|
||||
python = ">=3.10,<4.0"
|
||||
python-dotenv = "^1.0.1"
|
||||
supabase = "^2.10.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
redis = "^5.2.0"
|
||||
ruff = "^0.8.1"
|
||||
redis = "^5.2.1"
|
||||
ruff = "^0.8.3"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -6,18 +6,23 @@ ENV PYTHONUNBUFFERED 1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
||||
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y build-essential curl ffmpeg wget libcurl4-gnutls-dev libexpat1-dev libpq5 gettext libz-dev libssl-dev postgresql-client git \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV POETRY_VERSION=1.8.3 \
|
||||
POETRY_HOME="/opt/poetry" \
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=false \
|
||||
PATH="$POETRY_HOME/bin:$PATH"
|
||||
RUN apt-get install -y build-essential
|
||||
RUN apt-get install -y libpq5
|
||||
RUN apt-get install -y libz-dev
|
||||
RUN apt-get install -y libssl-dev
|
||||
RUN apt-get install -y postgresql-client
|
||||
|
||||
ENV POETRY_VERSION=1.8.3
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
ENV POETRY_VIRTUALENVS_CREATE=false
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
|
||||
@@ -39,11 +44,11 @@ FROM python:3.11.10-slim-bookworm AS server_dependencies
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV POETRY_VERSION=1.8.3 \
|
||||
POETRY_HOME="/opt/poetry" \
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=false \
|
||||
PATH="$POETRY_HOME/bin:$PATH"
|
||||
ENV POETRY_VERSION=1.8.3
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
ENV POETRY_VIRTUALENVS_CREATE=false
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
|
||||
@@ -200,4 +200,4 @@ To add a new agent block, you need to create a new class that inherits from `Blo
|
||||
* `run` method: the main logic of the block.
|
||||
* `test_input` & `test_output`: the sample input and output data for the block, which will be used to auto-test the block.
|
||||
* You can mock the functions declared in the block using the `test_mock` field for your unit tests.
|
||||
* Once you finish creating the block, you can test it by running `pytest -s test/block/test_block.py`.
|
||||
* Once you finish creating the block, you can test it by running `poetry run pytest -s test/block/test_block.py`.
|
||||
|
||||
@@ -12,6 +12,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
|
||||
class ImageSize(str, Enum):
|
||||
@@ -101,12 +102,10 @@ class ImageGenModel(str, Enum):
|
||||
|
||||
class AIImageGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = (
|
||||
CredentialsField(
|
||||
provider="replicate",
|
||||
supported_credential_types={"api_key"},
|
||||
description="Enter your Replicate API key to access the image generation API. You can obtain an API key from https://replicate.com/account/api-tokens.",
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Enter your Replicate API key to access the image generation API. You can obtain an API key from https://replicate.com/account/api-tokens.",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="Text prompt for image generation",
|
||||
|
||||
@@ -13,6 +13,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -54,13 +55,11 @@ class NormalizationStrategy(str, Enum):
|
||||
|
||||
class AIMusicGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = (
|
||||
CredentialsField(
|
||||
provider="replicate",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Replicate integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="The Replicate integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="A description of the music you want to generate",
|
||||
|
||||
@@ -12,6 +12,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -140,13 +141,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class AIShortformVideoCreatorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[Literal["revid"], Literal["api_key"]] = (
|
||||
CredentialsField(
|
||||
provider="revid",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The revid.ai integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="The revid.ai integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
script: str = SchemaField(
|
||||
description="""1. Use short and punctuated sentences\n\n2. Use linebreaks to create a new clip\n\n3. Text outside of brackets is spoken by the AI, and [text between brackets] will be used to guide the visual generation. For example, [close-up of a cat] will show a close-up of a cat.""",
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import re
|
||||
from typing import Any, List
|
||||
|
||||
from jinja2 import BaseLoader, Environment
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
jinja = Environment(loader=BaseLoader())
|
||||
formatter = TextFormatter()
|
||||
|
||||
|
||||
class StoreValueBlock(Block):
|
||||
@@ -296,9 +294,9 @@ class AgentOutputBlock(Block):
|
||||
"""
|
||||
if input_data.format:
|
||||
try:
|
||||
fmt = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", input_data.format)
|
||||
template = jinja.from_string(fmt)
|
||||
yield "output", template.render({input_data.name: input_data.value})
|
||||
yield "output", formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
except Exception as e:
|
||||
yield "output", f"Error: {e}, {input_data.value}"
|
||||
else:
|
||||
@@ -486,3 +484,101 @@ class NoteBlock(Block):
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "output", input_data.text
|
||||
|
||||
|
||||
class CreateDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: dict[str, Any] = SchemaField(
|
||||
description="Key-value pairs to create the dictionary with",
|
||||
placeholder="e.g., {'name': 'Alice', 'age': 25}",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
dictionary: dict[str, Any] = SchemaField(
|
||||
description="The created dictionary containing the specified key-value pairs"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if dictionary creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b924ddf4-de4f-4b56-9a85-358930dcbc91",
|
||||
description="Creates a dictionary with the specified key-value pairs. Use this when you know all the values you want to add upfront.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateDictionaryBlock.Input,
|
||||
output_schema=CreateDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": {"name": "Alice", "age": 25, "city": "New York"},
|
||||
},
|
||||
{
|
||||
"values": {"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"dictionary",
|
||||
{"name": "Alice", "age": 25, "city": "New York"},
|
||||
),
|
||||
(
|
||||
"dictionary",
|
||||
{"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# The values are already validated by Pydantic schema
|
||||
yield "dictionary", input_data.values
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create dictionary: {str(e)}"
|
||||
|
||||
|
||||
class CreateListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: List[Any] = SchemaField(
|
||||
description="A list of values to be combined into a new list.",
|
||||
placeholder="e.g., ['Alice', 25, True]",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
description="The created list containing the specified values."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if list creation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a912d5c7-6e00-4542-b2a9-8034136930e4",
|
||||
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateListBlock.Input,
|
||||
output_schema=CreateListBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": ["Alice", 25, True],
|
||||
},
|
||||
{
|
||||
"values": [1, 2, 3, "four", {"key": "value"}],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"list",
|
||||
["Alice", 25, True],
|
||||
),
|
||||
(
|
||||
"list",
|
||||
[1, 2, 3, "four", {"key": "value"}],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# The values are already validated by Pydantic schema
|
||||
yield "list", input_data.values
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create list: {str(e)}"
|
||||
|
||||
190
autogpt_platform/backend/backend/blocks/code_executor.py
Normal file
190
autogpt_platform/backend/backend/blocks/code_executor.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from e2b_code_interpreter import Sandbox
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="e2b",
|
||||
api_key=SecretStr("mock-e2b-api-key"),
|
||||
title="Mock E2B API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
|
||||
class ProgrammingLanguage(Enum):
|
||||
PYTHON = "python"
|
||||
JAVASCRIPT = "js"
|
||||
BASH = "bash"
|
||||
R = "r"
|
||||
JAVA = "java"
|
||||
|
||||
|
||||
class CodeExecutionBlock(Block):
|
||||
# TODO : Add support to upload and download files
|
||||
# Currently, You can customized the CPU and Memory, only by creating a pre customized sandbox template
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
)
|
||||
|
||||
# Todo : Option to run commond in background
|
||||
setup_commands: list[str] = SchemaField(
|
||||
description=(
|
||||
"Shell commands to set up the sandbox before running the code. "
|
||||
"You can use `curl` or `git` to install your desired Debian based "
|
||||
"package manager. `pip` and `npm` are pre-installed.\n\n"
|
||||
"These commands are executed with `sh`, in the foreground."
|
||||
),
|
||||
placeholder="pip install cowsay",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
code: str = SchemaField(
|
||||
description="Code to execute in the sandbox",
|
||||
placeholder="print('Hello, World!')",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
language: ProgrammingLanguage = SchemaField(
|
||||
description="Programming language to execute",
|
||||
default=ProgrammingLanguage.PYTHON,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
timeout: int = SchemaField(
|
||||
description="Execution timeout in seconds", default=300
|
||||
)
|
||||
|
||||
template_id: str = SchemaField(
|
||||
description=(
|
||||
"You can use an E2B sandbox template by entering its ID here. "
|
||||
"Check out the E2B docs for more details: "
|
||||
"[E2B - Sandbox template](https://e2b.dev/docs/sandbox-template)"
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
||||
error: str = SchemaField(description="Error message if execution failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0b02b072-abe7-11ef-8372-fb5d162dd712",
|
||||
description="Executes code in an isolated sandbox environment with internet access.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=CodeExecutionBlock.Input,
|
||||
output_schema=CodeExecutionBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"code": "print('Hello World')",
|
||||
"language": ProgrammingLanguage.PYTHON.value,
|
||||
"setup_commands": [],
|
||||
"timeout": 300,
|
||||
"template_id": "",
|
||||
},
|
||||
test_output=[
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
setup_commands: list[str],
|
||||
timeout: int,
|
||||
api_key: str,
|
||||
template_id: str,
|
||||
):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = Sandbox(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = Sandbox(api_key=api_key, timeout=timeout)
|
||||
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
response, stdout_logs, stderr_logs = self.execute_code(
|
||||
input_data.code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
input_data.timeout,
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.template_id,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
110
autogpt_platform/backend/backend/blocks/code_extraction_block.py
Normal file
110
autogpt_platform/backend/backend/blocks/code_extraction_block.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import re
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class CodeExtractionBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(
|
||||
description="Text containing code blocks to extract (e.g., AI response)",
|
||||
placeholder="Enter text containing code blocks",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
html: str = SchemaField(description="Extracted HTML code")
|
||||
css: str = SchemaField(description="Extracted CSS code")
|
||||
javascript: str = SchemaField(description="Extracted JavaScript code")
|
||||
python: str = SchemaField(description="Extracted Python code")
|
||||
sql: str = SchemaField(description="Extracted SQL code")
|
||||
java: str = SchemaField(description="Extracted Java code")
|
||||
cpp: str = SchemaField(description="Extracted C++ code")
|
||||
csharp: str = SchemaField(description="Extracted C# code")
|
||||
json_code: str = SchemaField(description="Extracted JSON code")
|
||||
bash: str = SchemaField(description="Extracted Bash code")
|
||||
php: str = SchemaField(description="Extracted PHP code")
|
||||
ruby: str = SchemaField(description="Extracted Ruby code")
|
||||
yaml: str = SchemaField(description="Extracted YAML code")
|
||||
markdown: str = SchemaField(description="Extracted Markdown code")
|
||||
typescript: str = SchemaField(description="Extracted TypeScript code")
|
||||
xml: str = SchemaField(description="Extracted XML code")
|
||||
remaining_text: str = SchemaField(
|
||||
description="Remaining text after code extraction"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d3a7d896-3b78-4f44-8b4b-48fbf4f0bcd8",
|
||||
description="Extracts code blocks from text and identifies their programming languages",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=CodeExtractionBlock.Input,
|
||||
output_schema=CodeExtractionBlock.Output,
|
||||
test_input={
|
||||
"text": "Here's a Python example:\n```python\nprint('Hello World')\n```\nAnd some HTML:\n```html\n<h1>Title</h1>\n```"
|
||||
},
|
||||
test_output=[
|
||||
("html", "<h1>Title</h1>"),
|
||||
("python", "print('Hello World')"),
|
||||
("remaining_text", "Here's a Python example:\nAnd some HTML:"),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# List of supported programming languages with mapped aliases
|
||||
language_aliases = {
|
||||
"html": ["html", "htm"],
|
||||
"css": ["css"],
|
||||
"javascript": ["javascript", "js"],
|
||||
"python": ["python", "py"],
|
||||
"sql": ["sql"],
|
||||
"java": ["java"],
|
||||
"cpp": ["cpp", "c++"],
|
||||
"csharp": ["csharp", "c#", "cs"],
|
||||
"json_code": ["json"],
|
||||
"bash": ["bash", "shell", "sh"],
|
||||
"php": ["php"],
|
||||
"ruby": ["ruby", "rb"],
|
||||
"yaml": ["yaml", "yml"],
|
||||
"markdown": ["markdown", "md"],
|
||||
"typescript": ["typescript", "ts"],
|
||||
"xml": ["xml"],
|
||||
}
|
||||
|
||||
# Extract code for each language
|
||||
for canonical_name, aliases in language_aliases.items():
|
||||
code = ""
|
||||
# Try each alias for the language
|
||||
for alias in aliases:
|
||||
code_for_alias = self.extract_code(input_data.text, alias)
|
||||
if code_for_alias:
|
||||
code = code + "\n\n" + code_for_alias if code else code_for_alias
|
||||
|
||||
if code: # Only yield if there's actual code content
|
||||
yield canonical_name, code
|
||||
|
||||
# Remove all code blocks from the text to get remaining text
|
||||
pattern = (
|
||||
r"```(?:"
|
||||
+ "|".join(
|
||||
re.escape(alias)
|
||||
for aliases in language_aliases.values()
|
||||
for alias in aliases
|
||||
)
|
||||
+ r")\s+[\s\S]*?```"
|
||||
)
|
||||
|
||||
remaining_text = re.sub(pattern, "", input_data.text).strip()
|
||||
remaining_text = re.sub(r"\n\s*\n", "\n", remaining_text)
|
||||
|
||||
if remaining_text: # Only yield if there's remaining text
|
||||
yield "remaining_text", remaining_text
|
||||
|
||||
def extract_code(self, text: str, language: str) -> str:
|
||||
# Escape special regex characters in the language string
|
||||
language = re.escape(language)
|
||||
# Extract all code blocks enclosed in ```language``` blocks
|
||||
pattern = re.compile(rf"```{language}\s+(.*?)```", re.DOTALL | re.IGNORECASE)
|
||||
matches = pattern.finditer(text)
|
||||
# Combine all code blocks for this language with newlines between them
|
||||
code_blocks = [match.group(1).strip() for match in matches]
|
||||
return "\n\n".join(code_blocks) if code_blocks else ""
|
||||
59
autogpt_platform/backend/backend/blocks/compass/triggers.py
Normal file
59
autogpt_platform/backend/backend/blocks/compass/triggers.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockManualWebhookConfig,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.webhooks.compass import CompassWebhookType
|
||||
|
||||
|
||||
class Transcription(BaseModel):
|
||||
text: str
|
||||
speaker: str
|
||||
end: float
|
||||
start: float
|
||||
duration: float
|
||||
|
||||
|
||||
class TranscriptionDataModel(BaseModel):
|
||||
date: str
|
||||
transcription: str
|
||||
transcriptions: list[Transcription]
|
||||
|
||||
|
||||
class CompassAITriggerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
payload: TranscriptionDataModel = SchemaField(hidden=True)
|
||||
|
||||
class Output(BlockSchema):
|
||||
transcription: str = SchemaField(
|
||||
description="The contents of the compass transcription."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9464a020-ed1d-49e1-990f-7f2ac924a2b7",
|
||||
description="This block will output the contents of the compass transcription.",
|
||||
categories={BlockCategory.HARDWARE},
|
||||
input_schema=CompassAITriggerBlock.Input,
|
||||
output_schema=CompassAITriggerBlock.Output,
|
||||
webhook_config=BlockManualWebhookConfig(
|
||||
provider="compass",
|
||||
webhook_type=CompassWebhookType.TRANSCRIPTION,
|
||||
),
|
||||
test_input=[
|
||||
{"input": "Hello, World!"},
|
||||
{"input": "Hello, World!", "data": "Existing Data"},
|
||||
],
|
||||
# test_output=[
|
||||
# ("output", "Hello, World!"), # No data provided, so trigger is returned
|
||||
# ("output", "Existing Data"), # Data is provided, so data is returned.
|
||||
# ],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "transcription", input_data.payload.transcription
|
||||
@@ -12,16 +12,15 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
DiscordCredentials = CredentialsMetaInput[Literal["discord"], Literal["api_key"]]
|
||||
DiscordCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.DISCORD], Literal["api_key"]
|
||||
]
|
||||
|
||||
|
||||
def DiscordCredentialsField() -> DiscordCredentials:
|
||||
return CredentialsField(
|
||||
description="Discord bot token",
|
||||
provider="discord",
|
||||
supported_credential_types={"api_key"},
|
||||
)
|
||||
return CredentialsField(description="Discord bot token")
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
|
||||
32
autogpt_platform/backend/backend/blocks/exa/_auth.py
Normal file
32
autogpt_platform/backend/backend/blocks/exa/_auth.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
ExaCredentials = APIKeyCredentials
|
||||
ExaCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.EXA],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="exa",
|
||||
api_key=SecretStr("mock-exa-api-key"),
|
||||
title="Mock Exa API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def ExaCredentialsField() -> ExaCredentialsInput:
|
||||
"""Creates an Exa credentials input on a block."""
|
||||
return CredentialsField(description="The Exa integration requires an API Key.")
|
||||
157
autogpt_platform/backend/backend/blocks/exa/search.py
Normal file
157
autogpt_platform/backend/backend/blocks/exa/search.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.blocks.exa._auth import (
|
||||
ExaCredentials,
|
||||
ExaCredentialsField,
|
||||
ExaCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import requests
|
||||
|
||||
|
||||
class ContentSettings(BaseModel):
|
||||
text: dict = SchemaField(
|
||||
description="Text content settings",
|
||||
default={"maxCharacters": 1000, "includeHtmlTags": False},
|
||||
)
|
||||
highlights: dict = SchemaField(
|
||||
description="Highlight settings",
|
||||
default={"numSentences": 3, "highlightsPerUrl": 3},
|
||||
)
|
||||
summary: dict = SchemaField(
|
||||
description="Summary settings",
|
||||
default={"query": ""},
|
||||
)
|
||||
|
||||
|
||||
class ExaSearchBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
query: str = SchemaField(description="The search query")
|
||||
useAutoprompt: bool = SchemaField(
|
||||
description="Whether to use autoprompt",
|
||||
default=True,
|
||||
)
|
||||
type: str = SchemaField(
|
||||
description="Type of search",
|
||||
default="",
|
||||
)
|
||||
category: str = SchemaField(
|
||||
description="Category to search within",
|
||||
default="",
|
||||
)
|
||||
numResults: int = SchemaField(
|
||||
description="Number of results to return",
|
||||
default=10,
|
||||
)
|
||||
includeDomains: List[str] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
default=[],
|
||||
)
|
||||
excludeDomains: List[str] = SchemaField(
|
||||
description="Domains to exclude from search",
|
||||
default=[],
|
||||
)
|
||||
startCrawlDate: datetime = SchemaField(
|
||||
description="Start date for crawled content",
|
||||
)
|
||||
endCrawlDate: datetime = SchemaField(
|
||||
description="End date for crawled content",
|
||||
)
|
||||
startPublishedDate: datetime = SchemaField(
|
||||
description="Start date for published content",
|
||||
)
|
||||
endPublishedDate: datetime = SchemaField(
|
||||
description="End date for published content",
|
||||
)
|
||||
includeText: List[str] = SchemaField(
|
||||
description="Text patterns to include",
|
||||
default=[],
|
||||
)
|
||||
excludeText: List[str] = SchemaField(
|
||||
description="Text patterns to exclude",
|
||||
default=[],
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
description="Content retrieval settings",
|
||||
default=ContentSettings(),
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of search results",
|
||||
default=[],
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="996cec64-ac40-4dde-982f-b0dc60a5824d",
|
||||
description="Searches the web using Exa's advanced search API",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaSearchBlock.Input,
|
||||
output_schema=ExaSearchBlock.Output,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/search"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
payload = {
|
||||
"query": input_data.query,
|
||||
"useAutoprompt": input_data.useAutoprompt,
|
||||
"numResults": input_data.numResults,
|
||||
"contents": {
|
||||
"text": {"maxCharacters": 1000, "includeHtmlTags": False},
|
||||
"highlights": {
|
||||
"numSentences": 3,
|
||||
"highlightsPerUrl": 3,
|
||||
},
|
||||
"summary": {"query": ""},
|
||||
},
|
||||
}
|
||||
|
||||
# Add dates if they exist
|
||||
date_fields = [
|
||||
"startCrawlDate",
|
||||
"endCrawlDate",
|
||||
"startPublishedDate",
|
||||
"endPublishedDate",
|
||||
]
|
||||
for field in date_fields:
|
||||
value = getattr(input_data, field, None)
|
||||
if value:
|
||||
payload[field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
|
||||
# Add other fields
|
||||
optional_fields = [
|
||||
"type",
|
||||
"category",
|
||||
"includeDomains",
|
||||
"excludeDomains",
|
||||
"includeText",
|
||||
"excludeText",
|
||||
]
|
||||
|
||||
for field in optional_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value: # Only add non-empty values
|
||||
payload[field] = value
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
# Extract just the results array from the response
|
||||
yield "results", data.get("results", [])
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "results", []
|
||||
@@ -3,10 +3,11 @@ from typing import Literal
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
FalCredentials = APIKeyCredentials
|
||||
FalCredentialsInput = CredentialsMetaInput[
|
||||
Literal["fal"],
|
||||
Literal[ProviderName.FAL],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
@@ -30,7 +31,5 @@ def FalCredentialsField() -> FalCredentialsInput:
|
||||
Creates a FAL credentials input on a block.
|
||||
"""
|
||||
return CredentialsField(
|
||||
provider="fal",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The FAL integration can be used with an API Key.",
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
secrets = Secrets()
|
||||
@@ -17,7 +18,7 @@ GITHUB_OAUTH_IS_CONFIGURED = bool(
|
||||
|
||||
GithubCredentials = APIKeyCredentials | OAuth2Credentials
|
||||
GithubCredentialsInput = CredentialsMetaInput[
|
||||
Literal["github"],
|
||||
Literal[ProviderName.GITHUB],
|
||||
Literal["api_key", "oauth2"] if GITHUB_OAUTH_IS_CONFIGURED else Literal["api_key"],
|
||||
]
|
||||
|
||||
@@ -30,10 +31,6 @@ def GithubCredentialsField(scope: str) -> GithubCredentialsInput:
|
||||
scope: The authorization scope needed for the block to work. ([list of available scopes](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/scopes-for-oauth-apps#available-scopes))
|
||||
""" # noqa
|
||||
return CredentialsField(
|
||||
provider="github",
|
||||
supported_credential_types=(
|
||||
{"api_key", "oauth2"} if GITHUB_OAUTH_IS_CONFIGURED else {"api_key"}
|
||||
),
|
||||
required_scopes={scope},
|
||||
description="The GitHub integration can be used with OAuth, "
|
||||
"or any API key with sufficient permissions for the blocks it is used on.",
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import re
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
@@ -253,7 +255,7 @@ class GithubReadPullRequestBlock(Block):
|
||||
@staticmethod
|
||||
def read_pr_changes(credentials: GithubCredentials, pr_url: str) -> str:
|
||||
api = get_api(credentials)
|
||||
files_url = pr_url + "/files"
|
||||
files_url = prepare_pr_api_url(pr_url=pr_url, path="files")
|
||||
response = api.get(files_url)
|
||||
files = response.json()
|
||||
changes = []
|
||||
@@ -331,7 +333,7 @@ class GithubAssignPRReviewerBlock(Block):
|
||||
credentials: GithubCredentials, pr_url: str, reviewer: str
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
reviewers_url = pr_url + "/requested_reviewers"
|
||||
reviewers_url = prepare_pr_api_url(pr_url=pr_url, path="requested_reviewers")
|
||||
data = {"reviewers": [reviewer]}
|
||||
api.post(reviewers_url, json=data)
|
||||
return "Reviewer assigned successfully"
|
||||
@@ -398,7 +400,7 @@ class GithubUnassignPRReviewerBlock(Block):
|
||||
credentials: GithubCredentials, pr_url: str, reviewer: str
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
reviewers_url = pr_url + "/requested_reviewers"
|
||||
reviewers_url = prepare_pr_api_url(pr_url=pr_url, path="requested_reviewers")
|
||||
data = {"reviewers": [reviewer]}
|
||||
api.delete(reviewers_url, json=data)
|
||||
return "Reviewer unassigned successfully"
|
||||
@@ -478,7 +480,7 @@ class GithubListPRReviewersBlock(Block):
|
||||
credentials: GithubCredentials, pr_url: str
|
||||
) -> list[Output.ReviewerItem]:
|
||||
api = get_api(credentials)
|
||||
reviewers_url = pr_url + "/requested_reviewers"
|
||||
reviewers_url = prepare_pr_api_url(pr_url=pr_url, path="requested_reviewers")
|
||||
response = api.get(reviewers_url)
|
||||
data = response.json()
|
||||
reviewers: list[GithubListPRReviewersBlock.Output.ReviewerItem] = [
|
||||
@@ -499,3 +501,14 @@ class GithubListPRReviewersBlock(Block):
|
||||
input_data.pr_url,
|
||||
)
|
||||
yield from (("reviewer", reviewer) for reviewer in reviewers)
|
||||
|
||||
|
||||
def prepare_pr_api_url(pr_url: str, path: str) -> str:
|
||||
# Pattern to capture the base repository URL and the pull request number
|
||||
pattern = r"^(?:https?://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
||||
match = re.match(pattern, pr_url)
|
||||
if not match:
|
||||
return pr_url
|
||||
|
||||
base_url, pr_number = match.groups()
|
||||
return f"{base_url}/pulls/{pr_number}/{path}"
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Literal
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput, OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
# --8<-- [start:GoogleOAuthIsConfigured]
|
||||
@@ -12,7 +13,9 @@ GOOGLE_OAUTH_IS_CONFIGURED = bool(
|
||||
)
|
||||
# --8<-- [end:GoogleOAuthIsConfigured]
|
||||
GoogleCredentials = OAuth2Credentials
|
||||
GoogleCredentialsInput = CredentialsMetaInput[Literal["google"], Literal["oauth2"]]
|
||||
GoogleCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.GOOGLE], Literal["oauth2"]
|
||||
]
|
||||
|
||||
|
||||
def GoogleCredentialsField(scopes: list[str]) -> GoogleCredentialsInput:
|
||||
@@ -23,8 +26,6 @@ def GoogleCredentialsField(scopes: list[str]) -> GoogleCredentialsInput:
|
||||
scopes: The authorization scopes needed for the block to work.
|
||||
"""
|
||||
return CredentialsField(
|
||||
provider="google",
|
||||
supported_credential_types={"oauth2"},
|
||||
required_scopes=set(scopes),
|
||||
description="The Google integration requires OAuth2 authentication.",
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -38,12 +39,8 @@ class Place(BaseModel):
|
||||
class GoogleMapsSearchBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal["google_maps"], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
provider="google_maps",
|
||||
supported_credential_types={"api_key"},
|
||||
description="Google Maps API Key",
|
||||
)
|
||||
Literal[ProviderName.GOOGLE_MAPS], Literal["api_key"]
|
||||
] = CredentialsField(description="Google Maps API Key")
|
||||
query: str = SchemaField(
|
||||
description="Search query for local businesses",
|
||||
placeholder="e.g., 'restaurants in New York'",
|
||||
|
||||
@@ -3,10 +3,11 @@ from typing import Literal
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
HubSpotCredentials = APIKeyCredentials
|
||||
HubSpotCredentialsInput = CredentialsMetaInput[
|
||||
Literal["hubspot"],
|
||||
Literal[ProviderName.HUBSPOT],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
@@ -14,8 +15,6 @@ HubSpotCredentialsInput = CredentialsMetaInput[
|
||||
def HubSpotCredentialsField() -> HubSpotCredentialsInput:
|
||||
"""Creates a HubSpot credentials input on a block."""
|
||||
return CredentialsField(
|
||||
provider="hubspot",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The HubSpot integration requires an API Key.",
|
||||
)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -83,13 +84,10 @@ class UpscaleOption(str, Enum):
|
||||
|
||||
class IdeogramModelBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
|
||||
credentials: CredentialsMetaInput[Literal["ideogram"], Literal["api_key"]] = (
|
||||
CredentialsField(
|
||||
provider="ideogram",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Ideogram integration can be used with any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.IDEOGRAM], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="The Ideogram integration can be used with any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="Text prompt for image generation",
|
||||
|
||||
@@ -3,27 +3,14 @@ from typing import Literal
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
JinaCredentials = APIKeyCredentials
|
||||
JinaCredentialsInput = CredentialsMetaInput[
|
||||
Literal["jina"],
|
||||
Literal[ProviderName.JINA],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="jina",
|
||||
api_key=SecretStr("mock-jina-api-key"),
|
||||
title="Mock Jina API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
|
||||
def JinaCredentialsField() -> JinaCredentialsInput:
|
||||
"""
|
||||
@@ -31,8 +18,6 @@ def JinaCredentialsField() -> JinaCredentialsInput:
|
||||
|
||||
"""
|
||||
return CredentialsField(
|
||||
provider="jina",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Jina integration can be used with an API Key.",
|
||||
)
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import _EnumMemberT
|
||||
|
||||
@@ -27,7 +29,13 @@ from backend.util.settings import BehaveAs, Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LLMProviderName = Literal["anthropic", "groq", "openai", "ollama", "open_router"]
|
||||
LLMProviderName = Literal[
|
||||
ProviderName.ANTHROPIC,
|
||||
ProviderName.GROQ,
|
||||
ProviderName.OLLAMA,
|
||||
ProviderName.OPENAI,
|
||||
ProviderName.OPEN_ROUTER,
|
||||
]
|
||||
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -48,8 +56,6 @@ TEST_CREDENTIALS_INPUT = {
|
||||
def AICredentialsField() -> AICredentials:
|
||||
return CredentialsField(
|
||||
description="API key for the LLM provider.",
|
||||
provider=["anthropic", "groq", "openai", "ollama", "open_router"],
|
||||
supported_credential_types={"api_key"},
|
||||
discriminator="model",
|
||||
discriminator_mapping={
|
||||
model.value: model.metadata.provider for model in LlmModel
|
||||
@@ -105,9 +111,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# Ollama models
|
||||
OLLAMA_LLAMA3_8B = "llama3"
|
||||
OLLAMA_LLAMA3_405B = "llama3.1:405b"
|
||||
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
|
||||
# OpenRouter models
|
||||
GEMINI_FLASH_1_5_8B = "google/gemini-flash-1.5"
|
||||
GEMINI_FLASH_1_5_EXP = "google/gemini-flash-1.5-exp"
|
||||
GROK_BETA = "x-ai/grok-beta"
|
||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||
@@ -117,6 +123,14 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE = (
|
||||
"perplexity/llama-3.1-sonar-large-128k-online"
|
||||
)
|
||||
QWEN_QWQ_32B_PREVIEW = "qwen/qwq-32b-preview"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||
AMAZON_NOVA_LITE_V1 = "amazon/nova-lite-v1"
|
||||
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
@@ -151,8 +165,8 @@ MODEL_METADATA = {
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192),
|
||||
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768),
|
||||
LlmModel.GEMINI_FLASH_1_5_8B: ModelMetadata("open_router", 8192),
|
||||
LlmModel.GEMINI_FLASH_1_5_EXP: ModelMetadata("open_router", 8192),
|
||||
LlmModel.GROK_BETA: ModelMetadata("open_router", 8192),
|
||||
LlmModel.MISTRAL_NEMO: ModelMetadata("open_router", 4000),
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata("open_router", 4000),
|
||||
@@ -162,6 +176,14 @@ MODEL_METADATA = {
|
||||
LlmModel.PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE: ModelMetadata(
|
||||
"open_router", 8192
|
||||
),
|
||||
LlmModel.QWEN_QWQ_32B_PREVIEW: ModelMetadata("open_router", 4000),
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata("open_router", 4000),
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata("open_router", 4000),
|
||||
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata("open_router", 4000),
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata("open_router", 4000),
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata("open_router", 4000),
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata("open_router", 4000),
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata("open_router", 4000),
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
@@ -220,6 +242,12 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
)
|
||||
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
default="localhost:11434",
|
||||
description="Ollama host for local models",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: dict[str, Any] = SchemaField(
|
||||
description="The response object generated by the language model."
|
||||
@@ -265,6 +293,7 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
max_tokens: int | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
) -> tuple[str, int, int]:
|
||||
"""
|
||||
Args:
|
||||
@@ -273,6 +302,7 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
prompt: The prompt to send to the LLM.
|
||||
json_format: Whether the response should be in JSON format.
|
||||
max_tokens: The maximum number of tokens to generate in the chat completion.
|
||||
ollama_host: The host for ollama to use
|
||||
|
||||
Returns:
|
||||
The response from the LLM.
|
||||
@@ -362,9 +392,10 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
response.usage.completion_tokens if response.usage else 0,
|
||||
)
|
||||
elif provider == "ollama":
|
||||
client = ollama.Client(host=ollama_host)
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
response = ollama.generate(
|
||||
response = client.generate(
|
||||
model=llm_model.value,
|
||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||
stream=False,
|
||||
@@ -464,6 +495,7 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
json_format=bool(input_data.expected_format),
|
||||
ollama_host=input_data.ollama_host,
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
self.merge_stats(
|
||||
@@ -546,6 +578,11 @@ class AITextGeneratorBlock(Block):
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False, default={}, description="Values used to fill in the prompt."
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
default="localhost:11434",
|
||||
description="Ollama host for local models",
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
default=None,
|
||||
@@ -636,6 +673,11 @@ class AITextSummarizerBlock(Block):
|
||||
description="The number of overlapping tokens between chunks to maintain context.",
|
||||
ge=0,
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
default="localhost:11434",
|
||||
description="Ollama host for local models",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
summary: str = SchemaField(description="The final summary of the text.")
|
||||
@@ -774,6 +816,11 @@ class AIConversationBlock(Block):
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
default="localhost:11434",
|
||||
description="Ollama host for local models",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(
|
||||
@@ -871,6 +918,11 @@ class AIListGeneratorBlock(Block):
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
default="localhost:11434",
|
||||
description="Ollama host for local models",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
generated_list: List[str] = SchemaField(description="The generated list.")
|
||||
@@ -1022,6 +1074,7 @@ class AIListGeneratorBlock(Block):
|
||||
credentials=input_data.credentials,
|
||||
model=input_data.model,
|
||||
expected_format={}, # Do not use structured response
|
||||
ollama_host=input_data.ollama_host,
|
||||
),
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
SecretField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -77,12 +78,10 @@ class PublishToMediumBlock(Block):
|
||||
description="Whether to notify followers that the user has published",
|
||||
placeholder="False",
|
||||
)
|
||||
credentials: CredentialsMetaInput[Literal["medium"], Literal["api_key"]] = (
|
||||
CredentialsField(
|
||||
provider="medium",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Medium integration can be used with any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.MEDIUM], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="The Medium integration can be used with any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -10,22 +10,18 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
PineconeCredentials = APIKeyCredentials
|
||||
PineconeCredentialsInput = CredentialsMetaInput[
|
||||
Literal["pinecone"],
|
||||
Literal[ProviderName.PINECONE],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
|
||||
def PineconeCredentialsField() -> PineconeCredentialsInput:
|
||||
"""
|
||||
Creates a Pinecone credentials input on a block.
|
||||
|
||||
"""
|
||||
"""Creates a Pinecone credentials input on a block."""
|
||||
return CredentialsField(
|
||||
provider="pinecone",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Pinecone integration can be used with an API Key.",
|
||||
)
|
||||
|
||||
@@ -147,7 +143,7 @@ class PineconeQueryBlock(Block):
|
||||
top_k=input_data.top_k,
|
||||
include_values=input_data.include_values,
|
||||
include_metadata=input_data.include_metadata,
|
||||
).to_dict()
|
||||
).to_dict() # type: ignore
|
||||
combined_text = ""
|
||||
if results["matches"]:
|
||||
texts = [
|
||||
|
||||
@@ -13,6 +13,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -54,13 +55,11 @@ class ImageType(str, Enum):
|
||||
|
||||
class ReplicateFluxAdvancedModelBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = (
|
||||
CredentialsField(
|
||||
provider="replicate",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Replicate integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="The Replicate integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="Text prompt for image generation",
|
||||
|
||||
@@ -11,6 +11,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
|
||||
class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||
@@ -65,10 +66,8 @@ class GetWeatherInformationBlock(Block, GetRequest):
|
||||
description="Location to get weather information for"
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal["openweathermap"], Literal["api_key"]
|
||||
Literal[ProviderName.OPENWEATHERMAP], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
provider="openweathermap",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The OpenWeatherMap integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
|
||||
@@ -4,16 +4,15 @@ from typing import Literal
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
Slant3DCredentialsInput = CredentialsMetaInput[Literal["slant3d"], Literal["api_key"]]
|
||||
Slant3DCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.SLANT3D], Literal["api_key"]
|
||||
]
|
||||
|
||||
|
||||
def Slant3DCredentialsField() -> Slant3DCredentialsInput:
|
||||
return CredentialsField(
|
||||
provider="slant3d",
|
||||
supported_credential_types={"api_key"},
|
||||
description="Slant3D API key for authentication",
|
||||
)
|
||||
return CredentialsField(description="Slant3D API key for authentication")
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
|
||||
@@ -10,6 +10,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -29,13 +30,11 @@ TEST_CREDENTIALS_INPUT = {
|
||||
|
||||
class CreateTalkingAvatarVideoBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[Literal["d_id"], Literal["api_key"]] = (
|
||||
CredentialsField(
|
||||
provider="d_id",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The D-ID integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.D_ID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="The D-ID integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
script_input: str = SchemaField(
|
||||
description="The text input for the script",
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import BaseLoader, Environment
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
from backend.util import json, text
|
||||
|
||||
jinja = Environment(loader=BaseLoader())
|
||||
formatter = text.TextFormatter()
|
||||
|
||||
|
||||
class MatchTextPatternBlock(Block):
|
||||
@@ -73,6 +71,7 @@ class ExtractTextInformationBlock(Block):
|
||||
description="Case sensitive match", default=True
|
||||
)
|
||||
dot_all: bool = SchemaField(description="Dot matches all", default=True)
|
||||
find_all: bool = SchemaField(description="Find all matches", default=False)
|
||||
|
||||
class Output(BlockSchema):
|
||||
positive: str = SchemaField(description="Extracted text")
|
||||
@@ -90,12 +89,27 @@ class ExtractTextInformationBlock(Block):
|
||||
{"text": "Hello, World!", "pattern": "Hello, (.+)", "group": 0},
|
||||
{"text": "Hello, World!", "pattern": "Hello, (.+)", "group": 2},
|
||||
{"text": "Hello, World!", "pattern": "hello,", "case_sensitive": False},
|
||||
{
|
||||
"text": "Hello, World!! Hello, Earth!!",
|
||||
"pattern": "Hello, (\\S+)",
|
||||
"group": 1,
|
||||
"find_all": False,
|
||||
},
|
||||
{
|
||||
"text": "Hello, World!! Hello, Earth!!",
|
||||
"pattern": "Hello, (\\S+)",
|
||||
"group": 1,
|
||||
"find_all": True,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("positive", "World!"),
|
||||
("positive", "Hello, World!"),
|
||||
("negative", "Hello, World!"),
|
||||
("positive", "Hello,"),
|
||||
("positive", "World!!"),
|
||||
("positive", "World!!"),
|
||||
("positive", "Earth!!"),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -107,15 +121,21 @@ class ExtractTextInformationBlock(Block):
|
||||
flags = flags | re.DOTALL
|
||||
|
||||
if isinstance(input_data.text, str):
|
||||
text = input_data.text
|
||||
txt = input_data.text
|
||||
else:
|
||||
text = json.dumps(input_data.text)
|
||||
txt = json.dumps(input_data.text)
|
||||
|
||||
match = re.search(input_data.pattern, text, flags)
|
||||
if match and input_data.group <= len(match.groups()):
|
||||
yield "positive", match.group(input_data.group)
|
||||
else:
|
||||
yield "negative", text
|
||||
matches = [
|
||||
match.group(input_data.group)
|
||||
for match in re.finditer(input_data.pattern, txt, flags)
|
||||
if input_data.group <= len(match.groups())
|
||||
]
|
||||
for match in matches:
|
||||
yield "positive", match
|
||||
if not input_data.find_all:
|
||||
return
|
||||
if not matches:
|
||||
yield "negative", input_data.text
|
||||
|
||||
|
||||
class FillTextTemplateBlock(Block):
|
||||
@@ -146,19 +166,20 @@ class FillTextTemplateBlock(Block):
|
||||
"values": {"list": ["Hello", " World!"]},
|
||||
"format": "{% for item in list %}{{ item }}{% endfor %}",
|
||||
},
|
||||
{
|
||||
"values": {},
|
||||
"format": "{% set name = 'Alice' %}Hello, World! {{ name }}",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("output", "Hello, World! Alice"),
|
||||
("output", "Hello World!"),
|
||||
("output", "Hello, World! Alice"),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# For python.format compatibility: replace all {...} with {{..}}.
|
||||
# But avoid replacing {{...}} to {{{...}}}.
|
||||
fmt = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", input_data.format)
|
||||
template = jinja.from_string(fmt)
|
||||
yield "output", template.render(**input_data.values)
|
||||
yield "output", formatter.format_string(input_data.format, input_data.values)
|
||||
|
||||
|
||||
class CombineTextsBlock(Block):
|
||||
|
||||
@@ -9,6 +9,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -38,10 +39,8 @@ class UnrealTextToSpeechBlock(Block):
|
||||
default="Scarlett",
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal["unreal_speech"], Literal["api_key"]
|
||||
Literal[ProviderName.UNREAL_SPEECH], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
provider="unreal_speech",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Unreal Speech integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
|
||||
@@ -42,6 +42,7 @@ class BlockType(Enum):
|
||||
OUTPUT = "Output"
|
||||
NOTE = "Note"
|
||||
WEBHOOK = "Webhook"
|
||||
WEBHOOK_MANUAL = "Webhook (manual)"
|
||||
AGENT = "Agent"
|
||||
|
||||
|
||||
@@ -57,6 +58,7 @@ class BlockCategory(Enum):
|
||||
COMMUNICATION = "Block that interacts with communication platforms."
|
||||
DEVELOPER_TOOLS = "Developer tools such as GitHub blocks."
|
||||
DATA = "Block that interacts with structured data."
|
||||
HARDWARE = "Block that interacts with hardware."
|
||||
AGENT = "Block that interacts with other agents."
|
||||
CRM = "Block that interacts with CRM services."
|
||||
|
||||
@@ -65,7 +67,7 @@ class BlockCategory(Enum):
|
||||
|
||||
|
||||
class BlockSchema(BaseModel):
|
||||
cached_jsonschema: ClassVar[dict[str, Any]] = {}
|
||||
cached_jsonschema: ClassVar[dict[str, Any]]
|
||||
|
||||
@classmethod
|
||||
def jsonschema(cls) -> dict[str, Any]:
|
||||
@@ -90,6 +92,7 @@ class BlockSchema(BaseModel):
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [ref_to_dict(item) for item in obj]
|
||||
|
||||
return obj
|
||||
|
||||
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||
@@ -145,6 +148,10 @@ class BlockSchema(BaseModel):
|
||||
- A field that is called `credentials` MUST be a `CredentialsMetaInput`.
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
|
||||
# Reset cached JSON schema to prevent inheriting it from parent class
|
||||
cls.cached_jsonschema = {}
|
||||
|
||||
credentials_fields = [
|
||||
field_name
|
||||
for field_name, info in cls.model_fields.items()
|
||||
@@ -176,6 +183,11 @@ class BlockSchema(BaseModel):
|
||||
f"Field 'credentials' on {cls.__qualname__} "
|
||||
f"must be of type {CredentialsMetaInput.__name__}"
|
||||
)
|
||||
if credentials_field := cls.model_fields.get(CREDENTIALS_FIELD_NAME):
|
||||
credentials_input_type = cast(
|
||||
CredentialsMetaInput, credentials_field.annotation
|
||||
)
|
||||
credentials_input_type.validate_credentials_field_schema(cls)
|
||||
|
||||
|
||||
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema)
|
||||
@@ -187,7 +199,12 @@ class EmptySchema(BlockSchema):
|
||||
|
||||
|
||||
# --8<-- [start:BlockWebhookConfig]
|
||||
class BlockWebhookConfig(BaseModel):
|
||||
class BlockManualWebhookConfig(BaseModel):
|
||||
"""
|
||||
Configuration model for webhook-triggered blocks on which
|
||||
the user has to manually set up the webhook at the provider.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
"""The service provider that the webhook connects to"""
|
||||
|
||||
@@ -198,6 +215,27 @@ class BlockWebhookConfig(BaseModel):
|
||||
Only for use in the corresponding `WebhooksManager`.
|
||||
"""
|
||||
|
||||
event_filter_input: str = ""
|
||||
"""
|
||||
Name of the block's event filter input.
|
||||
Leave empty if the corresponding webhook doesn't have distinct event/payload types.
|
||||
"""
|
||||
|
||||
event_format: str = "{event}"
|
||||
"""
|
||||
Template string for the event(s) that a block instance subscribes to.
|
||||
Applied individually to each event selected in the event filter input.
|
||||
|
||||
Example: `"pull_request.{event}"` -> `"pull_request.opened"`
|
||||
"""
|
||||
|
||||
|
||||
class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||
"""
|
||||
Configuration model for webhook-triggered blocks for which
|
||||
the webhook can be automatically set up through the provider's API.
|
||||
"""
|
||||
|
||||
resource_format: str
|
||||
"""
|
||||
Template string for the resource that a block instance subscribes to.
|
||||
@@ -207,17 +245,6 @@ class BlockWebhookConfig(BaseModel):
|
||||
|
||||
Only for use in the corresponding `WebhooksManager`.
|
||||
"""
|
||||
|
||||
event_filter_input: str
|
||||
"""Name of the block's event filter input."""
|
||||
|
||||
event_format: str = "{event}"
|
||||
"""
|
||||
Template string for the event(s) that a block instance subscribes to.
|
||||
Applied individually to each event selected in the event filter input.
|
||||
|
||||
Example: `"pull_request.{event}"` -> `"pull_request.opened"`
|
||||
"""
|
||||
# --8<-- [end:BlockWebhookConfig]
|
||||
|
||||
|
||||
@@ -237,7 +264,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
disabled: bool = False,
|
||||
static_output: bool = False,
|
||||
block_type: BlockType = BlockType.STANDARD,
|
||||
webhook_config: Optional[BlockWebhookConfig] = None,
|
||||
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the block with the given schema.
|
||||
@@ -268,27 +295,38 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.contributors = contributors or set()
|
||||
self.disabled = disabled
|
||||
self.static_output = static_output
|
||||
self.block_type = block_type if not webhook_config else BlockType.WEBHOOK
|
||||
self.block_type = block_type
|
||||
self.webhook_config = webhook_config
|
||||
self.execution_stats = {}
|
||||
|
||||
if self.webhook_config:
|
||||
# Enforce shape of webhook event filter
|
||||
event_filter_field = self.input_schema.model_fields[
|
||||
self.webhook_config.event_filter_input
|
||||
]
|
||||
if not (
|
||||
isinstance(event_filter_field.annotation, type)
|
||||
and issubclass(event_filter_field.annotation, BaseModel)
|
||||
and all(
|
||||
field.annotation is bool
|
||||
for field in event_filter_field.annotation.model_fields.values()
|
||||
)
|
||||
):
|
||||
raise NotImplementedError(
|
||||
f"{self.name} has an invalid webhook event selector: "
|
||||
"field must be a BaseModel and all its fields must be boolean"
|
||||
)
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
# Enforce presence of credentials field on auto-setup webhook blocks
|
||||
if CREDENTIALS_FIELD_NAME not in self.input_schema.model_fields:
|
||||
raise TypeError(
|
||||
"credentials field is required on auto-setup webhook blocks"
|
||||
)
|
||||
self.block_type = BlockType.WEBHOOK
|
||||
else:
|
||||
self.block_type = BlockType.WEBHOOK_MANUAL
|
||||
|
||||
# Enforce shape of webhook event filter, if present
|
||||
if self.webhook_config.event_filter_input:
|
||||
event_filter_field = self.input_schema.model_fields[
|
||||
self.webhook_config.event_filter_input
|
||||
]
|
||||
if not (
|
||||
isinstance(event_filter_field.annotation, type)
|
||||
and issubclass(event_filter_field.annotation, BaseModel)
|
||||
and all(
|
||||
field.annotation is bool
|
||||
for field in event_filter_field.annotation.model_fields.values()
|
||||
)
|
||||
):
|
||||
raise NotImplementedError(
|
||||
f"{self.name} has an invalid webhook event selector: "
|
||||
"field must be a BaseModel and all its fields must be boolean"
|
||||
)
|
||||
|
||||
# Enforce presence of 'payload' input
|
||||
if "payload" not in self.input_schema.model_fields:
|
||||
|
||||
@@ -53,8 +53,8 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_405B: 1,
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
LlmModel.GEMINI_FLASH_1_5_8B: 1,
|
||||
LlmModel.GEMINI_FLASH_1_5_EXP: 1,
|
||||
LlmModel.GROK_BETA: 5,
|
||||
LlmModel.MISTRAL_NEMO: 1,
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: 1,
|
||||
@@ -62,6 +62,14 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.EVA_QWEN_2_5_32B: 1,
|
||||
LlmModel.DEEPSEEK_CHAT: 2,
|
||||
LlmModel.PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE: 1,
|
||||
LlmModel.QWEN_QWQ_32B_PREVIEW: 2,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AMAZON_NOVA_LITE_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: 1,
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
|
||||
@@ -2,9 +2,9 @@ from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from prisma import Json
|
||||
from prisma.enums import UserBlockCreditType
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import UserBlockCredit
|
||||
from prisma.models import CreditTransaction
|
||||
|
||||
from backend.data.block import Block, BlockInput, get_block
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
@@ -76,7 +76,7 @@ class UserCredit(UserCreditBase):
|
||||
else cur_month.replace(year=cur_month.year + 1, month=1)
|
||||
)
|
||||
|
||||
user_credit = await UserBlockCredit.prisma().group_by(
|
||||
user_credit = await CreditTransaction.prisma().group_by(
|
||||
by=["userId"],
|
||||
sum={"amount": True},
|
||||
where={
|
||||
@@ -93,10 +93,10 @@ class UserCredit(UserCreditBase):
|
||||
key = f"MONTHLY-CREDIT-TOP-UP-{cur_month}"
|
||||
|
||||
try:
|
||||
await UserBlockCredit.prisma().create(
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"amount": self.num_user_credits_refill,
|
||||
"type": UserBlockCreditType.TOP_UP,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"userId": user_id,
|
||||
"transactionKey": key,
|
||||
"createdAt": self.time_now(),
|
||||
@@ -184,11 +184,11 @@ class UserCredit(UserCreditBase):
|
||||
if validate_balance and user_credit < cost:
|
||||
raise ValueError(f"Insufficient credit: {user_credit} < {cost}")
|
||||
|
||||
await UserBlockCredit.prisma().create(
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"amount": -cost,
|
||||
"type": UserBlockCreditType.USAGE,
|
||||
"type": CreditTransactionType.USAGE,
|
||||
"blockId": block.id,
|
||||
"metadata": Json(
|
||||
{
|
||||
@@ -202,11 +202,11 @@ class UserCredit(UserCreditBase):
|
||||
return cost
|
||||
|
||||
async def top_up_credits(self, user_id: str, amount: int):
|
||||
await UserBlockCredit.prisma().create(
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"amount": amount,
|
||||
"type": UserBlockCreditType.TOP_UP,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"createdAt": self.time_now(),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -29,6 +29,13 @@ async def connect():
|
||||
if not prisma.is_connected():
|
||||
raise ConnectionError("Failed to connect to Prisma.")
|
||||
|
||||
# Connection acquired from a pool like Supabase somehow still possibly allows
|
||||
# the db client obtains a connection but still reject query connection afterward.
|
||||
try:
|
||||
await prisma.execute_raw("SELECT 1")
|
||||
except Exception as e:
|
||||
raise ConnectionError("Failed to connect to Prisma.") from e
|
||||
|
||||
|
||||
@conn_retry("Prisma", "Releasing connection")
|
||||
async def disconnect():
|
||||
|
||||
@@ -9,7 +9,6 @@ from prisma.models import (
|
||||
AgentNodeExecution,
|
||||
AgentNodeExecutionInputOutput,
|
||||
)
|
||||
from prisma.types import AgentGraphExecutionWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
|
||||
@@ -19,14 +18,14 @@ from backend.util import json, mock
|
||||
from backend.util.settings import Config
|
||||
|
||||
|
||||
class GraphExecution(BaseModel):
|
||||
class GraphExecutionEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
start_node_execs: list["NodeExecution"]
|
||||
start_node_execs: list["NodeExecutionEntry"]
|
||||
|
||||
|
||||
class NodeExecution(BaseModel):
|
||||
class NodeExecutionEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
@@ -325,34 +324,6 @@ async def update_execution_status(
|
||||
return ExecutionResult.from_db(res)
|
||||
|
||||
|
||||
async def get_graph_execution(
|
||||
graph_exec_id: str, user_id: str
|
||||
) -> AgentGraphExecution | None:
|
||||
"""
|
||||
Retrieve a specific graph execution by its ID.
|
||||
|
||||
Args:
|
||||
graph_exec_id (str): The ID of the graph execution to retrieve.
|
||||
user_id (str): The ID of the user to whom the graph (execution) belongs.
|
||||
|
||||
Returns:
|
||||
AgentGraphExecution | None: The graph execution if found, None otherwise.
|
||||
"""
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": graph_exec_id, "userId": user_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return execution
|
||||
|
||||
|
||||
async def list_executions(graph_id: str, graph_version: int | None = None) -> list[str]:
|
||||
where: AgentGraphExecutionWhereInput = {"agentGraphId": graph_id}
|
||||
if graph_version is not None:
|
||||
where["agentGraphVersion"] = graph_version
|
||||
executions = await AgentGraphExecution.prisma().find_many(where=where)
|
||||
return [execution.id for execution in executions]
|
||||
|
||||
|
||||
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where={"agentGraphExecutionId": graph_exec_id},
|
||||
|
||||
@@ -84,6 +84,8 @@ class NodeModel(Node):
|
||||
raise ValueError(f"Block #{self.block_id} not found for node #{self.id}")
|
||||
if not block.webhook_config:
|
||||
raise TypeError("This method can't be used on non-webhook blocks")
|
||||
if not block.webhook_config.event_filter_input:
|
||||
return True
|
||||
event_filter = self.input_default.get(block.webhook_config.event_filter_input)
|
||||
if not event_filter:
|
||||
raise ValueError(f"Event filter is not configured on node #{self.id}")
|
||||
@@ -105,6 +107,8 @@ class GraphExecution(BaseDbModel):
|
||||
duration: float
|
||||
total_run_time: float
|
||||
status: ExecutionStatus
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
|
||||
@staticmethod
|
||||
def from_db(execution: AgentGraphExecution):
|
||||
@@ -130,6 +134,8 @@ class GraphExecution(BaseDbModel):
|
||||
duration=duration,
|
||||
total_run_time=total_run_time,
|
||||
status=ExecutionStatus(execution.executionStatus),
|
||||
graph_id=execution.agentGraphId,
|
||||
graph_version=execution.agentGraphVersion,
|
||||
)
|
||||
|
||||
|
||||
@@ -139,7 +145,6 @@ class Graph(BaseDbModel):
|
||||
is_template: bool = False
|
||||
name: str
|
||||
description: str
|
||||
executions: list[GraphExecution] = []
|
||||
nodes: list[Node] = []
|
||||
links: list[Link] = []
|
||||
|
||||
@@ -253,7 +258,7 @@ class GraphModel(Graph):
|
||||
for link in self.links:
|
||||
input_links[link.sink_id].append(link)
|
||||
|
||||
# Nodes: required fields are filled or connected
|
||||
# Nodes: required fields are filled or connected and dependencies are satisfied
|
||||
for node in self.nodes:
|
||||
block = get_block(node.block_id)
|
||||
if block is None:
|
||||
@@ -264,16 +269,55 @@ class GraphModel(Graph):
|
||||
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
|
||||
)
|
||||
for name in block.input_schema.get_required_fields():
|
||||
if name not in provided_inputs and (
|
||||
for_run # Skip input completion validation, unless when executing.
|
||||
or block.block_type == BlockType.INPUT
|
||||
or block.block_type == BlockType.OUTPUT
|
||||
or block.block_type == BlockType.AGENT
|
||||
if (
|
||||
name not in provided_inputs
|
||||
and not (
|
||||
name == "payload"
|
||||
and block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
)
|
||||
and (
|
||||
for_run # Skip input completion validation, unless when executing.
|
||||
or block.block_type == BlockType.INPUT
|
||||
or block.block_type == BlockType.OUTPUT
|
||||
or block.block_type == BlockType.AGENT
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} required input missing: `{name}`"
|
||||
)
|
||||
|
||||
# Get input schema properties and check dependencies
|
||||
input_schema = block.input_schema.model_fields
|
||||
required_fields = block.input_schema.get_required_fields()
|
||||
|
||||
def has_value(name):
|
||||
return (
|
||||
node is not None
|
||||
and name in node.input_default
|
||||
and node.input_default[name] is not None
|
||||
and str(node.input_default[name]).strip() != ""
|
||||
) or (name in input_schema and input_schema[name].default is not None)
|
||||
|
||||
# Validate dependencies between fields
|
||||
for field_name, field_info in input_schema.items():
|
||||
# Apply input dependency validation only on run & field with depends_on
|
||||
json_schema_extra = field_info.json_schema_extra or {}
|
||||
dependencies = json_schema_extra.get("depends_on", [])
|
||||
if not for_run or not dependencies:
|
||||
continue
|
||||
|
||||
# Check if dependent field has value in input_default
|
||||
field_has_value = has_value(field_name)
|
||||
field_is_required = field_name in required_fields
|
||||
|
||||
# Check for missing dependencies when dependent field is present
|
||||
missing_deps = [dep for dep in dependencies if not has_value(dep)]
|
||||
if missing_deps and (field_has_value or field_is_required):
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id}: Field `{field_name}` requires [{', '.join(missing_deps)}] to be set"
|
||||
)
|
||||
|
||||
node_map = {v.id: v for v in self.nodes}
|
||||
|
||||
def is_static_output_block(nid: str) -> bool:
|
||||
@@ -323,12 +367,7 @@ class GraphModel(Graph):
|
||||
link.is_static = True # Each value block output should be static.
|
||||
|
||||
@staticmethod
|
||||
def from_db(graph: AgentGraph, hide_credentials: bool = False):
|
||||
executions = [
|
||||
GraphExecution.from_db(execution)
|
||||
for execution in graph.AgentGraphExecution or []
|
||||
]
|
||||
|
||||
def from_db(graph: AgentGraph, for_export: bool = False):
|
||||
return GraphModel(
|
||||
id=graph.id,
|
||||
user_id=graph.userId,
|
||||
@@ -337,9 +376,8 @@ class GraphModel(Graph):
|
||||
is_template=graph.isTemplate,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
executions=executions,
|
||||
nodes=[
|
||||
GraphModel._process_node(node, hide_credentials)
|
||||
NodeModel.from_db(GraphModel._process_node(node, for_export))
|
||||
for node in graph.AgentNodes or []
|
||||
],
|
||||
links=list(
|
||||
@@ -352,23 +390,29 @@ class GraphModel(Graph):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_node(node: AgentNode, hide_credentials: bool) -> NodeModel:
|
||||
node_dict = {field: getattr(node, field) for field in node.model_fields}
|
||||
if hide_credentials and "constantInput" in node_dict:
|
||||
constant_input = json.loads(
|
||||
node_dict["constantInput"], target_type=dict[str, Any]
|
||||
)
|
||||
constant_input = GraphModel._hide_credentials_in_input(constant_input)
|
||||
node_dict["constantInput"] = json.dumps(constant_input)
|
||||
return NodeModel.from_db(AgentNode(**node_dict))
|
||||
def _process_node(node: AgentNode, for_export: bool) -> AgentNode:
|
||||
if for_export:
|
||||
# Remove credentials from node input
|
||||
if node.constantInput:
|
||||
constant_input = json.loads(
|
||||
node.constantInput, target_type=dict[str, Any]
|
||||
)
|
||||
constant_input = GraphModel._hide_node_input_credentials(constant_input)
|
||||
node.constantInput = json.dumps(constant_input)
|
||||
|
||||
# Remove webhook info
|
||||
node.webhookId = None
|
||||
node.Webhook = None
|
||||
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def _hide_credentials_in_input(input_data: dict[str, Any]) -> dict[str, Any]:
|
||||
def _hide_node_input_credentials(input_data: dict[str, Any]) -> dict[str, Any]:
|
||||
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
|
||||
result = {}
|
||||
for key, value in input_data.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = GraphModel._hide_credentials_in_input(value)
|
||||
result[key] = GraphModel._hide_node_input_credentials(value)
|
||||
elif isinstance(value, str) and any(
|
||||
sensitive_key in key.lower() for sensitive_key in sensitive_keys
|
||||
):
|
||||
@@ -407,7 +451,6 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
|
||||
|
||||
async def get_graphs(
|
||||
user_id: str,
|
||||
include_executions: bool = False,
|
||||
filter_by: Literal["active", "template"] | None = "active",
|
||||
) -> list[GraphModel]:
|
||||
"""
|
||||
@@ -415,33 +458,50 @@ async def get_graphs(
|
||||
Default behaviour is to get all currently active graphs.
|
||||
|
||||
Args:
|
||||
include_executions: Whether to include executions in the graph metadata.
|
||||
filter_by: An optional filter to either select templates or active graphs.
|
||||
user_id: The ID of the user that owns the graph.
|
||||
|
||||
Returns:
|
||||
list[GraphModel]: A list of objects representing the retrieved graphs.
|
||||
"""
|
||||
where_clause: AgentGraphWhereInput = {}
|
||||
where_clause: AgentGraphWhereInput = {"userId": user_id}
|
||||
|
||||
if filter_by == "active":
|
||||
where_clause["isActive"] = True
|
||||
elif filter_by == "template":
|
||||
where_clause["isTemplate"] = True
|
||||
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
graph_include = AGENT_GRAPH_INCLUDE
|
||||
graph_include["AgentGraphExecution"] = include_executions
|
||||
|
||||
graphs = await AgentGraph.prisma().find_many(
|
||||
where=where_clause,
|
||||
distinct=["id"],
|
||||
order={"version": "desc"},
|
||||
include=graph_include,
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
|
||||
return [GraphModel.from_db(graph) for graph in graphs]
|
||||
graph_models = []
|
||||
for graph in graphs:
|
||||
try:
|
||||
graph_models.append(GraphModel.from_db(graph))
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing graph {graph.id}: {e}")
|
||||
continue
|
||||
|
||||
return graph_models
|
||||
|
||||
|
||||
async def get_executions(user_id: str) -> list[GraphExecution]:
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [GraphExecution.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": execution_id, "userId": user_id}
|
||||
)
|
||||
return GraphExecution.from_db(execution) if execution else None
|
||||
|
||||
|
||||
async def get_graph(
|
||||
@@ -449,7 +509,7 @@ async def get_graph(
|
||||
version: int | None = None,
|
||||
template: bool = False,
|
||||
user_id: str | None = None,
|
||||
hide_credentials: bool = False,
|
||||
for_export: bool = False,
|
||||
) -> GraphModel | None:
|
||||
"""
|
||||
Retrieves a graph from the DB.
|
||||
@@ -475,7 +535,7 @@ async def get_graph(
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
order={"version": "desc"},
|
||||
)
|
||||
return GraphModel.from_db(graph, hide_credentials) if graph else None
|
||||
return GraphModel.from_db(graph, for_export) if graph else None
|
||||
|
||||
|
||||
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
|
||||
|
||||
@@ -3,10 +3,12 @@ from typing import TYPE_CHECKING, AsyncGenerator, Optional
|
||||
|
||||
from prisma import Json
|
||||
from prisma.models import IntegrationWebhook
|
||||
from pydantic import Field
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
from backend.data.queue import AsyncRedisEventBus
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
|
||||
from .db import BaseDbModel
|
||||
|
||||
@@ -18,7 +20,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class Webhook(BaseDbModel):
|
||||
user_id: str
|
||||
provider: str
|
||||
provider: ProviderName
|
||||
credentials_id: str
|
||||
webhook_type: str
|
||||
resource: str
|
||||
@@ -30,6 +32,11 @@ class Webhook(BaseDbModel):
|
||||
|
||||
attached_nodes: Optional[list["NodeModel"]] = None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return webhook_ingress_url(self.provider.value, self.id)
|
||||
|
||||
@staticmethod
|
||||
def from_db(webhook: IntegrationWebhook):
|
||||
from .graph import NodeModel
|
||||
@@ -37,7 +44,7 @@ class Webhook(BaseDbModel):
|
||||
return Webhook(
|
||||
id=webhook.id,
|
||||
user_id=webhook.userId,
|
||||
provider=webhook.provider,
|
||||
provider=ProviderName(webhook.provider),
|
||||
credentials_id=webhook.credentialsId,
|
||||
webhook_type=webhook.webhookType,
|
||||
resource=webhook.resource,
|
||||
@@ -61,7 +68,7 @@ async def create_webhook(webhook: Webhook) -> Webhook:
|
||||
data={
|
||||
"id": webhook.id,
|
||||
"userId": webhook.user_id,
|
||||
"provider": webhook.provider,
|
||||
"provider": webhook.provider.value,
|
||||
"credentialsId": webhook.credentials_id,
|
||||
"webhookType": webhook.webhook_type,
|
||||
"resource": webhook.resource,
|
||||
@@ -83,8 +90,10 @@ async def get_webhook(webhook_id: str) -> Webhook:
|
||||
return Webhook.from_db(webhook)
|
||||
|
||||
|
||||
async def get_all_webhooks(credentials_id: str) -> list[Webhook]:
|
||||
async def get_all_webhooks_by_creds(credentials_id: str) -> list[Webhook]:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
if not credentials_id:
|
||||
raise ValueError("credentials_id must not be empty")
|
||||
webhooks = await IntegrationWebhook.prisma().find_many(
|
||||
where={"credentialsId": credentials_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
@@ -92,7 +101,7 @@ async def get_all_webhooks(credentials_id: str) -> list[Webhook]:
|
||||
return [Webhook.from_db(webhook) for webhook in webhooks]
|
||||
|
||||
|
||||
async def find_webhook(
|
||||
async def find_webhook_by_credentials_and_props(
|
||||
credentials_id: str, webhook_type: str, resource: str, events: list[str]
|
||||
) -> Webhook | None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
@@ -108,6 +117,22 @@ async def find_webhook(
|
||||
return Webhook.from_db(webhook) if webhook else None
|
||||
|
||||
|
||||
async def find_webhook_by_graph_and_props(
|
||||
graph_id: str, provider: str, webhook_type: str, events: list[str]
|
||||
) -> Webhook | None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
webhook = await IntegrationWebhook.prisma().find_first(
|
||||
where={
|
||||
"provider": provider,
|
||||
"webhookType": webhook_type,
|
||||
"events": {"has_every": events},
|
||||
"AgentNodes": {"some": {"agentGraphId": graph_id}},
|
||||
},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
)
|
||||
return Webhook.from_db(webhook) if webhook else None
|
||||
|
||||
|
||||
async def update_webhook_config(webhook_id: str, updated_config: dict) -> Webhook:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
_updated_webhook = await IntegrationWebhook.prisma().update(
|
||||
@@ -144,25 +169,28 @@ class WebhookEventBus(AsyncRedisEventBus[WebhookEvent]):
|
||||
def event_bus_name(self) -> str:
|
||||
return "webhooks"
|
||||
|
||||
async def publish(self, event: WebhookEvent):
|
||||
await self.publish_event(event, f"{event.webhook_id}/{event.event_type}")
|
||||
|
||||
async def listen(
|
||||
self, webhook_id: str, event_type: Optional[str] = None
|
||||
) -> AsyncGenerator[WebhookEvent, None]:
|
||||
async for event in self.listen_events(f"{webhook_id}/{event_type or '*'}"):
|
||||
yield event
|
||||
|
||||
|
||||
event_bus = WebhookEventBus()
|
||||
_webhook_event_bus = WebhookEventBus()
|
||||
|
||||
|
||||
async def publish_webhook_event(event: WebhookEvent):
|
||||
await event_bus.publish(event)
|
||||
await _webhook_event_bus.publish_event(
|
||||
event, f"{event.webhook_id}/{event.event_type}"
|
||||
)
|
||||
|
||||
|
||||
async def listen_for_webhook_event(
|
||||
async def listen_for_webhook_events(
|
||||
webhook_id: str, event_type: Optional[str] = None
|
||||
) -> AsyncGenerator[WebhookEvent, None]:
|
||||
async for event in _webhook_event_bus.listen_events(
|
||||
f"{webhook_id}/{event_type or '*'}"
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
async def wait_for_webhook_event(
|
||||
webhook_id: str, event_type: Optional[str] = None, timeout: Optional[float] = None
|
||||
) -> WebhookEvent | None:
|
||||
async for event in event_bus.listen(webhook_id, event_type):
|
||||
return event # Only one event is expected
|
||||
return await _webhook_event_bus.wait_for_event(
|
||||
f"{webhook_id}/{event_type or '*'}", timeout
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
@@ -11,19 +12,32 @@ from typing import (
|
||||
Optional,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
get_args,
|
||||
)
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, GetCoreSchemaHandler, SecretStr, field_serializer
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
GetCoreSchemaHandler,
|
||||
SecretStr,
|
||||
field_serializer,
|
||||
)
|
||||
from pydantic_core import (
|
||||
CoreSchema,
|
||||
PydanticUndefined,
|
||||
PydanticUndefinedType,
|
||||
ValidationError,
|
||||
core_schema,
|
||||
)
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
T = TypeVar("T")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -124,6 +138,7 @@ def SchemaField(
|
||||
secret: bool = False,
|
||||
exclude: bool = False,
|
||||
hidden: Optional[bool] = None,
|
||||
depends_on: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
json_extra = {
|
||||
@@ -133,6 +148,7 @@ def SchemaField(
|
||||
"secret": secret,
|
||||
"advanced": advanced,
|
||||
"hidden": hidden,
|
||||
"depends_on": depends_on,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
@@ -146,7 +162,7 @@ def SchemaField(
|
||||
exclude=exclude,
|
||||
json_schema_extra=json_extra,
|
||||
**kwargs,
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class _BaseCredentials(BaseModel):
|
||||
@@ -220,7 +236,7 @@ class UserIntegrations(BaseModel):
|
||||
oauth_states: list[OAuthState] = Field(default_factory=list)
|
||||
|
||||
|
||||
CP = TypeVar("CP", bound=str)
|
||||
CP = TypeVar("CP", bound=ProviderName)
|
||||
CT = TypeVar("CT", bound=CredentialsType)
|
||||
|
||||
|
||||
@@ -233,19 +249,51 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
provider: CP
|
||||
type: CT
|
||||
|
||||
@staticmethod
|
||||
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
|
||||
schema["credentials_provider"] = get_args(
|
||||
cls.model_fields["provider"].annotation
|
||||
)
|
||||
schema["credentials_types"] = get_args(cls.model_fields["type"].annotation)
|
||||
|
||||
class CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=_add_json_schema_extra, # type: ignore
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
|
||||
"""Validates the schema of a `credentials` field"""
|
||||
field_schema = model.jsonschema()["properties"][CREDENTIALS_FIELD_NAME]
|
||||
try:
|
||||
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
|
||||
field_schema
|
||||
)
|
||||
except ValidationError as e:
|
||||
if "Field required [type=missing" not in str(e):
|
||||
raise
|
||||
|
||||
raise TypeError(
|
||||
"Field 'credentials' JSON schema lacks required extra items: "
|
||||
f"{field_schema}"
|
||||
) from e
|
||||
|
||||
if (
|
||||
len(schema_extra.credentials_provider) > 1
|
||||
and not schema_extra.discriminator
|
||||
):
|
||||
raise TypeError("Multi-provider CredentialsField requires discriminator!")
|
||||
|
||||
|
||||
class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
|
||||
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
|
||||
credentials_provider: list[CP]
|
||||
credentials_scopes: Optional[list[str]]
|
||||
credentials_scopes: Optional[list[str]] = None
|
||||
credentials_types: list[CT]
|
||||
discriminator: Optional[str] = None
|
||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||
|
||||
|
||||
def CredentialsField(
|
||||
provider: CP | list[CP],
|
||||
supported_credential_types: set[CT],
|
||||
required_scopes: set[str] = set(),
|
||||
*,
|
||||
discriminator: Optional[str] = None,
|
||||
@@ -253,26 +301,26 @@ def CredentialsField(
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> CredentialsMetaInput[CP, CT]:
|
||||
) -> CredentialsMetaInput:
|
||||
"""
|
||||
`CredentialsField` must and can only be used on fields named `credentials`.
|
||||
This is enforced by the `BlockSchema` base class.
|
||||
"""
|
||||
if not isinstance(provider, str) and len(provider) > 1 and not discriminator:
|
||||
raise TypeError("Multi-provider CredentialsField requires discriminator!")
|
||||
|
||||
field_schema_extra = CredentialsFieldSchemaExtra[CP, CT](
|
||||
credentials_provider=[provider] if isinstance(provider, str) else provider,
|
||||
credentials_scopes=list(required_scopes) or None, # omit if empty
|
||||
credentials_types=list(supported_credential_types),
|
||||
discriminator=discriminator,
|
||||
discriminator_mapping=discriminator_mapping,
|
||||
)
|
||||
field_schema_extra = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"credentials_scopes": list(required_scopes) or None,
|
||||
"discriminator": discriminator,
|
||||
"discriminator_mapping": discriminator_mapping,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
return Field(
|
||||
title=title,
|
||||
description=description,
|
||||
json_schema_extra=field_schema_extra.model_dump(exclude_none=True),
|
||||
json_schema_extra=field_schema_extra, # validated on BlockSchema init
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Generator, Generic, TypeVar
|
||||
from typing import Any, AsyncGenerator, Generator, Generic, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio.client import PubSub as AsyncPubSub
|
||||
@@ -48,12 +49,12 @@ class BaseRedisEventBus(Generic[M], ABC):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse event result from Redis {msg} {e}")
|
||||
|
||||
def _subscribe(
|
||||
def _get_pubsub_channel(
|
||||
self, connection: redis.Redis | redis.AsyncRedis, channel_key: str
|
||||
) -> tuple[PubSub | AsyncPubSub, str]:
|
||||
channel_name = f"{self.event_bus_name}/{channel_key}"
|
||||
full_channel_name = f"{self.event_bus_name}/{channel_key}"
|
||||
pubsub = connection.pubsub()
|
||||
return pubsub, channel_name
|
||||
return pubsub, full_channel_name
|
||||
|
||||
|
||||
class RedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
@@ -64,17 +65,19 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
return redis.get_redis()
|
||||
|
||||
def publish_event(self, event: M, channel_key: str):
|
||||
message, channel_name = self._serialize_message(event, channel_key)
|
||||
self.connection.publish(channel_name, message)
|
||||
message, full_channel_name = self._serialize_message(event, channel_key)
|
||||
self.connection.publish(full_channel_name, message)
|
||||
|
||||
def listen_events(self, channel_key: str) -> Generator[M, None, None]:
|
||||
pubsub, channel_name = self._subscribe(self.connection, channel_key)
|
||||
pubsub, full_channel_name = self._get_pubsub_channel(
|
||||
self.connection, channel_key
|
||||
)
|
||||
assert isinstance(pubsub, PubSub)
|
||||
|
||||
if "*" in channel_key:
|
||||
pubsub.psubscribe(channel_name)
|
||||
pubsub.psubscribe(full_channel_name)
|
||||
else:
|
||||
pubsub.subscribe(channel_name)
|
||||
pubsub.subscribe(full_channel_name)
|
||||
|
||||
for message in pubsub.listen():
|
||||
if event := self._deserialize_message(message, channel_key):
|
||||
@@ -89,19 +92,31 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
return await redis.get_redis_async()
|
||||
|
||||
async def publish_event(self, event: M, channel_key: str):
|
||||
message, channel_name = self._serialize_message(event, channel_key)
|
||||
message, full_channel_name = self._serialize_message(event, channel_key)
|
||||
connection = await self.connection
|
||||
await connection.publish(channel_name, message)
|
||||
await connection.publish(full_channel_name, message)
|
||||
|
||||
async def listen_events(self, channel_key: str) -> AsyncGenerator[M, None]:
|
||||
pubsub, channel_name = self._subscribe(await self.connection, channel_key)
|
||||
pubsub, full_channel_name = self._get_pubsub_channel(
|
||||
await self.connection, channel_key
|
||||
)
|
||||
assert isinstance(pubsub, AsyncPubSub)
|
||||
|
||||
if "*" in channel_key:
|
||||
await pubsub.psubscribe(channel_name)
|
||||
await pubsub.psubscribe(full_channel_name)
|
||||
else:
|
||||
await pubsub.subscribe(channel_name)
|
||||
await pubsub.subscribe(full_channel_name)
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if event := self._deserialize_message(message, channel_key):
|
||||
yield event
|
||||
|
||||
async def wait_for_event(
|
||||
self, channel_key: str, timeout: Optional[float] = None
|
||||
) -> M | None:
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
anext(aiter(self.listen_events(channel_key))), timeout
|
||||
)
|
||||
except TimeoutError:
|
||||
return None
|
||||
|
||||
@@ -25,8 +25,8 @@ from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionResult,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
NodeExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
merge_execution_input,
|
||||
parse_execution_output,
|
||||
)
|
||||
@@ -96,13 +96,13 @@ class LogMetadata:
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
ExecutionStream = Generator[NodeExecution, None, None]
|
||||
ExecutionStream = Generator[NodeExecutionEntry, None, None]
|
||||
|
||||
|
||||
def execute_node(
|
||||
db_client: "DatabaseManager",
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecution,
|
||||
data: NodeExecutionEntry,
|
||||
execution_stats: dict[str, Any] | None = None,
|
||||
) -> ExecutionStream:
|
||||
"""
|
||||
@@ -252,15 +252,15 @@ def _enqueue_next_nodes(
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
) -> list[NodeExecution]:
|
||||
) -> list[NodeExecutionEntry]:
|
||||
def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, data: BlockInput
|
||||
) -> NodeExecution:
|
||||
) -> NodeExecutionEntry:
|
||||
exec_update = db_client.update_execution_status(
|
||||
node_exec_id, ExecutionStatus.QUEUED, data
|
||||
)
|
||||
db_client.send_execution_update(exec_update)
|
||||
return NodeExecution(
|
||||
return NodeExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
@@ -269,7 +269,7 @@ def _enqueue_next_nodes(
|
||||
data=data,
|
||||
)
|
||||
|
||||
def register_next_executions(node_link: Link) -> list[NodeExecution]:
|
||||
def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
|
||||
enqueued_executions = []
|
||||
next_output_name = node_link.source_name
|
||||
next_input_name = node_link.sink_name
|
||||
@@ -501,8 +501,8 @@ class Executor:
|
||||
@error_logged
|
||||
def on_node_execution(
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
q: ExecutionQueue[NodeExecutionEntry],
|
||||
node_exec: NodeExecutionEntry,
|
||||
) -> dict[str, Any]:
|
||||
log_metadata = LogMetadata(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -529,8 +529,8 @@ class Executor:
|
||||
@time_measured
|
||||
def _on_node_execution(
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
q: ExecutionQueue[NodeExecutionEntry],
|
||||
node_exec: NodeExecutionEntry,
|
||||
log_metadata: LogMetadata,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
@@ -580,7 +580,9 @@ class Executor:
|
||||
|
||||
@classmethod
|
||||
@error_logged
|
||||
def on_graph_execution(cls, graph_exec: GraphExecution, cancel: threading.Event):
|
||||
def on_graph_execution(
|
||||
cls, graph_exec: GraphExecutionEntry, cancel: threading.Event
|
||||
):
|
||||
log_metadata = LogMetadata(
|
||||
user_id=graph_exec.user_id,
|
||||
graph_eid=graph_exec.graph_exec_id,
|
||||
@@ -605,7 +607,7 @@ class Executor:
|
||||
@time_measured
|
||||
def _on_graph_execution(
|
||||
cls,
|
||||
graph_exec: GraphExecution,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
log_metadata: LogMetadata,
|
||||
) -> tuple[dict[str, Any], Exception | None]:
|
||||
@@ -636,13 +638,13 @@ class Executor:
|
||||
cancel_thread.start()
|
||||
|
||||
try:
|
||||
queue = ExecutionQueue[NodeExecution]()
|
||||
queue = ExecutionQueue[NodeExecutionEntry]()
|
||||
for node_exec in graph_exec.start_node_execs:
|
||||
queue.add(node_exec)
|
||||
|
||||
running_executions: dict[str, AsyncResult] = {}
|
||||
|
||||
def make_exec_callback(exec_data: NodeExecution):
|
||||
def make_exec_callback(exec_data: NodeExecutionEntry):
|
||||
node_id = exec_data.node_id
|
||||
|
||||
def callback(result: object):
|
||||
@@ -717,7 +719,7 @@ class ExecutionManager(AppService):
|
||||
self.use_redis = True
|
||||
self.use_supabase = True
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecution]()
|
||||
self.queue = ExecutionQueue[GraphExecutionEntry]()
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
|
||||
@classmethod
|
||||
@@ -768,7 +770,7 @@ class ExecutionManager(AppService):
|
||||
data: BlockInput,
|
||||
user_id: str,
|
||||
graph_version: int | None = None,
|
||||
) -> GraphExecution:
|
||||
) -> GraphExecutionEntry:
|
||||
graph: GraphModel | None = self.db_client.get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
@@ -796,10 +798,13 @@ class ExecutionManager(AppService):
|
||||
# Extract webhook payload, and assign it to the input pin
|
||||
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
||||
if (
|
||||
block.block_type == BlockType.WEBHOOK
|
||||
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
and node.webhook_id
|
||||
and webhook_payload_key in data
|
||||
):
|
||||
if webhook_payload_key not in data:
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} webhook payload is missing"
|
||||
)
|
||||
input_data = {"payload": data[webhook_payload_key]}
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
@@ -818,7 +823,7 @@ class ExecutionManager(AppService):
|
||||
starting_node_execs = []
|
||||
for node_exec in node_execs:
|
||||
starting_node_execs.append(
|
||||
NodeExecution(
|
||||
NodeExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
@@ -832,7 +837,7 @@ class ExecutionManager(AppService):
|
||||
)
|
||||
self.db_client.send_execution_update(exec_update)
|
||||
|
||||
graph_exec = GraphExecution(
|
||||
graph_exec = GraphExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from autogpt_libs.utils.synchronize import RedisKeyedMutex
|
||||
from redis.lock import Lock as RedisLock
|
||||
@@ -8,10 +9,13 @@ from redis.lock import Lock as RedisLock
|
||||
from backend.data import redis
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.integrations.oauth import BaseOAuthHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
@@ -148,7 +152,7 @@ class IntegrationCredentialsManager:
|
||||
self.store.locks.release_all_locks()
|
||||
|
||||
|
||||
def _get_provider_oauth_handler(provider_name: str) -> BaseOAuthHandler:
|
||||
def _get_provider_oauth_handler(provider_name: str) -> "BaseOAuthHandler":
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
raise KeyError(f"Unknown provider '{provider_name}'")
|
||||
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from .base import BaseOAuthHandler
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
from .base import BaseOAuthHandler
|
||||
|
||||
# --8<-- [start:HANDLERS_BY_NAMEExample]
|
||||
HANDLERS_BY_NAME: dict[str, type[BaseOAuthHandler]] = {
|
||||
HANDLERS_BY_NAME: dict["ProviderName", type["BaseOAuthHandler"]] = {
|
||||
handler.PROVIDER_NAME: handler
|
||||
for handler in [
|
||||
GitHubOAuthHandler,
|
||||
|
||||
@@ -4,13 +4,14 @@ from abc import ABC, abstractmethod
|
||||
from typing import ClassVar
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseOAuthHandler(ABC):
|
||||
# --8<-- [start:BaseOAuthHandler1]
|
||||
PROVIDER_NAME: ClassVar[str]
|
||||
PROVIDER_NAME: ClassVar[ProviderName]
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
||||
# --8<-- [end:BaseOAuthHandler1]
|
||||
|
||||
@@ -76,6 +77,8 @@ class BaseOAuthHandler(ABC):
|
||||
"""Handles the default scopes for the provider"""
|
||||
# If scopes are empty, use the default scopes for the provider
|
||||
if not scopes:
|
||||
logger.debug(f"Using default scopes for provider {self.PROVIDER_NAME}")
|
||||
logger.debug(
|
||||
f"Using default scopes for provider {self.PROVIDER_NAME.value}"
|
||||
)
|
||||
scopes = self.DEFAULT_SCOPES
|
||||
return scopes
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import requests
|
||||
|
||||
from .base import BaseOAuthHandler
|
||||
@@ -23,7 +24,7 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
access token *with no refresh token*.
|
||||
""" # noqa
|
||||
|
||||
PROVIDER_NAME = "github"
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
|
||||
@@ -9,6 +9,7 @@ from google_auth_oauthlib.flow import Flow
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from .base import BaseOAuthHandler
|
||||
|
||||
@@ -21,7 +22,7 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
Based on the documentation at https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
""" # noqa
|
||||
|
||||
PROVIDER_NAME = "google"
|
||||
PROVIDER_NAME = ProviderName.GOOGLE
|
||||
EMAIL_ENDPOINT = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
DEFAULT_SCOPES = [
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
|
||||
@@ -2,6 +2,7 @@ from base64 import b64encode
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import requests
|
||||
|
||||
from .base import BaseOAuthHandler
|
||||
@@ -16,7 +17,7 @@ class NotionOAuthHandler(BaseOAuthHandler):
|
||||
- Notion doesn't use scopes
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = "notion"
|
||||
PROVIDER_NAME = ProviderName.NOTION
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
|
||||
@@ -1,7 +1,31 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# --8<-- [start:ProviderName]
|
||||
class ProviderName(str, Enum):
|
||||
ANTHROPIC = "anthropic"
|
||||
COMPASS = "compass"
|
||||
DISCORD = "discord"
|
||||
D_ID = "d_id"
|
||||
E2B = "e2b"
|
||||
EXA = "exa"
|
||||
FAL = "fal"
|
||||
GITHUB = "github"
|
||||
GOOGLE = "google"
|
||||
GOOGLE_MAPS = "google_maps"
|
||||
GROQ = "groq"
|
||||
HUBSPOT = "hubspot"
|
||||
IDEOGRAM = "ideogram"
|
||||
JINA = "jina"
|
||||
MEDIUM = "medium"
|
||||
NOTION = "notion"
|
||||
OLLAMA = "ollama"
|
||||
OPENAI = "openai"
|
||||
OPENWEATHERMAP = "openweathermap"
|
||||
OPEN_ROUTER = "open_router"
|
||||
PINECONE = "pinecone"
|
||||
REPLICATE = "replicate"
|
||||
REVID = "revid"
|
||||
SLANT3D = "slant3d"
|
||||
UNREAL_SPEECH = "unreal_speech"
|
||||
# --8<-- [end:ProviderName]
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .compass import CompassWebhookManager
|
||||
from .github import GithubWebhooksManager
|
||||
from .slant3d import Slant3DWebhooksManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base import BaseWebhooksManager
|
||||
from ..providers import ProviderName
|
||||
from ._base import BaseWebhooksManager
|
||||
|
||||
# --8<-- [start:WEBHOOK_MANAGERS_BY_NAME]
|
||||
WEBHOOK_MANAGERS_BY_NAME: dict[str, type["BaseWebhooksManager"]] = {
|
||||
WEBHOOK_MANAGERS_BY_NAME: dict["ProviderName", type["BaseWebhooksManager"]] = {
|
||||
handler.PROVIDER_NAME: handler
|
||||
for handler in [
|
||||
CompassWebhookManager,
|
||||
GithubWebhooksManager,
|
||||
Slant3DWebhooksManager,
|
||||
]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import secrets
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import ClassVar, Generic, TypeVar
|
||||
from typing import ClassVar, Generic, Optional, TypeVar
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import Request
|
||||
@@ -9,6 +9,8 @@ from strenum import StrEnum
|
||||
|
||||
from backend.data import integrations
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
from backend.util.settings import Config
|
||||
|
||||
@@ -20,12 +22,12 @@ WT = TypeVar("WT", bound=StrEnum)
|
||||
|
||||
class BaseWebhooksManager(ABC, Generic[WT]):
|
||||
# --8<-- [start:BaseWebhooksManager1]
|
||||
PROVIDER_NAME: ClassVar[str]
|
||||
PROVIDER_NAME: ClassVar[ProviderName]
|
||||
# --8<-- [end:BaseWebhooksManager1]
|
||||
|
||||
WebhookType: WT
|
||||
|
||||
async def get_suitable_webhook(
|
||||
async def get_suitable_auto_webhook(
|
||||
self,
|
||||
user_id: str,
|
||||
credentials: Credentials,
|
||||
@@ -38,16 +40,34 @@ class BaseWebhooksManager(ABC, Generic[WT]):
|
||||
"PLATFORM_BASE_URL must be set to use Webhook functionality"
|
||||
)
|
||||
|
||||
if webhook := await integrations.find_webhook(
|
||||
if webhook := await integrations.find_webhook_by_credentials_and_props(
|
||||
credentials.id, webhook_type, resource, events
|
||||
):
|
||||
return webhook
|
||||
return await self._create_webhook(
|
||||
user_id, credentials, webhook_type, resource, events
|
||||
user_id, webhook_type, events, resource, credentials
|
||||
)
|
||||
|
||||
async def get_manual_webhook(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
webhook_type: WT,
|
||||
events: list[str],
|
||||
):
|
||||
if current_webhook := await integrations.find_webhook_by_graph_and_props(
|
||||
graph_id, self.PROVIDER_NAME, webhook_type, events
|
||||
):
|
||||
return current_webhook
|
||||
return await self._create_webhook(
|
||||
user_id,
|
||||
webhook_type,
|
||||
events,
|
||||
register=False,
|
||||
)
|
||||
|
||||
async def prune_webhook_if_dangling(
|
||||
self, webhook_id: str, credentials: Credentials
|
||||
self, webhook_id: str, credentials: Optional[Credentials]
|
||||
) -> bool:
|
||||
webhook = await integrations.get_webhook(webhook_id)
|
||||
if webhook.attached_nodes is None:
|
||||
@@ -56,7 +76,8 @@ class BaseWebhooksManager(ABC, Generic[WT]):
|
||||
# Don't prune webhook if in use
|
||||
return False
|
||||
|
||||
await self._deregister_webhook(webhook, credentials)
|
||||
if credentials:
|
||||
await self._deregister_webhook(webhook, credentials)
|
||||
await integrations.delete_webhook(webhook.id)
|
||||
return True
|
||||
|
||||
@@ -81,7 +102,9 @@ class BaseWebhooksManager(ABC, Generic[WT]):
|
||||
# --8<-- [end:BaseWebhooksManager3]
|
||||
|
||||
# --8<-- [start:BaseWebhooksManager5]
|
||||
async def trigger_ping(self, webhook: integrations.Webhook) -> None:
|
||||
async def trigger_ping(
|
||||
self, webhook: integrations.Webhook, credentials: Credentials | None
|
||||
) -> None:
|
||||
"""
|
||||
Triggers a ping to the given webhook.
|
||||
|
||||
@@ -132,27 +155,36 @@ class BaseWebhooksManager(ABC, Generic[WT]):
|
||||
async def _create_webhook(
|
||||
self,
|
||||
user_id: str,
|
||||
credentials: Credentials,
|
||||
webhook_type: WT,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
resource: str = "",
|
||||
credentials: Optional[Credentials] = None,
|
||||
register: bool = True,
|
||||
) -> integrations.Webhook:
|
||||
if not app_config.platform_base_url:
|
||||
raise MissingConfigError(
|
||||
"PLATFORM_BASE_URL must be set to use Webhook functionality"
|
||||
)
|
||||
|
||||
id = str(uuid4())
|
||||
secret = secrets.token_hex(32)
|
||||
provider_name = self.PROVIDER_NAME
|
||||
ingress_url = (
|
||||
f"{app_config.platform_base_url}/api/integrations/{provider_name}"
|
||||
f"/webhooks/{id}/ingress"
|
||||
)
|
||||
provider_webhook_id, config = await self._register_webhook(
|
||||
credentials, webhook_type, resource, events, ingress_url, secret
|
||||
)
|
||||
ingress_url = webhook_ingress_url(provider_name=provider_name, webhook_id=id)
|
||||
if register:
|
||||
if not credentials:
|
||||
raise TypeError("credentials are required if register = True")
|
||||
provider_webhook_id, config = await self._register_webhook(
|
||||
credentials, webhook_type, resource, events, ingress_url, secret
|
||||
)
|
||||
else:
|
||||
provider_webhook_id, config = "", {}
|
||||
|
||||
return await integrations.create_webhook(
|
||||
integrations.Webhook(
|
||||
id=id,
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
credentials_id=credentials.id,
|
||||
credentials_id=credentials.id if credentials else "",
|
||||
webhook_type=webhook_type,
|
||||
resource=resource,
|
||||
events=events,
|
||||
@@ -0,0 +1,30 @@
|
||||
import logging
|
||||
|
||||
from backend.data import integrations
|
||||
from backend.data.model import APIKeyCredentials, Credentials, OAuth2Credentials
|
||||
|
||||
from ._base import WT, BaseWebhooksManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ManualWebhookManagerBase(BaseWebhooksManager[WT]):
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
webhook_type: WT,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> tuple[str, dict]:
|
||||
print(ingress_url) # FIXME: pass URL to user in front end
|
||||
|
||||
return "", {}
|
||||
|
||||
async def _deregister_webhook(
|
||||
self,
|
||||
webhook: integrations.Webhook,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
) -> None:
|
||||
pass
|
||||
@@ -0,0 +1,30 @@
|
||||
import logging
|
||||
|
||||
from fastapi import Request
|
||||
from strenum import StrEnum
|
||||
|
||||
from backend.data import integrations
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from ._manual_base import ManualWebhookManagerBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompassWebhookType(StrEnum):
|
||||
TRANSCRIPTION = "transcription"
|
||||
TASK = "task"
|
||||
|
||||
|
||||
class CompassWebhookManager(ManualWebhookManagerBase):
|
||||
PROVIDER_NAME = ProviderName.COMPASS
|
||||
WebhookType = CompassWebhookType
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(
|
||||
cls, webhook: integrations.Webhook, request: Request
|
||||
) -> tuple[dict, str]:
|
||||
payload = await request.json()
|
||||
event_type = CompassWebhookType.TRANSCRIPTION # currently the only type
|
||||
|
||||
return payload, event_type
|
||||
@@ -8,8 +8,9 @@ from strenum import StrEnum
|
||||
|
||||
from backend.data import integrations
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from .base import BaseWebhooksManager
|
||||
from ._base import BaseWebhooksManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,7 +21,7 @@ class GithubWebhookType(StrEnum):
|
||||
|
||||
|
||||
class GithubWebhooksManager(BaseWebhooksManager):
|
||||
PROVIDER_NAME = "github"
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
WebhookType = GithubWebhookType
|
||||
|
||||
@@ -58,10 +59,15 @@ class GithubWebhooksManager(BaseWebhooksManager):
|
||||
|
||||
return payload, event_type
|
||||
|
||||
async def trigger_ping(self, webhook: integrations.Webhook) -> None:
|
||||
async def trigger_ping(
|
||||
self, webhook: integrations.Webhook, credentials: Credentials | None
|
||||
) -> None:
|
||||
if not credentials:
|
||||
raise ValueError("Credentials are required but were not passed")
|
||||
|
||||
headers = {
|
||||
**self.GITHUB_API_DEFAULT_HEADERS,
|
||||
"Authorization": f"Bearer {webhook.config.get('access_token')}",
|
||||
"Authorization": credentials.bearer(),
|
||||
}
|
||||
|
||||
repo, github_hook_id = webhook.resource, webhook.provider_webhook_id
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable, Optional, cast
|
||||
|
||||
from backend.data.block import get_block
|
||||
from backend.data.block import BlockWebhookConfig, get_block
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.data.model import CREDENTIALS_FIELD_NAME
|
||||
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
|
||||
@@ -10,7 +10,7 @@ if TYPE_CHECKING:
|
||||
from backend.data.graph import GraphModel, NodeModel
|
||||
from backend.data.model import Credentials
|
||||
|
||||
from .base import BaseWebhooksManager
|
||||
from ._base import BaseWebhooksManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -95,56 +95,92 @@ async def on_node_activate(
|
||||
if not block.webhook_config:
|
||||
return node
|
||||
|
||||
provider = block.webhook_config.provider
|
||||
if provider not in WEBHOOK_MANAGERS_BY_NAME:
|
||||
raise ValueError(
|
||||
f"Block #{block.id} has webhook_config for provider {provider} "
|
||||
"which does not support webhooks"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Activating webhook node #{node.id} with config {block.webhook_config}"
|
||||
)
|
||||
|
||||
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[block.webhook_config.provider]()
|
||||
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
|
||||
|
||||
try:
|
||||
resource = block.webhook_config.resource_format.format(**node.input_default)
|
||||
except KeyError:
|
||||
resource = None
|
||||
logger.debug(
|
||||
f"Constructed resource string {resource} from input {node.input_default}"
|
||||
if auto_setup_webhook := isinstance(block.webhook_config, BlockWebhookConfig):
|
||||
try:
|
||||
resource = block.webhook_config.resource_format.format(**node.input_default)
|
||||
except KeyError:
|
||||
resource = None
|
||||
logger.debug(
|
||||
f"Constructed resource string {resource} from input {node.input_default}"
|
||||
)
|
||||
else:
|
||||
resource = "" # not relevant for manual webhooks
|
||||
|
||||
needs_credentials = CREDENTIALS_FIELD_NAME in block.input_schema.model_fields
|
||||
credentials_meta = (
|
||||
node.input_default.get(CREDENTIALS_FIELD_NAME) if needs_credentials else None
|
||||
)
|
||||
|
||||
event_filter_input_name = block.webhook_config.event_filter_input
|
||||
has_everything_for_webhook = (
|
||||
resource is not None
|
||||
and CREDENTIALS_FIELD_NAME in node.input_default
|
||||
and event_filter_input_name in node.input_default
|
||||
and any(is_on for is_on in node.input_default[event_filter_input_name].values())
|
||||
and (credentials_meta or not needs_credentials)
|
||||
and (
|
||||
not event_filter_input_name
|
||||
or (
|
||||
event_filter_input_name in node.input_default
|
||||
and any(
|
||||
is_on
|
||||
for is_on in node.input_default[event_filter_input_name].values()
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if has_everything_for_webhook and resource:
|
||||
if has_everything_for_webhook and resource is not None:
|
||||
logger.debug(f"Node #{node} has everything for a webhook!")
|
||||
if not credentials:
|
||||
credentials_meta = node.input_default[CREDENTIALS_FIELD_NAME]
|
||||
if credentials_meta and not credentials:
|
||||
raise ValueError(
|
||||
f"Cannot set up webhook for node #{node.id}: "
|
||||
f"credentials #{credentials_meta['id']} not available"
|
||||
)
|
||||
|
||||
# Shape of the event filter is enforced in Block.__init__
|
||||
event_filter = cast(dict, node.input_default[event_filter_input_name])
|
||||
events = [
|
||||
block.webhook_config.event_format.format(event=event)
|
||||
for event, enabled in event_filter.items()
|
||||
if enabled is True
|
||||
]
|
||||
logger.debug(f"Webhook events to subscribe to: {', '.join(events)}")
|
||||
if event_filter_input_name:
|
||||
# Shape of the event filter is enforced in Block.__init__
|
||||
event_filter = cast(dict, node.input_default[event_filter_input_name])
|
||||
events = [
|
||||
block.webhook_config.event_format.format(event=event)
|
||||
for event, enabled in event_filter.items()
|
||||
if enabled is True
|
||||
]
|
||||
logger.debug(f"Webhook events to subscribe to: {', '.join(events)}")
|
||||
else:
|
||||
events = []
|
||||
|
||||
# Find/make and attach a suitable webhook to the node
|
||||
new_webhook = await webhooks_manager.get_suitable_webhook(
|
||||
user_id,
|
||||
credentials,
|
||||
block.webhook_config.webhook_type,
|
||||
resource,
|
||||
events,
|
||||
)
|
||||
if auto_setup_webhook:
|
||||
assert credentials is not None
|
||||
new_webhook = await webhooks_manager.get_suitable_auto_webhook(
|
||||
user_id,
|
||||
credentials,
|
||||
block.webhook_config.webhook_type,
|
||||
resource,
|
||||
events,
|
||||
)
|
||||
else:
|
||||
# Manual webhook -> no credentials -> don't register but do create
|
||||
new_webhook = await webhooks_manager.get_manual_webhook(
|
||||
user_id,
|
||||
node.graph_id,
|
||||
block.webhook_config.webhook_type,
|
||||
events,
|
||||
)
|
||||
logger.debug(f"Acquired webhook: {new_webhook}")
|
||||
return await set_node_webhook(node.id, new_webhook.id)
|
||||
else:
|
||||
logger.debug(f"Node #{node.id} does not have everything for a webhook")
|
||||
|
||||
return node
|
||||
|
||||
@@ -167,7 +203,14 @@ async def on_node_deactivate(
|
||||
if not block.webhook_config:
|
||||
return node
|
||||
|
||||
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[block.webhook_config.provider]()
|
||||
provider = block.webhook_config.provider
|
||||
if provider not in WEBHOOK_MANAGERS_BY_NAME:
|
||||
raise ValueError(
|
||||
f"Block #{block.id} has webhook_config for provider {provider} "
|
||||
"which does not support webhooks"
|
||||
)
|
||||
|
||||
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
|
||||
|
||||
if node.webhook_id:
|
||||
logger.debug(f"Node #{node.id} has webhook_id {node.webhook_id}")
|
||||
@@ -180,16 +223,20 @@ async def on_node_deactivate(
|
||||
updated_node = await set_node_webhook(node.id, None)
|
||||
|
||||
# Prune and deregister the webhook if it is no longer used anywhere
|
||||
logger.debug("Pruning and deregistering webhook if dangling")
|
||||
webhook = node.webhook
|
||||
if credentials:
|
||||
logger.debug(f"Pruning webhook #{webhook.id} with credentials")
|
||||
await webhooks_manager.prune_webhook_if_dangling(webhook.id, credentials)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Pruning{' and deregistering' if credentials else ''} "
|
||||
f"webhook #{webhook.id}"
|
||||
)
|
||||
await webhooks_manager.prune_webhook_if_dangling(webhook.id, credentials)
|
||||
if (
|
||||
CREDENTIALS_FIELD_NAME in block.input_schema.model_fields
|
||||
and not credentials
|
||||
):
|
||||
logger.warning(
|
||||
f"Cannot deregister webhook #{webhook.id}: credentials "
|
||||
f"#{webhook.credentials_id} not available "
|
||||
f"({webhook.provider} webhook ID: {webhook.provider_webhook_id})"
|
||||
f"({webhook.provider.value} webhook ID: {webhook.provider_webhook_id})"
|
||||
)
|
||||
return updated_node
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import logging
|
||||
from typing import ClassVar
|
||||
|
||||
import requests
|
||||
from fastapi import Request
|
||||
|
||||
from backend.data import integrations
|
||||
from backend.data.model import APIKeyCredentials, Credentials
|
||||
from backend.integrations.webhooks.base import BaseWebhooksManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
||||
class Slant3DWebhooksManager(BaseWebhooksManager):
|
||||
"""Manager for Slant3D webhooks"""
|
||||
|
||||
PROVIDER_NAME: ClassVar[str] = "slant3d"
|
||||
PROVIDER_NAME = ProviderName.SLANT3D
|
||||
BASE_URL = "https://www.slant3dapi.com/api"
|
||||
|
||||
async def _register_webhook(
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from backend.util.settings import Config
|
||||
|
||||
app_config = Config()
|
||||
|
||||
|
||||
# TODO: add test to assert this matches the actual API route
|
||||
def webhook_ingress_url(provider_name: str, webhook_id: str) -> str:
|
||||
return (
|
||||
f"{app_config.platform_base_url}/api/integrations/{provider_name}"
|
||||
f"/webhooks/{webhook_id}/ingress"
|
||||
)
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Annotated, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, Literal
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
@@ -7,10 +7,10 @@ from pydantic import BaseModel, Field, SecretStr
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.data.integrations import (
|
||||
WebhookEvent,
|
||||
get_all_webhooks,
|
||||
get_all_webhooks_by_creds,
|
||||
get_webhook,
|
||||
listen_for_webhook_event,
|
||||
publish_webhook_event,
|
||||
wait_for_webhook_event,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
@@ -20,12 +20,16 @@ from backend.data.model import (
|
||||
)
|
||||
from backend.executor.manager import ExecutionManager
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
|
||||
from backend.util.exceptions import NeedConfirmation
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.integrations.oauth import BaseOAuthHandler
|
||||
|
||||
from ..utils import get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -42,7 +46,9 @@ class LoginResponse(BaseModel):
|
||||
|
||||
@router.get("/{provider}/login")
|
||||
def login(
|
||||
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to initiate an OAuth flow for")
|
||||
],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
request: Request,
|
||||
scopes: Annotated[
|
||||
@@ -74,7 +80,9 @@ class CredentialsMetaResponse(BaseModel):
|
||||
|
||||
@router.post("/{provider}/callback")
|
||||
def callback(
|
||||
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The target provider for this OAuth exchange")
|
||||
],
|
||||
code: Annotated[str, Body(title="Authorization code acquired by user login")],
|
||||
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
@@ -103,11 +111,12 @@ def callback(
|
||||
if not set(scopes).issubset(set(credentials.scopes)):
|
||||
# For now, we'll just log the warning and continue
|
||||
logger.warning(
|
||||
f"Granted scopes {credentials.scopes} for {provider}do not include all requested scopes {scopes}"
|
||||
f"Granted scopes {credentials.scopes} for provider {provider.value} "
|
||||
f"do not include all requested scopes {scopes}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Code->Token exchange failed for provider {provider}: {e}")
|
||||
logger.error(f"Code->Token exchange failed for provider {provider.value}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Failed to exchange code for tokens: {str(e)}"
|
||||
)
|
||||
@@ -116,7 +125,8 @@ def callback(
|
||||
creds_manager.create(user_id, credentials)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully processed OAuth callback for user {user_id} and provider {provider}"
|
||||
f"Successfully processed OAuth callback for user {user_id} "
|
||||
f"and provider {provider.value}"
|
||||
)
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
@@ -148,7 +158,9 @@ def list_credentials(
|
||||
|
||||
@router.get("/{provider}/credentials")
|
||||
def list_credentials_by_provider(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to list credentials for")
|
||||
],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
credentials = creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
@@ -167,7 +179,9 @@ def list_credentials_by_provider(
|
||||
|
||||
@router.get("/{provider}/credentials/{cred_id}")
|
||||
def get_credential(
|
||||
provider: Annotated[str, Path(title="The provider to retrieve credentials for")],
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to retrieve credentials for")
|
||||
],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> Credentials:
|
||||
@@ -184,7 +198,9 @@ def get_credential(
|
||||
@router.post("/{provider}/credentials", status_code=201)
|
||||
def create_api_key_credentials(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
provider: Annotated[str, Path(title="The provider to create credentials for")],
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to create credentials for")
|
||||
],
|
||||
api_key: Annotated[str, Body(title="The API key to store")],
|
||||
title: Annotated[str, Body(title="Optional title for the credentials")],
|
||||
expires_at: Annotated[
|
||||
@@ -225,7 +241,9 @@ class CredentialsDeletionNeedsConfirmationResponse(BaseModel):
|
||||
@router.delete("/{provider}/credentials/{cred_id}")
|
||||
async def delete_credentials(
|
||||
request: Request,
|
||||
provider: Annotated[str, Path(title="The provider to delete credentials for")],
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to delete credentials for")
|
||||
],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
force: Annotated[
|
||||
@@ -264,15 +282,20 @@ async def delete_credentials(
|
||||
@router.post("/{provider}/webhooks/{webhook_id}/ingress")
|
||||
async def webhook_ingress_generic(
|
||||
request: Request,
|
||||
provider: Annotated[str, Path(title="Provider where the webhook was registered")],
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="Provider where the webhook was registered")
|
||||
],
|
||||
webhook_id: Annotated[str, Path(title="Our ID for the webhook")],
|
||||
):
|
||||
logger.debug(f"Received {provider} webhook ingress for ID {webhook_id}")
|
||||
logger.debug(f"Received {provider.value} webhook ingress for ID {webhook_id}")
|
||||
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
|
||||
webhook = await get_webhook(webhook_id)
|
||||
logger.debug(f"Webhook #{webhook_id}: {webhook}")
|
||||
payload, event_type = await webhook_manager.validate_payload(webhook, request)
|
||||
logger.debug(f"Validated {provider} {event_type} event with payload {payload}")
|
||||
logger.debug(
|
||||
f"Validated {provider.value} {webhook.webhook_type} {event_type} event "
|
||||
f"with payload {payload}"
|
||||
)
|
||||
|
||||
webhook_event = WebhookEvent(
|
||||
provider=provider,
|
||||
@@ -300,18 +323,28 @@ async def webhook_ingress_generic(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/webhooks/{webhook_id}/ping")
|
||||
@router.post("/webhooks/{webhook_id}/ping")
|
||||
async def webhook_ping(
|
||||
provider: Annotated[str, Path(title="Provider where the webhook was registered")],
|
||||
webhook_id: Annotated[str, Path(title="Our ID for the webhook")],
|
||||
user_id: Annotated[str, Depends(get_user_id)], # require auth
|
||||
):
|
||||
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
|
||||
webhook = await get_webhook(webhook_id)
|
||||
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[webhook.provider]()
|
||||
|
||||
await webhook_manager.trigger_ping(webhook)
|
||||
if not await listen_for_webhook_event(webhook_id, event_type="ping"):
|
||||
raise HTTPException(status_code=500, detail="Webhook ping event not received")
|
||||
credentials = (
|
||||
creds_manager.get(user_id, webhook.credentials_id)
|
||||
if webhook.credentials_id
|
||||
else None
|
||||
)
|
||||
try:
|
||||
await webhook_manager.trigger_ping(webhook, credentials)
|
||||
except NotImplementedError:
|
||||
return False
|
||||
|
||||
if not await wait_for_webhook_event(webhook_id, event_type="ping", timeout=10):
|
||||
raise HTTPException(status_code=504, detail="Webhook ping timed out")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# --------------------------- UTILITIES ---------------------------- #
|
||||
@@ -330,7 +363,15 @@ async def remove_all_webhooks_for_credentials(
|
||||
Raises:
|
||||
NeedConfirmation: If any of the webhooks are still in use and `force` is `False`
|
||||
"""
|
||||
webhooks = await get_all_webhooks(credentials.id)
|
||||
webhooks = await get_all_webhooks_by_creds(credentials.id)
|
||||
if credentials.provider not in WEBHOOK_MANAGERS_BY_NAME:
|
||||
if webhooks:
|
||||
logger.error(
|
||||
f"Credentials #{credentials.id} for provider {credentials.provider} "
|
||||
f"are attached to {len(webhooks)} webhooks, "
|
||||
f"but there is no available WebhooksHandler for {credentials.provider}"
|
||||
)
|
||||
return
|
||||
if any(w.attached_nodes for w in webhooks) and not force:
|
||||
raise NeedConfirmation(
|
||||
"Some webhooks linked to these credentials are still in use by an agent"
|
||||
@@ -349,18 +390,23 @@ async def remove_all_webhooks_for_credentials(
|
||||
logger.warning(f"Webhook #{webhook.id} failed to prune")
|
||||
|
||||
|
||||
def _get_provider_oauth_handler(req: Request, provider_name: str) -> BaseOAuthHandler:
|
||||
def _get_provider_oauth_handler(
|
||||
req: Request, provider_name: ProviderName
|
||||
) -> "BaseOAuthHandler":
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Unknown provider '{provider_name}'"
|
||||
status_code=404,
|
||||
detail=f"Provider '{provider_name.value}' does not support OAuth",
|
||||
)
|
||||
|
||||
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
|
||||
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
|
||||
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id")
|
||||
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
|
||||
if not (client_id and client_secret):
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail=f"Integration with provider '{provider_name}' is not configured",
|
||||
detail=(
|
||||
f"Integration with provider '{provider_name.value}' is not configured"
|
||||
),
|
||||
)
|
||||
|
||||
handler_class = HANDLERS_BY_NAME[provider_name]
|
||||
|
||||
@@ -16,6 +16,8 @@ import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.library.routes
|
||||
import backend.server.v2.store.routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
|
||||
@@ -25,15 +27,26 @@ logger = logging.getLogger(__name__)
|
||||
logging.getLogger("autogpt_libs").setLevel(logging.INFO)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def launch_darkly_context():
|
||||
if settings.config.app_env != backend.util.settings.AppEnvironment.LOCAL:
|
||||
initialize_launchdarkly()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
shutdown_launchdarkly()
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.db.connect()
|
||||
await backend.data.block.initialize_blocks()
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
initialize_launchdarkly()
|
||||
yield
|
||||
shutdown_launchdarkly()
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
await backend.data.db.disconnect()
|
||||
|
||||
|
||||
@@ -73,7 +86,13 @@ def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
|
||||
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(Exception, handle_internal_http_error(500))
|
||||
app.include_router(backend.server.routers.v1.v1_router, tags=["v1"])
|
||||
app.include_router(backend.server.routers.v1.v1_router, tags=["v1"], prefix="/api")
|
||||
app.include_router(
|
||||
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||
)
|
||||
|
||||
|
||||
@app.get(path="/health", tags=["health"], dependencies=[])
|
||||
@@ -106,17 +125,17 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
async def test_create_graph(
|
||||
create_graph: backend.server.routers.v1.CreateGraph,
|
||||
user_id: str,
|
||||
is_template=False,
|
||||
):
|
||||
return await backend.server.routers.v1.create_new_graph(create_graph, user_id)
|
||||
|
||||
@staticmethod
|
||||
async def test_get_graph_run_status(
|
||||
graph_id: str, graph_exec_id: str, user_id: str
|
||||
):
|
||||
return await backend.server.routers.v1.get_graph_run_status(
|
||||
graph_id, graph_exec_id, user_id
|
||||
async def test_get_graph_run_status(graph_exec_id: str, user_id: str):
|
||||
execution = await backend.data.graph.get_execution(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not execution:
|
||||
raise ValueError(f"Execution {graph_exec_id} not found")
|
||||
return execution.status
|
||||
|
||||
@staticmethod
|
||||
async def test_get_graph_run_node_execution_results(
|
||||
|
||||
@@ -69,8 +69,7 @@ integration_creds_manager = IntegrationCredentialsManager()
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
# Define the API routes
|
||||
v1_router = APIRouter(prefix="/api")
|
||||
|
||||
v1_router = APIRouter()
|
||||
|
||||
v1_router.include_router(
|
||||
backend.server.integrations.router.router,
|
||||
@@ -132,7 +131,7 @@ def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput
|
||||
|
||||
@v1_router.get(path="/credits", dependencies=[Depends(auth_middleware)])
|
||||
async def get_user_credits(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[str, int]:
|
||||
# Credits can go negative, so ensure it's at least 0 for user to see.
|
||||
return {"credits": max(await _user_credit_model.get_or_refill_credit(user_id), 0)}
|
||||
@@ -149,12 +148,9 @@ class DeleteGraphResponse(TypedDict):
|
||||
|
||||
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
|
||||
async def get_graphs(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
with_runs: bool = False,
|
||||
) -> Sequence[graph_db.Graph]:
|
||||
return await graph_db.get_graphs(
|
||||
include_executions=with_runs, filter_by="active", user_id=user_id
|
||||
)
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -170,9 +166,9 @@ async def get_graph(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
version: int | None = None,
|
||||
hide_credentials: bool = False,
|
||||
) -> graph_db.Graph:
|
||||
) -> graph_db.GraphModel:
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id, version, user_id=user_id, hide_credentials=hide_credentials
|
||||
graph_id, version, user_id=user_id, for_export=hide_credentials
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
@@ -191,7 +187,7 @@ async def get_graph(
|
||||
)
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> Sequence[graph_db.Graph]:
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||
if not graphs:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
@@ -203,7 +199,7 @@ async def get_graph_all_versions(
|
||||
)
|
||||
async def create_new_graph(
|
||||
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.Graph:
|
||||
) -> graph_db.GraphModel:
|
||||
return await do_create_graph(create_graph, is_template=False, user_id=user_id)
|
||||
|
||||
|
||||
@@ -213,7 +209,7 @@ async def do_create_graph(
|
||||
# user_id doesn't have to be annotated like on other endpoints,
|
||||
# because create_graph isn't used directly as an endpoint
|
||||
user_id: str,
|
||||
) -> graph_db.Graph:
|
||||
) -> graph_db.GraphModel:
|
||||
if create_graph.graph:
|
||||
graph = graph_db.make_graph_model(create_graph.graph, user_id)
|
||||
elif create_graph.template_id:
|
||||
@@ -252,6 +248,13 @@ async def do_create_graph(
|
||||
async def delete_graph(
|
||||
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> DeleteGraphResponse:
|
||||
if active_version := await graph_db.get_graph(graph_id, user_id=user_id):
|
||||
|
||||
def get_credentials(credentials_id: str) -> "Credentials | None":
|
||||
return integration_creds_manager.get(user_id, credentials_id)
|
||||
|
||||
await on_graph_deactivate(active_version, get_credentials)
|
||||
|
||||
return {"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)}
|
||||
|
||||
|
||||
@@ -267,7 +270,7 @@ async def update_graph(
|
||||
graph_id: str,
|
||||
graph: graph_db.Graph,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> graph_db.Graph:
|
||||
) -> graph_db.GraphModel:
|
||||
# Sanity check
|
||||
if graph.id and graph.id != graph_id:
|
||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||
@@ -386,7 +389,7 @@ def execute_graph(
|
||||
async def stop_graph_run(
|
||||
graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> Sequence[execution_db.ExecutionResult]:
|
||||
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
|
||||
if not await graph_db.get_execution(user_id=user_id, execution_id=graph_exec_id):
|
||||
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
|
||||
|
||||
await asyncio.to_thread(
|
||||
@@ -398,23 +401,14 @@ async def stop_graph_run(
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions",
|
||||
path="/executions",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def list_graph_runs(
|
||||
graph_id: str,
|
||||
async def get_executions(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_version: int | None = None,
|
||||
) -> Sequence[str]:
|
||||
graph = await graph_db.get_graph(graph_id, graph_version, user_id=user_id)
|
||||
if not graph:
|
||||
rev = "" if graph_version is None else f" v{graph_version}"
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
|
||||
)
|
||||
|
||||
return await execution_db.list_executions(graph_id, graph_version)
|
||||
) -> list[graph_db.GraphExecution]:
|
||||
return await graph_db.get_executions(user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -434,25 +428,6 @@ async def get_graph_run_node_execution_results(
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
|
||||
# NOTE: This is used for testing
|
||||
async def get_graph_run_status(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> execution_db.ExecutionStatus:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
execution = await execution_db.get_graph_execution(graph_exec_id, user_id)
|
||||
if not execution:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Execution #{graph_exec_id} not found."
|
||||
)
|
||||
|
||||
return execution.executionStatus
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Templates ########################
|
||||
########################################################
|
||||
@@ -465,7 +440,7 @@ async def get_graph_run_status(
|
||||
)
|
||||
async def get_templates(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> Sequence[graph_db.Graph]:
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
return await graph_db.get_graphs(filter_by="template", user_id=user_id)
|
||||
|
||||
|
||||
@@ -474,7 +449,9 @@ async def get_templates(
|
||||
tags=["templates", "graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_template(graph_id: str, version: int | None = None) -> graph_db.Graph:
|
||||
async def get_template(
|
||||
graph_id: str, version: int | None = None
|
||||
) -> graph_db.GraphModel:
|
||||
graph = await graph_db.get_graph(graph_id, version, template=True)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Template #{graph_id} not found.")
|
||||
@@ -488,7 +465,7 @@ async def get_template(graph_id: str, version: int | None = None) -> graph_db.Gr
|
||||
)
|
||||
async def create_new_template(
|
||||
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.Graph:
|
||||
) -> graph_db.GraphModel:
|
||||
return await do_create_graph(create_graph, is_template=True, user_id=user_id)
|
||||
|
||||
|
||||
|
||||
165
autogpt_platform/backend/backend/server/v2/library/db.py
Normal file
165
autogpt_platform/backend/backend/server/v2/library/db.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.data.graph
|
||||
import backend.data.includes
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.store.exceptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_library_agents(
|
||||
user_id: str,
|
||||
) -> List[backend.server.v2.library.model.LibraryAgent]:
|
||||
"""
|
||||
Returns all agents (AgentGraph) that belong to the user and all agents in their library (UserAgent table)
|
||||
"""
|
||||
logger.debug(f"Getting library agents for user {user_id}")
|
||||
|
||||
try:
|
||||
# Get agents created by user with nodes and links
|
||||
user_created = await prisma.models.AgentGraph.prisma().find_many(
|
||||
where=prisma.types.AgentGraphWhereInput(userId=user_id, isActive=True),
|
||||
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
|
||||
# Get agents in user's library with nodes and links
|
||||
library_agents = await prisma.models.UserAgent.prisma().find_many(
|
||||
where=prisma.types.UserAgentWhereInput(
|
||||
userId=user_id, isDeleted=False, isArchived=False
|
||||
),
|
||||
include={
|
||||
"Agent": {
|
||||
"include": {
|
||||
"AgentNodes": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"Webhook": True,
|
||||
"AgentBlock": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Convert to Graph models first
|
||||
graphs = []
|
||||
|
||||
# Add user created agents
|
||||
for agent in user_created:
|
||||
try:
|
||||
graphs.append(backend.data.graph.GraphModel.from_db(agent))
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing user created agent {agent.id}: {e}")
|
||||
continue
|
||||
|
||||
# Add library agents
|
||||
for agent in library_agents:
|
||||
if agent.Agent:
|
||||
try:
|
||||
graphs.append(backend.data.graph.GraphModel.from_db(agent.Agent))
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing library agent {agent.agentId}: {e}")
|
||||
continue
|
||||
|
||||
# Convert Graph models to LibraryAgent models
|
||||
result = []
|
||||
for graph in graphs:
|
||||
result.append(
|
||||
backend.server.v2.library.model.LibraryAgent(
|
||||
id=graph.id,
|
||||
version=graph.version,
|
||||
is_active=graph.is_active,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
isCreatedByUser=any(a.id == graph.id for a in user_created),
|
||||
input_schema=graph.input_schema,
|
||||
output_schema=graph.output_schema,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Found {len(result)} library agents")
|
||||
return result
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error getting library agents: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch library agents"
|
||||
) from e
|
||||
|
||||
|
||||
async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> None:
|
||||
"""
|
||||
Finds the agent from the store listing version and adds it to the user's library (UserAgent table)
|
||||
if they don't already have it
|
||||
"""
|
||||
logger.debug(
|
||||
f"Adding agent from store listing version {store_listing_version_id} to library for user {user_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get store listing version to find agent
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"Agent": True}
|
||||
)
|
||||
)
|
||||
|
||||
if not store_listing_version or not store_listing_version.Agent:
|
||||
logger.warning(
|
||||
f"Store listing version not found: {store_listing_version_id}"
|
||||
)
|
||||
raise backend.server.v2.store.exceptions.AgentNotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found"
|
||||
)
|
||||
|
||||
agent = store_listing_version.Agent
|
||||
|
||||
if agent.userId == user_id:
|
||||
logger.warning(
|
||||
f"User {user_id} cannot add their own agent to their library"
|
||||
)
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Cannot add own agent to library"
|
||||
)
|
||||
|
||||
# Check if user already has this agent
|
||||
existing_user_agent = await prisma.models.UserAgent.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentId": agent.id,
|
||||
"agentVersion": agent.version,
|
||||
}
|
||||
)
|
||||
|
||||
if existing_user_agent:
|
||||
logger.debug(
|
||||
f"User {user_id} already has agent {agent.id} in their library"
|
||||
)
|
||||
return
|
||||
|
||||
# Create UserAgent entry
|
||||
await prisma.models.UserAgent.prisma().create(
|
||||
data=prisma.types.UserAgentCreateInput(
|
||||
userId=user_id,
|
||||
agentId=agent.id,
|
||||
agentVersion=agent.version,
|
||||
isCreatedByUser=False,
|
||||
)
|
||||
)
|
||||
logger.debug(f"Added agent {agent.id} to library for user {user_id}")
|
||||
|
||||
except backend.server.v2.store.exceptions.AgentNotFoundError:
|
||||
raise
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error adding agent to library: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to add agent to library"
|
||||
) from e
|
||||
197
autogpt_platform/backend/backend/server/v2/library/db_test.py
Normal file
197
autogpt_platform/backend/backend/server/v2/library/db_test.py
Normal file
@@ -0,0 +1,197 @@
|
||||
from datetime import datetime
|
||||
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import pytest
|
||||
from prisma import Prisma
|
||||
|
||||
import backend.data.includes
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.store.exceptions
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_prisma():
|
||||
# Don't register client if already registered
|
||||
try:
|
||||
Prisma()
|
||||
except prisma.errors.ClientAlreadyRegisteredError:
|
||||
pass
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agents(mocker):
|
||||
# Mock data
|
||||
mock_user_created = [
|
||||
prisma.models.AgentGraph(
|
||||
id="agent1",
|
||||
version=1,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
userId="test-user",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
isTemplate=False,
|
||||
)
|
||||
]
|
||||
|
||||
mock_library_agents = [
|
||||
prisma.models.UserAgent(
|
||||
id="ua1",
|
||||
userId="test-user",
|
||||
agentId="agent2",
|
||||
agentVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=False,
|
||||
Agent=prisma.models.AgentGraph(
|
||||
id="agent2",
|
||||
version=1,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
userId="other-user",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
isTemplate=False,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Mock prisma calls
|
||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||
mock_agent_graph.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_user_created
|
||||
)
|
||||
|
||||
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma")
|
||||
mock_user_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_library_agents
|
||||
)
|
||||
|
||||
# Call function
|
||||
result = await db.get_library_agents("test-user")
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 2
|
||||
assert result[0].id == "agent1"
|
||||
assert result[0].name == "Test Agent 1"
|
||||
assert result[0].description == "Test Description 1"
|
||||
assert result[0].isCreatedByUser is True
|
||||
assert result[1].id == "agent2"
|
||||
assert result[1].name == "Test Agent 2"
|
||||
assert result[1].description == "Test Description 2"
|
||||
assert result[1].isCreatedByUser is False
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_agent_graph.return_value.find_many.assert_called_once_with(
|
||||
where=prisma.types.AgentGraphWhereInput(userId="test-user", isActive=True),
|
||||
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
mock_user_agent.return_value.find_many.assert_called_once_with(
|
||||
where=prisma.types.UserAgentWhereInput(
|
||||
userId="test-user", isDeleted=False, isArchived=False
|
||||
),
|
||||
include={
|
||||
"Agent": {
|
||||
"include": {
|
||||
"AgentNodes": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"Webhook": True,
|
||||
"AgentBlock": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_agent_to_library(mocker):
|
||||
# Mock data
|
||||
mock_store_listing = prisma.models.StoreListingVersion(
|
||||
id="version123",
|
||||
version=1,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
agentId="agent1",
|
||||
agentVersion=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
subHeading="Test Agent Subheading",
|
||||
imageUrls=["https://example.com/image.jpg"],
|
||||
description="Test Description",
|
||||
categories=["test"],
|
||||
isFeatured=False,
|
||||
isDeleted=False,
|
||||
isAvailable=True,
|
||||
isApproved=True,
|
||||
Agent=prisma.models.AgentGraph(
|
||||
id="agent1",
|
||||
version=1,
|
||||
name="Test Agent",
|
||||
description="Test Description",
|
||||
userId="creator",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
isTemplate=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
mock_store_listing_version = mocker.patch(
|
||||
"prisma.models.StoreListingVersion.prisma"
|
||||
)
|
||||
mock_store_listing_version.return_value.find_unique = mocker.AsyncMock(
|
||||
return_value=mock_store_listing
|
||||
)
|
||||
|
||||
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma")
|
||||
mock_user_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_user_agent.return_value.create = mocker.AsyncMock()
|
||||
|
||||
# Call function
|
||||
await db.add_agent_to_library("version123", "test-user")
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"Agent": True}
|
||||
)
|
||||
mock_user_agent.return_value.find_first.assert_called_once_with(
|
||||
where={
|
||||
"userId": "test-user",
|
||||
"agentId": "agent1",
|
||||
"agentVersion": 1,
|
||||
}
|
||||
)
|
||||
mock_user_agent.return_value.create.assert_called_once_with(
|
||||
data=prisma.types.UserAgentCreateInput(
|
||||
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_agent_to_library_not_found(mocker):
|
||||
# Mock prisma calls
|
||||
mock_store_listing_version = mocker.patch(
|
||||
"prisma.models.StoreListingVersion.prisma"
|
||||
)
|
||||
mock_store_listing_version.return_value.find_unique = mocker.AsyncMock(
|
||||
return_value=None
|
||||
)
|
||||
|
||||
# Call function and verify exception
|
||||
with pytest.raises(backend.server.v2.store.exceptions.AgentNotFoundError):
|
||||
await db.add_agent_to_library("version123", "test-user")
|
||||
|
||||
# Verify mock called correctly
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"Agent": True}
|
||||
)
|
||||
16
autogpt_platform/backend/backend/server/v2/library/model.py
Normal file
16
autogpt_platform/backend/backend/server/v2/library/model.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class LibraryAgent(pydantic.BaseModel):
|
||||
id: str # Changed from agent_id to match GraphMeta
|
||||
version: int # Changed from agent_version to match GraphMeta
|
||||
is_active: bool # Added to match GraphMeta
|
||||
name: str
|
||||
description: str
|
||||
|
||||
isCreatedByUser: bool
|
||||
# Made input_schema and output_schema match GraphMeta's type
|
||||
input_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
output_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
@@ -0,0 +1,43 @@
|
||||
import backend.server.v2.library.model
|
||||
|
||||
|
||||
def test_library_agent():
|
||||
agent = backend.server.v2.library.model.LibraryAgent(
|
||||
id="test-agent-123",
|
||||
version=1,
|
||||
is_active=True,
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
isCreatedByUser=False,
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
)
|
||||
assert agent.id == "test-agent-123"
|
||||
assert agent.version == 1
|
||||
assert agent.is_active is True
|
||||
assert agent.name == "Test Agent"
|
||||
assert agent.description == "Test description"
|
||||
assert agent.isCreatedByUser is False
|
||||
assert agent.input_schema == {"type": "object", "properties": {}}
|
||||
assert agent.output_schema == {"type": "object", "properties": {}}
|
||||
|
||||
|
||||
def test_library_agent_with_user_created():
|
||||
agent = backend.server.v2.library.model.LibraryAgent(
|
||||
id="user-agent-456",
|
||||
version=2,
|
||||
is_active=True,
|
||||
name="User Created Agent",
|
||||
description="An agent created by the user",
|
||||
isCreatedByUser=True,
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
)
|
||||
assert agent.id == "user-agent-456"
|
||||
assert agent.version == 2
|
||||
assert agent.is_active is True
|
||||
assert agent.name == "User Created Agent"
|
||||
assert agent.description == "An agent created by the user"
|
||||
assert agent.isCreatedByUser is True
|
||||
assert agent.input_schema == {"type": "object", "properties": {}}
|
||||
assert agent.output_schema == {"type": "object", "properties": {}}
|
||||
74
autogpt_platform/backend/backend/server/v2/library/routes.py
Normal file
74
autogpt_platform/backend/backend/server/v2/library/routes.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import fastapi
|
||||
|
||||
import backend.data.graph
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
tags=["library", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def get_library_agents(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
]
|
||||
) -> typing.Sequence[backend.server.v2.library.model.LibraryAgent]:
|
||||
"""
|
||||
Get all agents in the user's library, including both created and saved agents.
|
||||
"""
|
||||
try:
|
||||
agents = await backend.server.v2.library.db.get_library_agents(user_id)
|
||||
return agents
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting library agents")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to get library agents"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/{store_listing_version_id}",
|
||||
tags=["library", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
status_code=201,
|
||||
)
|
||||
async def add_agent_to_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> fastapi.Response:
|
||||
"""
|
||||
Add an agent from the store to the user's library.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): ID of the store listing version to add
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
fastapi.Response: 201 status code on success
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error adding the agent to the library
|
||||
"""
|
||||
try:
|
||||
await backend.server.v2.library.db.add_agent_to_library(
|
||||
store_listing_version_id=store_listing_version_id, user_id=user_id
|
||||
)
|
||||
return fastapi.Response(status_code=201)
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst adding agent to library")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to add agent to library"
|
||||
)
|
||||
@@ -0,0 +1,103 @@
|
||||
import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest_mock
|
||||
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.library.routes
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(backend.server.v2.library.routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
def override_auth_middleware():
|
||||
"""Override auth middleware for testing"""
|
||||
return {"sub": "test-user-id"}
|
||||
|
||||
|
||||
def override_get_user_id():
|
||||
"""Override get_user_id for testing"""
|
||||
return "test-user-id"
|
||||
|
||||
|
||||
app.dependency_overrides[autogpt_libs.auth.middleware.auth_middleware] = (
|
||||
override_auth_middleware
|
||||
)
|
||||
app.dependency_overrides[autogpt_libs.auth.depends.get_user_id] = override_get_user_id
|
||||
|
||||
|
||||
def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = [
|
||||
backend.server.v2.library.model.LibraryAgent(
|
||||
id="test-agent-1",
|
||||
version=1,
|
||||
is_active=True,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
isCreatedByUser=True,
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
),
|
||||
backend.server.v2.library.model.LibraryAgent(
|
||||
id="test-agent-2",
|
||||
version=1,
|
||||
is_active=True,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
isCreatedByUser=False,
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
),
|
||||
]
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.get_library_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/agents")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = [
|
||||
backend.server.v2.library.model.LibraryAgent.model_validate(agent)
|
||||
for agent in response.json()
|
||||
]
|
||||
assert len(data) == 2
|
||||
assert data[0].id == "test-agent-1"
|
||||
assert data[0].isCreatedByUser is True
|
||||
assert data[1].id == "test-agent-2"
|
||||
assert data[1].isCreatedByUser is False
|
||||
mock_db_call.assert_called_once_with("test-user-id")
|
||||
|
||||
|
||||
def test_get_library_agents_error(mocker: pytest_mock.MockFixture):
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.get_library_agents")
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.get("/agents")
|
||||
assert response.status_code == 500
|
||||
mock_db_call.assert_called_once_with("test-user-id")
|
||||
|
||||
|
||||
def test_add_agent_to_library_success(mocker: pytest_mock.MockFixture):
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
|
||||
mock_db_call.return_value = None
|
||||
|
||||
response = client.post("/agents/test-version-id")
|
||||
assert response.status_code == 201
|
||||
mock_db_call.assert_called_once_with(
|
||||
store_listing_version_id="test-version-id", user_id="test-user-id"
|
||||
)
|
||||
|
||||
|
||||
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture):
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.post("/agents/test-version-id")
|
||||
assert response.status_code == 500
|
||||
assert response.json()["detail"] == "Failed to add agent to library"
|
||||
mock_db_call.assert_called_once_with(
|
||||
store_listing_version_id="test-version-id", user_id="test-user-id"
|
||||
)
|
||||
53
autogpt_platform/backend/backend/server/v2/store/README.md
Normal file
53
autogpt_platform/backend/backend/server/v2/store/README.md
Normal file
@@ -0,0 +1,53 @@
|
||||
# Store Module
|
||||
|
||||
This module implements the backend API for the AutoGPT Store, handling agents, creators, profiles, submissions and media uploads.
|
||||
|
||||
## Files
|
||||
|
||||
### routes.py
|
||||
Contains the FastAPI route handlers for the store API endpoints:
|
||||
|
||||
- Profile endpoints for managing user profiles
|
||||
- Agent endpoints for browsing and retrieving store agents
|
||||
- Creator endpoints for browsing and retrieving creator details
|
||||
- Store submission endpoints for submitting agents to the store
|
||||
- Media upload endpoints for submission images/videos
|
||||
|
||||
### model.py
|
||||
Contains Pydantic models for request/response validation and serialization:
|
||||
|
||||
- Pagination model for paginated responses
|
||||
- Models for agents, creators, profiles, submissions
|
||||
- Request/response models for all API endpoints
|
||||
|
||||
### db.py
|
||||
Contains database access functions using Prisma ORM:
|
||||
|
||||
- Functions to query and manipulate store data
|
||||
- Handles database operations for all API endpoints
|
||||
- Implements business logic and data validation
|
||||
|
||||
### media.py
|
||||
Handles media file uploads to Google Cloud Storage:
|
||||
|
||||
- Validates file types and sizes
|
||||
- Processes image and video uploads
|
||||
- Stores files in GCS buckets
|
||||
- Returns public URLs for uploaded media
|
||||
|
||||
## Key Features
|
||||
|
||||
- Paginated listings of store agents and creators
|
||||
- Search and filtering of agents and creators
|
||||
- Agent submission workflow
|
||||
- Media file upload handling
|
||||
- Profile management
|
||||
- Reviews and ratings
|
||||
|
||||
## Authentication
|
||||
|
||||
Most endpoints require authentication via the AutoGPT auth middleware. Public endpoints are marked with the "public" tag.
|
||||
|
||||
## Error Handling
|
||||
|
||||
All database and storage operations include proper error handling and logging. Errors are mapped to appropriate HTTP status codes.
|
||||
765
autogpt_platform/backend/backend/server/v2/store/db.py
Normal file
765
autogpt_platform/backend/backend/server/v2/store/db.py
Normal file
@@ -0,0 +1,765 @@
|
||||
import logging
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import prisma.enums
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_store_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
sorted_by: str | None = None,
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.StoreAgentsResponse:
|
||||
logger.debug(
|
||||
f"Getting store agents. featured={featured}, creator={creator}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||
)
|
||||
sanitized_query = None
|
||||
# Sanitize and validate search query by escaping special characters
|
||||
if search_query is not None:
|
||||
sanitized_query = search_query.strip()
|
||||
if not sanitized_query or len(sanitized_query) > 100: # Reasonable length limit
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Invalid search query"
|
||||
)
|
||||
|
||||
# Escape special SQL characters
|
||||
sanitized_query = (
|
||||
sanitized_query.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
.replace("[", "\\[")
|
||||
.replace("]", "\\]")
|
||||
.replace("'", "\\'")
|
||||
.replace('"', '\\"')
|
||||
.replace(";", "\\;")
|
||||
.replace("--", "\\--")
|
||||
.replace("/*", "\\/*")
|
||||
.replace("*/", "\\*/")
|
||||
)
|
||||
|
||||
where_clause = {}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creator:
|
||||
where_clause["creator_username"] = creator
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
if sanitized_query:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
if sorted_by == "rating":
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == "runs":
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == "name":
|
||||
order_by.append({"agent_name": "asc"})
|
||||
|
||||
try:
|
||||
agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
total = await prisma.models.StoreAgent.prisma().count(
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause)
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents = [
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
||||
creator=agent.creator_username,
|
||||
creator_avatar=agent.creator_avatar,
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
)
|
||||
for agent in agents
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(store_agents)} agents")
|
||||
return backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=store_agents,
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=page,
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store agents: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch store agents"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_store_agent_details(
|
||||
username: str, agent_name: str
|
||||
) -> backend.server.v2.store.model.StoreAgentDetails:
|
||||
logger.debug(f"Getting store agent details for {username}/{agent_name}")
|
||||
|
||||
try:
|
||||
agent = await prisma.models.StoreAgent.prisma().find_first(
|
||||
where={"creator_username": username, "slug": agent_name}
|
||||
)
|
||||
|
||||
if not agent:
|
||||
logger.warning(f"Agent not found: {username}/{agent_name}")
|
||||
raise backend.server.v2.store.exceptions.AgentNotFoundError(
|
||||
f"Agent {username}/{agent_name} not found"
|
||||
)
|
||||
|
||||
logger.debug(f"Found agent details for {username}/{agent_name}")
|
||||
return backend.server.v2.store.model.StoreAgentDetails(
|
||||
store_listing_version_id=agent.storeListingVersionId,
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_video=agent.agent_video or "",
|
||||
agent_image=agent.agent_image,
|
||||
creator=agent.creator_username,
|
||||
creator_avatar=agent.creator_avatar,
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
categories=agent.categories,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
last_updated=agent.updated_at,
|
||||
)
|
||||
except backend.server.v2.store.exceptions.AgentNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store agent details: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch agent details"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_store_creators(
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.CreatorsResponse:
|
||||
logger.debug(
|
||||
f"Getting store creators. featured={featured}, search={search_query}, sorted_by={sorted_by}, page={page}"
|
||||
)
|
||||
|
||||
# Build where clause with sanitized inputs
|
||||
where = {}
|
||||
|
||||
if featured:
|
||||
where["is_featured"] = featured
|
||||
|
||||
# Add search filter if provided, using parameterized queries
|
||||
if search_query:
|
||||
# Sanitize and validate search query by escaping special characters
|
||||
sanitized_query = search_query.strip()
|
||||
if not sanitized_query or len(sanitized_query) > 100: # Reasonable length limit
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Invalid search query"
|
||||
)
|
||||
|
||||
# Escape special SQL characters
|
||||
sanitized_query = (
|
||||
sanitized_query.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
.replace("[", "\\[")
|
||||
.replace("]", "\\]")
|
||||
.replace("'", "\\'")
|
||||
.replace('"', '\\"')
|
||||
.replace(";", "\\;")
|
||||
.replace("--", "\\--")
|
||||
.replace("/*", "\\/*")
|
||||
.replace("*/", "\\*/")
|
||||
)
|
||||
|
||||
where["OR"] = [
|
||||
{"username": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
{"name": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
try:
|
||||
# Validate pagination parameters
|
||||
if not isinstance(page, int) or page < 1:
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Invalid page number"
|
||||
)
|
||||
if not isinstance(page_size, int) or page_size < 1 or page_size > 100:
|
||||
raise backend.server.v2.store.exceptions.DatabaseError("Invalid page size")
|
||||
|
||||
# Get total count for pagination using sanitized where clause
|
||||
total = await prisma.models.Creator.prisma().count(
|
||||
where=prisma.types.CreatorWhereInput(**where)
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Add pagination with validated parameters
|
||||
skip = (page - 1) * page_size
|
||||
take = page_size
|
||||
|
||||
# Add sorting with validated sort parameter
|
||||
order = []
|
||||
valid_sort_fields = {"agent_rating", "agent_runs", "num_agents"}
|
||||
if sorted_by in valid_sort_fields:
|
||||
order.append({sorted_by: "desc"})
|
||||
else:
|
||||
order.append({"username": "asc"})
|
||||
|
||||
# Execute query with sanitized parameters
|
||||
creators = await prisma.models.Creator.prisma().find_many(
|
||||
where=prisma.types.CreatorWhereInput(**where),
|
||||
skip=skip,
|
||||
take=take,
|
||||
order=order,
|
||||
)
|
||||
|
||||
# Convert to response model
|
||||
creator_models = [
|
||||
backend.server.v2.store.model.Creator(
|
||||
username=creator.username,
|
||||
name=creator.name,
|
||||
description=creator.description,
|
||||
avatar_url=creator.avatar_url,
|
||||
num_agents=creator.num_agents,
|
||||
agent_rating=creator.agent_rating,
|
||||
agent_runs=creator.agent_runs,
|
||||
is_featured=creator.is_featured,
|
||||
)
|
||||
for creator in creators
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(creator_models)} creators")
|
||||
return backend.server.v2.store.model.CreatorsResponse(
|
||||
creators=creator_models,
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=page,
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store creators: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch store creators"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_store_creator_details(
|
||||
username: str,
|
||||
) -> backend.server.v2.store.model.CreatorDetails:
|
||||
logger.debug(f"Getting store creator details for {username}")
|
||||
|
||||
try:
|
||||
# Query creator details from database
|
||||
creator = await prisma.models.Creator.prisma().find_unique(
|
||||
where={"username": username}
|
||||
)
|
||||
|
||||
if not creator:
|
||||
logger.warning(f"Creator not found: {username}")
|
||||
raise backend.server.v2.store.exceptions.CreatorNotFoundError(
|
||||
f"Creator {username} not found"
|
||||
)
|
||||
|
||||
logger.debug(f"Found creator details for {username}")
|
||||
return backend.server.v2.store.model.CreatorDetails(
|
||||
name=creator.name,
|
||||
username=creator.username,
|
||||
description=creator.description,
|
||||
links=creator.links,
|
||||
avatar_url=creator.avatar_url,
|
||||
agent_rating=creator.agent_rating,
|
||||
agent_runs=creator.agent_runs,
|
||||
top_categories=creator.top_categories,
|
||||
)
|
||||
except backend.server.v2.store.exceptions.CreatorNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store creator details: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch creator details"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_store_submissions(
|
||||
user_id: str, page: int = 1, page_size: int = 20
|
||||
) -> backend.server.v2.store.model.StoreSubmissionsResponse:
|
||||
logger.debug(f"Getting store submissions for user {user_id}, page={page}")
|
||||
|
||||
try:
|
||||
# Calculate pagination values
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
where = prisma.types.StoreSubmissionWhereInput(user_id=user_id)
|
||||
# Query submissions from database
|
||||
submissions = await prisma.models.StoreSubmission.prisma().find_many(
|
||||
where=where, skip=skip, take=page_size, order=[{"date_submitted": "desc"}]
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
total = await prisma.models.StoreSubmission.prisma().count(where=where)
|
||||
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Convert to response models
|
||||
submission_models = [
|
||||
backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id=sub.agent_id,
|
||||
agent_version=sub.agent_version,
|
||||
name=sub.name,
|
||||
sub_heading=sub.sub_heading,
|
||||
slug=sub.slug,
|
||||
description=sub.description,
|
||||
image_urls=sub.image_urls or [],
|
||||
date_submitted=sub.date_submitted or datetime.now(),
|
||||
status=sub.status,
|
||||
runs=sub.runs or 0,
|
||||
rating=sub.rating or 0.0,
|
||||
)
|
||||
for sub in submissions
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(submission_models)} submissions")
|
||||
return backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
submissions=submission_models,
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=page,
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching store submissions: {str(e)}")
|
||||
# Return empty response rather than exposing internal errors
|
||||
return backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
submissions=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=page,
|
||||
total_items=0,
|
||||
total_pages=0,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def delete_store_submission(
|
||||
user_id: str,
|
||||
submission_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a store listing submission.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user
|
||||
submission_id: ID of the submission to be deleted
|
||||
|
||||
Returns:
|
||||
bool: True if the submission was successfully deleted, False otherwise
|
||||
"""
|
||||
logger.debug(f"Deleting store submission {submission_id} for user {user_id}")
|
||||
|
||||
try:
|
||||
# Verify the submission belongs to this user
|
||||
submission = await prisma.models.StoreListing.prisma().find_first(
|
||||
where={"agentId": submission_id, "owningUserId": user_id}
|
||||
)
|
||||
|
||||
if not submission:
|
||||
logger.warning(f"Submission not found for user {user_id}: {submission_id}")
|
||||
raise backend.server.v2.store.exceptions.SubmissionNotFoundError(
|
||||
f"Submission not found for this user. User ID: {user_id}, Submission ID: {submission_id}"
|
||||
)
|
||||
|
||||
# Delete the submission
|
||||
await prisma.models.StoreListing.prisma().delete(
|
||||
where=prisma.types.StoreListingWhereUniqueInput(id=submission.id)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully deleted submission {submission_id} for user {user_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting store submission: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def create_store_submission(
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
agent_version: int,
|
||||
slug: str,
|
||||
name: str,
|
||||
video_url: str | None = None,
|
||||
image_urls: list[str] = [],
|
||||
description: str = "",
|
||||
sub_heading: str = "",
|
||||
categories: list[str] = [],
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""
|
||||
Create a new store listing submission.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user submitting the listing
|
||||
agent_id: ID of the agent being submitted
|
||||
agent_version: Version of the agent being submitted
|
||||
slug: URL slug for the listing
|
||||
name: Name of the agent
|
||||
video_url: Optional URL to video demo
|
||||
image_urls: List of image URLs for the listing
|
||||
description: Description of the agent
|
||||
categories: List of categories for the agent
|
||||
|
||||
Returns:
|
||||
StoreSubmission: The created store submission
|
||||
"""
|
||||
logger.debug(
|
||||
f"Creating store submission for user {user_id}, agent {agent_id} v{agent_version}"
|
||||
)
|
||||
|
||||
try:
|
||||
# First verify the agent belongs to this user
|
||||
agent = await prisma.models.AgentGraph.prisma().find_first(
|
||||
where=prisma.types.AgentGraphWhereInput(
|
||||
id=agent_id, version=agent_version, userId=user_id
|
||||
)
|
||||
)
|
||||
|
||||
if not agent:
|
||||
logger.warning(
|
||||
f"Agent not found for user {user_id}: {agent_id} v{agent_version}"
|
||||
)
|
||||
raise backend.server.v2.store.exceptions.AgentNotFoundError(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
|
||||
listing = await prisma.models.StoreListing.prisma().find_first(
|
||||
where=prisma.types.StoreListingWhereInput(
|
||||
agentId=agent_id, owningUserId=user_id
|
||||
)
|
||||
)
|
||||
if listing is not None:
|
||||
logger.warning(f"Listing already exists for agent {agent_id}")
|
||||
raise backend.server.v2.store.exceptions.ListingExistsError(
|
||||
"Listing already exists for this agent"
|
||||
)
|
||||
|
||||
# Create the store listing
|
||||
listing = await prisma.models.StoreListing.prisma().create(
|
||||
data={
|
||||
"agentId": agent_id,
|
||||
"agentVersion": agent_version,
|
||||
"owningUserId": user_id,
|
||||
"createdAt": datetime.now(),
|
||||
"StoreListingVersions": {
|
||||
"create": {
|
||||
"agentId": agent_id,
|
||||
"agentVersion": agent_version,
|
||||
"slug": slug,
|
||||
"name": name,
|
||||
"videoUrl": video_url,
|
||||
"imageUrls": image_urls,
|
||||
"description": description,
|
||||
"categories": categories,
|
||||
"subHeading": sub_heading,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Created store listing for agent {agent_id}")
|
||||
# Return submission details
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
slug=slug,
|
||||
sub_heading=sub_heading,
|
||||
description=description,
|
||||
image_urls=image_urls,
|
||||
date_submitted=listing.createdAt,
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
)
|
||||
|
||||
except (
|
||||
backend.server.v2.store.exceptions.AgentNotFoundError,
|
||||
backend.server.v2.store.exceptions.ListingExistsError,
|
||||
):
|
||||
raise
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating store submission: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create store submission"
|
||||
) from e
|
||||
|
||||
|
||||
async def create_store_review(
|
||||
user_id: str,
|
||||
store_listing_version_id: str,
|
||||
score: int,
|
||||
comments: str | None = None,
|
||||
) -> backend.server.v2.store.model.StoreReview:
|
||||
try:
|
||||
review = await prisma.models.StoreListingReview.prisma().upsert(
|
||||
where={
|
||||
"storeListingVersionId_reviewByUserId": {
|
||||
"storeListingVersionId": store_listing_version_id,
|
||||
"reviewByUserId": user_id,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"create": {
|
||||
"reviewByUserId": user_id,
|
||||
"storeListingVersionId": store_listing_version_id,
|
||||
"score": score,
|
||||
"comments": comments,
|
||||
},
|
||||
"update": {
|
||||
"score": score,
|
||||
"comments": comments,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
return backend.server.v2.store.model.StoreReview(
|
||||
score=review.score,
|
||||
comments=review.comments,
|
||||
)
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating store review: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create store review"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_user_profile(
|
||||
user_id: str,
|
||||
) -> backend.server.v2.store.model.ProfileDetails:
|
||||
logger.debug(f"Getting user profile for {user_id}")
|
||||
|
||||
try:
|
||||
profile = await prisma.models.Profile.prisma().find_first(
|
||||
where={"userId": user_id} # type: ignore
|
||||
)
|
||||
|
||||
if not profile:
|
||||
logger.warning(f"Profile not found for user {user_id}")
|
||||
await prisma.models.Profile.prisma().create(
|
||||
data=prisma.types.ProfileCreateInput(
|
||||
userId=user_id,
|
||||
name="No Profile Data",
|
||||
username=f"{random.choice(['happy', 'clever', 'swift', 'bright', 'wise'])}-{random.choice(['fox', 'wolf', 'bear', 'eagle', 'owl'])}_{random.randint(1000,9999)}",
|
||||
description="No Profile Data",
|
||||
links=[],
|
||||
avatarUrl="",
|
||||
)
|
||||
)
|
||||
return backend.server.v2.store.model.ProfileDetails(
|
||||
name="No Profile Data",
|
||||
username="No Profile Data",
|
||||
description="No Profile Data",
|
||||
links=[],
|
||||
avatar_url="",
|
||||
)
|
||||
|
||||
return backend.server.v2.store.model.ProfileDetails(
|
||||
name=profile.name,
|
||||
username=profile.username,
|
||||
description=profile.description,
|
||||
links=profile.links,
|
||||
avatar_url=profile.avatarUrl,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user profile: {str(e)}")
|
||||
return backend.server.v2.store.model.ProfileDetails(
|
||||
name="No Profile Data",
|
||||
username="No Profile Data",
|
||||
description="No Profile Data",
|
||||
links=[],
|
||||
avatar_url="",
|
||||
)
|
||||
|
||||
|
||||
async def update_or_create_profile(
|
||||
user_id: str, profile: backend.server.v2.store.model.Profile
|
||||
) -> backend.server.v2.store.model.CreatorDetails:
|
||||
"""
|
||||
Update the store profile for a user. Creates a new profile if one doesn't exist.
|
||||
Only allows updating if the user_id matches the owning user.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user
|
||||
profile: Updated profile details
|
||||
|
||||
Returns:
|
||||
CreatorDetails: The updated profile
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is not authorized to update this profile
|
||||
"""
|
||||
logger.debug(f"Updating profile for user {user_id}")
|
||||
|
||||
try:
|
||||
# Check if profile exists for user
|
||||
existing_profile = await prisma.models.Profile.prisma().find_first(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
|
||||
# If no profile exists, create a new one
|
||||
if not existing_profile:
|
||||
logger.debug(f"Creating new profile for user {user_id}")
|
||||
# Create new profile since one doesn't exist
|
||||
new_profile = await prisma.models.Profile.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"name": profile.name,
|
||||
"username": profile.username,
|
||||
"description": profile.description,
|
||||
"links": profile.links,
|
||||
"avatarUrl": profile.avatar_url,
|
||||
}
|
||||
)
|
||||
|
||||
return backend.server.v2.store.model.CreatorDetails(
|
||||
name=new_profile.name,
|
||||
username=new_profile.username,
|
||||
description=new_profile.description,
|
||||
links=new_profile.links,
|
||||
avatar_url=new_profile.avatarUrl or "",
|
||||
agent_rating=0.0,
|
||||
agent_runs=0,
|
||||
top_categories=[],
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Updating existing profile for user {user_id}")
|
||||
# Update the existing profile
|
||||
updated_profile = await prisma.models.Profile.prisma().update(
|
||||
where={"id": existing_profile.id},
|
||||
data=prisma.types.ProfileUpdateInput(
|
||||
name=profile.name,
|
||||
username=profile.username,
|
||||
description=profile.description,
|
||||
links=profile.links,
|
||||
avatarUrl=profile.avatar_url,
|
||||
),
|
||||
)
|
||||
if updated_profile is None:
|
||||
logger.error(f"Failed to update profile for user {user_id}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to update profile"
|
||||
)
|
||||
|
||||
return backend.server.v2.store.model.CreatorDetails(
|
||||
name=updated_profile.name,
|
||||
username=updated_profile.username,
|
||||
description=updated_profile.description,
|
||||
links=updated_profile.links,
|
||||
avatar_url=updated_profile.avatarUrl or "",
|
||||
agent_rating=0.0,
|
||||
agent_runs=0,
|
||||
top_categories=[],
|
||||
)
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating profile: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to update profile"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_my_agents(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.MyAgentsResponse:
|
||||
logger.debug(f"Getting my agents for user {user_id}, page={page}")
|
||||
|
||||
try:
|
||||
agents_with_max_version = await prisma.models.AgentGraph.prisma().find_many(
|
||||
where=prisma.types.AgentGraphWhereInput(
|
||||
userId=user_id, StoreListing={"none": {"isDeleted": False}}
|
||||
),
|
||||
order=[{"version": "desc"}],
|
||||
distinct=["id"],
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
# store_listings = await prisma.models.StoreListing.prisma().find_many(
|
||||
# where=prisma.types.StoreListingWhereInput(
|
||||
# isDeleted=False,
|
||||
# ),
|
||||
# )
|
||||
|
||||
total = len(
|
||||
await prisma.models.AgentGraph.prisma().find_many(
|
||||
where=prisma.types.AgentGraphWhereInput(
|
||||
userId=user_id, StoreListing={"none": {"isDeleted": False}}
|
||||
),
|
||||
order=[{"version": "desc"}],
|
||||
distinct=["id"],
|
||||
)
|
||||
)
|
||||
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
agents = agents_with_max_version
|
||||
|
||||
my_agents = [
|
||||
backend.server.v2.store.model.MyAgent(
|
||||
agent_id=agent.id,
|
||||
agent_version=agent.version,
|
||||
agent_name=agent.name or "",
|
||||
last_edited=agent.updatedAt or agent.createdAt,
|
||||
)
|
||||
for agent in agents
|
||||
]
|
||||
|
||||
return backend.server.v2.store.model.MyAgentsResponse(
|
||||
agents=my_agents,
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=page,
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting my agents: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch my agents"
|
||||
) from e
|
||||
264
autogpt_platform/backend/backend/server/v2/store/db_test.py
Normal file
264
autogpt_platform/backend/backend/server/v2/store/db_test.py
Normal file
@@ -0,0 +1,264 @@
|
||||
from datetime import datetime
|
||||
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import pytest
|
||||
from prisma import Prisma
|
||||
|
||||
import backend.server.v2.store.db as db
|
||||
from backend.server.v2.store.model import Profile
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_prisma():
|
||||
# Don't register client if already registered
|
||||
try:
|
||||
Prisma()
|
||||
except prisma.errors.ClientAlreadyRegisteredError:
|
||||
pass
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_store_agents(mocker):
|
||||
# Mock data
|
||||
mock_agents = [
|
||||
prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video=None,
|
||||
agent_image=["image.jpg"],
|
||||
featured=False,
|
||||
creator_username="creator",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test heading",
|
||||
description="Test description",
|
||||
categories=[],
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
]
|
||||
|
||||
# Mock prisma calls
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_many = mocker.AsyncMock(return_value=mock_agents)
|
||||
mock_store_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_agents()
|
||||
|
||||
# Verify results
|
||||
assert len(result.agents) == 1
|
||||
assert result.agents[0].slug == "test-agent"
|
||||
assert result.pagination.total_items == 1
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_store_agent.return_value.find_many.assert_called_once()
|
||||
mock_store_agent.return_value.count.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_store_agent_details(mocker):
|
||||
# Mock data
|
||||
mock_agent = prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="video.mp4",
|
||||
agent_image=["image.jpg"],
|
||||
featured=False,
|
||||
creator_username="creator",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test heading",
|
||||
description="Test description",
|
||||
categories=["test"],
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
# Mock prisma call
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_agent_details("creator", "test-agent")
|
||||
|
||||
# Verify results
|
||||
assert result.slug == "test-agent"
|
||||
assert result.agent_name == "Test Agent"
|
||||
|
||||
# Verify mock called correctly
|
||||
mock_store_agent.return_value.find_first.assert_called_once_with(
|
||||
where={"creator_username": "creator", "slug": "test-agent"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_store_creator_details(mocker):
|
||||
# Mock data
|
||||
mock_creator_data = prisma.models.Creator(
|
||||
name="Test Creator",
|
||||
username="creator",
|
||||
description="Test description",
|
||||
links=["link1"],
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=1,
|
||||
agent_rating=4.5,
|
||||
agent_runs=10,
|
||||
top_categories=["test"],
|
||||
is_featured=False,
|
||||
)
|
||||
|
||||
# Mock prisma call
|
||||
mock_creator = mocker.patch("prisma.models.Creator.prisma")
|
||||
mock_creator.return_value.find_unique = mocker.AsyncMock()
|
||||
# Configure the mock to return values that will pass validation
|
||||
mock_creator.return_value.find_unique.return_value = mock_creator_data
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_creator_details("creator")
|
||||
|
||||
# Verify results
|
||||
assert result.username == "creator"
|
||||
assert result.name == "Test Creator"
|
||||
assert result.description == "Test description"
|
||||
assert result.avatar_url == "avatar.jpg"
|
||||
|
||||
# Verify mock called correctly
|
||||
mock_creator.return_value.find_unique.assert_called_once_with(
|
||||
where={"username": "creator"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_store_submission(mocker):
|
||||
# Mock data
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
id="agent-id",
|
||||
version=1,
|
||||
userId="user-id",
|
||||
createdAt=datetime.now(),
|
||||
isActive=True,
|
||||
isTemplate=False,
|
||||
)
|
||||
|
||||
mock_listing = prisma.models.StoreListing(
|
||||
id="listing-id",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isDeleted=False,
|
||||
isApproved=False,
|
||||
agentId="agent-id",
|
||||
agentVersion=1,
|
||||
owningUserId="user-id",
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
mock_store_listing = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_store_listing.return_value.create = mocker.AsyncMock(return_value=mock_listing)
|
||||
|
||||
# Call function
|
||||
result = await db.create_store_submission(
|
||||
user_id="user-id",
|
||||
agent_id="agent-id",
|
||||
agent_version=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result.name == "Test Agent"
|
||||
assert result.description == "Test description"
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_agent_graph.return_value.find_first.assert_called_once()
|
||||
mock_store_listing.return_value.find_first.assert_called_once()
|
||||
mock_store_listing.return_value.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile(mocker):
|
||||
# Mock data
|
||||
mock_profile = prisma.models.Profile(
|
||||
id="profile-id",
|
||||
name="Test Creator",
|
||||
username="creator",
|
||||
description="Test description",
|
||||
links=["link1"],
|
||||
avatarUrl="avatar.jpg",
|
||||
isFeatured=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
mock_profile_db.return_value.update = mocker.AsyncMock(return_value=mock_profile)
|
||||
|
||||
# Test data
|
||||
profile = Profile(
|
||||
name="Test Creator",
|
||||
username="creator",
|
||||
description="Test description",
|
||||
links=["link1"],
|
||||
avatar_url="avatar.jpg",
|
||||
is_featured=False,
|
||||
)
|
||||
|
||||
# Call function
|
||||
result = await db.update_or_create_profile("user-id", profile)
|
||||
|
||||
# Verify results
|
||||
assert result.username == "creator"
|
||||
assert result.name == "Test Creator"
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_profile_db.return_value.find_first.assert_called_once()
|
||||
mock_profile_db.return_value.update.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_profile(mocker):
|
||||
# Mock data
|
||||
mock_profile = prisma.models.Profile(
|
||||
id="profile-id",
|
||||
name="No Profile Data",
|
||||
username="testuser",
|
||||
description="Test description",
|
||||
links=["link1", "link2"],
|
||||
avatarUrl="avatar.jpg",
|
||||
isFeatured=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_unique = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
|
||||
# Call function
|
||||
result = await db.get_user_profile("user-id")
|
||||
|
||||
# Verify results
|
||||
assert result.name == "No Profile Data"
|
||||
assert result.username == "No Profile Data"
|
||||
assert result.description == "No Profile Data"
|
||||
assert result.links == []
|
||||
assert result.avatar_url == ""
|
||||
@@ -0,0 +1,76 @@
|
||||
class MediaUploadError(Exception):
|
||||
"""Base exception for media upload errors"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidFileTypeError(MediaUploadError):
|
||||
"""Raised when file type is not supported"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FileSizeTooLargeError(MediaUploadError):
|
||||
"""Raised when file size exceeds maximum limit"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FileReadError(MediaUploadError):
|
||||
"""Raised when there's an error reading the file"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StorageConfigError(MediaUploadError):
|
||||
"""Raised when storage configuration is invalid"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StorageUploadError(MediaUploadError):
|
||||
"""Raised when upload to storage fails"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StoreError(Exception):
|
||||
"""Base exception for store-related errors"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentNotFoundError(StoreError):
|
||||
"""Raised when an agent is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CreatorNotFoundError(StoreError):
|
||||
"""Raised when a creator is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ListingExistsError(StoreError):
|
||||
"""Raised when trying to create a listing that already exists"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseError(StoreError):
|
||||
"""Raised when there is an error interacting with the database"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ProfileNotFoundError(StoreError):
|
||||
"""Raised when a profile is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SubmissionNotFoundError(StoreError):
|
||||
"""Raised when a submission is not found"""
|
||||
|
||||
pass
|
||||
154
autogpt_platform/backend/backend/server/v2/store/media.py
Normal file
154
autogpt_platform/backend/backend/server/v2/store/media.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import fastapi
|
||||
from google.cloud import storage
|
||||
|
||||
import backend.server.v2.store.exceptions
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALLOWED_IMAGE_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp"}
|
||||
ALLOWED_VIDEO_TYPES = {"video/mp4", "video/webm"}
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
|
||||
async def upload_media(user_id: str, file: fastapi.UploadFile) -> str:
|
||||
|
||||
# Get file content for deeper validation
|
||||
try:
|
||||
content = await file.read(1024) # Read first 1KB for validation
|
||||
await file.seek(0) # Reset file pointer
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file content: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.FileReadError(
|
||||
"Failed to read file content"
|
||||
) from e
|
||||
|
||||
# Validate file signature/magic bytes
|
||||
if file.content_type in ALLOWED_IMAGE_TYPES:
|
||||
# Check image file signatures
|
||||
if content.startswith(b"\xFF\xD8\xFF"): # JPEG
|
||||
if file.content_type != "image/jpeg":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
elif content.startswith(b"\x89PNG\r\n\x1a\n"): # PNG
|
||||
if file.content_type != "image/png":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
elif content.startswith(b"GIF87a") or content.startswith(b"GIF89a"): # GIF
|
||||
if file.content_type != "image/gif":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
elif content.startswith(b"RIFF") and content[8:12] == b"WEBP": # WebP
|
||||
if file.content_type != "image/webp":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
else:
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"Invalid image file signature"
|
||||
)
|
||||
|
||||
elif file.content_type in ALLOWED_VIDEO_TYPES:
|
||||
# Check video file signatures
|
||||
if content.startswith(b"\x00\x00\x00") and (content[4:8] == b"ftyp"): # MP4
|
||||
if file.content_type != "video/mp4":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
elif content.startswith(b"\x1a\x45\xdf\xa3"): # WebM
|
||||
if file.content_type != "video/webm":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
else:
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"Invalid video file signature"
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Check required settings first before doing any file processing
|
||||
if not settings.config.media_gcs_bucket_name:
|
||||
logger.error("Missing GCS bucket name setting")
|
||||
raise backend.server.v2.store.exceptions.StorageConfigError(
|
||||
"Missing storage bucket configuration"
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate file type
|
||||
content_type = file.content_type
|
||||
if (
|
||||
content_type not in ALLOWED_IMAGE_TYPES
|
||||
and content_type not in ALLOWED_VIDEO_TYPES
|
||||
):
|
||||
logger.warning(f"Invalid file type attempted: {content_type}")
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
f"File type not supported. Must be jpeg, png, gif, webp, mp4 or webm. Content type: {content_type}"
|
||||
)
|
||||
|
||||
# Validate file size
|
||||
file_size = 0
|
||||
chunk_size = 8192 # 8KB chunks
|
||||
|
||||
try:
|
||||
while chunk := await file.read(chunk_size):
|
||||
file_size += len(chunk)
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
logger.warning(f"File size too large: {file_size} bytes")
|
||||
raise backend.server.v2.store.exceptions.FileSizeTooLargeError(
|
||||
"File too large. Maximum size is 50MB"
|
||||
)
|
||||
except backend.server.v2.store.exceptions.FileSizeTooLargeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file chunks: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.FileReadError(
|
||||
"Failed to read uploaded file"
|
||||
) from e
|
||||
|
||||
# Reset file pointer
|
||||
await file.seek(0)
|
||||
|
||||
# Generate unique filename
|
||||
filename = file.filename or ""
|
||||
file_ext = os.path.splitext(filename)[1].lower()
|
||||
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||
|
||||
# Construct storage path
|
||||
media_type = "images" if content_type in ALLOWED_IMAGE_TYPES else "videos"
|
||||
storage_path = f"users/{user_id}/{media_type}/{unique_filename}"
|
||||
|
||||
try:
|
||||
storage_client = storage.Client()
|
||||
bucket = storage_client.bucket(settings.config.media_gcs_bucket_name)
|
||||
blob = bucket.blob(storage_path)
|
||||
blob.content_type = content_type
|
||||
|
||||
file_bytes = await file.read()
|
||||
blob.upload_from_string(file_bytes, content_type=content_type)
|
||||
|
||||
public_url = blob.public_url
|
||||
|
||||
logger.info(f"Successfully uploaded file to: {storage_path}")
|
||||
return public_url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GCS storage error: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.StorageUploadError(
|
||||
"Failed to upload file to storage"
|
||||
) from e
|
||||
|
||||
except backend.server.v2.store.exceptions.MediaUploadError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error in upload_media")
|
||||
raise backend.server.v2.store.exceptions.MediaUploadError(
|
||||
"Unexpected error during media upload"
|
||||
) from e
|
||||
190
autogpt_platform/backend/backend/server/v2/store/media_test.py
Normal file
190
autogpt_platform/backend/backend/server/v2/store/media_test.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import io
|
||||
import unittest.mock
|
||||
|
||||
import fastapi
|
||||
import pytest
|
||||
import starlette.datastructures
|
||||
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.media
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(monkeypatch):
|
||||
settings = Settings()
|
||||
settings.config.media_gcs_bucket_name = "test-bucket"
|
||||
settings.config.google_application_credentials = "test-credentials"
|
||||
monkeypatch.setattr("backend.server.v2.store.media.Settings", lambda: settings)
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage_client(mocker):
|
||||
mock_client = unittest.mock.MagicMock()
|
||||
mock_bucket = unittest.mock.MagicMock()
|
||||
mock_blob = unittest.mock.MagicMock()
|
||||
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
mock_blob.public_url = "http://test-url/media/laptop.jpeg"
|
||||
|
||||
mocker.patch("google.cloud.storage.Client", return_value=mock_client)
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
async def test_upload_media_success(mock_settings, mock_storage_client):
|
||||
# Create test JPEG data with valid signature
|
||||
test_data = b"\xFF\xD8\xFF" + b"test data"
|
||||
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="laptop.jpeg",
|
||||
file=io.BytesIO(test_data),
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
mock_bucket = mock_storage_client.bucket.return_value
|
||||
mock_blob = mock_bucket.blob.return_value
|
||||
mock_blob.upload_from_string.assert_called_once()
|
||||
|
||||
|
||||
async def test_upload_media_invalid_type(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.txt",
|
||||
file=io.BytesIO(b"test data"),
|
||||
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
||||
)
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.InvalidFileTypeError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
mock_bucket = mock_storage_client.bucket.return_value
|
||||
mock_blob = mock_bucket.blob.return_value
|
||||
mock_blob.upload_from_string.assert_not_called()
|
||||
|
||||
|
||||
async def test_upload_media_missing_credentials(monkeypatch):
|
||||
settings = Settings()
|
||||
settings.config.media_gcs_bucket_name = ""
|
||||
settings.config.google_application_credentials = ""
|
||||
monkeypatch.setattr("backend.server.v2.store.media.Settings", lambda: settings)
|
||||
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="laptop.jpeg",
|
||||
file=io.BytesIO(b"\xFF\xD8\xFF" + b"test data"), # Valid JPEG signature
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.StorageConfigError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
|
||||
async def test_upload_media_video_type(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.mp4",
|
||||
file=io.BytesIO(b"\x00\x00\x00\x18ftypmp42"), # Valid MP4 signature
|
||||
headers=starlette.datastructures.Headers({"content-type": "video/mp4"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
mock_bucket = mock_storage_client.bucket.return_value
|
||||
mock_blob = mock_bucket.blob.return_value
|
||||
mock_blob.upload_from_string.assert_called_once()
|
||||
|
||||
|
||||
async def test_upload_media_file_too_large(mock_settings, mock_storage_client):
|
||||
large_data = b"\xFF\xD8\xFF" + b"x" * (
|
||||
50 * 1024 * 1024 + 1
|
||||
) # 50MB + 1 byte with valid JPEG signature
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="laptop.jpeg",
|
||||
file=io.BytesIO(large_data),
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.FileSizeTooLargeError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
|
||||
async def test_upload_media_file_read_error(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="laptop.jpeg",
|
||||
file=io.BytesIO(b""), # Empty file that will raise error on read
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
test_file.read = unittest.mock.AsyncMock(side_effect=Exception("Read error"))
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.FileReadError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
|
||||
async def test_upload_media_png_success(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.png",
|
||||
file=io.BytesIO(b"\x89PNG\r\n\x1a\n"), # Valid PNG signature
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/png"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
|
||||
|
||||
async def test_upload_media_gif_success(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.gif",
|
||||
file=io.BytesIO(b"GIF89a"), # Valid GIF signature
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/gif"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
|
||||
|
||||
async def test_upload_media_webp_success(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.webp",
|
||||
file=io.BytesIO(b"RIFF\x00\x00\x00\x00WEBP"), # Valid WebP signature
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/webp"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
|
||||
|
||||
async def test_upload_media_webm_success(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.webm",
|
||||
file=io.BytesIO(b"\x1a\x45\xdf\xa3"), # Valid WebM signature
|
||||
headers=starlette.datastructures.Headers({"content-type": "video/webm"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
|
||||
|
||||
async def test_upload_media_mismatched_signature(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.jpeg",
|
||||
file=io.BytesIO(b"\x89PNG\r\n\x1a\n"), # PNG signature with JPEG content type
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.InvalidFileTypeError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
|
||||
async def test_upload_media_invalid_signature(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.jpeg",
|
||||
file=io.BytesIO(b"invalid signature"),
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.InvalidFileTypeError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
152
autogpt_platform/backend/backend/server/v2/store/model.py
Normal file
152
autogpt_platform/backend/backend/server/v2/store/model.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import datetime
|
||||
from typing import List
|
||||
|
||||
import prisma.enums
|
||||
import pydantic
|
||||
|
||||
|
||||
class Pagination(pydantic.BaseModel):
|
||||
total_items: int = pydantic.Field(
|
||||
description="Total number of items.", examples=[42]
|
||||
)
|
||||
total_pages: int = pydantic.Field(
|
||||
description="Total number of pages.", examples=[97]
|
||||
)
|
||||
current_page: int = pydantic.Field(
|
||||
description="Current_page page number.", examples=[1]
|
||||
)
|
||||
page_size: int = pydantic.Field(
|
||||
description="Number of items per page.", examples=[25]
|
||||
)
|
||||
|
||||
|
||||
class MyAgent(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
agent_name: str
|
||||
last_edited: datetime.datetime
|
||||
|
||||
|
||||
class MyAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[MyAgent]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreAgent(pydantic.BaseModel):
|
||||
slug: str
|
||||
agent_name: str
|
||||
agent_image: str
|
||||
creator: str
|
||||
creator_avatar: str
|
||||
sub_heading: str
|
||||
description: str
|
||||
runs: int
|
||||
rating: float
|
||||
|
||||
|
||||
class StoreAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[StoreAgent]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreAgentDetails(pydantic.BaseModel):
|
||||
store_listing_version_id: str
|
||||
slug: str
|
||||
agent_name: str
|
||||
agent_video: str
|
||||
agent_image: list[str]
|
||||
creator: str
|
||||
creator_avatar: str
|
||||
sub_heading: str
|
||||
description: str
|
||||
categories: list[str]
|
||||
runs: int
|
||||
rating: float
|
||||
versions: list[str]
|
||||
last_updated: datetime.datetime
|
||||
|
||||
|
||||
class Creator(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
avatar_url: str
|
||||
num_agents: int
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
is_featured: bool
|
||||
|
||||
|
||||
class CreatorsResponse(pydantic.BaseModel):
|
||||
creators: List[Creator]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class CreatorDetails(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
top_categories: list[str]
|
||||
|
||||
|
||||
class Profile(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str
|
||||
is_featured: bool = False
|
||||
|
||||
|
||||
class StoreSubmission(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
name: str
|
||||
sub_heading: str
|
||||
slug: str
|
||||
description: str
|
||||
image_urls: list[str]
|
||||
date_submitted: datetime.datetime
|
||||
status: prisma.enums.SubmissionStatus
|
||||
runs: int
|
||||
rating: float
|
||||
|
||||
|
||||
class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||
submissions: list[StoreSubmission]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreSubmissionRequest(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
slug: str
|
||||
name: str
|
||||
sub_heading: str
|
||||
video_url: str | None = None
|
||||
image_urls: list[str] = []
|
||||
description: str = ""
|
||||
categories: list[str] = []
|
||||
|
||||
|
||||
class ProfileDetails(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str | None = None
|
||||
|
||||
|
||||
class StoreReview(pydantic.BaseModel):
|
||||
score: int
|
||||
comments: str | None = None
|
||||
|
||||
|
||||
class StoreReviewCreate(pydantic.BaseModel):
|
||||
store_listing_version_id: str
|
||||
score: int
|
||||
comments: str | None = None
|
||||
195
autogpt_platform/backend/backend/server/v2/store/model_test.py
Normal file
195
autogpt_platform/backend/backend/server/v2/store/model_test.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import datetime
|
||||
|
||||
import prisma.enums
|
||||
|
||||
import backend.server.v2.store.model
|
||||
|
||||
|
||||
def test_pagination():
|
||||
pagination = backend.server.v2.store.model.Pagination(
|
||||
total_items=100, total_pages=5, current_page=2, page_size=20
|
||||
)
|
||||
assert pagination.total_items == 100
|
||||
assert pagination.total_pages == 5
|
||||
assert pagination.current_page == 2
|
||||
assert pagination.page_size == 20
|
||||
|
||||
|
||||
def test_store_agent():
|
||||
agent = backend.server.v2.store.model.StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="test.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
assert agent.slug == "test-agent"
|
||||
assert agent.agent_name == "Test Agent"
|
||||
assert agent.runs == 50
|
||||
assert agent.rating == 4.5
|
||||
|
||||
|
||||
def test_store_agents_response():
|
||||
response = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="test.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.agents) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_store_agent_details():
|
||||
details = backend.server.v2.store.model.StoreAgentDetails(
|
||||
store_listing_version_id="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="video.mp4",
|
||||
agent_image=["image1.jpg", "image2.jpg"],
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
categories=["cat1", "cat2"],
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
versions=["1.0", "2.0"],
|
||||
last_updated=datetime.datetime.now(),
|
||||
)
|
||||
assert details.slug == "test-agent"
|
||||
assert len(details.agent_image) == 2
|
||||
assert len(details.categories) == 2
|
||||
assert len(details.versions) == 2
|
||||
|
||||
|
||||
def test_creator():
|
||||
creator = backend.server.v2.store.model.Creator(
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=5,
|
||||
is_featured=False,
|
||||
)
|
||||
assert creator.name == "Test Creator"
|
||||
assert creator.num_agents == 5
|
||||
|
||||
|
||||
def test_creators_response():
|
||||
response = backend.server.v2.store.model.CreatorsResponse(
|
||||
creators=[
|
||||
backend.server.v2.store.model.Creator(
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=5,
|
||||
is_featured=False,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.creators) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_creator_details():
|
||||
details = backend.server.v2.store.model.CreatorDetails(
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
links=["link1.com", "link2.com"],
|
||||
avatar_url="avatar.jpg",
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
top_categories=["cat1", "cat2"],
|
||||
)
|
||||
assert details.name == "Test Creator"
|
||||
assert len(details.links) == 2
|
||||
assert details.agent_rating == 4.8
|
||||
assert len(details.top_categories) == 2
|
||||
|
||||
|
||||
def test_store_submission():
|
||||
submission = backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
description="Test description",
|
||||
image_urls=["image1.jpg", "image2.jpg"],
|
||||
date_submitted=datetime.datetime(2023, 1, 1),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
assert submission.name == "Test Agent"
|
||||
assert len(submission.image_urls) == 2
|
||||
assert submission.status == prisma.enums.SubmissionStatus.PENDING
|
||||
|
||||
|
||||
def test_store_submissions_response():
|
||||
response = backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
description="Test description",
|
||||
image_urls=["image1.jpg"],
|
||||
date_submitted=datetime.datetime(2023, 1, 1),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.submissions) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_store_submission_request():
|
||||
request = backend.server.v2.store.model.StoreSubmissionRequest(
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
sub_heading="Test subheading",
|
||||
video_url="video.mp4",
|
||||
image_urls=["image1.jpg", "image2.jpg"],
|
||||
description="Test description",
|
||||
categories=["cat1", "cat2"],
|
||||
)
|
||||
assert request.agent_id == "agent123"
|
||||
assert request.agent_version == 1
|
||||
assert len(request.image_urls) == 2
|
||||
assert len(request.categories) == 2
|
||||
441
autogpt_platform/backend/backend/server/v2/store/routes.py
Normal file
441
autogpt_platform/backend/backend/server/v2/store/routes.py
Normal file
@@ -0,0 +1,441 @@
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.media
|
||||
import backend.server.v2.store.model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
|
||||
##############################################
|
||||
############### Profile Endpoints ############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get("/profile", tags=["store", "private"])
|
||||
async def get_profile(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
]
|
||||
) -> backend.server.v2.store.model.ProfileDetails:
|
||||
"""
|
||||
Get the profile details for the authenticated user.
|
||||
"""
|
||||
try:
|
||||
profile = await backend.server.v2.store.db.get_user_profile(user_id)
|
||||
return profile
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting user profile")
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
"/profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def update_or_create_profile(
|
||||
profile: backend.server.v2.store.model.Profile,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> backend.server.v2.store.model.CreatorDetails:
|
||||
"""
|
||||
Update the store profile for the authenticated user.
|
||||
|
||||
Args:
|
||||
profile (Profile): The updated profile details
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
CreatorDetails: The updated profile
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error updating the profile
|
||||
"""
|
||||
try:
|
||||
updated_profile = await backend.server.v2.store.db.update_or_create_profile(
|
||||
user_id=user_id, profile=profile
|
||||
)
|
||||
return updated_profile
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst updating profile")
|
||||
raise
|
||||
|
||||
|
||||
##############################################
|
||||
############### Agent Endpoints ##############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get("/agents", tags=["store", "public"])
|
||||
async def get_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
sorted_by: str | None = None,
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.StoreAgentsResponse:
|
||||
"""
|
||||
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||
|
||||
Args:
|
||||
featured (bool, optional): Filter to only show featured agents. Defaults to False.
|
||||
creator (str | None, optional): Filter agents by creator username. Defaults to None.
|
||||
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
|
||||
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
|
||||
category (str | None, optional): Filter agents by category. Defaults to None.
|
||||
page (int, optional): Page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of agents per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
StoreAgentsResponse: Paginated list of agents matching the filters
|
||||
|
||||
Raises:
|
||||
HTTPException: If page or page_size are less than 1
|
||||
|
||||
Used for:
|
||||
- Home Page Featured Agents
|
||||
- Home Page Top Agents
|
||||
- Search Results
|
||||
- Agent Details - Other Agents By Creator
|
||||
- Agent Details - Similar Agents
|
||||
- Creator Details - Agents By Creator
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
try:
|
||||
agents = await backend.server.v2.store.db.get_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
except Exception:
|
||||
logger.exception("Exception occured whilst getting store agents")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/agents/{username}/{agent_name}", tags=["store", "public"])
|
||||
async def get_agent(
|
||||
username: str, agent_name: str
|
||||
) -> backend.server.v2.store.model.StoreAgentDetails:
|
||||
"""
|
||||
This is only used on the AgentDetails Page
|
||||
|
||||
It returns the store listing agents details.
|
||||
"""
|
||||
try:
|
||||
agent = await backend.server.v2.store.db.get_store_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
return agent
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store agent details")
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/{username}/{agent_name}/review",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def create_review(
|
||||
username: str,
|
||||
agent_name: str,
|
||||
review: backend.server.v2.store.model.StoreReviewCreate,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> backend.server.v2.store.model.StoreReview:
|
||||
"""
|
||||
Create a review for a store agent.
|
||||
|
||||
Args:
|
||||
username: Creator's username
|
||||
agent_name: Name/slug of the agent
|
||||
review: Review details including score and optional comments
|
||||
user_id: ID of authenticated user creating the review
|
||||
|
||||
Returns:
|
||||
The created review
|
||||
"""
|
||||
try:
|
||||
# Create the review
|
||||
created_review = await backend.server.v2.store.db.create_store_review(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=review.store_listing_version_id,
|
||||
score=review.score,
|
||||
comments=review.comments,
|
||||
)
|
||||
|
||||
return created_review
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst creating store review")
|
||||
raise
|
||||
|
||||
|
||||
##############################################
|
||||
############# Creator Endpoints #############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get("/creators", tags=["store", "public"])
|
||||
async def get_creators(
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.CreatorsResponse:
|
||||
"""
|
||||
This is needed for:
|
||||
- Home Page Featured Creators
|
||||
- Search Results Page
|
||||
|
||||
---
|
||||
|
||||
To support this functionality we need:
|
||||
- featured: bool - to limit the list to just featured agents
|
||||
- search_query: str - vector search based on the creators profile description.
|
||||
- sorted_by: [agent_rating, agent_runs] -
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
try:
|
||||
creators = await backend.server.v2.store.db.get_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return creators
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store creators")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/creator/{username}", tags=["store", "public"])
|
||||
async def get_creator(username: str) -> backend.server.v2.store.model.CreatorDetails:
|
||||
"""
|
||||
Get the details of a creator
|
||||
- Creator Details Page
|
||||
"""
|
||||
try:
|
||||
creator = await backend.server.v2.store.db.get_store_creator_details(
|
||||
username=username
|
||||
)
|
||||
return creator
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting creator details")
|
||||
raise
|
||||
|
||||
|
||||
############################################
|
||||
############# Store Submissions ###############
|
||||
############################################
|
||||
@router.get(
|
||||
"/myagents",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def get_my_agents(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
]
|
||||
) -> backend.server.v2.store.model.MyAgentsResponse:
|
||||
try:
|
||||
agents = await backend.server.v2.store.db.get_my_agents(user_id)
|
||||
return agents
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting my agents")
|
||||
raise
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/submissions/{submission_id}",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def delete_submission(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
submission_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a store listing submission.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
submission_id (str): ID of the submission to be deleted
|
||||
|
||||
Returns:
|
||||
bool: True if the submission was successfully deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await backend.server.v2.store.db.delete_store_submission(
|
||||
user_id=user_id,
|
||||
submission_id=submission_id,
|
||||
)
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst deleting store submission")
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/submissions",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def get_submissions(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.StoreSubmissionsResponse:
|
||||
"""
|
||||
Get a paginated list of store submissions for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
page (int, optional): Page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of submissions per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
StoreListingsResponse: Paginated list of store submissions
|
||||
|
||||
Raises:
|
||||
HTTPException: If page or page_size are less than 1
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
try:
|
||||
listings = await backend.server.v2.store.db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store submissions")
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def create_submission(
|
||||
submission_request: backend.server.v2.store.model.StoreSubmissionRequest,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""
|
||||
Create a new store listing submission.
|
||||
|
||||
Args:
|
||||
submission_request (StoreSubmissionRequest): The submission details
|
||||
user_id (str): ID of the authenticated user submitting the listing
|
||||
|
||||
Returns:
|
||||
StoreSubmission: The created store submission
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error creating the submission
|
||||
"""
|
||||
try:
|
||||
submission = await backend.server.v2.store.db.create_store_submission(
|
||||
user_id=user_id,
|
||||
agent_id=submission_request.agent_id,
|
||||
agent_version=submission_request.agent_version,
|
||||
slug=submission_request.slug,
|
||||
name=submission_request.name,
|
||||
video_url=submission_request.video_url,
|
||||
image_urls=submission_request.image_urls,
|
||||
description=submission_request.description,
|
||||
sub_heading=submission_request.sub_heading,
|
||||
categories=submission_request.categories,
|
||||
)
|
||||
return submission
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst creating store submission")
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/media",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def upload_submission_media(
|
||||
file: fastapi.UploadFile,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> str:
|
||||
"""
|
||||
Upload media (images/videos) for a store listing submission.
|
||||
|
||||
Args:
|
||||
file (UploadFile): The media file to upload
|
||||
user_id (str): ID of the authenticated user uploading the media
|
||||
|
||||
Returns:
|
||||
str: URL of the uploaded media file
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error uploading the media
|
||||
"""
|
||||
try:
|
||||
media_url = await backend.server.v2.store.media.upload_media(
|
||||
user_id=user_id, file=file
|
||||
)
|
||||
return media_url
|
||||
except Exception as e:
|
||||
logger.exception("Exception occurred whilst uploading submission media")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail=f"Failed to upload media file: {str(e)}"
|
||||
)
|
||||
552
autogpt_platform/backend/backend/server/v2/store/routes_test.py
Normal file
552
autogpt_platform/backend/backend/server/v2/store/routes_test.py
Normal file
@@ -0,0 +1,552 @@
|
||||
import datetime
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import prisma.enums
|
||||
import pytest_mock
|
||||
|
||||
import backend.server.v2.store.model
|
||||
import backend.server.v2.store.routes
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(backend.server.v2.store.routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
def override_auth_middleware():
|
||||
"""Override auth middleware for testing"""
|
||||
return {"sub": "test-user-id"}
|
||||
|
||||
|
||||
def override_get_user_id():
|
||||
"""Override get_user_id for testing"""
|
||||
return "test-user-id"
|
||||
|
||||
|
||||
app.dependency_overrides[autogpt_libs.auth.middleware.auth_middleware] = (
|
||||
override_auth_middleware
|
||||
)
|
||||
app.dependency_overrides[autogpt_libs.auth.depends.get_user_id] = override_get_user_id
|
||||
|
||||
|
||||
def test_get_agents_defaults(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=0,
|
||||
total_items=0,
|
||||
total_pages=0,
|
||||
page_size=10,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert data.pagination.total_pages == 0
|
||||
assert data.agents == []
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agents_featured(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
slug="featured-agent",
|
||||
agent_name="Featured Agent",
|
||||
agent_image="featured.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar1.jpg",
|
||||
sub_heading="Featured agent subheading",
|
||||
description="Featured agent description",
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?featured=true")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert len(data.agents) == 1
|
||||
assert data.agents[0].slug == "featured-agent"
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=True,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agents_by_creator(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
slug="creator-agent",
|
||||
agent_name="Creator Agent",
|
||||
agent_image="agent.jpg",
|
||||
creator="specific-creator",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Creator agent subheading",
|
||||
description="Creator agent description",
|
||||
runs=50,
|
||||
rating=4.0,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?creator=specific-creator")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert len(data.agents) == 1
|
||||
assert data.agents[0].creator == "specific-creator"
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creator="specific-creator",
|
||||
sorted_by=None,
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agents_sorted(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
slug="top-agent",
|
||||
agent_name="Top Agent",
|
||||
agent_image="top.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar1.jpg",
|
||||
sub_heading="Top agent subheading",
|
||||
description="Top agent description",
|
||||
runs=1000,
|
||||
rating=5.0,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?sorted_by=runs")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert len(data.agents) == 1
|
||||
assert data.agents[0].runs == 1000
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by="runs",
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agents_search(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
slug="search-agent",
|
||||
agent_name="Search Agent",
|
||||
agent_image="search.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar1.jpg",
|
||||
sub_heading="Search agent subheading",
|
||||
description="Specific search term description",
|
||||
runs=75,
|
||||
rating=4.2,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?search_query=specific")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert len(data.agents) == 1
|
||||
assert "specific" in data.agents[0].description.lower()
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
search_query="specific",
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agents_category(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
slug="category-agent",
|
||||
agent_name="Category Agent",
|
||||
agent_image="category.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar1.jpg",
|
||||
sub_heading="Category agent subheading",
|
||||
description="Category agent description",
|
||||
runs=60,
|
||||
rating=4.1,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?category=test-category")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert len(data.agents) == 1
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
search_query=None,
|
||||
category="test-category",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agents_pagination(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
slug=f"agent-{i}",
|
||||
agent_name=f"Agent {i}",
|
||||
agent_image=f"agent{i}.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar1.jpg",
|
||||
sub_heading=f"Agent {i} subheading",
|
||||
description=f"Agent {i} description",
|
||||
runs=i * 10,
|
||||
rating=4.0,
|
||||
)
|
||||
for i in range(5)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=2,
|
||||
total_items=15,
|
||||
total_pages=3,
|
||||
page_size=5,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?page=2&page_size=5")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert len(data.agents) == 5
|
||||
assert data.pagination.current_page == 2
|
||||
assert data.pagination.page_size == 5
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=2,
|
||||
page_size=5,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agents_malformed_request(mocker: pytest_mock.MockFixture):
|
||||
# Test with invalid page number
|
||||
response = client.get("/agents?page=-1")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with invalid page size
|
||||
response = client.get("/agents?page_size=0")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with non-numeric values
|
||||
response = client.get("/agents?page=abc&page_size=def")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Verify no DB calls were made
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call.assert_not_called()
|
||||
|
||||
|
||||
def test_get_agent_details(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentDetails(
|
||||
store_listing_version_id="test-version-id",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="video.mp4",
|
||||
agent_image=["image1.jpg", "image2.jpg"],
|
||||
creator="creator1",
|
||||
creator_avatar="avatar1.jpg",
|
||||
sub_heading="Test agent subheading",
|
||||
description="Test agent description",
|
||||
categories=["category1", "category2"],
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
versions=["1.0.0", "1.1.0"],
|
||||
last_updated=datetime.datetime.now(),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agent_details")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/agents/creator1/test-agent")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.StoreAgentDetails.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert data.agent_name == "Test Agent"
|
||||
assert data.creator == "creator1"
|
||||
mock_db_call.assert_called_once_with(username="creator1", agent_name="test-agent")
|
||||
|
||||
|
||||
def test_get_creators_defaults(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.CreatorsResponse(
|
||||
creators=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=0,
|
||||
total_items=0,
|
||||
total_pages=0,
|
||||
page_size=10,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_creators")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/creators")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.CreatorsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert data.pagination.total_pages == 0
|
||||
assert data.creators == []
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False, search_query=None, sorted_by=None, page=1, page_size=20
|
||||
)
|
||||
|
||||
|
||||
def test_get_creators_pagination(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.CreatorsResponse(
|
||||
creators=[
|
||||
backend.server.v2.store.model.Creator(
|
||||
name=f"Creator {i}",
|
||||
username=f"creator{i}",
|
||||
description=f"Creator {i} description",
|
||||
avatar_url=f"avatar{i}.jpg",
|
||||
num_agents=1,
|
||||
agent_rating=4.5,
|
||||
agent_runs=100,
|
||||
is_featured=False,
|
||||
)
|
||||
for i in range(5)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=2,
|
||||
total_items=15,
|
||||
total_pages=3,
|
||||
page_size=5,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_creators")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/creators?page=2&page_size=5")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.CreatorsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert len(data.creators) == 5
|
||||
assert data.pagination.current_page == 2
|
||||
assert data.pagination.page_size == 5
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False, search_query=None, sorted_by=None, page=2, page_size=5
|
||||
)
|
||||
|
||||
|
||||
def test_get_creators_malformed_request(mocker: pytest_mock.MockFixture):
|
||||
# Test with invalid page number
|
||||
response = client.get("/creators?page=-1")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with invalid page size
|
||||
response = client.get("/creators?page_size=0")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with non-numeric values
|
||||
response = client.get("/creators?page=abc&page_size=def")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Verify no DB calls were made
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_creators")
|
||||
mock_db_call.assert_not_called()
|
||||
|
||||
|
||||
def test_get_creator_details(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.CreatorDetails(
|
||||
name="Test User",
|
||||
username="creator1",
|
||||
description="Test creator description",
|
||||
links=["link1.com", "link2.com"],
|
||||
avatar_url="avatar.jpg",
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
top_categories=["category1", "category2"],
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_creator_details")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/creator/creator1")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.CreatorDetails.model_validate(response.json())
|
||||
assert data.username == "creator1"
|
||||
assert data.name == "Test User"
|
||||
mock_db_call.assert_called_once_with(username="creator1")
|
||||
|
||||
|
||||
def test_get_submissions_success(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
backend.server.v2.store.model.StoreSubmission(
|
||||
name="Test Agent",
|
||||
description="Test agent description",
|
||||
image_urls=["test.jpg"],
|
||||
date_submitted=datetime.datetime.now(),
|
||||
status=prisma.enums.SubmissionStatus.APPROVED,
|
||||
runs=50,
|
||||
rating=4.2,
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
sub_heading="Test agent subheading",
|
||||
slug="test-agent",
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_submissions")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/submissions")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.StoreSubmissionsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert len(data.submissions) == 1
|
||||
assert data.submissions[0].name == "Test Agent"
|
||||
assert data.pagination.current_page == 1
|
||||
mock_db_call.assert_called_once_with(user_id="test-user-id", page=1, page_size=20)
|
||||
|
||||
|
||||
def test_get_submissions_pagination(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
submissions=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=2,
|
||||
total_items=10,
|
||||
total_pages=2,
|
||||
page_size=5,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_submissions")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/submissions?page=2&page_size=5")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.StoreSubmissionsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert data.pagination.current_page == 2
|
||||
assert data.pagination.page_size == 5
|
||||
mock_db_call.assert_called_once_with(user_id="test-user-id", page=2, page_size=5)
|
||||
|
||||
|
||||
def test_get_submissions_malformed_request(mocker: pytest_mock.MockFixture):
|
||||
# Test with invalid page number
|
||||
response = client.get("/submissions?page=-1")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with invalid page size
|
||||
response = client.get("/submissions?page_size=0")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with non-numeric values
|
||||
response = client.get("/submissions?page=abc&page_size=def")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Verify no DB calls were made
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_submissions")
|
||||
mock_db_call.assert_not_called()
|
||||
@@ -1,8 +1,10 @@
|
||||
import ipaddress
|
||||
import re
|
||||
import socket
|
||||
from typing import Callable
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
import idna
|
||||
import requests as req
|
||||
|
||||
from backend.util.settings import Config
|
||||
@@ -21,8 +23,23 @@ BLOCKED_IP_NETWORKS = [
|
||||
# --8<-- [end:BLOCKED_IP_NETWORKS]
|
||||
]
|
||||
|
||||
ALLOWED_SCHEMES = ["http", "https"]
|
||||
HOSTNAME_REGEX = re.compile(r"^[A-Za-z0-9.-]+$") # Basic DNS-safe hostname pattern
|
||||
|
||||
def is_ip_blocked(ip: str) -> bool:
|
||||
|
||||
def _canonicalize_url(url: str) -> str:
|
||||
# Strip spaces and trailing slashes
|
||||
url = url.strip().strip("/")
|
||||
# Ensure the URL starts with http:// or https://
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "http://" + url
|
||||
|
||||
# Replace backslashes with forward slashes to avoid parsing ambiguities
|
||||
url = url.replace("\\", "/")
|
||||
return url
|
||||
|
||||
|
||||
def _is_ip_blocked(ip: str) -> bool:
|
||||
"""
|
||||
Checks if the IP address is in a blocked network.
|
||||
"""
|
||||
@@ -35,29 +52,51 @@ def validate_url(url: str, trusted_origins: list[str]) -> str:
|
||||
Validates the URL to prevent SSRF attacks by ensuring it does not point to a private
|
||||
or untrusted IP address, unless whitelisted.
|
||||
"""
|
||||
url = url.strip().strip("/")
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "http://" + url
|
||||
url = _canonicalize_url(url)
|
||||
parsed = urlparse(url)
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
hostname = parsed_url.hostname
|
||||
# Check scheme
|
||||
if parsed.scheme not in ALLOWED_SCHEMES:
|
||||
raise ValueError(
|
||||
f"Scheme '{parsed.scheme}' is not allowed. Only HTTP/HTTPS are supported."
|
||||
)
|
||||
|
||||
if not hostname:
|
||||
raise ValueError(f"Invalid URL: Unable to determine hostname from {url}")
|
||||
# Validate and IDNA encode the hostname
|
||||
if not parsed.hostname:
|
||||
raise ValueError("Invalid URL: No hostname found.")
|
||||
|
||||
if any(hostname == origin for origin in trusted_origins):
|
||||
# IDNA encode to prevent Unicode domain attacks
|
||||
try:
|
||||
ascii_hostname = idna.encode(parsed.hostname).decode("ascii")
|
||||
except idna.IDNAError:
|
||||
raise ValueError("Invalid hostname with unsupported characters.")
|
||||
|
||||
# Check hostname characters
|
||||
if not HOSTNAME_REGEX.match(ascii_hostname):
|
||||
raise ValueError("Hostname contains invalid characters.")
|
||||
|
||||
# Rebuild the URL with the normalized, IDNA-encoded hostname
|
||||
parsed = parsed._replace(netloc=ascii_hostname)
|
||||
url = str(urlunparse(parsed))
|
||||
|
||||
# Check if hostname is a trusted origin (exact match)
|
||||
if ascii_hostname in trusted_origins:
|
||||
return url
|
||||
|
||||
# Resolve all IP addresses for the hostname
|
||||
ip_addresses = {result[4][0] for result in socket.getaddrinfo(hostname, None)}
|
||||
if not ip_addresses:
|
||||
raise ValueError(f"Unable to resolve IP address for {hostname}")
|
||||
try:
|
||||
ip_addresses = {res[4][0] for res in socket.getaddrinfo(ascii_hostname, None)}
|
||||
except socket.gaierror:
|
||||
raise ValueError(f"Unable to resolve IP address for hostname {ascii_hostname}")
|
||||
|
||||
# Check if all IP addresses are global
|
||||
if not ip_addresses:
|
||||
raise ValueError(f"No IP addresses found for {ascii_hostname}")
|
||||
|
||||
# Check if any resolved IP address falls into blocked ranges
|
||||
for ip in ip_addresses:
|
||||
if is_ip_blocked(ip):
|
||||
if _is_ip_blocked(ip):
|
||||
raise ValueError(
|
||||
f"Access to private IP address at {hostname}: {ip} is not allowed."
|
||||
f"Access to private IP address {ip} for hostname {ascii_hostname} is not allowed."
|
||||
)
|
||||
|
||||
return url
|
||||
|
||||
@@ -148,6 +148,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
"This value is then used to generate redirect URLs for OAuth flows.",
|
||||
)
|
||||
|
||||
media_gcs_bucket_name: str = Field(
|
||||
default="",
|
||||
description="The name of the Google Cloud Storage bucket for media files",
|
||||
)
|
||||
|
||||
@field_validator("platform_base_url", "frontend_base_url")
|
||||
@classmethod
|
||||
def validate_platform_base_url(cls, v: str, info: ValidationInfo) -> str:
|
||||
|
||||
@@ -60,9 +60,7 @@ async def wait_execution(
|
||||
timeout: int = 20,
|
||||
) -> Sequence[ExecutionResult]:
|
||||
async def is_execution_completed():
|
||||
status = await AgentServer().test_get_graph_run_status(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
status = await AgentServer().test_get_graph_run_status(graph_exec_id, user_id)
|
||||
log.info(f"Execution status: {status}")
|
||||
if status == ExecutionStatus.FAILED:
|
||||
log.info("Execution failed")
|
||||
|
||||
22
autogpt_platform/backend/backend/util/text.py
Normal file
22
autogpt_platform/backend/backend/util/text.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import re
|
||||
|
||||
from jinja2 import BaseLoader
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
|
||||
class TextFormatter:
|
||||
def __init__(self):
|
||||
# Create a sandboxed environment
|
||||
self.env = SandboxedEnvironment(loader=BaseLoader(), autoescape=True)
|
||||
|
||||
# Clear any registered filters, tests, and globals to minimize attack surface
|
||||
self.env.filters.clear()
|
||||
self.env.tests.clear()
|
||||
self.env.globals.clear()
|
||||
|
||||
def format_string(self, template_str: str, values=None, **kwargs) -> str:
|
||||
# For python.format compatibility: replace all {...} with {{..}}.
|
||||
# But avoid replacing {{...}} to {{{...}}}.
|
||||
template_str = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", template_str)
|
||||
template = self.env.from_string(template_str)
|
||||
return template.render(values or {}, **kwargs)
|
||||
@@ -1,20 +1,31 @@
|
||||
version: "3"
|
||||
services:
|
||||
postgres-test:
|
||||
image: ankane/pgvector:latest
|
||||
environment:
|
||||
- POSTGRES_USER=agpt_user
|
||||
- POSTGRES_PASSWORD=pass123
|
||||
- POSTGRES_DB=agpt_local
|
||||
- POSTGRES_USER=${DB_USER}
|
||||
- POSTGRES_PASSWORD=${DB_PASS}
|
||||
- POSTGRES_DB=${DB_NAME}
|
||||
healthcheck:
|
||||
test: pg_isready -U $$POSTGRES_USER -d $$POSTGRES_DB
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
ports:
|
||||
- "5433:5432"
|
||||
- "${DB_PORT}:5432"
|
||||
networks:
|
||||
- app-network-test
|
||||
redis-test:
|
||||
image: redis:latest
|
||||
command: redis-server --requirepass password
|
||||
ports:
|
||||
- "6379:6379"
|
||||
networks:
|
||||
- app-network-test
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
networks:
|
||||
app-network-test:
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentGraph_userId_isActive_idx" ON "AgentGraph"("userId", "isActive");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentGraphExecution_agentGraphId_agentGraphVersion_idx" ON "AgentGraphExecution"("agentGraphId", "agentGraphVersion");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentGraphExecution_userId_idx" ON "AgentGraphExecution"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNode_agentGraphId_agentGraphVersion_idx" ON "AgentNode"("agentGraphId", "agentGraphVersion");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNode_agentBlockId_idx" ON "AgentNode"("agentBlockId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNode_webhookId_idx" ON "AgentNode"("webhookId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNodeExecution_agentGraphExecutionId_idx" ON "AgentNodeExecution"("agentGraphExecutionId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNodeExecution_agentNodeId_idx" ON "AgentNodeExecution"("agentNodeId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNodeExecutionInputOutput_referencedByOutputExecId_idx" ON "AgentNodeExecutionInputOutput"("referencedByOutputExecId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNodeLink_agentNodeSourceId_idx" ON "AgentNodeLink"("agentNodeSourceId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNodeLink_agentNodeSinkId_idx" ON "AgentNodeLink"("agentNodeSinkId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AnalyticsMetrics_userId_idx" ON "AnalyticsMetrics"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "IntegrationWebhook_userId_idx" ON "IntegrationWebhook"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UserBlockCredit_userId_createdAt_idx" ON "UserBlockCredit"("userId", "createdAt");
|
||||
@@ -0,0 +1,8 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "User" ADD COLUMN "stripeCustomerId" TEXT;
|
||||
|
||||
-- AlterEnum
|
||||
ALTER TYPE "UserBlockCreditType" RENAME TO "CreditTransactionType";
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "UserBlockCredit" RENAME TO "CreditTransaction";
|
||||
@@ -0,0 +1,228 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "SubmissionStatus" AS ENUM ('DAFT', 'PENDING', 'APPROVED', 'REJECTED');
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraphExecution" ADD COLUMN "agentPresetId" TEXT;
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentNodeExecutionInputOutput" ADD COLUMN "agentPresetId" TEXT;
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "AnalyticsMetrics" ALTER COLUMN "id" DROP DEFAULT;
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "CreditTransaction" RENAME CONSTRAINT "UserBlockCredit_pkey" TO "CreditTransaction_pkey";
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "AgentPreset" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"name" TEXT NOT NULL,
|
||||
"description" TEXT NOT NULL,
|
||||
"isActive" BOOLEAN NOT NULL DEFAULT true,
|
||||
"userId" TEXT NOT NULL,
|
||||
"agentId" TEXT NOT NULL,
|
||||
"agentVersion" INTEGER NOT NULL,
|
||||
|
||||
CONSTRAINT "AgentPreset_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "UserAgent" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"userId" TEXT NOT NULL,
|
||||
"agentId" TEXT NOT NULL,
|
||||
"agentVersion" INTEGER NOT NULL,
|
||||
"agentPresetId" TEXT,
|
||||
"isFavorite" BOOLEAN NOT NULL DEFAULT false,
|
||||
"isCreatedByUser" BOOLEAN NOT NULL DEFAULT false,
|
||||
"isArchived" BOOLEAN NOT NULL DEFAULT false,
|
||||
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
|
||||
|
||||
CONSTRAINT "UserAgent_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "Profile" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"userId" TEXT,
|
||||
"name" TEXT NOT NULL,
|
||||
"username" TEXT NOT NULL,
|
||||
"description" TEXT NOT NULL,
|
||||
"links" TEXT[],
|
||||
"avatarUrl" TEXT,
|
||||
|
||||
CONSTRAINT "Profile_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "StoreListing" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
|
||||
"isApproved" BOOLEAN NOT NULL DEFAULT false,
|
||||
"agentId" TEXT NOT NULL,
|
||||
"agentVersion" INTEGER NOT NULL,
|
||||
"owningUserId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "StoreListing_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "StoreListingVersion" (
|
||||
"id" TEXT NOT NULL,
|
||||
"version" INTEGER NOT NULL DEFAULT 1,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"agentId" TEXT NOT NULL,
|
||||
"agentVersion" INTEGER NOT NULL,
|
||||
"slug" TEXT NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"subHeading" TEXT NOT NULL,
|
||||
"videoUrl" TEXT,
|
||||
"imageUrls" TEXT[],
|
||||
"description" TEXT NOT NULL,
|
||||
"categories" TEXT[],
|
||||
"isFeatured" BOOLEAN NOT NULL DEFAULT false,
|
||||
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
|
||||
"isAvailable" BOOLEAN NOT NULL DEFAULT true,
|
||||
"isApproved" BOOLEAN NOT NULL DEFAULT false,
|
||||
"storeListingId" TEXT,
|
||||
|
||||
CONSTRAINT "StoreListingVersion_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "StoreListingReview" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"storeListingVersionId" TEXT NOT NULL,
|
||||
"reviewByUserId" TEXT NOT NULL,
|
||||
"score" INTEGER NOT NULL,
|
||||
"comments" TEXT,
|
||||
|
||||
CONSTRAINT "StoreListingReview_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "StoreListingSubmission" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"storeListingId" TEXT NOT NULL,
|
||||
"storeListingVersionId" TEXT NOT NULL,
|
||||
"reviewerId" TEXT NOT NULL,
|
||||
"Status" "SubmissionStatus" NOT NULL DEFAULT 'PENDING',
|
||||
"reviewComments" TEXT,
|
||||
|
||||
CONSTRAINT "StoreListingSubmission_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentPreset_userId_idx" ON "AgentPreset"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UserAgent_userId_idx" ON "UserAgent"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "Profile_username_key" ON "Profile"("username");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "Profile_username_idx" ON "Profile"("username");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "Profile_userId_idx" ON "Profile"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoreListing_isApproved_idx" ON "StoreListing"("isApproved");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoreListing_agentId_idx" ON "StoreListing"("agentId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoreListing_owningUserId_idx" ON "StoreListing"("owningUserId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoreListingVersion_agentId_agentVersion_isApproved_idx" ON "StoreListingVersion"("agentId", "agentVersion", "isApproved");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "StoreListingVersion_agentId_agentVersion_key" ON "StoreListingVersion"("agentId", "agentVersion");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoreListingReview_storeListingVersionId_idx" ON "StoreListingReview"("storeListingVersionId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "StoreListingReview_storeListingVersionId_reviewByUserId_key" ON "StoreListingReview"("storeListingVersionId", "reviewByUserId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoreListingSubmission_storeListingId_idx" ON "StoreListingSubmission"("storeListingId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoreListingSubmission_Status_idx" ON "StoreListingSubmission"("Status");
|
||||
|
||||
-- RenameForeignKey
|
||||
ALTER TABLE "CreditTransaction" RENAME CONSTRAINT "UserBlockCredit_blockId_fkey" TO "CreditTransaction_blockId_fkey";
|
||||
|
||||
-- RenameForeignKey
|
||||
ALTER TABLE "CreditTransaction" RENAME CONSTRAINT "UserBlockCredit_userId_fkey" TO "CreditTransaction_userId_fkey";
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "AgentPreset" ADD CONSTRAINT "AgentPreset_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "AgentPreset" ADD CONSTRAINT "AgentPreset_agentId_agentVersion_fkey" FOREIGN KEY ("agentId", "agentVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "UserAgent" ADD CONSTRAINT "UserAgent_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "UserAgent" ADD CONSTRAINT "UserAgent_agentId_agentVersion_fkey" FOREIGN KEY ("agentId", "agentVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "UserAgent" ADD CONSTRAINT "UserAgent_agentPresetId_fkey" FOREIGN KEY ("agentPresetId") REFERENCES "AgentPreset"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "AgentGraphExecution" ADD CONSTRAINT "AgentGraphExecution_agentPresetId_fkey" FOREIGN KEY ("agentPresetId") REFERENCES "AgentPreset"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "AgentNodeExecutionInputOutput" ADD CONSTRAINT "AgentNodeExecutionInputOutput_agentPresetId_fkey" FOREIGN KEY ("agentPresetId") REFERENCES "AgentPreset"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "Profile" ADD CONSTRAINT "Profile_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "StoreListing" ADD CONSTRAINT "StoreListing_agentId_agentVersion_fkey" FOREIGN KEY ("agentId", "agentVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "StoreListing" ADD CONSTRAINT "StoreListing_owningUserId_fkey" FOREIGN KEY ("owningUserId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "StoreListingVersion" ADD CONSTRAINT "StoreListingVersion_agentId_agentVersion_fkey" FOREIGN KEY ("agentId", "agentVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "StoreListingVersion" ADD CONSTRAINT "StoreListingVersion_storeListingId_fkey" FOREIGN KEY ("storeListingId") REFERENCES "StoreListing"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "StoreListingReview" ADD CONSTRAINT "StoreListingReview_storeListingVersionId_fkey" FOREIGN KEY ("storeListingVersionId") REFERENCES "StoreListingVersion"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "StoreListingReview" ADD CONSTRAINT "StoreListingReview_reviewByUserId_fkey" FOREIGN KEY ("reviewByUserId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "StoreListingSubmission" ADD CONSTRAINT "StoreListingSubmission_storeListingId_fkey" FOREIGN KEY ("storeListingId") REFERENCES "StoreListing"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "StoreListingSubmission" ADD CONSTRAINT "StoreListingSubmission_storeListingVersionId_fkey" FOREIGN KEY ("storeListingVersionId") REFERENCES "StoreListingVersion"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "StoreListingSubmission" ADD CONSTRAINT "StoreListingSubmission_reviewerId_fkey" FOREIGN KEY ("reviewerId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- RenameIndex
|
||||
ALTER INDEX "UserBlockCredit_userId_createdAt_idx" RENAME TO "CreditTransaction_userId_createdAt_idx";
|
||||
@@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "Profile" ADD COLUMN "isFeatured" BOOLEAN NOT NULL DEFAULT false;
|
||||
@@ -0,0 +1,119 @@
|
||||
BEGIN;
|
||||
|
||||
CREATE VIEW "StoreAgent" AS
|
||||
WITH ReviewStats AS (
|
||||
SELECT sl."id" AS "storeListingId",
|
||||
COUNT(sr.id) AS review_count,
|
||||
AVG(CAST(sr.score AS DECIMAL)) AS avg_rating
|
||||
FROM "StoreListing" sl
|
||||
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl."id"
|
||||
JOIN "StoreListingReview" sr ON sr."storeListingVersionId" = slv.id
|
||||
WHERE sl."isDeleted" = FALSE
|
||||
GROUP BY sl."id"
|
||||
),
|
||||
AgentRuns AS (
|
||||
SELECT "agentGraphId", COUNT(*) AS run_count
|
||||
FROM "AgentGraphExecution"
|
||||
GROUP BY "agentGraphId"
|
||||
)
|
||||
SELECT
|
||||
sl.id AS listing_id,
|
||||
slv.id AS "storeListingVersionId",
|
||||
slv."createdAt" AS updated_at,
|
||||
slv.slug,
|
||||
a.name AS agent_name,
|
||||
slv."videoUrl" AS agent_video,
|
||||
COALESCE(slv."imageUrls", ARRAY[]::TEXT[]) AS agent_image,
|
||||
slv."isFeatured" AS featured,
|
||||
p.username AS creator_username,
|
||||
p."avatarUrl" AS creator_avatar,
|
||||
slv."subHeading" AS sub_heading,
|
||||
slv.description,
|
||||
slv.categories,
|
||||
COALESCE(ar.run_count, 0) AS runs,
|
||||
CAST(COALESCE(rs.avg_rating, 0.0) AS DOUBLE PRECISION) AS rating,
|
||||
ARRAY_AGG(DISTINCT CAST(slv.version AS TEXT)) AS versions
|
||||
FROM "StoreListing" sl
|
||||
JOIN "AgentGraph" a ON sl."agentId" = a.id AND sl."agentVersion" = a."version"
|
||||
LEFT JOIN "Profile" p ON sl."owningUserId" = p."userId"
|
||||
LEFT JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
|
||||
LEFT JOIN ReviewStats rs ON sl.id = rs."storeListingId"
|
||||
LEFT JOIN AgentRuns ar ON a.id = ar."agentGraphId"
|
||||
WHERE sl."isDeleted" = FALSE
|
||||
AND sl."isApproved" = TRUE
|
||||
GROUP BY sl.id, slv.id, slv.slug, slv."createdAt", a.name, slv."videoUrl", slv."imageUrls", slv."isFeatured",
|
||||
p.username, p."avatarUrl", slv."subHeading", slv.description, slv.categories,
|
||||
ar.run_count, rs.avg_rating;
|
||||
|
||||
CREATE VIEW "Creator" AS
|
||||
WITH AgentStats AS (
|
||||
SELECT
|
||||
p.username,
|
||||
COUNT(DISTINCT sl.id) as num_agents,
|
||||
AVG(CAST(COALESCE(sr.score, 0) AS DECIMAL)) as agent_rating,
|
||||
SUM(COALESCE(age.run_count, 0)) as agent_runs
|
||||
FROM "Profile" p
|
||||
LEFT JOIN "StoreListing" sl ON sl."owningUserId" = p."userId"
|
||||
LEFT JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
|
||||
LEFT JOIN "StoreListingReview" sr ON sr."storeListingVersionId" = slv.id
|
||||
LEFT JOIN (
|
||||
SELECT "agentGraphId", COUNT(*) as run_count
|
||||
FROM "AgentGraphExecution"
|
||||
GROUP BY "agentGraphId"
|
||||
) age ON age."agentGraphId" = sl."agentId"
|
||||
WHERE sl."isDeleted" = FALSE AND sl."isApproved" = TRUE
|
||||
GROUP BY p.username
|
||||
)
|
||||
SELECT
|
||||
p.username,
|
||||
p.name,
|
||||
p."avatarUrl" as avatar_url,
|
||||
p.description,
|
||||
ARRAY_AGG(DISTINCT c) FILTER (WHERE c IS NOT NULL) as top_categories,
|
||||
p.links,
|
||||
p."isFeatured" as is_featured,
|
||||
COALESCE(ast.num_agents, 0) as num_agents,
|
||||
COALESCE(ast.agent_rating, 0.0) as agent_rating,
|
||||
COALESCE(ast.agent_runs, 0) as agent_runs
|
||||
FROM "Profile" p
|
||||
LEFT JOIN AgentStats ast ON ast.username = p.username
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT UNNEST(slv.categories) as c
|
||||
FROM "StoreListing" sl
|
||||
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
|
||||
WHERE sl."owningUserId" = p."userId"
|
||||
AND sl."isDeleted" = FALSE
|
||||
AND sl."isApproved" = TRUE
|
||||
) cats ON true
|
||||
GROUP BY p.username, p.name, p."avatarUrl", p.description, p.links, p."isFeatured",
|
||||
ast.num_agents, ast.agent_rating, ast.agent_runs;
|
||||
|
||||
CREATE VIEW "StoreSubmission" AS
|
||||
SELECT
|
||||
sl.id as listing_id,
|
||||
sl."owningUserId" as user_id,
|
||||
slv."agentId" as agent_id,
|
||||
slv."version" as agent_version,
|
||||
slv.slug,
|
||||
slv.name,
|
||||
slv."subHeading" as sub_heading,
|
||||
slv.description,
|
||||
slv."imageUrls" as image_urls,
|
||||
slv."createdAt" as date_submitted,
|
||||
COALESCE(sls."Status", 'PENDING') as status,
|
||||
COALESCE(ar.run_count, 0) as runs,
|
||||
CAST(COALESCE(AVG(CAST(sr.score AS DECIMAL)), 0.0) AS DOUBLE PRECISION) as rating
|
||||
FROM "StoreListing" sl
|
||||
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
|
||||
LEFT JOIN "StoreListingSubmission" sls ON sls."storeListingId" = sl.id
|
||||
LEFT JOIN "StoreListingReview" sr ON sr."storeListingVersionId" = slv.id
|
||||
LEFT JOIN (
|
||||
SELECT "agentGraphId", COUNT(*) as run_count
|
||||
FROM "AgentGraphExecution"
|
||||
GROUP BY "agentGraphId"
|
||||
) ar ON ar."agentGraphId" = slv."agentId"
|
||||
WHERE sl."isDeleted" = FALSE
|
||||
GROUP BY sl.id, sl."owningUserId", slv."agentId", slv."version", slv.slug, slv.name, slv."subHeading",
|
||||
slv.description, slv."imageUrls", slv."createdAt", sls."Status", ar.run_count;
|
||||
|
||||
COMMIT;
|
||||
2592
autogpt_platform/backend/poetry.lock
generated
2592
autogpt_platform/backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -8,25 +8,26 @@ packages = [{ include = "backend" }]
|
||||
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
aio-pika = "^9.5.0"
|
||||
anthropic = "^0.39.0"
|
||||
python = ">=3.10,<3.13"
|
||||
aio-pika = "^9.5.4"
|
||||
anthropic = "^0.40.0"
|
||||
apscheduler = "^3.11.0"
|
||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||
click = "^8.1.7"
|
||||
croniter = "^5.0.1"
|
||||
discord-py = "^2.4.0"
|
||||
e2b-code-interpreter = "^1.0.1"
|
||||
fastapi = "^0.115.5"
|
||||
feedparser = "^6.0.11"
|
||||
flake8 = "^7.0.0"
|
||||
google-api-python-client = "^2.154.0"
|
||||
google-auth-oauthlib = "^1.2.1"
|
||||
groq = "^0.12.0"
|
||||
groq = "^0.13.1"
|
||||
jinja2 = "^3.1.4"
|
||||
jsonref = "^1.1.0"
|
||||
jsonschema = "^4.22.0"
|
||||
ollama = "^0.4.1"
|
||||
openai = "^1.55.1"
|
||||
openai = "^1.57.4"
|
||||
praw = "~7.8.1"
|
||||
prisma = "^0.15.0"
|
||||
psutil = "^6.1.0"
|
||||
@@ -34,23 +35,26 @@ pydantic = "^2.9.2"
|
||||
pydantic-settings = "^2.3.4"
|
||||
pyro5 = "^5.15"
|
||||
pytest = "^8.2.1"
|
||||
pytest-asyncio = "^0.24.0"
|
||||
pytest-asyncio = "^0.25.0"
|
||||
python-dotenv = "^1.0.1"
|
||||
redis = "^5.2.0"
|
||||
sentry-sdk = "2.19.0"
|
||||
sentry-sdk = "2.19.2"
|
||||
strenum = "^0.4.9"
|
||||
supabase = "^2.10.0"
|
||||
tenacity = "^9.0.0"
|
||||
uvicorn = { extras = ["standard"], version = "^0.32.1" }
|
||||
uvicorn = { extras = ["standard"], version = "^0.34.0" }
|
||||
websockets = "^13.1"
|
||||
youtube-transcript-api = "^0.6.2"
|
||||
googlemaps = "^4.10.0"
|
||||
replicate = "^1.0.4"
|
||||
pinecone = "^5.3.1"
|
||||
cryptography = "^43.0.3"
|
||||
cryptography = "^43.0"
|
||||
python-multipart = "^0.0.20"
|
||||
sqlalchemy = "^2.0.36"
|
||||
psycopg2-binary = "^2.9.10"
|
||||
google-cloud-storage = "^2.18.2"
|
||||
launchdarkly-server-sdk = "^9.8.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
poethepoet = "^0.31.0"
|
||||
httpx = "^0.27.0"
|
||||
@@ -61,6 +65,8 @@ pyright = "^1.1.389"
|
||||
isort = "^5.13.2"
|
||||
black = "^24.10.0"
|
||||
aiohappyeyeballs = "^2.4.3"
|
||||
pytest-mock = "^3.14.0"
|
||||
faker = "^33.1.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
@@ -90,3 +96,6 @@ ignore_patterns = []
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
|
||||
@@ -16,9 +16,9 @@ def wait_for_postgres(max_retries=5, delay=5):
|
||||
"postgres-test",
|
||||
"pg_isready",
|
||||
"-U",
|
||||
"agpt_user",
|
||||
"postgres",
|
||||
"-d",
|
||||
"agpt_local",
|
||||
"postgres",
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
|
||||
@@ -8,26 +8,36 @@ generator client {
|
||||
provider = "prisma-client-py"
|
||||
recursive_type_depth = 5
|
||||
interface = "asyncio"
|
||||
previewFeatures = ["views"]
|
||||
}
|
||||
|
||||
// User model to mirror Auth provider users
|
||||
model User {
|
||||
id String @id // This should match the Supabase user ID
|
||||
email String @unique
|
||||
name String?
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
metadata Json @default("{}")
|
||||
integrations String @default("")
|
||||
id String @id // This should match the Supabase user ID
|
||||
email String @unique
|
||||
name String?
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
metadata Json @default("{}")
|
||||
integrations String @default("")
|
||||
stripeCustomerId String?
|
||||
|
||||
// Relations
|
||||
AgentGraphs AgentGraph[]
|
||||
AgentGraphExecutions AgentGraphExecution[]
|
||||
IntegrationWebhooks IntegrationWebhook[]
|
||||
AnalyticsDetails AnalyticsDetails[]
|
||||
AnalyticsMetrics AnalyticsMetrics[]
|
||||
UserBlockCredit UserBlockCredit[]
|
||||
APIKeys APIKey[]
|
||||
AgentGraphs AgentGraph[]
|
||||
AgentGraphExecutions AgentGraphExecution[]
|
||||
AnalyticsDetails AnalyticsDetails[]
|
||||
AnalyticsMetrics AnalyticsMetrics[]
|
||||
CreditTransaction CreditTransaction[]
|
||||
|
||||
AgentPreset AgentPreset[]
|
||||
UserAgent UserAgent[]
|
||||
|
||||
Profile Profile[]
|
||||
StoreListing StoreListing[]
|
||||
StoreListingReview StoreListingReview[]
|
||||
StoreListingSubmission StoreListingSubmission[]
|
||||
APIKeys APIKey[]
|
||||
IntegrationWebhooks IntegrationWebhook[]
|
||||
|
||||
@@index([id])
|
||||
@@index([email])
|
||||
@@ -47,14 +57,89 @@ model AgentGraph {
|
||||
|
||||
// Link to User model
|
||||
userId String
|
||||
// FIX: Do not cascade delete the agent when the user is deleted
|
||||
// This allows us to delete user data with deleting the agent which maybe in use by other users
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
AgentNodes AgentNode[]
|
||||
AgentGraphExecution AgentGraphExecution[]
|
||||
AgentNodes AgentNode[]
|
||||
AgentGraphExecution AgentGraphExecution[]
|
||||
|
||||
AgentPreset AgentPreset[]
|
||||
UserAgent UserAgent[]
|
||||
StoreListing StoreListing[]
|
||||
StoreListingVersion StoreListingVersion?
|
||||
|
||||
@@id(name: "graphVersionId", [id, version])
|
||||
@@index([userId, isActive])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
//////////////// USER SPECIFIC DATA ////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// An AgentPrest is an Agent + User Configuration of that agent.
|
||||
// For example, if someone has created a weather agent and they want to set it up to
|
||||
// Inform them of extreme weather warnings in Texas, the agent with the configuration to set it to
|
||||
// monitor texas, along with the cron setup or webhook tiggers, is an AgentPreset
|
||||
model AgentPreset {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
name String
|
||||
description String
|
||||
|
||||
// For agents that can be triggered by webhooks or cronjob
|
||||
// This bool allows us to disable a configured agent without deleting it
|
||||
isActive Boolean @default(true)
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
agentId String
|
||||
agentVersion Int
|
||||
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
|
||||
|
||||
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
|
||||
UserAgents UserAgent[]
|
||||
AgentExecution AgentGraphExecution[]
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
// For the library page
|
||||
// It is a user controlled list of agents, that they will see in there library
|
||||
model UserAgent {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
agentId String
|
||||
agentVersion Int
|
||||
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version])
|
||||
|
||||
agentPresetId String?
|
||||
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
|
||||
|
||||
isFavorite Boolean @default(false)
|
||||
isCreatedByUser Boolean @default(false)
|
||||
isArchived Boolean @default(false)
|
||||
isDeleted Boolean @default(false)
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
//////// AGENT DEFINITION AND EXECUTION TABLES ////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// This model describes a single node in the Agent Graph/Flow (Multi Agent System).
|
||||
model AgentNode {
|
||||
id String @id @default(uuid())
|
||||
@@ -83,6 +168,10 @@ model AgentNode {
|
||||
metadata String @default("{}")
|
||||
|
||||
ExecutionHistory AgentNodeExecution[]
|
||||
|
||||
@@index([agentGraphId, agentGraphVersion])
|
||||
@@index([agentBlockId])
|
||||
@@index([webhookId])
|
||||
}
|
||||
|
||||
// This model describes the link between two AgentNodes.
|
||||
@@ -101,6 +190,9 @@ model AgentNodeLink {
|
||||
|
||||
// Default: the data coming from the source can only be consumed by the sink once, Static: input data will be reused.
|
||||
isStatic Boolean @default(false)
|
||||
|
||||
@@index([agentNodeSourceId])
|
||||
@@index([agentNodeSinkId])
|
||||
}
|
||||
|
||||
// This model describes a component that will be executed by the AgentNode.
|
||||
@@ -115,7 +207,7 @@ model AgentBlock {
|
||||
|
||||
// Prisma requires explicit back-references.
|
||||
ReferencedByAgentNode AgentNode[]
|
||||
UserBlockCredit UserBlockCredit[]
|
||||
CreditTransaction CreditTransaction[]
|
||||
}
|
||||
|
||||
// This model describes the status of an AgentGraphExecution or AgentNodeExecution.
|
||||
@@ -146,7 +238,12 @@ model AgentGraphExecution {
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
stats String? // JSON serialized object
|
||||
stats String? // JSON serialized object
|
||||
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
|
||||
agentPresetId String?
|
||||
|
||||
@@index([agentGraphId, agentGraphVersion])
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
// This model describes the execution of an AgentNode.
|
||||
@@ -171,6 +268,9 @@ model AgentNodeExecution {
|
||||
endedTime DateTime?
|
||||
|
||||
stats String? // JSON serialized object
|
||||
|
||||
@@index([agentGraphExecutionId])
|
||||
@@index([agentNodeId])
|
||||
}
|
||||
|
||||
// This model describes the output of an AgentNodeExecution.
|
||||
@@ -187,8 +287,12 @@ model AgentNodeExecutionInputOutput {
|
||||
referencedByOutputExecId String?
|
||||
ReferencedByOutputExec AgentNodeExecution? @relation("AgentNodeExecutionOutput", fields: [referencedByOutputExecId], references: [id], onDelete: Cascade)
|
||||
|
||||
agentPresetId String?
|
||||
AgentPreset AgentPreset? @relation("AgentPresetsInputData", fields: [agentPresetId], references: [id])
|
||||
|
||||
// Input and Output pin names are unique for each AgentNodeExecution.
|
||||
@@unique([referencedByInputExecId, referencedByOutputExecId, name])
|
||||
@@index([referencedByOutputExecId])
|
||||
}
|
||||
|
||||
// Webhook that is registered with a provider and propagates to one or more nodes
|
||||
@@ -211,6 +315,8 @@ model IntegrationWebhook {
|
||||
providerWebhookId String // Webhook ID assigned by the provider
|
||||
|
||||
AgentNodes AgentNode[]
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
model AnalyticsDetails {
|
||||
@@ -238,8 +344,13 @@ model AnalyticsDetails {
|
||||
@@index([type])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////// METRICS TRACKING TABLES ////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
model AnalyticsMetrics {
|
||||
id String @id @default(dbgenerated("gen_random_uuid()"))
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
@@ -254,14 +365,21 @@ model AnalyticsMetrics {
|
||||
// Link to User model
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
enum UserBlockCreditType {
|
||||
enum CreditTransactionType {
|
||||
TOP_UP
|
||||
USAGE
|
||||
}
|
||||
|
||||
model UserBlockCredit {
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
//////// ACCOUNTING AND CREDIT SYSTEM TABLES //////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
model CreditTransaction {
|
||||
transactionKey String @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
@@ -272,12 +390,215 @@ model UserBlockCredit {
|
||||
block AgentBlock? @relation(fields: [blockId], references: [id])
|
||||
|
||||
amount Int
|
||||
type UserBlockCreditType
|
||||
type CreditTransactionType
|
||||
|
||||
isActive Boolean @default(true)
|
||||
metadata Json?
|
||||
|
||||
@@id(name: "creditTransactionIdentifier", [transactionKey, userId])
|
||||
@@index([userId, createdAt])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////// Store TABLES ///////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
model Profile {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
// Only 1 of user or group can be set.
|
||||
// The user this profile belongs to, if any.
|
||||
userId String?
|
||||
User User? @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
name String
|
||||
username String @unique
|
||||
description String
|
||||
|
||||
links String[]
|
||||
|
||||
avatarUrl String?
|
||||
|
||||
isFeatured Boolean @default(false)
|
||||
|
||||
@@index([username])
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
view Creator {
|
||||
username String @unique
|
||||
name String
|
||||
avatar_url String
|
||||
description String
|
||||
|
||||
top_categories String[]
|
||||
links String[]
|
||||
|
||||
num_agents Int
|
||||
agent_rating Float
|
||||
agent_runs Int
|
||||
is_featured Boolean
|
||||
}
|
||||
|
||||
view StoreAgent {
|
||||
listing_id String @id
|
||||
storeListingVersionId String
|
||||
updated_at DateTime
|
||||
|
||||
slug String
|
||||
agent_name String
|
||||
agent_video String?
|
||||
agent_image String[]
|
||||
|
||||
featured Boolean @default(false)
|
||||
creator_username String
|
||||
creator_avatar String
|
||||
sub_heading String
|
||||
description String
|
||||
categories String[]
|
||||
runs Int
|
||||
rating Float
|
||||
versions String[]
|
||||
|
||||
@@unique([creator_username, slug])
|
||||
@@index([creator_username])
|
||||
@@index([featured])
|
||||
@@index([categories])
|
||||
@@index([storeListingVersionId])
|
||||
}
|
||||
|
||||
view StoreSubmission {
|
||||
listing_id String @id
|
||||
user_id String
|
||||
slug String
|
||||
name String
|
||||
sub_heading String
|
||||
description String
|
||||
image_urls String[]
|
||||
date_submitted DateTime
|
||||
status SubmissionStatus
|
||||
runs Int
|
||||
rating Float
|
||||
agent_id String
|
||||
agent_version Int
|
||||
|
||||
@@index([user_id])
|
||||
}
|
||||
|
||||
model StoreListing {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
isDeleted Boolean @default(false)
|
||||
// Not needed but makes lookups faster
|
||||
isApproved Boolean @default(false)
|
||||
|
||||
// The agent link here is only so we can do lookup on agentId, for the listing the StoreListingVersion is used.
|
||||
agentId String
|
||||
agentVersion Int
|
||||
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
|
||||
|
||||
owningUserId String
|
||||
OwningUser User @relation(fields: [owningUserId], references: [id])
|
||||
|
||||
StoreListingVersions StoreListingVersion[]
|
||||
StoreListingSubmission StoreListingSubmission[]
|
||||
|
||||
@@index([isApproved])
|
||||
@@index([agentId])
|
||||
@@index([owningUserId])
|
||||
}
|
||||
|
||||
model StoreListingVersion {
|
||||
id String @id @default(uuid())
|
||||
version Int @default(1)
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
// The agent and version to be listed on the store
|
||||
agentId String
|
||||
agentVersion Int
|
||||
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version])
|
||||
|
||||
// The detials for this version of the agent, this allows the author to update the details of the agent,
|
||||
// But still allow using old versions of the agent with there original details.
|
||||
// TODO: Create a database view that shows only the latest version of each store listing.
|
||||
slug String
|
||||
name String
|
||||
subHeading String
|
||||
videoUrl String?
|
||||
imageUrls String[]
|
||||
description String
|
||||
categories String[]
|
||||
|
||||
isFeatured Boolean @default(false)
|
||||
|
||||
isDeleted Boolean @default(false)
|
||||
// Old versions can be made unavailable by the author if desired
|
||||
isAvailable Boolean @default(true)
|
||||
// Not needed but makes lookups faster
|
||||
isApproved Boolean @default(false)
|
||||
StoreListing StoreListing? @relation(fields: [storeListingId], references: [id], onDelete: Cascade)
|
||||
storeListingId String?
|
||||
StoreListingSubmission StoreListingSubmission[]
|
||||
|
||||
// Reviews are on a specific version, but then aggregated up to the listing.
|
||||
// This allows us to provide a review filter to current version of the agent.
|
||||
StoreListingReview StoreListingReview[]
|
||||
|
||||
@@unique([agentId, agentVersion])
|
||||
@@index([agentId, agentVersion, isApproved])
|
||||
}
|
||||
|
||||
model StoreListingReview {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
storeListingVersionId String
|
||||
StoreListingVersion StoreListingVersion @relation(fields: [storeListingVersionId], references: [id], onDelete: Cascade)
|
||||
|
||||
reviewByUserId String
|
||||
ReviewByUser User @relation(fields: [reviewByUserId], references: [id])
|
||||
|
||||
score Int
|
||||
comments String?
|
||||
|
||||
@@unique([storeListingVersionId, reviewByUserId])
|
||||
@@index([storeListingVersionId])
|
||||
}
|
||||
|
||||
enum SubmissionStatus {
|
||||
DAFT
|
||||
PENDING
|
||||
APPROVED
|
||||
REJECTED
|
||||
}
|
||||
|
||||
model StoreListingSubmission {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
storeListingId String
|
||||
StoreListing StoreListing @relation(fields: [storeListingId], references: [id], onDelete: Cascade)
|
||||
|
||||
storeListingVersionId String
|
||||
StoreListingVersion StoreListingVersion @relation(fields: [storeListingVersionId], references: [id], onDelete: Cascade)
|
||||
|
||||
reviewerId String
|
||||
Reviewer User @relation(fields: [reviewerId], references: [id])
|
||||
|
||||
Status SubmissionStatus @default(PENDING)
|
||||
reviewComments String?
|
||||
|
||||
@@index([storeListingId])
|
||||
@@index([Status])
|
||||
}
|
||||
|
||||
enum APIKeyPermission {
|
||||
@@ -290,7 +611,7 @@ enum APIKeyPermission {
|
||||
model APIKey {
|
||||
id String @id @default(uuid())
|
||||
name String
|
||||
prefix String // First 8 chars for identification
|
||||
prefix String // First 8 chars for identification
|
||||
postfix String
|
||||
key String @unique // Hashed key
|
||||
status APIKeyStatus @default(ACTIVE)
|
||||
@@ -317,4 +638,4 @@ enum APIKeyStatus {
|
||||
ACTIVE
|
||||
REVOKED
|
||||
SUSPENDED
|
||||
}
|
||||
}
|
||||
@@ -1,628 +0,0 @@
|
||||
// We need to migrate our database schema to support the domain as we understand it now
|
||||
// To do so requires adding a bunch of new tables, but also modiftying old ones and how
|
||||
// they relate to each other. This is a large change, so instead of doing in in one go,
|
||||
// We have created the target schema, and will migrate to it incrementally.
|
||||
|
||||
datasource db {
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
}
|
||||
|
||||
generator client {
|
||||
provider = "prisma-client-py"
|
||||
recursive_type_depth = 5
|
||||
interface = "asyncio"
|
||||
}
|
||||
|
||||
// User model to mirror Auth provider users
|
||||
model User {
|
||||
id String @id @db.Uuid // This should match the Supabase user ID
|
||||
email String @unique
|
||||
name String?
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
metadata String @default("")
|
||||
|
||||
// Relations
|
||||
Agents Agent[]
|
||||
AgentExecutions AgentExecution[]
|
||||
AgentExecutionSchedules AgentExecutionSchedule[]
|
||||
AnalyticsDetails AnalyticsDetails[]
|
||||
AnalyticsMetrics AnalyticsMetrics[]
|
||||
UserBlockCredit UserBlockCredit[]
|
||||
AgentPresets AgentPreset[]
|
||||
UserAgents UserAgent[]
|
||||
|
||||
// User Group relations
|
||||
UserGroupMemberships UserGroupMembership[]
|
||||
Profile Profile[]
|
||||
StoreListing StoreListing[]
|
||||
StoreListingSubmission StoreListingSubmission[]
|
||||
StoreListingReview StoreListingReview[]
|
||||
}
|
||||
|
||||
model UserGroup {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
name String
|
||||
description String
|
||||
groupIconUrl String?
|
||||
|
||||
UserGroupMemberships UserGroupMembership[]
|
||||
|
||||
Agents Agent[]
|
||||
Profile Profile[]
|
||||
StoreListing StoreListing[]
|
||||
|
||||
@@index([name])
|
||||
}
|
||||
|
||||
enum UserGroupRole {
|
||||
MEMBER
|
||||
OWNER
|
||||
}
|
||||
|
||||
model UserGroupMembership {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
userId String @db.Uuid
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
userGroupId String @db.Uuid
|
||||
UserGroup UserGroup @relation(fields: [userGroupId], references: [id], onDelete: Cascade)
|
||||
Role UserGroupRole @default(MEMBER)
|
||||
|
||||
@@unique([userId, userGroupId])
|
||||
@@index([userId])
|
||||
@@index([userGroupId])
|
||||
}
|
||||
|
||||
// This model describes the Agent Graph/Flow (Multi Agent System).
|
||||
model Agent {
|
||||
id String @default(uuid()) @db.Uuid
|
||||
version Int @default(1)
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
name String?
|
||||
description String?
|
||||
|
||||
// Link to User model
|
||||
createdByUserId String? @db.Uuid
|
||||
// Do not cascade delete the agent when the user is deleted
|
||||
// This allows us to delete user data with deleting the agent which maybe in use by other users
|
||||
CreatedByUser User? @relation(fields: [createdByUserId], references: [id], onDelete: SetNull)
|
||||
|
||||
groupId String? @db.Uuid
|
||||
// Do not cascade delete the agent when the group is deleted
|
||||
// This allows us to delete user group data with deleting the agent which maybe in use by other users
|
||||
Group UserGroup? @relation(fields: [groupId], references: [id], onDelete: SetNull)
|
||||
|
||||
AgentNodes AgentNode[]
|
||||
AgentExecution AgentExecution[]
|
||||
|
||||
// All sub-graphs are defined within this 1-level depth list (even if it's a nested graph).
|
||||
SubAgents Agent[] @relation("SubAgents")
|
||||
agentParentId String? @db.Uuid
|
||||
agentParentVersion Int?
|
||||
AgentParent Agent? @relation("SubAgents", fields: [agentParentId, agentParentVersion], references: [id, version])
|
||||
|
||||
AgentPresets AgentPreset[]
|
||||
WebhookTrigger WebhookTrigger[]
|
||||
AgentExecutionSchedule AgentExecutionSchedule[]
|
||||
UserAgents UserAgent[]
|
||||
UserBlockCredit UserBlockCredit[]
|
||||
StoreListing StoreListing[]
|
||||
StoreListingVersion StoreListingVersion[]
|
||||
|
||||
@@id(name: "agentVersionId", [id, version])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
//////////////// USER SPECIFIC DATA ////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// An AgentPrest is an Agent + User Configuration of that agent.
|
||||
// For example, if someone has created a weather agent and they want to set it up to
|
||||
// Inform them of extreme weather warnings in Texas, the agent with the configuration to set it to
|
||||
// monitor texas, along with the cron setup or webhook tiggers, is an AgentPreset
|
||||
model AgentPreset {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
name String
|
||||
description String
|
||||
|
||||
// For agents that can be triggered by webhooks or cronjob
|
||||
// This bool allows us to disable a configured agent without deleting it
|
||||
isActive Boolean @default(true)
|
||||
|
||||
userId String @db.Uuid
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
agentId String @db.Uuid
|
||||
agentVersion Int
|
||||
Agent Agent @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
|
||||
|
||||
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
|
||||
UserAgents UserAgent[]
|
||||
WebhookTrigger WebhookTrigger[]
|
||||
AgentExecutionSchedule AgentExecutionSchedule[]
|
||||
AgentExecution AgentExecution[]
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
// For the library page
|
||||
// It is a user controlled list of agents, that they will see in there library
|
||||
model UserAgent {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
userId String @db.Uuid
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
agentId String @db.Uuid
|
||||
agentVersion Int
|
||||
Agent Agent @relation(fields: [agentId, agentVersion], references: [id, version])
|
||||
|
||||
agentPresetId String? @db.Uuid
|
||||
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
|
||||
|
||||
isFavorite Boolean @default(false)
|
||||
isCreatedByUser Boolean @default(false)
|
||||
isArchived Boolean @default(false)
|
||||
isDeleted Boolean @default(false)
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
//////// AGENT DEFINITION AND EXECUTION TABLES ////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
// This model describes a single node in the Agent Graph/Flow (Multi Agent System).
|
||||
model AgentNode {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
agentBlockId String @db.Uuid
|
||||
AgentBlock AgentBlock @relation(fields: [agentBlockId], references: [id], onUpdate: Cascade)
|
||||
|
||||
agentId String @db.Uuid
|
||||
agentVersion Int @default(1)
|
||||
Agent Agent @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
|
||||
|
||||
// List of consumed input, that the parent node should provide.
|
||||
Input AgentNodeLink[] @relation("AgentNodeSink")
|
||||
|
||||
// List of produced output, that the child node should be executed.
|
||||
Output AgentNodeLink[] @relation("AgentNodeSource")
|
||||
|
||||
// JSON serialized dict[str, str] containing predefined input values.
|
||||
constantInput Json @default("{}")
|
||||
|
||||
// JSON serialized dict[str, str] containing the node metadata.
|
||||
metadata Json @default("{}")
|
||||
|
||||
ExecutionHistory AgentNodeExecution[]
|
||||
}
|
||||
|
||||
// This model describes the link between two AgentNodes.
|
||||
model AgentNodeLink {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
// Output of a node is connected to the source of the link.
|
||||
agentNodeSourceId String @db.Uuid
|
||||
AgentNodeSource AgentNode @relation("AgentNodeSource", fields: [agentNodeSourceId], references: [id], onDelete: Cascade)
|
||||
sourceName String
|
||||
|
||||
// Input of a node is connected to the sink of the link.
|
||||
agentNodeSinkId String @db.Uuid
|
||||
AgentNodeSink AgentNode @relation("AgentNodeSink", fields: [agentNodeSinkId], references: [id], onDelete: Cascade)
|
||||
sinkName String
|
||||
|
||||
// Default: the data coming from the source can only be consumed by the sink once, Static: input data will be reused.
|
||||
isStatic Boolean @default(false)
|
||||
}
|
||||
|
||||
// This model describes a component that will be executed by the AgentNode.
|
||||
model AgentBlock {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
name String @unique
|
||||
|
||||
// We allow a block to have multiple types of input & output.
|
||||
// Serialized object-typed `jsonschema` with top-level properties as input/output name.
|
||||
inputSchema Json @default("{}")
|
||||
outputSchema Json @default("{}")
|
||||
|
||||
// Prisma requires explicit back-references.
|
||||
ReferencedByAgentNode AgentNode[]
|
||||
UserBlockCredit UserBlockCredit[]
|
||||
}
|
||||
|
||||
// This model describes the status of an AgentExecution or AgentNodeExecution.
|
||||
enum AgentExecutionStatus {
|
||||
INCOMPLETE
|
||||
QUEUED
|
||||
RUNNING
|
||||
COMPLETED
|
||||
FAILED
|
||||
}
|
||||
|
||||
// Enum for execution trigger types
|
||||
enum ExecutionTriggerType {
|
||||
MANUAL
|
||||
SCHEDULE
|
||||
WEBHOOK
|
||||
}
|
||||
|
||||
// This model describes the execution of an AgentGraph.
|
||||
model AgentExecution {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
startedAt DateTime?
|
||||
executionTriggerType ExecutionTriggerType @default(MANUAL)
|
||||
|
||||
executionStatus AgentExecutionStatus @default(COMPLETED)
|
||||
|
||||
agentId String @db.Uuid
|
||||
agentVersion Int @default(1)
|
||||
Agent Agent @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
|
||||
|
||||
// we need to be able to associate an agent execution with an agent preset
|
||||
agentPresetId String? @db.Uuid
|
||||
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
|
||||
|
||||
AgentNodeExecutions AgentNodeExecution[]
|
||||
|
||||
// This is so we can track which user executed the agent.
|
||||
executedByUserId String @db.Uuid
|
||||
ExecutedByUser User @relation(fields: [executedByUserId], references: [id], onDelete: Cascade)
|
||||
|
||||
stats Json @default("{}") // JSON serialized object
|
||||
}
|
||||
|
||||
// This model describes the execution of an AgentNode.
|
||||
model AgentNodeExecution {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
agentExecutionId String @db.Uuid
|
||||
AgentExecution AgentExecution @relation(fields: [agentExecutionId], references: [id], onDelete: Cascade)
|
||||
|
||||
agentNodeId String @db.Uuid
|
||||
AgentNode AgentNode @relation(fields: [agentNodeId], references: [id], onDelete: Cascade)
|
||||
|
||||
Input AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionInput")
|
||||
Output AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionOutput")
|
||||
|
||||
executionStatus AgentExecutionStatus @default(COMPLETED)
|
||||
// Final JSON serialized input data for the node execution.
|
||||
executionData String?
|
||||
addedTime DateTime @default(now())
|
||||
queuedTime DateTime?
|
||||
startedTime DateTime?
|
||||
endedTime DateTime?
|
||||
|
||||
stats Json @default("{}") // JSON serialized object
|
||||
UserBlockCredit UserBlockCredit[]
|
||||
}
|
||||
|
||||
// This model describes the output of an AgentNodeExecution.
|
||||
model AgentNodeExecutionInputOutput {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
name String
|
||||
data String
|
||||
time DateTime @default(now())
|
||||
|
||||
// Prisma requires explicit back-references.
|
||||
referencedByInputExecId String? @db.Uuid
|
||||
ReferencedByInputExec AgentNodeExecution? @relation("AgentNodeExecutionInput", fields: [referencedByInputExecId], references: [id], onDelete: Cascade)
|
||||
referencedByOutputExecId String? @db.Uuid
|
||||
ReferencedByOutputExec AgentNodeExecution? @relation("AgentNodeExecutionOutput", fields: [referencedByOutputExecId], references: [id], onDelete: Cascade)
|
||||
|
||||
agentPresetId String? @db.Uuid
|
||||
AgentPreset AgentPreset? @relation("AgentPresetsInputData", fields: [agentPresetId], references: [id])
|
||||
|
||||
// Input and Output pin names are unique for each AgentNodeExecution.
|
||||
@@unique([referencedByInputExecId, referencedByOutputExecId, name])
|
||||
}
|
||||
|
||||
// This model describes the recurring execution schedule of an Agent.
|
||||
model AgentExecutionSchedule {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
agentPresetId String @db.Uuid
|
||||
AgentPreset AgentPreset @relation(fields: [agentPresetId], references: [id], onDelete: Cascade)
|
||||
|
||||
schedule String // cron expression
|
||||
isEnabled Boolean @default(true)
|
||||
|
||||
// Allows triggers to be routed down different execution paths in an agent graph
|
||||
triggerIdentifier String
|
||||
|
||||
// default and set the value on each update, lastUpdated field has no time zone.
|
||||
lastUpdated DateTime @default(now()) @updatedAt
|
||||
|
||||
// Link to User model
|
||||
userId String @db.Uuid
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
Agent Agent? @relation(fields: [agentId, agentVersion], references: [id, version])
|
||||
agentId String? @db.Uuid
|
||||
agentVersion Int?
|
||||
|
||||
@@index([isEnabled])
|
||||
}
|
||||
|
||||
enum HttpMethod {
|
||||
GET
|
||||
POST
|
||||
PUT
|
||||
DELETE
|
||||
PATCH
|
||||
}
|
||||
|
||||
model WebhookTrigger {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
agentPresetId String @db.Uuid
|
||||
AgentPreset AgentPreset @relation(fields: [agentPresetId], references: [id])
|
||||
|
||||
method HttpMethod
|
||||
urlSlug String
|
||||
|
||||
// Allows triggers to be routed down different execution paths in an agent graph
|
||||
triggerIdentifier String
|
||||
|
||||
isActive Boolean @default(true)
|
||||
lastReceivedDataAt DateTime?
|
||||
isDeleted Boolean @default(false)
|
||||
Agent Agent? @relation(fields: [agentId, agentVersion], references: [id, version])
|
||||
agentId String? @db.Uuid
|
||||
agentVersion Int?
|
||||
|
||||
@@index([agentPresetId])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////// METRICS TRACKING TABLES ////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
model AnalyticsDetails {
|
||||
// PK uses gen_random_uuid() to allow the db inserts to happen outside of prisma
|
||||
// typical uuid() inserts are handled by prisma
|
||||
id String @id @default(dbgenerated("gen_random_uuid()")) @db.Uuid
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
// Link to User model
|
||||
userId String @db.Uuid
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
// Analytics Categorical data used for filtering (indexable w and w/o userId)
|
||||
type String
|
||||
|
||||
// Analytic Specific Data. We should use a union type here, but prisma doesn't support it.
|
||||
data Json @default("{}")
|
||||
|
||||
// Indexable field for any count based analytical measures like page order clicking, tutorial step completion, etc.
|
||||
dataIndex String?
|
||||
|
||||
@@index([userId, type], name: "analyticsDetails")
|
||||
@@index([type])
|
||||
}
|
||||
|
||||
model AnalyticsMetrics {
|
||||
id String @id @default(dbgenerated("gen_random_uuid()")) @db.Uuid
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
// Analytics Categorical data used for filtering (indexable w and w/o userId)
|
||||
analyticMetric String
|
||||
// Any numeric data that should be counted upon, summed, or otherwise aggregated.
|
||||
value Float
|
||||
// Any string data that should be used to identify the metric as distinct.
|
||||
// ex: '/build' vs '/market'
|
||||
dataString String?
|
||||
|
||||
// Link to User model
|
||||
userId String @db.Uuid
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
//////// ACCOUNTING AND CREDIT SYSTEM TABLES //////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
enum UserBlockCreditType {
|
||||
TOP_UP
|
||||
USAGE
|
||||
}
|
||||
|
||||
model UserBlockCredit {
|
||||
transactionKey String @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
userId String @db.Uuid
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
blockId String? @db.Uuid
|
||||
Block AgentBlock? @relation(fields: [blockId], references: [id])
|
||||
|
||||
// We need to be able to associate a credit transaction with an agent
|
||||
executedAgentId String? @db.Uuid
|
||||
executedAgentVersion Int?
|
||||
ExecutedAgent Agent? @relation(fields: [executedAgentId, executedAgentVersion], references: [id, version])
|
||||
|
||||
// We need to be able to associate a cost with a specific agent execution
|
||||
agentNodeExecutionId String? @db.Uuid
|
||||
AgentNodeExecution AgentNodeExecution? @relation(fields: [agentNodeExecutionId], references: [id])
|
||||
|
||||
amount Int
|
||||
type UserBlockCreditType
|
||||
|
||||
isActive Boolean @default(true)
|
||||
metadata Json @default("{}")
|
||||
|
||||
@@id(name: "creditTransactionIdentifier", [transactionKey, userId])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////// Store TABLES ///////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
model Profile {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
// Only 1 of user or group can be set.
|
||||
// The user this profile belongs to, if any.
|
||||
userId String? @db.Uuid
|
||||
User User? @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
// The group this profile belongs to, if any.
|
||||
groupId String? @db.Uuid
|
||||
Group UserGroup? @relation(fields: [groupId], references: [id])
|
||||
|
||||
username String @unique
|
||||
description String
|
||||
|
||||
links String[]
|
||||
|
||||
avatarUrl String?
|
||||
|
||||
@@index([username])
|
||||
}
|
||||
|
||||
model StoreListing {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
isDeleted Boolean @default(false)
|
||||
// Not needed but makes lookups faster
|
||||
isApproved Boolean @default(false)
|
||||
|
||||
// The agent link here is only so we can do lookup on agentId, for the listing the StoreListingVersion is used.
|
||||
agentId String @db.Uuid
|
||||
agentVersion Int
|
||||
Agent Agent @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
|
||||
|
||||
owningUserId String @db.Uuid
|
||||
OwningUser User @relation(fields: [owningUserId], references: [id])
|
||||
|
||||
isGroupListing Boolean @default(false)
|
||||
owningGroupId String? @db.Uuid
|
||||
OwningGroup UserGroup? @relation(fields: [owningGroupId], references: [id])
|
||||
|
||||
StoreListingVersions StoreListingVersion[]
|
||||
StoreListingSubmission StoreListingSubmission[]
|
||||
|
||||
@@index([isApproved])
|
||||
@@index([agentId])
|
||||
@@index([owningUserId])
|
||||
@@index([owningGroupId])
|
||||
}
|
||||
|
||||
model StoreListingVersion {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
version Int @default(1)
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
// The agent and version to be listed on the store
|
||||
agentId String @db.Uuid
|
||||
agentVersion Int
|
||||
Agent Agent @relation(fields: [agentId, agentVersion], references: [id, version])
|
||||
|
||||
// The detials for this version of the agent, this allows the author to update the details of the agent,
|
||||
// But still allow using old versions of the agent with there original details.
|
||||
// TODO: Create a database view that shows only the latest version of each store listing.
|
||||
slug String
|
||||
name String
|
||||
videoUrl String?
|
||||
imageUrls String[]
|
||||
description String
|
||||
categories String[]
|
||||
|
||||
isFeatured Boolean @default(false)
|
||||
|
||||
isDeleted Boolean @default(false)
|
||||
// Old versions can be made unavailable by the author if desired
|
||||
isAvailable Boolean @default(true)
|
||||
// Not needed but makes lookups faster
|
||||
isApproved Boolean @default(false)
|
||||
StoreListing StoreListing? @relation(fields: [storeListingId], references: [id], onDelete: Cascade)
|
||||
storeListingId String? @db.Uuid
|
||||
StoreListingSubmission StoreListingSubmission[]
|
||||
|
||||
// Reviews are on a specific version, but then aggregated up to the listing.
|
||||
// This allows us to provide a review filter to current version of the agent.
|
||||
StoreListingReview StoreListingReview[]
|
||||
|
||||
@@unique([agentId, agentVersion])
|
||||
@@index([agentId, agentVersion, isApproved])
|
||||
}
|
||||
|
||||
model StoreListingReview {
|
||||
id String @id @default(dbgenerated("gen_random_uuid()")) @db.Uuid
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
storeListingVersionId String @db.Uuid
|
||||
StoreListingVersion StoreListingVersion @relation(fields: [storeListingVersionId], references: [id], onDelete: Cascade)
|
||||
|
||||
reviewByUserId String @db.Uuid
|
||||
ReviewByUser User @relation(fields: [reviewByUserId], references: [id])
|
||||
|
||||
score Int
|
||||
comments String?
|
||||
}
|
||||
|
||||
enum SubmissionStatus {
|
||||
DAFT
|
||||
PENDING
|
||||
APPROVED
|
||||
REJECTED
|
||||
}
|
||||
|
||||
model StoreListingSubmission {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
storeListingId String @db.Uuid
|
||||
StoreListing StoreListing @relation(fields: [storeListingId], references: [id], onDelete: Cascade)
|
||||
|
||||
storeListingVersionId String @db.Uuid
|
||||
StoreListingVersion StoreListingVersion @relation(fields: [storeListingVersionId], references: [id], onDelete: Cascade)
|
||||
|
||||
reviewerId String @db.Uuid
|
||||
Reviewer User @relation(fields: [reviewerId], references: [id])
|
||||
|
||||
Status SubmissionStatus @default(PENDING)
|
||||
reviewComments String?
|
||||
|
||||
@@index([storeListingId])
|
||||
@@index([Status])
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from prisma.models import UserBlockCredit
|
||||
from prisma.models import CreditTransaction
|
||||
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.credit import UserCredit
|
||||
@@ -82,7 +82,7 @@ async def test_block_credit_reset(server: SpinTestServer):
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_credit_refill(server: SpinTestServer):
|
||||
# Clear all transactions within the month
|
||||
await UserBlockCredit.prisma().update_many(
|
||||
await CreditTransaction.prisma().update_many(
|
||||
where={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"createdAt": {
|
||||
|
||||
@@ -14,7 +14,6 @@ async def test_agent_schedule(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_graph = await server.agent_server.test_create_graph(
|
||||
create_graph=CreateGraph(graph=create_test_graph()),
|
||||
is_template=False,
|
||||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
|
||||
436
autogpt_platform/backend/test/test_data_creator.py
Normal file
436
autogpt_platform/backend/test/test_data_creator.py
Normal file
@@ -0,0 +1,436 @@
|
||||
import asyncio
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import prisma.enums
|
||||
from faker import Faker
|
||||
from prisma import Prisma
|
||||
|
||||
faker = Faker()
|
||||
|
||||
# Constants for data generation limits
|
||||
|
||||
# Base entities
|
||||
NUM_USERS = 100 # Creates 100 user records
|
||||
NUM_AGENT_BLOCKS = 100 # Creates 100 agent block templates
|
||||
|
||||
# Per-user entities
|
||||
MIN_GRAPHS_PER_USER = 1 # Each user will have between 1-5 graphs
|
||||
MAX_GRAPHS_PER_USER = 5 # Total graphs: 500-2500 (NUM_USERS * MIN/MAX_GRAPHS)
|
||||
|
||||
# Per-graph entities
|
||||
MIN_NODES_PER_GRAPH = 2 # Each graph will have between 2-5 nodes
|
||||
MAX_NODES_PER_GRAPH = (
|
||||
5 # Total nodes: 1000-2500 (GRAPHS_PER_USER * NUM_USERS * MIN/MAX_NODES)
|
||||
)
|
||||
|
||||
# Additional per-user entities
|
||||
MIN_PRESETS_PER_USER = 1 # Each user will have between 1-2 presets
|
||||
MAX_PRESETS_PER_USER = 5 # Total presets: 500-2500 (NUM_USERS * MIN/MAX_PRESETS)
|
||||
MIN_AGENTS_PER_USER = 1 # Each user will have between 1-2 agents
|
||||
MAX_AGENTS_PER_USER = 10 # Total agents: 500-5000 (NUM_USERS * MIN/MAX_AGENTS)
|
||||
|
||||
# Execution and review records
|
||||
MIN_EXECUTIONS_PER_GRAPH = 1 # Each graph will have between 1-5 execution records
|
||||
MAX_EXECUTIONS_PER_GRAPH = (
|
||||
20 # Total executions: 1000-5000 (TOTAL_GRAPHS * MIN/MAX_EXECUTIONS)
|
||||
)
|
||||
MIN_REVIEWS_PER_VERSION = 1 # Each version will have between 1-3 reviews
|
||||
MAX_REVIEWS_PER_VERSION = 5 # Total reviews depends on number of versions created
|
||||
|
||||
|
||||
def get_image():
|
||||
url = faker.image_url()
|
||||
while "placekitten.com" in url:
|
||||
url = faker.image_url()
|
||||
return url
|
||||
|
||||
|
||||
async def main():
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
|
||||
# Insert Users
|
||||
print(f"Inserting {NUM_USERS} users")
|
||||
users = []
|
||||
for _ in range(NUM_USERS):
|
||||
user = await db.user.create(
|
||||
data={
|
||||
"id": str(faker.uuid4()),
|
||||
"email": faker.unique.email(),
|
||||
"name": faker.name(),
|
||||
"metadata": prisma.Json({}),
|
||||
"integrations": "",
|
||||
}
|
||||
)
|
||||
users.append(user)
|
||||
|
||||
# Insert AgentBlocks
|
||||
agent_blocks = []
|
||||
print(f"Inserting {NUM_AGENT_BLOCKS} agent blocks")
|
||||
for _ in range(NUM_AGENT_BLOCKS):
|
||||
block = await db.agentblock.create(
|
||||
data={
|
||||
"name": f"{faker.word()}_{str(faker.uuid4())[:8]}",
|
||||
"inputSchema": "{}",
|
||||
"outputSchema": "{}",
|
||||
}
|
||||
)
|
||||
agent_blocks.append(block)
|
||||
|
||||
# Insert AgentGraphs
|
||||
agent_graphs = []
|
||||
print(f"Inserting {NUM_USERS * MAX_GRAPHS_PER_USER} agent graphs")
|
||||
for user in users:
|
||||
for _ in range(
|
||||
random.randint(MIN_GRAPHS_PER_USER, MAX_GRAPHS_PER_USER)
|
||||
): # Adjust the range to create more graphs per user if desired
|
||||
graph = await db.agentgraph.create(
|
||||
data={
|
||||
"name": faker.sentence(nb_words=3),
|
||||
"description": faker.text(max_nb_chars=200),
|
||||
"userId": user.id,
|
||||
"isActive": True,
|
||||
"isTemplate": False,
|
||||
}
|
||||
)
|
||||
agent_graphs.append(graph)
|
||||
|
||||
# Insert AgentNodes
|
||||
agent_nodes = []
|
||||
print(
|
||||
f"Inserting {NUM_USERS * MAX_GRAPHS_PER_USER * MAX_NODES_PER_GRAPH} agent nodes"
|
||||
)
|
||||
for graph in agent_graphs:
|
||||
num_nodes = random.randint(MIN_NODES_PER_GRAPH, MAX_NODES_PER_GRAPH)
|
||||
for _ in range(num_nodes): # Create 5 AgentNodes per graph
|
||||
block = random.choice(agent_blocks)
|
||||
node = await db.agentnode.create(
|
||||
data={
|
||||
"agentBlockId": block.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"constantInput": "{}",
|
||||
"metadata": "{}",
|
||||
}
|
||||
)
|
||||
agent_nodes.append(node)
|
||||
|
||||
# Insert AgentPresets
|
||||
agent_presets = []
|
||||
print(f"Inserting {NUM_USERS * MAX_PRESETS_PER_USER} agent presets")
|
||||
for user in users:
|
||||
num_presets = random.randint(MIN_PRESETS_PER_USER, MAX_PRESETS_PER_USER)
|
||||
for _ in range(num_presets): # Create 1 AgentPreset per user
|
||||
graph = random.choice(agent_graphs)
|
||||
preset = await db.agentpreset.create(
|
||||
data={
|
||||
"name": faker.sentence(nb_words=3),
|
||||
"description": faker.text(max_nb_chars=200),
|
||||
"userId": user.id,
|
||||
"agentId": graph.id,
|
||||
"agentVersion": graph.version,
|
||||
"isActive": True,
|
||||
}
|
||||
)
|
||||
agent_presets.append(preset)
|
||||
|
||||
# Insert UserAgents
|
||||
user_agents = []
|
||||
print(f"Inserting {NUM_USERS * MAX_AGENTS_PER_USER} user agents")
|
||||
for user in users:
|
||||
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
|
||||
for _ in range(num_agents): # Create 1 UserAgent per user
|
||||
graph = random.choice(agent_graphs)
|
||||
preset = random.choice(agent_presets)
|
||||
user_agent = await db.useragent.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"agentId": graph.id,
|
||||
"agentVersion": graph.version,
|
||||
"agentPresetId": preset.id,
|
||||
"isFavorite": random.choice([True, False]),
|
||||
"isCreatedByUser": random.choice([True, False]),
|
||||
"isArchived": random.choice([True, False]),
|
||||
"isDeleted": random.choice([True, False]),
|
||||
}
|
||||
)
|
||||
user_agents.append(user_agent)
|
||||
|
||||
# Insert AgentGraphExecutions
|
||||
# Insert AgentGraphExecutions
|
||||
agent_graph_executions = []
|
||||
print(
|
||||
f"Inserting {NUM_USERS * MAX_GRAPHS_PER_USER * MAX_EXECUTIONS_PER_GRAPH} agent graph executions"
|
||||
)
|
||||
graph_execution_data = []
|
||||
for graph in agent_graphs:
|
||||
user = random.choice(users)
|
||||
num_executions = random.randint(
|
||||
MIN_EXECUTIONS_PER_GRAPH, MAX_EXECUTIONS_PER_GRAPH
|
||||
)
|
||||
for _ in range(num_executions):
|
||||
matching_presets = [p for p in agent_presets if p.agentId == graph.id]
|
||||
preset = (
|
||||
random.choice(matching_presets)
|
||||
if matching_presets and random.random() < 0.5
|
||||
else None
|
||||
)
|
||||
|
||||
graph_execution_data.append(
|
||||
{
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"userId": user.id,
|
||||
"executionStatus": prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||
"startedAt": faker.date_time_this_year(),
|
||||
"agentPresetId": preset.id if preset else None,
|
||||
}
|
||||
)
|
||||
|
||||
agent_graph_executions = await db.agentgraphexecution.create_many(
|
||||
data=graph_execution_data
|
||||
)
|
||||
# Need to fetch the created records since create_many doesn't return them
|
||||
agent_graph_executions = await db.agentgraphexecution.find_many()
|
||||
|
||||
# Insert AgentNodeExecutions
|
||||
print(
|
||||
f"Inserting {NUM_USERS * MAX_GRAPHS_PER_USER * MAX_EXECUTIONS_PER_GRAPH} agent node executions"
|
||||
)
|
||||
node_execution_data = []
|
||||
for execution in agent_graph_executions:
|
||||
nodes = [
|
||||
node for node in agent_nodes if node.agentGraphId == execution.agentGraphId
|
||||
]
|
||||
for node in nodes:
|
||||
node_execution_data.append(
|
||||
{
|
||||
"agentGraphExecutionId": execution.id,
|
||||
"agentNodeId": node.id,
|
||||
"executionStatus": prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||
"addedTime": datetime.now(),
|
||||
}
|
||||
)
|
||||
|
||||
agent_node_executions = await db.agentnodeexecution.create_many(
|
||||
data=node_execution_data
|
||||
)
|
||||
# Need to fetch the created records since create_many doesn't return them
|
||||
agent_node_executions = await db.agentnodeexecution.find_many()
|
||||
|
||||
# Insert AgentNodeExecutionInputOutput
|
||||
print(
|
||||
f"Inserting {NUM_USERS * MAX_GRAPHS_PER_USER * MAX_EXECUTIONS_PER_GRAPH} agent node execution input/outputs"
|
||||
)
|
||||
input_output_data = []
|
||||
for node_execution in agent_node_executions:
|
||||
# Input data
|
||||
input_output_data.append(
|
||||
{
|
||||
"name": "input1",
|
||||
"data": "{}",
|
||||
"time": datetime.now(),
|
||||
"referencedByInputExecId": node_execution.id,
|
||||
}
|
||||
)
|
||||
# Output data
|
||||
input_output_data.append(
|
||||
{
|
||||
"name": "output1",
|
||||
"data": "{}",
|
||||
"time": datetime.now(),
|
||||
"referencedByOutputExecId": node_execution.id,
|
||||
}
|
||||
)
|
||||
|
||||
await db.agentnodeexecutioninputoutput.create_many(data=input_output_data)
|
||||
|
||||
# Insert AgentNodeLinks
|
||||
print(f"Inserting {NUM_USERS * MAX_GRAPHS_PER_USER} agent node links")
|
||||
for graph in agent_graphs:
|
||||
nodes = [node for node in agent_nodes if node.agentGraphId == graph.id]
|
||||
if len(nodes) >= 2:
|
||||
source_node = nodes[0]
|
||||
sink_node = nodes[1]
|
||||
await db.agentnodelink.create(
|
||||
data={
|
||||
"agentNodeSourceId": source_node.id,
|
||||
"sourceName": "output1",
|
||||
"agentNodeSinkId": sink_node.id,
|
||||
"sinkName": "input1",
|
||||
"isStatic": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Insert AnalyticsDetails
|
||||
print(f"Inserting {NUM_USERS} analytics details")
|
||||
for user in users:
|
||||
for _ in range(1):
|
||||
await db.analyticsdetails.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"type": faker.word(),
|
||||
"data": prisma.Json({}),
|
||||
"dataIndex": faker.word(),
|
||||
}
|
||||
)
|
||||
|
||||
# Insert AnalyticsMetrics
|
||||
print(f"Inserting {NUM_USERS} analytics metrics")
|
||||
for user in users:
|
||||
for _ in range(1):
|
||||
await db.analyticsmetrics.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"analyticMetric": faker.word(),
|
||||
"value": random.uniform(0, 100),
|
||||
"dataString": faker.word(),
|
||||
}
|
||||
)
|
||||
|
||||
# Insert CreditTransaction (formerly UserBlockCredit)
|
||||
print(f"Inserting {NUM_USERS} credit transactions")
|
||||
for user in users:
|
||||
for _ in range(1):
|
||||
block = random.choice(agent_blocks)
|
||||
await db.credittransaction.create(
|
||||
data={
|
||||
"transactionKey": str(faker.uuid4()),
|
||||
"userId": user.id,
|
||||
"blockId": block.id,
|
||||
"amount": random.randint(1, 100),
|
||||
"type": (
|
||||
prisma.enums.CreditTransactionType.TOP_UP
|
||||
if random.random() < 0.5
|
||||
else prisma.enums.CreditTransactionType.USAGE
|
||||
),
|
||||
"metadata": prisma.Json({}),
|
||||
}
|
||||
)
|
||||
|
||||
# Insert Profiles
|
||||
profiles = []
|
||||
print(f"Inserting {NUM_USERS} profiles")
|
||||
for user in users:
|
||||
profile = await db.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"name": user.name or faker.name(),
|
||||
"username": faker.unique.user_name(),
|
||||
"description": faker.text(),
|
||||
"links": [faker.url() for _ in range(3)],
|
||||
"avatarUrl": get_image(),
|
||||
}
|
||||
)
|
||||
profiles.append(profile)
|
||||
|
||||
# Insert StoreListings
|
||||
store_listings = []
|
||||
print(f"Inserting {NUM_USERS} store listings")
|
||||
for graph in agent_graphs:
|
||||
user = random.choice(users)
|
||||
listing = await db.storelisting.create(
|
||||
data={
|
||||
"agentId": graph.id,
|
||||
"agentVersion": graph.version,
|
||||
"owningUserId": user.id,
|
||||
"isApproved": random.choice([True, False]),
|
||||
}
|
||||
)
|
||||
store_listings.append(listing)
|
||||
|
||||
# Insert StoreListingVersions
|
||||
store_listing_versions = []
|
||||
print(f"Inserting {NUM_USERS} store listing versions")
|
||||
for listing in store_listings:
|
||||
graph = [g for g in agent_graphs if g.id == listing.agentId][0]
|
||||
version = await db.storelistingversion.create(
|
||||
data={
|
||||
"agentId": graph.id,
|
||||
"agentVersion": graph.version,
|
||||
"slug": faker.slug(),
|
||||
"name": graph.name or faker.sentence(nb_words=3),
|
||||
"subHeading": faker.sentence(),
|
||||
"videoUrl": faker.url(),
|
||||
"imageUrls": [get_image() for _ in range(3)],
|
||||
"description": faker.text(),
|
||||
"categories": [faker.word() for _ in range(3)],
|
||||
"isFeatured": random.choice([True, False]),
|
||||
"isAvailable": True,
|
||||
"isApproved": random.choice([True, False]),
|
||||
"storeListingId": listing.id,
|
||||
}
|
||||
)
|
||||
store_listing_versions.append(version)
|
||||
|
||||
# Insert StoreListingReviews
|
||||
print(f"Inserting {NUM_USERS * MAX_REVIEWS_PER_VERSION} store listing reviews")
|
||||
for version in store_listing_versions:
|
||||
# Create a copy of users list and shuffle it to avoid duplicates
|
||||
available_reviewers = users.copy()
|
||||
random.shuffle(available_reviewers)
|
||||
|
||||
# Limit number of reviews to available unique reviewers
|
||||
num_reviews = min(
|
||||
random.randint(MIN_REVIEWS_PER_VERSION, MAX_REVIEWS_PER_VERSION),
|
||||
len(available_reviewers),
|
||||
)
|
||||
|
||||
# Take only the first num_reviews reviewers
|
||||
for reviewer in available_reviewers[:num_reviews]:
|
||||
await db.storelistingreview.create(
|
||||
data={
|
||||
"storeListingVersionId": version.id,
|
||||
"reviewByUserId": reviewer.id,
|
||||
"score": random.randint(1, 5),
|
||||
"comments": faker.text(),
|
||||
}
|
||||
)
|
||||
|
||||
# Insert StoreListingSubmissions
|
||||
print(f"Inserting {NUM_USERS} store listing submissions")
|
||||
for listing in store_listings:
|
||||
version = random.choice(store_listing_versions)
|
||||
reviewer = random.choice(users)
|
||||
status: prisma.enums.SubmissionStatus = random.choice(
|
||||
[
|
||||
prisma.enums.SubmissionStatus.PENDING,
|
||||
prisma.enums.SubmissionStatus.APPROVED,
|
||||
prisma.enums.SubmissionStatus.REJECTED,
|
||||
]
|
||||
)
|
||||
await db.storelistingsubmission.create(
|
||||
data={
|
||||
"storeListingId": listing.id,
|
||||
"storeListingVersionId": version.id,
|
||||
"reviewerId": reviewer.id,
|
||||
"Status": status,
|
||||
"reviewComments": faker.text(),
|
||||
}
|
||||
)
|
||||
|
||||
# Insert APIKeys
|
||||
print(f"Inserting {NUM_USERS} api keys")
|
||||
for user in users:
|
||||
await db.apikey.create(
|
||||
data={
|
||||
"name": faker.word(),
|
||||
"prefix": str(faker.uuid4())[:8],
|
||||
"postfix": str(faker.uuid4())[-8:],
|
||||
"key": str(faker.sha256()),
|
||||
"status": prisma.enums.APIKeyStatus.ACTIVE,
|
||||
"permissions": [
|
||||
prisma.enums.APIKeyPermission.EXECUTE_GRAPH,
|
||||
prisma.enums.APIKeyPermission.READ_GRAPH,
|
||||
],
|
||||
"description": faker.text(),
|
||||
"userId": user.id,
|
||||
}
|
||||
)
|
||||
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -4,6 +4,7 @@ from backend.util.request import validate_url
|
||||
|
||||
|
||||
def test_validate_url():
|
||||
# Rejected IP ranges
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("localhost", [])
|
||||
|
||||
@@ -16,6 +17,63 @@ def test_validate_url():
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("0.0.0.0", [])
|
||||
|
||||
validate_url("google.com", [])
|
||||
validate_url("github.com", [])
|
||||
validate_url("http://github.com", [])
|
||||
# Normal URLs
|
||||
assert validate_url("google.com/a?b=c", []) == "http://google.com/a?b=c"
|
||||
assert validate_url("github.com?key=!@!@", []) == "http://github.com?key=!@!@"
|
||||
|
||||
# Scheme Enforcement
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("ftp://example.com", [])
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("file://example.com", [])
|
||||
|
||||
# International domain that converts to punycode - should be allowed if public
|
||||
assert validate_url("http://xn--exmple-cua.com", []) == "http://xn--exmple-cua.com"
|
||||
# If the domain fails IDNA encoding or is invalid, it should raise an error
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("http://exa◌mple.com", [])
|
||||
|
||||
# IPv6 Addresses
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("::1", []) # IPv6 loopback should be blocked
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("http://[::1]", []) # IPv6 loopback in URL form
|
||||
|
||||
# Suspicious Characters in Hostname
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("http://example_underscore.com", [])
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("http://exa mple.com", []) # Space in hostname
|
||||
|
||||
# Malformed URLs
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("http://", []) # No hostname
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("://missing-scheme", []) # Missing proper scheme
|
||||
|
||||
# Trusted Origins
|
||||
trusted = ["internal-api.company.com", "10.0.0.5"]
|
||||
assert (
|
||||
validate_url("internal-api.company.com", trusted)
|
||||
== "http://internal-api.company.com"
|
||||
)
|
||||
assert validate_url("10.0.0.5", ["10.0.0.5"]) == "http://10.0.0.5"
|
||||
|
||||
# Special Characters in Path or Query
|
||||
assert (
|
||||
validate_url("example.com/path%20with%20spaces", [])
|
||||
== "http://example.com/path%20with%20spaces"
|
||||
)
|
||||
|
||||
# Backslashes should be replaced with forward slashes
|
||||
assert (
|
||||
validate_url("http://example.com\\backslash", [])
|
||||
== "http://example.com/backslash"
|
||||
)
|
||||
|
||||
# Check defaulting scheme behavior for valid domains
|
||||
assert validate_url("example.com", []) == "http://example.com"
|
||||
assert validate_url("https://secure.com", []) == "https://secure.com"
|
||||
|
||||
# Non-ASCII Characters in Query/Fragment
|
||||
assert validate_url("example.com?param=äöü", []) == "http://example.com?param=äöü"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user