mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
75 Commits
pwuts/open
...
abhi-9274/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3d92a4e06 | ||
|
|
afb66f75ec | ||
|
|
59ec61ef98 | ||
|
|
d7077b5161 | ||
|
|
475c5a5cc3 | ||
|
|
86d5cfe60b | ||
|
|
602f887623 | ||
|
|
1edde778c5 | ||
|
|
3526986f98 | ||
|
|
04c4340ee3 | ||
|
|
9fa62c03f6 | ||
|
|
d5dc687484 | ||
|
|
fb5ce0a16d | ||
|
|
a1f17ca797 | ||
|
|
8fdfd75cc4 | ||
|
|
5b5b2043e8 | ||
|
|
7d83f1db05 | ||
|
|
f07696e3c1 | ||
|
|
96a173a85f | ||
|
|
9715ea5313 | ||
|
|
ef022720d5 | ||
|
|
4ddb206f86 | ||
|
|
91f34966c8 | ||
|
|
11a69170b5 | ||
|
|
0675a41e42 | ||
|
|
56ce1a0c1c | ||
|
|
7fbe135ec8 | ||
|
|
eb6a0b34e1 | ||
|
|
1e3236a041 | ||
|
|
160a622ba4 | ||
|
|
e2a226dc49 | ||
|
|
5047e99fd1 | ||
|
|
c80d357149 | ||
|
|
20d39f6d44 | ||
|
|
d5b82c01e0 | ||
|
|
69b8d96516 | ||
|
|
67af77e179 | ||
|
|
2a92970a5f | ||
|
|
9052ee7b95 | ||
|
|
c783f64b33 | ||
|
|
055a231aed | ||
|
|
417d7732af | ||
|
|
f16a398a8e | ||
|
|
e8bbd945f2 | ||
|
|
d1730d7b1d | ||
|
|
8ea64327a1 | ||
|
|
3cf30c22fb | ||
|
|
05c670eef9 | ||
|
|
f6a4b036c7 | ||
|
|
c43924cd4e | ||
|
|
e3846c22bd | ||
|
|
9a7a838418 | ||
|
|
d61d815208 | ||
|
|
44e3770003 | ||
|
|
c0ee71fb27 | ||
|
|
71cdc18674 | ||
|
|
dc9348ec26 | ||
|
|
3ccbc31705 | ||
|
|
7cf0c6fe46 | ||
|
|
c69faa2a94 | ||
|
|
0c9dbbbe24 | ||
|
|
3e0742f9c5 | ||
|
|
d791cdea76 | ||
|
|
bb92226f5d | ||
|
|
f7ca5ac1ba | ||
|
|
4621a95bf3 | ||
|
|
8d8a6e450f | ||
|
|
cda07e81d1 | ||
|
|
6156fbb731 | ||
|
|
07a09d802c | ||
|
|
88b81f8cb2 | ||
|
|
7294741001 | ||
|
|
ea03c404b1 | ||
|
|
79651855c2 | ||
|
|
7977d1b1e5 |
@@ -34,6 +34,7 @@ jobs:
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
|
||||
trigger:
|
||||
|
||||
@@ -36,6 +36,7 @@ jobs:
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
trigger:
|
||||
needs: migrate
|
||||
|
||||
12
.github/workflows/platform-backend-ci.yml
vendored
12
.github/workflows/platform-backend-ci.yml
vendored
@@ -135,6 +135,7 @@ jobs:
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
@@ -151,12 +152,13 @@ jobs:
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: 'localhost'
|
||||
REDIS_PORT: '6379'
|
||||
REDIS_PASSWORD: 'testpassword'
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
|
||||
env:
|
||||
CI: true
|
||||
@@ -169,8 +171,8 @@ jobs:
|
||||
# If you want to replace this, you can do so by making our entire system generate
|
||||
# new credentials for each local user and update the environment variables in
|
||||
# the backend service, docker composes, and examples
|
||||
RABBITMQ_DEFAULT_USER: 'rabbitmq_user_default'
|
||||
RABBITMQ_DEFAULT_PASS: 'k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7'
|
||||
RABBITMQ_DEFAULT_USER: "rabbitmq_user_default"
|
||||
RABBITMQ_DEFAULT_PASS: "k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7"
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
|
||||
@@ -16,7 +16,7 @@ jobs:
|
||||
# operations-per-run: 5000
|
||||
stale-issue-message: >
|
||||
This issue has automatically been marked as _stale_ because it has not had
|
||||
any activity in the last 50 days. You can _unstale_ it by commenting or
|
||||
any activity in the last 170 days. You can _unstale_ it by commenting or
|
||||
removing the label. Otherwise, this issue will be closed in 10 days.
|
||||
stale-pr-message: >
|
||||
This pull request has automatically been marked as _stale_ because it has
|
||||
@@ -25,7 +25,7 @@ jobs:
|
||||
close-issue-message: >
|
||||
This issue was closed automatically because it has been stale for 10 days
|
||||
with no activity.
|
||||
days-before-stale: 100
|
||||
days-before-stale: 170
|
||||
days-before-close: 10
|
||||
# Do not touch meta issues:
|
||||
exempt-issue-labels: meta,fridge,project management
|
||||
|
||||
@@ -1,20 +1,59 @@
|
||||
import inspect
|
||||
import threading
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
def thread_cached(
|
||||
func: Callable[P, R] | Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
|
||||
thread_local = threading.local()
|
||||
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
return wrapper
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
|
||||
@@ -31,7 +31,7 @@ class RedisKeyedMutex:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if lock.locked():
|
||||
if lock.locked() and lock.owned():
|
||||
lock.release()
|
||||
|
||||
def acquire(self, key: Any) -> "RedisLock":
|
||||
|
||||
@@ -8,6 +8,7 @@ DB_CONNECT_TIMEOUT=60
|
||||
DB_POOL_TIMEOUT=300
|
||||
DB_SCHEMA=platform
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
|
||||
# EXECUTOR
|
||||
|
||||
@@ -73,7 +73,6 @@ FROM server_dependencies AS server
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
RUN poetry install --no-ansi --only-root
|
||||
|
||||
ENV DATABASE_URL=""
|
||||
ENV PORT=8000
|
||||
|
||||
CMD ["poetry", "run", "rest"]
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -19,21 +17,6 @@ from backend.util import json
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_executor_manager_client():
|
||||
from backend.executor import ExecutionManager
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_event_bus():
|
||||
from backend.data.execution import RedisExecutionEventBus
|
||||
|
||||
return RedisExecutionEventBus()
|
||||
|
||||
|
||||
class AgentExecutorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
user_id: str = SchemaField(description="User ID")
|
||||
@@ -76,23 +59,23 @@ class AgentExecutorBlock(Block):
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
from backend.data.execution import ExecutionEventType
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
executor_manager = get_executor_manager_client()
|
||||
event_bus = get_event_bus()
|
||||
event_bus = execution_utils.get_execution_event_bus()
|
||||
|
||||
graph_exec = executor_manager.add_execution(
|
||||
graph_exec = execution_utils.add_graph_execution(
|
||||
graph_id=input_data.graph_id,
|
||||
graph_version=input_data.graph_version,
|
||||
user_id=input_data.user_id,
|
||||
data=input_data.data,
|
||||
inputs=input_data.data,
|
||||
)
|
||||
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.graph_exec_id}"
|
||||
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.id}"
|
||||
logger.info(f"Starting execution of {log_id}")
|
||||
|
||||
for event in event_bus.listen(
|
||||
user_id=graph_exec.user_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
graph_exec_id=graph_exec.id,
|
||||
):
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
if event.status in [
|
||||
@@ -105,7 +88,7 @@ class AgentExecutorBlock(Block):
|
||||
else:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
|
||||
)
|
||||
|
||||
@@ -123,5 +106,7 @@ class AgentExecutorBlock(Block):
|
||||
continue
|
||||
|
||||
for output_data in event.output_data.get("output", []):
|
||||
logger.info(f"Execution {log_id} produced {output_name}: {output_data}")
|
||||
logger.debug(
|
||||
f"Execution {log_id} produced {output_name}: {output_data}"
|
||||
)
|
||||
yield output_name, output_data
|
||||
|
||||
@@ -88,6 +88,33 @@ class StoreValueBlock(Block):
|
||||
yield "output", input_data.data or input_data.input
|
||||
|
||||
|
||||
class PrintToConsoleBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: Any = SchemaField(description="The data to print to the console.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The data printed to the console.")
|
||||
status: str = SchemaField(description="The status of the print operation.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
|
||||
description="Print the given text to the console, this is used for a debugging purpose.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=PrintToConsoleBlock.Input,
|
||||
output_schema=PrintToConsoleBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
test_output=[
|
||||
("output", "Hello, World!"),
|
||||
("status", "printed"),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "output", input_data.text
|
||||
yield "status", "printed"
|
||||
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
|
||||
@@ -3,7 +3,7 @@ from googleapiclient.discovery import build
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
|
||||
from ._auth import (
|
||||
GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
@@ -36,13 +36,15 @@ class GoogleSheetsReadBlock(Block):
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
settings = Settings()
|
||||
super().__init__(
|
||||
id="5724e902-3635-47e9-a108-aaa0263a4988",
|
||||
description="This block reads data from a Google Sheets spreadsheet.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=GoogleSheetsReadBlock.Input,
|
||||
output_schema=GoogleSheetsReadBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED
|
||||
or settings.config.app_env == AppEnvironment.PRODUCTION,
|
||||
test_input={
|
||||
"spreadsheet_id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms",
|
||||
"range": "Sheet1!A1:B2",
|
||||
|
||||
@@ -82,7 +82,15 @@ class SendWebRequestBlock(Block):
|
||||
json=body if input_data.json_format else None,
|
||||
data=body if not input_data.json_format else None,
|
||||
)
|
||||
result = response.json() if input_data.json_format else response.text
|
||||
|
||||
if input_data.json_format:
|
||||
if response.status_code == 204 or not response.content.strip():
|
||||
result = None
|
||||
else:
|
||||
result = response.json()
|
||||
else:
|
||||
result = response.text
|
||||
|
||||
yield "response", result
|
||||
|
||||
except HTTPError as e:
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Any, Iterable, List, Literal, NamedTuple, Optional
|
||||
import anthropic
|
||||
import ollama
|
||||
import openai
|
||||
from anthropic import NotGiven
|
||||
from anthropic.types import ToolParam
|
||||
from groq import Groq
|
||||
from pydantic import BaseModel, SecretStr
|
||||
@@ -90,14 +89,17 @@ class LlmModelMeta(EnumMeta):
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenAI models
|
||||
O3_MINI = "o3-mini"
|
||||
O3 = "o3-2025-04-16"
|
||||
O1 = "o1"
|
||||
O1_PREVIEW = "o1-preview"
|
||||
O1_MINI = "o1-mini"
|
||||
GPT41 = "gpt-4.1-2025-04-14"
|
||||
GPT4O_MINI = "gpt-4o-mini"
|
||||
GPT4O = "gpt-4o"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
|
||||
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
@@ -118,6 +120,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
|
||||
# OpenRouter models
|
||||
GEMINI_FLASH_1_5 = "google/gemini-flash-1.5"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||
GROK_BETA = "x-ai/grok-beta"
|
||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||
@@ -157,12 +160,14 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
|
||||
MODEL_METADATA = {
|
||||
# https://platform.openai.com/docs/models
|
||||
LlmModel.O3: ModelMetadata("openai", 200000, 100000),
|
||||
LlmModel.O3_MINI: ModelMetadata("openai", 200000, 100000), # o3-mini-2025-01-31
|
||||
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
|
||||
LlmModel.O1_PREVIEW: ModelMetadata(
|
||||
"openai", 128000, 32768
|
||||
), # o1-preview-2024-09-12
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
|
||||
LlmModel.GPT41: ModelMetadata("openai", 1047576, 32768),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata(
|
||||
"openai", 128000, 16384
|
||||
), # gpt-4o-mini-2024-07-18
|
||||
@@ -172,6 +177,9 @@ MODEL_METADATA = {
|
||||
), # gpt-4-turbo-2024-04-09
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-3-7-sonnet-20250219
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-3-5-sonnet-20241022
|
||||
@@ -197,6 +205,7 @@ MODEL_METADATA = {
|
||||
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768, None),
|
||||
# https://openrouter.ai/models
|
||||
LlmModel.GEMINI_FLASH_1_5: ModelMetadata("open_router", 1000000, 8192),
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata("open_router", 1050000, 8192),
|
||||
LlmModel.GROK_BETA: ModelMetadata("open_router", 131072, 131072),
|
||||
LlmModel.MISTRAL_NEMO: ModelMetadata("open_router", 128000, 4096),
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata("open_router", 128000, 4096),
|
||||
@@ -249,7 +258,7 @@ class LLMResponse(BaseModel):
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
openai_tools: list[dict] | None = None,
|
||||
) -> Iterable[ToolParam] | NotGiven:
|
||||
) -> Iterable[ToolParam] | anthropic.NotGiven:
|
||||
"""
|
||||
Convert OpenAI tool format to Anthropic tool format.
|
||||
"""
|
||||
@@ -279,6 +288,13 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def estimate_token_count(prompt_messages: list[dict]) -> int:
|
||||
char_count = sum(len(str(msg.get("content", ""))) for msg in prompt_messages)
|
||||
message_overhead = len(prompt_messages) * 4
|
||||
estimated_tokens = (char_count // 4) + message_overhead
|
||||
return int(estimated_tokens * 1.2)
|
||||
|
||||
|
||||
def llm_call(
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
@@ -287,6 +303,7 @@ def llm_call(
|
||||
max_tokens: int | None,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
parallel_tool_calls: bool | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Make a call to a language model.
|
||||
@@ -309,7 +326,14 @@ def llm_call(
|
||||
- completion_tokens: The number of tokens used in the completion.
|
||||
"""
|
||||
provider = llm_model.metadata.provider
|
||||
max_tokens = max_tokens or llm_model.max_output_tokens or 4096
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or 4096
|
||||
user_max = max_tokens or model_max_output
|
||||
available_tokens = max(context_window - estimated_input_tokens, 0)
|
||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 0)
|
||||
|
||||
if provider == "openai":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
@@ -332,6 +356,9 @@ def llm_call(
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=(
|
||||
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
|
||||
),
|
||||
)
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
@@ -462,6 +489,7 @@ def llm_call(
|
||||
model=llm_model.value,
|
||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||
stream=False,
|
||||
options={"num_ctx": max_tokens},
|
||||
)
|
||||
return LLMResponse(
|
||||
raw_response=response.get("response") or "",
|
||||
@@ -487,6 +515,9 @@ def llm_call(
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=(
|
||||
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
|
||||
),
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
@@ -757,6 +788,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
prompt.append({"role": "user", "content": retry_prompt})
|
||||
except Exception as e:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
if (
|
||||
"maximum context length" in str(e).lower()
|
||||
or "token limit" in str(e).lower()
|
||||
):
|
||||
if input_data.max_tokens is None:
|
||||
input_data.max_tokens = llm_model.max_output_tokens or 4096
|
||||
input_data.max_tokens = int(input_data.max_tokens * 0.85)
|
||||
logger.debug(
|
||||
f"Reducing max_tokens to {input_data.max_tokens} for next attempt"
|
||||
)
|
||||
retry_prompt = f"Error calling LLM: {e}"
|
||||
finally:
|
||||
self.merge_stats(
|
||||
|
||||
734
autogpt_platform/backend/backend/blocks/postgres.py
Normal file
734
autogpt_platform/backend/backend/blocks/postgres.py
Normal file
@@ -0,0 +1,734 @@
|
||||
from enum import Enum
|
||||
from typing import Any, List, Literal, Optional
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import RealDictCursor
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
PostgresCredentials = UserPasswordCredentials
|
||||
PostgresCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.POSTGRES],
|
||||
Literal["user_password"],
|
||||
]
|
||||
|
||||
|
||||
def PostgresCredentialsField() -> PostgresCredentialsInput:
|
||||
"""Creates a Postgres credentials input on a block."""
|
||||
return CredentialsField(
|
||||
description="The Postgres integration requires a username and password.",
|
||||
)
|
||||
|
||||
|
||||
TEST_POSTGRES_CREDENTIALS = UserPasswordCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="postgres",
|
||||
username=SecretStr("mock-postgres-username"),
|
||||
password=SecretStr("mock-postgres-password"),
|
||||
title="Mock Postgres credentials",
|
||||
)
|
||||
|
||||
TEST_POSTGRES_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_POSTGRES_CREDENTIALS.provider,
|
||||
"id": TEST_POSTGRES_CREDENTIALS.id,
|
||||
"type": TEST_POSTGRES_CREDENTIALS.type,
|
||||
"title": TEST_POSTGRES_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class CommandType(str, Enum):
|
||||
TRUNCATE = "TRUNCATE"
|
||||
DELETE = "DELETE"
|
||||
DROP = "DROP"
|
||||
|
||||
|
||||
class ConditionOperator(str, Enum):
|
||||
EQUALS = "="
|
||||
NOT_EQUALS = "<>"
|
||||
GREATER_THAN = ">"
|
||||
LESS_THAN = "<"
|
||||
GREATER_EQUALS = ">="
|
||||
LESS_EQUALS = "<="
|
||||
LIKE = "LIKE"
|
||||
IN = "IN"
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
column: str
|
||||
operator: ConditionOperator
|
||||
value: Any
|
||||
|
||||
|
||||
class CombineCondition(str, Enum):
|
||||
AND = "AND"
|
||||
OR = "OR"
|
||||
|
||||
|
||||
class PostgresDeleteBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: PostgresCredentialsInput = PostgresCredentialsField()
|
||||
host: str = SchemaField(description="Database host", advanced=False)
|
||||
port: int = SchemaField(description="Database port", advanced=False)
|
||||
database: str = SchemaField(description="Database name", default="postgres",advanced=False)
|
||||
schema_: str = SchemaField(description="Schema name", default="public",advanced=False)
|
||||
table: str = SchemaField(description="Table name")
|
||||
command: CommandType = SchemaField(
|
||||
description="Command type to execute",
|
||||
default=CommandType.DELETE,
|
||||
advanced=False
|
||||
)
|
||||
conditions: List[Condition] = SchemaField(
|
||||
description="Conditions for DELETE command",
|
||||
default=[],
|
||||
advanced=False
|
||||
)
|
||||
combine_conditions: CombineCondition = SchemaField(
|
||||
description="How to combine multiple conditions",
|
||||
default=CombineCondition.AND,
|
||||
advanced=False
|
||||
)
|
||||
restart_sequences: bool = SchemaField(
|
||||
description="Restart any auto-incrementing counters associated with the table after truncate",
|
||||
default=False
|
||||
)
|
||||
cascade: bool = SchemaField(
|
||||
description="This automatically truncates any tables that reference the target table via foreign keys, Only used for Truncate and Drop",
|
||||
default=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Operation succeeded")
|
||||
rows_affected: Optional[int] = SchemaField(description="Number of rows affected")
|
||||
error: str = SchemaField(description="Error message if operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="81b103ad-0fa9-47d3-a18f-2ea96579e3bb",
|
||||
description="Delete, truncate or drop data from a PostgreSQL table",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=PostgresDeleteBlock.Input,
|
||||
output_schema=PostgresDeleteBlock.Output,
|
||||
test_credentials=TEST_POSTGRES_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_POSTGRES_CREDENTIALS_INPUT,
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"database": "test_db",
|
||||
"schema_": "public",
|
||||
"table": "users",
|
||||
"command": CommandType.DELETE,
|
||||
"conditions": [
|
||||
{"column": "id", "operator": ConditionOperator.EQUALS, "value": 1}
|
||||
]
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("rows_affected", 1)
|
||||
],
|
||||
test_mock={
|
||||
"run": lambda *args, **kwargs: [
|
||||
("success", True),
|
||||
("rows_affected", 1)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: PostgresCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
conn = None
|
||||
try:
|
||||
conn = psycopg2.connect(
|
||||
host=input_data.host,
|
||||
port=input_data.port,
|
||||
database=input_data.database,
|
||||
user=credentials.username.get_secret_value(),
|
||||
password=credentials.password.get_secret_value()
|
||||
)
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
rows_affected = 0
|
||||
|
||||
if input_data.command == CommandType.TRUNCATE:
|
||||
sql = f"TRUNCATE TABLE {input_data.schema_}.{input_data.table}"
|
||||
if input_data.restart_sequences:
|
||||
sql += " RESTART IDENTITY"
|
||||
if input_data.cascade:
|
||||
sql += " CASCADE"
|
||||
cursor.execute(sql)
|
||||
|
||||
elif input_data.command == CommandType.DELETE:
|
||||
if input_data.conditions:
|
||||
where_clauses = []
|
||||
values = []
|
||||
|
||||
for condition in input_data.conditions:
|
||||
if condition.operator == ConditionOperator.IN:
|
||||
placeholders = ", ".join(["%s"] * len(condition.value))
|
||||
where_clauses.append(f"{condition.column} IN ({placeholders})")
|
||||
values.extend(condition.value)
|
||||
else:
|
||||
where_clauses.append(f"{condition.column} {condition.operator.value} %s")
|
||||
values.append(condition.value)
|
||||
|
||||
where_clause = f" {input_data.combine_conditions.value} ".join(where_clauses)
|
||||
sql = f"DELETE FROM {input_data.schema_}.{input_data.table} WHERE {where_clause}"
|
||||
cursor.execute(sql, values)
|
||||
else:
|
||||
sql = f"DELETE FROM {input_data.schema_}.{input_data.table}"
|
||||
cursor.execute(sql)
|
||||
|
||||
rows_affected = cursor.rowcount
|
||||
|
||||
elif input_data.command == CommandType.DROP:
|
||||
sql = f"DROP TABLE {input_data.schema_}.{input_data.table}"
|
||||
if input_data.cascade:
|
||||
sql += " CASCADE"
|
||||
cursor.execute(sql)
|
||||
|
||||
conn.commit()
|
||||
yield "success", True
|
||||
yield "rows_affected", rows_affected
|
||||
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
yield "error", str(e)
|
||||
|
||||
finally:
|
||||
if conn:
|
||||
conn.close() # Just for extra safety
|
||||
|
||||
class PostgresExecuteQueryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: PostgresCredentialsInput = PostgresCredentialsField()
|
||||
host: str = SchemaField(description="Database host", advanced=False)
|
||||
port: int = SchemaField(description="Database port", advanced=False)
|
||||
database: str = SchemaField(description="Database name", default="postgres", advanced=False)
|
||||
schema_: str = SchemaField(description="Schema name", default="public", advanced=False)
|
||||
query: str = SchemaField(description="SQL query to execute")
|
||||
parameters: List[Any] = SchemaField(description="Query parameters", default=[], advanced=False)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Operation succeeded")
|
||||
result: Any = SchemaField(description="Query results or affected rows")
|
||||
error: str = SchemaField(description="Error message if operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c5d18dc8-ee3c-4366-ba99-a3996b7a4e78",
|
||||
description="Executes an SQL query on a PostgreSQL database.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=PostgresExecuteQueryBlock.Input,
|
||||
output_schema=PostgresExecuteQueryBlock.Output,
|
||||
test_credentials=TEST_POSTGRES_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_POSTGRES_CREDENTIALS_INPUT,
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"database": "test_db",
|
||||
"schema_": "public",
|
||||
"query": "SELECT * FROM users WHERE id = %s",
|
||||
"parameters": [1]
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("result", [{"id": 1, "name": "Test User"}])
|
||||
],
|
||||
test_mock={
|
||||
"run": lambda *args, **kwargs: [
|
||||
("success", True),
|
||||
("result", [{"id": 1, "name": "Test User"}])
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: PostgresCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
conn = None
|
||||
try:
|
||||
conn = psycopg2.connect(
|
||||
host=input_data.host,
|
||||
port=input_data.port,
|
||||
database=input_data.database,
|
||||
user=credentials.username.get_secret_value(),
|
||||
password=credentials.password.get_secret_value()
|
||||
)
|
||||
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
# Using RealDictCursor to return data as dict, otherwise cursor return data as tuple
|
||||
cursor.execute(input_data.query, input_data.parameters)
|
||||
if cursor.description:
|
||||
result = cursor.fetchall()
|
||||
result = [dict(row) for row in result]
|
||||
else:
|
||||
# Query doesn't return data (INSERT, UPDATE, DELETE)
|
||||
result = cursor.rowcount # Number of rows affected by executing this query
|
||||
|
||||
conn.commit()
|
||||
yield "success", True
|
||||
yield "result", result
|
||||
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
yield "error", str(e)
|
||||
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
class PostgresInsertBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: PostgresCredentialsInput = PostgresCredentialsField()
|
||||
host: str = SchemaField(description="Database host", advanced=False)
|
||||
port: int = SchemaField(description="Database port", advanced=False)
|
||||
database: str = SchemaField(description="Database name", default="postgres", advanced=False)
|
||||
schema_: str = SchemaField(description="Schema name", default="public", advanced=False)
|
||||
table: str = SchemaField(description="Table name")
|
||||
data: List[dict] = SchemaField(description="Data to insert", default=[])
|
||||
return_inserted_rows: bool = SchemaField(description="Return inserted rows", default=False)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Operation succeeded")
|
||||
inserted_rows: List[dict] = SchemaField(description="Inserted rows if requested")
|
||||
rows_affected: int = SchemaField(description="Number of rows affected")
|
||||
error: str = SchemaField(description="Error message if operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="82a6c2d5-4c6f-4e3a-aba2-feae15c03cbe",
|
||||
description="Inserts rows into a PostgreSQL table",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=PostgresInsertBlock.Input,
|
||||
output_schema=PostgresInsertBlock.Output,
|
||||
test_credentials=TEST_POSTGRES_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_POSTGRES_CREDENTIALS_INPUT,
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"database": "test_db",
|
||||
"schema_": "public",
|
||||
"table": "users",
|
||||
"data": [{"name": "Test User", "email": "test@example.com"}],
|
||||
"return_inserted_rows": True
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("rows_affected", 1),
|
||||
("inserted_rows", [{"id": 1, "name": "Test User", "email": "test@example.com"}])
|
||||
],
|
||||
test_mock={
|
||||
"run": lambda *args, **kwargs: [
|
||||
("success", True),
|
||||
("rows_affected", 1),
|
||||
("inserted_rows", [{"id": 1, "name": "Test User", "email": "test@example.com"}])
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: PostgresCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
conn = None
|
||||
try:
|
||||
conn = psycopg2.connect(
|
||||
host=input_data.host,
|
||||
port=input_data.port,
|
||||
database=input_data.database,
|
||||
user=credentials.username.get_secret_value(),
|
||||
password=credentials.password.get_secret_value()
|
||||
)
|
||||
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
if not input_data.data:
|
||||
yield "success", True
|
||||
yield "rows_affected", 0
|
||||
yield "inserted_rows", []
|
||||
return
|
||||
|
||||
columns = list(input_data.data[0].keys())
|
||||
cols_str = ", ".join(columns)
|
||||
placeholders = ", ".join(["%s"] * len(columns))
|
||||
sql = f"INSERT INTO {input_data.schema_}.{input_data.table} ({cols_str}) VALUES ({placeholders})"
|
||||
|
||||
if input_data.return_inserted_rows:
|
||||
sql += " RETURNING *"
|
||||
|
||||
inserted_rows = []
|
||||
rows_affected = 0
|
||||
|
||||
for row in input_data.data:
|
||||
values = [row[col] for col in columns]
|
||||
cursor.execute(sql, values)
|
||||
rows_affected += cursor.rowcount
|
||||
|
||||
if input_data.return_inserted_rows:
|
||||
inserted_rows.extend([dict(row) for row in cursor.fetchall()])
|
||||
|
||||
conn.commit()
|
||||
yield "success", True
|
||||
yield "rows_affected", rows_affected
|
||||
yield "inserted_rows", inserted_rows
|
||||
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
yield "success", False
|
||||
yield "error", str(e)
|
||||
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
class PostgresInsertOrUpdateBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: PostgresCredentialsInput = PostgresCredentialsField()
|
||||
host: str = SchemaField(description="Database host", advanced=False)
|
||||
port: int = SchemaField(description="Database port", advanced=False)
|
||||
database: str = SchemaField(description="Database name", default="postgres", advanced=False)
|
||||
schema_: str = SchemaField(description="Schema name", default="public", advanced=False)
|
||||
table: str = SchemaField(description="Table name")
|
||||
data: List[dict] = SchemaField(description="Data to insert or update", default=[])
|
||||
key_columns: List[str] = SchemaField(description="Columns to use as unique constraint", default=[])
|
||||
return_affected_rows: bool = SchemaField(description="Return affected rows", default=False)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Operation succeeded")
|
||||
affected_rows: List[dict] = SchemaField(description="Affected rows if requested")
|
||||
rows_affected: int = SchemaField(description="Number of rows affected")
|
||||
error: str = SchemaField(description="Error message if operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fa8e0ce3-5b8c-49e2-a3b7-dca21f5c4a72",
|
||||
description="Inserts or updates rows in a PostgreSQL table using ON CONFLICT",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=PostgresInsertOrUpdateBlock.Input,
|
||||
output_schema=PostgresInsertOrUpdateBlock.Output,
|
||||
test_credentials=TEST_POSTGRES_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_POSTGRES_CREDENTIALS_INPUT,
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"database": "test_db",
|
||||
"schema_": "public",
|
||||
"table": "users",
|
||||
"data": [{"id": 1, "name": "Updated User", "email": "updated@example.com"}],
|
||||
"key_columns": ["id"],
|
||||
"return_affected_rows": True
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("rows_affected", 1),
|
||||
("affected_rows", [{"id": 1, "name": "Updated User", "email": "updated@example.com"}])
|
||||
],
|
||||
test_mock={
|
||||
"run": lambda *args, **kwargs: [
|
||||
("success", True),
|
||||
("rows_affected", 1),
|
||||
("affected_rows", [{"id": 1, "name": "Updated User", "email": "updated@example.com"}])
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: PostgresCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
conn = None
|
||||
try:
|
||||
conn = psycopg2.connect(
|
||||
host=input_data.host,
|
||||
port=input_data.port,
|
||||
database=input_data.database,
|
||||
user=credentials.username.get_secret_value(),
|
||||
password=credentials.password.get_secret_value()
|
||||
)
|
||||
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
if not input_data.data or not input_data.key_columns:
|
||||
yield "success", True
|
||||
yield "rows_affected", 0
|
||||
yield "affected_rows", []
|
||||
return
|
||||
|
||||
affected_rows = []
|
||||
rows_affected = 0
|
||||
|
||||
for row in input_data.data:
|
||||
columns = list(row.keys())
|
||||
cols_str = ", ".join(columns)
|
||||
placeholders = ", ".join(["%s"] * len(columns))
|
||||
conflict_cols = ", ".join(input_data.key_columns)
|
||||
update_cols = ", ".join(
|
||||
f"{col} = EXCLUDED.{col}" for col in columns if col not in input_data.key_columns
|
||||
)
|
||||
|
||||
sql = (
|
||||
f"INSERT INTO {input_data.schema_}.{input_data.table} ({cols_str}) "
|
||||
f"VALUES ({placeholders}) ON CONFLICT ({conflict_cols}) DO UPDATE SET {update_cols}"
|
||||
)
|
||||
|
||||
if input_data.return_affected_rows:
|
||||
sql += " RETURNING *"
|
||||
|
||||
values = [row[col] for col in columns]
|
||||
cursor.execute(sql, values)
|
||||
rows_affected += cursor.rowcount
|
||||
|
||||
if input_data.return_affected_rows:
|
||||
affected_rows.extend([dict(row) for row in cursor.fetchall()])
|
||||
|
||||
conn.commit()
|
||||
yield "success", True
|
||||
yield "rows_affected", rows_affected
|
||||
yield "affected_rows", affected_rows
|
||||
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
yield "success", False
|
||||
yield "error", str(e)
|
||||
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
class PostgresSelectBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: PostgresCredentialsInput = PostgresCredentialsField()
|
||||
host: str = SchemaField(description="Database host", advanced=False)
|
||||
port: int = SchemaField(description="Database port", advanced=False)
|
||||
database: str = SchemaField(description="Database name", default="postgres", advanced=False)
|
||||
schema_: str = SchemaField(description="Schema name", default="public", advanced=False)
|
||||
table: str = SchemaField(description="Table name")
|
||||
columns: List[str] = SchemaField(description="Columns to select (empty for all columns)", default=[])
|
||||
conditions: List[Condition] = SchemaField(description="Conditions for WHERE clause", default=[], advanced=False)
|
||||
combine_conditions: CombineCondition = SchemaField(
|
||||
description="How to combine multiple conditions",
|
||||
default=CombineCondition.AND,
|
||||
advanced=False
|
||||
)
|
||||
limit: Optional[int] = SchemaField(description="Maximum number of rows to return", default=None)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Operation succeeded")
|
||||
rows: List[dict] = SchemaField(description="Selected rows")
|
||||
error: str = SchemaField(description="Error message if operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e7c92ea5-1d2a-4e9c-bb89-376dfcbea342",
|
||||
description="Selects rows from a PostgreSQL table",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=PostgresSelectBlock.Input,
|
||||
output_schema=PostgresSelectBlock.Output,
|
||||
test_credentials=TEST_POSTGRES_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_POSTGRES_CREDENTIALS_INPUT,
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"database": "test_db",
|
||||
"schema_": "public",
|
||||
"table": "users",
|
||||
"columns": ["id", "name", "email"],
|
||||
"conditions": [
|
||||
{"column": "id", "operator": ConditionOperator.GREATER_THAN, "value": 0}
|
||||
],
|
||||
"limit": 100
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("rows", [{"id": 1, "name": "Test User", "email": "test@example.com"}])
|
||||
],
|
||||
test_mock={
|
||||
"run": lambda *args, **kwargs: [
|
||||
("success", True),
|
||||
("rows", [{"id": 1, "name": "Test User", "email": "test@example.com"}])
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: PostgresCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
conn = None
|
||||
try:
|
||||
conn = psycopg2.connect(
|
||||
host=input_data.host,
|
||||
port=input_data.port,
|
||||
database=input_data.database,
|
||||
user=credentials.username.get_secret_value(),
|
||||
password=credentials.password.get_secret_value()
|
||||
)
|
||||
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cols = ", ".join(input_data.columns) if input_data.columns else "*"
|
||||
sql = f"SELECT {cols} FROM {input_data.schema_}.{input_data.table}"
|
||||
|
||||
values = []
|
||||
if input_data.conditions:
|
||||
where_clauses = []
|
||||
|
||||
for condition in input_data.conditions:
|
||||
if condition.operator == ConditionOperator.IN:
|
||||
placeholders = ", ".join(["%s"] * len(condition.value))
|
||||
where_clauses.append(f"{condition.column} IN ({placeholders})")
|
||||
values.extend(condition.value)
|
||||
else:
|
||||
where_clauses.append(f"{condition.column} {condition.operator.value} %s")
|
||||
values.append(condition.value)
|
||||
|
||||
where_clause = f" {input_data.combine_conditions.value} ".join(where_clauses)
|
||||
sql += f" WHERE {where_clause}"
|
||||
|
||||
if input_data.limit is not None:
|
||||
sql += f" LIMIT {input_data.limit}"
|
||||
|
||||
cursor.execute(sql, values)
|
||||
rows = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
yield "success", True
|
||||
yield "rows", rows
|
||||
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
yield "success", False
|
||||
yield "error", str(e)
|
||||
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
class PostgresUpdateBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: PostgresCredentialsInput = PostgresCredentialsField()
|
||||
host: str = SchemaField(description="Database host", advanced=False)
|
||||
port: int = SchemaField(description="Database port", advanced=False)
|
||||
database: str = SchemaField(description="Database name", default="postgres", advanced=False)
|
||||
schema_: str = SchemaField(description="Schema name", default="public", advanced=False)
|
||||
table: str = SchemaField(description="Table name")
|
||||
set_data: dict = SchemaField(description="Column-value pairs to update", default={})
|
||||
conditions: List[Condition] = SchemaField(description="Conditions for WHERE clause", default=[], advanced=False)
|
||||
combine_conditions: CombineCondition = SchemaField(
|
||||
description="How to combine multiple conditions",
|
||||
default=CombineCondition.AND,
|
||||
advanced=False
|
||||
)
|
||||
return_updated_rows: bool = SchemaField(description="Return updated rows", default=False)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Operation succeeded")
|
||||
rows_affected: int = SchemaField(description="Number of rows affected")
|
||||
updated_rows: List[dict] = SchemaField(description="Updated rows if requested")
|
||||
error: str = SchemaField(description="Error message if operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a4e3d8c2-7f1b-49d0-8bc6-e479ea3d5752",
|
||||
description="Updates rows in a PostgreSQL table",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=PostgresUpdateBlock.Input,
|
||||
output_schema=PostgresUpdateBlock.Output,
|
||||
test_credentials=TEST_POSTGRES_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_POSTGRES_CREDENTIALS_INPUT,
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"database": "test_db",
|
||||
"schema_": "public",
|
||||
"table": "users",
|
||||
"set_data": {"name": "Updated User", "email": "updated@example.com"},
|
||||
"conditions": [
|
||||
{"column": "id", "operator": ConditionOperator.EQUALS, "value": 1}
|
||||
],
|
||||
"return_updated_rows": True
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("rows_affected", 1),
|
||||
("updated_rows", [{"id": 1, "name": "Updated User", "email": "updated@example.com"}])
|
||||
],
|
||||
test_mock={
|
||||
"run": lambda *args, **kwargs: [
|
||||
("success", True),
|
||||
("rows_affected", 1),
|
||||
("updated_rows", [{"id": 1, "name": "Updated User", "email": "updated@example.com"}])
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: PostgresCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
conn = None
|
||||
try:
|
||||
conn = psycopg2.connect(
|
||||
host=input_data.host,
|
||||
port=input_data.port,
|
||||
database=input_data.database,
|
||||
user=credentials.username.get_secret_value(),
|
||||
password=credentials.password.get_secret_value()
|
||||
)
|
||||
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
if not input_data.set_data:
|
||||
yield "success", True
|
||||
yield "rows_affected", 0
|
||||
yield "updated_rows", []
|
||||
return
|
||||
|
||||
set_clause = ", ".join(f"{k} = %s" for k in input_data.set_data.keys())
|
||||
sql = f"UPDATE {input_data.schema_}.{input_data.table} SET {set_clause}"
|
||||
|
||||
values = list(input_data.set_data.values())
|
||||
|
||||
if input_data.conditions:
|
||||
where_clauses = []
|
||||
|
||||
for condition in input_data.conditions:
|
||||
if condition.operator == ConditionOperator.IN:
|
||||
placeholders = ", ".join(["%s"] * len(condition.value))
|
||||
where_clauses.append(f"{condition.column} IN ({placeholders})")
|
||||
values.extend(condition.value)
|
||||
else:
|
||||
where_clauses.append(f"{condition.column} {condition.operator.value} %s")
|
||||
values.append(condition.value)
|
||||
|
||||
where_clause = f" {input_data.combine_conditions.value} ".join(where_clauses)
|
||||
sql += f" WHERE {where_clause}"
|
||||
|
||||
if input_data.return_updated_rows:
|
||||
sql += " RETURNING *"
|
||||
|
||||
cursor.execute(sql, values)
|
||||
rows_affected = cursor.rowcount
|
||||
|
||||
updated_rows = []
|
||||
if input_data.return_updated_rows:
|
||||
updated_rows = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
conn.commit()
|
||||
yield "success", True
|
||||
yield "rows_affected", rows_affected
|
||||
yield "updated_rows", updated_rows
|
||||
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
yield "success", False
|
||||
yield "error", str(e)
|
||||
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
@@ -26,10 +26,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
@@ -246,6 +246,10 @@ class SmartDecisionMakerBlock(Block):
|
||||
test_credentials=llm.TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def cleanup(s: str):
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
|
||||
|
||||
@staticmethod
|
||||
def _create_block_function_signature(
|
||||
sink_node: "Node", links: list["Link"]
|
||||
@@ -266,7 +270,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
block = sink_node.block
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", block.name).lower(),
|
||||
"name": SmartDecisionMakerBlock.cleanup(block.name),
|
||||
"description": block.description,
|
||||
}
|
||||
|
||||
@@ -281,7 +285,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
and sink_block_input_schema.model_fields[link.sink_name].description
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[link.sink_name.lower()] = {
|
||||
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
}
|
||||
@@ -326,7 +330,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", sink_graph_meta.name).lower(),
|
||||
"name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name),
|
||||
"description": sink_graph_meta.description,
|
||||
}
|
||||
|
||||
@@ -341,7 +345,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
in sink_block_input_schema["properties"][link.sink_name]
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[link.sink_name.lower()] = {
|
||||
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
}
|
||||
@@ -491,6 +495,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=False,
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
@@ -502,7 +507,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
for arg_name, arg_value in tool_args.items():
|
||||
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", arg_value
|
||||
|
||||
response.prompt.append(response.raw_response)
|
||||
yield "conversations", response.prompt
|
||||
|
||||
@@ -28,6 +28,7 @@ from backend.util.settings import Config
|
||||
from .model import (
|
||||
ContributorDetails,
|
||||
Credentials,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
@@ -203,6 +204,15 @@ class BlockSchema(BaseModel):
|
||||
)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
|
||||
return {
|
||||
field_name: CredentialsFieldInfo.model_validate(
|
||||
cls.get_field_schema(field_name), by_alias=True
|
||||
)
|
||||
for field_name in cls.get_credentials_fields().keys()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
return data # Return as is, by default.
|
||||
@@ -509,6 +519,7 @@ async def initialize_blocks() -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_block(block_id: str) -> Block | None:
|
||||
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||
def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
|
||||
cls = get_blocks().get(block_id)
|
||||
return cls() if cls else None
|
||||
|
||||
@@ -36,14 +36,17 @@ from backend.integrations.credentials_store import (
|
||||
# =============== Configure the cost for each LLM Model call =============== #
|
||||
|
||||
MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3: 7,
|
||||
LlmModel.O3_MINI: 2, # $1.10 / $4.40
|
||||
LlmModel.O1: 16, # $15 / $60
|
||||
LlmModel.O1_PREVIEW: 16,
|
||||
LlmModel.O1_MINI: 4,
|
||||
LlmModel.GPT41: 2,
|
||||
LlmModel.GPT4O_MINI: 1,
|
||||
LlmModel.GPT4O: 3,
|
||||
LlmModel.GPT4_TURBO: 10,
|
||||
LlmModel.GPT3_5_TURBO: 1,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_5_SONNET: 4,
|
||||
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
|
||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||
@@ -60,6 +63,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.DEEPSEEK_LLAMA_70B: 1, # ? / ?
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
LlmModel.GEMINI_FLASH_1_5: 1,
|
||||
LlmModel.GEMINI_2_5_PRO: 4,
|
||||
LlmModel.GROK_BETA: 5,
|
||||
LlmModel.MISTRAL_NEMO: 1,
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: 1,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
import stripe
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
@@ -11,6 +11,7 @@ from prisma.enums import (
|
||||
CreditRefundRequestStatus,
|
||||
CreditTransactionType,
|
||||
NotificationType,
|
||||
OnboardingStep,
|
||||
)
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User
|
||||
@@ -19,7 +20,7 @@ from prisma.types import (
|
||||
CreditTransactionCreateInput,
|
||||
CreditTransactionWhereInput,
|
||||
)
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
@@ -27,14 +28,17 @@ from backend.data.cost import BlockCost
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
TopUpType,
|
||||
TransactionHistory,
|
||||
UserTransaction,
|
||||
)
|
||||
from backend.data.notifications import NotificationEventDTO, RefundRequestData
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor.utils import UsageTransactionMetadata
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications import NotificationManagerClient
|
||||
from backend.server.model import Pagination
|
||||
from backend.server.v2.admin.model import UserHistoryResponse
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -44,6 +48,17 @@ logger = logging.getLogger(__name__)
|
||||
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
graph_exec_id: str | None = None
|
||||
graph_id: str | None = None
|
||||
node_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
block_id: str | None = None
|
||||
block: str | None = None
|
||||
input: dict[str, Any] | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class UserCreditBase(ABC):
|
||||
@abstractmethod
|
||||
async def get_credits(self, user_id: str) -> int:
|
||||
@@ -121,6 +136,18 @@ class UserCreditBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
|
||||
"""
|
||||
Reward the user with credits for completing an onboarding step.
|
||||
Won't reward if the user has already received credits for the step.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
step (OnboardingStep): The onboarding step.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def top_up_intent(self, user_id: str, amount: int) -> str:
|
||||
"""
|
||||
@@ -249,11 +276,7 @@ class UserCreditBase(ABC):
|
||||
)
|
||||
return transaction_balance, transaction_time
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
reraise=True,
|
||||
)
|
||||
@func_retry
|
||||
async def _enable_transaction(
|
||||
self,
|
||||
transaction_key: str,
|
||||
@@ -352,21 +375,19 @@ class UserCreditBase(ABC):
|
||||
|
||||
class UserCredit(UserCreditBase):
|
||||
@thread_cached
|
||||
def notification_client(self) -> NotificationManager:
|
||||
return get_service_client(NotificationManager)
|
||||
def notification_client(self) -> NotificationManagerClient:
|
||||
return get_service_client(NotificationManagerClient)
|
||||
|
||||
async def _send_refund_notification(
|
||||
self,
|
||||
notification_request: RefundRequestData,
|
||||
notification_type: NotificationType,
|
||||
):
|
||||
await asyncio.to_thread(
|
||||
lambda: self.notification_client().queue_notification(
|
||||
NotificationEventDTO(
|
||||
user_id=notification_request.user_id,
|
||||
type=notification_type,
|
||||
data=notification_request.model_dump(),
|
||||
)
|
||||
await self.notification_client().queue_notification_async(
|
||||
NotificationEventDTO(
|
||||
user_id=notification_request.user_id,
|
||||
type=notification_type,
|
||||
data=notification_request.model_dump(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -396,6 +417,7 @@ class UserCredit(UserCreditBase):
|
||||
# Avoid multiple auto top-ups within the same graph execution.
|
||||
key=f"AUTO-TOP-UP-{user_id}-{metadata.graph_exec_id}",
|
||||
ceiling_balance=auto_top_up.threshold,
|
||||
top_up_type=TopUpType.AUTO,
|
||||
)
|
||||
except Exception as e:
|
||||
# Failed top-up is not critical, we can move on.
|
||||
@@ -405,8 +427,30 @@ class UserCredit(UserCreditBase):
|
||||
|
||||
return balance
|
||||
|
||||
async def top_up_credits(self, user_id: str, amount: int):
|
||||
await self._top_up_credits(user_id, amount)
|
||||
async def top_up_credits(
|
||||
self,
|
||||
user_id: str,
|
||||
amount: int,
|
||||
top_up_type: TopUpType = TopUpType.UNCATEGORIZED,
|
||||
):
|
||||
await self._top_up_credits(
|
||||
user_id=user_id, amount=amount, top_up_type=top_up_type
|
||||
)
|
||||
|
||||
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
|
||||
try:
|
||||
await self._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=credits,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
transaction_key=f"REWARD-{user_id}-{step.value}",
|
||||
metadata=Json(
|
||||
{"reason": f"Reward for completing {step.value} onboarding step."}
|
||||
),
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# Already rewarded for this step
|
||||
pass
|
||||
|
||||
async def top_up_refund(
|
||||
self, user_id: str, transaction_key: str, metadata: dict[str, str]
|
||||
@@ -571,7 +615,7 @@ class UserCredit(UserCreditBase):
|
||||
|
||||
evidence_text += (
|
||||
f"- {tx.description}: Amount ${tx.amount / 100:.2f} on {tx.transaction_time.isoformat()}, "
|
||||
f"resulting balance ${tx.balance / 100:.2f} {additional_comment}\n"
|
||||
f"resulting balance ${tx.running_balance / 100:.2f} {additional_comment}\n"
|
||||
)
|
||||
evidence_text += (
|
||||
"\nThis evidence demonstrates that the transaction was authorized and that the charged amount was used to render the service as agreed."
|
||||
@@ -590,7 +634,24 @@ class UserCredit(UserCreditBase):
|
||||
amount: int,
|
||||
key: str | None = None,
|
||||
ceiling_balance: int | None = None,
|
||||
top_up_type: TopUpType = TopUpType.UNCATEGORIZED,
|
||||
metadata: dict | None = None,
|
||||
):
|
||||
# init metadata, without sharing it with the world
|
||||
metadata = metadata or {}
|
||||
if not metadata["reason"]:
|
||||
match top_up_type:
|
||||
case TopUpType.MANUAL:
|
||||
metadata["reason"] = {"reason": f"Top up credits for {user_id}"}
|
||||
case TopUpType.AUTO:
|
||||
metadata["reason"] = {
|
||||
"reason": f"Auto top up credits for {user_id}"
|
||||
}
|
||||
case _:
|
||||
metadata["reason"] = {
|
||||
"reason": f"Top up reason unknown for {user_id}"
|
||||
}
|
||||
|
||||
if amount < 0:
|
||||
raise ValueError(f"Top up amount must not be negative: {amount}")
|
||||
|
||||
@@ -613,6 +674,7 @@ class UserCredit(UserCreditBase):
|
||||
is_active=False,
|
||||
transaction_key=key,
|
||||
ceiling_balance=ceiling_balance,
|
||||
metadata=(Json(metadata)),
|
||||
)
|
||||
|
||||
customer_id = await get_stripe_customer_id(user_id)
|
||||
@@ -755,10 +817,15 @@ class UserCredit(UserCreditBase):
|
||||
# Check the Checkout Session's payment_status property
|
||||
# to determine if fulfillment should be performed
|
||||
if checkout_session.payment_status in ["paid", "no_payment_required"]:
|
||||
assert isinstance(checkout_session.payment_intent, stripe.PaymentIntent)
|
||||
if payment_intent := checkout_session.payment_intent:
|
||||
assert isinstance(payment_intent, stripe.PaymentIntent)
|
||||
new_transaction_key = payment_intent.id
|
||||
else:
|
||||
new_transaction_key = None
|
||||
|
||||
await self._enable_transaction(
|
||||
transaction_key=credit_transaction.transactionKey,
|
||||
new_transaction_key=checkout_session.payment_intent.id,
|
||||
new_transaction_key=new_transaction_key,
|
||||
user_id=credit_transaction.userId,
|
||||
metadata=Json(checkout_session),
|
||||
)
|
||||
@@ -791,8 +858,9 @@ class UserCredit(UserCreditBase):
|
||||
take=transaction_count_limit,
|
||||
)
|
||||
|
||||
# doesn't fill current_balance, reason, user_email, admin_email, or extra_data
|
||||
grouped_transactions: dict[str, UserTransaction] = defaultdict(
|
||||
lambda: UserTransaction()
|
||||
lambda: UserTransaction(user_id=user_id)
|
||||
)
|
||||
tx_time = None
|
||||
for t in transactions:
|
||||
@@ -822,7 +890,7 @@ class UserCredit(UserCreditBase):
|
||||
|
||||
if tx_time > gt.transaction_time:
|
||||
gt.transaction_time = tx_time
|
||||
gt.balance = t.runningBalance or 0
|
||||
gt.running_balance = t.runningBalance or 0
|
||||
|
||||
return TransactionHistory(
|
||||
transactions=list(grouped_transactions.values()),
|
||||
@@ -872,6 +940,7 @@ class BetaUserCredit(UserCredit):
|
||||
amount=max(self.num_user_credits_refill - balance, 0),
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
transaction_key=f"MONTHLY-CREDIT-TOP-UP-{cur_time}",
|
||||
metadata=Json({"reason": "Monthly credit refill"}),
|
||||
)
|
||||
return balance
|
||||
except UniqueViolationError:
|
||||
@@ -881,7 +950,7 @@ class BetaUserCredit(UserCredit):
|
||||
|
||||
class DisabledUserCredit(UserCreditBase):
|
||||
async def get_credits(self, *args, **kwargs) -> int:
|
||||
return 0
|
||||
return 100
|
||||
|
||||
async def get_transaction_history(self, *args, **kwargs) -> TransactionHistory:
|
||||
return TransactionHistory(transactions=[], next_transaction_time=None)
|
||||
@@ -895,6 +964,9 @@ class DisabledUserCredit(UserCreditBase):
|
||||
async def top_up_credits(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def onboarding_reward(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def top_up_intent(self, *args, **kwargs) -> str:
|
||||
return ""
|
||||
|
||||
@@ -956,3 +1028,81 @@ async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
|
||||
return AutoTopUpConfig(threshold=0, amount=0)
|
||||
|
||||
return AutoTopUpConfig.model_validate(user.topUpConfig)
|
||||
|
||||
|
||||
async def admin_get_user_history(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
search: str | None = None,
|
||||
transaction_filter: CreditTransactionType | None = None,
|
||||
) -> UserHistoryResponse:
|
||||
|
||||
if page < 1 or page_size < 1:
|
||||
raise ValueError("Invalid pagination input")
|
||||
|
||||
where_clause: CreditTransactionWhereInput = {}
|
||||
if transaction_filter:
|
||||
where_clause["type"] = transaction_filter
|
||||
if search:
|
||||
where_clause["OR"] = [
|
||||
{"userId": {"contains": search, "mode": "insensitive"}},
|
||||
{"User": {"is": {"email": {"contains": search, "mode": "insensitive"}}}},
|
||||
{"User": {"is": {"name": {"contains": search, "mode": "insensitive"}}}},
|
||||
]
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where=where_clause,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
include={"User": True},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
total = await CreditTransaction.prisma().count(where=where_clause)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
history = []
|
||||
for tx in transactions:
|
||||
admin_id = ""
|
||||
admin_email = ""
|
||||
reason = ""
|
||||
|
||||
metadata: dict = cast(dict, tx.metadata) or {}
|
||||
|
||||
if metadata:
|
||||
admin_id = metadata.get("admin_id")
|
||||
admin_email = (
|
||||
(await get_user_email_by_id(admin_id) or f"Unknown Admin: {admin_id}")
|
||||
if admin_id
|
||||
else ""
|
||||
)
|
||||
reason = metadata.get("reason", "No reason provided")
|
||||
|
||||
balance, last_update = await get_user_credit_model()._get_credits(tx.userId)
|
||||
|
||||
history.append(
|
||||
UserTransaction(
|
||||
transaction_key=tx.transactionKey,
|
||||
transaction_time=tx.createdAt,
|
||||
transaction_type=tx.type,
|
||||
amount=tx.amount,
|
||||
current_balance=balance,
|
||||
running_balance=tx.runningBalance or 0,
|
||||
user_id=tx.userId,
|
||||
user_email=(
|
||||
tx.User.email
|
||||
if tx.User
|
||||
else (await get_user_by_id(tx.userId)).email
|
||||
),
|
||||
reason=reason,
|
||||
admin_email=admin_email,
|
||||
extra_data=str(metadata),
|
||||
)
|
||||
)
|
||||
return UserHistoryResponse(
|
||||
history=history,
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -62,10 +62,10 @@ async def connect():
|
||||
|
||||
# Connection acquired from a pool like Supabase somehow still possibly allows
|
||||
# the db client obtains a connection but still reject query connection afterward.
|
||||
try:
|
||||
await prisma.execute_raw("SELECT 1")
|
||||
except Exception as e:
|
||||
raise ConnectionError("Failed to connect to Prisma.") from e
|
||||
# try:
|
||||
# await prisma.execute_raw("SELECT 1")
|
||||
# except Exception as e:
|
||||
# raise ConnectionError("Failed to connect to Prisma.") from e
|
||||
|
||||
|
||||
@conn_retry("Prisma", "Releasing connection")
|
||||
|
||||
@@ -34,18 +34,17 @@ from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import mock
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .block import BlockData, BlockInput, BlockType, CompletedBlockOutput, get_block
|
||||
from .block import BlockInput, BlockType, CompletedBlockOutput, get_block
|
||||
from .db import BaseDbModel
|
||||
from .includes import (
|
||||
EXECUTION_RESULT_INCLUDE,
|
||||
GRAPH_EXECUTION_INCLUDE,
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
||||
from .queue import AsyncRedisEventBus, RedisEventBus
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -203,6 +202,15 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
node_executions=node_executions,
|
||||
)
|
||||
|
||||
def to_graph_execution_entry(self):
|
||||
return GraphExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_id=self.graph_id,
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
node_credentials_input_map={}, # FIXME
|
||||
)
|
||||
|
||||
|
||||
class NodeExecutionResult(BaseModel):
|
||||
user_id: str
|
||||
@@ -260,6 +268,17 @@ class NodeExecutionResult(BaseModel):
|
||||
end_time=_node_exec.endedTime,
|
||||
)
|
||||
|
||||
def to_node_execution_entry(self) -> "NodeExecutionEntry":
|
||||
return NodeExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_exec_id=self.graph_exec_id,
|
||||
graph_id=self.graph_id,
|
||||
node_exec_id=self.node_exec_id,
|
||||
node_id=self.node_id,
|
||||
block_id=self.block_id,
|
||||
data=self.input_data,
|
||||
)
|
||||
|
||||
|
||||
# --------------------- Model functions --------------------- #
|
||||
|
||||
@@ -342,7 +361,7 @@ async def get_graph_execution(
|
||||
async def create_graph_execution(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
nodes_input: list[tuple[str, BlockInput]],
|
||||
starting_nodes_input: list[tuple[str, BlockInput]],
|
||||
user_id: str,
|
||||
preset_id: str | None = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
@@ -369,7 +388,7 @@ async def create_graph_execution(
|
||||
]
|
||||
},
|
||||
)
|
||||
for node_id, node_input in nodes_input
|
||||
for node_id, node_input in starting_nodes_input
|
||||
]
|
||||
},
|
||||
userId=user_id,
|
||||
@@ -469,7 +488,9 @@ async def upsert_execution_output(
|
||||
)
|
||||
|
||||
|
||||
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecution:
|
||||
async def update_graph_execution_start_time(
|
||||
graph_exec_id: str,
|
||||
) -> GraphExecution | None:
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
data={
|
||||
@@ -478,10 +499,7 @@ async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecutio
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
if not res:
|
||||
raise ValueError(f"Graph execution #{graph_exec_id} not found")
|
||||
|
||||
return GraphExecution.from_db(res)
|
||||
return GraphExecution.from_db(res) if res else None
|
||||
|
||||
|
||||
async def update_graph_execution_stats(
|
||||
@@ -597,8 +615,9 @@ async def delete_graph_execution(
|
||||
)
|
||||
|
||||
|
||||
async def get_node_execution_results(
|
||||
async def get_node_executions(
|
||||
graph_exec_id: str,
|
||||
node_id: str | None = None,
|
||||
block_ids: list[str] | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
limit: int | None = None,
|
||||
@@ -606,6 +625,8 @@ async def get_node_execution_results(
|
||||
where_clause: AgentNodeExecutionWhereInput = {
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
}
|
||||
if node_id:
|
||||
where_clause["agentNodeId"] = node_id
|
||||
if block_ids:
|
||||
where_clause["Node"] = {"is": {"agentBlockId": {"in": block_ids}}}
|
||||
if statuses:
|
||||
@@ -662,20 +683,6 @@ async def get_latest_node_execution(
|
||||
return NodeExecutionResult.from_db(execution)
|
||||
|
||||
|
||||
async def get_incomplete_node_executions(
|
||||
node_id: str, graph_eid: str
|
||||
) -> list[NodeExecutionResult]:
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
},
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
return [NodeExecutionResult.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
# ----------------- Execution Infrastructure ----------------- #
|
||||
|
||||
|
||||
@@ -684,7 +691,7 @@ class GraphExecutionEntry(BaseModel):
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
start_node_execs: list["NodeExecutionEntry"]
|
||||
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]]
|
||||
|
||||
|
||||
class NodeExecutionEntry(BaseModel):
|
||||
@@ -717,144 +724,6 @@ class ExecutionQueue(Generic[T]):
|
||||
return self.queue.empty()
|
||||
|
||||
|
||||
# ------------------- Execution Utilities -------------------- #
|
||||
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
"""
|
||||
Extracts partial output data by name from a given BlockData.
|
||||
|
||||
The function supports extracting data from lists, dictionaries, and objects
|
||||
using specific naming conventions:
|
||||
- For lists: <output_name>_$_<index>
|
||||
- For dictionaries: <output_name>_#_<key>
|
||||
- For objects: <output_name>_@_<attribute>
|
||||
|
||||
Args:
|
||||
output (BlockData): A tuple containing the output name and data.
|
||||
name (str): The name used to extract specific data from the output.
|
||||
|
||||
Returns:
|
||||
Any | None: The extracted data if found, otherwise None.
|
||||
|
||||
Examples:
|
||||
>>> output = ("result", [10, 20, 30])
|
||||
>>> parse_execution_output(output, "result_$_1")
|
||||
20
|
||||
|
||||
>>> output = ("config", {"key1": "value1", "key2": "value2"})
|
||||
>>> parse_execution_output(output, "config_#_key1")
|
||||
'value1'
|
||||
|
||||
>>> class Sample:
|
||||
... attr1 = "value1"
|
||||
... attr2 = "value2"
|
||||
>>> output = ("object", Sample())
|
||||
>>> parse_execution_output(output, "object_@_attr1")
|
||||
'value1'
|
||||
"""
|
||||
output_name, output_data = output
|
||||
|
||||
if name == output_name:
|
||||
return output_data
|
||||
|
||||
if name.startswith(f"{output_name}{LIST_SPLIT}"):
|
||||
index = int(name.split(LIST_SPLIT)[1])
|
||||
if not isinstance(output_data, list) or len(output_data) <= index:
|
||||
return None
|
||||
return output_data[int(name.split(LIST_SPLIT)[1])]
|
||||
|
||||
if name.startswith(f"{output_name}{DICT_SPLIT}"):
|
||||
index = name.split(DICT_SPLIT)[1]
|
||||
if not isinstance(output_data, dict) or index not in output_data:
|
||||
return None
|
||||
return output_data[index]
|
||||
|
||||
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
|
||||
index = name.split(OBJC_SPLIT)[1]
|
||||
if isinstance(output_data, object) and hasattr(output_data, index):
|
||||
return getattr(output_data, index)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
"""
|
||||
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
|
||||
|
||||
This function processes input keys that follow specific patterns to merge them into a unified structure:
|
||||
- `<input_name>_$_<index>` for list inputs.
|
||||
- `<input_name>_#_<index>` for dictionary inputs.
|
||||
- `<input_name>_@_<index>` for object inputs.
|
||||
|
||||
Args:
|
||||
data (BlockInput): A dictionary containing input keys and their corresponding values.
|
||||
|
||||
Returns:
|
||||
BlockInput: A dictionary with merged inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If a list index is not an integer.
|
||||
|
||||
Examples:
|
||||
>>> data = {
|
||||
... "list_$_0": "a",
|
||||
... "list_$_1": "b",
|
||||
... "dict_#_key1": "value1",
|
||||
... "dict_#_key2": "value2",
|
||||
... "object_@_attr1": "value1",
|
||||
... "object_@_attr2": "value2"
|
||||
... }
|
||||
>>> merge_execution_input(data)
|
||||
{
|
||||
"list": ["a", "b"],
|
||||
"dict": {"key1": "value1", "key2": "value2"},
|
||||
"object": <MockObject attr1="value1" attr2="value2">
|
||||
}
|
||||
"""
|
||||
|
||||
# Merge all input with <input_name>_$_<index> into a single list.
|
||||
items = list(data.items())
|
||||
|
||||
for key, value in items:
|
||||
if LIST_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(LIST_SPLIT)
|
||||
if not index.isdigit():
|
||||
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
|
||||
|
||||
data[name] = data.get(name, [])
|
||||
if int(index) >= len(data[name]):
|
||||
# Pad list with empty string on missing indices.
|
||||
data[name].extend([""] * (int(index) - len(data[name]) + 1))
|
||||
data[name][int(index)] = value
|
||||
|
||||
# Merge all input with <input_name>_#_<index> into a single dict.
|
||||
for key, value in items:
|
||||
if DICT_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(DICT_SPLIT)
|
||||
data[name] = data.get(name, {})
|
||||
data[name][index] = value
|
||||
|
||||
# Merge all input with <input_name>_@_<index> into a single object.
|
||||
for key, value in items:
|
||||
if OBJC_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(OBJC_SPLIT)
|
||||
if name not in data or not isinstance(data[name], object):
|
||||
data[name] = mock.MockObject()
|
||||
setattr(data[name], index, value)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# --------------------- Event Bus --------------------- #
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any, Literal, Optional, Type, cast
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
import prisma
|
||||
from prisma import Json
|
||||
@@ -13,12 +13,19 @@ from prisma.types import (
|
||||
AgentNodeCreateInput,
|
||||
AgentNodeLinkCreateInput,
|
||||
)
|
||||
from pydantic import create_model
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import prisma as db
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
from backend.util import type as type_utils
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
@@ -165,6 +172,8 @@ class BaseGraph(BaseDbModel):
|
||||
description: str
|
||||
nodes: list[Node] = []
|
||||
links: list[Link] = []
|
||||
forked_from_id: str | None = None
|
||||
forked_from_version: int | None = None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
@@ -190,14 +199,19 @@ class BaseGraph(BaseDbModel):
|
||||
)
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
return self._credentials_input_schema.jsonschema()
|
||||
|
||||
@staticmethod
|
||||
def _generate_schema(
|
||||
*props: tuple[Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input], dict],
|
||||
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
|
||||
) -> dict[str, Any]:
|
||||
schema = []
|
||||
schema_fields: list[AgentInputBlock.Input | AgentOutputBlock.Input] = []
|
||||
for type_class, input_default in props:
|
||||
try:
|
||||
schema.append(type_class(**input_default))
|
||||
schema_fields.append(type_class(**input_default))
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid {type_class}: {input_default}, {e}")
|
||||
|
||||
@@ -217,9 +231,93 @@ class BaseGraph(BaseDbModel):
|
||||
**({"description": p.description} if p.description else {}),
|
||||
**({"default": p.value} if p.value is not None else {}),
|
||||
}
|
||||
for p in schema
|
||||
for p in schema_fields
|
||||
},
|
||||
"required": [p.name for p in schema if p.value is None],
|
||||
"required": [p.name for p in schema_fields if p.value is None],
|
||||
}
|
||||
|
||||
@property
|
||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
logger.debug(
|
||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||
f"{graph_credentials_inputs}"
|
||||
)
|
||||
|
||||
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
||||
graph_cred_fields = list(graph_credentials_inputs.values())
|
||||
for i, (field, keys) in enumerate(graph_cred_fields):
|
||||
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
||||
if field.provider != other_field.provider:
|
||||
continue
|
||||
|
||||
# If this happens, that means a block implementation probably needs
|
||||
# to be updated.
|
||||
logger.warning(
|
||||
"Multiple combined credentials fields "
|
||||
f"for provider {field.provider} "
|
||||
f"on graph #{self.id} ({self.name}); "
|
||||
f"fields: {field} <> {other_field};"
|
||||
f"keys: {keys} <> {other_keys}."
|
||||
)
|
||||
|
||||
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
|
||||
agg_field_key: (
|
||||
CredentialsMetaInput[
|
||||
Literal[tuple(field_info.provider)], # type: ignore
|
||||
Literal[tuple(field_info.supported_types)], # type: ignore
|
||||
],
|
||||
CredentialsField(
|
||||
required_scopes=set(field_info.required_scopes or []),
|
||||
discriminator=field_info.discriminator,
|
||||
discriminator_mapping=field_info.discriminator_mapping,
|
||||
),
|
||||
)
|
||||
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
||||
}
|
||||
|
||||
return create_model(
|
||||
self.name.replace(" ", "") + "CredentialsInputSchema",
|
||||
__base__=BlockSchema,
|
||||
**fields, # type: ignore
|
||||
)
|
||||
|
||||
def aggregate_credentials_inputs(
|
||||
self,
|
||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
|
||||
"""
|
||||
Returns:
|
||||
dict[aggregated_field_key, tuple(
|
||||
CredentialsFieldInfo: A spec for one aggregated credentials field
|
||||
set[(node_id, field_name)]: Node credentials fields that are
|
||||
compatible with this aggregated field spec
|
||||
)]
|
||||
"""
|
||||
return {
|
||||
"_".join(sorted(agg_field_info.provider))
|
||||
+ "_"
|
||||
+ "_".join(sorted(agg_field_info.supported_types))
|
||||
+ "_credentials": (agg_field_info, node_fields)
|
||||
for agg_field_info, node_fields in CredentialsFieldInfo.combine(
|
||||
*(
|
||||
(
|
||||
# Apply discrimination before aggregating credentials inputs
|
||||
(
|
||||
field_info.discriminate(
|
||||
node.input_default[field_info.discriminator]
|
||||
)
|
||||
if (
|
||||
field_info.discriminator
|
||||
and node.input_default.get(field_info.discriminator)
|
||||
)
|
||||
else field_info
|
||||
),
|
||||
(node.id, field_name),
|
||||
)
|
||||
for node in self.nodes
|
||||
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -313,15 +411,16 @@ class GraphModel(Graph):
|
||||
|
||||
@staticmethod
|
||||
def _validate_graph(graph: BaseGraph, for_run: bool = False):
|
||||
def is_tool_pin(name: str) -> bool:
|
||||
return name.startswith("tools_^_")
|
||||
|
||||
def sanitize(name):
|
||||
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
if sanitized_name.startswith("tools_^_"):
|
||||
return sanitized_name.split("_^_")[0]
|
||||
if is_tool_pin(sanitized_name):
|
||||
return "tools"
|
||||
return sanitized_name
|
||||
|
||||
# Validate smart decision maker nodes
|
||||
smart_decision_maker_nodes = set()
|
||||
agent_nodes = set()
|
||||
nodes_block = {
|
||||
node.id: block
|
||||
for node in graph.nodes
|
||||
@@ -332,13 +431,6 @@ class GraphModel(Graph):
|
||||
if (block := nodes_block.get(node.id)) is None:
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
# Smart decision maker nodes
|
||||
if block.block_type == BlockType.AI:
|
||||
smart_decision_maker_nodes.add(node.id)
|
||||
# Agent nodes
|
||||
elif block.block_type == BlockType.AGENT:
|
||||
agent_nodes.add(node.id)
|
||||
|
||||
input_links = defaultdict(list)
|
||||
|
||||
for link in graph.links:
|
||||
@@ -353,16 +445,21 @@ class GraphModel(Graph):
|
||||
[sanitize(name) for name in node.input_default]
|
||||
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
|
||||
)
|
||||
for name in block.input_schema.get_required_fields():
|
||||
input_schema = block.input_schema
|
||||
for name in (required_fields := input_schema.get_required_fields()):
|
||||
if (
|
||||
name not in provided_inputs
|
||||
# Webhook payload is passed in by ExecutionManager
|
||||
and not (
|
||||
name == "payload"
|
||||
and block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
)
|
||||
# Checking availability of credentials is done by ExecutionManager
|
||||
and name not in input_schema.get_credentials_fields()
|
||||
# Validate only I/O nodes, or validate everything when executing
|
||||
and (
|
||||
for_run # Skip input completion validation, unless when executing.
|
||||
for_run
|
||||
or block.block_type
|
||||
in [
|
||||
BlockType.INPUT,
|
||||
@@ -375,9 +472,18 @@ class GraphModel(Graph):
|
||||
f"Node {block.name} #{node.id} required input missing: `{name}`"
|
||||
)
|
||||
|
||||
if (
|
||||
block.block_type == BlockType.INPUT
|
||||
and (input_key := node.input_default.get("name"))
|
||||
and is_credentials_field_name(input_key)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Agent input node uses reserved name '{input_key}'; "
|
||||
"'credentials' and `*_credentials` are reserved input names"
|
||||
)
|
||||
|
||||
# Get input schema properties and check dependencies
|
||||
input_schema = block.input_schema.model_fields
|
||||
required_fields = block.input_schema.get_required_fields()
|
||||
input_fields = input_schema.model_fields
|
||||
|
||||
def has_value(name):
|
||||
return (
|
||||
@@ -385,14 +491,21 @@ class GraphModel(Graph):
|
||||
and name in node.input_default
|
||||
and node.input_default[name] is not None
|
||||
and str(node.input_default[name]).strip() != ""
|
||||
) or (name in input_schema and input_schema[name].default is not None)
|
||||
) or (name in input_fields and input_fields[name].default is not None)
|
||||
|
||||
# Validate dependencies between fields
|
||||
for field_name, field_info in input_schema.items():
|
||||
for field_name, field_info in input_fields.items():
|
||||
# Apply input dependency validation only on run & field with depends_on
|
||||
json_schema_extra = field_info.json_schema_extra or {}
|
||||
dependencies = json_schema_extra.get("depends_on", [])
|
||||
if not for_run or not dependencies:
|
||||
if not (
|
||||
for_run
|
||||
and isinstance(json_schema_extra, dict)
|
||||
and (
|
||||
dependencies := cast(
|
||||
list[str], json_schema_extra.get("depends_on", [])
|
||||
)
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
# Check if dependent field has value in input_default
|
||||
@@ -445,7 +558,7 @@ class GraphModel(Graph):
|
||||
if block.block_type not in [BlockType.AGENT]
|
||||
else vals.get("input_schema", {}).get("properties", {}).keys()
|
||||
)
|
||||
if sanitized_name not in fields and not name.startswith("tools_^_"):
|
||||
if sanitized_name not in fields and not is_tool_pin(name):
|
||||
fields_msg = f"Allowed fields: {fields}"
|
||||
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
|
||||
|
||||
@@ -462,6 +575,8 @@ class GraphModel(Graph):
|
||||
id=graph.id,
|
||||
user_id=graph.userId if not for_export else "",
|
||||
version=graph.version,
|
||||
forked_from_id=graph.forkedFromId,
|
||||
forked_from_version=graph.forkedFromVersion,
|
||||
is_active=graph.isActive,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
@@ -621,6 +736,58 @@ async def get_graph(
|
||||
return GraphModel.from_db(graph, for_export)
|
||||
|
||||
|
||||
async def get_graph_as_admin(
|
||||
graph_id: str,
|
||||
version: int | None = None,
|
||||
user_id: str | None = None,
|
||||
for_export: bool = False,
|
||||
) -> GraphModel | None:
|
||||
"""
|
||||
Intentionally parallels the get_graph but should only be used for admin tasks, because can return any graph that's been submitted
|
||||
Retrieves a graph from the DB.
|
||||
Defaults to the version with `is_active` if `version` is not passed.
|
||||
|
||||
Returns `None` if the record is not found.
|
||||
"""
|
||||
logger.warning(f"Getting {graph_id=} {version=} as ADMIN {user_id=} {for_export=}")
|
||||
where_clause: AgentGraphWhereInput = {
|
||||
"id": graph_id,
|
||||
}
|
||||
|
||||
if version is not None:
|
||||
where_clause["version"] = version
|
||||
|
||||
graph = await AgentGraph.prisma().find_first(
|
||||
where=where_clause,
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
order={"version": "desc"},
|
||||
)
|
||||
|
||||
# For access, the graph must be owned by the user or listed in the store
|
||||
if graph is None or (
|
||||
graph.userId != user_id
|
||||
and not (
|
||||
await StoreListingVersion.prisma().find_first(
|
||||
where={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": version or graph.version,
|
||||
}
|
||||
)
|
||||
)
|
||||
):
|
||||
return None
|
||||
|
||||
if for_export:
|
||||
sub_graphs = await get_sub_graphs(graph)
|
||||
return GraphModel.from_db(
|
||||
graph=graph,
|
||||
sub_graphs=sub_graphs,
|
||||
for_export=for_export,
|
||||
)
|
||||
|
||||
return GraphModel.from_db(graph, for_export)
|
||||
|
||||
|
||||
async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
"""
|
||||
Iteratively fetches all sub-graphs of a given graph, and flattens them into a list.
|
||||
@@ -739,6 +906,27 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
|
||||
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
|
||||
|
||||
|
||||
async def fork_graph(graph_id: str, graph_version: int, user_id: str) -> GraphModel:
|
||||
"""
|
||||
Forks a graph by copying it and all its nodes and links to a new graph.
|
||||
"""
|
||||
async with transaction() as tx:
|
||||
graph = await get_graph(graph_id, graph_version, user_id, True)
|
||||
if not graph:
|
||||
raise ValueError(f"Graph {graph_id} v{graph_version} not found")
|
||||
|
||||
# Set forked from ID and version as itself as it's about ot be copied
|
||||
graph.forked_from_id = graph.id
|
||||
graph.forked_from_version = graph.version
|
||||
graph.name = f"{graph.name} (copy)"
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||
graph.validate_graph(for_run=False)
|
||||
|
||||
await __create_graph(tx, graph, user_id)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
graphs = [graph] + graph.sub_graphs
|
||||
|
||||
@@ -751,6 +939,8 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
description=graph.description,
|
||||
isActive=graph.is_active,
|
||||
userId=user_id,
|
||||
forkedFromId=graph.forked_from_id,
|
||||
forkedFromVersion=graph.forked_from_version,
|
||||
)
|
||||
for graph in graphs
|
||||
]
|
||||
@@ -914,24 +1104,24 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
if field.annotation == LlmModel:
|
||||
llm_model_fields[block.id] = field_name
|
||||
|
||||
# Convert enum values to a list of strings for the SQL query
|
||||
enum_values = [v.value for v in LlmModel]
|
||||
escaped_enum_values = repr(tuple(enum_values)) # hack but works
|
||||
|
||||
# Update each block
|
||||
for id, path in llm_model_fields.items():
|
||||
# Convert enum values to a list of strings for the SQL query
|
||||
enum_values = [v.value for v in LlmModel.__members__.values()]
|
||||
|
||||
escaped_enum_values = repr(tuple(enum_values)) # hack but works
|
||||
query = f"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = jsonb_set("constantInput", $1, $2, true)
|
||||
UPDATE platform."AgentNode"
|
||||
SET "constantInput" = jsonb_set("constantInput", $1, to_jsonb($2), true)
|
||||
WHERE "agentBlockId" = $3
|
||||
AND "constantInput" ? $4
|
||||
AND "constantInput"->>$4 NOT IN {escaped_enum_values}
|
||||
AND "constantInput" ? ($4)::text
|
||||
AND "constantInput"->>($4)::text NOT IN {escaped_enum_values}
|
||||
"""
|
||||
|
||||
await db.execute_raw(
|
||||
query, # type: ignore - is supposed to be LiteralString
|
||||
"{" + path + "}",
|
||||
f'"{migrate_to.value}"',
|
||||
[path],
|
||||
migrate_to.value,
|
||||
id,
|
||||
path,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import enum
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -12,6 +14,7 @@ from typing import (
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
get_args,
|
||||
@@ -300,9 +303,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
field_schema = model.jsonschema()["properties"][field_name]
|
||||
try:
|
||||
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
|
||||
field_schema
|
||||
)
|
||||
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
||||
except ValidationError as e:
|
||||
if "Field required [type=missing" not in str(e):
|
||||
raise
|
||||
@@ -328,14 +329,90 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
|
||||
|
||||
class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
|
||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
|
||||
credentials_provider: list[CP]
|
||||
credentials_scopes: Optional[list[str]] = None
|
||||
credentials_types: list[CT]
|
||||
provider: frozenset[CP] = Field(..., alias="credentials_provider")
|
||||
supported_types: frozenset[CT] = Field(..., alias="credentials_types")
|
||||
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
|
||||
discriminator: Optional[str] = None
|
||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||
|
||||
@classmethod
|
||||
def combine(
|
||||
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
|
||||
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
|
||||
"""
|
||||
Combines multiple CredentialsFieldInfo objects into as few as possible.
|
||||
|
||||
Rules:
|
||||
- Items can only be combined if they have the same supported credentials types
|
||||
and the same supported providers.
|
||||
- When combining items, the `required_scopes` of the result is a join
|
||||
of the `required_scopes` of the original items.
|
||||
|
||||
Params:
|
||||
*fields: (CredentialsFieldInfo, key) objects to group and combine
|
||||
|
||||
Returns:
|
||||
A sequence of tuples containing combined CredentialsFieldInfo objects and
|
||||
the set of keys of the respective original items that were grouped together.
|
||||
"""
|
||||
if not fields:
|
||||
return []
|
||||
|
||||
# Group fields by their provider and supported_types
|
||||
grouped_fields: defaultdict[
|
||||
tuple[frozenset[CP], frozenset[CT]],
|
||||
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
|
||||
] = defaultdict(list)
|
||||
|
||||
for field, key in fields:
|
||||
group_key = (frozenset(field.provider), frozenset(field.supported_types))
|
||||
grouped_fields[group_key].append((key, field))
|
||||
|
||||
# Combine fields within each group
|
||||
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
|
||||
|
||||
for group in grouped_fields.values():
|
||||
# Start with the first field in the group
|
||||
_, combined = group[0]
|
||||
|
||||
# Track the keys that were combined
|
||||
combined_keys = {key for key, _ in group}
|
||||
|
||||
# Combine required_scopes from all fields in the group
|
||||
all_scopes = set()
|
||||
for _, field in group:
|
||||
if field.required_scopes:
|
||||
all_scopes.update(field.required_scopes)
|
||||
|
||||
# Create a new combined field
|
||||
result.append(
|
||||
(
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
credentials_types=combined.supported_types,
|
||||
credentials_scopes=frozenset(all_scopes) or None,
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
),
|
||||
combined_keys,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def discriminate(self, discriminator_value: Any) -> CredentialsFieldInfo:
|
||||
if not (self.discriminator and self.discriminator_mapping):
|
||||
return self
|
||||
|
||||
discriminator_value = self.discriminator_mapping[discriminator_value]
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([discriminator_value]),
|
||||
credentials_types=self.supported_types,
|
||||
credentials_scopes=self.required_scopes,
|
||||
)
|
||||
|
||||
|
||||
def CredentialsField(
|
||||
required_scopes: set[str] = set(),
|
||||
@@ -373,6 +450,12 @@ class ContributorDetails(BaseModel):
|
||||
name: str = Field(title="Name", description="The name of the contributor.")
|
||||
|
||||
|
||||
class TopUpType(enum.Enum):
|
||||
AUTO = "AUTO"
|
||||
MANUAL = "MANUAL"
|
||||
UNCATEGORIZED = "UNCATEGORIZED"
|
||||
|
||||
|
||||
class AutoTopUpConfig(BaseModel):
|
||||
amount: int
|
||||
"""Amount of credits to top up."""
|
||||
@@ -385,12 +468,18 @@ class UserTransaction(BaseModel):
|
||||
transaction_time: datetime = datetime.min.replace(tzinfo=timezone.utc)
|
||||
transaction_type: CreditTransactionType = CreditTransactionType.USAGE
|
||||
amount: int = 0
|
||||
balance: int = 0
|
||||
running_balance: int = 0
|
||||
current_balance: int = 0
|
||||
description: str | None = None
|
||||
usage_graph_id: str | None = None
|
||||
usage_execution_id: str | None = None
|
||||
usage_node_count: int = 0
|
||||
usage_start_time: datetime = datetime.max.replace(tzinfo=timezone.utc)
|
||||
user_id: str
|
||||
user_email: str | None = None
|
||||
reason: str | None = None
|
||||
admin_email: str | None = None
|
||||
extra_data: str | None = None
|
||||
|
||||
|
||||
class TransactionHistory(BaseModel):
|
||||
|
||||
@@ -8,7 +8,9 @@ from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block import get_blocks
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
@@ -24,14 +26,19 @@ REASON_MAPPING: dict[str, list[str]] = {
|
||||
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
|
||||
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
|
||||
|
||||
user_credit = get_user_credit_model()
|
||||
|
||||
|
||||
class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
completedSteps: Optional[list[OnboardingStep]] = None
|
||||
notificationDot: Optional[bool] = None
|
||||
notified: Optional[list[OnboardingStep]] = None
|
||||
usageReason: Optional[str] = None
|
||||
integrations: Optional[list[str]] = None
|
||||
otherIntegrations: Optional[str] = None
|
||||
selectedStoreListingVersionId: Optional[str] = None
|
||||
agentInput: Optional[dict[str, Any]] = None
|
||||
onboardingAgentExecutionId: Optional[str] = None
|
||||
|
||||
|
||||
async def get_user_onboarding(user_id: str):
|
||||
@@ -48,6 +55,20 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update: UserOnboardingUpdateInput = {}
|
||||
if data.completedSteps is not None:
|
||||
update["completedSteps"] = list(set(data.completedSteps))
|
||||
for step in (
|
||||
OnboardingStep.AGENT_NEW_RUN,
|
||||
OnboardingStep.GET_RESULTS,
|
||||
OnboardingStep.MARKETPLACE_ADD_AGENT,
|
||||
OnboardingStep.MARKETPLACE_RUN_AGENT,
|
||||
OnboardingStep.BUILDER_SAVE_AGENT,
|
||||
OnboardingStep.BUILDER_RUN_AGENT,
|
||||
):
|
||||
if step in data.completedSteps:
|
||||
await reward_user(user_id, step)
|
||||
if data.notificationDot is not None:
|
||||
update["notificationDot"] = data.notificationDot
|
||||
if data.notified is not None:
|
||||
update["notified"] = list(set(data.notified))
|
||||
if data.usageReason is not None:
|
||||
update["usageReason"] = data.usageReason
|
||||
if data.integrations is not None:
|
||||
@@ -58,6 +79,8 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update["selectedStoreListingVersionId"] = data.selectedStoreListingVersionId
|
||||
if data.agentInput is not None:
|
||||
update["agentInput"] = Json(data.agentInput)
|
||||
if data.onboardingAgentExecutionId is not None:
|
||||
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
|
||||
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
@@ -68,6 +91,45 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
)
|
||||
|
||||
|
||||
async def reward_user(user_id: str, step: OnboardingStep):
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}-reward"):
|
||||
reward = 0
|
||||
match step:
|
||||
# Reward user when they clicked New Run during onboarding
|
||||
# This is because they need credits before scheduling a run (next step)
|
||||
case OnboardingStep.AGENT_NEW_RUN:
|
||||
reward = 300
|
||||
case OnboardingStep.GET_RESULTS:
|
||||
reward = 300
|
||||
case OnboardingStep.MARKETPLACE_ADD_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.MARKETPLACE_RUN_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_SAVE_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_RUN_AGENT:
|
||||
reward = 100
|
||||
|
||||
if reward == 0:
|
||||
return
|
||||
|
||||
onboarding = await get_user_onboarding(user_id)
|
||||
|
||||
# Skip if already rewarded
|
||||
if step in onboarding.rewardedFor:
|
||||
return
|
||||
|
||||
onboarding.rewardedFor.append(step)
|
||||
await user_credit.onboarding_reward(user_id, reward, step)
|
||||
await UserOnboarding.prisma().update(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"completedSteps": list(set(onboarding.completedSteps + [step])),
|
||||
"rewardedFor": onboarding.rewardedFor,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def clean_and_split(text: str) -> list[str]:
|
||||
"""
|
||||
Removes all special characters from a string, truncates it to 100 characters,
|
||||
|
||||
@@ -4,10 +4,18 @@ from enum import Enum
|
||||
from typing import Awaitable, Optional
|
||||
|
||||
import aio_pika
|
||||
import aio_pika.exceptions as aio_ex
|
||||
import pika
|
||||
import pika.adapters.blocking_connection
|
||||
from pika.exceptions import AMQPError
|
||||
from pika.spec import BasicProperties
|
||||
from pydantic import BaseModel
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Settings
|
||||
@@ -161,6 +169,12 @@ class SyncRabbitMQ(RabbitMQBase):
|
||||
routing_key=queue.routing_key or queue.name,
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type((AMQPError, ConnectionError)),
|
||||
wait=wait_random_exponential(multiplier=1, max=5),
|
||||
stop=stop_after_attempt(5),
|
||||
reraise=True,
|
||||
)
|
||||
def publish_message(
|
||||
self,
|
||||
routing_key: str,
|
||||
@@ -258,6 +272,12 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
exchange, routing_key=queue.routing_key or queue.name
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type((aio_ex.AMQPError, ConnectionError)),
|
||||
wait=wait_random_exponential(multiplier=1, max=5),
|
||||
stop=stop_after_attempt(5),
|
||||
reraise=True,
|
||||
)
|
||||
async def publish_message(
|
||||
self,
|
||||
routing_key: str,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from .database import DatabaseManager
|
||||
from .database import DatabaseManager, DatabaseManagerClient
|
||||
from .manager import ExecutionManager
|
||||
from .scheduler import Scheduler
|
||||
|
||||
__all__ = [
|
||||
"DatabaseManager",
|
||||
"DatabaseManagerClient",
|
||||
"ExecutionManager",
|
||||
"Scheduler",
|
||||
]
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import logging
|
||||
from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data import db, redis
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
GraphExecution,
|
||||
NodeExecutionResult,
|
||||
RedisExecutionEventBus,
|
||||
create_graph_execution,
|
||||
get_graph_execution,
|
||||
get_incomplete_node_executions,
|
||||
get_graph_execution_meta,
|
||||
get_latest_node_execution,
|
||||
get_node_execution_results,
|
||||
get_node_executions,
|
||||
update_graph_execution_start_time,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_stats,
|
||||
@@ -42,12 +40,14 @@ from backend.data.user import (
|
||||
update_user_integrations,
|
||||
update_user_metadata,
|
||||
)
|
||||
from backend.util.service import AppService, expose, exposed_run_and_wait
|
||||
from backend.util.service import AppService, AppServiceClient, endpoint_to_sync, expose
|
||||
from backend.util.settings import Config
|
||||
|
||||
config = Config()
|
||||
_user_credit_model = get_user_credit_model()
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
async def _spend_credits(
|
||||
@@ -56,22 +56,19 @@ async def _spend_credits(
|
||||
return await _user_credit_model.spend_credits(user_id, cost, metadata)
|
||||
|
||||
|
||||
async def _get_credits(user_id: str) -> int:
|
||||
return await _user_credit_model.get_credits(user_id)
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.execution_event_bus = RedisExecutionEventBus()
|
||||
|
||||
def run_service(self) -> None:
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
self.run_and_wait(db.connect())
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
|
||||
redis.connect()
|
||||
super().run_service()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
|
||||
@@ -79,64 +76,113 @@ class DatabaseManager(AppService):
|
||||
def get_port(cls) -> int:
|
||||
return config.database_api_port
|
||||
|
||||
@expose
|
||||
def send_execution_update(
|
||||
self, execution_result: GraphExecution | NodeExecutionResult
|
||||
):
|
||||
self.execution_event_bus.publish(execution_result)
|
||||
@staticmethod
|
||||
def _(
|
||||
f: Callable[P, R], name: str | None = None
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
if name is not None:
|
||||
f.__name__ = name
|
||||
return cast(Callable[Concatenate[object, P], R], expose(f))
|
||||
|
||||
# Executions
|
||||
get_graph_execution = exposed_run_and_wait(get_graph_execution)
|
||||
create_graph_execution = exposed_run_and_wait(create_graph_execution)
|
||||
get_node_execution_results = exposed_run_and_wait(get_node_execution_results)
|
||||
get_incomplete_node_executions = exposed_run_and_wait(
|
||||
get_incomplete_node_executions
|
||||
)
|
||||
get_latest_node_execution = exposed_run_and_wait(get_latest_node_execution)
|
||||
update_node_execution_status = exposed_run_and_wait(update_node_execution_status)
|
||||
update_node_execution_status_batch = exposed_run_and_wait(
|
||||
update_node_execution_status_batch
|
||||
)
|
||||
update_graph_execution_start_time = exposed_run_and_wait(
|
||||
update_graph_execution_start_time
|
||||
)
|
||||
update_graph_execution_stats = exposed_run_and_wait(update_graph_execution_stats)
|
||||
update_node_execution_stats = exposed_run_and_wait(update_node_execution_stats)
|
||||
upsert_execution_input = exposed_run_and_wait(upsert_execution_input)
|
||||
upsert_execution_output = exposed_run_and_wait(upsert_execution_output)
|
||||
get_graph_execution = _(get_graph_execution)
|
||||
get_graph_execution_meta = _(get_graph_execution_meta)
|
||||
create_graph_execution = _(create_graph_execution)
|
||||
get_node_executions = _(get_node_executions)
|
||||
get_latest_node_execution = _(get_latest_node_execution)
|
||||
update_node_execution_status = _(update_node_execution_status)
|
||||
update_node_execution_status_batch = _(update_node_execution_status_batch)
|
||||
update_graph_execution_start_time = _(update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(update_graph_execution_stats)
|
||||
update_node_execution_stats = _(update_node_execution_stats)
|
||||
upsert_execution_input = _(upsert_execution_input)
|
||||
upsert_execution_output = _(upsert_execution_output)
|
||||
|
||||
# Graphs
|
||||
get_node = exposed_run_and_wait(get_node)
|
||||
get_graph = exposed_run_and_wait(get_graph)
|
||||
get_connected_output_nodes = exposed_run_and_wait(get_connected_output_nodes)
|
||||
get_graph_metadata = exposed_run_and_wait(get_graph_metadata)
|
||||
get_node = _(get_node)
|
||||
get_graph = _(get_graph)
|
||||
get_connected_output_nodes = _(get_connected_output_nodes)
|
||||
get_graph_metadata = _(get_graph_metadata)
|
||||
|
||||
# Credits
|
||||
spend_credits = exposed_run_and_wait(_spend_credits)
|
||||
spend_credits = _(_spend_credits, name="spend_credits")
|
||||
get_credits = _(_get_credits, name="get_credits")
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
get_user_metadata = exposed_run_and_wait(get_user_metadata)
|
||||
update_user_metadata = exposed_run_and_wait(update_user_metadata)
|
||||
get_user_integrations = exposed_run_and_wait(get_user_integrations)
|
||||
update_user_integrations = exposed_run_and_wait(update_user_integrations)
|
||||
get_user_metadata = _(get_user_metadata)
|
||||
update_user_metadata = _(update_user_metadata)
|
||||
get_user_integrations = _(get_user_integrations)
|
||||
update_user_integrations = _(update_user_integrations)
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = exposed_run_and_wait(
|
||||
get_active_user_ids_in_timerange
|
||||
)
|
||||
get_user_email_by_id = exposed_run_and_wait(get_user_email_by_id)
|
||||
get_user_email_verification = exposed_run_and_wait(get_user_email_verification)
|
||||
get_user_notification_preference = exposed_run_and_wait(
|
||||
get_user_notification_preference
|
||||
)
|
||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||
get_user_email_by_id = _(get_user_email_by_id)
|
||||
get_user_email_verification = _(get_user_email_verification)
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
|
||||
# Notifications - async
|
||||
create_or_add_to_user_notification_batch = exposed_run_and_wait(
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
|
||||
get_all_batches_by_type = exposed_run_and_wait(get_all_batches_by_type)
|
||||
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = exposed_run_and_wait(
|
||||
empty_user_notification_batch = _(empty_user_notification_batch)
|
||||
get_all_batches_by_type = _(get_all_batches_by_type)
|
||||
get_user_notification_batch = _(get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = _(
|
||||
get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
_ = endpoint_to_sync
|
||||
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return DatabaseManager
|
||||
|
||||
# Executions
|
||||
get_graph_execution = _(d.get_graph_execution)
|
||||
get_graph_execution_meta = _(d.get_graph_execution_meta)
|
||||
create_graph_execution = _(d.create_graph_execution)
|
||||
get_node_executions = _(d.get_node_executions)
|
||||
get_latest_node_execution = _(d.get_latest_node_execution)
|
||||
update_node_execution_status = _(d.update_node_execution_status)
|
||||
update_node_execution_status_batch = _(d.update_node_execution_status_batch)
|
||||
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(d.update_graph_execution_stats)
|
||||
update_node_execution_stats = _(d.update_node_execution_stats)
|
||||
upsert_execution_input = _(d.upsert_execution_input)
|
||||
upsert_execution_output = _(d.upsert_execution_output)
|
||||
|
||||
# Graphs
|
||||
get_node = _(d.get_node)
|
||||
get_graph = _(d.get_graph)
|
||||
get_connected_output_nodes = _(d.get_connected_output_nodes)
|
||||
get_graph_metadata = _(d.get_graph_metadata)
|
||||
|
||||
# Credits
|
||||
spend_credits = _(d.spend_credits)
|
||||
get_credits = _(d.get_credits)
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
get_user_metadata = _(d.get_user_metadata)
|
||||
update_user_metadata = _(d.update_user_metadata)
|
||||
get_user_integrations = _(d.get_user_integrations)
|
||||
update_user_integrations = _(d.update_user_integrations)
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = _(d.get_active_user_ids_in_timerange)
|
||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
||||
get_user_email_verification = _(d.get_user_email_verification)
|
||||
get_user_notification_preference = _(d.get_user_notification_preference)
|
||||
|
||||
# Notifications - async
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
d.create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = _(d.empty_user_notification_batch)
|
||||
get_all_batches_by_type = _(d.get_all_batches_by_type)
|
||||
get_user_notification_batch = _(d.get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = _(
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
@@ -8,8 +8,10 @@ import threading
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
@@ -20,57 +22,65 @@ from backend.data.notifications import (
|
||||
NotificationEventDTO,
|
||||
NotificationType,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockData,
|
||||
BlockInput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.block import BlockData, BlockInput, BlockSchema, get_block
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
merge_execution_input,
|
||||
parse_execution_output,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Link, Node
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.executor.utils import (
|
||||
UsageTransactionMetadata,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
CancelExecutionEvent,
|
||||
block_usage_cost,
|
||||
execution_usage_cost,
|
||||
get_execution_event_bus,
|
||||
get_execution_queue,
|
||||
parse_execution_output,
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util import json
|
||||
from backend.util.decorator import error_logged, time_measured
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import configure_logging
|
||||
from backend.util.process import set_service_name
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
close_service_client,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.type import convert
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
active_runs_gauge = Gauge(
|
||||
"execution_manager_active_runs", "Number of active graph runs"
|
||||
)
|
||||
pool_size_gauge = Gauge(
|
||||
"execution_manager_pool_size", "Maximum number of graph workers"
|
||||
)
|
||||
utilization_gauge = Gauge(
|
||||
"execution_manager_utilization_ratio",
|
||||
"Ratio of active graph runs to max graph workers",
|
||||
)
|
||||
|
||||
|
||||
class LogMetadata:
|
||||
def __init__(
|
||||
@@ -91,7 +101,7 @@ class LogMetadata:
|
||||
"node_id": node_id,
|
||||
"block_name": block_name,
|
||||
}
|
||||
self.prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|nid:{node_eid}|{block_name}]"
|
||||
self.prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]"
|
||||
|
||||
def info(self, msg: str, **extra):
|
||||
msg = self._wrap(msg, **extra)
|
||||
@@ -125,7 +135,7 @@ ExecutionStream = Generator[NodeExecutionEntry, None, None]
|
||||
|
||||
|
||||
def execute_node(
|
||||
db_client: "DatabaseManager",
|
||||
db_client: "DatabaseManagerClient",
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecutionEntry,
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
@@ -152,7 +162,7 @@ def execute_node(
|
||||
def update_execution_status(status: ExecutionStatus) -> NodeExecutionResult:
|
||||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||||
exec_update = db_client.update_node_execution_status(node_exec_id, status)
|
||||
db_client.send_execution_update(exec_update)
|
||||
send_execution_update(exec_update)
|
||||
return exec_update
|
||||
|
||||
node = db_client.get_node(node_id)
|
||||
@@ -192,7 +202,7 @@ def execute_node(
|
||||
# Execute the node
|
||||
input_data_str = json.dumps(input_data)
|
||||
input_size = len(input_data_str)
|
||||
log_metadata.info("Executed node with input", input=input_data_str)
|
||||
log_metadata.debug("Executed node with input", input=input_data_str)
|
||||
update_execution_status(ExecutionStatus.RUNNING)
|
||||
|
||||
# Inject extra execution arguments for the blocks via kwargs
|
||||
@@ -223,7 +233,7 @@ def execute_node(
|
||||
):
|
||||
output_data = json.convert_pydantic_to_json(output_data)
|
||||
output_size += len(json.dumps(output_data))
|
||||
log_metadata.info("Node produced output", **{output_name: output_data})
|
||||
log_metadata.debug("Node produced output", **{output_name: output_data})
|
||||
push_output(output_name, output_data)
|
||||
outputs[output_name] = output_data
|
||||
for execution in _enqueue_next_nodes(
|
||||
@@ -258,7 +268,7 @@ def execute_node(
|
||||
raise e
|
||||
finally:
|
||||
# Ensure credentials are released even if execution fails
|
||||
if creds_lock and creds_lock.locked():
|
||||
if creds_lock and creds_lock.locked() and creds_lock.owned():
|
||||
try:
|
||||
creds_lock.release()
|
||||
except Exception as e:
|
||||
@@ -274,7 +284,7 @@ def execute_node(
|
||||
|
||||
|
||||
def _enqueue_next_nodes(
|
||||
db_client: "DatabaseManager",
|
||||
db_client: "DatabaseManagerClient",
|
||||
node: Node,
|
||||
output: BlockData,
|
||||
user_id: str,
|
||||
@@ -288,7 +298,7 @@ def _enqueue_next_nodes(
|
||||
exec_update = db_client.update_node_execution_status(
|
||||
node_exec_id, ExecutionStatus.QUEUED, data
|
||||
)
|
||||
db_client.send_execution_update(exec_update)
|
||||
send_execution_update(exec_update)
|
||||
return NodeExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
@@ -363,8 +373,10 @@ def _enqueue_next_nodes(
|
||||
|
||||
# If link is static, there could be some incomplete executions waiting for it.
|
||||
# Load and complete the input missing input data, and try to re-enqueue them.
|
||||
for iexec in db_client.get_incomplete_node_executions(
|
||||
next_node_id, graph_exec_id
|
||||
for iexec in db_client.get_node_executions(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.INCOMPLETE],
|
||||
):
|
||||
idata = iexec.input_data
|
||||
ineid = iexec.node_exec_id
|
||||
@@ -400,60 +412,6 @@ def _enqueue_next_nodes(
|
||||
]
|
||||
|
||||
|
||||
def validate_exec(
|
||||
node: Node,
|
||||
data: BlockInput,
|
||||
resolve_input: bool = True,
|
||||
) -> tuple[BlockInput | None, str]:
|
||||
"""
|
||||
Validate the input data for a node execution.
|
||||
|
||||
Args:
|
||||
node: The node to execute.
|
||||
data: The input data for the node execution.
|
||||
resolve_input: Whether to resolve dynamic pins into dict/list/object.
|
||||
|
||||
Returns:
|
||||
A tuple of the validated data and the block name.
|
||||
If the data is invalid, the first element will be None, and the second element
|
||||
will be an error message.
|
||||
If the data is valid, the first element will be the resolved input data, and
|
||||
the second element will be the block name.
|
||||
"""
|
||||
node_block: Block | None = get_block(node.block_id)
|
||||
if not node_block:
|
||||
return None, f"Block for {node.block_id} not found."
|
||||
schema = node_block.input_schema
|
||||
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
if (value := data.get(name)) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data (without default values) should contain all required fields.
|
||||
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
|
||||
if missing_links := schema.get_missing_links(data, node.input_links):
|
||||
return None, f"{error_prefix} unpopulated links {missing_links}"
|
||||
|
||||
# Merge input data with default values and resolve dynamic dict/list/object pins.
|
||||
input_default = schema.get_input_defaults(node.input_default)
|
||||
data = {**input_default, **data}
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
if missing_input := schema.get_missing_input(data):
|
||||
return None, f"{error_prefix} missing input {missing_input}"
|
||||
|
||||
# Last validation: Validate the input values against the schema.
|
||||
if error := schema.get_mismatch_error(data):
|
||||
error_message = f"{error_prefix} {error}"
|
||||
logger.error(error_message)
|
||||
return None, error_message
|
||||
|
||||
return data, node_block.name
|
||||
|
||||
|
||||
class Executor:
|
||||
"""
|
||||
This class contains event handlers for the process pool executor events.
|
||||
@@ -480,6 +438,7 @@ class Executor:
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@func_retry
|
||||
def on_node_executor_start(cls):
|
||||
configure_logging()
|
||||
set_service_name("NodeExecutor")
|
||||
@@ -490,36 +449,28 @@ class Executor:
|
||||
|
||||
# Set up shutdown handlers
|
||||
cls.shutdown_lock = threading.Lock()
|
||||
atexit.register(cls.on_node_executor_stop) # handle regular shutdown
|
||||
signal.signal( # handle termination
|
||||
signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm()
|
||||
)
|
||||
atexit.register(cls.on_node_executor_stop)
|
||||
signal.signal(signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm())
|
||||
signal.signal(signal.SIGINT, lambda _, __: cls.on_node_executor_sigterm())
|
||||
|
||||
@classmethod
|
||||
def on_node_executor_stop(cls):
|
||||
def on_node_executor_stop(cls, log=logger.info):
|
||||
if not cls.shutdown_lock.acquire(blocking=False):
|
||||
return # already shutting down
|
||||
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
|
||||
log(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
|
||||
cls.creds_manager.release_all_locks()
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
log(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB manager...")
|
||||
close_service_client(cls.db_client)
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
|
||||
log(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB manager...")
|
||||
cls.db_client.close()
|
||||
log(f"[on_node_executor_stop {cls.pid}] ✅ Finished NodeExec cleanup")
|
||||
sys.exit(0)
|
||||
|
||||
@classmethod
|
||||
def on_node_executor_sigterm(cls):
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ SIGTERM received")
|
||||
if not cls.shutdown_lock.acquire(blocking=False):
|
||||
return # already shutting down
|
||||
|
||||
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
|
||||
cls.creds_manager.release_all_locks()
|
||||
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
llprint(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
|
||||
sys.exit(0)
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ NodeExec SIGTERM received")
|
||||
cls.on_node_executor_stop(log=llprint)
|
||||
|
||||
@classmethod
|
||||
@error_logged
|
||||
@@ -585,6 +536,7 @@ class Executor:
|
||||
stats.error = e
|
||||
|
||||
@classmethod
|
||||
@func_retry
|
||||
def on_graph_executor_start(cls):
|
||||
configure_logging()
|
||||
set_service_name("GraphExecutor")
|
||||
@@ -594,21 +546,7 @@ class Executor:
|
||||
cls.pid = os.getpid()
|
||||
cls.notification_service = get_notification_service()
|
||||
cls._init_node_executor_pool()
|
||||
logger.info(
|
||||
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
|
||||
)
|
||||
|
||||
# Set up shutdown handler
|
||||
atexit.register(cls.on_graph_executor_stop)
|
||||
|
||||
@classmethod
|
||||
def on_graph_executor_stop(cls):
|
||||
prefix = f"[on_graph_executor_stop {cls.pid}]"
|
||||
logger.info(f"{prefix} ⏳ Terminating node executor pool...")
|
||||
cls.executor.terminate()
|
||||
logger.info(f"{prefix} ⏳ Disconnecting DB manager...")
|
||||
close_service_client(cls.db_client)
|
||||
logger.info(f"{prefix} ✅ Finished cleanup")
|
||||
logger.info(f"GraphExec {cls.pid} started with {cls.pool_size} node workers")
|
||||
|
||||
@classmethod
|
||||
def _init_node_executor_pool(cls):
|
||||
@@ -630,10 +568,35 @@ class Executor:
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
exec_meta = cls.db_client.update_graph_execution_start_time(
|
||||
graph_exec.graph_exec_id
|
||||
|
||||
exec_meta = cls.db_client.get_graph_execution_meta(
|
||||
user_id=graph_exec.user_id,
|
||||
execution_id=graph_exec.graph_exec_id,
|
||||
)
|
||||
cls.db_client.send_execution_update(exec_meta)
|
||||
if exec_meta is None:
|
||||
log_metadata.warning(
|
||||
f"Skipped graph execution #{graph_exec.graph_exec_id}, the graph execution is not found."
|
||||
)
|
||||
return
|
||||
|
||||
if exec_meta.status == ExecutionStatus.QUEUED:
|
||||
log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}")
|
||||
exec_meta.status = ExecutionStatus.RUNNING
|
||||
send_execution_update(
|
||||
cls.db_client.update_graph_execution_start_time(
|
||||
graph_exec.graph_exec_id
|
||||
)
|
||||
)
|
||||
elif exec_meta.status == ExecutionStatus.RUNNING:
|
||||
log_metadata.info(
|
||||
f"⚙️ Graph execution #{graph_exec.graph_exec_id} is already running, continuing where it left off."
|
||||
)
|
||||
else:
|
||||
log_metadata.warning(
|
||||
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution status is `{exec_meta.status}`."
|
||||
)
|
||||
return
|
||||
|
||||
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
|
||||
graph_exec, cancel, log_metadata
|
||||
)
|
||||
@@ -646,7 +609,7 @@ class Executor:
|
||||
status=status,
|
||||
stats=exec_stats,
|
||||
):
|
||||
cls.db_client.send_execution_update(graph_exec_result)
|
||||
send_execution_update(graph_exec_result)
|
||||
|
||||
cls._handle_agent_run_notif(graph_exec, exec_stats)
|
||||
|
||||
@@ -656,11 +619,11 @@ class Executor:
|
||||
node_exec: NodeExecutionEntry,
|
||||
execution_count: int,
|
||||
execution_stats: GraphExecutionStats,
|
||||
) -> int:
|
||||
):
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return execution_count
|
||||
return
|
||||
|
||||
cost, matching_filter = block_usage_cost(block=block, input_data=node_exec.data)
|
||||
if cost > 0:
|
||||
@@ -675,11 +638,12 @@ class Executor:
|
||||
block_id=node_exec.block_id,
|
||||
block=block.name,
|
||||
input=matching_filter,
|
||||
reason=f"Ran block {node_exec.block_id} {block.name}",
|
||||
),
|
||||
)
|
||||
execution_stats.cost += cost
|
||||
|
||||
cost, execution_count = execution_usage_cost(execution_count)
|
||||
cost, usage_count = execution_usage_cost(execution_count)
|
||||
if cost > 0:
|
||||
cls.db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -688,15 +652,14 @@ class Executor:
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
input={
|
||||
"execution_count": execution_count,
|
||||
"execution_count": usage_count,
|
||||
"charge": "Execution Cost",
|
||||
},
|
||||
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
|
||||
),
|
||||
)
|
||||
execution_stats.cost += cost
|
||||
|
||||
return execution_count
|
||||
|
||||
@classmethod
|
||||
@time_measured
|
||||
def _on_graph_execution(
|
||||
@@ -711,7 +674,6 @@ class Executor:
|
||||
ExecutionStatus: The final status of the graph execution.
|
||||
Exception | None: The error that occurred during the execution, if any.
|
||||
"""
|
||||
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
|
||||
execution_stats = GraphExecutionStats()
|
||||
execution_status = ExecutionStatus.RUNNING
|
||||
error = None
|
||||
@@ -733,11 +695,21 @@ class Executor:
|
||||
cancel_thread.start()
|
||||
|
||||
try:
|
||||
queue = ExecutionQueue[NodeExecutionEntry]()
|
||||
for node_exec in graph_exec.start_node_execs:
|
||||
queue.add(node_exec)
|
||||
if cls.db_client.get_credits(graph_exec.user_id) <= 0:
|
||||
raise InsufficientBalanceError(
|
||||
user_id=graph_exec.user_id,
|
||||
message="You have no credits left to run an agent.",
|
||||
balance=0,
|
||||
amount=1,
|
||||
)
|
||||
|
||||
queue = ExecutionQueue[NodeExecutionEntry]()
|
||||
for node_exec in cls.db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
statuses=[ExecutionStatus.RUNNING, ExecutionStatus.QUEUED],
|
||||
):
|
||||
queue.add(node_exec.to_node_execution_entry())
|
||||
|
||||
exec_cost_counter = 0
|
||||
running_executions: dict[str, AsyncResult] = {}
|
||||
|
||||
def make_exec_callback(exec_data: NodeExecutionEntry):
|
||||
@@ -759,7 +731,7 @@ class Executor:
|
||||
status=execution_status,
|
||||
stats=execution_stats,
|
||||
):
|
||||
cls.db_client.send_execution_update(_graph_exec)
|
||||
send_execution_update(_graph_exec)
|
||||
else:
|
||||
logger.error(
|
||||
"Callback for "
|
||||
@@ -776,10 +748,10 @@ class Executor:
|
||||
execution_status = ExecutionStatus.TERMINATED
|
||||
return execution_stats, execution_status, error
|
||||
|
||||
exec_data = queue.get()
|
||||
queued_node_exec = queue.get()
|
||||
|
||||
# Avoid parallel execution of the same node.
|
||||
execution = running_executions.get(exec_data.node_id)
|
||||
execution = running_executions.get(queued_node_exec.node_id)
|
||||
if execution and not execution.ready():
|
||||
# TODO (performance improvement):
|
||||
# Wait for the completion of the same node execution is blocking.
|
||||
@@ -788,18 +760,18 @@ class Executor:
|
||||
execution.wait()
|
||||
|
||||
log_metadata.debug(
|
||||
f"Dispatching node execution {exec_data.node_exec_id} "
|
||||
f"for node {exec_data.node_id}",
|
||||
f"Dispatching node execution {queued_node_exec.node_exec_id} "
|
||||
f"for node {queued_node_exec.node_id}",
|
||||
)
|
||||
|
||||
try:
|
||||
exec_cost_counter = cls._charge_usage(
|
||||
node_exec=exec_data,
|
||||
execution_count=exec_cost_counter + 1,
|
||||
cls._charge_usage(
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=increment_execution_count(graph_exec.user_id),
|
||||
execution_stats=execution_stats,
|
||||
)
|
||||
except InsufficientBalanceError as error:
|
||||
node_exec_id = exec_data.node_exec_id
|
||||
node_exec_id = queued_node_exec.node_exec_id
|
||||
cls.db_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="error",
|
||||
@@ -810,7 +782,7 @@ class Executor:
|
||||
exec_update = cls.db_client.update_node_execution_status(
|
||||
node_exec_id, execution_status
|
||||
)
|
||||
cls.db_client.send_execution_update(exec_update)
|
||||
send_execution_update(exec_update)
|
||||
|
||||
cls._handle_low_balance_notif(
|
||||
graph_exec.user_id,
|
||||
@@ -820,10 +792,23 @@ class Executor:
|
||||
)
|
||||
raise
|
||||
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
# Add credentials input overrides
|
||||
node_id = queued_node_exec.node_id
|
||||
if (node_creds_map := graph_exec.node_credentials_input_map) and (
|
||||
node_field_creds_map := node_creds_map.get(node_id)
|
||||
):
|
||||
queued_node_exec.data.update(
|
||||
{
|
||||
field_name: creds_meta.model_dump()
|
||||
for field_name, creds_meta in node_field_creds_map.items()
|
||||
}
|
||||
)
|
||||
|
||||
# Initiate node execution
|
||||
running_executions[queued_node_exec.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(queue, exec_data),
|
||||
callback=make_exec_callback(exec_data),
|
||||
(queue, queued_node_exec),
|
||||
callback=make_exec_callback(queued_node_exec),
|
||||
)
|
||||
|
||||
# Avoid terminating graph execution when some nodes are still running.
|
||||
@@ -843,24 +828,21 @@ class Executor:
|
||||
execution.wait(3)
|
||||
|
||||
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
|
||||
execution_status = ExecutionStatus.COMPLETED
|
||||
|
||||
except Exception as e:
|
||||
error = e
|
||||
finally:
|
||||
if error:
|
||||
log_metadata.error(
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
|
||||
)
|
||||
execution_status = ExecutionStatus.FAILED
|
||||
else:
|
||||
execution_status = ExecutionStatus.COMPLETED
|
||||
log_metadata.error(
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
|
||||
)
|
||||
execution_status = ExecutionStatus.FAILED
|
||||
|
||||
finally:
|
||||
if not cancel.is_set():
|
||||
finished = True
|
||||
cancel.set()
|
||||
cancel_thread.join()
|
||||
clean_exec_files(graph_exec.graph_exec_id)
|
||||
|
||||
return execution_stats, execution_status, error
|
||||
|
||||
@classmethod
|
||||
@@ -872,7 +854,7 @@ class Executor:
|
||||
metadata = cls.db_client.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
outputs = cls.db_client.get_node_execution_results(
|
||||
outputs = cls.db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
block_ids=[AgentOutputBlock().id],
|
||||
)
|
||||
@@ -927,22 +909,31 @@ class Executor:
|
||||
)
|
||||
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
class ExecutionManager(AppProcess):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecutionEntry]()
|
||||
self.running = True
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
atexit.register(self._on_cleanup)
|
||||
signal.signal(signal.SIGTERM, lambda sig, frame: self._on_sigterm())
|
||||
signal.signal(signal.SIGINT, lambda sig, frame: self._on_sigterm())
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return settings.config.execution_manager_port
|
||||
def run(self):
|
||||
pool_size_gauge.set(self.pool_size)
|
||||
active_runs_gauge.set(0)
|
||||
utilization_gauge.set(0)
|
||||
|
||||
def run_service(self):
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
self.credentials_store = IntegrationCredentialsStore()
|
||||
self.metrics_server = threading.Thread(
|
||||
target=start_http_server,
|
||||
args=(settings.config.execution_manager_port,),
|
||||
daemon=True,
|
||||
)
|
||||
self.metrics_server.start()
|
||||
logger.info(f"[{self.service_name}] Starting execution manager...")
|
||||
self._run()
|
||||
|
||||
def _run(self):
|
||||
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
|
||||
self.executor = ProcessPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
@@ -952,220 +943,174 @@ class ExecutionManager(AppService):
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
|
||||
redis.connect()
|
||||
|
||||
sync_manager = multiprocessing.Manager()
|
||||
while True:
|
||||
graph_exec_data = self.queue.get()
|
||||
graph_exec_id = graph_exec_data.graph_exec_id
|
||||
logger.debug(
|
||||
f"[ExecutionManager] Dispatching graph execution {graph_exec_id}"
|
||||
)
|
||||
cancel_event = sync_manager.Event()
|
||||
future = self.executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_data, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
future.add_done_callback(
|
||||
lambda _: self.active_graph_runs.pop(graph_exec_id, None)
|
||||
cancel_client = SyncRabbitMQ(create_execution_queue_config())
|
||||
cancel_client.connect()
|
||||
cancel_channel = cancel_client.get_channel()
|
||||
logger.info(f"[{self.service_name}] ⏳ Starting cancel message consumer...")
|
||||
threading.Thread(
|
||||
target=lambda: (
|
||||
cancel_channel.basic_consume(
|
||||
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
on_message_callback=self._handle_cancel_message,
|
||||
auto_ack=True,
|
||||
),
|
||||
cancel_channel.start_consuming(),
|
||||
),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
run_client = SyncRabbitMQ(create_execution_queue_config())
|
||||
run_client.connect()
|
||||
run_channel = run_client.get_channel()
|
||||
run_channel.basic_qos(prefetch_count=self.pool_size)
|
||||
run_channel.basic_consume(
|
||||
queue=GRAPH_EXECUTION_QUEUE_NAME,
|
||||
on_message_callback=self._handle_run_message,
|
||||
auto_ack=False,
|
||||
)
|
||||
logger.info(f"[{self.service_name}] ⏳ Starting to consume run messages...")
|
||||
run_channel.start_consuming()
|
||||
|
||||
def _handle_cancel_message(
|
||||
self,
|
||||
channel: BlockingChannel,
|
||||
method: Basic.Deliver,
|
||||
properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
"""
|
||||
Called whenever we receive a CANCEL message from the queue.
|
||||
(With auto_ack=True, message is considered 'acked' automatically.)
|
||||
"""
|
||||
try:
|
||||
request = CancelExecutionEvent.model_validate_json(body)
|
||||
graph_exec_id = request.graph_exec_id
|
||||
if not graph_exec_id:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Cancel message missing 'graph_exec_id'"
|
||||
)
|
||||
return
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
logger.debug(
|
||||
f"[{self.service_name}] Cancel received for {graph_exec_id} but not active."
|
||||
)
|
||||
return
|
||||
|
||||
_, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
logger.info(f"[{self.service_name}] Received cancel for {graph_exec_id}")
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
else:
|
||||
logger.debug(
|
||||
f"[{self.service_name}] Cancel already set for {graph_exec_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling cancel message: {e}")
|
||||
|
||||
def _handle_run_message(
|
||||
self,
|
||||
channel: BlockingChannel,
|
||||
method: Basic.Deliver,
|
||||
properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
delivery_tag = method.delivery_tag
|
||||
try:
|
||||
graph_exec_entry = GraphExecutionEntry.model_validate_json(body)
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.service_name}] Could not parse run message: {e}")
|
||||
channel.basic_nack(delivery_tag, requeue=False)
|
||||
return
|
||||
|
||||
graph_exec_id = graph_exec_entry.graph_exec_id
|
||||
logger.info(
|
||||
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
|
||||
)
|
||||
if graph_exec_id in self.active_graph_runs:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
|
||||
)
|
||||
channel.basic_nack(delivery_tag, requeue=False)
|
||||
return
|
||||
|
||||
cancel_event = multiprocessing.Manager().Event()
|
||||
future = self.executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_entry, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
active_runs_gauge.set(len(self.active_graph_runs))
|
||||
utilization_gauge.set(len(self.active_graph_runs) / self.pool_size)
|
||||
|
||||
def _on_run_done(f: Future):
|
||||
logger.info(f"[{self.service_name}] Run completed for {graph_exec_id}")
|
||||
try:
|
||||
self.active_graph_runs.pop(graph_exec_id, None)
|
||||
active_runs_gauge.set(len(self.active_graph_runs))
|
||||
utilization_gauge.set(len(self.active_graph_runs) / self.pool_size)
|
||||
if f.exception():
|
||||
logger.error(
|
||||
f"[{self.service_name}] Execution for {graph_exec_id} failed: {f.exception()}"
|
||||
)
|
||||
channel.connection.add_callback_threadsafe(
|
||||
lambda: channel.basic_nack(delivery_tag, requeue=False)
|
||||
)
|
||||
else:
|
||||
channel.connection.add_callback_threadsafe(
|
||||
lambda: channel.basic_ack(delivery_tag)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.service_name}] Error acknowledging message: {e}")
|
||||
|
||||
future.add_done_callback(_on_run_done)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
self._on_cleanup()
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down graph executor pool...")
|
||||
self.executor.shutdown(cancel_futures=True)
|
||||
def _on_sigterm(self):
|
||||
llprint(f"[{self.service_name}] ⚠️ GraphExec SIGTERM received")
|
||||
self._on_cleanup(log=llprint)
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
|
||||
def _on_cleanup(self, log=logger.info):
|
||||
prefix = f"[{self.service_name}][on_graph_executor_stop {os.getpid()}]"
|
||||
log(f"{prefix} ⏳ Shutting down service loop...")
|
||||
self.running = False
|
||||
|
||||
log(f"{prefix} ⏳ Shutting down RabbitMQ channel...")
|
||||
get_execution_queue().get_channel().stop_consuming()
|
||||
|
||||
if hasattr(self, "executor"):
|
||||
log(f"{prefix} ⏳ Shutting down GraphExec pool...")
|
||||
self.executor.shutdown(cancel_futures=False, wait=True)
|
||||
|
||||
log(f"{prefix} ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
|
||||
@property
|
||||
def db_client(self) -> "DatabaseManager":
|
||||
return get_db_client()
|
||||
|
||||
@expose
|
||||
def add_execution(
|
||||
self,
|
||||
graph_id: str,
|
||||
data: BlockInput,
|
||||
user_id: str,
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: str | None = None,
|
||||
) -> GraphExecutionEntry:
|
||||
graph: GraphModel | None = self.db_client.get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise ValueError(f"Graph #{graph_id} not found.")
|
||||
|
||||
graph.validate_graph(for_run=True)
|
||||
self._validate_node_input_credentials(graph, user_id)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
input_data = {}
|
||||
block = node.block
|
||||
|
||||
# Note block should never be executed.
|
||||
if block.block_type == BlockType.NOTE:
|
||||
continue
|
||||
|
||||
# Extract request input data, and assign it to the input pin.
|
||||
if block.block_type == BlockType.INPUT:
|
||||
input_name = node.input_default.get("name")
|
||||
if input_name and input_name in data:
|
||||
input_data = {"value": data[input_name]}
|
||||
|
||||
# Extract webhook payload, and assign it to the input pin
|
||||
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
||||
if (
|
||||
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
and node.webhook_id
|
||||
):
|
||||
if webhook_payload_key not in data:
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} webhook payload is missing"
|
||||
)
|
||||
input_data = {"payload": data[webhook_payload_key]}
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
raise ValueError(error)
|
||||
else:
|
||||
nodes_input.append((node.id, input_data))
|
||||
|
||||
if not nodes_input:
|
||||
raise ValueError(
|
||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||
)
|
||||
|
||||
graph_exec = self.db_client.create_graph_execution(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
nodes_input=nodes_input,
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
self.db_client.send_execution_update(graph_exec)
|
||||
|
||||
graph_exec_entry = GraphExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version or 0,
|
||||
graph_exec_id=graph_exec.id,
|
||||
start_node_execs=[
|
||||
NodeExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
data=node_exec.input_data,
|
||||
)
|
||||
for node_exec in graph_exec.node_executions
|
||||
],
|
||||
)
|
||||
self.queue.add(graph_exec_entry)
|
||||
|
||||
return graph_exec_entry
|
||||
|
||||
@expose
|
||||
def cancel_execution(self, graph_exec_id: str) -> None:
|
||||
"""
|
||||
Mechanism:
|
||||
1. Set the cancel event
|
||||
2. Graph executor's cancel handler thread detects the event, terminates workers,
|
||||
reinitializes worker pool, and returns.
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
logger.warning(
|
||||
f"Graph execution #{graph_exec_id} not active/running: "
|
||||
"possibly already completed/cancelled."
|
||||
)
|
||||
else:
|
||||
future, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
future.result()
|
||||
|
||||
# Update the status of the graph & node executions
|
||||
self.db_client.update_graph_execution_stats(
|
||||
graph_exec_id,
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
node_execs = self.db_client.get_node_execution_results(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
self.db_client.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
for node_exec in node_execs:
|
||||
node_exec.status = ExecutionStatus.TERMINATED
|
||||
self.db_client.send_execution_update(node_exec)
|
||||
|
||||
def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
for node in graph.nodes:
|
||||
block = node.block
|
||||
|
||||
# Find any fields of type CredentialsMetaInput
|
||||
credentials_fields = cast(
|
||||
type[BlockSchema], block.input_schema
|
||||
).get_credentials_fields()
|
||||
if not credentials_fields:
|
||||
continue
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = self.credentials_store.get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
f"Unknown credentials #{credentials_meta.id} "
|
||||
f"for node #{node.id} input '{field_name}'"
|
||||
)
|
||||
if (
|
||||
credentials.provider != credentials_meta.provider
|
||||
or credentials.type != credentials_meta.type
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch: "
|
||||
f"{credentials_meta.type}<>{credentials.type};"
|
||||
f"{credentials_meta.provider}<>{credentials.provider}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch"
|
||||
)
|
||||
log(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client() -> "DatabaseManager":
|
||||
from backend.executor import DatabaseManager
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
from backend.executor import DatabaseManagerClient
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_service() -> "NotificationManager":
|
||||
from backend.notifications import NotificationManager
|
||||
def get_notification_service() -> "NotificationManagerClient":
|
||||
from backend.notifications import NotificationManagerClient
|
||||
|
||||
return get_service_client(NotificationManager)
|
||||
return get_service_client(NotificationManagerClient)
|
||||
|
||||
|
||||
def send_execution_update(entry: GraphExecution | NodeExecutionResult | None):
|
||||
if entry is None:
|
||||
return
|
||||
return get_execution_event_bus().publish(entry)
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -1175,14 +1120,26 @@ def synchronized(key: str, timeout: int = 60):
|
||||
lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
if lock.locked():
|
||||
if lock.locked() and lock.owned():
|
||||
lock.release()
|
||||
|
||||
|
||||
def increment_execution_count(user_id: str) -> int:
|
||||
"""
|
||||
Increment the execution count for a given user,
|
||||
this will be used to charge the user for the execution cost.
|
||||
"""
|
||||
r = redis.get_redis()
|
||||
k = f"uec:{user_id}" # User Execution Count global key
|
||||
counter = cast(int, r.incr(k))
|
||||
if counter == 1:
|
||||
r.expire(k, settings.config.execution_counter_expiration_time)
|
||||
return counter
|
||||
|
||||
|
||||
def llprint(message: str):
|
||||
"""
|
||||
Low-level print/log helper function for use in signal handlers.
|
||||
Regular log/print statements are not allowed in signal handlers.
|
||||
"""
|
||||
if logger.getEffectiveLevel() == logging.DEBUG:
|
||||
os.write(sys.stdout.fileno(), (message + "\n").encode())
|
||||
os.write(sys.stdout.fileno(), (message + "\n").encode())
|
||||
|
||||
@@ -16,9 +16,15 @@ from pydantic import BaseModel
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
from backend.executor.manager import ExecutionManager
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
|
||||
|
||||
@@ -57,25 +63,18 @@ def job_listener(event):
|
||||
log(f"Job {event.job_id} completed successfully.")
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_client() -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_client():
|
||||
from backend.notifications import NotificationManager
|
||||
|
||||
return get_service_client(NotificationManager)
|
||||
return get_service_client(NotificationManagerClient)
|
||||
|
||||
|
||||
def execute_graph(**kwargs):
|
||||
args = ExecutionJobArgs(**kwargs)
|
||||
try:
|
||||
log(f"Executing recurring job for graph #{args.graph_id}")
|
||||
get_execution_client().add_execution(
|
||||
execution_utils.add_graph_execution(
|
||||
graph_id=args.graph_id,
|
||||
data=args.input_data,
|
||||
inputs=args.input_data,
|
||||
user_id=args.user_id,
|
||||
graph_version=args.graph_version,
|
||||
)
|
||||
@@ -164,19 +163,9 @@ class Scheduler(AppService):
|
||||
def db_pool_size(cls) -> int:
|
||||
return config.scheduler_db_pool_size
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def execution_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def notification_client(self) -> NotificationManager:
|
||||
return get_service_client(NotificationManager)
|
||||
|
||||
def run_service(self):
|
||||
load_dotenv()
|
||||
db_schema, db_url = _extract_schema_from_url(os.getenv("DATABASE_URL"))
|
||||
db_schema, db_url = _extract_schema_from_url(os.getenv("DIRECT_URL"))
|
||||
self.scheduler = BlockingScheduler(
|
||||
jobstores={
|
||||
Jobstores.EXECUTION.value: SQLAlchemyJobStore(
|
||||
@@ -310,3 +299,15 @@ class Scheduler(AppService):
|
||||
),
|
||||
job,
|
||||
)
|
||||
|
||||
|
||||
class SchedulerClient(AppServiceClient):
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return Scheduler
|
||||
|
||||
add_execution_schedule = endpoint_to_async(Scheduler.add_execution_schedule)
|
||||
delete_schedule = endpoint_to_async(Scheduler.delete_schedule)
|
||||
get_execution_schedules = endpoint_to_async(Scheduler.get_execution_schedules)
|
||||
add_batched_notification_schedule = Scheduler.add_batched_notification_schedule
|
||||
add_weekly_notification_schedule = Scheduler.add_weekly_notification_schedule
|
||||
|
||||
@@ -1,38 +1,113 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockInput
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockData,
|
||||
BlockInput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCostType
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
ExecutionStatus,
|
||||
GraphExecutionStats,
|
||||
GraphExecutionWithNodes,
|
||||
RedisExecutionEventBus,
|
||||
create_graph_execution,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_status_batch,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Node, get_graph
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.rabbitmq import (
|
||||
AsyncRabbitMQ,
|
||||
Exchange,
|
||||
ExchangeType,
|
||||
Queue,
|
||||
RabbitMQConfig,
|
||||
SyncRabbitMQ,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
config = Config()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ============ Resource Helpers ============ #
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
graph_exec_id: str | None = None
|
||||
graph_id: str | None = None
|
||||
node_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
block_id: str | None = None
|
||||
block: str | None = None
|
||||
input: BlockInput | None = None
|
||||
@thread_cached
|
||||
def get_execution_event_bus() -> RedisExecutionEventBus:
|
||||
return RedisExecutionEventBus()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_async_execution_event_bus() -> AsyncRedisExecutionEventBus:
|
||||
return AsyncRedisExecutionEventBus()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_queue() -> SyncRabbitMQ:
|
||||
client = SyncRabbitMQ(create_execution_queue_config())
|
||||
client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def get_async_execution_queue() -> AsyncRabbitMQ:
|
||||
client = AsyncRabbitMQ(create_execution_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
return IntegrationCredentialsStore()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
from backend.executor import DatabaseManagerClient
|
||||
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
# ============ Execution Cost Helpers ============ #
|
||||
|
||||
|
||||
def execution_usage_cost(execution_count: int) -> tuple[int, int]:
|
||||
"""
|
||||
Calculate the cost of executing a graph based on the number of executions.
|
||||
Calculate the cost of executing a graph based on the current number of node executions.
|
||||
|
||||
Args:
|
||||
execution_count: Number of executions
|
||||
execution_count: Number of node executions
|
||||
|
||||
Returns:
|
||||
Tuple of cost amount and remaining execution count
|
||||
Tuple of cost amount and the number of execution count that is included in the cost.
|
||||
"""
|
||||
return (
|
||||
execution_count
|
||||
// config.execution_cost_count_threshold
|
||||
* config.execution_cost_per_threshold,
|
||||
execution_count % config.execution_cost_count_threshold,
|
||||
(
|
||||
config.execution_cost_per_threshold
|
||||
if execution_count % config.execution_cost_count_threshold == 0
|
||||
else 0
|
||||
),
|
||||
config.execution_cost_count_threshold,
|
||||
)
|
||||
|
||||
|
||||
@@ -95,3 +170,571 @@ def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bo
|
||||
or (input_data.get(k) and _is_cost_filter_match(v, input_data[k]))
|
||||
for k, v in cost_filter.items()
|
||||
)
|
||||
|
||||
|
||||
# ============ Execution Input Helpers ============ #
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
"""
|
||||
Extracts partial output data by name from a given BlockData.
|
||||
|
||||
The function supports extracting data from lists, dictionaries, and objects
|
||||
using specific naming conventions:
|
||||
- For lists: <output_name>_$_<index>
|
||||
- For dictionaries: <output_name>_#_<key>
|
||||
- For objects: <output_name>_@_<attribute>
|
||||
|
||||
Args:
|
||||
output (BlockData): A tuple containing the output name and data.
|
||||
name (str): The name used to extract specific data from the output.
|
||||
|
||||
Returns:
|
||||
Any | None: The extracted data if found, otherwise None.
|
||||
|
||||
Examples:
|
||||
>>> output = ("result", [10, 20, 30])
|
||||
>>> parse_execution_output(output, "result_$_1")
|
||||
20
|
||||
|
||||
>>> output = ("config", {"key1": "value1", "key2": "value2"})
|
||||
>>> parse_execution_output(output, "config_#_key1")
|
||||
'value1'
|
||||
|
||||
>>> class Sample:
|
||||
... attr1 = "value1"
|
||||
... attr2 = "value2"
|
||||
>>> output = ("object", Sample())
|
||||
>>> parse_execution_output(output, "object_@_attr1")
|
||||
'value1'
|
||||
"""
|
||||
output_name, output_data = output
|
||||
|
||||
if name == output_name:
|
||||
return output_data
|
||||
|
||||
if name.startswith(f"{output_name}{LIST_SPLIT}"):
|
||||
index = int(name.split(LIST_SPLIT)[1])
|
||||
if not isinstance(output_data, list) or len(output_data) <= index:
|
||||
return None
|
||||
return output_data[int(name.split(LIST_SPLIT)[1])]
|
||||
|
||||
if name.startswith(f"{output_name}{DICT_SPLIT}"):
|
||||
index = name.split(DICT_SPLIT)[1]
|
||||
if not isinstance(output_data, dict) or index not in output_data:
|
||||
return None
|
||||
return output_data[index]
|
||||
|
||||
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
|
||||
index = name.split(OBJC_SPLIT)[1]
|
||||
if isinstance(output_data, object) and hasattr(output_data, index):
|
||||
return getattr(output_data, index)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_exec(
|
||||
node: Node,
|
||||
data: BlockInput,
|
||||
resolve_input: bool = True,
|
||||
) -> tuple[BlockInput | None, str]:
|
||||
"""
|
||||
Validate the input data for a node execution.
|
||||
|
||||
Args:
|
||||
node: The node to execute.
|
||||
data: The input data for the node execution.
|
||||
resolve_input: Whether to resolve dynamic pins into dict/list/object.
|
||||
|
||||
Returns:
|
||||
A tuple of the validated data and the block name.
|
||||
If the data is invalid, the first element will be None, and the second element
|
||||
will be an error message.
|
||||
If the data is valid, the first element will be the resolved input data, and
|
||||
the second element will be the block name.
|
||||
"""
|
||||
node_block: Block | None = get_block(node.block_id)
|
||||
if not node_block:
|
||||
return None, f"Block for {node.block_id} not found."
|
||||
schema = node_block.input_schema
|
||||
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
if (value := data.get(name)) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data (without default values) should contain all required fields.
|
||||
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
|
||||
if missing_links := schema.get_missing_links(data, node.input_links):
|
||||
return None, f"{error_prefix} unpopulated links {missing_links}"
|
||||
|
||||
# Merge input data with default values and resolve dynamic dict/list/object pins.
|
||||
input_default = schema.get_input_defaults(node.input_default)
|
||||
data = {**input_default, **data}
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
if missing_input := schema.get_missing_input(data):
|
||||
return None, f"{error_prefix} missing input {missing_input}"
|
||||
|
||||
# Last validation: Validate the input values against the schema.
|
||||
if error := schema.get_mismatch_error(data):
|
||||
error_message = f"{error_prefix} {error}"
|
||||
logger.error(error_message)
|
||||
return None, error_message
|
||||
|
||||
return data, node_block.name
|
||||
|
||||
|
||||
def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
"""
|
||||
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
|
||||
|
||||
This function processes input keys that follow specific patterns to merge them into a unified structure:
|
||||
- `<input_name>_$_<index>` for list inputs.
|
||||
- `<input_name>_#_<index>` for dictionary inputs.
|
||||
- `<input_name>_@_<index>` for object inputs.
|
||||
|
||||
Args:
|
||||
data (BlockInput): A dictionary containing input keys and their corresponding values.
|
||||
|
||||
Returns:
|
||||
BlockInput: A dictionary with merged inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If a list index is not an integer.
|
||||
|
||||
Examples:
|
||||
>>> data = {
|
||||
... "list_$_0": "a",
|
||||
... "list_$_1": "b",
|
||||
... "dict_#_key1": "value1",
|
||||
... "dict_#_key2": "value2",
|
||||
... "object_@_attr1": "value1",
|
||||
... "object_@_attr2": "value2"
|
||||
... }
|
||||
>>> merge_execution_input(data)
|
||||
{
|
||||
"list": ["a", "b"],
|
||||
"dict": {"key1": "value1", "key2": "value2"},
|
||||
"object": <MockObject attr1="value1" attr2="value2">
|
||||
}
|
||||
"""
|
||||
|
||||
# Merge all input with <input_name>_$_<index> into a single list.
|
||||
items = list(data.items())
|
||||
|
||||
for key, value in items:
|
||||
if LIST_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(LIST_SPLIT)
|
||||
if not index.isdigit():
|
||||
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
|
||||
|
||||
data[name] = data.get(name, [])
|
||||
if int(index) >= len(data[name]):
|
||||
# Pad list with empty string on missing indices.
|
||||
data[name].extend([""] * (int(index) - len(data[name]) + 1))
|
||||
data[name][int(index)] = value
|
||||
|
||||
# Merge all input with <input_name>_#_<index> into a single dict.
|
||||
for key, value in items:
|
||||
if DICT_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(DICT_SPLIT)
|
||||
data[name] = data.get(name, {})
|
||||
data[name][index] = value
|
||||
|
||||
# Merge all input with <input_name>_@_<index> into a single object.
|
||||
for key, value in items:
|
||||
if OBJC_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(OBJC_SPLIT)
|
||||
if name not in data or not isinstance(data[name], object):
|
||||
data[name] = MockObject()
|
||||
setattr(data[name], index, value)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
for node in graph.nodes:
|
||||
block = node.block
|
||||
|
||||
# Find any fields of type CredentialsMetaInput
|
||||
credentials_fields = cast(
|
||||
type[BlockSchema], block.input_schema
|
||||
).get_credentials_fields()
|
||||
if not credentials_fields:
|
||||
continue
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
if (
|
||||
node_credentials_input_map
|
||||
and (node_credentials_inputs := node_credentials_input_map.get(node.id))
|
||||
and field_name in node_credentials_inputs
|
||||
):
|
||||
credentials_meta = node_credentials_input_map[node.id][field_name]
|
||||
elif field_name in node.input_default:
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Credentials absent for {block.name} node #{node.id} "
|
||||
f"input '{field_name}'"
|
||||
)
|
||||
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = get_integration_credentials_store().get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
f"Unknown credentials #{credentials_meta.id} "
|
||||
f"for node #{node.id} input '{field_name}'"
|
||||
)
|
||||
if (
|
||||
credentials.provider != credentials_meta.provider
|
||||
or credentials.type != credentials_meta.type
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch: "
|
||||
f"{credentials_meta.type}<>{credentials.type};"
|
||||
f"{credentials_meta.provider}<>{credentials.provider}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch"
|
||||
)
|
||||
|
||||
|
||||
def make_node_credentials_input_map(
|
||||
graph: GraphModel,
|
||||
graph_credentials_input: dict[str, CredentialsMetaInput],
|
||||
) -> dict[str, dict[str, CredentialsMetaInput]]:
|
||||
"""
|
||||
Maps credentials for an execution to the correct nodes.
|
||||
|
||||
Params:
|
||||
graph: The graph to be executed.
|
||||
graph_credentials_input: A (graph_input_name, credentials_meta) map.
|
||||
|
||||
Returns:
|
||||
dict[node_id, dict[field_name, CredentialsMetaInput]]: Node credentials input map.
|
||||
"""
|
||||
result: dict[str, dict[str, CredentialsMetaInput]] = {}
|
||||
|
||||
# Get aggregated credentials fields for the graph
|
||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||
|
||||
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
|
||||
# Best-effort map: skip missing items
|
||||
if graph_input_name not in graph_credentials_input:
|
||||
continue
|
||||
|
||||
# Use passed-in credentials for all compatible node input fields
|
||||
for node_id, node_field_name in compatible_node_fields:
|
||||
if node_id not in result:
|
||||
result[node_id] = {}
|
||||
result[node_id][node_field_name] = graph_credentials_input[graph_input_name]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def construct_node_execution_input(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
This function checks the graph for starting nodes, validates the input data
|
||||
against the schema, and resolves dynamic input pins into a single list,
|
||||
dictionary, or object.
|
||||
|
||||
Args:
|
||||
graph (GraphModel): The graph model to execute.
|
||||
user_id (str): The ID of the user executing the graph.
|
||||
data (BlockInput): The input data for the graph execution.
|
||||
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
|
||||
|
||||
Returns:
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||
the corresponding input data for that node.
|
||||
"""
|
||||
graph.validate_graph(for_run=True)
|
||||
_validate_node_input_credentials(graph, user_id, node_credentials_input_map)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
input_data = {}
|
||||
block = node.block
|
||||
|
||||
# Note block should never be executed.
|
||||
if block.block_type == BlockType.NOTE:
|
||||
continue
|
||||
|
||||
# Extract request input data, and assign it to the input pin.
|
||||
if block.block_type == BlockType.INPUT:
|
||||
input_name = node.input_default.get("name")
|
||||
if input_name and input_name in graph_inputs:
|
||||
input_data = {"value": graph_inputs[input_name]}
|
||||
|
||||
# Extract webhook payload, and assign it to the input pin
|
||||
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
||||
if (
|
||||
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
and node.webhook_id
|
||||
):
|
||||
if webhook_payload_key not in graph_inputs:
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} webhook payload is missing"
|
||||
)
|
||||
input_data = {"payload": graph_inputs[webhook_payload_key]}
|
||||
|
||||
# Apply node credentials overrides
|
||||
if node_credentials_input_map and (
|
||||
node_credentials := node_credentials_input_map.get(node.id)
|
||||
):
|
||||
input_data.update({k: v.model_dump() for k, v in node_credentials.items()})
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
raise ValueError(error)
|
||||
else:
|
||||
nodes_input.append((node.id, input_data))
|
||||
|
||||
if not nodes_input:
|
||||
raise ValueError(
|
||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||
)
|
||||
|
||||
return nodes_input
|
||||
|
||||
|
||||
# ============ Execution Queue Helpers ============ #
|
||||
|
||||
|
||||
class CancelExecutionEvent(BaseModel):
|
||||
graph_exec_id: str
|
||||
|
||||
|
||||
GRAPH_EXECUTION_EXCHANGE = Exchange(
|
||||
name="graph_execution",
|
||||
type=ExchangeType.DIRECT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
GRAPH_EXECUTION_QUEUE_NAME = "graph_execution_queue"
|
||||
GRAPH_EXECUTION_ROUTING_KEY = "graph_execution.run"
|
||||
|
||||
GRAPH_EXECUTION_CANCEL_EXCHANGE = Exchange(
|
||||
name="graph_execution_cancel",
|
||||
type=ExchangeType.FANOUT,
|
||||
durable=True,
|
||||
auto_delete=True,
|
||||
)
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME = "graph_execution_cancel_queue"
|
||||
|
||||
|
||||
def create_execution_queue_config() -> RabbitMQConfig:
|
||||
"""
|
||||
Define two exchanges and queues:
|
||||
- 'graph_execution' (DIRECT) for run tasks.
|
||||
- 'graph_execution_cancel' (FANOUT) for cancel requests.
|
||||
"""
|
||||
run_queue = Queue(
|
||||
name=GRAPH_EXECUTION_QUEUE_NAME,
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
cancel_queue = Queue(
|
||||
name=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
|
||||
routing_key="", # not used for FANOUT
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
return RabbitMQConfig(
|
||||
vhost="/",
|
||||
exchanges=[GRAPH_EXECUTION_EXCHANGE, GRAPH_EXECUTION_CANCEL_EXCHANGE],
|
||||
queues=[run_queue, cancel_queue],
|
||||
)
|
||||
|
||||
|
||||
async def add_graph_execution_async(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
inputs: BlockInput,
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
|
||||
Args:
|
||||
graph_id: The ID of the graph to execute.
|
||||
user_id: The ID of the user executing the graph.
|
||||
inputs: The input data for the graph execution.
|
||||
preset_id: The ID of the preset to use.
|
||||
graph_version: The version of the graph to execute.
|
||||
graph_credentials_inputs: Credentials inputs to use in the execution.
|
||||
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
|
||||
Returns:
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
ValueError: If the graph is not found or if there are validation errors.
|
||||
""" # noqa
|
||||
graph: GraphModel | None = await get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
node_credentials_input_map = (
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else None
|
||||
)
|
||||
|
||||
graph_exec = await create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
),
|
||||
preset_id=preset_id,
|
||||
)
|
||||
try:
|
||||
queue = await get_async_execution_queue()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
if node_credentials_input_map:
|
||||
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
|
||||
await queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
|
||||
bus = get_async_execution_event_bus()
|
||||
await bus.publish(graph_exec)
|
||||
|
||||
return graph_exec
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
|
||||
|
||||
await update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
|
||||
ExecutionStatus.FAILED,
|
||||
)
|
||||
await update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
stats=GraphExecutionStats(error=str(e)),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def add_graph_execution(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
inputs: BlockInput,
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
|
||||
Args:
|
||||
graph_id: The ID of the graph to execute.
|
||||
user_id: The ID of the user executing the graph.
|
||||
inputs: The input data for the graph execution.
|
||||
preset_id: The ID of the preset to use.
|
||||
graph_version: The version of the graph to execute.
|
||||
graph_credentials_inputs: Credentials inputs to use in the execution.
|
||||
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
|
||||
Returns:
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
ValueError: If the graph is not found or if there are validation errors.
|
||||
"""
|
||||
db = get_db_client()
|
||||
graph: GraphModel | None = db.get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
node_credentials_input_map = (
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else None
|
||||
)
|
||||
|
||||
graph_exec = db.create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
),
|
||||
preset_id=preset_id,
|
||||
)
|
||||
try:
|
||||
queue = get_execution_queue()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
if node_credentials_input_map:
|
||||
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
|
||||
queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
|
||||
bus = get_execution_event_bus()
|
||||
bus.publish(graph_exec)
|
||||
|
||||
return graph_exec
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
|
||||
|
||||
db.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
|
||||
ExecutionStatus.FAILED,
|
||||
)
|
||||
db.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
stats=GraphExecutionStats(error=str(e)),
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
from pydantic import SecretStr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor.database import DatabaseManager
|
||||
from backend.executor.database import DatabaseManagerClient
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from autogpt_libs.utils.synchronize import RedisKeyedMutex
|
||||
@@ -161,6 +161,14 @@ smartlead_credentials = APIKeyCredentials(
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
google_maps_credentials = APIKeyCredentials(
|
||||
id="9aa1bde0-4947-4a70-a20c-84daa3850d52",
|
||||
provider="google_maps",
|
||||
api_key=SecretStr(settings.secrets.google_maps_api_key),
|
||||
title="Use Credits for Google Maps",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
zerobounce_credentials = APIKeyCredentials(
|
||||
id="63a6e279-2dc2-448e-bf57-85776f7176dc",
|
||||
provider="zerobounce",
|
||||
@@ -190,6 +198,7 @@ DEFAULT_CREDENTIALS = [
|
||||
apollo_credentials,
|
||||
smartlead_credentials,
|
||||
zerobounce_credentials,
|
||||
google_maps_credentials,
|
||||
]
|
||||
|
||||
|
||||
@@ -201,11 +210,11 @@ class IntegrationCredentialsStore:
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def db_manager(self) -> "DatabaseManager":
|
||||
from backend.executor.database import DatabaseManager
|
||||
def db_manager(self) -> "DatabaseManagerClient":
|
||||
from backend.executor.database import DatabaseManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
def add_creds(self, user_id: str, credentials: Credentials) -> None:
|
||||
with self.locked_user_integrations(user_id):
|
||||
@@ -263,6 +272,8 @@ class IntegrationCredentialsStore:
|
||||
all_credentials.append(smartlead_credentials)
|
||||
if settings.secrets.zerobounce_api_key:
|
||||
all_credentials.append(zerobounce_credentials)
|
||||
if settings.secrets.google_maps_api_key:
|
||||
all_credentials.append(google_maps_credentials)
|
||||
return all_credentials
|
||||
|
||||
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
|
||||
|
||||
@@ -93,7 +93,7 @@ class IntegrationCredentialsManager:
|
||||
|
||||
fresh_credentials = oauth_handler.refresh_tokens(credentials)
|
||||
self.store.update_creds(user_id, fresh_credentials)
|
||||
if _lock and _lock.locked():
|
||||
if _lock and _lock.locked() and _lock.owned():
|
||||
_lock.release()
|
||||
|
||||
credentials = fresh_credentials
|
||||
@@ -145,7 +145,7 @@ class IntegrationCredentialsManager:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if lock.locked():
|
||||
if lock.locked() and lock.owned():
|
||||
lock.release()
|
||||
|
||||
def release_all_locks(self):
|
||||
|
||||
@@ -29,6 +29,7 @@ class ProviderName(str, Enum):
|
||||
OPENWEATHERMAP = "openweathermap"
|
||||
OPEN_ROUTER = "open_router"
|
||||
PINECONE = "pinecone"
|
||||
POSTGRES = "postgres"
|
||||
REDDIT = "reddit"
|
||||
REPLICATE = "replicate"
|
||||
REVID = "revid"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .notifications import NotificationManager
|
||||
from .notifications import NotificationManager, NotificationManagerClient
|
||||
|
||||
__all__ = [
|
||||
"NotificationManager",
|
||||
"NotificationManagerClient",
|
||||
]
|
||||
|
||||
@@ -31,7 +31,13 @@ from backend.data.notifications import (
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.user import generate_unsubscribe_link
|
||||
from backend.notifications.email import EmailSender
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -108,16 +114,16 @@ def create_notification_config() -> RabbitMQConfig:
|
||||
|
||||
@thread_cached
|
||||
def get_scheduler():
|
||||
from backend.executor import Scheduler
|
||||
from backend.executor.scheduler import SchedulerClient
|
||||
|
||||
return get_service_client(Scheduler)
|
||||
return get_service_client(SchedulerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db():
|
||||
from backend.executor.database import DatabaseManager
|
||||
from backend.executor.database import DatabaseManagerClient
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
class NotificationManager(AppService):
|
||||
@@ -774,3 +780,14 @@ class NotificationManager(AppService):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
|
||||
self.run_and_wait(self.rabbitmq_service.disconnect())
|
||||
|
||||
|
||||
class NotificationManagerClient(AppServiceClient):
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return NotificationManager
|
||||
|
||||
queue_notification_async = endpoint_to_async(NotificationManager.queue_notification)
|
||||
queue_notification = NotificationManager.queue_notification
|
||||
process_existing_batches = NotificationManager.process_existing_batches
|
||||
queue_weekly_summary = NotificationManager.queue_weekly_summary
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Dict, List, Optional, Sequence
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
from typing_extensions import TypedDict
|
||||
@@ -13,17 +12,10 @@ from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKey
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.execution import NodeExecutionResult
|
||||
from backend.executor import ExecutionManager
|
||||
from backend.executor.utils import add_graph_execution_async
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_manager_client() -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -98,20 +90,20 @@ def execute_graph_block(
|
||||
path="/graphs/{graph_id}/execute/{graph_version}",
|
||||
tags=["graphs"],
|
||||
)
|
||||
def execute_graph(
|
||||
async def execute_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
graph_exec = execution_manager_client().add_execution(
|
||||
graph_id,
|
||||
graph_version=graph_version,
|
||||
data=node_input,
|
||||
graph_exec = await add_graph_execution_async(
|
||||
graph_id=graph_id,
|
||||
user_id=api_key.user_id,
|
||||
inputs=node_input,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
return {"id": graph_exec.graph_exec_id}
|
||||
return {"id": graph_exec.id}
|
||||
except Exception as e:
|
||||
msg = str(e).encode().decode("unicode_escape")
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
@@ -130,7 +122,7 @@ async def get_graph_execution_results(
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
results = await execution_db.get_node_execution_results(graph_exec_id)
|
||||
results = await execution_db.get_node_executions(graph_exec_id)
|
||||
last_result = results[-1] if results else None
|
||||
execution_status = (
|
||||
last_result.status if last_result else AgentExecutionStatus.INCOMPLETE
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -14,13 +15,12 @@ from backend.data.integrations import (
|
||||
wait_for_webhook_event,
|
||||
)
|
||||
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
|
||||
from backend.executor.manager import ExecutionManager
|
||||
from backend.executor.utils import add_graph_execution_async
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.util.exceptions import NeedConfirmation, NotFoundError
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -309,19 +309,22 @@ async def webhook_ingress_generic(
|
||||
if not webhook.attached_nodes:
|
||||
return
|
||||
|
||||
executor = get_service_client(ExecutionManager)
|
||||
executions: list[Awaitable] = []
|
||||
for node in webhook.attached_nodes:
|
||||
logger.debug(f"Webhook-attached node: {node}")
|
||||
if not node.is_triggered_by_event_type(event_type):
|
||||
logger.debug(f"Node #{node.id} doesn't trigger on event {event_type}")
|
||||
continue
|
||||
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
|
||||
executor.add_execution(
|
||||
graph_id=node.graph_id,
|
||||
graph_version=node.graph_version,
|
||||
data={f"webhook_{webhook_id}_payload": payload},
|
||||
user_id=webhook.user_id,
|
||||
executions.append(
|
||||
add_graph_execution_async(
|
||||
user_id=webhook.user_id,
|
||||
graph_id=node.graph_id,
|
||||
graph_version=node.graph_version,
|
||||
inputs={f"webhook_{webhook_id}_payload": payload},
|
||||
)
|
||||
)
|
||||
asyncio.gather(*executions)
|
||||
|
||||
|
||||
@router.post("/webhooks/{webhook_id}/ping")
|
||||
|
||||
@@ -17,9 +17,9 @@ import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.postmark.postmark
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.credit_admin_routes
|
||||
import backend.server.v2.admin.store_admin_routes
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
@@ -29,6 +29,7 @@ import backend.server.v2.store.model
|
||||
import backend.server.v2.store.routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.external.api import external_app
|
||||
@@ -57,8 +58,7 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.block.initialize_blocks()
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
# FIXME ERROR: operator does not exist: text ? unknown
|
||||
# await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
await backend.data.db.disconnect()
|
||||
@@ -108,6 +108,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/store",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.admin.credit_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/credits",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||
)
|
||||
@@ -156,11 +161,12 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
graph_version: Optional[int] = None,
|
||||
node_input: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
return backend.server.routers.v1.execute_graph(
|
||||
return await backend.server.routers.v1.execute_graph(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
node_input=node_input or {},
|
||||
inputs=node_input or {},
|
||||
credentials_inputs={},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -201,7 +207,7 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
|
||||
@staticmethod
|
||||
async def test_get_presets(user_id: str, page: int = 1, page_size: int = 10):
|
||||
return await backend.server.v2.library.routes.presets.list_presets(
|
||||
return await backend.server.v2.library.routes.presets.get_presets(
|
||||
user_id=user_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
@@ -213,7 +219,7 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
|
||||
@staticmethod
|
||||
async def test_create_preset(
|
||||
preset: backend.server.v2.library.model.LibraryAgentPresetCreatable,
|
||||
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
|
||||
user_id: str,
|
||||
):
|
||||
return await backend.server.v2.library.routes.presets.create_preset(
|
||||
@@ -223,7 +229,7 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
@staticmethod
|
||||
async def test_update_preset(
|
||||
preset_id: str,
|
||||
preset: backend.server.v2.library.model.LibraryAgentPresetUpdatable,
|
||||
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
|
||||
user_id: str,
|
||||
):
|
||||
return await backend.server.v2.library.routes.presets.update_preset(
|
||||
@@ -275,7 +281,9 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
return backend.server.integrations.router.create_credentials(
|
||||
from backend.server.integrations.router import create_credentials
|
||||
|
||||
return create_credentials(
|
||||
user_id=user_id, provider=provider, credentials=credentials
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
import backend.data.block
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.v2.library.db as library_db
|
||||
@@ -31,7 +30,7 @@ from backend.data.api_key import (
|
||||
suspend_api_key,
|
||||
update_api_key_permissions,
|
||||
)
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -41,6 +40,8 @@ from backend.data.credit import (
|
||||
get_user_credit_model,
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
UserOnboardingUpdate,
|
||||
@@ -49,13 +50,16 @@ from backend.data.onboarding import (
|
||||
onboarding_enabled,
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
update_user_notification_preference,
|
||||
)
|
||||
from backend.executor import ExecutionManager, Scheduler, scheduler
|
||||
from backend.executor import scheduler
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||
on_graph_activate,
|
||||
@@ -79,13 +83,20 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_manager_client() -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
def execution_scheduler_client() -> scheduler.SchedulerClient:
|
||||
return get_service_client(scheduler.SchedulerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_scheduler_client() -> Scheduler:
|
||||
return get_service_client(Scheduler)
|
||||
async def execution_queue_client() -> AsyncRabbitMQ:
|
||||
client = AsyncRabbitMQ(create_execution_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_event_bus() -> AsyncRedisExecutionEventBus:
|
||||
return AsyncRedisExecutionEventBus()
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -206,7 +217,7 @@ async def is_onboarding_enabled():
|
||||
|
||||
@v1_router.get(path="/blocks", tags=["blocks"], dependencies=[Depends(auth_middleware)])
|
||||
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||
blocks = [block() for block in get_blocks().values()]
|
||||
costs = get_block_costs()
|
||||
return [
|
||||
{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks if not b.disabled
|
||||
@@ -219,7 +230,7 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput:
|
||||
obj = backend.data.block.get_block(block_id)
|
||||
obj = get_block(block_id)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||
|
||||
@@ -308,7 +319,7 @@ async def configure_user_auto_top_up(
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_user_auto_top_up(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> AutoTopUpConfig:
|
||||
return await get_auto_top_up(user_id)
|
||||
|
||||
@@ -375,7 +386,7 @@ async def get_credit_history(
|
||||
|
||||
@v1_router.get(path="/credits/refunds", dependencies=[Depends(auth_middleware)])
|
||||
async def get_refund_requests(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[RefundRequest]:
|
||||
return await _user_credit_model.get_refund_requests(user_id)
|
||||
|
||||
@@ -391,7 +402,7 @@ class DeleteGraphResponse(TypedDict):
|
||||
|
||||
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
|
||||
async def get_graphs(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
|
||||
|
||||
@@ -580,16 +591,25 @@ async def set_graph_active_version(
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
def execute_graph(
|
||||
async def execute_graph(
|
||||
graph_id: str,
|
||||
node_input: Annotated[dict[str, Any], Body(..., default_factory=dict)],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
inputs: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
credentials_inputs: Annotated[
|
||||
dict[str, CredentialsMetaInput], Body(..., embed=True, default_factory=dict)
|
||||
],
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> ExecuteGraphResponse:
|
||||
graph_exec = execution_manager_client().add_execution(
|
||||
graph_id, node_input, user_id=user_id, graph_version=graph_version
|
||||
graph_exec = await execution_utils.add_graph_execution_async(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=credentials_inputs,
|
||||
)
|
||||
return ExecuteGraphResponse(graph_exec_id=graph_exec.graph_exec_id)
|
||||
return ExecuteGraphResponse(graph_exec_id=graph_exec.id)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -605,9 +625,7 @@ async def stop_graph_run(
|
||||
):
|
||||
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
|
||||
|
||||
await asyncio.to_thread(
|
||||
lambda: execution_manager_client().cancel_execution(graph_exec_id)
|
||||
)
|
||||
await _cancel_execution(graph_exec_id)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
result = await execution_db.get_graph_execution(
|
||||
@@ -621,6 +639,49 @@ async def stop_graph_run(
|
||||
return result
|
||||
|
||||
|
||||
async def _cancel_execution(graph_exec_id: str):
|
||||
"""
|
||||
Mechanism:
|
||||
1. Set the cancel event
|
||||
2. Graph executor's cancel handler thread detects the event, terminates workers,
|
||||
reinitializes worker pool, and returns.
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
queue_client = await execution_queue_client()
|
||||
await queue_client.publish_message(
|
||||
routing_key="",
|
||||
message=execution_utils.CancelExecutionEvent(
|
||||
graph_exec_id=graph_exec_id
|
||||
).model_dump_json(),
|
||||
exchange=execution_utils.GRAPH_EXECUTION_CANCEL_EXCHANGE,
|
||||
)
|
||||
|
||||
# Update the status of the graph & node executions
|
||||
await execution_db.update_graph_execution_stats(
|
||||
graph_exec_id,
|
||||
execution_db.ExecutionStatus.TERMINATED,
|
||||
)
|
||||
node_execs = [
|
||||
node_exec.model_copy(update={"status": execution_db.ExecutionStatus.TERMINATED})
|
||||
for node_exec in await execution_db.get_node_executions(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
execution_db.ExecutionStatus.QUEUED,
|
||||
execution_db.ExecutionStatus.RUNNING,
|
||||
execution_db.ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
await execution_db.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
execution_db.ExecutionStatus.TERMINATED,
|
||||
)
|
||||
await asyncio.gather(
|
||||
*[execution_event_bus().publish(node_exec) for node_exec in node_execs]
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/executions",
|
||||
tags=["graphs"],
|
||||
@@ -718,14 +779,12 @@ async def create_schedule(
|
||||
detail=f"Graph #{schedule.graph_id} v.{schedule.graph_version} not found.",
|
||||
)
|
||||
|
||||
return await asyncio.to_thread(
|
||||
lambda: execution_scheduler_client().add_execution_schedule(
|
||||
graph_id=schedule.graph_id,
|
||||
graph_version=graph.version,
|
||||
cron=schedule.cron,
|
||||
input_data=schedule.input_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
return await execution_scheduler_client().add_execution_schedule(
|
||||
graph_id=schedule.graph_id,
|
||||
graph_version=graph.version,
|
||||
cron=schedule.cron,
|
||||
input_data=schedule.input_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -734,11 +793,11 @@ async def create_schedule(
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
def delete_schedule(
|
||||
async def delete_schedule(
|
||||
schedule_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
|
||||
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
|
||||
return {"id": schedule_id}
|
||||
|
||||
|
||||
@@ -747,11 +806,11 @@ def delete_schedule(
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
def get_execution_schedules(
|
||||
async def get_execution_schedules(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_id: str | None = None,
|
||||
) -> list[scheduler.ExecutionJobInfo]:
|
||||
return execution_scheduler_client().get_execution_schedules(
|
||||
return await execution_scheduler_client().get_execution_schedules(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
)
|
||||
@@ -792,7 +851,7 @@ async def create_api_key(
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_api_keys(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[APIKeyWithoutHash]:
|
||||
"""List all API keys for the user"""
|
||||
try:
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
import logging
|
||||
import typing
|
||||
|
||||
from autogpt_libs.auth import requires_admin_user
|
||||
from autogpt_libs.auth.depends import get_user_id
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
from prisma import Json
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import admin_get_user_history, get_user_credit_model
|
||||
from backend.server.v2.admin.model import AddUserCreditsResponse, UserHistoryResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["credits", "admin"],
|
||||
dependencies=[Depends(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/add_credits", response_model=AddUserCreditsResponse)
|
||||
async def add_user_credits(
|
||||
user_id: typing.Annotated[str, Body()],
|
||||
amount: typing.Annotated[int, Body()],
|
||||
comments: typing.Annotated[str, Body()],
|
||||
admin_user: typing.Annotated[
|
||||
str,
|
||||
Depends(get_user_id),
|
||||
],
|
||||
):
|
||||
""" """
|
||||
logger.info(f"Admin user {admin_user} is adding {amount} credits to user {user_id}")
|
||||
new_balance, transaction_key = await _user_credit_model._add_transaction(
|
||||
user_id,
|
||||
amount,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
metadata=Json({"admin_id": admin_user, "reason": comments}),
|
||||
)
|
||||
return {
|
||||
"new_balance": new_balance,
|
||||
"transaction_key": transaction_key,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/users_history",
|
||||
response_model=UserHistoryResponse,
|
||||
)
|
||||
async def admin_get_all_user_history(
|
||||
admin_user: typing.Annotated[
|
||||
str,
|
||||
Depends(get_user_id),
|
||||
],
|
||||
search: typing.Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
transaction_filter: typing.Optional[CreditTransactionType] = None,
|
||||
):
|
||||
""" """
|
||||
logger.info(f"Admin user {admin_user} is getting grant history")
|
||||
|
||||
try:
|
||||
resp = await admin_get_user_history(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
search=search,
|
||||
transaction_filter=transaction_filter,
|
||||
)
|
||||
logger.info(f"Admin user {admin_user} got {len(resp.history)} grant history")
|
||||
return resp
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting grant history: {e}")
|
||||
raise e
|
||||
16
autogpt_platform/backend/backend/server/v2/admin/model.py
Normal file
16
autogpt_platform/backend/backend/server/v2/admin/model.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.server.model import Pagination
|
||||
|
||||
|
||||
class UserHistoryResponse(BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
|
||||
history: list[UserTransaction]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class AddUserCreditsResponse(BaseModel):
|
||||
new_balance: int
|
||||
transaction_key: str
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import tempfile
|
||||
import typing
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
@@ -9,6 +10,7 @@ import prisma.enums
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
import backend.util.json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -98,3 +100,47 @@ async def review_submission(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while reviewing the submission"},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/submissions/download/{store_listing_version_id}",
|
||||
tags=["store", "admin"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
|
||||
)
|
||||
async def admin_download_agent_file(
|
||||
user: typing.Annotated[
|
||||
autogpt_libs.auth.models.User,
|
||||
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
|
||||
],
|
||||
store_listing_version_id: str = fastapi.Path(
|
||||
..., description="The ID of the agent to download"
|
||||
),
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""
|
||||
Download the agent file by streaming its content.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): The ID of the agent to download
|
||||
|
||||
Returns:
|
||||
StreamingResponse: A streaming response containing the agent's graph data.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
graph_data = await backend.server.v2.store.db.get_agent(
|
||||
user_id=user.user_id,
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
@@ -13,14 +13,17 @@ import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
import backend.server.v2.store.image_gen as store_image_gen
|
||||
import backend.server.v2.store.media as store_media
|
||||
from backend.data import db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.db import locked_transaction
|
||||
from backend.data.execution import get_graph_execution
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
integration_creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def list_library_agents(
|
||||
@@ -142,7 +145,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
Get a specific agent from the user's library.
|
||||
|
||||
Args:
|
||||
id: ID of the library agent to retrieve.
|
||||
library_agent_id: ID of the library agent to retrieve.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
@@ -208,7 +211,7 @@ async def add_generated_agent_image(
|
||||
async def create_library_agent(
|
||||
graph: backend.data.graph.GraphModel,
|
||||
user_id: str,
|
||||
) -> prisma.models.LibraryAgent:
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Adds an agent to the user's library (LibraryAgent table).
|
||||
|
||||
@@ -229,19 +232,21 @@ async def create_library_agent(
|
||||
)
|
||||
|
||||
try:
|
||||
return await prisma.models.LibraryAgent.prisma().create(
|
||||
agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
isCreatedByUser=(user_id == graph.user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
# Creator={"connect": {"id": graph.user_id}},
|
||||
# Creator={"connect": {"id": agent.userId}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
)
|
||||
),
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(agent)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating agent in library: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to create agent in library") from e
|
||||
@@ -400,13 +405,11 @@ async def add_store_agent_to_library(
|
||||
|
||||
# Check if user already has this agent
|
||||
existing_library_agent = (
|
||||
await prisma.models.LibraryAgent.prisma().find_unique(
|
||||
await prisma.models.LibraryAgent.prisma().find_first(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
}
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
},
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
@@ -414,13 +417,13 @@ async def add_store_agent_to_library(
|
||||
if existing_library_agent:
|
||||
if existing_library_agent.isDeleted:
|
||||
# Even if agent exists it needs to be marked as not deleted
|
||||
await update_library_agent(
|
||||
existing_library_agent.id, user_id, is_deleted=False
|
||||
await set_is_deleted_for_library_agent(
|
||||
user_id, graph.id, graph.version, False
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"User #{user_id} already has graph #{graph.id} "
|
||||
f"v{graph.version} in their library"
|
||||
"in their library"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(existing_library_agent)
|
||||
|
||||
@@ -428,11 +431,8 @@ async def add_store_agent_to_library(
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId=user_id,
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
agentGraphId=graph.id,
|
||||
agentGraphVersion=graph.version,
|
||||
isCreatedByUser=False,
|
||||
),
|
||||
include=library_agent_include(user_id),
|
||||
@@ -452,22 +452,60 @@ async def add_store_agent_to_library(
|
||||
raise store_exceptions.DatabaseError("Failed to add agent to library") from e
|
||||
|
||||
|
||||
async def set_is_deleted_for_library_agent(
|
||||
user_id: str, agent_id: str, agent_version: int, is_deleted: bool
|
||||
) -> None:
|
||||
"""
|
||||
Changes the isDeleted flag for a library agent.
|
||||
|
||||
Args:
|
||||
user_id: The user's library from which the agent is being removed.
|
||||
agent_id: The ID of the agent to remove.
|
||||
agent_version: The version of the agent to remove.
|
||||
is_deleted: Whether the agent is being marked as deleted.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's an issue updating the Library
|
||||
"""
|
||||
logger.debug(
|
||||
f"Setting isDeleted={is_deleted} for agent {agent_id} v{agent_version} "
|
||||
f"in library for user {user_id}"
|
||||
)
|
||||
try:
|
||||
logger.warning(
|
||||
f"Setting isDeleted={is_deleted} for agent {agent_id} v{agent_version} in library for user {user_id}"
|
||||
)
|
||||
count = await prisma.models.LibraryAgent.prisma().update_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": agent_id,
|
||||
"agentGraphVersion": agent_version,
|
||||
},
|
||||
data={"isDeleted": is_deleted},
|
||||
)
|
||||
logger.warning(f"Updated {count} isDeleted library agents")
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error setting agent isDeleted: {e}")
|
||||
raise store_exceptions.DatabaseError(
|
||||
"Failed to set agent isDeleted in library"
|
||||
) from e
|
||||
|
||||
|
||||
##############################################
|
||||
########### Presets DB Functions #############
|
||||
##############################################
|
||||
|
||||
|
||||
async def list_presets(
|
||||
user_id: str, page: int, page_size: int, graph_id: Optional[str] = None
|
||||
async def get_presets(
|
||||
user_id: str, page: int, page_size: int
|
||||
) -> library_model.LibraryAgentPresetResponse:
|
||||
"""
|
||||
Retrieves a paginated list of AgentPresets for the specified user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID whose presets are being retrieved.
|
||||
page: The current page index (1-based).
|
||||
page: The current page index (0-based or 1-based, clarify in your domain).
|
||||
page_size: Number of items to retrieve per page.
|
||||
graph_id: Agent Graph ID to filter by.
|
||||
|
||||
Returns:
|
||||
A LibraryAgentPresetResponse containing a list of presets and pagination info.
|
||||
@@ -479,24 +517,21 @@ async def list_presets(
|
||||
f"Fetching presets for user #{user_id}, page={page}, page_size={page_size}"
|
||||
)
|
||||
|
||||
if page < 1 or page_size < 1:
|
||||
if page < 0 or page_size < 1:
|
||||
logger.warning(
|
||||
"Invalid pagination input: page=%d, page_size=%d", page, page_size
|
||||
)
|
||||
raise store_exceptions.DatabaseError("Invalid pagination parameters")
|
||||
|
||||
query_filter: prisma.types.AgentPresetWhereInput = {"userId": user_id}
|
||||
if graph_id:
|
||||
query_filter["agentGraphId"] = graph_id
|
||||
|
||||
try:
|
||||
presets_records = await prisma.models.AgentPreset.prisma().find_many(
|
||||
where=query_filter,
|
||||
skip=(page - 1) * page_size,
|
||||
where={"userId": user_id},
|
||||
skip=page * page_size,
|
||||
take=page_size,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
total_items = await prisma.models.AgentPreset.prisma().count(where=query_filter)
|
||||
total_items = await prisma.models.AgentPreset.prisma().count(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
total_pages = (total_items + page_size - 1) // page_size
|
||||
|
||||
presets = [
|
||||
@@ -549,142 +584,69 @@ async def get_preset(
|
||||
raise store_exceptions.DatabaseError("Failed to fetch preset") from e
|
||||
|
||||
|
||||
async def create_preset(
|
||||
async def upsert_preset(
|
||||
user_id: str,
|
||||
preset: library_model.LibraryAgentPresetCreatable,
|
||||
preset: library_model.CreateLibraryAgentPresetRequest,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
"""
|
||||
Creates a new AgentPreset for a user.
|
||||
Creates or updates an AgentPreset for a user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user creating the preset.
|
||||
preset: The preset data used for creation.
|
||||
user_id: The ID of the user creating/updating the preset.
|
||||
preset: The preset data used for creation or update.
|
||||
preset_id: An optional preset ID to update; if None, a new preset is created.
|
||||
|
||||
Returns:
|
||||
The newly created LibraryAgentPreset.
|
||||
The newly created or updated LibraryAgentPreset.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's a database error in creating the preset.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Creating preset ({repr(preset.name)}) for user #{user_id}",
|
||||
)
|
||||
try:
|
||||
new_preset = await prisma.models.AgentPreset.prisma().create(
|
||||
data=prisma.types.AgentPresetCreateInput(
|
||||
userId=user_id,
|
||||
name=preset.name,
|
||||
description=preset.description,
|
||||
agentGraphId=preset.graph_id,
|
||||
agentGraphVersion=preset.graph_version,
|
||||
isActive=preset.is_active,
|
||||
InputPresets={
|
||||
"create": [
|
||||
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
|
||||
name=name, data=prisma.fields.Json(data)
|
||||
)
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
},
|
||||
),
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
return library_model.LibraryAgentPreset.from_db(new_preset)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating preset: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to create preset") from e
|
||||
|
||||
|
||||
async def create_preset_from_graph_execution(
|
||||
user_id: str,
|
||||
create_request: library_model.LibraryAgentPresetCreatableFromGraphExecution,
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
"""
|
||||
Creates a new AgentPreset from an AgentGraphExecution.
|
||||
|
||||
Params:
|
||||
user_id: The ID of the user creating the preset.
|
||||
create_request: The data used for creation.
|
||||
|
||||
Returns:
|
||||
The newly created LibraryAgentPreset.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's a database error in creating the preset.
|
||||
"""
|
||||
graph_exec_id = create_request.graph_execution_id
|
||||
graph_execution = await get_graph_execution(user_id, graph_exec_id)
|
||||
if not graph_execution:
|
||||
raise NotFoundError(f"Graph execution #{graph_exec_id} not found")
|
||||
|
||||
logger.debug(
|
||||
f"Creating preset for user #{user_id} from graph execution #{graph_exec_id}",
|
||||
)
|
||||
return await create_preset(
|
||||
user_id=user_id,
|
||||
preset=library_model.LibraryAgentPresetCreatable(
|
||||
inputs=graph_execution.inputs,
|
||||
graph_id=graph_execution.graph_id,
|
||||
graph_version=graph_execution.graph_version,
|
||||
name=create_request.name,
|
||||
description=create_request.description,
|
||||
is_active=create_request.is_active,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def update_preset(
|
||||
user_id: str,
|
||||
preset_id: str,
|
||||
preset: library_model.LibraryAgentPresetUpdatable,
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
"""
|
||||
Updates an existing AgentPreset for a user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user updating the preset.
|
||||
preset_id: The ID of the preset to update.
|
||||
preset: The preset data used for the update.
|
||||
|
||||
Returns:
|
||||
The updated LibraryAgentPreset.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's a database error in updating the preset.
|
||||
DatabaseError: If there's a database error in creating or updating the preset.
|
||||
ValueError: If attempting to update a non-existent preset.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Updating preset #{preset_id} ({repr(preset.name)}) for user #{user_id}",
|
||||
f"Upserting preset #{preset_id} ({repr(preset.name)}) for user #{user_id}",
|
||||
)
|
||||
try:
|
||||
update_data: prisma.types.AgentPresetUpdateInput = {}
|
||||
if preset.name:
|
||||
update_data["name"] = preset.name
|
||||
if preset.description:
|
||||
update_data["description"] = preset.description
|
||||
if preset.inputs:
|
||||
update_data["InputPresets"] = {
|
||||
"create": [
|
||||
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
|
||||
name=name, data=prisma.fields.Json(data)
|
||||
)
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
}
|
||||
if preset.is_active:
|
||||
update_data["isActive"] = preset.is_active
|
||||
|
||||
updated = await prisma.models.AgentPreset.prisma().update(
|
||||
where={"id": preset_id},
|
||||
data=update_data,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not updated:
|
||||
raise ValueError(f"AgentPreset #{preset_id} not found")
|
||||
return library_model.LibraryAgentPreset.from_db(updated)
|
||||
inputs = [
|
||||
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput(
|
||||
name=name, data=prisma.fields.Json(data)
|
||||
)
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
if preset_id:
|
||||
# Update existing preset
|
||||
updated = await prisma.models.AgentPreset.prisma().update(
|
||||
where={"id": preset_id},
|
||||
data={
|
||||
"name": preset.name,
|
||||
"description": preset.description,
|
||||
"isActive": preset.is_active,
|
||||
"InputPresets": {"create": inputs},
|
||||
},
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not updated:
|
||||
raise ValueError(f"AgentPreset #{preset_id} not found")
|
||||
return library_model.LibraryAgentPreset.from_db(updated)
|
||||
else:
|
||||
# Create new preset
|
||||
new_preset = await prisma.models.AgentPreset.prisma().create(
|
||||
data=prisma.types.AgentPresetCreateInput(
|
||||
userId=user_id,
|
||||
name=preset.name,
|
||||
description=preset.description,
|
||||
agentGraphId=preset.graph_id,
|
||||
agentGraphVersion=preset.graph_version,
|
||||
isActive=preset.is_active,
|
||||
InputPresets={"create": inputs},
|
||||
),
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
return library_model.LibraryAgentPreset.from_db(new_preset)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating preset: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to update preset") from e
|
||||
logger.error(f"Database error upserting preset: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to create preset") from e
|
||||
|
||||
|
||||
async def delete_preset(user_id: str, preset_id: str) -> None:
|
||||
@@ -707,3 +669,47 @@ async def delete_preset(user_id: str, preset_id: str) -> None:
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error deleting preset: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to delete preset") from e
|
||||
|
||||
|
||||
async def fork_library_agent(library_agent_id: str, user_id: str):
|
||||
"""
|
||||
Clones a library agent and its underyling graph and nodes (with new ids) for the given user.
|
||||
|
||||
Args:
|
||||
library_agent_id: The ID of the library agent to fork.
|
||||
user_id: The ID of the user who owns the library agent.
|
||||
|
||||
Returns:
|
||||
The forked LibraryAgent.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's an error during the forking process.
|
||||
"""
|
||||
logger.debug(f"Forking library agent {library_agent_id} for user {user_id}")
|
||||
try:
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}-fork_agent"):
|
||||
# Fetch the original agent
|
||||
original_agent = await get_library_agent(library_agent_id, user_id)
|
||||
|
||||
# Check if user owns the library agent
|
||||
# TODO: once we have open/closed sourced agents this needs to be enabled ~kcze
|
||||
# + update library/agents/[id]/page.tsx agent actions
|
||||
# if not original_agent.can_access_graph:
|
||||
# raise store_exceptions.DatabaseError(
|
||||
# f"User {user_id} cannot access library agent graph {library_agent_id}"
|
||||
# )
|
||||
|
||||
# Fork the underlying graph and nodes
|
||||
new_graph = await graph_db.fork_graph(
|
||||
original_agent.graph_id, original_agent.graph_version, user_id
|
||||
)
|
||||
new_graph = await on_graph_activate(
|
||||
new_graph,
|
||||
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
|
||||
)
|
||||
|
||||
# Create a library agent for the new graph
|
||||
return await create_library_agent(new_graph, user_id)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error cloning library agent: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to fork library agent") from e
|
||||
|
||||
@@ -168,62 +168,27 @@ class LibraryAgentResponse(pydantic.BaseModel):
|
||||
pagination: server_model.Pagination
|
||||
|
||||
|
||||
class LibraryAgentPresetCreatable(pydantic.BaseModel):
|
||||
"""
|
||||
Request model used when creating a new preset for a library agent.
|
||||
"""
|
||||
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
|
||||
inputs: block_model.BlockInput
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class LibraryAgentPresetCreatableFromGraphExecution(pydantic.BaseModel):
|
||||
"""
|
||||
Request model used when creating a new preset for a library agent.
|
||||
"""
|
||||
|
||||
graph_execution_id: str
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class LibraryAgentPresetUpdatable(pydantic.BaseModel):
|
||||
"""
|
||||
Request model used when updating a preset for a library agent.
|
||||
"""
|
||||
|
||||
inputs: Optional[block_model.BlockInput] = None
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
class LibraryAgentPreset(pydantic.BaseModel):
|
||||
"""Represents a preset configuration for a library agent."""
|
||||
|
||||
id: str
|
||||
updated_at: datetime.datetime
|
||||
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
is_active: bool
|
||||
|
||||
inputs: block_model.BlockInput
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, preset: prisma.models.AgentPreset) -> "LibraryAgentPreset":
|
||||
if preset.InputPresets is None:
|
||||
raise ValueError("Input values must be included in object")
|
||||
|
||||
input_data: block_model.BlockInput = {}
|
||||
|
||||
for preset_input in preset.InputPresets:
|
||||
for preset_input in preset.InputPresets or []:
|
||||
input_data[preset_input.name] = preset_input.data
|
||||
|
||||
return cls(
|
||||
@@ -245,6 +210,19 @@ class LibraryAgentPresetResponse(pydantic.BaseModel):
|
||||
pagination: server_model.Pagination
|
||||
|
||||
|
||||
class CreateLibraryAgentPresetRequest(pydantic.BaseModel):
|
||||
"""
|
||||
Request model used when creating a new preset for a library agent.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
inputs: block_model.BlockInput
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
is_active: bool
|
||||
|
||||
|
||||
class LibraryAgentFilter(str, Enum):
|
||||
"""Possible filters for searching library agents."""
|
||||
|
||||
|
||||
@@ -22,11 +22,13 @@ async def test_agent_preset_from_db():
|
||||
userId="test-user-123",
|
||||
isDeleted=False,
|
||||
InputPresets=[
|
||||
prisma.models.AgentNodeExecutionInputOutput(
|
||||
id="input-123",
|
||||
time=datetime.datetime.now(),
|
||||
name="input1",
|
||||
data=prisma.Json({"type": "string", "value": "test value"}),
|
||||
prisma.models.AgentNodeExecutionInputOutput.model_validate(
|
||||
{
|
||||
"id": "input-123",
|
||||
"time": datetime.datetime.now(),
|
||||
"name": "input1",
|
||||
"data": '{"type": "string", "value": "test value"}',
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -190,3 +190,14 @@ async def update_library_agent(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update library agent",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/{library_agent_id}/fork")
|
||||
async def fork_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
return await library_db.fork_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -1,39 +1,27 @@
|
||||
import logging
|
||||
from typing import Annotated, Any, Optional
|
||||
from typing import Annotated, Any
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
import autogpt_libs.utils.cache
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, status
|
||||
|
||||
import backend.executor
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
import backend.util.service
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.executor.utils import add_graph_execution_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@autogpt_libs.utils.cache.thread_cached
|
||||
def execution_manager_client() -> backend.executor.ExecutionManager:
|
||||
"""Return a cached instance of ExecutionManager client."""
|
||||
return backend.util.service.get_service_client(backend.executor.ExecutionManager)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/presets",
|
||||
summary="List presets",
|
||||
description="Retrieve a paginated list of presets for the current user.",
|
||||
)
|
||||
async def list_presets(
|
||||
async def get_presets(
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
page: int = Query(default=1, ge=1),
|
||||
page_size: int = Query(default=10, ge=1),
|
||||
graph_id: Optional[str] = Query(
|
||||
description="Allows to filter presets by a specific agent graph"
|
||||
),
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
) -> models.LibraryAgentPresetResponse:
|
||||
"""
|
||||
Retrieve a paginated list of presets for the current user.
|
||||
@@ -42,18 +30,12 @@ async def list_presets(
|
||||
user_id (str): ID of the authenticated user.
|
||||
page (int): Page number for pagination.
|
||||
page_size (int): Number of items per page.
|
||||
graph_id: Allows to filter presets by a specific agent graph.
|
||||
|
||||
Returns:
|
||||
models.LibraryAgentPresetResponse: A response containing the list of presets.
|
||||
"""
|
||||
try:
|
||||
return await db.list_presets(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return await db.get_presets(user_id, page, page_size)
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception occurred while getting presets: {e}")
|
||||
raise HTTPException(
|
||||
@@ -106,17 +88,14 @@ async def get_preset(
|
||||
description="Create a new preset for the current user.",
|
||||
)
|
||||
async def create_preset(
|
||||
preset: (
|
||||
models.LibraryAgentPresetCreatable
|
||||
| models.LibraryAgentPresetCreatableFromGraphExecution
|
||||
),
|
||||
preset: models.CreateLibraryAgentPresetRequest,
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> models.LibraryAgentPreset:
|
||||
"""
|
||||
Create a new library agent preset. Automatically corrects node_input format if needed.
|
||||
|
||||
Args:
|
||||
preset (models.LibraryAgentPresetCreatable): The preset data to create.
|
||||
preset (models.CreateLibraryAgentPresetRequest): The preset data to create.
|
||||
user_id (str): ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
@@ -126,12 +105,7 @@ async def create_preset(
|
||||
HTTPException: If an error occurs while creating the preset.
|
||||
"""
|
||||
try:
|
||||
if isinstance(preset, models.LibraryAgentPresetCreatable):
|
||||
return await db.create_preset(user_id, preset)
|
||||
else:
|
||||
return await db.create_preset_from_graph_execution(user_id, preset)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
return await db.upsert_preset(user_id, preset)
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception occurred while creating preset: {e}")
|
||||
raise HTTPException(
|
||||
@@ -140,22 +114,22 @@ async def create_preset(
|
||||
)
|
||||
|
||||
|
||||
@router.patch(
|
||||
@router.put(
|
||||
"/presets/{preset_id}",
|
||||
summary="Update an existing preset",
|
||||
description="Update an existing preset by its ID.",
|
||||
)
|
||||
async def update_preset(
|
||||
preset_id: str,
|
||||
preset: models.LibraryAgentPresetUpdatable,
|
||||
preset: models.CreateLibraryAgentPresetRequest,
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> models.LibraryAgentPreset:
|
||||
"""
|
||||
Update an existing library agent preset.
|
||||
Update an existing library agent preset. If the preset doesn't exist, it may be created.
|
||||
|
||||
Args:
|
||||
preset_id (str): ID of the preset to update.
|
||||
preset (models.LibraryAgentPresetUpdatable): The preset data to update.
|
||||
preset (models.CreateLibraryAgentPresetRequest): The preset data to update.
|
||||
user_id (str): ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
@@ -165,9 +139,7 @@ async def update_preset(
|
||||
HTTPException: If an error occurs while updating the preset.
|
||||
"""
|
||||
try:
|
||||
return await db.update_preset(
|
||||
user_id=user_id, preset_id=preset_id, preset=preset
|
||||
)
|
||||
return await db.upsert_preset(user_id, preset, preset_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception occurred whilst updating preset: {e}")
|
||||
raise HTTPException(
|
||||
@@ -246,17 +218,17 @@ async def execute_preset(
|
||||
# Merge input overrides with preset inputs
|
||||
merged_node_input = preset.inputs | node_input
|
||||
|
||||
execution = execution_manager_client().add_execution(
|
||||
execution = await add_graph_execution_async(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
data=merged_node_input,
|
||||
user_id=user_id,
|
||||
inputs=merged_node_input,
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
|
||||
|
||||
return {"id": execution.graph_exec_id}
|
||||
return {"id": execution.id}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -793,6 +793,7 @@ async def create_store_version(
|
||||
changes_summary=changes_summary,
|
||||
version=next_version,
|
||||
)
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create new store version"
|
||||
@@ -1361,3 +1362,31 @@ async def get_admin_listings_with_versions(
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def get_agent_as_admin(
|
||||
user_id: str | None,
|
||||
store_listing_version_id: str,
|
||||
) -> GraphModel:
|
||||
"""Get agent using the version ID and store listing version ID."""
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}
|
||||
)
|
||||
)
|
||||
|
||||
if not store_listing_version:
|
||||
raise ValueError(f"Store listing version {store_listing_version_id} not found")
|
||||
|
||||
graph = await backend.data.graph.get_graph_as_admin(
|
||||
user_id=user_id,
|
||||
graph_id=store_listing_version.agentGraphId,
|
||||
version=store_listing_version.agentGraphVersion,
|
||||
for_export=True,
|
||||
)
|
||||
if not graph:
|
||||
raise ValueError(
|
||||
f"Agent {store_listing_version.agentGraphId} v{store_listing_version.agentGraphVersion} not found"
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
@@ -4,20 +4,7 @@ from typing import List
|
||||
import prisma.enums
|
||||
import pydantic
|
||||
|
||||
|
||||
class Pagination(pydantic.BaseModel):
|
||||
total_items: int = pydantic.Field(
|
||||
description="Total number of items.", examples=[42]
|
||||
)
|
||||
total_pages: int = pydantic.Field(
|
||||
description="Total number of pages.", examples=[97]
|
||||
)
|
||||
current_page: int = pydantic.Field(
|
||||
description="Current_page page number.", examples=[1]
|
||||
)
|
||||
page_size: int = pydantic.Field(
|
||||
description="Number of items per page.", examples=[25]
|
||||
)
|
||||
from backend.server.model import Pagination
|
||||
|
||||
|
||||
class MyAgent(pydantic.BaseModel):
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Protocol
|
||||
import uvicorn
|
||||
from autogpt_libs.auth import parse_jwt_token
|
||||
from autogpt_libs.logging.utils import generate_uvicorn_config
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
@@ -19,7 +18,7 @@ from backend.server.model import (
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
WSSubscribeGraphExecutionsRequest,
|
||||
)
|
||||
from backend.util.service import AppProcess, get_service_client
|
||||
from backend.util.service import AppProcess
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -46,13 +45,6 @@ def get_connection_manager():
|
||||
return _connection_manager
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client():
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
try:
|
||||
event_queue = AsyncRedisExecutionEventBus()
|
||||
|
||||
@@ -15,21 +15,25 @@ def to_dict(data) -> dict:
|
||||
|
||||
|
||||
def dumps(data) -> str:
|
||||
return json.dumps(jsonable_encoder(data))
|
||||
return json.dumps(to_dict(data))
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str, *args, target_type: Type[T], **kwargs) -> T: ...
|
||||
def loads(data: str | bytes, *args, target_type: Type[T], **kwargs) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str, *args, **kwargs) -> Any: ...
|
||||
def loads(data: str | bytes, *args, **kwargs) -> Any: ...
|
||||
|
||||
|
||||
def loads(data: str, *args, target_type: Type[T] | None = None, **kwargs) -> Any:
|
||||
def loads(
|
||||
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
|
||||
) -> Any:
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
parsed = json.loads(data, *args, **kwargs)
|
||||
if target_type:
|
||||
return type_match(parsed, target_type)
|
||||
|
||||
@@ -14,9 +14,7 @@ def sentry_init():
|
||||
traces_sample_rate=1.0,
|
||||
profiles_sample_rate=1.0,
|
||||
environment=f"app:{Settings().config.app_env.value}-behave:{Settings().config.behave_as.value}",
|
||||
_experiments={
|
||||
"enable_logs": True,
|
||||
},
|
||||
_experiments={"enable_logs": True},
|
||||
integrations=[
|
||||
LoggingIntegration(sentry_logs_level=logging.INFO),
|
||||
AnthropicIntegration(
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import signal
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from multiprocessing import Process, set_start_method
|
||||
from multiprocessing import Process, get_all_start_methods, set_start_method
|
||||
from typing import Optional
|
||||
|
||||
from backend.util.logging import configure_logging
|
||||
@@ -30,7 +30,12 @@ class AppProcess(ABC):
|
||||
process: Optional[Process] = None
|
||||
cleaned_up = False
|
||||
|
||||
set_start_method("spawn", force=True)
|
||||
if "forkserver" in get_all_start_methods():
|
||||
set_start_method("forkserver", force=True)
|
||||
else:
|
||||
logger.warning("Forkserver start method is not available. Using spawn instead.")
|
||||
set_start_method("spawn", force=True)
|
||||
|
||||
configure_logging()
|
||||
sentry_init()
|
||||
|
||||
|
||||
@@ -73,3 +73,10 @@ def conn_retry(
|
||||
return async_wrapper if is_coroutine else sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
func_retry = retry(
|
||||
reraise=False,
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=30),
|
||||
)
|
||||
|
||||
@@ -5,8 +5,10 @@ import os
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property, update_wrapper
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Concatenate,
|
||||
Coroutine,
|
||||
@@ -42,24 +44,15 @@ api_call_timeout = config.rpc_client_call_timeout
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
EXPOSED_FLAG = "__exposed__"
|
||||
|
||||
|
||||
def expose(func: C) -> C:
|
||||
func = getattr(func, "__func__", func)
|
||||
setattr(func, "__exposed__", True)
|
||||
setattr(func, EXPOSED_FLAG, True)
|
||||
return func
|
||||
|
||||
|
||||
def exposed_run_and_wait(
|
||||
f: Callable[P, Coroutine[None, None, R]]
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
# TODO:
|
||||
# This function lies about its return type to make the DynamicClient
|
||||
# call the function synchronously, fix this when DynamicClient can choose
|
||||
# to call a function synchronously or asynchronously.
|
||||
return expose(f) # type: ignore
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# AppService for IPC service based on HTTP request through FastAPI
|
||||
# --------------------------------------------------
|
||||
@@ -203,7 +196,7 @@ class AppService(BaseAppService, ABC):
|
||||
|
||||
# Register the exposed API routes.
|
||||
for attr_name, attr in vars(type(self)).items():
|
||||
if getattr(attr, "__exposed__", False):
|
||||
if getattr(attr, EXPOSED_FLAG, False):
|
||||
route_path = f"/{attr_name}"
|
||||
self.fastapi_app.add_api_route(
|
||||
route_path,
|
||||
@@ -234,31 +227,52 @@ class AppService(BaseAppService, ABC):
|
||||
AS = TypeVar("AS", bound=AppService)
|
||||
|
||||
|
||||
def close_service_client(client: Any) -> None:
|
||||
if hasattr(client, "close"):
|
||||
client.close()
|
||||
else:
|
||||
logger.warning(f"Client {client} is not closable")
|
||||
class AppServiceClient(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_service_type(cls) -> Type[AppService]:
|
||||
pass
|
||||
|
||||
def health_check(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
@conn_retry("FastAPI client", "Creating service client", max_retry=api_comm_retry)
|
||||
ASC = TypeVar("ASC", bound=AppServiceClient)
|
||||
|
||||
|
||||
@conn_retry("AppService client", "Creating service client", max_retry=api_comm_retry)
|
||||
def get_service_client(
|
||||
service_type: Type[AS],
|
||||
service_client_type: Type[ASC],
|
||||
call_timeout: int | None = api_call_timeout,
|
||||
) -> AS:
|
||||
) -> ASC:
|
||||
class DynamicClient:
|
||||
def __init__(self):
|
||||
service_type = service_client_type.get_service_type()
|
||||
host = service_type.get_host()
|
||||
port = service_type.get_port()
|
||||
self.base_url = f"http://{host}:{port}".rstrip("/")
|
||||
self.client = httpx.Client(
|
||||
|
||||
@cached_property
|
||||
def sync_client(self) -> httpx.Client:
|
||||
return httpx.Client(
|
||||
base_url=self.base_url,
|
||||
timeout=call_timeout,
|
||||
)
|
||||
|
||||
def _call_method(self, method_name: str, **kwargs) -> Any:
|
||||
@cached_property
|
||||
def async_client(self) -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=call_timeout,
|
||||
)
|
||||
|
||||
def _handle_call_method_response(
|
||||
self, response: httpx.Response, method_name: str
|
||||
) -> Any:
|
||||
try:
|
||||
response = self.client.post(method_name, json=to_dict(kwargs))
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
@@ -269,36 +283,102 @@ def get_service_client(
|
||||
*(error.args or [str(e)])
|
||||
)
|
||||
|
||||
def _call_method_sync(self, method_name: str, **kwargs) -> Any:
|
||||
return self._handle_call_method_response(
|
||||
method_name=method_name,
|
||||
response=self.sync_client.post(method_name, json=to_dict(kwargs)),
|
||||
)
|
||||
|
||||
async def _call_method_async(self, method_name: str, **kwargs) -> Any:
|
||||
return self._handle_call_method_response(
|
||||
method_name=method_name,
|
||||
response=await self.async_client.post(
|
||||
method_name, json=to_dict(kwargs)
|
||||
),
|
||||
)
|
||||
|
||||
async def aclose(self):
|
||||
self.sync_client.close()
|
||||
await self.async_client.aclose()
|
||||
|
||||
def close(self):
|
||||
self.client.close()
|
||||
self.sync_client.close()
|
||||
|
||||
def _get_params(self, signature: inspect.Signature, *args, **kwargs) -> dict:
|
||||
if args:
|
||||
arg_names = list(signature.parameters.keys())
|
||||
if arg_names[0] in ("self", "cls"):
|
||||
arg_names = arg_names[1:]
|
||||
kwargs.update(dict(zip(arg_names, args)))
|
||||
return kwargs
|
||||
|
||||
def _get_return(self, expected_return: TypeAdapter | None, result: Any) -> Any:
|
||||
if expected_return:
|
||||
return expected_return.validate_python(result)
|
||||
return result
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
# Try to get the original function from the service type.
|
||||
orig_func = getattr(service_type, name, None)
|
||||
if orig_func is None:
|
||||
raise AttributeError(f"Method {name} not found in {service_type}")
|
||||
original_func = getattr(service_client_type, name, None)
|
||||
if original_func is None:
|
||||
raise AttributeError(
|
||||
f"Method {name} not found in {service_client_type}"
|
||||
)
|
||||
else:
|
||||
name = original_func.__name__
|
||||
|
||||
sig = inspect.signature(orig_func)
|
||||
sig = inspect.signature(original_func)
|
||||
ret_ann = sig.return_annotation
|
||||
if ret_ann != inspect.Signature.empty:
|
||||
expected_return = TypeAdapter(ret_ann)
|
||||
else:
|
||||
expected_return = None
|
||||
|
||||
def method(*args, **kwargs) -> Any:
|
||||
if args:
|
||||
arg_names = list(sig.parameters.keys())
|
||||
if arg_names[0] in ("self", "cls"):
|
||||
arg_names = arg_names[1:]
|
||||
kwargs.update(dict(zip(arg_names, args)))
|
||||
result = self._call_method(name, **kwargs)
|
||||
if expected_return:
|
||||
return expected_return.validate_python(result)
|
||||
return result
|
||||
if inspect.iscoroutinefunction(original_func):
|
||||
|
||||
return method
|
||||
async def async_method(*args, **kwargs) -> Any:
|
||||
params = self._get_params(sig, *args, **kwargs)
|
||||
result = await self._call_method_async(name, **params)
|
||||
return self._get_return(expected_return, result)
|
||||
|
||||
client = cast(AS, DynamicClient())
|
||||
return async_method
|
||||
else:
|
||||
|
||||
def sync_method(*args, **kwargs) -> Any:
|
||||
params = self._get_params(sig, *args, **kwargs)
|
||||
result = self._call_method_sync(name, **params)
|
||||
return self._get_return(expected_return, result)
|
||||
|
||||
return sync_method
|
||||
|
||||
client = cast(ASC, DynamicClient())
|
||||
client.health_check()
|
||||
|
||||
return cast(AS, client)
|
||||
return client
|
||||
|
||||
|
||||
def endpoint_to_sync(
|
||||
func: Callable[Concatenate[Any, P], Awaitable[R]],
|
||||
) -> Callable[Concatenate[Any, P], R]:
|
||||
"""
|
||||
Produce a *typed* stub that **looks** synchronous to the type‑checker.
|
||||
"""
|
||||
|
||||
def _stub(*args: P.args, **kwargs: P.kwargs) -> R: # pragma: no cover
|
||||
raise RuntimeError("should be intercepted by __getattr__")
|
||||
|
||||
update_wrapper(_stub, func)
|
||||
return cast(Callable[Concatenate[Any, P], R], _stub)
|
||||
|
||||
|
||||
def endpoint_to_async(
|
||||
func: Callable[Concatenate[Any, P], R],
|
||||
) -> Callable[Concatenate[Any, P], Awaitable[R]]:
|
||||
"""
|
||||
The async mirror of `to_sync`.
|
||||
"""
|
||||
|
||||
async def _stub(*args: P.args, **kwargs: P.kwargs) -> R: # pragma: no cover
|
||||
raise RuntimeError("should be intercepted by __getattr__")
|
||||
|
||||
update_wrapper(_stub, func)
|
||||
return cast(Callable[Concatenate[Any, P], Awaitable[R]], _stub)
|
||||
|
||||
@@ -117,6 +117,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=1,
|
||||
description="Cost per execution in cents after each threshold.",
|
||||
)
|
||||
execution_counter_expiration_time: int = Field(
|
||||
default=60 * 60 * 24,
|
||||
description="Time in seconds after which the execution counter is reset.",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
-- Modify the OnboardingStep enum
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'GET_RESULTS';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_VISIT';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_ADD_AGENT';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_RUN_AGENT';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_OPEN';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_SAVE_AGENT';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_RUN_AGENT';
|
||||
|
||||
-- Modify the UserOnboarding table
|
||||
ALTER TABLE "UserOnboarding"
|
||||
ADD COLUMN "updatedAt" TIMESTAMP(3),
|
||||
ADD COLUMN "notificationDot" BOOLEAN NOT NULL DEFAULT true,
|
||||
ADD COLUMN "notified" "OnboardingStep"[] DEFAULT '{}',
|
||||
ADD COLUMN "rewardedFor" "OnboardingStep"[] DEFAULT '{}',
|
||||
ADD COLUMN "onboardingAgentExecutionId" TEXT
|
||||
@@ -0,0 +1,56 @@
|
||||
-- Backfill nulls with empty arrays
|
||||
UPDATE "UserOnboarding"
|
||||
SET "integrations" = ARRAY[]::TEXT[]
|
||||
WHERE "integrations" IS NULL;
|
||||
|
||||
UPDATE "UserOnboarding"
|
||||
SET "completedSteps" = '{}'
|
||||
WHERE "completedSteps" IS NULL;
|
||||
|
||||
UPDATE "UserOnboarding"
|
||||
SET "notified" = '{}'
|
||||
WHERE "notified" IS NULL;
|
||||
|
||||
UPDATE "UserOnboarding"
|
||||
SET "rewardedFor" = '{}'
|
||||
WHERE "rewardedFor" IS NULL;
|
||||
|
||||
UPDATE "IntegrationWebhook"
|
||||
SET "events" = ARRAY[]::TEXT[]
|
||||
WHERE "events" IS NULL;
|
||||
|
||||
UPDATE "APIKey"
|
||||
SET "permissions" = '{}'
|
||||
WHERE "permissions" IS NULL;
|
||||
|
||||
UPDATE "Profile"
|
||||
SET "links" = ARRAY[]::TEXT[]
|
||||
WHERE "links" IS NULL;
|
||||
|
||||
UPDATE "StoreListingVersion"
|
||||
SET "imageUrls" = ARRAY[]::TEXT[]
|
||||
WHERE "imageUrls" IS NULL;
|
||||
|
||||
UPDATE "StoreListingVersion"
|
||||
SET "categories" = ARRAY[]::TEXT[]
|
||||
WHERE "categories" IS NULL;
|
||||
|
||||
-- Enforce NOT NULL constraints
|
||||
ALTER TABLE "UserOnboarding"
|
||||
ALTER COLUMN "integrations" SET NOT NULL,
|
||||
ALTER COLUMN "completedSteps" SET NOT NULL,
|
||||
ALTER COLUMN "notified" SET NOT NULL,
|
||||
ALTER COLUMN "rewardedFor" SET NOT NULL;
|
||||
|
||||
ALTER TABLE "IntegrationWebhook"
|
||||
ALTER COLUMN "events" SET NOT NULL;
|
||||
|
||||
ALTER TABLE "APIKey"
|
||||
ALTER COLUMN "permissions" SET NOT NULL;
|
||||
|
||||
ALTER TABLE "Profile"
|
||||
ALTER COLUMN "links" SET NOT NULL;
|
||||
|
||||
ALTER TABLE "StoreListingVersion"
|
||||
ALTER COLUMN "imageUrls" SET NOT NULL,
|
||||
ALTER COLUMN "categories" SET NOT NULL;
|
||||
@@ -0,0 +1,7 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraph"
|
||||
ADD COLUMN "forkedFromId" TEXT,
|
||||
ADD COLUMN "forkedFromVersion" INTEGER;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "AgentGraph" ADD CONSTRAINT "AgentGraph_forkedFromId_forkedFromVersion_fkey" FOREIGN KEY ("forkedFromId", "forkedFromVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
17
autogpt_platform/backend/poetry.lock
generated
17
autogpt_platform/backend/poetry.lock
generated
@@ -3568,6 +3568,21 @@ files = [
|
||||
[package.dependencies]
|
||||
tqdm = "*"
|
||||
|
||||
[[package]]
|
||||
name = "prometheus-client"
|
||||
version = "0.21.1"
|
||||
description = "Python client for the Prometheus monitoring system."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "prometheus_client-0.21.1-py3-none-any.whl", hash = "sha256:594b45c410d6f4f8888940fe80b5cc2521b305a1fafe1c58609ef715a001f301"},
|
||||
{file = "prometheus_client-0.21.1.tar.gz", hash = "sha256:252505a722ac04b0456be05c05f75f45d760c2911ffc45f2a06bcaed9f3ae3fb"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
twisted = ["twisted"]
|
||||
|
||||
[[package]]
|
||||
name = "propcache"
|
||||
version = "0.2.1"
|
||||
@@ -6310,4 +6325,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "781f77ec77cfce78b34fb57063dcc81df8e9c5a4be9a644033a0c197e0063730"
|
||||
content-hash = "29ccee704d8296c57156daab98bb0cbbf5a43e83526b7f08a14c91fb7a4898f4"
|
||||
|
||||
@@ -64,6 +64,7 @@ websockets = "^14.2"
|
||||
youtube-transcript-api = "^0.6.2"
|
||||
zerobouncesdk = "^1.1.1"
|
||||
# NOTE: please insert new dependencies in their alphabetical location
|
||||
prometheus-client = "^0.21.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
aiohappyeyeballs = "^2.6.1"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
datasource db {
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
directUrl = env("DIRECT_URL")
|
||||
}
|
||||
|
||||
generator client {
|
||||
@@ -58,6 +59,7 @@ model User {
|
||||
}
|
||||
|
||||
enum OnboardingStep {
|
||||
// Introductory onboarding (Library)
|
||||
WELCOME
|
||||
USAGE_REASON
|
||||
INTEGRATIONS
|
||||
@@ -65,18 +67,32 @@ enum OnboardingStep {
|
||||
AGENT_NEW_RUN
|
||||
AGENT_INPUT
|
||||
CONGRATS
|
||||
GET_RESULTS
|
||||
// Marketplace
|
||||
MARKETPLACE_VISIT
|
||||
MARKETPLACE_ADD_AGENT
|
||||
MARKETPLACE_RUN_AGENT
|
||||
// Builder
|
||||
BUILDER_OPEN
|
||||
BUILDER_SAVE_AGENT
|
||||
BUILDER_RUN_AGENT
|
||||
}
|
||||
|
||||
model UserOnboarding {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime? @updatedAt
|
||||
|
||||
completedSteps OnboardingStep[] @default([])
|
||||
notificationDot Boolean @default(true)
|
||||
notified OnboardingStep[] @default([])
|
||||
rewardedFor OnboardingStep[] @default([])
|
||||
usageReason String?
|
||||
integrations String[] @default([])
|
||||
otherIntegrations String?
|
||||
selectedStoreListingVersionId String?
|
||||
agentInput Json?
|
||||
onboardingAgentExecutionId String?
|
||||
|
||||
userId String @unique
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
@@ -102,6 +118,11 @@ model AgentGraph {
|
||||
// This allows us to delete user data with deleting the agent which maybe in use by other users
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
forkedFromId String?
|
||||
forkedFromVersion Int?
|
||||
forkedFrom AgentGraph? @relation("AgentGraphForks", fields: [forkedFromId, forkedFromVersion], references: [id, version])
|
||||
forks AgentGraph[] @relation("AgentGraphForks")
|
||||
|
||||
Nodes AgentNode[]
|
||||
Executions AgentGraphExecution[]
|
||||
|
||||
|
||||
@@ -6,10 +6,10 @@ from prisma.models import CreditTransaction
|
||||
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.block import get_block
|
||||
from backend.data.credit import BetaUserCredit
|
||||
from backend.data.credit import BetaUserCredit, UsageTransactionMetadata
|
||||
from backend.data.execution import NodeExecutionEntry
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.executor.utils import UsageTransactionMetadata, block_usage_cost
|
||||
from backend.executor.utils import block_usage_cost
|
||||
from backend.integrations.credentials_store import openai_credentials
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
@@ -46,6 +46,7 @@ async def spend_credits(entry: NodeExecutionEntry) -> int:
|
||||
block_id=entry.block_id,
|
||||
block=entry.block_id,
|
||||
input=matching_filter,
|
||||
reason=f"Ran block {entry.block_id} {block.name}",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from backend.data.execution import merge_execution_input, parse_execution_output
|
||||
from backend.executor.utils import merge_execution_input, parse_execution_output
|
||||
|
||||
|
||||
def test_parse_execution_output():
|
||||
|
||||
@@ -357,7 +357,7 @@ async def test_execute_preset(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
# Create preset with initial values
|
||||
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
|
||||
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
graph_id=test_graph.id,
|
||||
@@ -446,7 +446,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
# Create preset with initial values
|
||||
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
|
||||
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
graph_id=test_graph.id,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from backend.data import db
|
||||
from backend.executor import Scheduler
|
||||
from backend.executor.scheduler import SchedulerClient
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.service import get_service_client
|
||||
@@ -17,11 +17,11 @@ async def test_agent_schedule(server: SpinTestServer):
|
||||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
scheduler = get_service_client(Scheduler)
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
scheduler = get_service_client(SchedulerClient)
|
||||
schedules = await scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
assert len(schedules) == 0
|
||||
|
||||
schedule = scheduler.add_execution_schedule(
|
||||
schedule = await scheduler.add_execution_schedule(
|
||||
graph_id=test_graph.id,
|
||||
user_id=test_user.id,
|
||||
graph_version=1,
|
||||
@@ -30,10 +30,12 @@ async def test_agent_schedule(server: SpinTestServer):
|
||||
)
|
||||
assert schedule
|
||||
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
schedules = await scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
assert len(schedules) == 1
|
||||
assert schedules[0].cron == "0 0 * * *"
|
||||
|
||||
scheduler.delete_schedule(schedule.id, user_id=test_user.id)
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id, user_id=test_user.id)
|
||||
await scheduler.delete_schedule(schedule.id, user_id=test_user.id)
|
||||
schedules = await scheduler.get_execution_schedules(
|
||||
test_graph.id, user_id=test_user.id
|
||||
)
|
||||
assert len(schedules) == 0
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
import pytest
|
||||
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
|
||||
TEST_SERVICE_PORT = 8765
|
||||
|
||||
@@ -32,10 +38,25 @@ class ServiceTest(AppService):
|
||||
return self.run_and_wait(add_async(a, b))
|
||||
|
||||
|
||||
class ServiceTestClient(AppServiceClient):
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return ServiceTest
|
||||
|
||||
add = ServiceTest.add
|
||||
subtract = ServiceTest.subtract
|
||||
fun_with_async = ServiceTest.fun_with_async
|
||||
|
||||
add_async = endpoint_to_async(ServiceTest.add)
|
||||
subtract_async = endpoint_to_async(ServiceTest.subtract)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_service_creation(server):
|
||||
with ServiceTest():
|
||||
client = get_service_client(ServiceTest)
|
||||
client = get_service_client(ServiceTestClient)
|
||||
assert client.add(5, 3) == 8
|
||||
assert client.subtract(10, 4) == 6
|
||||
assert client.fun_with_async(5, 3) == 8
|
||||
assert await client.add_async(5, 3) == 8
|
||||
assert await client.subtract_async(10, 4) == 6
|
||||
|
||||
@@ -15,6 +15,7 @@ services:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
networks:
|
||||
- app-network
|
||||
restart: on-failure
|
||||
@@ -77,6 +78,7 @@ services:
|
||||
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
- SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- REDIS_HOST=redis
|
||||
- REDIS_PORT=6379
|
||||
- RABBITMQ_HOST=rabbitmq
|
||||
@@ -126,6 +128,7 @@ services:
|
||||
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
- SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- REDIS_HOST=redis
|
||||
- REDIS_PORT=6379
|
||||
- REDIS_PASSWORD=password
|
||||
@@ -139,7 +142,7 @@ services:
|
||||
- NOTIFICATIONMANAGER_HOST=rest_server
|
||||
- ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw= # DO NOT USE IN PRODUCTION!!
|
||||
ports:
|
||||
- "8002:8000"
|
||||
- "8002:8002"
|
||||
networks:
|
||||
- app-network
|
||||
|
||||
@@ -167,6 +170,7 @@ services:
|
||||
- DATABASEMANAGER_HOST=rest_server
|
||||
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- REDIS_HOST=redis
|
||||
- REDIS_PORT=6379
|
||||
- REDIS_PASSWORD=password
|
||||
@@ -201,6 +205,7 @@ services:
|
||||
# - NEXT_PUBLIC_SUPABASE_URL=http://kong:8000
|
||||
# - NEXT_PUBLIC_SUPABASE_ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
|
||||
# - DATABASE_URL=postgresql://agpt_user:pass123@postgres:5432/postgres?connect_timeout=60&schema=platform
|
||||
# - DIRECT_URL=postgresql://agpt_user:pass123@postgres:5432/postgres?connect_timeout=60&schema=platform
|
||||
# - NEXT_PUBLIC_AGPT_SERVER_URL=http://localhost:8006/api
|
||||
# - NEXT_PUBLIC_AGPT_WS_SERVER_URL=ws://localhost:8001/ws
|
||||
# - NEXT_PUBLIC_AGPT_MARKETPLACE_URL=http://localhost:8015/api/v1/market
|
||||
|
||||
@@ -24,3 +24,4 @@ GA_MEASUREMENT_ID=G-FH2XK2W4GN
|
||||
|
||||
# When running locally, set NEXT_PUBLIC_BEHAVE_AS=CLOUD to use the a locally hosted marketplace (as is typical in development, and the cloud deployment), otherwise set it to LOCAL to have the marketplace open in a new tab
|
||||
NEXT_PUBLIC_BEHAVE_AS=LOCAL
|
||||
NEXT_PUBLIC_SHOW_BILLING_PAGE=false
|
||||
|
||||
@@ -42,6 +42,7 @@
|
||||
"@radix-ui/react-separator": "^1.1.0",
|
||||
"@radix-ui/react-slot": "^1.1.0",
|
||||
"@radix-ui/react-switch": "^1.1.1",
|
||||
"@radix-ui/react-tabs": "^1.1.4",
|
||||
"@radix-ui/react-toast": "^1.2.5",
|
||||
"@radix-ui/react-tooltip": "^1.1.7",
|
||||
"@sentry/nextjs": "^9",
|
||||
@@ -52,7 +53,6 @@
|
||||
"@xyflow/react": "12.4.2",
|
||||
"ajv": "^8.17.1",
|
||||
"boring-avatars": "^1.11.2",
|
||||
"canvas-confetti": "^1.9.3",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "1.0.4",
|
||||
@@ -70,6 +70,7 @@
|
||||
"moment": "^2.30.1",
|
||||
"next": "^14.2.26",
|
||||
"next-themes": "^0.4.5",
|
||||
"party-js": "^2.2.0",
|
||||
"react": "^18",
|
||||
"react-day-picker": "^9.6.1",
|
||||
"react-dom": "^18",
|
||||
|
||||
BIN
autogpt_platform/frontend/public/onboarding/builder-open.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/builder-open.mp4
Normal file
Binary file not shown.
BIN
autogpt_platform/frontend/public/onboarding/builder-run.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/builder-run.mp4
Normal file
Binary file not shown.
BIN
autogpt_platform/frontend/public/onboarding/builder-save.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/builder-save.mp4
Normal file
Binary file not shown.
BIN
autogpt_platform/frontend/public/onboarding/get-results.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/get-results.mp4
Normal file
Binary file not shown.
BIN
autogpt_platform/frontend/public/onboarding/marketplace-add.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/marketplace-add.mp4
Normal file
Binary file not shown.
BIN
autogpt_platform/frontend/public/onboarding/marketplace-run.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/marketplace-run.mp4
Normal file
Binary file not shown.
Binary file not shown.
@@ -163,14 +163,7 @@ export default function Page() {
|
||||
</div>
|
||||
|
||||
<OnboardingFooter>
|
||||
<OnboardingButton
|
||||
className="mb-2"
|
||||
href="/onboarding/4-agent"
|
||||
disabled={
|
||||
state?.integrations.length === 0 &&
|
||||
isEmptyOrWhitespace(state.otherIntegrations)
|
||||
}
|
||||
>
|
||||
<OnboardingButton className="mb-2" href="/onboarding/4-agent">
|
||||
Next
|
||||
</OnboardingButton>
|
||||
</OnboardingFooter>
|
||||
@@ -59,7 +59,7 @@ export default function Page() {
|
||||
|
||||
<div className="my-12 flex items-center justify-between gap-5">
|
||||
<OnboardingAgentCard
|
||||
{...(agents[0] || {})}
|
||||
agent={agents[0]}
|
||||
selected={
|
||||
agents[0] !== undefined
|
||||
? state?.selectedStoreListingVersionId ==
|
||||
@@ -74,7 +74,7 @@ export default function Page() {
|
||||
}
|
||||
/>
|
||||
<OnboardingAgentCard
|
||||
{...(agents[1] || {})}
|
||||
agent={agents[1]}
|
||||
selected={
|
||||
agents[1] !== undefined
|
||||
? state?.selectedStoreListingVersionId ==
|
||||
@@ -9,7 +9,6 @@ import StarRating from "@/components/onboarding/StarRating";
|
||||
import { Play } from "lucide-react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import Image from "next/image";
|
||||
import { GraphMeta, StoreAgentDetails } from "@/lib/autogpt-server-api";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { useRouter } from "next/navigation";
|
||||
@@ -17,6 +16,7 @@ import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import SchemaTooltip from "@/components/SchemaTooltip";
|
||||
import { TypeBasedInput } from "@/components/type-based-input";
|
||||
import SmartImage from "@/components/agptui/SmartImage";
|
||||
|
||||
export default function Page() {
|
||||
const { state, updateState, setStep } = useOnboarding(
|
||||
@@ -80,15 +80,26 @@ export default function Page() {
|
||||
if (!agent) {
|
||||
return;
|
||||
}
|
||||
api.addMarketplaceAgentToLibrary(
|
||||
storeAgent?.store_listing_version_id || "",
|
||||
);
|
||||
api.executeGraph(agent.id, agent.version, state?.agentInput || {});
|
||||
router.push("/onboarding/6-congrats");
|
||||
}, [api, agent, router, state?.agentInput]);
|
||||
api
|
||||
.addMarketplaceAgentToLibrary(storeAgent?.store_listing_version_id || "")
|
||||
.then((libraryAgent) => {
|
||||
api
|
||||
.executeGraph(
|
||||
libraryAgent.graph_id,
|
||||
libraryAgent.graph_version,
|
||||
state?.agentInput || {},
|
||||
)
|
||||
.then(({ graph_exec_id }) => {
|
||||
updateState({
|
||||
onboardingAgentExecutionId: graph_exec_id,
|
||||
});
|
||||
router.push("/onboarding/6-congrats");
|
||||
});
|
||||
});
|
||||
}, [api, agent, router, state?.agentInput, storeAgent, updateState]);
|
||||
|
||||
const runYourAgent = (
|
||||
<div className="ml-[54px] w-[481px] pl-5">
|
||||
<div className="ml-[104px] w-[481px] pl-5">
|
||||
<div className="flex flex-col">
|
||||
<OnboardingText variant="header">Run your first agent</OnboardingText>
|
||||
<span className="mt-9 text-base font-normal leading-normal text-zinc-600">
|
||||
@@ -136,32 +147,25 @@ export default function Page() {
|
||||
return (
|
||||
<OnboardingStep dotted>
|
||||
<OnboardingHeader backHref={"/onboarding/4-agent"} transparent />
|
||||
<div
|
||||
className={cn(
|
||||
"flex w-full items-center justify-center",
|
||||
showInput ? "mt-[32px]" : "mt-[192px]",
|
||||
)}
|
||||
>
|
||||
{/* Left side */}
|
||||
<div className="mr-[52px] w-[481px]">
|
||||
<div className="h-[156px] w-[481px] rounded-xl bg-white px-6 pb-5 pt-4">
|
||||
<span className="font-sans text-xs font-medium tracking-wide text-zinc-500">
|
||||
SELECTED AGENT
|
||||
</span>
|
||||
{/* Agent card */}
|
||||
<div className="fixed left-1/4 top-1/2 w-[481px] -translate-x-1/2 -translate-y-1/2">
|
||||
<div className="h-[156px] w-[481px] rounded-xl bg-white px-6 pb-5 pt-4">
|
||||
<span className="font-sans text-xs font-medium tracking-wide text-zinc-500">
|
||||
SELECTED AGENT
|
||||
</span>
|
||||
{storeAgent ? (
|
||||
<div className="mt-4 flex h-20 rounded-lg bg-violet-50 p-2">
|
||||
{/* Left image */}
|
||||
<Image
|
||||
src={storeAgent?.agent_image[0] || ""}
|
||||
alt="Description"
|
||||
width={350}
|
||||
height={196}
|
||||
className="h-full w-auto rounded-lg object-contain"
|
||||
<SmartImage
|
||||
src={storeAgent?.agent_image[0]}
|
||||
alt="Agent cover"
|
||||
imageContain
|
||||
className="w-[350px] rounded-lg"
|
||||
/>
|
||||
|
||||
{/* Right content */}
|
||||
<div className="ml-2 flex flex-1 flex-col">
|
||||
<span className="w-[292px] truncate font-sans text-[14px] font-medium leading-normal text-zinc-800">
|
||||
{agent?.name}
|
||||
{storeAgent?.agent_name}
|
||||
</span>
|
||||
<span className="mt-[5px] w-[292px] truncate font-sans text-xs font-normal leading-tight text-zinc-600">
|
||||
by {storeAgent?.creator}
|
||||
@@ -178,13 +182,19 @@ export default function Page() {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="mt-4 flex h-20 animate-pulse rounded-lg bg-gray-300 p-2" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex min-h-[80vh] items-center justify-center">
|
||||
{/* Left side */}
|
||||
<div className="w-[481px]" />
|
||||
{/* Right side */}
|
||||
{!showInput ? (
|
||||
runYourAgent
|
||||
) : (
|
||||
<div className="ml-[54px] w-[481px] pl-5">
|
||||
<div className="ml-[104px] w-[481px] pl-5">
|
||||
<div className="flex flex-col">
|
||||
<OnboardingText variant="header">
|
||||
Provide details for your agent
|
||||
@@ -0,0 +1,18 @@
|
||||
"use server";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { revalidatePath } from "next/cache";
|
||||
import { redirect } from "next/navigation";
|
||||
|
||||
export async function finishOnboarding() {
|
||||
const api = new BackendAPI();
|
||||
const onboarding = await api.getUserOnboarding();
|
||||
const listingId = onboarding?.selectedStoreListingVersionId;
|
||||
if (listingId) {
|
||||
const libraryAgent = await api.addMarketplaceAgentToLibrary(listingId);
|
||||
revalidatePath(`/library/agents/${libraryAgent.id}`, "layout");
|
||||
redirect(`/library/agents/${libraryAgent.id}`);
|
||||
} else {
|
||||
revalidatePath("/library", "layout");
|
||||
redirect("/library");
|
||||
}
|
||||
}
|
||||
@@ -1,24 +1,26 @@
|
||||
"use client";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { finishOnboarding } from "./actions";
|
||||
import confetti from "canvas-confetti";
|
||||
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
import * as party from "party-js";
|
||||
|
||||
export default function Page() {
|
||||
useOnboarding(7, "AGENT_INPUT");
|
||||
const { state, updateState } = useOnboarding(7, "AGENT_INPUT");
|
||||
const [showText, setShowText] = useState(false);
|
||||
const [showSubtext, setShowSubtext] = useState(false);
|
||||
const divRef = useRef(null);
|
||||
|
||||
useEffect(() => {
|
||||
confetti({
|
||||
particleCount: 120,
|
||||
spread: 360,
|
||||
shapes: ["square", "circle"],
|
||||
scalar: 2,
|
||||
decay: 0.93,
|
||||
origin: { y: 0.38, x: 0.51 },
|
||||
});
|
||||
if (divRef.current) {
|
||||
party.confetti(divRef.current, {
|
||||
count: 100,
|
||||
spread: 180,
|
||||
shapes: ["square", "circle"],
|
||||
size: party.variation.range(2, 2), // scalar: 2
|
||||
speed: party.variation.range(300, 1000),
|
||||
});
|
||||
}
|
||||
|
||||
const timer0 = setTimeout(() => {
|
||||
setShowText(true);
|
||||
@@ -29,6 +31,9 @@ export default function Page() {
|
||||
}, 500);
|
||||
|
||||
const timer2 = setTimeout(() => {
|
||||
updateState({
|
||||
completedSteps: [...(state?.completedSteps || []), "CONGRATS"],
|
||||
});
|
||||
finishOnboarding();
|
||||
}, 3000);
|
||||
|
||||
@@ -42,6 +47,7 @@ export default function Page() {
|
||||
return (
|
||||
<div className="flex h-screen w-screen flex-col items-center justify-center bg-violet-100">
|
||||
<div
|
||||
ref={divRef}
|
||||
className={cn(
|
||||
"z-10 -mb-16 text-9xl duration-500",
|
||||
showText ? "opacity-100" : "opacity-0",
|
||||
@@ -63,7 +69,7 @@ export default function Page() {
|
||||
showSubtext ? "opacity-100" : "opacity-0",
|
||||
)}
|
||||
>
|
||||
You earned 15$ for running your first agent
|
||||
You earned 3$ for running your first agent
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
@@ -7,9 +7,9 @@ export default function OnboardingLayout({
|
||||
}) {
|
||||
return (
|
||||
<div className="flex min-h-screen w-full items-center justify-center bg-gray-100">
|
||||
<div className="mx-auto flex w-full flex-col items-center">
|
||||
<main className="mx-auto flex w-full flex-col items-center">
|
||||
{children}
|
||||
</div>
|
||||
</main>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -13,7 +13,7 @@ export default async function OnboardingPage() {
|
||||
// CONGRATS is the last step in intro onboarding
|
||||
if (onboarding.completedSteps.includes("CONGRATS")) redirect("/marketplace");
|
||||
else if (onboarding.completedSteps.includes("AGENT_INPUT"))
|
||||
redirect("/onboarding/6-congrats");
|
||||
redirect("/onboarding/5-run");
|
||||
else if (onboarding.completedSteps.includes("AGENT_NEW_RUN"))
|
||||
redirect("/onboarding/5-run");
|
||||
else if (onboarding.completedSteps.includes("AGENT_CHOICE"))
|
||||
@@ -5,11 +5,14 @@ export default async function OnboardingResetPage() {
|
||||
const api = new BackendAPI();
|
||||
await api.updateUserOnboarding({
|
||||
completedSteps: [],
|
||||
notificationDot: true,
|
||||
notified: [],
|
||||
usageReason: null,
|
||||
integrations: [],
|
||||
otherIntegrations: "",
|
||||
selectedStoreListingVersionId: null,
|
||||
agentInput: {},
|
||||
onboardingAgentExecutionId: null,
|
||||
});
|
||||
redirect("/onboarding/1-welcome");
|
||||
}
|
||||
@@ -56,3 +56,9 @@ export async function getAdminListingsWithVersions(
|
||||
const response = await api.getAdminListingsWithVersions(data);
|
||||
return response;
|
||||
}
|
||||
|
||||
export async function downloadAsAdmin(storeListingVersion: string) {
|
||||
const api = new BackendApi();
|
||||
const file = await api.downloadStoreAgentAdmin(storeListingVersion);
|
||||
return file;
|
||||
}
|
||||
@@ -3,9 +3,16 @@
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { GraphID } from "@/lib/autogpt-server-api/types";
|
||||
import FlowEditor from "@/components/Flow";
|
||||
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
import { useEffect } from "react";
|
||||
|
||||
export default function Home() {
|
||||
const query = useSearchParams();
|
||||
const { completeStep } = useOnboarding();
|
||||
|
||||
useEffect(() => {
|
||||
completeStep("BUILDER_OPEN");
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<FlowEditor
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user