Merge branch 'dev' into abhi-9274/postgres-integration

This commit is contained in:
Abhimanyu Yadav
2025-05-05 13:44:14 +05:30
committed by GitHub
265 changed files with 7039 additions and 3371 deletions

View File

@@ -34,6 +34,7 @@ jobs:
python -m prisma migrate deploy
env:
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
trigger:

View File

@@ -36,6 +36,7 @@ jobs:
python -m prisma migrate deploy
env:
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
trigger:
needs: migrate

View File

@@ -135,6 +135,7 @@ jobs:
run: poetry run prisma migrate dev --name updates
env:
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
- id: lint
name: Run Linter
@@ -151,12 +152,13 @@ jobs:
env:
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
REDIS_HOST: 'localhost'
REDIS_PORT: '6379'
REDIS_PASSWORD: 'testpassword'
REDIS_HOST: "localhost"
REDIS_PORT: "6379"
REDIS_PASSWORD: "testpassword"
env:
CI: true
@@ -169,8 +171,8 @@ jobs:
# If you want to replace this, you can do so by making our entire system generate
# new credentials for each local user and update the environment variables in
# the backend service, docker composes, and examples
RABBITMQ_DEFAULT_USER: 'rabbitmq_user_default'
RABBITMQ_DEFAULT_PASS: 'k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7'
RABBITMQ_DEFAULT_USER: "rabbitmq_user_default"
RABBITMQ_DEFAULT_PASS: "k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7"
# - name: Upload coverage reports to Codecov
# uses: codecov/codecov-action@v4

View File

@@ -16,7 +16,7 @@ jobs:
# operations-per-run: 5000
stale-issue-message: >
This issue has automatically been marked as _stale_ because it has not had
any activity in the last 50 days. You can _unstale_ it by commenting or
any activity in the last 170 days. You can _unstale_ it by commenting or
removing the label. Otherwise, this issue will be closed in 10 days.
stale-pr-message: >
This pull request has automatically been marked as _stale_ because it has
@@ -25,7 +25,7 @@ jobs:
close-issue-message: >
This issue was closed automatically because it has been stale for 10 days
with no activity.
days-before-stale: 100
days-before-stale: 170
days-before-close: 10
# Do not touch meta issues:
exempt-issue-labels: meta,fridge,project management

View File

@@ -15,7 +15,11 @@
> Setting up and hosting the AutoGPT Platform yourself is a technical process.
> If you'd rather something that just works, we recommend [joining the waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta.
https://github.com/user-attachments/assets/d04273a5-b36a-4a37-818e-f631ce72d603
### Updated Setup Instructions:
Weve moved to a fully maintained and regularly updated documentation site.
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
This tutorial assumes you have Docker, VSCode, git and npm installed.

View File

@@ -8,7 +8,7 @@ from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from .filters import BelowLevelFilter
from .formatters import AGPTFormatter, StructuredLoggingFormatter
from .formatters import AGPTFormatter
LOG_DIR = Path(__file__).parent.parent.parent.parent / "logs"
LOG_FILE = "activity.log"
@@ -81,9 +81,26 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
"""
config = LoggingConfig()
log_handlers: list[logging.Handler] = []
# Console output handlers
stdout = logging.StreamHandler(stream=sys.stdout)
stdout.setLevel(config.level)
stdout.addFilter(BelowLevelFilter(logging.WARNING))
if config.level == logging.DEBUG:
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
else:
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
stderr = logging.StreamHandler()
stderr.setLevel(logging.WARNING)
if config.level == logging.DEBUG:
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
else:
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
log_handlers += [stdout, stderr]
# Cloud logging setup
if config.enable_cloud_logging or force_cloud_logging:
import google.cloud.logging
@@ -97,26 +114,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
transport=SyncTransport,
)
cloud_handler.setLevel(config.level)
cloud_handler.setFormatter(StructuredLoggingFormatter())
log_handlers.append(cloud_handler)
else:
# Console output handlers
stdout = logging.StreamHandler(stream=sys.stdout)
stdout.setLevel(config.level)
stdout.addFilter(BelowLevelFilter(logging.WARNING))
if config.level == logging.DEBUG:
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
else:
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
stderr = logging.StreamHandler()
stderr.setLevel(logging.WARNING)
if config.level == logging.DEBUG:
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
else:
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
log_handlers += [stdout, stderr]
# File logging setup
if config.enable_file_logging:

View File

@@ -1,7 +1,6 @@
import logging
from colorama import Fore, Style
from google.cloud.logging_v2.handlers import CloudLoggingFilter, StructuredLogHandler
from .utils import remove_color_codes
@@ -80,16 +79,3 @@ class AGPTFormatter(FancyConsoleFormatter):
return remove_color_codes(super().format(record))
else:
return super().format(record)
class StructuredLoggingFormatter(StructuredLogHandler, logging.Formatter):
def __init__(self):
# Set up CloudLoggingFilter to add diagnostic info to the log records
self.cloud_logging_filter = CloudLoggingFilter()
# Init StructuredLogHandler
super().__init__()
def format(self, record: logging.LogRecord) -> str:
self.cloud_logging_filter.filter(record)
return super().format(record)

View File

@@ -2,6 +2,7 @@ import logging
import re
from typing import Any
import uvicorn.config
from colorama import Fore
@@ -25,3 +26,14 @@ def print_attribute(
"color": value_color,
},
)
def generate_uvicorn_config():
"""
Generates a uvicorn logging config that silences uvicorn's default logging and tells it to use the native logging module.
"""
log_config = dict(uvicorn.config.LOGGING_CONFIG)
log_config["loggers"]["uvicorn"] = {"handlers": []}
log_config["loggers"]["uvicorn.error"] = {"handlers": []}
log_config["loggers"]["uvicorn.access"] = {"handlers": []}
return log_config

View File

@@ -1,20 +1,59 @@
import inspect
import threading
from typing import Callable, ParamSpec, TypeVar
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
P = ParamSpec("P")
R = TypeVar("R")
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
@overload
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
@overload
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
def thread_cached(
func: Callable[P, R] | Callable[P, Awaitable[R]],
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
thread_local = threading.local()
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
def _clear():
if hasattr(thread_local, "cache"):
del thread_local.cache
return wrapper
if inspect.iscoroutinefunction(func):
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
*args, **kwargs
)
return cache[key]
setattr(async_wrapper, "clear_cache", _clear)
return async_wrapper
else:
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
setattr(sync_wrapper, "clear_cache", _clear)
return sync_wrapper
def clear_thread_cache(func: Callable) -> None:
if clear := getattr(func, "clear_cache", None):
clear()

View File

@@ -31,7 +31,7 @@ class RedisKeyedMutex:
try:
yield
finally:
if lock.locked():
if lock.locked() and lock.owned():
lock.release()
def acquire(self, key: Any) -> "RedisLock":

View File

@@ -8,6 +8,7 @@ DB_CONNECT_TIMEOUT=60
DB_POOL_TIMEOUT=300
DB_SCHEMA=platform
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
PRISMA_SCHEMA="postgres/schema.prisma"
# EXECUTOR

View File

@@ -73,7 +73,6 @@ FROM server_dependencies AS server
COPY autogpt_platform/backend /app/autogpt_platform/backend
RUN poetry install --no-ansi --only-root
ENV DATABASE_URL=""
ENV PORT=8000
CMD ["poetry", "run", "rest"]

View File

@@ -1,8 +1,6 @@
import logging
from typing import Any
from autogpt_libs.utils.cache import thread_cached
from backend.data.block import (
Block,
BlockCategory,
@@ -19,21 +17,6 @@ from backend.util import json
logger = logging.getLogger(__name__)
@thread_cached
def get_executor_manager_client():
from backend.executor import ExecutionManager
from backend.util.service import get_service_client
return get_service_client(ExecutionManager)
@thread_cached
def get_event_bus():
from backend.data.execution import RedisExecutionEventBus
return RedisExecutionEventBus()
class AgentExecutorBlock(Block):
class Input(BlockSchema):
user_id: str = SchemaField(description="User ID")
@@ -76,23 +59,23 @@ class AgentExecutorBlock(Block):
def run(self, input_data: Input, **kwargs) -> BlockOutput:
from backend.data.execution import ExecutionEventType
from backend.executor import utils as execution_utils
executor_manager = get_executor_manager_client()
event_bus = get_event_bus()
event_bus = execution_utils.get_execution_event_bus()
graph_exec = executor_manager.add_execution(
graph_exec = execution_utils.add_graph_execution(
graph_id=input_data.graph_id,
graph_version=input_data.graph_version,
user_id=input_data.user_id,
data=input_data.data,
inputs=input_data.data,
)
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.graph_exec_id}"
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.id}"
logger.info(f"Starting execution of {log_id}")
for event in event_bus.listen(
user_id=graph_exec.user_id,
graph_id=graph_exec.graph_id,
graph_exec_id=graph_exec.graph_exec_id,
graph_exec_id=graph_exec.id,
):
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
if event.status in [
@@ -105,7 +88,7 @@ class AgentExecutorBlock(Block):
else:
continue
logger.info(
logger.debug(
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
)
@@ -123,5 +106,7 @@ class AgentExecutorBlock(Block):
continue
for output_data in event.output_data.get("output", []):
logger.info(f"Execution {log_id} produced {output_name}: {output_data}")
logger.debug(
f"Execution {log_id} produced {output_name}: {output_data}"
)
yield output_name, output_data

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from backend.data.model import SchemaField
@@ -143,11 +143,12 @@ class ContactEmail(BaseModel):
class EmploymentHistory(BaseModel):
"""An employment history in Apollo"""
class Config:
extra = "allow"
arbitrary_types_allowed = True
from_attributes = True
populate_by_name = True
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
from_attributes=True,
populate_by_name=True,
)
_id: Optional[str] = None
created_at: Optional[str] = None
@@ -188,11 +189,12 @@ class TypedCustomField(BaseModel):
class Pagination(BaseModel):
"""Pagination in Apollo"""
class Config:
extra = "allow" # Allow extra fields
arbitrary_types_allowed = True # Allow any type
from_attributes = True # Allow from_orm
populate_by_name = True # Allow field aliases to work both ways
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
from_attributes=True,
populate_by_name=True,
)
page: int = 0
per_page: int = 0
@@ -230,11 +232,12 @@ class PhoneNumber(BaseModel):
class Organization(BaseModel):
"""An organization in Apollo"""
class Config:
extra = "allow"
arbitrary_types_allowed = True
from_attributes = True
populate_by_name = True
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
from_attributes=True,
populate_by_name=True,
)
id: Optional[str] = "N/A"
name: Optional[str] = "N/A"
@@ -268,11 +271,12 @@ class Organization(BaseModel):
class Contact(BaseModel):
"""A contact in Apollo"""
class Config:
extra = "allow"
arbitrary_types_allowed = True
from_attributes = True
populate_by_name = True
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
from_attributes=True,
populate_by_name=True,
)
contact_roles: list[Any] = []
id: Optional[str] = None
@@ -369,14 +373,14 @@ If a company has several office locations, results are still based on the headqu
To exclude companies based on location, use the organization_not_locations parameter.
""",
default=[],
default_factory=list,
)
organizations_not_locations: list[str] = SchemaField(
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
""",
default=[],
default_factory=list,
)
q_organization_keyword_tags: list[str] = SchemaField(
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
@@ -390,7 +394,7 @@ If the value you enter for this parameter does not match with a company's name,
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, identify the values for organization_id when you call this endpoint.""",
default=[],
default_factory=list,
)
max_results: int = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
@@ -443,14 +447,14 @@ Results also include job titles with the same terms, even if they are not exact
Use this parameter in combination with the person_seniorities[] parameter to find people based on specific job functions and seniority levels.
""",
default=[],
default_factory=list,
placeholder="marketing manager",
)
person_locations: list[str] = SchemaField(
description="""The location where people live. You can search across cities, US states, and countries.
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
default=[],
default_factory=list,
)
person_seniorities: list[SenorityLevels] = SchemaField(
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
@@ -460,7 +464,7 @@ For a person to be included in search results, they only need to match 1 of the
Searches only return results based on their current job title, so searching for Director-level employees only returns people that currently hold a Director-level title. If someone was previously a Director, but is currently a VP, they would not be included in your search results.
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
default=[],
default_factory=list,
)
organization_locations: list[str] = SchemaField(
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
@@ -468,7 +472,7 @@ Use this parameter in combination with the person_titles[] parameter to find peo
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
To find people based on their personal location, use the person_locations parameter.""",
default=[],
default_factory=list,
)
q_organization_domains: list[str] = SchemaField(
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
@@ -476,23 +480,23 @@ To find people based on their personal location, use the person_locations parame
You can add multiple domains to search across companies.
Examples: apollo.io and microsoft.com""",
default=[],
default_factory=list,
)
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
default=[],
default_factory=list,
)
organization_ids: list[str] = SchemaField(
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
default=[],
default_factory=list,
)
organization_num_empoloyees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
default=[],
default_factory=list,
)
q_keywords: str = SchemaField(
description="""A string of words over which we want to filter the results""",
@@ -522,11 +526,12 @@ Use the page parameter to search the different pages of data.""",
class SearchPeopleResponse(BaseModel):
"""Response from Apollo's search people API"""
class Config:
extra = "allow" # Allow extra fields
arbitrary_types_allowed = True # Allow any type
from_attributes = True # Allow from_orm
populate_by_name = True # Allow field aliases to work both ways
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
from_attributes=True,
populate_by_name=True,
)
breadcrumbs: list[Breadcrumb] = []
partial_results_only: bool = True

View File

@@ -32,18 +32,18 @@ If a company has several office locations, results are still based on the headqu
To exclude companies based on location, use the organization_not_locations parameter.
""",
default=[],
default_factory=list,
)
organizations_not_locations: list[str] = SchemaField(
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
""",
default=[],
default_factory=list,
)
q_organization_keyword_tags: list[str] = SchemaField(
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
default=[],
default_factory=list,
)
q_organization_name: str = SchemaField(
description="""Filter search results to include a specific company name.
@@ -56,7 +56,7 @@ If the value you enter for this parameter does not match with a company's name,
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, identify the values for organization_id when you call this endpoint.""",
default=[],
default_factory=list,
)
max_results: int = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
@@ -72,7 +72,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
class Output(BlockSchema):
organizations: list[Organization] = SchemaField(
description="List of organizations found",
default=[],
default_factory=list,
)
organization: Organization = SchemaField(
description="Each found organization, one at a time",

View File

@@ -26,14 +26,14 @@ class SearchPeopleBlock(Block):
Use this parameter in combination with the person_seniorities[] parameter to find people based on specific job functions and seniority levels.
""",
default=[],
default_factory=list,
advanced=False,
)
person_locations: list[str] = SchemaField(
description="""The location where people live. You can search across cities, US states, and countries.
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
default=[],
default_factory=list,
advanced=False,
)
person_seniorities: list[SenorityLevels] = SchemaField(
@@ -44,7 +44,7 @@ class SearchPeopleBlock(Block):
Searches only return results based on their current job title, so searching for Director-level employees only returns people that currently hold a Director-level title. If someone was previously a Director, but is currently a VP, they would not be included in your search results.
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
default=[],
default_factory=list,
advanced=False,
)
organization_locations: list[str] = SchemaField(
@@ -53,7 +53,7 @@ class SearchPeopleBlock(Block):
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
To find people based on their personal location, use the person_locations parameter.""",
default=[],
default_factory=list,
advanced=False,
)
q_organization_domains: list[str] = SchemaField(
@@ -62,26 +62,26 @@ class SearchPeopleBlock(Block):
You can add multiple domains to search across companies.
Examples: apollo.io and microsoft.com""",
default=[],
default_factory=list,
advanced=False,
)
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
default=[],
default_factory=list,
advanced=False,
)
organization_ids: list[str] = SchemaField(
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
default=[],
default_factory=list,
advanced=False,
)
organization_num_empoloyees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
default=[],
default_factory=list,
advanced=False,
)
q_keywords: str = SchemaField(
@@ -104,7 +104,7 @@ class SearchPeopleBlock(Block):
class Output(BlockSchema):
people: list[Contact] = SchemaField(
description="List of people found",
default=[],
default_factory=list,
)
person: Contact = SchemaField(
description="Each found person, one at a time",

View File

@@ -88,6 +88,33 @@ class StoreValueBlock(Block):
yield "output", input_data.data or input_data.input
class PrintToConsoleBlock(Block):
class Input(BlockSchema):
text: Any = SchemaField(description="The data to print to the console.")
class Output(BlockSchema):
output: Any = SchemaField(description="The data printed to the console.")
status: str = SchemaField(description="The status of the print operation.")
def __init__(self):
super().__init__(
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
description="Print the given text to the console, this is used for a debugging purpose.",
categories={BlockCategory.BASIC},
input_schema=PrintToConsoleBlock.Input,
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
test_output=[
("output", "Hello, World!"),
("status", "printed"),
],
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "output", input_data.text
yield "status", "printed"
class FindInDictionaryBlock(Block):
class Input(BlockSchema):
input: Any = SchemaField(description="Dictionary to lookup from")
@@ -151,7 +178,7 @@ class FindInDictionaryBlock(Block):
class AddToDictionaryBlock(Block):
class Input(BlockSchema):
dictionary: dict[Any, Any] = SchemaField(
default={},
default_factory=dict,
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
)
key: str = SchemaField(
@@ -167,7 +194,7 @@ class AddToDictionaryBlock(Block):
advanced=False,
)
entries: dict[Any, Any] = SchemaField(
default={},
default_factory=dict,
description="The entries to add to the dictionary. This is the batch version of the `key` and `value` fields.",
advanced=True,
)
@@ -229,7 +256,7 @@ class AddToDictionaryBlock(Block):
class AddToListBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(
default=[],
default_factory=list,
advanced=False,
description="The list to add the entry to. If not provided, a new list will be created.",
)
@@ -239,7 +266,7 @@ class AddToListBlock(Block):
default=None,
)
entries: List[Any] = SchemaField(
default=[],
default_factory=lambda: list(),
description="The entries to add to the list. This is the batch version of the `entry` field.",
advanced=True,
)

View File

@@ -55,7 +55,7 @@ class CodeExecutionBlock(Block):
"These commands are executed with `sh`, in the foreground."
),
placeholder="pip install cowsay",
default=[],
default_factory=list,
advanced=False,
)
@@ -207,7 +207,7 @@ class InstantiationBlock(Block):
"These commands are executed with `sh`, in the foreground."
),
placeholder="pip install cowsay",
default=[],
default_factory=list,
advanced=False,
)

View File

@@ -34,7 +34,7 @@ class ReadCsvBlock(Block):
)
skip_columns: list[str] = SchemaField(
description="The columns to skip from the start of the row",
default=[],
default_factory=list,
)
class Output(BlockSchema):

View File

@@ -49,7 +49,7 @@ class ExaContentsBlock(Block):
class Output(BlockSchema):
results: list = SchemaField(
description="List of document contents",
default=[],
default_factory=list,
)
error: str = SchemaField(description="Error message if the request failed")

View File

@@ -38,11 +38,11 @@ class ExaSearchBlock(Block):
)
include_domains: List[str] = SchemaField(
description="Domains to include in search",
default=[],
default_factory=list,
)
exclude_domains: List[str] = SchemaField(
description="Domains to exclude from search",
default=[],
default_factory=list,
advanced=True,
)
start_crawl_date: datetime = SchemaField(
@@ -59,12 +59,12 @@ class ExaSearchBlock(Block):
)
include_text: List[str] = SchemaField(
description="Text patterns to include",
default=[],
default_factory=list,
advanced=True,
)
exclude_text: List[str] = SchemaField(
description="Text patterns to exclude",
default=[],
default_factory=list,
advanced=True,
)
contents: ContentSettings = SchemaField(
@@ -76,7 +76,7 @@ class ExaSearchBlock(Block):
class Output(BlockSchema):
results: list = SchemaField(
description="List of search results",
default=[],
default_factory=list,
)
def __init__(self):

View File

@@ -26,12 +26,12 @@ class ExaFindSimilarBlock(Block):
)
include_domains: List[str] = SchemaField(
description="Domains to include in search",
default=[],
default_factory=list,
advanced=True,
)
exclude_domains: List[str] = SchemaField(
description="Domains to exclude from search",
default=[],
default_factory=list,
advanced=True,
)
start_crawl_date: datetime = SchemaField(
@@ -48,12 +48,12 @@ class ExaFindSimilarBlock(Block):
)
include_text: List[str] = SchemaField(
description="Text patterns to include (max 1 string, up to 5 words)",
default=[],
default_factory=list,
advanced=True,
)
exclude_text: List[str] = SchemaField(
description="Text patterns to exclude (max 1 string, up to 5 words)",
default=[],
default_factory=list,
advanced=True,
)
contents: ContentSettings = SchemaField(
@@ -65,7 +65,7 @@ class ExaFindSimilarBlock(Block):
class Output(BlockSchema):
results: List[Any] = SchemaField(
description="List of similar documents with title, URL, published date, author, and score",
default=[],
default_factory=list,
)
def __init__(self):

View File

@@ -42,7 +42,7 @@ class AIVideoGeneratorBlock(Block):
description="Error message if video generation failed."
)
logs: list[str] = SchemaField(
description="Generation progress logs.", optional=True
description="Generation progress logs.",
)
def __init__(self):

View File

@@ -12,10 +12,10 @@ from backend.integrations.webhooks.generic import GenericWebhookType
class GenericWebhookTriggerBlock(Block):
class Input(BlockSchema):
payload: dict = SchemaField(hidden=True, default={})
payload: dict = SchemaField(hidden=True, default_factory=dict)
constants: dict = SchemaField(
description="The constants to be set when the block is put on the graph",
default={},
default_factory=dict,
)
class Output(BlockSchema):

View File

@@ -37,7 +37,7 @@ class GitHubTriggerBase:
placeholder="{owner}/{repo}",
)
# --8<-- [start:example-payload-field]
payload: dict = SchemaField(hidden=True, default={})
payload: dict = SchemaField(hidden=True, default_factory=dict)
# --8<-- [end:example-payload-field]
class Output(BlockSchema):

View File

@@ -3,7 +3,7 @@ from googleapiclient.discovery import build
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.settings import Settings
from backend.util.settings import AppEnvironment, Settings
from ._auth import (
GOOGLE_OAUTH_IS_CONFIGURED,
@@ -36,13 +36,15 @@ class GoogleSheetsReadBlock(Block):
)
def __init__(self):
settings = Settings()
super().__init__(
id="5724e902-3635-47e9-a108-aaa0263a4988",
description="This block reads data from a Google Sheets spreadsheet.",
categories={BlockCategory.DATA},
input_schema=GoogleSheetsReadBlock.Input,
output_schema=GoogleSheetsReadBlock.Output,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED
or settings.config.app_env == AppEnvironment.PRODUCTION,
test_input={
"spreadsheet_id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms",
"range": "Sheet1!A1:B2",

View File

@@ -34,7 +34,7 @@ class SendWebRequestBlock(Block):
)
headers: dict[str, str] = SchemaField(
description="The headers to include in the request",
default={},
default_factory=dict,
)
json_format: bool = SchemaField(
title="JSON format",
@@ -82,7 +82,15 @@ class SendWebRequestBlock(Block):
json=body if input_data.json_format else None,
data=body if not input_data.json_format else None,
)
result = response.json() if input_data.json_format else response.text
if input_data.json_format:
if response.status_code == 204 or not response.content.strip():
result = None
else:
result = response.json()
else:
result = response.text
yield "response", result
except HTTPError as e:

View File

@@ -15,7 +15,8 @@ class HubSpotCompanyBlock(Block):
description="Operation to perform (create, update, get)", default="get"
)
company_data: dict = SchemaField(
description="Company data for create/update operations", default={}
description="Company data for create/update operations",
default_factory=dict,
)
domain: str = SchemaField(
description="Company domain for get/update operations", default=""

View File

@@ -15,7 +15,8 @@ class HubSpotContactBlock(Block):
description="Operation to perform (create, update, get)", default="get"
)
contact_data: dict = SchemaField(
description="Contact data for create/update operations", default={}
description="Contact data for create/update operations",
default_factory=dict,
)
email: str = SchemaField(
description="Email address for get/update operations", default=""

View File

@@ -19,7 +19,7 @@ class HubSpotEngagementBlock(Block):
)
email_data: dict = SchemaField(
description="Email data including recipient, subject, content",
default={},
default_factory=dict,
)
contact_id: str = SchemaField(
description="Contact ID for engagement tracking", default=""
@@ -27,7 +27,6 @@ class HubSpotEngagementBlock(Block):
timeframe_days: int = SchemaField(
description="Number of days to look back for engagement",
default=30,
optional=True,
)
class Output(BlockSchema):

View File

@@ -1,3 +1,4 @@
import copy
from datetime import date, time
from typing import Any, Optional
@@ -38,7 +39,7 @@ class AgentInputBlock(Block):
)
placeholder_values: list = SchemaField(
description="The placeholder values to be passed as input.",
default=[],
default_factory=list,
advanced=True,
hidden=True,
)
@@ -54,7 +55,7 @@ class AgentInputBlock(Block):
)
def generate_schema(self):
schema = self.get_field_schema("value")
schema = copy.deepcopy(self.get_field_schema("value"))
if possible_values := self.placeholder_values:
schema["enum"] = possible_values
return schema
@@ -467,7 +468,7 @@ class AgentDropdownInputBlock(AgentInputBlock):
)
placeholder_values: list = SchemaField(
description="Possible values for the dropdown.",
default=[],
default_factory=list,
advanced=False,
title="Dropdown Options",
)

View File

