mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge branch 'dev' into abhi-9274/postgres-integration
This commit is contained in:
@@ -34,6 +34,7 @@ jobs:
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
|
||||
trigger:
|
||||
|
||||
@@ -36,6 +36,7 @@ jobs:
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
trigger:
|
||||
needs: migrate
|
||||
|
||||
12
.github/workflows/platform-backend-ci.yml
vendored
12
.github/workflows/platform-backend-ci.yml
vendored
@@ -135,6 +135,7 @@ jobs:
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
@@ -151,12 +152,13 @@ jobs:
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: 'localhost'
|
||||
REDIS_PORT: '6379'
|
||||
REDIS_PASSWORD: 'testpassword'
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
|
||||
env:
|
||||
CI: true
|
||||
@@ -169,8 +171,8 @@ jobs:
|
||||
# If you want to replace this, you can do so by making our entire system generate
|
||||
# new credentials for each local user and update the environment variables in
|
||||
# the backend service, docker composes, and examples
|
||||
RABBITMQ_DEFAULT_USER: 'rabbitmq_user_default'
|
||||
RABBITMQ_DEFAULT_PASS: 'k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7'
|
||||
RABBITMQ_DEFAULT_USER: "rabbitmq_user_default"
|
||||
RABBITMQ_DEFAULT_PASS: "k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7"
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
|
||||
@@ -16,7 +16,7 @@ jobs:
|
||||
# operations-per-run: 5000
|
||||
stale-issue-message: >
|
||||
This issue has automatically been marked as _stale_ because it has not had
|
||||
any activity in the last 50 days. You can _unstale_ it by commenting or
|
||||
any activity in the last 170 days. You can _unstale_ it by commenting or
|
||||
removing the label. Otherwise, this issue will be closed in 10 days.
|
||||
stale-pr-message: >
|
||||
This pull request has automatically been marked as _stale_ because it has
|
||||
@@ -25,7 +25,7 @@ jobs:
|
||||
close-issue-message: >
|
||||
This issue was closed automatically because it has been stale for 10 days
|
||||
with no activity.
|
||||
days-before-stale: 100
|
||||
days-before-stale: 170
|
||||
days-before-close: 10
|
||||
# Do not touch meta issues:
|
||||
exempt-issue-labels: meta,fridge,project management
|
||||
|
||||
@@ -15,7 +15,11 @@
|
||||
> Setting up and hosting the AutoGPT Platform yourself is a technical process.
|
||||
> If you'd rather something that just works, we recommend [joining the waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta.
|
||||
|
||||
https://github.com/user-attachments/assets/d04273a5-b36a-4a37-818e-f631ce72d603
|
||||
### Updated Setup Instructions:
|
||||
We’ve moved to a fully maintained and regularly updated documentation site.
|
||||
|
||||
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
|
||||
|
||||
|
||||
This tutorial assumes you have Docker, VSCode, git and npm installed.
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from .filters import BelowLevelFilter
|
||||
from .formatters import AGPTFormatter, StructuredLoggingFormatter
|
||||
from .formatters import AGPTFormatter
|
||||
|
||||
LOG_DIR = Path(__file__).parent.parent.parent.parent / "logs"
|
||||
LOG_FILE = "activity.log"
|
||||
@@ -81,9 +81,26 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
"""
|
||||
|
||||
config = LoggingConfig()
|
||||
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
# Console output handlers
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
log_handlers += [stdout, stderr]
|
||||
|
||||
# Cloud logging setup
|
||||
if config.enable_cloud_logging or force_cloud_logging:
|
||||
import google.cloud.logging
|
||||
@@ -97,26 +114,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
transport=SyncTransport,
|
||||
)
|
||||
cloud_handler.setLevel(config.level)
|
||||
cloud_handler.setFormatter(StructuredLoggingFormatter())
|
||||
log_handlers.append(cloud_handler)
|
||||
else:
|
||||
# Console output handlers
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
log_handlers += [stdout, stderr]
|
||||
|
||||
# File logging setup
|
||||
if config.enable_file_logging:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
|
||||
from colorama import Fore, Style
|
||||
from google.cloud.logging_v2.handlers import CloudLoggingFilter, StructuredLogHandler
|
||||
|
||||
from .utils import remove_color_codes
|
||||
|
||||
@@ -80,16 +79,3 @@ class AGPTFormatter(FancyConsoleFormatter):
|
||||
return remove_color_codes(super().format(record))
|
||||
else:
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class StructuredLoggingFormatter(StructuredLogHandler, logging.Formatter):
|
||||
def __init__(self):
|
||||
# Set up CloudLoggingFilter to add diagnostic info to the log records
|
||||
self.cloud_logging_filter = CloudLoggingFilter()
|
||||
|
||||
# Init StructuredLogHandler
|
||||
super().__init__()
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
self.cloud_logging_filter.filter(record)
|
||||
return super().format(record)
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import uvicorn.config
|
||||
from colorama import Fore
|
||||
|
||||
|
||||
@@ -25,3 +26,14 @@ def print_attribute(
|
||||
"color": value_color,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def generate_uvicorn_config():
|
||||
"""
|
||||
Generates a uvicorn logging config that silences uvicorn's default logging and tells it to use the native logging module.
|
||||
"""
|
||||
log_config = dict(uvicorn.config.LOGGING_CONFIG)
|
||||
log_config["loggers"]["uvicorn"] = {"handlers": []}
|
||||
log_config["loggers"]["uvicorn.error"] = {"handlers": []}
|
||||
log_config["loggers"]["uvicorn.access"] = {"handlers": []}
|
||||
return log_config
|
||||
|
||||
@@ -1,20 +1,59 @@
|
||||
import inspect
|
||||
import threading
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
def thread_cached(
|
||||
func: Callable[P, R] | Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
|
||||
thread_local = threading.local()
|
||||
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
return wrapper
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
|
||||
@@ -31,7 +31,7 @@ class RedisKeyedMutex:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if lock.locked():
|
||||
if lock.locked() and lock.owned():
|
||||
lock.release()
|
||||
|
||||
def acquire(self, key: Any) -> "RedisLock":
|
||||
|
||||
@@ -8,6 +8,7 @@ DB_CONNECT_TIMEOUT=60
|
||||
DB_POOL_TIMEOUT=300
|
||||
DB_SCHEMA=platform
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
|
||||
# EXECUTOR
|
||||
|
||||
@@ -73,7 +73,6 @@ FROM server_dependencies AS server
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
RUN poetry install --no-ansi --only-root
|
||||
|
||||
ENV DATABASE_URL=""
|
||||
ENV PORT=8000
|
||||
|
||||
CMD ["poetry", "run", "rest"]
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -19,21 +17,6 @@ from backend.util import json
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_executor_manager_client():
|
||||
from backend.executor import ExecutionManager
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_event_bus():
|
||||
from backend.data.execution import RedisExecutionEventBus
|
||||
|
||||
return RedisExecutionEventBus()
|
||||
|
||||
|
||||
class AgentExecutorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
user_id: str = SchemaField(description="User ID")
|
||||
@@ -76,23 +59,23 @@ class AgentExecutorBlock(Block):
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
from backend.data.execution import ExecutionEventType
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
executor_manager = get_executor_manager_client()
|
||||
event_bus = get_event_bus()
|
||||
event_bus = execution_utils.get_execution_event_bus()
|
||||
|
||||
graph_exec = executor_manager.add_execution(
|
||||
graph_exec = execution_utils.add_graph_execution(
|
||||
graph_id=input_data.graph_id,
|
||||
graph_version=input_data.graph_version,
|
||||
user_id=input_data.user_id,
|
||||
data=input_data.data,
|
||||
inputs=input_data.data,
|
||||
)
|
||||
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.graph_exec_id}"
|
||||
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.id}"
|
||||
logger.info(f"Starting execution of {log_id}")
|
||||
|
||||
for event in event_bus.listen(
|
||||
user_id=graph_exec.user_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
graph_exec_id=graph_exec.id,
|
||||
):
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
if event.status in [
|
||||
@@ -105,7 +88,7 @@ class AgentExecutorBlock(Block):
|
||||
else:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
|
||||
)
|
||||
|
||||
@@ -123,5 +106,7 @@ class AgentExecutorBlock(Block):
|
||||
continue
|
||||
|
||||
for output_data in event.output_data.get("output", []):
|
||||
logger.info(f"Execution {log_id} produced {output_name}: {output_data}")
|
||||
logger.debug(
|
||||
f"Execution {log_id} produced {output_name}: {output_data}"
|
||||
)
|
||||
yield output_name, output_data
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
@@ -143,11 +143,12 @@ class ContactEmail(BaseModel):
|
||||
class EmploymentHistory(BaseModel):
|
||||
"""An employment history in Apollo"""
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
@@ -188,11 +189,12 @@ class TypedCustomField(BaseModel):
|
||||
class Pagination(BaseModel):
|
||||
"""Pagination in Apollo"""
|
||||
|
||||
class Config:
|
||||
extra = "allow" # Allow extra fields
|
||||
arbitrary_types_allowed = True # Allow any type
|
||||
from_attributes = True # Allow from_orm
|
||||
populate_by_name = True # Allow field aliases to work both ways
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
page: int = 0
|
||||
per_page: int = 0
|
||||
@@ -230,11 +232,12 @@ class PhoneNumber(BaseModel):
|
||||
class Organization(BaseModel):
|
||||
"""An organization in Apollo"""
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
id: Optional[str] = "N/A"
|
||||
name: Optional[str] = "N/A"
|
||||
@@ -268,11 +271,12 @@ class Organization(BaseModel):
|
||||
class Contact(BaseModel):
|
||||
"""A contact in Apollo"""
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
contact_roles: list[Any] = []
|
||||
id: Optional[str] = None
|
||||
@@ -369,14 +373,14 @@ If a company has several office locations, results are still based on the headqu
|
||||
|
||||
To exclude companies based on location, use the organization_not_locations parameter.
|
||||
""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
organizations_not_locations: list[str] = SchemaField(
|
||||
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
|
||||
|
||||
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
|
||||
""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_keyword_tags: list[str] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
|
||||
@@ -390,7 +394,7 @@ If the value you enter for this parameter does not match with a company's name,
|
||||
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, identify the values for organization_id when you call this endpoint.""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
@@ -443,14 +447,14 @@ Results also include job titles with the same terms, even if they are not exact
|
||||
|
||||
Use this parameter in combination with the person_seniorities[] parameter to find people based on specific job functions and seniority levels.
|
||||
""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
placeholder="marketing manager",
|
||||
)
|
||||
person_locations: list[str] = SchemaField(
|
||||
description="""The location where people live. You can search across cities, US states, and countries.
|
||||
|
||||
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
person_seniorities: list[SenorityLevels] = SchemaField(
|
||||
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
|
||||
@@ -460,7 +464,7 @@ For a person to be included in search results, they only need to match 1 of the
|
||||
Searches only return results based on their current job title, so searching for Director-level employees only returns people that currently hold a Director-level title. If someone was previously a Director, but is currently a VP, they would not be included in your search results.
|
||||
|
||||
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
organization_locations: list[str] = SchemaField(
|
||||
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
|
||||
@@ -468,7 +472,7 @@ Use this parameter in combination with the person_titles[] parameter to find peo
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
|
||||
|
||||
To find people based on their personal location, use the person_locations parameter.""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_domains: list[str] = SchemaField(
|
||||
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
|
||||
@@ -476,23 +480,23 @@ To find people based on their personal location, use the person_locations parame
|
||||
You can add multiple domains to search across companies.
|
||||
|
||||
Examples: apollo.io and microsoft.com""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
|
||||
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
q_keywords: str = SchemaField(
|
||||
description="""A string of words over which we want to filter the results""",
|
||||
@@ -522,11 +526,12 @@ Use the page parameter to search the different pages of data.""",
|
||||
class SearchPeopleResponse(BaseModel):
|
||||
"""Response from Apollo's search people API"""
|
||||
|
||||
class Config:
|
||||
extra = "allow" # Allow extra fields
|
||||
arbitrary_types_allowed = True # Allow any type
|
||||
from_attributes = True # Allow from_orm
|
||||
populate_by_name = True # Allow field aliases to work both ways
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
|
||||
@@ -32,18 +32,18 @@ If a company has several office locations, results are still based on the headqu
|
||||
|
||||
To exclude companies based on location, use the organization_not_locations parameter.
|
||||
""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
organizations_not_locations: list[str] = SchemaField(
|
||||
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
|
||||
|
||||
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
|
||||
""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_keyword_tags: list[str] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_name: str = SchemaField(
|
||||
description="""Filter search results to include a specific company name.
|
||||
@@ -56,7 +56,7 @@ If the value you enter for this parameter does not match with a company's name,
|
||||
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, identify the values for organization_id when you call this endpoint.""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
@@ -72,7 +72,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
class Output(BlockSchema):
|
||||
organizations: list[Organization] = SchemaField(
|
||||
description="List of organizations found",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
organization: Organization = SchemaField(
|
||||
description="Each found organization, one at a time",
|
||||
|
||||
@@ -26,14 +26,14 @@ class SearchPeopleBlock(Block):
|
||||
|
||||
Use this parameter in combination with the person_seniorities[] parameter to find people based on specific job functions and seniority levels.
|
||||
""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
person_locations: list[str] = SchemaField(
|
||||
description="""The location where people live. You can search across cities, US states, and countries.
|
||||
|
||||
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
person_seniorities: list[SenorityLevels] = SchemaField(
|
||||
@@ -44,7 +44,7 @@ class SearchPeopleBlock(Block):
|
||||
Searches only return results based on their current job title, so searching for Director-level employees only returns people that currently hold a Director-level title. If someone was previously a Director, but is currently a VP, they would not be included in your search results.
|
||||
|
||||
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
organization_locations: list[str] = SchemaField(
|
||||
@@ -53,7 +53,7 @@ class SearchPeopleBlock(Block):
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
|
||||
|
||||
To find people based on their personal location, use the person_locations parameter.""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
q_organization_domains: list[str] = SchemaField(
|
||||
@@ -62,26 +62,26 @@ class SearchPeopleBlock(Block):
|
||||
You can add multiple domains to search across companies.
|
||||
|
||||
Examples: apollo.io and microsoft.com""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
|
||||
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
q_keywords: str = SchemaField(
|
||||
@@ -104,7 +104,7 @@ class SearchPeopleBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
people: list[Contact] = SchemaField(
|
||||
description="List of people found",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
person: Contact = SchemaField(
|
||||
description="Each found person, one at a time",
|
||||
|
||||
@@ -88,6 +88,33 @@ class StoreValueBlock(Block):
|
||||
yield "output", input_data.data or input_data.input
|
||||
|
||||
|
||||
class PrintToConsoleBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: Any = SchemaField(description="The data to print to the console.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The data printed to the console.")
|
||||
status: str = SchemaField(description="The status of the print operation.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
|
||||
description="Print the given text to the console, this is used for a debugging purpose.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=PrintToConsoleBlock.Input,
|
||||
output_schema=PrintToConsoleBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
test_output=[
|
||||
("output", "Hello, World!"),
|
||||
("status", "printed"),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "output", input_data.text
|
||||
yield "status", "printed"
|
||||
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
@@ -151,7 +178,7 @@ class FindInDictionaryBlock(Block):
|
||||
class AddToDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
default={},
|
||||
default_factory=dict,
|
||||
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
|
||||
)
|
||||
key: str = SchemaField(
|
||||
@@ -167,7 +194,7 @@ class AddToDictionaryBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
entries: dict[Any, Any] = SchemaField(
|
||||
default={},
|
||||
default_factory=dict,
|
||||
description="The entries to add to the dictionary. This is the batch version of the `key` and `value` fields.",
|
||||
advanced=True,
|
||||
)
|
||||
@@ -229,7 +256,7 @@ class AddToDictionaryBlock(Block):
|
||||
class AddToListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
description="The list to add the entry to. If not provided, a new list will be created.",
|
||||
)
|
||||
@@ -239,7 +266,7 @@ class AddToListBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
entries: List[Any] = SchemaField(
|
||||
default=[],
|
||||
default_factory=lambda: list(),
|
||||
description="The entries to add to the list. This is the batch version of the `entry` field.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
@@ -55,7 +55,7 @@ class CodeExecutionBlock(Block):
|
||||
"These commands are executed with `sh`, in the foreground."
|
||||
),
|
||||
placeholder="pip install cowsay",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -207,7 +207,7 @@ class InstantiationBlock(Block):
|
||||
"These commands are executed with `sh`, in the foreground."
|
||||
),
|
||||
placeholder="pip install cowsay",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ class ReadCsvBlock(Block):
|
||||
)
|
||||
skip_columns: list[str] = SchemaField(
|
||||
description="The columns to skip from the start of the row",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -49,7 +49,7 @@ class ExaContentsBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of document contents",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
|
||||
@@ -38,11 +38,11 @@ class ExaSearchBlock(Block):
|
||||
)
|
||||
include_domains: List[str] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
exclude_domains: List[str] = SchemaField(
|
||||
description="Domains to exclude from search",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
@@ -59,12 +59,12 @@ class ExaSearchBlock(Block):
|
||||
)
|
||||
include_text: List[str] = SchemaField(
|
||||
description="Text patterns to include",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
exclude_text: List[str] = SchemaField(
|
||||
description="Text patterns to exclude",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
@@ -76,7 +76,7 @@ class ExaSearchBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of search results",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -26,12 +26,12 @@ class ExaFindSimilarBlock(Block):
|
||||
)
|
||||
include_domains: List[str] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
exclude_domains: List[str] = SchemaField(
|
||||
description="Domains to exclude from search",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
@@ -48,12 +48,12 @@ class ExaFindSimilarBlock(Block):
|
||||
)
|
||||
include_text: List[str] = SchemaField(
|
||||
description="Text patterns to include (max 1 string, up to 5 words)",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
exclude_text: List[str] = SchemaField(
|
||||
description="Text patterns to exclude (max 1 string, up to 5 words)",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
@@ -65,7 +65,7 @@ class ExaFindSimilarBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
results: List[Any] = SchemaField(
|
||||
description="List of similar documents with title, URL, published date, author, and score",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -42,7 +42,7 @@ class AIVideoGeneratorBlock(Block):
|
||||
description="Error message if video generation failed."
|
||||
)
|
||||
logs: list[str] = SchemaField(
|
||||
description="Generation progress logs.", optional=True
|
||||
description="Generation progress logs.",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -12,10 +12,10 @@ from backend.integrations.webhooks.generic import GenericWebhookType
|
||||
|
||||
class GenericWebhookTriggerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
payload: dict = SchemaField(hidden=True, default={})
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
constants: dict = SchemaField(
|
||||
description="The constants to be set when the block is put on the graph",
|
||||
default={},
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -37,7 +37,7 @@ class GitHubTriggerBase:
|
||||
placeholder="{owner}/{repo}",
|
||||
)
|
||||
# --8<-- [start:example-payload-field]
|
||||
payload: dict = SchemaField(hidden=True, default={})
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
# --8<-- [end:example-payload-field]
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -3,7 +3,7 @@ from googleapiclient.discovery import build
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
|
||||
from ._auth import (
|
||||
GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
@@ -36,13 +36,15 @@ class GoogleSheetsReadBlock(Block):
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
settings = Settings()
|
||||
super().__init__(
|
||||
id="5724e902-3635-47e9-a108-aaa0263a4988",
|
||||
description="This block reads data from a Google Sheets spreadsheet.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=GoogleSheetsReadBlock.Input,
|
||||
output_schema=GoogleSheetsReadBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED
|
||||
or settings.config.app_env == AppEnvironment.PRODUCTION,
|
||||
test_input={
|
||||
"spreadsheet_id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms",
|
||||
"range": "Sheet1!A1:B2",
|
||||
|
||||
@@ -34,7 +34,7 @@ class SendWebRequestBlock(Block):
|
||||
)
|
||||
headers: dict[str, str] = SchemaField(
|
||||
description="The headers to include in the request",
|
||||
default={},
|
||||
default_factory=dict,
|
||||
)
|
||||
json_format: bool = SchemaField(
|
||||
title="JSON format",
|
||||
@@ -82,7 +82,15 @@ class SendWebRequestBlock(Block):
|
||||
json=body if input_data.json_format else None,
|
||||
data=body if not input_data.json_format else None,
|
||||
)
|
||||
result = response.json() if input_data.json_format else response.text
|
||||
|
||||
if input_data.json_format:
|
||||
if response.status_code == 204 or not response.content.strip():
|
||||
result = None
|
||||
else:
|
||||
result = response.json()
|
||||
else:
|
||||
result = response.text
|
||||
|
||||
yield "response", result
|
||||
|
||||
except HTTPError as e:
|
||||
|
||||
@@ -15,7 +15,8 @@ class HubSpotCompanyBlock(Block):
|
||||
description="Operation to perform (create, update, get)", default="get"
|
||||
)
|
||||
company_data: dict = SchemaField(
|
||||
description="Company data for create/update operations", default={}
|
||||
description="Company data for create/update operations",
|
||||
default_factory=dict,
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Company domain for get/update operations", default=""
|
||||
|
||||
@@ -15,7 +15,8 @@ class HubSpotContactBlock(Block):
|
||||
description="Operation to perform (create, update, get)", default="get"
|
||||
)
|
||||
contact_data: dict = SchemaField(
|
||||
description="Contact data for create/update operations", default={}
|
||||
description="Contact data for create/update operations",
|
||||
default_factory=dict,
|
||||
)
|
||||
email: str = SchemaField(
|
||||
description="Email address for get/update operations", default=""
|
||||
|
||||
@@ -19,7 +19,7 @@ class HubSpotEngagementBlock(Block):
|
||||
)
|
||||
email_data: dict = SchemaField(
|
||||
description="Email data including recipient, subject, content",
|
||||
default={},
|
||||
default_factory=dict,
|
||||
)
|
||||
contact_id: str = SchemaField(
|
||||
description="Contact ID for engagement tracking", default=""
|
||||
@@ -27,7 +27,6 @@ class HubSpotEngagementBlock(Block):
|
||||
timeframe_days: int = SchemaField(
|
||||
description="Number of days to look back for engagement",
|
||||
default=30,
|
||||
optional=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
from datetime import date, time
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -38,7 +39,7 @@ class AgentInputBlock(Block):
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="The placeholder values to be passed as input.",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
hidden=True,
|
||||
)
|
||||
@@ -54,7 +55,7 @@ class AgentInputBlock(Block):
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = self.get_field_schema("value")
|
||||
schema = copy.deepcopy(self.get_field_schema("value"))
|
||||
if possible_values := self.placeholder_values:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
@@ -467,7 +468,7 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="Possible values for the dropdown.",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
title="Dropdown Options",
|
||||
)
|
||||
|
||||
@@ -11,13 +11,13 @@ class StepThroughItemsBlock(Block):
|
||||
advanced=False,
|
||||
description="The list or dictionary of items to iterate over",
|
||||
placeholder="[1, 2, 3, 4, 5] or {'key1': 'value1', 'key2': 'value2'}",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
items_object: dict = SchemaField(
|
||||
advanced=False,
|
||||
description="The list or dictionary of items to iterate over",
|
||||
placeholder="[1, 2, 3, 4, 5] or {'key1': 'value1', 'key2': 'value2'}",
|
||||
default={},
|
||||
default_factory=dict,
|
||||
)
|
||||
items_str: str = SchemaField(
|
||||
advanced=False,
|
||||
|
||||
@@ -23,7 +23,7 @@ class JinaChunkingBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
chunks: list = SchemaField(description="List of chunked texts")
|
||||
tokens: list = SchemaField(
|
||||
description="List of token information for each chunk", optional=True
|
||||
description="List of token information for each chunk",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from groq._utils._utils import quote
|
||||
from urllib.parse import quote
|
||||
|
||||
from backend.blocks.jina._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
|
||||
@@ -28,8 +28,8 @@ class LinearCreateIssueBlock(Block):
|
||||
priority: int | None = SchemaField(
|
||||
description="Priority of the issue",
|
||||
default=None,
|
||||
minimum=0,
|
||||
maximum=4,
|
||||
ge=0,
|
||||
le=4,
|
||||
)
|
||||
project_name: str | None = SchemaField(
|
||||
description="Name of the project to create the issue on",
|
||||
|
||||
@@ -4,30 +4,24 @@ from abc import ABC
|
||||
from enum import Enum, EnumMeta
|
||||
from json import JSONDecodeError
|
||||
from types import MappingProxyType
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, NamedTuple, Optional
|
||||
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import _EnumMemberT
|
||||
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
|
||||
|
||||
import anthropic
|
||||
import ollama
|
||||
import openai
|
||||
from anthropic._types import NotGiven
|
||||
from anthropic.types import ToolParam
|
||||
from groq import Groq
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.settings import BehaveAs, Settings
|
||||
from backend.util.text import TextFormatter
|
||||
@@ -77,12 +71,10 @@ class ModelMetadata(NamedTuple):
|
||||
|
||||
class LlmModelMeta(EnumMeta):
|
||||
@property
|
||||
def __members__(
|
||||
self: type["_EnumMemberT"],
|
||||
) -> MappingProxyType[str, "_EnumMemberT"]:
|
||||
def __members__(self) -> MappingProxyType:
|
||||
if Settings().config.behave_as == BehaveAs.LOCAL:
|
||||
members = super().__members__
|
||||
return members
|
||||
return MappingProxyType(members)
|
||||
else:
|
||||
removed_providers = ["ollama"]
|
||||
existing_members = super().__members__
|
||||
@@ -97,14 +89,17 @@ class LlmModelMeta(EnumMeta):
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenAI models
|
||||
O3_MINI = "o3-mini"
|
||||
O3 = "o3-2025-04-16"
|
||||
O1 = "o1"
|
||||
O1_PREVIEW = "o1-preview"
|
||||
O1_MINI = "o1-mini"
|
||||
GPT41 = "gpt-4.1-2025-04-14"
|
||||
GPT4O_MINI = "gpt-4o-mini"
|
||||
GPT4O = "gpt-4o"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
|
||||
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
@@ -125,6 +120,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
|
||||
# OpenRouter models
|
||||
GEMINI_FLASH_1_5 = "google/gemini-flash-1.5"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||
GROK_BETA = "x-ai/grok-beta"
|
||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||
@@ -142,6 +138,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
||||
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
@@ -162,12 +160,14 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
|
||||
MODEL_METADATA = {
|
||||
# https://platform.openai.com/docs/models
|
||||
LlmModel.O3: ModelMetadata("openai", 200000, 100000),
|
||||
LlmModel.O3_MINI: ModelMetadata("openai", 200000, 100000), # o3-mini-2025-01-31
|
||||
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
|
||||
LlmModel.O1_PREVIEW: ModelMetadata(
|
||||
"openai", 128000, 32768
|
||||
), # o1-preview-2024-09-12
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
|
||||
LlmModel.GPT41: ModelMetadata("openai", 1047576, 32768),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata(
|
||||
"openai", 128000, 16384
|
||||
), # gpt-4o-mini-2024-07-18
|
||||
@@ -177,6 +177,9 @@ MODEL_METADATA = {
|
||||
), # gpt-4-turbo-2024-04-09
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-3-7-sonnet-20250219
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-3-5-sonnet-20241022
|
||||
@@ -202,6 +205,7 @@ MODEL_METADATA = {
|
||||
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768, None),
|
||||
# https://openrouter.ai/models
|
||||
LlmModel.GEMINI_FLASH_1_5: ModelMetadata("open_router", 1000000, 8192),
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata("open_router", 1050000, 8192),
|
||||
LlmModel.GROK_BETA: ModelMetadata("open_router", 131072, 131072),
|
||||
LlmModel.MISTRAL_NEMO: ModelMetadata("open_router", 128000, 4096),
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata("open_router", 128000, 4096),
|
||||
@@ -223,6 +227,8 @@ MODEL_METADATA = {
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata("open_router", 300000, 5120),
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata("open_router", 65536, 4096),
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata("open_router", 4096, 4096),
|
||||
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata("open_router", 131072, 131072),
|
||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata("open_router", 1048576, 1000000),
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
@@ -252,7 +258,7 @@ class LLMResponse(BaseModel):
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
openai_tools: list[dict] | None = None,
|
||||
) -> Iterable[ToolParam] | NotGiven:
|
||||
) -> Iterable[ToolParam] | anthropic.NotGiven:
|
||||
"""
|
||||
Convert OpenAI tool format to Anthropic tool format.
|
||||
"""
|
||||
@@ -282,6 +288,13 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def estimate_token_count(prompt_messages: list[dict]) -> int:
|
||||
char_count = sum(len(str(msg.get("content", ""))) for msg in prompt_messages)
|
||||
message_overhead = len(prompt_messages) * 4
|
||||
estimated_tokens = (char_count // 4) + message_overhead
|
||||
return int(estimated_tokens * 1.2)
|
||||
|
||||
|
||||
def llm_call(
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
@@ -290,6 +303,7 @@ def llm_call(
|
||||
max_tokens: int | None,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
parallel_tool_calls: bool | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Make a call to a language model.
|
||||
@@ -312,7 +326,14 @@ def llm_call(
|
||||
- completion_tokens: The number of tokens used in the completion.
|
||||
"""
|
||||
provider = llm_model.metadata.provider
|
||||
max_tokens = max_tokens or llm_model.max_output_tokens or 4096
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or 4096
|
||||
user_max = max_tokens or model_max_output
|
||||
available_tokens = max(context_window - estimated_input_tokens, 0)
|
||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 0)
|
||||
|
||||
if provider == "openai":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
@@ -335,6 +356,9 @@ def llm_call(
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=(
|
||||
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
|
||||
),
|
||||
)
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
@@ -424,7 +448,7 @@ def llm_call(
|
||||
response=(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else resp.content[0].text
|
||||
else getattr(resp.content[0], "text", "")
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
@@ -465,6 +489,7 @@ def llm_call(
|
||||
model=llm_model.value,
|
||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||
stream=False,
|
||||
options={"num_ctx": max_tokens},
|
||||
)
|
||||
return LLMResponse(
|
||||
raw_response=response.get("response") or "",
|
||||
@@ -490,6 +515,9 @@ def llm_call(
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=(
|
||||
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
|
||||
),
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
@@ -528,7 +556,7 @@ def llm_call(
|
||||
class AIBlockBase(Block, ABC):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.prompt = ""
|
||||
self.prompt = []
|
||||
|
||||
def merge_llm_stats(self, block: "AIBlockBase"):
|
||||
self.merge_stats(block.execution_stats)
|
||||
@@ -558,7 +586,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
description="The system prompt to provide additional context to the model.",
|
||||
)
|
||||
conversation_history: list[dict] = SchemaField(
|
||||
default=[],
|
||||
default_factory=list,
|
||||
description="The conversation history to provide context for the prompt.",
|
||||
)
|
||||
retry: int = SchemaField(
|
||||
@@ -568,7 +596,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False,
|
||||
default={},
|
||||
default_factory=dict,
|
||||
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
@@ -587,7 +615,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
response: dict[str, Any] = SchemaField(
|
||||
description="The response object generated by the language model."
|
||||
)
|
||||
prompt: str = SchemaField(description="The prompt sent to the language model.")
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -609,7 +637,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("response", {"key1": "key1Value", "key2": "key2Value"}),
|
||||
("prompt", str),
|
||||
("prompt", list),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: LLMResponse(
|
||||
@@ -642,6 +670,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
Test mocks work only on class functions, this wraps the llm_call function
|
||||
so that it can be mocked withing the block testing framework.
|
||||
"""
|
||||
self.prompt = prompt
|
||||
return llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
@@ -759,6 +788,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
prompt.append({"role": "user", "content": retry_prompt})
|
||||
except Exception as e:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
if (
|
||||
"maximum context length" in str(e).lower()
|
||||
or "token limit" in str(e).lower()
|
||||
):
|
||||
if input_data.max_tokens is None:
|
||||
input_data.max_tokens = llm_model.max_output_tokens or 4096
|
||||
input_data.max_tokens = int(input_data.max_tokens * 0.85)
|
||||
logger.debug(
|
||||
f"Reducing max_tokens to {input_data.max_tokens} for next attempt"
|
||||
)
|
||||
retry_prompt = f"Error calling LLM: {e}"
|
||||
finally:
|
||||
self.merge_stats(
|
||||
@@ -796,7 +835,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False,
|
||||
default={},
|
||||
default_factory=dict,
|
||||
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
@@ -814,7 +853,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
response: str = SchemaField(
|
||||
description="The response generated by the language model."
|
||||
)
|
||||
prompt: str = SchemaField(description="The prompt sent to the language model.")
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -831,7 +870,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("response", "Response text"),
|
||||
("prompt", str),
|
||||
("prompt", list),
|
||||
],
|
||||
test_mock={"llm_call": lambda *args, **kwargs: "Response text"},
|
||||
)
|
||||
@@ -850,7 +889,10 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
object_input_data = AIStructuredResponseGeneratorBlock.Input(
|
||||
**{attr: getattr(input_data, attr) for attr in input_data.model_fields},
|
||||
**{
|
||||
attr: getattr(input_data, attr)
|
||||
for attr in AITextGeneratorBlock.Input.model_fields
|
||||
},
|
||||
expected_format={},
|
||||
)
|
||||
yield "response", self.llm_call(object_input_data, credentials)
|
||||
@@ -907,7 +949,7 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
|
||||
class Output(BlockSchema):
|
||||
summary: str = SchemaField(description="The final summary of the text.")
|
||||
prompt: str = SchemaField(description="The prompt sent to the language model.")
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -924,7 +966,7 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("summary", "Final summary of a long text"),
|
||||
("prompt", str),
|
||||
("prompt", list),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda input_data, credentials: (
|
||||
@@ -1033,8 +1075,14 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
|
||||
class AIConversationBlock(AIBlockBase):
|
||||
class Input(BlockSchema):
|
||||
prompt: str = SchemaField(
|
||||
description="The prompt to send to the language model.",
|
||||
placeholder="Enter your prompt here...",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
messages: List[Any] = SchemaField(
|
||||
description="List of messages in the conversation.", min_length=1
|
||||
description="List of messages in the conversation.",
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
@@ -1057,7 +1105,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
response: str = SchemaField(
|
||||
description="The model's response to the conversation."
|
||||
)
|
||||
prompt: str = SchemaField(description="The prompt sent to the language model.")
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -1086,7 +1134,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
"response",
|
||||
"The 2020 World Series was played at Globe Life Field in Arlington, Texas.",
|
||||
),
|
||||
("prompt", str),
|
||||
("prompt", list),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: "The 2020 World Series was played at Globe Life Field in Arlington, Texas."
|
||||
@@ -1108,7 +1156,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
) -> BlockOutput:
|
||||
response = self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="",
|
||||
prompt=input_data.prompt,
|
||||
credentials=input_data.credentials,
|
||||
model=input_data.model,
|
||||
conversation_history=input_data.messages,
|
||||
@@ -1166,7 +1214,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
list_item: str = SchemaField(
|
||||
description="Each individual item in the list.",
|
||||
)
|
||||
prompt: str = SchemaField(description="The prompt sent to the language model.")
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the list generation failed."
|
||||
)
|
||||
@@ -1198,7 +1246,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"generated_list",
|
||||
["Zylora Prime", "Kharon-9", "Vortexia", "Oceara", "Draknos"],
|
||||
),
|
||||
("prompt", str),
|
||||
("prompt", list),
|
||||
("list_item", "Zylora Prime"),
|
||||
("list_item", "Kharon-9"),
|
||||
("list_item", "Vortexia"),
|
||||
|
||||
@@ -65,7 +65,7 @@ class AddMemoryBlock(Block, Mem0Base):
|
||||
default=Content(discriminator="content", content="I'm a vegetarian"),
|
||||
)
|
||||
metadata: dict[str, Any] = SchemaField(
|
||||
description="Optional metadata for the memory", default={}
|
||||
description="Optional metadata for the memory", default_factory=dict
|
||||
)
|
||||
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
@@ -173,7 +173,7 @@ class SearchMemoryBlock(Block, Mem0Base):
|
||||
)
|
||||
categories_filter: list[str] = SchemaField(
|
||||
description="Categories to filter by",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
|
||||
@@ -177,7 +177,8 @@ class PineconeInsertBlock(Block):
|
||||
description="Namespace to use in Pinecone", default=""
|
||||
)
|
||||
metadata: dict = SchemaField(
|
||||
description="Additional metadata to store with each vector", default={}
|
||||
description="Additional metadata to store with each vector",
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -26,7 +26,7 @@ class Slant3DTriggerBase:
|
||||
class Input(BlockSchema):
|
||||
credentials: Slant3DCredentialsInput = Slant3DCredentialsField()
|
||||
# Webhook URL is handled by the webhook system
|
||||
payload: dict = SchemaField(hidden=True, default={})
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
|
||||
class Output(BlockSchema):
|
||||
payload: dict = SchemaField(
|
||||
|
||||
@@ -14,7 +14,6 @@ from backend.data.block import (
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
@@ -27,10 +26,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
@@ -155,7 +154,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
description="The system prompt to provide additional context to the model.",
|
||||
)
|
||||
conversation_history: list[dict] = SchemaField(
|
||||
default=[],
|
||||
default_factory=list,
|
||||
description="The conversation history to provide context for the prompt.",
|
||||
)
|
||||
last_tool_output: Any = SchemaField(
|
||||
@@ -169,7 +168,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False,
|
||||
default={},
|
||||
default_factory=dict,
|
||||
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
@@ -247,6 +246,10 @@ class SmartDecisionMakerBlock(Block):
|
||||
test_credentials=llm.TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def cleanup(s: str):
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
|
||||
|
||||
@staticmethod
|
||||
def _create_block_function_signature(
|
||||
sink_node: "Node", links: list["Link"]
|
||||
@@ -264,12 +267,10 @@ class SmartDecisionMakerBlock(Block):
|
||||
Raises:
|
||||
ValueError: If the block specified by sink_node.block_id is not found.
|
||||
"""
|
||||
block = get_block(sink_node.block_id)
|
||||
if not block:
|
||||
raise ValueError(f"Block not found: {sink_node.block_id}")
|
||||
block = sink_node.block
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", block.name).lower(),
|
||||
"name": SmartDecisionMakerBlock.cleanup(block.name),
|
||||
"description": block.description,
|
||||
}
|
||||
|
||||
@@ -284,7 +285,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
and sink_block_input_schema.model_fields[link.sink_name].description
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[link.sink_name.lower()] = {
|
||||
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
}
|
||||
@@ -329,7 +330,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", sink_graph_meta.name).lower(),
|
||||
"name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name),
|
||||
"description": sink_graph_meta.description,
|
||||
}
|
||||
|
||||
@@ -344,7 +345,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
in sink_block_input_schema["properties"][link.sink_name]
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[link.sink_name.lower()] = {
|
||||
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
}
|
||||
@@ -494,6 +495,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=False,
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
@@ -505,7 +507,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
for arg_name, arg_value in tool_args.items():
|
||||
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", arg_value
|
||||
|
||||
response.prompt.append(response.raw_response)
|
||||
yield "conversations", response.prompt
|
||||
|
||||
@@ -112,7 +112,7 @@ class AddLeadToCampaignBlock(Block):
|
||||
lead_list: list[LeadInput] = SchemaField(
|
||||
description="An array of JSON objects, each representing a lead's details. Can hold max 100 leads.",
|
||||
max_length=100,
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
settings: LeadUploadSettings = SchemaField(
|
||||
@@ -248,7 +248,7 @@ class SaveCampaignSequencesBlock(Block):
|
||||
)
|
||||
sequences: list[Sequence] = SchemaField(
|
||||
description="The sequences to save",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
|
||||
@@ -39,7 +39,7 @@ class LeadCustomFields(BaseModel):
|
||||
fields: dict[str, str] = SchemaField(
|
||||
description="Custom fields for a lead (max 20 fields)",
|
||||
max_length=20,
|
||||
default={},
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ class AddLeadsRequest(BaseModel):
|
||||
lead_list: list[LeadInput] = SchemaField(
|
||||
description="List of leads to add to the campaign",
|
||||
max_length=100,
|
||||
default=[],
|
||||
default_factory=list,
|
||||
)
|
||||
settings: LeadUploadSettings
|
||||
campaign_id: int
|
||||
|
||||
@@ -156,7 +156,7 @@
|
||||
# participant_ids: list[str] = SchemaField(
|
||||
# description="Array of User IDs to create conversation with (max 50)",
|
||||
# placeholder="Enter participant user IDs",
|
||||
# default=[],
|
||||
# default_factory=list,
|
||||
# advanced=False
|
||||
# )
|
||||
|
||||
|
||||
@@ -39,7 +39,6 @@ class TwitterGetListBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to lookup",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -184,7 +183,6 @@ class TwitterGetOwnedListsBlock(Block):
|
||||
user_id: str = SchemaField(
|
||||
description="The user ID whose owned Lists to retrieve",
|
||||
placeholder="Enter user ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
|
||||
@@ -45,13 +45,11 @@ class TwitterRemoveListMemberBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to remove the member from",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="The ID of the user to remove from the List",
|
||||
placeholder="Enter user ID to remove",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -120,13 +118,11 @@ class TwitterAddListMemberBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to add the member to",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="The ID of the user to add to the List",
|
||||
placeholder="Enter user ID to add",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -195,7 +191,6 @@ class TwitterGetListMembersBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to get members from",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
@@ -376,7 +371,6 @@ class TwitterGetListMembershipsBlock(Block):
|
||||
user_id: str = SchemaField(
|
||||
description="The ID of the user whose List memberships to retrieve",
|
||||
placeholder="Enter user ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
|
||||
@@ -42,7 +42,6 @@ class TwitterGetListTweetsBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List whose Tweets you would like to retrieve",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
|
||||
@@ -28,7 +28,6 @@ class TwitterDeleteListBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to be deleted",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -39,7 +39,6 @@ class TwitterUnpinListBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to unpin",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -103,7 +102,6 @@ class TwitterPinListBlock(Block):
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to pin",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -44,7 +44,7 @@ class SpaceList(BaseModel):
|
||||
space_ids: list[str] = SchemaField(
|
||||
description="List of Space IDs to lookup (up to 100)",
|
||||
placeholder="Enter Space IDs",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -54,7 +54,7 @@ class UserList(BaseModel):
|
||||
user_ids: list[str] = SchemaField(
|
||||
description="List of user IDs to lookup their Spaces (up to 100)",
|
||||
placeholder="Enter user IDs",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -227,7 +227,6 @@ class TwitterGetSpaceByIdBlock(Block):
|
||||
space_id: str = SchemaField(
|
||||
description="Space ID to lookup",
|
||||
placeholder="Enter Space ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -389,7 +388,6 @@ class TwitterGetSpaceBuyersBlock(Block):
|
||||
space_id: str = SchemaField(
|
||||
description="Space ID to lookup buyers for",
|
||||
placeholder="Enter Space ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -517,7 +515,6 @@ class TwitterGetSpaceTweetsBlock(Block):
|
||||
space_id: str = SchemaField(
|
||||
description="Space ID to lookup tweets for",
|
||||
placeholder="Enter Space ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -200,7 +200,7 @@ class UserIdList(BaseModel):
|
||||
user_ids: list[str] = SchemaField(
|
||||
description="List of user IDs to lookup (max 100)",
|
||||
placeholder="Enter user IDs",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -210,7 +210,7 @@ class UsernameList(BaseModel):
|
||||
usernames: list[str] = SchemaField(
|
||||
description="List of Twitter usernames/handles to lookup (max 100)",
|
||||
placeholder="Enter usernames",
|
||||
default=[],
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import pathlib
|
||||
import click
|
||||
import psutil
|
||||
|
||||
from backend import app
|
||||
from backend.util.process import AppProcess
|
||||
|
||||
|
||||
@@ -42,8 +41,13 @@ def write_pid(pid: int):
|
||||
|
||||
class MainApp(AppProcess):
|
||||
def run(self):
|
||||
from backend import app
|
||||
|
||||
app.main(silent=True)
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
|
||||
@click.group()
|
||||
def main():
|
||||
|
||||
@@ -12,12 +12,12 @@ async def log_raw_analytics(
|
||||
data_index: str,
|
||||
):
|
||||
details = await prisma.models.AnalyticsDetails.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"type": type,
|
||||
"data": prisma.Json(data),
|
||||
"dataIndex": data_index,
|
||||
}
|
||||
data=prisma.types.AnalyticsDetailsCreateInput(
|
||||
userId=user_id,
|
||||
type=type,
|
||||
data=prisma.Json(data),
|
||||
dataIndex=data_index,
|
||||
)
|
||||
)
|
||||
return details
|
||||
|
||||
@@ -32,12 +32,12 @@ async def log_raw_metric(
|
||||
raise ValueError("metric_value must be non-negative")
|
||||
|
||||
result = await prisma.models.AnalyticsMetrics.prisma().create(
|
||||
data={
|
||||
"value": metric_value,
|
||||
"analyticMetric": metric_name,
|
||||
"userId": user_id,
|
||||
"dataString": data_string,
|
||||
},
|
||||
data=prisma.types.AnalyticsMetricsCreateInput(
|
||||
value=metric_value,
|
||||
analyticMetric=metric_name,
|
||||
userId=user_id,
|
||||
dataString=data_string,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -17,6 +17,7 @@ from typing import (
|
||||
import jsonref
|
||||
import jsonschema
|
||||
from prisma.models import AgentBlock
|
||||
from prisma.types import AgentBlockCreateInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
@@ -27,6 +28,7 @@ from backend.util.settings import Config
|
||||
from .model import (
|
||||
ContributorDetails,
|
||||
Credentials,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
@@ -202,6 +204,15 @@ class BlockSchema(BaseModel):
|
||||
)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
|
||||
return {
|
||||
field_name: CredentialsFieldInfo.model_validate(
|
||||
cls.get_field_schema(field_name), by_alias=True
|
||||
)
|
||||
for field_name in cls.get_credentials_fields().keys()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
return data # Return as is, by default.
|
||||
@@ -480,12 +491,12 @@ async def initialize_blocks() -> None:
|
||||
)
|
||||
if not existing_block:
|
||||
await AgentBlock.prisma().create(
|
||||
data={
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"inputSchema": json.dumps(block.input_schema.jsonschema()),
|
||||
"outputSchema": json.dumps(block.output_schema.jsonschema()),
|
||||
}
|
||||
data=AgentBlockCreateInput(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
inputSchema=json.dumps(block.input_schema.jsonschema()),
|
||||
outputSchema=json.dumps(block.output_schema.jsonschema()),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -508,6 +519,7 @@ async def initialize_blocks() -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_block(block_id: str) -> Block | None:
|
||||
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||
def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
|
||||
cls = get_blocks().get(block_id)
|
||||
return cls() if cls else None
|
||||
|
||||
@@ -36,14 +36,17 @@ from backend.integrations.credentials_store import (
|
||||
# =============== Configure the cost for each LLM Model call =============== #
|
||||
|
||||
MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3: 7,
|
||||
LlmModel.O3_MINI: 2, # $1.10 / $4.40
|
||||
LlmModel.O1: 16, # $15 / $60
|
||||
LlmModel.O1_PREVIEW: 16,
|
||||
LlmModel.O1_MINI: 4,
|
||||
LlmModel.GPT41: 2,
|
||||
LlmModel.GPT4O_MINI: 1,
|
||||
LlmModel.GPT4O: 3,
|
||||
LlmModel.GPT4_TURBO: 10,
|
||||
LlmModel.GPT3_5_TURBO: 1,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_5_SONNET: 4,
|
||||
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
|
||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||
@@ -60,6 +63,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.DEEPSEEK_LLAMA_70B: 1, # ? / ?
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
LlmModel.GEMINI_FLASH_1_5: 1,
|
||||
LlmModel.GEMINI_2_5_PRO: 4,
|
||||
LlmModel.GROK_BETA: 5,
|
||||
LlmModel.MISTRAL_NEMO: 1,
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: 1,
|
||||
@@ -75,6 +79,8 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: 1,
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
|
||||
LlmModel.META_LLAMA_4_SCOUT: 1,
|
||||
LlmModel.META_LLAMA_4_MAVERICK: 1,
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
import stripe
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
@@ -11,11 +11,16 @@ from prisma.enums import (
|
||||
CreditRefundRequestStatus,
|
||||
CreditTransactionType,
|
||||
NotificationType,
|
||||
OnboardingStep,
|
||||
)
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User
|
||||
from prisma.types import CreditTransactionCreateInput, CreditTransactionWhereInput
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from prisma.types import (
|
||||
CreditRefundRequestCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
CreditTransactionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
@@ -23,14 +28,17 @@ from backend.data.cost import BlockCost
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
TopUpType,
|
||||
TransactionHistory,
|
||||
UserTransaction,
|
||||
)
|
||||
from backend.data.notifications import NotificationEventDTO, RefundRequestData
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor.utils import UsageTransactionMetadata
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications import NotificationManagerClient
|
||||
from backend.server.model import Pagination
|
||||
from backend.server.v2.admin.model import UserHistoryResponse
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -40,6 +48,17 @@ logger = logging.getLogger(__name__)
|
||||
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
graph_exec_id: str | None = None
|
||||
graph_id: str | None = None
|
||||
node_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
block_id: str | None = None
|
||||
block: str | None = None
|
||||
input: dict[str, Any] | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class UserCreditBase(ABC):
|
||||
@abstractmethod
|
||||
async def get_credits(self, user_id: str) -> int:
|
||||
@@ -117,6 +136,18 @@ class UserCreditBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
|
||||
"""
|
||||
Reward the user with credits for completing an onboarding step.
|
||||
Won't reward if the user has already received credits for the step.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
step (OnboardingStep): The onboarding step.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def top_up_intent(self, user_id: str, amount: int) -> str:
|
||||
"""
|
||||
@@ -209,7 +240,7 @@ class UserCreditBase(ABC):
|
||||
"userId": user_id,
|
||||
"createdAt": {"lte": top_time},
|
||||
"isActive": True,
|
||||
"runningBalance": {"not": None}, # type: ignore
|
||||
"NOT": [{"runningBalance": None}],
|
||||
},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
@@ -245,11 +276,7 @@ class UserCreditBase(ABC):
|
||||
)
|
||||
return transaction_balance, transaction_time
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
reraise=True,
|
||||
)
|
||||
@func_retry
|
||||
async def _enable_transaction(
|
||||
self,
|
||||
transaction_key: str,
|
||||
@@ -331,15 +358,15 @@ class UserCreditBase(ABC):
|
||||
amount = min(-user_balance, 0)
|
||||
|
||||
# Create the transaction
|
||||
transaction_data: CreditTransactionCreateInput = {
|
||||
"userId": user_id,
|
||||
"amount": amount,
|
||||
"runningBalance": user_balance + amount,
|
||||
"type": transaction_type,
|
||||
"metadata": metadata,
|
||||
"isActive": is_active,
|
||||
"createdAt": self.time_now(),
|
||||
}
|
||||
transaction_data = CreditTransactionCreateInput(
|
||||
userId=user_id,
|
||||
amount=amount,
|
||||
runningBalance=user_balance + amount,
|
||||
type=transaction_type,
|
||||
metadata=metadata,
|
||||
isActive=is_active,
|
||||
createdAt=self.time_now(),
|
||||
)
|
||||
if transaction_key:
|
||||
transaction_data["transactionKey"] = transaction_key
|
||||
tx = await CreditTransaction.prisma().create(data=transaction_data)
|
||||
@@ -348,21 +375,19 @@ class UserCreditBase(ABC):
|
||||
|
||||
class UserCredit(UserCreditBase):
|
||||
@thread_cached
|
||||
def notification_client(self) -> NotificationManager:
|
||||
return get_service_client(NotificationManager)
|
||||
def notification_client(self) -> NotificationManagerClient:
|
||||
return get_service_client(NotificationManagerClient)
|
||||
|
||||
async def _send_refund_notification(
|
||||
self,
|
||||
notification_request: RefundRequestData,
|
||||
notification_type: NotificationType,
|
||||
):
|
||||
await asyncio.to_thread(
|
||||
lambda: self.notification_client().queue_notification(
|
||||
NotificationEventDTO(
|
||||
user_id=notification_request.user_id,
|
||||
type=notification_type,
|
||||
data=notification_request.model_dump(),
|
||||
)
|
||||
await self.notification_client().queue_notification_async(
|
||||
NotificationEventDTO(
|
||||
user_id=notification_request.user_id,
|
||||
type=notification_type,
|
||||
data=notification_request.model_dump(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -392,6 +417,7 @@ class UserCredit(UserCreditBase):
|
||||
# Avoid multiple auto top-ups within the same graph execution.
|
||||
key=f"AUTO-TOP-UP-{user_id}-{metadata.graph_exec_id}",
|
||||
ceiling_balance=auto_top_up.threshold,
|
||||
top_up_type=TopUpType.AUTO,
|
||||
)
|
||||
except Exception as e:
|
||||
# Failed top-up is not critical, we can move on.
|
||||
@@ -401,8 +427,30 @@ class UserCredit(UserCreditBase):
|
||||
|
||||
return balance
|
||||
|
||||
async def top_up_credits(self, user_id: str, amount: int):
|
||||
await self._top_up_credits(user_id, amount)
|
||||
async def top_up_credits(
|
||||
self,
|
||||
user_id: str,
|
||||
amount: int,
|
||||
top_up_type: TopUpType = TopUpType.UNCATEGORIZED,
|
||||
):
|
||||
await self._top_up_credits(
|
||||
user_id=user_id, amount=amount, top_up_type=top_up_type
|
||||
)
|
||||
|
||||
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
|
||||
try:
|
||||
await self._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=credits,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
transaction_key=f"REWARD-{user_id}-{step.value}",
|
||||
metadata=Json(
|
||||
{"reason": f"Reward for completing {step.value} onboarding step."}
|
||||
),
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# Already rewarded for this step
|
||||
pass
|
||||
|
||||
async def top_up_refund(
|
||||
self, user_id: str, transaction_key: str, metadata: dict[str, str]
|
||||
@@ -422,15 +470,15 @@ class UserCredit(UserCreditBase):
|
||||
|
||||
try:
|
||||
refund_request = await CreditRefundRequest.prisma().create(
|
||||
data={
|
||||
"id": refund_key,
|
||||
"transactionKey": transaction_key,
|
||||
"userId": user_id,
|
||||
"amount": amount,
|
||||
"reason": metadata.get("reason", ""),
|
||||
"status": CreditRefundRequestStatus.PENDING,
|
||||
"result": "The refund request is under review.",
|
||||
}
|
||||
data=CreditRefundRequestCreateInput(
|
||||
id=refund_key,
|
||||
transactionKey=transaction_key,
|
||||
userId=user_id,
|
||||
amount=amount,
|
||||
reason=metadata.get("reason", ""),
|
||||
status=CreditRefundRequestStatus.PENDING,
|
||||
result="The refund request is under review.",
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
raise ValueError(
|
||||
@@ -567,7 +615,7 @@ class UserCredit(UserCreditBase):
|
||||
|
||||
evidence_text += (
|
||||
f"- {tx.description}: Amount ${tx.amount / 100:.2f} on {tx.transaction_time.isoformat()}, "
|
||||
f"resulting balance ${tx.balance / 100:.2f} {additional_comment}\n"
|
||||
f"resulting balance ${tx.running_balance / 100:.2f} {additional_comment}\n"
|
||||
)
|
||||
evidence_text += (
|
||||
"\nThis evidence demonstrates that the transaction was authorized and that the charged amount was used to render the service as agreed."
|
||||
@@ -586,7 +634,24 @@ class UserCredit(UserCreditBase):
|
||||
amount: int,
|
||||
key: str | None = None,
|
||||
ceiling_balance: int | None = None,
|
||||
top_up_type: TopUpType = TopUpType.UNCATEGORIZED,
|
||||
metadata: dict | None = None,
|
||||
):
|
||||
# init metadata, without sharing it with the world
|
||||
metadata = metadata or {}
|
||||
if not metadata["reason"]:
|
||||
match top_up_type:
|
||||
case TopUpType.MANUAL:
|
||||
metadata["reason"] = {"reason": f"Top up credits for {user_id}"}
|
||||
case TopUpType.AUTO:
|
||||
metadata["reason"] = {
|
||||
"reason": f"Auto top up credits for {user_id}"
|
||||
}
|
||||
case _:
|
||||
metadata["reason"] = {
|
||||
"reason": f"Top up reason unknown for {user_id}"
|
||||
}
|
||||
|
||||
if amount < 0:
|
||||
raise ValueError(f"Top up amount must not be negative: {amount}")
|
||||
|
||||
@@ -609,6 +674,7 @@ class UserCredit(UserCreditBase):
|
||||
is_active=False,
|
||||
transaction_key=key,
|
||||
ceiling_balance=ceiling_balance,
|
||||
metadata=(Json(metadata)),
|
||||
)
|
||||
|
||||
customer_id = await get_stripe_customer_id(user_id)
|
||||
@@ -751,10 +817,15 @@ class UserCredit(UserCreditBase):
|
||||
# Check the Checkout Session's payment_status property
|
||||
# to determine if fulfillment should be performed
|
||||
if checkout_session.payment_status in ["paid", "no_payment_required"]:
|
||||
assert isinstance(checkout_session.payment_intent, stripe.PaymentIntent)
|
||||
if payment_intent := checkout_session.payment_intent:
|
||||
assert isinstance(payment_intent, stripe.PaymentIntent)
|
||||
new_transaction_key = payment_intent.id
|
||||
else:
|
||||
new_transaction_key = None
|
||||
|
||||
await self._enable_transaction(
|
||||
transaction_key=credit_transaction.transactionKey,
|
||||
new_transaction_key=checkout_session.payment_intent.id,
|
||||
new_transaction_key=new_transaction_key,
|
||||
user_id=credit_transaction.userId,
|
||||
metadata=Json(checkout_session),
|
||||
)
|
||||
@@ -787,8 +858,9 @@ class UserCredit(UserCreditBase):
|
||||
take=transaction_count_limit,
|
||||
)
|
||||
|
||||
# doesn't fill current_balance, reason, user_email, admin_email, or extra_data
|
||||
grouped_transactions: dict[str, UserTransaction] = defaultdict(
|
||||
lambda: UserTransaction()
|
||||
lambda: UserTransaction(user_id=user_id)
|
||||
)
|
||||
tx_time = None
|
||||
for t in transactions:
|
||||
@@ -818,7 +890,7 @@ class UserCredit(UserCreditBase):
|
||||
|
||||
if tx_time > gt.transaction_time:
|
||||
gt.transaction_time = tx_time
|
||||
gt.balance = t.runningBalance or 0
|
||||
gt.running_balance = t.runningBalance or 0
|
||||
|
||||
return TransactionHistory(
|
||||
transactions=list(grouped_transactions.values()),
|
||||
@@ -868,6 +940,7 @@ class BetaUserCredit(UserCredit):
|
||||
amount=max(self.num_user_credits_refill - balance, 0),
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
transaction_key=f"MONTHLY-CREDIT-TOP-UP-{cur_time}",
|
||||
metadata=Json({"reason": "Monthly credit refill"}),
|
||||
)
|
||||
return balance
|
||||
except UniqueViolationError:
|
||||
@@ -877,7 +950,7 @@ class BetaUserCredit(UserCredit):
|
||||
|
||||
class DisabledUserCredit(UserCreditBase):
|
||||
async def get_credits(self, *args, **kwargs) -> int:
|
||||
return 0
|
||||
return 100
|
||||
|
||||
async def get_transaction_history(self, *args, **kwargs) -> TransactionHistory:
|
||||
return TransactionHistory(transactions=[], next_transaction_time=None)
|
||||
@@ -891,6 +964,9 @@ class DisabledUserCredit(UserCreditBase):
|
||||
async def top_up_credits(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def onboarding_reward(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def top_up_intent(self, *args, **kwargs) -> str:
|
||||
return ""
|
||||
|
||||
@@ -952,3 +1028,81 @@ async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
|
||||
return AutoTopUpConfig(threshold=0, amount=0)
|
||||
|
||||
return AutoTopUpConfig.model_validate(user.topUpConfig)
|
||||
|
||||
|
||||
async def admin_get_user_history(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
search: str | None = None,
|
||||
transaction_filter: CreditTransactionType | None = None,
|
||||
) -> UserHistoryResponse:
|
||||
|
||||
if page < 1 or page_size < 1:
|
||||
raise ValueError("Invalid pagination input")
|
||||
|
||||
where_clause: CreditTransactionWhereInput = {}
|
||||
if transaction_filter:
|
||||
where_clause["type"] = transaction_filter
|
||||
if search:
|
||||
where_clause["OR"] = [
|
||||
{"userId": {"contains": search, "mode": "insensitive"}},
|
||||
{"User": {"is": {"email": {"contains": search, "mode": "insensitive"}}}},
|
||||
{"User": {"is": {"name": {"contains": search, "mode": "insensitive"}}}},
|
||||
]
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where=where_clause,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
include={"User": True},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
total = await CreditTransaction.prisma().count(where=where_clause)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
history = []
|
||||
for tx in transactions:
|
||||
admin_id = ""
|
||||
admin_email = ""
|
||||
reason = ""
|
||||
|
||||
metadata: dict = cast(dict, tx.metadata) or {}
|
||||
|
||||
if metadata:
|
||||
admin_id = metadata.get("admin_id")
|
||||
admin_email = (
|
||||
(await get_user_email_by_id(admin_id) or f"Unknown Admin: {admin_id}")
|
||||
if admin_id
|
||||
else ""
|
||||
)
|
||||
reason = metadata.get("reason", "No reason provided")
|
||||
|
||||
balance, last_update = await get_user_credit_model()._get_credits(tx.userId)
|
||||
|
||||
history.append(
|
||||
UserTransaction(
|
||||
transaction_key=tx.transactionKey,
|
||||
transaction_time=tx.createdAt,
|
||||
transaction_type=tx.type,
|
||||
amount=tx.amount,
|
||||
current_balance=balance,
|
||||
running_balance=tx.runningBalance or 0,
|
||||
user_id=tx.userId,
|
||||
user_email=(
|
||||
tx.User.email
|
||||
if tx.User
|
||||
else (await get_user_by_id(tx.userId)).email
|
||||
),
|
||||
reason=reason,
|
||||
admin_email=admin_email,
|
||||
extra_data=str(metadata),
|
||||
)
|
||||
)
|
||||
return UserHistoryResponse(
|
||||
history=history,
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -62,10 +62,10 @@ async def connect():
|
||||
|
||||
# Connection acquired from a pool like Supabase somehow still possibly allows
|
||||
# the db client obtains a connection but still reject query connection afterward.
|
||||
try:
|
||||
await prisma.execute_raw("SELECT 1")
|
||||
except Exception as e:
|
||||
raise ConnectionError("Failed to connect to Prisma.") from e
|
||||
# try:
|
||||
# await prisma.execute_raw("SELECT 1")
|
||||
# except Exception as e:
|
||||
# raise ConnectionError("Failed to connect to Prisma.") from e
|
||||
|
||||
|
||||
@conn_retry("Prisma", "Releasing connection")
|
||||
@@ -89,7 +89,7 @@ async def transaction():
|
||||
async def locked_transaction(key: str):
|
||||
lock_key = zlib.crc32(key.encode("utf-8"))
|
||||
async with transaction() as tx:
|
||||
await tx.execute_raw(f"SELECT pg_advisory_xact_lock({lock_key})")
|
||||
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
|
||||
yield tx
|
||||
|
||||
|
||||
|
||||
@@ -23,7 +23,10 @@ from prisma.models import (
|
||||
AgentNodeExecutionInputOutput,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionCreateInput,
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionCreateInput,
|
||||
AgentNodeExecutionInputOutputCreateInput,
|
||||
AgentNodeExecutionUpdateInput,
|
||||
AgentNodeExecutionWhereInput,
|
||||
)
|
||||
@@ -31,18 +34,17 @@ from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import mock
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .block import BlockData, BlockInput, BlockType, CompletedBlockOutput, get_block
|
||||
from .block import BlockInput, BlockType, CompletedBlockOutput, get_block
|
||||
from .db import BaseDbModel
|
||||
from .includes import (
|
||||
EXECUTION_RESULT_INCLUDE,
|
||||
GRAPH_EXECUTION_INCLUDE,
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
||||
from .queue import AsyncRedisEventBus, RedisEventBus
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -59,23 +61,27 @@ ExecutionStatus = AgentExecutionStatus
|
||||
|
||||
class GraphExecutionMeta(BaseDbModel):
|
||||
user_id: str
|
||||
started_at: datetime
|
||||
ended_at: datetime
|
||||
cost: Optional[int] = Field(..., description="Execution cost in credits")
|
||||
duration: float = Field(..., description="Seconds from start to end of run")
|
||||
total_run_time: float = Field(..., description="Seconds of node runtime")
|
||||
status: ExecutionStatus
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
preset_id: Optional[str] = None
|
||||
status: ExecutionStatus
|
||||
started_at: datetime
|
||||
ended_at: datetime
|
||||
|
||||
class Stats(BaseModel):
|
||||
cost: int = Field(..., description="Execution cost (cents)")
|
||||
duration: float = Field(..., description="Seconds from start to end of run")
|
||||
node_exec_time: float = Field(..., description="Seconds of total node runtime")
|
||||
node_exec_count: int = Field(..., description="Number of node executions")
|
||||
|
||||
stats: Stats | None
|
||||
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
now = datetime.now(timezone.utc)
|
||||
# TODO: make started_at and ended_at optional
|
||||
start_time = _graph_exec.startedAt or _graph_exec.createdAt
|
||||
end_time = _graph_exec.updatedAt or now
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
total_run_time = duration
|
||||
|
||||
try:
|
||||
stats = GraphExecutionStats.model_validate(_graph_exec.stats)
|
||||
@@ -87,21 +93,25 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
)
|
||||
stats = None
|
||||
|
||||
duration = stats.walltime if stats else duration
|
||||
total_run_time = stats.nodes_walltime if stats else total_run_time
|
||||
|
||||
return GraphExecutionMeta(
|
||||
id=_graph_exec.id,
|
||||
user_id=_graph_exec.userId,
|
||||
started_at=start_time,
|
||||
ended_at=end_time,
|
||||
cost=stats.cost if stats else None,
|
||||
duration=duration,
|
||||
total_run_time=total_run_time,
|
||||
status=ExecutionStatus(_graph_exec.executionStatus),
|
||||
graph_id=_graph_exec.agentGraphId,
|
||||
graph_version=_graph_exec.agentGraphVersion,
|
||||
preset_id=_graph_exec.agentPresetId,
|
||||
status=ExecutionStatus(_graph_exec.executionStatus),
|
||||
started_at=start_time,
|
||||
ended_at=end_time,
|
||||
stats=(
|
||||
GraphExecutionMeta.Stats(
|
||||
cost=stats.cost,
|
||||
duration=stats.walltime,
|
||||
node_exec_time=stats.nodes_walltime,
|
||||
node_exec_count=stats.node_count,
|
||||
)
|
||||
if stats
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -111,15 +121,16 @@ class GraphExecution(GraphExecutionMeta):
|
||||
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
if _graph_exec.AgentNodeExecutions is None:
|
||||
if _graph_exec.NodeExecutions is None:
|
||||
raise ValueError("Node executions must be included in query")
|
||||
|
||||
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
|
||||
|
||||
node_executions = sorted(
|
||||
complete_node_executions = sorted(
|
||||
[
|
||||
NodeExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
for ne in _graph_exec.NodeExecutions
|
||||
if ne.executionStatus != ExecutionStatus.INCOMPLETE
|
||||
],
|
||||
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
|
||||
)
|
||||
@@ -128,7 +139,7 @@ class GraphExecution(GraphExecutionMeta):
|
||||
**{
|
||||
# inputs from Agent Input Blocks
|
||||
exec.input_data["name"]: exec.input_data.get("value")
|
||||
for exec in node_executions
|
||||
for exec in complete_node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type == BlockType.INPUT
|
||||
@@ -137,7 +148,7 @@ class GraphExecution(GraphExecutionMeta):
|
||||
**{
|
||||
# input from webhook-triggered block
|
||||
"payload": exec.input_data["payload"]
|
||||
for exec in node_executions
|
||||
for exec in complete_node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type
|
||||
@@ -147,7 +158,7 @@ class GraphExecution(GraphExecutionMeta):
|
||||
}
|
||||
|
||||
outputs: CompletedBlockOutput = defaultdict(list)
|
||||
for exec in node_executions:
|
||||
for exec in complete_node_executions:
|
||||
if (
|
||||
block := get_block(exec.block_id)
|
||||
) and block.block_type == BlockType.OUTPUT:
|
||||
@@ -158,7 +169,7 @@ class GraphExecution(GraphExecutionMeta):
|
||||
return GraphExecution(
|
||||
**{
|
||||
field_name: getattr(graph_exec, field_name)
|
||||
for field_name in graph_exec.model_fields
|
||||
for field_name in GraphExecutionMeta.model_fields
|
||||
},
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
@@ -170,7 +181,7 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
if _graph_exec.AgentNodeExecutions is None:
|
||||
if _graph_exec.NodeExecutions is None:
|
||||
raise ValueError("Node executions must be included in query")
|
||||
|
||||
graph_exec_with_io = GraphExecution.from_db(_graph_exec)
|
||||
@@ -178,7 +189,7 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
node_executions = sorted(
|
||||
[
|
||||
NodeExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
for ne in _graph_exec.NodeExecutions
|
||||
],
|
||||
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
|
||||
)
|
||||
@@ -186,11 +197,20 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
return GraphExecutionWithNodes(
|
||||
**{
|
||||
field_name: getattr(graph_exec_with_io, field_name)
|
||||
for field_name in graph_exec_with_io.model_fields
|
||||
for field_name in GraphExecution.model_fields
|
||||
},
|
||||
node_executions=node_executions,
|
||||
)
|
||||
|
||||
def to_graph_execution_entry(self):
|
||||
return GraphExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_id=self.graph_id,
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
node_credentials_input_map={}, # FIXME
|
||||
)
|
||||
|
||||
|
||||
class NodeExecutionResult(BaseModel):
|
||||
user_id: str
|
||||
@@ -209,21 +229,21 @@ class NodeExecutionResult(BaseModel):
|
||||
end_time: datetime | None
|
||||
|
||||
@staticmethod
|
||||
def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None):
|
||||
if execution.executionData:
|
||||
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
|
||||
if _node_exec.executionData:
|
||||
# Execution that has been queued for execution will persist its data.
|
||||
input_data = type_utils.convert(execution.executionData, dict[str, Any])
|
||||
input_data = type_utils.convert(_node_exec.executionData, dict[str, Any])
|
||||
else:
|
||||
# For incomplete execution, executionData will not be yet available.
|
||||
input_data: BlockInput = defaultdict()
|
||||
for data in execution.Input or []:
|
||||
for data in _node_exec.Input or []:
|
||||
input_data[data.name] = type_utils.convert(data.data, type[Any])
|
||||
|
||||
output_data: CompletedBlockOutput = defaultdict(list)
|
||||
for data in execution.Output or []:
|
||||
for data in _node_exec.Output or []:
|
||||
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
|
||||
|
||||
graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution
|
||||
graph_execution: AgentGraphExecution | None = _node_exec.GraphExecution
|
||||
if graph_execution:
|
||||
user_id = graph_execution.userId
|
||||
elif not user_id:
|
||||
@@ -235,17 +255,28 @@ class NodeExecutionResult(BaseModel):
|
||||
user_id=user_id,
|
||||
graph_id=graph_execution.agentGraphId if graph_execution else "",
|
||||
graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
|
||||
graph_exec_id=execution.agentGraphExecutionId,
|
||||
block_id=execution.AgentNode.agentBlockId if execution.AgentNode else "",
|
||||
node_exec_id=execution.id,
|
||||
node_id=execution.agentNodeId,
|
||||
status=execution.executionStatus,
|
||||
graph_exec_id=_node_exec.agentGraphExecutionId,
|
||||
block_id=_node_exec.Node.agentBlockId if _node_exec.Node else "",
|
||||
node_exec_id=_node_exec.id,
|
||||
node_id=_node_exec.agentNodeId,
|
||||
status=_node_exec.executionStatus,
|
||||
input_data=input_data,
|
||||
output_data=output_data,
|
||||
add_time=execution.addedTime,
|
||||
queue_time=execution.queuedTime,
|
||||
start_time=execution.startedTime,
|
||||
end_time=execution.endedTime,
|
||||
add_time=_node_exec.addedTime,
|
||||
queue_time=_node_exec.queuedTime,
|
||||
start_time=_node_exec.startedTime,
|
||||
end_time=_node_exec.endedTime,
|
||||
)
|
||||
|
||||
def to_node_execution_entry(self) -> "NodeExecutionEntry":
|
||||
return NodeExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_exec_id=self.graph_exec_id,
|
||||
graph_id=self.graph_id,
|
||||
node_exec_id=self.node_exec_id,
|
||||
node_id=self.node_id,
|
||||
block_id=self.block_id,
|
||||
data=self.input_data,
|
||||
)
|
||||
|
||||
|
||||
@@ -330,7 +361,7 @@ async def get_graph_execution(
|
||||
async def create_graph_execution(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
nodes_input: list[tuple[str, BlockInput]],
|
||||
starting_nodes_input: list[tuple[str, BlockInput]],
|
||||
user_id: str,
|
||||
preset_id: str | None = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
@@ -340,29 +371,29 @@ async def create_graph_execution(
|
||||
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
||||
"""
|
||||
result = await AgentGraphExecution.prisma().create(
|
||||
data={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
"AgentNodeExecutions": {
|
||||
"create": [ # type: ignore
|
||||
{
|
||||
"agentNodeId": node_id,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
"queuedTime": datetime.now(tz=timezone.utc),
|
||||
"Input": {
|
||||
data=AgentGraphExecutionCreateInput(
|
||||
agentGraphId=graph_id,
|
||||
agentGraphVersion=graph_version,
|
||||
executionStatus=ExecutionStatus.QUEUED,
|
||||
NodeExecutions={
|
||||
"create": [
|
||||
AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
executionStatus=ExecutionStatus.QUEUED,
|
||||
queuedTime=datetime.now(tz=timezone.utc),
|
||||
Input={
|
||||
"create": [
|
||||
{"name": name, "data": Json(data)}
|
||||
for name, data in node_input.items()
|
||||
]
|
||||
},
|
||||
}
|
||||
for node_id, node_input in nodes_input
|
||||
)
|
||||
for node_id, node_input in starting_nodes_input
|
||||
]
|
||||
},
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
},
|
||||
userId=user_id,
|
||||
agentPresetId=preset_id,
|
||||
),
|
||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
|
||||
@@ -409,11 +440,11 @@ async def upsert_execution_input(
|
||||
|
||||
if existing_execution:
|
||||
await AgentNodeExecutionInputOutput.prisma().create(
|
||||
data={
|
||||
"name": input_name,
|
||||
"data": json_input_data,
|
||||
"referencedByInputExecId": existing_execution.id,
|
||||
}
|
||||
data=AgentNodeExecutionInputOutputCreateInput(
|
||||
name=input_name,
|
||||
data=json_input_data,
|
||||
referencedByInputExecId=existing_execution.id,
|
||||
)
|
||||
)
|
||||
return existing_execution.id, {
|
||||
**{
|
||||
@@ -425,12 +456,12 @@ async def upsert_execution_input(
|
||||
|
||||
elif not node_exec_id:
|
||||
result = await AgentNodeExecution.prisma().create(
|
||||
data={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"Input": {"create": {"name": input_name, "data": json_input_data}},
|
||||
}
|
||||
data=AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
agentGraphExecutionId=graph_exec_id,
|
||||
executionStatus=ExecutionStatus.INCOMPLETE,
|
||||
Input={"create": {"name": input_name, "data": json_input_data}},
|
||||
)
|
||||
)
|
||||
return result.id, {input_name: input_data}
|
||||
|
||||
@@ -449,15 +480,17 @@ async def upsert_execution_output(
|
||||
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
||||
"""
|
||||
await AgentNodeExecutionInputOutput.prisma().create(
|
||||
data={
|
||||
"name": output_name,
|
||||
"data": Json(output_data),
|
||||
"referencedByOutputExecId": node_exec_id,
|
||||
}
|
||||
data=AgentNodeExecutionInputOutputCreateInput(
|
||||
name=output_name,
|
||||
data=Json(output_data),
|
||||
referencedByOutputExecId=node_exec_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecution:
|
||||
async def update_graph_execution_start_time(
|
||||
graph_exec_id: str,
|
||||
) -> GraphExecution | None:
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
data={
|
||||
@@ -466,10 +499,7 @@ async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecutio
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
if not res:
|
||||
raise ValueError(f"Graph execution #{graph_exec_id} not found")
|
||||
|
||||
return GraphExecution.from_db(res)
|
||||
return GraphExecution.from_db(res) if res else None
|
||||
|
||||
|
||||
async def update_graph_execution_stats(
|
||||
@@ -480,7 +510,8 @@ async def update_graph_execution_stats(
|
||||
data = stats.model_dump() if stats else {}
|
||||
if isinstance(data.get("error"), Exception):
|
||||
data["error"] = str(data["error"])
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
|
||||
updated_count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"OR": [
|
||||
@@ -492,10 +523,15 @@ async def update_graph_execution_stats(
|
||||
"executionStatus": status,
|
||||
"stats": Json(data),
|
||||
},
|
||||
)
|
||||
if updated_count == 0:
|
||||
return None
|
||||
|
||||
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
|
||||
where={"id": graph_exec_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
|
||||
return GraphExecution.from_db(res) if res else None
|
||||
return GraphExecution.from_db(graph_exec)
|
||||
|
||||
|
||||
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
|
||||
@@ -579,8 +615,9 @@ async def delete_graph_execution(
|
||||
)
|
||||
|
||||
|
||||
async def get_node_execution_results(
|
||||
async def get_node_executions(
|
||||
graph_exec_id: str,
|
||||
node_id: str | None = None,
|
||||
block_ids: list[str] | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
limit: int | None = None,
|
||||
@@ -588,8 +625,10 @@ async def get_node_execution_results(
|
||||
where_clause: AgentNodeExecutionWhereInput = {
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
}
|
||||
if node_id:
|
||||
where_clause["agentNodeId"] = node_id
|
||||
if block_ids:
|
||||
where_clause["AgentNode"] = {"is": {"agentBlockId": {"in": block_ids}}}
|
||||
where_clause["Node"] = {"is": {"agentBlockId": {"in": block_ids}}}
|
||||
if statuses:
|
||||
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
@@ -631,7 +670,7 @@ async def get_latest_node_execution(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
|
||||
"NOT": [{"executionStatus": ExecutionStatus.INCOMPLETE}],
|
||||
},
|
||||
order=[
|
||||
{"queuedTime": "desc"},
|
||||
@@ -644,20 +683,6 @@ async def get_latest_node_execution(
|
||||
return NodeExecutionResult.from_db(execution)
|
||||
|
||||
|
||||
async def get_incomplete_node_executions(
|
||||
node_id: str, graph_eid: str
|
||||
) -> list[NodeExecutionResult]:
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
},
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
return [NodeExecutionResult.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
# ----------------- Execution Infrastructure ----------------- #
|
||||
|
||||
|
||||
@@ -666,7 +691,7 @@ class GraphExecutionEntry(BaseModel):
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
start_node_execs: list["NodeExecutionEntry"]
|
||||
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]]
|
||||
|
||||
|
||||
class NodeExecutionEntry(BaseModel):
|
||||
@@ -699,144 +724,6 @@ class ExecutionQueue(Generic[T]):
|
||||
return self.queue.empty()
|
||||
|
||||
|
||||
# ------------------- Execution Utilities -------------------- #
|
||||
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
"""
|
||||
Extracts partial output data by name from a given BlockData.
|
||||
|
||||
The function supports extracting data from lists, dictionaries, and objects
|
||||
using specific naming conventions:
|
||||
- For lists: <output_name>_$_<index>
|
||||
- For dictionaries: <output_name>_#_<key>
|
||||
- For objects: <output_name>_@_<attribute>
|
||||
|
||||
Args:
|
||||
output (BlockData): A tuple containing the output name and data.
|
||||
name (str): The name used to extract specific data from the output.
|
||||
|
||||
Returns:
|
||||
Any | None: The extracted data if found, otherwise None.
|
||||
|
||||
Examples:
|
||||
>>> output = ("result", [10, 20, 30])
|
||||
>>> parse_execution_output(output, "result_$_1")
|
||||
20
|
||||
|
||||
>>> output = ("config", {"key1": "value1", "key2": "value2"})
|
||||
>>> parse_execution_output(output, "config_#_key1")
|
||||
'value1'
|
||||
|
||||
>>> class Sample:
|
||||
... attr1 = "value1"
|
||||
... attr2 = "value2"
|
||||
>>> output = ("object", Sample())
|
||||
>>> parse_execution_output(output, "object_@_attr1")
|
||||
'value1'
|
||||
"""
|
||||
output_name, output_data = output
|
||||
|
||||
if name == output_name:
|
||||
return output_data
|
||||
|
||||
if name.startswith(f"{output_name}{LIST_SPLIT}"):
|
||||
index = int(name.split(LIST_SPLIT)[1])
|
||||
if not isinstance(output_data, list) or len(output_data) <= index:
|
||||
return None
|
||||
return output_data[int(name.split(LIST_SPLIT)[1])]
|
||||
|
||||
if name.startswith(f"{output_name}{DICT_SPLIT}"):
|
||||
index = name.split(DICT_SPLIT)[1]
|
||||
if not isinstance(output_data, dict) or index not in output_data:
|
||||
return None
|
||||
return output_data[index]
|
||||
|
||||
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
|
||||
index = name.split(OBJC_SPLIT)[1]
|
||||
if isinstance(output_data, object) and hasattr(output_data, index):
|
||||
return getattr(output_data, index)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
"""
|
||||
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
|
||||
|
||||
This function processes input keys that follow specific patterns to merge them into a unified structure:
|
||||
- `<input_name>_$_<index>` for list inputs.
|
||||
- `<input_name>_#_<index>` for dictionary inputs.
|
||||
- `<input_name>_@_<index>` for object inputs.
|
||||
|
||||
Args:
|
||||
data (BlockInput): A dictionary containing input keys and their corresponding values.
|
||||
|
||||
Returns:
|
||||
BlockInput: A dictionary with merged inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If a list index is not an integer.
|
||||
|
||||
Examples:
|
||||
>>> data = {
|
||||
... "list_$_0": "a",
|
||||
... "list_$_1": "b",
|
||||
... "dict_#_key1": "value1",
|
||||
... "dict_#_key2": "value2",
|
||||
... "object_@_attr1": "value1",
|
||||
... "object_@_attr2": "value2"
|
||||
... }
|
||||
>>> merge_execution_input(data)
|
||||
{
|
||||
"list": ["a", "b"],
|
||||
"dict": {"key1": "value1", "key2": "value2"},
|
||||
"object": <MockObject attr1="value1" attr2="value2">
|
||||
}
|
||||
"""
|
||||
|
||||
# Merge all input with <input_name>_$_<index> into a single list.
|
||||
items = list(data.items())
|
||||
|
||||
for key, value in items:
|
||||
if LIST_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(LIST_SPLIT)
|
||||
if not index.isdigit():
|
||||
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
|
||||
|
||||
data[name] = data.get(name, [])
|
||||
if int(index) >= len(data[name]):
|
||||
# Pad list with empty string on missing indices.
|
||||
data[name].extend([""] * (int(index) - len(data[name]) + 1))
|
||||
data[name][int(index)] = value
|
||||
|
||||
# Merge all input with <input_name>_#_<index> into a single dict.
|
||||
for key, value in items:
|
||||
if DICT_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(DICT_SPLIT)
|
||||
data[name] = data.get(name, {})
|
||||
data[name][index] = value
|
||||
|
||||
# Merge all input with <input_name>_@_<index> into a single object.
|
||||
for key, value in items:
|
||||
if OBJC_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(OBJC_SPLIT)
|
||||
if name not in data or not isinstance(data[name], object):
|
||||
data[name] = mock.MockObject()
|
||||
setattr(data[name], index, value)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# --------------------- Event Bus --------------------- #
|
||||
|
||||
|
||||
|
||||
@@ -1,19 +1,31 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any, Literal, Optional, Type
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
import prisma
|
||||
from prisma import Json
|
||||
from prisma.enums import SubmissionStatus
|
||||
from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVersion
|
||||
from prisma.types import AgentGraphWhereInput
|
||||
from prisma.types import (
|
||||
AgentGraphCreateInput,
|
||||
AgentGraphWhereInput,
|
||||
AgentNodeCreateInput,
|
||||
AgentNodeLinkCreateInput,
|
||||
)
|
||||
from pydantic import create_model
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import prisma as db
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
from backend.util import type as type_utils
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
@@ -53,22 +65,23 @@ class Node(BaseDbModel):
|
||||
input_links: list[Link] = []
|
||||
output_links: list[Link] = []
|
||||
|
||||
webhook_id: Optional[str] = None
|
||||
@property
|
||||
def block(self) -> Block[BlockSchema, BlockSchema]:
|
||||
block = get_block(self.block_id)
|
||||
if not block:
|
||||
raise ValueError(
|
||||
f"Block #{self.block_id} does not exist -> Node #{self.id} is invalid"
|
||||
)
|
||||
return block
|
||||
|
||||
|
||||
class NodeModel(Node):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
|
||||
webhook_id: Optional[str] = None
|
||||
webhook: Optional[Webhook] = None
|
||||
|
||||
@property
|
||||
def block(self) -> Block[BlockSchema, BlockSchema]:
|
||||
block = get_block(self.block_id)
|
||||
if not block:
|
||||
raise ValueError(f"Block #{self.block_id} does not exist")
|
||||
return block
|
||||
|
||||
@staticmethod
|
||||
def from_db(node: AgentNode, for_export: bool = False) -> "NodeModel":
|
||||
obj = NodeModel(
|
||||
@@ -88,8 +101,7 @@ class NodeModel(Node):
|
||||
return obj
|
||||
|
||||
def is_triggered_by_event_type(self, event_type: str) -> bool:
|
||||
if not (block := get_block(self.block_id)):
|
||||
raise ValueError(f"Block #{self.block_id} not found for node #{self.id}")
|
||||
block = self.block
|
||||
if not block.webhook_config:
|
||||
raise TypeError("This method can't be used on non-webhook blocks")
|
||||
if not block.webhook_config.event_filter_input:
|
||||
@@ -160,17 +172,18 @@ class BaseGraph(BaseDbModel):
|
||||
description: str
|
||||
nodes: list[Node] = []
|
||||
links: list[Link] = []
|
||||
forked_from_id: str | None = None
|
||||
forked_from_version: int | None = None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return self._generate_schema(
|
||||
*(
|
||||
(b.input_schema, node.input_default)
|
||||
(block.input_schema, node.input_default)
|
||||
for node in self.nodes
|
||||
if (b := get_block(node.block_id))
|
||||
and b.block_type == BlockType.INPUT
|
||||
and issubclass(b.input_schema, AgentInputBlock.Input)
|
||||
if (block := node.block).block_type == BlockType.INPUT
|
||||
and issubclass(block.input_schema, AgentInputBlock.Input)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -179,22 +192,26 @@ class BaseGraph(BaseDbModel):
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return self._generate_schema(
|
||||
*(
|
||||
(b.input_schema, node.input_default)
|
||||
(block.input_schema, node.input_default)
|
||||
for node in self.nodes
|
||||
if (b := get_block(node.block_id))
|
||||
and b.block_type == BlockType.OUTPUT
|
||||
and issubclass(b.input_schema, AgentOutputBlock.Input)
|
||||
if (block := node.block).block_type == BlockType.OUTPUT
|
||||
and issubclass(block.input_schema, AgentOutputBlock.Input)
|
||||
)
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
return self._credentials_input_schema.jsonschema()
|
||||
|
||||
@staticmethod
|
||||
def _generate_schema(
|
||||
*props: tuple[Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input], dict],
|
||||
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
|
||||
) -> dict[str, Any]:
|
||||
schema = []
|
||||
schema_fields: list[AgentInputBlock.Input | AgentOutputBlock.Input] = []
|
||||
for type_class, input_default in props:
|
||||
try:
|
||||
schema.append(type_class(**input_default))
|
||||
schema_fields.append(type_class(**input_default))
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid {type_class}: {input_default}, {e}")
|
||||
|
||||
@@ -214,9 +231,93 @@ class BaseGraph(BaseDbModel):
|
||||
**({"description": p.description} if p.description else {}),
|
||||
**({"default": p.value} if p.value is not None else {}),
|
||||
}
|
||||
for p in schema
|
||||
for p in schema_fields
|
||||
},
|
||||
"required": [p.name for p in schema if p.value is None],
|
||||
"required": [p.name for p in schema_fields if p.value is None],
|
||||
}
|
||||
|
||||
@property
|
||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
logger.debug(
|
||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||
f"{graph_credentials_inputs}"
|
||||
)
|
||||
|
||||
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
||||
graph_cred_fields = list(graph_credentials_inputs.values())
|
||||
for i, (field, keys) in enumerate(graph_cred_fields):
|
||||
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
||||
if field.provider != other_field.provider:
|
||||
continue
|
||||
|
||||
# If this happens, that means a block implementation probably needs
|
||||
# to be updated.
|
||||
logger.warning(
|
||||
"Multiple combined credentials fields "
|
||||
f"for provider {field.provider} "
|
||||
f"on graph #{self.id} ({self.name}); "
|
||||
f"fields: {field} <> {other_field};"
|
||||
f"keys: {keys} <> {other_keys}."
|
||||
)
|
||||
|
||||
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
|
||||
agg_field_key: (
|
||||
CredentialsMetaInput[
|
||||
Literal[tuple(field_info.provider)], # type: ignore
|
||||
Literal[tuple(field_info.supported_types)], # type: ignore
|
||||
],
|
||||
CredentialsField(
|
||||
required_scopes=set(field_info.required_scopes or []),
|
||||
discriminator=field_info.discriminator,
|
||||
discriminator_mapping=field_info.discriminator_mapping,
|
||||
),
|
||||
)
|
||||
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
||||
}
|
||||
|
||||
return create_model(
|
||||
self.name.replace(" ", "") + "CredentialsInputSchema",
|
||||
__base__=BlockSchema,
|
||||
**fields, # type: ignore
|
||||
)
|
||||
|
||||
def aggregate_credentials_inputs(
|
||||
self,
|
||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
|
||||
"""
|
||||
Returns:
|
||||
dict[aggregated_field_key, tuple(
|
||||
CredentialsFieldInfo: A spec for one aggregated credentials field
|
||||
set[(node_id, field_name)]: Node credentials fields that are
|
||||
compatible with this aggregated field spec
|
||||
)]
|
||||
"""
|
||||
return {
|
||||
"_".join(sorted(agg_field_info.provider))
|
||||
+ "_"
|
||||
+ "_".join(sorted(agg_field_info.supported_types))
|
||||
+ "_credentials": (agg_field_info, node_fields)
|
||||
for agg_field_info, node_fields in CredentialsFieldInfo.combine(
|
||||
*(
|
||||
(
|
||||
# Apply discrimination before aggregating credentials inputs
|
||||
(
|
||||
field_info.discriminate(
|
||||
node.input_default[field_info.discriminator]
|
||||
)
|
||||
if (
|
||||
field_info.discriminator
|
||||
and node.input_default.get(field_info.discriminator)
|
||||
)
|
||||
else field_info
|
||||
),
|
||||
(node.id, field_name),
|
||||
)
|
||||
for node in self.nodes
|
||||
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -228,13 +329,16 @@ class GraphModel(Graph):
|
||||
user_id: str
|
||||
nodes: list[NodeModel] = [] # type: ignore
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def starting_nodes(self) -> list[Node]:
|
||||
def has_webhook_trigger(self) -> bool:
|
||||
return self.webhook_input_node is not None
|
||||
|
||||
@property
|
||||
def starting_nodes(self) -> list[NodeModel]:
|
||||
outbound_nodes = {link.sink_id for link in self.links}
|
||||
input_nodes = {
|
||||
v.id
|
||||
for v in self.nodes
|
||||
if (b := get_block(v.block_id)) and b.block_type == BlockType.INPUT
|
||||
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
||||
}
|
||||
return [
|
||||
node
|
||||
@@ -242,6 +346,18 @@ class GraphModel(Graph):
|
||||
if node.id not in outbound_nodes or node.id in input_nodes
|
||||
]
|
||||
|
||||
@property
|
||||
def webhook_input_node(self) -> NodeModel | None:
|
||||
return next(
|
||||
(
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||
"""
|
||||
Reassigns all IDs in the graph to new UUIDs.
|
||||
@@ -295,15 +411,16 @@ class GraphModel(Graph):
|
||||
|
||||
@staticmethod
|
||||
def _validate_graph(graph: BaseGraph, for_run: bool = False):
|
||||
def is_tool_pin(name: str) -> bool:
|
||||
return name.startswith("tools_^_")
|
||||
|
||||
def sanitize(name):
|
||||
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
if sanitized_name.startswith("tools_^_"):
|
||||
return sanitized_name.split("_^_")[0]
|
||||
if is_tool_pin(sanitized_name):
|
||||
return "tools"
|
||||
return sanitized_name
|
||||
|
||||
# Validate smart decision maker nodes
|
||||
smart_decision_maker_nodes = set()
|
||||
agent_nodes = set()
|
||||
nodes_block = {
|
||||
node.id: block
|
||||
for node in graph.nodes
|
||||
@@ -314,13 +431,6 @@ class GraphModel(Graph):
|
||||
if (block := nodes_block.get(node.id)) is None:
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
# Smart decision maker nodes
|
||||
if block.block_type == BlockType.AI:
|
||||
smart_decision_maker_nodes.add(node.id)
|
||||
# Agent nodes
|
||||
elif block.block_type == BlockType.AGENT:
|
||||
agent_nodes.add(node.id)
|
||||
|
||||
input_links = defaultdict(list)
|
||||
|
||||
for link in graph.links:
|
||||
@@ -335,16 +445,21 @@ class GraphModel(Graph):
|
||||
[sanitize(name) for name in node.input_default]
|
||||
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
|
||||
)
|
||||
for name in block.input_schema.get_required_fields():
|
||||
input_schema = block.input_schema
|
||||
for name in (required_fields := input_schema.get_required_fields()):
|
||||
if (
|
||||
name not in provided_inputs
|
||||
# Webhook payload is passed in by ExecutionManager
|
||||
and not (
|
||||
name == "payload"
|
||||
and block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
)
|
||||
# Checking availability of credentials is done by ExecutionManager
|
||||
and name not in input_schema.get_credentials_fields()
|
||||
# Validate only I/O nodes, or validate everything when executing
|
||||
and (
|
||||
for_run # Skip input completion validation, unless when executing.
|
||||
for_run
|
||||
or block.block_type
|
||||
in [
|
||||
BlockType.INPUT,
|
||||
@@ -357,9 +472,18 @@ class GraphModel(Graph):
|
||||
f"Node {block.name} #{node.id} required input missing: `{name}`"
|
||||
)
|
||||
|
||||
if (
|
||||
block.block_type == BlockType.INPUT
|
||||
and (input_key := node.input_default.get("name"))
|
||||
and is_credentials_field_name(input_key)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Agent input node uses reserved name '{input_key}'; "
|
||||
"'credentials' and `*_credentials` are reserved input names"
|
||||
)
|
||||
|
||||
# Get input schema properties and check dependencies
|
||||
input_schema = block.input_schema.model_fields
|
||||
required_fields = block.input_schema.get_required_fields()
|
||||
input_fields = input_schema.model_fields
|
||||
|
||||
def has_value(name):
|
||||
return (
|
||||
@@ -367,14 +491,21 @@ class GraphModel(Graph):
|
||||
and name in node.input_default
|
||||
and node.input_default[name] is not None
|
||||
and str(node.input_default[name]).strip() != ""
|
||||
) or (name in input_schema and input_schema[name].default is not None)
|
||||
) or (name in input_fields and input_fields[name].default is not None)
|
||||
|
||||
# Validate dependencies between fields
|
||||
for field_name, field_info in input_schema.items():
|
||||
for field_name, field_info in input_fields.items():
|
||||
# Apply input dependency validation only on run & field with depends_on
|
||||
json_schema_extra = field_info.json_schema_extra or {}
|
||||
dependencies = json_schema_extra.get("depends_on", [])
|
||||
if not for_run or not dependencies:
|
||||
if not (
|
||||
for_run
|
||||
and isinstance(json_schema_extra, dict)
|
||||
and (
|
||||
dependencies := cast(
|
||||
list[str], json_schema_extra.get("depends_on", [])
|
||||
)
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
# Check if dependent field has value in input_default
|
||||
@@ -391,9 +522,7 @@ class GraphModel(Graph):
|
||||
node_map = {v.id: v for v in graph.nodes}
|
||||
|
||||
def is_static_output_block(nid: str) -> bool:
|
||||
bid = node_map[nid].block_id
|
||||
b = get_block(bid)
|
||||
return b.static_output if b else False
|
||||
return node_map[nid].block.static_output
|
||||
|
||||
# Links: links are connected and the connected pin data type are compatible.
|
||||
for link in graph.links:
|
||||
@@ -429,7 +558,7 @@ class GraphModel(Graph):
|
||||
if block.block_type not in [BlockType.AGENT]
|
||||
else vals.get("input_schema", {}).get("properties", {}).keys()
|
||||
)
|
||||
if sanitized_name not in fields and not name.startswith("tools_^_"):
|
||||
if sanitized_name not in fields and not is_tool_pin(name):
|
||||
fields_msg = f"Allowed fields: {fields}"
|
||||
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
|
||||
|
||||
@@ -446,16 +575,16 @@ class GraphModel(Graph):
|
||||
id=graph.id,
|
||||
user_id=graph.userId if not for_export else "",
|
||||
version=graph.version,
|
||||
forked_from_id=graph.forkedFromId,
|
||||
forked_from_version=graph.forkedFromVersion,
|
||||
is_active=graph.isActive,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
nodes=[
|
||||
NodeModel.from_db(node, for_export) for node in graph.AgentNodes or []
|
||||
],
|
||||
nodes=[NodeModel.from_db(node, for_export) for node in graph.Nodes or []],
|
||||
links=list(
|
||||
{
|
||||
Link.from_db(link)
|
||||
for node in graph.AgentNodes or []
|
||||
for node in graph.Nodes or []
|
||||
for link in (node.Input or []) + (node.Output or [])
|
||||
}
|
||||
),
|
||||
@@ -586,8 +715,8 @@ async def get_graph(
|
||||
and not (
|
||||
await StoreListingVersion.prisma().find_first(
|
||||
where={
|
||||
"agentId": graph_id,
|
||||
"agentVersion": version or graph.version,
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": version or graph.version,
|
||||
"isDeleted": False,
|
||||
"submissionStatus": SubmissionStatus.APPROVED,
|
||||
}
|
||||
@@ -607,6 +736,58 @@ async def get_graph(
|
||||
return GraphModel.from_db(graph, for_export)
|
||||
|
||||
|
||||
async def get_graph_as_admin(
|
||||
graph_id: str,
|
||||
version: int | None = None,
|
||||
user_id: str | None = None,
|
||||
for_export: bool = False,
|
||||
) -> GraphModel | None:
|
||||
"""
|
||||
Intentionally parallels the get_graph but should only be used for admin tasks, because can return any graph that's been submitted
|
||||
Retrieves a graph from the DB.
|
||||
Defaults to the version with `is_active` if `version` is not passed.
|
||||
|
||||
Returns `None` if the record is not found.
|
||||
"""
|
||||
logger.warning(f"Getting {graph_id=} {version=} as ADMIN {user_id=} {for_export=}")
|
||||
where_clause: AgentGraphWhereInput = {
|
||||
"id": graph_id,
|
||||
}
|
||||
|
||||
if version is not None:
|
||||
where_clause["version"] = version
|
||||
|
||||
graph = await AgentGraph.prisma().find_first(
|
||||
where=where_clause,
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
order={"version": "desc"},
|
||||
)
|
||||
|
||||
# For access, the graph must be owned by the user or listed in the store
|
||||
if graph is None or (
|
||||
graph.userId != user_id
|
||||
and not (
|
||||
await StoreListingVersion.prisma().find_first(
|
||||
where={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": version or graph.version,
|
||||
}
|
||||
)
|
||||
)
|
||||
):
|
||||
return None
|
||||
|
||||
if for_export:
|
||||
sub_graphs = await get_sub_graphs(graph)
|
||||
return GraphModel.from_db(
|
||||
graph=graph,
|
||||
sub_graphs=sub_graphs,
|
||||
for_export=for_export,
|
||||
)
|
||||
|
||||
return GraphModel.from_db(graph, for_export)
|
||||
|
||||
|
||||
async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
"""
|
||||
Iteratively fetches all sub-graphs of a given graph, and flattens them into a list.
|
||||
@@ -621,12 +802,16 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
sub_graph_ids = [
|
||||
(graph_id, graph_version)
|
||||
for graph in search_graphs
|
||||
for node in graph.AgentNodes or []
|
||||
for node in graph.Nodes or []
|
||||
if (
|
||||
node.AgentBlock
|
||||
and node.AgentBlock.id == agent_block_id
|
||||
and (graph_id := dict(node.constantInput).get("graph_id"))
|
||||
and (graph_version := dict(node.constantInput).get("graph_version"))
|
||||
and (graph_id := cast(str, dict(node.constantInput).get("graph_id")))
|
||||
and (
|
||||
graph_version := cast(
|
||||
int, dict(node.constantInput).get("graph_version")
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
if not sub_graph_ids:
|
||||
@@ -641,7 +826,7 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
|
||||
}
|
||||
for graph_id, graph_version in sub_graph_ids
|
||||
] # type: ignore
|
||||
]
|
||||
},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
@@ -655,7 +840,7 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
|
||||
links = await AgentNodeLink.prisma().find_many(
|
||||
where={"agentNodeSourceId": node_id},
|
||||
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}}, # type: ignore
|
||||
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}},
|
||||
)
|
||||
return [
|
||||
(Link.from_db(link), NodeModel.from_db(link.AgentNodeSink))
|
||||
@@ -721,34 +906,56 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
|
||||
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
|
||||
|
||||
|
||||
async def fork_graph(graph_id: str, graph_version: int, user_id: str) -> GraphModel:
|
||||
"""
|
||||
Forks a graph by copying it and all its nodes and links to a new graph.
|
||||
"""
|
||||
async with transaction() as tx:
|
||||
graph = await get_graph(graph_id, graph_version, user_id, True)
|
||||
if not graph:
|
||||
raise ValueError(f"Graph {graph_id} v{graph_version} not found")
|
||||
|
||||
# Set forked from ID and version as itself as it's about ot be copied
|
||||
graph.forked_from_id = graph.id
|
||||
graph.forked_from_version = graph.version
|
||||
graph.name = f"{graph.name} (copy)"
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||
graph.validate_graph(for_run=False)
|
||||
|
||||
await __create_graph(tx, graph, user_id)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
graphs = [graph] + graph.sub_graphs
|
||||
|
||||
await AgentGraph.prisma(tx).create_many(
|
||||
data=[
|
||||
{
|
||||
"id": graph.id,
|
||||
"version": graph.version,
|
||||
"name": graph.name,
|
||||
"description": graph.description,
|
||||
"isActive": graph.is_active,
|
||||
"userId": user_id,
|
||||
}
|
||||
AgentGraphCreateInput(
|
||||
id=graph.id,
|
||||
version=graph.version,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
isActive=graph.is_active,
|
||||
userId=user_id,
|
||||
forkedFromId=graph.forked_from_id,
|
||||
forkedFromVersion=graph.forked_from_version,
|
||||
)
|
||||
for graph in graphs
|
||||
]
|
||||
)
|
||||
|
||||
await AgentNode.prisma(tx).create_many(
|
||||
data=[
|
||||
{
|
||||
"id": node.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"agentBlockId": node.block_id,
|
||||
"constantInput": Json(node.input_default),
|
||||
"metadata": Json(node.metadata),
|
||||
"webhookId": node.webhook_id,
|
||||
}
|
||||
AgentNodeCreateInput(
|
||||
id=node.id,
|
||||
agentGraphId=graph.id,
|
||||
agentGraphVersion=graph.version,
|
||||
agentBlockId=node.block_id,
|
||||
constantInput=Json(node.input_default),
|
||||
metadata=Json(node.metadata),
|
||||
)
|
||||
for graph in graphs
|
||||
for node in graph.nodes
|
||||
]
|
||||
@@ -756,14 +963,14 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
|
||||
await AgentNodeLink.prisma(tx).create_many(
|
||||
data=[
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"sourceName": link.source_name,
|
||||
"sinkName": link.sink_name,
|
||||
"agentNodeSourceId": link.source_id,
|
||||
"agentNodeSinkId": link.sink_id,
|
||||
"isStatic": link.is_static,
|
||||
}
|
||||
AgentNodeLinkCreateInput(
|
||||
id=str(uuid.uuid4()),
|
||||
sourceName=link.source_name,
|
||||
sinkName=link.sink_name,
|
||||
agentNodeSourceId=link.source_id,
|
||||
agentNodeSinkId=link.sink_id,
|
||||
isStatic=link.is_static,
|
||||
)
|
||||
for graph in graphs
|
||||
for link in graph.links
|
||||
]
|
||||
@@ -814,12 +1021,12 @@ async def fix_llm_provider_credentials():
|
||||
SELECT graph."userId" user_id,
|
||||
node.id node_id,
|
||||
node."constantInput" node_preset_input
|
||||
FROM platform."AgentNode" node
|
||||
LEFT JOIN platform."AgentGraph" graph
|
||||
ON node."agentGraphId" = graph.id
|
||||
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
|
||||
ORDER BY graph."userId";
|
||||
"""
|
||||
FROM platform."AgentNode" node
|
||||
LEFT JOIN platform."AgentGraph" graph
|
||||
ON node."agentGraphId" = graph.id
|
||||
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
|
||||
ORDER BY graph."userId";
|
||||
"""
|
||||
)
|
||||
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")
|
||||
except Exception as e:
|
||||
@@ -897,17 +1104,24 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
if field.annotation == LlmModel:
|
||||
llm_model_fields[block.id] = field_name
|
||||
|
||||
# Convert enum values to a list of strings for the SQL query
|
||||
enum_values = [v.value for v in LlmModel]
|
||||
escaped_enum_values = repr(tuple(enum_values)) # hack but works
|
||||
|
||||
# Update each block
|
||||
for id, path in llm_model_fields.items():
|
||||
# Convert enum values to a list of strings for the SQL query
|
||||
enum_values = [v.value for v in LlmModel.__members__.values()]
|
||||
|
||||
query = f"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = jsonb_set("constantInput", '{{{path}}}', '"{migrate_to.value}"', true)
|
||||
WHERE "agentBlockId" = '{id}'
|
||||
AND "constantInput" ? '{path}'
|
||||
AND "constantInput"->>'{path}' NOT IN ({','.join(f"'{value}'" for value in enum_values)})
|
||||
UPDATE platform."AgentNode"
|
||||
SET "constantInput" = jsonb_set("constantInput", $1, to_jsonb($2), true)
|
||||
WHERE "agentBlockId" = $3
|
||||
AND "constantInput" ? ($4)::text
|
||||
AND "constantInput"->>($4)::text NOT IN {escaped_enum_values}
|
||||
"""
|
||||
|
||||
await db.execute_raw(query)
|
||||
await db.execute_raw(
|
||||
query, # type: ignore - is supposed to be LiteralString
|
||||
[path],
|
||||
migrate_to.value,
|
||||
id,
|
||||
path,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import prisma
|
||||
from typing import cast
|
||||
|
||||
import prisma.enums
|
||||
import prisma.types
|
||||
|
||||
from backend.blocks.io import IO_BLOCK_IDs
|
||||
|
||||
@@ -10,25 +13,25 @@ AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
}
|
||||
|
||||
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
|
||||
"Nodes": {"include": AGENT_NODE_INCLUDE}
|
||||
}
|
||||
|
||||
EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
"Node": True,
|
||||
"GraphExecution": True,
|
||||
}
|
||||
|
||||
MAX_NODE_EXECUTIONS_FETCH = 1000
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
|
||||
"AgentNodeExecutions": {
|
||||
"NodeExecutions": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
"Node": True,
|
||||
"GraphExecution": True,
|
||||
},
|
||||
"order_by": [
|
||||
{"queuedTime": "desc"},
|
||||
@@ -40,28 +43,30 @@ GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
|
||||
}
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
"AgentNodeExecutions": {
|
||||
**GRAPH_EXECUTION_INCLUDE_WITH_NODES["AgentNodeExecutions"], # type: ignore
|
||||
"NodeExecutions": {
|
||||
**cast(
|
||||
prisma.types.FindManyAgentNodeExecutionArgsFromAgentGraphExecution,
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES["NodeExecutions"],
|
||||
),
|
||||
"where": {
|
||||
"AgentNode": {
|
||||
"AgentBlock": {"id": {"in": IO_BLOCK_IDs}}, # type: ignore
|
||||
},
|
||||
"Node": {"is": {"AgentBlock": {"is": {"id": {"in": IO_BLOCK_IDs}}}}},
|
||||
"NOT": [{"executionStatus": prisma.enums.AgentExecutionStatus.INCOMPLETE}],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE}
|
||||
}
|
||||
|
||||
|
||||
def library_agent_include(user_id: str) -> prisma.types.LibraryAgentInclude:
|
||||
return {
|
||||
"Agent": {
|
||||
"AgentGraph": {
|
||||
"include": {
|
||||
**AGENT_GRAPH_INCLUDE,
|
||||
"AgentGraphExecution": {"where": {"userId": user_id}},
|
||||
"Executions": {"where": {"userId": user_id}},
|
||||
}
|
||||
},
|
||||
"Creator": True,
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, AsyncGenerator, Optional
|
||||
|
||||
from prisma import Json
|
||||
from prisma.models import IntegrationWebhook
|
||||
from prisma.types import IntegrationWebhookCreateInput
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
@@ -66,18 +67,18 @@ class Webhook(BaseDbModel):
|
||||
|
||||
async def create_webhook(webhook: Webhook) -> Webhook:
|
||||
created_webhook = await IntegrationWebhook.prisma().create(
|
||||
data={
|
||||
"id": webhook.id,
|
||||
"userId": webhook.user_id,
|
||||
"provider": webhook.provider.value,
|
||||
"credentialsId": webhook.credentials_id,
|
||||
"webhookType": webhook.webhook_type,
|
||||
"resource": webhook.resource,
|
||||
"events": webhook.events,
|
||||
"config": Json(webhook.config),
|
||||
"secret": webhook.secret,
|
||||
"providerWebhookId": webhook.provider_webhook_id,
|
||||
}
|
||||
data=IntegrationWebhookCreateInput(
|
||||
id=webhook.id,
|
||||
userId=webhook.user_id,
|
||||
provider=webhook.provider.value,
|
||||
credentialsId=webhook.credentials_id,
|
||||
webhookType=webhook.webhook_type,
|
||||
resource=webhook.resource,
|
||||
events=webhook.events,
|
||||
config=Json(webhook.config),
|
||||
secret=webhook.secret,
|
||||
providerWebhookId=webhook.provider_webhook_id,
|
||||
)
|
||||
)
|
||||
return Webhook.from_db(created_webhook)
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import enum
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -12,6 +14,7 @@ from typing import (
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
get_args,
|
||||
@@ -142,8 +145,12 @@ def SchemaField(
|
||||
exclude: bool = False,
|
||||
hidden: Optional[bool] = None,
|
||||
depends_on: Optional[list[str]] = None,
|
||||
ge: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
json_schema_extra: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
if default is PydanticUndefined and default_factory is None:
|
||||
advanced = False
|
||||
@@ -170,8 +177,12 @@ def SchemaField(
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
ge=ge,
|
||||
le=le,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
discriminator=discriminator,
|
||||
json_schema_extra=json_schema_extra,
|
||||
**kwargs,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
@@ -292,9 +303,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
field_schema = model.jsonschema()["properties"][field_name]
|
||||
try:
|
||||
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
|
||||
field_schema
|
||||
)
|
||||
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
||||
except ValidationError as e:
|
||||
if "Field required [type=missing" not in str(e):
|
||||
raise
|
||||
@@ -320,14 +329,90 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
|
||||
|
||||
class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
|
||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
|
||||
credentials_provider: list[CP]
|
||||
credentials_scopes: Optional[list[str]] = None
|
||||
credentials_types: list[CT]
|
||||
provider: frozenset[CP] = Field(..., alias="credentials_provider")
|
||||
supported_types: frozenset[CT] = Field(..., alias="credentials_types")
|
||||
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
|
||||
discriminator: Optional[str] = None
|
||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||
|
||||
@classmethod
|
||||
def combine(
|
||||
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
|
||||
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
|
||||
"""
|
||||
Combines multiple CredentialsFieldInfo objects into as few as possible.
|
||||
|
||||
Rules:
|
||||
- Items can only be combined if they have the same supported credentials types
|
||||
and the same supported providers.
|
||||
- When combining items, the `required_scopes` of the result is a join
|
||||
of the `required_scopes` of the original items.
|
||||
|
||||
Params:
|
||||
*fields: (CredentialsFieldInfo, key) objects to group and combine
|
||||
|
||||
Returns:
|
||||
A sequence of tuples containing combined CredentialsFieldInfo objects and
|
||||
the set of keys of the respective original items that were grouped together.
|
||||
"""
|
||||
if not fields:
|
||||
return []
|
||||
|
||||
# Group fields by their provider and supported_types
|
||||
grouped_fields: defaultdict[
|
||||
tuple[frozenset[CP], frozenset[CT]],
|
||||
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
|
||||
] = defaultdict(list)
|
||||
|
||||
for field, key in fields:
|
||||
group_key = (frozenset(field.provider), frozenset(field.supported_types))
|
||||
grouped_fields[group_key].append((key, field))
|
||||
|
||||
# Combine fields within each group
|
||||
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
|
||||
|
||||
for group in grouped_fields.values():
|
||||
# Start with the first field in the group
|
||||
_, combined = group[0]
|
||||
|
||||
# Track the keys that were combined
|
||||
combined_keys = {key for key, _ in group}
|
||||
|
||||
# Combine required_scopes from all fields in the group
|
||||
all_scopes = set()
|
||||
for _, field in group:
|
||||
if field.required_scopes:
|
||||
all_scopes.update(field.required_scopes)
|
||||
|
||||
# Create a new combined field
|
||||
result.append(
|
||||
(
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
credentials_types=combined.supported_types,
|
||||
credentials_scopes=frozenset(all_scopes) or None,
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
),
|
||||
combined_keys,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def discriminate(self, discriminator_value: Any) -> CredentialsFieldInfo:
|
||||
if not (self.discriminator and self.discriminator_mapping):
|
||||
return self
|
||||
|
||||
discriminator_value = self.discriminator_mapping[discriminator_value]
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([discriminator_value]),
|
||||
credentials_types=self.supported_types,
|
||||
credentials_scopes=self.required_scopes,
|
||||
)
|
||||
|
||||
|
||||
def CredentialsField(
|
||||
required_scopes: set[str] = set(),
|
||||
@@ -365,6 +450,12 @@ class ContributorDetails(BaseModel):
|
||||
name: str = Field(title="Name", description="The name of the contributor.")
|
||||
|
||||
|
||||
class TopUpType(enum.Enum):
|
||||
AUTO = "AUTO"
|
||||
MANUAL = "MANUAL"
|
||||
UNCATEGORIZED = "UNCATEGORIZED"
|
||||
|
||||
|
||||
class AutoTopUpConfig(BaseModel):
|
||||
amount: int
|
||||
"""Amount of credits to top up."""
|
||||
@@ -377,12 +468,18 @@ class UserTransaction(BaseModel):
|
||||
transaction_time: datetime = datetime.min.replace(tzinfo=timezone.utc)
|
||||
transaction_type: CreditTransactionType = CreditTransactionType.USAGE
|
||||
amount: int = 0
|
||||
balance: int = 0
|
||||
running_balance: int = 0
|
||||
current_balance: int = 0
|
||||
description: str | None = None
|
||||
usage_graph_id: str | None = None
|
||||
usage_execution_id: str | None = None
|
||||
usage_node_count: int = 0
|
||||
usage_start_time: datetime = datetime.max.replace(tzinfo=timezone.utc)
|
||||
user_id: str
|
||||
user_email: str | None = None
|
||||
reason: str | None = None
|
||||
admin_email: str | None = None
|
||||
extra_data: str | None = None
|
||||
|
||||
|
||||
class TransactionHistory(BaseModel):
|
||||
@@ -405,9 +502,10 @@ class RefundRequest(BaseModel):
|
||||
class NodeExecutionStats(BaseModel):
|
||||
"""Execution statistics for a node execution."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "allow"
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
error: Optional[Exception | str] = None
|
||||
walltime: float = 0
|
||||
@@ -423,9 +521,10 @@ class NodeExecutionStats(BaseModel):
|
||||
class GraphExecutionStats(BaseModel):
|
||||
"""Execution statistics for a graph execution."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "allow"
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
error: Optional[Exception | str] = None
|
||||
walltime: float = Field(
|
||||
|
||||
@@ -6,10 +6,14 @@ from typing import Annotated, Any, Generic, Optional, TypeVar, Union
|
||||
from prisma import Json
|
||||
from prisma.enums import NotificationType
|
||||
from prisma.models import NotificationEvent, UserNotificationBatch
|
||||
from prisma.types import UserNotificationBatchWhereInput
|
||||
from prisma.types import (
|
||||
NotificationEventCreateInput,
|
||||
UserNotificationBatchCreateInput,
|
||||
UserNotificationBatchWhereInput,
|
||||
)
|
||||
|
||||
# from backend.notifications.models import NotificationEvent
|
||||
from pydantic import BaseModel, EmailStr, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
|
||||
@@ -35,8 +39,7 @@ class QueueType(Enum):
|
||||
|
||||
|
||||
class BaseNotificationData(BaseModel):
|
||||
class Config:
|
||||
extra = "allow"
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class AgentRunData(BaseNotificationData):
|
||||
@@ -418,30 +421,30 @@ async def create_or_add_to_user_notification_batch(
|
||||
if not existing_batch:
|
||||
async with transaction() as tx:
|
||||
notification_event = await tx.notificationevent.create(
|
||||
data={
|
||||
"type": notification_type,
|
||||
"data": json_data,
|
||||
}
|
||||
data=NotificationEventCreateInput(
|
||||
type=notification_type,
|
||||
data=json_data,
|
||||
)
|
||||
)
|
||||
|
||||
# Create new batch
|
||||
resp = await tx.usernotificationbatch.create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"type": notification_type,
|
||||
"Notifications": {"connect": [{"id": notification_event.id}]},
|
||||
},
|
||||
data=UserNotificationBatchCreateInput(
|
||||
userId=user_id,
|
||||
type=notification_type,
|
||||
Notifications={"connect": [{"id": notification_event.id}]},
|
||||
),
|
||||
include={"Notifications": True},
|
||||
)
|
||||
return UserNotificationBatchDTO.from_db(resp)
|
||||
else:
|
||||
async with transaction() as tx:
|
||||
notification_event = await tx.notificationevent.create(
|
||||
data={
|
||||
"type": notification_type,
|
||||
"data": json_data,
|
||||
"UserNotificationBatch": {"connect": {"id": existing_batch.id}},
|
||||
}
|
||||
data=NotificationEventCreateInput(
|
||||
type=notification_type,
|
||||
data=json_data,
|
||||
UserNotificationBatch={"connect": {"id": existing_batch.id}},
|
||||
)
|
||||
)
|
||||
# Add to existing batch
|
||||
resp = await tx.usernotificationbatch.update(
|
||||
|
||||
@@ -6,9 +6,11 @@ import pydantic
|
||||
from prisma import Json
|
||||
from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingUpdateInput
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block import get_blocks
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
@@ -24,21 +26,26 @@ REASON_MAPPING: dict[str, list[str]] = {
|
||||
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
|
||||
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
|
||||
|
||||
user_credit = get_user_credit_model()
|
||||
|
||||
|
||||
class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
completedSteps: Optional[list[OnboardingStep]] = None
|
||||
notificationDot: Optional[bool] = None
|
||||
notified: Optional[list[OnboardingStep]] = None
|
||||
usageReason: Optional[str] = None
|
||||
integrations: Optional[list[str]] = None
|
||||
otherIntegrations: Optional[str] = None
|
||||
selectedStoreListingVersionId: Optional[str] = None
|
||||
agentInput: Optional[dict[str, Any]] = None
|
||||
onboardingAgentExecutionId: Optional[str] = None
|
||||
|
||||
|
||||
async def get_user_onboarding(user_id: str):
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id}, # type: ignore
|
||||
"create": UserOnboardingCreateInput(userId=user_id),
|
||||
"update": {},
|
||||
},
|
||||
)
|
||||
@@ -48,6 +55,20 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update: UserOnboardingUpdateInput = {}
|
||||
if data.completedSteps is not None:
|
||||
update["completedSteps"] = list(set(data.completedSteps))
|
||||
for step in (
|
||||
OnboardingStep.AGENT_NEW_RUN,
|
||||
OnboardingStep.GET_RESULTS,
|
||||
OnboardingStep.MARKETPLACE_ADD_AGENT,
|
||||
OnboardingStep.MARKETPLACE_RUN_AGENT,
|
||||
OnboardingStep.BUILDER_SAVE_AGENT,
|
||||
OnboardingStep.BUILDER_RUN_AGENT,
|
||||
):
|
||||
if step in data.completedSteps:
|
||||
await reward_user(user_id, step)
|
||||
if data.notificationDot is not None:
|
||||
update["notificationDot"] = data.notificationDot
|
||||
if data.notified is not None:
|
||||
update["notified"] = list(set(data.notified))
|
||||
if data.usageReason is not None:
|
||||
update["usageReason"] = data.usageReason
|
||||
if data.integrations is not None:
|
||||
@@ -58,16 +79,57 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update["selectedStoreListingVersionId"] = data.selectedStoreListingVersionId
|
||||
if data.agentInput is not None:
|
||||
update["agentInput"] = Json(data.agentInput)
|
||||
if data.onboardingAgentExecutionId is not None:
|
||||
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
|
||||
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, **update}, # type: ignore
|
||||
"create": {"userId": user_id, **update},
|
||||
"update": update,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def reward_user(user_id: str, step: OnboardingStep):
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}-reward"):
|
||||
reward = 0
|
||||
match step:
|
||||
# Reward user when they clicked New Run during onboarding
|
||||
# This is because they need credits before scheduling a run (next step)
|
||||
case OnboardingStep.AGENT_NEW_RUN:
|
||||
reward = 300
|
||||
case OnboardingStep.GET_RESULTS:
|
||||
reward = 300
|
||||
case OnboardingStep.MARKETPLACE_ADD_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.MARKETPLACE_RUN_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_SAVE_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_RUN_AGENT:
|
||||
reward = 100
|
||||
|
||||
if reward == 0:
|
||||
return
|
||||
|
||||
onboarding = await get_user_onboarding(user_id)
|
||||
|
||||
# Skip if already rewarded
|
||||
if step in onboarding.rewardedFor:
|
||||
return
|
||||
|
||||
onboarding.rewardedFor.append(step)
|
||||
await user_credit.onboarding_reward(user_id, reward, step)
|
||||
await UserOnboarding.prisma().update(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"completedSteps": list(set(onboarding.completedSteps + [step])),
|
||||
"rewardedFor": onboarding.rewardedFor,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def clean_and_split(text: str) -> list[str]:
|
||||
"""
|
||||
Removes all special characters from a string, truncates it to 100 characters,
|
||||
@@ -186,11 +248,11 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
where={
|
||||
"id": {"in": [agent.storeListingVersionId for agent in storeAgents]},
|
||||
},
|
||||
include={"Agent": True},
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
|
||||
for listing in agentListings:
|
||||
agent = listing.Agent
|
||||
agent = listing.AgentGraph
|
||||
if agent is None:
|
||||
continue
|
||||
graph = GraphModel.from_db(agent)
|
||||
|
||||
@@ -4,10 +4,18 @@ from enum import Enum
|
||||
from typing import Awaitable, Optional
|
||||
|
||||
import aio_pika
|
||||
import aio_pika.exceptions as aio_ex
|
||||
import pika
|
||||
import pika.adapters.blocking_connection
|
||||
from pika.exceptions import AMQPError
|
||||
from pika.spec import BasicProperties
|
||||
from pydantic import BaseModel
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Settings
|
||||
@@ -161,6 +169,12 @@ class SyncRabbitMQ(RabbitMQBase):
|
||||
routing_key=queue.routing_key or queue.name,
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type((AMQPError, ConnectionError)),
|
||||
wait=wait_random_exponential(multiplier=1, max=5),
|
||||
stop=stop_after_attempt(5),
|
||||
reraise=True,
|
||||
)
|
||||
def publish_message(
|
||||
self,
|
||||
routing_key: str,
|
||||
@@ -258,6 +272,12 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
exchange, routing_key=queue.routing_key or queue.name
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type((aio_ex.AMQPError, ConnectionError)),
|
||||
wait=wait_random_exponential(multiplier=1, max=5),
|
||||
stop=stop_after_attempt(5),
|
||||
reraise=True,
|
||||
)
|
||||
async def publish_message(
|
||||
self,
|
||||
routing_key: str,
|
||||
|
||||
@@ -11,7 +11,7 @@ from fastapi import HTTPException
|
||||
from prisma import Json
|
||||
from prisma.enums import NotificationType
|
||||
from prisma.models import User
|
||||
from prisma.types import UserUpdateInput
|
||||
from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
|
||||
|
||||
from backend.data.db import prisma
|
||||
from backend.data.model import UserIntegrations, UserMetadata, UserMetadataRaw
|
||||
@@ -36,11 +36,11 @@ async def get_or_create_user(user_data: dict) -> User:
|
||||
user = await prisma.user.find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
user = await prisma.user.create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": user_email,
|
||||
"name": user_data.get("user_metadata", {}).get("name"),
|
||||
}
|
||||
data=UserCreateInput(
|
||||
id=user_id,
|
||||
email=user_email,
|
||||
name=user_data.get("user_metadata", {}).get("name"),
|
||||
)
|
||||
)
|
||||
|
||||
return User.model_validate(user)
|
||||
@@ -84,11 +84,11 @@ async def create_default_user() -> Optional[User]:
|
||||
user = await prisma.user.find_unique(where={"id": DEFAULT_USER_ID})
|
||||
if not user:
|
||||
user = await prisma.user.create(
|
||||
data={
|
||||
"id": DEFAULT_USER_ID,
|
||||
"email": "default@example.com",
|
||||
"name": "Default User",
|
||||
}
|
||||
data=UserCreateInput(
|
||||
id=DEFAULT_USER_ID,
|
||||
email="default@example.com",
|
||||
name="Default User",
|
||||
)
|
||||
)
|
||||
return User.model_validate(user)
|
||||
|
||||
@@ -135,16 +135,21 @@ async def migrate_and_encrypt_user_integrations():
|
||||
"""Migrate integration credentials and OAuth states from metadata to integrations column."""
|
||||
users = await User.prisma().find_many(
|
||||
where={
|
||||
"metadata": {
|
||||
"path": ["integration_credentials"],
|
||||
"not": Json({"a": "yolo"}), # bogus value works to check if key exists
|
||||
} # type: ignore
|
||||
"metadata": cast(
|
||||
JsonFilter,
|
||||
{
|
||||
"path": ["integration_credentials"],
|
||||
"not": Json(
|
||||
{"a": "yolo"}
|
||||
), # bogus value works to check if key exists
|
||||
},
|
||||
)
|
||||
}
|
||||
)
|
||||
logger.info(f"Migrating integration credentials for {len(users)} users")
|
||||
|
||||
for user in users:
|
||||
raw_metadata = cast(UserMetadataRaw, user.metadata)
|
||||
raw_metadata = cast(dict, user.metadata)
|
||||
metadata = UserMetadata.model_validate(raw_metadata)
|
||||
|
||||
# Get existing integrations data
|
||||
@@ -160,7 +165,6 @@ async def migrate_and_encrypt_user_integrations():
|
||||
await update_user_integrations(user_id=user.id, data=integrations)
|
||||
|
||||
# Remove from metadata
|
||||
raw_metadata = dict(raw_metadata)
|
||||
raw_metadata.pop("integration_credentials", None)
|
||||
raw_metadata.pop("integration_oauth_states", None)
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from .database import DatabaseManager
|
||||
from .database import DatabaseManager, DatabaseManagerClient
|
||||
from .manager import ExecutionManager
|
||||
from .scheduler import Scheduler
|
||||
|
||||
__all__ = [
|
||||
"DatabaseManager",
|
||||
"DatabaseManagerClient",
|
||||
"ExecutionManager",
|
||||
"Scheduler",
|
||||
]
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import logging
|
||||
from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
GraphExecution,
|
||||
NodeExecutionResult,
|
||||
RedisExecutionEventBus,
|
||||
create_graph_execution,
|
||||
get_graph_execution,
|
||||
get_incomplete_node_executions,
|
||||
get_graph_execution_meta,
|
||||
get_latest_node_execution,
|
||||
get_node_execution_results,
|
||||
get_node_executions,
|
||||
update_graph_execution_start_time,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_stats,
|
||||
@@ -39,11 +40,14 @@ from backend.data.user import (
|
||||
update_user_integrations,
|
||||
update_user_metadata,
|
||||
)
|
||||
from backend.util.service import AppService, expose, exposed_run_and_wait
|
||||
from backend.util.service import AppService, AppServiceClient, endpoint_to_sync, expose
|
||||
from backend.util.settings import Config
|
||||
|
||||
config = Config()
|
||||
_user_credit_model = get_user_credit_model()
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
async def _spend_credits(
|
||||
@@ -52,75 +56,133 @@ async def _spend_credits(
|
||||
return await _user_credit_model.spend_credits(user_id, cost, metadata)
|
||||
|
||||
|
||||
async def _get_credits(user_id: str) -> int:
|
||||
return await _user_credit_model.get_credits(user_id)
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_db = True
|
||||
self.use_redis = True
|
||||
self.execution_event_bus = RedisExecutionEventBus()
|
||||
|
||||
def run_service(self) -> None:
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
self.run_and_wait(db.connect())
|
||||
super().run_service()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return config.database_api_port
|
||||
|
||||
@expose
|
||||
def send_execution_update(
|
||||
self, execution_result: GraphExecution | NodeExecutionResult
|
||||
):
|
||||
self.execution_event_bus.publish(execution_result)
|
||||
@staticmethod
|
||||
def _(
|
||||
f: Callable[P, R], name: str | None = None
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
if name is not None:
|
||||
f.__name__ = name
|
||||
return cast(Callable[Concatenate[object, P], R], expose(f))
|
||||
|
||||
# Executions
|
||||
get_graph_execution = exposed_run_and_wait(get_graph_execution)
|
||||
create_graph_execution = exposed_run_and_wait(create_graph_execution)
|
||||
get_node_execution_results = exposed_run_and_wait(get_node_execution_results)
|
||||
get_incomplete_node_executions = exposed_run_and_wait(
|
||||
get_incomplete_node_executions
|
||||
)
|
||||
get_latest_node_execution = exposed_run_and_wait(get_latest_node_execution)
|
||||
update_node_execution_status = exposed_run_and_wait(update_node_execution_status)
|
||||
update_node_execution_status_batch = exposed_run_and_wait(
|
||||
update_node_execution_status_batch
|
||||
)
|
||||
update_graph_execution_start_time = exposed_run_and_wait(
|
||||
update_graph_execution_start_time
|
||||
)
|
||||
update_graph_execution_stats = exposed_run_and_wait(update_graph_execution_stats)
|
||||
update_node_execution_stats = exposed_run_and_wait(update_node_execution_stats)
|
||||
upsert_execution_input = exposed_run_and_wait(upsert_execution_input)
|
||||
upsert_execution_output = exposed_run_and_wait(upsert_execution_output)
|
||||
get_graph_execution = _(get_graph_execution)
|
||||
get_graph_execution_meta = _(get_graph_execution_meta)
|
||||
create_graph_execution = _(create_graph_execution)
|
||||
get_node_executions = _(get_node_executions)
|
||||
get_latest_node_execution = _(get_latest_node_execution)
|
||||
update_node_execution_status = _(update_node_execution_status)
|
||||
update_node_execution_status_batch = _(update_node_execution_status_batch)
|
||||
update_graph_execution_start_time = _(update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(update_graph_execution_stats)
|
||||
update_node_execution_stats = _(update_node_execution_stats)
|
||||
upsert_execution_input = _(upsert_execution_input)
|
||||
upsert_execution_output = _(upsert_execution_output)
|
||||
|
||||
# Graphs
|
||||
get_node = exposed_run_and_wait(get_node)
|
||||
get_graph = exposed_run_and_wait(get_graph)
|
||||
get_connected_output_nodes = exposed_run_and_wait(get_connected_output_nodes)
|
||||
get_graph_metadata = exposed_run_and_wait(get_graph_metadata)
|
||||
get_node = _(get_node)
|
||||
get_graph = _(get_graph)
|
||||
get_connected_output_nodes = _(get_connected_output_nodes)
|
||||
get_graph_metadata = _(get_graph_metadata)
|
||||
|
||||
# Credits
|
||||
spend_credits = exposed_run_and_wait(_spend_credits)
|
||||
spend_credits = _(_spend_credits, name="spend_credits")
|
||||
get_credits = _(_get_credits, name="get_credits")
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
get_user_metadata = exposed_run_and_wait(get_user_metadata)
|
||||
update_user_metadata = exposed_run_and_wait(update_user_metadata)
|
||||
get_user_integrations = exposed_run_and_wait(get_user_integrations)
|
||||
update_user_integrations = exposed_run_and_wait(update_user_integrations)
|
||||
get_user_metadata = _(get_user_metadata)
|
||||
update_user_metadata = _(update_user_metadata)
|
||||
get_user_integrations = _(get_user_integrations)
|
||||
update_user_integrations = _(update_user_integrations)
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = exposed_run_and_wait(
|
||||
get_active_user_ids_in_timerange
|
||||
)
|
||||
get_user_email_by_id = exposed_run_and_wait(get_user_email_by_id)
|
||||
get_user_email_verification = exposed_run_and_wait(get_user_email_verification)
|
||||
get_user_notification_preference = exposed_run_and_wait(
|
||||
get_user_notification_preference
|
||||
)
|
||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||
get_user_email_by_id = _(get_user_email_by_id)
|
||||
get_user_email_verification = _(get_user_email_verification)
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
|
||||
# Notifications - async
|
||||
create_or_add_to_user_notification_batch = exposed_run_and_wait(
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
|
||||
get_all_batches_by_type = exposed_run_and_wait(get_all_batches_by_type)
|
||||
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = exposed_run_and_wait(
|
||||
empty_user_notification_batch = _(empty_user_notification_batch)
|
||||
get_all_batches_by_type = _(get_all_batches_by_type)
|
||||
get_user_notification_batch = _(get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = _(
|
||||
get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
_ = endpoint_to_sync
|
||||
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return DatabaseManager
|
||||
|
||||
# Executions
|
||||
get_graph_execution = _(d.get_graph_execution)
|
||||
get_graph_execution_meta = _(d.get_graph_execution_meta)
|
||||
create_graph_execution = _(d.create_graph_execution)
|
||||
get_node_executions = _(d.get_node_executions)
|
||||
get_latest_node_execution = _(d.get_latest_node_execution)
|
||||
update_node_execution_status = _(d.update_node_execution_status)
|
||||
update_node_execution_status_batch = _(d.update_node_execution_status_batch)
|
||||
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(d.update_graph_execution_stats)
|
||||
update_node_execution_stats = _(d.update_node_execution_stats)
|
||||
upsert_execution_input = _(d.upsert_execution_input)
|
||||
upsert_execution_output = _(d.upsert_execution_output)
|
||||
|
||||
# Graphs
|
||||
get_node = _(d.get_node)
|
||||
get_graph = _(d.get_graph)
|
||||
get_connected_output_nodes = _(d.get_connected_output_nodes)
|
||||
get_graph_metadata = _(d.get_graph_metadata)
|
||||
|
||||
# Credits
|
||||
spend_credits = _(d.spend_credits)
|
||||
get_credits = _(d.get_credits)
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
get_user_metadata = _(d.get_user_metadata)
|
||||
update_user_metadata = _(d.update_user_metadata)
|
||||
get_user_integrations = _(d.get_user_integrations)
|
||||
update_user_integrations = _(d.update_user_integrations)
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = _(d.get_active_user_ids_in_timerange)
|
||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
||||
get_user_email_verification = _(d.get_user_email_verification)
|
||||
get_user_notification_preference = _(d.get_user_notification_preference)
|
||||
|
||||
# Notifications - async
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
d.create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = _(d.empty_user_notification_batch)
|
||||
get_all_batches_by_type = _(d.get_all_batches_by_type)
|
||||
get_user_notification_batch = _(d.get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = _(
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -16,9 +16,15 @@ from pydantic import BaseModel
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
from backend.executor.manager import ExecutionManager
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
|
||||
|
||||
@@ -57,25 +63,18 @@ def job_listener(event):
|
||||
log(f"Job {event.job_id} completed successfully.")
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_client() -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_client():
|
||||
from backend.notifications import NotificationManager
|
||||
|
||||
return get_service_client(NotificationManager)
|
||||
return get_service_client(NotificationManagerClient)
|
||||
|
||||
|
||||
def execute_graph(**kwargs):
|
||||
args = ExecutionJobArgs(**kwargs)
|
||||
try:
|
||||
log(f"Executing recurring job for graph #{args.graph_id}")
|
||||
get_execution_client().add_execution(
|
||||
execution_utils.add_graph_execution(
|
||||
graph_id=args.graph_id,
|
||||
data=args.input_data,
|
||||
inputs=args.input_data,
|
||||
user_id=args.user_id,
|
||||
graph_version=args.graph_version,
|
||||
)
|
||||
@@ -164,19 +163,9 @@ class Scheduler(AppService):
|
||||
def db_pool_size(cls) -> int:
|
||||
return config.scheduler_db_pool_size
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def execution_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def notification_client(self) -> NotificationManager:
|
||||
return get_service_client(NotificationManager)
|
||||
|
||||
def run_service(self):
|
||||
load_dotenv()
|
||||
db_schema, db_url = _extract_schema_from_url(os.getenv("DATABASE_URL"))
|
||||
db_schema, db_url = _extract_schema_from_url(os.getenv("DIRECT_URL"))
|
||||
self.scheduler = BlockingScheduler(
|
||||
jobstores={
|
||||
Jobstores.EXECUTION.value: SQLAlchemyJobStore(
|
||||
@@ -206,6 +195,12 @@ class Scheduler(AppService):
|
||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||
self.scheduler.start()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down scheduler...")
|
||||
if self.scheduler:
|
||||
self.scheduler.shutdown(wait=False)
|
||||
|
||||
@expose
|
||||
def add_execution_schedule(
|
||||
self,
|
||||
@@ -304,3 +299,15 @@ class Scheduler(AppService):
|
||||
),
|
||||
job,
|
||||
)
|
||||
|
||||
|
||||
class SchedulerClient(AppServiceClient):
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return Scheduler
|
||||
|
||||
add_execution_schedule = endpoint_to_async(Scheduler.add_execution_schedule)
|
||||
delete_schedule = endpoint_to_async(Scheduler.delete_schedule)
|
||||
get_execution_schedules = endpoint_to_async(Scheduler.get_execution_schedules)
|
||||
add_batched_notification_schedule = Scheduler.add_batched_notification_schedule
|
||||
add_weekly_notification_schedule = Scheduler.add_weekly_notification_schedule
|
||||
|
||||
@@ -1,38 +1,113 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockInput
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockData,
|
||||
BlockInput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCostType
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
ExecutionStatus,
|
||||
GraphExecutionStats,
|
||||
GraphExecutionWithNodes,
|
||||
RedisExecutionEventBus,
|
||||
create_graph_execution,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_status_batch,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Node, get_graph
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.rabbitmq import (
|
||||
AsyncRabbitMQ,
|
||||
Exchange,
|
||||
ExchangeType,
|
||||
Queue,
|
||||
RabbitMQConfig,
|
||||
SyncRabbitMQ,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
config = Config()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ============ Resource Helpers ============ #
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
graph_exec_id: str | None = None
|
||||
graph_id: str | None = None
|
||||
node_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
block_id: str | None = None
|
||||
block: str | None = None
|
||||
input: BlockInput | None = None
|
||||
@thread_cached
|
||||
def get_execution_event_bus() -> RedisExecutionEventBus:
|
||||
return RedisExecutionEventBus()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_async_execution_event_bus() -> AsyncRedisExecutionEventBus:
|
||||
return AsyncRedisExecutionEventBus()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_queue() -> SyncRabbitMQ:
|
||||
client = SyncRabbitMQ(create_execution_queue_config())
|
||||
client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def get_async_execution_queue() -> AsyncRabbitMQ:
|
||||
client = AsyncRabbitMQ(create_execution_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
return IntegrationCredentialsStore()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
from backend.executor import DatabaseManagerClient
|
||||
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
# ============ Execution Cost Helpers ============ #
|
||||
|
||||
|
||||
def execution_usage_cost(execution_count: int) -> tuple[int, int]:
|
||||
"""
|
||||
Calculate the cost of executing a graph based on the number of executions.
|
||||
Calculate the cost of executing a graph based on the current number of node executions.
|
||||
|
||||
Args:
|
||||
execution_count: Number of executions
|
||||
execution_count: Number of node executions
|
||||
|
||||
Returns:
|
||||
Tuple of cost amount and remaining execution count
|
||||
Tuple of cost amount and the number of execution count that is included in the cost.
|
||||
"""
|
||||
return (
|
||||
execution_count
|
||||
// config.execution_cost_count_threshold
|
||||
* config.execution_cost_per_threshold,
|
||||
execution_count % config.execution_cost_count_threshold,
|
||||
(
|
||||
config.execution_cost_per_threshold
|
||||
if execution_count % config.execution_cost_count_threshold == 0
|
||||
else 0
|
||||
),
|
||||
config.execution_cost_count_threshold,
|
||||
)
|
||||
|
||||
|
||||
@@ -95,3 +170,571 @@ def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bo
|
||||
or (input_data.get(k) and _is_cost_filter_match(v, input_data[k]))
|
||||
for k, v in cost_filter.items()
|
||||
)
|
||||
|
||||
|
||||
# ============ Execution Input Helpers ============ #
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
"""
|
||||
Extracts partial output data by name from a given BlockData.
|
||||
|
||||
The function supports extracting data from lists, dictionaries, and objects
|
||||
using specific naming conventions:
|
||||
- For lists: <output_name>_$_<index>
|
||||
- For dictionaries: <output_name>_#_<key>
|
||||
- For objects: <output_name>_@_<attribute>
|
||||
|
||||
Args:
|
||||
output (BlockData): A tuple containing the output name and data.
|
||||
name (str): The name used to extract specific data from the output.
|
||||
|
||||
Returns:
|
||||
Any | None: The extracted data if found, otherwise None.
|
||||
|
||||
Examples:
|
||||
>>> output = ("result", [10, 20, 30])
|
||||
>>> parse_execution_output(output, "result_$_1")
|
||||
20
|
||||
|
||||
>>> output = ("config", {"key1": "value1", "key2": "value2"})
|
||||
>>> parse_execution_output(output, "config_#_key1")
|
||||
'value1'
|
||||
|
||||
>>> class Sample:
|
||||
... attr1 = "value1"
|
||||
... attr2 = "value2"
|
||||
>>> output = ("object", Sample())
|
||||
>>> parse_execution_output(output, "object_@_attr1")
|
||||
'value1'
|
||||
"""
|
||||
output_name, output_data = output
|
||||
|
||||
if name == output_name:
|
||||
return output_data
|
||||
|
||||
if name.startswith(f"{output_name}{LIST_SPLIT}"):
|
||||
index = int(name.split(LIST_SPLIT)[1])
|
||||
if not isinstance(output_data, list) or len(output_data) <= index:
|
||||
return None
|
||||
return output_data[int(name.split(LIST_SPLIT)[1])]
|
||||
|
||||
if name.startswith(f"{output_name}{DICT_SPLIT}"):
|
||||
index = name.split(DICT_SPLIT)[1]
|
||||
if not isinstance(output_data, dict) or index not in output_data:
|
||||
return None
|
||||
return output_data[index]
|
||||
|
||||
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
|
||||
index = name.split(OBJC_SPLIT)[1]
|
||||
if isinstance(output_data, object) and hasattr(output_data, index):
|
||||
return getattr(output_data, index)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_exec(
|
||||
node: Node,
|
||||
data: BlockInput,
|
||||
resolve_input: bool = True,
|
||||
) -> tuple[BlockInput | None, str]:
|
||||
"""
|
||||
Validate the input data for a node execution.
|
||||
|
||||
Args:
|
||||
node: The node to execute.
|
||||
data: The input data for the node execution.
|
||||
resolve_input: Whether to resolve dynamic pins into dict/list/object.
|
||||
|
||||
Returns:
|
||||
A tuple of the validated data and the block name.
|
||||
If the data is invalid, the first element will be None, and the second element
|
||||
will be an error message.
|
||||
If the data is valid, the first element will be the resolved input data, and
|
||||
the second element will be the block name.
|
||||
"""
|
||||
node_block: Block | None = get_block(node.block_id)
|
||||
if not node_block:
|
||||
return None, f"Block for {node.block_id} not found."
|
||||
schema = node_block.input_schema
|
||||
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
if (value := data.get(name)) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data (without default values) should contain all required fields.
|
||||
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
|
||||
if missing_links := schema.get_missing_links(data, node.input_links):
|
||||
return None, f"{error_prefix} unpopulated links {missing_links}"
|
||||
|
||||
# Merge input data with default values and resolve dynamic dict/list/object pins.
|
||||
input_default = schema.get_input_defaults(node.input_default)
|
||||
data = {**input_default, **data}
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
if missing_input := schema.get_missing_input(data):
|
||||
return None, f"{error_prefix} missing input {missing_input}"
|
||||
|
||||
# Last validation: Validate the input values against the schema.
|
||||
if error := schema.get_mismatch_error(data):
|
||||
error_message = f"{error_prefix} {error}"
|
||||
logger.error(error_message)
|
||||
return None, error_message
|
||||
|
||||
return data, node_block.name
|
||||
|
||||
|
||||
def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
"""
|
||||
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
|
||||
|
||||
This function processes input keys that follow specific patterns to merge them into a unified structure:
|
||||
- `<input_name>_$_<index>` for list inputs.
|
||||
- `<input_name>_#_<index>` for dictionary inputs.
|
||||
- `<input_name>_@_<index>` for object inputs.
|
||||
|
||||
Args:
|
||||
data (BlockInput): A dictionary containing input keys and their corresponding values.
|
||||
|
||||
Returns:
|
||||
BlockInput: A dictionary with merged inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If a list index is not an integer.
|
||||
|
||||
Examples:
|
||||
>>> data = {
|
||||
... "list_$_0": "a",
|
||||
... "list_$_1": "b",
|
||||
... "dict_#_key1": "value1",
|
||||
... "dict_#_key2": "value2",
|
||||
... "object_@_attr1": "value1",
|
||||
... "object_@_attr2": "value2"
|
||||
... }
|
||||
>>> merge_execution_input(data)
|
||||
{
|
||||
"list": ["a", "b"],
|
||||
"dict": {"key1": "value1", "key2": "value2"},
|
||||
"object": <MockObject attr1="value1" attr2="value2">
|
||||
}
|
||||
"""
|
||||
|
||||
# Merge all input with <input_name>_$_<index> into a single list.
|
||||
items = list(data.items())
|
||||
|
||||
for key, value in items:
|
||||
if LIST_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(LIST_SPLIT)
|
||||
if not index.isdigit():
|
||||
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
|
||||
|
||||
data[name] = data.get(name, [])
|
||||
if int(index) >= len(data[name]):
|
||||
# Pad list with empty string on missing indices.
|
||||
data[name].extend([""] * (int(index) - len(data[name]) + 1))
|
||||
data[name][int(index)] = value
|
||||
|
||||
# Merge all input with <input_name>_#_<index> into a single dict.
|
||||
for key, value in items:
|
||||
if DICT_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(DICT_SPLIT)
|
||||
data[name] = data.get(name, {})
|
||||
data[name][index] = value
|
||||
|
||||
# Merge all input with <input_name>_@_<index> into a single object.
|
||||
for key, value in items:
|
||||
if OBJC_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(OBJC_SPLIT)
|
||||
if name not in data or not isinstance(data[name], object):
|
||||
data[name] = MockObject()
|
||||
setattr(data[name], index, value)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
for node in graph.nodes:
|
||||
block = node.block
|
||||
|
||||
# Find any fields of type CredentialsMetaInput
|
||||
credentials_fields = cast(
|
||||
type[BlockSchema], block.input_schema
|
||||
).get_credentials_fields()
|
||||
if not credentials_fields:
|
||||
continue
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
if (
|
||||
node_credentials_input_map
|
||||
and (node_credentials_inputs := node_credentials_input_map.get(node.id))
|
||||
and field_name in node_credentials_inputs
|
||||
):
|
||||
credentials_meta = node_credentials_input_map[node.id][field_name]
|
||||
elif field_name in node.input_default:
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Credentials absent for {block.name} node #{node.id} "
|
||||
f"input '{field_name}'"
|
||||
)
|
||||
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = get_integration_credentials_store().get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
f"Unknown credentials #{credentials_meta.id} "
|
||||
f"for node #{node.id} input '{field_name}'"
|
||||
)
|
||||
if (
|
||||
credentials.provider != credentials_meta.provider
|
||||
or credentials.type != credentials_meta.type
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch: "
|
||||
f"{credentials_meta.type}<>{credentials.type};"
|
||||
f"{credentials_meta.provider}<>{credentials.provider}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch"
|
||||
)
|
||||
|
||||
|
||||
def make_node_credentials_input_map(
|
||||
graph: GraphModel,
|
||||
graph_credentials_input: dict[str, CredentialsMetaInput],
|
||||
) -> dict[str, dict[str, CredentialsMetaInput]]:
|
||||
"""
|
||||
Maps credentials for an execution to the correct nodes.
|
||||
|
||||
Params:
|
||||
graph: The graph to be executed.
|
||||
graph_credentials_input: A (graph_input_name, credentials_meta) map.
|
||||
|
||||
Returns:
|
||||
dict[node_id, dict[field_name, CredentialsMetaInput]]: Node credentials input map.
|
||||
"""
|
||||
result: dict[str, dict[str, CredentialsMetaInput]] = {}
|
||||
|
||||
# Get aggregated credentials fields for the graph
|
||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||
|
||||
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
|
||||
# Best-effort map: skip missing items
|
||||
if graph_input_name not in graph_credentials_input:
|
||||
continue
|
||||
|
||||
# Use passed-in credentials for all compatible node input fields
|
||||
for node_id, node_field_name in compatible_node_fields:
|
||||
if node_id not in result:
|
||||
result[node_id] = {}
|
||||
result[node_id][node_field_name] = graph_credentials_input[graph_input_name]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def construct_node_execution_input(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
This function checks the graph for starting nodes, validates the input data
|
||||
against the schema, and resolves dynamic input pins into a single list,
|
||||
dictionary, or object.
|
||||
|
||||
Args:
|
||||
graph (GraphModel): The graph model to execute.
|
||||
user_id (str): The ID of the user executing the graph.
|
||||
data (BlockInput): The input data for the graph execution.
|
||||
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
|
||||
|
||||
Returns:
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||
the corresponding input data for that node.
|
||||
"""
|
||||
graph.validate_graph(for_run=True)
|
||||
_validate_node_input_credentials(graph, user_id, node_credentials_input_map)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
input_data = {}
|
||||
block = node.block
|
||||
|
||||
# Note block should never be executed.
|
||||
if block.block_type == BlockType.NOTE:
|
||||
continue
|
||||
|
||||
# Extract request input data, and assign it to the input pin.
|
||||
if block.block_type == BlockType.INPUT:
|
||||
input_name = node.input_default.get("name")
|
||||
if input_name and input_name in graph_inputs:
|
||||
input_data = {"value": graph_inputs[input_name]}
|
||||
|
||||
# Extract webhook payload, and assign it to the input pin
|
||||
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
||||
if (
|
||||
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
and node.webhook_id
|
||||
):
|
||||
if webhook_payload_key not in graph_inputs:
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} webhook payload is missing"
|
||||
)
|
||||
input_data = {"payload": graph_inputs[webhook_payload_key]}
|
||||
|
||||
# Apply node credentials overrides
|
||||
if node_credentials_input_map and (
|
||||
node_credentials := node_credentials_input_map.get(node.id)
|
||||
):
|
||||
input_data.update({k: v.model_dump() for k, v in node_credentials.items()})
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
raise ValueError(error)
|
||||
else:
|
||||
nodes_input.append((node.id, input_data))
|
||||
|
||||
if not nodes_input:
|
||||
raise ValueError(
|
||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||
)
|
||||
|
||||
return nodes_input
|
||||
|
||||
|
||||
# ============ Execution Queue Helpers ============ #
|
||||
|
||||
|
||||
class CancelExecutionEvent(BaseModel):
|
||||
graph_exec_id: str
|
||||
|
||||
|
||||
GRAPH_EXECUTION_EXCHANGE = Exchange(
|
||||
name="graph_execution",
|
||||
type=ExchangeType.DIRECT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
GRAPH_EXECUTION_QUEUE_NAME = "graph_execution_queue"
|
||||
GRAPH_EXECUTION_ROUTING_KEY = "graph_execution.run"
|
||||
|
||||
GRAPH_EXECUTION_CANCEL_EXCHANGE = Exchange(
|
||||
name="graph_execution_cancel",
|
||||
type=ExchangeType.FANOUT,
|
||||
durable=True,
|
||||
auto_delete=True,
|
||||
)
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME = "graph_execution_cancel_queue"
|
||||
|
||||
|
||||
def create_execution_queue_config() -> RabbitMQConfig:
|
||||
"""
|
||||
Define two exchanges and queues:
|
||||
- 'graph_execution' (DIRECT) for run tasks.
|
||||
- 'graph_execution_cancel' (FANOUT) for cancel requests.
|
||||
"""
|
||||
run_queue = Queue(
|
||||
name=GRAPH_EXECUTION_QUEUE_NAME,
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
cancel_queue = Queue(
|
||||
name=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
|
||||
routing_key="", # not used for FANOUT
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
return RabbitMQConfig(
|
||||
vhost="/",
|
||||
exchanges=[GRAPH_EXECUTION_EXCHANGE, GRAPH_EXECUTION_CANCEL_EXCHANGE],
|
||||
queues=[run_queue, cancel_queue],
|
||||
)
|
||||
|
||||
|
||||
async def add_graph_execution_async(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
inputs: BlockInput,
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
|
||||
Args:
|
||||
graph_id: The ID of the graph to execute.
|
||||
user_id: The ID of the user executing the graph.
|
||||
inputs: The input data for the graph execution.
|
||||
preset_id: The ID of the preset to use.
|
||||
graph_version: The version of the graph to execute.
|
||||
graph_credentials_inputs: Credentials inputs to use in the execution.
|
||||
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
|
||||
Returns:
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
ValueError: If the graph is not found or if there are validation errors.
|
||||
""" # noqa
|
||||
graph: GraphModel | None = await get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
node_credentials_input_map = (
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else None
|
||||
)
|
||||
|
||||
graph_exec = await create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
),
|
||||
preset_id=preset_id,
|
||||
)
|
||||
try:
|
||||
queue = await get_async_execution_queue()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
if node_credentials_input_map:
|
||||
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
|
||||
await queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
|
||||
bus = get_async_execution_event_bus()
|
||||
await bus.publish(graph_exec)
|
||||
|
||||
return graph_exec
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
|
||||
|
||||
await update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
|
||||
ExecutionStatus.FAILED,
|
||||
)
|
||||
await update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
stats=GraphExecutionStats(error=str(e)),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def add_graph_execution(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
inputs: BlockInput,
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
|
||||
Args:
|
||||
graph_id: The ID of the graph to execute.
|
||||
user_id: The ID of the user executing the graph.
|
||||
inputs: The input data for the graph execution.
|
||||
preset_id: The ID of the preset to use.
|
||||
graph_version: The version of the graph to execute.
|
||||
graph_credentials_inputs: Credentials inputs to use in the execution.
|
||||
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
|
||||
Returns:
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
ValueError: If the graph is not found or if there are validation errors.
|
||||
"""
|
||||
db = get_db_client()
|
||||
graph: GraphModel | None = db.get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
node_credentials_input_map = (
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else None
|
||||
)
|
||||
|
||||
graph_exec = db.create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
),
|
||||
preset_id=preset_id,
|
||||
)
|
||||
try:
|
||||
queue = get_execution_queue()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
if node_credentials_input_map:
|
||||
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
|
||||
queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
|
||||
bus = get_execution_event_bus()
|
||||
bus.publish(graph_exec)
|
||||
|
||||
return graph_exec
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
|
||||
|
||||
db.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
|
||||
ExecutionStatus.FAILED,
|
||||
)
|
||||
db.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
stats=GraphExecutionStats(error=str(e)),
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
from pydantic import SecretStr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor.database import DatabaseManager
|
||||
from backend.executor.database import DatabaseManagerClient
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from autogpt_libs.utils.synchronize import RedisKeyedMutex
|
||||
@@ -161,6 +161,14 @@ smartlead_credentials = APIKeyCredentials(
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
google_maps_credentials = APIKeyCredentials(
|
||||
id="9aa1bde0-4947-4a70-a20c-84daa3850d52",
|
||||
provider="google_maps",
|
||||
api_key=SecretStr(settings.secrets.google_maps_api_key),
|
||||
title="Use Credits for Google Maps",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
zerobounce_credentials = APIKeyCredentials(
|
||||
id="63a6e279-2dc2-448e-bf57-85776f7176dc",
|
||||
provider="zerobounce",
|
||||
@@ -190,6 +198,7 @@ DEFAULT_CREDENTIALS = [
|
||||
apollo_credentials,
|
||||
smartlead_credentials,
|
||||
zerobounce_credentials,
|
||||
google_maps_credentials,
|
||||
]
|
||||
|
||||
|
||||
@@ -201,11 +210,11 @@ class IntegrationCredentialsStore:
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def db_manager(self) -> "DatabaseManager":
|
||||
from backend.executor.database import DatabaseManager
|
||||
def db_manager(self) -> "DatabaseManagerClient":
|
||||
from backend.executor.database import DatabaseManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
def add_creds(self, user_id: str, credentials: Credentials) -> None:
|
||||
with self.locked_user_integrations(user_id):
|
||||
@@ -263,6 +272,8 @@ class IntegrationCredentialsStore:
|
||||
all_credentials.append(smartlead_credentials)
|
||||
if settings.secrets.zerobounce_api_key:
|
||||
all_credentials.append(zerobounce_credentials)
|
||||
if settings.secrets.google_maps_api_key:
|
||||
all_credentials.append(google_maps_credentials)
|
||||
return all_credentials
|
||||
|
||||
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
|
||||
|
||||
@@ -10,6 +10,7 @@ from backend.data import redis
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -92,7 +93,7 @@ class IntegrationCredentialsManager:
|
||||
|
||||
fresh_credentials = oauth_handler.refresh_tokens(credentials)
|
||||
self.store.update_creds(user_id, fresh_credentials)
|
||||
if _lock and _lock.locked():
|
||||
if _lock and _lock.locked() and _lock.owned():
|
||||
_lock.release()
|
||||
|
||||
credentials = fresh_credentials
|
||||
@@ -144,7 +145,7 @@ class IntegrationCredentialsManager:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if lock.locked():
|
||||
if lock.locked() and lock.owned():
|
||||
lock.release()
|
||||
|
||||
def release_all_locks(self):
|
||||
@@ -153,12 +154,13 @@ class IntegrationCredentialsManager:
|
||||
self.store.locks.release_all_locks()
|
||||
|
||||
|
||||
def _get_provider_oauth_handler(provider_name: str) -> "BaseOAuthHandler":
|
||||
def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandler":
|
||||
provider_name = ProviderName(provider_name_str)
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
raise KeyError(f"Unknown provider '{provider_name}'")
|
||||
|
||||
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
|
||||
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
|
||||
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id")
|
||||
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
|
||||
if not (client_id and client_secret):
|
||||
raise MissingConfigError(
|
||||
f"Integration with provider '{provider_name}' is not configured",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable, Optional, cast
|
||||
|
||||
from backend.data.block import BlockSchema, BlockWebhookConfig, get_block
|
||||
from backend.data.block import BlockSchema, BlockWebhookConfig
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.integrations.webhooks import get_webhook_manager, supports_webhooks
|
||||
|
||||
@@ -29,12 +29,7 @@ async def on_graph_activate(
|
||||
# Compare nodes in new_graph_version with previous_graph_version
|
||||
updated_nodes = []
|
||||
for new_node in graph.nodes:
|
||||
block = get_block(new_node.block_id)
|
||||
if not block:
|
||||
raise ValueError(
|
||||
f"Node #{new_node.id} is instance of unknown block #{new_node.block_id}"
|
||||
)
|
||||
block_input_schema = cast(BlockSchema, block.input_schema)
|
||||
block_input_schema = cast(BlockSchema, new_node.block.input_schema)
|
||||
|
||||
node_credentials = None
|
||||
if (
|
||||
@@ -75,12 +70,7 @@ async def on_graph_deactivate(
|
||||
"""
|
||||
updated_nodes = []
|
||||
for node in graph.nodes:
|
||||
block = get_block(node.block_id)
|
||||
if not block:
|
||||
raise ValueError(
|
||||
f"Node #{node.id} is instance of unknown block #{node.block_id}"
|
||||
)
|
||||
block_input_schema = cast(BlockSchema, block.input_schema)
|
||||
block_input_schema = cast(BlockSchema, node.block.input_schema)
|
||||
|
||||
node_credentials = None
|
||||
if (
|
||||
@@ -113,11 +103,7 @@ async def on_node_activate(
|
||||
) -> "NodeModel":
|
||||
"""Hook to be called when the node is activated/created"""
|
||||
|
||||
block = get_block(node.block_id)
|
||||
if not block:
|
||||
raise ValueError(
|
||||
f"Node #{node.id} is instance of unknown block #{node.block_id}"
|
||||
)
|
||||
block = node.block
|
||||
|
||||
if not block.webhook_config:
|
||||
return node
|
||||
@@ -224,11 +210,7 @@ async def on_node_deactivate(
|
||||
"""Hook to be called when node is deactivated/deleted"""
|
||||
|
||||
logger.debug(f"Deactivating node #{node.id}")
|
||||
block = get_block(node.block_id)
|
||||
if not block:
|
||||
raise ValueError(
|
||||
f"Node #{node.id} is instance of unknown block #{node.block_id}"
|
||||
)
|
||||
block = node.block
|
||||
|
||||
if not block.webhook_config:
|
||||
return node
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .notifications import NotificationManager
|
||||
from .notifications import NotificationManager, NotificationManagerClient
|
||||
|
||||
__all__ = [
|
||||
"NotificationManager",
|
||||
"NotificationManagerClient",
|
||||
]
|
||||
|
||||
@@ -9,6 +9,7 @@ from autogpt_libs.utils.cache import thread_cached
|
||||
from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import rabbitmq
|
||||
from backend.data.notifications import (
|
||||
BaseSummaryData,
|
||||
BaseSummaryParams,
|
||||
@@ -30,7 +31,13 @@ from backend.data.notifications import (
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.user import generate_unsubscribe_link
|
||||
from backend.notifications.email import EmailSender
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -107,16 +114,16 @@ def create_notification_config() -> RabbitMQConfig:
|
||||
|
||||
@thread_cached
|
||||
def get_scheduler():
|
||||
from backend.executor import Scheduler
|
||||
from backend.executor.scheduler import SchedulerClient
|
||||
|
||||
return get_service_client(Scheduler)
|
||||
return get_service_client(SchedulerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db():
|
||||
from backend.executor.database import DatabaseManager
|
||||
from backend.executor.database import DatabaseManagerClient
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
class NotificationManager(AppService):
|
||||
@@ -128,6 +135,20 @@ class NotificationManager(AppService):
|
||||
self.running = True
|
||||
self.email_sender = EmailSender()
|
||||
|
||||
@property
|
||||
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
|
||||
"""Access the RabbitMQ service. Will raise if not configured."""
|
||||
if not self.rabbitmq_service:
|
||||
raise RuntimeError("RabbitMQ not configured for this service")
|
||||
return self.rabbitmq_service
|
||||
|
||||
@property
|
||||
def rabbit_config(self) -> rabbitmq.RabbitMQConfig:
|
||||
"""Access the RabbitMQ config. Will raise if not configured."""
|
||||
if not self.rabbitmq_config:
|
||||
raise RuntimeError("RabbitMQ not configured for this service")
|
||||
return self.rabbitmq_config
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return settings.config.notification_service_port
|
||||
@@ -688,10 +709,14 @@ class NotificationManager(AppService):
|
||||
)
|
||||
|
||||
def run_service(self):
|
||||
logger.info(f"[{self.service_name}] ⏳ Configuring RabbitMQ...")
|
||||
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
|
||||
self.run_and_wait(self.rabbitmq_service.connect())
|
||||
|
||||
logger.info(f"[{self.service_name}] Started notification service")
|
||||
|
||||
# Set up scheduler for batch processing of all notification types
|
||||
# this can be changed later to spawn differnt cleanups on different schedules
|
||||
# this can be changed later to spawn different cleanups on different schedules
|
||||
try:
|
||||
get_scheduler().add_batched_notification_schedule(
|
||||
notification_types=list(NotificationType),
|
||||
@@ -753,3 +778,16 @@ class NotificationManager(AppService):
|
||||
"""Cleanup service resources"""
|
||||
self.running = False
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
|
||||
self.run_and_wait(self.rabbitmq_service.disconnect())
|
||||
|
||||
|
||||
class NotificationManagerClient(AppServiceClient):
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return NotificationManager
|
||||
|
||||
queue_notification_async = endpoint_to_async(NotificationManager.queue_notification)
|
||||
queue_notification = NotificationManager.queue_notification
|
||||
process_existing_batches = NotificationManager.process_existing_batches
|
||||
queue_weekly_summary = NotificationManager.queue_weekly_summary
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Dict, List, Optional, Sequence
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
from typing_extensions import TypedDict
|
||||
@@ -13,17 +12,10 @@ from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKey
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.execution import NodeExecutionResult
|
||||
from backend.executor import ExecutionManager
|
||||
from backend.executor.utils import add_graph_execution_async
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_manager_client() -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -98,20 +90,20 @@ def execute_graph_block(
|
||||
path="/graphs/{graph_id}/execute/{graph_version}",
|
||||
tags=["graphs"],
|
||||
)
|
||||
def execute_graph(
|
||||
async def execute_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
graph_exec = execution_manager_client().add_execution(
|
||||
graph_id,
|
||||
graph_version=graph_version,
|
||||
data=node_input,
|
||||
graph_exec = await add_graph_execution_async(
|
||||
graph_id=graph_id,
|
||||
user_id=api_key.user_id,
|
||||
inputs=node_input,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
return {"id": graph_exec.graph_exec_id}
|
||||
return {"id": graph_exec.id}
|
||||
except Exception as e:
|
||||
msg = str(e).encode().decode("unicode_escape")
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
@@ -130,7 +122,7 @@ async def get_graph_execution_results(
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
results = await execution_db.get_node_execution_results(graph_exec_id)
|
||||
results = await execution_db.get_node_executions(graph_exec_id)
|
||||
last_result = results[-1] if results else None
|
||||
execution_status = (
|
||||
last_result.status if last_result else AgentExecutionStatus.INCOMPLETE
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -14,13 +15,12 @@ from backend.data.integrations import (
|
||||
wait_for_webhook_event,
|
||||
)
|
||||
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
|
||||
from backend.executor.manager import ExecutionManager
|
||||
from backend.executor.utils import add_graph_execution_async
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.util.exceptions import NeedConfirmation, NotFoundError
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -309,19 +309,22 @@ async def webhook_ingress_generic(
|
||||
if not webhook.attached_nodes:
|
||||
return
|
||||
|
||||
executor = get_service_client(ExecutionManager)
|
||||
executions: list[Awaitable] = []
|
||||
for node in webhook.attached_nodes:
|
||||
logger.debug(f"Webhook-attached node: {node}")
|
||||
if not node.is_triggered_by_event_type(event_type):
|
||||
logger.debug(f"Node #{node.id} doesn't trigger on event {event_type}")
|
||||
continue
|
||||
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
|
||||
executor.add_execution(
|
||||
graph_id=node.graph_id,
|
||||
graph_version=node.graph_version,
|
||||
data={f"webhook_{webhook_id}_payload": payload},
|
||||
user_id=webhook.user_id,
|
||||
executions.append(
|
||||
add_graph_execution_async(
|
||||
user_id=webhook.user_id,
|
||||
graph_id=node.graph_id,
|
||||
graph_version=node.graph_version,
|
||||
inputs={f"webhook_{webhook_id}_payload": payload},
|
||||
)
|
||||
)
|
||||
asyncio.gather(*executions)
|
||||
|
||||
|
||||
@router.post("/webhooks/{webhook_id}/ping")
|
||||
|
||||
@@ -11,14 +11,15 @@ from autogpt_libs.feature_flag.client import (
|
||||
initialize_launchdarkly,
|
||||
shutdown_launchdarkly,
|
||||
)
|
||||
from autogpt_libs.logging.utils import generate_uvicorn_config
|
||||
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.postmark.postmark
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.credit_admin_routes
|
||||
import backend.server.v2.admin.store_admin_routes
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
@@ -28,6 +29,7 @@ import backend.server.v2.store.model
|
||||
import backend.server.v2.store.routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.external.api import external_app
|
||||
@@ -56,8 +58,7 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.block.initialize_blocks()
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
# FIXME ERROR: operator does not exist: text ? unknown
|
||||
# await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
await backend.data.db.disconnect()
|
||||
@@ -107,6 +108,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/store",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.admin.credit_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/credits",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||
)
|
||||
@@ -141,8 +147,13 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
server_app,
|
||||
host=backend.util.settings.Config().agent_api_host,
|
||||
port=backend.util.settings.Config().agent_api_port,
|
||||
log_config=generate_uvicorn_config(),
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down Agent Server...")
|
||||
|
||||
@staticmethod
|
||||
async def test_execute_graph(
|
||||
graph_id: str,
|
||||
@@ -150,11 +161,12 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
graph_version: Optional[int] = None,
|
||||
node_input: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
return backend.server.routers.v1.execute_graph(
|
||||
return await backend.server.routers.v1.execute_graph(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
node_input=node_input or {},
|
||||
inputs=node_input or {},
|
||||
credentials_inputs={},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -269,7 +281,9 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
return backend.server.integrations.router.create_credentials(
|
||||
from backend.server.integrations.router import create_credentials
|
||||
|
||||
return create_credentials(
|
||||
user_id=user_id, provider=provider, credentials=credentials
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
import backend.data.block
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.v2.library.db as library_db
|
||||
@@ -31,7 +30,7 @@ from backend.data.api_key import (
|
||||
suspend_api_key,
|
||||
update_api_key_permissions,
|
||||
)
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -41,6 +40,8 @@ from backend.data.credit import (
|
||||
get_user_credit_model,
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
UserOnboardingUpdate,
|
||||
@@ -49,13 +50,16 @@ from backend.data.onboarding import (
|
||||
onboarding_enabled,
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
update_user_notification_preference,
|
||||
)
|
||||
from backend.executor import ExecutionManager, Scheduler, scheduler
|
||||
from backend.executor import scheduler
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||
on_graph_activate,
|
||||
@@ -79,13 +83,20 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_manager_client() -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
def execution_scheduler_client() -> scheduler.SchedulerClient:
|
||||
return get_service_client(scheduler.SchedulerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_scheduler_client() -> Scheduler:
|
||||
return get_service_client(Scheduler)
|
||||
async def execution_queue_client() -> AsyncRabbitMQ:
|
||||
client = AsyncRabbitMQ(create_execution_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_event_bus() -> AsyncRedisExecutionEventBus:
|
||||
return AsyncRedisExecutionEventBus()
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -206,7 +217,7 @@ async def is_onboarding_enabled():
|
||||
|
||||
@v1_router.get(path="/blocks", tags=["blocks"], dependencies=[Depends(auth_middleware)])
|
||||
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||
blocks = [block() for block in get_blocks().values()]
|
||||
costs = get_block_costs()
|
||||
return [
|
||||
{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks if not b.disabled
|
||||
@@ -219,7 +230,7 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput:
|
||||
obj = backend.data.block.get_block(block_id)
|
||||
obj = get_block(block_id)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||
|
||||
@@ -308,7 +319,7 @@ async def configure_user_auto_top_up(
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_user_auto_top_up(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> AutoTopUpConfig:
|
||||
return await get_auto_top_up(user_id)
|
||||
|
||||
@@ -375,7 +386,7 @@ async def get_credit_history(
|
||||
|
||||
@v1_router.get(path="/credits/refunds", dependencies=[Depends(auth_middleware)])
|
||||
async def get_refund_requests(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[RefundRequest]:
|
||||
return await _user_credit_model.get_refund_requests(user_id)
|
||||
|
||||
@@ -391,7 +402,7 @@ class DeleteGraphResponse(TypedDict):
|
||||
|
||||
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
|
||||
async def get_graphs(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
|
||||
|
||||
@@ -580,16 +591,25 @@ async def set_graph_active_version(
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
def execute_graph(
|
||||
async def execute_graph(
|
||||
graph_id: str,
|
||||
node_input: Annotated[dict[str, Any], Body(..., default_factory=dict)],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
inputs: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
credentials_inputs: Annotated[
|
||||
dict[str, CredentialsMetaInput], Body(..., embed=True, default_factory=dict)
|
||||
],
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> ExecuteGraphResponse:
|
||||
graph_exec = execution_manager_client().add_execution(
|
||||
graph_id, node_input, user_id=user_id, graph_version=graph_version
|
||||
graph_exec = await execution_utils.add_graph_execution_async(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=credentials_inputs,
|
||||
)
|
||||
return ExecuteGraphResponse(graph_exec_id=graph_exec.graph_exec_id)
|
||||
return ExecuteGraphResponse(graph_exec_id=graph_exec.id)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -605,9 +625,7 @@ async def stop_graph_run(
|
||||
):
|
||||
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
|
||||
|
||||
await asyncio.to_thread(
|
||||
lambda: execution_manager_client().cancel_execution(graph_exec_id)
|
||||
)
|
||||
await _cancel_execution(graph_exec_id)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
result = await execution_db.get_graph_execution(
|
||||
@@ -621,6 +639,49 @@ async def stop_graph_run(
|
||||
return result
|
||||
|
||||
|
||||
async def _cancel_execution(graph_exec_id: str):
|
||||
"""
|
||||
Mechanism:
|
||||
1. Set the cancel event
|
||||
2. Graph executor's cancel handler thread detects the event, terminates workers,
|
||||
reinitializes worker pool, and returns.
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
queue_client = await execution_queue_client()
|
||||
await queue_client.publish_message(
|
||||
routing_key="",
|
||||
message=execution_utils.CancelExecutionEvent(
|
||||
graph_exec_id=graph_exec_id
|
||||
).model_dump_json(),
|
||||
exchange=execution_utils.GRAPH_EXECUTION_CANCEL_EXCHANGE,
|
||||
)
|
||||
|
||||
# Update the status of the graph & node executions
|
||||
await execution_db.update_graph_execution_stats(
|
||||
graph_exec_id,
|
||||
execution_db.ExecutionStatus.TERMINATED,
|
||||
)
|
||||
node_execs = [
|
||||
node_exec.model_copy(update={"status": execution_db.ExecutionStatus.TERMINATED})
|
||||
for node_exec in await execution_db.get_node_executions(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
execution_db.ExecutionStatus.QUEUED,
|
||||
execution_db.ExecutionStatus.RUNNING,
|
||||
execution_db.ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
await execution_db.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
execution_db.ExecutionStatus.TERMINATED,
|
||||
)
|
||||
await asyncio.gather(
|
||||
*[execution_event_bus().publish(node_exec) for node_exec in node_execs]
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/executions",
|
||||
tags=["graphs"],
|
||||
@@ -718,14 +779,12 @@ async def create_schedule(
|
||||
detail=f"Graph #{schedule.graph_id} v.{schedule.graph_version} not found.",
|
||||
)
|
||||
|
||||
return await asyncio.to_thread(
|
||||
lambda: execution_scheduler_client().add_execution_schedule(
|
||||
graph_id=schedule.graph_id,
|
||||
graph_version=graph.version,
|
||||
cron=schedule.cron,
|
||||
input_data=schedule.input_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
return await execution_scheduler_client().add_execution_schedule(
|
||||
graph_id=schedule.graph_id,
|
||||
graph_version=graph.version,
|
||||
cron=schedule.cron,
|
||||
input_data=schedule.input_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -734,11 +793,11 @@ async def create_schedule(
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
def delete_schedule(
|
||||
async def delete_schedule(
|
||||
schedule_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
|
||||
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
|
||||
return {"id": schedule_id}
|
||||
|
||||
|
||||
@@ -747,11 +806,11 @@ def delete_schedule(
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
def get_execution_schedules(
|
||||
async def get_execution_schedules(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_id: str | None = None,
|
||||
) -> list[scheduler.ExecutionJobInfo]:
|
||||
return execution_scheduler_client().get_execution_schedules(
|
||||
return await execution_scheduler_client().get_execution_schedules(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
)
|
||||
@@ -792,7 +851,7 @@ async def create_api_key(
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_api_keys(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[APIKeyWithoutHash]:
|
||||
"""List all API keys for the user"""
|
||||
try:
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
import logging
|
||||
import typing
|
||||
|
||||
from autogpt_libs.auth import requires_admin_user
|
||||
from autogpt_libs.auth.depends import get_user_id
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
from prisma import Json
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import admin_get_user_history, get_user_credit_model
|
||||
from backend.server.v2.admin.model import AddUserCreditsResponse, UserHistoryResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["credits", "admin"],
|
||||
dependencies=[Depends(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/add_credits", response_model=AddUserCreditsResponse)
|
||||
async def add_user_credits(
|
||||
user_id: typing.Annotated[str, Body()],
|
||||
amount: typing.Annotated[int, Body()],
|
||||
comments: typing.Annotated[str, Body()],
|
||||
admin_user: typing.Annotated[
|
||||
str,
|
||||
Depends(get_user_id),
|
||||
],
|
||||
):
|
||||
""" """
|
||||
logger.info(f"Admin user {admin_user} is adding {amount} credits to user {user_id}")
|
||||
new_balance, transaction_key = await _user_credit_model._add_transaction(
|
||||
user_id,
|
||||
amount,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
metadata=Json({"admin_id": admin_user, "reason": comments}),
|
||||
)
|
||||
return {
|
||||
"new_balance": new_balance,
|
||||
"transaction_key": transaction_key,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/users_history",
|
||||
response_model=UserHistoryResponse,
|
||||
)
|
||||
async def admin_get_all_user_history(
|
||||
admin_user: typing.Annotated[
|
||||
str,
|
||||
Depends(get_user_id),
|
||||
],
|
||||
search: typing.Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
transaction_filter: typing.Optional[CreditTransactionType] = None,
|
||||
):
|
||||
""" """
|
||||
logger.info(f"Admin user {admin_user} is getting grant history")
|
||||
|
||||
try:
|
||||
resp = await admin_get_user_history(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
search=search,
|
||||
transaction_filter=transaction_filter,
|
||||
)
|
||||
logger.info(f"Admin user {admin_user} got {len(resp.history)} grant history")
|
||||
return resp
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting grant history: {e}")
|
||||
raise e
|
||||
16
autogpt_platform/backend/backend/server/v2/admin/model.py
Normal file
16
autogpt_platform/backend/backend/server/v2/admin/model.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.server.model import Pagination
|
||||
|
||||
|
||||
class UserHistoryResponse(BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
|
||||
history: list[UserTransaction]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class AddUserCreditsResponse(BaseModel):
|
||||
new_balance: int
|
||||
transaction_key: str
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import tempfile
|
||||
import typing
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
@@ -9,6 +10,7 @@ import prisma.enums
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
import backend.util.json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -98,3 +100,47 @@ async def review_submission(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while reviewing the submission"},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/submissions/download/{store_listing_version_id}",
|
||||
tags=["store", "admin"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
|
||||
)
|
||||
async def admin_download_agent_file(
|
||||
user: typing.Annotated[
|
||||
autogpt_libs.auth.models.User,
|
||||
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
|
||||
],
|
||||
store_listing_version_id: str = fastapi.Path(
|
||||
..., description="The ID of the agent to download"
|
||||
),
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""
|
||||
Download the agent file by streaming its content.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): The ID of the agent to download
|
||||
|
||||
Returns:
|
||||
StreamingResponse: A streaming response containing the agent's graph data.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
graph_data = await backend.server.v2.store.db.get_agent(
|
||||
user_id=user.user_id,
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
@@ -13,12 +13,17 @@ import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
import backend.server.v2.store.image_gen as store_image_gen
|
||||
import backend.server.v2.store.media as store_media
|
||||
from backend.data import db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.db import locked_transaction
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
integration_creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def list_library_agents(
|
||||
@@ -68,12 +73,12 @@ async def list_library_agents(
|
||||
if search_term:
|
||||
where_clause["OR"] = [
|
||||
{
|
||||
"Agent": {
|
||||
"AgentGraph": {
|
||||
"is": {"name": {"contains": search_term, "mode": "insensitive"}}
|
||||
}
|
||||
},
|
||||
{
|
||||
"Agent": {
|
||||
"AgentGraph": {
|
||||
"is": {
|
||||
"description": {"contains": search_term, "mode": "insensitive"}
|
||||
}
|
||||
@@ -206,7 +211,7 @@ async def add_generated_agent_image(
|
||||
async def create_library_agent(
|
||||
graph: backend.data.graph.GraphModel,
|
||||
user_id: str,
|
||||
) -> prisma.models.LibraryAgent:
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Adds an agent to the user's library (LibraryAgent table).
|
||||
|
||||
@@ -227,18 +232,21 @@ async def create_library_agent(
|
||||
)
|
||||
|
||||
try:
|
||||
return await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"isCreatedByUser": (user_id == graph.user_id),
|
||||
"useGraphIsActiveVersion": True,
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"Agent": {
|
||||
agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
isCreatedByUser=(user_id == graph.user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
# Creator={"connect": {"id": agent.userId}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
}
|
||||
),
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(agent)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating agent in library: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to create agent in library") from e
|
||||
@@ -246,38 +254,41 @@ async def create_library_agent(
|
||||
|
||||
async def update_agent_version_in_library(
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
agent_version: int,
|
||||
agent_graph_id: str,
|
||||
agent_graph_version: int,
|
||||
) -> None:
|
||||
"""
|
||||
Updates the agent version in the library if useGraphIsActiveVersion is True.
|
||||
|
||||
Args:
|
||||
user_id: Owner of the LibraryAgent.
|
||||
agent_id: The agent's ID to update.
|
||||
agent_version: The new version of the agent.
|
||||
agent_graph_id: The agent graph's ID to update.
|
||||
agent_graph_version: The new version of the agent graph.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's an error with the update.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Updating agent version in library for user #{user_id}, "
|
||||
f"agent #{agent_id} v{agent_version}"
|
||||
f"agent #{agent_graph_id} v{agent_graph_version}"
|
||||
)
|
||||
try:
|
||||
library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentId": agent_id,
|
||||
"agentGraphId": agent_graph_id,
|
||||
"useGraphIsActiveVersion": True,
|
||||
},
|
||||
)
|
||||
await prisma.models.LibraryAgent.prisma().update(
|
||||
where={"id": library_agent.id},
|
||||
data={
|
||||
"Agent": {
|
||||
"AgentGraph": {
|
||||
"connect": {
|
||||
"graphVersionId": {"id": agent_id, "version": agent_version}
|
||||
"graphVersionId": {
|
||||
"id": agent_graph_id,
|
||||
"version": agent_graph_version,
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -341,7 +352,7 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
|
||||
"""
|
||||
try:
|
||||
await prisma.models.LibraryAgent.prisma().delete_many(
|
||||
where={"agentId": graph_id, "userId": user_id}
|
||||
where={"agentGraphId": graph_id, "userId": user_id}
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error deleting library agent: {e}")
|
||||
@@ -374,10 +385,10 @@ async def add_store_agent_to_library(
|
||||
async with locked_transaction(f"add_agent_trx_{user_id}"):
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"Agent": True}
|
||||
where={"id": store_listing_version_id}, include={"AgentGraph": True}
|
||||
)
|
||||
)
|
||||
if not store_listing_version or not store_listing_version.Agent:
|
||||
if not store_listing_version or not store_listing_version.AgentGraph:
|
||||
logger.warning(
|
||||
f"Store listing version not found: {store_listing_version_id}"
|
||||
)
|
||||
@@ -385,7 +396,7 @@ async def add_store_agent_to_library(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
graph = store_listing_version.Agent
|
||||
graph = store_listing_version.AgentGraph
|
||||
if graph.userId == user_id:
|
||||
logger.warning(
|
||||
f"User #{user_id} attempted to add their own agent to their library"
|
||||
@@ -397,8 +408,8 @@ async def add_store_agent_to_library(
|
||||
await prisma.models.LibraryAgent.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentId": graph.id,
|
||||
"agentVersion": graph.version,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
},
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
@@ -418,17 +429,17 @@ async def add_store_agent_to_library(
|
||||
|
||||
# Create LibraryAgent entry
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"agentId": graph.id,
|
||||
"agentVersion": graph.version,
|
||||
"isCreatedByUser": False,
|
||||
},
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId=user_id,
|
||||
agentGraphId=graph.id,
|
||||
agentGraphVersion=graph.version,
|
||||
isCreatedByUser=False,
|
||||
),
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
logger.debug(
|
||||
f"Added graph #{graph.id} "
|
||||
f"for store listing #{store_listing_version.id} "
|
||||
f"Added graph #{graph.id} v{graph.version}"
|
||||
f"for store listing version #{store_listing_version.id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
@@ -467,8 +478,8 @@ async def set_is_deleted_for_library_agent(
|
||||
count = await prisma.models.LibraryAgent.prisma().update_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentId": agent_id,
|
||||
"agentVersion": agent_version,
|
||||
"agentGraphId": agent_id,
|
||||
"agentGraphVersion": agent_version,
|
||||
},
|
||||
data={"isDeleted": is_deleted},
|
||||
)
|
||||
@@ -597,6 +608,12 @@ async def upsert_preset(
|
||||
f"Upserting preset #{preset_id} ({repr(preset.name)}) for user #{user_id}",
|
||||
)
|
||||
try:
|
||||
inputs = [
|
||||
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput(
|
||||
name=name, data=prisma.fields.Json(data)
|
||||
)
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
if preset_id:
|
||||
# Update existing preset
|
||||
updated = await prisma.models.AgentPreset.prisma().update(
|
||||
@@ -605,12 +622,7 @@ async def upsert_preset(
|
||||
"name": preset.name,
|
||||
"description": preset.description,
|
||||
"isActive": preset.is_active,
|
||||
"InputPresets": {
|
||||
"create": [
|
||||
{"name": name, "data": prisma.fields.Json(data)}
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
},
|
||||
"InputPresets": {"create": inputs},
|
||||
},
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
@@ -620,20 +632,15 @@ async def upsert_preset(
|
||||
else:
|
||||
# Create new preset
|
||||
new_preset = await prisma.models.AgentPreset.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"name": preset.name,
|
||||
"description": preset.description,
|
||||
"agentId": preset.agent_id,
|
||||
"agentVersion": preset.agent_version,
|
||||
"isActive": preset.is_active,
|
||||
"InputPresets": {
|
||||
"create": [
|
||||
{"name": name, "data": prisma.fields.Json(data)}
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
},
|
||||
},
|
||||
data=prisma.types.AgentPresetCreateInput(
|
||||
userId=user_id,
|
||||
name=preset.name,
|
||||
description=preset.description,
|
||||
agentGraphId=preset.graph_id,
|
||||
agentGraphVersion=preset.graph_version,
|
||||
isActive=preset.is_active,
|
||||
InputPresets={"create": inputs},
|
||||
),
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
return library_model.LibraryAgentPreset.from_db(new_preset)
|
||||
@@ -662,3 +669,47 @@ async def delete_preset(user_id: str, preset_id: str) -> None:
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error deleting preset: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to delete preset") from e
|
||||
|
||||
|
||||
async def fork_library_agent(library_agent_id: str, user_id: str):
|
||||
"""
|
||||
Clones a library agent and its underyling graph and nodes (with new ids) for the given user.
|
||||
|
||||
Args:
|
||||
library_agent_id: The ID of the library agent to fork.
|
||||
user_id: The ID of the user who owns the library agent.
|
||||
|
||||
Returns:
|
||||
The forked LibraryAgent.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's an error during the forking process.
|
||||
"""
|
||||
logger.debug(f"Forking library agent {library_agent_id} for user {user_id}")
|
||||
try:
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}-fork_agent"):
|
||||
# Fetch the original agent
|
||||
original_agent = await get_library_agent(library_agent_id, user_id)
|
||||
|
||||
# Check if user owns the library agent
|
||||
# TODO: once we have open/closed sourced agents this needs to be enabled ~kcze
|
||||
# + update library/agents/[id]/page.tsx agent actions
|
||||
# if not original_agent.can_access_graph:
|
||||
# raise store_exceptions.DatabaseError(
|
||||
# f"User {user_id} cannot access library agent graph {library_agent_id}"
|
||||
# )
|
||||
|
||||
# Fork the underlying graph and nodes
|
||||
new_graph = await graph_db.fork_graph(
|
||||
original_agent.graph_id, original_agent.graph_version, user_id
|
||||
)
|
||||
new_graph = await on_graph_activate(
|
||||
new_graph,
|
||||
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
|
||||
)
|
||||
|
||||
# Create a library agent for the new graph
|
||||
return await create_library_agent(new_graph, user_id)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error cloning library agent: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to fork library agent") from e
|
||||
|
||||
@@ -30,8 +30,8 @@ async def test_get_library_agents(mocker):
|
||||
prisma.models.LibraryAgent(
|
||||
id="ua1",
|
||||
userId="test-user",
|
||||
agentId="agent2",
|
||||
agentVersion=1,
|
||||
agentGraphId="agent2",
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
@@ -39,7 +39,7 @@ async def test_get_library_agents(mocker):
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
Agent=prisma.models.AgentGraph(
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id="agent2",
|
||||
version=1,
|
||||
name="Test Agent 2",
|
||||
@@ -71,8 +71,8 @@ async def test_get_library_agents(mocker):
|
||||
assert result.agents[0].id == "ua1"
|
||||
assert result.agents[0].name == "Test Agent 2"
|
||||
assert result.agents[0].description == "Test Description 2"
|
||||
assert result.agents[0].agent_id == "agent2"
|
||||
assert result.agents[0].agent_version == 1
|
||||
assert result.agents[0].graph_id == "agent2"
|
||||
assert result.agents[0].graph_version == 1
|
||||
assert result.agents[0].can_access_graph is False
|
||||
assert result.agents[0].is_latest_version is True
|
||||
assert result.pagination.total_items == 1
|
||||
@@ -81,7 +81,7 @@ async def test_get_library_agents(mocker):
|
||||
assert result.pagination.page_size == 50
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_agent_to_library(mocker):
|
||||
await connect()
|
||||
# Mock data
|
||||
@@ -90,8 +90,8 @@ async def test_add_agent_to_library(mocker):
|
||||
version=1,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
agentId="agent1",
|
||||
agentVersion=1,
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
name="Test Agent",
|
||||
subHeading="Test Agent Subheading",
|
||||
imageUrls=["https://example.com/image.jpg"],
|
||||
@@ -102,7 +102,7 @@ async def test_add_agent_to_library(mocker):
|
||||
isAvailable=True,
|
||||
storeListingId="listing123",
|
||||
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
|
||||
Agent=prisma.models.AgentGraph(
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id="agent1",
|
||||
version=1,
|
||||
name="Test Agent",
|
||||
@@ -116,8 +116,8 @@ async def test_add_agent_to_library(mocker):
|
||||
mock_library_agent_data = prisma.models.LibraryAgent(
|
||||
id="ua1",
|
||||
userId="test-user",
|
||||
agentId=mock_store_listing_data.agentId,
|
||||
agentVersion=1,
|
||||
agentGraphId=mock_store_listing_data.agentGraphId,
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
@@ -125,7 +125,7 @@ async def test_add_agent_to_library(mocker):
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
Agent=mock_store_listing_data.Agent,
|
||||
AgentGraph=mock_store_listing_data.AgentGraph,
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
@@ -147,25 +147,28 @@ async def test_add_agent_to_library(mocker):
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"Agent": True}
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
mock_library_agent.return_value.find_first.assert_called_once_with(
|
||||
where={
|
||||
"userId": "test-user",
|
||||
"agentId": "agent1",
|
||||
"agentVersion": 1,
|
||||
"agentGraphId": "agent1",
|
||||
"agentGraphVersion": 1,
|
||||
},
|
||||
include=library_agent_include("test-user"),
|
||||
)
|
||||
mock_library_agent.return_value.create.assert_called_once_with(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
|
||||
userId="test-user",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
),
|
||||
include=library_agent_include("test-user"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_agent_to_library_not_found(mocker):
|
||||
await connect()
|
||||
# Mock prisma calls
|
||||
@@ -182,5 +185,5 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
|
||||
# Verify mock called correctly
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"Agent": True}
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
|
||||
@@ -25,8 +25,8 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
"""
|
||||
|
||||
id: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -58,12 +58,12 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
|
||||
model instance.
|
||||
"""
|
||||
if not agent.Agent:
|
||||
if not agent.AgentGraph:
|
||||
raise ValueError("Associated Agent record is required.")
|
||||
|
||||
graph = graph_model.GraphModel.from_db(agent.Agent)
|
||||
graph = graph_model.GraphModel.from_db(agent.AgentGraph)
|
||||
|
||||
agent_updated_at = agent.Agent.updatedAt
|
||||
agent_updated_at = agent.AgentGraph.updatedAt
|
||||
lib_agent_updated_at = agent.updatedAt
|
||||
|
||||
# Compute updated_at as the latest between library agent and graph
|
||||
@@ -83,21 +83,21 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||
days=7
|
||||
)
|
||||
executions = agent.Agent.AgentGraphExecution or []
|
||||
executions = agent.AgentGraph.Executions or []
|
||||
status_result = _calculate_agent_status(executions, week_ago)
|
||||
status = status_result.status
|
||||
new_output = status_result.new_output
|
||||
|
||||
# Check if user can access the graph
|
||||
can_access_graph = agent.Agent.userId == agent.userId
|
||||
can_access_graph = agent.AgentGraph.userId == agent.userId
|
||||
|
||||
# Hard-coded to True until a method to check is implemented
|
||||
is_latest_version = True
|
||||
|
||||
return LibraryAgent(
|
||||
id=agent.id,
|
||||
agent_id=agent.agentId,
|
||||
agent_version=agent.agentVersion,
|
||||
graph_id=agent.agentGraphId,
|
||||
graph_version=agent.agentGraphVersion,
|
||||
image_url=agent.imageUrl,
|
||||
creator_name=creator_name,
|
||||
creator_image_url=creator_image_url,
|
||||
@@ -174,8 +174,8 @@ class LibraryAgentPreset(pydantic.BaseModel):
|
||||
id: str
|
||||
updated_at: datetime.datetime
|
||||
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
|
||||
name: str
|
||||
description: str
|
||||
@@ -194,8 +194,8 @@ class LibraryAgentPreset(pydantic.BaseModel):
|
||||
return cls(
|
||||
id=preset.id,
|
||||
updated_at=preset.updatedAt,
|
||||
agent_id=preset.agentId,
|
||||
agent_version=preset.agentVersion,
|
||||
graph_id=preset.agentGraphId,
|
||||
graph_version=preset.agentGraphVersion,
|
||||
name=preset.name,
|
||||
description=preset.description,
|
||||
is_active=preset.isActive,
|
||||
@@ -218,8 +218,8 @@ class CreateLibraryAgentPresetRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
inputs: block_model.BlockInput
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
is_active: bool
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import prisma.models
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
from backend.util import json
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -15,19 +14,21 @@ async def test_agent_preset_from_db():
|
||||
id="test-agent-123",
|
||||
createdAt=datetime.datetime.now(),
|
||||
updatedAt=datetime.datetime.now(),
|
||||
agentId="agent-123",
|
||||
agentVersion=1,
|
||||
agentGraphId="agent-123",
|
||||
agentGraphVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test agent description",
|
||||
isActive=True,
|
||||
userId="test-user-123",
|
||||
isDeleted=False,
|
||||
InputPresets=[
|
||||
prisma.models.AgentNodeExecutionInputOutput(
|
||||
id="input-123",
|
||||
time=datetime.datetime.now(),
|
||||
name="input1",
|
||||
data=json.dumps({"type": "string", "value": "test value"}), # type: ignore
|
||||
prisma.models.AgentNodeExecutionInputOutput.model_validate(
|
||||
{
|
||||
"id": "input-123",
|
||||
"time": datetime.datetime.now(),
|
||||
"name": "input1",
|
||||
"data": '{"type": "string", "value": "test value"}',
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -36,7 +37,7 @@ async def test_agent_preset_from_db():
|
||||
agent = library_model.LibraryAgentPreset.from_db(db_agent)
|
||||
|
||||
assert agent.id == "test-agent-123"
|
||||
assert agent.agent_version == 1
|
||||
assert agent.graph_version == 1
|
||||
assert agent.is_active is True
|
||||
assert agent.name == "Test Agent"
|
||||
assert agent.description == "Test agent description"
|
||||
|
||||
@@ -190,3 +190,14 @@ async def update_library_agent(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update library agent",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/{library_agent_id}/fork")
|
||||
async def fork_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
return await library_db.fork_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -2,25 +2,17 @@ import logging
|
||||
from typing import Annotated, Any
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
import autogpt_libs.utils.cache
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, status
|
||||
|
||||
import backend.executor
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
import backend.util.service
|
||||
from backend.executor.utils import add_graph_execution_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@autogpt_libs.utils.cache.thread_cached
|
||||
def execution_manager_client() -> backend.executor.ExecutionManager:
|
||||
"""Return a cached instance of ExecutionManager client."""
|
||||
return backend.util.service.get_service_client(backend.executor.ExecutionManager)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/presets",
|
||||
summary="List presets",
|
||||
@@ -226,17 +218,17 @@ async def execute_preset(
|
||||
# Merge input overrides with preset inputs
|
||||
merged_node_input = preset.inputs | node_input
|
||||
|
||||
execution = execution_manager_client().add_execution(
|
||||
execution = await add_graph_execution_async(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
data=merged_node_input,
|
||||
user_id=user_id,
|
||||
inputs=merged_node_input,
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
|
||||
|
||||
return {"id": execution.graph_exec_id}
|
||||
return {"id": execution.id}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -35,8 +35,8 @@ async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||
agents=[
|
||||
library_model.LibraryAgent(
|
||||
id="test-agent-1",
|
||||
agent_id="test-agent-1",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -51,8 +51,8 @@ async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||
),
|
||||
library_model.LibraryAgent(
|
||||
id="test-agent-2",
|
||||
agent_id="test-agent-2",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-2",
|
||||
graph_version=1,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
@@ -78,9 +78,9 @@ async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||
|
||||
data = library_model.LibraryAgentResponse.model_validate(response.json())
|
||||
assert len(data.agents) == 2
|
||||
assert data.agents[0].agent_id == "test-agent-1"
|
||||
assert data.agents[0].graph_id == "test-agent-1"
|
||||
assert data.agents[0].can_access_graph is True
|
||||
assert data.agents[1].agent_id == "test-agent-2"
|
||||
assert data.agents[1].graph_id == "test-agent-2"
|
||||
assert data.agents[1].can_access_graph is False
|
||||
mock_db_call.assert_called_once_with(
|
||||
user_id="test-user-id",
|
||||
|
||||
@@ -200,17 +200,17 @@ async def get_available_graph(
|
||||
"isAvailable": True,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include={"Agent": {"include": {"AgentNodes": True}}},
|
||||
include={"AgentGraph": {"include": {"Nodes": True}}},
|
||||
)
|
||||
)
|
||||
|
||||
if not store_listing_version or not store_listing_version.Agent:
|
||||
if not store_listing_version or not store_listing_version.AgentGraph:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Store listing version {store_listing_version_id} not found",
|
||||
)
|
||||
|
||||
graph = GraphModel.from_db(store_listing_version.Agent)
|
||||
graph = GraphModel.from_db(store_listing_version.AgentGraph)
|
||||
# We return graph meta, without nodes, they cannot be just removed
|
||||
# because then input_schema would be empty
|
||||
return {
|
||||
@@ -516,7 +516,7 @@ async def delete_store_submission(
|
||||
try:
|
||||
# Verify the submission belongs to this user
|
||||
submission = await prisma.models.StoreListing.prisma().find_first(
|
||||
where={"agentId": submission_id, "owningUserId": user_id}
|
||||
where={"agentGraphId": submission_id, "owningUserId": user_id}
|
||||
)
|
||||
|
||||
if not submission:
|
||||
@@ -598,7 +598,7 @@ async def create_store_submission(
|
||||
# Check if listing already exists for this agent
|
||||
existing_listing = await prisma.models.StoreListing.prisma().find_first(
|
||||
where=prisma.types.StoreListingWhereInput(
|
||||
agentId=agent_id, owningUserId=user_id
|
||||
agentGraphId=agent_id, owningUserId=user_id
|
||||
)
|
||||
)
|
||||
|
||||
@@ -625,15 +625,15 @@ async def create_store_submission(
|
||||
# If no existing listing, create a new one
|
||||
data = prisma.types.StoreListingCreateInput(
|
||||
slug=slug,
|
||||
agentId=agent_id,
|
||||
agentVersion=agent_version,
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
owningUserId=user_id,
|
||||
createdAt=datetime.now(tz=timezone.utc),
|
||||
Versions={
|
||||
"create": [
|
||||
prisma.types.StoreListingVersionCreateInput(
|
||||
agentId=agent_id,
|
||||
agentVersion=agent_version,
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
imageUrls=image_urls,
|
||||
@@ -758,8 +758,8 @@ async def create_store_version(
|
||||
new_version = await prisma.models.StoreListingVersion.prisma().create(
|
||||
data=prisma.types.StoreListingVersionCreateInput(
|
||||
version=next_version,
|
||||
agentId=agent_id,
|
||||
agentVersion=agent_version,
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
imageUrls=image_urls,
|
||||
@@ -793,6 +793,7 @@ async def create_store_version(
|
||||
changes_summary=changes_summary,
|
||||
version=next_version,
|
||||
)
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create new store version"
|
||||
@@ -959,17 +960,17 @@ async def get_my_agents(
|
||||
try:
|
||||
search_filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
"Agent": {"is": {"StoreListing": {"none": {"isDeleted": False}}}},
|
||||
"AgentGraph": {"is": {"StoreListings": {"none": {"isDeleted": False}}}},
|
||||
"isArchived": False,
|
||||
"isDeleted": False,
|
||||
}
|
||||
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=search_filter,
|
||||
order=[{"agentVersion": "desc"}],
|
||||
order=[{"agentGraphVersion": "desc"}],
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
include={"Agent": True},
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
|
||||
total = await prisma.models.LibraryAgent.prisma().count(where=search_filter)
|
||||
@@ -985,7 +986,7 @@ async def get_my_agents(
|
||||
agent_image=library_agent.imageUrl,
|
||||
)
|
||||
for library_agent in library_agents
|
||||
if (graph := library_agent.Agent)
|
||||
if (graph := library_agent.AgentGraph)
|
||||
]
|
||||
|
||||
return backend.server.v2.store.model.MyAgentsResponse(
|
||||
@@ -1020,13 +1021,13 @@ async def get_agent(
|
||||
|
||||
graph = await backend.data.graph.get_graph(
|
||||
user_id=user_id,
|
||||
graph_id=store_listing_version.agentId,
|
||||
version=store_listing_version.agentVersion,
|
||||
graph_id=store_listing_version.agentGraphId,
|
||||
version=store_listing_version.agentGraphVersion,
|
||||
for_export=True,
|
||||
)
|
||||
if not graph:
|
||||
raise ValueError(
|
||||
f"Agent {store_listing_version.agentId} v{store_listing_version.agentVersion} not found"
|
||||
f"Agent {store_listing_version.agentGraphId} v{store_listing_version.agentGraphVersion} not found"
|
||||
)
|
||||
|
||||
return graph
|
||||
@@ -1050,11 +1051,14 @@ async def _get_missing_sub_store_listing(
|
||||
|
||||
# Fetch all the sub-graphs that are listed, and return the ones missing.
|
||||
store_listed_sub_graphs = {
|
||||
(listing.agentId, listing.agentVersion)
|
||||
(listing.agentGraphId, listing.agentGraphVersion)
|
||||
for listing in await prisma.models.StoreListingVersion.prisma().find_many(
|
||||
where={
|
||||
"OR": [
|
||||
{"agentId": sub_graph.id, "agentVersion": sub_graph.version}
|
||||
{
|
||||
"agentGraphId": sub_graph.id,
|
||||
"agentGraphVersion": sub_graph.version,
|
||||
}
|
||||
for sub_graph in sub_graphs
|
||||
],
|
||||
"submissionStatus": prisma.enums.SubmissionStatus.APPROVED,
|
||||
@@ -1084,7 +1088,7 @@ async def review_store_submission(
|
||||
where={"id": store_listing_version_id},
|
||||
include={
|
||||
"StoreListing": True,
|
||||
"Agent": {"include": AGENT_GRAPH_INCLUDE}, # type: ignore
|
||||
"AgentGraph": {"include": AGENT_GRAPH_INCLUDE},
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -1096,23 +1100,23 @@ async def review_store_submission(
|
||||
)
|
||||
|
||||
# If approving, update the listing to indicate it has an approved version
|
||||
if is_approved and store_listing_version.Agent:
|
||||
heading = f"Sub-graph of {store_listing_version.name}v{store_listing_version.agentVersion}"
|
||||
if is_approved and store_listing_version.AgentGraph:
|
||||
heading = f"Sub-graph of {store_listing_version.name}v{store_listing_version.agentGraphVersion}"
|
||||
|
||||
sub_store_listing_versions = [
|
||||
prisma.types.StoreListingVersionCreateWithoutRelationsInput(
|
||||
agentId=sub_graph.id,
|
||||
agentVersion=sub_graph.version,
|
||||
agentGraphId=sub_graph.id,
|
||||
agentGraphVersion=sub_graph.version,
|
||||
name=sub_graph.name or heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
|
||||
subHeading=heading,
|
||||
description=f"{heading}: {sub_graph.description}",
|
||||
changesSummary=f"This listing is added as a {heading} / #{store_listing_version.agentId}.",
|
||||
changesSummary=f"This listing is added as a {heading} / #{store_listing_version.agentGraphId}.",
|
||||
isAvailable=False, # Hide sub-graphs from the store by default.
|
||||
submittedAt=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
for sub_graph in await _get_missing_sub_store_listing(
|
||||
store_listing_version.Agent
|
||||
store_listing_version.AgentGraph
|
||||
)
|
||||
]
|
||||
|
||||
@@ -1155,8 +1159,8 @@ async def review_store_submission(
|
||||
|
||||
# Convert to Pydantic model for consistency
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id=submission.agentId,
|
||||
agent_version=submission.agentVersion,
|
||||
agent_id=submission.agentGraphId,
|
||||
agent_version=submission.agentGraphVersion,
|
||||
name=submission.name,
|
||||
sub_heading=submission.subHeading,
|
||||
slug=(
|
||||
@@ -1294,8 +1298,8 @@ async def get_admin_listings_with_versions(
|
||||
# If we have versions, turn them into StoreSubmission models
|
||||
for version in listing.Versions or []:
|
||||
version_model = backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id=version.agentId,
|
||||
agent_version=version.agentVersion,
|
||||
agent_id=version.agentGraphId,
|
||||
agent_version=version.agentGraphVersion,
|
||||
name=version.name,
|
||||
sub_heading=version.subHeading,
|
||||
slug=listing.slug,
|
||||
@@ -1324,8 +1328,8 @@ async def get_admin_listings_with_versions(
|
||||
backend.server.v2.store.model.StoreListingWithVersions(
|
||||
listing_id=listing.id,
|
||||
slug=listing.slug,
|
||||
agent_id=listing.agentId,
|
||||
agent_version=listing.agentVersion,
|
||||
agent_id=listing.agentGraphId,
|
||||
agent_version=listing.agentGraphVersion,
|
||||
active_version_id=listing.activeVersionId,
|
||||
has_approved_version=listing.hasApprovedVersion,
|
||||
creator_email=creator_email,
|
||||
@@ -1358,3 +1362,31 @@ async def get_admin_listings_with_versions(
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def get_agent_as_admin(
|
||||
user_id: str | None,
|
||||
store_listing_version_id: str,
|
||||
) -> GraphModel:
|
||||
"""Get agent using the version ID and store listing version ID."""
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}
|
||||
)
|
||||
)
|
||||
|
||||
if not store_listing_version:
|
||||
raise ValueError(f"Store listing version {store_listing_version_id} not found")
|
||||
|
||||
graph = await backend.data.graph.get_graph_as_admin(
|
||||
user_id=user_id,
|
||||
graph_id=store_listing_version.agentGraphId,
|
||||
version=store_listing_version.agentGraphVersion,
|
||||
for_export=True,
|
||||
)
|
||||
if not graph:
|
||||
raise ValueError(
|
||||
f"Agent {store_listing_version.agentGraphId} v{store_listing_version.agentGraphVersion} not found"
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
@@ -170,14 +170,14 @@ async def test_create_store_submission(mocker):
|
||||
isDeleted=False,
|
||||
hasApprovedVersion=False,
|
||||
slug="test-agent",
|
||||
agentId="agent-id",
|
||||
agentVersion=1,
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
owningUserId="user-id",
|
||||
Versions=[
|
||||
prisma.models.StoreListingVersion(
|
||||
id="version-id",
|
||||
agentId="agent-id",
|
||||
agentVersion=1,
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
createdAt=datetime.now(),
|
||||
|
||||
@@ -4,20 +4,7 @@ from typing import List
|
||||
import prisma.enums
|
||||
import pydantic
|
||||
|
||||
|
||||
class Pagination(pydantic.BaseModel):
|
||||
total_items: int = pydantic.Field(
|
||||
description="Total number of items.", examples=[42]
|
||||
)
|
||||
total_pages: int = pydantic.Field(
|
||||
description="Total number of pages.", examples=[97]
|
||||
)
|
||||
current_page: int = pydantic.Field(
|
||||
description="Current_page page number.", examples=[1]
|
||||
)
|
||||
page_size: int = pydantic.Field(
|
||||
description="Number of items per page.", examples=[25]
|
||||
)
|
||||
from backend.server.model import Pagination
|
||||
|
||||
|
||||
class MyAgent(pydantic.BaseModel):
|
||||
|
||||
@@ -5,11 +5,10 @@ from typing import Protocol
|
||||
|
||||
import uvicorn
|
||||
from autogpt_libs.auth import parse_jwt_token
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from autogpt_libs.logging.utils import generate_uvicorn_config
|
||||
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
@@ -19,7 +18,7 @@ from backend.server.model import (
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
WSSubscribeGraphExecutionsRequest,
|
||||
)
|
||||
from backend.util.service import AppProcess, get_service_client
|
||||
from backend.util.service import AppProcess
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -46,24 +45,14 @@ def get_connection_manager():
|
||||
return _connection_manager
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client():
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
try:
|
||||
redis.connect()
|
||||
event_queue = AsyncRedisExecutionEventBus()
|
||||
async for event in event_queue.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
except Exception as e:
|
||||
logger.exception(f"Event broadcaster error: {e}")
|
||||
raise
|
||||
finally:
|
||||
redis.disconnect()
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
@@ -286,8 +275,14 @@ class WebsocketServer(AppProcess):
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
uvicorn.run(
|
||||
server_app,
|
||||
host=Config().websocket_server_host,
|
||||
port=Config().websocket_server_port,
|
||||
log_config=generate_uvicorn_config(),
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down WebSocket Server...")
|
||||
|
||||
@@ -15,21 +15,25 @@ def to_dict(data) -> dict:
|
||||
|
||||
|
||||
def dumps(data) -> str:
|
||||
return json.dumps(jsonable_encoder(data))
|
||||
return json.dumps(to_dict(data))
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str, *args, target_type: Type[T], **kwargs) -> T: ...
|
||||
def loads(data: str | bytes, *args, target_type: Type[T], **kwargs) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str, *args, **kwargs) -> Any: ...
|
||||
def loads(data: str | bytes, *args, **kwargs) -> Any: ...
|
||||
|
||||
|
||||
def loads(data: str, *args, target_type: Type[T] | None = None, **kwargs) -> Any:
|
||||
def loads(
|
||||
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
|
||||
) -> Any:
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
parsed = json.loads(data, *args, **kwargs)
|
||||
if target_type:
|
||||
return type_match(parsed, target_type)
|
||||
|
||||
@@ -1,8 +1,24 @@
|
||||
import logging
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.integrations.anthropic import AnthropicIntegration
|
||||
from sentry_sdk.integrations.logging import LoggingIntegration
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
def sentry_init():
|
||||
sentry_dsn = Settings().secrets.sentry_dsn
|
||||
sentry_sdk.init(dsn=sentry_dsn, traces_sample_rate=1.0, profiles_sample_rate=1.0)
|
||||
sentry_sdk.init(
|
||||
dsn=sentry_dsn,
|
||||
traces_sample_rate=1.0,
|
||||
profiles_sample_rate=1.0,
|
||||
environment=f"app:{Settings().config.app_env.value}-behave:{Settings().config.behave_as.value}",
|
||||
_experiments={"enable_logs": True},
|
||||
integrations=[
|
||||
LoggingIntegration(sentry_logs_level=logging.INFO),
|
||||
AnthropicIntegration(
|
||||
include_prompts=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import signal
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from multiprocessing import Process, set_start_method
|
||||
from multiprocessing import Process, get_all_start_methods, set_start_method
|
||||
from typing import Optional
|
||||
|
||||
from backend.util.logging import configure_logging
|
||||
@@ -30,7 +30,12 @@ class AppProcess(ABC):
|
||||
process: Optional[Process] = None
|
||||
cleaned_up = False
|
||||
|
||||
set_start_method("spawn", force=True)
|
||||
if "forkserver" in get_all_start_methods():
|
||||
set_start_method("forkserver", force=True)
|
||||
else:
|
||||
logger.warning("Forkserver start method is not available. Using spawn instead.")
|
||||
set_start_method("spawn", force=True)
|
||||
|
||||
configure_logging()
|
||||
sentry_init()
|
||||
|
||||
@@ -48,6 +53,7 @@ class AppProcess(ABC):
|
||||
def service_name(cls) -> str:
|
||||
return cls.__name__
|
||||
|
||||
@abstractmethod
|
||||
def cleanup(self):
|
||||
"""
|
||||
Implement this method on a subclass to do post-execution cleanup,
|
||||
|
||||
@@ -142,7 +142,7 @@ def validate_url(
|
||||
|
||||
# Resolve all IP addresses for the hostname
|
||||
try:
|
||||
ip_list = [res[4][0] for res in socket.getaddrinfo(ascii_hostname, None)]
|
||||
ip_list = [str(res[4][0]) for res in socket.getaddrinfo(ascii_hostname, None)]
|
||||
ipv4 = [ip for ip in ip_list if ":" not in ip]
|
||||
ipv6 = [ip for ip in ip_list if ":" in ip]
|
||||
ip_addresses = ipv4 + ipv6 # Prefer IPv4 over IPv6
|
||||
|
||||
@@ -34,7 +34,7 @@ def conn_retry(
|
||||
def on_retry(retry_state):
|
||||
prefix = _log_prefix(resource_name, conn_id)
|
||||
exception = retry_state.outcome.exception()
|
||||
logger.error(f"{prefix} {action_name} failed: {exception}. Retrying now...")
|
||||
logger.warning(f"{prefix} {action_name} failed: {exception}. Retrying now...")
|
||||
|
||||
def decorator(func):
|
||||
is_coroutine = asyncio.iscoroutinefunction(func)
|
||||
@@ -73,3 +73,10 @@ def conn_retry(
|
||||
return async_wrapper if is_coroutine else sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
func_retry = retry(
|
||||
reraise=False,
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=30),
|
||||
)
|
||||
|
||||
@@ -1,52 +1,36 @@
|
||||
import asyncio
|
||||
import builtins
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from types import NoneType, UnionType
|
||||
from functools import cached_property, update_wrapper
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Concatenate,
|
||||
Coroutine,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
import httpx
|
||||
import Pyro5.api
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, responses
|
||||
from pydantic import BaseModel, TypeAdapter, create_model
|
||||
from Pyro5 import api as pyro
|
||||
from Pyro5 import config as pyro_config
|
||||
|
||||
from backend.data import db, rabbitmq, redis
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.json import to_dict
|
||||
from backend.util.metrics import sentry_init
|
||||
from backend.util.process import AppProcess, get_service_name
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Config, Secrets
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
@@ -57,131 +41,23 @@ api_host = config.pyro_host
|
||||
api_comm_retry = config.pyro_client_comm_retry
|
||||
api_comm_timeout = config.pyro_client_comm_timeout
|
||||
api_call_timeout = config.rpc_client_call_timeout
|
||||
pyro_config.MAX_RETRIES = api_comm_retry # type: ignore
|
||||
pyro_config.COMMTIMEOUT = api_comm_timeout # type: ignore
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
EXPOSED_FLAG = "__exposed__"
|
||||
|
||||
|
||||
def fastapi_expose(func: C) -> C:
|
||||
def expose(func: C) -> C:
|
||||
func = getattr(func, "__func__", func)
|
||||
setattr(func, "__exposed__", True)
|
||||
setattr(func, EXPOSED_FLAG, True)
|
||||
return func
|
||||
|
||||
|
||||
def fastapi_exposed_run_and_wait(
|
||||
f: Callable[P, Coroutine[None, None, R]]
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
# TODO:
|
||||
# This function lies about its return type to make the DynamicClient
|
||||
# call the function synchronously, fix this when DynamicClient can choose
|
||||
# to call a function synchronously or asynchronously.
|
||||
return expose(f) # type: ignore
|
||||
|
||||
|
||||
# ----- Begin Pyro Expose Block ---- #
|
||||
def pyro_expose(func: C) -> C:
|
||||
"""
|
||||
Decorator to mark a method or class to be exposed for remote calls.
|
||||
|
||||
## ⚠️ Gotcha
|
||||
Aside from "simple" types, only Pydantic models are passed unscathed *if annotated*.
|
||||
Any other passed or returned class objects are converted to dictionaries by Pyro.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
msg = f"Error in {func.__name__}: {e}"
|
||||
if isinstance(e, ValueError):
|
||||
logger.warning(msg)
|
||||
else:
|
||||
logger.exception(msg)
|
||||
raise
|
||||
|
||||
register_pydantic_serializers(func)
|
||||
|
||||
return pyro.expose(wrapper) # type: ignore
|
||||
|
||||
|
||||
def register_pydantic_serializers(func: Callable):
|
||||
"""Register custom serializers and deserializers for annotated Pydantic models"""
|
||||
for name, annotation in func.__annotations__.items():
|
||||
try:
|
||||
pydantic_types = _pydantic_models_from_type_annotation(annotation)
|
||||
except Exception as e:
|
||||
raise TypeError(f"Error while exposing {func.__name__}: {e}")
|
||||
|
||||
for model in pydantic_types:
|
||||
logger.debug(
|
||||
f"Registering Pyro (de)serializers for {func.__name__} annotation "
|
||||
f"'{name}': {model.__qualname__}"
|
||||
)
|
||||
pyro.register_class_to_dict(model, _make_custom_serializer(model))
|
||||
pyro.register_dict_to_class(
|
||||
model.__qualname__, _make_custom_deserializer(model)
|
||||
)
|
||||
|
||||
|
||||
def _make_custom_serializer(model: Type[BaseModel]):
|
||||
def custom_class_to_dict(obj):
|
||||
data = {
|
||||
"__class__": obj.__class__.__qualname__,
|
||||
**obj.model_dump(),
|
||||
}
|
||||
logger.debug(f"Serializing {obj.__class__.__qualname__} with data: {data}")
|
||||
return data
|
||||
|
||||
return custom_class_to_dict
|
||||
|
||||
|
||||
def _make_custom_deserializer(model: Type[BaseModel]):
|
||||
def custom_dict_to_class(qualname, data: dict):
|
||||
logger.debug(f"Deserializing {model.__qualname__} from data: {data}")
|
||||
return model(**data)
|
||||
|
||||
return custom_dict_to_class
|
||||
|
||||
|
||||
def pyro_exposed_run_and_wait(
|
||||
f: Callable[P, Coroutine[None, None, R]]
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
@expose
|
||||
@wraps(f)
|
||||
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
coroutine = f(*args, **kwargs)
|
||||
res = self.run_and_wait(coroutine)
|
||||
return res
|
||||
|
||||
# Register serializers for annotations on bare function
|
||||
register_pydantic_serializers(f)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
if config.use_http_based_rpc:
|
||||
expose = fastapi_expose
|
||||
exposed_run_and_wait = fastapi_exposed_run_and_wait
|
||||
else:
|
||||
expose = pyro_expose
|
||||
exposed_run_and_wait = pyro_exposed_run_and_wait
|
||||
|
||||
# ----- End Pyro Expose Block ---- #
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# AppService for IPC service based on HTTP request through FastAPI
|
||||
# --------------------------------------------------
|
||||
class BaseAppService(AppProcess, ABC):
|
||||
shared_event_loop: asyncio.AbstractEventLoop
|
||||
use_db: bool = False
|
||||
use_redis: bool = False
|
||||
rabbitmq_config: Optional[rabbitmq.RabbitMQConfig] = None
|
||||
rabbitmq_service: Optional[rabbitmq.AsyncRabbitMQ] = None
|
||||
use_supabase: bool = False
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@@ -202,20 +78,6 @@ class BaseAppService(AppProcess, ABC):
|
||||
|
||||
return target_host
|
||||
|
||||
@property
|
||||
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
|
||||
"""Access the RabbitMQ service. Will raise if not configured."""
|
||||
if not self.rabbitmq_service:
|
||||
raise RuntimeError("RabbitMQ not configured for this service")
|
||||
return self.rabbitmq_service
|
||||
|
||||
@property
|
||||
def rabbit_config(self) -> rabbitmq.RabbitMQConfig:
|
||||
"""Access the RabbitMQ config. Will raise if not configured."""
|
||||
if not self.rabbitmq_config:
|
||||
raise RuntimeError("RabbitMQ not configured for this service")
|
||||
return self.rabbitmq_config
|
||||
|
||||
def run_service(self) -> None:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
@@ -225,31 +87,6 @@ class BaseAppService(AppProcess, ABC):
|
||||
|
||||
def run(self):
|
||||
self.shared_event_loop = asyncio.get_event_loop()
|
||||
if self.use_db:
|
||||
self.shared_event_loop.run_until_complete(db.connect())
|
||||
if self.use_redis:
|
||||
redis.connect()
|
||||
if self.rabbitmq_config:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Configuring RabbitMQ...")
|
||||
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
|
||||
self.shared_event_loop.run_until_complete(self.rabbitmq_service.connect())
|
||||
if self.use_supabase:
|
||||
from supabase import create_client
|
||||
|
||||
secrets = Secrets()
|
||||
self.supabase = create_client(
|
||||
secrets.supabase_url, secrets.supabase_service_role_key
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
if self.use_db:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting DB...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
if self.use_redis:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
if self.rabbitmq_config:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting RabbitMQ...")
|
||||
|
||||
|
||||
class RemoteCallError(BaseModel):
|
||||
@@ -268,7 +105,7 @@ EXCEPTION_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
class FastApiAppService(BaseAppService, ABC):
|
||||
class AppService(BaseAppService, ABC):
|
||||
fastapi_app: FastAPI
|
||||
|
||||
@staticmethod
|
||||
@@ -324,14 +161,16 @@ class FastApiAppService(BaseAppService, ABC):
|
||||
|
||||
async def async_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
|
||||
return await f(
|
||||
**{name: getattr(body, name) for name in body.model_fields}
|
||||
**{name: getattr(body, name) for name in type(body).model_fields}
|
||||
)
|
||||
|
||||
return async_endpoint
|
||||
else:
|
||||
|
||||
def sync_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
|
||||
return f(**{name: getattr(body, name) for name in body.model_fields})
|
||||
return f(
|
||||
**{name: getattr(body, name) for name in type(body).model_fields}
|
||||
)
|
||||
|
||||
return sync_endpoint
|
||||
|
||||
@@ -351,12 +190,13 @@ class FastApiAppService(BaseAppService, ABC):
|
||||
self.shared_event_loop.run_until_complete(server.serve())
|
||||
|
||||
def run(self):
|
||||
sentry_init()
|
||||
super().run()
|
||||
self.fastapi_app = FastAPI()
|
||||
|
||||
# Register the exposed API routes.
|
||||
for attr_name, attr in vars(type(self)).items():
|
||||
if getattr(attr, "__exposed__", False):
|
||||
if getattr(attr, EXPOSED_FLAG, False):
|
||||
route_path = f"/{attr_name}"
|
||||
self.fastapi_app.add_api_route(
|
||||
route_path,
|
||||
@@ -381,86 +221,58 @@ class FastApiAppService(BaseAppService, ABC):
|
||||
self.run_service()
|
||||
|
||||
|
||||
# ----- Begin Pyro AppService Block ---- #
|
||||
|
||||
|
||||
class PyroAppService(BaseAppService, ABC):
|
||||
|
||||
@conn_retry("Pyro", "Starting Pyro Service")
|
||||
def __start_pyro(self):
|
||||
maximum_connection_thread_count = max(
|
||||
Pyro5.config.THREADPOOL_SIZE,
|
||||
config.num_node_workers * config.num_graph_workers,
|
||||
)
|
||||
|
||||
Pyro5.config.THREADPOOL_SIZE = maximum_connection_thread_count # type: ignore
|
||||
daemon = Pyro5.api.Daemon(host=api_host, port=self.get_port())
|
||||
self.uri = daemon.register(self, objectId=self.service_name)
|
||||
logger.info(f"[{self.service_name}] Connected to Pyro; URI = {self.uri}")
|
||||
daemon.requestLoop()
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
# Initialize the async loop.
|
||||
async_thread = threading.Thread(target=self.shared_event_loop.run_forever)
|
||||
async_thread.daemon = True
|
||||
async_thread.start()
|
||||
|
||||
# Initialize pyro service
|
||||
daemon_thread = threading.Thread(target=self.__start_pyro)
|
||||
daemon_thread.daemon = True
|
||||
daemon_thread.start()
|
||||
|
||||
# Run the main service loop (blocking).
|
||||
self.run_service()
|
||||
|
||||
|
||||
if config.use_http_based_rpc:
|
||||
|
||||
class AppService(FastApiAppService, ABC): # type: ignore #AppService defined twice
|
||||
pass
|
||||
|
||||
else:
|
||||
|
||||
class AppService(PyroAppService, ABC):
|
||||
pass
|
||||
|
||||
|
||||
# ----- End Pyro AppService Block ---- #
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# HTTP Client utilities for dynamic service client abstraction
|
||||
# --------------------------------------------------
|
||||
AS = TypeVar("AS", bound=AppService)
|
||||
|
||||
|
||||
def fastapi_close_service_client(client: Any) -> None:
|
||||
if hasattr(client, "close"):
|
||||
client.close()
|
||||
else:
|
||||
logger.warning(f"Client {client} is not closable")
|
||||
class AppServiceClient(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_service_type(cls) -> Type[AppService]:
|
||||
pass
|
||||
|
||||
def health_check(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
@conn_retry("FastAPI client", "Creating service client", max_retry=api_comm_retry)
|
||||
def fastapi_get_service_client(
|
||||
service_type: Type[AS],
|
||||
ASC = TypeVar("ASC", bound=AppServiceClient)
|
||||
|
||||
|
||||
@conn_retry("AppService client", "Creating service client", max_retry=api_comm_retry)
|
||||
def get_service_client(
|
||||
service_client_type: Type[ASC],
|
||||
call_timeout: int | None = api_call_timeout,
|
||||
) -> AS:
|
||||
) -> ASC:
|
||||
class DynamicClient:
|
||||
def __init__(self):
|
||||
service_type = service_client_type.get_service_type()
|
||||
host = service_type.get_host()
|
||||
port = service_type.get_port()
|
||||
self.base_url = f"http://{host}:{port}".rstrip("/")
|
||||
self.client = httpx.Client(
|
||||
|
||||
@cached_property
|
||||
def sync_client(self) -> httpx.Client:
|
||||
return httpx.Client(
|
||||
base_url=self.base_url,
|
||||
timeout=call_timeout,
|
||||
)
|
||||
|
||||
def _call_method(self, method_name: str, **kwargs) -> Any:
|
||||
@cached_property
|
||||
def async_client(self) -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=call_timeout,
|
||||
)
|
||||
|
||||
def _handle_call_method_response(
|
||||
self, response: httpx.Response, method_name: str
|
||||
) -> Any:
|
||||
try:
|
||||
response = self.client.post(method_name, json=to_dict(kwargs))
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
@@ -471,126 +283,102 @@ def fastapi_get_service_client(
|
||||
*(error.args or [str(e)])
|
||||
)
|
||||
|
||||
def _call_method_sync(self, method_name: str, **kwargs) -> Any:
|
||||
return self._handle_call_method_response(
|
||||
method_name=method_name,
|
||||
response=self.sync_client.post(method_name, json=to_dict(kwargs)),
|
||||
)
|
||||
|
||||
async def _call_method_async(self, method_name: str, **kwargs) -> Any:
|
||||
return self._handle_call_method_response(
|
||||
method_name=method_name,
|
||||
response=await self.async_client.post(
|
||||
method_name, json=to_dict(kwargs)
|
||||
),
|
||||
)
|
||||
|
||||
async def aclose(self):
|
||||
self.sync_client.close()
|
||||
await self.async_client.aclose()
|
||||
|
||||
def close(self):
|
||||
self.client.close()
|
||||
self.sync_client.close()
|
||||
|
||||
def _get_params(self, signature: inspect.Signature, *args, **kwargs) -> dict:
|
||||
if args:
|
||||
arg_names = list(signature.parameters.keys())
|
||||
if arg_names[0] in ("self", "cls"):
|
||||
arg_names = arg_names[1:]
|
||||
kwargs.update(dict(zip(arg_names, args)))
|
||||
return kwargs
|
||||
|
||||
def _get_return(self, expected_return: TypeAdapter | None, result: Any) -> Any:
|
||||
if expected_return:
|
||||
return expected_return.validate_python(result)
|
||||
return result
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
# Try to get the original function from the service type.
|
||||
orig_func = getattr(service_type, name, None)
|
||||
if orig_func is None:
|
||||
raise AttributeError(f"Method {name} not found in {service_type}")
|
||||
original_func = getattr(service_client_type, name, None)
|
||||
if original_func is None:
|
||||
raise AttributeError(
|
||||
f"Method {name} not found in {service_client_type}"
|
||||
)
|
||||
else:
|
||||
name = original_func.__name__
|
||||
|
||||
sig = inspect.signature(orig_func)
|
||||
sig = inspect.signature(original_func)
|
||||
ret_ann = sig.return_annotation
|
||||
if ret_ann != inspect.Signature.empty:
|
||||
expected_return = TypeAdapter(ret_ann)
|
||||
else:
|
||||
expected_return = None
|
||||
|
||||
def method(*args, **kwargs) -> Any:
|
||||
if args:
|
||||
arg_names = list(sig.parameters.keys())
|
||||
if arg_names[0] in ("self", "cls"):
|
||||
arg_names = arg_names[1:]
|
||||
kwargs.update(dict(zip(arg_names, args)))
|
||||
result = self._call_method(name, **kwargs)
|
||||
if expected_return:
|
||||
return expected_return.validate_python(result)
|
||||
return result
|
||||
if inspect.iscoroutinefunction(original_func):
|
||||
|
||||
return method
|
||||
async def async_method(*args, **kwargs) -> Any:
|
||||
params = self._get_params(sig, *args, **kwargs)
|
||||
result = await self._call_method_async(name, **params)
|
||||
return self._get_return(expected_return, result)
|
||||
|
||||
client = cast(AS, DynamicClient())
|
||||
return async_method
|
||||
else:
|
||||
|
||||
def sync_method(*args, **kwargs) -> Any:
|
||||
params = self._get_params(sig, *args, **kwargs)
|
||||
result = self._call_method_sync(name, **params)
|
||||
return self._get_return(expected_return, result)
|
||||
|
||||
return sync_method
|
||||
|
||||
client = cast(ASC, DynamicClient())
|
||||
client.health_check()
|
||||
|
||||
return cast(AS, client)
|
||||
return client
|
||||
|
||||
|
||||
# ----- Begin Pyro Client Block ---- #
|
||||
class PyroClient:
|
||||
proxy: Pyro5.api.Proxy
|
||||
def endpoint_to_sync(
|
||||
func: Callable[Concatenate[Any, P], Awaitable[R]],
|
||||
) -> Callable[Concatenate[Any, P], R]:
|
||||
"""
|
||||
Produce a *typed* stub that **looks** synchronous to the type‑checker.
|
||||
"""
|
||||
|
||||
def _stub(*args: P.args, **kwargs: P.kwargs) -> R: # pragma: no cover
|
||||
raise RuntimeError("should be intercepted by __getattr__")
|
||||
|
||||
update_wrapper(_stub, func)
|
||||
return cast(Callable[Concatenate[Any, P], R], _stub)
|
||||
|
||||
|
||||
def pyro_close_service_client(client: BaseAppService) -> None:
|
||||
if isinstance(client, PyroClient):
|
||||
client.proxy._pyroRelease()
|
||||
else:
|
||||
raise RuntimeError(f"Client {client.__class__} is not a Pyro client.")
|
||||
def endpoint_to_async(
|
||||
func: Callable[Concatenate[Any, P], R],
|
||||
) -> Callable[Concatenate[Any, P], Awaitable[R]]:
|
||||
"""
|
||||
The async mirror of `to_sync`.
|
||||
"""
|
||||
|
||||
async def _stub(*args: P.args, **kwargs: P.kwargs) -> R: # pragma: no cover
|
||||
raise RuntimeError("should be intercepted by __getattr__")
|
||||
|
||||
def pyro_get_service_client(service_type: Type[AS]) -> AS:
|
||||
service_name = service_type.service_name
|
||||
|
||||
class DynamicClient(PyroClient):
|
||||
@conn_retry("Pyro", f"Connecting to [{service_name}]")
|
||||
def __init__(self):
|
||||
uri = f"PYRO:{service_type.service_name}@{service_type.get_host()}:{service_type.get_port()}"
|
||||
logger.debug(f"Connecting to service [{service_name}]. URI = {uri}")
|
||||
self.proxy = Pyro5.api.Proxy(uri)
|
||||
# Attempt to bind to ensure the connection is established
|
||||
self.proxy._pyroBind()
|
||||
logger.debug(f"Successfully connected to service [{service_name}]")
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
res = getattr(self.proxy, name)
|
||||
return res
|
||||
|
||||
return cast(AS, DynamicClient())
|
||||
|
||||
|
||||
builtin_types = [*vars(builtins).values(), NoneType, Enum]
|
||||
|
||||
|
||||
def _pydantic_models_from_type_annotation(annotation) -> Iterator[type[BaseModel]]:
|
||||
# Peel Annotated parameters
|
||||
if (origin := get_origin(annotation)) and origin is Annotated:
|
||||
annotation = get_args(annotation)[0]
|
||||
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if origin in (
|
||||
Union,
|
||||
UnionType,
|
||||
list,
|
||||
List,
|
||||
tuple,
|
||||
Tuple,
|
||||
set,
|
||||
Set,
|
||||
frozenset,
|
||||
FrozenSet,
|
||||
):
|
||||
for arg in args:
|
||||
yield from _pydantic_models_from_type_annotation(arg)
|
||||
elif origin in (dict, Dict):
|
||||
key_type, value_type = args
|
||||
yield from _pydantic_models_from_type_annotation(key_type)
|
||||
yield from _pydantic_models_from_type_annotation(value_type)
|
||||
elif origin in (Awaitable, Coroutine):
|
||||
# For coroutines and awaitables, check the return type
|
||||
return_type = args[-1]
|
||||
yield from _pydantic_models_from_type_annotation(return_type)
|
||||
else:
|
||||
annotype = annotation if origin is None else origin
|
||||
|
||||
# Exclude generic types and aliases
|
||||
if (
|
||||
annotype is not None
|
||||
and not hasattr(typing, getattr(annotype, "__name__", ""))
|
||||
and isinstance(annotype, type)
|
||||
):
|
||||
if issubclass(annotype, BaseModel):
|
||||
yield annotype
|
||||
elif annotype not in builtin_types and not issubclass(annotype, Enum):
|
||||
raise TypeError(f"Unsupported type encountered: {annotype}")
|
||||
|
||||
|
||||
if config.use_http_based_rpc:
|
||||
close_service_client = fastapi_close_service_client
|
||||
get_service_client = fastapi_get_service_client
|
||||
else:
|
||||
close_service_client = pyro_close_service_client
|
||||
get_service_client = pyro_get_service_client
|
||||
|
||||
# ----- End Pyro Client Block ---- #
|
||||
update_wrapper(_stub, func)
|
||||
return cast(Callable[Concatenate[Any, P], Awaitable[R]], _stub)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user