mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-13 09:08:02 -05:00
Compare commits
5 Commits
change-log
...
clarify-li
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8155a11d53 | ||
|
|
490dbde2cc | ||
|
|
f29e8c514d | ||
|
|
79914655ae | ||
|
|
6720ba895c |
@@ -34,7 +34,6 @@ jobs:
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
|
||||
trigger:
|
||||
|
||||
@@ -36,7 +36,6 @@ jobs:
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
trigger:
|
||||
needs: migrate
|
||||
|
||||
33
.github/workflows/platform-backend-ci.yml
vendored
33
.github/workflows/platform-backend-ci.yml
vendored
@@ -80,35 +80,18 @@ jobs:
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock
|
||||
HEAD_POETRY_VERSION=$(head -n 1 poetry.lock | grep -oP '(?<=Poetry )[0-9]+\.[0-9]+\.[0-9]+')
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
if [ -n "$BASE_REF" ]; then
|
||||
BASE_BRANCH=${BASE_REF/refs\/heads\//}
|
||||
BASE_POETRY_VERSION=$((git show "origin/$BASE_BRANCH":./poetry.lock; true) | head -n 1 | grep -oP '(?<=Poetry )[0-9]+\.[0-9]+\.[0-9]+')
|
||||
echo "Found Poetry version ${BASE_POETRY_VERSION} in backend/poetry.lock on ${BASE_REF}"
|
||||
POETRY_VERSION=$(printf '%s\n' "$HEAD_POETRY_VERSION" "$BASE_POETRY_VERSION" | sort -V | tail -n1)
|
||||
else
|
||||
POETRY_VERSION=$HEAD_POETRY_VERSION
|
||||
fi
|
||||
echo "Using Poetry version ${POETRY_VERSION}"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$POETRY_VERSION python3 -
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
env:
|
||||
BASE_REF: ${{ github.base_ref || github.event.merge_group.base_ref }}
|
||||
|
||||
- name: Check poetry.lock
|
||||
run: |
|
||||
poetry lock
|
||||
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
if ! git diff --quiet poetry.lock; then
|
||||
echo "Error: poetry.lock not up to date."
|
||||
echo
|
||||
git diff poetry.lock
|
||||
@@ -135,7 +118,6 @@ jobs:
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
@@ -152,13 +134,12 @@ jobs:
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
REDIS_HOST: 'localhost'
|
||||
REDIS_PORT: '6379'
|
||||
REDIS_PASSWORD: 'testpassword'
|
||||
|
||||
env:
|
||||
CI: true
|
||||
@@ -171,8 +152,8 @@ jobs:
|
||||
# If you want to replace this, you can do so by making our entire system generate
|
||||
# new credentials for each local user and update the environment variables in
|
||||
# the backend service, docker composes, and examples
|
||||
RABBITMQ_DEFAULT_USER: "rabbitmq_user_default"
|
||||
RABBITMQ_DEFAULT_PASS: "k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7"
|
||||
RABBITMQ_DEFAULT_USER: 'rabbitmq_user_default'
|
||||
RABBITMQ_DEFAULT_PASS: 'k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7'
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
|
||||
173
LICENSE
173
LICENSE
@@ -1,5 +1,8 @@
|
||||
All portions of this repository are under one of two licenses. The majority of the AutoGPT repository is under the MIT License below. The autogpt_platform folder is under the
|
||||
Polyform Shield License.
|
||||
All portions of this repository are under one of two licenses.
|
||||
|
||||
The all files outside of the autogpt_platform folder are under the MIT License below.
|
||||
|
||||
The autogpt_platform folder is under the Polyform Shield License below.
|
||||
|
||||
|
||||
MIT License
|
||||
@@ -27,3 +30,169 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
# PolyForm Shield License 1.0.0
|
||||
|
||||
<https://polyformproject.org/licenses/shield/1.0.0>
|
||||
|
||||
## Acceptance
|
||||
|
||||
In order to get any license under these terms, you must agree
|
||||
to them as both strict obligations and conditions to all
|
||||
your licenses.
|
||||
|
||||
## Copyright License
|
||||
|
||||
The licensor grants you a copyright license for the
|
||||
software to do everything you might do with the software
|
||||
that would otherwise infringe the licensor's copyright
|
||||
in it for any permitted purpose. However, you may
|
||||
only distribute the software according to [Distribution
|
||||
License](#distribution-license) and make changes or new works
|
||||
based on the software according to [Changes and New Works
|
||||
License](#changes-and-new-works-license).
|
||||
|
||||
## Distribution License
|
||||
|
||||
The licensor grants you an additional copyright license
|
||||
to distribute copies of the software. Your license
|
||||
to distribute covers distributing the software with
|
||||
changes and new works permitted by [Changes and New Works
|
||||
License](#changes-and-new-works-license).
|
||||
|
||||
## Notices
|
||||
|
||||
You must ensure that anyone who gets a copy of any part of
|
||||
the software from you also gets a copy of these terms or the
|
||||
URL for them above, as well as copies of any plain-text lines
|
||||
beginning with `Required Notice:` that the licensor provided
|
||||
with the software. For example:
|
||||
|
||||
> Required Notice: Copyright Yoyodyne, Inc. (http://example.com)
|
||||
|
||||
## Changes and New Works License
|
||||
|
||||
The licensor grants you an additional copyright license to
|
||||
make changes and new works based on the software for any
|
||||
permitted purpose.
|
||||
|
||||
## Patent License
|
||||
|
||||
The licensor grants you a patent license for the software that
|
||||
covers patent claims the licensor can license, or becomes able
|
||||
to license, that you would infringe by using the software.
|
||||
|
||||
## Noncompete
|
||||
|
||||
Any purpose is a permitted purpose, except for providing any
|
||||
product that competes with the software or any product the
|
||||
licensor or any of its affiliates provides using the software.
|
||||
|
||||
## Competition
|
||||
|
||||
Goods and services compete even when they provide functionality
|
||||
through different kinds of interfaces or for different technical
|
||||
platforms. Applications can compete with services, libraries
|
||||
with plugins, frameworks with development tools, and so on,
|
||||
even if they're written in different programming languages
|
||||
or for different computer architectures. Goods and services
|
||||
compete even when provided free of charge. If you market a
|
||||
product as a practical substitute for the software or another
|
||||
product, it definitely competes.
|
||||
|
||||
## New Products
|
||||
|
||||
If you are using the software to provide a product that does
|
||||
not compete, but the licensor or any of its affiliates brings
|
||||
your product into competition by providing a new version of
|
||||
the software or another product using the software, you may
|
||||
continue using versions of the software available under these
|
||||
terms beforehand to provide your competing product, but not
|
||||
any later versions.
|
||||
|
||||
## Discontinued Products
|
||||
|
||||
You may begin using the software to compete with a product
|
||||
or service that the licensor or any of its affiliates has
|
||||
stopped providing, unless the licensor includes a plain-text
|
||||
line beginning with `Licensor Line of Business:` with the
|
||||
software that mentions that line of business. For example:
|
||||
|
||||
> Licensor Line of Business: YoyodyneCMS Content Management
|
||||
System (http://example.com/cms)
|
||||
|
||||
## Sales of Business
|
||||
|
||||
If the licensor or any of its affiliates sells a line of
|
||||
business developing the software or using the software
|
||||
to provide a product, the buyer can also enforce
|
||||
[Noncompete](#noncompete) for that product.
|
||||
|
||||
## Fair Use
|
||||
|
||||
You may have "fair use" rights for the software under the
|
||||
law. These terms do not limit them.
|
||||
|
||||
## No Other Rights
|
||||
|
||||
These terms do not allow you to sublicense or transfer any of
|
||||
your licenses to anyone else, or prevent the licensor from
|
||||
granting licenses to anyone else. These terms do not imply
|
||||
any other licenses.
|
||||
|
||||
## Patent Defense
|
||||
|
||||
If you make any written claim that the software infringes or
|
||||
contributes to infringement of any patent, your patent license
|
||||
for the software granted under these terms ends immediately. If
|
||||
your company makes such a claim, your patent license ends
|
||||
immediately for work on behalf of your company.
|
||||
|
||||
## Violations
|
||||
|
||||
The first time you are notified in writing that you have
|
||||
violated any of these terms, or done anything with the software
|
||||
not covered by your licenses, your licenses can nonetheless
|
||||
continue if you come into full compliance with these terms,
|
||||
and take practical steps to correct past violations, within
|
||||
32 days of receiving notice. Otherwise, all your licenses
|
||||
end immediately.
|
||||
|
||||
## No Liability
|
||||
|
||||
***As far as the law allows, the software comes as is, without
|
||||
any warranty or condition, and the licensor will not be liable
|
||||
to you for any damages arising out of these terms or the use
|
||||
or nature of the software, under any kind of legal claim.***
|
||||
|
||||
## Definitions
|
||||
|
||||
The **licensor** is the individual or entity offering these
|
||||
terms, and the **software** is the software the licensor makes
|
||||
available under these terms.
|
||||
|
||||
A **product** can be a good or service, or a combination
|
||||
of them.
|
||||
|
||||
**You** refers to the individual or entity agreeing to these
|
||||
terms.
|
||||
|
||||
**Your company** is any legal entity, sole proprietorship,
|
||||
or other kind of organization that you work for, plus all
|
||||
its affiliates.
|
||||
|
||||
**Affiliates** means the other organizations than an
|
||||
organization has control over, is under the control of, or is
|
||||
under common control with.
|
||||
|
||||
**Control** means ownership of substantially all the assets of
|
||||
an entity, or the power to direct its management and policies
|
||||
by vote, contract, or otherwise. Control can be direct or
|
||||
indirect.
|
||||
|
||||
**Your licenses** are all the licenses granted to you for the
|
||||
software under these terms.
|
||||
|
||||
**Use** means anything you do with the software requiring one
|
||||
of your licenses.
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
[](https://discord.gg/autogpt)  
|
||||
[](https://twitter.com/Auto_GPT)  
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
**AutoGPT** is a powerful platform that allows you to create, deploy, and manage continuous AI agents that automate complex workflows.
|
||||
|
||||
@@ -15,11 +14,7 @@
|
||||
> Setting up and hosting the AutoGPT Platform yourself is a technical process.
|
||||
> If you'd rather something that just works, we recommend [joining the waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta.
|
||||
|
||||
### Updated Setup Instructions:
|
||||
We’ve moved to a fully maintained and regularly updated documentation site.
|
||||
|
||||
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
|
||||
|
||||
https://github.com/user-attachments/assets/d04273a5-b36a-4a37-818e-f631ce72d603
|
||||
|
||||
This tutorial assumes you have Docker, VSCode, git and npm installed.
|
||||
|
||||
@@ -84,7 +79,7 @@ Be part of the revolution! **AutoGPT** is here to stay, at the forefront of AI i
|
||||
|
||||
**Licensing:**
|
||||
|
||||
MIT License: The majority of the AutoGPT repository is under the MIT License.
|
||||
MIT License: All files outside of autogpt_platform folder are under the MIT License.
|
||||
|
||||
Polyform Shield License: This license applies to the autogpt_platform folder.
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
**Contributor License Agreement (“Agreement”)**
|
||||
|
||||
Thank you for your interest in the AutoGPT open source project at [https://github.com/Significant-Gravitas/AutoGPT](https://github.com/Significant-Gravitas/AutoGPT) stewarded by Determinist Ltd (“**Determinist**”), with offices at 3rd Floor 1 Ashley Road, Altrincham, Cheshire, WA14 2DT, United Kingdom. The form of license below is a document that clarifies the terms under which You, the person listed below, may contribute software code described below (the “**Contribution**”) to the project. We appreciate your participation in our project, and your help in improving our products, so we want you to understand what will be done with the Contributions. This license is for your protection as well as the protection of Determinist and its licensees; it does not change your rights to use your own Contributions for any other purpose.
|
||||
Thank you for your interest in the AutoGPT project at [https://github.com/Significant-Gravitas/AutoGPT](https://github.com/Significant-Gravitas/AutoGPT) stewarded by Determinist Ltd (“**Determinist**”), with offices at 3rd Floor 1 Ashley Road, Altrincham, Cheshire, WA14 2DT, United Kingdom. The form of license below is a document that clarifies the terms under which You, the person listed below, may contribute software code described below (the “**Contribution**”) to the project. We appreciate your participation in our project, and your help in improving our products, so we want you to understand what will be done with the Contributions. This license is for your protection as well as the protection of Determinist and its licensees; it does not change your rights to use your own Contributions for any other purpose.
|
||||
|
||||
By submitting a Pull Request which modifies the content of the “autogpt\_platform” folder at [https://github.com/Significant-Gravitas/AutoGPT/tree/master/autogpt\_platform](https://github.com/Significant-Gravitas/AutoGPT/tree/master/autogpt_platform), You hereby agree:
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from .filters import BelowLevelFilter
|
||||
from .formatters import AGPTFormatter
|
||||
from .formatters import AGPTFormatter, StructuredLoggingFormatter
|
||||
|
||||
LOG_DIR = Path(__file__).parent.parent.parent.parent / "logs"
|
||||
LOG_FILE = "activity.log"
|
||||
@@ -81,26 +81,9 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
"""
|
||||
|
||||
config = LoggingConfig()
|
||||
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
# Console output handlers
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
log_handlers += [stdout, stderr]
|
||||
|
||||
# Cloud logging setup
|
||||
if config.enable_cloud_logging or force_cloud_logging:
|
||||
import google.cloud.logging
|
||||
@@ -114,7 +97,26 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
transport=SyncTransport,
|
||||
)
|
||||
cloud_handler.setLevel(config.level)
|
||||
cloud_handler.setFormatter(StructuredLoggingFormatter())
|
||||
log_handlers.append(cloud_handler)
|
||||
else:
|
||||
# Console output handlers
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
log_handlers += [stdout, stderr]
|
||||
|
||||
# File logging setup
|
||||
if config.enable_file_logging:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from colorama import Fore, Style
|
||||
from google.cloud.logging_v2.handlers import CloudLoggingFilter, StructuredLogHandler
|
||||
|
||||
from .utils import remove_color_codes
|
||||
|
||||
@@ -79,3 +80,16 @@ class AGPTFormatter(FancyConsoleFormatter):
|
||||
return remove_color_codes(super().format(record))
|
||||
else:
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class StructuredLoggingFormatter(StructuredLogHandler, logging.Formatter):
|
||||
def __init__(self):
|
||||
# Set up CloudLoggingFilter to add diagnostic info to the log records
|
||||
self.cloud_logging_filter = CloudLoggingFilter()
|
||||
|
||||
# Init StructuredLogHandler
|
||||
super().__init__()
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
self.cloud_logging_filter.filter(record)
|
||||
return super().format(record)
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import uvicorn.config
|
||||
from colorama import Fore
|
||||
|
||||
|
||||
@@ -26,14 +25,3 @@ def print_attribute(
|
||||
"color": value_color,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def generate_uvicorn_config():
|
||||
"""
|
||||
Generates a uvicorn logging config that silences uvicorn's default logging and tells it to use the native logging module.
|
||||
"""
|
||||
log_config = dict(uvicorn.config.LOGGING_CONFIG)
|
||||
log_config["loggers"]["uvicorn"] = {"handlers": []}
|
||||
log_config["loggers"]["uvicorn.error"] = {"handlers": []}
|
||||
log_config["loggers"]["uvicorn.access"] = {"handlers": []}
|
||||
return log_config
|
||||
|
||||
@@ -1,59 +1,20 @@
|
||||
import inspect
|
||||
import threading
|
||||
from typing import Any, Awaitable, Callable, ParamSpec, TypeVar, cast, overload
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
def thread_cached(
|
||||
func: Callable[P, R] | Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
thread_local = threading.local()
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (func, args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cache[key]
|
||||
|
||||
return async_wrapper
|
||||
else:
|
||||
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
# Include function in the key to prevent collisions between different functions
|
||||
key = (func, args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable[..., Any]) -> None:
|
||||
"""Clear the cache for a thread-cached function."""
|
||||
thread_local = threading.local()
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is not None:
|
||||
# Clear all entries that match the function
|
||||
for key in list(cache.keys()):
|
||||
if key and len(key) > 0 and key[0] == func:
|
||||
del cache[key]
|
||||
return wrapper
|
||||
|
||||
@@ -8,7 +8,6 @@ DB_CONNECT_TIMEOUT=60
|
||||
DB_POOL_TIMEOUT=300
|
||||
DB_SCHEMA=platform
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
|
||||
# EXECUTOR
|
||||
|
||||
@@ -73,6 +73,7 @@ FROM server_dependencies AS server
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
RUN poetry install --no-ansi --only-root
|
||||
|
||||
ENV DATABASE_URL=""
|
||||
ENV PORT=8000
|
||||
|
||||
CMD ["poetry", "run", "rest"]
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -17,6 +19,21 @@ from backend.util import json
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_executor_manager_client():
|
||||
from backend.executor import ExecutionManager
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_event_bus():
|
||||
from backend.data.execution import RedisExecutionEventBus
|
||||
|
||||
return RedisExecutionEventBus()
|
||||
|
||||
|
||||
class AgentExecutorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
user_id: str = SchemaField(description="User ID")
|
||||
@@ -59,11 +76,11 @@ class AgentExecutorBlock(Block):
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
from backend.data.execution import ExecutionEventType
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
event_bus = execution_utils.get_execution_event_bus()
|
||||
executor_manager = get_executor_manager_client()
|
||||
event_bus = get_event_bus()
|
||||
|
||||
graph_exec = execution_utils.add_graph_execution(
|
||||
graph_exec = executor_manager.add_execution(
|
||||
graph_id=input_data.graph_id,
|
||||
graph_version=input_data.graph_version,
|
||||
user_id=input_data.user_id,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
@@ -143,12 +143,11 @@ class ContactEmail(BaseModel):
|
||||
class EmploymentHistory(BaseModel):
|
||||
"""An employment history in Apollo"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
|
||||
_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
@@ -189,12 +188,11 @@ class TypedCustomField(BaseModel):
|
||||
class Pagination(BaseModel):
|
||||
"""Pagination in Apollo"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
class Config:
|
||||
extra = "allow" # Allow extra fields
|
||||
arbitrary_types_allowed = True # Allow any type
|
||||
from_attributes = True # Allow from_orm
|
||||
populate_by_name = True # Allow field aliases to work both ways
|
||||
|
||||
page: int = 0
|
||||
per_page: int = 0
|
||||
@@ -232,12 +230,11 @@ class PhoneNumber(BaseModel):
|
||||
class Organization(BaseModel):
|
||||
"""An organization in Apollo"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
|
||||
id: Optional[str] = "N/A"
|
||||
name: Optional[str] = "N/A"
|
||||
@@ -271,12 +268,11 @@ class Organization(BaseModel):
|
||||
class Contact(BaseModel):
|
||||
"""A contact in Apollo"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
|
||||
contact_roles: list[Any] = []
|
||||
id: Optional[str] = None
|
||||
@@ -373,14 +369,14 @@ If a company has several office locations, results are still based on the headqu
|
||||
|
||||
To exclude companies based on location, use the organization_not_locations parameter.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organizations_not_locations: list[str] = SchemaField(
|
||||
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
|
||||
|
||||
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
q_organization_keyword_tags: list[str] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
|
||||
@@ -394,7 +390,7 @@ If the value you enter for this parameter does not match with a company's name,
|
||||
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, identify the values for organization_id when you call this endpoint.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
@@ -447,14 +443,14 @@ Results also include job titles with the same terms, even if they are not exact
|
||||
|
||||
Use this parameter in combination with the person_seniorities[] parameter to find people based on specific job functions and seniority levels.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
placeholder="marketing manager",
|
||||
)
|
||||
person_locations: list[str] = SchemaField(
|
||||
description="""The location where people live. You can search across cities, US states, and countries.
|
||||
|
||||
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
person_seniorities: list[SenorityLevels] = SchemaField(
|
||||
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
|
||||
@@ -464,7 +460,7 @@ For a person to be included in search results, they only need to match 1 of the
|
||||
Searches only return results based on their current job title, so searching for Director-level employees only returns people that currently hold a Director-level title. If someone was previously a Director, but is currently a VP, they would not be included in your search results.
|
||||
|
||||
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organization_locations: list[str] = SchemaField(
|
||||
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
|
||||
@@ -472,7 +468,7 @@ Use this parameter in combination with the person_titles[] parameter to find peo
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
|
||||
|
||||
To find people based on their personal location, use the person_locations parameter.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
q_organization_domains: list[str] = SchemaField(
|
||||
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
|
||||
@@ -480,23 +476,23 @@ To find people based on their personal location, use the person_locations parame
|
||||
You can add multiple domains to search across companies.
|
||||
|
||||
Examples: apollo.io and microsoft.com""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
|
||||
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
q_keywords: str = SchemaField(
|
||||
description="""A string of words over which we want to filter the results""",
|
||||
@@ -526,12 +522,11 @@ Use the page parameter to search the different pages of data.""",
|
||||
class SearchPeopleResponse(BaseModel):
|
||||
"""Response from Apollo's search people API"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
class Config:
|
||||
extra = "allow" # Allow extra fields
|
||||
arbitrary_types_allowed = True # Allow any type
|
||||
from_attributes = True # Allow from_orm
|
||||
populate_by_name = True # Allow field aliases to work both ways
|
||||
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
|
||||
@@ -32,18 +32,18 @@ If a company has several office locations, results are still based on the headqu
|
||||
|
||||
To exclude companies based on location, use the organization_not_locations parameter.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organizations_not_locations: list[str] = SchemaField(
|
||||
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
|
||||
|
||||
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
q_organization_keyword_tags: list[str] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
q_organization_name: str = SchemaField(
|
||||
description="""Filter search results to include a specific company name.
|
||||
@@ -56,7 +56,7 @@ If the value you enter for this parameter does not match with a company's name,
|
||||
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, identify the values for organization_id when you call this endpoint.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
@@ -72,7 +72,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
class Output(BlockSchema):
|
||||
organizations: list[Organization] = SchemaField(
|
||||
description="List of organizations found",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organization: Organization = SchemaField(
|
||||
description="Each found organization, one at a time",
|
||||
|
||||
@@ -26,14 +26,14 @@ class SearchPeopleBlock(Block):
|
||||
|
||||
Use this parameter in combination with the person_seniorities[] parameter to find people based on specific job functions and seniority levels.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
person_locations: list[str] = SchemaField(
|
||||
description="""The location where people live. You can search across cities, US states, and countries.
|
||||
|
||||
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
person_seniorities: list[SenorityLevels] = SchemaField(
|
||||
@@ -44,7 +44,7 @@ class SearchPeopleBlock(Block):
|
||||
Searches only return results based on their current job title, so searching for Director-level employees only returns people that currently hold a Director-level title. If someone was previously a Director, but is currently a VP, they would not be included in your search results.
|
||||
|
||||
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
organization_locations: list[str] = SchemaField(
|
||||
@@ -53,7 +53,7 @@ class SearchPeopleBlock(Block):
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
|
||||
|
||||
To find people based on their personal location, use the person_locations parameter.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
q_organization_domains: list[str] = SchemaField(
|
||||
@@ -62,26 +62,26 @@ class SearchPeopleBlock(Block):
|
||||
You can add multiple domains to search across companies.
|
||||
|
||||
Examples: apollo.io and microsoft.com""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
|
||||
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
q_keywords: str = SchemaField(
|
||||
@@ -104,7 +104,7 @@ class SearchPeopleBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
people: list[Contact] = SchemaField(
|
||||
description="List of people found",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
person: Contact = SchemaField(
|
||||
description="Each found person, one at a time",
|
||||
|
||||
@@ -151,7 +151,7 @@ class FindInDictionaryBlock(Block):
|
||||
class AddToDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
default={},
|
||||
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
|
||||
)
|
||||
key: str = SchemaField(
|
||||
@@ -167,7 +167,7 @@ class AddToDictionaryBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
entries: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
default={},
|
||||
description="The entries to add to the dictionary. This is the batch version of the `key` and `value` fields.",
|
||||
advanced=True,
|
||||
)
|
||||
@@ -229,7 +229,7 @@ class AddToDictionaryBlock(Block):
|
||||
class AddToListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
description="The list to add the entry to. If not provided, a new list will be created.",
|
||||
)
|
||||
@@ -239,7 +239,7 @@ class AddToListBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
entries: List[Any] = SchemaField(
|
||||
default_factory=lambda: list(),
|
||||
default=[],
|
||||
description="The entries to add to the list. This is the batch version of the `entry` field.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
@@ -55,7 +55,7 @@ class CodeExecutionBlock(Block):
|
||||
"These commands are executed with `sh`, in the foreground."
|
||||
),
|
||||
placeholder="pip install cowsay",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -207,7 +207,7 @@ class InstantiationBlock(Block):
|
||||
"These commands are executed with `sh`, in the foreground."
|
||||
),
|
||||
placeholder="pip install cowsay",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ class ReadCsvBlock(Block):
|
||||
)
|
||||
skip_columns: list[str] = SchemaField(
|
||||
description="The columns to skip from the start of the row",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -49,7 +49,7 @@ class ExaContentsBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of document contents",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
|
||||
@@ -38,11 +38,11 @@ class ExaSearchBlock(Block):
|
||||
)
|
||||
include_domains: List[str] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
exclude_domains: List[str] = SchemaField(
|
||||
description="Domains to exclude from search",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
@@ -59,12 +59,12 @@ class ExaSearchBlock(Block):
|
||||
)
|
||||
include_text: List[str] = SchemaField(
|
||||
description="Text patterns to include",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
exclude_text: List[str] = SchemaField(
|
||||
description="Text patterns to exclude",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
@@ -76,7 +76,7 @@ class ExaSearchBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of search results",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -26,12 +26,12 @@ class ExaFindSimilarBlock(Block):
|
||||
)
|
||||
include_domains: List[str] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
exclude_domains: List[str] = SchemaField(
|
||||
description="Domains to exclude from search",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
@@ -48,12 +48,12 @@ class ExaFindSimilarBlock(Block):
|
||||
)
|
||||
include_text: List[str] = SchemaField(
|
||||
description="Text patterns to include (max 1 string, up to 5 words)",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
exclude_text: List[str] = SchemaField(
|
||||
description="Text patterns to exclude (max 1 string, up to 5 words)",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
@@ -65,7 +65,7 @@ class ExaFindSimilarBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
results: List[Any] = SchemaField(
|
||||
description="List of similar documents with title, URL, published date, author, and score",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -42,7 +42,7 @@ class AIVideoGeneratorBlock(Block):
|
||||
description="Error message if video generation failed."
|
||||
)
|
||||
logs: list[str] = SchemaField(
|
||||
description="Generation progress logs.",
|
||||
description="Generation progress logs.", optional=True
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockManualWebhookConfig,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.generic import GenericWebhookType
|
||||
|
||||
|
||||
class GenericWebhookTriggerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
constants: dict = SchemaField(
|
||||
description="The constants to be set when the block is put on the graph",
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
payload: dict = SchemaField(
|
||||
description="The complete webhook payload that was received from the generic webhook."
|
||||
)
|
||||
constants: dict = SchemaField(
|
||||
description="The constants to be set when the block is put on the graph"
|
||||
)
|
||||
|
||||
example_payload = {"message": "Hello, World!"}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8fa8c167-2002-47ce-aba8-97572fc5d387",
|
||||
description="This block will output the contents of the generic input for the webhook.",
|
||||
categories={BlockCategory.INPUT},
|
||||
input_schema=GenericWebhookTriggerBlock.Input,
|
||||
output_schema=GenericWebhookTriggerBlock.Output,
|
||||
webhook_config=BlockManualWebhookConfig(
|
||||
provider=ProviderName.GENERIC_WEBHOOK,
|
||||
webhook_type=GenericWebhookType.PLAIN,
|
||||
),
|
||||
test_input={"constants": {"key": "value"}, "payload": self.example_payload},
|
||||
test_output=[
|
||||
("constants", {"key": "value"}),
|
||||
("payload", self.example_payload),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "constants", input_data.constants
|
||||
yield "payload", input_data.payload
|
||||
@@ -37,7 +37,7 @@ class GitHubTriggerBase:
|
||||
placeholder="{owner}/{repo}",
|
||||
)
|
||||
# --8<-- [start:example-payload-field]
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
payload: dict = SchemaField(hidden=True, default={})
|
||||
# --8<-- [end:example-payload-field]
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -34,7 +34,7 @@ class SendWebRequestBlock(Block):
|
||||
)
|
||||
headers: dict[str, str] = SchemaField(
|
||||
description="The headers to include in the request",
|
||||
default_factory=dict,
|
||||
default={},
|
||||
)
|
||||
json_format: bool = SchemaField(
|
||||
title="JSON format",
|
||||
|
||||
@@ -15,8 +15,7 @@ class HubSpotCompanyBlock(Block):
|
||||
description="Operation to perform (create, update, get)", default="get"
|
||||
)
|
||||
company_data: dict = SchemaField(
|
||||
description="Company data for create/update operations",
|
||||
default_factory=dict,
|
||||
description="Company data for create/update operations", default={}
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Company domain for get/update operations", default=""
|
||||
|
||||
@@ -15,8 +15,7 @@ class HubSpotContactBlock(Block):
|
||||
description="Operation to perform (create, update, get)", default="get"
|
||||
)
|
||||
contact_data: dict = SchemaField(
|
||||
description="Contact data for create/update operations",
|
||||
default_factory=dict,
|
||||
description="Contact data for create/update operations", default={}
|
||||
)
|
||||
email: str = SchemaField(
|
||||
description="Email address for get/update operations", default=""
|
||||
|
||||
@@ -19,7 +19,7 @@ class HubSpotEngagementBlock(Block):
|
||||
)
|
||||
email_data: dict = SchemaField(
|
||||
description="Email data including recipient, subject, content",
|
||||
default_factory=dict,
|
||||
default={},
|
||||
)
|
||||
contact_id: str = SchemaField(
|
||||
description="Contact ID for engagement tracking", default=""
|
||||
@@ -27,6 +27,7 @@ class HubSpotEngagementBlock(Block):
|
||||
timeframe_days: int = SchemaField(
|
||||
description="Number of days to look back for engagement",
|
||||
default=30,
|
||||
optional=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
from datetime import date, time
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -39,7 +38,7 @@ class AgentInputBlock(Block):
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="The placeholder values to be passed as input.",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
hidden=True,
|
||||
)
|
||||
@@ -55,7 +54,7 @@ class AgentInputBlock(Block):
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = copy.deepcopy(self.get_field_schema("value"))
|
||||
schema = self.get_field_schema("value")
|
||||
if possible_values := self.placeholder_values:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
@@ -468,7 +467,7 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="Possible values for the dropdown.",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
title="Dropdown Options",
|
||||
)
|
||||
|
||||
@@ -11,13 +11,13 @@ class StepThroughItemsBlock(Block):
|
||||
advanced=False,
|
||||
description="The list or dictionary of items to iterate over",
|
||||
placeholder="[1, 2, 3, 4, 5] or {'key1': 'value1', 'key2': 'value2'}",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
items_object: dict = SchemaField(
|
||||
advanced=False,
|
||||
description="The list or dictionary of items to iterate over",
|
||||
placeholder="[1, 2, 3, 4, 5] or {'key1': 'value1', 'key2': 'value2'}",
|
||||
default_factory=dict,
|
||||
default={},
|
||||
)
|
||||
items_str: str = SchemaField(
|
||||
advanced=False,
|
||||
|
||||
@@ -23,7 +23,7 @@ class JinaChunkingBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
chunks: list = SchemaField(description="List of chunked texts")
|
||||
tokens: list = SchemaField(
|
||||
description="List of token information for each chunk",
|
||||
description="List of token information for each chunk", optional=True
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from urllib.parse import quote
|
||||
from groq._utils._utils import quote
|
||||
|
||||
from backend.blocks.jina._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
|
||||
@@ -28,8 +28,8 @@ class LinearCreateIssueBlock(Block):
|
||||
priority: int | None = SchemaField(
|
||||
description="Priority of the issue",
|
||||
default=None,
|
||||
ge=0,
|
||||
le=4,
|
||||
minimum=0,
|
||||
maximum=4,
|
||||
)
|
||||
project_name: str | None = SchemaField(
|
||||
description="Name of the project to create the issue on",
|
||||
|
||||
@@ -4,25 +4,30 @@ from abc import ABC
|
||||
from enum import Enum, EnumMeta
|
||||
from json import JSONDecodeError
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, NamedTuple, Optional
|
||||
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import _EnumMemberT
|
||||
|
||||
import anthropic
|
||||
import ollama
|
||||
import openai
|
||||
from anthropic import NotGiven
|
||||
from anthropic._types import NotGiven
|
||||
from anthropic.types import ToolParam
|
||||
from groq import Groq
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.settings import BehaveAs, Settings
|
||||
from backend.util.text import TextFormatter
|
||||
@@ -72,10 +77,12 @@ class ModelMetadata(NamedTuple):
|
||||
|
||||
class LlmModelMeta(EnumMeta):
|
||||
@property
|
||||
def __members__(self) -> MappingProxyType:
|
||||
def __members__(
|
||||
self: type["_EnumMemberT"],
|
||||
) -> MappingProxyType[str, "_EnumMemberT"]:
|
||||
if Settings().config.behave_as == BehaveAs.LOCAL:
|
||||
members = super().__members__
|
||||
return MappingProxyType(members)
|
||||
return members
|
||||
else:
|
||||
removed_providers = ["ollama"]
|
||||
existing_members = super().__members__
|
||||
@@ -135,8 +142,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
||||
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
@@ -218,8 +223,6 @@ MODEL_METADATA = {
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata("open_router", 300000, 5120),
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata("open_router", 65536, 4096),
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata("open_router", 4096, 4096),
|
||||
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata("open_router", 131072, 131072),
|
||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata("open_router", 1048576, 1000000),
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
@@ -421,7 +424,7 @@ def llm_call(
|
||||
response=(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else getattr(resp.content[0], "text", "")
|
||||
else resp.content[0].text
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
@@ -525,7 +528,7 @@ def llm_call(
|
||||
class AIBlockBase(Block, ABC):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.prompt = []
|
||||
self.prompt = ""
|
||||
|
||||
def merge_llm_stats(self, block: "AIBlockBase"):
|
||||
self.merge_stats(block.execution_stats)
|
||||
@@ -555,7 +558,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
description="The system prompt to provide additional context to the model.",
|
||||
)
|
||||
conversation_history: list[dict] = SchemaField(
|
||||
default_factory=list,
|
||||
default=[],
|
||||
description="The conversation history to provide context for the prompt.",
|
||||
)
|
||||
retry: int = SchemaField(
|
||||
@@ -565,7 +568,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False,
|
||||
default_factory=dict,
|
||||
default={},
|
||||
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
@@ -584,7 +587,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
response: dict[str, Any] = SchemaField(
|
||||
description="The response object generated by the language model."
|
||||
)
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
prompt: str = SchemaField(description="The prompt sent to the language model.")
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -606,7 +609,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("response", {"key1": "key1Value", "key2": "key2Value"}),
|
||||
("prompt", list),
|
||||
("prompt", str),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: LLMResponse(
|
||||
@@ -639,7 +642,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
Test mocks work only on class functions, this wraps the llm_call function
|
||||
so that it can be mocked withing the block testing framework.
|
||||
"""
|
||||
self.prompt = prompt
|
||||
return llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
@@ -794,7 +796,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False,
|
||||
default_factory=dict,
|
||||
default={},
|
||||
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
@@ -812,7 +814,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
response: str = SchemaField(
|
||||
description="The response generated by the language model."
|
||||
)
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
prompt: str = SchemaField(description="The prompt sent to the language model.")
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -829,7 +831,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("response", "Response text"),
|
||||
("prompt", list),
|
||||
("prompt", str),
|
||||
],
|
||||
test_mock={"llm_call": lambda *args, **kwargs: "Response text"},
|
||||
)
|
||||
@@ -848,10 +850,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
object_input_data = AIStructuredResponseGeneratorBlock.Input(
|
||||
**{
|
||||
attr: getattr(input_data, attr)
|
||||
for attr in AITextGeneratorBlock.Input.model_fields
|
||||
},
|
||||
**{attr: getattr(input_data, attr) for attr in input_data.model_fields},
|
||||
expected_format={},
|
||||
)
|
||||
yield "response", self.llm_call(object_input_data, credentials)
|
||||
@@ -908,7 +907,7 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
|
||||
class Output(BlockSchema):
|
||||
summary: str = SchemaField(description="The final summary of the text.")
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
prompt: str = SchemaField(description="The prompt sent to the language model.")
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -925,7 +924,7 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("summary", "Final summary of a long text"),
|
||||
("prompt", list),
|
||||
("prompt", str),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda input_data, credentials: (
|
||||
@@ -1034,14 +1033,8 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
|
||||
class AIConversationBlock(AIBlockBase):
|
||||
class Input(BlockSchema):
|
||||
prompt: str = SchemaField(
|
||||
description="The prompt to send to the language model.",
|
||||
placeholder="Enter your prompt here...",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
messages: List[Any] = SchemaField(
|
||||
description="List of messages in the conversation.",
|
||||
description="List of messages in the conversation.", min_length=1
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
@@ -1064,7 +1057,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
response: str = SchemaField(
|
||||
description="The model's response to the conversation."
|
||||
)
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
prompt: str = SchemaField(description="The prompt sent to the language model.")
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -1093,7 +1086,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
"response",
|
||||
"The 2020 World Series was played at Globe Life Field in Arlington, Texas.",
|
||||
),
|
||||
("prompt", list),
|
||||
("prompt", str),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: "The 2020 World Series was played at Globe Life Field in Arlington, Texas."
|
||||
@@ -1115,7 +1108,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
) -> BlockOutput:
|
||||
response = self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt=input_data.prompt,
|
||||
prompt="",
|
||||
credentials=input_data.credentials,
|
||||
model=input_data.model,
|
||||
conversation_history=input_data.messages,
|
||||
@@ -1173,7 +1166,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
list_item: str = SchemaField(
|
||||
description="Each individual item in the list.",
|
||||
)
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
prompt: str = SchemaField(description="The prompt sent to the language model.")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the list generation failed."
|
||||
)
|
||||
@@ -1205,7 +1198,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"generated_list",
|
||||
["Zylora Prime", "Kharon-9", "Vortexia", "Oceara", "Draknos"],
|
||||
),
|
||||
("prompt", list),
|
||||
("prompt", str),
|
||||
("list_item", "Zylora Prime"),
|
||||
("list_item", "Kharon-9"),
|
||||
("list_item", "Vortexia"),
|
||||
|
||||
@@ -65,7 +65,7 @@ class AddMemoryBlock(Block, Mem0Base):
|
||||
default=Content(discriminator="content", content="I'm a vegetarian"),
|
||||
)
|
||||
metadata: dict[str, Any] = SchemaField(
|
||||
description="Optional metadata for the memory", default_factory=dict
|
||||
description="Optional metadata for the memory", default={}
|
||||
)
|
||||
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
@@ -173,7 +173,7 @@ class SearchMemoryBlock(Block, Mem0Base):
|
||||
)
|
||||
categories_filter: list[str] = SchemaField(
|
||||
description="Categories to filter by",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
|
||||
@@ -177,8 +177,7 @@ class PineconeInsertBlock(Block):
|
||||
description="Namespace to use in Pinecone", default=""
|
||||
)
|
||||
metadata: dict = SchemaField(
|
||||
description="Additional metadata to store with each vector",
|
||||
default_factory=dict,
|
||||
description="Additional metadata to store with each vector", default={}
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -26,7 +26,7 @@ class Slant3DTriggerBase:
|
||||
class Input(BlockSchema):
|
||||
credentials: Slant3DCredentialsInput = Slant3DCredentialsField()
|
||||
# Webhook URL is handled by the webhook system
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
payload: dict = SchemaField(hidden=True, default={})
|
||||
|
||||
class Output(BlockSchema):
|
||||
payload: dict = SchemaField(
|
||||
|
||||
@@ -14,6 +14,7 @@ from backend.data.block import (
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
@@ -154,7 +155,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
description="The system prompt to provide additional context to the model.",
|
||||
)
|
||||
conversation_history: list[dict] = SchemaField(
|
||||
default_factory=list,
|
||||
default=[],
|
||||
description="The conversation history to provide context for the prompt.",
|
||||
)
|
||||
last_tool_output: Any = SchemaField(
|
||||
@@ -168,7 +169,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False,
|
||||
default_factory=dict,
|
||||
default={},
|
||||
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
@@ -263,7 +264,9 @@ class SmartDecisionMakerBlock(Block):
|
||||
Raises:
|
||||
ValueError: If the block specified by sink_node.block_id is not found.
|
||||
"""
|
||||
block = sink_node.block
|
||||
block = get_block(sink_node.block_id)
|
||||
if not block:
|
||||
raise ValueError(f"Block not found: {sink_node.block_id}")
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", block.name).lower(),
|
||||
|
||||
@@ -112,7 +112,7 @@ class AddLeadToCampaignBlock(Block):
|
||||
lead_list: list[LeadInput] = SchemaField(
|
||||
description="An array of JSON objects, each representing a lead's details. Can hold max 100 leads.",
|
||||
max_length=100,
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
settings: LeadUploadSettings = SchemaField(
|
||||
@@ -248,7 +248,7 @@ class SaveCampaignSequencesBlock(Block):
|
||||
)
|
||||
sequences: list[Sequence] = SchemaField(
|
||||
description="The sequences to save",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
|
||||
@@ -39,7 +39,7 @@ class LeadCustomFields(BaseModel):
|
||||
fields: dict[str, str] = SchemaField(
|
||||
description="Custom fields for a lead (max 20 fields)",
|
||||
max_length=20,
|
||||
default_factory=dict,
|
||||
default={},
|
||||
)
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ class AddLeadsRequest(BaseModel):
|
||||
lead_list: list[LeadInput] = SchemaField(
|
||||
description="List of leads to add to the campaign",
|
||||
max_length=100,
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
settings: LeadUploadSettings
|
||||
campaign_id: int
|
||||
|
||||
@@ -156,7 +156,7 @@
|
||||
# participant_ids: list[str] = SchemaField(
|
||||
# description="Array of User IDs to create conversation with (max 50)",
|
||||
# placeholder="Enter participant user IDs",
|
||||
# default_factory=list,
|
||||
# default=[],
|
||||
# advanced=False
|
||||
# )
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ class TwitterGetListBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to lookup",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -183,6 +184,7 @@ class TwitterGetOwnedListsBlock(Block):
|
||||
user_id: str = SchemaField(
|
||||
description="The user ID whose owned Lists to retrieve",
|
||||
placeholder="Enter user ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
|
||||
@@ -45,11 +45,13 @@ class TwitterRemoveListMemberBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to remove the member from",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="The ID of the user to remove from the List",
|
||||
placeholder="Enter user ID to remove",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -118,11 +120,13 @@ class TwitterAddListMemberBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to add the member to",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="The ID of the user to add to the List",
|
||||
placeholder="Enter user ID to add",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -191,6 +195,7 @@ class TwitterGetListMembersBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to get members from",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
@@ -371,6 +376,7 @@ class TwitterGetListMembershipsBlock(Block):
|
||||
user_id: str = SchemaField(
|
||||
description="The ID of the user whose List memberships to retrieve",
|
||||
placeholder="Enter user ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
|
||||
@@ -42,6 +42,7 @@ class TwitterGetListTweetsBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List whose Tweets you would like to retrieve",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
|
||||
@@ -28,6 +28,7 @@ class TwitterDeleteListBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to be deleted",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -39,6 +39,7 @@ class TwitterUnpinListBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to unpin",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -102,6 +103,7 @@ class TwitterPinListBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to pin",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -44,7 +44,7 @@ class SpaceList(BaseModel):
|
||||
space_ids: list[str] = SchemaField(
|
||||
description="List of Space IDs to lookup (up to 100)",
|
||||
placeholder="Enter Space IDs",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -54,7 +54,7 @@ class UserList(BaseModel):
|
||||
user_ids: list[str] = SchemaField(
|
||||
description="List of user IDs to lookup their Spaces (up to 100)",
|
||||
placeholder="Enter user IDs",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -227,6 +227,7 @@ class TwitterGetSpaceByIdBlock(Block):
|
||||
space_id: str = SchemaField(
|
||||
description="Space ID to lookup",
|
||||
placeholder="Enter Space ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -388,6 +389,7 @@ class TwitterGetSpaceBuyersBlock(Block):
|
||||
space_id: str = SchemaField(
|
||||
description="Space ID to lookup buyers for",
|
||||
placeholder="Enter Space ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -515,6 +517,7 @@ class TwitterGetSpaceTweetsBlock(Block):
|
||||
space_id: str = SchemaField(
|
||||
description="Space ID to lookup tweets for",
|
||||
placeholder="Enter Space ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -200,7 +200,7 @@ class UserIdList(BaseModel):
|
||||
user_ids: list[str] = SchemaField(
|
||||
description="List of user IDs to lookup (max 100)",
|
||||
placeholder="Enter user IDs",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -210,7 +210,7 @@ class UsernameList(BaseModel):
|
||||
usernames: list[str] = SchemaField(
|
||||
description="List of Twitter usernames/handles to lookup (max 100)",
|
||||
placeholder="Enter usernames",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import pathlib
|
||||
import click
|
||||
import psutil
|
||||
|
||||
from backend import app
|
||||
from backend.util.process import AppProcess
|
||||
|
||||
|
||||
@@ -41,13 +42,8 @@ def write_pid(pid: int):
|
||||
|
||||
class MainApp(AppProcess):
|
||||
def run(self):
|
||||
from backend import app
|
||||
|
||||
app.main(silent=True)
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
|
||||
@click.group()
|
||||
def main():
|
||||
|
||||
@@ -12,12 +12,12 @@ async def log_raw_analytics(
|
||||
data_index: str,
|
||||
):
|
||||
details = await prisma.models.AnalyticsDetails.prisma().create(
|
||||
data=prisma.types.AnalyticsDetailsCreateInput(
|
||||
userId=user_id,
|
||||
type=type,
|
||||
data=prisma.Json(data),
|
||||
dataIndex=data_index,
|
||||
)
|
||||
data={
|
||||
"userId": user_id,
|
||||
"type": type,
|
||||
"data": prisma.Json(data),
|
||||
"dataIndex": data_index,
|
||||
}
|
||||
)
|
||||
return details
|
||||
|
||||
@@ -32,12 +32,12 @@ async def log_raw_metric(
|
||||
raise ValueError("metric_value must be non-negative")
|
||||
|
||||
result = await prisma.models.AnalyticsMetrics.prisma().create(
|
||||
data=prisma.types.AnalyticsMetricsCreateInput(
|
||||
value=metric_value,
|
||||
analyticMetric=metric_name,
|
||||
userId=user_id,
|
||||
dataString=data_string,
|
||||
)
|
||||
data={
|
||||
"value": metric_value,
|
||||
"analyticMetric": metric_name,
|
||||
"userId": user_id,
|
||||
"dataString": data_string,
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -17,7 +17,6 @@ from typing import (
|
||||
import jsonref
|
||||
import jsonschema
|
||||
from prisma.models import AgentBlock
|
||||
from prisma.types import AgentBlockCreateInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
@@ -481,12 +480,12 @@ async def initialize_blocks() -> None:
|
||||
)
|
||||
if not existing_block:
|
||||
await AgentBlock.prisma().create(
|
||||
data=AgentBlockCreateInput(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
inputSchema=json.dumps(block.input_schema.jsonschema()),
|
||||
outputSchema=json.dumps(block.output_schema.jsonschema()),
|
||||
)
|
||||
data={
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"inputSchema": json.dumps(block.input_schema.jsonschema()),
|
||||
"outputSchema": json.dumps(block.output_schema.jsonschema()),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
@@ -75,8 +75,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: 1,
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
|
||||
LlmModel.META_LLAMA_4_SCOUT: 1,
|
||||
LlmModel.META_LLAMA_4_MAVERICK: 1,
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
|
||||
@@ -11,15 +11,10 @@ from prisma.enums import (
|
||||
CreditRefundRequestStatus,
|
||||
CreditTransactionType,
|
||||
NotificationType,
|
||||
OnboardingStep,
|
||||
)
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User
|
||||
from prisma.types import (
|
||||
CreditRefundRequestCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
CreditTransactionWhereInput,
|
||||
)
|
||||
from prisma.types import CreditTransactionCreateInput, CreditTransactionWhereInput
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from backend.data import db
|
||||
@@ -122,18 +117,6 @@ class UserCreditBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
|
||||
"""
|
||||
Reward the user with credits for completing an onboarding step.
|
||||
Won't reward if the user has already received credits for the step.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
step (OnboardingStep): The onboarding step.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def top_up_intent(self, user_id: str, amount: int) -> str:
|
||||
"""
|
||||
@@ -226,7 +209,7 @@ class UserCreditBase(ABC):
|
||||
"userId": user_id,
|
||||
"createdAt": {"lte": top_time},
|
||||
"isActive": True,
|
||||
"NOT": [{"runningBalance": None}],
|
||||
"runningBalance": {"not": None}, # type: ignore
|
||||
},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
@@ -348,15 +331,15 @@ class UserCreditBase(ABC):
|
||||
amount = min(-user_balance, 0)
|
||||
|
||||
# Create the transaction
|
||||
transaction_data = CreditTransactionCreateInput(
|
||||
userId=user_id,
|
||||
amount=amount,
|
||||
runningBalance=user_balance + amount,
|
||||
type=transaction_type,
|
||||
metadata=metadata,
|
||||
isActive=is_active,
|
||||
createdAt=self.time_now(),
|
||||
)
|
||||
transaction_data: CreditTransactionCreateInput = {
|
||||
"userId": user_id,
|
||||
"amount": amount,
|
||||
"runningBalance": user_balance + amount,
|
||||
"type": transaction_type,
|
||||
"metadata": metadata,
|
||||
"isActive": is_active,
|
||||
"createdAt": self.time_now(),
|
||||
}
|
||||
if transaction_key:
|
||||
transaction_data["transactionKey"] = transaction_key
|
||||
tx = await CreditTransaction.prisma().create(data=transaction_data)
|
||||
@@ -421,24 +404,6 @@ class UserCredit(UserCreditBase):
|
||||
async def top_up_credits(self, user_id: str, amount: int):
|
||||
await self._top_up_credits(user_id, amount)
|
||||
|
||||
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
|
||||
key = f"REWARD-{user_id}-{step.value}"
|
||||
if not await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"transactionKey": key,
|
||||
}
|
||||
):
|
||||
await self._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=credits,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
transaction_key=key,
|
||||
metadata=Json(
|
||||
{"reason": f"Reward for completing {step.value} onboarding step."}
|
||||
),
|
||||
)
|
||||
|
||||
async def top_up_refund(
|
||||
self, user_id: str, transaction_key: str, metadata: dict[str, str]
|
||||
) -> int:
|
||||
@@ -457,15 +422,15 @@ class UserCredit(UserCreditBase):
|
||||
|
||||
try:
|
||||
refund_request = await CreditRefundRequest.prisma().create(
|
||||
data=CreditRefundRequestCreateInput(
|
||||
id=refund_key,
|
||||
transactionKey=transaction_key,
|
||||
userId=user_id,
|
||||
amount=amount,
|
||||
reason=metadata.get("reason", ""),
|
||||
status=CreditRefundRequestStatus.PENDING,
|
||||
result="The refund request is under review.",
|
||||
)
|
||||
data={
|
||||
"id": refund_key,
|
||||
"transactionKey": transaction_key,
|
||||
"userId": user_id,
|
||||
"amount": amount,
|
||||
"reason": metadata.get("reason", ""),
|
||||
"status": CreditRefundRequestStatus.PENDING,
|
||||
"result": "The refund request is under review.",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
raise ValueError(
|
||||
@@ -926,9 +891,6 @@ class DisabledUserCredit(UserCreditBase):
|
||||
async def top_up_credits(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def onboarding_reward(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def top_up_intent(self, *args, **kwargs) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
@@ -62,10 +62,10 @@ async def connect():
|
||||
|
||||
# Connection acquired from a pool like Supabase somehow still possibly allows
|
||||
# the db client obtains a connection but still reject query connection afterward.
|
||||
# try:
|
||||
# await prisma.execute_raw("SELECT 1")
|
||||
# except Exception as e:
|
||||
# raise ConnectionError("Failed to connect to Prisma.") from e
|
||||
try:
|
||||
await prisma.execute_raw("SELECT 1")
|
||||
except Exception as e:
|
||||
raise ConnectionError("Failed to connect to Prisma.") from e
|
||||
|
||||
|
||||
@conn_retry("Prisma", "Releasing connection")
|
||||
@@ -89,7 +89,7 @@ async def transaction():
|
||||
async def locked_transaction(key: str):
|
||||
lock_key = zlib.crc32(key.encode("utf-8"))
|
||||
async with transaction() as tx:
|
||||
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
|
||||
await tx.execute_raw(f"SELECT pg_advisory_xact_lock({lock_key})")
|
||||
yield tx
|
||||
|
||||
|
||||
|
||||
@@ -23,10 +23,7 @@ from prisma.models import (
|
||||
AgentNodeExecutionInputOutput,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionCreateInput,
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionCreateInput,
|
||||
AgentNodeExecutionInputOutputCreateInput,
|
||||
AgentNodeExecutionUpdateInput,
|
||||
AgentNodeExecutionWhereInput,
|
||||
)
|
||||
@@ -34,10 +31,11 @@ from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import mock
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .block import BlockInput, BlockType, CompletedBlockOutput, get_block
|
||||
from .block import BlockData, BlockInput, BlockType, CompletedBlockOutput, get_block
|
||||
from .db import BaseDbModel
|
||||
from .includes import (
|
||||
EXECUTION_RESULT_INCLUDE,
|
||||
@@ -61,27 +59,23 @@ ExecutionStatus = AgentExecutionStatus
|
||||
|
||||
class GraphExecutionMeta(BaseDbModel):
|
||||
user_id: str
|
||||
started_at: datetime
|
||||
ended_at: datetime
|
||||
cost: Optional[int] = Field(..., description="Execution cost in credits")
|
||||
duration: float = Field(..., description="Seconds from start to end of run")
|
||||
total_run_time: float = Field(..., description="Seconds of node runtime")
|
||||
status: ExecutionStatus
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
preset_id: Optional[str] = None
|
||||
status: ExecutionStatus
|
||||
started_at: datetime
|
||||
ended_at: datetime
|
||||
|
||||
class Stats(BaseModel):
|
||||
cost: int = Field(..., description="Execution cost (cents)")
|
||||
duration: float = Field(..., description="Seconds from start to end of run")
|
||||
node_exec_time: float = Field(..., description="Seconds of total node runtime")
|
||||
node_exec_count: int = Field(..., description="Number of node executions")
|
||||
|
||||
stats: Stats | None
|
||||
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
now = datetime.now(timezone.utc)
|
||||
# TODO: make started_at and ended_at optional
|
||||
start_time = _graph_exec.startedAt or _graph_exec.createdAt
|
||||
end_time = _graph_exec.updatedAt or now
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
total_run_time = duration
|
||||
|
||||
try:
|
||||
stats = GraphExecutionStats.model_validate(_graph_exec.stats)
|
||||
@@ -93,25 +87,21 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
)
|
||||
stats = None
|
||||
|
||||
duration = stats.walltime if stats else duration
|
||||
total_run_time = stats.nodes_walltime if stats else total_run_time
|
||||
|
||||
return GraphExecutionMeta(
|
||||
id=_graph_exec.id,
|
||||
user_id=_graph_exec.userId,
|
||||
started_at=start_time,
|
||||
ended_at=end_time,
|
||||
cost=stats.cost if stats else None,
|
||||
duration=duration,
|
||||
total_run_time=total_run_time,
|
||||
status=ExecutionStatus(_graph_exec.executionStatus),
|
||||
graph_id=_graph_exec.agentGraphId,
|
||||
graph_version=_graph_exec.agentGraphVersion,
|
||||
preset_id=_graph_exec.agentPresetId,
|
||||
status=ExecutionStatus(_graph_exec.executionStatus),
|
||||
started_at=start_time,
|
||||
ended_at=end_time,
|
||||
stats=(
|
||||
GraphExecutionMeta.Stats(
|
||||
cost=stats.cost,
|
||||
duration=stats.walltime,
|
||||
node_exec_time=stats.nodes_walltime,
|
||||
node_exec_count=stats.node_count,
|
||||
)
|
||||
if stats
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -121,16 +111,15 @@ class GraphExecution(GraphExecutionMeta):
|
||||
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
if _graph_exec.NodeExecutions is None:
|
||||
if _graph_exec.AgentNodeExecutions is None:
|
||||
raise ValueError("Node executions must be included in query")
|
||||
|
||||
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
|
||||
|
||||
complete_node_executions = sorted(
|
||||
node_executions = sorted(
|
||||
[
|
||||
NodeExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.NodeExecutions
|
||||
if ne.executionStatus != ExecutionStatus.INCOMPLETE
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
],
|
||||
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
|
||||
)
|
||||
@@ -139,7 +128,7 @@ class GraphExecution(GraphExecutionMeta):
|
||||
**{
|
||||
# inputs from Agent Input Blocks
|
||||
exec.input_data["name"]: exec.input_data.get("value")
|
||||
for exec in complete_node_executions
|
||||
for exec in node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type == BlockType.INPUT
|
||||
@@ -148,7 +137,7 @@ class GraphExecution(GraphExecutionMeta):
|
||||
**{
|
||||
# input from webhook-triggered block
|
||||
"payload": exec.input_data["payload"]
|
||||
for exec in complete_node_executions
|
||||
for exec in node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type
|
||||
@@ -158,7 +147,7 @@ class GraphExecution(GraphExecutionMeta):
|
||||
}
|
||||
|
||||
outputs: CompletedBlockOutput = defaultdict(list)
|
||||
for exec in complete_node_executions:
|
||||
for exec in node_executions:
|
||||
if (
|
||||
block := get_block(exec.block_id)
|
||||
) and block.block_type == BlockType.OUTPUT:
|
||||
@@ -169,7 +158,7 @@ class GraphExecution(GraphExecutionMeta):
|
||||
return GraphExecution(
|
||||
**{
|
||||
field_name: getattr(graph_exec, field_name)
|
||||
for field_name in GraphExecutionMeta.model_fields
|
||||
for field_name in graph_exec.model_fields
|
||||
},
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
@@ -181,7 +170,7 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
if _graph_exec.NodeExecutions is None:
|
||||
if _graph_exec.AgentNodeExecutions is None:
|
||||
raise ValueError("Node executions must be included in query")
|
||||
|
||||
graph_exec_with_io = GraphExecution.from_db(_graph_exec)
|
||||
@@ -189,7 +178,7 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
node_executions = sorted(
|
||||
[
|
||||
NodeExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.NodeExecutions
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
],
|
||||
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
|
||||
)
|
||||
@@ -197,31 +186,11 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
return GraphExecutionWithNodes(
|
||||
**{
|
||||
field_name: getattr(graph_exec_with_io, field_name)
|
||||
for field_name in GraphExecution.model_fields
|
||||
for field_name in graph_exec_with_io.model_fields
|
||||
},
|
||||
node_executions=node_executions,
|
||||
)
|
||||
|
||||
def to_graph_execution_entry(self):
|
||||
return GraphExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_id=self.graph_id,
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
start_node_execs=[
|
||||
NodeExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
data=node_exec.input_data,
|
||||
)
|
||||
for node_exec in self.node_executions
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class NodeExecutionResult(BaseModel):
|
||||
user_id: str
|
||||
@@ -240,21 +209,21 @@ class NodeExecutionResult(BaseModel):
|
||||
end_time: datetime | None
|
||||
|
||||
@staticmethod
|
||||
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
|
||||
if _node_exec.executionData:
|
||||
def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None):
|
||||
if execution.executionData:
|
||||
# Execution that has been queued for execution will persist its data.
|
||||
input_data = type_utils.convert(_node_exec.executionData, dict[str, Any])
|
||||
input_data = type_utils.convert(execution.executionData, dict[str, Any])
|
||||
else:
|
||||
# For incomplete execution, executionData will not be yet available.
|
||||
input_data: BlockInput = defaultdict()
|
||||
for data in _node_exec.Input or []:
|
||||
for data in execution.Input or []:
|
||||
input_data[data.name] = type_utils.convert(data.data, type[Any])
|
||||
|
||||
output_data: CompletedBlockOutput = defaultdict(list)
|
||||
for data in _node_exec.Output or []:
|
||||
for data in execution.Output or []:
|
||||
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
|
||||
|
||||
graph_execution: AgentGraphExecution | None = _node_exec.GraphExecution
|
||||
graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution
|
||||
if graph_execution:
|
||||
user_id = graph_execution.userId
|
||||
elif not user_id:
|
||||
@@ -266,17 +235,17 @@ class NodeExecutionResult(BaseModel):
|
||||
user_id=user_id,
|
||||
graph_id=graph_execution.agentGraphId if graph_execution else "",
|
||||
graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
|
||||
graph_exec_id=_node_exec.agentGraphExecutionId,
|
||||
block_id=_node_exec.Node.agentBlockId if _node_exec.Node else "",
|
||||
node_exec_id=_node_exec.id,
|
||||
node_id=_node_exec.agentNodeId,
|
||||
status=_node_exec.executionStatus,
|
||||
graph_exec_id=execution.agentGraphExecutionId,
|
||||
block_id=execution.AgentNode.agentBlockId if execution.AgentNode else "",
|
||||
node_exec_id=execution.id,
|
||||
node_id=execution.agentNodeId,
|
||||
status=execution.executionStatus,
|
||||
input_data=input_data,
|
||||
output_data=output_data,
|
||||
add_time=_node_exec.addedTime,
|
||||
queue_time=_node_exec.queuedTime,
|
||||
start_time=_node_exec.startedTime,
|
||||
end_time=_node_exec.endedTime,
|
||||
add_time=execution.addedTime,
|
||||
queue_time=execution.queuedTime,
|
||||
start_time=execution.startedTime,
|
||||
end_time=execution.endedTime,
|
||||
)
|
||||
|
||||
|
||||
@@ -371,29 +340,29 @@ async def create_graph_execution(
|
||||
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
||||
"""
|
||||
result = await AgentGraphExecution.prisma().create(
|
||||
data=AgentGraphExecutionCreateInput(
|
||||
agentGraphId=graph_id,
|
||||
agentGraphVersion=graph_version,
|
||||
executionStatus=ExecutionStatus.QUEUED,
|
||||
NodeExecutions={
|
||||
"create": [
|
||||
AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
executionStatus=ExecutionStatus.QUEUED,
|
||||
queuedTime=datetime.now(tz=timezone.utc),
|
||||
Input={
|
||||
data={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
"AgentNodeExecutions": {
|
||||
"create": [ # type: ignore
|
||||
{
|
||||
"agentNodeId": node_id,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
"queuedTime": datetime.now(tz=timezone.utc),
|
||||
"Input": {
|
||||
"create": [
|
||||
{"name": name, "data": Json(data)}
|
||||
for name, data in node_input.items()
|
||||
]
|
||||
},
|
||||
)
|
||||
}
|
||||
for node_id, node_input in nodes_input
|
||||
]
|
||||
},
|
||||
userId=user_id,
|
||||
agentPresetId=preset_id,
|
||||
),
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
|
||||
@@ -440,11 +409,11 @@ async def upsert_execution_input(
|
||||
|
||||
if existing_execution:
|
||||
await AgentNodeExecutionInputOutput.prisma().create(
|
||||
data=AgentNodeExecutionInputOutputCreateInput(
|
||||
name=input_name,
|
||||
data=json_input_data,
|
||||
referencedByInputExecId=existing_execution.id,
|
||||
)
|
||||
data={
|
||||
"name": input_name,
|
||||
"data": json_input_data,
|
||||
"referencedByInputExecId": existing_execution.id,
|
||||
}
|
||||
)
|
||||
return existing_execution.id, {
|
||||
**{
|
||||
@@ -456,12 +425,12 @@ async def upsert_execution_input(
|
||||
|
||||
elif not node_exec_id:
|
||||
result = await AgentNodeExecution.prisma().create(
|
||||
data=AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
agentGraphExecutionId=graph_exec_id,
|
||||
executionStatus=ExecutionStatus.INCOMPLETE,
|
||||
Input={"create": {"name": input_name, "data": json_input_data}},
|
||||
)
|
||||
data={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"Input": {"create": {"name": input_name, "data": json_input_data}},
|
||||
}
|
||||
)
|
||||
return result.id, {input_name: input_data}
|
||||
|
||||
@@ -480,35 +449,27 @@ async def upsert_execution_output(
|
||||
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
||||
"""
|
||||
await AgentNodeExecutionInputOutput.prisma().create(
|
||||
data=AgentNodeExecutionInputOutputCreateInput(
|
||||
name=output_name,
|
||||
data=Json(output_data),
|
||||
referencedByOutputExecId=node_exec_id,
|
||||
)
|
||||
data={
|
||||
"name": output_name,
|
||||
"data": Json(output_data),
|
||||
"referencedByOutputExecId": node_exec_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def update_graph_execution_start_time(
|
||||
graph_exec_id: str,
|
||||
) -> GraphExecution | None:
|
||||
count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
},
|
||||
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecution:
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
data={
|
||||
"executionStatus": ExecutionStatus.RUNNING,
|
||||
"startedAt": datetime.now(tz=timezone.utc),
|
||||
},
|
||||
)
|
||||
if count == 0:
|
||||
return None
|
||||
|
||||
res = await AgentGraphExecution.prisma().find_unique(
|
||||
where={"id": graph_exec_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return GraphExecution.from_db(res) if res else None
|
||||
if not res:
|
||||
raise ValueError(f"Graph execution #{graph_exec_id} not found")
|
||||
|
||||
return GraphExecution.from_db(res)
|
||||
|
||||
|
||||
async def update_graph_execution_stats(
|
||||
@@ -519,8 +480,7 @@ async def update_graph_execution_stats(
|
||||
data = stats.model_dump() if stats else {}
|
||||
if isinstance(data.get("error"), Exception):
|
||||
data["error"] = str(data["error"])
|
||||
|
||||
updated_count = await AgentGraphExecution.prisma().update_many(
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"OR": [
|
||||
@@ -532,15 +492,10 @@ async def update_graph_execution_stats(
|
||||
"executionStatus": status,
|
||||
"stats": Json(data),
|
||||
},
|
||||
)
|
||||
if updated_count == 0:
|
||||
return None
|
||||
|
||||
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
|
||||
where={"id": graph_exec_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return GraphExecution.from_db(graph_exec)
|
||||
|
||||
return GraphExecution.from_db(res) if res else None
|
||||
|
||||
|
||||
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
|
||||
@@ -634,7 +589,7 @@ async def get_node_execution_results(
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
}
|
||||
if block_ids:
|
||||
where_clause["Node"] = {"is": {"agentBlockId": {"in": block_ids}}}
|
||||
where_clause["AgentNode"] = {"is": {"agentBlockId": {"in": block_ids}}}
|
||||
if statuses:
|
||||
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
@@ -676,7 +631,7 @@ async def get_latest_node_execution(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"NOT": [{"executionStatus": ExecutionStatus.INCOMPLETE}],
|
||||
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
|
||||
},
|
||||
order=[
|
||||
{"queuedTime": "desc"},
|
||||
@@ -744,6 +699,144 @@ class ExecutionQueue(Generic[T]):
|
||||
return self.queue.empty()
|
||||
|
||||
|
||||
# ------------------- Execution Utilities -------------------- #
|
||||
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
"""
|
||||
Extracts partial output data by name from a given BlockData.
|
||||
|
||||
The function supports extracting data from lists, dictionaries, and objects
|
||||
using specific naming conventions:
|
||||
- For lists: <output_name>_$_<index>
|
||||
- For dictionaries: <output_name>_#_<key>
|
||||
- For objects: <output_name>_@_<attribute>
|
||||
|
||||
Args:
|
||||
output (BlockData): A tuple containing the output name and data.
|
||||
name (str): The name used to extract specific data from the output.
|
||||
|
||||
Returns:
|
||||
Any | None: The extracted data if found, otherwise None.
|
||||
|
||||
Examples:
|
||||
>>> output = ("result", [10, 20, 30])
|
||||
>>> parse_execution_output(output, "result_$_1")
|
||||
20
|
||||
|
||||
>>> output = ("config", {"key1": "value1", "key2": "value2"})
|
||||
>>> parse_execution_output(output, "config_#_key1")
|
||||
'value1'
|
||||
|
||||
>>> class Sample:
|
||||
... attr1 = "value1"
|
||||
... attr2 = "value2"
|
||||
>>> output = ("object", Sample())
|
||||
>>> parse_execution_output(output, "object_@_attr1")
|
||||
'value1'
|
||||
"""
|
||||
output_name, output_data = output
|
||||
|
||||
if name == output_name:
|
||||
return output_data
|
||||
|
||||
if name.startswith(f"{output_name}{LIST_SPLIT}"):
|
||||
index = int(name.split(LIST_SPLIT)[1])
|
||||
if not isinstance(output_data, list) or len(output_data) <= index:
|
||||
return None
|
||||
return output_data[int(name.split(LIST_SPLIT)[1])]
|
||||
|
||||
if name.startswith(f"{output_name}{DICT_SPLIT}"):
|
||||
index = name.split(DICT_SPLIT)[1]
|
||||
if not isinstance(output_data, dict) or index not in output_data:
|
||||
return None
|
||||
return output_data[index]
|
||||
|
||||
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
|
||||
index = name.split(OBJC_SPLIT)[1]
|
||||
if isinstance(output_data, object) and hasattr(output_data, index):
|
||||
return getattr(output_data, index)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
"""
|
||||
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
|
||||
|
||||
This function processes input keys that follow specific patterns to merge them into a unified structure:
|
||||
- `<input_name>_$_<index>` for list inputs.
|
||||
- `<input_name>_#_<index>` for dictionary inputs.
|
||||
- `<input_name>_@_<index>` for object inputs.
|
||||
|
||||
Args:
|
||||
data (BlockInput): A dictionary containing input keys and their corresponding values.
|
||||
|
||||
Returns:
|
||||
BlockInput: A dictionary with merged inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If a list index is not an integer.
|
||||
|
||||
Examples:
|
||||
>>> data = {
|
||||
... "list_$_0": "a",
|
||||
... "list_$_1": "b",
|
||||
... "dict_#_key1": "value1",
|
||||
... "dict_#_key2": "value2",
|
||||
... "object_@_attr1": "value1",
|
||||
... "object_@_attr2": "value2"
|
||||
... }
|
||||
>>> merge_execution_input(data)
|
||||
{
|
||||
"list": ["a", "b"],
|
||||
"dict": {"key1": "value1", "key2": "value2"},
|
||||
"object": <MockObject attr1="value1" attr2="value2">
|
||||
}
|
||||
"""
|
||||
|
||||
# Merge all input with <input_name>_$_<index> into a single list.
|
||||
items = list(data.items())
|
||||
|
||||
for key, value in items:
|
||||
if LIST_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(LIST_SPLIT)
|
||||
if not index.isdigit():
|
||||
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
|
||||
|
||||
data[name] = data.get(name, [])
|
||||
if int(index) >= len(data[name]):
|
||||
# Pad list with empty string on missing indices.
|
||||
data[name].extend([""] * (int(index) - len(data[name]) + 1))
|
||||
data[name][int(index)] = value
|
||||
|
||||
# Merge all input with <input_name>_#_<index> into a single dict.
|
||||
for key, value in items:
|
||||
if DICT_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(DICT_SPLIT)
|
||||
data[name] = data.get(name, {})
|
||||
data[name][index] = value
|
||||
|
||||
# Merge all input with <input_name>_@_<index> into a single object.
|
||||
for key, value in items:
|
||||
if OBJC_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(OBJC_SPLIT)
|
||||
if name not in data or not isinstance(data[name], object):
|
||||
data[name] = mock.MockObject()
|
||||
setattr(data[name], index, value)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# --------------------- Event Bus --------------------- #
|
||||
|
||||
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any, Literal, Optional, Type, cast
|
||||
from typing import Any, Literal, Optional, Type
|
||||
|
||||
import prisma
|
||||
from prisma import Json
|
||||
from prisma.enums import SubmissionStatus
|
||||
from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVersion
|
||||
from prisma.types import (
|
||||
AgentGraphCreateInput,
|
||||
AgentGraphWhereInput,
|
||||
AgentNodeCreateInput,
|
||||
AgentNodeLinkCreateInput,
|
||||
)
|
||||
from prisma.types import AgentGraphWhereInput
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
@@ -58,23 +53,22 @@ class Node(BaseDbModel):
|
||||
input_links: list[Link] = []
|
||||
output_links: list[Link] = []
|
||||
|
||||
@property
|
||||
def block(self) -> Block[BlockSchema, BlockSchema]:
|
||||
block = get_block(self.block_id)
|
||||
if not block:
|
||||
raise ValueError(
|
||||
f"Block #{self.block_id} does not exist -> Node #{self.id} is invalid"
|
||||
)
|
||||
return block
|
||||
webhook_id: Optional[str] = None
|
||||
|
||||
|
||||
class NodeModel(Node):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
|
||||
webhook_id: Optional[str] = None
|
||||
webhook: Optional[Webhook] = None
|
||||
|
||||
@property
|
||||
def block(self) -> Block[BlockSchema, BlockSchema]:
|
||||
block = get_block(self.block_id)
|
||||
if not block:
|
||||
raise ValueError(f"Block #{self.block_id} does not exist")
|
||||
return block
|
||||
|
||||
@staticmethod
|
||||
def from_db(node: AgentNode, for_export: bool = False) -> "NodeModel":
|
||||
obj = NodeModel(
|
||||
@@ -94,7 +88,8 @@ class NodeModel(Node):
|
||||
return obj
|
||||
|
||||
def is_triggered_by_event_type(self, event_type: str) -> bool:
|
||||
block = self.block
|
||||
if not (block := get_block(self.block_id)):
|
||||
raise ValueError(f"Block #{self.block_id} not found for node #{self.id}")
|
||||
if not block.webhook_config:
|
||||
raise TypeError("This method can't be used on non-webhook blocks")
|
||||
if not block.webhook_config.event_filter_input:
|
||||
@@ -171,10 +166,11 @@ class BaseGraph(BaseDbModel):
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return self._generate_schema(
|
||||
*(
|
||||
(block.input_schema, node.input_default)
|
||||
(b.input_schema, node.input_default)
|
||||
for node in self.nodes
|
||||
if (block := node.block).block_type == BlockType.INPUT
|
||||
and issubclass(block.input_schema, AgentInputBlock.Input)
|
||||
if (b := get_block(node.block_id))
|
||||
and b.block_type == BlockType.INPUT
|
||||
and issubclass(b.input_schema, AgentInputBlock.Input)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -183,10 +179,11 @@ class BaseGraph(BaseDbModel):
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return self._generate_schema(
|
||||
*(
|
||||
(block.input_schema, node.input_default)
|
||||
(b.input_schema, node.input_default)
|
||||
for node in self.nodes
|
||||
if (block := node.block).block_type == BlockType.OUTPUT
|
||||
and issubclass(block.input_schema, AgentOutputBlock.Input)
|
||||
if (b := get_block(node.block_id))
|
||||
and b.block_type == BlockType.OUTPUT
|
||||
and issubclass(b.input_schema, AgentOutputBlock.Input)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -231,16 +228,13 @@ class GraphModel(Graph):
|
||||
user_id: str
|
||||
nodes: list[NodeModel] = [] # type: ignore
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def has_webhook_trigger(self) -> bool:
|
||||
return self.webhook_input_node is not None
|
||||
|
||||
@property
|
||||
def starting_nodes(self) -> list[NodeModel]:
|
||||
def starting_nodes(self) -> list[Node]:
|
||||
outbound_nodes = {link.sink_id for link in self.links}
|
||||
input_nodes = {
|
||||
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
||||
v.id
|
||||
for v in self.nodes
|
||||
if (b := get_block(v.block_id)) and b.block_type == BlockType.INPUT
|
||||
}
|
||||
return [
|
||||
node
|
||||
@@ -248,18 +242,6 @@ class GraphModel(Graph):
|
||||
if node.id not in outbound_nodes or node.id in input_nodes
|
||||
]
|
||||
|
||||
@property
|
||||
def webhook_input_node(self) -> NodeModel | None:
|
||||
return next(
|
||||
(
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||
"""
|
||||
Reassigns all IDs in the graph to new UUIDs.
|
||||
@@ -409,7 +391,9 @@ class GraphModel(Graph):
|
||||
node_map = {v.id: v for v in graph.nodes}
|
||||
|
||||
def is_static_output_block(nid: str) -> bool:
|
||||
return node_map[nid].block.static_output
|
||||
bid = node_map[nid].block_id
|
||||
b = get_block(bid)
|
||||
return b.static_output if b else False
|
||||
|
||||
# Links: links are connected and the connected pin data type are compatible.
|
||||
for link in graph.links:
|
||||
@@ -465,11 +449,13 @@ class GraphModel(Graph):
|
||||
is_active=graph.isActive,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
nodes=[NodeModel.from_db(node, for_export) for node in graph.Nodes or []],
|
||||
nodes=[
|
||||
NodeModel.from_db(node, for_export) for node in graph.AgentNodes or []
|
||||
],
|
||||
links=list(
|
||||
{
|
||||
Link.from_db(link)
|
||||
for node in graph.Nodes or []
|
||||
for node in graph.AgentNodes or []
|
||||
for link in (node.Input or []) + (node.Output or [])
|
||||
}
|
||||
),
|
||||
@@ -600,8 +586,8 @@ async def get_graph(
|
||||
and not (
|
||||
await StoreListingVersion.prisma().find_first(
|
||||
where={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": version or graph.version,
|
||||
"agentId": graph_id,
|
||||
"agentVersion": version or graph.version,
|
||||
"isDeleted": False,
|
||||
"submissionStatus": SubmissionStatus.APPROVED,
|
||||
}
|
||||
@@ -635,16 +621,12 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
sub_graph_ids = [
|
||||
(graph_id, graph_version)
|
||||
for graph in search_graphs
|
||||
for node in graph.Nodes or []
|
||||
for node in graph.AgentNodes or []
|
||||
if (
|
||||
node.AgentBlock
|
||||
and node.AgentBlock.id == agent_block_id
|
||||
and (graph_id := cast(str, dict(node.constantInput).get("graph_id")))
|
||||
and (
|
||||
graph_version := cast(
|
||||
int, dict(node.constantInput).get("graph_version")
|
||||
)
|
||||
)
|
||||
and (graph_id := dict(node.constantInput).get("graph_id"))
|
||||
and (graph_version := dict(node.constantInput).get("graph_version"))
|
||||
)
|
||||
]
|
||||
if not sub_graph_ids:
|
||||
@@ -659,7 +641,7 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
|
||||
}
|
||||
for graph_id, graph_version in sub_graph_ids
|
||||
]
|
||||
] # type: ignore
|
||||
},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
@@ -673,7 +655,7 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
|
||||
links = await AgentNodeLink.prisma().find_many(
|
||||
where={"agentNodeSourceId": node_id},
|
||||
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}},
|
||||
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}}, # type: ignore
|
||||
)
|
||||
return [
|
||||
(Link.from_db(link), NodeModel.from_db(link.AgentNodeSink))
|
||||
@@ -744,28 +726,29 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
|
||||
await AgentGraph.prisma(tx).create_many(
|
||||
data=[
|
||||
AgentGraphCreateInput(
|
||||
id=graph.id,
|
||||
version=graph.version,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
isActive=graph.is_active,
|
||||
userId=user_id,
|
||||
)
|
||||
{
|
||||
"id": graph.id,
|
||||
"version": graph.version,
|
||||
"name": graph.name,
|
||||
"description": graph.description,
|
||||
"isActive": graph.is_active,
|
||||
"userId": user_id,
|
||||
}
|
||||
for graph in graphs
|
||||
]
|
||||
)
|
||||
|
||||
await AgentNode.prisma(tx).create_many(
|
||||
data=[
|
||||
AgentNodeCreateInput(
|
||||
id=node.id,
|
||||
agentGraphId=graph.id,
|
||||
agentGraphVersion=graph.version,
|
||||
agentBlockId=node.block_id,
|
||||
constantInput=Json(node.input_default),
|
||||
metadata=Json(node.metadata),
|
||||
)
|
||||
{
|
||||
"id": node.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"agentBlockId": node.block_id,
|
||||
"constantInput": Json(node.input_default),
|
||||
"metadata": Json(node.metadata),
|
||||
"webhookId": node.webhook_id,
|
||||
}
|
||||
for graph in graphs
|
||||
for node in graph.nodes
|
||||
]
|
||||
@@ -773,14 +756,14 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
|
||||
await AgentNodeLink.prisma(tx).create_many(
|
||||
data=[
|
||||
AgentNodeLinkCreateInput(
|
||||
id=str(uuid.uuid4()),
|
||||
sourceName=link.source_name,
|
||||
sinkName=link.sink_name,
|
||||
agentNodeSourceId=link.source_id,
|
||||
agentNodeSinkId=link.sink_id,
|
||||
isStatic=link.is_static,
|
||||
)
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"sourceName": link.source_name,
|
||||
"sinkName": link.sink_name,
|
||||
"agentNodeSourceId": link.source_id,
|
||||
"agentNodeSinkId": link.sink_id,
|
||||
"isStatic": link.is_static,
|
||||
}
|
||||
for graph in graphs
|
||||
for link in graph.links
|
||||
]
|
||||
@@ -831,12 +814,12 @@ async def fix_llm_provider_credentials():
|
||||
SELECT graph."userId" user_id,
|
||||
node.id node_id,
|
||||
node."constantInput" node_preset_input
|
||||
FROM platform."AgentNode" node
|
||||
LEFT JOIN platform."AgentGraph" graph
|
||||
ON node."agentGraphId" = graph.id
|
||||
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
|
||||
ORDER BY graph."userId";
|
||||
"""
|
||||
FROM platform."AgentNode" node
|
||||
LEFT JOIN platform."AgentGraph" graph
|
||||
ON node."agentGraphId" = graph.id
|
||||
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
|
||||
ORDER BY graph."userId";
|
||||
"""
|
||||
)
|
||||
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")
|
||||
except Exception as e:
|
||||
@@ -919,19 +902,12 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
# Convert enum values to a list of strings for the SQL query
|
||||
enum_values = [v.value for v in LlmModel.__members__.values()]
|
||||
|
||||
escaped_enum_values = repr(tuple(enum_values)) # hack but works
|
||||
query = f"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = jsonb_set("constantInput", $1, $2, true)
|
||||
WHERE "agentBlockId" = $3
|
||||
AND "constantInput" ? $4
|
||||
AND "constantInput"->>$4 NOT IN {escaped_enum_values}
|
||||
SET "constantInput" = jsonb_set("constantInput", '{{{path}}}', '"{migrate_to.value}"', true)
|
||||
WHERE "agentBlockId" = '{id}'
|
||||
AND "constantInput" ? '{path}'
|
||||
AND "constantInput"->>'{path}' NOT IN ({','.join(f"'{value}'" for value in enum_values)})
|
||||
"""
|
||||
|
||||
await db.execute_raw(
|
||||
query, # type: ignore - is supposed to be LiteralString
|
||||
"{" + path + "}",
|
||||
f'"{migrate_to.value}"',
|
||||
id,
|
||||
path,
|
||||
)
|
||||
await db.execute_raw(query)
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
from typing import cast
|
||||
|
||||
import prisma.enums
|
||||
import prisma.types
|
||||
import prisma
|
||||
|
||||
from backend.blocks.io import IO_BLOCK_IDs
|
||||
|
||||
@@ -13,25 +10,25 @@ AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
}
|
||||
|
||||
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
|
||||
"Nodes": {"include": AGENT_NODE_INCLUDE}
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
|
||||
}
|
||||
|
||||
EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"Node": True,
|
||||
"GraphExecution": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
|
||||
MAX_NODE_EXECUTIONS_FETCH = 1000
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
|
||||
"NodeExecutions": {
|
||||
"AgentNodeExecutions": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"Node": True,
|
||||
"GraphExecution": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
},
|
||||
"order_by": [
|
||||
{"queuedTime": "desc"},
|
||||
@@ -43,30 +40,28 @@ GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
|
||||
}
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
"NodeExecutions": {
|
||||
**cast(
|
||||
prisma.types.FindManyAgentNodeExecutionArgsFromAgentGraphExecution,
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES["NodeExecutions"],
|
||||
),
|
||||
"AgentNodeExecutions": {
|
||||
**GRAPH_EXECUTION_INCLUDE_WITH_NODES["AgentNodeExecutions"], # type: ignore
|
||||
"where": {
|
||||
"Node": {"is": {"AgentBlock": {"is": {"id": {"in": IO_BLOCK_IDs}}}}},
|
||||
"NOT": [{"executionStatus": prisma.enums.AgentExecutionStatus.INCOMPLETE}],
|
||||
"AgentNode": {
|
||||
"AgentBlock": {"id": {"in": IO_BLOCK_IDs}}, # type: ignore
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE}
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
|
||||
}
|
||||
|
||||
|
||||
def library_agent_include(user_id: str) -> prisma.types.LibraryAgentInclude:
|
||||
return {
|
||||
"AgentGraph": {
|
||||
"Agent": {
|
||||
"include": {
|
||||
**AGENT_GRAPH_INCLUDE,
|
||||
"Executions": {"where": {"userId": user_id}},
|
||||
"AgentGraphExecution": {"where": {"userId": user_id}},
|
||||
}
|
||||
},
|
||||
"Creator": True,
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, AsyncGenerator, Optional
|
||||
|
||||
from prisma import Json
|
||||
from prisma.models import IntegrationWebhook
|
||||
from prisma.types import IntegrationWebhookCreateInput
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
@@ -67,18 +66,18 @@ class Webhook(BaseDbModel):
|
||||
|
||||
async def create_webhook(webhook: Webhook) -> Webhook:
|
||||
created_webhook = await IntegrationWebhook.prisma().create(
|
||||
data=IntegrationWebhookCreateInput(
|
||||
id=webhook.id,
|
||||
userId=webhook.user_id,
|
||||
provider=webhook.provider.value,
|
||||
credentialsId=webhook.credentials_id,
|
||||
webhookType=webhook.webhook_type,
|
||||
resource=webhook.resource,
|
||||
events=webhook.events,
|
||||
config=Json(webhook.config),
|
||||
secret=webhook.secret,
|
||||
providerWebhookId=webhook.provider_webhook_id,
|
||||
)
|
||||
data={
|
||||
"id": webhook.id,
|
||||
"userId": webhook.user_id,
|
||||
"provider": webhook.provider.value,
|
||||
"credentialsId": webhook.credentials_id,
|
||||
"webhookType": webhook.webhook_type,
|
||||
"resource": webhook.resource,
|
||||
"events": webhook.events,
|
||||
"config": Json(webhook.config),
|
||||
"secret": webhook.secret,
|
||||
"providerWebhookId": webhook.provider_webhook_id,
|
||||
}
|
||||
)
|
||||
return Webhook.from_db(created_webhook)
|
||||
|
||||
|
||||
@@ -142,12 +142,8 @@ def SchemaField(
|
||||
exclude: bool = False,
|
||||
hidden: Optional[bool] = None,
|
||||
depends_on: Optional[list[str]] = None,
|
||||
ge: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
json_schema_extra: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
if default is PydanticUndefined and default_factory is None:
|
||||
advanced = False
|
||||
@@ -174,12 +170,8 @@ def SchemaField(
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
ge=ge,
|
||||
le=le,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
discriminator=discriminator,
|
||||
json_schema_extra=json_schema_extra,
|
||||
**kwargs,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
@@ -413,10 +405,9 @@ class RefundRequest(BaseModel):
|
||||
class NodeExecutionStats(BaseModel):
|
||||
"""Execution statistics for a node execution."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "allow"
|
||||
|
||||
error: Optional[Exception | str] = None
|
||||
walltime: float = 0
|
||||
@@ -432,10 +423,9 @@ class NodeExecutionStats(BaseModel):
|
||||
class GraphExecutionStats(BaseModel):
|
||||
"""Execution statistics for a graph execution."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "allow"
|
||||
|
||||
error: Optional[Exception | str] = None
|
||||
walltime: float = Field(
|
||||
|
||||
@@ -6,14 +6,10 @@ from typing import Annotated, Any, Generic, Optional, TypeVar, Union
|
||||
from prisma import Json
|
||||
from prisma.enums import NotificationType
|
||||
from prisma.models import NotificationEvent, UserNotificationBatch
|
||||
from prisma.types import (
|
||||
NotificationEventCreateInput,
|
||||
UserNotificationBatchCreateInput,
|
||||
UserNotificationBatchWhereInput,
|
||||
)
|
||||
from prisma.types import UserNotificationBatchWhereInput
|
||||
|
||||
# from backend.notifications.models import NotificationEvent
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
|
||||
from pydantic import BaseModel, EmailStr, Field, field_validator
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
|
||||
@@ -39,7 +35,8 @@ class QueueType(Enum):
|
||||
|
||||
|
||||
class BaseNotificationData(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class AgentRunData(BaseNotificationData):
|
||||
@@ -401,8 +398,6 @@ async def create_or_add_to_user_notification_batch(
|
||||
logger.info(
|
||||
f"Creating or adding to notification batch for {user_id} with type {notification_type} and data {notification_data}"
|
||||
)
|
||||
if not notification_data.data:
|
||||
raise ValueError("Notification data must be provided")
|
||||
|
||||
# Serialize the data
|
||||
json_data: Json = Json(notification_data.data.model_dump())
|
||||
@@ -421,30 +416,30 @@ async def create_or_add_to_user_notification_batch(
|
||||
if not existing_batch:
|
||||
async with transaction() as tx:
|
||||
notification_event = await tx.notificationevent.create(
|
||||
data=NotificationEventCreateInput(
|
||||
type=notification_type,
|
||||
data=json_data,
|
||||
)
|
||||
data={
|
||||
"type": notification_type,
|
||||
"data": json_data,
|
||||
}
|
||||
)
|
||||
|
||||
# Create new batch
|
||||
resp = await tx.usernotificationbatch.create(
|
||||
data=UserNotificationBatchCreateInput(
|
||||
userId=user_id,
|
||||
type=notification_type,
|
||||
Notifications={"connect": [{"id": notification_event.id}]},
|
||||
),
|
||||
data={
|
||||
"userId": user_id,
|
||||
"type": notification_type,
|
||||
"Notifications": {"connect": [{"id": notification_event.id}]},
|
||||
},
|
||||
include={"Notifications": True},
|
||||
)
|
||||
return UserNotificationBatchDTO.from_db(resp)
|
||||
else:
|
||||
async with transaction() as tx:
|
||||
notification_event = await tx.notificationevent.create(
|
||||
data=NotificationEventCreateInput(
|
||||
type=notification_type,
|
||||
data=json_data,
|
||||
UserNotificationBatch={"connect": {"id": existing_batch.id}},
|
||||
)
|
||||
data={
|
||||
"type": notification_type,
|
||||
"data": json_data,
|
||||
"UserNotificationBatch": {"connect": {"id": existing_batch.id}},
|
||||
}
|
||||
)
|
||||
# Add to existing batch
|
||||
resp = await tx.usernotificationbatch.update(
|
||||
|
||||
@@ -6,11 +6,9 @@ import pydantic
|
||||
from prisma import Json
|
||||
from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
from prisma.types import UserOnboardingUpdateInput
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block import get_blocks
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
@@ -26,26 +24,21 @@ REASON_MAPPING: dict[str, list[str]] = {
|
||||
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
|
||||
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
|
||||
|
||||
user_credit = get_user_credit_model()
|
||||
|
||||
|
||||
class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
completedSteps: Optional[list[OnboardingStep]] = None
|
||||
notificationDot: Optional[bool] = None
|
||||
notified: Optional[list[OnboardingStep]] = None
|
||||
usageReason: Optional[str] = None
|
||||
integrations: Optional[list[str]] = None
|
||||
otherIntegrations: Optional[str] = None
|
||||
selectedStoreListingVersionId: Optional[str] = None
|
||||
agentInput: Optional[dict[str, Any]] = None
|
||||
onboardingAgentExecutionId: Optional[str] = None
|
||||
|
||||
|
||||
async def get_user_onboarding(user_id: str):
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": UserOnboardingCreateInput(userId=user_id),
|
||||
"create": {"userId": user_id}, # type: ignore
|
||||
"update": {},
|
||||
},
|
||||
)
|
||||
@@ -55,20 +48,6 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update: UserOnboardingUpdateInput = {}
|
||||
if data.completedSteps is not None:
|
||||
update["completedSteps"] = list(set(data.completedSteps))
|
||||
for step in (
|
||||
OnboardingStep.AGENT_NEW_RUN,
|
||||
OnboardingStep.GET_RESULTS,
|
||||
OnboardingStep.MARKETPLACE_ADD_AGENT,
|
||||
OnboardingStep.MARKETPLACE_RUN_AGENT,
|
||||
OnboardingStep.BUILDER_SAVE_AGENT,
|
||||
OnboardingStep.BUILDER_RUN_AGENT,
|
||||
):
|
||||
if step in data.completedSteps:
|
||||
await reward_user(user_id, step)
|
||||
if data.notificationDot is not None:
|
||||
update["notificationDot"] = data.notificationDot
|
||||
if data.notified is not None:
|
||||
update["notified"] = list(set(data.notified))
|
||||
if data.usageReason is not None:
|
||||
update["usageReason"] = data.usageReason
|
||||
if data.integrations is not None:
|
||||
@@ -79,57 +58,16 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update["selectedStoreListingVersionId"] = data.selectedStoreListingVersionId
|
||||
if data.agentInput is not None:
|
||||
update["agentInput"] = Json(data.agentInput)
|
||||
if data.onboardingAgentExecutionId is not None:
|
||||
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
|
||||
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, **update},
|
||||
"create": {"userId": user_id, **update}, # type: ignore
|
||||
"update": update,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def reward_user(user_id: str, step: OnboardingStep):
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}-reward"):
|
||||
reward = 0
|
||||
match step:
|
||||
# Reward user when they clicked New Run during onboarding
|
||||
# This is because they need credits before scheduling a run (next step)
|
||||
case OnboardingStep.AGENT_NEW_RUN:
|
||||
reward = 300
|
||||
case OnboardingStep.GET_RESULTS:
|
||||
reward = 300
|
||||
case OnboardingStep.MARKETPLACE_ADD_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.MARKETPLACE_RUN_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_SAVE_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_RUN_AGENT:
|
||||
reward = 100
|
||||
|
||||
if reward == 0:
|
||||
return
|
||||
|
||||
onboarding = await get_user_onboarding(user_id)
|
||||
|
||||
# Skip if already rewarded
|
||||
if step in onboarding.rewardedFor:
|
||||
return
|
||||
|
||||
onboarding.rewardedFor.append(step)
|
||||
await user_credit.onboarding_reward(user_id, reward, step)
|
||||
await UserOnboarding.prisma().update(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"completedSteps": list(set(onboarding.completedSteps + [step])),
|
||||
"rewardedFor": onboarding.rewardedFor,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def clean_and_split(text: str) -> list[str]:
|
||||
"""
|
||||
Removes all special characters from a string, truncates it to 100 characters,
|
||||
@@ -248,11 +186,11 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
where={
|
||||
"id": {"in": [agent.storeListingVersionId for agent in storeAgents]},
|
||||
},
|
||||
include={"AgentGraph": True},
|
||||
include={"Agent": True},
|
||||
)
|
||||
|
||||
for listing in agentListings:
|
||||
agent = listing.AgentGraph
|
||||
agent = listing.Agent
|
||||
if agent is None:
|
||||
continue
|
||||
graph = GraphModel.from_db(agent)
|
||||
|
||||
@@ -11,7 +11,7 @@ from fastapi import HTTPException
|
||||
from prisma import Json
|
||||
from prisma.enums import NotificationType
|
||||
from prisma.models import User
|
||||
from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
|
||||
from prisma.types import UserUpdateInput
|
||||
|
||||
from backend.data.db import prisma
|
||||
from backend.data.model import UserIntegrations, UserMetadata, UserMetadataRaw
|
||||
@@ -36,11 +36,11 @@ async def get_or_create_user(user_data: dict) -> User:
|
||||
user = await prisma.user.find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
user = await prisma.user.create(
|
||||
data=UserCreateInput(
|
||||
id=user_id,
|
||||
email=user_email,
|
||||
name=user_data.get("user_metadata", {}).get("name"),
|
||||
)
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": user_email,
|
||||
"name": user_data.get("user_metadata", {}).get("name"),
|
||||
}
|
||||
)
|
||||
|
||||
return User.model_validate(user)
|
||||
@@ -84,11 +84,11 @@ async def create_default_user() -> Optional[User]:
|
||||
user = await prisma.user.find_unique(where={"id": DEFAULT_USER_ID})
|
||||
if not user:
|
||||
user = await prisma.user.create(
|
||||
data=UserCreateInput(
|
||||
id=DEFAULT_USER_ID,
|
||||
email="default@example.com",
|
||||
name="Default User",
|
||||
)
|
||||
data={
|
||||
"id": DEFAULT_USER_ID,
|
||||
"email": "default@example.com",
|
||||
"name": "Default User",
|
||||
}
|
||||
)
|
||||
return User.model_validate(user)
|
||||
|
||||
@@ -135,21 +135,16 @@ async def migrate_and_encrypt_user_integrations():
|
||||
"""Migrate integration credentials and OAuth states from metadata to integrations column."""
|
||||
users = await User.prisma().find_many(
|
||||
where={
|
||||
"metadata": cast(
|
||||
JsonFilter,
|
||||
{
|
||||
"path": ["integration_credentials"],
|
||||
"not": Json(
|
||||
{"a": "yolo"}
|
||||
), # bogus value works to check if key exists
|
||||
},
|
||||
)
|
||||
"metadata": {
|
||||
"path": ["integration_credentials"],
|
||||
"not": Json({"a": "yolo"}), # bogus value works to check if key exists
|
||||
} # type: ignore
|
||||
}
|
||||
)
|
||||
logger.info(f"Migrating integration credentials for {len(users)} users")
|
||||
|
||||
for user in users:
|
||||
raw_metadata = cast(dict, user.metadata)
|
||||
raw_metadata = cast(UserMetadataRaw, user.metadata)
|
||||
metadata = UserMetadata.model_validate(raw_metadata)
|
||||
|
||||
# Get existing integrations data
|
||||
@@ -165,6 +160,7 @@ async def migrate_and_encrypt_user_integrations():
|
||||
await update_user_integrations(user_id=user.id, data=integrations)
|
||||
|
||||
# Remove from metadata
|
||||
raw_metadata = dict(raw_metadata)
|
||||
raw_metadata.pop("integration_credentials", None)
|
||||
raw_metadata.pop("integration_oauth_states", None)
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
GraphExecution,
|
||||
NodeExecutionResult,
|
||||
RedisExecutionEventBus,
|
||||
create_graph_execution,
|
||||
get_graph_execution,
|
||||
get_incomplete_node_executions,
|
||||
@@ -39,12 +39,11 @@ from backend.data.user import (
|
||||
update_user_integrations,
|
||||
update_user_metadata,
|
||||
)
|
||||
from backend.util.service import AppService, exposed_run_and_wait
|
||||
from backend.util.service import AppService, expose, exposed_run_and_wait
|
||||
from backend.util.settings import Config
|
||||
|
||||
config = Config()
|
||||
_user_credit_model = get_user_credit_model()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _spend_credits(
|
||||
@@ -54,21 +53,22 @@ async def _spend_credits(
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
|
||||
def run_service(self) -> None:
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
self.run_and_wait(db.connect())
|
||||
super().run_service()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_db = True
|
||||
self.use_redis = True
|
||||
self.execution_event_bus = RedisExecutionEventBus()
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return config.database_api_port
|
||||
|
||||
@expose
|
||||
def send_execution_update(
|
||||
self, execution_result: GraphExecution | NodeExecutionResult
|
||||
):
|
||||
self.execution_event_bus.publish(execution_result)
|
||||
|
||||
# Executions
|
||||
get_graph_execution = exposed_run_and_wait(get_graph_execution)
|
||||
create_graph_execution = exposed_run_and_wait(create_graph_execution)
|
||||
|
||||
@@ -5,14 +5,11 @@ import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
@@ -33,36 +30,43 @@ from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis
|
||||
from backend.data.block import BlockData, BlockInput, BlockSchema, get_block
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockData,
|
||||
BlockInput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
merge_execution_input,
|
||||
parse_execution_output,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.data.graph import GraphModel, Link, Node
|
||||
from backend.executor.utils import (
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
CancelExecutionEvent,
|
||||
UsageTransactionMetadata,
|
||||
block_usage_cost,
|
||||
execution_usage_cost,
|
||||
get_execution_event_bus,
|
||||
get_execution_queue,
|
||||
parse_execution_output,
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util import json
|
||||
from backend.util.decorator import error_logged, time_measured
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import configure_logging
|
||||
from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.service import close_service_client, get_service_client
|
||||
from backend.util.process import set_service_name
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
close_service_client,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.type import convert
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -148,16 +152,23 @@ def execute_node(
|
||||
def update_execution_status(status: ExecutionStatus) -> NodeExecutionResult:
|
||||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||||
exec_update = db_client.update_node_execution_status(node_exec_id, status)
|
||||
send_execution_update(exec_update)
|
||||
db_client.send_execution_update(exec_update)
|
||||
return exec_update
|
||||
|
||||
node = db_client.get_node(node_id)
|
||||
|
||||
node_block = node.block
|
||||
node_block = get_block(node.block_id)
|
||||
if not node_block:
|
||||
logger.error(f"Block {node.block_id} not found.")
|
||||
return
|
||||
|
||||
def push_output(output_name: str, output_data: Any) -> None:
|
||||
db_client.upsert_execution_output(
|
||||
_push_node_execution_output(
|
||||
db_client=db_client,
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
node_exec_id=node_exec_id,
|
||||
block_id=node_block.id,
|
||||
output_name=output_name,
|
||||
output_data=output_data,
|
||||
)
|
||||
@@ -188,7 +199,7 @@ def execute_node(
|
||||
# Execute the node
|
||||
input_data_str = json.dumps(input_data)
|
||||
input_size = len(input_data_str)
|
||||
log_metadata.debug("Executed node with input", input=input_data_str)
|
||||
log_metadata.info("Executed node with input", input=input_data_str)
|
||||
update_execution_status(ExecutionStatus.RUNNING)
|
||||
|
||||
# Inject extra execution arguments for the blocks via kwargs
|
||||
@@ -219,7 +230,7 @@ def execute_node(
|
||||
):
|
||||
output_data = json.convert_pydantic_to_json(output_data)
|
||||
output_size += len(json.dumps(output_data))
|
||||
log_metadata.debug("Node produced output", **{output_name: output_data})
|
||||
log_metadata.info("Node produced output", **{output_name: output_data})
|
||||
push_output(output_name, output_data)
|
||||
outputs[output_name] = output_data
|
||||
for execution in _enqueue_next_nodes(
|
||||
@@ -269,6 +280,35 @@ def execute_node(
|
||||
execution_stats.output_size = output_size
|
||||
|
||||
|
||||
def _push_node_execution_output(
|
||||
db_client: "DatabaseManager",
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
node_exec_id: str,
|
||||
block_id: str,
|
||||
output_name: str,
|
||||
output_data: Any,
|
||||
):
|
||||
from backend.blocks.io import IO_BLOCK_IDs
|
||||
|
||||
db_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name=output_name,
|
||||
output_data=output_data,
|
||||
)
|
||||
|
||||
# Automatically push execution updates for all agent I/O
|
||||
if block_id in IO_BLOCK_IDs:
|
||||
graph_exec = db_client.get_graph_execution(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec:
|
||||
raise ValueError(
|
||||
f"Graph execution #{graph_exec_id} for user #{user_id} not found"
|
||||
)
|
||||
db_client.send_execution_update(graph_exec)
|
||||
|
||||
|
||||
def _enqueue_next_nodes(
|
||||
db_client: "DatabaseManager",
|
||||
node: Node,
|
||||
@@ -284,7 +324,7 @@ def _enqueue_next_nodes(
|
||||
exec_update = db_client.update_node_execution_status(
|
||||
node_exec_id, ExecutionStatus.QUEUED, data
|
||||
)
|
||||
send_execution_update(exec_update)
|
||||
db_client.send_execution_update(exec_update)
|
||||
return NodeExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
@@ -396,6 +436,60 @@ def _enqueue_next_nodes(
|
||||
]
|
||||
|
||||
|
||||
def validate_exec(
|
||||
node: Node,
|
||||
data: BlockInput,
|
||||
resolve_input: bool = True,
|
||||
) -> tuple[BlockInput | None, str]:
|
||||
"""
|
||||
Validate the input data for a node execution.
|
||||
|
||||
Args:
|
||||
node: The node to execute.
|
||||
data: The input data for the node execution.
|
||||
resolve_input: Whether to resolve dynamic pins into dict/list/object.
|
||||
|
||||
Returns:
|
||||
A tuple of the validated data and the block name.
|
||||
If the data is invalid, the first element will be None, and the second element
|
||||
will be an error message.
|
||||
If the data is valid, the first element will be the resolved input data, and
|
||||
the second element will be the block name.
|
||||
"""
|
||||
node_block: Block | None = get_block(node.block_id)
|
||||
if not node_block:
|
||||
return None, f"Block for {node.block_id} not found."
|
||||
schema = node_block.input_schema
|
||||
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
if (value := data.get(name)) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data (without default values) should contain all required fields.
|
||||
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
|
||||
if missing_links := schema.get_missing_links(data, node.input_links):
|
||||
return None, f"{error_prefix} unpopulated links {missing_links}"
|
||||
|
||||
# Merge input data with default values and resolve dynamic dict/list/object pins.
|
||||
input_default = schema.get_input_defaults(node.input_default)
|
||||
data = {**input_default, **data}
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
if missing_input := schema.get_missing_input(data):
|
||||
return None, f"{error_prefix} missing input {missing_input}"
|
||||
|
||||
# Last validation: Validate the input values against the schema.
|
||||
if error := schema.get_mismatch_error(data):
|
||||
error_message = f"{error_prefix} {error}"
|
||||
logger.error(error_message)
|
||||
return None, error_message
|
||||
|
||||
return data, node_block.name
|
||||
|
||||
|
||||
class Executor:
|
||||
"""
|
||||
This class contains event handlers for the process pool executor events.
|
||||
@@ -575,13 +669,7 @@ class Executor:
|
||||
exec_meta = cls.db_client.update_graph_execution_start_time(
|
||||
graph_exec.graph_exec_id
|
||||
)
|
||||
if exec_meta is None:
|
||||
logger.warning(
|
||||
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution is not found or not currently in the QUEUED state."
|
||||
)
|
||||
return
|
||||
|
||||
send_execution_update(exec_meta)
|
||||
cls.db_client.send_execution_update(exec_meta)
|
||||
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
|
||||
graph_exec, cancel, log_metadata
|
||||
)
|
||||
@@ -594,7 +682,7 @@ class Executor:
|
||||
status=status,
|
||||
stats=exec_stats,
|
||||
):
|
||||
send_execution_update(graph_exec_result)
|
||||
cls.db_client.send_execution_update(graph_exec_result)
|
||||
|
||||
cls._handle_agent_run_notif(graph_exec, exec_stats)
|
||||
|
||||
@@ -660,19 +748,15 @@ class Executor:
|
||||
Exception | None: The error that occurred during the execution, if any.
|
||||
"""
|
||||
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
|
||||
execution_stats = GraphExecutionStats()
|
||||
execution_status = ExecutionStatus.RUNNING
|
||||
exec_stats = GraphExecutionStats()
|
||||
error = None
|
||||
finished = False
|
||||
|
||||
def cancel_handler():
|
||||
nonlocal execution_status
|
||||
|
||||
while not cancel.is_set():
|
||||
cancel.wait(1)
|
||||
if finished:
|
||||
return
|
||||
execution_status = ExecutionStatus.TERMINATED
|
||||
cls.executor.terminate()
|
||||
log_metadata.info(f"Terminated graph execution {graph_exec.graph_exec_id}")
|
||||
cls._init_node_executor_pool()
|
||||
@@ -695,34 +779,18 @@ class Executor:
|
||||
if not isinstance(result, NodeExecutionStats):
|
||||
return
|
||||
|
||||
nonlocal execution_stats
|
||||
execution_stats.node_count += 1
|
||||
execution_stats.nodes_cputime += result.cputime
|
||||
execution_stats.nodes_walltime += result.walltime
|
||||
nonlocal exec_stats
|
||||
exec_stats.node_count += 1
|
||||
exec_stats.nodes_cputime += result.cputime
|
||||
exec_stats.nodes_walltime += result.walltime
|
||||
if (err := result.error) and isinstance(err, Exception):
|
||||
execution_stats.node_error_count += 1
|
||||
|
||||
if _graph_exec := cls.db_client.update_graph_execution_stats(
|
||||
graph_exec_id=exec_data.graph_exec_id,
|
||||
status=execution_status,
|
||||
stats=execution_stats,
|
||||
):
|
||||
send_execution_update(_graph_exec)
|
||||
else:
|
||||
logger.error(
|
||||
"Callback for "
|
||||
f"finished node execution #{exec_data.node_exec_id} "
|
||||
"could not update execution stats "
|
||||
f"for graph execution #{exec_data.graph_exec_id}; "
|
||||
f"triggered while graph exec status = {execution_status}"
|
||||
)
|
||||
exec_stats.node_error_count += 1
|
||||
|
||||
return callback
|
||||
|
||||
while not queue.empty():
|
||||
if cancel.is_set():
|
||||
execution_status = ExecutionStatus.TERMINATED
|
||||
return execution_stats, execution_status, error
|
||||
return exec_stats, ExecutionStatus.TERMINATED, error
|
||||
|
||||
exec_data = queue.get()
|
||||
|
||||
@@ -744,26 +812,29 @@ class Executor:
|
||||
exec_cost_counter = cls._charge_usage(
|
||||
node_exec=exec_data,
|
||||
execution_count=exec_cost_counter + 1,
|
||||
execution_stats=execution_stats,
|
||||
execution_stats=exec_stats,
|
||||
)
|
||||
except InsufficientBalanceError as error:
|
||||
node_exec_id = exec_data.node_exec_id
|
||||
cls.db_client.upsert_execution_output(
|
||||
_push_node_execution_output(
|
||||
db_client=cls.db_client,
|
||||
user_id=graph_exec.user_id,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
node_exec_id=node_exec_id,
|
||||
block_id=exec_data.block_id,
|
||||
output_name="error",
|
||||
output_data=str(error),
|
||||
)
|
||||
|
||||
execution_status = ExecutionStatus.FAILED
|
||||
exec_update = cls.db_client.update_node_execution_status(
|
||||
node_exec_id, execution_status
|
||||
node_exec_id, ExecutionStatus.FAILED
|
||||
)
|
||||
send_execution_update(exec_update)
|
||||
cls.db_client.send_execution_update(exec_update)
|
||||
|
||||
cls._handle_low_balance_notif(
|
||||
graph_exec.user_id,
|
||||
graph_exec.graph_id,
|
||||
execution_stats,
|
||||
exec_stats,
|
||||
error,
|
||||
)
|
||||
raise
|
||||
@@ -781,8 +852,7 @@ class Executor:
|
||||
)
|
||||
for node_id, execution in list(running_executions.items()):
|
||||
if cancel.is_set():
|
||||
execution_status = ExecutionStatus.TERMINATED
|
||||
return execution_stats, execution_status, error
|
||||
return exec_stats, ExecutionStatus.TERMINATED, error
|
||||
|
||||
if not queue.empty():
|
||||
break # yield to parent loop to execute new queue items
|
||||
@@ -809,7 +879,7 @@ class Executor:
|
||||
cancel_thread.join()
|
||||
clean_exec_files(graph_exec.graph_exec_id)
|
||||
|
||||
return execution_stats, execution_status, error
|
||||
return exec_stats, execution_status, error
|
||||
|
||||
@classmethod
|
||||
def _handle_agent_run_notif(
|
||||
@@ -875,170 +945,227 @@ class Executor:
|
||||
)
|
||||
|
||||
|
||||
class ExecutionManager(AppProcess):
|
||||
class ExecutionManager(AppService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_redis = True
|
||||
self.use_supabase = True
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.running = True
|
||||
self.queue = ExecutionQueue[GraphExecutionEntry]()
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return settings.config.execution_manager_port
|
||||
|
||||
def run(self):
|
||||
retry_count_max = settings.config.execution_manager_loop_max_retry
|
||||
retry_count = 0
|
||||
def run_service(self):
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
for retry_count in range(retry_count_max):
|
||||
try:
|
||||
self._run()
|
||||
except Exception as e:
|
||||
if not self.running:
|
||||
break
|
||||
logger.exception(
|
||||
f"[{self.service_name}] Error in execution manager: {e}"
|
||||
)
|
||||
|
||||
if retry_count >= retry_count_max:
|
||||
logger.error(
|
||||
f"[{self.service_name}] Max retries reached ({retry_count_max}), exiting..."
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.info(
|
||||
f"[{self.service_name}] Retrying execution loop in {retry_count} seconds..."
|
||||
)
|
||||
time.sleep(retry_count)
|
||||
|
||||
def _run(self):
|
||||
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
|
||||
self.credentials_store = IntegrationCredentialsStore()
|
||||
self.executor = ProcessPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=Executor.on_graph_executor_start,
|
||||
)
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
|
||||
redis.connect()
|
||||
|
||||
# Consume Cancel & Run execution requests.
|
||||
channel = get_execution_queue().get_channel()
|
||||
channel.basic_qos(prefetch_count=self.pool_size)
|
||||
channel.basic_consume(
|
||||
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
on_message_callback=self._handle_cancel_message,
|
||||
auto_ack=True,
|
||||
)
|
||||
channel.basic_consume(
|
||||
queue=GRAPH_EXECUTION_QUEUE_NAME,
|
||||
on_message_callback=self._handle_run_message,
|
||||
auto_ack=False,
|
||||
)
|
||||
|
||||
logger.info(f"[{self.service_name}] Ready to consume messages...")
|
||||
channel.start_consuming()
|
||||
|
||||
def _handle_cancel_message(
|
||||
self,
|
||||
channel: BlockingChannel,
|
||||
method: Basic.Deliver,
|
||||
properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
"""
|
||||
Called whenever we receive a CANCEL message from the queue.
|
||||
(With auto_ack=True, message is considered 'acked' automatically.)
|
||||
"""
|
||||
try:
|
||||
request = CancelExecutionEvent.model_validate_json(body)
|
||||
graph_exec_id = request.graph_exec_id
|
||||
if not graph_exec_id:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Cancel message missing 'graph_exec_id'"
|
||||
)
|
||||
return
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
logger.debug(
|
||||
f"[{self.service_name}] Cancel received for {graph_exec_id} but not active."
|
||||
)
|
||||
return
|
||||
|
||||
_, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
logger.info(f"[{self.service_name}] Received cancel for {graph_exec_id}")
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
else:
|
||||
logger.debug(
|
||||
f"[{self.service_name}] Cancel already set for {graph_exec_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling cancel message: {e}")
|
||||
|
||||
def _handle_run_message(
|
||||
self,
|
||||
channel: BlockingChannel,
|
||||
method: Basic.Deliver,
|
||||
properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
delivery_tag = method.delivery_tag
|
||||
try:
|
||||
graph_exec_entry = GraphExecutionEntry.model_validate_json(body)
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.service_name}] Could not parse run message: {e}")
|
||||
channel.basic_nack(delivery_tag, requeue=False)
|
||||
return
|
||||
|
||||
graph_exec_id = graph_exec_entry.graph_exec_id
|
||||
sync_manager = multiprocessing.Manager()
|
||||
logger.info(
|
||||
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
|
||||
f"[{self.service_name}] Started with max-{self.pool_size} graph workers"
|
||||
)
|
||||
if graph_exec_id in self.active_graph_runs:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
|
||||
while True:
|
||||
graph_exec_data = self.queue.get()
|
||||
graph_exec_id = graph_exec_data.graph_exec_id
|
||||
logger.debug(
|
||||
f"[ExecutionManager] Dispatching graph execution {graph_exec_id}"
|
||||
)
|
||||
cancel_event = sync_manager.Event()
|
||||
future = self.executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_data, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
future.add_done_callback(
|
||||
lambda _: self.active_graph_runs.pop(graph_exec_id, None)
|
||||
)
|
||||
channel.basic_nack(delivery_tag, requeue=False)
|
||||
return
|
||||
|
||||
cancel_event = multiprocessing.Manager().Event()
|
||||
future = self.executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_entry, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
|
||||
def _on_run_done(f: Future):
|
||||
logger.info(f"[{self.service_name}] Run completed for {graph_exec_id}")
|
||||
try:
|
||||
self.active_graph_runs.pop(graph_exec_id, None)
|
||||
if f.exception():
|
||||
logger.error(
|
||||
f"[{self.service_name}] Execution for {graph_exec_id} failed: {f.exception()}"
|
||||
)
|
||||
channel.basic_nack(delivery_tag, requeue=False)
|
||||
else:
|
||||
channel.basic_ack(delivery_tag)
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.service_name}] Error acknowledging message: {e}")
|
||||
|
||||
future.add_done_callback(_on_run_done)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down service loop...")
|
||||
self.running = False
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down graph executor pool...")
|
||||
logger.info(f"[{__class__.__name__}] ⏳ Shutting down graph executor pool...")
|
||||
self.executor.shutdown(cancel_futures=True)
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
super().cleanup()
|
||||
|
||||
@property
|
||||
def db_client(self) -> "DatabaseManager":
|
||||
return get_db_client()
|
||||
|
||||
@expose
|
||||
def add_execution(
|
||||
self,
|
||||
graph_id: str,
|
||||
data: BlockInput,
|
||||
user_id: str,
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: str | None = None,
|
||||
) -> GraphExecutionEntry:
|
||||
graph: GraphModel | None = self.db_client.get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise ValueError(f"Graph #{graph_id} not found.")
|
||||
|
||||
graph.validate_graph(for_run=True)
|
||||
self._validate_node_input_credentials(graph, user_id)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
input_data = {}
|
||||
block = get_block(node.block_id)
|
||||
|
||||
# Invalid block & Note block should never be executed.
|
||||
if not block or block.block_type == BlockType.NOTE:
|
||||
continue
|
||||
|
||||
# Extract request input data, and assign it to the input pin.
|
||||
if block.block_type == BlockType.INPUT:
|
||||
input_name = node.input_default.get("name")
|
||||
if input_name and input_name in data:
|
||||
input_data = {"value": data[input_name]}
|
||||
|
||||
# Extract webhook payload, and assign it to the input pin
|
||||
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
||||
if (
|
||||
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
and node.webhook_id
|
||||
):
|
||||
if webhook_payload_key not in data:
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} webhook payload is missing"
|
||||
)
|
||||
input_data = {"payload": data[webhook_payload_key]}
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
raise ValueError(error)
|
||||
else:
|
||||
nodes_input.append((node.id, input_data))
|
||||
|
||||
if not nodes_input:
|
||||
raise ValueError(
|
||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||
)
|
||||
|
||||
graph_exec = self.db_client.create_graph_execution(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
nodes_input=nodes_input,
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
self.db_client.send_execution_update(graph_exec)
|
||||
|
||||
graph_exec_entry = GraphExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version or 0,
|
||||
graph_exec_id=graph_exec.id,
|
||||
start_node_execs=[
|
||||
NodeExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
data=node_exec.input_data,
|
||||
)
|
||||
for node_exec in graph_exec.node_executions
|
||||
],
|
||||
)
|
||||
self.queue.add(graph_exec_entry)
|
||||
|
||||
return graph_exec_entry
|
||||
|
||||
@expose
|
||||
def cancel_execution(self, graph_exec_id: str) -> None:
|
||||
"""
|
||||
Mechanism:
|
||||
1. Set the cancel event
|
||||
2. Graph executor's cancel handler thread detects the event, terminates workers,
|
||||
reinitializes worker pool, and returns.
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
logger.warning(
|
||||
f"Graph execution #{graph_exec_id} not active/running: "
|
||||
"possibly already completed/cancelled."
|
||||
)
|
||||
else:
|
||||
future, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
future.result()
|
||||
|
||||
# Update the status of the graph & node executions
|
||||
self.db_client.update_graph_execution_stats(
|
||||
graph_exec_id,
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
node_execs = self.db_client.get_node_execution_results(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
self.db_client.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
for node_exec in node_execs:
|
||||
node_exec.status = ExecutionStatus.TERMINATED
|
||||
self.db_client.send_execution_update(node_exec)
|
||||
|
||||
def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
for node in graph.nodes:
|
||||
block = get_block(node.block_id)
|
||||
if not block:
|
||||
raise ValueError(f"Unknown block {node.block_id} for node #{node.id}")
|
||||
|
||||
# Find any fields of type CredentialsMetaInput
|
||||
credentials_fields = cast(
|
||||
type[BlockSchema], block.input_schema
|
||||
).get_credentials_fields()
|
||||
if not credentials_fields:
|
||||
continue
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = self.credentials_store.get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
f"Unknown credentials #{credentials_meta.id} "
|
||||
f"for node #{node.id} input '{field_name}'"
|
||||
)
|
||||
if (
|
||||
credentials.provider != credentials_meta.provider
|
||||
or credentials.type != credentials_meta.type
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch: "
|
||||
f"{credentials_meta.type}<>{credentials.type};"
|
||||
f"{credentials_meta.provider}<>{credentials.provider}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch"
|
||||
)
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
@@ -1057,10 +1184,6 @@ def get_notification_service() -> "NotificationManager":
|
||||
return get_service_client(NotificationManager)
|
||||
|
||||
|
||||
def send_execution_update(entry: GraphExecution | NodeExecutionResult):
|
||||
return get_execution_event_bus().publish(entry)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def synchronized(key: str, timeout: int = 60):
|
||||
lock: RedisLock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)
|
||||
|
||||
@@ -16,7 +16,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.executor.manager import ExecutionManager
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import Config
|
||||
@@ -57,6 +57,11 @@ def job_listener(event):
|
||||
log(f"Job {event.job_id} completed successfully.")
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_client() -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_client():
|
||||
from backend.notifications import NotificationManager
|
||||
@@ -68,7 +73,7 @@ def execute_graph(**kwargs):
|
||||
args = ExecutionJobArgs(**kwargs)
|
||||
try:
|
||||
log(f"Executing recurring job for graph #{args.graph_id}")
|
||||
execution_utils.add_graph_execution(
|
||||
get_execution_client().add_execution(
|
||||
graph_id=args.graph_id,
|
||||
data=args.input_data,
|
||||
user_id=args.user_id,
|
||||
@@ -159,6 +164,11 @@ class Scheduler(AppService):
|
||||
def db_pool_size(cls) -> int:
|
||||
return config.scheduler_db_pool_size
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def execution_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def notification_client(self) -> NotificationManager:
|
||||
@@ -166,7 +176,7 @@ class Scheduler(AppService):
|
||||
|
||||
def run_service(self):
|
||||
load_dotenv()
|
||||
db_schema, db_url = _extract_schema_from_url(os.getenv("DIRECT_URL"))
|
||||
db_schema, db_url = _extract_schema_from_url(os.getenv("DATABASE_URL"))
|
||||
self.scheduler = BlockingScheduler(
|
||||
jobstores={
|
||||
Jobstores.EXECUTION.value: SQLAlchemyJobStore(
|
||||
@@ -196,12 +206,6 @@ class Scheduler(AppService):
|
||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||
self.scheduler.start()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down scheduler...")
|
||||
if self.scheduler:
|
||||
self.scheduler.shutdown(wait=False)
|
||||
|
||||
@expose
|
||||
def add_execution_schedule(
|
||||
self,
|
||||
|
||||
@@ -1,70 +1,11 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockData,
|
||||
BlockInput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.block import Block, BlockInput
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCostType
|
||||
from backend.data.execution import GraphExecutionEntry, RedisExecutionEventBus
|
||||
from backend.data.graph import GraphModel, Node
|
||||
from backend.data.rabbitmq import (
|
||||
Exchange,
|
||||
ExchangeType,
|
||||
Queue,
|
||||
RabbitMQConfig,
|
||||
SyncRabbitMQ,
|
||||
)
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
config = Config()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ============ Resource Helpers ============ #
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_event_bus() -> RedisExecutionEventBus:
|
||||
return RedisExecutionEventBus()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_queue() -> SyncRabbitMQ:
|
||||
client = SyncRabbitMQ(create_execution_queue_config())
|
||||
client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
return IntegrationCredentialsStore()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client() -> "DatabaseManager":
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
# ============ Execution Cost Helpers ============ #
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
@@ -154,398 +95,3 @@ def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bo
|
||||
or (input_data.get(k) and _is_cost_filter_match(v, input_data[k]))
|
||||
for k, v in cost_filter.items()
|
||||
)
|
||||
|
||||
|
||||
# ============ Execution Input Helpers ============ #
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
"""
|
||||
Extracts partial output data by name from a given BlockData.
|
||||
|
||||
The function supports extracting data from lists, dictionaries, and objects
|
||||
using specific naming conventions:
|
||||
- For lists: <output_name>_$_<index>
|
||||
- For dictionaries: <output_name>_#_<key>
|
||||
- For objects: <output_name>_@_<attribute>
|
||||
|
||||
Args:
|
||||
output (BlockData): A tuple containing the output name and data.
|
||||
name (str): The name used to extract specific data from the output.
|
||||
|
||||
Returns:
|
||||
Any | None: The extracted data if found, otherwise None.
|
||||
|
||||
Examples:
|
||||
>>> output = ("result", [10, 20, 30])
|
||||
>>> parse_execution_output(output, "result_$_1")
|
||||
20
|
||||
|
||||
>>> output = ("config", {"key1": "value1", "key2": "value2"})
|
||||
>>> parse_execution_output(output, "config_#_key1")
|
||||
'value1'
|
||||
|
||||
>>> class Sample:
|
||||
... attr1 = "value1"
|
||||
... attr2 = "value2"
|
||||
>>> output = ("object", Sample())
|
||||
>>> parse_execution_output(output, "object_@_attr1")
|
||||
'value1'
|
||||
"""
|
||||
output_name, output_data = output
|
||||
|
||||
if name == output_name:
|
||||
return output_data
|
||||
|
||||
if name.startswith(f"{output_name}{LIST_SPLIT}"):
|
||||
index = int(name.split(LIST_SPLIT)[1])
|
||||
if not isinstance(output_data, list) or len(output_data) <= index:
|
||||
return None
|
||||
return output_data[int(name.split(LIST_SPLIT)[1])]
|
||||
|
||||
if name.startswith(f"{output_name}{DICT_SPLIT}"):
|
||||
index = name.split(DICT_SPLIT)[1]
|
||||
if not isinstance(output_data, dict) or index not in output_data:
|
||||
return None
|
||||
return output_data[index]
|
||||
|
||||
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
|
||||
index = name.split(OBJC_SPLIT)[1]
|
||||
if isinstance(output_data, object) and hasattr(output_data, index):
|
||||
return getattr(output_data, index)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_exec(
|
||||
node: Node,
|
||||
data: BlockInput,
|
||||
resolve_input: bool = True,
|
||||
) -> tuple[BlockInput | None, str]:
|
||||
"""
|
||||
Validate the input data for a node execution.
|
||||
|
||||
Args:
|
||||
node: The node to execute.
|
||||
data: The input data for the node execution.
|
||||
resolve_input: Whether to resolve dynamic pins into dict/list/object.
|
||||
|
||||
Returns:
|
||||
A tuple of the validated data and the block name.
|
||||
If the data is invalid, the first element will be None, and the second element
|
||||
will be an error message.
|
||||
If the data is valid, the first element will be the resolved input data, and
|
||||
the second element will be the block name.
|
||||
"""
|
||||
node_block: Block | None = get_block(node.block_id)
|
||||
if not node_block:
|
||||
return None, f"Block for {node.block_id} not found."
|
||||
schema = node_block.input_schema
|
||||
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
if (value := data.get(name)) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data (without default values) should contain all required fields.
|
||||
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
|
||||
if missing_links := schema.get_missing_links(data, node.input_links):
|
||||
return None, f"{error_prefix} unpopulated links {missing_links}"
|
||||
|
||||
# Merge input data with default values and resolve dynamic dict/list/object pins.
|
||||
input_default = schema.get_input_defaults(node.input_default)
|
||||
data = {**input_default, **data}
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
if missing_input := schema.get_missing_input(data):
|
||||
return None, f"{error_prefix} missing input {missing_input}"
|
||||
|
||||
# Last validation: Validate the input values against the schema.
|
||||
if error := schema.get_mismatch_error(data):
|
||||
error_message = f"{error_prefix} {error}"
|
||||
logger.error(error_message)
|
||||
return None, error_message
|
||||
|
||||
return data, node_block.name
|
||||
|
||||
|
||||
def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
"""
|
||||
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
|
||||
|
||||
This function processes input keys that follow specific patterns to merge them into a unified structure:
|
||||
- `<input_name>_$_<index>` for list inputs.
|
||||
- `<input_name>_#_<index>` for dictionary inputs.
|
||||
- `<input_name>_@_<index>` for object inputs.
|
||||
|
||||
Args:
|
||||
data (BlockInput): A dictionary containing input keys and their corresponding values.
|
||||
|
||||
Returns:
|
||||
BlockInput: A dictionary with merged inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If a list index is not an integer.
|
||||
|
||||
Examples:
|
||||
>>> data = {
|
||||
... "list_$_0": "a",
|
||||
... "list_$_1": "b",
|
||||
... "dict_#_key1": "value1",
|
||||
... "dict_#_key2": "value2",
|
||||
... "object_@_attr1": "value1",
|
||||
... "object_@_attr2": "value2"
|
||||
... }
|
||||
>>> merge_execution_input(data)
|
||||
{
|
||||
"list": ["a", "b"],
|
||||
"dict": {"key1": "value1", "key2": "value2"},
|
||||
"object": <MockObject attr1="value1" attr2="value2">
|
||||
}
|
||||
"""
|
||||
|
||||
# Merge all input with <input_name>_$_<index> into a single list.
|
||||
items = list(data.items())
|
||||
|
||||
for key, value in items:
|
||||
if LIST_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(LIST_SPLIT)
|
||||
if not index.isdigit():
|
||||
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
|
||||
|
||||
data[name] = data.get(name, [])
|
||||
if int(index) >= len(data[name]):
|
||||
# Pad list with empty string on missing indices.
|
||||
data[name].extend([""] * (int(index) - len(data[name]) + 1))
|
||||
data[name][int(index)] = value
|
||||
|
||||
# Merge all input with <input_name>_#_<index> into a single dict.
|
||||
for key, value in items:
|
||||
if DICT_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(DICT_SPLIT)
|
||||
data[name] = data.get(name, {})
|
||||
data[name][index] = value
|
||||
|
||||
# Merge all input with <input_name>_@_<index> into a single object.
|
||||
for key, value in items:
|
||||
if OBJC_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(OBJC_SPLIT)
|
||||
if name not in data or not isinstance(data[name], object):
|
||||
data[name] = MockObject()
|
||||
setattr(data[name], index, value)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _validate_node_input_credentials(graph: GraphModel, user_id: str):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
for node in graph.nodes:
|
||||
block = node.block
|
||||
|
||||
# Find any fields of type CredentialsMetaInput
|
||||
credentials_fields = cast(
|
||||
type[BlockSchema], block.input_schema
|
||||
).get_credentials_fields()
|
||||
if not credentials_fields:
|
||||
continue
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = get_integration_credentials_store().get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
f"Unknown credentials #{credentials_meta.id} "
|
||||
f"for node #{node.id} input '{field_name}'"
|
||||
)
|
||||
if (
|
||||
credentials.provider != credentials_meta.provider
|
||||
or credentials.type != credentials_meta.type
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch: "
|
||||
f"{credentials_meta.type}<>{credentials.type};"
|
||||
f"{credentials_meta.provider}<>{credentials.provider}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch"
|
||||
)
|
||||
|
||||
|
||||
def construct_node_execution_input(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
data: BlockInput,
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
This function checks the graph for starting nodes, validates the input data
|
||||
against the schema, and resolves dynamic input pins into a single list,
|
||||
dictionary, or object.
|
||||
|
||||
Args:
|
||||
graph (GraphModel): The graph model to execute.
|
||||
user_id (str): The ID of the user executing the graph.
|
||||
data (BlockInput): The input data for the graph execution.
|
||||
|
||||
Returns:
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||
the corresponding input data for that node.
|
||||
"""
|
||||
graph.validate_graph(for_run=True)
|
||||
_validate_node_input_credentials(graph, user_id)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
input_data = {}
|
||||
block = node.block
|
||||
|
||||
# Note block should never be executed.
|
||||
if block.block_type == BlockType.NOTE:
|
||||
continue
|
||||
|
||||
# Extract request input data, and assign it to the input pin.
|
||||
if block.block_type == BlockType.INPUT:
|
||||
input_name = node.input_default.get("name")
|
||||
if input_name and input_name in data:
|
||||
input_data = {"value": data[input_name]}
|
||||
|
||||
# Extract webhook payload, and assign it to the input pin
|
||||
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
||||
if (
|
||||
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
and node.webhook_id
|
||||
):
|
||||
if webhook_payload_key not in data:
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} webhook payload is missing"
|
||||
)
|
||||
input_data = {"payload": data[webhook_payload_key]}
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
raise ValueError(error)
|
||||
else:
|
||||
nodes_input.append((node.id, input_data))
|
||||
|
||||
if not nodes_input:
|
||||
raise ValueError(
|
||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||
)
|
||||
|
||||
return nodes_input
|
||||
|
||||
|
||||
# ============ Execution Queue Helpers ============ #
|
||||
|
||||
|
||||
class CancelExecutionEvent(BaseModel):
|
||||
graph_exec_id: str
|
||||
|
||||
|
||||
GRAPH_EXECUTION_EXCHANGE = Exchange(
|
||||
name="graph_execution",
|
||||
type=ExchangeType.DIRECT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
GRAPH_EXECUTION_QUEUE_NAME = "graph_execution_queue"
|
||||
GRAPH_EXECUTION_ROUTING_KEY = "graph_execution.run"
|
||||
|
||||
GRAPH_EXECUTION_CANCEL_EXCHANGE = Exchange(
|
||||
name="graph_execution_cancel",
|
||||
type=ExchangeType.FANOUT,
|
||||
durable=True,
|
||||
auto_delete=True,
|
||||
)
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME = "graph_execution_cancel_queue"
|
||||
|
||||
|
||||
def create_execution_queue_config() -> RabbitMQConfig:
|
||||
"""
|
||||
Define two exchanges and queues:
|
||||
- 'graph_execution' (DIRECT) for run tasks.
|
||||
- 'graph_execution_cancel' (FANOUT) for cancel requests.
|
||||
"""
|
||||
run_queue = Queue(
|
||||
name=GRAPH_EXECUTION_QUEUE_NAME,
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
cancel_queue = Queue(
|
||||
name=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
|
||||
routing_key="", # not used for FANOUT
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
return RabbitMQConfig(
|
||||
vhost="/",
|
||||
exchanges=[GRAPH_EXECUTION_EXCHANGE, GRAPH_EXECUTION_CANCEL_EXCHANGE],
|
||||
queues=[run_queue, cancel_queue],
|
||||
)
|
||||
|
||||
|
||||
def add_graph_execution(
|
||||
graph_id: str,
|
||||
data: BlockInput,
|
||||
user_id: str,
|
||||
graph_version: int | None = None,
|
||||
preset_id: str | None = None,
|
||||
) -> GraphExecutionEntry:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
|
||||
Args:
|
||||
graph_id (str): The ID of the graph to execute.
|
||||
data (BlockInput): The input data for the graph execution.
|
||||
user_id (str): The ID of the user executing the graph.
|
||||
graph_version (int | None): The version of the graph to execute. Defaults to None.
|
||||
preset_id (str | None): The ID of the preset to use. Defaults to None.
|
||||
Returns:
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
ValueError: If the graph is not found or if there are validation errors.
|
||||
"""
|
||||
graph: GraphModel | None = get_db_client().get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise ValueError(f"Graph #{graph_id} not found.")
|
||||
|
||||
graph_exec = get_db_client().create_graph_execution(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
nodes_input=construct_node_execution_input(graph, user_id, data),
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
get_execution_event_bus().publish(graph_exec)
|
||||
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
get_execution_queue().publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
|
||||
return graph_exec_entry
|
||||
|
||||
@@ -10,7 +10,6 @@ from backend.data import redis
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -154,13 +153,12 @@ class IntegrationCredentialsManager:
|
||||
self.store.locks.release_all_locks()
|
||||
|
||||
|
||||
def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandler":
|
||||
provider_name = ProviderName(provider_name_str)
|
||||
def _get_provider_oauth_handler(provider_name: str) -> "BaseOAuthHandler":
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
raise KeyError(f"Unknown provider '{provider_name}'")
|
||||
|
||||
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id")
|
||||
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
|
||||
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
|
||||
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
|
||||
if not (client_id and client_secret):
|
||||
raise MissingConfigError(
|
||||
f"Integration with provider '{provider_name}' is not configured",
|
||||
|
||||
@@ -11,7 +11,6 @@ class ProviderName(str, Enum):
|
||||
E2B = "e2b"
|
||||
EXA = "exa"
|
||||
FAL = "fal"
|
||||
GENERIC_WEBHOOK = "generic_webhook"
|
||||
GITHUB = "github"
|
||||
GOOGLE = "google"
|
||||
GOOGLE_MAPS = "google_maps"
|
||||
|
||||
@@ -13,7 +13,6 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
|
||||
return _WEBHOOK_MANAGERS
|
||||
|
||||
from .compass import CompassWebhookManager
|
||||
from .generic import GenericWebhooksManager
|
||||
from .github import GithubWebhooksManager
|
||||
from .slant3d import Slant3DWebhooksManager
|
||||
|
||||
@@ -24,7 +23,6 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
|
||||
CompassWebhookManager,
|
||||
GithubWebhooksManager,
|
||||
Slant3DWebhooksManager,
|
||||
GenericWebhooksManager,
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
import logging
|
||||
|
||||
from fastapi import Request
|
||||
from strenum import StrEnum
|
||||
|
||||
from backend.data import integrations
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from ._manual_base import ManualWebhookManagerBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenericWebhookType(StrEnum):
|
||||
PLAIN = "plain"
|
||||
|
||||
|
||||
class GenericWebhooksManager(ManualWebhookManagerBase):
|
||||
PROVIDER_NAME = ProviderName.GENERIC_WEBHOOK
|
||||
WebhookType = GenericWebhookType
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(
|
||||
cls, webhook: integrations.Webhook, request: Request
|
||||
) -> tuple[dict, str]:
|
||||
payload = await request.json()
|
||||
event_type = GenericWebhookType.PLAIN
|
||||
|
||||
return payload, event_type
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable, Optional, cast
|
||||
|
||||
from backend.data.block import BlockSchema, BlockWebhookConfig
|
||||
from backend.data.block import BlockSchema, BlockWebhookConfig, get_block
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.integrations.webhooks import get_webhook_manager, supports_webhooks
|
||||
|
||||
@@ -29,7 +29,12 @@ async def on_graph_activate(
|
||||
# Compare nodes in new_graph_version with previous_graph_version
|
||||
updated_nodes = []
|
||||
for new_node in graph.nodes:
|
||||
block_input_schema = cast(BlockSchema, new_node.block.input_schema)
|
||||
block = get_block(new_node.block_id)
|
||||
if not block:
|
||||
raise ValueError(
|
||||
f"Node #{new_node.id} is instance of unknown block #{new_node.block_id}"
|
||||
)
|
||||
block_input_schema = cast(BlockSchema, block.input_schema)
|
||||
|
||||
node_credentials = None
|
||||
if (
|
||||
@@ -70,7 +75,12 @@ async def on_graph_deactivate(
|
||||
"""
|
||||
updated_nodes = []
|
||||
for node in graph.nodes:
|
||||
block_input_schema = cast(BlockSchema, node.block.input_schema)
|
||||
block = get_block(node.block_id)
|
||||
if not block:
|
||||
raise ValueError(
|
||||
f"Node #{node.id} is instance of unknown block #{node.block_id}"
|
||||
)
|
||||
block_input_schema = cast(BlockSchema, block.input_schema)
|
||||
|
||||
node_credentials = None
|
||||
if (
|
||||
@@ -103,7 +113,11 @@ async def on_node_activate(
|
||||
) -> "NodeModel":
|
||||
"""Hook to be called when the node is activated/created"""
|
||||
|
||||
block = node.block
|
||||
block = get_block(node.block_id)
|
||||
if not block:
|
||||
raise ValueError(
|
||||
f"Node #{node.id} is instance of unknown block #{node.block_id}"
|
||||
)
|
||||
|
||||
if not block.webhook_config:
|
||||
return node
|
||||
@@ -210,7 +224,11 @@ async def on_node_deactivate(
|
||||
"""Hook to be called when node is deactivated/deleted"""
|
||||
|
||||
logger.debug(f"Deactivating node #{node.id}")
|
||||
block = node.block
|
||||
block = get_block(node.block_id)
|
||||
if not block:
|
||||
raise ValueError(
|
||||
f"Node #{node.id} is instance of unknown block #{node.block_id}"
|
||||
)
|
||||
|
||||
if not block.webhook_config:
|
||||
return node
|
||||
|
||||
@@ -9,7 +9,6 @@ from autogpt_libs.utils.cache import thread_cached
|
||||
from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import rabbitmq
|
||||
from backend.data.notifications import (
|
||||
BaseSummaryData,
|
||||
BaseSummaryParams,
|
||||
@@ -129,20 +128,6 @@ class NotificationManager(AppService):
|
||||
self.running = True
|
||||
self.email_sender = EmailSender()
|
||||
|
||||
@property
|
||||
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
|
||||
"""Access the RabbitMQ service. Will raise if not configured."""
|
||||
if not self.rabbitmq_service:
|
||||
raise RuntimeError("RabbitMQ not configured for this service")
|
||||
return self.rabbitmq_service
|
||||
|
||||
@property
|
||||
def rabbit_config(self) -> rabbitmq.RabbitMQConfig:
|
||||
"""Access the RabbitMQ config. Will raise if not configured."""
|
||||
if not self.rabbitmq_config:
|
||||
raise RuntimeError("RabbitMQ not configured for this service")
|
||||
return self.rabbitmq_config
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return settings.config.notification_service_port
|
||||
@@ -260,26 +245,20 @@ class NotificationManager(AppService):
|
||||
continue
|
||||
|
||||
unsub_link = generate_unsubscribe_link(batch.user_id)
|
||||
events = []
|
||||
for db_event in batch_data.notifications:
|
||||
try:
|
||||
events.append(
|
||||
NotificationEventModel[
|
||||
get_notif_data_type(db_event.type)
|
||||
].model_validate(
|
||||
{
|
||||
"user_id": batch.user_id,
|
||||
"type": db_event.type,
|
||||
"data": db_event.data,
|
||||
"created_at": db_event.created_at,
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error parsing notification event: {e=}, {db_event=}"
|
||||
)
|
||||
continue
|
||||
|
||||
events = [
|
||||
NotificationEventModel[
|
||||
get_notif_data_type(db_event.type)
|
||||
].model_validate(
|
||||
{
|
||||
"user_id": batch.user_id,
|
||||
"type": db_event.type,
|
||||
"data": db_event.data,
|
||||
"created_at": db_event.created_at,
|
||||
}
|
||||
)
|
||||
for db_event in batch_data.notifications
|
||||
]
|
||||
logger.info(f"{events=}")
|
||||
|
||||
self.email_sender.send_templated(
|
||||
@@ -689,8 +668,6 @@ class NotificationManager(AppService):
|
||||
|
||||
except QueueEmpty:
|
||||
logger.debug(f"Queue {error_queue_name} empty")
|
||||
except TimeoutError:
|
||||
logger.debug(f"Queue {error_queue_name} timed out")
|
||||
except Exception as e:
|
||||
if message:
|
||||
logger.error(
|
||||
@@ -698,19 +675,15 @@ class NotificationManager(AppService):
|
||||
)
|
||||
self.run_and_wait(message.reject(requeue=False))
|
||||
else:
|
||||
logger.exception(
|
||||
f"Error in notification service loop, message unable to be rejected, and will have to be manually removed to free space in the queue: {e=}"
|
||||
logger.error(
|
||||
f"Error in notification service loop, message unable to be rejected, and will have to be manually removed to free space in the queue: {e}"
|
||||
)
|
||||
|
||||
def run_service(self):
|
||||
logger.info(f"[{self.service_name}] ⏳ Configuring RabbitMQ...")
|
||||
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
|
||||
self.run_and_wait(self.rabbitmq_service.connect())
|
||||
|
||||
logger.info(f"[{self.service_name}] Started notification service")
|
||||
|
||||
# Set up scheduler for batch processing of all notification types
|
||||
# this can be changed later to spawn different cleanups on different schedules
|
||||
# this can be changed later to spawn differnt cleanups on different schedules
|
||||
try:
|
||||
get_scheduler().add_batched_notification_schedule(
|
||||
notification_types=list(NotificationType),
|
||||
@@ -772,5 +745,3 @@ class NotificationManager(AppService):
|
||||
"""Cleanup service resources"""
|
||||
self.running = False
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
|
||||
self.run_and_wait(self.rabbitmq_service.disconnect())
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Dict, List, Optional, Sequence
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
from typing_extensions import TypedDict
|
||||
@@ -12,10 +13,17 @@ from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKey
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.execution import NodeExecutionResult
|
||||
from backend.executor import ExecutionManager
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.server.routers import v1 as internal_api_routes
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_manager_client() -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -90,18 +98,18 @@ def execute_graph_block(
|
||||
path="/graphs/{graph_id}/execute/{graph_version}",
|
||||
tags=["graphs"],
|
||||
)
|
||||
async def execute_graph(
|
||||
def execute_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
graph_exec = await internal_api_routes.execute_graph(
|
||||
graph_id=graph_id,
|
||||
node_input=node_input,
|
||||
user_id=api_key.user_id,
|
||||
graph_exec = execution_manager_client().add_execution(
|
||||
graph_id,
|
||||
graph_version=graph_version,
|
||||
data=node_input,
|
||||
user_id=api_key.user_id,
|
||||
)
|
||||
return {"id": graph_exec.graph_exec_id}
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Literal
|
||||
|
||||
@@ -15,12 +14,13 @@ from backend.data.integrations import (
|
||||
wait_for_webhook_event,
|
||||
)
|
||||
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
|
||||
from backend.executor.manager import ExecutionManager
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.server.routers import v1 as internal_api_routes
|
||||
from backend.util.exceptions import NeedConfirmation, NotFoundError
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -309,22 +309,19 @@ async def webhook_ingress_generic(
|
||||
if not webhook.attached_nodes:
|
||||
return
|
||||
|
||||
executions = []
|
||||
executor = get_service_client(ExecutionManager)
|
||||
for node in webhook.attached_nodes:
|
||||
logger.debug(f"Webhook-attached node: {node}")
|
||||
if not node.is_triggered_by_event_type(event_type):
|
||||
logger.debug(f"Node #{node.id} doesn't trigger on event {event_type}")
|
||||
continue
|
||||
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
|
||||
executions.append(
|
||||
internal_api_routes.execute_graph(
|
||||
graph_id=node.graph_id,
|
||||
graph_version=node.graph_version,
|
||||
node_input={f"webhook_{webhook_id}_payload": payload},
|
||||
user_id=webhook.user_id,
|
||||
)
|
||||
executor.add_execution(
|
||||
graph_id=node.graph_id,
|
||||
graph_version=node.graph_version,
|
||||
data={f"webhook_{webhook_id}_payload": payload},
|
||||
user_id=webhook.user_id,
|
||||
)
|
||||
asyncio.gather(*executions)
|
||||
|
||||
|
||||
@router.post("/webhooks/{webhook_id}/ping")
|
||||
|
||||
@@ -11,19 +11,19 @@ from autogpt_libs.feature_flag.client import (
|
||||
initialize_launchdarkly,
|
||||
shutdown_launchdarkly,
|
||||
)
|
||||
from autogpt_libs.logging.utils import generate_uvicorn_config
|
||||
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.server.routers.postmark.postmark
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.store_admin_routes
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.library.routes
|
||||
import backend.server.v2.otto.routes
|
||||
import backend.server.v2.postmark.postmark
|
||||
import backend.server.v2.store.model
|
||||
import backend.server.v2.store.routes
|
||||
import backend.util.service
|
||||
@@ -115,8 +115,8 @@ app.include_router(
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
backend.server.routers.postmark.postmark.router,
|
||||
tags=["v1", "email"],
|
||||
backend.server.v2.postmark.postmark.router,
|
||||
tags=["v2", "email"],
|
||||
prefix="/api/email",
|
||||
)
|
||||
|
||||
@@ -141,13 +141,8 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
server_app,
|
||||
host=backend.util.settings.Config().agent_api_host,
|
||||
port=backend.util.settings.Config().agent_api_port,
|
||||
log_config=generate_uvicorn_config(),
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down Agent Server...")
|
||||
|
||||
@staticmethod
|
||||
async def test_execute_graph(
|
||||
graph_id: str,
|
||||
@@ -155,7 +150,7 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
graph_version: Optional[int] = None,
|
||||
node_input: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
return await backend.server.routers.v1.execute_graph(
|
||||
return backend.server.routers.v1.execute_graph(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
@@ -274,9 +269,7 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
from backend.server.integrations.router import create_credentials
|
||||
|
||||
return create_credentials(
|
||||
return backend.server.integrations.router.create_credentials(
|
||||
user_id=user_id, provider=provider, credentials=credentials
|
||||
)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
import backend.data.block
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.v2.library.db as library_db
|
||||
@@ -30,7 +31,7 @@ from backend.data.api_key import (
|
||||
suspend_api_key,
|
||||
update_api_key_permissions,
|
||||
)
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -40,7 +41,6 @@ from backend.data.credit import (
|
||||
get_user_credit_model,
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
UserOnboardingUpdate,
|
||||
@@ -49,16 +49,13 @@ from backend.data.onboarding import (
|
||||
onboarding_enabled,
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
update_user_notification_preference,
|
||||
)
|
||||
from backend.executor import Scheduler, scheduler
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
from backend.executor import ExecutionManager, Scheduler, scheduler
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||
on_graph_activate,
|
||||
@@ -81,23 +78,16 @@ if TYPE_CHECKING:
|
||||
from backend.data.model import Credentials
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_manager_client() -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_scheduler_client() -> Scheduler:
|
||||
return get_service_client(Scheduler)
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def execution_queue_client() -> AsyncRabbitMQ:
|
||||
client = AsyncRabbitMQ(create_execution_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_event_bus() -> AsyncRedisExecutionEventBus:
|
||||
return AsyncRedisExecutionEventBus()
|
||||
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
integration_creds_manager = IntegrationCredentialsManager()
|
||||
@@ -216,7 +206,7 @@ async def is_onboarding_enabled():
|
||||
|
||||
@v1_router.get(path="/blocks", tags=["blocks"], dependencies=[Depends(auth_middleware)])
|
||||
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in get_blocks().values()]
|
||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||
costs = get_block_costs()
|
||||
return [
|
||||
{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks if not b.disabled
|
||||
@@ -229,7 +219,7 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput:
|
||||
obj = get_block(block_id)
|
||||
obj = backend.data.block.get_block(block_id)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||
|
||||
@@ -318,7 +308,7 @@ async def configure_user_auto_top_up(
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_user_auto_top_up(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> AutoTopUpConfig:
|
||||
return await get_auto_top_up(user_id)
|
||||
|
||||
@@ -385,7 +375,7 @@ async def get_credit_history(
|
||||
|
||||
@v1_router.get(path="/credits/refunds", dependencies=[Depends(auth_middleware)])
|
||||
async def get_refund_requests(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[RefundRequest]:
|
||||
return await _user_credit_model.get_refund_requests(user_id)
|
||||
|
||||
@@ -401,7 +391,7 @@ class DeleteGraphResponse(TypedDict):
|
||||
|
||||
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
|
||||
async def get_graphs(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
|
||||
|
||||
@@ -590,35 +580,16 @@ async def set_graph_active_version(
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def execute_graph(
|
||||
def execute_graph(
|
||||
graph_id: str,
|
||||
node_input: Annotated[dict[str, Any], Body(..., default_factory=dict)],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> ExecuteGraphResponse:
|
||||
graph: graph_db.GraphModel | None = await graph_db.get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
graph_exec = execution_manager_client().add_execution(
|
||||
graph_id, node_input, user_id=user_id, graph_version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise ValueError(f"Graph #{graph_id} not found.")
|
||||
|
||||
graph_exec = await execution_db.create_graph_execution(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
nodes_input=execution_utils.construct_node_execution_input(
|
||||
graph, user_id, node_input
|
||||
),
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
execution_utils.get_execution_event_bus().publish(graph_exec)
|
||||
execution_utils.get_execution_queue().publish_message(
|
||||
routing_key=execution_utils.GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec.to_graph_execution_entry().model_dump_json(),
|
||||
exchange=execution_utils.GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
return ExecuteGraphResponse(graph_exec_id=graph_exec.id)
|
||||
return ExecuteGraphResponse(graph_exec_id=graph_exec.graph_exec_id)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -634,7 +605,9 @@ async def stop_graph_run(
|
||||
):
|
||||
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
|
||||
|
||||
await _cancel_execution(graph_exec_id)
|
||||
await asyncio.to_thread(
|
||||
lambda: execution_manager_client().cancel_execution(graph_exec_id)
|
||||
)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
result = await execution_db.get_graph_execution(
|
||||
@@ -648,49 +621,6 @@ async def stop_graph_run(
|
||||
return result
|
||||
|
||||
|
||||
async def _cancel_execution(graph_exec_id: str):
|
||||
"""
|
||||
Mechanism:
|
||||
1. Set the cancel event
|
||||
2. Graph executor's cancel handler thread detects the event, terminates workers,
|
||||
reinitializes worker pool, and returns.
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
queue_client = await execution_queue_client()
|
||||
await queue_client.publish_message(
|
||||
routing_key="",
|
||||
message=execution_utils.CancelExecutionEvent(
|
||||
graph_exec_id=graph_exec_id
|
||||
).model_dump_json(),
|
||||
exchange=execution_utils.GRAPH_EXECUTION_CANCEL_EXCHANGE,
|
||||
)
|
||||
|
||||
# Update the status of the graph & node executions
|
||||
await execution_db.update_graph_execution_stats(
|
||||
graph_exec_id,
|
||||
execution_db.ExecutionStatus.TERMINATED,
|
||||
)
|
||||
node_execs = [
|
||||
node_exec.model_copy(update={"status": execution_db.ExecutionStatus.TERMINATED})
|
||||
for node_exec in await execution_db.get_node_execution_results(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
execution_db.ExecutionStatus.QUEUED,
|
||||
execution_db.ExecutionStatus.RUNNING,
|
||||
execution_db.ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
await execution_db.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
execution_db.ExecutionStatus.TERMINATED,
|
||||
)
|
||||
await asyncio.gather(
|
||||
*[execution_event_bus().publish(node_exec) for node_exec in node_execs]
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/executions",
|
||||
tags=["graphs"],
|
||||
@@ -862,7 +792,7 @@ async def create_api_key(
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_api_keys(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[APIKeyWithoutHash]:
|
||||
"""List all API keys for the user"""
|
||||
try:
|
||||
|
||||
@@ -68,12 +68,12 @@ async def list_library_agents(
|
||||
if search_term:
|
||||
where_clause["OR"] = [
|
||||
{
|
||||
"AgentGraph": {
|
||||
"Agent": {
|
||||
"is": {"name": {"contains": search_term, "mode": "insensitive"}}
|
||||
}
|
||||
},
|
||||
{
|
||||
"AgentGraph": {
|
||||
"Agent": {
|
||||
"is": {
|
||||
"description": {"contains": search_term, "mode": "insensitive"}
|
||||
}
|
||||
@@ -228,17 +228,16 @@ async def create_library_agent(
|
||||
|
||||
try:
|
||||
return await prisma.models.LibraryAgent.prisma().create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
isCreatedByUser=(user_id == graph.user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
# Creator={"connect": {"id": agent.userId}},
|
||||
AgentGraph={
|
||||
data={
|
||||
"isCreatedByUser": (user_id == graph.user_id),
|
||||
"useGraphIsActiveVersion": True,
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"Agent": {
|
||||
"connect": {
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating agent in library: {e}")
|
||||
@@ -247,41 +246,38 @@ async def create_library_agent(
|
||||
|
||||
async def update_agent_version_in_library(
|
||||
user_id: str,
|
||||
agent_graph_id: str,
|
||||
agent_graph_version: int,
|
||||
agent_id: str,
|
||||
agent_version: int,
|
||||
) -> None:
|
||||
"""
|
||||
Updates the agent version in the library if useGraphIsActiveVersion is True.
|
||||
|
||||
Args:
|
||||
user_id: Owner of the LibraryAgent.
|
||||
agent_graph_id: The agent graph's ID to update.
|
||||
agent_graph_version: The new version of the agent graph.
|
||||
agent_id: The agent's ID to update.
|
||||
agent_version: The new version of the agent.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's an error with the update.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Updating agent version in library for user #{user_id}, "
|
||||
f"agent #{agent_graph_id} v{agent_graph_version}"
|
||||
f"agent #{agent_id} v{agent_version}"
|
||||
)
|
||||
try:
|
||||
library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": agent_graph_id,
|
||||
"agentId": agent_id,
|
||||
"useGraphIsActiveVersion": True,
|
||||
},
|
||||
)
|
||||
await prisma.models.LibraryAgent.prisma().update(
|
||||
where={"id": library_agent.id},
|
||||
data={
|
||||
"AgentGraph": {
|
||||
"Agent": {
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": agent_graph_id,
|
||||
"version": agent_graph_version,
|
||||
}
|
||||
"graphVersionId": {"id": agent_id, "version": agent_version}
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -345,7 +341,7 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
|
||||
"""
|
||||
try:
|
||||
await prisma.models.LibraryAgent.prisma().delete_many(
|
||||
where={"agentGraphId": graph_id, "userId": user_id}
|
||||
where={"agentId": graph_id, "userId": user_id}
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error deleting library agent: {e}")
|
||||
@@ -378,10 +374,10 @@ async def add_store_agent_to_library(
|
||||
async with locked_transaction(f"add_agent_trx_{user_id}"):
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"AgentGraph": True}
|
||||
where={"id": store_listing_version_id}, include={"Agent": True}
|
||||
)
|
||||
)
|
||||
if not store_listing_version or not store_listing_version.AgentGraph:
|
||||
if not store_listing_version or not store_listing_version.Agent:
|
||||
logger.warning(
|
||||
f"Store listing version not found: {store_listing_version_id}"
|
||||
)
|
||||
@@ -389,7 +385,7 @@ async def add_store_agent_to_library(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
graph = store_listing_version.AgentGraph
|
||||
graph = store_listing_version.Agent
|
||||
if graph.userId == user_id:
|
||||
logger.warning(
|
||||
f"User #{user_id} attempted to add their own agent to their library"
|
||||
@@ -401,8 +397,8 @@ async def add_store_agent_to_library(
|
||||
await prisma.models.LibraryAgent.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"agentId": graph.id,
|
||||
"agentVersion": graph.version,
|
||||
},
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
@@ -422,17 +418,17 @@ async def add_store_agent_to_library(
|
||||
|
||||
# Create LibraryAgent entry
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId=user_id,
|
||||
agentGraphId=graph.id,
|
||||
agentGraphVersion=graph.version,
|
||||
isCreatedByUser=False,
|
||||
),
|
||||
data={
|
||||
"userId": user_id,
|
||||
"agentId": graph.id,
|
||||
"agentVersion": graph.version,
|
||||
"isCreatedByUser": False,
|
||||
},
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
logger.debug(
|
||||
f"Added graph #{graph.id} v{graph.version}"
|
||||
f"for store listing version #{store_listing_version.id} "
|
||||
f"Added graph #{graph.id} "
|
||||
f"for store listing #{store_listing_version.id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
@@ -471,8 +467,8 @@ async def set_is_deleted_for_library_agent(
|
||||
count = await prisma.models.LibraryAgent.prisma().update_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": agent_id,
|
||||
"agentGraphVersion": agent_version,
|
||||
"agentId": agent_id,
|
||||
"agentVersion": agent_version,
|
||||
},
|
||||
data={"isDeleted": is_deleted},
|
||||
)
|
||||
@@ -601,12 +597,6 @@ async def upsert_preset(
|
||||
f"Upserting preset #{preset_id} ({repr(preset.name)}) for user #{user_id}",
|
||||
)
|
||||
try:
|
||||
inputs = [
|
||||
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput(
|
||||
name=name, data=prisma.fields.Json(data)
|
||||
)
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
if preset_id:
|
||||
# Update existing preset
|
||||
updated = await prisma.models.AgentPreset.prisma().update(
|
||||
@@ -615,7 +605,12 @@ async def upsert_preset(
|
||||
"name": preset.name,
|
||||
"description": preset.description,
|
||||
"isActive": preset.is_active,
|
||||
"InputPresets": {"create": inputs},
|
||||
"InputPresets": {
|
||||
"create": [
|
||||
{"name": name, "data": prisma.fields.Json(data)}
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
},
|
||||
},
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
@@ -625,15 +620,20 @@ async def upsert_preset(
|
||||
else:
|
||||
# Create new preset
|
||||
new_preset = await prisma.models.AgentPreset.prisma().create(
|
||||
data=prisma.types.AgentPresetCreateInput(
|
||||
userId=user_id,
|
||||
name=preset.name,
|
||||
description=preset.description,
|
||||
agentGraphId=preset.graph_id,
|
||||
agentGraphVersion=preset.graph_version,
|
||||
isActive=preset.is_active,
|
||||
InputPresets={"create": inputs},
|
||||
),
|
||||
data={
|
||||
"userId": user_id,
|
||||
"name": preset.name,
|
||||
"description": preset.description,
|
||||
"agentId": preset.agent_id,
|
||||
"agentVersion": preset.agent_version,
|
||||
"isActive": preset.is_active,
|
||||
"InputPresets": {
|
||||
"create": [
|
||||
{"name": name, "data": prisma.fields.Json(data)}
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
},
|
||||
},
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
return library_model.LibraryAgentPreset.from_db(new_preset)
|
||||
|
||||
@@ -30,8 +30,8 @@ async def test_get_library_agents(mocker):
|
||||
prisma.models.LibraryAgent(
|
||||
id="ua1",
|
||||
userId="test-user",
|
||||
agentGraphId="agent2",
|
||||
agentGraphVersion=1,
|
||||
agentId="agent2",
|
||||
agentVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
@@ -39,7 +39,7 @@ async def test_get_library_agents(mocker):
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
Agent=prisma.models.AgentGraph(
|
||||
id="agent2",
|
||||
version=1,
|
||||
name="Test Agent 2",
|
||||
@@ -71,8 +71,8 @@ async def test_get_library_agents(mocker):
|
||||
assert result.agents[0].id == "ua1"
|
||||
assert result.agents[0].name == "Test Agent 2"
|
||||
assert result.agents[0].description == "Test Description 2"
|
||||
assert result.agents[0].graph_id == "agent2"
|
||||
assert result.agents[0].graph_version == 1
|
||||
assert result.agents[0].agent_id == "agent2"
|
||||
assert result.agents[0].agent_version == 1
|
||||
assert result.agents[0].can_access_graph is False
|
||||
assert result.agents[0].is_latest_version is True
|
||||
assert result.pagination.total_items == 1
|
||||
@@ -81,7 +81,7 @@ async def test_get_library_agents(mocker):
|
||||
assert result.pagination.page_size == 50
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_add_agent_to_library(mocker):
|
||||
await connect()
|
||||
# Mock data
|
||||
@@ -90,8 +90,8 @@ async def test_add_agent_to_library(mocker):
|
||||
version=1,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
agentId="agent1",
|
||||
agentVersion=1,
|
||||
name="Test Agent",
|
||||
subHeading="Test Agent Subheading",
|
||||
imageUrls=["https://example.com/image.jpg"],
|
||||
@@ -102,7 +102,7 @@ async def test_add_agent_to_library(mocker):
|
||||
isAvailable=True,
|
||||
storeListingId="listing123",
|
||||
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
Agent=prisma.models.AgentGraph(
|
||||
id="agent1",
|
||||
version=1,
|
||||
name="Test Agent",
|
||||
@@ -116,8 +116,8 @@ async def test_add_agent_to_library(mocker):
|
||||
mock_library_agent_data = prisma.models.LibraryAgent(
|
||||
id="ua1",
|
||||
userId="test-user",
|
||||
agentGraphId=mock_store_listing_data.agentGraphId,
|
||||
agentGraphVersion=1,
|
||||
agentId=mock_store_listing_data.agentId,
|
||||
agentVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
@@ -125,7 +125,7 @@ async def test_add_agent_to_library(mocker):
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=mock_store_listing_data.AgentGraph,
|
||||
Agent=mock_store_listing_data.Agent,
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
@@ -147,28 +147,25 @@ async def test_add_agent_to_library(mocker):
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
where={"id": "version123"}, include={"Agent": True}
|
||||
)
|
||||
mock_library_agent.return_value.find_first.assert_called_once_with(
|
||||
where={
|
||||
"userId": "test-user",
|
||||
"agentGraphId": "agent1",
|
||||
"agentGraphVersion": 1,
|
||||
"agentId": "agent1",
|
||||
"agentVersion": 1,
|
||||
},
|
||||
include=library_agent_include("test-user"),
|
||||
)
|
||||
mock_library_agent.return_value.create.assert_called_once_with(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId="test-user",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
|
||||
),
|
||||
include=library_agent_include("test-user"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_add_agent_to_library_not_found(mocker):
|
||||
await connect()
|
||||
# Mock prisma calls
|
||||
@@ -185,5 +182,5 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
|
||||
# Verify mock called correctly
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
where={"id": "version123"}, include={"Agent": True}
|
||||
)
|
||||
|
||||
@@ -25,8 +25,8 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
"""
|
||||
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -58,12 +58,12 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
|
||||
model instance.
|
||||
"""
|
||||
if not agent.AgentGraph:
|
||||
if not agent.Agent:
|
||||
raise ValueError("Associated Agent record is required.")
|
||||
|
||||
graph = graph_model.GraphModel.from_db(agent.AgentGraph)
|
||||
graph = graph_model.GraphModel.from_db(agent.Agent)
|
||||
|
||||
agent_updated_at = agent.AgentGraph.updatedAt
|
||||
agent_updated_at = agent.Agent.updatedAt
|
||||
lib_agent_updated_at = agent.updatedAt
|
||||
|
||||
# Compute updated_at as the latest between library agent and graph
|
||||
@@ -83,21 +83,21 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||
days=7
|
||||
)
|
||||
executions = agent.AgentGraph.Executions or []
|
||||
executions = agent.Agent.AgentGraphExecution or []
|
||||
status_result = _calculate_agent_status(executions, week_ago)
|
||||
status = status_result.status
|
||||
new_output = status_result.new_output
|
||||
|
||||
# Check if user can access the graph
|
||||
can_access_graph = agent.AgentGraph.userId == agent.userId
|
||||
can_access_graph = agent.Agent.userId == agent.userId
|
||||
|
||||
# Hard-coded to True until a method to check is implemented
|
||||
is_latest_version = True
|
||||
|
||||
return LibraryAgent(
|
||||
id=agent.id,
|
||||
graph_id=agent.agentGraphId,
|
||||
graph_version=agent.agentGraphVersion,
|
||||
agent_id=agent.agentId,
|
||||
agent_version=agent.agentVersion,
|
||||
image_url=agent.imageUrl,
|
||||
creator_name=creator_name,
|
||||
creator_image_url=creator_image_url,
|
||||
@@ -174,8 +174,8 @@ class LibraryAgentPreset(pydantic.BaseModel):
|
||||
id: str
|
||||
updated_at: datetime.datetime
|
||||
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
|
||||
name: str
|
||||
description: str
|
||||
@@ -194,8 +194,8 @@ class LibraryAgentPreset(pydantic.BaseModel):
|
||||
return cls(
|
||||
id=preset.id,
|
||||
updated_at=preset.updatedAt,
|
||||
graph_id=preset.agentGraphId,
|
||||
graph_version=preset.agentGraphVersion,
|
||||
agent_id=preset.agentId,
|
||||
agent_version=preset.agentVersion,
|
||||
name=preset.name,
|
||||
description=preset.description,
|
||||
is_active=preset.isActive,
|
||||
@@ -218,8 +218,8 @@ class CreateLibraryAgentPresetRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
inputs: block_model.BlockInput
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
is_active: bool
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import prisma.models
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
from backend.util import json
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -14,8 +15,8 @@ async def test_agent_preset_from_db():
|
||||
id="test-agent-123",
|
||||
createdAt=datetime.datetime.now(),
|
||||
updatedAt=datetime.datetime.now(),
|
||||
agentGraphId="agent-123",
|
||||
agentGraphVersion=1,
|
||||
agentId="agent-123",
|
||||
agentVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test agent description",
|
||||
isActive=True,
|
||||
@@ -26,7 +27,7 @@ async def test_agent_preset_from_db():
|
||||
id="input-123",
|
||||
time=datetime.datetime.now(),
|
||||
name="input1",
|
||||
data=prisma.Json({"type": "string", "value": "test value"}),
|
||||
data=json.dumps({"type": "string", "value": "test value"}), # type: ignore
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -35,7 +36,7 @@ async def test_agent_preset_from_db():
|
||||
agent = library_model.LibraryAgentPreset.from_db(db_agent)
|
||||
|
||||
assert agent.id == "test-agent-123"
|
||||
assert agent.graph_version == 1
|
||||
assert agent.agent_version == 1
|
||||
assert agent.is_active is True
|
||||
assert agent.name == "Test Agent"
|
||||
assert agent.description == "Test agent description"
|
||||
|
||||
@@ -2,16 +2,25 @@ import logging
|
||||
from typing import Annotated, Any
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
import autogpt_libs.utils.cache
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, status
|
||||
|
||||
import backend.executor
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
import backend.util.service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@autogpt_libs.utils.cache.thread_cached
|
||||
def execution_manager_client() -> backend.executor.ExecutionManager:
|
||||
"""Return a cached instance of ExecutionManager client."""
|
||||
return backend.util.service.get_service_client(backend.executor.ExecutionManager)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/presets",
|
||||
summary="List presets",
|
||||
@@ -207,8 +216,6 @@ async def execute_preset(
|
||||
HTTPException: If the preset is not found or an error occurs while executing the preset.
|
||||
"""
|
||||
try:
|
||||
from backend.server.routers import v1 as internal_api_routes
|
||||
|
||||
preset = await db.get_preset(user_id, preset_id)
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
@@ -219,10 +226,10 @@ async def execute_preset(
|
||||
# Merge input overrides with preset inputs
|
||||
merged_node_input = preset.inputs | node_input
|
||||
|
||||
execution = await internal_api_routes.execute_graph(
|
||||
execution = execution_manager_client().add_execution(
|
||||
graph_id=graph_id,
|
||||
node_input=merged_node_input,
|
||||
graph_version=graph_version,
|
||||
data=merged_node_input,
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
@@ -35,8 +35,8 @@ async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||
agents=[
|
||||
library_model.LibraryAgent(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
agent_id="test-agent-1",
|
||||
agent_version=1,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -51,8 +51,8 @@ async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||
),
|
||||
library_model.LibraryAgent(
|
||||
id="test-agent-2",
|
||||
graph_id="test-agent-2",
|
||||
graph_version=1,
|
||||
agent_id="test-agent-2",
|
||||
agent_version=1,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
@@ -78,9 +78,9 @@ async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||
|
||||
data = library_model.LibraryAgentResponse.model_validate(response.json())
|
||||
assert len(data.agents) == 2
|
||||
assert data.agents[0].graph_id == "test-agent-1"
|
||||
assert data.agents[0].agent_id == "test-agent-1"
|
||||
assert data.agents[0].can_access_graph is True
|
||||
assert data.agents[1].graph_id == "test-agent-2"
|
||||
assert data.agents[1].agent_id == "test-agent-2"
|
||||
assert data.agents[1].can_access_graph is False
|
||||
mock_db_call.assert_called_once_with(
|
||||
user_id="test-user-id",
|
||||
|
||||
@@ -10,7 +10,7 @@ from backend.data.user import (
|
||||
set_user_email_verification,
|
||||
unsubscribe_user_by_token,
|
||||
)
|
||||
from backend.server.routers.postmark.models import (
|
||||
from backend.server.v2.postmark.models import (
|
||||
PostmarkBounceEnum,
|
||||
PostmarkBounceWebhook,
|
||||
PostmarkClickWebhook,
|
||||
@@ -200,17 +200,17 @@ async def get_available_graph(
|
||||
"isAvailable": True,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include={"AgentGraph": {"include": {"Nodes": True}}},
|
||||
include={"Agent": {"include": {"AgentNodes": True}}},
|
||||
)
|
||||
)
|
||||
|
||||
if not store_listing_version or not store_listing_version.AgentGraph:
|
||||
if not store_listing_version or not store_listing_version.Agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Store listing version {store_listing_version_id} not found",
|
||||
)
|
||||
|
||||
graph = GraphModel.from_db(store_listing_version.AgentGraph)
|
||||
graph = GraphModel.from_db(store_listing_version.Agent)
|
||||
# We return graph meta, without nodes, they cannot be just removed
|
||||
# because then input_schema would be empty
|
||||
return {
|
||||
@@ -516,7 +516,7 @@ async def delete_store_submission(
|
||||
try:
|
||||
# Verify the submission belongs to this user
|
||||
submission = await prisma.models.StoreListing.prisma().find_first(
|
||||
where={"agentGraphId": submission_id, "owningUserId": user_id}
|
||||
where={"agentId": submission_id, "owningUserId": user_id}
|
||||
)
|
||||
|
||||
if not submission:
|
||||
@@ -598,7 +598,7 @@ async def create_store_submission(
|
||||
# Check if listing already exists for this agent
|
||||
existing_listing = await prisma.models.StoreListing.prisma().find_first(
|
||||
where=prisma.types.StoreListingWhereInput(
|
||||
agentGraphId=agent_id, owningUserId=user_id
|
||||
agentId=agent_id, owningUserId=user_id
|
||||
)
|
||||
)
|
||||
|
||||
@@ -625,15 +625,15 @@ async def create_store_submission(
|
||||
# If no existing listing, create a new one
|
||||
data = prisma.types.StoreListingCreateInput(
|
||||
slug=slug,
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
agentId=agent_id,
|
||||
agentVersion=agent_version,
|
||||
owningUserId=user_id,
|
||||
createdAt=datetime.now(tz=timezone.utc),
|
||||
Versions={
|
||||
"create": [
|
||||
prisma.types.StoreListingVersionCreateInput(
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
agentId=agent_id,
|
||||
agentVersion=agent_version,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
imageUrls=image_urls,
|
||||
@@ -758,8 +758,8 @@ async def create_store_version(
|
||||
new_version = await prisma.models.StoreListingVersion.prisma().create(
|
||||
data=prisma.types.StoreListingVersionCreateInput(
|
||||
version=next_version,
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
agentId=agent_id,
|
||||
agentVersion=agent_version,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
imageUrls=image_urls,
|
||||
@@ -959,17 +959,17 @@ async def get_my_agents(
|
||||
try:
|
||||
search_filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
"AgentGraph": {"is": {"StoreListings": {"none": {"isDeleted": False}}}},
|
||||
"Agent": {"is": {"StoreListing": {"none": {"isDeleted": False}}}},
|
||||
"isArchived": False,
|
||||
"isDeleted": False,
|
||||
}
|
||||
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=search_filter,
|
||||
order=[{"agentGraphVersion": "desc"}],
|
||||
order=[{"agentVersion": "desc"}],
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
include={"AgentGraph": True},
|
||||
include={"Agent": True},
|
||||
)
|
||||
|
||||
total = await prisma.models.LibraryAgent.prisma().count(where=search_filter)
|
||||
@@ -985,7 +985,7 @@ async def get_my_agents(
|
||||
agent_image=library_agent.imageUrl,
|
||||
)
|
||||
for library_agent in library_agents
|
||||
if (graph := library_agent.AgentGraph)
|
||||
if (graph := library_agent.Agent)
|
||||
]
|
||||
|
||||
return backend.server.v2.store.model.MyAgentsResponse(
|
||||
@@ -1020,13 +1020,13 @@ async def get_agent(
|
||||
|
||||
graph = await backend.data.graph.get_graph(
|
||||
user_id=user_id,
|
||||
graph_id=store_listing_version.agentGraphId,
|
||||
version=store_listing_version.agentGraphVersion,
|
||||
graph_id=store_listing_version.agentId,
|
||||
version=store_listing_version.agentVersion,
|
||||
for_export=True,
|
||||
)
|
||||
if not graph:
|
||||
raise ValueError(
|
||||
f"Agent {store_listing_version.agentGraphId} v{store_listing_version.agentGraphVersion} not found"
|
||||
f"Agent {store_listing_version.agentId} v{store_listing_version.agentVersion} not found"
|
||||
)
|
||||
|
||||
return graph
|
||||
@@ -1050,14 +1050,11 @@ async def _get_missing_sub_store_listing(
|
||||
|
||||
# Fetch all the sub-graphs that are listed, and return the ones missing.
|
||||
store_listed_sub_graphs = {
|
||||
(listing.agentGraphId, listing.agentGraphVersion)
|
||||
(listing.agentId, listing.agentVersion)
|
||||
for listing in await prisma.models.StoreListingVersion.prisma().find_many(
|
||||
where={
|
||||
"OR": [
|
||||
{
|
||||
"agentGraphId": sub_graph.id,
|
||||
"agentGraphVersion": sub_graph.version,
|
||||
}
|
||||
{"agentId": sub_graph.id, "agentVersion": sub_graph.version}
|
||||
for sub_graph in sub_graphs
|
||||
],
|
||||
"submissionStatus": prisma.enums.SubmissionStatus.APPROVED,
|
||||
@@ -1087,7 +1084,7 @@ async def review_store_submission(
|
||||
where={"id": store_listing_version_id},
|
||||
include={
|
||||
"StoreListing": True,
|
||||
"AgentGraph": {"include": AGENT_GRAPH_INCLUDE},
|
||||
"Agent": {"include": AGENT_GRAPH_INCLUDE}, # type: ignore
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -1099,23 +1096,23 @@ async def review_store_submission(
|
||||
)
|
||||
|
||||
# If approving, update the listing to indicate it has an approved version
|
||||
if is_approved and store_listing_version.AgentGraph:
|
||||
heading = f"Sub-graph of {store_listing_version.name}v{store_listing_version.agentGraphVersion}"
|
||||
if is_approved and store_listing_version.Agent:
|
||||
heading = f"Sub-graph of {store_listing_version.name}v{store_listing_version.agentVersion}"
|
||||
|
||||
sub_store_listing_versions = [
|
||||
prisma.types.StoreListingVersionCreateWithoutRelationsInput(
|
||||
agentGraphId=sub_graph.id,
|
||||
agentGraphVersion=sub_graph.version,
|
||||
agentId=sub_graph.id,
|
||||
agentVersion=sub_graph.version,
|
||||
name=sub_graph.name or heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
|
||||
subHeading=heading,
|
||||
description=f"{heading}: {sub_graph.description}",
|
||||
changesSummary=f"This listing is added as a {heading} / #{store_listing_version.agentGraphId}.",
|
||||
changesSummary=f"This listing is added as a {heading} / #{store_listing_version.agentId}.",
|
||||
isAvailable=False, # Hide sub-graphs from the store by default.
|
||||
submittedAt=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
for sub_graph in await _get_missing_sub_store_listing(
|
||||
store_listing_version.AgentGraph
|
||||
store_listing_version.Agent
|
||||
)
|
||||
]
|
||||
|
||||
@@ -1158,8 +1155,8 @@ async def review_store_submission(
|
||||
|
||||
# Convert to Pydantic model for consistency
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id=submission.agentGraphId,
|
||||
agent_version=submission.agentGraphVersion,
|
||||
agent_id=submission.agentId,
|
||||
agent_version=submission.agentVersion,
|
||||
name=submission.name,
|
||||
sub_heading=submission.subHeading,
|
||||
slug=(
|
||||
@@ -1297,8 +1294,8 @@ async def get_admin_listings_with_versions(
|
||||
# If we have versions, turn them into StoreSubmission models
|
||||
for version in listing.Versions or []:
|
||||
version_model = backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id=version.agentGraphId,
|
||||
agent_version=version.agentGraphVersion,
|
||||
agent_id=version.agentId,
|
||||
agent_version=version.agentVersion,
|
||||
name=version.name,
|
||||
sub_heading=version.subHeading,
|
||||
slug=listing.slug,
|
||||
@@ -1327,8 +1324,8 @@ async def get_admin_listings_with_versions(
|
||||
backend.server.v2.store.model.StoreListingWithVersions(
|
||||
listing_id=listing.id,
|
||||
slug=listing.slug,
|
||||
agent_id=listing.agentGraphId,
|
||||
agent_version=listing.agentGraphVersion,
|
||||
agent_id=listing.agentId,
|
||||
agent_version=listing.agentVersion,
|
||||
active_version_id=listing.activeVersionId,
|
||||
has_approved_version=listing.hasApprovedVersion,
|
||||
creator_email=creator_email,
|
||||
|
||||
@@ -170,14 +170,14 @@ async def test_create_store_submission(mocker):
|
||||
isDeleted=False,
|
||||
hasApprovedVersion=False,
|
||||
slug="test-agent",
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
agentId="agent-id",
|
||||
agentVersion=1,
|
||||
owningUserId="user-id",
|
||||
Versions=[
|
||||
prisma.models.StoreListingVersion(
|
||||
id="version-id",
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
agentId="agent-id",
|
||||
agentVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
createdAt=datetime.now(),
|
||||
|
||||
@@ -5,11 +5,11 @@ from typing import Protocol
|
||||
|
||||
import uvicorn
|
||||
from autogpt_libs.auth import parse_jwt_token
|
||||
from autogpt_libs.logging.utils import generate_uvicorn_config
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
@@ -55,12 +55,15 @@ def get_db_client():
|
||||
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
try:
|
||||
redis.connect()
|
||||
event_queue = AsyncRedisExecutionEventBus()
|
||||
async for event in event_queue.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
except Exception as e:
|
||||
logger.exception(f"Event broadcaster error: {e}")
|
||||
raise
|
||||
finally:
|
||||
redis.disconnect()
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
@@ -283,14 +286,8 @@ class WebsocketServer(AppProcess):
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
uvicorn.run(
|
||||
server_app,
|
||||
host=Config().websocket_server_host,
|
||||
port=Config().websocket_server_port,
|
||||
log_config=generate_uvicorn_config(),
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down WebSocket Server...")
|
||||
|
||||
@@ -15,25 +15,21 @@ def to_dict(data) -> dict:
|
||||
|
||||
|
||||
def dumps(data) -> str:
|
||||
return json.dumps(to_dict(data))
|
||||
return json.dumps(jsonable_encoder(data))
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str | bytes, *args, target_type: Type[T], **kwargs) -> T: ...
|
||||
def loads(data: str, *args, target_type: Type[T], **kwargs) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str | bytes, *args, **kwargs) -> Any: ...
|
||||
def loads(data: str, *args, **kwargs) -> Any: ...
|
||||
|
||||
|
||||
def loads(
|
||||
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
|
||||
) -> Any:
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
def loads(data: str, *args, target_type: Type[T] | None = None, **kwargs) -> Any:
|
||||
parsed = json.loads(data, *args, **kwargs)
|
||||
if target_type:
|
||||
return type_match(parsed, target_type)
|
||||
|
||||
@@ -1,24 +1,8 @@
|
||||
import logging
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.integrations.anthropic import AnthropicIntegration
|
||||
from sentry_sdk.integrations.logging import LoggingIntegration
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
def sentry_init():
|
||||
sentry_dsn = Settings().secrets.sentry_dsn
|
||||
sentry_sdk.init(
|
||||
dsn=sentry_dsn,
|
||||
traces_sample_rate=1.0,
|
||||
profiles_sample_rate=1.0,
|
||||
environment=f"app:{Settings().config.app_env.value}-behave:{Settings().config.behave_as.value}",
|
||||
_experiments={"enable_logs": True},
|
||||
integrations=[
|
||||
LoggingIntegration(sentry_logs_level=logging.INFO),
|
||||
AnthropicIntegration(
|
||||
include_prompts=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
sentry_sdk.init(dsn=sentry_dsn, traces_sample_rate=1.0, profiles_sample_rate=1.0)
|
||||
|
||||
@@ -28,7 +28,6 @@ class AppProcess(ABC):
|
||||
"""
|
||||
|
||||
process: Optional[Process] = None
|
||||
cleaned_up = False
|
||||
|
||||
set_start_method("spawn", force=True)
|
||||
configure_logging()
|
||||
@@ -48,7 +47,6 @@ class AppProcess(ABC):
|
||||
def service_name(cls) -> str:
|
||||
return cls.__name__
|
||||
|
||||
@abstractmethod
|
||||
def cleanup(self):
|
||||
"""
|
||||
Implement this method on a subclass to do post-execution cleanup,
|
||||
@@ -64,7 +62,6 @@ class AppProcess(ABC):
|
||||
|
||||
def execute_run_command(self, silent):
|
||||
signal.signal(signal.SIGTERM, self._self_terminate)
|
||||
signal.signal(signal.SIGINT, self._self_terminate)
|
||||
|
||||
try:
|
||||
if silent:
|
||||
@@ -76,16 +73,9 @@ class AppProcess(ABC):
|
||||
self.run()
|
||||
except (KeyboardInterrupt, SystemExit) as e:
|
||||
logger.warning(f"[{self.service_name}] Terminated: {e}; quitting...")
|
||||
finally:
|
||||
if not self.cleaned_up:
|
||||
self.cleanup()
|
||||
self.cleaned_up = True
|
||||
logger.info(f"[{self.service_name}] Terminated.")
|
||||
|
||||
def _self_terminate(self, signum: int, frame):
|
||||
if not self.cleaned_up:
|
||||
self.cleanup()
|
||||
self.cleaned_up = True
|
||||
self.cleanup()
|
||||
sys.exit(0)
|
||||
|
||||
# Methods that are executed OUTSIDE the process #
|
||||
|
||||
@@ -142,7 +142,7 @@ def validate_url(
|
||||
|
||||
# Resolve all IP addresses for the hostname
|
||||
try:
|
||||
ip_list = [str(res[4][0]) for res in socket.getaddrinfo(ascii_hostname, None)]
|
||||
ip_list = [res[4][0] for res in socket.getaddrinfo(ascii_hostname, None)]
|
||||
ipv4 = [ip for ip in ip_list if ":" not in ip]
|
||||
ipv6 = [ip for ip in ip_list if ":" in ip]
|
||||
ip_addresses = ipv4 + ipv6 # Prefer IPv4 over IPv6
|
||||
|
||||
@@ -34,7 +34,7 @@ def conn_retry(
|
||||
def on_retry(retry_state):
|
||||
prefix = _log_prefix(resource_name, conn_id)
|
||||
exception = retry_state.outcome.exception()
|
||||
logger.warning(f"{prefix} {action_name} failed: {exception}. Retrying now...")
|
||||
logger.error(f"{prefix} {action_name} failed: {exception}. Retrying now...")
|
||||
|
||||
def decorator(func):
|
||||
is_coroutine = asyncio.iscoroutinefunction(func)
|
||||
|
||||
@@ -1,34 +1,52 @@
|
||||
import asyncio
|
||||
import builtins
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from types import NoneType, UnionType
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Concatenate,
|
||||
Coroutine,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
import httpx
|
||||
import Pyro5.api
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, responses
|
||||
from pydantic import BaseModel, TypeAdapter, create_model
|
||||
from Pyro5 import api as pyro
|
||||
from Pyro5 import config as pyro_config
|
||||
|
||||
from backend.data import db, rabbitmq, redis
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.json import to_dict
|
||||
from backend.util.metrics import sentry_init
|
||||
from backend.util.process import AppProcess, get_service_name
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Config
|
||||
from backend.util.settings import Config, Secrets
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
@@ -39,18 +57,21 @@ api_host = config.pyro_host
|
||||
api_comm_retry = config.pyro_client_comm_retry
|
||||
api_comm_timeout = config.pyro_client_comm_timeout
|
||||
api_call_timeout = config.rpc_client_call_timeout
|
||||
pyro_config.MAX_RETRIES = api_comm_retry # type: ignore
|
||||
pyro_config.COMMTIMEOUT = api_comm_timeout # type: ignore
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def expose(func: C) -> C:
|
||||
def fastapi_expose(func: C) -> C:
|
||||
func = getattr(func, "__func__", func)
|
||||
setattr(func, "__exposed__", True)
|
||||
return func
|
||||
|
||||
|
||||
def exposed_run_and_wait(
|
||||
def fastapi_exposed_run_and_wait(
|
||||
f: Callable[P, Coroutine[None, None, R]]
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
# TODO:
|
||||
@@ -60,11 +81,107 @@ def exposed_run_and_wait(
|
||||
return expose(f) # type: ignore
|
||||
|
||||
|
||||
# ----- Begin Pyro Expose Block ---- #
|
||||
def pyro_expose(func: C) -> C:
|
||||
"""
|
||||
Decorator to mark a method or class to be exposed for remote calls.
|
||||
|
||||
## ⚠️ Gotcha
|
||||
Aside from "simple" types, only Pydantic models are passed unscathed *if annotated*.
|
||||
Any other passed or returned class objects are converted to dictionaries by Pyro.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
msg = f"Error in {func.__name__}: {e}"
|
||||
if isinstance(e, ValueError):
|
||||
logger.warning(msg)
|
||||
else:
|
||||
logger.exception(msg)
|
||||
raise
|
||||
|
||||
register_pydantic_serializers(func)
|
||||
|
||||
return pyro.expose(wrapper) # type: ignore
|
||||
|
||||
|
||||
def register_pydantic_serializers(func: Callable):
|
||||
"""Register custom serializers and deserializers for annotated Pydantic models"""
|
||||
for name, annotation in func.__annotations__.items():
|
||||
try:
|
||||
pydantic_types = _pydantic_models_from_type_annotation(annotation)
|
||||
except Exception as e:
|
||||
raise TypeError(f"Error while exposing {func.__name__}: {e}")
|
||||
|
||||
for model in pydantic_types:
|
||||
logger.debug(
|
||||
f"Registering Pyro (de)serializers for {func.__name__} annotation "
|
||||
f"'{name}': {model.__qualname__}"
|
||||
)
|
||||
pyro.register_class_to_dict(model, _make_custom_serializer(model))
|
||||
pyro.register_dict_to_class(
|
||||
model.__qualname__, _make_custom_deserializer(model)
|
||||
)
|
||||
|
||||
|
||||
def _make_custom_serializer(model: Type[BaseModel]):
|
||||
def custom_class_to_dict(obj):
|
||||
data = {
|
||||
"__class__": obj.__class__.__qualname__,
|
||||
**obj.model_dump(),
|
||||
}
|
||||
logger.debug(f"Serializing {obj.__class__.__qualname__} with data: {data}")
|
||||
return data
|
||||
|
||||
return custom_class_to_dict
|
||||
|
||||
|
||||
def _make_custom_deserializer(model: Type[BaseModel]):
|
||||
def custom_dict_to_class(qualname, data: dict):
|
||||
logger.debug(f"Deserializing {model.__qualname__} from data: {data}")
|
||||
return model(**data)
|
||||
|
||||
return custom_dict_to_class
|
||||
|
||||
|
||||
def pyro_exposed_run_and_wait(
|
||||
f: Callable[P, Coroutine[None, None, R]]
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
@expose
|
||||
@wraps(f)
|
||||
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
coroutine = f(*args, **kwargs)
|
||||
res = self.run_and_wait(coroutine)
|
||||
return res
|
||||
|
||||
# Register serializers for annotations on bare function
|
||||
register_pydantic_serializers(f)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
if config.use_http_based_rpc:
|
||||
expose = fastapi_expose
|
||||
exposed_run_and_wait = fastapi_exposed_run_and_wait
|
||||
else:
|
||||
expose = pyro_expose
|
||||
exposed_run_and_wait = pyro_exposed_run_and_wait
|
||||
|
||||
# ----- End Pyro Expose Block ---- #
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# AppService for IPC service based on HTTP request through FastAPI
|
||||
# --------------------------------------------------
|
||||
class BaseAppService(AppProcess, ABC):
|
||||
shared_event_loop: asyncio.AbstractEventLoop
|
||||
use_db: bool = False
|
||||
use_redis: bool = False
|
||||
rabbitmq_config: Optional[rabbitmq.RabbitMQConfig] = None
|
||||
rabbitmq_service: Optional[rabbitmq.AsyncRabbitMQ] = None
|
||||
use_supabase: bool = False
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@@ -85,6 +202,20 @@ class BaseAppService(AppProcess, ABC):
|
||||
|
||||
return target_host
|
||||
|
||||
@property
|
||||
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
|
||||
"""Access the RabbitMQ service. Will raise if not configured."""
|
||||
if not self.rabbitmq_service:
|
||||
raise RuntimeError("RabbitMQ not configured for this service")
|
||||
return self.rabbitmq_service
|
||||
|
||||
@property
|
||||
def rabbit_config(self) -> rabbitmq.RabbitMQConfig:
|
||||
"""Access the RabbitMQ config. Will raise if not configured."""
|
||||
if not self.rabbitmq_config:
|
||||
raise RuntimeError("RabbitMQ not configured for this service")
|
||||
return self.rabbitmq_config
|
||||
|
||||
def run_service(self) -> None:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
@@ -94,6 +225,31 @@ class BaseAppService(AppProcess, ABC):
|
||||
|
||||
def run(self):
|
||||
self.shared_event_loop = asyncio.get_event_loop()
|
||||
if self.use_db:
|
||||
self.shared_event_loop.run_until_complete(db.connect())
|
||||
if self.use_redis:
|
||||
redis.connect()
|
||||
if self.rabbitmq_config:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Configuring RabbitMQ...")
|
||||
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
|
||||
self.shared_event_loop.run_until_complete(self.rabbitmq_service.connect())
|
||||
if self.use_supabase:
|
||||
from supabase import create_client
|
||||
|
||||
secrets = Secrets()
|
||||
self.supabase = create_client(
|
||||
secrets.supabase_url, secrets.supabase_service_role_key
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
if self.use_db:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting DB...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
if self.use_redis:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
if self.rabbitmq_config:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting RabbitMQ...")
|
||||
|
||||
|
||||
class RemoteCallError(BaseModel):
|
||||
@@ -112,7 +268,7 @@ EXCEPTION_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
class AppService(BaseAppService, ABC):
|
||||
class FastApiAppService(BaseAppService, ABC):
|
||||
fastapi_app: FastAPI
|
||||
|
||||
@staticmethod
|
||||
@@ -168,16 +324,14 @@ class AppService(BaseAppService, ABC):
|
||||
|
||||
async def async_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
|
||||
return await f(
|
||||
**{name: getattr(body, name) for name in type(body).model_fields}
|
||||
**{name: getattr(body, name) for name in body.model_fields}
|
||||
)
|
||||
|
||||
return async_endpoint
|
||||
else:
|
||||
|
||||
def sync_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
|
||||
return f(
|
||||
**{name: getattr(body, name) for name in type(body).model_fields}
|
||||
)
|
||||
return f(**{name: getattr(body, name) for name in body.model_fields})
|
||||
|
||||
return sync_endpoint
|
||||
|
||||
@@ -197,7 +351,6 @@ class AppService(BaseAppService, ABC):
|
||||
self.shared_event_loop.run_until_complete(server.serve())
|
||||
|
||||
def run(self):
|
||||
sentry_init()
|
||||
super().run()
|
||||
self.fastapi_app = FastAPI()
|
||||
|
||||
@@ -228,13 +381,62 @@ class AppService(BaseAppService, ABC):
|
||||
self.run_service()
|
||||
|
||||
|
||||
# ----- Begin Pyro AppService Block ---- #
|
||||
|
||||
|
||||
class PyroAppService(BaseAppService, ABC):
|
||||
|
||||
@conn_retry("Pyro", "Starting Pyro Service")
|
||||
def __start_pyro(self):
|
||||
maximum_connection_thread_count = max(
|
||||
Pyro5.config.THREADPOOL_SIZE,
|
||||
config.num_node_workers * config.num_graph_workers,
|
||||
)
|
||||
|
||||
Pyro5.config.THREADPOOL_SIZE = maximum_connection_thread_count # type: ignore
|
||||
daemon = Pyro5.api.Daemon(host=api_host, port=self.get_port())
|
||||
self.uri = daemon.register(self, objectId=self.service_name)
|
||||
logger.info(f"[{self.service_name}] Connected to Pyro; URI = {self.uri}")
|
||||
daemon.requestLoop()
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
# Initialize the async loop.
|
||||
async_thread = threading.Thread(target=self.shared_event_loop.run_forever)
|
||||
async_thread.daemon = True
|
||||
async_thread.start()
|
||||
|
||||
# Initialize pyro service
|
||||
daemon_thread = threading.Thread(target=self.__start_pyro)
|
||||
daemon_thread.daemon = True
|
||||
daemon_thread.start()
|
||||
|
||||
# Run the main service loop (blocking).
|
||||
self.run_service()
|
||||
|
||||
|
||||
if config.use_http_based_rpc:
|
||||
|
||||
class AppService(FastApiAppService, ABC): # type: ignore #AppService defined twice
|
||||
pass
|
||||
|
||||
else:
|
||||
|
||||
class AppService(PyroAppService, ABC):
|
||||
pass
|
||||
|
||||
|
||||
# ----- End Pyro AppService Block ---- #
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# HTTP Client utilities for dynamic service client abstraction
|
||||
# --------------------------------------------------
|
||||
AS = TypeVar("AS", bound=AppService)
|
||||
|
||||
|
||||
def close_service_client(client: Any) -> None:
|
||||
def fastapi_close_service_client(client: Any) -> None:
|
||||
if hasattr(client, "close"):
|
||||
client.close()
|
||||
else:
|
||||
@@ -242,7 +444,7 @@ def close_service_client(client: Any) -> None:
|
||||
|
||||
|
||||
@conn_retry("FastAPI client", "Creating service client", max_retry=api_comm_retry)
|
||||
def get_service_client(
|
||||
def fastapi_get_service_client(
|
||||
service_type: Type[AS],
|
||||
call_timeout: int | None = api_call_timeout,
|
||||
) -> AS:
|
||||
@@ -302,3 +504,93 @@ def get_service_client(
|
||||
client.health_check()
|
||||
|
||||
return cast(AS, client)
|
||||
|
||||
|
||||
# ----- Begin Pyro Client Block ---- #
|
||||
class PyroClient:
|
||||
proxy: Pyro5.api.Proxy
|
||||
|
||||
|
||||
def pyro_close_service_client(client: BaseAppService) -> None:
|
||||
if isinstance(client, PyroClient):
|
||||
client.proxy._pyroRelease()
|
||||
else:
|
||||
raise RuntimeError(f"Client {client.__class__} is not a Pyro client.")
|
||||
|
||||
|
||||
def pyro_get_service_client(service_type: Type[AS]) -> AS:
|
||||
service_name = service_type.service_name
|
||||
|
||||
class DynamicClient(PyroClient):
|
||||
@conn_retry("Pyro", f"Connecting to [{service_name}]")
|
||||
def __init__(self):
|
||||
uri = f"PYRO:{service_type.service_name}@{service_type.get_host()}:{service_type.get_port()}"
|
||||
logger.debug(f"Connecting to service [{service_name}]. URI = {uri}")
|
||||
self.proxy = Pyro5.api.Proxy(uri)
|
||||
# Attempt to bind to ensure the connection is established
|
||||
self.proxy._pyroBind()
|
||||
logger.debug(f"Successfully connected to service [{service_name}]")
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
res = getattr(self.proxy, name)
|
||||
return res
|
||||
|
||||
return cast(AS, DynamicClient())
|
||||
|
||||
|
||||
builtin_types = [*vars(builtins).values(), NoneType, Enum]
|
||||
|
||||
|
||||
def _pydantic_models_from_type_annotation(annotation) -> Iterator[type[BaseModel]]:
|
||||
# Peel Annotated parameters
|
||||
if (origin := get_origin(annotation)) and origin is Annotated:
|
||||
annotation = get_args(annotation)[0]
|
||||
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if origin in (
|
||||
Union,
|
||||
UnionType,
|
||||
list,
|
||||
List,
|
||||
tuple,
|
||||
Tuple,
|
||||
set,
|
||||
Set,
|
||||
frozenset,
|
||||
FrozenSet,
|
||||
):
|
||||
for arg in args:
|
||||
yield from _pydantic_models_from_type_annotation(arg)
|
||||
elif origin in (dict, Dict):
|
||||
key_type, value_type = args
|
||||
yield from _pydantic_models_from_type_annotation(key_type)
|
||||
yield from _pydantic_models_from_type_annotation(value_type)
|
||||
elif origin in (Awaitable, Coroutine):
|
||||
# For coroutines and awaitables, check the return type
|
||||
return_type = args[-1]
|
||||
yield from _pydantic_models_from_type_annotation(return_type)
|
||||
else:
|
||||
annotype = annotation if origin is None else origin
|
||||
|
||||
# Exclude generic types and aliases
|
||||
if (
|
||||
annotype is not None
|
||||
and not hasattr(typing, getattr(annotype, "__name__", ""))
|
||||
and isinstance(annotype, type)
|
||||
):
|
||||
if issubclass(annotype, BaseModel):
|
||||
yield annotype
|
||||
elif annotype not in builtin_types and not issubclass(annotype, Enum):
|
||||
raise TypeError(f"Unsupported type encountered: {annotype}")
|
||||
|
||||
|
||||
if config.use_http_based_rpc:
|
||||
close_service_client = fastapi_close_service_client
|
||||
get_service_client = fastapi_get_service_client
|
||||
else:
|
||||
close_service_client = pyro_close_service_client
|
||||
get_service_client = pyro_get_service_client
|
||||
|
||||
# ----- End Pyro Client Block ---- #
|
||||
|
||||
@@ -31,12 +31,12 @@ class UpdateTrackingModel(BaseModel, Generic[T]):
|
||||
_updated_fields: Set[str] = PrivateAttr(default_factory=set)
|
||||
|
||||
def __setattr__(self, name: str, value) -> None:
|
||||
if name in UpdateTrackingModel.model_fields:
|
||||
if name in self.model_fields:
|
||||
self._updated_fields.add(name)
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def mark_updated(self, field_name: str) -> None:
|
||||
if field_name in UpdateTrackingModel.model_fields:
|
||||
if field_name in self.model_fields:
|
||||
self._updated_fields.add(field_name)
|
||||
|
||||
def clear_updates(self) -> None:
|
||||
@@ -65,6 +65,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
le=1000,
|
||||
description="Maximum number of workers to use for node execution within a single graph.",
|
||||
)
|
||||
use_http_based_rpc: bool = Field(
|
||||
default=True,
|
||||
description="Whether to use HTTP-based RPC for communication between services.",
|
||||
)
|
||||
pyro_host: str = Field(
|
||||
default="localhost",
|
||||
description="The default hostname of the Pyro server.",
|
||||
@@ -137,10 +141,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=8002,
|
||||
description="The port for execution manager daemon to run on",
|
||||
)
|
||||
execution_manager_loop_max_retry: int = Field(
|
||||
default=5,
|
||||
description="The maximum number of retries for the execution manager loop",
|
||||
)
|
||||
|
||||
execution_scheduler_port: int = Field(
|
||||
default=8003,
|
||||
|
||||
@@ -182,7 +182,6 @@ def _try_convert(value: Any, target_type: Type, raise_on_mismatch: bool) -> Any:
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
TT = TypeVar("TT")
|
||||
|
||||
|
||||
def type_match(value: Any, target_type: Type[T]) -> T:
|
||||
|
||||
@@ -21,27 +21,18 @@ def run(*command: str) -> None:
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.output.decode("utf-8"), file=sys.stderr)
|
||||
raise
|
||||
|
||||
|
||||
def lint():
|
||||
lint_step_args: list[list[str]] = [
|
||||
["ruff", "check", *TARGET_DIRS, "--exit-zero"],
|
||||
["ruff", "format", "--diff", "--check", LIBS_DIR],
|
||||
["isort", "--diff", "--check", "--profile", "black", BACKEND_DIR],
|
||||
["black", "--diff", "--check", BACKEND_DIR],
|
||||
["pyright", *TARGET_DIRS],
|
||||
]
|
||||
lint_error = None
|
||||
for args in lint_step_args:
|
||||
try:
|
||||
run(*args)
|
||||
except subprocess.CalledProcessError as e:
|
||||
lint_error = e
|
||||
|
||||
if lint_error:
|
||||
print("Lint failed, try running `poetry run format` to fix the issues")
|
||||
sys.exit(1)
|
||||
try:
|
||||
run("ruff", "check", *TARGET_DIRS, "--exit-zero")
|
||||
run("ruff", "format", "--diff", "--check", LIBS_DIR)
|
||||
run("isort", "--diff", "--check", "--profile", "black", BACKEND_DIR)
|
||||
run("black", "--diff", "--check", BACKEND_DIR)
|
||||
run("pyright", *TARGET_DIRS)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Lint failed, try running `poetry run format` to fix the issues: ", e)
|
||||
raise e
|
||||
|
||||
|
||||
def format():
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
/*
|
||||
Warnings:
|
||||
- The relation LibraryAgent:AgentPreset was REMOVED
|
||||
- A unique constraint covering the columns `[userId,agentGraphId,agentGraphVersion]` on the table `LibraryAgent` will be added. If there are existing duplicate values, this will fail.
|
||||
- The foreign key constraints on AgentPreset and LibraryAgent are being changed from CASCADE to RESTRICT for AgentGraph deletion, which means you cannot delete AgentGraphs that have associated LibraryAgents or AgentPresets.
|
||||
|
||||
Use the following query to check whether these conditions are satisfied:
|
||||
|
||||
-- Check for duplicate LibraryAgent userId + agentGraphId + agentGraphVersion combinations that would violate the new unique constraint
|
||||
SELECT la."userId",
|
||||
la."agentId" as graph_id,
|
||||
la."agentVersion" as graph_version,
|
||||
COUNT(*) as multiplicity
|
||||
FROM "LibraryAgent" la
|
||||
GROUP BY la."userId",
|
||||
la."agentId",
|
||||
la."agentVersion"
|
||||
HAVING COUNT(*) > 1;
|
||||
*/
|
||||
|
||||
-- Drop foreign key constraints on columns we're about to rename
|
||||
ALTER TABLE "AgentPreset" DROP CONSTRAINT "AgentPreset_agentId_agentVersion_fkey";
|
||||
ALTER TABLE "LibraryAgent" DROP CONSTRAINT "LibraryAgent_agentId_agentVersion_fkey";
|
||||
ALTER TABLE "LibraryAgent" DROP CONSTRAINT "LibraryAgent_agentPresetId_fkey";
|
||||
|
||||
-- Rename columns in AgentPreset
|
||||
ALTER TABLE "AgentPreset" RENAME COLUMN "agentId" TO "agentGraphId";
|
||||
ALTER TABLE "AgentPreset" RENAME COLUMN "agentVersion" TO "agentGraphVersion";
|
||||
|
||||
-- Rename columns in LibraryAgent
|
||||
ALTER TABLE "LibraryAgent" RENAME COLUMN "agentId" TO "agentGraphId";
|
||||
ALTER TABLE "LibraryAgent" RENAME COLUMN "agentVersion" TO "agentGraphVersion";
|
||||
|
||||
-- Drop LibraryAgent.agentPresetId column
|
||||
ALTER TABLE "LibraryAgent" DROP COLUMN "agentPresetId";
|
||||
|
||||
-- Replace userId index with unique index on userId + agentGraphId + agentGraphVersion
|
||||
DROP INDEX "LibraryAgent_userId_idx";
|
||||
CREATE UNIQUE INDEX "LibraryAgent_userId_agentGraphId_agentGraphVersion_key" ON "LibraryAgent"("userId", "agentGraphId", "agentGraphVersion");
|
||||
|
||||
-- Re-add the foreign key constraints with new column names
|
||||
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_agentGraphId_agentGraphVersion_fkey"
|
||||
FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph"("id", "version")
|
||||
ON DELETE RESTRICT -- Disallow deleting AgentGraph when still referenced by existing LibraryAgents
|
||||
ON UPDATE CASCADE;
|
||||
|
||||
ALTER TABLE "AgentPreset" ADD CONSTRAINT "AgentPreset_agentGraphId_agentGraphVersion_fkey"
|
||||
FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph"("id", "version")
|
||||
ON DELETE RESTRICT -- Disallow deleting AgentGraph when still referenced by existing AgentPresets
|
||||
ON UPDATE CASCADE;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user