@@ -11,13 +11,13 @@ class StepThroughItemsBlock(Block):
advanced=False,
description="The list or dictionary of items to iterate over",
placeholder="[1, 2, 3, 4, 5] or {'key1': 'value1', 'key2': 'value2'}",
default=[],
default_factory=list,
)
items_object: dict = SchemaField(
advanced=False,
description="The list or dictionary of items to iterate over",
placeholder="[1, 2, 3, 4, 5] or {'key1': 'value1', 'key2': 'value2'}",
default={},
default_factory=dict,
)
items_str: str = SchemaField(
advanced=False,

View File

@@ -23,7 +23,7 @@ class JinaChunkingBlock(Block):
class Output(BlockSchema):
chunks: list = SchemaField(description="List of chunked texts")
tokens: list = SchemaField(
description="List of token information for each chunk", optional=True
description="List of token information for each chunk",
)
def __init__(self):

View File

@@ -1,4 +1,4 @@
from groq._utils._utils import quote
from urllib.parse import quote
from backend.blocks.jina._auth import (
TEST_CREDENTIALS,

View File

@@ -28,8 +28,8 @@ class LinearCreateIssueBlock(Block):
priority: int | None = SchemaField(
description="Priority of the issue",
default=None,
minimum=0,
maximum=4,
ge=0,
le=4,
)
project_name: str | None = SchemaField(
description="Name of the project to create the issue on",

View File

@@ -4,30 +4,24 @@ from abc import ABC
from enum import Enum, EnumMeta
from json import JSONDecodeError
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, NamedTuple, Optional
from pydantic import BaseModel, SecretStr
from backend.data.model import NodeExecutionStats
from backend.integrations.providers import ProviderName
if TYPE_CHECKING:
from enum import _EnumMemberT
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
import anthropic
import ollama
import openai
from anthropic._types import NotGiven
from anthropic.types import ToolParam
from groq import Groq
from pydantic import BaseModel, SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.settings import BehaveAs, Settings
from backend.util.text import TextFormatter
@@ -77,12 +71,10 @@ class ModelMetadata(NamedTuple):
class LlmModelMeta(EnumMeta):
@property
def __members__(
self: type["_EnumMemberT"],
) -> MappingProxyType[str, "_EnumMemberT"]:
def __members__(self) -> MappingProxyType:
if Settings().config.behave_as == BehaveAs.LOCAL:
members = super().__members__
return members
return MappingProxyType(members)
else:
removed_providers = ["ollama"]
existing_members = super().__members__
@@ -97,14 +89,17 @@ class LlmModelMeta(EnumMeta):
class LlmModel(str, Enum, metaclass=LlmModelMeta):
# OpenAI models
O3_MINI = "o3-mini"
O3 = "o3-2025-04-16"
O1 = "o1"
O1_PREVIEW = "o1-preview"
O1_MINI = "o1-mini"
GPT41 = "gpt-4.1-2025-04-14"
GPT4O_MINI = "gpt-4o-mini"
GPT4O = "gpt-4o"
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"
# Anthropic models
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
@@ -125,6 +120,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
# OpenRouter models
GEMINI_FLASH_1_5 = "google/gemini-flash-1.5"
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
GROK_BETA = "x-ai/grok-beta"
MISTRAL_NEMO = "mistralai/mistral-nemo"
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
@@ -142,6 +138,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
@property
def metadata(self) -> ModelMetadata:
@@ -162,12 +160,14 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
MODEL_METADATA = {
# https://platform.openai.com/docs/models
LlmModel.O3: ModelMetadata("openai", 200000, 100000),
LlmModel.O3_MINI: ModelMetadata("openai", 200000, 100000), # o3-mini-2025-01-31
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
LlmModel.O1_PREVIEW: ModelMetadata(
"openai", 128000, 32768
), # o1-preview-2024-09-12
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
LlmModel.GPT41: ModelMetadata("openai", 1047576, 32768),
LlmModel.GPT4O_MINI: ModelMetadata(
"openai", 128000, 16384
), # gpt-4o-mini-2024-07-18
@@ -177,6 +177,9 @@ MODEL_METADATA = {
), # gpt-4-turbo-2024-04-09
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
# https://docs.anthropic.com/en/docs/about-claude/models
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
"anthropic", 200000, 8192
), # claude-3-7-sonnet-20250219
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
"anthropic", 200000, 8192
), # claude-3-5-sonnet-20241022
@@ -202,6 +205,7 @@ MODEL_METADATA = {
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768, None),
# https://openrouter.ai/models
LlmModel.GEMINI_FLASH_1_5: ModelMetadata("open_router", 1000000, 8192),
LlmModel.GEMINI_2_5_PRO: ModelMetadata("open_router", 1050000, 8192),
LlmModel.GROK_BETA: ModelMetadata("open_router", 131072, 131072),
LlmModel.MISTRAL_NEMO: ModelMetadata("open_router", 128000, 4096),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata("open_router", 128000, 4096),
@@ -223,6 +227,8 @@ MODEL_METADATA = {
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata("open_router", 300000, 5120),
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata("open_router", 65536, 4096),
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata("open_router", 4096, 4096),
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata("open_router", 131072, 131072),
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata("open_router", 1048576, 1000000),
}
for model in LlmModel:
@@ -252,7 +258,7 @@ class LLMResponse(BaseModel):
def convert_openai_tool_fmt_to_anthropic(
openai_tools: list[dict] | None = None,
) -> Iterable[ToolParam] | NotGiven:
) -> Iterable[ToolParam] | anthropic.NotGiven:
"""
Convert OpenAI tool format to Anthropic tool format.
"""
@@ -282,6 +288,13 @@ def convert_openai_tool_fmt_to_anthropic(
return anthropic_tools
def estimate_token_count(prompt_messages: list[dict]) -> int:
char_count = sum(len(str(msg.get("content", ""))) for msg in prompt_messages)
message_overhead = len(prompt_messages) * 4
estimated_tokens = (char_count // 4) + message_overhead
return int(estimated_tokens * 1.2)
def llm_call(
credentials: APIKeyCredentials,
llm_model: LlmModel,
@@ -290,6 +303,7 @@ def llm_call(
max_tokens: int | None,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
parallel_tool_calls: bool | None = None,
) -> LLMResponse:
"""
Make a call to a language model.
@@ -312,7 +326,14 @@ def llm_call(
- completion_tokens: The number of tokens used in the completion.
"""
provider = llm_model.metadata.provider
max_tokens = max_tokens or llm_model.max_output_tokens or 4096
# Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt)
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or 4096
user_max = max_tokens or model_max_output
available_tokens = max(context_window - estimated_input_tokens, 0)
max_tokens = max(min(available_tokens, model_max_output, user_max), 0)
if provider == "openai":
tools_param = tools if tools else openai.NOT_GIVEN
@@ -335,6 +356,9 @@ def llm_call(
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=(
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
),
)
if response.choices[0].message.tool_calls:
@@ -424,7 +448,7 @@ def llm_call(
response=(
resp.content[0].name
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
else resp.content[0].text
else getattr(resp.content[0], "text", "")
),
tool_calls=tool_calls,
prompt_tokens=resp.usage.input_tokens,
@@ -465,6 +489,7 @@ def llm_call(
model=llm_model.value,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
options={"num_ctx": max_tokens},
)
return LLMResponse(
raw_response=response.get("response") or "",
@@ -490,6 +515,9 @@ def llm_call(
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=(
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
),
)
# If there's no response, raise an error
@@ -528,7 +556,7 @@ def llm_call(
class AIBlockBase(Block, ABC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt = ""
self.prompt = []
def merge_llm_stats(self, block: "AIBlockBase"):
self.merge_stats(block.execution_stats)
@@ -558,7 +586,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
description="The system prompt to provide additional context to the model.",
)
conversation_history: list[dict] = SchemaField(
default=[],
default_factory=list,
description="The conversation history to provide context for the prompt.",
)
retry: int = SchemaField(
@@ -568,7 +596,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
prompt_values: dict[str, str] = SchemaField(
advanced=False,
default={},
default_factory=dict,
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
)
max_tokens: int | None = SchemaField(
@@ -587,7 +615,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
response: dict[str, Any] = SchemaField(
description="The response object generated by the language model."
)
prompt: str = SchemaField(description="The prompt sent to the language model.")
prompt: list = SchemaField(description="The prompt sent to the language model.")
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
@@ -609,7 +637,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
test_credentials=TEST_CREDENTIALS,
test_output=[
("response", {"key1": "key1Value", "key2": "key2Value"}),
("prompt", str),
("prompt", list),
],
test_mock={
"llm_call": lambda *args, **kwargs: LLMResponse(
@@ -642,6 +670,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
Test mocks work only on class functions, this wraps the llm_call function
so that it can be mocked withing the block testing framework.
"""
self.prompt = prompt
return llm_call(
credentials=credentials,
llm_model=llm_model,
@@ -759,6 +788,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
prompt.append({"role": "user", "content": retry_prompt})
except Exception as e:
logger.exception(f"Error calling LLM: {e}")
if (
"maximum context length" in str(e).lower()
or "token limit" in str(e).lower()
):
if input_data.max_tokens is None:
input_data.max_tokens = llm_model.max_output_tokens or 4096
input_data.max_tokens = int(input_data.max_tokens * 0.85)
logger.debug(
f"Reducing max_tokens to {input_data.max_tokens} for next attempt"
)
retry_prompt = f"Error calling LLM: {e}"
finally:
self.merge_stats(
@@ -796,7 +835,7 @@ class AITextGeneratorBlock(AIBlockBase):
)
prompt_values: dict[str, str] = SchemaField(
advanced=False,
default={},
default_factory=dict,
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
)
ollama_host: str = SchemaField(
@@ -814,7 +853,7 @@ class AITextGeneratorBlock(AIBlockBase):
response: str = SchemaField(
description="The response generated by the language model."
)
prompt: str = SchemaField(description="The prompt sent to the language model.")
prompt: list = SchemaField(description="The prompt sent to the language model.")
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
@@ -831,7 +870,7 @@ class AITextGeneratorBlock(AIBlockBase):
test_credentials=TEST_CREDENTIALS,
test_output=[
("response", "Response text"),
("prompt", str),
("prompt", list),
],
test_mock={"llm_call": lambda *args, **kwargs: "Response text"},
)
@@ -850,7 +889,10 @@ class AITextGeneratorBlock(AIBlockBase):
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
object_input_data = AIStructuredResponseGeneratorBlock.Input(
**{attr: getattr(input_data, attr) for attr in input_data.model_fields},
**{
attr: getattr(input_data, attr)
for attr in AITextGeneratorBlock.Input.model_fields
},
expected_format={},
)
yield "response", self.llm_call(object_input_data, credentials)
@@ -907,7 +949,7 @@ class AITextSummarizerBlock(AIBlockBase):
class Output(BlockSchema):
summary: str = SchemaField(description="The final summary of the text.")
prompt: str = SchemaField(description="The prompt sent to the language model.")
prompt: list = SchemaField(description="The prompt sent to the language model.")
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
@@ -924,7 +966,7 @@ class AITextSummarizerBlock(AIBlockBase):
test_credentials=TEST_CREDENTIALS,
test_output=[
("summary", "Final summary of a long text"),
("prompt", str),
("prompt", list),
],
test_mock={
"llm_call": lambda input_data, credentials: (
@@ -1033,8 +1075,14 @@ class AITextSummarizerBlock(AIBlockBase):
class AIConversationBlock(AIBlockBase):
class Input(BlockSchema):
prompt: str = SchemaField(
description="The prompt to send to the language model.",
placeholder="Enter your prompt here...",
default="",
advanced=False,
)
messages: List[Any] = SchemaField(
description="List of messages in the conversation.", min_length=1
description="List of messages in the conversation.",
)
model: LlmModel = SchemaField(
title="LLM Model",
@@ -1057,7 +1105,7 @@ class AIConversationBlock(AIBlockBase):
response: str = SchemaField(
description="The model's response to the conversation."
)
prompt: str = SchemaField(description="The prompt sent to the language model.")
prompt: list = SchemaField(description="The prompt sent to the language model.")
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
@@ -1086,7 +1134,7 @@ class AIConversationBlock(AIBlockBase):
"response",
"The 2020 World Series was played at Globe Life Field in Arlington, Texas.",
),
("prompt", str),
("prompt", list),
],
test_mock={
"llm_call": lambda *args, **kwargs: "The 2020 World Series was played at Globe Life Field in Arlington, Texas."
@@ -1108,7 +1156,7 @@ class AIConversationBlock(AIBlockBase):
) -> BlockOutput:
response = self.llm_call(
AIStructuredResponseGeneratorBlock.Input(
prompt="",
prompt=input_data.prompt,
credentials=input_data.credentials,
model=input_data.model,
conversation_history=input_data.messages,
@@ -1166,7 +1214,7 @@ class AIListGeneratorBlock(AIBlockBase):
list_item: str = SchemaField(
description="Each individual item in the list.",
)
prompt: str = SchemaField(description="The prompt sent to the language model.")
prompt: list = SchemaField(description="The prompt sent to the language model.")
error: str = SchemaField(
description="Error message if the list generation failed."
)
@@ -1198,7 +1246,7 @@ class AIListGeneratorBlock(AIBlockBase):
"generated_list",
["Zylora Prime", "Kharon-9", "Vortexia", "Oceara", "Draknos"],
),
("prompt", str),
("prompt", list),
("list_item", "Zylora Prime"),
("list_item", "Kharon-9"),
("list_item", "Vortexia"),

View File

@@ -65,7 +65,7 @@ class AddMemoryBlock(Block, Mem0Base):
default=Content(discriminator="content", content="I'm a vegetarian"),
)
metadata: dict[str, Any] = SchemaField(
description="Optional metadata for the memory", default={}
description="Optional metadata for the memory", default_factory=dict
)
limit_memory_to_run: bool = SchemaField(
@@ -173,7 +173,7 @@ class SearchMemoryBlock(Block, Mem0Base):
)
categories_filter: list[str] = SchemaField(
description="Categories to filter by",
default=[],
default_factory=list,
advanced=True,
)
limit_memory_to_run: bool = SchemaField(

View File

@@ -177,7 +177,8 @@ class PineconeInsertBlock(Block):
description="Namespace to use in Pinecone", default=""
)
metadata: dict = SchemaField(
description="Additional metadata to store with each vector", default={}
description="Additional metadata to store with each vector",
default_factory=dict,
)
class Output(BlockSchema):

View File

@@ -26,7 +26,7 @@ class Slant3DTriggerBase:
class Input(BlockSchema):
credentials: Slant3DCredentialsInput = Slant3DCredentialsField()
# Webhook URL is handled by the webhook system
payload: dict = SchemaField(hidden=True, default={})
payload: dict = SchemaField(hidden=True, default_factory=dict)
class Output(BlockSchema):
payload: dict = SchemaField(

View File

@@ -14,7 +14,6 @@ from backend.data.block import (
BlockOutput,
BlockSchema,
BlockType,
get_block,
)
from backend.data.model import SchemaField
from backend.util import json
@@ -27,10 +26,10 @@ logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManager
from backend.executor import DatabaseManagerClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManager)
return get_service_client(DatabaseManagerClient)
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
@@ -155,7 +154,7 @@ class SmartDecisionMakerBlock(Block):
description="The system prompt to provide additional context to the model.",
)
conversation_history: list[dict] = SchemaField(
default=[],
default_factory=list,
description="The conversation history to provide context for the prompt.",
)
last_tool_output: Any = SchemaField(
@@ -169,7 +168,7 @@ class SmartDecisionMakerBlock(Block):
)
prompt_values: dict[str, str] = SchemaField(
advanced=False,
default={},
default_factory=dict,
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
)
max_tokens: int | None = SchemaField(
@@ -247,6 +246,10 @@ class SmartDecisionMakerBlock(Block):
test_credentials=llm.TEST_CREDENTIALS,
)
@staticmethod
def cleanup(s: str):
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
@staticmethod
def _create_block_function_signature(
sink_node: "Node", links: list["Link"]
@@ -264,12 +267,10 @@ class SmartDecisionMakerBlock(Block):
Raises:
ValueError: If the block specified by sink_node.block_id is not found.
"""
block = get_block(sink_node.block_id)
if not block:
raise ValueError(f"Block not found: {sink_node.block_id}")
block = sink_node.block
tool_function: dict[str, Any] = {
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", block.name).lower(),
"name": SmartDecisionMakerBlock.cleanup(block.name),
"description": block.description,
}
@@ -284,7 +285,7 @@ class SmartDecisionMakerBlock(Block):
and sink_block_input_schema.model_fields[link.sink_name].description
else f"The {link.sink_name} of the tool"
)
properties[link.sink_name.lower()] = {
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
"type": "string",
"description": description,
}
@@ -329,7 +330,7 @@ class SmartDecisionMakerBlock(Block):
)
tool_function: dict[str, Any] = {
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", sink_graph_meta.name).lower(),
"name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name),
"description": sink_graph_meta.description,
}
@@ -344,7 +345,7 @@ class SmartDecisionMakerBlock(Block):
in sink_block_input_schema["properties"][link.sink_name]
else f"The {link.sink_name} of the tool"
)
properties[link.sink_name.lower()] = {
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
"type": "string",
"description": description,
}
@@ -494,6 +495,7 @@ class SmartDecisionMakerBlock(Block):
max_tokens=input_data.max_tokens,
tools=tool_functions,
ollama_host=input_data.ollama_host,
parallel_tool_calls=False,
)
if not response.tool_calls:
@@ -505,7 +507,7 @@ class SmartDecisionMakerBlock(Block):
tool_args = json.loads(tool_call.function.arguments)
for arg_name, arg_value in tool_args.items():
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
yield f"tools_^_{tool_name}_~_{arg_name}", arg_value
response.prompt.append(response.raw_response)
yield "conversations", response.prompt

View File

@@ -112,7 +112,7 @@ class AddLeadToCampaignBlock(Block):
lead_list: list[LeadInput] = SchemaField(
description="An array of JSON objects, each representing a lead's details. Can hold max 100 leads.",
max_length=100,
default=[],
default_factory=list,
advanced=False,
)
settings: LeadUploadSettings = SchemaField(
@@ -248,7 +248,7 @@ class SaveCampaignSequencesBlock(Block):
)
sequences: list[Sequence] = SchemaField(
description="The sequences to save",
default=[],
default_factory=list,
advanced=False,
)
credentials: SmartLeadCredentialsInput = SchemaField(

View File

@@ -39,7 +39,7 @@ class LeadCustomFields(BaseModel):
fields: dict[str, str] = SchemaField(
description="Custom fields for a lead (max 20 fields)",
max_length=20,
default={},
default_factory=dict,
)
@@ -85,7 +85,7 @@ class AddLeadsRequest(BaseModel):
lead_list: list[LeadInput] = SchemaField(
description="List of leads to add to the campaign",
max_length=100,
default=[],
default_factory=list,
)
settings: LeadUploadSettings
campaign_id: int

View File

@@ -156,7 +156,7 @@
# participant_ids: list[str] = SchemaField(
# description="Array of User IDs to create conversation with (max 50)",
# placeholder="Enter participant user IDs",
# default=[],
# default_factory=list,
# advanced=False
# )

View File

@@ -39,7 +39,6 @@ class TwitterGetListBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to lookup",
placeholder="Enter list ID",
required=True,
)
class Output(BlockSchema):
@@ -184,7 +183,6 @@ class TwitterGetOwnedListsBlock(Block):
user_id: str = SchemaField(
description="The user ID whose owned Lists to retrieve",
placeholder="Enter user ID",
required=True,
)
max_results: int | None = SchemaField(

View File

@@ -45,13 +45,11 @@ class TwitterRemoveListMemberBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to remove the member from",
placeholder="Enter list ID",
required=True,
)
user_id: str = SchemaField(
description="The ID of the user to remove from the List",
placeholder="Enter user ID to remove",
required=True,
)
class Output(BlockSchema):
@@ -120,13 +118,11 @@ class TwitterAddListMemberBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to add the member to",
placeholder="Enter list ID",
required=True,
)
user_id: str = SchemaField(
description="The ID of the user to add to the List",
placeholder="Enter user ID to add",
required=True,
)
class Output(BlockSchema):
@@ -195,7 +191,6 @@ class TwitterGetListMembersBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to get members from",
placeholder="Enter list ID",
required=True,
)
max_results: int | None = SchemaField(
@@ -376,7 +371,6 @@ class TwitterGetListMembershipsBlock(Block):
user_id: str = SchemaField(
description="The ID of the user whose List memberships to retrieve",
placeholder="Enter user ID",
required=True,
)
max_results: int | None = SchemaField(

View File

@@ -42,7 +42,6 @@ class TwitterGetListTweetsBlock(Block):
list_id: str = SchemaField(
description="The ID of the List whose Tweets you would like to retrieve",
placeholder="Enter list ID",
required=True,
)
max_results: int | None = SchemaField(

View File

@@ -28,7 +28,6 @@ class TwitterDeleteListBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to be deleted",
placeholder="Enter list ID",
required=True,
)
class Output(BlockSchema):

View File

@@ -39,7 +39,6 @@ class TwitterUnpinListBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to unpin",
placeholder="Enter list ID",
required=True,
)
class Output(BlockSchema):
@@ -103,7 +102,6 @@ class TwitterPinListBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to pin",
placeholder="Enter list ID",
required=True,
)
class Output(BlockSchema):

View File

@@ -44,7 +44,7 @@ class SpaceList(BaseModel):
space_ids: list[str] = SchemaField(
description="List of Space IDs to lookup (up to 100)",
placeholder="Enter Space IDs",
default=[],
default_factory=list,
advanced=False,
)
@@ -54,7 +54,7 @@ class UserList(BaseModel):
user_ids: list[str] = SchemaField(
description="List of user IDs to lookup their Spaces (up to 100)",
placeholder="Enter user IDs",
default=[],
default_factory=list,
advanced=False,
)
@@ -227,7 +227,6 @@ class TwitterGetSpaceByIdBlock(Block):
space_id: str = SchemaField(
description="Space ID to lookup",
placeholder="Enter Space ID",
required=True,
)
class Output(BlockSchema):
@@ -389,7 +388,6 @@ class TwitterGetSpaceBuyersBlock(Block):
space_id: str = SchemaField(
description="Space ID to lookup buyers for",
placeholder="Enter Space ID",
required=True,
)
class Output(BlockSchema):
@@ -517,7 +515,6 @@ class TwitterGetSpaceTweetsBlock(Block):
space_id: str = SchemaField(
description="Space ID to lookup tweets for",
placeholder="Enter Space ID",
required=True,
)
class Output(BlockSchema):

View File

@@ -200,7 +200,7 @@ class UserIdList(BaseModel):
user_ids: list[str] = SchemaField(
description="List of user IDs to lookup (max 100)",
placeholder="Enter user IDs",
default=[],
default_factory=list,
advanced=False,
)
@@ -210,7 +210,7 @@ class UsernameList(BaseModel):
usernames: list[str] = SchemaField(
description="List of Twitter usernames/handles to lookup (max 100)",
placeholder="Enter usernames",
default=[],
default_factory=list,
advanced=False,
)

View File

@@ -8,7 +8,6 @@ import pathlib
import click
import psutil
from backend import app
from backend.util.process import AppProcess
@@ -42,8 +41,13 @@ def write_pid(pid: int):
class MainApp(AppProcess):
def run(self):
from backend import app
app.main(silent=True)
def cleanup(self):
pass
@click.group()
def main():

View File

@@ -12,12 +12,12 @@ async def log_raw_analytics(
data_index: str,
):
details = await prisma.models.AnalyticsDetails.prisma().create(
data={
"userId": user_id,
"type": type,
"data": prisma.Json(data),
"dataIndex": data_index,
}
data=prisma.types.AnalyticsDetailsCreateInput(
userId=user_id,
type=type,
data=prisma.Json(data),
dataIndex=data_index,
)
)
return details
@@ -32,12 +32,12 @@ async def log_raw_metric(
raise ValueError("metric_value must be non-negative")
result = await prisma.models.AnalyticsMetrics.prisma().create(
data={
"value": metric_value,
"analyticMetric": metric_name,
"userId": user_id,
"dataString": data_string,
},
data=prisma.types.AnalyticsMetricsCreateInput(
value=metric_value,
analyticMetric=metric_name,
userId=user_id,
dataString=data_string,
)
)
return result

View File

@@ -17,6 +17,7 @@ from typing import (
import jsonref
import jsonschema
from prisma.models import AgentBlock
from prisma.types import AgentBlockCreateInput
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
@@ -27,6 +28,7 @@ from backend.util.settings import Config
from .model import (
ContributorDetails,
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
is_credentials_field_name,
)
@@ -202,6 +204,15 @@ class BlockSchema(BaseModel):
)
}
@classmethod
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
return {
field_name: CredentialsFieldInfo.model_validate(
cls.get_field_schema(field_name), by_alias=True
)
for field_name in cls.get_credentials_fields().keys()
}
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data # Return as is, by default.
@@ -480,12 +491,12 @@ async def initialize_blocks() -> None:
)
if not existing_block:
await AgentBlock.prisma().create(
data={
"id": block.id,
"name": block.name,
"inputSchema": json.dumps(block.input_schema.jsonschema()),
"outputSchema": json.dumps(block.output_schema.jsonschema()),
}
data=AgentBlockCreateInput(
id=block.id,
name=block.name,
inputSchema=json.dumps(block.input_schema.jsonschema()),
outputSchema=json.dumps(block.output_schema.jsonschema()),
)
)
continue
@@ -508,6 +519,7 @@ async def initialize_blocks() -> None:
)
def get_block(block_id: str) -> Block | None:
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
cls = get_blocks().get(block_id)
return cls() if cls else None

View File

@@ -36,14 +36,17 @@ from backend.integrations.credentials_store import (
# =============== Configure the cost for each LLM Model call =============== #
MODEL_COST: dict[LlmModel, int] = {
LlmModel.O3: 7,
LlmModel.O3_MINI: 2, # $1.10 / $4.40
LlmModel.O1: 16, # $15 / $60
LlmModel.O1_PREVIEW: 16,
LlmModel.O1_MINI: 4,
LlmModel.GPT41: 2,
LlmModel.GPT4O_MINI: 1,
LlmModel.GPT4O: 3,
LlmModel.GPT4_TURBO: 10,
LlmModel.GPT3_5_TURBO: 1,
LlmModel.CLAUDE_3_7_SONNET: 5,
LlmModel.CLAUDE_3_5_SONNET: 4,
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
LlmModel.CLAUDE_3_HAIKU: 1,
@@ -60,6 +63,7 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.DEEPSEEK_LLAMA_70B: 1, # ? / ?
LlmModel.OLLAMA_DOLPHIN: 1,
LlmModel.GEMINI_FLASH_1_5: 1,
LlmModel.GEMINI_2_5_PRO: 4,
LlmModel.GROK_BETA: 5,
LlmModel.MISTRAL_NEMO: 1,
LlmModel.COHERE_COMMAND_R_08_2024: 1,
@@ -75,6 +79,8 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.AMAZON_NOVA_PRO_V1: 1,
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
LlmModel.META_LLAMA_4_SCOUT: 1,
LlmModel.META_LLAMA_4_MAVERICK: 1,
}
for model in LlmModel:

View File

@@ -1,8 +1,8 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timezone
from typing import Any, cast
import stripe
from autogpt_libs.utils.cache import thread_cached
@@ -11,11 +11,16 @@ from prisma.enums import (
CreditRefundRequestStatus,
CreditTransactionType,
NotificationType,
OnboardingStep,
)
from prisma.errors import UniqueViolationError
from prisma.models import CreditRefundRequest, CreditTransaction, User
from prisma.types import CreditTransactionCreateInput, CreditTransactionWhereInput
from tenacity import retry, stop_after_attempt, wait_exponential
from prisma.types import (
CreditRefundRequestCreateInput,
CreditTransactionCreateInput,
CreditTransactionWhereInput,
)
from pydantic import BaseModel
from backend.data import db
from backend.data.block_cost_config import BLOCK_COSTS
@@ -23,14 +28,17 @@ from backend.data.cost import BlockCost
from backend.data.model import (
AutoTopUpConfig,
RefundRequest,
TopUpType,
TransactionHistory,
UserTransaction,
)
from backend.data.notifications import NotificationEventDTO, RefundRequestData
from backend.data.user import get_user_by_id
from backend.executor.utils import UsageTransactionMetadata
from backend.notifications import NotificationManager
from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications import NotificationManagerClient
from backend.server.model import Pagination
from backend.server.v2.admin.model import UserHistoryResponse
from backend.util.exceptions import InsufficientBalanceError
from backend.util.retry import func_retry
from backend.util.service import get_service_client
from backend.util.settings import Settings
@@ -40,6 +48,17 @@ logger = logging.getLogger(__name__)
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
node_exec_id: str | None = None
block_id: str | None = None
block: str | None = None
input: dict[str, Any] | None = None
reason: str | None = None
class UserCreditBase(ABC):
@abstractmethod
async def get_credits(self, user_id: str) -> int:
@@ -117,6 +136,18 @@ class UserCreditBase(ABC):
"""
pass
@abstractmethod
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
"""
Reward the user with credits for completing an onboarding step.
Won't reward if the user has already received credits for the step.
Args:
user_id (str): The user ID.
step (OnboardingStep): The onboarding step.
"""
pass
@abstractmethod
async def top_up_intent(self, user_id: str, amount: int) -> str:
"""
@@ -209,7 +240,7 @@ class UserCreditBase(ABC):
"userId": user_id,
"createdAt": {"lte": top_time},
"isActive": True,
"runningBalance": {"not": None}, # type: ignore
"NOT": [{"runningBalance": None}],
},
order={"createdAt": "desc"},
)
@@ -245,11 +276,7 @@ class UserCreditBase(ABC):
)
return transaction_balance, transaction_time
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=1, max=10),
reraise=True,
)
@func_retry
async def _enable_transaction(
self,
transaction_key: str,
@@ -331,15 +358,15 @@ class UserCreditBase(ABC):
amount = min(-user_balance, 0)
# Create the transaction
transaction_data: CreditTransactionCreateInput = {
"userId": user_id,
"amount": amount,
"runningBalance": user_balance + amount,
"type": transaction_type,
"metadata": metadata,
"isActive": is_active,
"createdAt": self.time_now(),
}
transaction_data = CreditTransactionCreateInput(
userId=user_id,
amount=amount,
runningBalance=user_balance + amount,
type=transaction_type,
metadata=metadata,
isActive=is_active,
createdAt=self.time_now(),
)
if transaction_key:
transaction_data["transactionKey"] = transaction_key
tx = await CreditTransaction.prisma().create(data=transaction_data)
@@ -348,21 +375,19 @@ class UserCreditBase(ABC):
class UserCredit(UserCreditBase):
@thread_cached
def notification_client(self) -> NotificationManager:
return get_service_client(NotificationManager)
def notification_client(self) -> NotificationManagerClient:
return get_service_client(NotificationManagerClient)
async def _send_refund_notification(
self,
notification_request: RefundRequestData,
notification_type: NotificationType,
):
await asyncio.to_thread(
lambda: self.notification_client().queue_notification(
NotificationEventDTO(
user_id=notification_request.user_id,
type=notification_type,
data=notification_request.model_dump(),
)
await self.notification_client().queue_notification_async(
NotificationEventDTO(
user_id=notification_request.user_id,
type=notification_type,
data=notification_request.model_dump(),
)
)
@@ -392,6 +417,7 @@ class UserCredit(UserCreditBase):
# Avoid multiple auto top-ups within the same graph execution.
key=f"AUTO-TOP-UP-{user_id}-{metadata.graph_exec_id}",
ceiling_balance=auto_top_up.threshold,
top_up_type=TopUpType.AUTO,
)
except Exception as e:
# Failed top-up is not critical, we can move on.
@@ -401,8 +427,30 @@ class UserCredit(UserCreditBase):
return balance
async def top_up_credits(self, user_id: str, amount: int):
await self._top_up_credits(user_id, amount)
async def top_up_credits(
self,
user_id: str,
amount: int,
top_up_type: TopUpType = TopUpType.UNCATEGORIZED,
):
await self._top_up_credits(
user_id=user_id, amount=amount, top_up_type=top_up_type
)
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
try:
await self._add_transaction(
user_id=user_id,
amount=credits,
transaction_type=CreditTransactionType.GRANT,
transaction_key=f"REWARD-{user_id}-{step.value}",
metadata=Json(
{"reason": f"Reward for completing {step.value} onboarding step."}
),
)
except UniqueViolationError:
# Already rewarded for this step
pass
async def top_up_refund(
self, user_id: str, transaction_key: str, metadata: dict[str, str]
@@ -422,15 +470,15 @@ class UserCredit(UserCreditBase):
try:
refund_request = await CreditRefundRequest.prisma().create(
data={
"id": refund_key,
"transactionKey": transaction_key,
"userId": user_id,
"amount": amount,
"reason": metadata.get("reason", ""),
"status": CreditRefundRequestStatus.PENDING,
"result": "The refund request is under review.",
}
data=CreditRefundRequestCreateInput(
id=refund_key,
transactionKey=transaction_key,
userId=user_id,
amount=amount,
reason=metadata.get("reason", ""),
status=CreditRefundRequestStatus.PENDING,
result="The refund request is under review.",
)
)
except UniqueViolationError:
raise ValueError(
@@ -567,7 +615,7 @@ class UserCredit(UserCreditBase):
evidence_text += (
f"- {tx.description}: Amount ${tx.amount / 100:.2f} on {tx.transaction_time.isoformat()}, "
f"resulting balance ${tx.balance / 100:.2f} {additional_comment}\n"
f"resulting balance ${tx.running_balance / 100:.2f} {additional_comment}\n"
)
evidence_text += (
"\nThis evidence demonstrates that the transaction was authorized and that the charged amount was used to render the service as agreed."
@@ -586,7 +634,24 @@ class UserCredit(UserCreditBase):
amount: int,
key: str | None = None,
ceiling_balance: int | None = None,
top_up_type: TopUpType = TopUpType.UNCATEGORIZED,
metadata: dict | None = None,
):
# init metadata, without sharing it with the world
metadata = metadata or {}
if not metadata["reason"]:
match top_up_type:
case TopUpType.MANUAL:
metadata["reason"] = {"reason": f"Top up credits for {user_id}"}
case TopUpType.AUTO:
metadata["reason"] = {
"reason": f"Auto top up credits for {user_id}"
}
case _:
metadata["reason"] = {
"reason": f"Top up reason unknown for {user_id}"
}
if amount < 0:
raise ValueError(f"Top up amount must not be negative: {amount}")
@@ -609,6 +674,7 @@ class UserCredit(UserCreditBase):
is_active=False,
transaction_key=key,
ceiling_balance=ceiling_balance,
metadata=(Json(metadata)),
)
customer_id = await get_stripe_customer_id(user_id)
@@ -751,10 +817,15 @@ class UserCredit(UserCreditBase):
# Check the Checkout Session's payment_status property
# to determine if fulfillment should be performed
if checkout_session.payment_status in ["paid", "no_payment_required"]:
assert isinstance(checkout_session.payment_intent, stripe.PaymentIntent)
if payment_intent := checkout_session.payment_intent:
assert isinstance(payment_intent, stripe.PaymentIntent)
new_transaction_key = payment_intent.id
else:
new_transaction_key = None
await self._enable_transaction(
transaction_key=credit_transaction.transactionKey,
new_transaction_key=checkout_session.payment_intent.id,
new_transaction_key=new_transaction_key,
user_id=credit_transaction.userId,
metadata=Json(checkout_session),
)
@@ -787,8 +858,9 @@ class UserCredit(UserCreditBase):
take=transaction_count_limit,
)
# doesn't fill current_balance, reason, user_email, admin_email, or extra_data
grouped_transactions: dict[str, UserTransaction] = defaultdict(
lambda: UserTransaction()
lambda: UserTransaction(user_id=user_id)
)
tx_time = None
for t in transactions:
@@ -818,7 +890,7 @@ class UserCredit(UserCreditBase):
if tx_time > gt.transaction_time:
gt.transaction_time = tx_time
gt.balance = t.runningBalance or 0
gt.running_balance = t.runningBalance or 0
return TransactionHistory(
transactions=list(grouped_transactions.values()),
@@ -868,6 +940,7 @@ class BetaUserCredit(UserCredit):
amount=max(self.num_user_credits_refill - balance, 0),
transaction_type=CreditTransactionType.GRANT,
transaction_key=f"MONTHLY-CREDIT-TOP-UP-{cur_time}",
metadata=Json({"reason": "Monthly credit refill"}),
)
return balance
except UniqueViolationError:
@@ -877,7 +950,7 @@ class BetaUserCredit(UserCredit):
class DisabledUserCredit(UserCreditBase):
async def get_credits(self, *args, **kwargs) -> int:
return 0
return 100
async def get_transaction_history(self, *args, **kwargs) -> TransactionHistory:
return TransactionHistory(transactions=[], next_transaction_time=None)
@@ -891,6 +964,9 @@ class DisabledUserCredit(UserCreditBase):
async def top_up_credits(self, *args, **kwargs):
pass
async def onboarding_reward(self, *args, **kwargs):
pass
async def top_up_intent(self, *args, **kwargs) -> str:
return ""
@@ -952,3 +1028,81 @@ async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
return AutoTopUpConfig(threshold=0, amount=0)
return AutoTopUpConfig.model_validate(user.topUpConfig)
async def admin_get_user_history(
page: int = 1,
page_size: int = 20,
search: str | None = None,
transaction_filter: CreditTransactionType | None = None,
) -> UserHistoryResponse:
if page < 1 or page_size < 1:
raise ValueError("Invalid pagination input")
where_clause: CreditTransactionWhereInput = {}
if transaction_filter:
where_clause["type"] = transaction_filter
if search:
where_clause["OR"] = [
{"userId": {"contains": search, "mode": "insensitive"}},
{"User": {"is": {"email": {"contains": search, "mode": "insensitive"}}}},
{"User": {"is": {"name": {"contains": search, "mode": "insensitive"}}}},
]
transactions = await CreditTransaction.prisma().find_many(
where=where_clause,
skip=(page - 1) * page_size,
take=page_size,
include={"User": True},
order={"createdAt": "desc"},
)
total = await CreditTransaction.prisma().count(where=where_clause)
total_pages = (total + page_size - 1) // page_size
history = []
for tx in transactions:
admin_id = ""
admin_email = ""
reason = ""
metadata: dict = cast(dict, tx.metadata) or {}
if metadata:
admin_id = metadata.get("admin_id")
admin_email = (
(await get_user_email_by_id(admin_id) or f"Unknown Admin: {admin_id}")
if admin_id
else ""
)
reason = metadata.get("reason", "No reason provided")
balance, last_update = await get_user_credit_model()._get_credits(tx.userId)
history.append(
UserTransaction(
transaction_key=tx.transactionKey,
transaction_time=tx.createdAt,
transaction_type=tx.type,
amount=tx.amount,
current_balance=balance,
running_balance=tx.runningBalance or 0,
user_id=tx.userId,
user_email=(
tx.User.email
if tx.User
else (await get_user_by_id(tx.userId)).email
),
reason=reason,
admin_email=admin_email,
extra_data=str(metadata),
)
)
return UserHistoryResponse(
history=history,
pagination=Pagination(
total_items=total,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)

View File

@@ -62,10 +62,10 @@ async def connect():
# Connection acquired from a pool like Supabase somehow still possibly allows
# the db client obtains a connection but still reject query connection afterward.
try:
await prisma.execute_raw("SELECT 1")
except Exception as e:
raise ConnectionError("Failed to connect to Prisma.") from e
# try:
# await prisma.execute_raw("SELECT 1")
# except Exception as e:
# raise ConnectionError("Failed to connect to Prisma.") from e
@conn_retry("Prisma", "Releasing connection")
@@ -89,7 +89,7 @@ async def transaction():
async def locked_transaction(key: str):
lock_key = zlib.crc32(key.encode("utf-8"))
async with transaction() as tx:
await tx.execute_raw(f"SELECT pg_advisory_xact_lock({lock_key})")
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
yield tx

View File

@@ -23,7 +23,10 @@ from prisma.models import (
AgentNodeExecutionInputOutput,
)
from prisma.types import (
AgentGraphExecutionCreateInput,
AgentGraphExecutionWhereInput,
AgentNodeExecutionCreateInput,
AgentNodeExecutionInputOutputCreateInput,
AgentNodeExecutionUpdateInput,
AgentNodeExecutionWhereInput,
)
@@ -31,18 +34,17 @@ from pydantic import BaseModel
from pydantic.fields import Field
from backend.server.v2.store.exceptions import DatabaseError
from backend.util import mock
from backend.util import type as type_utils
from backend.util.settings import Config
from .block import BlockData, BlockInput, BlockType, CompletedBlockOutput, get_block
from .block import BlockInput, BlockType, CompletedBlockOutput, get_block
from .db import BaseDbModel
from .includes import (
EXECUTION_RESULT_INCLUDE,
GRAPH_EXECUTION_INCLUDE,
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
)
from .model import GraphExecutionStats, NodeExecutionStats
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
from .queue import AsyncRedisEventBus, RedisEventBus
T = TypeVar("T")
@@ -59,23 +61,27 @@ ExecutionStatus = AgentExecutionStatus
class GraphExecutionMeta(BaseDbModel):
user_id: str
started_at: datetime
ended_at: datetime
cost: Optional[int] = Field(..., description="Execution cost in credits")
duration: float = Field(..., description="Seconds from start to end of run")
total_run_time: float = Field(..., description="Seconds of node runtime")
status: ExecutionStatus
graph_id: str
graph_version: int
preset_id: Optional[str] = None
status: ExecutionStatus
started_at: datetime
ended_at: datetime
class Stats(BaseModel):
cost: int = Field(..., description="Execution cost (cents)")
duration: float = Field(..., description="Seconds from start to end of run")
node_exec_time: float = Field(..., description="Seconds of total node runtime")
node_exec_count: int = Field(..., description="Number of node executions")
stats: Stats | None
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
now = datetime.now(timezone.utc)
# TODO: make started_at and ended_at optional
start_time = _graph_exec.startedAt or _graph_exec.createdAt
end_time = _graph_exec.updatedAt or now
duration = (end_time - start_time).total_seconds()
total_run_time = duration
try:
stats = GraphExecutionStats.model_validate(_graph_exec.stats)
@@ -87,21 +93,25 @@ class GraphExecutionMeta(BaseDbModel):
)
stats = None
duration = stats.walltime if stats else duration
total_run_time = stats.nodes_walltime if stats else total_run_time
return GraphExecutionMeta(
id=_graph_exec.id,
user_id=_graph_exec.userId,
started_at=start_time,
ended_at=end_time,
cost=stats.cost if stats else None,
duration=duration,
total_run_time=total_run_time,
status=ExecutionStatus(_graph_exec.executionStatus),
graph_id=_graph_exec.agentGraphId,
graph_version=_graph_exec.agentGraphVersion,
preset_id=_graph_exec.agentPresetId,
status=ExecutionStatus(_graph_exec.executionStatus),
started_at=start_time,
ended_at=end_time,
stats=(
GraphExecutionMeta.Stats(
cost=stats.cost,
duration=stats.walltime,
node_exec_time=stats.nodes_walltime,
node_exec_count=stats.node_count,
)
if stats
else None
),
)
@@ -111,15 +121,16 @@ class GraphExecution(GraphExecutionMeta):
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
if _graph_exec.AgentNodeExecutions is None:
if _graph_exec.NodeExecutions is None:
raise ValueError("Node executions must be included in query")
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
node_executions = sorted(
complete_node_executions = sorted(
[
NodeExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.AgentNodeExecutions
for ne in _graph_exec.NodeExecutions
if ne.executionStatus != ExecutionStatus.INCOMPLETE
],
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
)
@@ -128,7 +139,7 @@ class GraphExecution(GraphExecutionMeta):
**{
# inputs from Agent Input Blocks
exec.input_data["name"]: exec.input_data.get("value")
for exec in node_executions
for exec in complete_node_executions
if (
(block := get_block(exec.block_id))
and block.block_type == BlockType.INPUT
@@ -137,7 +148,7 @@ class GraphExecution(GraphExecutionMeta):
**{
# input from webhook-triggered block
"payload": exec.input_data["payload"]
for exec in node_executions
for exec in complete_node_executions
if (
(block := get_block(exec.block_id))
and block.block_type
@@ -147,7 +158,7 @@ class GraphExecution(GraphExecutionMeta):
}
outputs: CompletedBlockOutput = defaultdict(list)
for exec in node_executions:
for exec in complete_node_executions:
if (
block := get_block(exec.block_id)
) and block.block_type == BlockType.OUTPUT:
@@ -158,7 +169,7 @@ class GraphExecution(GraphExecutionMeta):
return GraphExecution(
**{
field_name: getattr(graph_exec, field_name)
for field_name in graph_exec.model_fields
for field_name in GraphExecutionMeta.model_fields
},
inputs=inputs,
outputs=outputs,
@@ -170,7 +181,7 @@ class GraphExecutionWithNodes(GraphExecution):
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
if _graph_exec.AgentNodeExecutions is None:
if _graph_exec.NodeExecutions is None:
raise ValueError("Node executions must be included in query")
graph_exec_with_io = GraphExecution.from_db(_graph_exec)
@@ -178,7 +189,7 @@ class GraphExecutionWithNodes(GraphExecution):
node_executions = sorted(
[
NodeExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.AgentNodeExecutions
for ne in _graph_exec.NodeExecutions
],
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
)
@@ -186,11 +197,20 @@ class GraphExecutionWithNodes(GraphExecution):
return GraphExecutionWithNodes(
**{
field_name: getattr(graph_exec_with_io, field_name)
for field_name in graph_exec_with_io.model_fields
for field_name in GraphExecution.model_fields
},
node_executions=node_executions,
)
def to_graph_execution_entry(self):
return GraphExecutionEntry(
user_id=self.user_id,
graph_id=self.graph_id,
graph_version=self.graph_version or 0,
graph_exec_id=self.id,
node_credentials_input_map={}, # FIXME
)
class NodeExecutionResult(BaseModel):
user_id: str
@@ -209,21 +229,21 @@ class NodeExecutionResult(BaseModel):
end_time: datetime | None
@staticmethod
def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None):
if execution.executionData:
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
if _node_exec.executionData:
# Execution that has been queued for execution will persist its data.
input_data = type_utils.convert(execution.executionData, dict[str, Any])
input_data = type_utils.convert(_node_exec.executionData, dict[str, Any])
else:
# For incomplete execution, executionData will not be yet available.
input_data: BlockInput = defaultdict()
for data in execution.Input or []:
for data in _node_exec.Input or []:
input_data[data.name] = type_utils.convert(data.data, type[Any])
output_data: CompletedBlockOutput = defaultdict(list)
for data in execution.Output or []:
for data in _node_exec.Output or []:
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution
graph_execution: AgentGraphExecution | None = _node_exec.GraphExecution
if graph_execution:
user_id = graph_execution.userId
elif not user_id:
@@ -235,17 +255,28 @@ class NodeExecutionResult(BaseModel):
user_id=user_id,
graph_id=graph_execution.agentGraphId if graph_execution else "",
graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
graph_exec_id=execution.agentGraphExecutionId,
block_id=execution.AgentNode.agentBlockId if execution.AgentNode else "",
node_exec_id=execution.id,
node_id=execution.agentNodeId,
status=execution.executionStatus,
graph_exec_id=_node_exec.agentGraphExecutionId,
block_id=_node_exec.Node.agentBlockId if _node_exec.Node else "",
node_exec_id=_node_exec.id,
node_id=_node_exec.agentNodeId,
status=_node_exec.executionStatus,
input_data=input_data,
output_data=output_data,
add_time=execution.addedTime,
queue_time=execution.queuedTime,
start_time=execution.startedTime,
end_time=execution.endedTime,
add_time=_node_exec.addedTime,
queue_time=_node_exec.queuedTime,
start_time=_node_exec.startedTime,
end_time=_node_exec.endedTime,
)
def to_node_execution_entry(self) -> "NodeExecutionEntry":
return NodeExecutionEntry(
user_id=self.user_id,
graph_exec_id=self.graph_exec_id,
graph_id=self.graph_id,
node_exec_id=self.node_exec_id,
node_id=self.node_id,
block_id=self.block_id,
data=self.input_data,
)
@@ -330,7 +361,7 @@ async def get_graph_execution(
async def create_graph_execution(
graph_id: str,
graph_version: int,
nodes_input: list[tuple[str, BlockInput]],
starting_nodes_input: list[tuple[str, BlockInput]],
user_id: str,
preset_id: str | None = None,
) -> GraphExecutionWithNodes:
@@ -340,29 +371,29 @@ async def create_graph_execution(
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
"""
result = await AgentGraphExecution.prisma().create(
data={
"agentGraphId": graph_id,
"agentGraphVersion": graph_version,
"executionStatus": ExecutionStatus.QUEUED,
"AgentNodeExecutions": {
"create": [ # type: ignore
{
"agentNodeId": node_id,
"executionStatus": ExecutionStatus.QUEUED,
"queuedTime": datetime.now(tz=timezone.utc),
"Input": {
data=AgentGraphExecutionCreateInput(
agentGraphId=graph_id,
agentGraphVersion=graph_version,
executionStatus=ExecutionStatus.QUEUED,
NodeExecutions={
"create": [
AgentNodeExecutionCreateInput(
agentNodeId=node_id,
executionStatus=ExecutionStatus.QUEUED,
queuedTime=datetime.now(tz=timezone.utc),
Input={
"create": [
{"name": name, "data": Json(data)}
for name, data in node_input.items()
]
},
}
for node_id, node_input in nodes_input
)
for node_id, node_input in starting_nodes_input
]
},
"userId": user_id,
"agentPresetId": preset_id,
},
userId=user_id,
agentPresetId=preset_id,
),
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
)
@@ -409,11 +440,11 @@ async def upsert_execution_input(
if existing_execution:
await AgentNodeExecutionInputOutput.prisma().create(
data={
"name": input_name,
"data": json_input_data,
"referencedByInputExecId": existing_execution.id,
}
data=AgentNodeExecutionInputOutputCreateInput(
name=input_name,
data=json_input_data,
referencedByInputExecId=existing_execution.id,
)
)
return existing_execution.id, {
**{
@@ -425,12 +456,12 @@ async def upsert_execution_input(
elif not node_exec_id:
result = await AgentNodeExecution.prisma().create(
data={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_exec_id,
"executionStatus": ExecutionStatus.INCOMPLETE,
"Input": {"create": {"name": input_name, "data": json_input_data}},
}
data=AgentNodeExecutionCreateInput(
agentNodeId=node_id,
agentGraphExecutionId=graph_exec_id,
executionStatus=ExecutionStatus.INCOMPLETE,
Input={"create": {"name": input_name, "data": json_input_data}},
)
)
return result.id, {input_name: input_data}
@@ -449,15 +480,17 @@ async def upsert_execution_output(
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
"""
await AgentNodeExecutionInputOutput.prisma().create(
data={
"name": output_name,
"data": Json(output_data),
"referencedByOutputExecId": node_exec_id,
}
data=AgentNodeExecutionInputOutputCreateInput(
name=output_name,
data=Json(output_data),
referencedByOutputExecId=node_exec_id,
)
)
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecution:
async def update_graph_execution_start_time(
graph_exec_id: str,
) -> GraphExecution | None:
res = await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
data={
@@ -466,10 +499,7 @@ async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecutio
},
include=GRAPH_EXECUTION_INCLUDE,
)
if not res:
raise ValueError(f"Graph execution #{graph_exec_id} not found")
return GraphExecution.from_db(res)
return GraphExecution.from_db(res) if res else None
async def update_graph_execution_stats(
@@ -480,7 +510,8 @@ async def update_graph_execution_stats(
data = stats.model_dump() if stats else {}
if isinstance(data.get("error"), Exception):
data["error"] = str(data["error"])
res = await AgentGraphExecution.prisma().update(
updated_count = await AgentGraphExecution.prisma().update_many(
where={
"id": graph_exec_id,
"OR": [
@@ -492,10 +523,15 @@ async def update_graph_execution_stats(
"executionStatus": status,
"stats": Json(data),
},
)
if updated_count == 0:
return None
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
where={"id": graph_exec_id},
include=GRAPH_EXECUTION_INCLUDE,
)
return GraphExecution.from_db(res) if res else None
return GraphExecution.from_db(graph_exec)
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
@@ -579,8 +615,9 @@ async def delete_graph_execution(
)
async def get_node_execution_results(
async def get_node_executions(
graph_exec_id: str,
node_id: str | None = None,
block_ids: list[str] | None = None,
statuses: list[ExecutionStatus] | None = None,
limit: int | None = None,
@@ -588,8 +625,10 @@ async def get_node_execution_results(
where_clause: AgentNodeExecutionWhereInput = {
"agentGraphExecutionId": graph_exec_id,
}
if node_id:
where_clause["agentNodeId"] = node_id
if block_ids:
where_clause["AgentNode"] = {"is": {"agentBlockId": {"in": block_ids}}}
where_clause["Node"] = {"is": {"agentBlockId": {"in": block_ids}}}
if statuses:
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
@@ -631,7 +670,7 @@ async def get_latest_node_execution(
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
"NOT": [{"executionStatus": ExecutionStatus.INCOMPLETE}],
},
order=[
{"queuedTime": "desc"},
@@ -644,20 +683,6 @@ async def get_latest_node_execution(
return NodeExecutionResult.from_db(execution)
async def get_incomplete_node_executions(
node_id: str, graph_eid: str
) -> list[NodeExecutionResult]:
executions = await AgentNodeExecution.prisma().find_many(
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": ExecutionStatus.INCOMPLETE,
},
include=EXECUTION_RESULT_INCLUDE,
)
return [NodeExecutionResult.from_db(execution) for execution in executions]
# ----------------- Execution Infrastructure ----------------- #
@@ -666,7 +691,7 @@ class GraphExecutionEntry(BaseModel):
graph_exec_id: str
graph_id: str
graph_version: int
start_node_execs: list["NodeExecutionEntry"]
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]]
class NodeExecutionEntry(BaseModel):
@@ -699,144 +724,6 @@ class ExecutionQueue(Generic[T]):
return self.queue.empty()
# ------------------- Execution Utilities -------------------- #
LIST_SPLIT = "_$_"
DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
def parse_execution_output(output: BlockData, name: str) -> Any | None:
"""
Extracts partial output data by name from a given BlockData.
The function supports extracting data from lists, dictionaries, and objects
using specific naming conventions:
- For lists: <output_name>_$_<index>
- For dictionaries: <output_name>_#_<key>
- For objects: <output_name>_@_<attribute>
Args:
output (BlockData): A tuple containing the output name and data.
name (str): The name used to extract specific data from the output.
Returns:
Any | None: The extracted data if found, otherwise None.
Examples:
>>> output = ("result", [10, 20, 30])
>>> parse_execution_output(output, "result_$_1")
20
>>> output = ("config", {"key1": "value1", "key2": "value2"})
>>> parse_execution_output(output, "config_#_key1")
'value1'
>>> class Sample:
... attr1 = "value1"
... attr2 = "value2"
>>> output = ("object", Sample())
>>> parse_execution_output(output, "object_@_attr1")
'value1'
"""
output_name, output_data = output
if name == output_name:
return output_data
if name.startswith(f"{output_name}{LIST_SPLIT}"):
index = int(name.split(LIST_SPLIT)[1])
if not isinstance(output_data, list) or len(output_data) <= index:
return None
return output_data[int(name.split(LIST_SPLIT)[1])]
if name.startswith(f"{output_name}{DICT_SPLIT}"):
index = name.split(DICT_SPLIT)[1]
if not isinstance(output_data, dict) or index not in output_data:
return None
return output_data[index]
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
index = name.split(OBJC_SPLIT)[1]
if isinstance(output_data, object) and hasattr(output_data, index):
return getattr(output_data, index)
return None
return None
def merge_execution_input(data: BlockInput) -> BlockInput:
"""
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
This function processes input keys that follow specific patterns to merge them into a unified structure:
- `<input_name>_$_<index>` for list inputs.
- `<input_name>_#_<index>` for dictionary inputs.
- `<input_name>_@_<index>` for object inputs.
Args:
data (BlockInput): A dictionary containing input keys and their corresponding values.
Returns:
BlockInput: A dictionary with merged inputs.
Raises:
ValueError: If a list index is not an integer.
Examples:
>>> data = {
... "list_$_0": "a",
... "list_$_1": "b",
... "dict_#_key1": "value1",
... "dict_#_key2": "value2",
... "object_@_attr1": "value1",
... "object_@_attr2": "value2"
... }
>>> merge_execution_input(data)
{
"list": ["a", "b"],
"dict": {"key1": "value1", "key2": "value2"},
"object": <MockObject attr1="value1" attr2="value2">
}
"""
# Merge all input with <input_name>_$_<index> into a single list.
items = list(data.items())
for key, value in items:
if LIST_SPLIT not in key:
continue
name, index = key.split(LIST_SPLIT)
if not index.isdigit():
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
data[name] = data.get(name, [])
if int(index) >= len(data[name]):
# Pad list with empty string on missing indices.
data[name].extend([""] * (int(index) - len(data[name]) + 1))
data[name][int(index)] = value
# Merge all input with <input_name>_#_<index> into a single dict.
for key, value in items:
if DICT_SPLIT not in key:
continue
name, index = key.split(DICT_SPLIT)
data[name] = data.get(name, {})
data[name][index] = value
# Merge all input with <input_name>_@_<index> into a single object.
for key, value in items:
if OBJC_SPLIT not in key:
continue
name, index = key.split(OBJC_SPLIT)
if name not in data or not isinstance(data[name], object):
data[name] = mock.MockObject()
setattr(data[name], index, value)
return data
# --------------------- Event Bus --------------------- #

View File

@@ -1,19 +1,31 @@
import logging
import uuid
from collections import defaultdict
from typing import Any, Literal, Optional, Type
from typing import Any, Literal, Optional, cast
import prisma
from prisma import Json
from prisma.enums import SubmissionStatus
from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVersion
from prisma.types import AgentGraphWhereInput
from prisma.types import (
AgentGraphCreateInput,
AgentGraphWhereInput,
AgentNodeCreateInput,
AgentNodeLinkCreateInput,
)
from pydantic import create_model
from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.blocks.llm import LlmModel
from backend.data.db import prisma as db
from backend.data.model import (
CredentialsField,
CredentialsFieldInfo,
CredentialsMetaInput,
is_credentials_field_name,
)
from backend.util import type as type_utils
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
@@ -53,22 +65,23 @@ class Node(BaseDbModel):
input_links: list[Link] = []
output_links: list[Link] = []
webhook_id: Optional[str] = None
@property
def block(self) -> Block[BlockSchema, BlockSchema]:
block = get_block(self.block_id)
if not block:
raise ValueError(
f"Block #{self.block_id} does not exist -> Node #{self.id} is invalid"
)
return block
class NodeModel(Node):
graph_id: str
graph_version: int
webhook_id: Optional[str] = None
webhook: Optional[Webhook] = None
@property
def block(self) -> Block[BlockSchema, BlockSchema]:
block = get_block(self.block_id)
if not block:
raise ValueError(f"Block #{self.block_id} does not exist")
return block
@staticmethod
def from_db(node: AgentNode, for_export: bool = False) -> "NodeModel":
obj = NodeModel(
@@ -88,8 +101,7 @@ class NodeModel(Node):
return obj
def is_triggered_by_event_type(self, event_type: str) -> bool:
if not (block := get_block(self.block_id)):
raise ValueError(f"Block #{self.block_id} not found for node #{self.id}")
block = self.block
if not block.webhook_config:
raise TypeError("This method can't be used on non-webhook blocks")
if not block.webhook_config.event_filter_input:
@@ -160,17 +172,18 @@ class BaseGraph(BaseDbModel):
description: str
nodes: list[Node] = []
links: list[Link] = []
forked_from_id: str | None = None
forked_from_version: int | None = None
@computed_field
@property
def input_schema(self) -> dict[str, Any]:
return self._generate_schema(
*(
(b.input_schema, node.input_default)
(block.input_schema, node.input_default)
for node in self.nodes
if (b := get_block(node.block_id))
and b.block_type == BlockType.INPUT
and issubclass(b.input_schema, AgentInputBlock.Input)
if (block := node.block).block_type == BlockType.INPUT
and issubclass(block.input_schema, AgentInputBlock.Input)
)
)
@@ -179,22 +192,26 @@ class BaseGraph(BaseDbModel):
def output_schema(self) -> dict[str, Any]:
return self._generate_schema(
*(
(b.input_schema, node.input_default)
(block.input_schema, node.input_default)
for node in self.nodes
if (b := get_block(node.block_id))
and b.block_type == BlockType.OUTPUT
and issubclass(b.input_schema, AgentOutputBlock.Input)
if (block := node.block).block_type == BlockType.OUTPUT
and issubclass(block.input_schema, AgentOutputBlock.Input)
)
)
@computed_field
@property
def credentials_input_schema(self) -> dict[str, Any]:
return self._credentials_input_schema.jsonschema()
@staticmethod
def _generate_schema(
*props: tuple[Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input], dict],
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
) -> dict[str, Any]:
schema = []
schema_fields: list[AgentInputBlock.Input | AgentOutputBlock.Input] = []
for type_class, input_default in props:
try:
schema.append(type_class(**input_default))
schema_fields.append(type_class(**input_default))
except Exception as e:
logger.warning(f"Invalid {type_class}: {input_default}, {e}")
@@ -214,9 +231,93 @@ class BaseGraph(BaseDbModel):
**({"description": p.description} if p.description else {}),
**({"default": p.value} if p.value is not None else {}),
}
for p in schema
for p in schema_fields
},
"required": [p.name for p in schema if p.value is None],
"required": [p.name for p in schema_fields if p.value is None],
}
@property
def _credentials_input_schema(self) -> type[BlockSchema]:
graph_credentials_inputs = self.aggregate_credentials_inputs()
logger.debug(
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
f"{graph_credentials_inputs}"
)
# Warn if same-provider credentials inputs can't be combined (= bad UX)
graph_cred_fields = list(graph_credentials_inputs.values())
for i, (field, keys) in enumerate(graph_cred_fields):
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
if field.provider != other_field.provider:
continue
# If this happens, that means a block implementation probably needs
# to be updated.
logger.warning(
"Multiple combined credentials fields "
f"for provider {field.provider} "
f"on graph #{self.id} ({self.name}); "
f"fields: {field} <> {other_field};"
f"keys: {keys} <> {other_keys}."
)
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
agg_field_key: (
CredentialsMetaInput[
Literal[tuple(field_info.provider)], # type: ignore
Literal[tuple(field_info.supported_types)], # type: ignore
],
CredentialsField(
required_scopes=set(field_info.required_scopes or []),
discriminator=field_info.discriminator,
discriminator_mapping=field_info.discriminator_mapping,
),
)
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
}
return create_model(
self.name.replace(" ", "") + "CredentialsInputSchema",
__base__=BlockSchema,
**fields, # type: ignore
)
def aggregate_credentials_inputs(
self,
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
"""
Returns:
dict[aggregated_field_key, tuple(
CredentialsFieldInfo: A spec for one aggregated credentials field
set[(node_id, field_name)]: Node credentials fields that are
compatible with this aggregated field spec
)]
"""
return {
"_".join(sorted(agg_field_info.provider))
+ "_"
+ "_".join(sorted(agg_field_info.supported_types))
+ "_credentials": (agg_field_info, node_fields)
for agg_field_info, node_fields in CredentialsFieldInfo.combine(
*(
(
# Apply discrimination before aggregating credentials inputs
(
field_info.discriminate(
node.input_default[field_info.discriminator]
)
if (
field_info.discriminator
and node.input_default.get(field_info.discriminator)
)
else field_info
),
(node.id, field_name),
)
for node in self.nodes
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
)
)
}
@@ -228,13 +329,16 @@ class GraphModel(Graph):
user_id: str
nodes: list[NodeModel] = [] # type: ignore
@computed_field
@property
def starting_nodes(self) -> list[Node]:
def has_webhook_trigger(self) -> bool:
return self.webhook_input_node is not None
@property
def starting_nodes(self) -> list[NodeModel]:
outbound_nodes = {link.sink_id for link in self.links}
input_nodes = {
v.id
for v in self.nodes
if (b := get_block(v.block_id)) and b.block_type == BlockType.INPUT
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
}
return [
node
@@ -242,6 +346,18 @@ class GraphModel(Graph):
if node.id not in outbound_nodes or node.id in input_nodes
]
@property
def webhook_input_node(self) -> NodeModel | None:
return next(
(
node
for node in self.nodes
if node.block.block_type
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
),
None,
)
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
"""
Reassigns all IDs in the graph to new UUIDs.
@@ -295,15 +411,16 @@ class GraphModel(Graph):
@staticmethod
def _validate_graph(graph: BaseGraph, for_run: bool = False):
def is_tool_pin(name: str) -> bool:
return name.startswith("tools_^_")
def sanitize(name):
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
if sanitized_name.startswith("tools_^_"):
return sanitized_name.split("_^_")[0]
if is_tool_pin(sanitized_name):
return "tools"
return sanitized_name
# Validate smart decision maker nodes
smart_decision_maker_nodes = set()
agent_nodes = set()
nodes_block = {
node.id: block
for node in graph.nodes
@@ -314,13 +431,6 @@ class GraphModel(Graph):
if (block := nodes_block.get(node.id)) is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
# Smart decision maker nodes
if block.block_type == BlockType.AI:
smart_decision_maker_nodes.add(node.id)
# Agent nodes
elif block.block_type == BlockType.AGENT:
agent_nodes.add(node.id)
input_links = defaultdict(list)
for link in graph.links:
@@ -335,16 +445,21 @@ class GraphModel(Graph):
[sanitize(name) for name in node.input_default]
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
)
for name in block.input_schema.get_required_fields():
input_schema = block.input_schema
for name in (required_fields := input_schema.get_required_fields()):
if (
name not in provided_inputs
# Webhook payload is passed in by ExecutionManager
and not (
name == "payload"
and block.block_type
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
)
# Checking availability of credentials is done by ExecutionManager
and name not in input_schema.get_credentials_fields()
# Validate only I/O nodes, or validate everything when executing
and (
for_run # Skip input completion validation, unless when executing.
for_run
or block.block_type
in [
BlockType.INPUT,
@@ -357,9 +472,18 @@ class GraphModel(Graph):
f"Node {block.name} #{node.id} required input missing: `{name}`"
)
if (
block.block_type == BlockType.INPUT
and (input_key := node.input_default.get("name"))
and is_credentials_field_name(input_key)
):
raise ValueError(
f"Agent input node uses reserved name '{input_key}'; "
"'credentials' and `*_credentials` are reserved input names"
)
# Get input schema properties and check dependencies
input_schema = block.input_schema.model_fields
required_fields = block.input_schema.get_required_fields()
input_fields = input_schema.model_fields
def has_value(name):
return (
@@ -367,14 +491,21 @@ class GraphModel(Graph):
and name in node.input_default
and node.input_default[name] is not None
and str(node.input_default[name]).strip() != ""
) or (name in input_schema and input_schema[name].default is not None)
) or (name in input_fields and input_fields[name].default is not None)
# Validate dependencies between fields
for field_name, field_info in input_schema.items():
for field_name, field_info in input_fields.items():
# Apply input dependency validation only on run & field with depends_on
json_schema_extra = field_info.json_schema_extra or {}
dependencies = json_schema_extra.get("depends_on", [])
if not for_run or not dependencies:
if not (
for_run
and isinstance(json_schema_extra, dict)
and (
dependencies := cast(
list[str], json_schema_extra.get("depends_on", [])
)
)
):
continue
# Check if dependent field has value in input_default
@@ -391,9 +522,7 @@ class GraphModel(Graph):
node_map = {v.id: v for v in graph.nodes}
def is_static_output_block(nid: str) -> bool:
bid = node_map[nid].block_id
b = get_block(bid)
return b.static_output if b else False
return node_map[nid].block.static_output
# Links: links are connected and the connected pin data type are compatible.
for link in graph.links:
@@ -429,7 +558,7 @@ class GraphModel(Graph):
if block.block_type not in [BlockType.AGENT]
else vals.get("input_schema", {}).get("properties", {}).keys()
)
if sanitized_name not in fields and not name.startswith("tools_^_"):
if sanitized_name not in fields and not is_tool_pin(name):
fields_msg = f"Allowed fields: {fields}"
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
@@ -446,16 +575,16 @@ class GraphModel(Graph):
id=graph.id,
user_id=graph.userId if not for_export else "",
version=graph.version,
forked_from_id=graph.forkedFromId,
forked_from_version=graph.forkedFromVersion,
is_active=graph.isActive,
name=graph.name or "",
description=graph.description or "",
nodes=[
NodeModel.from_db(node, for_export) for node in graph.AgentNodes or []
],
nodes=[NodeModel.from_db(node, for_export) for node in graph.Nodes or []],
links=list(
{
Link.from_db(link)
for node in graph.AgentNodes or []
for node in graph.Nodes or []
for link in (node.Input or []) + (node.Output or [])
}
),
@@ -586,8 +715,8 @@ async def get_graph(
and not (
await StoreListingVersion.prisma().find_first(
where={
"agentId": graph_id,
"agentVersion": version or graph.version,
"agentGraphId": graph_id,
"agentGraphVersion": version or graph.version,
"isDeleted": False,
"submissionStatus": SubmissionStatus.APPROVED,
}
@@ -607,6 +736,58 @@ async def get_graph(
return GraphModel.from_db(graph, for_export)
async def get_graph_as_admin(
graph_id: str,
version: int | None = None,
user_id: str | None = None,
for_export: bool = False,
) -> GraphModel | None:
"""
Intentionally parallels the get_graph but should only be used for admin tasks, because can return any graph that's been submitted
Retrieves a graph from the DB.
Defaults to the version with `is_active` if `version` is not passed.
Returns `None` if the record is not found.
"""
logger.warning(f"Getting {graph_id=} {version=} as ADMIN {user_id=} {for_export=}")
where_clause: AgentGraphWhereInput = {
"id": graph_id,
}
if version is not None:
where_clause["version"] = version
graph = await AgentGraph.prisma().find_first(
where=where_clause,
include=AGENT_GRAPH_INCLUDE,
order={"version": "desc"},
)
# For access, the graph must be owned by the user or listed in the store
if graph is None or (
graph.userId != user_id
and not (
await StoreListingVersion.prisma().find_first(
where={
"agentGraphId": graph_id,
"agentGraphVersion": version or graph.version,
}
)
)
):
return None
if for_export:
sub_graphs = await get_sub_graphs(graph)
return GraphModel.from_db(
graph=graph,
sub_graphs=sub_graphs,
for_export=for_export,
)
return GraphModel.from_db(graph, for_export)
async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
"""
Iteratively fetches all sub-graphs of a given graph, and flattens them into a list.
@@ -621,12 +802,16 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
sub_graph_ids = [
(graph_id, graph_version)
for graph in search_graphs
for node in graph.AgentNodes or []
for node in graph.Nodes or []
if (
node.AgentBlock
and node.AgentBlock.id == agent_block_id
and (graph_id := dict(node.constantInput).get("graph_id"))
and (graph_version := dict(node.constantInput).get("graph_version"))
and (graph_id := cast(str, dict(node.constantInput).get("graph_id")))
and (
graph_version := cast(
int, dict(node.constantInput).get("graph_version")
)
)
)
]
if not sub_graph_ids:
@@ -641,7 +826,7 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
}
for graph_id, graph_version in sub_graph_ids
] # type: ignore
]
},
include=AGENT_GRAPH_INCLUDE,
)
@@ -655,7 +840,7 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
links = await AgentNodeLink.prisma().find_many(
where={"agentNodeSourceId": node_id},
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}}, # type: ignore
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}},
)
return [
(Link.from_db(link), NodeModel.from_db(link.AgentNodeSink))
@@ -721,34 +906,56 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
async def fork_graph(graph_id: str, graph_version: int, user_id: str) -> GraphModel:
"""
Forks a graph by copying it and all its nodes and links to a new graph.
"""
async with transaction() as tx:
graph = await get_graph(graph_id, graph_version, user_id, True)
if not graph:
raise ValueError(f"Graph {graph_id} v{graph_version} not found")
# Set forked from ID and version as itself as it's about ot be copied
graph.forked_from_id = graph.id
graph.forked_from_version = graph.version
graph.name = f"{graph.name} (copy)"
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
graph.validate_graph(for_run=False)
await __create_graph(tx, graph, user_id)
return graph
async def __create_graph(tx, graph: Graph, user_id: str):
graphs = [graph] + graph.sub_graphs
await AgentGraph.prisma(tx).create_many(
data=[
{
"id": graph.id,
"version": graph.version,
"name": graph.name,
"description": graph.description,
"isActive": graph.is_active,
"userId": user_id,
}
AgentGraphCreateInput(
id=graph.id,
version=graph.version,
name=graph.name,
description=graph.description,
isActive=graph.is_active,
userId=user_id,
forkedFromId=graph.forked_from_id,
forkedFromVersion=graph.forked_from_version,
)
for graph in graphs
]
)
await AgentNode.prisma(tx).create_many(
data=[
{
"id": node.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"agentBlockId": node.block_id,
"constantInput": Json(node.input_default),
"metadata": Json(node.metadata),
"webhookId": node.webhook_id,
}
AgentNodeCreateInput(
id=node.id,
agentGraphId=graph.id,
agentGraphVersion=graph.version,
agentBlockId=node.block_id,
constantInput=Json(node.input_default),
metadata=Json(node.metadata),
)
for graph in graphs
for node in graph.nodes
]
@@ -756,14 +963,14 @@ async def __create_graph(tx, graph: Graph, user_id: str):
await AgentNodeLink.prisma(tx).create_many(
data=[
{
"id": str(uuid.uuid4()),
"sourceName": link.source_name,
"sinkName": link.sink_name,
"agentNodeSourceId": link.source_id,
"agentNodeSinkId": link.sink_id,
"isStatic": link.is_static,
}
AgentNodeLinkCreateInput(
id=str(uuid.uuid4()),
sourceName=link.source_name,
sinkName=link.sink_name,
agentNodeSourceId=link.source_id,
agentNodeSinkId=link.sink_id,
isStatic=link.is_static,
)
for graph in graphs
for link in graph.links
]
@@ -814,12 +1021,12 @@ async def fix_llm_provider_credentials():
SELECT graph."userId" user_id,
node.id node_id,
node."constantInput" node_preset_input
FROM platform."AgentNode" node
LEFT JOIN platform."AgentGraph" graph
ON node."agentGraphId" = graph.id
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
ORDER BY graph."userId";
"""
FROM platform."AgentNode" node
LEFT JOIN platform."AgentGraph" graph
ON node."agentGraphId" = graph.id
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
ORDER BY graph."userId";
"""
)
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")
except Exception as e:
@@ -897,17 +1104,24 @@ async def migrate_llm_models(migrate_to: LlmModel):
if field.annotation == LlmModel:
llm_model_fields[block.id] = field_name
# Convert enum values to a list of strings for the SQL query
enum_values = [v.value for v in LlmModel]
escaped_enum_values = repr(tuple(enum_values)) # hack but works
# Update each block
for id, path in llm_model_fields.items():
# Convert enum values to a list of strings for the SQL query
enum_values = [v.value for v in LlmModel.__members__.values()]
query = f"""
UPDATE "AgentNode"
SET "constantInput" = jsonb_set("constantInput", '{{{path}}}', '"{migrate_to.value}"', true)
WHERE "agentBlockId" = '{id}'
AND "constantInput" ? '{path}'
AND "constantInput"->>'{path}' NOT IN ({','.join(f"'{value}'" for value in enum_values)})
UPDATE platform."AgentNode"
SET "constantInput" = jsonb_set("constantInput", $1, to_jsonb($2), true)
WHERE "agentBlockId" = $3
AND "constantInput" ? ($4)::text
AND "constantInput"->>($4)::text NOT IN {escaped_enum_values}
"""
await db.execute_raw(query)
await db.execute_raw(
query, # type: ignore - is supposed to be LiteralString
[path],
migrate_to.value,
id,
path,
)

View File

@@ -1,4 +1,7 @@
import prisma
from typing import cast
import prisma.enums
import prisma.types
from backend.blocks.io import IO_BLOCK_IDs
@@ -10,25 +13,25 @@ AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
}
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
"Nodes": {"include": AGENT_NODE_INCLUDE}
}
EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
"Input": True,
"Output": True,
"AgentNode": True,
"AgentGraphExecution": True,
"Node": True,
"GraphExecution": True,
}
MAX_NODE_EXECUTIONS_FETCH = 1000
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
"AgentNodeExecutions": {
"NodeExecutions": {
"include": {
"Input": True,
"Output": True,
"AgentNode": True,
"AgentGraphExecution": True,
"Node": True,
"GraphExecution": True,
},
"order_by": [
{"queuedTime": "desc"},
@@ -40,28 +43,30 @@ GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
}
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
"AgentNodeExecutions": {
**GRAPH_EXECUTION_INCLUDE_WITH_NODES["AgentNodeExecutions"], # type: ignore
"NodeExecutions": {
**cast(
prisma.types.FindManyAgentNodeExecutionArgsFromAgentGraphExecution,
GRAPH_EXECUTION_INCLUDE_WITH_NODES["NodeExecutions"],
),
"where": {
"AgentNode": {
"AgentBlock": {"id": {"in": IO_BLOCK_IDs}}, # type: ignore
},
"Node": {"is": {"AgentBlock": {"is": {"id": {"in": IO_BLOCK_IDs}}}}},
"NOT": [{"executionStatus": prisma.enums.AgentExecutionStatus.INCOMPLETE}],
},
}
}
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
"AgentNodes": {"include": AGENT_NODE_INCLUDE}
}
def library_agent_include(user_id: str) -> prisma.types.LibraryAgentInclude:
return {
"Agent": {
"AgentGraph": {
"include": {
**AGENT_GRAPH_INCLUDE,
"AgentGraphExecution": {"where": {"userId": user_id}},
"Executions": {"where": {"userId": user_id}},
}
},
"Creator": True,

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, AsyncGenerator, Optional
from prisma import Json
from prisma.models import IntegrationWebhook
from prisma.types import IntegrationWebhookCreateInput
from pydantic import Field, computed_field
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
@@ -66,18 +67,18 @@ class Webhook(BaseDbModel):
async def create_webhook(webhook: Webhook) -> Webhook:
created_webhook = await IntegrationWebhook.prisma().create(
data={
"id": webhook.id,
"userId": webhook.user_id,
"provider": webhook.provider.value,
"credentialsId": webhook.credentials_id,
"webhookType": webhook.webhook_type,
"resource": webhook.resource,
"events": webhook.events,
"config": Json(webhook.config),
"secret": webhook.secret,
"providerWebhookId": webhook.provider_webhook_id,
}
data=IntegrationWebhookCreateInput(
id=webhook.id,
userId=webhook.user_id,
provider=webhook.provider.value,
credentialsId=webhook.credentials_id,
webhookType=webhook.webhook_type,
resource=webhook.resource,
events=webhook.events,
config=Json(webhook.config),
secret=webhook.secret,
providerWebhookId=webhook.provider_webhook_id,
)
)
return Webhook.from_db(created_webhook)

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import base64
import enum
import logging
from collections import defaultdict
from datetime import datetime, timezone
from typing import (
TYPE_CHECKING,
@@ -12,6 +14,7 @@ from typing import (
Generic,
Literal,
Optional,
Sequence,
TypedDict,
TypeVar,
get_args,
@@ -142,8 +145,12 @@ def SchemaField(
exclude: bool = False,
hidden: Optional[bool] = None,
depends_on: Optional[list[str]] = None,
ge: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
discriminator: Optional[str] = None,
json_schema_extra: Optional[dict[str, Any]] = None,
**kwargs,
) -> T:
if default is PydanticUndefined and default_factory is None:
advanced = False
@@ -170,8 +177,12 @@ def SchemaField(
title=title,
description=description,
exclude=exclude,
ge=ge,
le=le,
min_length=min_length,
max_length=max_length,
discriminator=discriminator,
json_schema_extra=json_schema_extra,
**kwargs,
) # type: ignore
@@ -292,9 +303,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
)
field_schema = model.jsonschema()["properties"][field_name]
try:
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
field_schema
)
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
except ValidationError as e:
if "Field required [type=missing" not in str(e):
raise
@@ -320,14 +329,90 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
)
class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
credentials_provider: list[CP]
credentials_scopes: Optional[list[str]] = None
credentials_types: list[CT]
provider: frozenset[CP] = Field(..., alias="credentials_provider")
supported_types: frozenset[CT] = Field(..., alias="credentials_types")
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
discriminator: Optional[str] = None
discriminator_mapping: Optional[dict[str, CP]] = None
@classmethod
def combine(
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
"""
Combines multiple CredentialsFieldInfo objects into as few as possible.
Rules:
- Items can only be combined if they have the same supported credentials types
and the same supported providers.
- When combining items, the `required_scopes` of the result is a join
of the `required_scopes` of the original items.
Params:
*fields: (CredentialsFieldInfo, key) objects to group and combine
Returns:
A sequence of tuples containing combined CredentialsFieldInfo objects and
the set of keys of the respective original items that were grouped together.
"""
if not fields:
return []
# Group fields by their provider and supported_types
grouped_fields: defaultdict[
tuple[frozenset[CP], frozenset[CT]],
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
] = defaultdict(list)
for field, key in fields:
group_key = (frozenset(field.provider), frozenset(field.supported_types))
grouped_fields[group_key].append((key, field))
# Combine fields within each group
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
for group in grouped_fields.values():
# Start with the first field in the group
_, combined = group[0]
# Track the keys that were combined
combined_keys = {key for key, _ in group}
# Combine required_scopes from all fields in the group
all_scopes = set()
for _, field in group:
if field.required_scopes:
all_scopes.update(field.required_scopes)
# Create a new combined field
result.append(
(
CredentialsFieldInfo[CP, CT](
credentials_provider=combined.provider,
credentials_types=combined.supported_types,
credentials_scopes=frozenset(all_scopes) or None,
discriminator=combined.discriminator,
discriminator_mapping=combined.discriminator_mapping,
),
combined_keys,
)
)
return result
def discriminate(self, discriminator_value: Any) -> CredentialsFieldInfo:
if not (self.discriminator and self.discriminator_mapping):
return self
discriminator_value = self.discriminator_mapping[discriminator_value]
return CredentialsFieldInfo(
credentials_provider=frozenset([discriminator_value]),
credentials_types=self.supported_types,
credentials_scopes=self.required_scopes,
)
def CredentialsField(
required_scopes: set[str] = set(),
@@ -365,6 +450,12 @@ class ContributorDetails(BaseModel):
name: str = Field(title="Name", description="The name of the contributor.")
class TopUpType(enum.Enum):
AUTO = "AUTO"
MANUAL = "MANUAL"
UNCATEGORIZED = "UNCATEGORIZED"
class AutoTopUpConfig(BaseModel):
amount: int
"""Amount of credits to top up."""
@@ -377,12 +468,18 @@ class UserTransaction(BaseModel):
transaction_time: datetime = datetime.min.replace(tzinfo=timezone.utc)
transaction_type: CreditTransactionType = CreditTransactionType.USAGE
amount: int = 0
balance: int = 0
running_balance: int = 0
current_balance: int = 0
description: str | None = None
usage_graph_id: str | None = None
usage_execution_id: str | None = None
usage_node_count: int = 0
usage_start_time: datetime = datetime.max.replace(tzinfo=timezone.utc)
user_id: str
user_email: str | None = None
reason: str | None = None
admin_email: str | None = None
extra_data: str | None = None
class TransactionHistory(BaseModel):
@@ -405,9 +502,10 @@ class RefundRequest(BaseModel):
class NodeExecutionStats(BaseModel):
"""Execution statistics for a node execution."""
class Config:
arbitrary_types_allowed = True
extra = "allow"
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
)
error: Optional[Exception | str] = None
walltime: float = 0
@@ -423,9 +521,10 @@ class NodeExecutionStats(BaseModel):
class GraphExecutionStats(BaseModel):
"""Execution statistics for a graph execution."""
class Config:
arbitrary_types_allowed = True
extra = "allow"
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
)
error: Optional[Exception | str] = None
walltime: float = Field(

View File

@@ -6,10 +6,14 @@ from typing import Annotated, Any, Generic, Optional, TypeVar, Union
from prisma import Json
from prisma.enums import NotificationType
from prisma.models import NotificationEvent, UserNotificationBatch
from prisma.types import UserNotificationBatchWhereInput
from prisma.types import (
NotificationEventCreateInput,
UserNotificationBatchCreateInput,
UserNotificationBatchWhereInput,
)
# from backend.notifications.models import NotificationEvent
from pydantic import BaseModel, EmailStr, Field, field_validator
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
from backend.server.v2.store.exceptions import DatabaseError
@@ -35,8 +39,7 @@ class QueueType(Enum):
class BaseNotificationData(BaseModel):
class Config:
extra = "allow"
model_config = ConfigDict(extra="allow")
class AgentRunData(BaseNotificationData):
@@ -418,30 +421,30 @@ async def create_or_add_to_user_notification_batch(
if not existing_batch:
async with transaction() as tx:
notification_event = await tx.notificationevent.create(
data={
"type": notification_type,
"data": json_data,
}
data=NotificationEventCreateInput(
type=notification_type,
data=json_data,
)
)
# Create new batch
resp = await tx.usernotificationbatch.create(
data={
"userId": user_id,
"type": notification_type,
"Notifications": {"connect": [{"id": notification_event.id}]},
},
data=UserNotificationBatchCreateInput(
userId=user_id,
type=notification_type,
Notifications={"connect": [{"id": notification_event.id}]},
),
include={"Notifications": True},
)
return UserNotificationBatchDTO.from_db(resp)
else:
async with transaction() as tx:
notification_event = await tx.notificationevent.create(
data={
"type": notification_type,
"data": json_data,
"UserNotificationBatch": {"connect": {"id": existing_batch.id}},
}
data=NotificationEventCreateInput(
type=notification_type,
data=json_data,
UserNotificationBatch={"connect": {"id": existing_batch.id}},
)
)
# Add to existing batch
resp = await tx.usernotificationbatch.update(

View File

@@ -6,9 +6,11 @@ import pydantic
from prisma import Json
from prisma.enums import OnboardingStep
from prisma.models import UserOnboarding
from prisma.types import UserOnboardingUpdateInput
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
from backend.data import db
from backend.data.block import get_blocks
from backend.data.credit import get_user_credit_model
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.server.v2.store.model import StoreAgentDetails
@@ -24,21 +26,26 @@ REASON_MAPPING: dict[str, list[str]] = {
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
user_credit = get_user_credit_model()
class UserOnboardingUpdate(pydantic.BaseModel):
completedSteps: Optional[list[OnboardingStep]] = None
notificationDot: Optional[bool] = None
notified: Optional[list[OnboardingStep]] = None
usageReason: Optional[str] = None
integrations: Optional[list[str]] = None
otherIntegrations: Optional[str] = None
selectedStoreListingVersionId: Optional[str] = None
agentInput: Optional[dict[str, Any]] = None
onboardingAgentExecutionId: Optional[str] = None
async def get_user_onboarding(user_id: str):
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id}, # type: ignore
"create": UserOnboardingCreateInput(userId=user_id),
"update": {},
},
)
@@ -48,6 +55,20 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update: UserOnboardingUpdateInput = {}
if data.completedSteps is not None:
update["completedSteps"] = list(set(data.completedSteps))
for step in (
OnboardingStep.AGENT_NEW_RUN,
OnboardingStep.GET_RESULTS,
OnboardingStep.MARKETPLACE_ADD_AGENT,
OnboardingStep.MARKETPLACE_RUN_AGENT,
OnboardingStep.BUILDER_SAVE_AGENT,
OnboardingStep.BUILDER_RUN_AGENT,
):
if step in data.completedSteps:
await reward_user(user_id, step)
if data.notificationDot is not None:
update["notificationDot"] = data.notificationDot
if data.notified is not None:
update["notified"] = list(set(data.notified))
if data.usageReason is not None:
update["usageReason"] = data.usageReason
if data.integrations is not None:
@@ -58,16 +79,57 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update["selectedStoreListingVersionId"] = data.selectedStoreListingVersionId
if data.agentInput is not None:
update["agentInput"] = Json(data.agentInput)
if data.onboardingAgentExecutionId is not None:
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, **update}, # type: ignore
"create": {"userId": user_id, **update},
"update": update,
},
)
async def reward_user(user_id: str, step: OnboardingStep):
async with db.locked_transaction(f"usr_trx_{user_id}-reward"):
reward = 0
match step:
# Reward user when they clicked New Run during onboarding
# This is because they need credits before scheduling a run (next step)
case OnboardingStep.AGENT_NEW_RUN:
reward = 300
case OnboardingStep.GET_RESULTS:
reward = 300
case OnboardingStep.MARKETPLACE_ADD_AGENT:
reward = 100
case OnboardingStep.MARKETPLACE_RUN_AGENT:
reward = 100
case OnboardingStep.BUILDER_SAVE_AGENT:
reward = 100
case OnboardingStep.BUILDER_RUN_AGENT:
reward = 100
if reward == 0:
return
onboarding = await get_user_onboarding(user_id)
# Skip if already rewarded
if step in onboarding.rewardedFor:
return
onboarding.rewardedFor.append(step)
await user_credit.onboarding_reward(user_id, reward, step)
await UserOnboarding.prisma().update(
where={"userId": user_id},
data={
"completedSteps": list(set(onboarding.completedSteps + [step])),
"rewardedFor": onboarding.rewardedFor,
},
)
def clean_and_split(text: str) -> list[str]:
"""
Removes all special characters from a string, truncates it to 100 characters,
@@ -186,11 +248,11 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
where={
"id": {"in": [agent.storeListingVersionId for agent in storeAgents]},
},
include={"Agent": True},
include={"AgentGraph": True},
)
for listing in agentListings:
agent = listing.Agent
agent = listing.AgentGraph
if agent is None:
continue
graph = GraphModel.from_db(agent)

View File

@@ -4,10 +4,18 @@ from enum import Enum
from typing import Awaitable, Optional
import aio_pika
import aio_pika.exceptions as aio_ex
import pika
import pika.adapters.blocking_connection
from pika.exceptions import AMQPError
from pika.spec import BasicProperties
from pydantic import BaseModel
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from backend.util.retry import conn_retry
from backend.util.settings import Settings
@@ -161,6 +169,12 @@ class SyncRabbitMQ(RabbitMQBase):
routing_key=queue.routing_key or queue.name,
)
@retry(
retry=retry_if_exception_type((AMQPError, ConnectionError)),
wait=wait_random_exponential(multiplier=1, max=5),
stop=stop_after_attempt(5),
reraise=True,
)
def publish_message(
self,
routing_key: str,
@@ -258,6 +272,12 @@ class AsyncRabbitMQ(RabbitMQBase):
exchange, routing_key=queue.routing_key or queue.name
)
@retry(
retry=retry_if_exception_type((aio_ex.AMQPError, ConnectionError)),
wait=wait_random_exponential(multiplier=1, max=5),
stop=stop_after_attempt(5),
reraise=True,
)
async def publish_message(
self,
routing_key: str,

View File

@@ -11,7 +11,7 @@ from fastapi import HTTPException
from prisma import Json
from prisma.enums import NotificationType
from prisma.models import User
from prisma.types import UserUpdateInput
from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
from backend.data.db import prisma
from backend.data.model import UserIntegrations, UserMetadata, UserMetadataRaw
@@ -36,11 +36,11 @@ async def get_or_create_user(user_data: dict) -> User:
user = await prisma.user.find_unique(where={"id": user_id})
if not user:
user = await prisma.user.create(
data={
"id": user_id,
"email": user_email,
"name": user_data.get("user_metadata", {}).get("name"),
}
data=UserCreateInput(
id=user_id,
email=user_email,
name=user_data.get("user_metadata", {}).get("name"),
)
)
return User.model_validate(user)
@@ -84,11 +84,11 @@ async def create_default_user() -> Optional[User]:
user = await prisma.user.find_unique(where={"id": DEFAULT_USER_ID})
if not user:
user = await prisma.user.create(
data={
"id": DEFAULT_USER_ID,
"email": "default@example.com",
"name": "Default User",
}
data=UserCreateInput(
id=DEFAULT_USER_ID,
email="default@example.com",
name="Default User",
)
)
return User.model_validate(user)
@@ -135,16 +135,21 @@ async def migrate_and_encrypt_user_integrations():
"""Migrate integration credentials and OAuth states from metadata to integrations column."""
users = await User.prisma().find_many(
where={
"metadata": {
"path": ["integration_credentials"],
"not": Json({"a": "yolo"}), # bogus value works to check if key exists
} # type: ignore
"metadata": cast(
JsonFilter,
{
"path": ["integration_credentials"],
"not": Json(
{"a": "yolo"}
), # bogus value works to check if key exists
},
)
}
)
logger.info(f"Migrating integration credentials for {len(users)} users")
for user in users:
raw_metadata = cast(UserMetadataRaw, user.metadata)
raw_metadata = cast(dict, user.metadata)
metadata = UserMetadata.model_validate(raw_metadata)
# Get existing integrations data
@@ -160,7 +165,6 @@ async def migrate_and_encrypt_user_integrations():
await update_user_integrations(user_id=user.id, data=integrations)
# Remove from metadata
raw_metadata = dict(raw_metadata)
raw_metadata.pop("integration_credentials", None)
raw_metadata.pop("integration_oauth_states", None)

View File

@@ -1,9 +1,10 @@
from .database import DatabaseManager
from .database import DatabaseManager, DatabaseManagerClient
from .manager import ExecutionManager
from .scheduler import Scheduler
__all__ = [
"DatabaseManager",
"DatabaseManagerClient",
"ExecutionManager",
"Scheduler",
]

View File

@@ -1,13 +1,14 @@
import logging
from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
GraphExecution,
NodeExecutionResult,
RedisExecutionEventBus,
create_graph_execution,
get_graph_execution,
get_incomplete_node_executions,
get_graph_execution_meta,
get_latest_node_execution,
get_node_execution_results,
get_node_executions,
update_graph_execution_start_time,
update_graph_execution_stats,
update_node_execution_stats,
@@ -39,11 +40,14 @@ from backend.data.user import (
update_user_integrations,
update_user_metadata,
)
from backend.util.service import AppService, expose, exposed_run_and_wait
from backend.util.service import AppService, AppServiceClient, endpoint_to_sync, expose
from backend.util.settings import Config
config = Config()
_user_credit_model = get_user_credit_model()
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
async def _spend_credits(
@@ -52,75 +56,133 @@ async def _spend_credits(
return await _user_credit_model.spend_credits(user_id, cost, metadata)
async def _get_credits(user_id: str) -> int:
return await _user_credit_model.get_credits(user_id)
class DatabaseManager(AppService):
def __init__(self):
super().__init__()
self.use_db = True
self.use_redis = True
self.execution_event_bus = RedisExecutionEventBus()
def run_service(self) -> None:
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
self.run_and_wait(db.connect())
super().run_service()
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
self.run_and_wait(db.disconnect())
@classmethod
def get_port(cls) -> int:
return config.database_api_port
@expose
def send_execution_update(
self, execution_result: GraphExecution | NodeExecutionResult
):
self.execution_event_bus.publish(execution_result)
@staticmethod
def _(
f: Callable[P, R], name: str | None = None
) -> Callable[Concatenate[object, P], R]:
if name is not None:
f.__name__ = name
return cast(Callable[Concatenate[object, P], R], expose(f))
# Executions
get_graph_execution = exposed_run_and_wait(get_graph_execution)
create_graph_execution = exposed_run_and_wait(create_graph_execution)
get_node_execution_results = exposed_run_and_wait(get_node_execution_results)
get_incomplete_node_executions = exposed_run_and_wait(
get_incomplete_node_executions
)
get_latest_node_execution = exposed_run_and_wait(get_latest_node_execution)
update_node_execution_status = exposed_run_and_wait(update_node_execution_status)
update_node_execution_status_batch = exposed_run_and_wait(
update_node_execution_status_batch
)
update_graph_execution_start_time = exposed_run_and_wait(
update_graph_execution_start_time
)
update_graph_execution_stats = exposed_run_and_wait(update_graph_execution_stats)
update_node_execution_stats = exposed_run_and_wait(update_node_execution_stats)
upsert_execution_input = exposed_run_and_wait(upsert_execution_input)
upsert_execution_output = exposed_run_and_wait(upsert_execution_output)
get_graph_execution = _(get_graph_execution)
get_graph_execution_meta = _(get_graph_execution_meta)
create_graph_execution = _(create_graph_execution)
get_node_executions = _(get_node_executions)
get_latest_node_execution = _(get_latest_node_execution)
update_node_execution_status = _(update_node_execution_status)
update_node_execution_status_batch = _(update_node_execution_status_batch)
update_graph_execution_start_time = _(update_graph_execution_start_time)
update_graph_execution_stats = _(update_graph_execution_stats)
update_node_execution_stats = _(update_node_execution_stats)
upsert_execution_input = _(upsert_execution_input)
upsert_execution_output = _(upsert_execution_output)
# Graphs
get_node = exposed_run_and_wait(get_node)
get_graph = exposed_run_and_wait(get_graph)
get_connected_output_nodes = exposed_run_and_wait(get_connected_output_nodes)
get_graph_metadata = exposed_run_and_wait(get_graph_metadata)
get_node = _(get_node)
get_graph = _(get_graph)
get_connected_output_nodes = _(get_connected_output_nodes)
get_graph_metadata = _(get_graph_metadata)
# Credits
spend_credits = exposed_run_and_wait(_spend_credits)
spend_credits = _(_spend_credits, name="spend_credits")
get_credits = _(_get_credits, name="get_credits")
# User + User Metadata + User Integrations
get_user_metadata = exposed_run_and_wait(get_user_metadata)
update_user_metadata = exposed_run_and_wait(update_user_metadata)
get_user_integrations = exposed_run_and_wait(get_user_integrations)
update_user_integrations = exposed_run_and_wait(update_user_integrations)
get_user_metadata = _(get_user_metadata)
update_user_metadata = _(update_user_metadata)
get_user_integrations = _(get_user_integrations)
update_user_integrations = _(update_user_integrations)
# User Comms - async
get_active_user_ids_in_timerange = exposed_run_and_wait(
get_active_user_ids_in_timerange
)
get_user_email_by_id = exposed_run_and_wait(get_user_email_by_id)
get_user_email_verification = exposed_run_and_wait(get_user_email_verification)
get_user_notification_preference = exposed_run_and_wait(
get_user_notification_preference
)
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
get_user_email_by_id = _(get_user_email_by_id)
get_user_email_verification = _(get_user_email_verification)
get_user_notification_preference = _(get_user_notification_preference)
# Notifications - async
create_or_add_to_user_notification_batch = exposed_run_and_wait(
create_or_add_to_user_notification_batch = _(
create_or_add_to_user_notification_batch
)
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
get_all_batches_by_type = exposed_run_and_wait(get_all_batches_by_type)
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)
get_user_notification_oldest_message_in_batch = exposed_run_and_wait(
empty_user_notification_batch = _(empty_user_notification_batch)
get_all_batches_by_type = _(get_all_batches_by_type)
get_user_notification_batch = _(get_user_notification_batch)
get_user_notification_oldest_message_in_batch = _(
get_user_notification_oldest_message_in_batch
)
class DatabaseManagerClient(AppServiceClient):
d = DatabaseManager
_ = endpoint_to_sync
@classmethod
def get_service_type(cls):
return DatabaseManager
# Executions
get_graph_execution = _(d.get_graph_execution)
get_graph_execution_meta = _(d.get_graph_execution_meta)
create_graph_execution = _(d.create_graph_execution)
get_node_executions = _(d.get_node_executions)
get_latest_node_execution = _(d.get_latest_node_execution)
update_node_execution_status = _(d.update_node_execution_status)
update_node_execution_status_batch = _(d.update_node_execution_status_batch)
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
update_graph_execution_stats = _(d.update_graph_execution_stats)
update_node_execution_stats = _(d.update_node_execution_stats)
upsert_execution_input = _(d.upsert_execution_input)
upsert_execution_output = _(d.upsert_execution_output)
# Graphs
get_node = _(d.get_node)
get_graph = _(d.get_graph)
get_connected_output_nodes = _(d.get_connected_output_nodes)
get_graph_metadata = _(d.get_graph_metadata)
# Credits
spend_credits = _(d.spend_credits)
get_credits = _(d.get_credits)
# User + User Metadata + User Integrations
get_user_metadata = _(d.get_user_metadata)
update_user_metadata = _(d.update_user_metadata)
get_user_integrations = _(d.get_user_integrations)
update_user_integrations = _(d.update_user_integrations)
# User Comms - async
get_active_user_ids_in_timerange = _(d.get_active_user_ids_in_timerange)
get_user_email_by_id = _(d.get_user_email_by_id)
get_user_email_verification = _(d.get_user_email_verification)
get_user_notification_preference = _(d.get_user_notification_preference)
# Notifications - async
create_or_add_to_user_notification_batch = _(
d.create_or_add_to_user_notification_batch
)
empty_user_notification_batch = _(d.empty_user_notification_batch)
get_all_batches_by_type = _(d.get_all_batches_by_type)
get_user_notification_batch = _(d.get_user_notification_batch)
get_user_notification_oldest_message_in_batch = _(
d.get_user_notification_oldest_message_in_batch
)

File diff suppressed because it is too large Load Diff

View File

@@ -16,9 +16,15 @@ from pydantic import BaseModel
from sqlalchemy import MetaData, create_engine
from backend.data.block import BlockInput
from backend.executor.manager import ExecutionManager
from backend.notifications.notifications import NotificationManager
from backend.util.service import AppService, expose, get_service_client
from backend.executor import utils as execution_utils
from backend.notifications.notifications import NotificationManagerClient
from backend.util.service import (
AppService,
AppServiceClient,
endpoint_to_async,
expose,
get_service_client,
)
from backend.util.settings import Config
@@ -57,25 +63,18 @@ def job_listener(event):
log(f"Job {event.job_id} completed successfully.")
@thread_cached
def get_execution_client() -> ExecutionManager:
return get_service_client(ExecutionManager)
@thread_cached
def get_notification_client():
from backend.notifications import NotificationManager
return get_service_client(NotificationManager)
return get_service_client(NotificationManagerClient)
def execute_graph(**kwargs):
args = ExecutionJobArgs(**kwargs)
try:
log(f"Executing recurring job for graph #{args.graph_id}")
get_execution_client().add_execution(
execution_utils.add_graph_execution(
graph_id=args.graph_id,
data=args.input_data,
inputs=args.input_data,
user_id=args.user_id,
graph_version=args.graph_version,
)
@@ -164,19 +163,9 @@ class Scheduler(AppService):
def db_pool_size(cls) -> int:
return config.scheduler_db_pool_size
@property
@thread_cached
def execution_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager)
@property
@thread_cached
def notification_client(self) -> NotificationManager:
return get_service_client(NotificationManager)
def run_service(self):
load_dotenv()
db_schema, db_url = _extract_schema_from_url(os.getenv("DATABASE_URL"))
db_schema, db_url = _extract_schema_from_url(os.getenv("DIRECT_URL"))
self.scheduler = BlockingScheduler(
jobstores={
Jobstores.EXECUTION.value: SQLAlchemyJobStore(
@@ -206,6 +195,12 @@ class Scheduler(AppService):
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
self.scheduler.start()
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down scheduler...")
if self.scheduler:
self.scheduler.shutdown(wait=False)
@expose
def add_execution_schedule(
self,
@@ -304,3 +299,15 @@ class Scheduler(AppService):
),
job,
)
class SchedulerClient(AppServiceClient):
@classmethod
def get_service_type(cls):
return Scheduler
add_execution_schedule = endpoint_to_async(Scheduler.add_execution_schedule)
delete_schedule = endpoint_to_async(Scheduler.delete_schedule)
get_execution_schedules = endpoint_to_async(Scheduler.get_execution_schedules)
add_batched_notification_schedule = Scheduler.add_batched_notification_schedule
add_weekly_notification_schedule = Scheduler.add_weekly_notification_schedule

View File

@@ -1,38 +1,113 @@
import logging
from typing import TYPE_CHECKING, Any, Optional, cast
from autogpt_libs.utils.cache import thread_cached
from pydantic import BaseModel
from backend.data.block import Block, BlockInput
from backend.data.block import (
Block,
BlockData,
BlockInput,
BlockSchema,
BlockType,
get_block,
)
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCostType
from backend.data.execution import (
AsyncRedisExecutionEventBus,
ExecutionStatus,
GraphExecutionStats,
GraphExecutionWithNodes,
RedisExecutionEventBus,
create_graph_execution,
update_graph_execution_stats,
update_node_execution_status_batch,
)
from backend.data.graph import GraphModel, Node, get_graph
from backend.data.model import CredentialsMetaInput
from backend.data.rabbitmq import (
AsyncRabbitMQ,
Exchange,
ExchangeType,
Queue,
RabbitMQConfig,
SyncRabbitMQ,
)
from backend.util.exceptions import NotFoundError
from backend.util.mock import MockObject
from backend.util.service import get_service_client
from backend.util.settings import Config
from backend.util.type import convert
if TYPE_CHECKING:
from backend.executor import DatabaseManagerClient
from backend.integrations.credentials_store import IntegrationCredentialsStore
config = Config()
logger = logging.getLogger(__name__)
# ============ Resource Helpers ============ #
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
node_exec_id: str | None = None
block_id: str | None = None
block: str | None = None
input: BlockInput | None = None
@thread_cached
def get_execution_event_bus() -> RedisExecutionEventBus:
return RedisExecutionEventBus()
@thread_cached
def get_async_execution_event_bus() -> AsyncRedisExecutionEventBus:
return AsyncRedisExecutionEventBus()
@thread_cached
def get_execution_queue() -> SyncRabbitMQ:
client = SyncRabbitMQ(create_execution_queue_config())
client.connect()
return client
@thread_cached
async def get_async_execution_queue() -> AsyncRabbitMQ:
client = AsyncRabbitMQ(create_execution_queue_config())
await client.connect()
return client
@thread_cached
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
from backend.integrations.credentials_store import IntegrationCredentialsStore
return IntegrationCredentialsStore()
@thread_cached
def get_db_client() -> "DatabaseManagerClient":
from backend.executor import DatabaseManagerClient
return get_service_client(DatabaseManagerClient)
# ============ Execution Cost Helpers ============ #
def execution_usage_cost(execution_count: int) -> tuple[int, int]:
"""
Calculate the cost of executing a graph based on the number of executions.
Calculate the cost of executing a graph based on the current number of node executions.
Args:
execution_count: Number of executions
execution_count: Number of node executions
Returns:
Tuple of cost amount and remaining execution count
Tuple of cost amount and the number of execution count that is included in the cost.
"""
return (
execution_count
// config.execution_cost_count_threshold
* config.execution_cost_per_threshold,
execution_count % config.execution_cost_count_threshold,
(
config.execution_cost_per_threshold
if execution_count % config.execution_cost_count_threshold == 0
else 0
),
config.execution_cost_count_threshold,
)
@@ -95,3 +170,571 @@ def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bo
or (input_data.get(k) and _is_cost_filter_match(v, input_data[k]))
for k, v in cost_filter.items()
)
# ============ Execution Input Helpers ============ #
LIST_SPLIT = "_$_"
DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
def parse_execution_output(output: BlockData, name: str) -> Any | None:
"""
Extracts partial output data by name from a given BlockData.
The function supports extracting data from lists, dictionaries, and objects
using specific naming conventions:
- For lists: <output_name>_$_<index>
- For dictionaries: <output_name>_#_<key>
- For objects: <output_name>_@_<attribute>
Args:
output (BlockData): A tuple containing the output name and data.
name (str): The name used to extract specific data from the output.
Returns:
Any | None: The extracted data if found, otherwise None.
Examples:
>>> output = ("result", [10, 20, 30])
>>> parse_execution_output(output, "result_$_1")
20
>>> output = ("config", {"key1": "value1", "key2": "value2"})
>>> parse_execution_output(output, "config_#_key1")
'value1'
>>> class Sample:
... attr1 = "value1"
... attr2 = "value2"
>>> output = ("object", Sample())
>>> parse_execution_output(output, "object_@_attr1")
'value1'
"""
output_name, output_data = output
if name == output_name:
return output_data
if name.startswith(f"{output_name}{LIST_SPLIT}"):
index = int(name.split(LIST_SPLIT)[1])
if not isinstance(output_data, list) or len(output_data) <= index:
return None
return output_data[int(name.split(LIST_SPLIT)[1])]
if name.startswith(f"{output_name}{DICT_SPLIT}"):
index = name.split(DICT_SPLIT)[1]
if not isinstance(output_data, dict) or index not in output_data:
return None
return output_data[index]
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
index = name.split(OBJC_SPLIT)[1]
if isinstance(output_data, object) and hasattr(output_data, index):
return getattr(output_data, index)
return None
return None
def validate_exec(
node: Node,
data: BlockInput,
resolve_input: bool = True,
) -> tuple[BlockInput | None, str]:
"""
Validate the input data for a node execution.
Args:
node: The node to execute.
data: The input data for the node execution.
resolve_input: Whether to resolve dynamic pins into dict/list/object.
Returns:
A tuple of the validated data and the block name.
If the data is invalid, the first element will be None, and the second element
will be an error message.
If the data is valid, the first element will be the resolved input data, and
the second element will be the block name.
"""
node_block: Block | None = get_block(node.block_id)
if not node_block:
return None, f"Block for {node.block_id} not found."
schema = node_block.input_schema
# Convert non-matching data types to the expected input schema.
for name, data_type in schema.__annotations__.items():
if (value := data.get(name)) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Input data (without default values) should contain all required fields.
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
if missing_links := schema.get_missing_links(data, node.input_links):
return None, f"{error_prefix} unpopulated links {missing_links}"
# Merge input data with default values and resolve dynamic dict/list/object pins.
input_default = schema.get_input_defaults(node.input_default)
data = {**input_default, **data}
if resolve_input:
data = merge_execution_input(data)
# Input data post-merge should contain all required fields from the schema.
if missing_input := schema.get_missing_input(data):
return None, f"{error_prefix} missing input {missing_input}"
# Last validation: Validate the input values against the schema.
if error := schema.get_mismatch_error(data):
error_message = f"{error_prefix} {error}"
logger.error(error_message)
return None, error_message
return data, node_block.name
def merge_execution_input(data: BlockInput) -> BlockInput:
"""
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
This function processes input keys that follow specific patterns to merge them into a unified structure:
- `<input_name>_$_<index>` for list inputs.
- `<input_name>_#_<index>` for dictionary inputs.
- `<input_name>_@_<index>` for object inputs.
Args:
data (BlockInput): A dictionary containing input keys and their corresponding values.
Returns:
BlockInput: A dictionary with merged inputs.
Raises:
ValueError: If a list index is not an integer.
Examples:
>>> data = {
... "list_$_0": "a",
... "list_$_1": "b",
... "dict_#_key1": "value1",
... "dict_#_key2": "value2",
... "object_@_attr1": "value1",
... "object_@_attr2": "value2"
... }
>>> merge_execution_input(data)
{
"list": ["a", "b"],
"dict": {"key1": "value1", "key2": "value2"},
"object": <MockObject attr1="value1" attr2="value2">
}
"""
# Merge all input with <input_name>_$_<index> into a single list.
items = list(data.items())
for key, value in items:
if LIST_SPLIT not in key:
continue
name, index = key.split(LIST_SPLIT)
if not index.isdigit():
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
data[name] = data.get(name, [])
if int(index) >= len(data[name]):
# Pad list with empty string on missing indices.
data[name].extend([""] * (int(index) - len(data[name]) + 1))
data[name][int(index)] = value
# Merge all input with <input_name>_#_<index> into a single dict.
for key, value in items:
if DICT_SPLIT not in key:
continue
name, index = key.split(DICT_SPLIT)
data[name] = data.get(name, {})
data[name][index] = value
# Merge all input with <input_name>_@_<index> into a single object.
for key, value in items:
if OBJC_SPLIT not in key:
continue
name, index = key.split(OBJC_SPLIT)
if name not in data or not isinstance(data[name], object):
data[name] = MockObject()
setattr(data[name], index, value)
return data
def _validate_node_input_credentials(
graph: GraphModel,
user_id: str,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
):
"""Checks all credentials for all nodes of the graph"""
for node in graph.nodes:
block = node.block
# Find any fields of type CredentialsMetaInput
credentials_fields = cast(
type[BlockSchema], block.input_schema
).get_credentials_fields()
if not credentials_fields:
continue
for field_name, credentials_meta_type in credentials_fields.items():
if (
node_credentials_input_map
and (node_credentials_inputs := node_credentials_input_map.get(node.id))
and field_name in node_credentials_inputs
):
credentials_meta = node_credentials_input_map[node.id][field_name]
elif field_name in node.input_default:
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
else:
raise ValueError(
f"Credentials absent for {block.name} node #{node.id} "
f"input '{field_name}'"
)
# Fetch the corresponding Credentials and perform sanity checks
credentials = get_integration_credentials_store().get_creds_by_id(
user_id, credentials_meta.id
)
if not credentials:
raise ValueError(
f"Unknown credentials #{credentials_meta.id} "
f"for node #{node.id} input '{field_name}'"
)
if (
credentials.provider != credentials_meta.provider
or credentials.type != credentials_meta.type
):
logger.warning(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch: "
f"{credentials_meta.type}<>{credentials.type};"
f"{credentials_meta.provider}<>{credentials.provider}"
)
raise ValueError(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch"
)
def make_node_credentials_input_map(
graph: GraphModel,
graph_credentials_input: dict[str, CredentialsMetaInput],
) -> dict[str, dict[str, CredentialsMetaInput]]:
"""
Maps credentials for an execution to the correct nodes.
Params:
graph: The graph to be executed.
graph_credentials_input: A (graph_input_name, credentials_meta) map.
Returns:
dict[node_id, dict[field_name, CredentialsMetaInput]]: Node credentials input map.
"""
result: dict[str, dict[str, CredentialsMetaInput]] = {}
# Get aggregated credentials fields for the graph
graph_cred_inputs = graph.aggregate_credentials_inputs()
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
# Best-effort map: skip missing items
if graph_input_name not in graph_credentials_input:
continue
# Use passed-in credentials for all compatible node input fields
for node_id, node_field_name in compatible_node_fields:
if node_id not in result:
result[node_id] = {}
result[node_id][node_field_name] = graph_credentials_input[graph_input_name]
return result
def construct_node_execution_input(
graph: GraphModel,
user_id: str,
graph_inputs: BlockInput,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
) -> list[tuple[str, BlockInput]]:
"""
Validates and prepares the input data for executing a graph.
This function checks the graph for starting nodes, validates the input data
against the schema, and resolves dynamic input pins into a single list,
dictionary, or object.
Args:
graph (GraphModel): The graph model to execute.
user_id (str): The ID of the user executing the graph.
data (BlockInput): The input data for the graph execution.
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
Returns:
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
the corresponding input data for that node.
"""
graph.validate_graph(for_run=True)
_validate_node_input_credentials(graph, user_id, node_credentials_input_map)
nodes_input = []
for node in graph.starting_nodes:
input_data = {}
block = node.block
# Note block should never be executed.
if block.block_type == BlockType.NOTE:
continue
# Extract request input data, and assign it to the input pin.
if block.block_type == BlockType.INPUT:
input_name = node.input_default.get("name")
if input_name and input_name in graph_inputs:
input_data = {"value": graph_inputs[input_name]}
# Extract webhook payload, and assign it to the input pin
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
if (
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
and node.webhook_id
):
if webhook_payload_key not in graph_inputs:
raise ValueError(
f"Node {block.name} #{node.id} webhook payload is missing"
)
input_data = {"payload": graph_inputs[webhook_payload_key]}
# Apply node credentials overrides
if node_credentials_input_map and (
node_credentials := node_credentials_input_map.get(node.id)
):
input_data.update({k: v.model_dump() for k, v in node_credentials.items()})
input_data, error = validate_exec(node, input_data)
if input_data is None:
raise ValueError(error)
else:
nodes_input.append((node.id, input_data))
if not nodes_input:
raise ValueError(
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
)
return nodes_input
# ============ Execution Queue Helpers ============ #
class CancelExecutionEvent(BaseModel):
graph_exec_id: str
GRAPH_EXECUTION_EXCHANGE = Exchange(
name="graph_execution",
type=ExchangeType.DIRECT,
durable=True,
auto_delete=False,
)
GRAPH_EXECUTION_QUEUE_NAME = "graph_execution_queue"
GRAPH_EXECUTION_ROUTING_KEY = "graph_execution.run"
GRAPH_EXECUTION_CANCEL_EXCHANGE = Exchange(
name="graph_execution_cancel",
type=ExchangeType.FANOUT,
durable=True,
auto_delete=True,
)
GRAPH_EXECUTION_CANCEL_QUEUE_NAME = "graph_execution_cancel_queue"
def create_execution_queue_config() -> RabbitMQConfig:
"""
Define two exchanges and queues:
- 'graph_execution' (DIRECT) for run tasks.
- 'graph_execution_cancel' (FANOUT) for cancel requests.
"""
run_queue = Queue(
name=GRAPH_EXECUTION_QUEUE_NAME,
exchange=GRAPH_EXECUTION_EXCHANGE,
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
durable=True,
auto_delete=False,
)
cancel_queue = Queue(
name=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
routing_key="", # not used for FANOUT
durable=True,
auto_delete=False,
)
return RabbitMQConfig(
vhost="/",
exchanges=[GRAPH_EXECUTION_EXCHANGE, GRAPH_EXECUTION_CANCEL_EXCHANGE],
queues=[run_queue, cancel_queue],
)
async def add_graph_execution_async(
graph_id: str,
user_id: str,
inputs: BlockInput,
preset_id: Optional[str] = None,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
) -> GraphExecutionWithNodes:
"""
Adds a graph execution to the queue and returns the execution entry.
Args:
graph_id: The ID of the graph to execute.
user_id: The ID of the user executing the graph.
inputs: The input data for the graph execution.
preset_id: The ID of the preset to use.
graph_version: The version of the graph to execute.
graph_credentials_inputs: Credentials inputs to use in the execution.
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
Returns:
GraphExecutionEntry: The entry for the graph execution.
Raises:
ValueError: If the graph is not found or if there are validation errors.
""" # noqa
graph: GraphModel | None = await get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
node_credentials_input_map = (
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else None
)
graph_exec = await create_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
starting_nodes_input=construct_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=inputs,
node_credentials_input_map=node_credentials_input_map,
),
preset_id=preset_id,
)
try:
queue = await get_async_execution_queue()
graph_exec_entry = graph_exec.to_graph_execution_entry()
if node_credentials_input_map:
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
await queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec_entry.model_dump_json(),
exchange=GRAPH_EXECUTION_EXCHANGE,
)
bus = get_async_execution_event_bus()
await bus.publish(graph_exec)
return graph_exec
except Exception as e:
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
await update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
ExecutionStatus.FAILED,
)
await update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.FAILED,
stats=GraphExecutionStats(error=str(e)),
)
raise
def add_graph_execution(
graph_id: str,
user_id: str,
inputs: BlockInput,
preset_id: Optional[str] = None,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
) -> GraphExecutionWithNodes:
"""
Adds a graph execution to the queue and returns the execution entry.
Args:
graph_id: The ID of the graph to execute.
user_id: The ID of the user executing the graph.
inputs: The input data for the graph execution.
preset_id: The ID of the preset to use.
graph_version: The version of the graph to execute.
graph_credentials_inputs: Credentials inputs to use in the execution.
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
Returns:
GraphExecutionEntry: The entry for the graph execution.
Raises:
ValueError: If the graph is not found or if there are validation errors.
"""
db = get_db_client()
graph: GraphModel | None = db.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
node_credentials_input_map = (
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else None
)
graph_exec = db.create_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
starting_nodes_input=construct_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=inputs,
node_credentials_input_map=node_credentials_input_map,
),
preset_id=preset_id,
)
try:
queue = get_execution_queue()
graph_exec_entry = graph_exec.to_graph_execution_entry()
if node_credentials_input_map:
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec_entry.model_dump_json(),
exchange=GRAPH_EXECUTION_EXCHANGE,
)
bus = get_execution_event_bus()
bus.publish(graph_exec)
return graph_exec
except Exception as e:
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
db.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
ExecutionStatus.FAILED,
)
db.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.FAILED,
stats=GraphExecutionStats(error=str(e)),
)
raise

View File

@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional
from pydantic import SecretStr
if TYPE_CHECKING:
from backend.executor.database import DatabaseManager
from backend.executor.database import DatabaseManagerClient
from autogpt_libs.utils.cache import thread_cached
from autogpt_libs.utils.synchronize import RedisKeyedMutex
@@ -161,6 +161,14 @@ smartlead_credentials = APIKeyCredentials(
expires_at=None,
)
google_maps_credentials = APIKeyCredentials(
id="9aa1bde0-4947-4a70-a20c-84daa3850d52",
provider="google_maps",
api_key=SecretStr(settings.secrets.google_maps_api_key),
title="Use Credits for Google Maps",
expires_at=None,
)
zerobounce_credentials = APIKeyCredentials(
id="63a6e279-2dc2-448e-bf57-85776f7176dc",
provider="zerobounce",
@@ -190,6 +198,7 @@ DEFAULT_CREDENTIALS = [
apollo_credentials,
smartlead_credentials,
zerobounce_credentials,
google_maps_credentials,
]
@@ -201,11 +210,11 @@ class IntegrationCredentialsStore:
@property
@thread_cached
def db_manager(self) -> "DatabaseManager":
from backend.executor.database import DatabaseManager
def db_manager(self) -> "DatabaseManagerClient":
from backend.executor.database import DatabaseManagerClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManager)
return get_service_client(DatabaseManagerClient)
def add_creds(self, user_id: str, credentials: Credentials) -> None:
with self.locked_user_integrations(user_id):
@@ -263,6 +272,8 @@ class IntegrationCredentialsStore:
all_credentials.append(smartlead_credentials)
if settings.secrets.zerobounce_api_key:
all_credentials.append(zerobounce_credentials)
if settings.secrets.google_maps_api_key:
all_credentials.append(google_maps_credentials)
return all_credentials
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:

View File

@@ -10,6 +10,7 @@ from backend.data import redis
from backend.data.model import Credentials
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.util.exceptions import MissingConfigError
from backend.util.settings import Settings
@@ -92,7 +93,7 @@ class IntegrationCredentialsManager:
fresh_credentials = oauth_handler.refresh_tokens(credentials)
self.store.update_creds(user_id, fresh_credentials)
if _lock and _lock.locked():
if _lock and _lock.locked() and _lock.owned():
_lock.release()
credentials = fresh_credentials
@@ -144,7 +145,7 @@ class IntegrationCredentialsManager:
try:
yield
finally:
if lock.locked():
if lock.locked() and lock.owned():
lock.release()
def release_all_locks(self):
@@ -153,12 +154,13 @@ class IntegrationCredentialsManager:
self.store.locks.release_all_locks()
def _get_provider_oauth_handler(provider_name: str) -> "BaseOAuthHandler":
def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandler":
provider_name = ProviderName(provider_name_str)
if provider_name not in HANDLERS_BY_NAME:
raise KeyError(f"Unknown provider '{provider_name}'")
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id")
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
if not (client_id and client_secret):
raise MissingConfigError(
f"Integration with provider '{provider_name}' is not configured",

View File

@@ -1,7 +1,7 @@
import logging
from typing import TYPE_CHECKING, Callable, Optional, cast
from backend.data.block import BlockSchema, BlockWebhookConfig, get_block
from backend.data.block import BlockSchema, BlockWebhookConfig
from backend.data.graph import set_node_webhook
from backend.integrations.webhooks import get_webhook_manager, supports_webhooks
@@ -29,12 +29,7 @@ async def on_graph_activate(
# Compare nodes in new_graph_version with previous_graph_version
updated_nodes = []
for new_node in graph.nodes:
block = get_block(new_node.block_id)
if not block:
raise ValueError(
f"Node #{new_node.id} is instance of unknown block #{new_node.block_id}"
)
block_input_schema = cast(BlockSchema, block.input_schema)
block_input_schema = cast(BlockSchema, new_node.block.input_schema)
node_credentials = None
if (
@@ -75,12 +70,7 @@ async def on_graph_deactivate(
"""
updated_nodes = []
for node in graph.nodes:
block = get_block(node.block_id)
if not block:
raise ValueError(
f"Node #{node.id} is instance of unknown block #{node.block_id}"
)
block_input_schema = cast(BlockSchema, block.input_schema)
block_input_schema = cast(BlockSchema, node.block.input_schema)
node_credentials = None
if (
@@ -113,11 +103,7 @@ async def on_node_activate(
) -> "NodeModel":
"""Hook to be called when the node is activated/created"""
block = get_block(node.block_id)
if not block:
raise ValueError(
f"Node #{node.id} is instance of unknown block #{node.block_id}"
)
block = node.block
if not block.webhook_config:
return node
@@ -224,11 +210,7 @@ async def on_node_deactivate(
"""Hook to be called when node is deactivated/deleted"""
logger.debug(f"Deactivating node #{node.id}")
block = get_block(node.block_id)
if not block:
raise ValueError(
f"Node #{node.id} is instance of unknown block #{node.block_id}"
)
block = node.block
if not block.webhook_config:
return node

View File

@@ -1,5 +1,6 @@
from .notifications import NotificationManager
from .notifications import NotificationManager, NotificationManagerClient
__all__ = [
"NotificationManager",
"NotificationManagerClient",
]

View File

@@ -9,6 +9,7 @@ from autogpt_libs.utils.cache import thread_cached
from prisma.enums import NotificationType
from pydantic import BaseModel
from backend.data import rabbitmq
from backend.data.notifications import (
BaseSummaryData,
BaseSummaryParams,
@@ -30,7 +31,13 @@ from backend.data.notifications import (
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.user import generate_unsubscribe_link
from backend.notifications.email import EmailSender
from backend.util.service import AppService, expose, get_service_client
from backend.util.service import (
AppService,
AppServiceClient,
endpoint_to_async,
expose,
get_service_client,
)
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -107,16 +114,16 @@ def create_notification_config() -> RabbitMQConfig:
@thread_cached
def get_scheduler():
from backend.executor import Scheduler
from backend.executor.scheduler import SchedulerClient
return get_service_client(Scheduler)
return get_service_client(SchedulerClient)
@thread_cached
def get_db():
from backend.executor.database import DatabaseManager
from backend.executor.database import DatabaseManagerClient
return get_service_client(DatabaseManager)
return get_service_client(DatabaseManagerClient)
class NotificationManager(AppService):
@@ -128,6 +135,20 @@ class NotificationManager(AppService):
self.running = True
self.email_sender = EmailSender()
@property
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
"""Access the RabbitMQ service. Will raise if not configured."""
if not self.rabbitmq_service:
raise RuntimeError("RabbitMQ not configured for this service")
return self.rabbitmq_service
@property
def rabbit_config(self) -> rabbitmq.RabbitMQConfig:
"""Access the RabbitMQ config. Will raise if not configured."""
if not self.rabbitmq_config:
raise RuntimeError("RabbitMQ not configured for this service")
return self.rabbitmq_config
@classmethod
def get_port(cls) -> int:
return settings.config.notification_service_port
@@ -688,10 +709,14 @@ class NotificationManager(AppService):
)
def run_service(self):
logger.info(f"[{self.service_name}] ⏳ Configuring RabbitMQ...")
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
self.run_and_wait(self.rabbitmq_service.connect())
logger.info(f"[{self.service_name}] Started notification service")
# Set up scheduler for batch processing of all notification types
# this can be changed later to spawn differnt cleanups on different schedules
# this can be changed later to spawn different cleanups on different schedules
try:
get_scheduler().add_batched_notification_schedule(
notification_types=list(NotificationType),
@@ -753,3 +778,16 @@ class NotificationManager(AppService):
"""Cleanup service resources"""
self.running = False
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
self.run_and_wait(self.rabbitmq_service.disconnect())
class NotificationManagerClient(AppServiceClient):
@classmethod
def get_service_type(cls):
return NotificationManager
queue_notification_async = endpoint_to_async(NotificationManager.queue_notification)
queue_notification = NotificationManager.queue_notification
process_existing_batches = NotificationManager.process_existing_batches
queue_weekly_summary = NotificationManager.queue_weekly_summary

View File

@@ -2,7 +2,6 @@ import logging
from collections import defaultdict
from typing import Annotated, Any, Dict, List, Optional, Sequence
from autogpt_libs.utils.cache import thread_cached
from fastapi import APIRouter, Body, Depends, HTTPException
from prisma.enums import AgentExecutionStatus, APIKeyPermission
from typing_extensions import TypedDict
@@ -13,17 +12,10 @@ from backend.data import graph as graph_db
from backend.data.api_key import APIKey
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.execution import NodeExecutionResult
from backend.executor import ExecutionManager
from backend.executor.utils import add_graph_execution_async
from backend.server.external.middleware import require_permission
from backend.util.service import get_service_client
from backend.util.settings import Settings
@thread_cached
def execution_manager_client() -> ExecutionManager:
return get_service_client(ExecutionManager)
settings = Settings()
logger = logging.getLogger(__name__)
@@ -98,20 +90,20 @@ def execute_graph_block(
path="/graphs/{graph_id}/execute/{graph_version}",
tags=["graphs"],
)
def execute_graph(
async def execute_graph(
graph_id: str,
graph_version: int,
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
) -> dict[str, Any]:
try:
graph_exec = execution_manager_client().add_execution(
graph_id,
graph_version=graph_version,
data=node_input,
graph_exec = await add_graph_execution_async(
graph_id=graph_id,
user_id=api_key.user_id,
inputs=node_input,
graph_version=graph_version,
)
return {"id": graph_exec.graph_exec_id}
return {"id": graph_exec.id}
except Exception as e:
msg = str(e).encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)
@@ -130,7 +122,7 @@ async def get_graph_execution_results(
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
results = await execution_db.get_node_execution_results(graph_exec_id)
results = await execution_db.get_node_executions(graph_exec_id)
last_result = results[-1] if results else None
execution_status = (
last_result.status if last_result else AgentExecutionStatus.INCOMPLETE

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
from typing import TYPE_CHECKING, Annotated, Literal
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from pydantic import BaseModel, Field
@@ -14,13 +15,12 @@ from backend.data.integrations import (
wait_for_webhook_event,
)
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
from backend.executor.manager import ExecutionManager
from backend.executor.utils import add_graph_execution_async
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
from backend.util.exceptions import NeedConfirmation, NotFoundError
from backend.util.service import get_service_client
from backend.util.settings import Settings
if TYPE_CHECKING:
@@ -309,19 +309,22 @@ async def webhook_ingress_generic(
if not webhook.attached_nodes:
return
executor = get_service_client(ExecutionManager)
executions: list[Awaitable] = []
for node in webhook.attached_nodes:
logger.debug(f"Webhook-attached node: {node}")
if not node.is_triggered_by_event_type(event_type):
logger.debug(f"Node #{node.id} doesn't trigger on event {event_type}")
continue
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
executor.add_execution(
graph_id=node.graph_id,
graph_version=node.graph_version,
data={f"webhook_{webhook_id}_payload": payload},
user_id=webhook.user_id,
executions.append(
add_graph_execution_async(
user_id=webhook.user_id,
graph_id=node.graph_id,
graph_version=node.graph_version,
inputs={f"webhook_{webhook_id}_payload": payload},
)
)
asyncio.gather(*executions)
@router.post("/webhooks/{webhook_id}/ping")

View File

@@ -11,14 +11,15 @@ from autogpt_libs.feature_flag.client import (
initialize_launchdarkly,
shutdown_launchdarkly,
)
from autogpt_libs.logging.utils import generate_uvicorn_config
import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.user
import backend.server.integrations.router
import backend.server.routers.postmark.postmark
import backend.server.routers.v1
import backend.server.v2.admin.credit_admin_routes
import backend.server.v2.admin.store_admin_routes
import backend.server.v2.library.db
import backend.server.v2.library.model
@@ -28,6 +29,7 @@ import backend.server.v2.store.model
import backend.server.v2.store.routes
import backend.util.service
import backend.util.settings
from backend.blocks.llm import LlmModel
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.server.external.api import external_app
@@ -56,8 +58,7 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
# FIXME ERROR: operator does not exist: text ? unknown
# await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
with launch_darkly_context():
yield
await backend.data.db.disconnect()
@@ -107,6 +108,11 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/store",
)
app.include_router(
backend.server.v2.admin.credit_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/credits",
)
app.include_router(
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
)
@@ -141,8 +147,13 @@ class AgentServer(backend.util.service.AppProcess):
server_app,
host=backend.util.settings.Config().agent_api_host,
port=backend.util.settings.Config().agent_api_port,
log_config=generate_uvicorn_config(),
)
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down Agent Server...")
@staticmethod
async def test_execute_graph(
graph_id: str,
@@ -150,11 +161,12 @@ class AgentServer(backend.util.service.AppProcess):
graph_version: Optional[int] = None,
node_input: Optional[dict[str, Any]] = None,
):
return backend.server.routers.v1.execute_graph(
return await backend.server.routers.v1.execute_graph(
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version,
node_input=node_input or {},
inputs=node_input or {},
credentials_inputs={},
)
@staticmethod
@@ -269,7 +281,9 @@ class AgentServer(backend.util.service.AppProcess):
provider: ProviderName,
credentials: Credentials,
) -> Credentials:
return backend.server.integrations.router.create_credentials(
from backend.server.integrations.router import create_credentials
return create_credentials(
user_id=user_id, provider=provider, credentials=credentials
)

View File

@@ -13,7 +13,6 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
import backend.data.block
import backend.server.integrations.router
import backend.server.routers.analytics
import backend.server.v2.library.db as library_db
@@ -31,7 +30,7 @@ from backend.data.api_key import (
suspend_api_key,
update_api_key_permissions,
)
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
from backend.data.credit import (
AutoTopUpConfig,
RefundRequest,
@@ -41,6 +40,8 @@ from backend.data.credit import (
get_user_credit_model,
set_auto_top_up,
)
from backend.data.execution import AsyncRedisExecutionEventBus
from backend.data.model import CredentialsMetaInput
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.data.onboarding import (
UserOnboardingUpdate,
@@ -49,13 +50,16 @@ from backend.data.onboarding import (
onboarding_enabled,
update_user_onboarding,
)
from backend.data.rabbitmq import AsyncRabbitMQ
from backend.data.user import (
get_or_create_user,
get_user_notification_preference,
update_user_email,
update_user_notification_preference,
)
from backend.executor import ExecutionManager, Scheduler, scheduler
from backend.executor import scheduler
from backend.executor import utils as execution_utils
from backend.executor.utils import create_execution_queue_config
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import (
on_graph_activate,
@@ -79,13 +83,20 @@ if TYPE_CHECKING:
@thread_cached
def execution_manager_client() -> ExecutionManager:
return get_service_client(ExecutionManager)
def execution_scheduler_client() -> scheduler.SchedulerClient:
return get_service_client(scheduler.SchedulerClient)
@thread_cached
def execution_scheduler_client() -> Scheduler:
return get_service_client(Scheduler)
async def execution_queue_client() -> AsyncRabbitMQ:
client = AsyncRabbitMQ(create_execution_queue_config())
await client.connect()
return client
@thread_cached
def execution_event_bus() -> AsyncRedisExecutionEventBus:
return AsyncRedisExecutionEventBus()
settings = Settings()
@@ -206,7 +217,7 @@ async def is_onboarding_enabled():
@v1_router.get(path="/blocks", tags=["blocks"], dependencies=[Depends(auth_middleware)])
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.data.block.get_blocks().values()]
blocks = [block() for block in get_blocks().values()]
costs = get_block_costs()
return [
{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks if not b.disabled
@@ -219,7 +230,7 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
dependencies=[Depends(auth_middleware)],
)
def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput:
obj = backend.data.block.get_block(block_id)
obj = get_block(block_id)
if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
@@ -308,7 +319,7 @@ async def configure_user_auto_top_up(
dependencies=[Depends(auth_middleware)],
)
async def get_user_auto_top_up(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> AutoTopUpConfig:
return await get_auto_top_up(user_id)
@@ -375,7 +386,7 @@ async def get_credit_history(
@v1_router.get(path="/credits/refunds", dependencies=[Depends(auth_middleware)])
async def get_refund_requests(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> list[RefundRequest]:
return await _user_credit_model.get_refund_requests(user_id)
@@ -391,7 +402,7 @@ class DeleteGraphResponse(TypedDict):
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
async def get_graphs(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> Sequence[graph_db.GraphModel]:
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
@@ -580,16 +591,25 @@ async def set_graph_active_version(
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
def execute_graph(
async def execute_graph(
graph_id: str,
node_input: Annotated[dict[str, Any], Body(..., default_factory=dict)],
user_id: Annotated[str, Depends(get_user_id)],
inputs: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
credentials_inputs: Annotated[
dict[str, CredentialsMetaInput], Body(..., embed=True, default_factory=dict)
],
graph_version: Optional[int] = None,
preset_id: Optional[str] = None,
) -> ExecuteGraphResponse:
graph_exec = execution_manager_client().add_execution(
graph_id, node_input, user_id=user_id, graph_version=graph_version
graph_exec = await execution_utils.add_graph_execution_async(
graph_id=graph_id,
user_id=user_id,
inputs=inputs,
preset_id=preset_id,
graph_version=graph_version,
graph_credentials_inputs=credentials_inputs,
)
return ExecuteGraphResponse(graph_exec_id=graph_exec.graph_exec_id)
return ExecuteGraphResponse(graph_exec_id=graph_exec.id)
@v1_router.post(
@@ -605,9 +625,7 @@ async def stop_graph_run(
):
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
await asyncio.to_thread(
lambda: execution_manager_client().cancel_execution(graph_exec_id)
)
await _cancel_execution(graph_exec_id)
# Retrieve & return canceled graph execution in its final state
result = await execution_db.get_graph_execution(
@@ -621,6 +639,49 @@ async def stop_graph_run(
return result
async def _cancel_execution(graph_exec_id: str):
"""
Mechanism:
1. Set the cancel event
2. Graph executor's cancel handler thread detects the event, terminates workers,
reinitializes worker pool, and returns.
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
queue_client = await execution_queue_client()
await queue_client.publish_message(
routing_key="",
message=execution_utils.CancelExecutionEvent(
graph_exec_id=graph_exec_id
).model_dump_json(),
exchange=execution_utils.GRAPH_EXECUTION_CANCEL_EXCHANGE,
)
# Update the status of the graph & node executions
await execution_db.update_graph_execution_stats(
graph_exec_id,
execution_db.ExecutionStatus.TERMINATED,
)
node_execs = [
node_exec.model_copy(update={"status": execution_db.ExecutionStatus.TERMINATED})
for node_exec in await execution_db.get_node_executions(
graph_exec_id=graph_exec_id,
statuses=[
execution_db.ExecutionStatus.QUEUED,
execution_db.ExecutionStatus.RUNNING,
execution_db.ExecutionStatus.INCOMPLETE,
],
)
]
await execution_db.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in node_execs],
execution_db.ExecutionStatus.TERMINATED,
)
await asyncio.gather(
*[execution_event_bus().publish(node_exec) for node_exec in node_execs]
)
@v1_router.get(
path="/executions",
tags=["graphs"],
@@ -718,14 +779,12 @@ async def create_schedule(
detail=f"Graph #{schedule.graph_id} v.{schedule.graph_version} not found.",
)
return await asyncio.to_thread(
lambda: execution_scheduler_client().add_execution_schedule(
graph_id=schedule.graph_id,
graph_version=graph.version,
cron=schedule.cron,
input_data=schedule.input_data,
user_id=user_id,
)
return await execution_scheduler_client().add_execution_schedule(
graph_id=schedule.graph_id,
graph_version=graph.version,
cron=schedule.cron,
input_data=schedule.input_data,
user_id=user_id,
)
@@ -734,11 +793,11 @@ async def create_schedule(
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
def delete_schedule(
async def delete_schedule(
schedule_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
return {"id": schedule_id}
@@ -747,11 +806,11 @@ def delete_schedule(
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
def get_execution_schedules(
async def get_execution_schedules(
user_id: Annotated[str, Depends(get_user_id)],
graph_id: str | None = None,
) -> list[scheduler.ExecutionJobInfo]:
return execution_scheduler_client().get_execution_schedules(
return await execution_scheduler_client().get_execution_schedules(
user_id=user_id,
graph_id=graph_id,
)
@@ -792,7 +851,7 @@ async def create_api_key(
dependencies=[Depends(auth_middleware)],
)
async def get_api_keys(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> list[APIKeyWithoutHash]:
"""List all API keys for the user"""
try:

View File

@@ -0,0 +1,77 @@
import logging
import typing
from autogpt_libs.auth import requires_admin_user
from autogpt_libs.auth.depends import get_user_id
from fastapi import APIRouter, Body, Depends
from prisma import Json
from prisma.enums import CreditTransactionType
from backend.data.credit import admin_get_user_history, get_user_credit_model
from backend.server.v2.admin.model import AddUserCreditsResponse, UserHistoryResponse
logger = logging.getLogger(__name__)
_user_credit_model = get_user_credit_model()
router = APIRouter(
prefix="/admin",
tags=["credits", "admin"],
dependencies=[Depends(requires_admin_user)],
)
@router.post("/add_credits", response_model=AddUserCreditsResponse)
async def add_user_credits(
user_id: typing.Annotated[str, Body()],
amount: typing.Annotated[int, Body()],
comments: typing.Annotated[str, Body()],
admin_user: typing.Annotated[
str,
Depends(get_user_id),
],
):
""" """
logger.info(f"Admin user {admin_user} is adding {amount} credits to user {user_id}")
new_balance, transaction_key = await _user_credit_model._add_transaction(
user_id,
amount,
transaction_type=CreditTransactionType.GRANT,
metadata=Json({"admin_id": admin_user, "reason": comments}),
)
return {
"new_balance": new_balance,
"transaction_key": transaction_key,
}
@router.get(
"/users_history",
response_model=UserHistoryResponse,
)
async def admin_get_all_user_history(
admin_user: typing.Annotated[
str,
Depends(get_user_id),
],
search: typing.Optional[str] = None,
page: int = 1,
page_size: int = 20,
transaction_filter: typing.Optional[CreditTransactionType] = None,
):
""" """
logger.info(f"Admin user {admin_user} is getting grant history")
try:
resp = await admin_get_user_history(
page=page,
page_size=page_size,
search=search,
transaction_filter=transaction_filter,
)
logger.info(f"Admin user {admin_user} got {len(resp.history)} grant history")
return resp
except Exception as e:
logger.exception(f"Error getting grant history: {e}")
raise e

View File

@@ -0,0 +1,16 @@
from pydantic import BaseModel
from backend.data.model import UserTransaction
from backend.server.model import Pagination
class UserHistoryResponse(BaseModel):
"""Response model for listings with version history"""
history: list[UserTransaction]
pagination: Pagination
class AddUserCreditsResponse(BaseModel):
new_balance: int
transaction_key: str

View File

@@ -1,4 +1,5 @@
import logging
import tempfile
import typing
import autogpt_libs.auth.depends
@@ -9,6 +10,7 @@ import prisma.enums
import backend.server.v2.store.db
import backend.server.v2.store.exceptions
import backend.server.v2.store.model
import backend.util.json
logger = logging.getLogger(__name__)
@@ -98,3 +100,47 @@ async def review_submission(
status_code=500,
content={"detail": "An error occurred while reviewing the submission"},
)
@router.get(
"/submissions/download/{store_listing_version_id}",
tags=["store", "admin"],
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
)
async def admin_download_agent_file(
user: typing.Annotated[
autogpt_libs.auth.models.User,
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
],
store_listing_version_id: str = fastapi.Path(
..., description="The ID of the agent to download"
),
) -> fastapi.responses.FileResponse:
"""
Download the agent file by streaming its content.
Args:
store_listing_version_id (str): The ID of the agent to download
Returns:
StreamingResponse: A streaming response containing the agent's graph data.
Raises:
HTTPException: If the agent is not found or an unexpected error occurs.
"""
graph_data = await backend.server.v2.store.db.get_agent(
user_id=user.user_id,
store_listing_version_id=store_listing_version_id,
)
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
# Sending graph as a stream (similar to marketplace v1)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
tmp_file.write(backend.util.json.dumps(graph_data))
tmp_file.flush()
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)

View File

@@ -13,12 +13,17 @@ import backend.server.v2.library.model as library_model
import backend.server.v2.store.exceptions as store_exceptions
import backend.server.v2.store.image_gen as store_image_gen
import backend.server.v2.store.media as store_media
from backend.data import db
from backend.data import graph as graph_db
from backend.data.db import locked_transaction
from backend.data.includes import library_agent_include
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.settings import Config
logger = logging.getLogger(__name__)
config = Config()
integration_creds_manager = IntegrationCredentialsManager()
async def list_library_agents(
@@ -68,12 +73,12 @@ async def list_library_agents(
if search_term:
where_clause["OR"] = [
{
"Agent": {
"AgentGraph": {
"is": {"name": {"contains": search_term, "mode": "insensitive"}}
}
},
{
"Agent": {
"AgentGraph": {
"is": {
"description": {"contains": search_term, "mode": "insensitive"}
}
@@ -206,7 +211,7 @@ async def add_generated_agent_image(
async def create_library_agent(
graph: backend.data.graph.GraphModel,
user_id: str,
) -> prisma.models.LibraryAgent:
) -> library_model.LibraryAgent:
"""
Adds an agent to the user's library (LibraryAgent table).
@@ -227,18 +232,21 @@ async def create_library_agent(
)
try:
return await prisma.models.LibraryAgent.prisma().create(
data={
"isCreatedByUser": (user_id == graph.user_id),
"useGraphIsActiveVersion": True,
"User": {"connect": {"id": user_id}},
"Agent": {
agent = await prisma.models.LibraryAgent.prisma().create(
data=prisma.types.LibraryAgentCreateInput(
isCreatedByUser=(user_id == graph.user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
# Creator={"connect": {"id": agent.userId}},
AgentGraph={
"connect": {
"graphVersionId": {"id": graph.id, "version": graph.version}
}
},
}
),
include={"AgentGraph": True},
)
return library_model.LibraryAgent.from_db(agent)
except prisma.errors.PrismaError as e:
logger.error(f"Database error creating agent in library: {e}")
raise store_exceptions.DatabaseError("Failed to create agent in library") from e
@@ -246,38 +254,41 @@ async def create_library_agent(
async def update_agent_version_in_library(
user_id: str,
agent_id: str,
agent_version: int,
agent_graph_id: str,
agent_graph_version: int,
) -> None:
"""
Updates the agent version in the library if useGraphIsActiveVersion is True.
Args:
user_id: Owner of the LibraryAgent.
agent_id: The agent's ID to update.
agent_version: The new version of the agent.
agent_graph_id: The agent graph's ID to update.
agent_graph_version: The new version of the agent graph.
Raises:
DatabaseError: If there's an error with the update.
"""
logger.debug(
f"Updating agent version in library for user #{user_id}, "
f"agent #{agent_id} v{agent_version}"
f"agent #{agent_graph_id} v{agent_graph_version}"
)
try:
library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise(
where={
"userId": user_id,
"agentId": agent_id,
"agentGraphId": agent_graph_id,
"useGraphIsActiveVersion": True,
},
)
await prisma.models.LibraryAgent.prisma().update(
where={"id": library_agent.id},
data={
"Agent": {
"AgentGraph": {
"connect": {
"graphVersionId": {"id": agent_id, "version": agent_version}
"graphVersionId": {
"id": agent_graph_id,
"version": agent_graph_version,
}
},
},
},
@@ -341,7 +352,7 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
"""
try:
await prisma.models.LibraryAgent.prisma().delete_many(
where={"agentId": graph_id, "userId": user_id}
where={"agentGraphId": graph_id, "userId": user_id}
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error deleting library agent: {e}")
@@ -374,10 +385,10 @@ async def add_store_agent_to_library(
async with locked_transaction(f"add_agent_trx_{user_id}"):
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}, include={"Agent": True}
where={"id": store_listing_version_id}, include={"AgentGraph": True}
)
)
if not store_listing_version or not store_listing_version.Agent:
if not store_listing_version or not store_listing_version.AgentGraph:
logger.warning(
f"Store listing version not found: {store_listing_version_id}"
)
@@ -385,7 +396,7 @@ async def add_store_agent_to_library(
f"Store listing version {store_listing_version_id} not found or invalid"
)
graph = store_listing_version.Agent
graph = store_listing_version.AgentGraph
if graph.userId == user_id:
logger.warning(
f"User #{user_id} attempted to add their own agent to their library"
@@ -397,8 +408,8 @@ async def add_store_agent_to_library(
await prisma.models.LibraryAgent.prisma().find_first(
where={
"userId": user_id,
"agentId": graph.id,
"agentVersion": graph.version,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
},
include=library_agent_include(user_id),
)
@@ -418,17 +429,17 @@ async def add_store_agent_to_library(
# Create LibraryAgent entry
added_agent = await prisma.models.LibraryAgent.prisma().create(
data={
"userId": user_id,
"agentId": graph.id,
"agentVersion": graph.version,
"isCreatedByUser": False,
},
data=prisma.types.LibraryAgentCreateInput(
userId=user_id,
agentGraphId=graph.id,
agentGraphVersion=graph.version,
isCreatedByUser=False,
),
include=library_agent_include(user_id),
)
logger.debug(
f"Added graph #{graph.id} "
f"for store listing #{store_listing_version.id} "
f"Added graph #{graph.id} v{graph.version}"
f"for store listing version #{store_listing_version.id} "
f"to library for user #{user_id}"
)
return library_model.LibraryAgent.from_db(added_agent)
@@ -467,8 +478,8 @@ async def set_is_deleted_for_library_agent(
count = await prisma.models.LibraryAgent.prisma().update_many(
where={
"userId": user_id,
"agentId": agent_id,
"agentVersion": agent_version,
"agentGraphId": agent_id,
"agentGraphVersion": agent_version,
},
data={"isDeleted": is_deleted},
)
@@ -597,6 +608,12 @@ async def upsert_preset(
f"Upserting preset #{preset_id} ({repr(preset.name)}) for user #{user_id}",
)
try:
inputs = [
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput(
name=name, data=prisma.fields.Json(data)
)
for name, data in preset.inputs.items()
]
if preset_id:
# Update existing preset
updated = await prisma.models.AgentPreset.prisma().update(
@@ -605,12 +622,7 @@ async def upsert_preset(
"name": preset.name,
"description": preset.description,
"isActive": preset.is_active,
"InputPresets": {
"create": [
{"name": name, "data": prisma.fields.Json(data)}
for name, data in preset.inputs.items()
]
},
"InputPresets": {"create": inputs},
},
include={"InputPresets": True},
)
@@ -620,20 +632,15 @@ async def upsert_preset(
else:
# Create new preset
new_preset = await prisma.models.AgentPreset.prisma().create(
data={
"userId": user_id,
"name": preset.name,
"description": preset.description,
"agentId": preset.agent_id,
"agentVersion": preset.agent_version,
"isActive": preset.is_active,
"InputPresets": {
"create": [
{"name": name, "data": prisma.fields.Json(data)}
for name, data in preset.inputs.items()
]
},
},
data=prisma.types.AgentPresetCreateInput(
userId=user_id,
name=preset.name,
description=preset.description,
agentGraphId=preset.graph_id,
agentGraphVersion=preset.graph_version,
isActive=preset.is_active,
InputPresets={"create": inputs},
),
include={"InputPresets": True},
)
return library_model.LibraryAgentPreset.from_db(new_preset)
@@ -662,3 +669,47 @@ async def delete_preset(user_id: str, preset_id: str) -> None:
except prisma.errors.PrismaError as e:
logger.error(f"Database error deleting preset: {e}")
raise store_exceptions.DatabaseError("Failed to delete preset") from e
async def fork_library_agent(library_agent_id: str, user_id: str):
"""
Clones a library agent and its underyling graph and nodes (with new ids) for the given user.
Args:
library_agent_id: The ID of the library agent to fork.
user_id: The ID of the user who owns the library agent.
Returns:
The forked LibraryAgent.
Raises:
DatabaseError: If there's an error during the forking process.
"""
logger.debug(f"Forking library agent {library_agent_id} for user {user_id}")
try:
async with db.locked_transaction(f"usr_trx_{user_id}-fork_agent"):
# Fetch the original agent
original_agent = await get_library_agent(library_agent_id, user_id)
# Check if user owns the library agent
# TODO: once we have open/closed sourced agents this needs to be enabled ~kcze
# + update library/agents/[id]/page.tsx agent actions
# if not original_agent.can_access_graph:
# raise store_exceptions.DatabaseError(
# f"User {user_id} cannot access library agent graph {library_agent_id}"
# )
# Fork the underlying graph and nodes
new_graph = await graph_db.fork_graph(
original_agent.graph_id, original_agent.graph_version, user_id
)
new_graph = await on_graph_activate(
new_graph,
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
)
# Create a library agent for the new graph
return await create_library_agent(new_graph, user_id)
except prisma.errors.PrismaError as e:
logger.error(f"Database error cloning library agent: {e}")
raise store_exceptions.DatabaseError("Failed to fork library agent") from e

View File

@@ -30,8 +30,8 @@ async def test_get_library_agents(mocker):
prisma.models.LibraryAgent(
id="ua1",
userId="test-user",
agentId="agent2",
agentVersion=1,
agentGraphId="agent2",
agentGraphVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
@@ -39,7 +39,7 @@ async def test_get_library_agents(mocker):
updatedAt=datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
Agent=prisma.models.AgentGraph(
AgentGraph=prisma.models.AgentGraph(
id="agent2",
version=1,
name="Test Agent 2",
@@ -71,8 +71,8 @@ async def test_get_library_agents(mocker):
assert result.agents[0].id == "ua1"
assert result.agents[0].name == "Test Agent 2"
assert result.agents[0].description == "Test Description 2"
assert result.agents[0].agent_id == "agent2"
assert result.agents[0].agent_version == 1
assert result.agents[0].graph_id == "agent2"
assert result.agents[0].graph_version == 1
assert result.agents[0].can_access_graph is False
assert result.agents[0].is_latest_version is True
assert result.pagination.total_items == 1
@@ -81,7 +81,7 @@ async def test_get_library_agents(mocker):
assert result.pagination.page_size == 50
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_add_agent_to_library(mocker):
await connect()
# Mock data
@@ -90,8 +90,8 @@ async def test_add_agent_to_library(mocker):
version=1,
createdAt=datetime.now(),
updatedAt=datetime.now(),
agentId="agent1",
agentVersion=1,
agentGraphId="agent1",
agentGraphVersion=1,
name="Test Agent",
subHeading="Test Agent Subheading",
imageUrls=["https://example.com/image.jpg"],
@@ -102,7 +102,7 @@ async def test_add_agent_to_library(mocker):
isAvailable=True,
storeListingId="listing123",
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
Agent=prisma.models.AgentGraph(
AgentGraph=prisma.models.AgentGraph(
id="agent1",
version=1,
name="Test Agent",
@@ -116,8 +116,8 @@ async def test_add_agent_to_library(mocker):
mock_library_agent_data = prisma.models.LibraryAgent(
id="ua1",
userId="test-user",
agentId=mock_store_listing_data.agentId,
agentVersion=1,
agentGraphId=mock_store_listing_data.agentGraphId,
agentGraphVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
@@ -125,7 +125,7 @@ async def test_add_agent_to_library(mocker):
updatedAt=datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
Agent=mock_store_listing_data.Agent,
AgentGraph=mock_store_listing_data.AgentGraph,
)
# Mock prisma calls
@@ -147,25 +147,28 @@ async def test_add_agent_to_library(mocker):
# Verify mocks called correctly
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"Agent": True}
where={"id": "version123"}, include={"AgentGraph": True}
)
mock_library_agent.return_value.find_first.assert_called_once_with(
where={
"userId": "test-user",
"agentId": "agent1",
"agentVersion": 1,
"agentGraphId": "agent1",
"agentGraphVersion": 1,
},
include=library_agent_include("test-user"),
)
mock_library_agent.return_value.create.assert_called_once_with(
data=prisma.types.LibraryAgentCreateInput(
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
userId="test-user",
agentGraphId="agent1",
agentGraphVersion=1,
isCreatedByUser=False,
),
include=library_agent_include("test-user"),
)
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_add_agent_to_library_not_found(mocker):
await connect()
# Mock prisma calls
@@ -182,5 +185,5 @@ async def test_add_agent_to_library_not_found(mocker):
# Verify mock called correctly
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"Agent": True}
where={"id": "version123"}, include={"AgentGraph": True}
)

View File

@@ -25,8 +25,8 @@ class LibraryAgent(pydantic.BaseModel):
"""
id: str
agent_id: str
agent_version: int
graph_id: str
graph_version: int
image_url: str | None
@@ -58,12 +58,12 @@ class LibraryAgent(pydantic.BaseModel):
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
model instance.
"""
if not agent.Agent:
if not agent.AgentGraph:
raise ValueError("Associated Agent record is required.")
graph = graph_model.GraphModel.from_db(agent.Agent)
graph = graph_model.GraphModel.from_db(agent.AgentGraph)
agent_updated_at = agent.Agent.updatedAt
agent_updated_at = agent.AgentGraph.updatedAt
lib_agent_updated_at = agent.updatedAt
# Compute updated_at as the latest between library agent and graph
@@ -83,21 +83,21 @@ class LibraryAgent(pydantic.BaseModel):
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
days=7
)
executions = agent.Agent.AgentGraphExecution or []
executions = agent.AgentGraph.Executions or []
status_result = _calculate_agent_status(executions, week_ago)
status = status_result.status
new_output = status_result.new_output
# Check if user can access the graph
can_access_graph = agent.Agent.userId == agent.userId
can_access_graph = agent.AgentGraph.userId == agent.userId
# Hard-coded to True until a method to check is implemented
is_latest_version = True
return LibraryAgent(
id=agent.id,
agent_id=agent.agentId,
agent_version=agent.agentVersion,
graph_id=agent.agentGraphId,
graph_version=agent.agentGraphVersion,
image_url=agent.imageUrl,
creator_name=creator_name,
creator_image_url=creator_image_url,
@@ -174,8 +174,8 @@ class LibraryAgentPreset(pydantic.BaseModel):
id: str
updated_at: datetime.datetime
agent_id: str
agent_version: int
graph_id: str
graph_version: int
name: str
description: str
@@ -194,8 +194,8 @@ class LibraryAgentPreset(pydantic.BaseModel):
return cls(
id=preset.id,
updated_at=preset.updatedAt,
agent_id=preset.agentId,
agent_version=preset.agentVersion,
graph_id=preset.agentGraphId,
graph_version=preset.agentGraphVersion,
name=preset.name,
description=preset.description,
is_active=preset.isActive,
@@ -218,8 +218,8 @@ class CreateLibraryAgentPresetRequest(pydantic.BaseModel):
name: str
description: str
inputs: block_model.BlockInput
agent_id: str
agent_version: int
graph_id: str
graph_version: int
is_active: bool

View File

@@ -5,7 +5,6 @@ import prisma.models
import pytest
import backend.server.v2.library.model as library_model
from backend.util import json
@pytest.mark.asyncio
@@ -15,19 +14,21 @@ async def test_agent_preset_from_db():
id="test-agent-123",
createdAt=datetime.datetime.now(),
updatedAt=datetime.datetime.now(),
agentId="agent-123",
agentVersion=1,
agentGraphId="agent-123",
agentGraphVersion=1,
name="Test Agent",
description="Test agent description",
isActive=True,
userId="test-user-123",
isDeleted=False,
InputPresets=[
prisma.models.AgentNodeExecutionInputOutput(
id="input-123",
time=datetime.datetime.now(),
name="input1",
data=json.dumps({"type": "string", "value": "test value"}), # type: ignore
prisma.models.AgentNodeExecutionInputOutput.model_validate(
{
"id": "input-123",
"time": datetime.datetime.now(),
"name": "input1",
"data": '{"type": "string", "value": "test value"}',
}
)
],
)
@@ -36,7 +37,7 @@ async def test_agent_preset_from_db():
agent = library_model.LibraryAgentPreset.from_db(db_agent)
assert agent.id == "test-agent-123"
assert agent.agent_version == 1
assert agent.graph_version == 1
assert agent.is_active is True
assert agent.name == "Test Agent"
assert agent.description == "Test agent description"

View File

@@ -190,3 +190,14 @@ async def update_library_agent(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update library agent",
) from e
@router.post("/{library_agent_id}/fork")
async def fork_library_agent(
library_agent_id: str,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> library_model.LibraryAgent:
return await library_db.fork_library_agent(
library_agent_id=library_agent_id,
user_id=user_id,
)

View File

@@ -2,25 +2,17 @@ import logging
from typing import Annotated, Any
import autogpt_libs.auth as autogpt_auth_lib
import autogpt_libs.utils.cache
from fastapi import APIRouter, Body, Depends, HTTPException, status
import backend.executor
import backend.server.v2.library.db as db
import backend.server.v2.library.model as models
import backend.util.service
from backend.executor.utils import add_graph_execution_async
logger = logging.getLogger(__name__)
router = APIRouter()
@autogpt_libs.utils.cache.thread_cached
def execution_manager_client() -> backend.executor.ExecutionManager:
"""Return a cached instance of ExecutionManager client."""
return backend.util.service.get_service_client(backend.executor.ExecutionManager)
@router.get(
"/presets",
summary="List presets",
@@ -226,17 +218,17 @@ async def execute_preset(
# Merge input overrides with preset inputs
merged_node_input = preset.inputs | node_input
execution = execution_manager_client().add_execution(
execution = await add_graph_execution_async(
graph_id=graph_id,
graph_version=graph_version,
data=merged_node_input,
user_id=user_id,
inputs=merged_node_input,
preset_id=preset_id,
graph_version=graph_version,
)
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
return {"id": execution.graph_exec_id}
return {"id": execution.id}
except HTTPException:
raise
except Exception as e:

View File

@@ -35,8 +35,8 @@ async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
agents=[
library_model.LibraryAgent(
id="test-agent-1",
agent_id="test-agent-1",
agent_version=1,
graph_id="test-agent-1",
graph_version=1,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
@@ -51,8 +51,8 @@ async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
),
library_model.LibraryAgent(
id="test-agent-2",
agent_id="test-agent-2",
agent_version=1,
graph_id="test-agent-2",
graph_version=1,
name="Test Agent 2",
description="Test Description 2",
image_url=None,
@@ -78,9 +78,9 @@ async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
data = library_model.LibraryAgentResponse.model_validate(response.json())
assert len(data.agents) == 2
assert data.agents[0].agent_id == "test-agent-1"
assert data.agents[0].graph_id == "test-agent-1"
assert data.agents[0].can_access_graph is True
assert data.agents[1].agent_id == "test-agent-2"
assert data.agents[1].graph_id == "test-agent-2"
assert data.agents[1].can_access_graph is False
mock_db_call.assert_called_once_with(
user_id="test-user-id",

View File

@@ -200,17 +200,17 @@ async def get_available_graph(
"isAvailable": True,
"isDeleted": False,
},
include={"Agent": {"include": {"AgentNodes": True}}},
include={"AgentGraph": {"include": {"Nodes": True}}},
)
)
if not store_listing_version or not store_listing_version.Agent:
if not store_listing_version or not store_listing_version.AgentGraph:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing version {store_listing_version_id} not found",
)
graph = GraphModel.from_db(store_listing_version.Agent)
graph = GraphModel.from_db(store_listing_version.AgentGraph)
# We return graph meta, without nodes, they cannot be just removed
# because then input_schema would be empty
return {
@@ -516,7 +516,7 @@ async def delete_store_submission(
try:
# Verify the submission belongs to this user
submission = await prisma.models.StoreListing.prisma().find_first(
where={"agentId": submission_id, "owningUserId": user_id}
where={"agentGraphId": submission_id, "owningUserId": user_id}
)
if not submission:
@@ -598,7 +598,7 @@ async def create_store_submission(
# Check if listing already exists for this agent
existing_listing = await prisma.models.StoreListing.prisma().find_first(
where=prisma.types.StoreListingWhereInput(
agentId=agent_id, owningUserId=user_id
agentGraphId=agent_id, owningUserId=user_id
)
)
@@ -625,15 +625,15 @@ async def create_store_submission(
# If no existing listing, create a new one
data = prisma.types.StoreListingCreateInput(
slug=slug,
agentId=agent_id,
agentVersion=agent_version,
agentGraphId=agent_id,
agentGraphVersion=agent_version,
owningUserId=user_id,
createdAt=datetime.now(tz=timezone.utc),
Versions={
"create": [
prisma.types.StoreListingVersionCreateInput(
agentId=agent_id,
agentVersion=agent_version,
agentGraphId=agent_id,
agentGraphVersion=agent_version,
name=name,
videoUrl=video_url,
imageUrls=image_urls,
@@ -758,8 +758,8 @@ async def create_store_version(
new_version = await prisma.models.StoreListingVersion.prisma().create(
data=prisma.types.StoreListingVersionCreateInput(
version=next_version,
agentId=agent_id,
agentVersion=agent_version,
agentGraphId=agent_id,
agentGraphVersion=agent_version,
name=name,
videoUrl=video_url,
imageUrls=image_urls,
@@ -793,6 +793,7 @@ async def create_store_version(
changes_summary=changes_summary,
version=next_version,
)
except prisma.errors.PrismaError as e:
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to create new store version"
@@ -959,17 +960,17 @@ async def get_my_agents(
try:
search_filter: prisma.types.LibraryAgentWhereInput = {
"userId": user_id,
"Agent": {"is": {"StoreListing": {"none": {"isDeleted": False}}}},
"AgentGraph": {"is": {"StoreListings": {"none": {"isDeleted": False}}}},
"isArchived": False,
"isDeleted": False,
}
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
where=search_filter,
order=[{"agentVersion": "desc"}],
order=[{"agentGraphVersion": "desc"}],
skip=(page - 1) * page_size,
take=page_size,
include={"Agent": True},
include={"AgentGraph": True},
)
total = await prisma.models.LibraryAgent.prisma().count(where=search_filter)
@@ -985,7 +986,7 @@ async def get_my_agents(
agent_image=library_agent.imageUrl,
)
for library_agent in library_agents
if (graph := library_agent.Agent)
if (graph := library_agent.AgentGraph)
]
return backend.server.v2.store.model.MyAgentsResponse(
@@ -1020,13 +1021,13 @@ async def get_agent(
graph = await backend.data.graph.get_graph(
user_id=user_id,
graph_id=store_listing_version.agentId,
version=store_listing_version.agentVersion,
graph_id=store_listing_version.agentGraphId,
version=store_listing_version.agentGraphVersion,
for_export=True,
)
if not graph:
raise ValueError(
f"Agent {store_listing_version.agentId} v{store_listing_version.agentVersion} not found"
f"Agent {store_listing_version.agentGraphId} v{store_listing_version.agentGraphVersion} not found"
)
return graph
@@ -1050,11 +1051,14 @@ async def _get_missing_sub_store_listing(
# Fetch all the sub-graphs that are listed, and return the ones missing.
store_listed_sub_graphs = {
(listing.agentId, listing.agentVersion)
(listing.agentGraphId, listing.agentGraphVersion)
for listing in await prisma.models.StoreListingVersion.prisma().find_many(
where={
"OR": [
{"agentId": sub_graph.id, "agentVersion": sub_graph.version}
{
"agentGraphId": sub_graph.id,
"agentGraphVersion": sub_graph.version,
}
for sub_graph in sub_graphs
],
"submissionStatus": prisma.enums.SubmissionStatus.APPROVED,
@@ -1084,7 +1088,7 @@ async def review_store_submission(
where={"id": store_listing_version_id},
include={
"StoreListing": True,
"Agent": {"include": AGENT_GRAPH_INCLUDE}, # type: ignore
"AgentGraph": {"include": AGENT_GRAPH_INCLUDE},
},
)
)
@@ -1096,23 +1100,23 @@ async def review_store_submission(
)
# If approving, update the listing to indicate it has an approved version
if is_approved and store_listing_version.Agent:
heading = f"Sub-graph of {store_listing_version.name}v{store_listing_version.agentVersion}"
if is_approved and store_listing_version.AgentGraph:
heading = f"Sub-graph of {store_listing_version.name}v{store_listing_version.agentGraphVersion}"
sub_store_listing_versions = [
prisma.types.StoreListingVersionCreateWithoutRelationsInput(
agentId=sub_graph.id,
agentVersion=sub_graph.version,
agentGraphId=sub_graph.id,
agentGraphVersion=sub_graph.version,
name=sub_graph.name or heading,
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
subHeading=heading,
description=f"{heading}: {sub_graph.description}",
changesSummary=f"This listing is added as a {heading} / #{store_listing_version.agentId}.",
changesSummary=f"This listing is added as a {heading} / #{store_listing_version.agentGraphId}.",
isAvailable=False, # Hide sub-graphs from the store by default.
submittedAt=datetime.now(tz=timezone.utc),
)
for sub_graph in await _get_missing_sub_store_listing(
store_listing_version.Agent
store_listing_version.AgentGraph
)
]
@@ -1155,8 +1159,8 @@ async def review_store_submission(
# Convert to Pydantic model for consistency
return backend.server.v2.store.model.StoreSubmission(
agent_id=submission.agentId,
agent_version=submission.agentVersion,
agent_id=submission.agentGraphId,
agent_version=submission.agentGraphVersion,
name=submission.name,
sub_heading=submission.subHeading,
slug=(
@@ -1294,8 +1298,8 @@ async def get_admin_listings_with_versions(
# If we have versions, turn them into StoreSubmission models
for version in listing.Versions or []:
version_model = backend.server.v2.store.model.StoreSubmission(
agent_id=version.agentId,
agent_version=version.agentVersion,
agent_id=version.agentGraphId,
agent_version=version.agentGraphVersion,
name=version.name,
sub_heading=version.subHeading,
slug=listing.slug,
@@ -1324,8 +1328,8 @@ async def get_admin_listings_with_versions(
backend.server.v2.store.model.StoreListingWithVersions(
listing_id=listing.id,
slug=listing.slug,
agent_id=listing.agentId,
agent_version=listing.agentVersion,
agent_id=listing.agentGraphId,
agent_version=listing.agentGraphVersion,
active_version_id=listing.activeVersionId,
has_approved_version=listing.hasApprovedVersion,
creator_email=creator_email,
@@ -1358,3 +1362,31 @@ async def get_admin_listings_with_versions(
page_size=page_size,
),
)
async def get_agent_as_admin(
user_id: str | None,
store_listing_version_id: str,
) -> GraphModel:
"""Get agent using the version ID and store listing version ID."""
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}
)
)
if not store_listing_version:
raise ValueError(f"Store listing version {store_listing_version_id} not found")
graph = await backend.data.graph.get_graph_as_admin(
user_id=user_id,
graph_id=store_listing_version.agentGraphId,
version=store_listing_version.agentGraphVersion,
for_export=True,
)
if not graph:
raise ValueError(
f"Agent {store_listing_version.agentGraphId} v{store_listing_version.agentGraphVersion} not found"
)
return graph

View File

@@ -170,14 +170,14 @@ async def test_create_store_submission(mocker):
isDeleted=False,
hasApprovedVersion=False,
slug="test-agent",
agentId="agent-id",
agentVersion=1,
agentGraphId="agent-id",
agentGraphVersion=1,
owningUserId="user-id",
Versions=[
prisma.models.StoreListingVersion(
id="version-id",
agentId="agent-id",
agentVersion=1,
agentGraphId="agent-id",
agentGraphVersion=1,
name="Test Agent",
description="Test description",
createdAt=datetime.now(),

View File

@@ -4,20 +4,7 @@ from typing import List
import prisma.enums
import pydantic
class Pagination(pydantic.BaseModel):
total_items: int = pydantic.Field(
description="Total number of items.", examples=[42]
)
total_pages: int = pydantic.Field(
description="Total number of pages.", examples=[97]
)
current_page: int = pydantic.Field(
description="Current_page page number.", examples=[1]
)
page_size: int = pydantic.Field(
description="Number of items per page.", examples=[25]
)
from backend.server.model import Pagination
class MyAgent(pydantic.BaseModel):

View File

@@ -5,11 +5,10 @@ from typing import Protocol
import uvicorn
from autogpt_libs.auth import parse_jwt_token
from autogpt_libs.utils.cache import thread_cached
from autogpt_libs.logging.utils import generate_uvicorn_config
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
from starlette.middleware.cors import CORSMiddleware
from backend.data import redis
from backend.data.execution import AsyncRedisExecutionEventBus
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
@@ -19,7 +18,7 @@ from backend.server.model import (
WSSubscribeGraphExecutionRequest,
WSSubscribeGraphExecutionsRequest,
)
from backend.util.service import AppProcess, get_service_client
from backend.util.service import AppProcess
from backend.util.settings import AppEnvironment, Config, Settings
logger = logging.getLogger(__name__)
@@ -46,24 +45,14 @@ def get_connection_manager():
return _connection_manager
@thread_cached
def get_db_client():
from backend.executor import DatabaseManager
return get_service_client(DatabaseManager)
async def event_broadcaster(manager: ConnectionManager):
try:
redis.connect()
event_queue = AsyncRedisExecutionEventBus()
async for event in event_queue.listen("*"):
await manager.send_execution_update(event)
except Exception as e:
logger.exception(f"Event broadcaster error: {e}")
raise
finally:
redis.disconnect()
async def authenticate_websocket(websocket: WebSocket) -> str:
@@ -286,8 +275,14 @@ class WebsocketServer(AppProcess):
allow_methods=["*"],
allow_headers=["*"],
)
uvicorn.run(
server_app,
host=Config().websocket_server_host,
port=Config().websocket_server_port,
log_config=generate_uvicorn_config(),
)
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down WebSocket Server...")

View File

@@ -15,21 +15,25 @@ def to_dict(data) -> dict:
def dumps(data) -> str:
return json.dumps(jsonable_encoder(data))
return json.dumps(to_dict(data))
T = TypeVar("T")
@overload
def loads(data: str, *args, target_type: Type[T], **kwargs) -> T: ...
def loads(data: str | bytes, *args, target_type: Type[T], **kwargs) -> T: ...
@overload
def loads(data: str, *args, **kwargs) -> Any: ...
def loads(data: str | bytes, *args, **kwargs) -> Any: ...
def loads(data: str, *args, target_type: Type[T] | None = None, **kwargs) -> Any:
def loads(
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
) -> Any:
if isinstance(data, bytes):
data = data.decode("utf-8")
parsed = json.loads(data, *args, **kwargs)
if target_type:
return type_match(parsed, target_type)

View File

@@ -1,8 +1,24 @@
import logging
import sentry_sdk
from sentry_sdk.integrations.anthropic import AnthropicIntegration
from sentry_sdk.integrations.logging import LoggingIntegration
from backend.util.settings import Settings
def sentry_init():
sentry_dsn = Settings().secrets.sentry_dsn
sentry_sdk.init(dsn=sentry_dsn, traces_sample_rate=1.0, profiles_sample_rate=1.0)
sentry_sdk.init(
dsn=sentry_dsn,
traces_sample_rate=1.0,
profiles_sample_rate=1.0,
environment=f"app:{Settings().config.app_env.value}-behave:{Settings().config.behave_as.value}",
_experiments={"enable_logs": True},
integrations=[
LoggingIntegration(sentry_logs_level=logging.INFO),
AnthropicIntegration(
include_prompts=False,
),
],
)

View File

@@ -3,7 +3,7 @@ import os
import signal
import sys
from abc import ABC, abstractmethod
from multiprocessing import Process, set_start_method
from multiprocessing import Process, get_all_start_methods, set_start_method
from typing import Optional
from backend.util.logging import configure_logging
@@ -30,7 +30,12 @@ class AppProcess(ABC):
process: Optional[Process] = None
cleaned_up = False
set_start_method("spawn", force=True)
if "forkserver" in get_all_start_methods():
set_start_method("forkserver", force=True)
else:
logger.warning("Forkserver start method is not available. Using spawn instead.")
set_start_method("spawn", force=True)
configure_logging()
sentry_init()
@@ -48,6 +53,7 @@ class AppProcess(ABC):
def service_name(cls) -> str:
return cls.__name__
@abstractmethod
def cleanup(self):
"""
Implement this method on a subclass to do post-execution cleanup,

View File

@@ -142,7 +142,7 @@ def validate_url(
# Resolve all IP addresses for the hostname
try:
ip_list = [res[4][0] for res in socket.getaddrinfo(ascii_hostname, None)]
ip_list = [str(res[4][0]) for res in socket.getaddrinfo(ascii_hostname, None)]
ipv4 = [ip for ip in ip_list if ":" not in ip]
ipv6 = [ip for ip in ip_list if ":" in ip]
ip_addresses = ipv4 + ipv6 # Prefer IPv4 over IPv6

View File

@@ -34,7 +34,7 @@ def conn_retry(
def on_retry(retry_state):
prefix = _log_prefix(resource_name, conn_id)
exception = retry_state.outcome.exception()
logger.error(f"{prefix} {action_name} failed: {exception}. Retrying now...")
logger.warning(f"{prefix} {action_name} failed: {exception}. Retrying now...")
def decorator(func):
is_coroutine = asyncio.iscoroutinefunction(func)
@@ -73,3 +73,10 @@ def conn_retry(
return async_wrapper if is_coroutine else sync_wrapper
return decorator
func_retry = retry(
reraise=False,
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=1, max=30),
)

View File

@@ -1,52 +1,36 @@
import asyncio
import builtins
import inspect
import logging
import os
import threading
import time
import typing
from abc import ABC, abstractmethod
from enum import Enum
from functools import wraps
from types import NoneType, UnionType
from functools import cached_property, update_wrapper
from typing import (
Annotated,
Any,
Awaitable,
Callable,
Concatenate,
Coroutine,
Dict,
FrozenSet,
Iterator,
List,
Optional,
ParamSpec,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
get_args,
get_origin,
)
import httpx
import Pyro5.api
import uvicorn
from fastapi import FastAPI, Request, responses
from pydantic import BaseModel, TypeAdapter, create_model
from Pyro5 import api as pyro
from Pyro5 import config as pyro_config
from backend.data import db, rabbitmq, redis
from backend.util.exceptions import InsufficientBalanceError
from backend.util.json import to_dict
from backend.util.metrics import sentry_init
from backend.util.process import AppProcess, get_service_name
from backend.util.retry import conn_retry
from backend.util.settings import Config, Secrets
from backend.util.settings import Config
logger = logging.getLogger(__name__)
T = TypeVar("T")
@@ -57,131 +41,23 @@ api_host = config.pyro_host
api_comm_retry = config.pyro_client_comm_retry
api_comm_timeout = config.pyro_client_comm_timeout
api_call_timeout = config.rpc_client_call_timeout
pyro_config.MAX_RETRIES = api_comm_retry # type: ignore
pyro_config.COMMTIMEOUT = api_comm_timeout # type: ignore
P = ParamSpec("P")
R = TypeVar("R")
EXPOSED_FLAG = "__exposed__"
def fastapi_expose(func: C) -> C:
def expose(func: C) -> C:
func = getattr(func, "__func__", func)
setattr(func, "__exposed__", True)
setattr(func, EXPOSED_FLAG, True)
return func
def fastapi_exposed_run_and_wait(
f: Callable[P, Coroutine[None, None, R]]
) -> Callable[Concatenate[object, P], R]:
# TODO:
# This function lies about its return type to make the DynamicClient
# call the function synchronously, fix this when DynamicClient can choose
# to call a function synchronously or asynchronously.
return expose(f) # type: ignore
# ----- Begin Pyro Expose Block ---- #
def pyro_expose(func: C) -> C:
"""
Decorator to mark a method or class to be exposed for remote calls.
## ⚠️ Gotcha
Aside from "simple" types, only Pydantic models are passed unscathed *if annotated*.
Any other passed or returned class objects are converted to dictionaries by Pyro.
"""
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
msg = f"Error in {func.__name__}: {e}"
if isinstance(e, ValueError):
logger.warning(msg)
else:
logger.exception(msg)
raise
register_pydantic_serializers(func)
return pyro.expose(wrapper) # type: ignore
def register_pydantic_serializers(func: Callable):
"""Register custom serializers and deserializers for annotated Pydantic models"""
for name, annotation in func.__annotations__.items():
try:
pydantic_types = _pydantic_models_from_type_annotation(annotation)
except Exception as e:
raise TypeError(f"Error while exposing {func.__name__}: {e}")
for model in pydantic_types:
logger.debug(
f"Registering Pyro (de)serializers for {func.__name__} annotation "
f"'{name}': {model.__qualname__}"
)
pyro.register_class_to_dict(model, _make_custom_serializer(model))
pyro.register_dict_to_class(
model.__qualname__, _make_custom_deserializer(model)
)
def _make_custom_serializer(model: Type[BaseModel]):
def custom_class_to_dict(obj):
data = {
"__class__": obj.__class__.__qualname__,
**obj.model_dump(),
}
logger.debug(f"Serializing {obj.__class__.__qualname__} with data: {data}")
return data
return custom_class_to_dict
def _make_custom_deserializer(model: Type[BaseModel]):
def custom_dict_to_class(qualname, data: dict):
logger.debug(f"Deserializing {model.__qualname__} from data: {data}")
return model(**data)
return custom_dict_to_class
def pyro_exposed_run_and_wait(
f: Callable[P, Coroutine[None, None, R]]
) -> Callable[Concatenate[object, P], R]:
@expose
@wraps(f)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
coroutine = f(*args, **kwargs)
res = self.run_and_wait(coroutine)
return res
# Register serializers for annotations on bare function
register_pydantic_serializers(f)
return wrapper
if config.use_http_based_rpc:
expose = fastapi_expose
exposed_run_and_wait = fastapi_exposed_run_and_wait
else:
expose = pyro_expose
exposed_run_and_wait = pyro_exposed_run_and_wait
# ----- End Pyro Expose Block ---- #
# --------------------------------------------------
# AppService for IPC service based on HTTP request through FastAPI
# --------------------------------------------------
class BaseAppService(AppProcess, ABC):
shared_event_loop: asyncio.AbstractEventLoop
use_db: bool = False
use_redis: bool = False
rabbitmq_config: Optional[rabbitmq.RabbitMQConfig] = None
rabbitmq_service: Optional[rabbitmq.AsyncRabbitMQ] = None
use_supabase: bool = False
@classmethod
@abstractmethod
@@ -202,20 +78,6 @@ class BaseAppService(AppProcess, ABC):
return target_host
@property
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
"""Access the RabbitMQ service. Will raise if not configured."""
if not self.rabbitmq_service:
raise RuntimeError("RabbitMQ not configured for this service")
return self.rabbitmq_service
@property
def rabbit_config(self) -> rabbitmq.RabbitMQConfig:
"""Access the RabbitMQ config. Will raise if not configured."""
if not self.rabbitmq_config:
raise RuntimeError("RabbitMQ not configured for this service")
return self.rabbitmq_config
def run_service(self) -> None:
while True:
time.sleep(10)
@@ -225,31 +87,6 @@ class BaseAppService(AppProcess, ABC):
def run(self):
self.shared_event_loop = asyncio.get_event_loop()
if self.use_db:
self.shared_event_loop.run_until_complete(db.connect())
if self.use_redis:
redis.connect()
if self.rabbitmq_config:
logger.info(f"[{self.__class__.__name__}] ⏳ Configuring RabbitMQ...")
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
self.shared_event_loop.run_until_complete(self.rabbitmq_service.connect())
if self.use_supabase:
from supabase import create_client
secrets = Secrets()
self.supabase = create_client(
secrets.supabase_url, secrets.supabase_service_role_key
)
def cleanup(self):
if self.use_db:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting DB...")
self.run_and_wait(db.disconnect())
if self.use_redis:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting Redis...")
redis.disconnect()
if self.rabbitmq_config:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting RabbitMQ...")
class RemoteCallError(BaseModel):
@@ -268,7 +105,7 @@ EXCEPTION_MAPPING = {
}
class FastApiAppService(BaseAppService, ABC):
class AppService(BaseAppService, ABC):
fastapi_app: FastAPI
@staticmethod
@@ -324,14 +161,16 @@ class FastApiAppService(BaseAppService, ABC):
async def async_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
return await f(
**{name: getattr(body, name) for name in body.model_fields}
**{name: getattr(body, name) for name in type(body).model_fields}
)
return async_endpoint
else:
def sync_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
return f(**{name: getattr(body, name) for name in body.model_fields})
return f(
**{name: getattr(body, name) for name in type(body).model_fields}
)
return sync_endpoint
@@ -351,12 +190,13 @@ class FastApiAppService(BaseAppService, ABC):
self.shared_event_loop.run_until_complete(server.serve())
def run(self):
sentry_init()
super().run()
self.fastapi_app = FastAPI()
# Register the exposed API routes.
for attr_name, attr in vars(type(self)).items():
if getattr(attr, "__exposed__", False):
if getattr(attr, EXPOSED_FLAG, False):
route_path = f"/{attr_name}"
self.fastapi_app.add_api_route(
route_path,
@@ -381,86 +221,58 @@ class FastApiAppService(BaseAppService, ABC):
self.run_service()
# ----- Begin Pyro AppService Block ---- #
class PyroAppService(BaseAppService, ABC):
@conn_retry("Pyro", "Starting Pyro Service")
def __start_pyro(self):
maximum_connection_thread_count = max(
Pyro5.config.THREADPOOL_SIZE,
config.num_node_workers * config.num_graph_workers,
)
Pyro5.config.THREADPOOL_SIZE = maximum_connection_thread_count # type: ignore
daemon = Pyro5.api.Daemon(host=api_host, port=self.get_port())
self.uri = daemon.register(self, objectId=self.service_name)
logger.info(f"[{self.service_name}] Connected to Pyro; URI = {self.uri}")
daemon.requestLoop()
def run(self):
super().run()
# Initialize the async loop.
async_thread = threading.Thread(target=self.shared_event_loop.run_forever)
async_thread.daemon = True
async_thread.start()
# Initialize pyro service
daemon_thread = threading.Thread(target=self.__start_pyro)
daemon_thread.daemon = True
daemon_thread.start()
# Run the main service loop (blocking).
self.run_service()
if config.use_http_based_rpc:
class AppService(FastApiAppService, ABC): # type: ignore #AppService defined twice
pass
else:
class AppService(PyroAppService, ABC):
pass
# ----- End Pyro AppService Block ---- #
# --------------------------------------------------
# HTTP Client utilities for dynamic service client abstraction
# --------------------------------------------------
AS = TypeVar("AS", bound=AppService)
def fastapi_close_service_client(client: Any) -> None:
if hasattr(client, "close"):
client.close()
else:
logger.warning(f"Client {client} is not closable")
class AppServiceClient(ABC):
@classmethod
@abstractmethod
def get_service_type(cls) -> Type[AppService]:
pass
def health_check(self):
pass
def close(self):
pass
@conn_retry("FastAPI client", "Creating service client", max_retry=api_comm_retry)
def fastapi_get_service_client(
service_type: Type[AS],
ASC = TypeVar("ASC", bound=AppServiceClient)
@conn_retry("AppService client", "Creating service client", max_retry=api_comm_retry)
def get_service_client(
service_client_type: Type[ASC],
call_timeout: int | None = api_call_timeout,
) -> AS:
) -> ASC:
class DynamicClient:
def __init__(self):
service_type = service_client_type.get_service_type()
host = service_type.get_host()
port = service_type.get_port()
self.base_url = f"http://{host}:{port}".rstrip("/")
self.client = httpx.Client(
@cached_property
def sync_client(self) -> httpx.Client:
return httpx.Client(
base_url=self.base_url,
timeout=call_timeout,
)
def _call_method(self, method_name: str, **kwargs) -> Any:
@cached_property
def async_client(self) -> httpx.AsyncClient:
return httpx.AsyncClient(
base_url=self.base_url,
timeout=call_timeout,
)
def _handle_call_method_response(
self, response: httpx.Response, method_name: str
) -> Any:
try:
response = self.client.post(method_name, json=to_dict(kwargs))
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
@@ -471,126 +283,102 @@ def fastapi_get_service_client(
*(error.args or [str(e)])
)
def _call_method_sync(self, method_name: str, **kwargs) -> Any:
return self._handle_call_method_response(
method_name=method_name,
response=self.sync_client.post(method_name, json=to_dict(kwargs)),
)
async def _call_method_async(self, method_name: str, **kwargs) -> Any:
return self._handle_call_method_response(
method_name=method_name,
response=await self.async_client.post(
method_name, json=to_dict(kwargs)
),
)
async def aclose(self):
self.sync_client.close()
await self.async_client.aclose()
def close(self):
self.client.close()
self.sync_client.close()
def _get_params(self, signature: inspect.Signature, *args, **kwargs) -> dict:
if args:
arg_names = list(signature.parameters.keys())
if arg_names[0] in ("self", "cls"):
arg_names = arg_names[1:]
kwargs.update(dict(zip(arg_names, args)))
return kwargs
def _get_return(self, expected_return: TypeAdapter | None, result: Any) -> Any:
if expected_return:
return expected_return.validate_python(result)
return result
def __getattr__(self, name: str) -> Callable[..., Any]:
# Try to get the original function from the service type.
orig_func = getattr(service_type, name, None)
if orig_func is None:
raise AttributeError(f"Method {name} not found in {service_type}")
original_func = getattr(service_client_type, name, None)
if original_func is None:
raise AttributeError(
f"Method {name} not found in {service_client_type}"
)
else:
name = original_func.__name__
sig = inspect.signature(orig_func)
sig = inspect.signature(original_func)
ret_ann = sig.return_annotation
if ret_ann != inspect.Signature.empty:
expected_return = TypeAdapter(ret_ann)
else:
expected_return = None
def method(*args, **kwargs) -> Any:
if args:
arg_names = list(sig.parameters.keys())
if arg_names[0] in ("self", "cls"):
arg_names = arg_names[1:]
kwargs.update(dict(zip(arg_names, args)))
result = self._call_method(name, **kwargs)
if expected_return:
return expected_return.validate_python(result)
return result
if inspect.iscoroutinefunction(original_func):
return method
async def async_method(*args, **kwargs) -> Any:
params = self._get_params(sig, *args, **kwargs)
result = await self._call_method_async(name, **params)
return self._get_return(expected_return, result)
client = cast(AS, DynamicClient())
return async_method
else:
def sync_method(*args, **kwargs) -> Any:
params = self._get_params(sig, *args, **kwargs)
result = self._call_method_sync(name, **params)
return self._get_return(expected_return, result)
return sync_method
client = cast(ASC, DynamicClient())
client.health_check()
return cast(AS, client)
return client
# ----- Begin Pyro Client Block ---- #
class PyroClient:
proxy: Pyro5.api.Proxy
def endpoint_to_sync(
func: Callable[Concatenate[Any, P], Awaitable[R]],
) -> Callable[Concatenate[Any, P], R]:
"""
Produce a *typed* stub that **looks** synchronous to the typechecker.
"""
def _stub(*args: P.args, **kwargs: P.kwargs) -> R: # pragma: no cover
raise RuntimeError("should be intercepted by __getattr__")
update_wrapper(_stub, func)
return cast(Callable[Concatenate[Any, P], R], _stub)
def pyro_close_service_client(client: BaseAppService) -> None:
if isinstance(client, PyroClient):
client.proxy._pyroRelease()
else:
raise RuntimeError(f"Client {client.__class__} is not a Pyro client.")
def endpoint_to_async(
func: Callable[Concatenate[Any, P], R],
) -> Callable[Concatenate[Any, P], Awaitable[R]]:
"""
The async mirror of `to_sync`.
"""
async def _stub(*args: P.args, **kwargs: P.kwargs) -> R: # pragma: no cover
raise RuntimeError("should be intercepted by __getattr__")
def pyro_get_service_client(service_type: Type[AS]) -> AS:
service_name = service_type.service_name
class DynamicClient(PyroClient):
@conn_retry("Pyro", f"Connecting to [{service_name}]")
def __init__(self):
uri = f"PYRO:{service_type.service_name}@{service_type.get_host()}:{service_type.get_port()}"
logger.debug(f"Connecting to service [{service_name}]. URI = {uri}")
self.proxy = Pyro5.api.Proxy(uri)
# Attempt to bind to ensure the connection is established
self.proxy._pyroBind()
logger.debug(f"Successfully connected to service [{service_name}]")
def __getattr__(self, name: str) -> Callable[..., Any]:
res = getattr(self.proxy, name)
return res
return cast(AS, DynamicClient())
builtin_types = [*vars(builtins).values(), NoneType, Enum]
def _pydantic_models_from_type_annotation(annotation) -> Iterator[type[BaseModel]]:
# Peel Annotated parameters
if (origin := get_origin(annotation)) and origin is Annotated:
annotation = get_args(annotation)[0]
origin = get_origin(annotation)
args = get_args(annotation)
if origin in (
Union,
UnionType,
list,
List,
tuple,
Tuple,
set,
Set,
frozenset,
FrozenSet,
):
for arg in args:
yield from _pydantic_models_from_type_annotation(arg)
elif origin in (dict, Dict):
key_type, value_type = args
yield from _pydantic_models_from_type_annotation(key_type)
yield from _pydantic_models_from_type_annotation(value_type)
elif origin in (Awaitable, Coroutine):
# For coroutines and awaitables, check the return type
return_type = args[-1]
yield from _pydantic_models_from_type_annotation(return_type)
else:
annotype = annotation if origin is None else origin
# Exclude generic types and aliases
if (
annotype is not None
and not hasattr(typing, getattr(annotype, "__name__", ""))
and isinstance(annotype, type)
):
if issubclass(annotype, BaseModel):
yield annotype
elif annotype not in builtin_types and not issubclass(annotype, Enum):
raise TypeError(f"Unsupported type encountered: {annotype}")
if config.use_http_based_rpc:
close_service_client = fastapi_close_service_client
get_service_client = fastapi_get_service_client
else:
close_service_client = pyro_close_service_client
get_service_client = pyro_get_service_client
# ----- End Pyro Client Block ---- #
update_wrapper(_stub, func)
return cast(Callable[Concatenate[Any, P], Awaitable[R]], _stub)

Some files were not shown because too many files have changed in this diff Show More