mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
61 Commits
swiftyos/a
...
ci-chromat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54bbafc431 | ||
|
|
5662783624 | ||
|
|
a5f448af98 | ||
|
|
c766bd66e1 | ||
|
|
6d11ad8051 | ||
|
|
d476983bd2 | ||
|
|
3ac1ce5a3f | ||
|
|
3b89e6d2b7 | ||
|
|
c7a7652b9f | ||
|
|
b6b0d0b209 | ||
|
|
a5b1495062 | ||
|
|
026f16c10f | ||
|
|
c468201c53 | ||
|
|
5beb581d1c | ||
|
|
df2339c1cf | ||
|
|
327db54321 | ||
|
|
234d6f78ba | ||
|
|
43088ddff8 | ||
|
|
fd955fba25 | ||
|
|
83943d9ddb | ||
|
|
60c26e62f6 | ||
|
|
1fc8f9ba66 | ||
|
|
33d747f457 | ||
|
|
06fa001a37 | ||
|
|
4e7b56b814 | ||
|
|
d6b03a4f18 | ||
|
|
fae9aeb49a | ||
|
|
5e8c1e274e | ||
|
|
55f7dc4853 | ||
|
|
b317adb9cf | ||
|
|
c873ba04b8 | ||
|
|
00f0311dd0 | ||
|
|
9b2bd756fa | ||
|
|
bceb83ca30 | ||
|
|
eadbfcd920 | ||
|
|
9768540b60 | ||
|
|
697436be07 | ||
|
|
d725e105a0 | ||
|
|
927f43f52f | ||
|
|
eedcc92d6f | ||
|
|
f0c378c70d | ||
|
|
c6c2b852df | ||
|
|
aaab8b1e0e | ||
|
|
a4eeb4535a | ||
|
|
db068c598c | ||
|
|
d4d9efc73e | ||
|
|
ffaf77df4e | ||
|
|
2daf08434e | ||
|
|
745137f4c2 | ||
|
|
3a2c3deb0e | ||
|
|
66a15a7b8c | ||
|
|
669c61de76 | ||
|
|
e860bde3d4 | ||
|
|
f5394f6d65 | ||
|
|
06e845abe7 | ||
|
|
c2c3c29018 | ||
|
|
31fd0b557a | ||
|
|
9350fe1d2b | ||
|
|
5ae92820b4 | ||
|
|
66a87e5a14 | ||
|
|
e1f8882e2d |
24
.github/workflows/platform-frontend-ci.yml
vendored
24
.github/workflows/platform-frontend-ci.yml
vendored
@@ -56,6 +56,30 @@ jobs:
|
||||
run: |
|
||||
yarn type-check
|
||||
|
||||
design:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
yarn install --frozen-lockfile
|
||||
|
||||
- name: Run Chromatic
|
||||
uses: chromaui/action@latest
|
||||
with:
|
||||
# ⚠️ Make sure to configure a `CHROMATIC_PROJECT_TOKEN` repository secret
|
||||
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
|
||||
workingDir: autogpt_platform/frontend
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
|
||||
@@ -16,7 +16,7 @@ jobs:
|
||||
# operations-per-run: 5000
|
||||
stale-issue-message: >
|
||||
This issue has automatically been marked as _stale_ because it has not had
|
||||
any activity in the last 170 days. You can _unstale_ it by commenting or
|
||||
any activity in the last 50 days. You can _unstale_ it by commenting or
|
||||
removing the label. Otherwise, this issue will be closed in 10 days.
|
||||
stale-pr-message: >
|
||||
This pull request has automatically been marked as _stale_ because it has
|
||||
@@ -25,7 +25,7 @@ jobs:
|
||||
close-issue-message: >
|
||||
This issue was closed automatically because it has been stale for 10 days
|
||||
with no activity.
|
||||
days-before-stale: 170
|
||||
days-before-stale: 100
|
||||
days-before-close: 10
|
||||
# Do not touch meta issues:
|
||||
exempt-issue-labels: meta,fridge,project management
|
||||
|
||||
@@ -31,7 +31,7 @@ class RedisKeyedMutex:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if lock.locked() and lock.owned():
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
|
||||
def acquire(self, key: Any) -> "RedisLock":
|
||||
|
||||
@@ -66,13 +66,6 @@ MEDIA_GCS_BUCKET_NAME=
|
||||
## and tunnel it to your locally running backend.
|
||||
PLATFORM_BASE_URL=http://localhost:3000
|
||||
|
||||
## Cloudflare Turnstile (CAPTCHA) Configuration
|
||||
## Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
|
||||
## This is the backend secret key
|
||||
TURNSTILE_SECRET_KEY=
|
||||
## This is the verify URL
|
||||
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
|
||||
|
||||
## == INTEGRATION CREDENTIALS == ##
|
||||
# Each set of server side credentials is required for the corresponding 3rd party
|
||||
# integration to work.
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
@@ -11,11 +10,17 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@functools.cache
|
||||
_AVAILABLE_BLOCKS: dict[str, type["Block"]] = {}
|
||||
|
||||
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
|
||||
if _AVAILABLE_BLOCKS:
|
||||
return _AVAILABLE_BLOCKS
|
||||
|
||||
# Dynamically load all modules under backend.blocks
|
||||
AVAILABLE_MODULES = []
|
||||
current_dir = Path(__file__).parent
|
||||
modules = [
|
||||
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
|
||||
@@ -30,9 +35,9 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
)
|
||||
|
||||
importlib.import_module(f".{module}", package=__name__)
|
||||
AVAILABLE_MODULES.append(module)
|
||||
|
||||
# Load all Block instances from the available modules
|
||||
available_blocks: dict[str, type["Block"]] = {}
|
||||
for block_cls in all_subclasses(Block):
|
||||
class_name = block_cls.__name__
|
||||
|
||||
@@ -53,7 +58,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
f"Block ID {block.name} error: {block.id} is not a valid UUID"
|
||||
)
|
||||
|
||||
if block.id in available_blocks:
|
||||
if block.id in _AVAILABLE_BLOCKS:
|
||||
raise ValueError(
|
||||
f"Block ID {block.name} error: {block.id} is already in use"
|
||||
)
|
||||
@@ -84,9 +89,9 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
f"{block.name} has a boolean field with no default value"
|
||||
)
|
||||
|
||||
available_blocks[block.id] = block_cls
|
||||
_AVAILABLE_BLOCKS[block.id] = block_cls
|
||||
|
||||
return available_blocks
|
||||
return _AVAILABLE_BLOCKS
|
||||
|
||||
|
||||
__all__ = ["load_all_blocks"]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
@@ -11,7 +11,7 @@ from backend.data.block import (
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import CredentialsMetaInput, SchemaField
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -23,21 +23,17 @@ class AgentExecutorBlock(Block):
|
||||
graph_id: str = SchemaField(description="Graph ID")
|
||||
graph_version: int = SchemaField(description="Graph Version")
|
||||
|
||||
inputs: BlockInput = SchemaField(description="Input data for the graph")
|
||||
data: BlockInput = SchemaField(description="Input data for the graph")
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
output_schema: dict = SchemaField(description="Output schema for the graph")
|
||||
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = SchemaField(default=None, hidden=True)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
||||
return data.get("input_schema", {})
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
return data.get("inputs", {})
|
||||
return data.get("data", {})
|
||||
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
@@ -71,8 +67,7 @@ class AgentExecutorBlock(Block):
|
||||
graph_id=input_data.graph_id,
|
||||
graph_version=input_data.graph_version,
|
||||
user_id=input_data.user_id,
|
||||
inputs=input_data.inputs,
|
||||
node_credentials_input_map=input_data.node_credentials_input_map,
|
||||
inputs=input_data.data,
|
||||
)
|
||||
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.id}"
|
||||
logger.info(f"Starting execution of {log_id}")
|
||||
|
||||
@@ -88,33 +88,6 @@ class StoreValueBlock(Block):
|
||||
yield "output", input_data.data or input_data.input
|
||||
|
||||
|
||||
class PrintToConsoleBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: Any = SchemaField(description="The data to print to the console.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The data printed to the console.")
|
||||
status: str = SchemaField(description="The status of the print operation.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
|
||||
description="Print the given text to the console, this is used for a debugging purpose.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=PrintToConsoleBlock.Input,
|
||||
output_schema=PrintToConsoleBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
test_output=[
|
||||
("output", "Hello, World!"),
|
||||
("status", "printed"),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "output", input_data.text
|
||||
yield "status", "printed"
|
||||
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockWebhookConfig,
|
||||
)
|
||||
from backend.data.model import SchemaField, APIKeyCredentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ExaCredentialsField,
|
||||
ExaCredentialsInput,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExaTriggerBase:
|
||||
"""Base class for Exa webhook triggers."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
"""Base input schema for Exa triggers."""
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
# --8<-- [start:example-payload-field]
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
# --8<-- [end:example-payload-field]
|
||||
|
||||
class Output(BlockSchema):
|
||||
"""Base output schema for Exa triggers."""
|
||||
payload: dict = SchemaField(
|
||||
description="The complete webhook payload that was received from Exa. "
|
||||
"Includes information about the event type, data, and creation timestamp."
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the payload could not be processed"
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
"""Process the webhook payload from Exa.
|
||||
|
||||
Args:
|
||||
input_data: The input data containing the webhook payload
|
||||
|
||||
Yields:
|
||||
The complete webhook payload
|
||||
"""
|
||||
yield "payload", input_data.payload
|
||||
|
||||
|
||||
class ExaWebsetTriggerBlock(ExaTriggerBase, Block):
|
||||
"""Block for handling Exa Webset webhook events.
|
||||
|
||||
This block triggers on various Exa Webset events such as webset creation,
|
||||
deletion, search completion, etc. and outputs the event details.
|
||||
"""
|
||||
|
||||
EXAMPLE_PAYLOAD_FILE = (
|
||||
Path(__file__).parent / "example_payloads" / "webset.created.json"
|
||||
)
|
||||
|
||||
class Input(ExaTriggerBase.Input):
|
||||
"""Input schema for Exa Webset trigger with event filtering options."""
|
||||
|
||||
class EventsFilter(BaseModel):
|
||||
"""
|
||||
Event filter options for Exa Webset webhooks.
|
||||
|
||||
See: https://docs.exa.ai/api-reference/webhooks
|
||||
"""
|
||||
# Webset events
|
||||
webset_created: bool = False
|
||||
webset_deleted: bool = False
|
||||
webset_paused: bool = False
|
||||
webset_idle: bool = False
|
||||
|
||||
# Search events
|
||||
webset_search_created: bool = False
|
||||
webset_search_updated: bool = False
|
||||
webset_search_completed: bool = False
|
||||
webset_search_canceled: bool = False
|
||||
|
||||
# Item events
|
||||
webset_item_created: bool = False
|
||||
webset_item_enriched: bool = False
|
||||
|
||||
# Export events
|
||||
webset_export_created: bool = False
|
||||
webset_export_completed: bool = False
|
||||
|
||||
events: EventsFilter = SchemaField(
|
||||
title="Events", description="The events to subscribe to"
|
||||
)
|
||||
|
||||
class Output(ExaTriggerBase.Output):
|
||||
"""Output schema for Exa Webset trigger with event-specific fields."""
|
||||
|
||||
event_type: str = SchemaField(
|
||||
description="The type of event that triggered the webhook"
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID of the affected webset"
|
||||
)
|
||||
created_at: str = SchemaField(
|
||||
description="Timestamp when the event was created"
|
||||
)
|
||||
data: dict = SchemaField(
|
||||
description="Object containing the full resource that triggered the event"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Exa Webset trigger block with its configuration."""
|
||||
|
||||
# Define a webhook type constant for Exa
|
||||
class ExaWebhookType:
|
||||
"""Constants for Exa webhook types."""
|
||||
WEBSET = "webset"
|
||||
|
||||
# Create example payload
|
||||
example_payload = {
|
||||
"id": "663de972-bfe7-47ef-b4d7-179cfed7aa44",
|
||||
"object": "event",
|
||||
"type": "webset.created",
|
||||
"data": {
|
||||
"id": "wbs_123456789",
|
||||
"name": "Example Webset",
|
||||
"description": "An example webset for testing"
|
||||
},
|
||||
"createdAt": "2023-06-01T12:00:00Z"
|
||||
}
|
||||
|
||||
# Map UI event names to API event names
|
||||
self.event_mapping = {
|
||||
"webset_created": "webset.created",
|
||||
"webset_deleted": "webset.deleted",
|
||||
"webset_paused": "webset.paused",
|
||||
"webset_idle": "webset.idle",
|
||||
"webset_search_created": "webset.search.created",
|
||||
"webset_search_updated": "webset.search.updated",
|
||||
"webset_search_completed": "webset.search.completed",
|
||||
"webset_search_canceled": "webset.search.canceled",
|
||||
"webset_item_created": "webset.item.created",
|
||||
"webset_item_enriched": "webset.item.enriched",
|
||||
"webset_export_created": "webset.export.created",
|
||||
"webset_export_completed": "webset.export.completed"
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
id="804ac1ed-d692-4ccb-a390-739a846a2667",
|
||||
description="This block triggers on Exa Webset events and outputs the event type and payload.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
|
||||
input_schema=ExaWebsetTriggerBlock.Input,
|
||||
output_schema=ExaWebsetTriggerBlock.Output,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName.EXA,
|
||||
webhook_type=ExaWebhookType.WEBSET,
|
||||
resource_format="", # Exa doesn't require a specific resource format
|
||||
event_filter_input="events",
|
||||
event_format="webset.{event}",
|
||||
),
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"events": {"webset_created": True, "webset_search_completed": True},
|
||||
"payload": example_payload,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("payload", example_payload),
|
||||
("event_type", example_payload["type"]),
|
||||
("webset_id", "wbs_123456789"),
|
||||
("created_at", "2023-06-01T12:00:00Z"),
|
||||
("data", example_payload["data"]),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
|
||||
"""Process Exa Webset webhook events.
|
||||
|
||||
Args:
|
||||
input_data: The input data containing the webhook payload and event filter
|
||||
|
||||
Yields:
|
||||
Event details including event type, webset ID, creation timestamp, and data,
|
||||
or an error message if the event type doesn't match the filter or if required
|
||||
fields are missing from the payload.
|
||||
"""
|
||||
yield from super().run(input_data, **kwargs)
|
||||
try:
|
||||
# Get the event type from the payload
|
||||
event_type = input_data.payload["type"]
|
||||
|
||||
# Check if this event type is in the user's selected events
|
||||
# Convert API event name to UI event name (reverse mapping)
|
||||
ui_event_name = next((k for k, v in self.event_mapping.items() if v == event_type), None)
|
||||
|
||||
# Only process events that match the filter
|
||||
if ui_event_name and getattr(input_data.events, ui_event_name, False):
|
||||
yield "event_type", event_type
|
||||
yield "webset_id", input_data.payload["data"]["id"]
|
||||
yield "created_at", input_data.payload["createdAt"]
|
||||
yield "data", input_data.payload["data"]
|
||||
else:
|
||||
yield "error", f"Event type {event_type} not in selected events filter"
|
||||
except KeyError as e:
|
||||
yield "error", f"Missing expected field in payload: {str(e)}"
|
||||
@@ -1,321 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from backend.blocks.exa._auth import (
|
||||
ExaCredentials,
|
||||
ExaCredentialsField,
|
||||
ExaCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
# --- Create Webset Block ---
|
||||
class ExaCreateWebsetBlock(Block):
|
||||
"""Block for creating a Webset using Exa's Websets API."""
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
query: str = SchemaField(description="The search query for the webset")
|
||||
count: int = SchemaField(description="Number of results to return", default=5)
|
||||
enrichments: Optional[List[Dict[str, Any]]] = SchemaField(
|
||||
description="List of enrichment dicts (optional)", default_factory=list, advanced=True
|
||||
)
|
||||
external_id: Optional[str] = SchemaField(
|
||||
description="Optional external identifier", default=None, advanced=True
|
||||
)
|
||||
metadata: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Optional metadata", default_factory=dict, advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="The created webset object (or None if error)", default=None
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the ExaCreateWebsetBlock with its configuration."""
|
||||
super().__init__(
|
||||
id="322351cc-35d7-45ec-8920-9a3c98920411",
|
||||
description="Creates a Webset using Exa's Websets API",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCreateWebsetBlock.Input,
|
||||
output_schema=ExaCreateWebsetBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *, credentials: ExaCredentials, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Execute the block to create a webset with Exa's API.
|
||||
|
||||
Args:
|
||||
input_data: The input parameters for creating a webset
|
||||
credentials: The Exa API credentials
|
||||
|
||||
Yields:
|
||||
Either the created webset object or an error message
|
||||
"""
|
||||
url = "https://api.exa.ai/websets/v0/websets"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
payload = {
|
||||
"search": {
|
||||
"query": input_data.query,
|
||||
"count": input_data.count,
|
||||
}
|
||||
}
|
||||
optional_fields = {}
|
||||
if isinstance(input_data.enrichments, list) and input_data.enrichments:
|
||||
optional_fields["enrichments"] = input_data.enrichments
|
||||
if isinstance(input_data.external_id, str) and input_data.external_id:
|
||||
optional_fields["externalId"] = input_data.external_id
|
||||
if isinstance(input_data.metadata, dict) and input_data.metadata:
|
||||
optional_fields["metadata"] = input_data.metadata
|
||||
payload.update(optional_fields)
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
yield "webset", data
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
# --- Get Webset Block ---
|
||||
class ExaGetWebsetBlock(Block):
|
||||
"""Block for retrieving a Webset by ID using Exa's Websets API."""
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
webset_id: str = SchemaField(description="The Webset ID or externalId")
|
||||
expand_items: bool = SchemaField(description="Expand with items", default=False, advanced=True)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset: Optional[Dict[str, Any]] = SchemaField(description="The webset object (or None if error)", default=None)
|
||||
error: str = SchemaField(description="Error message if the request failed", default="")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the ExaGetWebsetBlock with its configuration."""
|
||||
super().__init__(
|
||||
id="f9229293-cddf-43fc-94b3-48cbd1a44618",
|
||||
description="Retrieves a Webset by ID using Exa's Websets API",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetWebsetBlock.Input,
|
||||
output_schema=ExaGetWebsetBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *, credentials: ExaCredentials, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Execute the block to retrieve a webset by ID from Exa's API.
|
||||
|
||||
Args:
|
||||
input_data: The input parameters including the webset ID
|
||||
credentials: The Exa API credentials
|
||||
|
||||
Yields:
|
||||
Either the retrieved webset object or an error message
|
||||
"""
|
||||
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
params = {"expand": "items"} if input_data.expand_items else None
|
||||
try:
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
yield "webset", data
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
# --- Delete Webset Block ---
|
||||
class ExaDeleteWebsetBlock(Block):
|
||||
"""Block for deleting a Webset by ID using Exa's Websets API."""
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
webset_id: str = SchemaField(description="The Webset ID or externalId")
|
||||
|
||||
class Output(BlockSchema):
|
||||
deleted: Optional[Dict[str, Any]] = SchemaField(description="The deleted webset object (or None if error)", default=None)
|
||||
error: str = SchemaField(description="Error message if the request failed", default="")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the ExaDeleteWebsetBlock with its configuration."""
|
||||
super().__init__(
|
||||
id="a082e162-274e-4167-a467-a1839e644cbd",
|
||||
description="Deletes a Webset by ID using Exa's Websets API",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaDeleteWebsetBlock.Input,
|
||||
output_schema=ExaDeleteWebsetBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *, credentials: ExaCredentials, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Execute the block to delete a webset by ID using Exa's API.
|
||||
|
||||
Args:
|
||||
input_data: The input parameters including the webset ID
|
||||
credentials: The Exa API credentials
|
||||
|
||||
Yields:
|
||||
Either the deleted webset object or an error message
|
||||
"""
|
||||
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
try:
|
||||
response = requests.delete(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
yield "deleted", data
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
# --- Update Webset Block ---
|
||||
class ExaUpdateWebsetBlock(Block):
|
||||
"""Block for updating a Webset's metadata using Exa's Websets API."""
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
webset_id: str = SchemaField(description="The Webset ID or externalId")
|
||||
metadata: Dict[str, Any] = SchemaField(description="Metadata to update", default_factory=dict)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset: Optional[Dict[str, Any]] = SchemaField(description="The updated webset object (or None if error)", default=None)
|
||||
error: str = SchemaField(description="Error message if the request failed", default="")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the ExaUpdateWebsetBlock with its configuration."""
|
||||
super().__init__(
|
||||
id="e0c81b70-ac38-4239-8ecd-a75c1737c9ef",
|
||||
description="Updates a Webset's metadata using Exa's Websets API",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaUpdateWebsetBlock.Input,
|
||||
output_schema=ExaUpdateWebsetBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *, credentials: ExaCredentials, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Execute the block to update a webset's metadata using Exa's API.
|
||||
|
||||
Args:
|
||||
input_data: The input parameters including the webset ID and metadata
|
||||
credentials: The Exa API credentials
|
||||
|
||||
Yields:
|
||||
Either the updated webset object or an error message
|
||||
"""
|
||||
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
payload = {"metadata": input_data.metadata}
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
yield "webset", data
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
# --- List Websets Block ---
|
||||
class ExaListWebsetsBlock(Block):
|
||||
"""Block for listing all Websets using Exa's Websets API with pagination support."""
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
limit: int = SchemaField(description="Number of websets to return (max 100)", default=25)
|
||||
cursor: Optional[str] = SchemaField(description="Pagination cursor (optional)", default=None, advanced=True)
|
||||
|
||||
class Output(BlockSchema):
|
||||
data: Optional[List[Dict[str, Any]]] = SchemaField(description="List of websets", default=None)
|
||||
has_more: Optional[bool] = SchemaField(description="Whether there are more results", default=None)
|
||||
next_cursor: Optional[str] = SchemaField(description="Cursor for next page", default=None)
|
||||
error: str = SchemaField(description="Error message if the request failed", default="")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the ExaListWebsetsBlock with its configuration."""
|
||||
super().__init__(
|
||||
id="887a2dae-c9c3-4ae5-a079-fe3b52be64e4",
|
||||
description="Lists all Websets using Exa's Websets API",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaListWebsetsBlock.Input,
|
||||
output_schema=ExaListWebsetsBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *, credentials: ExaCredentials, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Execute the block to list websets with pagination using Exa's API.
|
||||
|
||||
Args:
|
||||
input_data: The input parameters including limit and optional cursor
|
||||
credentials: The Exa API credentials
|
||||
|
||||
Yields:
|
||||
The list of websets, pagination info, or an error message
|
||||
"""
|
||||
url = "https://api.exa.ai/websets/v0/websets"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
params: dict[str, Any] = {"limit": int(input_data.limit)}
|
||||
if isinstance(input_data.cursor, str) and input_data.cursor:
|
||||
params["cursor"] = input_data.cursor
|
||||
try:
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
yield "data", data.get("data")
|
||||
yield "has_more", data.get("hasMore")
|
||||
yield "next_cursor", data.get("nextCursor")
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
# --- Cancel Webset Block ---
|
||||
class ExaCancelWebsetBlock(Block):
|
||||
"""Block for canceling a running Webset using Exa's Websets API."""
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
webset_id: str = SchemaField(description="The Webset ID or externalId")
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset: Optional[Dict[str, Any]] = SchemaField(description="The canceled webset object (or None if error)", default=None)
|
||||
error: str = SchemaField(description="Error message if the request failed", default="")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the ExaCancelWebsetBlock with its configuration."""
|
||||
super().__init__(
|
||||
id="f7f0b19c-71e8-4c2f-bc68-904a6a61faf7",
|
||||
description="Cancels a running Webset using Exa's Websets API",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCancelWebsetBlock.Input,
|
||||
output_schema=ExaCancelWebsetBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *, credentials: ExaCredentials, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Execute the block to cancel a running webset using Exa's API.
|
||||
|
||||
Args:
|
||||
input_data: The input parameters including the webset ID
|
||||
credentials: The Exa API credentials
|
||||
|
||||
Yields:
|
||||
Either the canceled webset object or an error message
|
||||
"""
|
||||
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}/cancel"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
try:
|
||||
response = requests.post(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
yield "webset", data
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,598 +0,0 @@
|
||||
import enum
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Literal
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
|
||||
from ._auth import (
|
||||
GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GoogleCredentials,
|
||||
GoogleCredentialsField,
|
||||
GoogleCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
"""Structured representation of a Google Calendar event."""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
start_time: str
|
||||
end_time: str
|
||||
is_all_day: bool
|
||||
location: str | None
|
||||
description: str | None
|
||||
organizer: str | None
|
||||
attendees: list[str]
|
||||
has_video_call: bool
|
||||
video_link: str | None
|
||||
calendar_link: str
|
||||
is_recurring: bool
|
||||
|
||||
|
||||
class GoogleCalendarReadEventsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/calendar.readonly"]
|
||||
)
|
||||
calendar_id: str = SchemaField(
|
||||
description="Calendar ID (use 'primary' for your main calendar)",
|
||||
default="primary",
|
||||
)
|
||||
max_events: int = SchemaField(
|
||||
description="Maximum number of events to retrieve", default=10
|
||||
)
|
||||
start_time: datetime = SchemaField(
|
||||
description="Retrieve events starting from this time",
|
||||
default_factory=lambda: datetime.now(tz=timezone.utc),
|
||||
)
|
||||
time_range_days: int = SchemaField(
|
||||
description="Number of days to look ahead for events", default=30
|
||||
)
|
||||
search_term: str | None = SchemaField(
|
||||
description="Optional search term to filter events by", default=None
|
||||
)
|
||||
|
||||
page_token: str | None = SchemaField(
|
||||
description="Page token from previous request to get the next batch of events. You can use this if you have lots of events you want to process in a loop",
|
||||
default=None,
|
||||
)
|
||||
include_declined_events: bool = SchemaField(
|
||||
description="Include events you've declined", default=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
events: list[CalendarEvent] = SchemaField(
|
||||
description="List of calendar events in the requested time range",
|
||||
default_factory=list,
|
||||
)
|
||||
event: CalendarEvent = SchemaField(
|
||||
description="One of the calendar events in the requested time range"
|
||||
)
|
||||
next_page_token: str | None = SchemaField(
|
||||
description="Token for retrieving the next page of events if more exist",
|
||||
default=None,
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
settings = Settings()
|
||||
|
||||
# Create realistic test data for events
|
||||
test_now = datetime.now(tz=timezone.utc)
|
||||
test_tomorrow = test_now + timedelta(days=1)
|
||||
|
||||
test_event_dict = {
|
||||
"id": "event1id",
|
||||
"title": "Team Meeting",
|
||||
"start_time": test_tomorrow.strftime("%Y-%m-%d %H:%M"),
|
||||
"end_time": (test_tomorrow + timedelta(hours=1)).strftime("%Y-%m-%d %H:%M"),
|
||||
"is_all_day": False,
|
||||
"location": "Conference Room A",
|
||||
"description": "Weekly team sync",
|
||||
"organizer": "manager@example.com",
|
||||
"attendees": ["colleague1@example.com", "colleague2@example.com"],
|
||||
"has_video_call": True,
|
||||
"video_link": "https://meet.google.com/abc-defg-hij",
|
||||
"calendar_link": "https://calendar.google.com/calendar/event?eid=event1id",
|
||||
"is_recurring": True,
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
id="80bc3ed1-e9a4-449e-8163-a8fc86f74f6a",
|
||||
description="Retrieves upcoming events from a Google Calendar with filtering options",
|
||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.DATA},
|
||||
input_schema=GoogleCalendarReadEventsBlock.Input,
|
||||
output_schema=GoogleCalendarReadEventsBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED
|
||||
or settings.config.app_env == AppEnvironment.PRODUCTION,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"calendar_id": "primary",
|
||||
"max_events": 5,
|
||||
"start_time": test_now.isoformat(),
|
||||
"time_range_days": 7,
|
||||
"search_term": None,
|
||||
"include_declined_events": False,
|
||||
"page_token": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("event", test_event_dict),
|
||||
("events", [test_event_dict]),
|
||||
],
|
||||
test_mock={
|
||||
"_read_calendar": lambda *args, **kwargs: {
|
||||
"items": [
|
||||
{
|
||||
"id": "event1id",
|
||||
"summary": "Team Meeting",
|
||||
"start": {
|
||||
"dateTime": test_tomorrow.isoformat(),
|
||||
"timeZone": "UTC",
|
||||
},
|
||||
"end": {
|
||||
"dateTime": (
|
||||
test_tomorrow + timedelta(hours=1)
|
||||
).isoformat(),
|
||||
"timeZone": "UTC",
|
||||
},
|
||||
"location": "Conference Room A",
|
||||
"description": "Weekly team sync",
|
||||
"organizer": {"email": "manager@example.com"},
|
||||
"attendees": [
|
||||
{"email": "colleague1@example.com"},
|
||||
{"email": "colleague2@example.com"},
|
||||
],
|
||||
"conferenceData": {
|
||||
"conferenceUrl": "https://meet.google.com/abc-defg-hij"
|
||||
},
|
||||
"htmlLink": "https://calendar.google.com/calendar/event?eid=event1id",
|
||||
"recurrence": ["RRULE:FREQ=WEEKLY;COUNT=10"],
|
||||
}
|
||||
],
|
||||
"nextPageToken": None,
|
||||
},
|
||||
"_format_events": lambda *args, **kwargs: [test_event_dict],
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
|
||||
# Calculate end time based on start time and time range
|
||||
end_time = input_data.start_time + timedelta(
|
||||
days=input_data.time_range_days
|
||||
)
|
||||
|
||||
# Call Google Calendar API
|
||||
result = self._read_calendar(
|
||||
service=service,
|
||||
calendarId=input_data.calendar_id,
|
||||
time_min=input_data.start_time.isoformat(),
|
||||
time_max=end_time.isoformat(),
|
||||
max_results=input_data.max_events,
|
||||
single_events=True,
|
||||
search_term=input_data.search_term,
|
||||
show_deleted=False,
|
||||
show_hidden=input_data.include_declined_events,
|
||||
page_token=input_data.page_token,
|
||||
)
|
||||
|
||||
# Format events into a user-friendly structure
|
||||
formatted_events = self._format_events(result.get("items", []))
|
||||
|
||||
# Include next page token if available
|
||||
if next_page_token := result.get("nextPageToken"):
|
||||
yield "next_page_token", next_page_token
|
||||
|
||||
for event in formatted_events:
|
||||
yield "event", event
|
||||
|
||||
yield "events", formatted_events
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@staticmethod
|
||||
def _build_service(credentials: GoogleCredentials, **kwargs):
|
||||
creds = Credentials(
|
||||
token=(
|
||||
credentials.access_token.get_secret_value()
|
||||
if credentials.access_token
|
||||
else None
|
||||
),
|
||||
refresh_token=(
|
||||
credentials.refresh_token.get_secret_value()
|
||||
if credentials.refresh_token
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=Settings().secrets.google_client_id,
|
||||
client_secret=Settings().secrets.google_client_secret,
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
return build("calendar", "v3", credentials=creds)
|
||||
|
||||
def _read_calendar(
|
||||
self,
|
||||
service,
|
||||
calendarId: str,
|
||||
time_min: str,
|
||||
time_max: str,
|
||||
max_results: int,
|
||||
single_events: bool,
|
||||
search_term: str | None = None,
|
||||
show_deleted: bool = False,
|
||||
show_hidden: bool = False,
|
||||
page_token: str | None = None,
|
||||
) -> dict:
|
||||
"""Read calendar events with optional filtering."""
|
||||
calendar = service.events()
|
||||
|
||||
# Build query parameters
|
||||
params = {
|
||||
"calendarId": calendarId,
|
||||
"timeMin": time_min,
|
||||
"timeMax": time_max,
|
||||
"maxResults": max_results,
|
||||
"singleEvents": single_events,
|
||||
"orderBy": "startTime",
|
||||
"showDeleted": show_deleted,
|
||||
"showHiddenInvitations": show_hidden,
|
||||
**({"pageToken": page_token} if page_token else {}),
|
||||
}
|
||||
|
||||
# Add search term if provided
|
||||
if search_term:
|
||||
params["q"] = search_term
|
||||
|
||||
result = calendar.list(**params).execute()
|
||||
return result
|
||||
|
||||
def _format_events(self, events: list[dict]) -> list[CalendarEvent]:
|
||||
"""Format Google Calendar API events into user-friendly structure."""
|
||||
formatted_events = []
|
||||
|
||||
for event in events:
|
||||
# Determine if all-day event
|
||||
is_all_day = "date" in event.get("start", {})
|
||||
|
||||
# Format start and end times
|
||||
if is_all_day:
|
||||
start_time = event.get("start", {}).get("date", "")
|
||||
end_time = event.get("end", {}).get("date", "")
|
||||
else:
|
||||
# Convert ISO format to more readable format
|
||||
start_datetime = datetime.fromisoformat(
|
||||
event.get("start", {}).get("dateTime", "").replace("Z", "+00:00")
|
||||
)
|
||||
end_datetime = datetime.fromisoformat(
|
||||
event.get("end", {}).get("dateTime", "").replace("Z", "+00:00")
|
||||
)
|
||||
start_time = start_datetime.strftime("%Y-%m-%d %H:%M")
|
||||
end_time = end_datetime.strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
# Extract attendees
|
||||
attendees = []
|
||||
for attendee in event.get("attendees", []):
|
||||
if email := attendee.get("email"):
|
||||
attendees.append(email)
|
||||
|
||||
# Check for video call link
|
||||
has_video_call = False
|
||||
video_link = None
|
||||
if conf_data := event.get("conferenceData"):
|
||||
if conf_url := conf_data.get("conferenceUrl"):
|
||||
has_video_call = True
|
||||
video_link = conf_url
|
||||
elif entry_points := conf_data.get("entryPoints", []):
|
||||
for entry in entry_points:
|
||||
if entry.get("entryPointType") == "video":
|
||||
has_video_call = True
|
||||
video_link = entry.get("uri")
|
||||
break
|
||||
|
||||
# Create formatted event
|
||||
formatted_event = CalendarEvent(
|
||||
id=event.get("id", ""),
|
||||
title=event.get("summary", "Untitled Event"),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
is_all_day=is_all_day,
|
||||
location=event.get("location"),
|
||||
description=event.get("description"),
|
||||
organizer=event.get("organizer", {}).get("email"),
|
||||
attendees=attendees,
|
||||
has_video_call=has_video_call,
|
||||
video_link=video_link,
|
||||
calendar_link=event.get("htmlLink", ""),
|
||||
is_recurring=bool(event.get("recurrence")),
|
||||
)
|
||||
|
||||
formatted_events.append(formatted_event)
|
||||
|
||||
return formatted_events
|
||||
|
||||
|
||||
class ReminderPreset(enum.Enum):
|
||||
"""Common reminder times before an event."""
|
||||
|
||||
TEN_MINUTES = 10
|
||||
THIRTY_MINUTES = 30
|
||||
ONE_HOUR = 60
|
||||
ONE_DAY = 1440 # 24 hours in minutes
|
||||
|
||||
|
||||
class RecurrenceFrequency(enum.Enum):
|
||||
"""Frequency options for recurring events."""
|
||||
|
||||
DAILY = "DAILY"
|
||||
WEEKLY = "WEEKLY"
|
||||
MONTHLY = "MONTHLY"
|
||||
YEARLY = "YEARLY"
|
||||
|
||||
|
||||
class ExactTiming(BaseModel):
|
||||
"""Model for specifying start and end times."""
|
||||
|
||||
discriminator: Literal["exact_timing"]
|
||||
start_datetime: datetime
|
||||
end_datetime: datetime
|
||||
|
||||
|
||||
class DurationTiming(BaseModel):
|
||||
"""Model for specifying start time and duration."""
|
||||
|
||||
discriminator: Literal["duration_timing"]
|
||||
start_datetime: datetime
|
||||
duration_minutes: int
|
||||
|
||||
|
||||
class OneTimeEvent(BaseModel):
|
||||
"""Model for a one-time event."""
|
||||
|
||||
discriminator: Literal["one_time"]
|
||||
|
||||
|
||||
class RecurringEvent(BaseModel):
|
||||
"""Model for a recurring event."""
|
||||
|
||||
discriminator: Literal["recurring"]
|
||||
frequency: RecurrenceFrequency
|
||||
count: int
|
||||
|
||||
|
||||
class GoogleCalendarCreateEventBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/calendar"]
|
||||
)
|
||||
# Event Details
|
||||
event_title: str = SchemaField(description="Title of the event")
|
||||
location: str | None = SchemaField(
|
||||
description="Location of the event", default=None
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="Description of the event", default=None
|
||||
)
|
||||
|
||||
# Timing
|
||||
timing: ExactTiming | DurationTiming = SchemaField(
|
||||
discriminator="discriminator",
|
||||
advanced=False,
|
||||
description="Specify when the event starts and ends",
|
||||
default_factory=lambda: DurationTiming(
|
||||
discriminator="duration_timing",
|
||||
start_datetime=datetime.now().replace(microsecond=0, second=0, minute=0)
|
||||
+ timedelta(hours=1),
|
||||
duration_minutes=60,
|
||||
),
|
||||
)
|
||||
|
||||
# Calendar selection
|
||||
calendar_id: str = SchemaField(
|
||||
description="Calendar ID (use 'primary' for your main calendar)",
|
||||
default="primary",
|
||||
)
|
||||
|
||||
# Guests
|
||||
guest_emails: list[str] = SchemaField(
|
||||
description="Email addresses of guests to invite", default_factory=list
|
||||
)
|
||||
send_notifications: bool = SchemaField(
|
||||
description="Send email notifications to guests", default=True
|
||||
)
|
||||
|
||||
# Extras
|
||||
add_google_meet: bool = SchemaField(
|
||||
description="Include a Google Meet video conference link", default=False
|
||||
)
|
||||
recurrence: OneTimeEvent | RecurringEvent = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Whether the event repeats",
|
||||
default_factory=lambda: OneTimeEvent(discriminator="one_time"),
|
||||
)
|
||||
reminder_minutes: list[ReminderPreset] = SchemaField(
|
||||
description="When to send reminders before the event",
|
||||
default_factory=lambda: [ReminderPreset.TEN_MINUTES],
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
event_id: str = SchemaField(description="ID of the created event")
|
||||
event_link: str = SchemaField(
|
||||
description="Link to view the event in Google Calendar"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if event creation failed")
|
||||
|
||||
def __init__(self):
|
||||
settings = Settings()
|
||||
|
||||
super().__init__(
|
||||
id="ed2ec950-fbff-4204-94c0-023fb1d625e0",
|
||||
description="This block creates a new event in Google Calendar with customizable parameters.",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=GoogleCalendarCreateEventBlock.Input,
|
||||
output_schema=GoogleCalendarCreateEventBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED
|
||||
or settings.config.app_env == AppEnvironment.PRODUCTION,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"event_title": "Team Meeting",
|
||||
"location": "Conference Room A",
|
||||
"description": "Weekly team sync-up",
|
||||
"calendar_id": "primary",
|
||||
"guest_emails": ["colleague1@example.com", "colleague2@example.com"],
|
||||
"add_google_meet": True,
|
||||
"send_notifications": True,
|
||||
"reminder_minutes": [
|
||||
ReminderPreset.TEN_MINUTES.value,
|
||||
ReminderPreset.ONE_HOUR.value,
|
||||
],
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("event_id", "abc123event_id"),
|
||||
("event_link", "https://calendar.google.com/calendar/event?eid=abc123"),
|
||||
],
|
||||
test_mock={
|
||||
"_create_event": lambda *args, **kwargs: {
|
||||
"id": "abc123event_id",
|
||||
"htmlLink": "https://calendar.google.com/calendar/event?eid=abc123",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
|
||||
# Get start and end times based on the timing option
|
||||
if input_data.timing.discriminator == "exact_timing":
|
||||
start_datetime = input_data.timing.start_datetime
|
||||
end_datetime = input_data.timing.end_datetime
|
||||
else: # duration_timing
|
||||
start_datetime = input_data.timing.start_datetime
|
||||
end_datetime = start_datetime + timedelta(
|
||||
minutes=input_data.timing.duration_minutes
|
||||
)
|
||||
|
||||
# Format datetimes for Google Calendar API
|
||||
start_time_str = start_datetime.isoformat()
|
||||
end_time_str = end_datetime.isoformat()
|
||||
|
||||
# Build the event body
|
||||
event_body = {
|
||||
"summary": input_data.event_title,
|
||||
"start": {"dateTime": start_time_str},
|
||||
"end": {"dateTime": end_time_str},
|
||||
}
|
||||
|
||||
# Add optional fields
|
||||
if input_data.location:
|
||||
event_body["location"] = input_data.location
|
||||
|
||||
if input_data.description:
|
||||
event_body["description"] = input_data.description
|
||||
|
||||
# Add guests
|
||||
if input_data.guest_emails:
|
||||
event_body["attendees"] = [
|
||||
{"email": email} for email in input_data.guest_emails
|
||||
]
|
||||
|
||||
# Add reminders
|
||||
if input_data.reminder_minutes:
|
||||
event_body["reminders"] = {
|
||||
"useDefault": False,
|
||||
"overrides": [
|
||||
{"method": "popup", "minutes": reminder.value}
|
||||
for reminder in input_data.reminder_minutes
|
||||
],
|
||||
}
|
||||
|
||||
# Add Google Meet
|
||||
if input_data.add_google_meet:
|
||||
event_body["conferenceData"] = {
|
||||
"createRequest": {
|
||||
"requestId": f"meet-{uuid.uuid4()}",
|
||||
"conferenceSolutionKey": {"type": "hangoutsMeet"},
|
||||
}
|
||||
}
|
||||
|
||||
# Add recurrence
|
||||
if input_data.recurrence.discriminator == "recurring":
|
||||
rule = f"RRULE:FREQ={input_data.recurrence.frequency.value}"
|
||||
rule += f";COUNT={input_data.recurrence.count}"
|
||||
event_body["recurrence"] = [rule]
|
||||
|
||||
# Create the event
|
||||
result = self._create_event(
|
||||
service=service,
|
||||
calendar_id=input_data.calendar_id,
|
||||
event_body=event_body,
|
||||
send_notifications=input_data.send_notifications,
|
||||
conference_data_version=1 if input_data.add_google_meet else 0,
|
||||
)
|
||||
|
||||
yield "event_id", result.get("id", "")
|
||||
yield "event_link", result.get("htmlLink", "")
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@staticmethod
|
||||
def _build_service(credentials: GoogleCredentials, **kwargs):
|
||||
creds = Credentials(
|
||||
token=(
|
||||
credentials.access_token.get_secret_value()
|
||||
if credentials.access_token
|
||||
else None
|
||||
),
|
||||
refresh_token=(
|
||||
credentials.refresh_token.get_secret_value()
|
||||
if credentials.refresh_token
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=Settings().secrets.google_client_id,
|
||||
client_secret=Settings().secrets.google_client_secret,
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
return build("calendar", "v3", credentials=creds)
|
||||
|
||||
def _create_event(
|
||||
self,
|
||||
service,
|
||||
calendar_id: str,
|
||||
event_body: dict,
|
||||
send_notifications: bool = False,
|
||||
conference_data_version: int = 0,
|
||||
) -> dict:
|
||||
"""Create a new event in Google Calendar."""
|
||||
calendar = service.events()
|
||||
|
||||
# Make the API call
|
||||
result = calendar.insert(
|
||||
calendarId=calendar_id,
|
||||
body=event_body,
|
||||
sendNotifications=send_notifications,
|
||||
conferenceDataVersion=conference_data_version,
|
||||
).execute()
|
||||
|
||||
return result
|
||||
@@ -3,7 +3,7 @@ from googleapiclient.discovery import build
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ._auth import (
|
||||
GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
@@ -36,15 +36,13 @@ class GoogleSheetsReadBlock(Block):
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
settings = Settings()
|
||||
super().__init__(
|
||||
id="5724e902-3635-47e9-a108-aaa0263a4988",
|
||||
description="This block reads data from a Google Sheets spreadsheet.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=GoogleSheetsReadBlock.Input,
|
||||
output_schema=GoogleSheetsReadBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED
|
||||
or settings.config.app_env == AppEnvironment.PRODUCTION,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"spreadsheet_id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms",
|
||||
"range": "Sheet1!A1:B2",
|
||||
|
||||
@@ -82,15 +82,7 @@ class SendWebRequestBlock(Block):
|
||||
json=body if input_data.json_format else None,
|
||||
data=body if not input_data.json_format else None,
|
||||
)
|
||||
|
||||
if input_data.json_format:
|
||||
if response.status_code == 204 or not response.content.strip():
|
||||
result = None
|
||||
else:
|
||||
result = response.json()
|
||||
else:
|
||||
result = response.text
|
||||
|
||||
result = response.json() if input_data.json_format else response.text
|
||||
yield "response", result
|
||||
|
||||
except HTTPError as e:
|
||||
|
||||
@@ -288,13 +288,6 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def estimate_token_count(prompt_messages: list[dict]) -> int:
|
||||
char_count = sum(len(str(msg.get("content", ""))) for msg in prompt_messages)
|
||||
message_overhead = len(prompt_messages) * 4
|
||||
estimated_tokens = (char_count // 4) + message_overhead
|
||||
return int(estimated_tokens * 1.2)
|
||||
|
||||
|
||||
def llm_call(
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
@@ -326,14 +319,7 @@ def llm_call(
|
||||
- completion_tokens: The number of tokens used in the completion.
|
||||
"""
|
||||
provider = llm_model.metadata.provider
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or 4096
|
||||
user_max = max_tokens or model_max_output
|
||||
available_tokens = max(context_window - estimated_input_tokens, 0)
|
||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 0)
|
||||
max_tokens = max_tokens or llm_model.max_output_tokens or 4096
|
||||
|
||||
if provider == "openai":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
@@ -489,7 +475,6 @@ def llm_call(
|
||||
model=llm_model.value,
|
||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||
stream=False,
|
||||
options={"num_ctx": max_tokens},
|
||||
)
|
||||
return LLMResponse(
|
||||
raw_response=response.get("response") or "",
|
||||
@@ -788,16 +773,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
prompt.append({"role": "user", "content": retry_prompt})
|
||||
except Exception as e:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
if (
|
||||
"maximum context length" in str(e).lower()
|
||||
or "token limit" in str(e).lower()
|
||||
):
|
||||
if input_data.max_tokens is None:
|
||||
input_data.max_tokens = llm_model.max_output_tokens or 4096
|
||||
input_data.max_tokens = int(input_data.max_tokens * 0.85)
|
||||
logger.debug(
|
||||
f"Reducing max_tokens to {input_data.max_tokens} for next attempt"
|
||||
)
|
||||
retry_prompt = f"Error calling LLM: {e}"
|
||||
finally:
|
||||
self.merge_stats(
|
||||
|
||||
@@ -26,10 +26,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
@@ -246,10 +246,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
test_credentials=llm.TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def cleanup(s: str):
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
|
||||
|
||||
@staticmethod
|
||||
def _create_block_function_signature(
|
||||
sink_node: "Node", links: list["Link"]
|
||||
@@ -270,7 +266,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
block = sink_node.block
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": SmartDecisionMakerBlock.cleanup(block.name),
|
||||
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", block.name).lower(),
|
||||
"description": block.description,
|
||||
}
|
||||
|
||||
@@ -285,7 +281,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
and sink_block_input_schema.model_fields[link.sink_name].description
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
|
||||
properties[link.sink_name.lower()] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
}
|
||||
@@ -330,7 +326,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name),
|
||||
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", sink_graph_meta.name).lower(),
|
||||
"description": sink_graph_meta.description,
|
||||
}
|
||||
|
||||
@@ -345,7 +341,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
in sink_block_input_schema["properties"][link.sink_name]
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
|
||||
properties[link.sink_name.lower()] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
}
|
||||
@@ -507,7 +503,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
for arg_name, arg_value in tool_args.items():
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", arg_value
|
||||
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
|
||||
|
||||
response.prompt.append(response.raw_response)
|
||||
yield "conversations", response.prompt
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
import stripe
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from prisma import Json
|
||||
from prisma.enums import (
|
||||
CreditRefundRequestStatus,
|
||||
@@ -19,7 +20,7 @@ from prisma.types import (
|
||||
CreditTransactionCreateInput,
|
||||
CreditTransactionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
@@ -27,17 +28,15 @@ from backend.data.cost import BlockCost
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
TopUpType,
|
||||
TransactionHistory,
|
||||
UserTransaction,
|
||||
)
|
||||
from backend.data.notifications import NotificationEventModel, RefundRequestData
|
||||
from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.server.model import Pagination
|
||||
from backend.server.v2.admin.model import UserHistoryResponse
|
||||
from backend.data.notifications import NotificationEventDTO, RefundRequestData
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor.utils import UsageTransactionMetadata
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
@@ -46,17 +45,6 @@ logger = logging.getLogger(__name__)
|
||||
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
graph_exec_id: str | None = None
|
||||
graph_id: str | None = None
|
||||
node_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
block_id: str | None = None
|
||||
block: str | None = None
|
||||
input: dict[str, Any] | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class UserCreditBase(ABC):
|
||||
@abstractmethod
|
||||
async def get_credits(self, user_id: str) -> int:
|
||||
@@ -274,7 +262,11 @@ class UserCreditBase(ABC):
|
||||
)
|
||||
return transaction_balance, transaction_time
|
||||
|
||||
@func_retry
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
reraise=True,
|
||||
)
|
||||
async def _enable_transaction(
|
||||
self,
|
||||
transaction_key: str,
|
||||
@@ -372,17 +364,22 @@ class UserCreditBase(ABC):
|
||||
|
||||
|
||||
class UserCredit(UserCreditBase):
|
||||
@thread_cached
|
||||
def notification_client(self) -> NotificationManager:
|
||||
return get_service_client(NotificationManager)
|
||||
|
||||
async def _send_refund_notification(
|
||||
self,
|
||||
notification_request: RefundRequestData,
|
||||
notification_type: NotificationType,
|
||||
):
|
||||
await queue_notification_async(
|
||||
NotificationEventModel(
|
||||
user_id=notification_request.user_id,
|
||||
type=notification_type,
|
||||
data=notification_request,
|
||||
await asyncio.to_thread(
|
||||
lambda: self.notification_client().queue_notification(
|
||||
NotificationEventDTO(
|
||||
user_id=notification_request.user_id,
|
||||
type=notification_type,
|
||||
data=notification_request.model_dump(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -412,7 +409,6 @@ class UserCredit(UserCreditBase):
|
||||
# Avoid multiple auto top-ups within the same graph execution.
|
||||
key=f"AUTO-TOP-UP-{user_id}-{metadata.graph_exec_id}",
|
||||
ceiling_balance=auto_top_up.threshold,
|
||||
top_up_type=TopUpType.AUTO,
|
||||
)
|
||||
except Exception as e:
|
||||
# Failed top-up is not critical, we can move on.
|
||||
@@ -422,30 +418,26 @@ class UserCredit(UserCreditBase):
|
||||
|
||||
return balance
|
||||
|
||||
async def top_up_credits(
|
||||
self,
|
||||
user_id: str,
|
||||
amount: int,
|
||||
top_up_type: TopUpType = TopUpType.UNCATEGORIZED,
|
||||
):
|
||||
await self._top_up_credits(
|
||||
user_id=user_id, amount=amount, top_up_type=top_up_type
|
||||
)
|
||||
async def top_up_credits(self, user_id: str, amount: int):
|
||||
await self._top_up_credits(user_id, amount)
|
||||
|
||||
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
|
||||
try:
|
||||
key = f"REWARD-{user_id}-{step.value}"
|
||||
if not await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"transactionKey": key,
|
||||
}
|
||||
):
|
||||
await self._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=credits,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
transaction_key=f"REWARD-{user_id}-{step.value}",
|
||||
transaction_key=key,
|
||||
metadata=Json(
|
||||
{"reason": f"Reward for completing {step.value} onboarding step."}
|
||||
),
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# Already rewarded for this step
|
||||
pass
|
||||
|
||||
async def top_up_refund(
|
||||
self, user_id: str, transaction_key: str, metadata: dict[str, str]
|
||||
@@ -610,7 +602,7 @@ class UserCredit(UserCreditBase):
|
||||
|
||||
evidence_text += (
|
||||
f"- {tx.description}: Amount ${tx.amount / 100:.2f} on {tx.transaction_time.isoformat()}, "
|
||||
f"resulting balance ${tx.running_balance / 100:.2f} {additional_comment}\n"
|
||||
f"resulting balance ${tx.balance / 100:.2f} {additional_comment}\n"
|
||||
)
|
||||
evidence_text += (
|
||||
"\nThis evidence demonstrates that the transaction was authorized and that the charged amount was used to render the service as agreed."
|
||||
@@ -629,24 +621,7 @@ class UserCredit(UserCreditBase):
|
||||
amount: int,
|
||||
key: str | None = None,
|
||||
ceiling_balance: int | None = None,
|
||||
top_up_type: TopUpType = TopUpType.UNCATEGORIZED,
|
||||
metadata: dict | None = None,
|
||||
):
|
||||
# init metadata, without sharing it with the world
|
||||
metadata = metadata or {}
|
||||
if not metadata["reason"]:
|
||||
match top_up_type:
|
||||
case TopUpType.MANUAL:
|
||||
metadata["reason"] = {"reason": f"Top up credits for {user_id}"}
|
||||
case TopUpType.AUTO:
|
||||
metadata["reason"] = {
|
||||
"reason": f"Auto top up credits for {user_id}"
|
||||
}
|
||||
case _:
|
||||
metadata["reason"] = {
|
||||
"reason": f"Top up reason unknown for {user_id}"
|
||||
}
|
||||
|
||||
if amount < 0:
|
||||
raise ValueError(f"Top up amount must not be negative: {amount}")
|
||||
|
||||
@@ -669,7 +644,6 @@ class UserCredit(UserCreditBase):
|
||||
is_active=False,
|
||||
transaction_key=key,
|
||||
ceiling_balance=ceiling_balance,
|
||||
metadata=(Json(metadata)),
|
||||
)
|
||||
|
||||
customer_id = await get_stripe_customer_id(user_id)
|
||||
@@ -812,15 +786,10 @@ class UserCredit(UserCreditBase):
|
||||
# Check the Checkout Session's payment_status property
|
||||
# to determine if fulfillment should be performed
|
||||
if checkout_session.payment_status in ["paid", "no_payment_required"]:
|
||||
if payment_intent := checkout_session.payment_intent:
|
||||
assert isinstance(payment_intent, stripe.PaymentIntent)
|
||||
new_transaction_key = payment_intent.id
|
||||
else:
|
||||
new_transaction_key = None
|
||||
|
||||
assert isinstance(checkout_session.payment_intent, stripe.PaymentIntent)
|
||||
await self._enable_transaction(
|
||||
transaction_key=credit_transaction.transactionKey,
|
||||
new_transaction_key=new_transaction_key,
|
||||
new_transaction_key=checkout_session.payment_intent.id,
|
||||
user_id=credit_transaction.userId,
|
||||
metadata=Json(checkout_session),
|
||||
)
|
||||
@@ -853,9 +822,8 @@ class UserCredit(UserCreditBase):
|
||||
take=transaction_count_limit,
|
||||
)
|
||||
|
||||
# doesn't fill current_balance, reason, user_email, admin_email, or extra_data
|
||||
grouped_transactions: dict[str, UserTransaction] = defaultdict(
|
||||
lambda: UserTransaction(user_id=user_id)
|
||||
lambda: UserTransaction()
|
||||
)
|
||||
tx_time = None
|
||||
for t in transactions:
|
||||
@@ -885,7 +853,7 @@ class UserCredit(UserCreditBase):
|
||||
|
||||
if tx_time > gt.transaction_time:
|
||||
gt.transaction_time = tx_time
|
||||
gt.running_balance = t.runningBalance or 0
|
||||
gt.balance = t.runningBalance or 0
|
||||
|
||||
return TransactionHistory(
|
||||
transactions=list(grouped_transactions.values()),
|
||||
@@ -935,7 +903,6 @@ class BetaUserCredit(UserCredit):
|
||||
amount=max(self.num_user_credits_refill - balance, 0),
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
transaction_key=f"MONTHLY-CREDIT-TOP-UP-{cur_time}",
|
||||
metadata=Json({"reason": "Monthly credit refill"}),
|
||||
)
|
||||
return balance
|
||||
except UniqueViolationError:
|
||||
@@ -945,7 +912,7 @@ class BetaUserCredit(UserCredit):
|
||||
|
||||
class DisabledUserCredit(UserCreditBase):
|
||||
async def get_credits(self, *args, **kwargs) -> int:
|
||||
return 100
|
||||
return 0
|
||||
|
||||
async def get_transaction_history(self, *args, **kwargs) -> TransactionHistory:
|
||||
return TransactionHistory(transactions=[], next_transaction_time=None)
|
||||
@@ -1023,81 +990,3 @@ async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
|
||||
return AutoTopUpConfig(threshold=0, amount=0)
|
||||
|
||||
return AutoTopUpConfig.model_validate(user.topUpConfig)
|
||||
|
||||
|
||||
async def admin_get_user_history(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
search: str | None = None,
|
||||
transaction_filter: CreditTransactionType | None = None,
|
||||
) -> UserHistoryResponse:
|
||||
|
||||
if page < 1 or page_size < 1:
|
||||
raise ValueError("Invalid pagination input")
|
||||
|
||||
where_clause: CreditTransactionWhereInput = {}
|
||||
if transaction_filter:
|
||||
where_clause["type"] = transaction_filter
|
||||
if search:
|
||||
where_clause["OR"] = [
|
||||
{"userId": {"contains": search, "mode": "insensitive"}},
|
||||
{"User": {"is": {"email": {"contains": search, "mode": "insensitive"}}}},
|
||||
{"User": {"is": {"name": {"contains": search, "mode": "insensitive"}}}},
|
||||
]
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where=where_clause,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
include={"User": True},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
total = await CreditTransaction.prisma().count(where=where_clause)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
history = []
|
||||
for tx in transactions:
|
||||
admin_id = ""
|
||||
admin_email = ""
|
||||
reason = ""
|
||||
|
||||
metadata: dict = cast(dict, tx.metadata) or {}
|
||||
|
||||
if metadata:
|
||||
admin_id = metadata.get("admin_id")
|
||||
admin_email = (
|
||||
(await get_user_email_by_id(admin_id) or f"Unknown Admin: {admin_id}")
|
||||
if admin_id
|
||||
else ""
|
||||
)
|
||||
reason = metadata.get("reason", "No reason provided")
|
||||
|
||||
balance, last_update = await get_user_credit_model()._get_credits(tx.userId)
|
||||
|
||||
history.append(
|
||||
UserTransaction(
|
||||
transaction_key=tx.transactionKey,
|
||||
transaction_time=tx.createdAt,
|
||||
transaction_type=tx.type,
|
||||
amount=tx.amount,
|
||||
current_balance=balance,
|
||||
running_balance=tx.runningBalance or 0,
|
||||
user_id=tx.userId,
|
||||
user_email=(
|
||||
tx.User.email
|
||||
if tx.User
|
||||
else (await get_user_by_id(tx.userId)).email
|
||||
),
|
||||
reason=reason,
|
||||
admin_email=admin_email,
|
||||
extra_data=str(metadata),
|
||||
)
|
||||
)
|
||||
return UserHistoryResponse(
|
||||
history=history,
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -30,7 +30,7 @@ from prisma.types import (
|
||||
AgentNodeExecutionUpdateInput,
|
||||
AgentNodeExecutionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
@@ -69,55 +69,10 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
ended_at: datetime
|
||||
|
||||
class Stats(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
cost: int = Field(
|
||||
default=0,
|
||||
description="Execution cost (cents)",
|
||||
)
|
||||
duration: float = Field(
|
||||
default=0,
|
||||
description="Seconds from start to end of run",
|
||||
)
|
||||
duration_cpu_only: float = Field(
|
||||
default=0,
|
||||
description="CPU sec of duration",
|
||||
)
|
||||
node_exec_time: float = Field(
|
||||
default=0,
|
||||
description="Seconds of total node runtime",
|
||||
)
|
||||
node_exec_time_cpu_only: float = Field(
|
||||
default=0,
|
||||
description="CPU sec of node_exec_time",
|
||||
)
|
||||
node_exec_count: int = Field(
|
||||
default=0,
|
||||
description="Number of node executions",
|
||||
)
|
||||
node_error_count: int = Field(
|
||||
default=0,
|
||||
description="Number of node errors",
|
||||
)
|
||||
error: str | None = Field(
|
||||
default=None,
|
||||
description="Error message if any",
|
||||
)
|
||||
|
||||
def to_db(self) -> GraphExecutionStats:
|
||||
return GraphExecutionStats(
|
||||
cost=self.cost,
|
||||
walltime=self.duration,
|
||||
cputime=self.duration_cpu_only,
|
||||
nodes_walltime=self.node_exec_time,
|
||||
nodes_cputime=self.node_exec_time_cpu_only,
|
||||
node_count=self.node_exec_count,
|
||||
node_error_count=self.node_error_count,
|
||||
error=self.error,
|
||||
)
|
||||
cost: int = Field(..., description="Execution cost (cents)")
|
||||
duration: float = Field(..., description="Seconds from start to end of run")
|
||||
node_exec_time: float = Field(..., description="Seconds of total node runtime")
|
||||
node_exec_count: int = Field(..., description="Number of node executions")
|
||||
|
||||
stats: Stats | None
|
||||
|
||||
@@ -151,16 +106,8 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
GraphExecutionMeta.Stats(
|
||||
cost=stats.cost,
|
||||
duration=stats.walltime,
|
||||
duration_cpu_only=stats.cputime,
|
||||
node_exec_time=stats.nodes_walltime,
|
||||
node_exec_time_cpu_only=stats.nodes_cputime,
|
||||
node_exec_count=stats.node_count,
|
||||
node_error_count=stats.node_error_count,
|
||||
error=(
|
||||
str(stats.error)
|
||||
if isinstance(stats.error, Exception)
|
||||
else stats.error
|
||||
),
|
||||
)
|
||||
if stats
|
||||
else None
|
||||
@@ -261,6 +208,18 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
graph_id=self.graph_id,
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
start_node_execs=[
|
||||
NodeExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
data=node_exec.input_data,
|
||||
)
|
||||
for node_exec in self.node_executions
|
||||
],
|
||||
node_credentials_input_map={}, # FIXME
|
||||
)
|
||||
|
||||
@@ -321,28 +280,13 @@ class NodeExecutionResult(BaseModel):
|
||||
end_time=_node_exec.endedTime,
|
||||
)
|
||||
|
||||
def to_node_execution_entry(self) -> "NodeExecutionEntry":
|
||||
return NodeExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_exec_id=self.graph_exec_id,
|
||||
graph_id=self.graph_id,
|
||||
node_exec_id=self.node_exec_id,
|
||||
node_id=self.node_id,
|
||||
block_id=self.block_id,
|
||||
inputs=self.input_data,
|
||||
)
|
||||
|
||||
|
||||
# --------------------- Model functions --------------------- #
|
||||
|
||||
|
||||
async def get_graph_executions(
|
||||
graph_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
created_time_gte: datetime | None = None,
|
||||
created_time_lte: datetime | None = None,
|
||||
limit: int | None = None,
|
||||
graph_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> list[GraphExecutionMeta]:
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
"isDeleted": False,
|
||||
@@ -351,18 +295,10 @@ async def get_graph_executions(
|
||||
where_filter["userId"] = user_id
|
||||
if graph_id:
|
||||
where_filter["agentGraphId"] = graph_id
|
||||
if created_time_gte or created_time_lte:
|
||||
where_filter["createdAt"] = {
|
||||
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
|
||||
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
|
||||
}
|
||||
if statuses:
|
||||
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where=where_filter,
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
return [GraphExecutionMeta.from_db(execution) for execution in executions]
|
||||
|
||||
@@ -556,12 +492,21 @@ async def upsert_execution_output(
|
||||
async def update_graph_execution_start_time(
|
||||
graph_exec_id: str,
|
||||
) -> GraphExecution | None:
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
},
|
||||
data={
|
||||
"executionStatus": ExecutionStatus.RUNNING,
|
||||
"startedAt": datetime.now(tz=timezone.utc),
|
||||
},
|
||||
)
|
||||
if count == 0:
|
||||
return None
|
||||
|
||||
res = await AgentGraphExecution.prisma().find_unique(
|
||||
where={"id": graph_exec_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return GraphExecution.from_db(res) if res else None
|
||||
@@ -680,9 +625,8 @@ async def delete_graph_execution(
|
||||
)
|
||||
|
||||
|
||||
async def get_node_executions(
|
||||
async def get_node_execution_results(
|
||||
graph_exec_id: str,
|
||||
node_id: str | None = None,
|
||||
block_ids: list[str] | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
limit: int | None = None,
|
||||
@@ -690,8 +634,6 @@ async def get_node_executions(
|
||||
where_clause: AgentNodeExecutionWhereInput = {
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
}
|
||||
if node_id:
|
||||
where_clause["agentNodeId"] = node_id
|
||||
if block_ids:
|
||||
where_clause["Node"] = {"is": {"agentBlockId": {"in": block_ids}}}
|
||||
if statuses:
|
||||
@@ -706,6 +648,28 @@ async def get_node_executions(
|
||||
return res
|
||||
|
||||
|
||||
async def get_graph_executions_in_timerange(
|
||||
user_id: str, start_time: str, end_time: str
|
||||
) -> list[GraphExecution]:
|
||||
try:
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where={
|
||||
"startedAt": {
|
||||
"gte": datetime.fromisoformat(start_time),
|
||||
"lte": datetime.fromisoformat(end_time),
|
||||
},
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return [GraphExecution.from_db(execution) for execution in executions]
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to get executions in timerange {start_time} to {end_time} for user {user_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_latest_node_execution(
|
||||
node_id: str, graph_eid: str
|
||||
) -> NodeExecutionResult | None:
|
||||
@@ -726,6 +690,20 @@ async def get_latest_node_execution(
|
||||
return NodeExecutionResult.from_db(execution)
|
||||
|
||||
|
||||
async def get_incomplete_node_executions(
|
||||
node_id: str, graph_eid: str
|
||||
) -> list[NodeExecutionResult]:
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
},
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
return [NodeExecutionResult.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
# ----------------- Execution Infrastructure ----------------- #
|
||||
|
||||
|
||||
@@ -734,6 +712,7 @@ class GraphExecutionEntry(BaseModel):
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
start_node_execs: list["NodeExecutionEntry"]
|
||||
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]]
|
||||
|
||||
|
||||
@@ -744,7 +723,7 @@ class NodeExecutionEntry(BaseModel):
|
||||
node_exec_id: str
|
||||
node_id: str
|
||||
block_id: str
|
||||
inputs: BlockInput
|
||||
data: BlockInput
|
||||
|
||||
|
||||
class ExecutionQueue(Generic[T]):
|
||||
|
||||
@@ -172,8 +172,6 @@ class BaseGraph(BaseDbModel):
|
||||
description: str
|
||||
nodes: list[Node] = []
|
||||
links: list[Link] = []
|
||||
forked_from_id: str | None = None
|
||||
forked_from_version: int | None = None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
@@ -199,6 +197,11 @@ class BaseGraph(BaseDbModel):
|
||||
)
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
return self._credentials_input_schema.jsonschema()
|
||||
|
||||
@staticmethod
|
||||
def _generate_schema(
|
||||
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
|
||||
@@ -231,15 +234,6 @@ class BaseGraph(BaseDbModel):
|
||||
"required": [p.name for p in schema_fields if p.value is None],
|
||||
}
|
||||
|
||||
|
||||
class Graph(BaseGraph):
|
||||
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
return self._credentials_input_schema.jsonschema()
|
||||
|
||||
@property
|
||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
@@ -318,14 +312,17 @@ class Graph(BaseGraph):
|
||||
),
|
||||
(node.id, field_name),
|
||||
)
|
||||
for graph in [self] + self.sub_graphs
|
||||
for node in graph.nodes
|
||||
for node in self.nodes
|
||||
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class Graph(BaseGraph):
|
||||
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs, only used in export
|
||||
|
||||
|
||||
class GraphModel(Graph):
|
||||
user_id: str
|
||||
nodes: list[NodeModel] = [] # type: ignore
|
||||
@@ -401,7 +398,7 @@ class GraphModel(Graph):
|
||||
if node.block_id != AgentExecutorBlock().id:
|
||||
continue
|
||||
node.input_default["user_id"] = user_id
|
||||
node.input_default.setdefault("inputs", {})
|
||||
node.input_default.setdefault("data", {})
|
||||
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
|
||||
node.input_default["graph_id"] = graph_id_map[graph_id]
|
||||
|
||||
@@ -412,13 +409,10 @@ class GraphModel(Graph):
|
||||
|
||||
@staticmethod
|
||||
def _validate_graph(graph: BaseGraph, for_run: bool = False):
|
||||
def is_tool_pin(name: str) -> bool:
|
||||
return name.startswith("tools_^_")
|
||||
|
||||
def sanitize(name):
|
||||
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
if is_tool_pin(sanitized_name):
|
||||
return "tools"
|
||||
if sanitized_name.startswith("tools_^_"):
|
||||
return sanitized_name.split("_^_")[0]
|
||||
return sanitized_name
|
||||
|
||||
# Validate smart decision maker nodes
|
||||
@@ -428,6 +422,10 @@ class GraphModel(Graph):
|
||||
if (block := get_block(node.block_id)) is not None
|
||||
}
|
||||
|
||||
for node in graph.nodes:
|
||||
if (block := nodes_block.get(node.id)) is None:
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
input_links = defaultdict(list)
|
||||
|
||||
for link in graph.links:
|
||||
@@ -442,8 +440,8 @@ class GraphModel(Graph):
|
||||
[sanitize(name) for name in node.input_default]
|
||||
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
|
||||
)
|
||||
InputSchema = block.input_schema
|
||||
for name in (required_fields := InputSchema.get_required_fields()):
|
||||
input_schema = block.input_schema
|
||||
for name in (required_fields := input_schema.get_required_fields()):
|
||||
if (
|
||||
name not in provided_inputs
|
||||
# Webhook payload is passed in by ExecutionManager
|
||||
@@ -453,7 +451,7 @@ class GraphModel(Graph):
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
)
|
||||
# Checking availability of credentials is done by ExecutionManager
|
||||
and name not in InputSchema.get_credentials_fields()
|
||||
and name not in input_schema.get_credentials_fields()
|
||||
# Validate only I/O nodes, or validate everything when executing
|
||||
and (
|
||||
for_run
|
||||
@@ -480,43 +478,37 @@ class GraphModel(Graph):
|
||||
)
|
||||
|
||||
# Get input schema properties and check dependencies
|
||||
input_fields = InputSchema.model_fields
|
||||
input_fields = input_schema.model_fields
|
||||
|
||||
def has_value(node: Node, name: str):
|
||||
def has_value(name):
|
||||
return (
|
||||
name in node.input_default
|
||||
node is not None
|
||||
and name in node.input_default
|
||||
and node.input_default[name] is not None
|
||||
and str(node.input_default[name]).strip() != ""
|
||||
) or (name in input_fields and input_fields[name].default is not None)
|
||||
|
||||
# Validate dependencies between fields
|
||||
for field_name in input_fields.keys():
|
||||
field_json_schema = InputSchema.get_field_schema(field_name)
|
||||
|
||||
dependencies: list[str] = []
|
||||
|
||||
# Check regular field dependencies (only pre graph execution)
|
||||
if for_run:
|
||||
dependencies.extend(field_json_schema.get("depends_on", []))
|
||||
|
||||
# Require presence of credentials discriminator (always).
|
||||
# The `discriminator` is either the name of a sibling field (str),
|
||||
# or an object that discriminates between possible types for this field:
|
||||
# {"propertyName": prop_name, "mapping": {prop_value: sub_schema}}
|
||||
if (
|
||||
discriminator := field_json_schema.get("discriminator")
|
||||
) and isinstance(discriminator, str):
|
||||
dependencies.append(discriminator)
|
||||
|
||||
if not dependencies:
|
||||
for field_name, field_info in input_fields.items():
|
||||
# Apply input dependency validation only on run & field with depends_on
|
||||
json_schema_extra = field_info.json_schema_extra or {}
|
||||
if not (
|
||||
for_run
|
||||
and isinstance(json_schema_extra, dict)
|
||||
and (
|
||||
dependencies := cast(
|
||||
list[str], json_schema_extra.get("depends_on", [])
|
||||
)
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
# Check if dependent field has value in input_default
|
||||
field_has_value = has_value(node, field_name)
|
||||
field_has_value = has_value(field_name)
|
||||
field_is_required = field_name in required_fields
|
||||
|
||||
# Check for missing dependencies when dependent field is present
|
||||
missing_deps = [dep for dep in dependencies if not has_value(node, dep)]
|
||||
missing_deps = [dep for dep in dependencies if not has_value(dep)]
|
||||
if missing_deps and (field_has_value or field_is_required):
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id}: Field `{field_name}` requires [{', '.join(missing_deps)}] to be set"
|
||||
@@ -561,7 +553,7 @@ class GraphModel(Graph):
|
||||
if block.block_type not in [BlockType.AGENT]
|
||||
else vals.get("input_schema", {}).get("properties", {}).keys()
|
||||
)
|
||||
if sanitized_name not in fields and not is_tool_pin(name):
|
||||
if sanitized_name not in fields and not name.startswith("tools_^_"):
|
||||
fields_msg = f"Allowed fields: {fields}"
|
||||
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
|
||||
|
||||
@@ -578,8 +570,6 @@ class GraphModel(Graph):
|
||||
id=graph.id,
|
||||
user_id=graph.userId if not for_export else "",
|
||||
version=graph.version,
|
||||
forked_from_id=graph.forkedFromId,
|
||||
forked_from_version=graph.forkedFromVersion,
|
||||
is_active=graph.isActive,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
@@ -692,7 +682,6 @@ async def get_graph(
|
||||
version: int | None = None,
|
||||
user_id: str | None = None,
|
||||
for_export: bool = False,
|
||||
include_subgraphs: bool = False,
|
||||
) -> GraphModel | None:
|
||||
"""
|
||||
Retrieves a graph from the DB.
|
||||
@@ -729,58 +718,6 @@ async def get_graph(
|
||||
):
|
||||
return None
|
||||
|
||||
if include_subgraphs or for_export:
|
||||
sub_graphs = await get_sub_graphs(graph)
|
||||
return GraphModel.from_db(
|
||||
graph=graph,
|
||||
sub_graphs=sub_graphs,
|
||||
for_export=for_export,
|
||||
)
|
||||
|
||||
return GraphModel.from_db(graph, for_export)
|
||||
|
||||
|
||||
async def get_graph_as_admin(
|
||||
graph_id: str,
|
||||
version: int | None = None,
|
||||
user_id: str | None = None,
|
||||
for_export: bool = False,
|
||||
) -> GraphModel | None:
|
||||
"""
|
||||
Intentionally parallels the get_graph but should only be used for admin tasks, because can return any graph that's been submitted
|
||||
Retrieves a graph from the DB.
|
||||
Defaults to the version with `is_active` if `version` is not passed.
|
||||
|
||||
Returns `None` if the record is not found.
|
||||
"""
|
||||
logger.warning(f"Getting {graph_id=} {version=} as ADMIN {user_id=} {for_export=}")
|
||||
where_clause: AgentGraphWhereInput = {
|
||||
"id": graph_id,
|
||||
}
|
||||
|
||||
if version is not None:
|
||||
where_clause["version"] = version
|
||||
|
||||
graph = await AgentGraph.prisma().find_first(
|
||||
where=where_clause,
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
order={"version": "desc"},
|
||||
)
|
||||
|
||||
# For access, the graph must be owned by the user or listed in the store
|
||||
if graph is None or (
|
||||
graph.userId != user_id
|
||||
and not (
|
||||
await StoreListingVersion.prisma().find_first(
|
||||
where={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": version or graph.version,
|
||||
}
|
||||
)
|
||||
)
|
||||
):
|
||||
return None
|
||||
|
||||
if for_export:
|
||||
sub_graphs = await get_sub_graphs(graph)
|
||||
return GraphModel.from_db(
|
||||
@@ -910,27 +847,6 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
|
||||
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
|
||||
|
||||
|
||||
async def fork_graph(graph_id: str, graph_version: int, user_id: str) -> GraphModel:
|
||||
"""
|
||||
Forks a graph by copying it and all its nodes and links to a new graph.
|
||||
"""
|
||||
async with transaction() as tx:
|
||||
graph = await get_graph(graph_id, graph_version, user_id, True)
|
||||
if not graph:
|
||||
raise ValueError(f"Graph {graph_id} v{graph_version} not found")
|
||||
|
||||
# Set forked from ID and version as itself as it's about ot be copied
|
||||
graph.forked_from_id = graph.id
|
||||
graph.forked_from_version = graph.version
|
||||
graph.name = f"{graph.name} (copy)"
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||
graph.validate_graph(for_run=False)
|
||||
|
||||
await __create_graph(tx, graph, user_id)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
graphs = [graph] + graph.sub_graphs
|
||||
|
||||
@@ -943,8 +859,6 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
description=graph.description,
|
||||
isActive=graph.is_active,
|
||||
userId=user_id,
|
||||
forkedFromId=graph.forked_from_id,
|
||||
forkedFromVersion=graph.forked_from_version,
|
||||
)
|
||||
for graph in graphs
|
||||
]
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import enum
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
@@ -450,12 +449,6 @@ class ContributorDetails(BaseModel):
|
||||
name: str = Field(title="Name", description="The name of the contributor.")
|
||||
|
||||
|
||||
class TopUpType(enum.Enum):
|
||||
AUTO = "AUTO"
|
||||
MANUAL = "MANUAL"
|
||||
UNCATEGORIZED = "UNCATEGORIZED"
|
||||
|
||||
|
||||
class AutoTopUpConfig(BaseModel):
|
||||
amount: int
|
||||
"""Amount of credits to top up."""
|
||||
@@ -468,18 +461,12 @@ class UserTransaction(BaseModel):
|
||||
transaction_time: datetime = datetime.min.replace(tzinfo=timezone.utc)
|
||||
transaction_type: CreditTransactionType = CreditTransactionType.USAGE
|
||||
amount: int = 0
|
||||
running_balance: int = 0
|
||||
current_balance: int = 0
|
||||
balance: int = 0
|
||||
description: str | None = None
|
||||
usage_graph_id: str | None = None
|
||||
usage_execution_id: str | None = None
|
||||
usage_node_count: int = 0
|
||||
usage_start_time: datetime = datetime.max.replace(tzinfo=timezone.utc)
|
||||
user_id: str
|
||||
user_email: str | None = None
|
||||
reason: str | None = None
|
||||
admin_email: str | None = None
|
||||
extra_data: str | None = None
|
||||
|
||||
|
||||
class TransactionHistory(BaseModel):
|
||||
|
||||
@@ -189,14 +189,26 @@ NotificationData = Annotated[
|
||||
]
|
||||
|
||||
|
||||
class BaseEventModel(BaseModel):
|
||||
type: NotificationType
|
||||
class NotificationEventDTO(BaseModel):
|
||||
user_id: str
|
||||
type: NotificationType
|
||||
data: dict
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
||||
retry_count: int = 0
|
||||
|
||||
|
||||
class SummaryParamsEventDTO(BaseModel):
|
||||
user_id: str
|
||||
type: NotificationType
|
||||
data: dict
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
||||
|
||||
|
||||
class NotificationEventModel(BaseEventModel, Generic[NotificationDataType_co]):
|
||||
class NotificationEventModel(BaseModel, Generic[NotificationDataType_co]):
|
||||
user_id: str
|
||||
type: NotificationType
|
||||
data: NotificationDataType_co
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
||||
|
||||
@property
|
||||
def strategy(self) -> QueueType:
|
||||
@@ -213,8 +225,11 @@ class NotificationEventModel(BaseEventModel, Generic[NotificationDataType_co]):
|
||||
return NotificationTypeOverride(self.type).template
|
||||
|
||||
|
||||
class SummaryParamsEventModel(BaseEventModel, Generic[SummaryParamsType_co]):
|
||||
class SummaryParamsEventModel(BaseModel, Generic[SummaryParamsType_co]):
|
||||
user_id: str
|
||||
type: NotificationType
|
||||
data: SummaryParamsType_co
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
||||
|
||||
|
||||
def get_notif_data_type(
|
||||
@@ -369,7 +384,7 @@ class UserNotificationBatchDTO(BaseModel):
|
||||
|
||||
def get_batch_delay(notification_type: NotificationType) -> timedelta:
|
||||
return {
|
||||
NotificationType.AGENT_RUN: timedelta(days=1),
|
||||
NotificationType.AGENT_RUN: timedelta(minutes=60),
|
||||
NotificationType.ZERO_BALANCE: timedelta(minutes=60),
|
||||
NotificationType.LOW_BALANCE: timedelta(minutes=60),
|
||||
NotificationType.BLOCK_EXECUTION_FAILED: timedelta(minutes=60),
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from .database import DatabaseManager, DatabaseManagerClient
|
||||
from .database import DatabaseManager
|
||||
from .manager import ExecutionManager
|
||||
from .scheduler import Scheduler
|
||||
|
||||
__all__ = [
|
||||
"DatabaseManager",
|
||||
"DatabaseManagerClient",
|
||||
"ExecutionManager",
|
||||
"Scheduler",
|
||||
]
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import logging
|
||||
from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
create_graph_execution,
|
||||
get_graph_execution,
|
||||
get_graph_execution_meta,
|
||||
get_graph_executions,
|
||||
get_incomplete_node_executions,
|
||||
get_latest_node_execution,
|
||||
get_node_executions,
|
||||
get_node_execution_results,
|
||||
update_graph_execution_start_time,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_stats,
|
||||
@@ -41,14 +39,12 @@ from backend.data.user import (
|
||||
update_user_integrations,
|
||||
update_user_metadata,
|
||||
)
|
||||
from backend.util.service import AppService, AppServiceClient, endpoint_to_sync, expose
|
||||
from backend.util.service import AppService, exposed_run_and_wait
|
||||
from backend.util.settings import Config
|
||||
|
||||
config = Config()
|
||||
_user_credit_model = get_user_credit_model()
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
async def _spend_credits(
|
||||
@@ -57,10 +53,6 @@ async def _spend_credits(
|
||||
return await _user_credit_model.spend_credits(user_id, cost, metadata)
|
||||
|
||||
|
||||
async def _get_credits(user_id: str) -> int:
|
||||
return await _user_credit_model.get_credits(user_id)
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
|
||||
def run_service(self) -> None:
|
||||
@@ -77,115 +69,58 @@ class DatabaseManager(AppService):
|
||||
def get_port(cls) -> int:
|
||||
return config.database_api_port
|
||||
|
||||
@staticmethod
|
||||
def _(
|
||||
f: Callable[P, R], name: str | None = None
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
if name is not None:
|
||||
f.__name__ = name
|
||||
return cast(Callable[Concatenate[object, P], R], expose(f))
|
||||
|
||||
# Executions
|
||||
get_graph_execution = _(get_graph_execution)
|
||||
get_graph_executions = _(get_graph_executions)
|
||||
get_graph_execution_meta = _(get_graph_execution_meta)
|
||||
create_graph_execution = _(create_graph_execution)
|
||||
get_node_executions = _(get_node_executions)
|
||||
get_latest_node_execution = _(get_latest_node_execution)
|
||||
update_node_execution_status = _(update_node_execution_status)
|
||||
update_node_execution_status_batch = _(update_node_execution_status_batch)
|
||||
update_graph_execution_start_time = _(update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(update_graph_execution_stats)
|
||||
update_node_execution_stats = _(update_node_execution_stats)
|
||||
upsert_execution_input = _(upsert_execution_input)
|
||||
upsert_execution_output = _(upsert_execution_output)
|
||||
get_graph_execution = exposed_run_and_wait(get_graph_execution)
|
||||
create_graph_execution = exposed_run_and_wait(create_graph_execution)
|
||||
get_node_execution_results = exposed_run_and_wait(get_node_execution_results)
|
||||
get_incomplete_node_executions = exposed_run_and_wait(
|
||||
get_incomplete_node_executions
|
||||
)
|
||||
get_latest_node_execution = exposed_run_and_wait(get_latest_node_execution)
|
||||
update_node_execution_status = exposed_run_and_wait(update_node_execution_status)
|
||||
update_node_execution_status_batch = exposed_run_and_wait(
|
||||
update_node_execution_status_batch
|
||||
)
|
||||
update_graph_execution_start_time = exposed_run_and_wait(
|
||||
update_graph_execution_start_time
|
||||
)
|
||||
update_graph_execution_stats = exposed_run_and_wait(update_graph_execution_stats)
|
||||
update_node_execution_stats = exposed_run_and_wait(update_node_execution_stats)
|
||||
upsert_execution_input = exposed_run_and_wait(upsert_execution_input)
|
||||
upsert_execution_output = exposed_run_and_wait(upsert_execution_output)
|
||||
|
||||
# Graphs
|
||||
get_node = _(get_node)
|
||||
get_graph = _(get_graph)
|
||||
get_connected_output_nodes = _(get_connected_output_nodes)
|
||||
get_graph_metadata = _(get_graph_metadata)
|
||||
get_node = exposed_run_and_wait(get_node)
|
||||
get_graph = exposed_run_and_wait(get_graph)
|
||||
get_connected_output_nodes = exposed_run_and_wait(get_connected_output_nodes)
|
||||
get_graph_metadata = exposed_run_and_wait(get_graph_metadata)
|
||||
|
||||
# Credits
|
||||
spend_credits = _(_spend_credits, name="spend_credits")
|
||||
get_credits = _(_get_credits, name="get_credits")
|
||||
spend_credits = exposed_run_and_wait(_spend_credits)
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
get_user_metadata = _(get_user_metadata)
|
||||
update_user_metadata = _(update_user_metadata)
|
||||
get_user_integrations = _(get_user_integrations)
|
||||
update_user_integrations = _(update_user_integrations)
|
||||
get_user_metadata = exposed_run_and_wait(get_user_metadata)
|
||||
update_user_metadata = exposed_run_and_wait(update_user_metadata)
|
||||
get_user_integrations = exposed_run_and_wait(get_user_integrations)
|
||||
update_user_integrations = exposed_run_and_wait(update_user_integrations)
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||
get_user_email_by_id = _(get_user_email_by_id)
|
||||
get_user_email_verification = _(get_user_email_verification)
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
get_active_user_ids_in_timerange = exposed_run_and_wait(
|
||||
get_active_user_ids_in_timerange
|
||||
)
|
||||
get_user_email_by_id = exposed_run_and_wait(get_user_email_by_id)
|
||||
get_user_email_verification = exposed_run_and_wait(get_user_email_verification)
|
||||
get_user_notification_preference = exposed_run_and_wait(
|
||||
get_user_notification_preference
|
||||
)
|
||||
|
||||
# Notifications - async
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
create_or_add_to_user_notification_batch = exposed_run_and_wait(
|
||||
create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = _(empty_user_notification_batch)
|
||||
get_all_batches_by_type = _(get_all_batches_by_type)
|
||||
get_user_notification_batch = _(get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = _(
|
||||
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
|
||||
get_all_batches_by_type = exposed_run_and_wait(get_all_batches_by_type)
|
||||
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = exposed_run_and_wait(
|
||||
get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
_ = endpoint_to_sync
|
||||
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return DatabaseManager
|
||||
|
||||
# Executions
|
||||
get_graph_execution = _(d.get_graph_execution)
|
||||
get_graph_executions = _(d.get_graph_executions)
|
||||
get_graph_execution_meta = _(d.get_graph_execution_meta)
|
||||
create_graph_execution = _(d.create_graph_execution)
|
||||
get_node_executions = _(d.get_node_executions)
|
||||
get_latest_node_execution = _(d.get_latest_node_execution)
|
||||
update_node_execution_status = _(d.update_node_execution_status)
|
||||
update_node_execution_status_batch = _(d.update_node_execution_status_batch)
|
||||
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(d.update_graph_execution_stats)
|
||||
update_node_execution_stats = _(d.update_node_execution_stats)
|
||||
upsert_execution_input = _(d.upsert_execution_input)
|
||||
upsert_execution_output = _(d.upsert_execution_output)
|
||||
|
||||
# Graphs
|
||||
get_node = _(d.get_node)
|
||||
get_graph = _(d.get_graph)
|
||||
get_connected_output_nodes = _(d.get_connected_output_nodes)
|
||||
get_graph_metadata = _(d.get_graph_metadata)
|
||||
|
||||
# Credits
|
||||
spend_credits = _(d.spend_credits)
|
||||
get_credits = _(d.get_credits)
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
get_user_metadata = _(d.get_user_metadata)
|
||||
update_user_metadata = _(d.update_user_metadata)
|
||||
get_user_integrations = _(d.get_user_integrations)
|
||||
update_user_integrations = _(d.update_user_integrations)
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = _(d.get_active_user_ids_in_timerange)
|
||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
||||
get_user_email_verification = _(d.get_user_email_verification)
|
||||
get_user_notification_preference = _(d.get_user_notification_preference)
|
||||
|
||||
# Notifications - async
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
d.create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = _(d.empty_user_notification_batch)
|
||||
get_all_batches_by_type = _(d.get_all_batches_by_type)
|
||||
get_user_notification_batch = _(d.get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = _(
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
@@ -5,42 +5,35 @@ import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
GraphExecutionStats,
|
||||
NodeExecutionStats,
|
||||
)
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationEventDTO,
|
||||
NotificationType,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis
|
||||
from backend.data.block import BlockData, BlockInput, BlockSchema, get_block
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
@@ -54,6 +47,7 @@ from backend.executor.utils import (
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
CancelExecutionEvent,
|
||||
UsageTransactionMetadata,
|
||||
block_usage_cost,
|
||||
execution_usage_cost,
|
||||
get_execution_event_bus,
|
||||
@@ -67,24 +61,12 @@ from backend.util.decorator import error_logged, time_measured
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import configure_logging
|
||||
from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.service import close_service_client, get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
active_runs_gauge = Gauge(
|
||||
"execution_manager_active_runs", "Number of active graph runs"
|
||||
)
|
||||
pool_size_gauge = Gauge(
|
||||
"execution_manager_pool_size", "Maximum number of graph workers"
|
||||
)
|
||||
utilization_gauge = Gauge(
|
||||
"execution_manager_utilization_ratio",
|
||||
"Ratio of active graph runs to max graph workers",
|
||||
)
|
||||
|
||||
|
||||
class LogMetadata:
|
||||
def __init__(
|
||||
@@ -139,13 +121,10 @@ ExecutionStream = Generator[NodeExecutionEntry, None, None]
|
||||
|
||||
|
||||
def execute_node(
|
||||
db_client: "DatabaseManagerClient",
|
||||
db_client: "DatabaseManager",
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecutionEntry,
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
) -> ExecutionStream:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
@@ -193,7 +172,7 @@ def execute_node(
|
||||
)
|
||||
|
||||
# Sanity check: validate the execution input.
|
||||
input_data, error = validate_exec(node, data.inputs, resolve_input=False)
|
||||
input_data, error = validate_exec(node, data.data, resolve_input=False)
|
||||
if input_data is None:
|
||||
log_metadata.error(f"Skip execution, input validation error: {error}")
|
||||
push_output("error", error)
|
||||
@@ -203,12 +182,8 @@ def execute_node(
|
||||
# Re-shape the input data for agent block.
|
||||
# AgentExecutorBlock specially separate the node input_data & its input_default.
|
||||
if isinstance(node_block, AgentExecutorBlock):
|
||||
_input_data = AgentExecutorBlock.Input(**node.input_default)
|
||||
_input_data.inputs = input_data
|
||||
if node_credentials_input_map:
|
||||
_input_data.node_credentials_input_map = node_credentials_input_map
|
||||
input_data = _input_data.model_dump()
|
||||
data.inputs = input_data
|
||||
input_data = {**node.input_default, "data": input_data}
|
||||
data.data = input_data
|
||||
|
||||
# Execute the node
|
||||
input_data_str = json.dumps(input_data)
|
||||
@@ -255,7 +230,6 @@ def execute_node(
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
log_metadata=log_metadata,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
):
|
||||
yield execution
|
||||
|
||||
@@ -274,14 +248,13 @@ def execute_node(
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
log_metadata=log_metadata,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
):
|
||||
yield execution
|
||||
|
||||
raise e
|
||||
finally:
|
||||
# Ensure credentials are released even if execution fails
|
||||
if creds_lock and creds_lock.locked() and creds_lock.owned():
|
||||
if creds_lock and creds_lock.locked():
|
||||
try:
|
||||
creds_lock.release()
|
||||
except Exception as e:
|
||||
@@ -297,14 +270,13 @@ def execute_node(
|
||||
|
||||
|
||||
def _enqueue_next_nodes(
|
||||
db_client: "DatabaseManagerClient",
|
||||
db_client: "DatabaseManager",
|
||||
node: Node,
|
||||
output: BlockData,
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]],
|
||||
) -> list[NodeExecutionEntry]:
|
||||
def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
||||
@@ -320,7 +292,7 @@ def _enqueue_next_nodes(
|
||||
node_exec_id=node_exec_id,
|
||||
node_id=node_id,
|
||||
block_id=block_id,
|
||||
inputs=data,
|
||||
data=data,
|
||||
)
|
||||
|
||||
def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
|
||||
@@ -361,15 +333,6 @@ def _enqueue_next_nodes(
|
||||
for name in static_link_names:
|
||||
next_node_input[name] = latest_execution.input_data.get(name)
|
||||
|
||||
# Apply node credentials overrides
|
||||
node_credentials = None
|
||||
if node_credentials_input_map and (
|
||||
node_credentials := node_credentials_input_map.get(next_node.id)
|
||||
):
|
||||
next_node_input.update(
|
||||
{k: v.model_dump() for k, v in node_credentials.items()}
|
||||
)
|
||||
|
||||
# Validate the input data for the next node.
|
||||
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
|
||||
suffix = f"{next_output_name}>{next_input_name}~{next_node_exec_id}:{validation_msg}"
|
||||
@@ -396,10 +359,8 @@ def _enqueue_next_nodes(
|
||||
|
||||
# If link is static, there could be some incomplete executions waiting for it.
|
||||
# Load and complete the input missing input data, and try to re-enqueue them.
|
||||
for iexec in db_client.get_node_executions(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.INCOMPLETE],
|
||||
for iexec in db_client.get_incomplete_node_executions(
|
||||
next_node_id, graph_exec_id
|
||||
):
|
||||
idata = iexec.input_data
|
||||
ineid = iexec.node_exec_id
|
||||
@@ -412,12 +373,6 @@ def _enqueue_next_nodes(
|
||||
for input_name in static_link_names:
|
||||
idata[input_name] = next_node_input[input_name]
|
||||
|
||||
# Apply node credentials overrides
|
||||
if node_credentials:
|
||||
idata.update(
|
||||
{k: v.model_dump() for k, v in node_credentials.items()}
|
||||
)
|
||||
|
||||
idata, msg = validate_exec(next_node, idata)
|
||||
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
|
||||
if not idata:
|
||||
@@ -467,7 +422,6 @@ class Executor:
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@func_retry
|
||||
def on_node_executor_start(cls):
|
||||
configure_logging()
|
||||
set_service_name("NodeExecutor")
|
||||
@@ -478,28 +432,36 @@ class Executor:
|
||||
|
||||
# Set up shutdown handlers
|
||||
cls.shutdown_lock = threading.Lock()
|
||||
atexit.register(cls.on_node_executor_stop)
|
||||
signal.signal(signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm())
|
||||
signal.signal(signal.SIGINT, lambda _, __: cls.on_node_executor_sigterm())
|
||||
atexit.register(cls.on_node_executor_stop) # handle regular shutdown
|
||||
signal.signal( # handle termination
|
||||
signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def on_node_executor_stop(cls, log=logger.info):
|
||||
def on_node_executor_stop(cls):
|
||||
if not cls.shutdown_lock.acquire(blocking=False):
|
||||
return # already shutting down
|
||||
|
||||
log(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
|
||||
cls.creds_manager.release_all_locks()
|
||||
log(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
log(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB manager...")
|
||||
cls.db_client.close()
|
||||
log(f"[on_node_executor_stop {cls.pid}] ✅ Finished NodeExec cleanup")
|
||||
sys.exit(0)
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB manager...")
|
||||
close_service_client(cls.db_client)
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
|
||||
|
||||
@classmethod
|
||||
def on_node_executor_sigterm(cls):
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ NodeExec SIGTERM received")
|
||||
cls.on_node_executor_stop(log=llprint)
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ SIGTERM received")
|
||||
if not cls.shutdown_lock.acquire(blocking=False):
|
||||
return # already shutting down
|
||||
|
||||
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
|
||||
cls.creds_manager.release_all_locks()
|
||||
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
llprint(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
|
||||
sys.exit(0)
|
||||
|
||||
@classmethod
|
||||
@error_logged
|
||||
@@ -507,9 +469,6 @@ class Executor:
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecutionEntry],
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -522,7 +481,7 @@ class Executor:
|
||||
|
||||
execution_stats = NodeExecutionStats()
|
||||
timing_info, _ = cls._on_node_execution(
|
||||
q, node_exec, log_metadata, execution_stats, node_credentials_input_map
|
||||
q, node_exec, log_metadata, execution_stats
|
||||
)
|
||||
execution_stats.walltime = timing_info.wall_time
|
||||
execution_stats.cputime = timing_info.cpu_time
|
||||
@@ -542,9 +501,6 @@ class Executor:
|
||||
node_exec: NodeExecutionEntry,
|
||||
log_metadata: LogMetadata,
|
||||
stats: NodeExecutionStats | None = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
):
|
||||
try:
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
@@ -553,7 +509,6 @@ class Executor:
|
||||
creds_manager=cls.creds_manager,
|
||||
data=node_exec,
|
||||
execution_stats=stats,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
):
|
||||
q.add(execution)
|
||||
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
|
||||
@@ -572,7 +527,6 @@ class Executor:
|
||||
stats.error = e
|
||||
|
||||
@classmethod
|
||||
@func_retry
|
||||
def on_graph_executor_start(cls):
|
||||
configure_logging()
|
||||
set_service_name("GraphExecutor")
|
||||
@@ -580,8 +534,23 @@ class Executor:
|
||||
cls.db_client = get_db_client()
|
||||
cls.pool_size = settings.config.num_node_workers
|
||||
cls.pid = os.getpid()
|
||||
cls.notification_service = get_notification_service()
|
||||
cls._init_node_executor_pool()
|
||||
logger.info(f"GraphExec {cls.pid} started with {cls.pool_size} node workers")
|
||||
logger.info(
|
||||
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
|
||||
)
|
||||
|
||||
# Set up shutdown handler
|
||||
atexit.register(cls.on_graph_executor_stop)
|
||||
|
||||
@classmethod
|
||||
def on_graph_executor_stop(cls):
|
||||
prefix = f"[on_graph_executor_stop {cls.pid}]"
|
||||
logger.info(f"{prefix} ⏳ Terminating node executor pool...")
|
||||
cls.executor.terminate()
|
||||
logger.info(f"{prefix} ⏳ Disconnecting DB manager...")
|
||||
close_service_client(cls.db_client)
|
||||
logger.info(f"{prefix} ✅ Finished cleanup")
|
||||
|
||||
@classmethod
|
||||
def _init_node_executor_pool(cls):
|
||||
@@ -603,46 +572,22 @@ class Executor:
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
|
||||
exec_meta = cls.db_client.get_graph_execution_meta(
|
||||
user_id=graph_exec.user_id,
|
||||
execution_id=graph_exec.graph_exec_id,
|
||||
exec_meta = cls.db_client.update_graph_execution_start_time(
|
||||
graph_exec.graph_exec_id
|
||||
)
|
||||
if exec_meta is None:
|
||||
log_metadata.warning(
|
||||
f"Skipped graph execution #{graph_exec.graph_exec_id}, the graph execution is not found."
|
||||
)
|
||||
return
|
||||
|
||||
if exec_meta.status == ExecutionStatus.QUEUED:
|
||||
log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}")
|
||||
exec_meta.status = ExecutionStatus.RUNNING
|
||||
send_execution_update(
|
||||
cls.db_client.update_graph_execution_start_time(
|
||||
graph_exec.graph_exec_id
|
||||
)
|
||||
)
|
||||
elif exec_meta.status == ExecutionStatus.RUNNING:
|
||||
log_metadata.info(
|
||||
f"⚙️ Graph execution #{graph_exec.graph_exec_id} is already running, continuing where it left off."
|
||||
)
|
||||
else:
|
||||
log_metadata.warning(
|
||||
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution status is `{exec_meta.status}`."
|
||||
logger.warning(
|
||||
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution is not found or not currently in the QUEUED state."
|
||||
)
|
||||
return
|
||||
|
||||
send_execution_update(exec_meta)
|
||||
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
|
||||
graph_exec=graph_exec,
|
||||
cancel=cancel,
|
||||
log_metadata=log_metadata,
|
||||
execution_stats=(
|
||||
exec_meta.stats.to_db() if exec_meta.stats else GraphExecutionStats()
|
||||
),
|
||||
graph_exec, cancel, log_metadata
|
||||
)
|
||||
exec_stats.walltime += timing_info.wall_time
|
||||
exec_stats.cputime += timing_info.cpu_time
|
||||
exec_stats.error = str(error) if error else exec_stats.error
|
||||
exec_stats.walltime = timing_info.wall_time
|
||||
exec_stats.cputime = timing_info.cpu_time
|
||||
exec_stats.error = str(error)
|
||||
|
||||
if graph_exec_result := cls.db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
@@ -659,15 +604,13 @@ class Executor:
|
||||
node_exec: NodeExecutionEntry,
|
||||
execution_count: int,
|
||||
execution_stats: GraphExecutionStats,
|
||||
):
|
||||
) -> int:
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return
|
||||
return execution_count
|
||||
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
cost, matching_filter = block_usage_cost(block=block, input_data=node_exec.data)
|
||||
if cost > 0:
|
||||
cls.db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -680,12 +623,11 @@ class Executor:
|
||||
block_id=node_exec.block_id,
|
||||
block=block.name,
|
||||
input=matching_filter,
|
||||
reason=f"Ran block {node_exec.block_id} {block.name}",
|
||||
),
|
||||
)
|
||||
execution_stats.cost += cost
|
||||
|
||||
cost, usage_count = execution_usage_cost(execution_count)
|
||||
cost, execution_count = execution_usage_cost(execution_count)
|
||||
if cost > 0:
|
||||
cls.db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -694,14 +636,15 @@ class Executor:
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
input={
|
||||
"execution_count": usage_count,
|
||||
"execution_count": execution_count,
|
||||
"charge": "Execution Cost",
|
||||
},
|
||||
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
|
||||
),
|
||||
)
|
||||
execution_stats.cost += cost
|
||||
|
||||
return execution_count
|
||||
|
||||
@classmethod
|
||||
@time_measured
|
||||
def _on_graph_execution(
|
||||
@@ -709,7 +652,6 @@ class Executor:
|
||||
graph_exec: GraphExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
log_metadata: LogMetadata,
|
||||
execution_stats: GraphExecutionStats,
|
||||
) -> tuple[GraphExecutionStats, ExecutionStatus, Exception | None]:
|
||||
"""
|
||||
Returns:
|
||||
@@ -717,6 +659,8 @@ class Executor:
|
||||
ExecutionStatus: The final status of the graph execution.
|
||||
Exception | None: The error that occurred during the execution, if any.
|
||||
"""
|
||||
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
|
||||
execution_stats = GraphExecutionStats()
|
||||
execution_status = ExecutionStatus.RUNNING
|
||||
error = None
|
||||
finished = False
|
||||
@@ -737,21 +681,11 @@ class Executor:
|
||||
cancel_thread.start()
|
||||
|
||||
try:
|
||||
if cls.db_client.get_credits(graph_exec.user_id) <= 0:
|
||||
raise InsufficientBalanceError(
|
||||
user_id=graph_exec.user_id,
|
||||
message="You have no credits left to run an agent.",
|
||||
balance=0,
|
||||
amount=1,
|
||||
)
|
||||
|
||||
queue = ExecutionQueue[NodeExecutionEntry]()
|
||||
for node_exec in cls.db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
statuses=[ExecutionStatus.RUNNING, ExecutionStatus.QUEUED],
|
||||
):
|
||||
queue.add(node_exec.to_node_execution_entry())
|
||||
for node_exec in graph_exec.start_node_execs:
|
||||
queue.add(node_exec)
|
||||
|
||||
exec_cost_counter = 0
|
||||
running_executions: dict[str, AsyncResult] = {}
|
||||
|
||||
def make_exec_callback(exec_data: NodeExecutionEntry):
|
||||
@@ -807,9 +741,9 @@ class Executor:
|
||||
)
|
||||
|
||||
try:
|
||||
cls._charge_usage(
|
||||
exec_cost_counter = cls._charge_usage(
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=increment_execution_count(graph_exec.user_id),
|
||||
execution_count=exec_cost_counter + 1,
|
||||
execution_stats=execution_stats,
|
||||
)
|
||||
except InsufficientBalanceError as error:
|
||||
@@ -839,7 +773,7 @@ class Executor:
|
||||
if (node_creds_map := graph_exec.node_credentials_input_map) and (
|
||||
node_field_creds_map := node_creds_map.get(node_id)
|
||||
):
|
||||
queued_node_exec.inputs.update(
|
||||
queued_node_exec.data.update(
|
||||
{
|
||||
field_name: creds_meta.model_dump()
|
||||
for field_name, creds_meta in node_field_creds_map.items()
|
||||
@@ -849,7 +783,7 @@ class Executor:
|
||||
# Initiate node execution
|
||||
running_executions[queued_node_exec.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(queue, queued_node_exec, node_creds_map),
|
||||
(queue, queued_node_exec),
|
||||
callback=make_exec_callback(queued_node_exec),
|
||||
)
|
||||
|
||||
@@ -870,21 +804,24 @@ class Executor:
|
||||
execution.wait(3)
|
||||
|
||||
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
|
||||
execution_status = ExecutionStatus.COMPLETED
|
||||
|
||||
except Exception as e:
|
||||
error = e
|
||||
log_metadata.error(
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
|
||||
)
|
||||
execution_status = ExecutionStatus.FAILED
|
||||
|
||||
finally:
|
||||
if error:
|
||||
log_metadata.error(
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
|
||||
)
|
||||
execution_status = ExecutionStatus.FAILED
|
||||
else:
|
||||
execution_status = ExecutionStatus.COMPLETED
|
||||
|
||||
if not cancel.is_set():
|
||||
finished = True
|
||||
cancel.set()
|
||||
cancel_thread.join()
|
||||
clean_exec_files(graph_exec.graph_exec_id)
|
||||
|
||||
return execution_stats, execution_status, error
|
||||
|
||||
@classmethod
|
||||
@@ -896,7 +833,7 @@ class Executor:
|
||||
metadata = cls.db_client.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
outputs = cls.db_client.get_node_executions(
|
||||
outputs = cls.db_client.get_node_execution_results(
|
||||
graph_exec.graph_exec_id,
|
||||
block_ids=[AgentOutputBlock().id],
|
||||
)
|
||||
@@ -909,21 +846,21 @@ class Executor:
|
||||
for output in outputs
|
||||
]
|
||||
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=graph_exec.user_id,
|
||||
type=NotificationType.AGENT_RUN,
|
||||
data=AgentRunData(
|
||||
outputs=named_outputs,
|
||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||
credits_used=exec_stats.cost,
|
||||
execution_time=exec_stats.walltime,
|
||||
graph_id=graph_exec.graph_id,
|
||||
node_count=exec_stats.node_count,
|
||||
),
|
||||
)
|
||||
event = NotificationEventDTO(
|
||||
user_id=graph_exec.user_id,
|
||||
type=NotificationType.AGENT_RUN,
|
||||
data=AgentRunData(
|
||||
outputs=named_outputs,
|
||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||
credits_used=exec_stats.cost,
|
||||
execution_time=exec_stats.walltime,
|
||||
graph_id=graph_exec.graph_id,
|
||||
node_count=exec_stats.node_count,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
cls.notification_service.queue_notification(event)
|
||||
|
||||
@classmethod
|
||||
def _handle_low_balance_notif(
|
||||
cls,
|
||||
@@ -937,8 +874,8 @@ class Executor:
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
cls.notification_service.queue_notification(
|
||||
NotificationEventDTO(
|
||||
user_id=user_id,
|
||||
type=NotificationType.LOW_BALANCE,
|
||||
data=LowBalanceData(
|
||||
@@ -946,7 +883,7 @@ class Executor:
|
||||
billing_page_link=f"{base_url}/profile/credits",
|
||||
shortfall=shortfall,
|
||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||
),
|
||||
).model_dump(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -957,23 +894,35 @@ class ExecutionManager(AppProcess):
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.running = True
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
atexit.register(self._on_cleanup)
|
||||
signal.signal(signal.SIGTERM, lambda sig, frame: self._on_sigterm())
|
||||
signal.signal(signal.SIGINT, lambda sig, frame: self._on_sigterm())
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return settings.config.execution_manager_port
|
||||
|
||||
def run(self):
|
||||
pool_size_gauge.set(self.pool_size)
|
||||
active_runs_gauge.set(0)
|
||||
utilization_gauge.set(0)
|
||||
retry_count_max = settings.config.execution_manager_loop_max_retry
|
||||
retry_count = 0
|
||||
|
||||
self.metrics_server = threading.Thread(
|
||||
target=start_http_server,
|
||||
args=(settings.config.execution_manager_port,),
|
||||
daemon=True,
|
||||
)
|
||||
self.metrics_server.start()
|
||||
logger.info(f"[{self.service_name}] Starting execution manager...")
|
||||
self._run()
|
||||
for retry_count in range(retry_count_max):
|
||||
try:
|
||||
self._run()
|
||||
except Exception as e:
|
||||
if not self.running:
|
||||
break
|
||||
logger.exception(
|
||||
f"[{self.service_name}] Error in execution manager: {e}"
|
||||
)
|
||||
|
||||
if retry_count >= retry_count_max:
|
||||
logger.error(
|
||||
f"[{self.service_name}] Max retries reached ({retry_count_max}), exiting..."
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.info(
|
||||
f"[{self.service_name}] Retrying execution loop in {retry_count} seconds..."
|
||||
)
|
||||
time.sleep(retry_count)
|
||||
|
||||
def _run(self):
|
||||
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
|
||||
@@ -985,33 +934,23 @@ class ExecutionManager(AppProcess):
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
|
||||
redis.connect()
|
||||
|
||||
cancel_client = SyncRabbitMQ(create_execution_queue_config())
|
||||
cancel_client.connect()
|
||||
cancel_channel = cancel_client.get_channel()
|
||||
logger.info(f"[{self.service_name}] ⏳ Starting cancel message consumer...")
|
||||
threading.Thread(
|
||||
target=lambda: (
|
||||
cancel_channel.basic_consume(
|
||||
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
on_message_callback=self._handle_cancel_message,
|
||||
auto_ack=True,
|
||||
),
|
||||
cancel_channel.start_consuming(),
|
||||
),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
run_client = SyncRabbitMQ(create_execution_queue_config())
|
||||
run_client.connect()
|
||||
run_channel = run_client.get_channel()
|
||||
run_channel.basic_qos(prefetch_count=self.pool_size)
|
||||
run_channel.basic_consume(
|
||||
# Consume Cancel & Run execution requests.
|
||||
clear_thread_cache(get_execution_queue)
|
||||
channel = get_execution_queue().get_channel()
|
||||
channel.basic_qos(prefetch_count=self.pool_size)
|
||||
channel.basic_consume(
|
||||
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
on_message_callback=self._handle_cancel_message,
|
||||
auto_ack=True,
|
||||
)
|
||||
channel.basic_consume(
|
||||
queue=GRAPH_EXECUTION_QUEUE_NAME,
|
||||
on_message_callback=self._handle_run_message,
|
||||
auto_ack=False,
|
||||
)
|
||||
logger.info(f"[{self.service_name}] ⏳ Starting to consume run messages...")
|
||||
run_channel.start_consuming()
|
||||
|
||||
logger.info(f"[{self.service_name}] Ready to consume messages...")
|
||||
channel.start_consuming()
|
||||
|
||||
def _handle_cancel_message(
|
||||
self,
|
||||
@@ -1081,15 +1020,11 @@ class ExecutionManager(AppProcess):
|
||||
Executor.on_graph_execution, graph_exec_entry, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
active_runs_gauge.set(len(self.active_graph_runs))
|
||||
utilization_gauge.set(len(self.active_graph_runs) / self.pool_size)
|
||||
|
||||
def _on_run_done(f: Future):
|
||||
logger.info(f"[{self.service_name}] Run completed for {graph_exec_id}")
|
||||
try:
|
||||
self.active_graph_runs.pop(graph_exec_id, None)
|
||||
active_runs_gauge.set(len(self.active_graph_runs))
|
||||
utilization_gauge.set(len(self.active_graph_runs) / self.pool_size)
|
||||
if f.exception():
|
||||
logger.error(
|
||||
f"[{self.service_name}] Execution for {graph_exec_id} failed: {f.exception()}"
|
||||
@@ -1108,44 +1043,42 @@ class ExecutionManager(AppProcess):
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
self._on_cleanup()
|
||||
|
||||
def _on_sigterm(self):
|
||||
llprint(f"[{self.service_name}] ⚠️ GraphExec SIGTERM received")
|
||||
self._on_cleanup(log=llprint)
|
||||
|
||||
def _on_cleanup(self, log=logger.info):
|
||||
prefix = f"[{self.service_name}][on_graph_executor_stop {os.getpid()}]"
|
||||
log(f"{prefix} ⏳ Shutting down service loop...")
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down service loop...")
|
||||
self.running = False
|
||||
|
||||
log(f"{prefix} ⏳ Shutting down RabbitMQ channel...")
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down RabbitMQ channel...")
|
||||
get_execution_queue().get_channel().stop_consuming()
|
||||
|
||||
if hasattr(self, "executor"):
|
||||
log(f"{prefix} ⏳ Shutting down GraphExec pool...")
|
||||
self.executor.shutdown(cancel_futures=True, wait=False)
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down graph executor pool...")
|
||||
self.executor.shutdown(cancel_futures=True)
|
||||
|
||||
log(f"{prefix} ⏳ Disconnecting Redis...")
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
|
||||
log(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
@property
|
||||
def db_client(self) -> "DatabaseManager":
|
||||
return get_db_client()
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
from backend.executor import DatabaseManagerClient
|
||||
def get_db_client() -> "DatabaseManager":
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
# Disable health check for the service client to avoid breaking process initializer.
|
||||
return get_service_client(DatabaseManagerClient, health_check=False)
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
def send_execution_update(entry: GraphExecution | NodeExecutionResult | None):
|
||||
if entry is None:
|
||||
return
|
||||
@thread_cached
|
||||
def get_notification_service() -> "NotificationManager":
|
||||
from backend.notifications import NotificationManager
|
||||
|
||||
return get_service_client(NotificationManager)
|
||||
|
||||
|
||||
def send_execution_update(entry: GraphExecution | NodeExecutionResult):
|
||||
return get_execution_event_bus().publish(entry)
|
||||
|
||||
|
||||
@@ -1156,26 +1089,14 @@ def synchronized(key: str, timeout: int = 60):
|
||||
lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
if lock.locked() and lock.owned():
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
|
||||
|
||||
def increment_execution_count(user_id: str) -> int:
|
||||
"""
|
||||
Increment the execution count for a given user,
|
||||
this will be used to charge the user for the execution cost.
|
||||
"""
|
||||
r = redis.get_redis()
|
||||
k = f"uec:{user_id}" # User Execution Count global key
|
||||
counter = cast(int, r.incr(k))
|
||||
if counter == 1:
|
||||
r.expire(k, settings.config.execution_counter_expiration_time)
|
||||
return counter
|
||||
|
||||
|
||||
def llprint(message: str):
|
||||
"""
|
||||
Low-level print/log helper function for use in signal handlers.
|
||||
Regular log/print statements are not allowed in signal handlers.
|
||||
"""
|
||||
os.write(sys.stdout.fileno(), (message + "\n").encode())
|
||||
if logger.getEffectiveLevel() == logging.DEBUG:
|
||||
os.write(sys.stdout.fileno(), (message + "\n").encode())
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
@@ -13,21 +12,13 @@ from apscheduler.triggers.cron import CronTrigger
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from dotenv import load_dotenv
|
||||
from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import Config
|
||||
|
||||
|
||||
@@ -68,11 +59,13 @@ def job_listener(event):
|
||||
|
||||
@thread_cached
|
||||
def get_notification_client():
|
||||
return get_service_client(NotificationManagerClient)
|
||||
from backend.notifications import NotificationManager
|
||||
|
||||
return get_service_client(NotificationManager)
|
||||
|
||||
|
||||
def execute_graph(**kwargs):
|
||||
args = GraphExecutionJobArgs(**kwargs)
|
||||
args = ExecutionJobArgs(**kwargs)
|
||||
try:
|
||||
log(f"Executing recurring job for graph #{args.graph_id}")
|
||||
execution_utils.add_graph_execution(
|
||||
@@ -85,37 +78,6 @@ def execute_graph(**kwargs):
|
||||
logger.exception(f"Error executing graph {args.graph_id}: {e}")
|
||||
|
||||
|
||||
class LateExecutionException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def report_late_executions() -> str:
|
||||
late_executions = execution_utils.get_db_client().get_graph_executions(
|
||||
statuses=[ExecutionStatus.QUEUED],
|
||||
created_time_gte=datetime.now(timezone.utc)
|
||||
- timedelta(seconds=config.execution_late_notification_checkrange_secs),
|
||||
created_time_lte=datetime.now(timezone.utc)
|
||||
- timedelta(seconds=config.execution_late_notification_threshold_secs),
|
||||
limit=1000,
|
||||
)
|
||||
|
||||
if not late_executions:
|
||||
return "No late executions detected."
|
||||
|
||||
num_late_executions = len(late_executions)
|
||||
num_users = len(set([r.user_id for r in late_executions]))
|
||||
error = LateExecutionException(
|
||||
f"Late executions detected: {num_late_executions} late executions from {num_users} users "
|
||||
f"in the last {config.execution_late_notification_checkrange_secs} seconds. "
|
||||
f"Graph has been queued for more than {config.execution_late_notification_threshold_secs} seconds. "
|
||||
"Please check the executor status."
|
||||
)
|
||||
msg = str(error)
|
||||
sentry_capture_error(error)
|
||||
get_notification_client().discord_system_alert(msg)
|
||||
return msg
|
||||
|
||||
|
||||
def process_existing_batches(**kwargs):
|
||||
args = NotificationJobArgs(**kwargs)
|
||||
try:
|
||||
@@ -141,7 +103,7 @@ class Jobstores(Enum):
|
||||
WEEKLY_NOTIFICATIONS = "weekly_notifications"
|
||||
|
||||
|
||||
class GraphExecutionJobArgs(BaseModel):
|
||||
class ExecutionJobArgs(BaseModel):
|
||||
graph_id: str
|
||||
input_data: BlockInput
|
||||
user_id: str
|
||||
@@ -149,16 +111,14 @@ class GraphExecutionJobArgs(BaseModel):
|
||||
cron: str
|
||||
|
||||
|
||||
class GraphExecutionJobInfo(GraphExecutionJobArgs):
|
||||
class ExecutionJobInfo(ExecutionJobArgs):
|
||||
id: str
|
||||
name: str
|
||||
next_run_time: str
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
job_args: GraphExecutionJobArgs, job_obj: JobObj
|
||||
) -> "GraphExecutionJobInfo":
|
||||
return GraphExecutionJobInfo(
|
||||
def from_db(job_args: ExecutionJobArgs, job_obj: JobObj) -> "ExecutionJobInfo":
|
||||
return ExecutionJobInfo(
|
||||
id=job_obj.id,
|
||||
name=job_obj.name,
|
||||
next_run_time=job_obj.next_run_time.isoformat(),
|
||||
@@ -191,9 +151,6 @@ class NotificationJobInfo(NotificationJobArgs):
|
||||
class Scheduler(AppService):
|
||||
scheduler: BlockingScheduler
|
||||
|
||||
def __init__(self, register_system_tasks: bool = True):
|
||||
self.register_system_tasks = register_system_tasks
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return config.execution_scheduler_port
|
||||
@@ -202,6 +159,11 @@ class Scheduler(AppService):
|
||||
def db_pool_size(cls) -> int:
|
||||
return config.scheduler_db_pool_size
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def notification_client(self) -> NotificationManager:
|
||||
return get_service_client(NotificationManager)
|
||||
|
||||
def run_service(self):
|
||||
load_dotenv()
|
||||
db_schema, db_url = _extract_schema_from_url(os.getenv("DIRECT_URL"))
|
||||
@@ -231,37 +193,6 @@ class Scheduler(AppService):
|
||||
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
|
||||
}
|
||||
)
|
||||
|
||||
if self.register_system_tasks:
|
||||
# Notification PROCESS WEEKLY SUMMARY
|
||||
self.scheduler.add_job(
|
||||
process_weekly_summary,
|
||||
CronTrigger.from_crontab("0 * * * *"),
|
||||
id="process_weekly_summary",
|
||||
kwargs={},
|
||||
replace_existing=True,
|
||||
jobstore=Jobstores.WEEKLY_NOTIFICATIONS.value,
|
||||
)
|
||||
|
||||
# Notification PROCESS EXISTING BATCHES
|
||||
# self.scheduler.add_job(
|
||||
# process_existing_batches,
|
||||
# id="process_existing_batches",
|
||||
# CronTrigger.from_crontab("0 12 * * 5"),
|
||||
# replace_existing=True,
|
||||
# jobstore=Jobstores.BATCHED_NOTIFICATIONS.value,
|
||||
# )
|
||||
|
||||
# Notification LATE EXECUTIONS ALERT
|
||||
self.scheduler.add_job(
|
||||
report_late_executions,
|
||||
id="report_late_executions",
|
||||
trigger="interval",
|
||||
replace_existing=True,
|
||||
seconds=config.execution_late_notification_threshold_secs,
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||
self.scheduler.start()
|
||||
|
||||
@@ -272,15 +203,15 @@ class Scheduler(AppService):
|
||||
self.scheduler.shutdown(wait=False)
|
||||
|
||||
@expose
|
||||
def add_graph_execution_schedule(
|
||||
def add_execution_schedule(
|
||||
self,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
cron: str,
|
||||
input_data: BlockInput,
|
||||
user_id: str,
|
||||
) -> GraphExecutionJobInfo:
|
||||
job_args = GraphExecutionJobArgs(
|
||||
) -> ExecutionJobInfo:
|
||||
job_args = ExecutionJobArgs(
|
||||
graph_id=graph_id,
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
@@ -295,66 +226,77 @@ class Scheduler(AppService):
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
log(f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}")
|
||||
return GraphExecutionJobInfo.from_db(job_args, job)
|
||||
return ExecutionJobInfo.from_db(job_args, job)
|
||||
|
||||
@expose
|
||||
def delete_graph_execution_schedule(
|
||||
self, schedule_id: str, user_id: str
|
||||
) -> GraphExecutionJobInfo:
|
||||
def delete_schedule(self, schedule_id: str, user_id: str) -> ExecutionJobInfo:
|
||||
job = self.scheduler.get_job(schedule_id, jobstore=Jobstores.EXECUTION.value)
|
||||
if not job:
|
||||
log(f"Job {schedule_id} not found.")
|
||||
raise ValueError(f"Job #{schedule_id} not found.")
|
||||
|
||||
job_args = GraphExecutionJobArgs(**job.kwargs)
|
||||
job_args = ExecutionJobArgs(**job.kwargs)
|
||||
if job_args.user_id != user_id:
|
||||
raise ValueError("User ID does not match the job's user ID.")
|
||||
|
||||
log(f"Deleting job {schedule_id}")
|
||||
job.remove()
|
||||
|
||||
return GraphExecutionJobInfo.from_db(job_args, job)
|
||||
return ExecutionJobInfo.from_db(job_args, job)
|
||||
|
||||
@expose
|
||||
def get_graph_execution_schedules(
|
||||
def get_execution_schedules(
|
||||
self, graph_id: str | None = None, user_id: str | None = None
|
||||
) -> list[GraphExecutionJobInfo]:
|
||||
jobs: list[JobObj] = self.scheduler.get_jobs(jobstore=Jobstores.EXECUTION.value)
|
||||
) -> list[ExecutionJobInfo]:
|
||||
schedules = []
|
||||
for job in jobs:
|
||||
logger.debug(
|
||||
for job in self.scheduler.get_jobs(jobstore=Jobstores.EXECUTION.value):
|
||||
logger.info(
|
||||
f"Found job {job.id} with cron schedule {job.trigger} and args {job.kwargs}"
|
||||
)
|
||||
try:
|
||||
job_args = GraphExecutionJobArgs.model_validate(job.kwargs)
|
||||
except ValidationError:
|
||||
continue
|
||||
job_args = ExecutionJobArgs(**job.kwargs)
|
||||
if (
|
||||
job.next_run_time is not None
|
||||
and (graph_id is None or job_args.graph_id == graph_id)
|
||||
and (user_id is None or job_args.user_id == user_id)
|
||||
):
|
||||
schedules.append(GraphExecutionJobInfo.from_db(job_args, job))
|
||||
schedules.append(ExecutionJobInfo.from_db(job_args, job))
|
||||
return schedules
|
||||
|
||||
@expose
|
||||
def execute_process_existing_batches(self, kwargs: dict):
|
||||
process_existing_batches(**kwargs)
|
||||
def add_batched_notification_schedule(
|
||||
self,
|
||||
notification_types: list[NotificationType],
|
||||
data: dict,
|
||||
cron: str,
|
||||
) -> NotificationJobInfo:
|
||||
job_args = NotificationJobArgs(
|
||||
notification_types=notification_types,
|
||||
cron=cron,
|
||||
)
|
||||
job = self.scheduler.add_job(
|
||||
process_existing_batches,
|
||||
CronTrigger.from_crontab(cron),
|
||||
kwargs=job_args.model_dump(),
|
||||
replace_existing=True,
|
||||
jobstore=Jobstores.BATCHED_NOTIFICATIONS.value,
|
||||
)
|
||||
log(f"Added job {job.id} with cron schedule '{cron}' input data: {data}")
|
||||
return NotificationJobInfo.from_db(job_args, job)
|
||||
|
||||
@expose
|
||||
def execute_process_weekly_summary(self):
|
||||
process_weekly_summary()
|
||||
def add_weekly_notification_schedule(self, cron: str) -> NotificationJobInfo:
|
||||
|
||||
@expose
|
||||
def execute_report_late_executions(self):
|
||||
return report_late_executions()
|
||||
|
||||
|
||||
class SchedulerClient(AppServiceClient):
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return Scheduler
|
||||
|
||||
add_execution_schedule = endpoint_to_async(Scheduler.add_graph_execution_schedule)
|
||||
delete_schedule = endpoint_to_async(Scheduler.delete_graph_execution_schedule)
|
||||
get_execution_schedules = endpoint_to_async(Scheduler.get_graph_execution_schedules)
|
||||
job = self.scheduler.add_job(
|
||||
process_weekly_summary,
|
||||
CronTrigger.from_crontab(cron),
|
||||
kwargs={},
|
||||
replace_existing=True,
|
||||
jobstore=Jobstores.WEEKLY_NOTIFICATIONS.value,
|
||||
)
|
||||
log(f"Added job {job.id} with cron schedule '{cron}'")
|
||||
return NotificationJobInfo.from_db(
|
||||
NotificationJobArgs(
|
||||
cron=cron, notification_types=[NotificationType.WEEKLY_SUMMARY]
|
||||
),
|
||||
job,
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
config = Config()
|
||||
@@ -82,32 +82,40 @@ def get_integration_credentials_store() -> "IntegrationCredentialsStore":
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
from backend.executor import DatabaseManagerClient
|
||||
def get_db_client() -> "DatabaseManager":
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
# ============ Execution Cost Helpers ============ #
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
graph_exec_id: str | None = None
|
||||
graph_id: str | None = None
|
||||
node_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
block_id: str | None = None
|
||||
block: str | None = None
|
||||
input: BlockInput | None = None
|
||||
|
||||
|
||||
def execution_usage_cost(execution_count: int) -> tuple[int, int]:
|
||||
"""
|
||||
Calculate the cost of executing a graph based on the current number of node executions.
|
||||
Calculate the cost of executing a graph based on the number of executions.
|
||||
|
||||
Args:
|
||||
execution_count: Number of node executions
|
||||
execution_count: Number of executions
|
||||
|
||||
Returns:
|
||||
Tuple of cost amount and the number of execution count that is included in the cost.
|
||||
Tuple of cost amount and remaining execution count
|
||||
"""
|
||||
return (
|
||||
(
|
||||
config.execution_cost_per_threshold
|
||||
if execution_count % config.execution_cost_count_threshold == 0
|
||||
else 0
|
||||
),
|
||||
config.execution_cost_count_threshold,
|
||||
execution_count
|
||||
// config.execution_cost_count_threshold
|
||||
* config.execution_cost_per_threshold,
|
||||
execution_count % config.execution_cost_count_threshold,
|
||||
)
|
||||
|
||||
|
||||
@@ -258,7 +266,7 @@ def validate_exec(
|
||||
If the data is valid, the first element will be the resolved input data, and
|
||||
the second element will be the block name.
|
||||
"""
|
||||
node_block = get_block(node.block_id)
|
||||
node_block: Block | None = get_block(node.block_id)
|
||||
if not node_block:
|
||||
return None, f"Block for {node.block_id} not found."
|
||||
schema = node_block.input_schema
|
||||
@@ -608,10 +616,7 @@ async def add_graph_execution_async(
|
||||
ValueError: If the graph is not found or if there are validation errors.
|
||||
""" # noqa
|
||||
graph: GraphModel | None = await get_graph(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
version=graph_version,
|
||||
include_subgraphs=True,
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
@@ -671,9 +676,6 @@ def add_graph_execution(
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
@@ -686,7 +688,6 @@ def add_graph_execution(
|
||||
graph_version: The version of the graph to execute.
|
||||
graph_credentials_inputs: Credentials inputs to use in the execution.
|
||||
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
|
||||
node_credentials_input_map: Credentials inputs to use in the execution, mapped to specific nodes.
|
||||
Returns:
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
@@ -694,15 +695,12 @@ def add_graph_execution(
|
||||
"""
|
||||
db = get_db_client()
|
||||
graph: GraphModel | None = db.get_graph(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
version=graph_version,
|
||||
include_subgraphs=True,
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
node_credentials_input_map = node_credentials_input_map or (
|
||||
node_credentials_input_map = (
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else None
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
from pydantic import SecretStr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor.database import DatabaseManagerClient
|
||||
from backend.executor.database import DatabaseManager
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from autogpt_libs.utils.synchronize import RedisKeyedMutex
|
||||
@@ -210,11 +210,11 @@ class IntegrationCredentialsStore:
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def db_manager(self) -> "DatabaseManagerClient":
|
||||
from backend.executor.database import DatabaseManagerClient
|
||||
def db_manager(self) -> "DatabaseManager":
|
||||
from backend.executor.database import DatabaseManager
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
def add_creds(self, user_id: str, credentials: Credentials) -> None:
|
||||
with self.locked_user_integrations(user_id):
|
||||
|
||||
@@ -93,7 +93,7 @@ class IntegrationCredentialsManager:
|
||||
|
||||
fresh_credentials = oauth_handler.refresh_tokens(credentials)
|
||||
self.store.update_creds(user_id, fresh_credentials)
|
||||
if _lock and _lock.locked() and _lock.owned():
|
||||
if _lock and _lock.locked():
|
||||
_lock.release()
|
||||
|
||||
credentials = fresh_credentials
|
||||
@@ -145,7 +145,7 @@ class IntegrationCredentialsManager:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if lock.locked() and lock.owned():
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
|
||||
def release_all_locks(self):
|
||||
|
||||
@@ -16,7 +16,6 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
|
||||
from .generic import GenericWebhooksManager
|
||||
from .github import GithubWebhooksManager
|
||||
from .slant3d import Slant3DWebhooksManager
|
||||
from .exa import ExaWebhooksManager
|
||||
|
||||
_WEBHOOK_MANAGERS.update(
|
||||
{
|
||||
@@ -26,7 +25,6 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
|
||||
GithubWebhooksManager,
|
||||
Slant3DWebhooksManager,
|
||||
GenericWebhooksManager,
|
||||
ExaWebhooksManager,
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from fastapi import Request
|
||||
|
||||
from backend.data import integrations
|
||||
from backend.data.model import APIKeyCredentials, Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExaWebhooksManager(BaseWebhooksManager):
|
||||
"""Manager for Exa webhooks"""
|
||||
|
||||
PROVIDER_NAME = ProviderName.EXA
|
||||
BASE_URL = "https://api.exa.ai/websets/v0"
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> tuple[str, dict]:
|
||||
"""Register a new webhook with Exa"""
|
||||
|
||||
if not isinstance(credentials, APIKeyCredentials):
|
||||
raise ValueError("API key is required to register a webhook")
|
||||
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"events": events,
|
||||
"url": ingress_url,
|
||||
"metadata": {} # Optional metadata can be added here
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.BASE_URL}/webhooks", headers=headers, json=payload
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
error = response.json().get("error", "Unknown error")
|
||||
raise RuntimeError(f"Failed to register webhook: {error}")
|
||||
|
||||
response_data = response.json()
|
||||
webhook_id = response_data.get("id", "")
|
||||
|
||||
webhook_config = {
|
||||
"endpoint": ingress_url,
|
||||
"provider": self.PROVIDER_NAME,
|
||||
"events": events,
|
||||
"type": webhook_type,
|
||||
"webhook_id": webhook_id,
|
||||
"secret": response_data.get("secret", "")
|
||||
}
|
||||
|
||||
return webhook_id, webhook_config
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(
|
||||
cls, webhook: integrations.Webhook, request: Request
|
||||
) -> tuple[dict, str]:
|
||||
"""Validate incoming webhook payload from Exa"""
|
||||
|
||||
payload = await request.json()
|
||||
|
||||
# Validate required fields from Exa API spec
|
||||
required_fields = ["id", "object", "type", "data", "createdAt"]
|
||||
missing_fields = [field for field in required_fields if field not in payload]
|
||||
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")
|
||||
|
||||
# Normalize payload structure
|
||||
normalized_payload = {
|
||||
"id": payload["id"],
|
||||
"type": payload["type"],
|
||||
"data": payload["data"],
|
||||
"createdAt": payload["createdAt"]
|
||||
}
|
||||
|
||||
# Extract event type from the payload
|
||||
event_type = payload["type"]
|
||||
|
||||
return normalized_payload, event_type
|
||||
|
||||
async def _deregister_webhook(
|
||||
self, webhook: integrations.Webhook, credentials: Credentials
|
||||
) -> None:
|
||||
"""Deregister a webhook with Exa"""
|
||||
|
||||
if not isinstance(credentials, APIKeyCredentials):
|
||||
raise ValueError("API key is required to deregister a webhook")
|
||||
|
||||
webhook_id = webhook.config.get("webhook_id")
|
||||
if not webhook_id:
|
||||
logger.warning(f"No webhook ID found for webhook {webhook.id}, cannot deregister")
|
||||
return
|
||||
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
response = requests.delete(
|
||||
f"{self.BASE_URL}/webhooks/{webhook_id}", headers=headers
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
error = response.json().get("error", "Unknown error")
|
||||
logger.error(f"Failed to deregister webhook {webhook_id}: {error}")
|
||||
raise RuntimeError(f"Failed to deregister webhook: {error}")
|
||||
@@ -1,6 +1,5 @@
|
||||
from .notifications import NotificationManager, NotificationManagerClient
|
||||
from .notifications import NotificationManager
|
||||
|
||||
__all__ = [
|
||||
"NotificationManager",
|
||||
"NotificationManagerClient",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable
|
||||
|
||||
@@ -8,18 +7,20 @@ import aio_pika
|
||||
from aio_pika.exceptions import QueueEmpty
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import rabbitmq
|
||||
from backend.data.notifications import (
|
||||
BaseEventModel,
|
||||
BaseSummaryData,
|
||||
BaseSummaryParams,
|
||||
DailySummaryData,
|
||||
DailySummaryParams,
|
||||
NotificationEventDTO,
|
||||
NotificationEventModel,
|
||||
NotificationResult,
|
||||
NotificationTypeOverride,
|
||||
QueueType,
|
||||
SummaryParamsEventDTO,
|
||||
SummaryParamsEventModel,
|
||||
WeeklySummaryData,
|
||||
WeeklySummaryParams,
|
||||
@@ -27,178 +28,96 @@ from backend.data.notifications import (
|
||||
get_notif_data_type,
|
||||
get_summary_params_type,
|
||||
)
|
||||
from backend.data.rabbitmq import (
|
||||
AsyncRabbitMQ,
|
||||
Exchange,
|
||||
ExchangeType,
|
||||
Queue,
|
||||
RabbitMQConfig,
|
||||
SyncRabbitMQ,
|
||||
)
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.user import generate_unsubscribe_link
|
||||
from backend.notifications.email import EmailSender
|
||||
from backend.util.metrics import discord_send_alert
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
NOTIFICATION_EXCHANGE = Exchange(name="notifications", type=ExchangeType.TOPIC)
|
||||
DEAD_LETTER_EXCHANGE = Exchange(name="dead_letter", type=ExchangeType.TOPIC)
|
||||
EXCHANGES = [NOTIFICATION_EXCHANGE, DEAD_LETTER_EXCHANGE]
|
||||
|
||||
background_executor = ThreadPoolExecutor(max_workers=2)
|
||||
class NotificationEvent(BaseModel):
|
||||
event: NotificationEventDTO
|
||||
model: NotificationEventModel
|
||||
|
||||
|
||||
def create_notification_config() -> RabbitMQConfig:
|
||||
"""Create RabbitMQ configuration for notifications"""
|
||||
notification_exchange = Exchange(name="notifications", type=ExchangeType.TOPIC)
|
||||
|
||||
dead_letter_exchange = Exchange(name="dead_letter", type=ExchangeType.TOPIC)
|
||||
|
||||
queues = [
|
||||
# Main notification queues
|
||||
Queue(
|
||||
name="immediate_notifications",
|
||||
exchange=NOTIFICATION_EXCHANGE,
|
||||
exchange=notification_exchange,
|
||||
routing_key="notification.immediate.#",
|
||||
arguments={
|
||||
"x-dead-letter-exchange": DEAD_LETTER_EXCHANGE.name,
|
||||
"x-dead-letter-exchange": dead_letter_exchange.name,
|
||||
"x-dead-letter-routing-key": "failed.immediate",
|
||||
},
|
||||
),
|
||||
Queue(
|
||||
name="admin_notifications",
|
||||
exchange=NOTIFICATION_EXCHANGE,
|
||||
exchange=notification_exchange,
|
||||
routing_key="notification.admin.#",
|
||||
arguments={
|
||||
"x-dead-letter-exchange": DEAD_LETTER_EXCHANGE.name,
|
||||
"x-dead-letter-exchange": dead_letter_exchange.name,
|
||||
"x-dead-letter-routing-key": "failed.admin",
|
||||
},
|
||||
),
|
||||
# Summary notification queues
|
||||
Queue(
|
||||
name="summary_notifications",
|
||||
exchange=NOTIFICATION_EXCHANGE,
|
||||
exchange=notification_exchange,
|
||||
routing_key="notification.summary.#",
|
||||
arguments={
|
||||
"x-dead-letter-exchange": DEAD_LETTER_EXCHANGE.name,
|
||||
"x-dead-letter-exchange": dead_letter_exchange.name,
|
||||
"x-dead-letter-routing-key": "failed.summary",
|
||||
},
|
||||
),
|
||||
# Batch Queue
|
||||
Queue(
|
||||
name="batch_notifications",
|
||||
exchange=NOTIFICATION_EXCHANGE,
|
||||
exchange=notification_exchange,
|
||||
routing_key="notification.batch.#",
|
||||
arguments={
|
||||
"x-dead-letter-exchange": DEAD_LETTER_EXCHANGE.name,
|
||||
"x-dead-letter-exchange": dead_letter_exchange.name,
|
||||
"x-dead-letter-routing-key": "failed.batch",
|
||||
},
|
||||
),
|
||||
# Failed notifications queue
|
||||
Queue(
|
||||
name="failed_notifications",
|
||||
exchange=DEAD_LETTER_EXCHANGE,
|
||||
exchange=dead_letter_exchange,
|
||||
routing_key="failed.#",
|
||||
),
|
||||
]
|
||||
|
||||
return RabbitMQConfig(
|
||||
exchanges=EXCHANGES,
|
||||
exchanges=[
|
||||
notification_exchange,
|
||||
dead_letter_exchange,
|
||||
],
|
||||
queues=queues,
|
||||
)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_scheduler():
|
||||
from backend.executor import Scheduler
|
||||
|
||||
return get_service_client(Scheduler)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db():
|
||||
from backend.executor.database import DatabaseManagerClient
|
||||
from backend.executor.database import DatabaseManager
|
||||
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_queue() -> SyncRabbitMQ:
|
||||
client = SyncRabbitMQ(create_notification_config())
|
||||
client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def get_async_notification_queue() -> AsyncRabbitMQ:
|
||||
client = AsyncRabbitMQ(create_notification_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
def get_routing_key(event_type: NotificationType) -> str:
|
||||
strategy = NotificationTypeOverride(event_type).strategy
|
||||
"""Get the appropriate routing key for an event"""
|
||||
if strategy == QueueType.IMMEDIATE:
|
||||
return f"notification.immediate.{event_type.value}"
|
||||
elif strategy == QueueType.BACKOFF:
|
||||
return f"notification.backoff.{event_type.value}"
|
||||
elif strategy == QueueType.ADMIN:
|
||||
return f"notification.admin.{event_type.value}"
|
||||
elif strategy == QueueType.BATCH:
|
||||
return f"notification.batch.{event_type.value}"
|
||||
elif strategy == QueueType.SUMMARY:
|
||||
return f"notification.summary.{event_type.value}"
|
||||
return f"notification.{event_type.value}"
|
||||
|
||||
|
||||
def queue_notification(event: NotificationEventModel) -> NotificationResult:
|
||||
"""Queue a notification - exposed method for other services to call"""
|
||||
try:
|
||||
logger.debug(f"Received Request to queue {event=}")
|
||||
|
||||
exchange = "notifications"
|
||||
routing_key = get_routing_key(event.type)
|
||||
|
||||
queue = get_notification_queue()
|
||||
queue.publish_message(
|
||||
routing_key=routing_key,
|
||||
message=event.model_dump_json(),
|
||||
exchange=next(ex for ex in EXCHANGES if ex.name == exchange),
|
||||
)
|
||||
|
||||
return NotificationResult(
|
||||
success=True,
|
||||
message=f"Notification queued with routing key: {routing_key}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error queueing notification: {e}")
|
||||
return NotificationResult(success=False, message=str(e))
|
||||
|
||||
|
||||
async def queue_notification_async(event: NotificationEventModel) -> NotificationResult:
|
||||
"""Queue a notification - exposed method for other services to call"""
|
||||
try:
|
||||
logger.debug(f"Received Request to queue {event=}")
|
||||
|
||||
exchange = "notifications"
|
||||
routing_key = get_routing_key(event.type)
|
||||
|
||||
queue = await get_async_notification_queue()
|
||||
await queue.publish_message(
|
||||
routing_key=routing_key,
|
||||
message=event.model_dump_json(),
|
||||
exchange=next(ex for ex in EXCHANGES if ex.name == exchange),
|
||||
)
|
||||
|
||||
return NotificationResult(
|
||||
success=True,
|
||||
message=f"Notification queued with routing key: {routing_key}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error queueing notification: {e}")
|
||||
return NotificationResult(success=False, message=str(e))
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
class NotificationManager(AppService):
|
||||
@@ -228,11 +147,23 @@ class NotificationManager(AppService):
|
||||
def get_port(cls) -> int:
|
||||
return settings.config.notification_service_port
|
||||
|
||||
def get_routing_key(self, event_type: NotificationType) -> str:
|
||||
strategy = NotificationTypeOverride(event_type).strategy
|
||||
"""Get the appropriate routing key for an event"""
|
||||
if strategy == QueueType.IMMEDIATE:
|
||||
return f"notification.immediate.{event_type.value}"
|
||||
elif strategy == QueueType.BACKOFF:
|
||||
return f"notification.backoff.{event_type.value}"
|
||||
elif strategy == QueueType.ADMIN:
|
||||
return f"notification.admin.{event_type.value}"
|
||||
elif strategy == QueueType.BATCH:
|
||||
return f"notification.batch.{event_type.value}"
|
||||
elif strategy == QueueType.SUMMARY:
|
||||
return f"notification.summary.{event_type.value}"
|
||||
return f"notification.{event_type.value}"
|
||||
|
||||
@expose
|
||||
def queue_weekly_summary(self):
|
||||
background_executor.submit(self._queue_weekly_summary)
|
||||
|
||||
def _queue_weekly_summary(self):
|
||||
"""Process weekly summary for specified notification types"""
|
||||
try:
|
||||
logger.info("Processing weekly summary queuing operation")
|
||||
@@ -246,13 +177,13 @@ class NotificationManager(AppService):
|
||||
for user in users:
|
||||
|
||||
self._queue_scheduled_notification(
|
||||
SummaryParamsEventModel(
|
||||
SummaryParamsEventDTO(
|
||||
user_id=user,
|
||||
type=NotificationType.WEEKLY_SUMMARY,
|
||||
data=WeeklySummaryParams(
|
||||
start_date=start_time,
|
||||
end_date=current_time,
|
||||
),
|
||||
).model_dump(),
|
||||
),
|
||||
)
|
||||
processed_count += 1
|
||||
@@ -264,9 +195,6 @@ class NotificationManager(AppService):
|
||||
|
||||
@expose
|
||||
def process_existing_batches(self, notification_types: list[NotificationType]):
|
||||
background_executor.submit(self._process_existing_batches, notification_types)
|
||||
|
||||
def _process_existing_batches(self, notification_types: list[NotificationType]):
|
||||
"""Process existing batches for specified notification types"""
|
||||
try:
|
||||
processed_count = 0
|
||||
@@ -386,23 +314,65 @@ class NotificationManager(AppService):
|
||||
}
|
||||
|
||||
@expose
|
||||
def discord_system_alert(self, content: str):
|
||||
discord_send_alert(content)
|
||||
|
||||
def _queue_scheduled_notification(self, event: SummaryParamsEventModel):
|
||||
"""Queue a scheduled notification - exposed method for other services to call"""
|
||||
def queue_notification(self, event: NotificationEventDTO) -> NotificationResult:
|
||||
"""Queue a notification - exposed method for other services to call"""
|
||||
try:
|
||||
logger.debug(f"Received Request to queue scheduled notification {event=}")
|
||||
logger.info(f"Received Request to queue {event=}")
|
||||
# Workaround for not being able to serialize generics over the expose bus
|
||||
parsed_event = NotificationEventModel[
|
||||
get_notif_data_type(event.type)
|
||||
].model_validate(event.model_dump())
|
||||
routing_key = self.get_routing_key(parsed_event.type)
|
||||
message = parsed_event.model_dump_json()
|
||||
|
||||
logger.info(f"Received Request to queue {message=}")
|
||||
|
||||
exchange = "notifications"
|
||||
routing_key = get_routing_key(event.type)
|
||||
|
||||
# Publish to RabbitMQ
|
||||
self.run_and_wait(
|
||||
self.rabbit.publish_message(
|
||||
routing_key=routing_key,
|
||||
message=event.model_dump_json(),
|
||||
exchange=next(ex for ex in EXCHANGES if ex.name == exchange),
|
||||
message=message,
|
||||
exchange=next(
|
||||
ex for ex in self.rabbit_config.exchanges if ex.name == exchange
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return NotificationResult(
|
||||
success=True,
|
||||
message=f"Notification queued with routing key: {routing_key}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error queueing notification: {e}")
|
||||
return NotificationResult(success=False, message=str(e))
|
||||
|
||||
def _queue_scheduled_notification(self, event: SummaryParamsEventDTO):
|
||||
"""Queue a scheduled notification - exposed method for other services to call"""
|
||||
try:
|
||||
logger.info(f"Received Request to queue scheduled notification {event=}")
|
||||
|
||||
parsed_event = SummaryParamsEventModel[
|
||||
get_summary_params_type(event.type)
|
||||
].model_validate(event.model_dump())
|
||||
|
||||
routing_key = self.get_routing_key(event.type)
|
||||
message = parsed_event.model_dump_json()
|
||||
|
||||
logger.info(f"Received Request to queue {message=}")
|
||||
|
||||
exchange = "notifications"
|
||||
|
||||
# Publish to RabbitMQ
|
||||
self.run_and_wait(
|
||||
self.rabbit.publish_message(
|
||||
routing_key=routing_key,
|
||||
message=message,
|
||||
exchange=next(
|
||||
ex for ex in self.rabbit_config.exchanges if ex.name == exchange
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -528,12 +498,13 @@ class NotificationManager(AppService):
|
||||
)
|
||||
return False
|
||||
|
||||
def _parse_message(self, message: str) -> NotificationEventModel | None:
|
||||
def _parse_message(self, message: str) -> NotificationEvent | None:
|
||||
try:
|
||||
event = BaseEventModel.model_validate_json(message)
|
||||
return NotificationEventModel[
|
||||
event = NotificationEventDTO.model_validate_json(message)
|
||||
model = NotificationEventModel[
|
||||
get_notif_data_type(event.type)
|
||||
].model_validate_json(message)
|
||||
return NotificationEvent(event=event, model=model)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing message due to non matching schema {e}")
|
||||
return None
|
||||
@@ -541,12 +512,14 @@ class NotificationManager(AppService):
|
||||
def _process_admin_message(self, message: str) -> bool:
|
||||
"""Process a single notification, sending to an admin, returning whether to put into the failed queue"""
|
||||
try:
|
||||
event = self._parse_message(message)
|
||||
if not event:
|
||||
parsed = self._parse_message(message)
|
||||
if not parsed:
|
||||
return False
|
||||
logger.debug(f"Processing notification for admin: {event}")
|
||||
event = parsed.event
|
||||
model = parsed.model
|
||||
logger.debug(f"Processing notification for admin: {model}")
|
||||
recipient_email = settings.config.refund_notification_email
|
||||
self.email_sender.send_templated(event.type, recipient_email, event)
|
||||
self.email_sender.send_templated(event.type, recipient_email, model)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing notification for admin queue: {e}")
|
||||
@@ -555,10 +528,12 @@ class NotificationManager(AppService):
|
||||
def _process_immediate(self, message: str) -> bool:
|
||||
"""Process a single notification immediately, returning whether to put into the failed queue"""
|
||||
try:
|
||||
event = self._parse_message(message)
|
||||
if not event:
|
||||
parsed = self._parse_message(message)
|
||||
if not parsed:
|
||||
return False
|
||||
logger.debug(f"Processing immediate notification: {event}")
|
||||
event = parsed.event
|
||||
model = parsed.model
|
||||
logger.debug(f"Processing immediate notification: {model}")
|
||||
|
||||
recipient_email = get_db().get_user_email_by_id(event.user_id)
|
||||
if not recipient_email:
|
||||
@@ -579,7 +554,7 @@ class NotificationManager(AppService):
|
||||
self.email_sender.send_templated(
|
||||
notification=event.type,
|
||||
user_email=recipient_email,
|
||||
data=event,
|
||||
data=model,
|
||||
user_unsub_link=unsub_link,
|
||||
)
|
||||
return True
|
||||
@@ -590,10 +565,12 @@ class NotificationManager(AppService):
|
||||
def _process_batch(self, message: str) -> bool:
|
||||
"""Process a single notification with a batching strategy, returning whether to put into the failed queue"""
|
||||
try:
|
||||
event = self._parse_message(message)
|
||||
if not event:
|
||||
parsed = self._parse_message(message)
|
||||
if not parsed:
|
||||
return False
|
||||
logger.info(f"Processing batch notification: {event}")
|
||||
event = parsed.event
|
||||
model = parsed.model
|
||||
logger.info(f"Processing batch notification: {model}")
|
||||
|
||||
recipient_email = get_db().get_user_email_by_id(event.user_id)
|
||||
if not recipient_email:
|
||||
@@ -609,7 +586,7 @@ class NotificationManager(AppService):
|
||||
)
|
||||
return True
|
||||
|
||||
should_send = self._should_batch(event.user_id, event.type, event)
|
||||
should_send = self._should_batch(event.user_id, event.type, model)
|
||||
|
||||
if not should_send:
|
||||
logger.info("Batch not old enough to send")
|
||||
@@ -651,7 +628,7 @@ class NotificationManager(AppService):
|
||||
"""Process a single notification with a summary strategy, returning whether to put into the failed queue"""
|
||||
try:
|
||||
logger.info(f"Processing summary notification: {message}")
|
||||
event = BaseEventModel.model_validate_json(message)
|
||||
event = SummaryParamsEventDTO.model_validate_json(message)
|
||||
model = SummaryParamsEventModel[
|
||||
get_summary_params_type(event.type)
|
||||
].model_validate_json(message)
|
||||
@@ -732,6 +709,22 @@ class NotificationManager(AppService):
|
||||
|
||||
logger.info(f"[{self.service_name}] Started notification service")
|
||||
|
||||
# Set up scheduler for batch processing of all notification types
|
||||
# this can be changed later to spawn different cleanups on different schedules
|
||||
try:
|
||||
get_scheduler().add_batched_notification_schedule(
|
||||
notification_types=list(NotificationType),
|
||||
data={},
|
||||
cron="0 * * * *",
|
||||
)
|
||||
# get_scheduler().add_weekly_notification_schedule(
|
||||
# # weekly on Friday at 12pm
|
||||
# cron="0 12 * * 5",
|
||||
# )
|
||||
logger.info("Scheduled notification cleanup")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scheduling notification cleanup: {e}")
|
||||
|
||||
# Set up queue consumers
|
||||
channel = self.run_and_wait(self.rabbit.get_channel())
|
||||
|
||||
@@ -781,13 +774,3 @@ class NotificationManager(AppService):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
|
||||
self.run_and_wait(self.rabbitmq_service.disconnect())
|
||||
|
||||
|
||||
class NotificationManagerClient(AppServiceClient):
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return NotificationManager
|
||||
|
||||
process_existing_batches = NotificationManager.process_existing_batches
|
||||
queue_weekly_summary = NotificationManager.queue_weekly_summary
|
||||
discord_system_alert = NotificationManager.discord_system_alert
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from backend.app import run_processes
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.executor import DatabaseManager, Scheduler
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
from backend.server.rest_api import AgentServer
|
||||
|
||||
@@ -11,6 +11,7 @@ def main():
|
||||
run_processes(
|
||||
NotificationManager(),
|
||||
DatabaseManager(),
|
||||
Scheduler(),
|
||||
AgentServer(),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from backend.app import run_processes
|
||||
from backend.executor.scheduler import Scheduler
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Run all the processes required for the AutoGPT-server Scheduling System.
|
||||
"""
|
||||
run_processes(Scheduler())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -122,7 +122,7 @@ async def get_graph_execution_results(
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
results = await execution_db.get_node_executions(graph_exec_id)
|
||||
results = await execution_db.get_node_execution_results(graph_exec_id)
|
||||
last_result = results[-1] if results else None
|
||||
execution_status = (
|
||||
last_result.status if last_result else AgentExecutionStatus.INCOMPLETE
|
||||
|
||||
@@ -19,7 +19,6 @@ import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.server.routers.postmark.postmark
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.credit_admin_routes
|
||||
import backend.server.v2.admin.store_admin_routes
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
@@ -27,7 +26,6 @@ import backend.server.v2.library.routes
|
||||
import backend.server.v2.otto.routes
|
||||
import backend.server.v2.store.model
|
||||
import backend.server.v2.store.routes
|
||||
import backend.server.v2.turnstile.routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.blocks.llm import LlmModel
|
||||
@@ -109,20 +107,12 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/store",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.admin.credit_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/credits",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.otto.routes.router, tags=["v2"], prefix="/api/otto"
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.turnstile.routes.router, tags=["v2"], prefix="/api/turnstile"
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
backend.server.routers.postmark.postmark.router,
|
||||
|
||||
@@ -57,7 +57,7 @@ from backend.data.user import (
|
||||
update_user_email,
|
||||
update_user_notification_preference,
|
||||
)
|
||||
from backend.executor import scheduler
|
||||
from backend.executor import Scheduler, scheduler
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
@@ -83,8 +83,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_scheduler_client() -> scheduler.SchedulerClient:
|
||||
return get_service_client(scheduler.SchedulerClient)
|
||||
def execution_scheduler_client() -> Scheduler:
|
||||
return get_service_client(Scheduler)
|
||||
|
||||
|
||||
@thread_cached
|
||||
@@ -422,11 +422,7 @@ async def get_graph(
|
||||
for_export: bool = False,
|
||||
) -> graph_db.GraphModel:
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id,
|
||||
version,
|
||||
user_id=user_id,
|
||||
for_export=for_export,
|
||||
include_subgraphs=True, # needed to construct full credentials input schema
|
||||
graph_id, version, user_id=user_id, for_export=for_export
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
@@ -667,7 +663,7 @@ async def _cancel_execution(graph_exec_id: str):
|
||||
)
|
||||
node_execs = [
|
||||
node_exec.model_copy(update={"status": execution_db.ExecutionStatus.TERMINATED})
|
||||
for node_exec in await execution_db.get_node_executions(
|
||||
for node_exec in await execution_db.get_node_execution_results(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
execution_db.ExecutionStatus.QUEUED,
|
||||
@@ -773,7 +769,7 @@ class ScheduleCreationRequest(pydantic.BaseModel):
|
||||
async def create_schedule(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
schedule: ScheduleCreationRequest,
|
||||
) -> scheduler.GraphExecutionJobInfo:
|
||||
) -> scheduler.ExecutionJobInfo:
|
||||
graph = await graph_db.get_graph(
|
||||
schedule.graph_id, schedule.graph_version, user_id=user_id
|
||||
)
|
||||
@@ -783,12 +779,14 @@ async def create_schedule(
|
||||
detail=f"Graph #{schedule.graph_id} v.{schedule.graph_version} not found.",
|
||||
)
|
||||
|
||||
return await execution_scheduler_client().add_execution_schedule(
|
||||
graph_id=schedule.graph_id,
|
||||
graph_version=graph.version,
|
||||
cron=schedule.cron,
|
||||
input_data=schedule.input_data,
|
||||
user_id=user_id,
|
||||
return await asyncio.to_thread(
|
||||
lambda: execution_scheduler_client().add_execution_schedule(
|
||||
graph_id=schedule.graph_id,
|
||||
graph_version=graph.version,
|
||||
cron=schedule.cron,
|
||||
input_data=schedule.input_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -797,11 +795,11 @@ async def create_schedule(
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def delete_schedule(
|
||||
def delete_schedule(
|
||||
schedule_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
|
||||
execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
|
||||
return {"id": schedule_id}
|
||||
|
||||
|
||||
@@ -810,11 +808,11 @@ async def delete_schedule(
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_execution_schedules(
|
||||
def get_execution_schedules(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_id: str | None = None,
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
return await execution_scheduler_client().get_execution_schedules(
|
||||
) -> list[scheduler.ExecutionJobInfo]:
|
||||
return execution_scheduler_client().get_execution_schedules(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
)
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
import logging
|
||||
import typing
|
||||
|
||||
from autogpt_libs.auth import requires_admin_user
|
||||
from autogpt_libs.auth.depends import get_user_id
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
from prisma import Json
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import admin_get_user_history, get_user_credit_model
|
||||
from backend.server.v2.admin.model import AddUserCreditsResponse, UserHistoryResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["credits", "admin"],
|
||||
dependencies=[Depends(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/add_credits", response_model=AddUserCreditsResponse)
|
||||
async def add_user_credits(
|
||||
user_id: typing.Annotated[str, Body()],
|
||||
amount: typing.Annotated[int, Body()],
|
||||
comments: typing.Annotated[str, Body()],
|
||||
admin_user: typing.Annotated[
|
||||
str,
|
||||
Depends(get_user_id),
|
||||
],
|
||||
):
|
||||
""" """
|
||||
logger.info(f"Admin user {admin_user} is adding {amount} credits to user {user_id}")
|
||||
new_balance, transaction_key = await _user_credit_model._add_transaction(
|
||||
user_id,
|
||||
amount,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
metadata=Json({"admin_id": admin_user, "reason": comments}),
|
||||
)
|
||||
return {
|
||||
"new_balance": new_balance,
|
||||
"transaction_key": transaction_key,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/users_history",
|
||||
response_model=UserHistoryResponse,
|
||||
)
|
||||
async def admin_get_all_user_history(
|
||||
admin_user: typing.Annotated[
|
||||
str,
|
||||
Depends(get_user_id),
|
||||
],
|
||||
search: typing.Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
transaction_filter: typing.Optional[CreditTransactionType] = None,
|
||||
):
|
||||
""" """
|
||||
logger.info(f"Admin user {admin_user} is getting grant history")
|
||||
|
||||
try:
|
||||
resp = await admin_get_user_history(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
search=search,
|
||||
transaction_filter=transaction_filter,
|
||||
)
|
||||
logger.info(f"Admin user {admin_user} got {len(resp.history)} grant history")
|
||||
return resp
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting grant history: {e}")
|
||||
raise e
|
||||
@@ -1,16 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.server.model import Pagination
|
||||
|
||||
|
||||
class UserHistoryResponse(BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
|
||||
history: list[UserTransaction]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class AddUserCreditsResponse(BaseModel):
|
||||
new_balance: int
|
||||
transaction_key: str
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import tempfile
|
||||
import typing
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
@@ -10,7 +9,6 @@ import prisma.enums
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
import backend.util.json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -100,47 +98,3 @@ async def review_submission(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while reviewing the submission"},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/submissions/download/{store_listing_version_id}",
|
||||
tags=["store", "admin"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
|
||||
)
|
||||
async def admin_download_agent_file(
|
||||
user: typing.Annotated[
|
||||
autogpt_libs.auth.models.User,
|
||||
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
|
||||
],
|
||||
store_listing_version_id: str = fastapi.Path(
|
||||
..., description="The ID of the agent to download"
|
||||
),
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""
|
||||
Download the agent file by streaming its content.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): The ID of the agent to download
|
||||
|
||||
Returns:
|
||||
StreamingResponse: A streaming response containing the agent's graph data.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
graph_data = await backend.server.v2.store.db.get_agent(
|
||||
user_id=user.user_id,
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
@@ -13,17 +13,12 @@ import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
import backend.server.v2.store.image_gen as store_image_gen
|
||||
import backend.server.v2.store.media as store_media
|
||||
from backend.data import db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.db import locked_transaction
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
integration_creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def list_library_agents(
|
||||
@@ -175,44 +170,6 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
|
||||
|
||||
|
||||
async def get_library_agent_by_store_version_id(
|
||||
store_listing_version_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""
|
||||
Get the library agent metadata for a given store listing version ID and user ID.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Getting library agent for store listing ID: {store_listing_version_id}"
|
||||
)
|
||||
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id},
|
||||
)
|
||||
)
|
||||
if not store_listing_version:
|
||||
logger.warning(f"Store listing version not found: {store_listing_version_id}")
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
# Check if user already has this agent
|
||||
agent = await prisma.models.LibraryAgent.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": store_listing_version.agentGraphId,
|
||||
"agentGraphVersion": store_listing_version.agentGraphVersion,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
if agent:
|
||||
return library_model.LibraryAgent.from_db(agent)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def add_generated_agent_image(
|
||||
graph: backend.data.graph.GraphModel,
|
||||
library_agent_id: str,
|
||||
@@ -249,7 +206,7 @@ async def add_generated_agent_image(
|
||||
async def create_library_agent(
|
||||
graph: backend.data.graph.GraphModel,
|
||||
user_id: str,
|
||||
) -> library_model.LibraryAgent:
|
||||
) -> prisma.models.LibraryAgent:
|
||||
"""
|
||||
Adds an agent to the user's library (LibraryAgent table).
|
||||
|
||||
@@ -270,7 +227,7 @@ async def create_library_agent(
|
||||
)
|
||||
|
||||
try:
|
||||
agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
return await prisma.models.LibraryAgent.prisma().create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
isCreatedByUser=(user_id == graph.user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
@@ -281,10 +238,8 @@ async def create_library_agent(
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
),
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(agent)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating agent in library: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to create agent in library") from e
|
||||
@@ -435,6 +390,11 @@ async def add_store_agent_to_library(
|
||||
)
|
||||
|
||||
graph = store_listing_version.AgentGraph
|
||||
if graph.userId == user_id:
|
||||
logger.warning(
|
||||
f"User #{user_id} attempted to add their own agent to their library"
|
||||
)
|
||||
raise store_exceptions.DatabaseError("Cannot add own agent to library")
|
||||
|
||||
# Check if user already has this agent
|
||||
existing_library_agent = (
|
||||
@@ -444,7 +404,7 @@ async def add_store_agent_to_library(
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
},
|
||||
include={"AgentGraph": True},
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
)
|
||||
if existing_library_agent:
|
||||
@@ -702,47 +662,3 @@ async def delete_preset(user_id: str, preset_id: str) -> None:
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error deleting preset: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to delete preset") from e
|
||||
|
||||
|
||||
async def fork_library_agent(library_agent_id: str, user_id: str):
|
||||
"""
|
||||
Clones a library agent and its underyling graph and nodes (with new ids) for the given user.
|
||||
|
||||
Args:
|
||||
library_agent_id: The ID of the library agent to fork.
|
||||
user_id: The ID of the user who owns the library agent.
|
||||
|
||||
Returns:
|
||||
The forked LibraryAgent.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's an error during the forking process.
|
||||
"""
|
||||
logger.debug(f"Forking library agent {library_agent_id} for user {user_id}")
|
||||
try:
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}-fork_agent"):
|
||||
# Fetch the original agent
|
||||
original_agent = await get_library_agent(library_agent_id, user_id)
|
||||
|
||||
# Check if user owns the library agent
|
||||
# TODO: once we have open/closed sourced agents this needs to be enabled ~kcze
|
||||
# + update library/agents/[id]/page.tsx agent actions
|
||||
# if not original_agent.can_access_graph:
|
||||
# raise store_exceptions.DatabaseError(
|
||||
# f"User {user_id} cannot access library agent graph {library_agent_id}"
|
||||
# )
|
||||
|
||||
# Fork the underlying graph and nodes
|
||||
new_graph = await graph_db.fork_graph(
|
||||
original_agent.graph_id, original_agent.graph_version, user_id
|
||||
)
|
||||
new_graph = await on_graph_activate(
|
||||
new_graph,
|
||||
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
|
||||
)
|
||||
|
||||
# Create a library agent for the new graph
|
||||
return await create_library_agent(new_graph, user_id)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error cloning library agent: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to fork library agent") from e
|
||||
|
||||
@@ -22,13 +22,11 @@ async def test_agent_preset_from_db():
|
||||
userId="test-user-123",
|
||||
isDeleted=False,
|
||||
InputPresets=[
|
||||
prisma.models.AgentNodeExecutionInputOutput.model_validate(
|
||||
{
|
||||
"id": "input-123",
|
||||
"time": datetime.datetime.now(),
|
||||
"name": "input1",
|
||||
"data": '{"type": "string", "value": "test value"}',
|
||||
}
|
||||
prisma.models.AgentNodeExecutionInputOutput(
|
||||
id="input-123",
|
||||
time=datetime.datetime.now(),
|
||||
name="input1",
|
||||
data=prisma.Json({"type": "string", "value": "test value"}),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -85,30 +85,6 @@ async def get_library_agent(
|
||||
return await library_db.get_library_agent(id=library_agent_id, user_id=user_id)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/marketplace/{store_listing_version_id}/",
|
||||
tags=["store, library"],
|
||||
response_model=library_model.LibraryAgent | None,
|
||||
)
|
||||
async def get_library_agent_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
):
|
||||
"""
|
||||
Get Library Agent from Store Listing Version ID.
|
||||
"""
|
||||
try:
|
||||
return await library_db.get_library_agent_by_store_version_id(
|
||||
store_listing_version_id, user_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not fetch library agent from store version ID: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to add agent to library",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
@@ -214,14 +190,3 @@ async def update_library_agent(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update library agent",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/{library_agent_id}/fork")
|
||||
async def fork_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
return await library_db.fork_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -793,7 +793,6 @@ async def create_store_version(
|
||||
changes_summary=changes_summary,
|
||||
version=next_version,
|
||||
)
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create new store version"
|
||||
@@ -967,7 +966,7 @@ async def get_my_agents(
|
||||
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=search_filter,
|
||||
order=[{"updatedAt": "desc"}],
|
||||
order=[{"agentGraphVersion": "desc"}],
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
include={"AgentGraph": True},
|
||||
@@ -1362,31 +1361,3 @@ async def get_admin_listings_with_versions(
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def get_agent_as_admin(
|
||||
user_id: str | None,
|
||||
store_listing_version_id: str,
|
||||
) -> GraphModel:
|
||||
"""Get agent using the version ID and store listing version ID."""
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}
|
||||
)
|
||||
)
|
||||
|
||||
if not store_listing_version:
|
||||
raise ValueError(f"Store listing version {store_listing_version_id} not found")
|
||||
|
||||
graph = await backend.data.graph.get_graph_as_admin(
|
||||
user_id=user_id,
|
||||
graph_id=store_listing_version.agentGraphId,
|
||||
version=store_listing_version.agentGraphVersion,
|
||||
for_export=True,
|
||||
)
|
||||
if not graph:
|
||||
raise ValueError(
|
||||
f"Agent {store_listing_version.agentGraphId} v{store_listing_version.agentGraphVersion} not found"
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
@@ -4,7 +4,20 @@ from typing import List
|
||||
import prisma.enums
|
||||
import pydantic
|
||||
|
||||
from backend.server.model import Pagination
|
||||
|
||||
class Pagination(pydantic.BaseModel):
|
||||
total_items: int = pydantic.Field(
|
||||
description="Total number of items.", examples=[42]
|
||||
)
|
||||
total_pages: int = pydantic.Field(
|
||||
description="Total number of pages.", examples=[97]
|
||||
)
|
||||
current_page: int = pydantic.Field(
|
||||
description="Current_page page number.", examples=[1]
|
||||
)
|
||||
page_size: int = pydantic.Field(
|
||||
description="Number of items per page.", examples=[25]
|
||||
)
|
||||
|
||||
|
||||
class MyAgent(pydantic.BaseModel):
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TurnstileVerifyRequest(BaseModel):
|
||||
"""Request model for verifying a Turnstile token."""
|
||||
|
||||
token: str = Field(description="The Turnstile token to verify")
|
||||
action: Optional[str] = Field(
|
||||
default=None, description="The action that the user is attempting to perform"
|
||||
)
|
||||
|
||||
|
||||
class TurnstileVerifyResponse(BaseModel):
|
||||
"""Response model for the Turnstile verification endpoint."""
|
||||
|
||||
success: bool = Field(description="Whether the token verification was successful")
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="Error message if verification failed"
|
||||
)
|
||||
challenge_timestamp: Optional[str] = Field(
|
||||
default=None, description="Timestamp of the challenge (ISO format)"
|
||||
)
|
||||
hostname: Optional[str] = Field(
|
||||
default=None, description="Hostname of the site where the challenge was solved"
|
||||
)
|
||||
action: Optional[str] = Field(
|
||||
default=None, description="The action associated with this verification"
|
||||
)
|
||||
@@ -1,108 +0,0 @@
|
||||
import logging
|
||||
|
||||
import aiohttp
|
||||
from fastapi import APIRouter
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .models import TurnstileVerifyRequest, TurnstileVerifyResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
settings = Settings()
|
||||
|
||||
|
||||
@router.post("/verify", response_model=TurnstileVerifyResponse)
|
||||
async def verify_turnstile_token(
|
||||
request: TurnstileVerifyRequest,
|
||||
) -> TurnstileVerifyResponse:
|
||||
"""
|
||||
Verify a Cloudflare Turnstile token.
|
||||
This endpoint verifies a token returned by the Cloudflare Turnstile challenge
|
||||
on the client side. It returns whether the verification was successful.
|
||||
"""
|
||||
logger.info(f"Verifying Turnstile token for action: {request.action}")
|
||||
return await verify_token(request)
|
||||
|
||||
|
||||
async def verify_token(request: TurnstileVerifyRequest) -> TurnstileVerifyResponse:
|
||||
"""
|
||||
Verify a Cloudflare Turnstile token by making a request to the Cloudflare API.
|
||||
"""
|
||||
# Get the secret key from settings
|
||||
turnstile_secret_key = settings.secrets.turnstile_secret_key
|
||||
turnstile_verify_url = settings.secrets.turnstile_verify_url
|
||||
|
||||
if not turnstile_secret_key:
|
||||
logger.error("Turnstile secret key is not configured")
|
||||
return TurnstileVerifyResponse(
|
||||
success=False,
|
||||
error="CONFIGURATION_ERROR",
|
||||
challenge_timestamp=None,
|
||||
hostname=None,
|
||||
action=None,
|
||||
)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
payload = {
|
||||
"secret": turnstile_secret_key,
|
||||
"response": request.token,
|
||||
}
|
||||
|
||||
if request.action:
|
||||
payload["action"] = request.action
|
||||
|
||||
logger.debug(f"Verifying Turnstile token with action: {request.action}")
|
||||
|
||||
async with session.post(
|
||||
turnstile_verify_url,
|
||||
data=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Turnstile API error: {error_text}")
|
||||
return TurnstileVerifyResponse(
|
||||
success=False,
|
||||
error=f"API_ERROR: {response.status}",
|
||||
challenge_timestamp=None,
|
||||
hostname=None,
|
||||
action=None,
|
||||
)
|
||||
|
||||
data = await response.json()
|
||||
logger.debug(f"Turnstile API response: {data}")
|
||||
|
||||
# Parse the response and return a structured object
|
||||
return TurnstileVerifyResponse(
|
||||
success=data.get("success", False),
|
||||
error=(
|
||||
data.get("error-codes", None)[0]
|
||||
if data.get("error-codes")
|
||||
else None
|
||||
),
|
||||
challenge_timestamp=data.get("challenge_timestamp"),
|
||||
hostname=data.get("hostname"),
|
||||
action=data.get("action"),
|
||||
)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Connection error to Turnstile API: {str(e)}")
|
||||
return TurnstileVerifyResponse(
|
||||
success=False,
|
||||
error=f"CONNECTION_ERROR: {str(e)}",
|
||||
challenge_timestamp=None,
|
||||
hostname=None,
|
||||
action=None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in Turnstile verification: {str(e)}")
|
||||
return TurnstileVerifyResponse(
|
||||
success=False,
|
||||
error=f"UNEXPECTED_ERROR: {str(e)}",
|
||||
challenge_timestamp=None,
|
||||
hostname=None,
|
||||
action=None,
|
||||
)
|
||||
@@ -6,6 +6,7 @@ from typing import Protocol
|
||||
import uvicorn
|
||||
from autogpt_libs.auth import parse_jwt_token
|
||||
from autogpt_libs.logging.utils import generate_uvicorn_config
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
@@ -18,7 +19,7 @@ from backend.server.model import (
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
WSSubscribeGraphExecutionsRequest,
|
||||
)
|
||||
from backend.util.service import AppProcess
|
||||
from backend.util.service import AppProcess, get_service_client
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -45,6 +46,13 @@ def get_connection_manager():
|
||||
return _connection_manager
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client():
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
try:
|
||||
event_queue = AsyncRedisExecutionEventBus()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import sentry_sdk
|
||||
from pydantic import SecretStr
|
||||
from sentry_sdk.integrations.anthropic import AnthropicIntegration
|
||||
from sentry_sdk.integrations.logging import LoggingIntegration
|
||||
|
||||
@@ -24,43 +22,3 @@ def sentry_init():
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def sentry_capture_error(error: Exception):
|
||||
sentry_sdk.capture_exception(error)
|
||||
sentry_sdk.flush()
|
||||
|
||||
|
||||
def discord_send_alert(content: str):
|
||||
from backend.blocks.discord import SendDiscordMessageBlock
|
||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput, ProviderName
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
creds = APIKeyCredentials(
|
||||
provider="discord",
|
||||
api_key=SecretStr(settings.secrets.discord_bot_token),
|
||||
title="Provide Discord Bot Token for the platform alert",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return SendDiscordMessageBlock().run_once(
|
||||
SendDiscordMessageBlock.Input(
|
||||
credentials=CredentialsMetaInput(
|
||||
id=creds.id,
|
||||
title=creds.title,
|
||||
type=creds.type,
|
||||
provider=ProviderName.DISCORD,
|
||||
),
|
||||
message_content=content,
|
||||
channel_name=settings.config.platform_alert_discord_channel,
|
||||
),
|
||||
"status",
|
||||
credentials=creds,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import signal
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from multiprocessing import Process, get_all_start_methods, set_start_method
|
||||
from multiprocessing import Process, set_start_method
|
||||
from typing import Optional
|
||||
|
||||
from backend.util.logging import configure_logging
|
||||
@@ -30,12 +30,7 @@ class AppProcess(ABC):
|
||||
process: Optional[Process] = None
|
||||
cleaned_up = False
|
||||
|
||||
if "forkserver" in get_all_start_methods():
|
||||
set_start_method("forkserver", force=True)
|
||||
else:
|
||||
logger.warning("Forkserver start method is not available. Using spawn instead.")
|
||||
set_start_method("spawn", force=True)
|
||||
|
||||
set_start_method("spawn", force=True)
|
||||
configure_logging()
|
||||
sentry_init()
|
||||
|
||||
|
||||
@@ -73,10 +73,3 @@ def conn_retry(
|
||||
return async_wrapper if is_coroutine else sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
func_retry = retry(
|
||||
reraise=False,
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=30),
|
||||
)
|
||||
|
||||
@@ -5,10 +5,8 @@ import os
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property, update_wrapper
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Concatenate,
|
||||
Coroutine,
|
||||
@@ -44,15 +42,24 @@ api_call_timeout = config.rpc_client_call_timeout
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
EXPOSED_FLAG = "__exposed__"
|
||||
|
||||
|
||||
def expose(func: C) -> C:
|
||||
func = getattr(func, "__func__", func)
|
||||
setattr(func, EXPOSED_FLAG, True)
|
||||
setattr(func, "__exposed__", True)
|
||||
return func
|
||||
|
||||
|
||||
def exposed_run_and_wait(
|
||||
f: Callable[P, Coroutine[None, None, R]]
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
# TODO:
|
||||
# This function lies about its return type to make the DynamicClient
|
||||
# call the function synchronously, fix this when DynamicClient can choose
|
||||
# to call a function synchronously or asynchronously.
|
||||
return expose(f) # type: ignore
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# AppService for IPC service based on HTTP request through FastAPI
|
||||
# --------------------------------------------------
|
||||
@@ -196,7 +203,7 @@ class AppService(BaseAppService, ABC):
|
||||
|
||||
# Register the exposed API routes.
|
||||
for attr_name, attr in vars(type(self)).items():
|
||||
if getattr(attr, EXPOSED_FLAG, False):
|
||||
if getattr(attr, "__exposed__", False):
|
||||
route_path = f"/{attr_name}"
|
||||
self.fastapi_app.add_api_route(
|
||||
route_path,
|
||||
@@ -227,53 +234,31 @@ class AppService(BaseAppService, ABC):
|
||||
AS = TypeVar("AS", bound=AppService)
|
||||
|
||||
|
||||
class AppServiceClient(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_service_type(cls) -> Type[AppService]:
|
||||
pass
|
||||
|
||||
def health_check(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
def close_service_client(client: Any) -> None:
|
||||
if hasattr(client, "close"):
|
||||
client.close()
|
||||
else:
|
||||
logger.warning(f"Client {client} is not closable")
|
||||
|
||||
|
||||
ASC = TypeVar("ASC", bound=AppServiceClient)
|
||||
|
||||
|
||||
@conn_retry("AppService client", "Creating service client", max_retry=api_comm_retry)
|
||||
@conn_retry("FastAPI client", "Creating service client", max_retry=api_comm_retry)
|
||||
def get_service_client(
|
||||
service_client_type: Type[ASC],
|
||||
service_type: Type[AS],
|
||||
call_timeout: int | None = api_call_timeout,
|
||||
health_check: bool = True,
|
||||
) -> ASC:
|
||||
) -> AS:
|
||||
class DynamicClient:
|
||||
def __init__(self):
|
||||
service_type = service_client_type.get_service_type()
|
||||
host = service_type.get_host()
|
||||
port = service_type.get_port()
|
||||
self.base_url = f"http://{host}:{port}".rstrip("/")
|
||||
|
||||
@cached_property
|
||||
def sync_client(self) -> httpx.Client:
|
||||
return httpx.Client(
|
||||
self.client = httpx.Client(
|
||||
base_url=self.base_url,
|
||||
timeout=call_timeout,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def async_client(self) -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=call_timeout,
|
||||
)
|
||||
|
||||
def _handle_call_method_response(
|
||||
self, response: httpx.Response, method_name: str
|
||||
) -> Any:
|
||||
def _call_method(self, method_name: str, **kwargs) -> Any:
|
||||
try:
|
||||
response = self.client.post(method_name, json=to_dict(kwargs))
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
@@ -284,103 +269,36 @@ def get_service_client(
|
||||
*(error.args or [str(e)])
|
||||
)
|
||||
|
||||
def _call_method_sync(self, method_name: str, **kwargs) -> Any:
|
||||
return self._handle_call_method_response(
|
||||
method_name=method_name,
|
||||
response=self.sync_client.post(method_name, json=to_dict(kwargs)),
|
||||
)
|
||||
|
||||
async def _call_method_async(self, method_name: str, **kwargs) -> Any:
|
||||
return self._handle_call_method_response(
|
||||
method_name=method_name,
|
||||
response=await self.async_client.post(
|
||||
method_name, json=to_dict(kwargs)
|
||||
),
|
||||
)
|
||||
|
||||
async def aclose(self):
|
||||
self.sync_client.close()
|
||||
await self.async_client.aclose()
|
||||
|
||||
def close(self):
|
||||
self.sync_client.close()
|
||||
|
||||
def _get_params(self, signature: inspect.Signature, *args, **kwargs) -> dict:
|
||||
if args:
|
||||
arg_names = list(signature.parameters.keys())
|
||||
if arg_names[0] in ("self", "cls"):
|
||||
arg_names = arg_names[1:]
|
||||
kwargs.update(dict(zip(arg_names, args)))
|
||||
return kwargs
|
||||
|
||||
def _get_return(self, expected_return: TypeAdapter | None, result: Any) -> Any:
|
||||
if expected_return:
|
||||
return expected_return.validate_python(result)
|
||||
return result
|
||||
self.client.close()
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
original_func = getattr(service_client_type, name, None)
|
||||
if original_func is None:
|
||||
raise AttributeError(
|
||||
f"Method {name} not found in {service_client_type}"
|
||||
)
|
||||
else:
|
||||
name = original_func.__name__
|
||||
# Try to get the original function from the service type.
|
||||
orig_func = getattr(service_type, name, None)
|
||||
if orig_func is None:
|
||||
raise AttributeError(f"Method {name} not found in {service_type}")
|
||||
|
||||
sig = inspect.signature(original_func)
|
||||
sig = inspect.signature(orig_func)
|
||||
ret_ann = sig.return_annotation
|
||||
if ret_ann != inspect.Signature.empty:
|
||||
expected_return = TypeAdapter(ret_ann)
|
||||
else:
|
||||
expected_return = None
|
||||
|
||||
if inspect.iscoroutinefunction(original_func):
|
||||
def method(*args, **kwargs) -> Any:
|
||||
if args:
|
||||
arg_names = list(sig.parameters.keys())
|
||||
if arg_names[0] in ("self", "cls"):
|
||||
arg_names = arg_names[1:]
|
||||
kwargs.update(dict(zip(arg_names, args)))
|
||||
result = self._call_method(name, **kwargs)
|
||||
if expected_return:
|
||||
return expected_return.validate_python(result)
|
||||
return result
|
||||
|
||||
async def async_method(*args, **kwargs) -> Any:
|
||||
params = self._get_params(sig, *args, **kwargs)
|
||||
result = await self._call_method_async(name, **params)
|
||||
return self._get_return(expected_return, result)
|
||||
return method
|
||||
|
||||
return async_method
|
||||
else:
|
||||
client = cast(AS, DynamicClient())
|
||||
client.health_check()
|
||||
|
||||
def sync_method(*args, **kwargs) -> Any:
|
||||
params = self._get_params(sig, *args, **kwargs)
|
||||
result = self._call_method_sync(name, **params)
|
||||
return self._get_return(expected_return, result)
|
||||
|
||||
return sync_method
|
||||
|
||||
client = cast(ASC, DynamicClient())
|
||||
if health_check:
|
||||
client.health_check()
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def endpoint_to_sync(
|
||||
func: Callable[Concatenate[Any, P], Awaitable[R]],
|
||||
) -> Callable[Concatenate[Any, P], R]:
|
||||
"""
|
||||
Produce a *typed* stub that **looks** synchronous to the type‑checker.
|
||||
"""
|
||||
|
||||
def _stub(*args: P.args, **kwargs: P.kwargs) -> R: # pragma: no cover
|
||||
raise RuntimeError("should be intercepted by __getattr__")
|
||||
|
||||
update_wrapper(_stub, func)
|
||||
return cast(Callable[Concatenate[Any, P], R], _stub)
|
||||
|
||||
|
||||
def endpoint_to_async(
|
||||
func: Callable[Concatenate[Any, P], R],
|
||||
) -> Callable[Concatenate[Any, P], Awaitable[R]]:
|
||||
"""
|
||||
The async mirror of `to_sync`.
|
||||
"""
|
||||
|
||||
async def _stub(*args: P.args, **kwargs: P.kwargs) -> R: # pragma: no cover
|
||||
raise RuntimeError("should be intercepted by __getattr__")
|
||||
|
||||
update_wrapper(_stub, func)
|
||||
return cast(Callable[Concatenate[Any, P], Awaitable[R]], _stub)
|
||||
return cast(AS, client)
|
||||
|
||||
@@ -117,18 +117,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=1,
|
||||
description="Cost per execution in cents after each threshold.",
|
||||
)
|
||||
execution_counter_expiration_time: int = Field(
|
||||
default=60 * 60 * 24,
|
||||
description="Time in seconds after which the execution counter is reset.",
|
||||
)
|
||||
execution_late_notification_threshold_secs: int = Field(
|
||||
default=5 * 60,
|
||||
description="Time in seconds after which the execution stuck on QUEUED status is considered late.",
|
||||
)
|
||||
execution_late_notification_checkrange_secs: int = Field(
|
||||
default=60 * 60,
|
||||
description="Time in seconds for how far back to check for the late executions.",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
@@ -149,6 +137,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=8002,
|
||||
description="The port for execution manager daemon to run on",
|
||||
)
|
||||
execution_manager_loop_max_retry: int = Field(
|
||||
default=5,
|
||||
description="The maximum number of retries for the execution manager loop",
|
||||
)
|
||||
|
||||
execution_scheduler_port: int = Field(
|
||||
default=8003,
|
||||
@@ -239,10 +231,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=True,
|
||||
description="Whether to enable the agent input subtype blocks",
|
||||
)
|
||||
platform_alert_discord_channel: str = Field(
|
||||
default="local-alerts",
|
||||
description="The Discord channel for the platform",
|
||||
)
|
||||
|
||||
@field_validator("platform_base_url", "frontend_base_url")
|
||||
@classmethod
|
||||
@@ -354,16 +342,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
description="The secret key to use for the unsubscribe user by token",
|
||||
)
|
||||
|
||||
# Cloudflare Turnstile credentials
|
||||
turnstile_secret_key: str = Field(
|
||||
default="",
|
||||
description="Cloudflare Turnstile backend secret key",
|
||||
)
|
||||
turnstile_verify_url: str = Field(
|
||||
default="https://challenges.cloudflare.com/turnstile/v0/siteverify",
|
||||
description="Cloudflare Turnstile verify URL",
|
||||
)
|
||||
|
||||
# OAuth server credentials for integrations
|
||||
# --8<-- [start:OAuthServerCredentialsExample]
|
||||
github_client_id: str = Field(default="", description="GitHub OAuth client ID")
|
||||
|
||||
@@ -25,7 +25,7 @@ class SpinTestServer:
|
||||
self.db_api = DatabaseManager()
|
||||
self.exec_manager = ExecutionManager()
|
||||
self.agent_server = AgentServer()
|
||||
self.scheduler = Scheduler(register_system_tasks=False)
|
||||
self.scheduler = Scheduler()
|
||||
self.notif_manager = NotificationManager()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraph"
|
||||
ADD COLUMN "forkedFromId" TEXT,
|
||||
ADD COLUMN "forkedFromVersion" INTEGER;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "AgentGraph" ADD CONSTRAINT "AgentGraph_forkedFromId_forkedFromVersion_fkey" FOREIGN KEY ("forkedFromId", "forkedFromVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
@@ -1,5 +0,0 @@
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentGraphExecution_createdAt_idx" ON "AgentGraphExecution"("createdAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNodeExecution_addedTime_idx" ON "AgentNodeExecution"("addedTime");
|
||||
@@ -1,9 +0,0 @@
|
||||
-- Rename 'data' input to 'inputs' on all Agent Executor nodes
|
||||
UPDATE "AgentNode" AS node
|
||||
SET "constantInput" = jsonb_set(
|
||||
"constantInput",
|
||||
'{inputs}',
|
||||
"constantInput"->'data'
|
||||
) - 'data'
|
||||
WHERE node."agentBlockId" = 'e189baac-8c20-45a1-94a7-55177ea42565'
|
||||
AND node."constantInput" ? 'data';
|
||||
17
autogpt_platform/backend/poetry.lock
generated
17
autogpt_platform/backend/poetry.lock
generated
@@ -3568,21 +3568,6 @@ files = [
|
||||
[package.dependencies]
|
||||
tqdm = "*"
|
||||
|
||||
[[package]]
|
||||
name = "prometheus-client"
|
||||
version = "0.21.1"
|
||||
description = "Python client for the Prometheus monitoring system."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "prometheus_client-0.21.1-py3-none-any.whl", hash = "sha256:594b45c410d6f4f8888940fe80b5cc2521b305a1fafe1c58609ef715a001f301"},
|
||||
{file = "prometheus_client-0.21.1.tar.gz", hash = "sha256:252505a722ac04b0456be05c05f75f45d760c2911ffc45f2a06bcaed9f3ae3fb"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
twisted = ["twisted"]
|
||||
|
||||
[[package]]
|
||||
name = "propcache"
|
||||
version = "0.2.1"
|
||||
@@ -6325,4 +6310,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "29ccee704d8296c57156daab98bb0cbbf5a43e83526b7f08a14c91fb7a4898f4"
|
||||
content-hash = "781f77ec77cfce78b34fb57063dcc81df8e9c5a4be9a644033a0c197e0063730"
|
||||
|
||||
@@ -64,7 +64,6 @@ websockets = "^14.2"
|
||||
youtube-transcript-api = "^0.6.2"
|
||||
zerobouncesdk = "^1.1.1"
|
||||
# NOTE: please insert new dependencies in their alphabetical location
|
||||
prometheus-client = "^0.21.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
aiohappyeyeballs = "^2.6.1"
|
||||
@@ -88,7 +87,6 @@ build-backend = "poetry.core.masonry.api"
|
||||
app = "backend.app:main"
|
||||
rest = "backend.rest:main"
|
||||
ws = "backend.ws:main"
|
||||
scheduler = "backend.scheduler:main"
|
||||
executor = "backend.exec:main"
|
||||
cli = "backend.cli:main"
|
||||
format = "linter:format"
|
||||
|
||||
@@ -118,11 +118,6 @@ model AgentGraph {
|
||||
// This allows us to delete user data with deleting the agent which maybe in use by other users
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
forkedFromId String?
|
||||
forkedFromVersion Int?
|
||||
forkedFrom AgentGraph? @relation("AgentGraphForks", fields: [forkedFromId, forkedFromVersion], references: [id, version])
|
||||
forks AgentGraph[] @relation("AgentGraphForks")
|
||||
|
||||
Nodes AgentNode[]
|
||||
Executions AgentGraphExecution[]
|
||||
|
||||
@@ -352,7 +347,6 @@ model AgentGraphExecution {
|
||||
|
||||
@@index([agentGraphId, agentGraphVersion])
|
||||
@@index([userId])
|
||||
@@index([createdAt])
|
||||
}
|
||||
|
||||
// This model describes the execution of an AgentNode.
|
||||
@@ -379,7 +373,6 @@ model AgentNodeExecution {
|
||||
|
||||
@@index([agentGraphExecutionId])
|
||||
@@index([agentNodeId])
|
||||
@@index([addedTime])
|
||||
}
|
||||
|
||||
// This model describes the output of an AgentNodeExecution.
|
||||
|
||||
@@ -6,10 +6,10 @@ from prisma.models import CreditTransaction
|
||||
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.block import get_block
|
||||
from backend.data.credit import BetaUserCredit, UsageTransactionMetadata
|
||||
from backend.data.credit import BetaUserCredit
|
||||
from backend.data.execution import NodeExecutionEntry
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.executor.utils import block_usage_cost
|
||||
from backend.executor.utils import UsageTransactionMetadata, block_usage_cost
|
||||
from backend.integrations.credentials_store import openai_credentials
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
@@ -34,7 +34,7 @@ async def spend_credits(entry: NodeExecutionEntry) -> int:
|
||||
if not block:
|
||||
raise RuntimeError(f"Block {entry.block_id} not found")
|
||||
|
||||
cost, matching_filter = block_usage_cost(block=block, input_data=entry.inputs)
|
||||
cost, matching_filter = block_usage_cost(block=block, input_data=entry.data)
|
||||
await user_credit.spend_credits(
|
||||
entry.user_id,
|
||||
cost,
|
||||
@@ -46,7 +46,6 @@ async def spend_credits(entry: NodeExecutionEntry) -> int:
|
||||
block_id=entry.block_id,
|
||||
block=entry.block_id,
|
||||
input=matching_filter,
|
||||
reason=f"Ran block {entry.block_id} {block.name}",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -67,7 +66,7 @@ async def test_block_credit_usage(server: SpinTestServer):
|
||||
graph_exec_id="test_graph_exec",
|
||||
node_exec_id="test_node_exec",
|
||||
block_id=AITextGeneratorBlock().id,
|
||||
inputs={
|
||||
data={
|
||||
"model": "gpt-4-turbo",
|
||||
"credentials": {
|
||||
"id": openai_credentials.id,
|
||||
@@ -87,7 +86,7 @@ async def test_block_credit_usage(server: SpinTestServer):
|
||||
graph_exec_id="test_graph_exec",
|
||||
node_exec_id="test_node_exec",
|
||||
block_id=AITextGeneratorBlock().id,
|
||||
inputs={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
|
||||
data={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
|
||||
),
|
||||
)
|
||||
assert spending_amount_2 == 0
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from backend.data import db
|
||||
from backend.executor.scheduler import SchedulerClient
|
||||
from backend.executor import Scheduler
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.service import get_service_client
|
||||
@@ -17,11 +17,11 @@ async def test_agent_schedule(server: SpinTestServer):
|
||||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
scheduler = get_service_client(SchedulerClient)
|
||||
schedules = await scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
scheduler = get_service_client(Scheduler)
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
assert len(schedules) == 0
|
||||
|
||||
schedule = await scheduler.add_execution_schedule(
|
||||
schedule = scheduler.add_execution_schedule(
|
||||
graph_id=test_graph.id,
|
||||
user_id=test_user.id,
|
||||
graph_version=1,
|
||||
@@ -30,12 +30,10 @@ async def test_agent_schedule(server: SpinTestServer):
|
||||
)
|
||||
assert schedule
|
||||
|
||||
schedules = await scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
assert len(schedules) == 1
|
||||
assert schedules[0].cron == "0 0 * * *"
|
||||
|
||||
await scheduler.delete_schedule(schedule.id, user_id=test_user.id)
|
||||
schedules = await scheduler.get_execution_schedules(
|
||||
test_graph.id, user_id=test_user.id
|
||||
)
|
||||
scheduler.delete_schedule(schedule.id, user_id=test_user.id)
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id, user_id=test_user.id)
|
||||
assert len(schedules) == 0
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
|
||||
TEST_SERVICE_PORT = 8765
|
||||
|
||||
@@ -38,25 +32,10 @@ class ServiceTest(AppService):
|
||||
return self.run_and_wait(add_async(a, b))
|
||||
|
||||
|
||||
class ServiceTestClient(AppServiceClient):
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return ServiceTest
|
||||
|
||||
add = ServiceTest.add
|
||||
subtract = ServiceTest.subtract
|
||||
fun_with_async = ServiceTest.fun_with_async
|
||||
|
||||
add_async = endpoint_to_async(ServiceTest.add)
|
||||
subtract_async = endpoint_to_async(ServiceTest.subtract)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_service_creation(server):
|
||||
with ServiceTest():
|
||||
client = get_service_client(ServiceTestClient)
|
||||
client = get_service_client(ServiceTest)
|
||||
assert client.add(5, 3) == 8
|
||||
assert client.subtract(10, 4) == 6
|
||||
assert client.fun_with_async(5, 3) == 8
|
||||
assert await client.add_async(5, 3) == 8
|
||||
assert await client.subtract_async(10, 4) == 6
|
||||
|
||||
@@ -73,8 +73,6 @@ services:
|
||||
condition: service_completed_successfully
|
||||
rabbitmq:
|
||||
condition: service_healthy
|
||||
# scheduler_server:
|
||||
# condition: service_healthy
|
||||
environment:
|
||||
- SUPABASE_URL=http://kong:8000
|
||||
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
@@ -90,7 +88,7 @@ services:
|
||||
- REDIS_PASSWORD=password
|
||||
- ENABLE_AUTH=true
|
||||
- PYRO_HOST=0.0.0.0
|
||||
- SCHEDULER_HOST=scheduler_server
|
||||
- EXECUTIONSCHEDULER_HOST=rest_server
|
||||
- EXECUTIONMANAGER_HOST=executor
|
||||
- NOTIFICATIONMANAGER_HOST=rest_server
|
||||
- FRONTEND_BASE_URL=http://localhost:3000
|
||||
@@ -100,6 +98,7 @@ services:
|
||||
ports:
|
||||
- "8006:8006"
|
||||
- "8007:8007"
|
||||
- "8003:8003" # execution scheduler
|
||||
networks:
|
||||
- app-network
|
||||
|
||||
@@ -143,7 +142,7 @@ services:
|
||||
- NOTIFICATIONMANAGER_HOST=rest_server
|
||||
- ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw= # DO NOT USE IN PRODUCTION!!
|
||||
ports:
|
||||
- "8002:8002"
|
||||
- "8002:8000"
|
||||
networks:
|
||||
- app-network
|
||||
|
||||
@@ -188,61 +187,6 @@ services:
|
||||
networks:
|
||||
- app-network
|
||||
|
||||
scheduler_server:
|
||||
build:
|
||||
context: ../
|
||||
dockerfile: autogpt_platform/backend/Dockerfile
|
||||
target: server
|
||||
command: ["python", "-m", "backend.scheduler"]
|
||||
develop:
|
||||
watch:
|
||||
- path: ./
|
||||
target: autogpt_platform/backend/
|
||||
action: rebuild
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
rabbitmq:
|
||||
condition: service_healthy
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
# healthcheck:
|
||||
# test:
|
||||
# [
|
||||
# "CMD",
|
||||
# "curl",
|
||||
# "-f",
|
||||
# "-X",
|
||||
# "POST",
|
||||
# "http://localhost:8003/health_check",
|
||||
# ]
|
||||
# interval: 10s
|
||||
# timeout: 10s
|
||||
# retries: 5
|
||||
environment:
|
||||
- DATABASEMANAGER_HOST=rest_server
|
||||
- NOTIFICATIONMANAGER_HOST=rest_server
|
||||
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- REDIS_HOST=redis
|
||||
- REDIS_PORT=6379
|
||||
- REDIS_PASSWORD=password
|
||||
- RABBITMQ_HOST=rabbitmq
|
||||
- RABBITMQ_PORT=5672
|
||||
- RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
- RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
- ENABLE_AUTH=true
|
||||
- PYRO_HOST=0.0.0.0
|
||||
- BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
|
||||
|
||||
ports:
|
||||
- "8003:8003"
|
||||
networks:
|
||||
- app-network
|
||||
|
||||
# frontend:
|
||||
# build:
|
||||
# context: ../
|
||||
|
||||
@@ -57,17 +57,11 @@ services:
|
||||
file: ./docker-compose.platform.yml
|
||||
service: websocket_server
|
||||
|
||||
scheduler_server:
|
||||
<<: *agpt-services
|
||||
extends:
|
||||
file: ./docker-compose.platform.yml
|
||||
service: scheduler_server
|
||||
|
||||
# frontend:
|
||||
# <<: *agpt-services
|
||||
# extends:
|
||||
# file: ./docker-compose.platform.yml
|
||||
# service: frontend
|
||||
# frontend:
|
||||
# <<: *agpt-services
|
||||
# extends:
|
||||
# file: ./docker-compose.platform.yml
|
||||
# service: frontend
|
||||
|
||||
# Supabase services
|
||||
studio:
|
||||
|
||||
@@ -24,9 +24,3 @@ GA_MEASUREMENT_ID=G-FH2XK2W4GN
|
||||
|
||||
# When running locally, set NEXT_PUBLIC_BEHAVE_AS=CLOUD to use the a locally hosted marketplace (as is typical in development, and the cloud deployment), otherwise set it to LOCAL to have the marketplace open in a new tab
|
||||
NEXT_PUBLIC_BEHAVE_AS=LOCAL
|
||||
NEXT_PUBLIC_SHOW_BILLING_PAGE=false
|
||||
|
||||
## Cloudflare Turnstile (CAPTCHA) Configuration
|
||||
## Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
|
||||
## This is the frontend site key
|
||||
NEXT_PUBLIC_CLOUDFLARE_TURNSTILE_SITE_KEY=
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import type { Preview } from "@storybook/react";
|
||||
import { initialize, mswLoader } from "msw-storybook-addon";
|
||||
import "../src/app/globals.css";
|
||||
|
||||
import { Providers } from "../src/app/providers";
|
||||
// Initialize MSW
|
||||
import React from "react";
|
||||
initialize();
|
||||
|
||||
const preview: Preview = {
|
||||
@@ -18,6 +19,17 @@ const preview: Preview = {
|
||||
},
|
||||
},
|
||||
loaders: [mswLoader],
|
||||
decorators: [
|
||||
(Story, context) => {
|
||||
const mockOptions = context.parameters.mockBackend || {};
|
||||
|
||||
return (
|
||||
<Providers useMockBackend mockClientProps={mockOptions}>
|
||||
<Story />
|
||||
</Providers>
|
||||
);
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
export default preview;
|
||||
@@ -163,7 +163,14 @@ export default function Page() {
|
||||
</div>
|
||||
|
||||
<OnboardingFooter>
|
||||
<OnboardingButton className="mb-2" href="/onboarding/4-agent">
|
||||
<OnboardingButton
|
||||
className="mb-2"
|
||||
href="/onboarding/4-agent"
|
||||
disabled={
|
||||
state?.integrations.length === 0 &&
|
||||
isEmptyOrWhitespace(state.otherIntegrations)
|
||||
}
|
||||
>
|
||||
Next
|
||||
</OnboardingButton>
|
||||
</OnboardingFooter>
|
||||
|
||||
@@ -59,7 +59,7 @@ export default function Page() {
|
||||
|
||||
<div className="my-12 flex items-center justify-between gap-5">
|
||||
<OnboardingAgentCard
|
||||
agent={agents[0]}
|
||||
{...(agents[0] || {})}
|
||||
selected={
|
||||
agents[0] !== undefined
|
||||
? state?.selectedStoreListingVersionId ==
|
||||
@@ -74,7 +74,7 @@ export default function Page() {
|
||||
}
|
||||
/>
|
||||
<OnboardingAgentCard
|
||||
agent={agents[1]}
|
||||
{...(agents[1] || {})}
|
||||
selected={
|
||||
agents[1] !== undefined
|
||||
? state?.selectedStoreListingVersionId ==
|
||||
|
||||
@@ -9,6 +9,7 @@ import StarRating from "@/components/onboarding/StarRating";
|
||||
import { Play } from "lucide-react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import Image from "next/image";
|
||||
import { GraphMeta, StoreAgentDetails } from "@/lib/autogpt-server-api";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { useRouter } from "next/navigation";
|
||||
@@ -16,7 +17,6 @@ import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import SchemaTooltip from "@/components/SchemaTooltip";
|
||||
import { TypeBasedInput } from "@/components/type-based-input";
|
||||
import SmartImage from "@/components/agptui/SmartImage";
|
||||
|
||||
export default function Page() {
|
||||
const { state, updateState, setStep } = useOnboarding(
|
||||
@@ -99,7 +99,7 @@ export default function Page() {
|
||||
}, [api, agent, router, state?.agentInput, storeAgent, updateState]);
|
||||
|
||||
const runYourAgent = (
|
||||
<div className="ml-[104px] w-[481px] pl-5">
|
||||
<div className="ml-[54px] w-[481px] pl-5">
|
||||
<div className="flex flex-col">
|
||||
<OnboardingText variant="header">Run your first agent</OnboardingText>
|
||||
<span className="mt-9 text-base font-normal leading-normal text-zinc-600">
|
||||
@@ -147,25 +147,32 @@ export default function Page() {
|
||||
return (
|
||||
<OnboardingStep dotted>
|
||||
<OnboardingHeader backHref={"/onboarding/4-agent"} transparent />
|
||||
{/* Agent card */}
|
||||
<div className="fixed left-1/4 top-1/2 w-[481px] -translate-x-1/2 -translate-y-1/2">
|
||||
<div className="h-[156px] w-[481px] rounded-xl bg-white px-6 pb-5 pt-4">
|
||||
<span className="font-sans text-xs font-medium tracking-wide text-zinc-500">
|
||||
SELECTED AGENT
|
||||
</span>
|
||||
{storeAgent ? (
|
||||
<div
|
||||
className={cn(
|
||||
"flex w-full items-center justify-center",
|
||||
showInput ? "mt-[32px]" : "mt-[192px]",
|
||||
)}
|
||||
>
|
||||
{/* Left side */}
|
||||
<div className="mr-[52px] w-[481px]">
|
||||
<div className="h-[156px] w-[481px] rounded-xl bg-white px-6 pb-5 pt-4">
|
||||
<span className="font-sans text-xs font-medium tracking-wide text-zinc-500">
|
||||
SELECTED AGENT
|
||||
</span>
|
||||
<div className="mt-4 flex h-20 rounded-lg bg-violet-50 p-2">
|
||||
{/* Left image */}
|
||||
<SmartImage
|
||||
src={storeAgent?.agent_image[0]}
|
||||
alt="Agent cover"
|
||||
imageContain
|
||||
className="w-[350px] rounded-lg"
|
||||
<Image
|
||||
src={storeAgent?.agent_image[0] || ""}
|
||||
alt="Description"
|
||||
width={350}
|
||||
height={196}
|
||||
className="h-full w-auto rounded-lg object-contain"
|
||||
/>
|
||||
|
||||
{/* Right content */}
|
||||
<div className="ml-2 flex flex-1 flex-col">
|
||||
<span className="w-[292px] truncate font-sans text-[14px] font-medium leading-normal text-zinc-800">
|
||||
{storeAgent?.agent_name}
|
||||
{agent?.name}
|
||||
</span>
|
||||
<span className="mt-[5px] w-[292px] truncate font-sans text-xs font-normal leading-tight text-zinc-600">
|
||||
by {storeAgent?.creator}
|
||||
@@ -182,19 +189,13 @@ export default function Page() {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="mt-4 flex h-20 animate-pulse rounded-lg bg-gray-300 p-2" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex min-h-[80vh] items-center justify-center">
|
||||
{/* Left side */}
|
||||
<div className="w-[481px]" />
|
||||
{/* Right side */}
|
||||
{!showInput ? (
|
||||
runYourAgent
|
||||
) : (
|
||||
<div className="ml-[104px] w-[481px] pl-5">
|
||||
<div className="ml-[54px] w-[481px] pl-5">
|
||||
<div className="flex flex-col">
|
||||
<OnboardingText variant="header">
|
||||
Provide details for your agent
|
||||
|
||||
@@ -6,13 +6,6 @@ import { redirect } from "next/navigation";
|
||||
export async function finishOnboarding() {
|
||||
const api = new BackendAPI();
|
||||
const onboarding = await api.getUserOnboarding();
|
||||
const listingId = onboarding?.selectedStoreListingVersionId;
|
||||
if (listingId) {
|
||||
const libraryAgent = await api.addMarketplaceAgentToLibrary(listingId);
|
||||
revalidatePath(`/library/agents/${libraryAgent.id}`, "layout");
|
||||
redirect(`/library/agents/${libraryAgent.id}`);
|
||||
} else {
|
||||
revalidatePath("/library", "layout");
|
||||
redirect("/library");
|
||||
}
|
||||
revalidatePath("/library", "layout");
|
||||
redirect("/library");
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ export default async function OnboardingPage() {
|
||||
// CONGRATS is the last step in intro onboarding
|
||||
if (onboarding.completedSteps.includes("CONGRATS")) redirect("/marketplace");
|
||||
else if (onboarding.completedSteps.includes("AGENT_INPUT"))
|
||||
redirect("/onboarding/5-run");
|
||||
redirect("/onboarding/6-congrats");
|
||||
else if (onboarding.completedSteps.includes("AGENT_NEW_RUN"))
|
||||
redirect("/onboarding/5-run");
|
||||
else if (onboarding.completedSteps.includes("AGENT_CHOICE"))
|
||||
|
||||
@@ -56,9 +56,3 @@ export async function getAdminListingsWithVersions(
|
||||
const response = await api.getAdminListingsWithVersions(data);
|
||||
return response;
|
||||
}
|
||||
|
||||
export async function downloadAsAdmin(storeListingVersion: string) {
|
||||
const api = new BackendApi();
|
||||
const file = await api.downloadStoreAgentAdmin(storeListingVersion);
|
||||
return file;
|
||||
}
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
"use client";
|
||||
import React, {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
import React, { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { useParams, useRouter } from "next/navigation";
|
||||
|
||||
import { exportAsJSONFile } from "@/lib/utils";
|
||||
@@ -29,16 +23,6 @@ import AgentRunDetailsView from "@/components/agents/agent-run-details-view";
|
||||
import AgentRunsSelectorList from "@/components/agents/agent-runs-selector-list";
|
||||
import AgentScheduleDetailsView from "@/components/agents/agent-schedule-details-view";
|
||||
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
|
||||
export default function AgentRunsPage(): React.ReactElement {
|
||||
const { id: agentID }: { id: LibraryAgentID } = useParams();
|
||||
@@ -47,7 +31,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
|
||||
// ============================ STATE =============================
|
||||
|
||||
const [graph, setGraph] = useState<Graph | null>(null); // Graph version corresponding to LibraryAgent
|
||||
const [graph, setGraph] = useState<Graph | null>(null);
|
||||
const [agent, setAgent] = useState<LibraryAgent | null>(null);
|
||||
const [agentRuns, setAgentRuns] = useState<GraphExecutionMeta[]>([]);
|
||||
const [schedules, setSchedules] = useState<Schedule[]>([]);
|
||||
@@ -66,10 +50,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
useState<boolean>(false);
|
||||
const [confirmingDeleteAgentRun, setConfirmingDeleteAgentRun] =
|
||||
useState<GraphExecutionMeta | null>(null);
|
||||
const { state: onboardingState, updateState: updateOnboardingState } =
|
||||
useOnboarding();
|
||||
const [copyAgentDialogOpen, setCopyAgentDialogOpen] = useState(false);
|
||||
const { toast } = useToast();
|
||||
const { state, updateState } = useOnboarding();
|
||||
|
||||
const openRunDraftView = useCallback(() => {
|
||||
selectView({ type: "run" });
|
||||
@@ -84,43 +65,34 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
setSelectedSchedule(schedule);
|
||||
}, []);
|
||||
|
||||
const graphVersions = useRef<Record<number, Graph>>({});
|
||||
const loadingGraphVersions = useRef<Record<number, Promise<Graph>>>({});
|
||||
const [graphVersions, setGraphVersions] = useState<Record<number, Graph>>({});
|
||||
const getGraphVersion = useCallback(
|
||||
async (graphID: GraphID, version: number) => {
|
||||
if (version in graphVersions.current)
|
||||
return graphVersions.current[version];
|
||||
if (version in loadingGraphVersions.current)
|
||||
return loadingGraphVersions.current[version];
|
||||
if (graphVersions[version]) return graphVersions[version];
|
||||
|
||||
const pendingGraph = api.getGraph(graphID, version).then((graph) => {
|
||||
graphVersions.current[version] = graph;
|
||||
return graph;
|
||||
});
|
||||
// Cache promise as well to avoid duplicate requests
|
||||
loadingGraphVersions.current[version] = pendingGraph;
|
||||
return pendingGraph;
|
||||
const graphVersion = await api.getGraph(graphID, version);
|
||||
setGraphVersions((prev) => ({
|
||||
...prev,
|
||||
[version]: graphVersion,
|
||||
}));
|
||||
return graphVersion;
|
||||
},
|
||||
[api, graphVersions, loadingGraphVersions],
|
||||
[api, graphVersions],
|
||||
);
|
||||
|
||||
// Reward user for viewing results of their onboarding agent
|
||||
useEffect(() => {
|
||||
if (
|
||||
!onboardingState ||
|
||||
!selectedRun ||
|
||||
onboardingState.completedSteps.includes("GET_RESULTS")
|
||||
)
|
||||
if (!state || !selectedRun || state.completedSteps.includes("GET_RESULTS"))
|
||||
return;
|
||||
|
||||
if (selectedRun.id === onboardingState.onboardingAgentExecutionId) {
|
||||
updateOnboardingState({
|
||||
completedSteps: [...onboardingState.completedSteps, "GET_RESULTS"],
|
||||
if (selectedRun.id === state.onboardingAgentExecutionId) {
|
||||
updateState({
|
||||
completedSteps: [...state.completedSteps, "GET_RESULTS"],
|
||||
});
|
||||
}
|
||||
}, [selectedRun, onboardingState, updateOnboardingState]);
|
||||
}, [selectedRun, state]);
|
||||
|
||||
const refreshPageData = useCallback(() => {
|
||||
const fetchAgents = useCallback(() => {
|
||||
api.getLibraryAgent(agentID).then((agent) => {
|
||||
setAgent(agent);
|
||||
|
||||
@@ -135,40 +107,38 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
new Set(agentRuns.map((run) => run.graph_version)).forEach((version) =>
|
||||
getGraphVersion(agent.graph_id, version),
|
||||
);
|
||||
|
||||
if (!selectedView.id && isFirstLoad && agentRuns.length > 0) {
|
||||
// only for first load or first execution
|
||||
setIsFirstLoad(false);
|
||||
|
||||
const latestRun = agentRuns.reduce((latest, current) => {
|
||||
if (latest.started_at && !current.started_at) return current;
|
||||
else if (!latest.started_at) return latest;
|
||||
return latest.started_at > current.started_at ? latest : current;
|
||||
}, agentRuns[0]);
|
||||
selectView({ type: "run", id: latestRun.id });
|
||||
}
|
||||
});
|
||||
});
|
||||
}, [api, agentID, getGraphVersion, graph]);
|
||||
if (selectedView.type == "run" && selectedView.id && agent) {
|
||||
api
|
||||
.getGraphExecutionInfo(agent.graph_id, selectedView.id)
|
||||
.then(setSelectedRun);
|
||||
}
|
||||
}, [api, agentID, getGraphVersion, graph, selectedView, isFirstLoad, agent]);
|
||||
|
||||
// On first load: select the latest run
|
||||
useEffect(() => {
|
||||
// Only for first load or first execution
|
||||
if (selectedView.id || !isFirstLoad || agentRuns.length == 0) return;
|
||||
setIsFirstLoad(false);
|
||||
|
||||
const latestRun = agentRuns.reduce((latest, current) => {
|
||||
if (latest.started_at && !current.started_at) return current;
|
||||
else if (!latest.started_at) return latest;
|
||||
return latest.started_at > current.started_at ? latest : current;
|
||||
}, agentRuns[0]);
|
||||
selectView({ type: "run", id: latestRun.id });
|
||||
}, [agentRuns, isFirstLoad, selectedView.id, selectView]);
|
||||
|
||||
// Initial load
|
||||
useEffect(() => {
|
||||
refreshPageData();
|
||||
fetchAgents();
|
||||
}, []);
|
||||
|
||||
// Subscribe to WebSocket updates for agent runs
|
||||
// Subscribe to websocket updates for agent runs
|
||||
useEffect(() => {
|
||||
if (!agent?.graph_id) return;
|
||||
if (!agent) return;
|
||||
|
||||
return api.onWebSocketConnect(() => {
|
||||
refreshPageData(); // Sync up on (re)connect
|
||||
|
||||
// Subscribe to all executions for this agent
|
||||
api.subscribeToGraphExecutions(agent.graph_id);
|
||||
});
|
||||
}, [api, agent?.graph_id, refreshPageData]);
|
||||
// Subscribe to all executions for this agent
|
||||
api.subscribeToGraphExecutions(agent.graph_id);
|
||||
}, [api, agent]);
|
||||
|
||||
// Handle execution updates
|
||||
useEffect(() => {
|
||||
@@ -197,29 +167,24 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
};
|
||||
}, [api, agent?.graph_id, selectedView.id]);
|
||||
|
||||
// Pre-load selectedRun based on selectedView
|
||||
// load selectedRun based on selectedView
|
||||
useEffect(() => {
|
||||
if (selectedView.type != "run" || !selectedView.id) return;
|
||||
if (selectedView.type != "run" || !selectedView.id || !agent) return;
|
||||
|
||||
const newSelectedRun = agentRuns.find((run) => run.id == selectedView.id);
|
||||
if (selectedView.id !== selectedRun?.id) {
|
||||
// Pull partial data from "cache" while waiting for the rest to load
|
||||
setSelectedRun(newSelectedRun ?? null);
|
||||
|
||||
// Ensure corresponding graph version is available before rendering I/O
|
||||
api
|
||||
.getGraphExecutionInfo(agent.graph_id, selectedView.id)
|
||||
.then(async (run) => {
|
||||
await getGraphVersion(run.graph_id, run.graph_version);
|
||||
setSelectedRun(run);
|
||||
});
|
||||
}
|
||||
}, [api, selectedView, agentRuns, selectedRun?.id]);
|
||||
|
||||
// Load selectedRun based on selectedView; refresh on agent refresh
|
||||
useEffect(() => {
|
||||
if (selectedView.type != "run" || !selectedView.id || !agent) return;
|
||||
|
||||
api
|
||||
.getGraphExecutionInfo(agent.graph_id, selectedView.id)
|
||||
.then(async (run) => {
|
||||
// Ensure corresponding graph version is available before rendering I/O
|
||||
await getGraphVersion(run.graph_id, run.graph_version);
|
||||
setSelectedRun(run);
|
||||
});
|
||||
}, [api, selectedView, agent, getGraphVersion]);
|
||||
}, [api, selectedView, agent, agentRuns, selectedRun?.id, getGraphVersion]);
|
||||
|
||||
const fetchSchedules = useCallback(async () => {
|
||||
if (!agent) return;
|
||||
@@ -272,41 +237,20 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
[api, agent],
|
||||
);
|
||||
|
||||
const copyAgent = useCallback(async () => {
|
||||
setCopyAgentDialogOpen(false);
|
||||
api
|
||||
.forkLibraryAgent(agentID)
|
||||
.then((newAgent) => {
|
||||
router.push(`/library/agents/${newAgent.id}`);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error("Error copying agent:", error);
|
||||
toast({
|
||||
title: "Error copying agent",
|
||||
description: `An error occurred while copying the agent: ${error.message}`,
|
||||
variant: "destructive",
|
||||
});
|
||||
});
|
||||
}, [agentID, api, router, toast]);
|
||||
|
||||
const agentActions: ButtonAction[] = useMemo(
|
||||
() => [
|
||||
{
|
||||
label: "Customize agent",
|
||||
href: `/build?flowID=${agent?.graph_id}&flowVersion=${agent?.graph_version}`,
|
||||
disabled: !agent?.can_access_graph,
|
||||
},
|
||||
{ label: "Export agent to file", callback: downloadGraph },
|
||||
...(!agent?.can_access_graph
|
||||
...(agent?.can_access_graph
|
||||
? [
|
||||
{
|
||||
label: "Edit a copy",
|
||||
callback: () => setCopyAgentDialogOpen(true),
|
||||
label: "Open graph in builder",
|
||||
href: `/build?flowID=${agent.graph_id}&flowVersion=${agent.graph_version}`,
|
||||
},
|
||||
{ label: "Export agent to file", callback: downloadGraph },
|
||||
]
|
||||
: []),
|
||||
{
|
||||
label: "Delete agent",
|
||||
variant: "destructive",
|
||||
callback: () => setAgentDeleteDialogOpen(true),
|
||||
},
|
||||
],
|
||||
@@ -351,7 +295,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
selectedRun && (
|
||||
<AgentRunDetailsView
|
||||
agent={agent}
|
||||
graph={graphVersions.current[selectedRun.graph_version] ?? graph}
|
||||
graph={graphVersions[selectedRun.graph_version] ?? graph}
|
||||
run={selectedRun}
|
||||
agentActions={agentActions}
|
||||
onRun={(runID) => selectRun(runID)}
|
||||
@@ -395,36 +339,6 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
confirmingDeleteAgentRun && deleteRun(confirmingDeleteAgentRun)
|
||||
}
|
||||
/>
|
||||
{/* Copy agent confirmation dialog */}
|
||||
<Dialog
|
||||
onOpenChange={setCopyAgentDialogOpen}
|
||||
open={copyAgentDialogOpen}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>You're making an editable copy</DialogTitle>
|
||||
<DialogDescription className="pt-2">
|
||||
The original Marketplace agent stays the same and cannot be
|
||||
edited. We'll save a new version of this agent to your
|
||||
Library. From there, you can customize it however you'd
|
||||
like by clicking "Customize agent" — this will open
|
||||
the builder where you can see and modify the inner workings.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<DialogFooter className="justify-end">
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={() => setCopyAgentDialogOpen(false)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button type="button" onClick={copyAgent}>
|
||||
Continue
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -6,7 +6,6 @@ import * as Sentry from "@sentry/nextjs";
|
||||
import getServerSupabase from "@/lib/supabase/getServerSupabase";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
||||
import { verifyTurnstileToken } from "@/lib/turnstile";
|
||||
|
||||
export async function logout() {
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
@@ -40,10 +39,7 @@ async function shouldShowOnboarding() {
|
||||
);
|
||||
}
|
||||
|
||||
export async function login(
|
||||
values: z.infer<typeof loginFormSchema>,
|
||||
turnstileToken: string,
|
||||
) {
|
||||
export async function login(values: z.infer<typeof loginFormSchema>) {
|
||||
return await Sentry.withServerActionInstrumentation("login", {}, async () => {
|
||||
const supabase = getServerSupabase();
|
||||
const api = new BackendAPI();
|
||||
@@ -52,12 +48,6 @@ export async function login(
|
||||
redirect("/error");
|
||||
}
|
||||
|
||||
// Verify Turnstile token if provided
|
||||
const success = await verifyTurnstileToken(turnstileToken, "login");
|
||||
if (!success) {
|
||||
return "CAPTCHA verification failed. Please try again.";
|
||||
}
|
||||
|
||||
// We are sure that the values are of the correct type because zod validates the form
|
||||
const { data, error } = await supabase.auth.signInWithPassword(values);
|
||||
|
||||
|
||||
@@ -24,11 +24,9 @@ import {
|
||||
AuthFeedback,
|
||||
AuthBottomText,
|
||||
PasswordInput,
|
||||
Turnstile,
|
||||
} from "@/components/auth";
|
||||
import { loginFormSchema } from "@/types/auth";
|
||||
import { getBehaveAs } from "@/lib/utils";
|
||||
import { useTurnstile } from "@/hooks/useTurnstile";
|
||||
|
||||
export default function LoginPage() {
|
||||
const { supabase, user, isUserLoading } = useSupabase();
|
||||
@@ -36,12 +34,6 @@ export default function LoginPage() {
|
||||
const router = useRouter();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
const turnstile = useTurnstile({
|
||||
action: "login",
|
||||
autoVerify: false,
|
||||
resetOnError: true,
|
||||
});
|
||||
|
||||
const form = useForm<z.infer<typeof loginFormSchema>>({
|
||||
resolver: zodResolver(loginFormSchema),
|
||||
defaultValues: {
|
||||
@@ -73,23 +65,15 @@ export default function LoginPage() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!turnstile.verified) {
|
||||
setFeedback("Please complete the CAPTCHA challenge.");
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const error = await login(data, turnstile.token as string);
|
||||
const error = await login(data);
|
||||
setIsLoading(false);
|
||||
if (error) {
|
||||
setFeedback(error);
|
||||
// Always reset the turnstile on any error
|
||||
turnstile.reset();
|
||||
return;
|
||||
}
|
||||
setFeedback(null);
|
||||
},
|
||||
[form, turnstile],
|
||||
[form],
|
||||
);
|
||||
|
||||
if (user) {
|
||||
@@ -156,17 +140,6 @@ export default function LoginPage() {
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
{/* Turnstile CAPTCHA Component */}
|
||||
<Turnstile
|
||||
siteKey={turnstile.siteKey}
|
||||
onVerify={turnstile.handleVerify}
|
||||
onExpire={turnstile.handleExpire}
|
||||
onError={turnstile.handleError}
|
||||
action="login"
|
||||
shouldRender={turnstile.shouldRender}
|
||||
/>
|
||||
|
||||
<AuthButton
|
||||
onClick={() => onLogin(form.getValues())}
|
||||
isLoading={isLoading}
|
||||
@@ -176,7 +149,6 @@ export default function LoginPage() {
|
||||
</AuthButton>
|
||||
</form>
|
||||
<AuthFeedback
|
||||
type="login"
|
||||
message={feedback}
|
||||
isError={!!feedback}
|
||||
behaveAs={getBehaveAs()}
|
||||
|
||||
@@ -6,7 +6,6 @@ import { AgentsSection } from "@/components/agptui/composite/AgentsSection";
|
||||
import { BecomeACreator } from "@/components/agptui/BecomeACreator";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { Metadata } from "next";
|
||||
import getServerUser from "@/lib/supabase/getServerUser";
|
||||
|
||||
export async function generateMetadata({
|
||||
params,
|
||||
@@ -17,7 +16,7 @@ export async function generateMetadata({
|
||||
const agent = await api.getStoreAgent(params.creator, params.slug);
|
||||
|
||||
return {
|
||||
title: `${agent.agent_name} - AutoGPT Marketplace`,
|
||||
title: `${agent.agent_name} - AutoGPT Store`,
|
||||
description: agent.description,
|
||||
};
|
||||
}
|
||||
@@ -37,7 +36,6 @@ export default async function Page({
|
||||
params: { creator: string; slug: string };
|
||||
}) {
|
||||
const creator_lower = params.creator.toLowerCase();
|
||||
const { user } = await getServerUser();
|
||||
const api = new BackendAPI();
|
||||
const agent = await api.getStoreAgent(creator_lower, params.slug);
|
||||
const otherAgents = await api.getStoreAgents({ creator: creator_lower });
|
||||
@@ -45,17 +43,9 @@ export default async function Page({
|
||||
// We are using slug as we know its has been sanitized and is not null
|
||||
search_query: agent.slug.replace(/-/g, " "),
|
||||
});
|
||||
const libraryAgent = user
|
||||
? await api
|
||||
.getLibraryAgentByStoreListingVersionID(agent.store_listing_version_id)
|
||||
.catch((error) => {
|
||||
console.error("Failed to fetch library agent:", error);
|
||||
return null;
|
||||
})
|
||||
: null;
|
||||
|
||||
const breadcrumbs = [
|
||||
{ name: "Marketplace", link: "/marketplace" },
|
||||
{ name: "Store", link: "/marketplace" },
|
||||
{
|
||||
name: agent.creator,
|
||||
link: `/marketplace/creator/${encodeURIComponent(agent.creator)}`,
|
||||
@@ -71,10 +61,9 @@ export default async function Page({
|
||||
<div className="mt-4 flex flex-col items-start gap-4 sm:mt-6 sm:gap-6 md:mt-8 md:flex-row md:gap-8">
|
||||
<div className="w-full md:w-auto md:shrink-0">
|
||||
<AgentInfo
|
||||
user={user}
|
||||
name={agent.agent_name}
|
||||
creator={agent.creator}
|
||||
shortDescription={agent.sub_heading}
|
||||
shortDescription={agent.description}
|
||||
longDescription={agent.description}
|
||||
rating={agent.rating}
|
||||
runs={agent.runs}
|
||||
@@ -82,7 +71,6 @@ export default async function Page({
|
||||
lastUpdated={agent.updated_at}
|
||||
version={agent.versions[agent.versions.length - 1]}
|
||||
storeListingVersionId={agent.store_listing_version_id}
|
||||
libraryAgent={libraryAgent}
|
||||
/>
|
||||
</div>
|
||||
<AgentImages
|
||||
|
||||
@@ -298,9 +298,7 @@ export default function CreditsPage() {
|
||||
>
|
||||
<b>{formatCredits(transaction.amount)}</b>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{formatCredits(transaction.running_balance)}
|
||||
</TableCell>
|
||||
<TableCell>{formatCredits(transaction.balance)}</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
|
||||
@@ -116,7 +116,6 @@ export default function PrivatePage() {
|
||||
"544c62b5-1d0f-4156-8fb4-9525f11656eb", // Apollo
|
||||
"3bcdbda3-84a3-46af-8fdb-bfd2472298b8", // SmartLead
|
||||
"63a6e279-2dc2-448e-bf57-85776f7176dc", // ZeroBounce
|
||||
"9aa1bde0-4947-4a70-a20c-84daa3850d52", // Google Maps
|
||||
],
|
||||
[],
|
||||
);
|
||||
|
||||
@@ -18,15 +18,11 @@ export default function Layout({ children }: { children: React.ReactNode }) {
|
||||
href: "/profile/dashboard",
|
||||
icon: <IconDashboardLayout className="h-6 w-6" />,
|
||||
},
|
||||
...(process.env.NEXT_PUBLIC_SHOW_BILLING_PAGE === "true"
|
||||
? [
|
||||
{
|
||||
text: "Billing",
|
||||
href: "/profile/credits",
|
||||
icon: <IconCoin className="h-6 w-6" />,
|
||||
},
|
||||
]
|
||||
: []),
|
||||
{
|
||||
text: "Billing",
|
||||
href: "/profile/credits",
|
||||
icon: <IconCoin className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Integrations",
|
||||
href: "/profile/integrations",
|
||||
|
||||
@@ -3,9 +3,8 @@ import getServerSupabase from "@/lib/supabase/getServerSupabase";
|
||||
import { redirect } from "next/navigation";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { headers } from "next/headers";
|
||||
import { verifyTurnstileToken } from "@/lib/turnstile";
|
||||
|
||||
export async function sendResetEmail(email: string, turnstileToken: string) {
|
||||
export async function sendResetEmail(email: string) {
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"sendResetEmail",
|
||||
{},
|
||||
@@ -21,15 +20,6 @@ export async function sendResetEmail(email: string, turnstileToken: string) {
|
||||
redirect("/error");
|
||||
}
|
||||
|
||||
// Verify Turnstile token if provided
|
||||
const success = await verifyTurnstileToken(
|
||||
turnstileToken,
|
||||
"reset_password",
|
||||
);
|
||||
if (!success) {
|
||||
return "CAPTCHA verification failed. Please try again.";
|
||||
}
|
||||
|
||||
const { error } = await supabase.auth.resetPasswordForEmail(email, {
|
||||
redirectTo: `${origin}/reset_password`,
|
||||
});
|
||||
@@ -44,7 +34,7 @@ export async function sendResetEmail(email: string, turnstileToken: string) {
|
||||
);
|
||||
}
|
||||
|
||||
export async function changePassword(password: string, turnstileToken: string) {
|
||||
export async function changePassword(password: string) {
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"changePassword",
|
||||
{},
|
||||
@@ -55,15 +45,6 @@ export async function changePassword(password: string, turnstileToken: string) {
|
||||
redirect("/error");
|
||||
}
|
||||
|
||||
// Verify Turnstile token if provided
|
||||
const success = await verifyTurnstileToken(
|
||||
turnstileToken,
|
||||
"change_password",
|
||||
);
|
||||
if (!success) {
|
||||
return "CAPTCHA verification failed. Please try again.";
|
||||
}
|
||||
|
||||
const { error } = await supabase.auth.updateUser({ password });
|
||||
|
||||
if (error) {
|
||||
|
||||
@@ -5,7 +5,6 @@ import {
|
||||
AuthButton,
|
||||
AuthFeedback,
|
||||
PasswordInput,
|
||||
Turnstile,
|
||||
} from "@/components/auth";
|
||||
import {
|
||||
Form,
|
||||
@@ -26,7 +25,6 @@ import { z } from "zod";
|
||||
import { changePassword, sendResetEmail } from "./actions";
|
||||
import Spinner from "@/components/Spinner";
|
||||
import { getBehaveAs } from "@/lib/utils";
|
||||
import { useTurnstile } from "@/hooks/useTurnstile";
|
||||
|
||||
export default function ResetPasswordPage() {
|
||||
const { supabase, user, isUserLoading } = useSupabase();
|
||||
@@ -35,18 +33,6 @@ export default function ResetPasswordPage() {
|
||||
const [isError, setIsError] = useState(false);
|
||||
const [disabled, setDisabled] = useState(false);
|
||||
|
||||
const sendEmailTurnstile = useTurnstile({
|
||||
action: "reset_password",
|
||||
autoVerify: false,
|
||||
resetOnError: true,
|
||||
});
|
||||
|
||||
const changePasswordTurnstile = useTurnstile({
|
||||
action: "change_password",
|
||||
autoVerify: false,
|
||||
resetOnError: true,
|
||||
});
|
||||
|
||||
const sendEmailForm = useForm<z.infer<typeof sendEmailFormSchema>>({
|
||||
resolver: zodResolver(sendEmailFormSchema),
|
||||
defaultValues: {
|
||||
@@ -72,22 +58,11 @@ export default function ResetPasswordPage() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!sendEmailTurnstile.verified) {
|
||||
setFeedback("Please complete the CAPTCHA challenge.");
|
||||
setIsError(true);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const error = await sendResetEmail(
|
||||
data.email,
|
||||
sendEmailTurnstile.token as string,
|
||||
);
|
||||
const error = await sendResetEmail(data.email);
|
||||
setIsLoading(false);
|
||||
if (error) {
|
||||
setFeedback(error);
|
||||
setIsError(true);
|
||||
sendEmailTurnstile.reset();
|
||||
return;
|
||||
}
|
||||
setDisabled(true);
|
||||
@@ -96,7 +71,7 @@ export default function ResetPasswordPage() {
|
||||
);
|
||||
setIsError(false);
|
||||
},
|
||||
[sendEmailForm, sendEmailTurnstile],
|
||||
[sendEmailForm],
|
||||
);
|
||||
|
||||
const onChangePassword = useCallback(
|
||||
@@ -109,28 +84,17 @@ export default function ResetPasswordPage() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!changePasswordTurnstile.verified) {
|
||||
setFeedback("Please complete the CAPTCHA challenge.");
|
||||
setIsError(true);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const error = await changePassword(
|
||||
data.password,
|
||||
changePasswordTurnstile.token as string,
|
||||
);
|
||||
const error = await changePassword(data.password);
|
||||
setIsLoading(false);
|
||||
if (error) {
|
||||
setFeedback(error);
|
||||
setIsError(true);
|
||||
changePasswordTurnstile.reset();
|
||||
return;
|
||||
}
|
||||
setFeedback("Password changed successfully. Redirecting to login.");
|
||||
setIsError(false);
|
||||
},
|
||||
[changePasswordForm, changePasswordTurnstile],
|
||||
[changePasswordForm],
|
||||
);
|
||||
|
||||
if (isUserLoading) {
|
||||
@@ -181,17 +145,6 @@ export default function ResetPasswordPage() {
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
{/* Turnstile CAPTCHA Component for password change */}
|
||||
<Turnstile
|
||||
siteKey={changePasswordTurnstile.siteKey}
|
||||
onVerify={changePasswordTurnstile.handleVerify}
|
||||
onExpire={changePasswordTurnstile.handleExpire}
|
||||
onError={changePasswordTurnstile.handleError}
|
||||
action="change_password"
|
||||
shouldRender={changePasswordTurnstile.shouldRender}
|
||||
/>
|
||||
|
||||
<AuthButton
|
||||
onClick={() => onChangePassword(changePasswordForm.getValues())}
|
||||
isLoading={isLoading}
|
||||
@@ -200,7 +153,6 @@ export default function ResetPasswordPage() {
|
||||
Update password
|
||||
</AuthButton>
|
||||
<AuthFeedback
|
||||
type="login"
|
||||
message={feedback}
|
||||
isError={isError}
|
||||
behaveAs={getBehaveAs()}
|
||||
@@ -223,17 +175,6 @@ export default function ResetPasswordPage() {
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
{/* Turnstile CAPTCHA Component for reset email */}
|
||||
<Turnstile
|
||||
siteKey={sendEmailTurnstile.siteKey}
|
||||
onVerify={sendEmailTurnstile.handleVerify}
|
||||
onExpire={sendEmailTurnstile.handleExpire}
|
||||
onError={sendEmailTurnstile.handleError}
|
||||
action="reset_password"
|
||||
shouldRender={sendEmailTurnstile.shouldRender}
|
||||
/>
|
||||
|
||||
<AuthButton
|
||||
onClick={() => onSendEmail(sendEmailForm.getValues())}
|
||||
isLoading={isLoading}
|
||||
@@ -243,7 +184,6 @@ export default function ResetPasswordPage() {
|
||||
Send reset email
|
||||
</AuthButton>
|
||||
<AuthFeedback
|
||||
type="login"
|
||||
message={feedback}
|
||||
isError={isError}
|
||||
behaveAs={getBehaveAs()}
|
||||
|
||||
@@ -6,12 +6,8 @@ import * as Sentry from "@sentry/nextjs";
|
||||
import getServerSupabase from "@/lib/supabase/getServerSupabase";
|
||||
import { signupFormSchema } from "@/types/auth";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { verifyTurnstileToken } from "@/lib/turnstile";
|
||||
|
||||
export async function signup(
|
||||
values: z.infer<typeof signupFormSchema>,
|
||||
turnstileToken: string,
|
||||
) {
|
||||
export async function signup(values: z.infer<typeof signupFormSchema>) {
|
||||
"use server";
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"signup",
|
||||
@@ -23,12 +19,6 @@ export async function signup(
|
||||
redirect("/error");
|
||||
}
|
||||
|
||||
// Verify Turnstile token if provided
|
||||
const success = await verifyTurnstileToken(turnstileToken, "signup");
|
||||
if (!success) {
|
||||
return "CAPTCHA verification failed. Please try again.";
|
||||
}
|
||||
|
||||
// We are sure that the values are of the correct type because zod validates the form
|
||||
const { data, error } = await supabase.auth.signUp(values);
|
||||
|
||||
|
||||
@@ -25,12 +25,10 @@ import {
|
||||
AuthButton,
|
||||
AuthBottomText,
|
||||
PasswordInput,
|
||||
Turnstile,
|
||||
} from "@/components/auth";
|
||||
import AuthFeedback from "@/components/auth/AuthFeedback";
|
||||
import { signupFormSchema } from "@/types/auth";
|
||||
import { getBehaveAs } from "@/lib/utils";
|
||||
import { useTurnstile } from "@/hooks/useTurnstile";
|
||||
|
||||
export default function SignupPage() {
|
||||
const { supabase, user, isUserLoading } = useSupabase();
|
||||
@@ -39,12 +37,6 @@ export default function SignupPage() {
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
//TODO: Remove after closed beta
|
||||
|
||||
const turnstile = useTurnstile({
|
||||
action: "signup",
|
||||
autoVerify: false,
|
||||
resetOnError: true,
|
||||
});
|
||||
|
||||
const form = useForm<z.infer<typeof signupFormSchema>>({
|
||||
resolver: zodResolver(signupFormSchema),
|
||||
defaultValues: {
|
||||
@@ -64,28 +56,20 @@ export default function SignupPage() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!turnstile.verified) {
|
||||
setFeedback("Please complete the CAPTCHA challenge.");
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const error = await signup(data, turnstile.token as string);
|
||||
const error = await signup(data);
|
||||
setIsLoading(false);
|
||||
if (error) {
|
||||
if (error === "user_already_exists") {
|
||||
setFeedback("User with this email already exists");
|
||||
turnstile.reset();
|
||||
return;
|
||||
} else {
|
||||
setFeedback(error);
|
||||
turnstile.reset();
|
||||
}
|
||||
return;
|
||||
}
|
||||
setFeedback(null);
|
||||
},
|
||||
[form, turnstile],
|
||||
[form],
|
||||
);
|
||||
|
||||
if (user) {
|
||||
@@ -106,7 +90,7 @@ export default function SignupPage() {
|
||||
}
|
||||
|
||||
return (
|
||||
<AuthCard className="mx-auto mt-12">
|
||||
<AuthCard className="mx-auto">
|
||||
<AuthHeader>Create a new account</AuthHeader>
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSignup)}>
|
||||
@@ -157,17 +141,6 @@ export default function SignupPage() {
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
{/* Turnstile CAPTCHA Component */}
|
||||
<Turnstile
|
||||
siteKey={turnstile.siteKey}
|
||||
onVerify={turnstile.handleVerify}
|
||||
onExpire={turnstile.handleExpire}
|
||||
onError={turnstile.handleError}
|
||||
action="signup"
|
||||
shouldRender={turnstile.shouldRender}
|
||||
/>
|
||||
|
||||
<AuthButton
|
||||
onClick={() => onSignup(form.getValues())}
|
||||
isLoading={isLoading}
|
||||
@@ -215,7 +188,6 @@ export default function SignupPage() {
|
||||
</form>
|
||||
</Form>
|
||||
<AuthFeedback
|
||||
type="signup"
|
||||
message={feedback}
|
||||
isError={!!feedback}
|
||||
behaveAs={getBehaveAs()}
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
"use server";
|
||||
|
||||
import { revalidatePath } from "next/cache";
|
||||
import BackendApi from "@/lib/autogpt-server-api";
|
||||
import {
|
||||
UsersBalanceHistoryResponse,
|
||||
CreditTransactionType,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
|
||||
export async function addDollars(formData: FormData) {
|
||||
const data = {
|
||||
user_id: formData.get("id") as string,
|
||||
amount: parseInt(formData.get("amount") as string),
|
||||
comments: formData.get("comments") as string,
|
||||
};
|
||||
const api = new BackendApi();
|
||||
const resp = await api.addUserCredits(
|
||||
data.user_id,
|
||||
data.amount,
|
||||
data.comments,
|
||||
);
|
||||
console.log(resp);
|
||||
revalidatePath("/admin/spending");
|
||||
}
|
||||
|
||||
export async function getUsersTransactionHistory(
|
||||
page: number = 1,
|
||||
pageSize: number = 20,
|
||||
search?: string,
|
||||
transactionType?: CreditTransactionType,
|
||||
): Promise<UsersBalanceHistoryResponse> {
|
||||
const data: Record<string, any> = {
|
||||
page,
|
||||
page_size: pageSize,
|
||||
};
|
||||
if (search) {
|
||||
data.search = search;
|
||||
}
|
||||
if (transactionType) {
|
||||
data.transaction_filter = transactionType;
|
||||
}
|
||||
const api = new BackendApi();
|
||||
const history = await api.getUsersHistory(data);
|
||||
return history;
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
import { AdminUserGrantHistory } from "@/components/admin/spending/admin-grant-history-data-table";
|
||||
import type { CreditTransactionType } from "@/lib/autogpt-server-api";
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { Suspense } from "react";
|
||||
|
||||
function SpendingDashboard({
|
||||
searchParams,
|
||||
}: {
|
||||
searchParams: {
|
||||
page?: string;
|
||||
status?: string;
|
||||
search?: string;
|
||||
};
|
||||
}) {
|
||||
const page = searchParams.page ? Number.parseInt(searchParams.page) : 1;
|
||||
const search = searchParams.search;
|
||||
const status = searchParams.status as CreditTransactionType | undefined;
|
||||
|
||||
return (
|
||||
<div className="mx-auto p-6">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<h1 className="text-3xl font-bold">User Spending</h1>
|
||||
<p className="text-gray-500">Manage user spending balances</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Suspense
|
||||
fallback={
|
||||
<div className="py-10 text-center">Loading submissions...</div>
|
||||
}
|
||||
>
|
||||
<AdminUserGrantHistory
|
||||
initialPage={page}
|
||||
initialStatus={status}
|
||||
initialSearch={search}
|
||||
/>
|
||||
</Suspense>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default async function SpendingDashboardPage({
|
||||
searchParams,
|
||||
}: {
|
||||
searchParams: {
|
||||
page?: string;
|
||||
status?: string;
|
||||
search?: string;
|
||||
};
|
||||
}) {
|
||||
"use server";
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedSpendingDashboard = await withAdminAccess(SpendingDashboard);
|
||||
return <ProtectedSpendingDashboard searchParams={searchParams} />;
|
||||
}
|
||||
@@ -7,12 +7,27 @@ import { BackendAPIProvider } from "@/lib/autogpt-server-api/context";
|
||||
import { TooltipProvider } from "@/components/ui/tooltip";
|
||||
import CredentialsProvider from "@/components/integrations/credentials-provider";
|
||||
import { LaunchDarklyProvider } from "@/components/feature-flag/feature-flag-provider";
|
||||
import { MockClientProps } from "@/lib/autogpt-server-api/mock_client";
|
||||
import OnboardingProvider from "@/components/onboarding/onboarding-provider";
|
||||
|
||||
export function Providers({ children, ...props }: ThemeProviderProps) {
|
||||
export interface ProvidersProps extends ThemeProviderProps {
|
||||
children: React.ReactNode;
|
||||
useMockBackend?: boolean;
|
||||
mockClientProps?: MockClientProps;
|
||||
}
|
||||
|
||||
export function Providers({
|
||||
children,
|
||||
useMockBackend,
|
||||
mockClientProps,
|
||||
...props
|
||||
}: ProvidersProps) {
|
||||
return (
|
||||
<NextThemesProvider {...props}>
|
||||
<BackendAPIProvider>
|
||||
<BackendAPIProvider
|
||||
useMockBackend={useMockBackend}
|
||||
mockClientProps={mockClientProps}
|
||||
>
|
||||
<CredentialsProvider>
|
||||
<LaunchDarklyProvider>
|
||||
<OnboardingProvider>
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { downloadAsAdmin } from "@/app/(platform)/admin/marketplace/actions";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ExternalLink } from "lucide-react";
|
||||
import { useState } from "react";
|
||||
|
||||
export function DownloadAgentAdminButton({
|
||||
storeListingVersionId,
|
||||
}: {
|
||||
storeListingVersionId: string;
|
||||
}) {
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
const handleDownload = async () => {
|
||||
try {
|
||||
setIsLoading(true);
|
||||
// Call the server action to get the data
|
||||
const fileData = await downloadAsAdmin(storeListingVersionId);
|
||||
|
||||
// Client-side download logic
|
||||
const jsonData = JSON.stringify(fileData, null, 2);
|
||||
const blob = new Blob([jsonData], { type: "application/json" });
|
||||
|
||||
// Create a temporary URL for the Blob
|
||||
const url = window.URL.createObjectURL(blob);
|
||||
|
||||
// Create a temporary anchor element
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = `agent_${storeListingVersionId}.json`;
|
||||
|
||||
// Append the anchor to the body, click it, and remove it
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
|
||||
// Revoke the temporary URL
|
||||
window.URL.revokeObjectURL(url);
|
||||
} catch (error) {
|
||||
console.error("Download failed:", error);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={handleDownload}
|
||||
disabled={isLoading}
|
||||
>
|
||||
<ExternalLink className="mr-2 h-4 w-4" />
|
||||
{isLoading ? "Downloading..." : "Download"}
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
@@ -19,8 +19,6 @@ import {
|
||||
SubmissionStatus,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import { ApproveRejectButtons } from "./approve-reject-buttons";
|
||||
import { downloadAsAdmin } from "@/app/(platform)/admin/marketplace/actions";
|
||||
import { DownloadAgentAdminButton } from "./download-agent-button";
|
||||
|
||||
// Moved the getStatusBadge function into the client component
|
||||
const getStatusBadge = (status: SubmissionStatus) => {
|
||||
@@ -79,11 +77,10 @@ export function ExpandableRow({
|
||||
</TableCell>
|
||||
<TableCell className="text-right">
|
||||
<div className="flex justify-end gap-2">
|
||||
{latestVersion?.store_listing_version_id && (
|
||||
<DownloadAgentAdminButton
|
||||
storeListingVersionId={latestVersion.store_listing_version_id}
|
||||
/>
|
||||
)}
|
||||
<Button size="sm" variant="outline">
|
||||
<ExternalLink className="mr-2 h-4 w-4" />
|
||||
Builder
|
||||
</Button>
|
||||
|
||||
{latestVersion?.status === SubmissionStatus.PENDING && (
|
||||
<ApproveRejectButtons version={latestVersion} />
|
||||
@@ -183,13 +180,17 @@ export function ExpandableRow({
|
||||
{/* <TableCell>{version.categories.join(", ")}</TableCell> */}
|
||||
<TableCell className="text-right">
|
||||
<div className="flex justify-end gap-2">
|
||||
{version.store_listing_version_id && (
|
||||
<DownloadAgentAdminButton
|
||||
storeListingVersionId={
|
||||
version.store_listing_version_id
|
||||
}
|
||||
/>
|
||||
)}
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={() =>
|
||||
(window.location.href = `/admin/agents/${version.store_listing_version_id}`)
|
||||
}
|
||||
>
|
||||
<ExternalLink className="mr-2 h-4 w-4" />
|
||||
Builder
|
||||
</Button>
|
||||
|
||||
{version.status === SubmissionStatus.PENDING && (
|
||||
<ApproveRejectButtons version={version} />
|
||||
)}
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogFooter,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { addDollars } from "@/app/admin/spending/actions";
|
||||
import useCredits from "@/hooks/useCredits";
|
||||
|
||||
export function AdminAddMoneyButton({
|
||||
userId,
|
||||
userEmail,
|
||||
currentBalance,
|
||||
defaultAmount,
|
||||
defaultComments,
|
||||
}: {
|
||||
userId: string;
|
||||
userEmail: string;
|
||||
currentBalance: number;
|
||||
defaultAmount?: number;
|
||||
defaultComments?: string;
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const [isAddMoneyDialogOpen, setIsAddMoneyDialogOpen] = useState(false);
|
||||
const [dollarAmount, setDollarAmount] = useState(
|
||||
defaultAmount ? Math.abs(defaultAmount / 100).toFixed(2) : "1.00",
|
||||
);
|
||||
|
||||
const { formatCredits } = useCredits();
|
||||
|
||||
const handleApproveSubmit = async (formData: FormData) => {
|
||||
setIsAddMoneyDialogOpen(false);
|
||||
try {
|
||||
await addDollars(formData);
|
||||
router.refresh(); // Refresh the current route
|
||||
} catch (error) {
|
||||
console.error("Error adding dollars:", error);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="default"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsAddMoneyDialogOpen(true);
|
||||
}}
|
||||
>
|
||||
Add Dollars
|
||||
</Button>
|
||||
|
||||
{/* Add $$$ Dialog */}
|
||||
<Dialog
|
||||
open={isAddMoneyDialogOpen}
|
||||
onOpenChange={setIsAddMoneyDialogOpen}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Add Dollars</DialogTitle>
|
||||
<DialogDescription className="pt-2">
|
||||
<div className="mb-2">
|
||||
<span className="font-medium">User:</span> {userEmail}
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-medium">Current balance:</span> $
|
||||
{(currentBalance / 100).toFixed(2)}
|
||||
</div>
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<form action={handleApproveSubmit}>
|
||||
<input type="hidden" name="id" value={userId} />
|
||||
<input
|
||||
type="hidden"
|
||||
name="amount"
|
||||
value={Math.round(parseFloat(dollarAmount) * 100)}
|
||||
/>
|
||||
|
||||
<div className="grid gap-4 py-4">
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="dollarAmount">Amount (in dollars)</Label>
|
||||
<div className="flex">
|
||||
<div className="flex items-center justify-center rounded-l-md border border-r-0 bg-gray-50 px-3 text-gray-500">
|
||||
$
|
||||
</div>
|
||||
<Input
|
||||
id="dollarAmount"
|
||||
type="number"
|
||||
step="0.01"
|
||||
min="0"
|
||||
className="rounded-l-none"
|
||||
value={dollarAmount}
|
||||
onChange={(e) => setDollarAmount(e.target.value)}
|
||||
placeholder="0.00"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 py-4">
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="comments">Comments (Optional)</Label>
|
||||
<Textarea
|
||||
id="comments"
|
||||
name="comments"
|
||||
placeholder="Why are you adding dollars?"
|
||||
defaultValue={defaultComments || "We love you!"}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<DialogFooter>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={() => setIsAddMoneyDialogOpen(false)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button type="submit">Add Dollars</Button>
|
||||
</DialogFooter>
|
||||
</form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,183 +0,0 @@
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
|
||||
import { PaginationControls } from "../../ui/pagination-controls";
|
||||
import { SearchAndFilterAdminSpending } from "./search-filter-form";
|
||||
import { getUsersTransactionHistory } from "@/app/admin/spending/actions";
|
||||
import { AdminAddMoneyButton } from "./add-money-button";
|
||||
import { CreditTransactionType } from "@/lib/autogpt-server-api";
|
||||
|
||||
export async function AdminUserGrantHistory({
|
||||
initialPage = 1,
|
||||
initialStatus,
|
||||
initialSearch,
|
||||
}: {
|
||||
initialPage?: number;
|
||||
initialStatus?: CreditTransactionType;
|
||||
initialSearch?: string;
|
||||
}) {
|
||||
// Server-side data fetching
|
||||
const { history, pagination } = await getUsersTransactionHistory(
|
||||
initialPage,
|
||||
15,
|
||||
initialSearch,
|
||||
initialStatus,
|
||||
);
|
||||
|
||||
// Helper function to format the amount with color based on transaction type
|
||||
const formatAmount = (amount: number, type: CreditTransactionType) => {
|
||||
const isPositive = type === CreditTransactionType.GRANT;
|
||||
const isNeutral = type === CreditTransactionType.TOP_UP;
|
||||
const color = isPositive
|
||||
? "text-green-600"
|
||||
: isNeutral
|
||||
? "text-blue-600"
|
||||
: "text-red-600";
|
||||
return <span className={color}>${Math.abs(amount / 100)}</span>;
|
||||
};
|
||||
|
||||
// Helper function to format the transaction type with color
|
||||
const formatType = (type: CreditTransactionType) => {
|
||||
const isGrant = type === CreditTransactionType.GRANT;
|
||||
const isPurchased = type === CreditTransactionType.TOP_UP;
|
||||
const isSpent = type === CreditTransactionType.USAGE;
|
||||
|
||||
let displayText = type;
|
||||
let bgColor = "";
|
||||
|
||||
if (isGrant) {
|
||||
bgColor = "bg-green-100 text-green-800";
|
||||
} else if (isPurchased) {
|
||||
bgColor = "bg-blue-100 text-blue-800";
|
||||
} else if (isSpent) {
|
||||
bgColor = "bg-red-100 text-red-800";
|
||||
}
|
||||
|
||||
return (
|
||||
<span className={`rounded-full px-2 py-1 text-xs font-medium ${bgColor}`}>
|
||||
{displayText.valueOf()}
|
||||
</span>
|
||||
);
|
||||
};
|
||||
|
||||
// Helper function to format the date
|
||||
const formatDate = (date: Date) => {
|
||||
return new Intl.DateTimeFormat("en-US", {
|
||||
month: "short",
|
||||
day: "numeric",
|
||||
year: "numeric",
|
||||
hour: "numeric",
|
||||
minute: "numeric",
|
||||
hour12: true,
|
||||
}).format(new Date(date));
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<SearchAndFilterAdminSpending
|
||||
initialStatus={initialStatus}
|
||||
initialSearch={initialSearch}
|
||||
/>
|
||||
|
||||
<div className="rounded-md border bg-white">
|
||||
<Table>
|
||||
<TableHeader className="bg-gray-50">
|
||||
<TableRow>
|
||||
<TableHead className="font-medium">User</TableHead>
|
||||
<TableHead className="font-medium">Type</TableHead>
|
||||
<TableHead className="font-medium">Date</TableHead>
|
||||
<TableHead className="font-medium">Reason</TableHead>
|
||||
<TableHead className="font-medium">Admin</TableHead>
|
||||
<TableHead className="font-medium">Starting Balance</TableHead>
|
||||
<TableHead className="font-medium">Amount</TableHead>
|
||||
<TableHead className="font-medium">Ending Balance</TableHead>
|
||||
{/* <TableHead className="font-medium">Current Balance</TableHead> */}
|
||||
<TableHead className="text-right font-medium">Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{history.length === 0 ? (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={8}
|
||||
className="py-10 text-center text-gray-500"
|
||||
>
|
||||
No transactions found
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
) : (
|
||||
history.map((transaction) => (
|
||||
<TableRow
|
||||
key={transaction.user_id}
|
||||
className="hover:bg-gray-50"
|
||||
>
|
||||
<TableCell className="font-medium">
|
||||
{transaction.user_email}
|
||||
</TableCell>
|
||||
|
||||
<TableCell>
|
||||
{formatType(transaction.transaction_type)}
|
||||
</TableCell>
|
||||
<TableCell className="text-gray-600">
|
||||
{formatDate(transaction.transaction_time)}
|
||||
</TableCell>
|
||||
<TableCell>{transaction.reason}</TableCell>
|
||||
<TableCell className="text-gray-600">
|
||||
{transaction.admin_email}
|
||||
</TableCell>
|
||||
<TableCell className="font-medium text-green-600">
|
||||
${(transaction.running_balance + -transaction.amount) / 100}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{formatAmount(
|
||||
transaction.amount,
|
||||
transaction.transaction_type,
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="font-medium text-green-600">
|
||||
${transaction.running_balance / 100}
|
||||
</TableCell>
|
||||
{/* <TableCell className="font-medium text-green-600">
|
||||
${transaction.current_balance / 100}
|
||||
</TableCell> */}
|
||||
<TableCell className="text-right">
|
||||
<AdminAddMoneyButton
|
||||
userId={transaction.user_id}
|
||||
userEmail={
|
||||
transaction.user_email ?? "User Email wasn't attached"
|
||||
}
|
||||
currentBalance={transaction.current_balance}
|
||||
defaultAmount={
|
||||
transaction.transaction_type ===
|
||||
CreditTransactionType.USAGE
|
||||
? -transaction.amount
|
||||
: undefined
|
||||
}
|
||||
defaultComments={
|
||||
transaction.transaction_type ===
|
||||
CreditTransactionType.USAGE
|
||||
? "Refund for usage"
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
|
||||
<PaginationControls
|
||||
currentPage={initialPage}
|
||||
totalPages={pagination.total_pages}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { useRouter, usePathname, useSearchParams } from "next/navigation";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Search } from "lucide-react";
|
||||
import { CreditTransactionType } from "@/lib/autogpt-server-api";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
|
||||
export function SearchAndFilterAdminSpending({
|
||||
initialStatus,
|
||||
initialSearch,
|
||||
}: {
|
||||
initialStatus?: CreditTransactionType;
|
||||
initialSearch?: string;
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const pathname = usePathname();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
// Initialize state from URL parameters
|
||||
const [searchQuery, setSearchQuery] = useState(initialSearch || "");
|
||||
const [selectedStatus, setSelectedStatus] = useState<string>(
|
||||
searchParams.get("status") || "ALL",
|
||||
);
|
||||
|
||||
// Update local state when URL parameters change
|
||||
useEffect(() => {
|
||||
const status = searchParams.get("status");
|
||||
setSelectedStatus(status || "ALL");
|
||||
setSearchQuery(searchParams.get("search") || "");
|
||||
}, [searchParams]);
|
||||
|
||||
const handleSearch = () => {
|
||||
const params = new URLSearchParams(searchParams.toString());
|
||||
|
||||
if (searchQuery) {
|
||||
params.set("search", searchQuery);
|
||||
} else {
|
||||
params.delete("search");
|
||||
}
|
||||
|
||||
if (selectedStatus !== "ALL") {
|
||||
params.set("status", selectedStatus);
|
||||
} else {
|
||||
params.delete("status");
|
||||
}
|
||||
|
||||
params.set("page", "1"); // Reset to first page on new search
|
||||
|
||||
router.push(`${pathname}?${params.toString()}`);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex w-full items-center gap-2">
|
||||
<Input
|
||||
placeholder="Search users by Name or Email..."
|
||||
value={searchQuery}
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
onKeyDown={(e) => e.key === "Enter" && handleSearch()}
|
||||
/>
|
||||
<Button variant="outline" onClick={handleSearch}>
|
||||
<Search className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<Select
|
||||
value={selectedStatus}
|
||||
onValueChange={(value) => {
|
||||
setSelectedStatus(value);
|
||||
const params = new URLSearchParams(searchParams.toString());
|
||||
if (value === "ALL") {
|
||||
params.delete("status");
|
||||
} else {
|
||||
params.set("status", value);
|
||||
}
|
||||
params.set("page", "1");
|
||||
router.push(`${pathname}?${params.toString()}`);
|
||||
}}
|
||||
>
|
||||
<SelectTrigger className="w-1/4">
|
||||
<SelectValue placeholder="Select Status" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="ALL">All</SelectItem>
|
||||
<SelectItem value={CreditTransactionType.TOP_UP}>Top Up</SelectItem>
|
||||
<SelectItem value={CreditTransactionType.USAGE}>Usage</SelectItem>
|
||||
<SelectItem value={CreditTransactionType.REFUND}>Refund</SelectItem>
|
||||
<SelectItem value={CreditTransactionType.GRANT}>Grant</SelectItem>
|
||||
<SelectItem value={CreditTransactionType.CARD_CHECK}>
|
||||
Card Check
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -239,10 +239,7 @@ export default function AgentRunDetailsView({
|
||||
{title || key}
|
||||
</label>
|
||||
{values.map((value, i) => (
|
||||
<p
|
||||
className="resize-none whitespace-pre-wrap break-words border-none text-sm text-neutral-700 disabled:cursor-not-allowed"
|
||||
key={i}
|
||||
>
|
||||
<p className="text-sm text-neutral-700" key={i}>
|
||||
{value}
|
||||
</p>
|
||||
))}
|
||||
|
||||
@@ -27,8 +27,6 @@ type Story = StoryObj<typeof meta>;
|
||||
|
||||
export const Default: Story = {
|
||||
args: {
|
||||
user: null,
|
||||
libraryAgent: null,
|
||||
name: "AI Video Generator",
|
||||
storeListingVersionId: "123",
|
||||
creator: "Toran Richards",
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
"use client";
|
||||
|
||||
import { StarRatingIcons } from "@/components/ui/icons";
|
||||
import * as React from "react";
|
||||
import { IconPlay, StarRatingIcons } from "@/components/ui/icons";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import BackendAPI, { LibraryAgent } from "@/lib/autogpt-server-api";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { useRouter } from "next/navigation";
|
||||
import Link from "next/link";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
|
||||
import useSupabase from "@/hooks/useSupabase";
|
||||
import { DownloadIcon, LoaderIcon } from "lucide-react";
|
||||
import { useOnboarding } from "../onboarding/onboarding-provider";
|
||||
import { User } from "@supabase/supabase-js";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { FC, useCallback, useMemo, useState } from "react";
|
||||
|
||||
interface AgentInfoProps {
|
||||
user: User | null;
|
||||
name: string;
|
||||
creator: string;
|
||||
shortDescription: string;
|
||||
@@ -24,11 +22,9 @@ interface AgentInfoProps {
|
||||
lastUpdated: string;
|
||||
version: string;
|
||||
storeListingVersionId: string;
|
||||
libraryAgent: LibraryAgent | null;
|
||||
}
|
||||
|
||||
export const AgentInfo: FC<AgentInfoProps> = ({
|
||||
user,
|
||||
export const AgentInfo: React.FC<AgentInfoProps> = ({
|
||||
name,
|
||||
creator,
|
||||
shortDescription,
|
||||
@@ -39,48 +35,28 @@ export const AgentInfo: FC<AgentInfoProps> = ({
|
||||
lastUpdated,
|
||||
version,
|
||||
storeListingVersionId,
|
||||
libraryAgent,
|
||||
}) => {
|
||||
const router = useRouter();
|
||||
const api = useMemo(() => new BackendAPI(), []);
|
||||
const api = React.useMemo(() => new BackendAPI(), []);
|
||||
const { user } = useSupabase();
|
||||
const { toast } = useToast();
|
||||
const { completeStep } = useOnboarding();
|
||||
const [adding, setAdding] = useState(false);
|
||||
const [downloading, setDownloading] = useState(false);
|
||||
|
||||
const libraryAction = useCallback(async () => {
|
||||
setAdding(true);
|
||||
if (libraryAgent) {
|
||||
toast({
|
||||
description: "Redirecting to your library...",
|
||||
duration: 2000,
|
||||
});
|
||||
// Redirect to the library agent page
|
||||
router.push(`/library/agents/${libraryAgent.id}`);
|
||||
return;
|
||||
}
|
||||
const [downloading, setDownloading] = React.useState(false);
|
||||
|
||||
const handleAddToLibrary = async () => {
|
||||
try {
|
||||
const newLibraryAgent = await api.addMarketplaceAgentToLibrary(
|
||||
storeListingVersionId,
|
||||
);
|
||||
completeStep("MARKETPLACE_ADD_AGENT");
|
||||
router.push(`/library/agents/${newLibraryAgent.id}`);
|
||||
toast({
|
||||
title: "Agent Added",
|
||||
description: "Redirecting to your library...",
|
||||
duration: 2000,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Failed to add agent to library:", error);
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to add agent to library. Please try again.",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}, [toast, api, storeListingVersionId, completeStep, router]);
|
||||
};
|
||||
|
||||
const handleDownload = useCallback(async () => {
|
||||
const handleDownloadToLibrary = async () => {
|
||||
const downloadAgent = async (): Promise<void> => {
|
||||
setDownloading(true);
|
||||
try {
|
||||
@@ -113,16 +89,12 @@ export const AgentInfo: FC<AgentInfoProps> = ({
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(`Error downloading agent:`, error);
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to download agent. Please try again.",
|
||||
variant: "destructive",
|
||||
});
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
await downloadAgent();
|
||||
setDownloading(false);
|
||||
}, [setDownloading, api, storeListingVersionId, toast]);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="w-full max-w-[396px] px-4 sm:px-6 lg:w-[396px] lg:px-0">
|
||||
@@ -133,61 +105,65 @@ export const AgentInfo: FC<AgentInfoProps> = ({
|
||||
|
||||
{/* Creator */}
|
||||
<div className="mb-3 flex w-full items-center gap-1.5 lg:mb-4">
|
||||
<div className="font-sans text-base font-normal text-neutral-800 dark:text-neutral-200 sm:text-lg lg:text-xl">
|
||||
<div className="font-geist text-base font-normal text-neutral-800 dark:text-neutral-200 sm:text-lg lg:text-xl">
|
||||
by
|
||||
</div>
|
||||
<Link
|
||||
href={`/marketplace/creator/${encodeURIComponent(creator)}`}
|
||||
className="font-sans text-base font-medium text-neutral-800 hover:underline dark:text-neutral-200 sm:text-lg lg:text-xl"
|
||||
className="font-geist text-base font-medium text-neutral-800 hover:underline dark:text-neutral-200 sm:text-lg lg:text-xl"
|
||||
>
|
||||
{creator}
|
||||
</Link>
|
||||
</div>
|
||||
|
||||
{/* Short Description */}
|
||||
<div className="mb-4 line-clamp-2 w-full font-sans text-base font-normal leading-normal text-neutral-600 dark:text-neutral-300 sm:text-lg lg:mb-6 lg:text-xl lg:leading-7">
|
||||
<div className="font-geist mb-4 line-clamp-2 w-full text-base font-normal leading-normal text-neutral-600 dark:text-neutral-300 sm:text-lg lg:mb-6 lg:text-xl lg:leading-7">
|
||||
{shortDescription}
|
||||
</div>
|
||||
|
||||
{/* Buttons */}
|
||||
<div className="mb-4 flex w-full gap-3 lg:mb-[60px]">
|
||||
{user && (
|
||||
{/* Run Agent Button */}
|
||||
<div className="mb-4 w-full lg:mb-[60px]">
|
||||
{user ? (
|
||||
<button
|
||||
className={cn(
|
||||
"inline-flex min-w-24 items-center justify-center rounded-full bg-violet-600 px-4 py-3",
|
||||
"transition-colors duration-200 hover:bg-violet-500 disabled:bg-zinc-400",
|
||||
)}
|
||||
onClick={libraryAction}
|
||||
disabled={adding}
|
||||
onClick={handleAddToLibrary}
|
||||
className="inline-flex w-full items-center justify-center gap-2 rounded-[38px] bg-violet-600 px-4 py-3 transition-colors hover:bg-violet-700 sm:w-auto sm:gap-2.5 sm:px-5 sm:py-3.5 lg:px-6 lg:py-4"
|
||||
>
|
||||
<span className="justify-start font-sans text-sm font-medium leading-snug text-primary-foreground">
|
||||
{libraryAgent ? "See runs" : "Add to library"}
|
||||
<IconPlay className="h-5 w-5 text-white sm:h-5 sm:w-5 lg:h-6 lg:w-6" />
|
||||
<span className="font-poppins text-base font-medium text-neutral-50 sm:text-lg">
|
||||
Add To Library
|
||||
</span>
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
onClick={handleDownloadToLibrary}
|
||||
className={`inline-flex w-full items-center justify-center gap-2 rounded-[38px] px-4 py-3 transition-colors sm:w-auto sm:gap-2.5 sm:px-5 sm:py-3.5 lg:px-6 lg:py-4 ${
|
||||
downloading
|
||||
? "bg-neutral-400"
|
||||
: "bg-violet-600 hover:bg-violet-700"
|
||||
}`}
|
||||
disabled={downloading}
|
||||
>
|
||||
{downloading ? (
|
||||
<LoaderIcon className="h-5 w-5 animate-spin text-white sm:h-5 sm:w-5 lg:h-6 lg:w-6" />
|
||||
) : (
|
||||
<DownloadIcon className="h-5 w-5 text-white sm:h-5 sm:w-5 lg:h-6 lg:w-6" />
|
||||
)}
|
||||
<span className="font-poppins text-base font-medium text-neutral-50 sm:text-lg">
|
||||
{downloading ? "Downloading..." : "Download Agent as File"}
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
className={cn(
|
||||
"inline-flex min-w-24 items-center justify-center rounded-full bg-zinc-200 px-4 py-3",
|
||||
"transition-colors duration-200 hover:bg-zinc-200/70 disabled:bg-zinc-200/40",
|
||||
)}
|
||||
onClick={handleDownload}
|
||||
disabled={downloading}
|
||||
>
|
||||
<div className="justify-start text-center font-sans text-sm font-medium leading-snug text-zinc-800">
|
||||
Download agent
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Rating and Runs */}
|
||||
<div className="mb-4 flex w-full items-center justify-between lg:mb-[44px]">
|
||||
<div className="flex items-center gap-1.5 sm:gap-2">
|
||||
<span className="whitespace-nowrap font-sans text-base font-semibold text-neutral-800 dark:text-neutral-200 sm:text-lg">
|
||||
<span className="font-geist whitespace-nowrap text-base font-semibold text-neutral-800 dark:text-neutral-200 sm:text-lg">
|
||||
{rating.toFixed(1)}
|
||||
</span>
|
||||
<div className="flex gap-0.5">{StarRatingIcons(rating)}</div>
|
||||
</div>
|
||||
<div className="whitespace-nowrap font-sans text-base font-semibold text-neutral-800 dark:text-neutral-200 sm:text-lg">
|
||||
<div className="font-geist whitespace-nowrap text-base font-semibold text-neutral-800 dark:text-neutral-200 sm:text-lg">
|
||||
{runs.toLocaleString()} runs
|
||||
</div>
|
||||
</div>
|
||||
@@ -207,14 +183,14 @@ export const AgentInfo: FC<AgentInfoProps> = ({
|
||||
|
||||
{/* Categories */}
|
||||
<div className="mb-4 flex w-full flex-col gap-1.5 sm:gap-2 lg:mb-[36px]">
|
||||
<div className="decoration-skip-ink-none mb-1.5 font-sans text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
|
||||
<div className="font-geist decoration-skip-ink-none mb-1.5 text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
|
||||
Categories
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-1.5 sm:gap-2">
|
||||
{categories.map((category, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className="decoration-skip-ink-none whitespace-nowrap rounded-full border border-neutral-600 bg-white px-2 py-0.5 font-sans text-base font-normal leading-6 text-neutral-800 underline-offset-[from-font] dark:border-neutral-700 dark:bg-neutral-800 dark:text-neutral-200 sm:px-[16px] sm:py-[10px]"
|
||||
className="font-geist decoration-skip-ink-none whitespace-nowrap rounded-full border border-neutral-600 bg-white px-2 py-0.5 text-base font-normal leading-6 text-neutral-800 underline-offset-[from-font] dark:border-neutral-700 dark:bg-neutral-800 dark:text-neutral-200 sm:px-[16px] sm:py-[10px]"
|
||||
>
|
||||
{category}
|
||||
</div>
|
||||
@@ -224,10 +200,10 @@ export const AgentInfo: FC<AgentInfoProps> = ({
|
||||
|
||||
{/* Version History */}
|
||||
<div className="flex w-full flex-col gap-0.5 sm:gap-1">
|
||||
<div className="decoration-skip-ink-none mb-1.5 font-sans text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
|
||||
<div className="font-geist decoration-skip-ink-none mb-1.5 text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
|
||||
Version history
|
||||
</div>
|
||||
<div className="decoration-skip-ink-none font-sans text-base font-normal leading-6 text-neutral-600 underline-offset-[from-font] dark:text-neutral-400">
|
||||
<div className="font-geist decoration-skip-ink-none text-base font-normal leading-6 text-neutral-600 underline-offset-[from-font] dark:text-neutral-400">
|
||||
Last updated {lastUpdated}
|
||||
</div>
|
||||
<div className="text-xs text-neutral-600 dark:text-neutral-400 sm:text-sm">
|
||||
|
||||
@@ -15,8 +15,8 @@ export interface AgentTableCardProps {
|
||||
imageSrc: string[];
|
||||
dateSubmitted: string;
|
||||
status: StatusType;
|
||||
runs: number;
|
||||
rating: number;
|
||||
runs?: number;
|
||||
rating?: number;
|
||||
id: number;
|
||||
onEditSubmission: (submission: StoreSubmissionRequest) => void;
|
||||
}
|
||||
@@ -82,11 +82,11 @@ export const AgentTableCard: React.FC<AgentTableCardProps> = ({
|
||||
{dateSubmitted}
|
||||
</div>
|
||||
<div className="text-sm text-neutral-600 dark:text-neutral-400">
|
||||
{runs.toLocaleString()} runs
|
||||
{runs ? runs.toLocaleString() : "N/A"} runs
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<span className="text-sm font-medium text-neutral-800 dark:text-neutral-200">
|
||||
{rating.toFixed(1)}
|
||||
{rating ? rating.toFixed(1) : "N/A"}
|
||||
</span>
|
||||
<IconStarFilled className="h-4 w-4 text-neutral-800 dark:text-neutral-200" />
|
||||
</div>
|
||||
|
||||
@@ -3,8 +3,6 @@ import { Navbar } from "./Navbar";
|
||||
import { userEvent, within } from "@storybook/test";
|
||||
import { IconType } from "../ui/icons";
|
||||
import { ProfileDetails } from "@/lib/autogpt-server-api/types";
|
||||
// You can't import this here, jest is not available in storybook and will crash it
|
||||
// import { jest } from "@jest/globals";
|
||||
|
||||
// Mock the API responses
|
||||
const mockProfileData: ProfileDetails = {
|
||||
@@ -15,40 +13,6 @@ const mockProfileData: ProfileDetails = {
|
||||
avatar_url: "https://avatars.githubusercontent.com/u/123456789?v=4",
|
||||
};
|
||||
|
||||
const mockCreditData = {
|
||||
credits: 1500,
|
||||
};
|
||||
|
||||
// Mock the API module
|
||||
// jest.mock("@/lib/autogpt-server-api", () => {
|
||||
// return function () {
|
||||
// return {
|
||||
// getStoreProfile: () => Promise.resolve(mockProfileData),
|
||||
// getUserCredit: () => Promise.resolve(mockCreditData),
|
||||
// };
|
||||
// };
|
||||
// });
|
||||
|
||||
const meta = {
|
||||
title: "AGPT UI/Navbar",
|
||||
component: Navbar,
|
||||
parameters: {
|
||||
layout: "fullscreen",
|
||||
},
|
||||
tags: ["autodocs"],
|
||||
argTypes: {
|
||||
// isLoggedIn: { control: "boolean" },
|
||||
// avatarSrc: { control: "text" },
|
||||
links: { control: "object" },
|
||||
// activeLink: { control: "text" },
|
||||
menuItemGroups: { control: "object" },
|
||||
// params: { control: { type: "object", defaultValue: { lang: "en" } } },
|
||||
},
|
||||
} satisfies Meta<typeof Navbar>;
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
const defaultMenuItemGroups = [
|
||||
{
|
||||
items: [
|
||||
@@ -89,35 +53,83 @@ const defaultLinks = [
|
||||
{ name: "Build", href: "/builder" },
|
||||
];
|
||||
|
||||
const meta = {
|
||||
title: "AGPT UI/Navbar",
|
||||
component: Navbar,
|
||||
parameters: {
|
||||
layout: "fullscreen",
|
||||
},
|
||||
tags: ["autodocs"],
|
||||
argTypes: {
|
||||
links: { control: "object" },
|
||||
menuItemGroups: { control: "object" },
|
||||
mockUser: { control: "object" },
|
||||
mockClientProps: { control: "object" },
|
||||
},
|
||||
} satisfies Meta<typeof Navbar>;
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
export const Default: Story = {
|
||||
args: {
|
||||
// params: { lang: "en" },
|
||||
// isLoggedIn: true,
|
||||
links: defaultLinks,
|
||||
// activeLink: "/marketplace",
|
||||
// avatarSrc: mockProfileData.avatar_url,
|
||||
menuItemGroups: defaultMenuItemGroups,
|
||||
mockUser: {
|
||||
id: "123",
|
||||
email: "test@test.com",
|
||||
user_metadata: {
|
||||
name: "Test User",
|
||||
},
|
||||
app_metadata: {
|
||||
provider: "email",
|
||||
},
|
||||
aud: "test",
|
||||
created_at: new Date().toISOString(),
|
||||
},
|
||||
mockClientProps: {
|
||||
credits: 1500,
|
||||
profile: mockProfileData,
|
||||
},
|
||||
},
|
||||
parameters: {
|
||||
mockBackend: {
|
||||
credits: 1500,
|
||||
profile: mockProfileData,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const WithActiveLink: Story = {
|
||||
export const WithCredits: Story = {
|
||||
args: {
|
||||
...Default.args,
|
||||
// activeLink: "/library",
|
||||
},
|
||||
parameters: {
|
||||
mockBackend: {
|
||||
credits: 1500,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const LongUserName: Story = {
|
||||
export const WithLargeCredits: Story = {
|
||||
args: {
|
||||
...Default.args,
|
||||
// avatarSrc: "https://avatars.githubusercontent.com/u/987654321?v=4",
|
||||
},
|
||||
parameters: {
|
||||
mockBackend: {
|
||||
credits: 999999,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const NoAvatar: Story = {
|
||||
export const WithZeroCredits: Story = {
|
||||
args: {
|
||||
...Default.args,
|
||||
// avatarSrc: undefined,
|
||||
},
|
||||
parameters: {
|
||||
mockBackend: {
|
||||
credits: 0,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -125,6 +137,12 @@ export const WithInteraction: Story = {
|
||||
args: {
|
||||
...Default.args,
|
||||
},
|
||||
parameters: {
|
||||
mockBackend: {
|
||||
credits: 1500,
|
||||
profile: mockProfileData,
|
||||
},
|
||||
},
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
const profileTrigger = canvas.getByRole("button");
|
||||
@@ -135,29 +153,3 @@ export const WithInteraction: Story = {
|
||||
await canvas.findByText("Edit profile");
|
||||
},
|
||||
};
|
||||
|
||||
export const NotLoggedIn: Story = {
|
||||
args: {
|
||||
...Default.args,
|
||||
// isLoggedIn: false,
|
||||
// avatarSrc: undefined,
|
||||
},
|
||||
};
|
||||
|
||||
export const WithCredits: Story = {
|
||||
args: {
|
||||
...Default.args,
|
||||
},
|
||||
};
|
||||
|
||||
export const WithLargeCredits: Story = {
|
||||
args: {
|
||||
...Default.args,
|
||||
},
|
||||
};
|
||||
|
||||
export const WithZeroCredits: Story = {
|
||||
args: {
|
||||
...Default.args,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -9,6 +9,10 @@ import { ProfileDetails } from "@/lib/autogpt-server-api/types";
|
||||
import { NavbarLink } from "./NavbarLink";
|
||||
import getServerUser from "@/lib/supabase/getServerUser";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { User } from "@supabase/supabase-js";
|
||||
import MockClient, {
|
||||
MockClientProps,
|
||||
} from "@/lib/autogpt-server-api/mock_client";
|
||||
|
||||
// Disable theme toggle for now
|
||||
// import { ThemeToggle } from "./ThemeToggle";
|
||||
@@ -29,21 +33,33 @@ interface NavbarProps {
|
||||
onClick?: () => void;
|
||||
}[];
|
||||
}[];
|
||||
mockUser?: User;
|
||||
mockClientProps?: MockClientProps;
|
||||
}
|
||||
|
||||
async function getProfileData() {
|
||||
async function getProfileData(mockClientProps?: MockClientProps) {
|
||||
if (mockClientProps) {
|
||||
const api = new MockClient(mockClientProps);
|
||||
const profile = await Promise.resolve(api.getStoreProfile("navbar"));
|
||||
return profile;
|
||||
}
|
||||
const api = new BackendAPI();
|
||||
const profile = await Promise.resolve(api.getStoreProfile());
|
||||
|
||||
return profile;
|
||||
}
|
||||
|
||||
export const Navbar = async ({ links, menuItemGroups }: NavbarProps) => {
|
||||
const { user } = await getServerUser();
|
||||
export const Navbar = async ({
|
||||
links,
|
||||
menuItemGroups,
|
||||
mockUser,
|
||||
mockClientProps,
|
||||
}: NavbarProps) => {
|
||||
const { user } = await getServerUser(mockUser);
|
||||
const isLoggedIn = user !== null;
|
||||
let profile: ProfileDetails | null = null;
|
||||
if (isLoggedIn) {
|
||||
profile = await getProfileData();
|
||||
profile = await getProfileData(mockClientProps);
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
"use client";
|
||||
import { cn } from "@/lib/utils";
|
||||
import Image from "next/image";
|
||||
import { useState } from "react";
|
||||
|
||||
interface SmartImageProps {
|
||||
src?: string | null;
|
||||
alt: string;
|
||||
imageContain?: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export default function SmartImage({
|
||||
src,
|
||||
alt,
|
||||
imageContain,
|
||||
className,
|
||||
}: SmartImageProps) {
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const shouldShowSkeleton = isLoading || !src;
|
||||
|
||||
return (
|
||||
<div className={cn("relative overflow-hidden", className)}>
|
||||
{src && (
|
||||
<Image
|
||||
src={src}
|
||||
alt={alt}
|
||||
fill
|
||||
onLoad={() => setIsLoading(false)}
|
||||
className={cn(
|
||||
"rounded-inherit object-center transition-opacity duration-300",
|
||||
isLoading ? "opacity-0" : "opacity-100",
|
||||
imageContain ? "object-contain" : "object-cover",
|
||||
)}
|
||||
sizes="100%"
|
||||
/>
|
||||
)}
|
||||
{shouldShowSkeleton && (
|
||||
<div className="rounded-inherit absolute inset-0 animate-pulse bg-gray-300 dark:bg-gray-700" />
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import { PopoverClose } from "@radix-ui/react-popover";
|
||||
import { TaskGroups } from "../onboarding/WalletTaskGroups";
|
||||
import { ScrollArea } from "../ui/scroll-area";
|
||||
import { useOnboarding } from "../onboarding/onboarding-provider";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { useCallback, useEffect, useRef } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import * as party from "party-js";
|
||||
import WalletRefill from "./WalletRefill";
|
||||
@@ -21,11 +21,6 @@ export default function Wallet() {
|
||||
fetchInitialCredits: true,
|
||||
});
|
||||
const { state, updateState } = useOnboarding();
|
||||
const [prevCredits, setPrevCredits] = useState<number | null>(credits);
|
||||
const [flash, setFlash] = useState(false);
|
||||
const [stepsLength, setStepsLength] = useState<number | null>(
|
||||
state?.completedSteps?.length || null,
|
||||
);
|
||||
const walletRef = useRef<HTMLButtonElement | null>(null);
|
||||
|
||||
const onWalletOpen = useCallback(async () => {
|
||||
@@ -42,81 +37,49 @@ export default function Wallet() {
|
||||
.through("lifetime")
|
||||
.build();
|
||||
|
||||
// Confetti effect on the wallet button
|
||||
useEffect(() => {
|
||||
if (!state?.completedSteps) {
|
||||
return;
|
||||
}
|
||||
// If we haven't set the length yet, just set it and return
|
||||
if (stepsLength === null) {
|
||||
setStepsLength(state?.completedSteps?.length);
|
||||
return;
|
||||
}
|
||||
// It's enough to compare array lengths,
|
||||
// because the order of completed steps is not important
|
||||
// If the length is the same, we don't need to do anything
|
||||
if (state?.completedSteps?.length === stepsLength) {
|
||||
return;
|
||||
}
|
||||
// Otherwise, we need to set the new length
|
||||
setStepsLength(state?.completedSteps?.length);
|
||||
// And make confetti
|
||||
if (walletRef.current) {
|
||||
setTimeout(() => {
|
||||
fetchCredits();
|
||||
party.confetti(walletRef.current!, {
|
||||
count: 30,
|
||||
spread: 120,
|
||||
shapes: ["square", "circle"],
|
||||
size: party.variation.range(1, 2),
|
||||
speed: party.variation.range(200, 300),
|
||||
modules: [fadeOut],
|
||||
});
|
||||
}, 800);
|
||||
// Check if there are any completed tasks (state?.completedTasks) that
|
||||
// are not in the state?.notified array and play confetti if so
|
||||
const pending = state?.completedSteps
|
||||
.filter((step) => !state?.notified.includes(step))
|
||||
// Ignore steps that are not relevant for notifications
|
||||
.filter(
|
||||
(step) =>
|
||||
step !== "WELCOME" &&
|
||||
step !== "USAGE_REASON" &&
|
||||
step !== "INTEGRATIONS" &&
|
||||
step !== "AGENT_CHOICE" &&
|
||||
step !== "AGENT_NEW_RUN" &&
|
||||
step !== "AGENT_INPUT",
|
||||
);
|
||||
if ((pending?.length || 0) > 0 && walletRef.current) {
|
||||
party.confetti(walletRef.current, {
|
||||
count: 30,
|
||||
spread: 120,
|
||||
shapes: ["square", "circle"],
|
||||
size: party.variation.range(1, 2),
|
||||
speed: party.variation.range(200, 300),
|
||||
modules: [fadeOut],
|
||||
});
|
||||
}
|
||||
}, [state?.completedSteps, state?.notified]);
|
||||
|
||||
// Wallet flash on credits change
|
||||
useEffect(() => {
|
||||
if (credits === prevCredits) {
|
||||
return;
|
||||
}
|
||||
setPrevCredits(credits);
|
||||
if (prevCredits === null) {
|
||||
return;
|
||||
}
|
||||
setFlash(true);
|
||||
setTimeout(() => {
|
||||
setFlash(false);
|
||||
}, 300);
|
||||
}, [credits]);
|
||||
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<div className="relative inline-block">
|
||||
<button
|
||||
ref={walletRef}
|
||||
className={cn(
|
||||
"relative flex items-center gap-1 rounded-md bg-zinc-200 px-3 py-2 text-sm transition-colors duration-200 hover:bg-zinc-300",
|
||||
)}
|
||||
onClick={onWalletOpen}
|
||||
>
|
||||
Wallet{" "}
|
||||
<span className="text-sm font-semibold">
|
||||
{formatCredits(credits)}
|
||||
</span>
|
||||
{state?.notificationDot && (
|
||||
<span className="absolute right-1 top-1 h-2 w-2 rounded-full bg-violet-600"></span>
|
||||
)}
|
||||
</button>
|
||||
<div
|
||||
className={cn(
|
||||
"pointer-events-none absolute inset-0 rounded-md bg-violet-400 duration-2000 ease-in-out",
|
||||
flash ? "opacity-50 duration-0" : "opacity-0",
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
ref={walletRef}
|
||||
className="relative flex items-center gap-1 rounded-md bg-zinc-200 px-3 py-2 text-sm transition-colors duration-200 hover:bg-zinc-300"
|
||||
onClick={onWalletOpen}
|
||||
>
|
||||
Wallet{" "}
|
||||
<span className="text-sm font-semibold">
|
||||
{formatCredits(credits)}
|
||||
</span>
|
||||
{state?.notificationDot && (
|
||||
<span className="absolute right-1 top-1 h-2 w-2 rounded-full bg-violet-600"></span>
|
||||
)}
|
||||
</button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
className={cn(
|
||||
@@ -141,9 +104,7 @@ export default function Wallet() {
|
||||
</div>
|
||||
<ScrollArea className="max-h-[85vh] overflow-y-auto">
|
||||
{/* Top ups */}
|
||||
{process.env.NEXT_PUBLIC_SHOW_BILLING_PAGE === "true" && (
|
||||
<WalletRefill />
|
||||
)}
|
||||
<WalletRefill />
|
||||
{/* Tasks */}
|
||||
<p className="mx-1 mt-4 font-sans text-xs font-medium text-violet-700">
|
||||
Onboarding tasks
|
||||
|
||||
@@ -22,7 +22,6 @@ export default function ActionButtonGroup({
|
||||
<Button
|
||||
key={i}
|
||||
variant={action.variant ?? "outline"}
|
||||
disabled={action.disabled}
|
||||
onClick={action.callback}
|
||||
>
|
||||
{action.label}
|
||||
@@ -30,11 +29,7 @@ export default function ActionButtonGroup({
|
||||
) : (
|
||||
<Link
|
||||
key={i}
|
||||
className={cn(
|
||||
buttonVariants({ variant: action.variant }),
|
||||
action.disabled &&
|
||||
"pointer-events-none border-zinc-400 text-zinc-400",
|
||||
)}
|
||||
className={buttonVariants({ variant: action.variant })}
|
||||
href={action.href}
|
||||
>
|
||||
{action.label}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user