Compare commits

..

61 Commits

Author SHA1 Message Date
Abhimanyu Yadav
54bbafc431 Merge branch 'dev' into ci-chromatic 2025-04-22 20:26:42 +05:30
Abhimanyu Yadav
5662783624 Merge branch 'dev' into ci-chromatic 2025-04-22 10:29:24 +05:30
Nicholas Tindle
a5f448af98 Merge branch 'dev' into ci-chromatic 2025-03-06 11:24:49 -06:00
Nicholas Tindle
c766bd66e1 fix(frontend): typechecking 2025-02-05 17:15:24 -06:00
Nicholas Tindle
6d11ad8051 fix(frontend): format 2025-02-05 17:12:33 -06:00
Nicholas Tindle
d476983bd2 fix: doesn't crash 2025-02-05 17:10:21 -06:00
Nicholas Tindle
3ac1ce5a3f fix: format 2025-02-05 16:50:51 -06:00
Nicholas Tindle
3b89e6d2b7 Merge branch 'ci-chromatic' of https://github.com/Significant-Gravitas/AutoGPT into ci-chromatic 2025-02-05 16:49:32 -06:00
Nicholas Tindle
c7a7652b9f Merge branch 'dev' into ci-chromatic 2025-02-05 16:47:46 -06:00
Nicholas Tindle
b6b0d0b209 Merge branch 'dev' into ci-chromatic 2025-02-03 11:55:31 -06:00
Nicholas Tindle
a5b1495062 Merge branch 'dev' into ci-chromatic 2025-02-03 07:13:53 -06:00
Nicholas Tindle
026f16c10f Merge branch 'dev' into ci-chromatic 2025-01-31 04:50:48 -06:00
Nicholas Tindle
c468201c53 Update mock_client.ts 2025-01-29 07:10:03 -06:00
Nicholas Tindle
5beb581d1c feat(frontend): minimocking 2025-01-29 07:04:38 -06:00
Nicholas Tindle
df2339c1cf feat: add mock backend for rendering the storybook stuff 2025-01-29 06:46:50 -06:00
Nicholas Tindle
327db54321 Merge branch 'open-2047-add-type-checking-step-to-front-end-ci' into ci-chromatic 2025-01-29 12:26:19 +00:00
Nicholas Tindle
234d6f78ba Merge branch 'dev' into open-2047-add-type-checking-step-to-front-end-ci 2025-01-29 12:17:23 +00:00
Nicholas Tindle
43088ddff8 fix: incorrect meshing of types and test 2025-01-29 06:16:28 -06:00
Nicholas Tindle
fd955fba25 ref: add providers to the story previews 2025-01-29 05:35:56 -06:00
Nicholas Tindle
83943d9ddb Merge branch 'open-2047-add-type-checking-step-to-front-end-ci' into ci-chromatic 2025-01-29 10:57:00 +00:00
Nicholas Tindle
60c26e62f6 Merge branch 'dev' into open-2047-add-type-checking-step-to-front-end-ci 2025-01-29 10:53:51 +00:00
Nicholas Tindle
1fc8f9ba66 fix: handle conditions better for feature flagging 2025-01-28 18:04:41 +00:00
Nicholas Tindle
33d747f457 ref: remove unused code 2025-01-28 16:39:11 +00:00
Nicholas Tindle
06fa001a37 ref: use data structure for copy and paste data 2025-01-28 16:36:07 +00:00
Nicholas Tindle
4e7b56b814 ref: pr changes 2025-01-28 16:31:25 +00:00
Nicholas Tindle
d6b03a4f18 ref: pr change request 2025-01-28 16:30:13 +00:00
Nicholas Tindle
fae9aeb49a fix: linting 2025-01-28 16:30:05 +00:00
Nicholas Tindle
5e8c1e274e fix: linting 2025-01-28 16:29:58 +00:00
Nicholas Tindle
55f7dc4853 fix: drop classname unused 2025-01-28 16:25:23 +00:00
Nicholas Tindle
b317adb9cf ref: remove classname from navbar link 2025-01-28 16:23:14 +00:00
Nicholas Tindle
c873ba04b8 ref: split out type-check step + fix tsc error 2025-01-28 15:38:05 +00:00
Nicholas Tindle
00f0311dd0 ref: split out type-check step + fix tsc error 2025-01-28 15:31:52 +00:00
Nicholas Tindle
9b2bd756fa Update platform-frontend-ci.yml 2025-01-28 15:17:42 +00:00
Nicholas Tindle
bceb83ca30 fix: workingdir required 2025-01-28 15:17:42 +00:00
Nicholas Tindle
eadbfcd920 Update platform-frontend-ci.yml 2025-01-28 15:17:41 +00:00
SwiftyOS
9768540b60 Merge branch 'dev' into open-2047-add-type-checking-step-to-front-end-ci 2025-01-28 15:46:21 +01:00
Nicholas Tindle
697436be07 Merge branch 'dev' into open-2047-add-type-checking-step-to-front-end-ci 2025-01-28 07:53:27 +00:00
Nicholas Tindle
d725e105a0 Merge branch 'dev' into open-2047-add-type-checking-step-to-front-end-ci 2025-01-26 15:27:50 +01:00
Nicholas Tindle
927f43f52f fix: formatting 2025-01-26 12:18:11 +00:00
Nicholas Tindle
eedcc92d6f fix: add secret to all the subschemas 2025-01-26 12:15:37 +00:00
Nicholas Tindle
f0c378c70d fix: missing type addition 2025-01-26 12:12:12 +00:00
Nicholas Tindle
c6c2b852df fix: missing inputs based on changes 2025-01-26 12:11:37 +00:00
Nicholas Tindle
aaab8b1e0e fix: more formatting 2025-01-26 11:56:07 +00:00
Nicholas Tindle
a4eeb4535a fix: formatting 2025-01-26 11:55:56 +00:00
Nicholas Tindle
db068c598c fix: missing types 2025-01-26 11:46:35 +00:00
Nicholas Tindle
d4d9efc73e fix: missing attribute 2025-01-26 11:46:25 +00:00
Nicholas Tindle
ffaf77df4e fix: type the params 2025-01-26 11:46:10 +00:00
Nicholas Tindle
2daf08434e fix: type the params 2025-01-26 11:46:01 +00:00
Nicholas Tindle
745137f4c2 fix: pass correct subclass 2025-01-26 11:45:46 +00:00
Nicholas Tindle
3a2c3deb0e fix: remove import + impossible case 2025-01-26 11:45:34 +00:00
Nicholas Tindle
66a15a7b8c fix: user correct object when deleting 2025-01-26 11:44:25 +00:00
Nicholas Tindle
669c61de76 fix: take in classnames as used by the outer component
we probbaly shouldn't be doing this?
2025-01-26 11:44:04 +00:00
Nicholas Tindle
e860bde3d4 fix: coalesce types and use a default
@aarushik93 is this okay?
2025-01-26 11:43:40 +00:00
Nicholas Tindle
f5394f6d65 fix: expose interface for sub object so it can be used in other places to fix type errors 2025-01-26 11:43:00 +00:00
Nicholas Tindle
06e845abe7 feat: take in props for navbar
Is this desired?
2025-01-26 11:42:28 +00:00
Nicholas Tindle
c2c3c29018 fix: use proper state object 2025-01-26 11:42:08 +00:00
Nicholas Tindle
31fd0b557a fix: add missing import 2025-01-26 11:41:44 +00:00
Nicholas Tindle
9350fe1d2b fix: fully disable unused page 2025-01-26 11:41:12 +00:00
Nicholas Tindle
5ae92820b4 fix: remove unused classnames 2025-01-26 11:40:57 +00:00
Nicholas Tindle
66a87e5a14 ci: typechecker for frontend 2025-01-26 11:22:38 +00:00
Nicholas Tindle
e1f8882e2d fix: stories being broken 2025-01-26 11:18:12 +00:00
121 changed files with 1489 additions and 5337 deletions

View File

@@ -56,6 +56,30 @@ jobs:
run: |
yarn type-check
design:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Install dependencies
run: |
yarn install --frozen-lockfile
- name: Run Chromatic
uses: chromaui/action@latest
with:
# ⚠️ Make sure to configure a `CHROMATIC_PROJECT_TOKEN` repository secret
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
workingDir: autogpt_platform/frontend
test:
runs-on: ubuntu-latest
strategy:

View File

@@ -16,7 +16,7 @@ jobs:
# operations-per-run: 5000
stale-issue-message: >
This issue has automatically been marked as _stale_ because it has not had
any activity in the last 170 days. You can _unstale_ it by commenting or
any activity in the last 50 days. You can _unstale_ it by commenting or
removing the label. Otherwise, this issue will be closed in 10 days.
stale-pr-message: >
This pull request has automatically been marked as _stale_ because it has
@@ -25,7 +25,7 @@ jobs:
close-issue-message: >
This issue was closed automatically because it has been stale for 10 days
with no activity.
days-before-stale: 170
days-before-stale: 100
days-before-close: 10
# Do not touch meta issues:
exempt-issue-labels: meta,fridge,project management

View File

@@ -31,7 +31,7 @@ class RedisKeyedMutex:
try:
yield
finally:
if lock.locked() and lock.owned():
if lock.locked():
lock.release()
def acquire(self, key: Any) -> "RedisLock":

View File

@@ -66,13 +66,6 @@ MEDIA_GCS_BUCKET_NAME=
## and tunnel it to your locally running backend.
PLATFORM_BASE_URL=http://localhost:3000
## Cloudflare Turnstile (CAPTCHA) Configuration
## Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
## This is the backend secret key
TURNSTILE_SECRET_KEY=
## This is the verify URL
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
## == INTEGRATION CREDENTIALS == ##
# Each set of server side credentials is required for the corresponding 3rd party
# integration to work.

View File

@@ -1,4 +1,3 @@
import functools
import importlib
import os
import re
@@ -11,11 +10,17 @@ if TYPE_CHECKING:
T = TypeVar("T")
@functools.cache
_AVAILABLE_BLOCKS: dict[str, type["Block"]] = {}
def load_all_blocks() -> dict[str, type["Block"]]:
from backend.data.block import Block
if _AVAILABLE_BLOCKS:
return _AVAILABLE_BLOCKS
# Dynamically load all modules under backend.blocks
AVAILABLE_MODULES = []
current_dir = Path(__file__).parent
modules = [
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
@@ -30,9 +35,9 @@ def load_all_blocks() -> dict[str, type["Block"]]:
)
importlib.import_module(f".{module}", package=__name__)
AVAILABLE_MODULES.append(module)
# Load all Block instances from the available modules
available_blocks: dict[str, type["Block"]] = {}
for block_cls in all_subclasses(Block):
class_name = block_cls.__name__
@@ -53,7 +58,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
f"Block ID {block.name} error: {block.id} is not a valid UUID"
)
if block.id in available_blocks:
if block.id in _AVAILABLE_BLOCKS:
raise ValueError(
f"Block ID {block.name} error: {block.id} is already in use"
)
@@ -84,9 +89,9 @@ def load_all_blocks() -> dict[str, type["Block"]]:
f"{block.name} has a boolean field with no default value"
)
available_blocks[block.id] = block_cls
_AVAILABLE_BLOCKS[block.id] = block_cls
return available_blocks
return _AVAILABLE_BLOCKS
__all__ = ["load_all_blocks"]

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, Optional
from typing import Any
from backend.data.block import (
Block,
@@ -11,7 +11,7 @@ from backend.data.block import (
get_block,
)
from backend.data.execution import ExecutionStatus
from backend.data.model import CredentialsMetaInput, SchemaField
from backend.data.model import SchemaField
from backend.util import json
logger = logging.getLogger(__name__)
@@ -23,21 +23,17 @@ class AgentExecutorBlock(Block):
graph_id: str = SchemaField(description="Graph ID")
graph_version: int = SchemaField(description="Graph Version")
inputs: BlockInput = SchemaField(description="Input data for the graph")
data: BlockInput = SchemaField(description="Input data for the graph")
input_schema: dict = SchemaField(description="Input schema for the graph")
output_schema: dict = SchemaField(description="Output schema for the graph")
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = SchemaField(default=None, hidden=True)
@classmethod
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
return data.get("input_schema", {})
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data.get("inputs", {})
return data.get("data", {})
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
@@ -71,8 +67,7 @@ class AgentExecutorBlock(Block):
graph_id=input_data.graph_id,
graph_version=input_data.graph_version,
user_id=input_data.user_id,
inputs=input_data.inputs,
node_credentials_input_map=input_data.node_credentials_input_map,
inputs=input_data.data,
)
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.id}"
logger.info(f"Starting execution of {log_id}")

View File

@@ -88,33 +88,6 @@ class StoreValueBlock(Block):
yield "output", input_data.data or input_data.input
class PrintToConsoleBlock(Block):
class Input(BlockSchema):
text: Any = SchemaField(description="The data to print to the console.")
class Output(BlockSchema):
output: Any = SchemaField(description="The data printed to the console.")
status: str = SchemaField(description="The status of the print operation.")
def __init__(self):
super().__init__(
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
description="Print the given text to the console, this is used for a debugging purpose.",
categories={BlockCategory.BASIC},
input_schema=PrintToConsoleBlock.Input,
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
test_output=[
("output", "Hello, World!"),
("status", "printed"),
],
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "output", input_data.text
yield "status", "printed"
class FindInDictionaryBlock(Block):
class Input(BlockSchema):
input: Any = SchemaField(description="Dictionary to lookup from")

View File

@@ -1,212 +0,0 @@
import json
import logging
from pathlib import Path
from pydantic import BaseModel
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockWebhookConfig,
)
from backend.data.model import SchemaField, APIKeyCredentials
from backend.integrations.providers import ProviderName
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
ExaCredentialsField,
ExaCredentialsInput,
)
logger = logging.getLogger(__name__)
class ExaTriggerBase:
"""Base class for Exa webhook triggers."""
class Input(BlockSchema):
"""Base input schema for Exa triggers."""
credentials: ExaCredentialsInput = ExaCredentialsField()
# --8<-- [start:example-payload-field]
payload: dict = SchemaField(hidden=True, default_factory=dict)
# --8<-- [end:example-payload-field]
class Output(BlockSchema):
"""Base output schema for Exa triggers."""
payload: dict = SchemaField(
description="The complete webhook payload that was received from Exa. "
"Includes information about the event type, data, and creation timestamp."
)
error: str = SchemaField(
description="Error message if the payload could not be processed"
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
"""Process the webhook payload from Exa.
Args:
input_data: The input data containing the webhook payload
Yields:
The complete webhook payload
"""
yield "payload", input_data.payload
class ExaWebsetTriggerBlock(ExaTriggerBase, Block):
"""Block for handling Exa Webset webhook events.
This block triggers on various Exa Webset events such as webset creation,
deletion, search completion, etc. and outputs the event details.
"""
EXAMPLE_PAYLOAD_FILE = (
Path(__file__).parent / "example_payloads" / "webset.created.json"
)
class Input(ExaTriggerBase.Input):
"""Input schema for Exa Webset trigger with event filtering options."""
class EventsFilter(BaseModel):
"""
Event filter options for Exa Webset webhooks.
See: https://docs.exa.ai/api-reference/webhooks
"""
# Webset events
webset_created: bool = False
webset_deleted: bool = False
webset_paused: bool = False
webset_idle: bool = False
# Search events
webset_search_created: bool = False
webset_search_updated: bool = False
webset_search_completed: bool = False
webset_search_canceled: bool = False
# Item events
webset_item_created: bool = False
webset_item_enriched: bool = False
# Export events
webset_export_created: bool = False
webset_export_completed: bool = False
events: EventsFilter = SchemaField(
title="Events", description="The events to subscribe to"
)
class Output(ExaTriggerBase.Output):
"""Output schema for Exa Webset trigger with event-specific fields."""
event_type: str = SchemaField(
description="The type of event that triggered the webhook"
)
webset_id: str = SchemaField(
description="The ID of the affected webset"
)
created_at: str = SchemaField(
description="Timestamp when the event was created"
)
data: dict = SchemaField(
description="Object containing the full resource that triggered the event"
)
def __init__(self):
"""Initialize the Exa Webset trigger block with its configuration."""
# Define a webhook type constant for Exa
class ExaWebhookType:
"""Constants for Exa webhook types."""
WEBSET = "webset"
# Create example payload
example_payload = {
"id": "663de972-bfe7-47ef-b4d7-179cfed7aa44",
"object": "event",
"type": "webset.created",
"data": {
"id": "wbs_123456789",
"name": "Example Webset",
"description": "An example webset for testing"
},
"createdAt": "2023-06-01T12:00:00Z"
}
# Map UI event names to API event names
self.event_mapping = {
"webset_created": "webset.created",
"webset_deleted": "webset.deleted",
"webset_paused": "webset.paused",
"webset_idle": "webset.idle",
"webset_search_created": "webset.search.created",
"webset_search_updated": "webset.search.updated",
"webset_search_completed": "webset.search.completed",
"webset_search_canceled": "webset.search.canceled",
"webset_item_created": "webset.item.created",
"webset_item_enriched": "webset.item.enriched",
"webset_export_created": "webset.export.created",
"webset_export_completed": "webset.export.completed"
}
super().__init__(
id="804ac1ed-d692-4ccb-a390-739a846a2667",
description="This block triggers on Exa Webset events and outputs the event type and payload.",
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
input_schema=ExaWebsetTriggerBlock.Input,
output_schema=ExaWebsetTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.EXA,
webhook_type=ExaWebhookType.WEBSET,
resource_format="", # Exa doesn't require a specific resource format
event_filter_input="events",
event_format="webset.{event}",
),
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"events": {"webset_created": True, "webset_search_completed": True},
"payload": example_payload,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", example_payload),
("event_type", example_payload["type"]),
("webset_id", "wbs_123456789"),
("created_at", "2023-06-01T12:00:00Z"),
("data", example_payload["data"]),
],
)
def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
"""Process Exa Webset webhook events.
Args:
input_data: The input data containing the webhook payload and event filter
Yields:
Event details including event type, webset ID, creation timestamp, and data,
or an error message if the event type doesn't match the filter or if required
fields are missing from the payload.
"""
yield from super().run(input_data, **kwargs)
try:
# Get the event type from the payload
event_type = input_data.payload["type"]
# Check if this event type is in the user's selected events
# Convert API event name to UI event name (reverse mapping)
ui_event_name = next((k for k, v in self.event_mapping.items() if v == event_type), None)
# Only process events that match the filter
if ui_event_name and getattr(input_data.events, ui_event_name, False):
yield "event_type", event_type
yield "webset_id", input_data.payload["data"]["id"]
yield "created_at", input_data.payload["createdAt"]
yield "data", input_data.payload["data"]
else:
yield "error", f"Event type {event_type} not in selected events filter"
except KeyError as e:
yield "error", f"Missing expected field in payload: {str(e)}"

View File

@@ -1,321 +0,0 @@
from typing import Any, Dict, List, Optional
import requests
from backend.blocks.exa._auth import (
ExaCredentials,
ExaCredentialsField,
ExaCredentialsInput,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
# --- Create Webset Block ---
class ExaCreateWebsetBlock(Block):
"""Block for creating a Webset using Exa's Websets API."""
class Input(BlockSchema):
credentials: ExaCredentialsInput = ExaCredentialsField()
query: str = SchemaField(description="The search query for the webset")
count: int = SchemaField(description="Number of results to return", default=5)
enrichments: Optional[List[Dict[str, Any]]] = SchemaField(
description="List of enrichment dicts (optional)", default_factory=list, advanced=True
)
external_id: Optional[str] = SchemaField(
description="Optional external identifier", default=None, advanced=True
)
metadata: Optional[Dict[str, Any]] = SchemaField(
description="Optional metadata", default_factory=dict, advanced=True
)
class Output(BlockSchema):
webset: Optional[Dict[str, Any]] = SchemaField(
description="The created webset object (or None if error)", default=None
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
"""Initialize the ExaCreateWebsetBlock with its configuration."""
super().__init__(
id="322351cc-35d7-45ec-8920-9a3c98920411",
description="Creates a Webset using Exa's Websets API",
categories={BlockCategory.SEARCH},
input_schema=ExaCreateWebsetBlock.Input,
output_schema=ExaCreateWebsetBlock.Output,
)
def run(self, input_data: Input, *, credentials: ExaCredentials, **kwargs) -> BlockOutput:
"""
Execute the block to create a webset with Exa's API.
Args:
input_data: The input parameters for creating a webset
credentials: The Exa API credentials
Yields:
Either the created webset object or an error message
"""
url = "https://api.exa.ai/websets/v0/websets"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
payload = {
"search": {
"query": input_data.query,
"count": input_data.count,
}
}
optional_fields = {}
if isinstance(input_data.enrichments, list) and input_data.enrichments:
optional_fields["enrichments"] = input_data.enrichments
if isinstance(input_data.external_id, str) and input_data.external_id:
optional_fields["externalId"] = input_data.external_id
if isinstance(input_data.metadata, dict) and input_data.metadata:
optional_fields["metadata"] = input_data.metadata
payload.update(optional_fields)
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
yield "webset", data
except Exception as e:
yield "error", str(e)
# --- Get Webset Block ---
class ExaGetWebsetBlock(Block):
"""Block for retrieving a Webset by ID using Exa's Websets API."""
class Input(BlockSchema):
credentials: ExaCredentialsInput = ExaCredentialsField()
webset_id: str = SchemaField(description="The Webset ID or externalId")
expand_items: bool = SchemaField(description="Expand with items", default=False, advanced=True)
class Output(BlockSchema):
webset: Optional[Dict[str, Any]] = SchemaField(description="The webset object (or None if error)", default=None)
error: str = SchemaField(description="Error message if the request failed", default="")
def __init__(self):
"""Initialize the ExaGetWebsetBlock with its configuration."""
super().__init__(
id="f9229293-cddf-43fc-94b3-48cbd1a44618",
description="Retrieves a Webset by ID using Exa's Websets API",
categories={BlockCategory.SEARCH},
input_schema=ExaGetWebsetBlock.Input,
output_schema=ExaGetWebsetBlock.Output,
)
def run(self, input_data: Input, *, credentials: ExaCredentials, **kwargs) -> BlockOutput:
"""
Execute the block to retrieve a webset by ID from Exa's API.
Args:
input_data: The input parameters including the webset ID
credentials: The Exa API credentials
Yields:
Either the retrieved webset object or an error message
"""
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
params = {"expand": "items"} if input_data.expand_items else None
try:
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
data = response.json()
yield "webset", data
except Exception as e:
yield "error", str(e)
# --- Delete Webset Block ---
class ExaDeleteWebsetBlock(Block):
"""Block for deleting a Webset by ID using Exa's Websets API."""
class Input(BlockSchema):
credentials: ExaCredentialsInput = ExaCredentialsField()
webset_id: str = SchemaField(description="The Webset ID or externalId")
class Output(BlockSchema):
deleted: Optional[Dict[str, Any]] = SchemaField(description="The deleted webset object (or None if error)", default=None)
error: str = SchemaField(description="Error message if the request failed", default="")
def __init__(self):
"""Initialize the ExaDeleteWebsetBlock with its configuration."""
super().__init__(
id="a082e162-274e-4167-a467-a1839e644cbd",
description="Deletes a Webset by ID using Exa's Websets API",
categories={BlockCategory.SEARCH},
input_schema=ExaDeleteWebsetBlock.Input,
output_schema=ExaDeleteWebsetBlock.Output,
)
def run(self, input_data: Input, *, credentials: ExaCredentials, **kwargs) -> BlockOutput:
"""
Execute the block to delete a webset by ID using Exa's API.
Args:
input_data: The input parameters including the webset ID
credentials: The Exa API credentials
Yields:
Either the deleted webset object or an error message
"""
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
try:
response = requests.delete(url, headers=headers)
response.raise_for_status()
data = response.json()
yield "deleted", data
except Exception as e:
yield "error", str(e)
# --- Update Webset Block ---
class ExaUpdateWebsetBlock(Block):
"""Block for updating a Webset's metadata using Exa's Websets API."""
class Input(BlockSchema):
credentials: ExaCredentialsInput = ExaCredentialsField()
webset_id: str = SchemaField(description="The Webset ID or externalId")
metadata: Dict[str, Any] = SchemaField(description="Metadata to update", default_factory=dict)
class Output(BlockSchema):
webset: Optional[Dict[str, Any]] = SchemaField(description="The updated webset object (or None if error)", default=None)
error: str = SchemaField(description="Error message if the request failed", default="")
def __init__(self):
"""Initialize the ExaUpdateWebsetBlock with its configuration."""
super().__init__(
id="e0c81b70-ac38-4239-8ecd-a75c1737c9ef",
description="Updates a Webset's metadata using Exa's Websets API",
categories={BlockCategory.SEARCH},
input_schema=ExaUpdateWebsetBlock.Input,
output_schema=ExaUpdateWebsetBlock.Output,
)
def run(self, input_data: Input, *, credentials: ExaCredentials, **kwargs) -> BlockOutput:
"""
Execute the block to update a webset's metadata using Exa's API.
Args:
input_data: The input parameters including the webset ID and metadata
credentials: The Exa API credentials
Yields:
Either the updated webset object or an error message
"""
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
payload = {"metadata": input_data.metadata}
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
yield "webset", data
except Exception as e:
yield "error", str(e)
# --- List Websets Block ---
class ExaListWebsetsBlock(Block):
"""Block for listing all Websets using Exa's Websets API with pagination support."""
class Input(BlockSchema):
credentials: ExaCredentialsInput = ExaCredentialsField()
limit: int = SchemaField(description="Number of websets to return (max 100)", default=25)
cursor: Optional[str] = SchemaField(description="Pagination cursor (optional)", default=None, advanced=True)
class Output(BlockSchema):
data: Optional[List[Dict[str, Any]]] = SchemaField(description="List of websets", default=None)
has_more: Optional[bool] = SchemaField(description="Whether there are more results", default=None)
next_cursor: Optional[str] = SchemaField(description="Cursor for next page", default=None)
error: str = SchemaField(description="Error message if the request failed", default="")
def __init__(self):
"""Initialize the ExaListWebsetsBlock with its configuration."""
super().__init__(
id="887a2dae-c9c3-4ae5-a079-fe3b52be64e4",
description="Lists all Websets using Exa's Websets API",
categories={BlockCategory.SEARCH},
input_schema=ExaListWebsetsBlock.Input,
output_schema=ExaListWebsetsBlock.Output,
)
def run(self, input_data: Input, *, credentials: ExaCredentials, **kwargs) -> BlockOutput:
"""
Execute the block to list websets with pagination using Exa's API.
Args:
input_data: The input parameters including limit and optional cursor
credentials: The Exa API credentials
Yields:
The list of websets, pagination info, or an error message
"""
url = "https://api.exa.ai/websets/v0/websets"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
params: dict[str, Any] = {"limit": int(input_data.limit)}
if isinstance(input_data.cursor, str) and input_data.cursor:
params["cursor"] = input_data.cursor
try:
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
data = response.json()
yield "data", data.get("data")
yield "has_more", data.get("hasMore")
yield "next_cursor", data.get("nextCursor")
except Exception as e:
yield "error", str(e)
# --- Cancel Webset Block ---
class ExaCancelWebsetBlock(Block):
"""Block for canceling a running Webset using Exa's Websets API."""
class Input(BlockSchema):
credentials: ExaCredentialsInput = ExaCredentialsField()
webset_id: str = SchemaField(description="The Webset ID or externalId")
class Output(BlockSchema):
webset: Optional[Dict[str, Any]] = SchemaField(description="The canceled webset object (or None if error)", default=None)
error: str = SchemaField(description="Error message if the request failed", default="")
def __init__(self):
"""Initialize the ExaCancelWebsetBlock with its configuration."""
super().__init__(
id="f7f0b19c-71e8-4c2f-bc68-904a6a61faf7",
description="Cancels a running Webset using Exa's Websets API",
categories={BlockCategory.SEARCH},
input_schema=ExaCancelWebsetBlock.Input,
output_schema=ExaCancelWebsetBlock.Output,
)
def run(self, input_data: Input, *, credentials: ExaCredentials, **kwargs) -> BlockOutput:
"""
Execute the block to cancel a running webset using Exa's API.
Args:
input_data: The input parameters including the webset ID
credentials: The Exa API credentials
Yields:
Either the canceled webset object or an error message
"""
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}/cancel"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
try:
response = requests.post(url, headers=headers)
response.raise_for_status()
data = response.json()
yield "webset", data
except Exception as e:
yield "error", str(e)

View File

@@ -1,598 +0,0 @@
import enum
import uuid
from datetime import datetime, timedelta, timezone
from typing import Literal
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from pydantic import BaseModel
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.settings import AppEnvironment, Settings
from ._auth import (
GOOGLE_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GoogleCredentials,
GoogleCredentialsField,
GoogleCredentialsInput,
)
class CalendarEvent(BaseModel):
"""Structured representation of a Google Calendar event."""
id: str
title: str
start_time: str
end_time: str
is_all_day: bool
location: str | None
description: str | None
organizer: str | None
attendees: list[str]
has_video_call: bool
video_link: str | None
calendar_link: str
is_recurring: bool
class GoogleCalendarReadEventsBlock(Block):
class Input(BlockSchema):
credentials: GoogleCredentialsInput = GoogleCredentialsField(
["https://www.googleapis.com/auth/calendar.readonly"]
)
calendar_id: str = SchemaField(
description="Calendar ID (use 'primary' for your main calendar)",
default="primary",
)
max_events: int = SchemaField(
description="Maximum number of events to retrieve", default=10
)
start_time: datetime = SchemaField(
description="Retrieve events starting from this time",
default_factory=lambda: datetime.now(tz=timezone.utc),
)
time_range_days: int = SchemaField(
description="Number of days to look ahead for events", default=30
)
search_term: str | None = SchemaField(
description="Optional search term to filter events by", default=None
)
page_token: str | None = SchemaField(
description="Page token from previous request to get the next batch of events. You can use this if you have lots of events you want to process in a loop",
default=None,
)
include_declined_events: bool = SchemaField(
description="Include events you've declined", default=False
)
class Output(BlockSchema):
events: list[CalendarEvent] = SchemaField(
description="List of calendar events in the requested time range",
default_factory=list,
)
event: CalendarEvent = SchemaField(
description="One of the calendar events in the requested time range"
)
next_page_token: str | None = SchemaField(
description="Token for retrieving the next page of events if more exist",
default=None,
)
error: str = SchemaField(
description="Error message if the request failed",
)
def __init__(self):
settings = Settings()
# Create realistic test data for events
test_now = datetime.now(tz=timezone.utc)
test_tomorrow = test_now + timedelta(days=1)
test_event_dict = {
"id": "event1id",
"title": "Team Meeting",
"start_time": test_tomorrow.strftime("%Y-%m-%d %H:%M"),
"end_time": (test_tomorrow + timedelta(hours=1)).strftime("%Y-%m-%d %H:%M"),
"is_all_day": False,
"location": "Conference Room A",
"description": "Weekly team sync",
"organizer": "manager@example.com",
"attendees": ["colleague1@example.com", "colleague2@example.com"],
"has_video_call": True,
"video_link": "https://meet.google.com/abc-defg-hij",
"calendar_link": "https://calendar.google.com/calendar/event?eid=event1id",
"is_recurring": True,
}
super().__init__(
id="80bc3ed1-e9a4-449e-8163-a8fc86f74f6a",
description="Retrieves upcoming events from a Google Calendar with filtering options",
categories={BlockCategory.PRODUCTIVITY, BlockCategory.DATA},
input_schema=GoogleCalendarReadEventsBlock.Input,
output_schema=GoogleCalendarReadEventsBlock.Output,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED
or settings.config.app_env == AppEnvironment.PRODUCTION,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"calendar_id": "primary",
"max_events": 5,
"start_time": test_now.isoformat(),
"time_range_days": 7,
"search_term": None,
"include_declined_events": False,
"page_token": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("event", test_event_dict),
("events", [test_event_dict]),
],
test_mock={
"_read_calendar": lambda *args, **kwargs: {
"items": [
{
"id": "event1id",
"summary": "Team Meeting",
"start": {
"dateTime": test_tomorrow.isoformat(),
"timeZone": "UTC",
},
"end": {
"dateTime": (
test_tomorrow + timedelta(hours=1)
).isoformat(),
"timeZone": "UTC",
},
"location": "Conference Room A",
"description": "Weekly team sync",
"organizer": {"email": "manager@example.com"},
"attendees": [
{"email": "colleague1@example.com"},
{"email": "colleague2@example.com"},
],
"conferenceData": {
"conferenceUrl": "https://meet.google.com/abc-defg-hij"
},
"htmlLink": "https://calendar.google.com/calendar/event?eid=event1id",
"recurrence": ["RRULE:FREQ=WEEKLY;COUNT=10"],
}
],
"nextPageToken": None,
},
"_format_events": lambda *args, **kwargs: [test_event_dict],
},
)
def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
) -> BlockOutput:
try:
service = self._build_service(credentials, **kwargs)
# Calculate end time based on start time and time range
end_time = input_data.start_time + timedelta(
days=input_data.time_range_days
)
# Call Google Calendar API
result = self._read_calendar(
service=service,
calendarId=input_data.calendar_id,
time_min=input_data.start_time.isoformat(),
time_max=end_time.isoformat(),
max_results=input_data.max_events,
single_events=True,
search_term=input_data.search_term,
show_deleted=False,
show_hidden=input_data.include_declined_events,
page_token=input_data.page_token,
)
# Format events into a user-friendly structure
formatted_events = self._format_events(result.get("items", []))
# Include next page token if available
if next_page_token := result.get("nextPageToken"):
yield "next_page_token", next_page_token
for event in formatted_events:
yield "event", event
yield "events", formatted_events
except Exception as e:
yield "error", str(e)
@staticmethod
def _build_service(credentials: GoogleCredentials, **kwargs):
creds = Credentials(
token=(
credentials.access_token.get_secret_value()
if credentials.access_token
else None
),
refresh_token=(
credentials.refresh_token.get_secret_value()
if credentials.refresh_token
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=Settings().secrets.google_client_id,
client_secret=Settings().secrets.google_client_secret,
scopes=credentials.scopes,
)
return build("calendar", "v3", credentials=creds)
def _read_calendar(
self,
service,
calendarId: str,
time_min: str,
time_max: str,
max_results: int,
single_events: bool,
search_term: str | None = None,
show_deleted: bool = False,
show_hidden: bool = False,
page_token: str | None = None,
) -> dict:
"""Read calendar events with optional filtering."""
calendar = service.events()
# Build query parameters
params = {
"calendarId": calendarId,
"timeMin": time_min,
"timeMax": time_max,
"maxResults": max_results,
"singleEvents": single_events,
"orderBy": "startTime",
"showDeleted": show_deleted,
"showHiddenInvitations": show_hidden,
**({"pageToken": page_token} if page_token else {}),
}
# Add search term if provided
if search_term:
params["q"] = search_term
result = calendar.list(**params).execute()
return result
def _format_events(self, events: list[dict]) -> list[CalendarEvent]:
"""Format Google Calendar API events into user-friendly structure."""
formatted_events = []
for event in events:
# Determine if all-day event
is_all_day = "date" in event.get("start", {})
# Format start and end times
if is_all_day:
start_time = event.get("start", {}).get("date", "")
end_time = event.get("end", {}).get("date", "")
else:
# Convert ISO format to more readable format
start_datetime = datetime.fromisoformat(
event.get("start", {}).get("dateTime", "").replace("Z", "+00:00")
)
end_datetime = datetime.fromisoformat(
event.get("end", {}).get("dateTime", "").replace("Z", "+00:00")
)
start_time = start_datetime.strftime("%Y-%m-%d %H:%M")
end_time = end_datetime.strftime("%Y-%m-%d %H:%M")
# Extract attendees
attendees = []
for attendee in event.get("attendees", []):
if email := attendee.get("email"):
attendees.append(email)
# Check for video call link
has_video_call = False
video_link = None
if conf_data := event.get("conferenceData"):
if conf_url := conf_data.get("conferenceUrl"):
has_video_call = True
video_link = conf_url
elif entry_points := conf_data.get("entryPoints", []):
for entry in entry_points:
if entry.get("entryPointType") == "video":
has_video_call = True
video_link = entry.get("uri")
break
# Create formatted event
formatted_event = CalendarEvent(
id=event.get("id", ""),
title=event.get("summary", "Untitled Event"),
start_time=start_time,
end_time=end_time,
is_all_day=is_all_day,
location=event.get("location"),
description=event.get("description"),
organizer=event.get("organizer", {}).get("email"),
attendees=attendees,
has_video_call=has_video_call,
video_link=video_link,
calendar_link=event.get("htmlLink", ""),
is_recurring=bool(event.get("recurrence")),
)
formatted_events.append(formatted_event)
return formatted_events
class ReminderPreset(enum.Enum):
"""Common reminder times before an event."""
TEN_MINUTES = 10
THIRTY_MINUTES = 30
ONE_HOUR = 60
ONE_DAY = 1440 # 24 hours in minutes
class RecurrenceFrequency(enum.Enum):
"""Frequency options for recurring events."""
DAILY = "DAILY"
WEEKLY = "WEEKLY"
MONTHLY = "MONTHLY"
YEARLY = "YEARLY"
class ExactTiming(BaseModel):
"""Model for specifying start and end times."""
discriminator: Literal["exact_timing"]
start_datetime: datetime
end_datetime: datetime
class DurationTiming(BaseModel):
"""Model for specifying start time and duration."""
discriminator: Literal["duration_timing"]
start_datetime: datetime
duration_minutes: int
class OneTimeEvent(BaseModel):
"""Model for a one-time event."""
discriminator: Literal["one_time"]
class RecurringEvent(BaseModel):
"""Model for a recurring event."""
discriminator: Literal["recurring"]
frequency: RecurrenceFrequency
count: int
class GoogleCalendarCreateEventBlock(Block):
class Input(BlockSchema):
credentials: GoogleCredentialsInput = GoogleCredentialsField(
["https://www.googleapis.com/auth/calendar"]
)
# Event Details
event_title: str = SchemaField(description="Title of the event")
location: str | None = SchemaField(
description="Location of the event", default=None
)
description: str | None = SchemaField(
description="Description of the event", default=None
)
# Timing
timing: ExactTiming | DurationTiming = SchemaField(
discriminator="discriminator",
advanced=False,
description="Specify when the event starts and ends",
default_factory=lambda: DurationTiming(
discriminator="duration_timing",
start_datetime=datetime.now().replace(microsecond=0, second=0, minute=0)
+ timedelta(hours=1),
duration_minutes=60,
),
)
# Calendar selection
calendar_id: str = SchemaField(
description="Calendar ID (use 'primary' for your main calendar)",
default="primary",
)
# Guests
guest_emails: list[str] = SchemaField(
description="Email addresses of guests to invite", default_factory=list
)
send_notifications: bool = SchemaField(
description="Send email notifications to guests", default=True
)
# Extras
add_google_meet: bool = SchemaField(
description="Include a Google Meet video conference link", default=False
)
recurrence: OneTimeEvent | RecurringEvent = SchemaField(
discriminator="discriminator",
description="Whether the event repeats",
default_factory=lambda: OneTimeEvent(discriminator="one_time"),
)
reminder_minutes: list[ReminderPreset] = SchemaField(
description="When to send reminders before the event",
default_factory=lambda: [ReminderPreset.TEN_MINUTES],
)
class Output(BlockSchema):
event_id: str = SchemaField(description="ID of the created event")
event_link: str = SchemaField(
description="Link to view the event in Google Calendar"
)
error: str = SchemaField(description="Error message if event creation failed")
def __init__(self):
settings = Settings()
super().__init__(
id="ed2ec950-fbff-4204-94c0-023fb1d625e0",
description="This block creates a new event in Google Calendar with customizable parameters.",
categories={BlockCategory.PRODUCTIVITY},
input_schema=GoogleCalendarCreateEventBlock.Input,
output_schema=GoogleCalendarCreateEventBlock.Output,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED
or settings.config.app_env == AppEnvironment.PRODUCTION,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"event_title": "Team Meeting",
"location": "Conference Room A",
"description": "Weekly team sync-up",
"calendar_id": "primary",
"guest_emails": ["colleague1@example.com", "colleague2@example.com"],
"add_google_meet": True,
"send_notifications": True,
"reminder_minutes": [
ReminderPreset.TEN_MINUTES.value,
ReminderPreset.ONE_HOUR.value,
],
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("event_id", "abc123event_id"),
("event_link", "https://calendar.google.com/calendar/event?eid=abc123"),
],
test_mock={
"_create_event": lambda *args, **kwargs: {
"id": "abc123event_id",
"htmlLink": "https://calendar.google.com/calendar/event?eid=abc123",
}
},
)
def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
) -> BlockOutput:
try:
service = self._build_service(credentials, **kwargs)
# Get start and end times based on the timing option
if input_data.timing.discriminator == "exact_timing":
start_datetime = input_data.timing.start_datetime
end_datetime = input_data.timing.end_datetime
else: # duration_timing
start_datetime = input_data.timing.start_datetime
end_datetime = start_datetime + timedelta(
minutes=input_data.timing.duration_minutes
)
# Format datetimes for Google Calendar API
start_time_str = start_datetime.isoformat()
end_time_str = end_datetime.isoformat()
# Build the event body
event_body = {
"summary": input_data.event_title,
"start": {"dateTime": start_time_str},
"end": {"dateTime": end_time_str},
}
# Add optional fields
if input_data.location:
event_body["location"] = input_data.location
if input_data.description:
event_body["description"] = input_data.description
# Add guests
if input_data.guest_emails:
event_body["attendees"] = [
{"email": email} for email in input_data.guest_emails
]
# Add reminders
if input_data.reminder_minutes:
event_body["reminders"] = {
"useDefault": False,
"overrides": [
{"method": "popup", "minutes": reminder.value}
for reminder in input_data.reminder_minutes
],
}
# Add Google Meet
if input_data.add_google_meet:
event_body["conferenceData"] = {
"createRequest": {
"requestId": f"meet-{uuid.uuid4()}",
"conferenceSolutionKey": {"type": "hangoutsMeet"},
}
}
# Add recurrence
if input_data.recurrence.discriminator == "recurring":
rule = f"RRULE:FREQ={input_data.recurrence.frequency.value}"
rule += f";COUNT={input_data.recurrence.count}"
event_body["recurrence"] = [rule]
# Create the event
result = self._create_event(
service=service,
calendar_id=input_data.calendar_id,
event_body=event_body,
send_notifications=input_data.send_notifications,
conference_data_version=1 if input_data.add_google_meet else 0,
)
yield "event_id", result.get("id", "")
yield "event_link", result.get("htmlLink", "")
except Exception as e:
yield "error", str(e)
@staticmethod
def _build_service(credentials: GoogleCredentials, **kwargs):
creds = Credentials(
token=(
credentials.access_token.get_secret_value()
if credentials.access_token
else None
),
refresh_token=(
credentials.refresh_token.get_secret_value()
if credentials.refresh_token
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=Settings().secrets.google_client_id,
client_secret=Settings().secrets.google_client_secret,
scopes=credentials.scopes,
)
return build("calendar", "v3", credentials=creds)
def _create_event(
self,
service,
calendar_id: str,
event_body: dict,
send_notifications: bool = False,
conference_data_version: int = 0,
) -> dict:
"""Create a new event in Google Calendar."""
calendar = service.events()
# Make the API call
result = calendar.insert(
calendarId=calendar_id,
body=event_body,
sendNotifications=send_notifications,
conferenceDataVersion=conference_data_version,
).execute()
return result

View File

@@ -3,7 +3,7 @@ from googleapiclient.discovery import build
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.settings import AppEnvironment, Settings
from backend.util.settings import Settings
from ._auth import (
GOOGLE_OAUTH_IS_CONFIGURED,
@@ -36,15 +36,13 @@ class GoogleSheetsReadBlock(Block):
)
def __init__(self):
settings = Settings()
super().__init__(
id="5724e902-3635-47e9-a108-aaa0263a4988",
description="This block reads data from a Google Sheets spreadsheet.",
categories={BlockCategory.DATA},
input_schema=GoogleSheetsReadBlock.Input,
output_schema=GoogleSheetsReadBlock.Output,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED
or settings.config.app_env == AppEnvironment.PRODUCTION,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
test_input={
"spreadsheet_id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms",
"range": "Sheet1!A1:B2",

View File

@@ -82,15 +82,7 @@ class SendWebRequestBlock(Block):
json=body if input_data.json_format else None,
data=body if not input_data.json_format else None,
)
if input_data.json_format:
if response.status_code == 204 or not response.content.strip():
result = None
else:
result = response.json()
else:
result = response.text
result = response.json() if input_data.json_format else response.text
yield "response", result
except HTTPError as e:

View File

@@ -288,13 +288,6 @@ def convert_openai_tool_fmt_to_anthropic(
return anthropic_tools
def estimate_token_count(prompt_messages: list[dict]) -> int:
char_count = sum(len(str(msg.get("content", ""))) for msg in prompt_messages)
message_overhead = len(prompt_messages) * 4
estimated_tokens = (char_count // 4) + message_overhead
return int(estimated_tokens * 1.2)
def llm_call(
credentials: APIKeyCredentials,
llm_model: LlmModel,
@@ -326,14 +319,7 @@ def llm_call(
- completion_tokens: The number of tokens used in the completion.
"""
provider = llm_model.metadata.provider
# Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt)
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or 4096
user_max = max_tokens or model_max_output
available_tokens = max(context_window - estimated_input_tokens, 0)
max_tokens = max(min(available_tokens, model_max_output, user_max), 0)
max_tokens = max_tokens or llm_model.max_output_tokens or 4096
if provider == "openai":
tools_param = tools if tools else openai.NOT_GIVEN
@@ -489,7 +475,6 @@ def llm_call(
model=llm_model.value,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
options={"num_ctx": max_tokens},
)
return LLMResponse(
raw_response=response.get("response") or "",
@@ -788,16 +773,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
prompt.append({"role": "user", "content": retry_prompt})
except Exception as e:
logger.exception(f"Error calling LLM: {e}")
if (
"maximum context length" in str(e).lower()
or "token limit" in str(e).lower()
):
if input_data.max_tokens is None:
input_data.max_tokens = llm_model.max_output_tokens or 4096
input_data.max_tokens = int(input_data.max_tokens * 0.85)
logger.debug(
f"Reducing max_tokens to {input_data.max_tokens} for next attempt"
)
retry_prompt = f"Error calling LLM: {e}"
finally:
self.merge_stats(

View File

@@ -26,10 +26,10 @@ logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManagerClient
from backend.executor import DatabaseManager
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerClient)
return get_service_client(DatabaseManager)
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
@@ -246,10 +246,6 @@ class SmartDecisionMakerBlock(Block):
test_credentials=llm.TEST_CREDENTIALS,
)
@staticmethod
def cleanup(s: str):
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
@staticmethod
def _create_block_function_signature(
sink_node: "Node", links: list["Link"]
@@ -270,7 +266,7 @@ class SmartDecisionMakerBlock(Block):
block = sink_node.block
tool_function: dict[str, Any] = {
"name": SmartDecisionMakerBlock.cleanup(block.name),
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", block.name).lower(),
"description": block.description,
}
@@ -285,7 +281,7 @@ class SmartDecisionMakerBlock(Block):
and sink_block_input_schema.model_fields[link.sink_name].description
else f"The {link.sink_name} of the tool"
)
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
properties[link.sink_name.lower()] = {
"type": "string",
"description": description,
}
@@ -330,7 +326,7 @@ class SmartDecisionMakerBlock(Block):
)
tool_function: dict[str, Any] = {
"name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name),
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", sink_graph_meta.name).lower(),
"description": sink_graph_meta.description,
}
@@ -345,7 +341,7 @@ class SmartDecisionMakerBlock(Block):
in sink_block_input_schema["properties"][link.sink_name]
else f"The {link.sink_name} of the tool"
)
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
properties[link.sink_name.lower()] = {
"type": "string",
"description": description,
}
@@ -507,7 +503,7 @@ class SmartDecisionMakerBlock(Block):
tool_args = json.loads(tool_call.function.arguments)
for arg_name, arg_value in tool_args.items():
yield f"tools_^_{tool_name}_~_{arg_name}", arg_value
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
response.prompt.append(response.raw_response)
yield "conversations", response.prompt

View File

@@ -1,10 +1,11 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timezone
from typing import Any, cast
import stripe
from autogpt_libs.utils.cache import thread_cached
from prisma import Json
from prisma.enums import (
CreditRefundRequestStatus,
@@ -19,7 +20,7 @@ from prisma.types import (
CreditTransactionCreateInput,
CreditTransactionWhereInput,
)
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_exponential
from backend.data import db
from backend.data.block_cost_config import BLOCK_COSTS
@@ -27,17 +28,15 @@ from backend.data.cost import BlockCost
from backend.data.model import (
AutoTopUpConfig,
RefundRequest,
TopUpType,
TransactionHistory,
UserTransaction,
)
from backend.data.notifications import NotificationEventModel, RefundRequestData
from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications.notifications import queue_notification_async
from backend.server.model import Pagination
from backend.server.v2.admin.model import UserHistoryResponse
from backend.data.notifications import NotificationEventDTO, RefundRequestData
from backend.data.user import get_user_by_id
from backend.executor.utils import UsageTransactionMetadata
from backend.notifications import NotificationManager
from backend.util.exceptions import InsufficientBalanceError
from backend.util.retry import func_retry
from backend.util.service import get_service_client
from backend.util.settings import Settings
settings = Settings()
@@ -46,17 +45,6 @@ logger = logging.getLogger(__name__)
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
node_exec_id: str | None = None
block_id: str | None = None
block: str | None = None
input: dict[str, Any] | None = None
reason: str | None = None
class UserCreditBase(ABC):
@abstractmethod
async def get_credits(self, user_id: str) -> int:
@@ -274,7 +262,11 @@ class UserCreditBase(ABC):
)
return transaction_balance, transaction_time
@func_retry
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=1, max=10),
reraise=True,
)
async def _enable_transaction(
self,
transaction_key: str,
@@ -372,17 +364,22 @@ class UserCreditBase(ABC):
class UserCredit(UserCreditBase):
@thread_cached
def notification_client(self) -> NotificationManager:
return get_service_client(NotificationManager)
async def _send_refund_notification(
self,
notification_request: RefundRequestData,
notification_type: NotificationType,
):
await queue_notification_async(
NotificationEventModel(
user_id=notification_request.user_id,
type=notification_type,
data=notification_request,
await asyncio.to_thread(
lambda: self.notification_client().queue_notification(
NotificationEventDTO(
user_id=notification_request.user_id,
type=notification_type,
data=notification_request.model_dump(),
)
)
)
@@ -412,7 +409,6 @@ class UserCredit(UserCreditBase):
# Avoid multiple auto top-ups within the same graph execution.
key=f"AUTO-TOP-UP-{user_id}-{metadata.graph_exec_id}",
ceiling_balance=auto_top_up.threshold,
top_up_type=TopUpType.AUTO,
)
except Exception as e:
# Failed top-up is not critical, we can move on.
@@ -422,30 +418,26 @@ class UserCredit(UserCreditBase):
return balance
async def top_up_credits(
self,
user_id: str,
amount: int,
top_up_type: TopUpType = TopUpType.UNCATEGORIZED,
):
await self._top_up_credits(
user_id=user_id, amount=amount, top_up_type=top_up_type
)
async def top_up_credits(self, user_id: str, amount: int):
await self._top_up_credits(user_id, amount)
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
try:
key = f"REWARD-{user_id}-{step.value}"
if not await CreditTransaction.prisma().find_first(
where={
"userId": user_id,
"transactionKey": key,
}
):
await self._add_transaction(
user_id=user_id,
amount=credits,
transaction_type=CreditTransactionType.GRANT,
transaction_key=f"REWARD-{user_id}-{step.value}",
transaction_key=key,
metadata=Json(
{"reason": f"Reward for completing {step.value} onboarding step."}
),
)
except UniqueViolationError:
# Already rewarded for this step
pass
async def top_up_refund(
self, user_id: str, transaction_key: str, metadata: dict[str, str]
@@ -610,7 +602,7 @@ class UserCredit(UserCreditBase):
evidence_text += (
f"- {tx.description}: Amount ${tx.amount / 100:.2f} on {tx.transaction_time.isoformat()}, "
f"resulting balance ${tx.running_balance / 100:.2f} {additional_comment}\n"
f"resulting balance ${tx.balance / 100:.2f} {additional_comment}\n"
)
evidence_text += (
"\nThis evidence demonstrates that the transaction was authorized and that the charged amount was used to render the service as agreed."
@@ -629,24 +621,7 @@ class UserCredit(UserCreditBase):
amount: int,
key: str | None = None,
ceiling_balance: int | None = None,
top_up_type: TopUpType = TopUpType.UNCATEGORIZED,
metadata: dict | None = None,
):
# init metadata, without sharing it with the world
metadata = metadata or {}
if not metadata["reason"]:
match top_up_type:
case TopUpType.MANUAL:
metadata["reason"] = {"reason": f"Top up credits for {user_id}"}
case TopUpType.AUTO:
metadata["reason"] = {
"reason": f"Auto top up credits for {user_id}"
}
case _:
metadata["reason"] = {
"reason": f"Top up reason unknown for {user_id}"
}
if amount < 0:
raise ValueError(f"Top up amount must not be negative: {amount}")
@@ -669,7 +644,6 @@ class UserCredit(UserCreditBase):
is_active=False,
transaction_key=key,
ceiling_balance=ceiling_balance,
metadata=(Json(metadata)),
)
customer_id = await get_stripe_customer_id(user_id)
@@ -812,15 +786,10 @@ class UserCredit(UserCreditBase):
# Check the Checkout Session's payment_status property
# to determine if fulfillment should be performed
if checkout_session.payment_status in ["paid", "no_payment_required"]:
if payment_intent := checkout_session.payment_intent:
assert isinstance(payment_intent, stripe.PaymentIntent)
new_transaction_key = payment_intent.id
else:
new_transaction_key = None
assert isinstance(checkout_session.payment_intent, stripe.PaymentIntent)
await self._enable_transaction(
transaction_key=credit_transaction.transactionKey,
new_transaction_key=new_transaction_key,
new_transaction_key=checkout_session.payment_intent.id,
user_id=credit_transaction.userId,
metadata=Json(checkout_session),
)
@@ -853,9 +822,8 @@ class UserCredit(UserCreditBase):
take=transaction_count_limit,
)
# doesn't fill current_balance, reason, user_email, admin_email, or extra_data
grouped_transactions: dict[str, UserTransaction] = defaultdict(
lambda: UserTransaction(user_id=user_id)
lambda: UserTransaction()
)
tx_time = None
for t in transactions:
@@ -885,7 +853,7 @@ class UserCredit(UserCreditBase):
if tx_time > gt.transaction_time:
gt.transaction_time = tx_time
gt.running_balance = t.runningBalance or 0
gt.balance = t.runningBalance or 0
return TransactionHistory(
transactions=list(grouped_transactions.values()),
@@ -935,7 +903,6 @@ class BetaUserCredit(UserCredit):
amount=max(self.num_user_credits_refill - balance, 0),
transaction_type=CreditTransactionType.GRANT,
transaction_key=f"MONTHLY-CREDIT-TOP-UP-{cur_time}",
metadata=Json({"reason": "Monthly credit refill"}),
)
return balance
except UniqueViolationError:
@@ -945,7 +912,7 @@ class BetaUserCredit(UserCredit):
class DisabledUserCredit(UserCreditBase):
async def get_credits(self, *args, **kwargs) -> int:
return 100
return 0
async def get_transaction_history(self, *args, **kwargs) -> TransactionHistory:
return TransactionHistory(transactions=[], next_transaction_time=None)
@@ -1023,81 +990,3 @@ async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
return AutoTopUpConfig(threshold=0, amount=0)
return AutoTopUpConfig.model_validate(user.topUpConfig)
async def admin_get_user_history(
page: int = 1,
page_size: int = 20,
search: str | None = None,
transaction_filter: CreditTransactionType | None = None,
) -> UserHistoryResponse:
if page < 1 or page_size < 1:
raise ValueError("Invalid pagination input")
where_clause: CreditTransactionWhereInput = {}
if transaction_filter:
where_clause["type"] = transaction_filter
if search:
where_clause["OR"] = [
{"userId": {"contains": search, "mode": "insensitive"}},
{"User": {"is": {"email": {"contains": search, "mode": "insensitive"}}}},
{"User": {"is": {"name": {"contains": search, "mode": "insensitive"}}}},
]
transactions = await CreditTransaction.prisma().find_many(
where=where_clause,
skip=(page - 1) * page_size,
take=page_size,
include={"User": True},
order={"createdAt": "desc"},
)
total = await CreditTransaction.prisma().count(where=where_clause)
total_pages = (total + page_size - 1) // page_size
history = []
for tx in transactions:
admin_id = ""
admin_email = ""
reason = ""
metadata: dict = cast(dict, tx.metadata) or {}
if metadata:
admin_id = metadata.get("admin_id")
admin_email = (
(await get_user_email_by_id(admin_id) or f"Unknown Admin: {admin_id}")
if admin_id
else ""
)
reason = metadata.get("reason", "No reason provided")
balance, last_update = await get_user_credit_model()._get_credits(tx.userId)
history.append(
UserTransaction(
transaction_key=tx.transactionKey,
transaction_time=tx.createdAt,
transaction_type=tx.type,
amount=tx.amount,
current_balance=balance,
running_balance=tx.runningBalance or 0,
user_id=tx.userId,
user_email=(
tx.User.email
if tx.User
else (await get_user_by_id(tx.userId)).email
),
reason=reason,
admin_email=admin_email,
extra_data=str(metadata),
)
)
return UserHistoryResponse(
history=history,
pagination=Pagination(
total_items=total,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)

View File

@@ -30,7 +30,7 @@ from prisma.types import (
AgentNodeExecutionUpdateInput,
AgentNodeExecutionWhereInput,
)
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel
from pydantic.fields import Field
from backend.server.v2.store.exceptions import DatabaseError
@@ -69,55 +69,10 @@ class GraphExecutionMeta(BaseDbModel):
ended_at: datetime
class Stats(BaseModel):
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
)
cost: int = Field(
default=0,
description="Execution cost (cents)",
)
duration: float = Field(
default=0,
description="Seconds from start to end of run",
)
duration_cpu_only: float = Field(
default=0,
description="CPU sec of duration",
)
node_exec_time: float = Field(
default=0,
description="Seconds of total node runtime",
)
node_exec_time_cpu_only: float = Field(
default=0,
description="CPU sec of node_exec_time",
)
node_exec_count: int = Field(
default=0,
description="Number of node executions",
)
node_error_count: int = Field(
default=0,
description="Number of node errors",
)
error: str | None = Field(
default=None,
description="Error message if any",
)
def to_db(self) -> GraphExecutionStats:
return GraphExecutionStats(
cost=self.cost,
walltime=self.duration,
cputime=self.duration_cpu_only,
nodes_walltime=self.node_exec_time,
nodes_cputime=self.node_exec_time_cpu_only,
node_count=self.node_exec_count,
node_error_count=self.node_error_count,
error=self.error,
)
cost: int = Field(..., description="Execution cost (cents)")
duration: float = Field(..., description="Seconds from start to end of run")
node_exec_time: float = Field(..., description="Seconds of total node runtime")
node_exec_count: int = Field(..., description="Number of node executions")
stats: Stats | None
@@ -151,16 +106,8 @@ class GraphExecutionMeta(BaseDbModel):
GraphExecutionMeta.Stats(
cost=stats.cost,
duration=stats.walltime,
duration_cpu_only=stats.cputime,
node_exec_time=stats.nodes_walltime,
node_exec_time_cpu_only=stats.nodes_cputime,
node_exec_count=stats.node_count,
node_error_count=stats.node_error_count,
error=(
str(stats.error)
if isinstance(stats.error, Exception)
else stats.error
),
)
if stats
else None
@@ -261,6 +208,18 @@ class GraphExecutionWithNodes(GraphExecution):
graph_id=self.graph_id,
graph_version=self.graph_version or 0,
graph_exec_id=self.id,
start_node_execs=[
NodeExecutionEntry(
user_id=self.user_id,
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
data=node_exec.input_data,
)
for node_exec in self.node_executions
],
node_credentials_input_map={}, # FIXME
)
@@ -321,28 +280,13 @@ class NodeExecutionResult(BaseModel):
end_time=_node_exec.endedTime,
)
def to_node_execution_entry(self) -> "NodeExecutionEntry":
return NodeExecutionEntry(
user_id=self.user_id,
graph_exec_id=self.graph_exec_id,
graph_id=self.graph_id,
node_exec_id=self.node_exec_id,
node_id=self.node_id,
block_id=self.block_id,
inputs=self.input_data,
)
# --------------------- Model functions --------------------- #
async def get_graph_executions(
graph_id: str | None = None,
user_id: str | None = None,
statuses: list[ExecutionStatus] | None = None,
created_time_gte: datetime | None = None,
created_time_lte: datetime | None = None,
limit: int | None = None,
graph_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> list[GraphExecutionMeta]:
where_filter: AgentGraphExecutionWhereInput = {
"isDeleted": False,
@@ -351,18 +295,10 @@ async def get_graph_executions(
where_filter["userId"] = user_id
if graph_id:
where_filter["agentGraphId"] = graph_id
if created_time_gte or created_time_lte:
where_filter["createdAt"] = {
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
}
if statuses:
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
executions = await AgentGraphExecution.prisma().find_many(
where=where_filter,
order={"createdAt": "desc"},
take=limit,
)
return [GraphExecutionMeta.from_db(execution) for execution in executions]
@@ -556,12 +492,21 @@ async def upsert_execution_output(
async def update_graph_execution_start_time(
graph_exec_id: str,
) -> GraphExecution | None:
res = await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
count = await AgentGraphExecution.prisma().update_many(
where={
"id": graph_exec_id,
"executionStatus": ExecutionStatus.QUEUED,
},
data={
"executionStatus": ExecutionStatus.RUNNING,
"startedAt": datetime.now(tz=timezone.utc),
},
)
if count == 0:
return None
res = await AgentGraphExecution.prisma().find_unique(
where={"id": graph_exec_id},
include=GRAPH_EXECUTION_INCLUDE,
)
return GraphExecution.from_db(res) if res else None
@@ -680,9 +625,8 @@ async def delete_graph_execution(
)
async def get_node_executions(
async def get_node_execution_results(
graph_exec_id: str,
node_id: str | None = None,
block_ids: list[str] | None = None,
statuses: list[ExecutionStatus] | None = None,
limit: int | None = None,
@@ -690,8 +634,6 @@ async def get_node_executions(
where_clause: AgentNodeExecutionWhereInput = {
"agentGraphExecutionId": graph_exec_id,
}
if node_id:
where_clause["agentNodeId"] = node_id
if block_ids:
where_clause["Node"] = {"is": {"agentBlockId": {"in": block_ids}}}
if statuses:
@@ -706,6 +648,28 @@ async def get_node_executions(
return res
async def get_graph_executions_in_timerange(
user_id: str, start_time: str, end_time: str
) -> list[GraphExecution]:
try:
executions = await AgentGraphExecution.prisma().find_many(
where={
"startedAt": {
"gte": datetime.fromisoformat(start_time),
"lte": datetime.fromisoformat(end_time),
},
"userId": user_id,
"isDeleted": False,
},
include=GRAPH_EXECUTION_INCLUDE,
)
return [GraphExecution.from_db(execution) for execution in executions]
except Exception as e:
raise DatabaseError(
f"Failed to get executions in timerange {start_time} to {end_time} for user {user_id}: {e}"
) from e
async def get_latest_node_execution(
node_id: str, graph_eid: str
) -> NodeExecutionResult | None:
@@ -726,6 +690,20 @@ async def get_latest_node_execution(
return NodeExecutionResult.from_db(execution)
async def get_incomplete_node_executions(
node_id: str, graph_eid: str
) -> list[NodeExecutionResult]:
executions = await AgentNodeExecution.prisma().find_many(
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": ExecutionStatus.INCOMPLETE,
},
include=EXECUTION_RESULT_INCLUDE,
)
return [NodeExecutionResult.from_db(execution) for execution in executions]
# ----------------- Execution Infrastructure ----------------- #
@@ -734,6 +712,7 @@ class GraphExecutionEntry(BaseModel):
graph_exec_id: str
graph_id: str
graph_version: int
start_node_execs: list["NodeExecutionEntry"]
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]]
@@ -744,7 +723,7 @@ class NodeExecutionEntry(BaseModel):
node_exec_id: str
node_id: str
block_id: str
inputs: BlockInput
data: BlockInput
class ExecutionQueue(Generic[T]):

View File

@@ -172,8 +172,6 @@ class BaseGraph(BaseDbModel):
description: str
nodes: list[Node] = []
links: list[Link] = []
forked_from_id: str | None = None
forked_from_version: int | None = None
@computed_field
@property
@@ -199,6 +197,11 @@ class BaseGraph(BaseDbModel):
)
)
@computed_field
@property
def credentials_input_schema(self) -> dict[str, Any]:
return self._credentials_input_schema.jsonschema()
@staticmethod
def _generate_schema(
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
@@ -231,15 +234,6 @@ class BaseGraph(BaseDbModel):
"required": [p.name for p in schema_fields if p.value is None],
}
class Graph(BaseGraph):
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
@computed_field
@property
def credentials_input_schema(self) -> dict[str, Any]:
return self._credentials_input_schema.jsonschema()
@property
def _credentials_input_schema(self) -> type[BlockSchema]:
graph_credentials_inputs = self.aggregate_credentials_inputs()
@@ -318,14 +312,17 @@ class Graph(BaseGraph):
),
(node.id, field_name),
)
for graph in [self] + self.sub_graphs
for node in graph.nodes
for node in self.nodes
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
)
)
}
class Graph(BaseGraph):
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs, only used in export
class GraphModel(Graph):
user_id: str
nodes: list[NodeModel] = [] # type: ignore
@@ -401,7 +398,7 @@ class GraphModel(Graph):
if node.block_id != AgentExecutorBlock().id:
continue
node.input_default["user_id"] = user_id
node.input_default.setdefault("inputs", {})
node.input_default.setdefault("data", {})
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
node.input_default["graph_id"] = graph_id_map[graph_id]
@@ -412,13 +409,10 @@ class GraphModel(Graph):
@staticmethod
def _validate_graph(graph: BaseGraph, for_run: bool = False):
def is_tool_pin(name: str) -> bool:
return name.startswith("tools_^_")
def sanitize(name):
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
if is_tool_pin(sanitized_name):
return "tools"
if sanitized_name.startswith("tools_^_"):
return sanitized_name.split("_^_")[0]
return sanitized_name
# Validate smart decision maker nodes
@@ -428,6 +422,10 @@ class GraphModel(Graph):
if (block := get_block(node.block_id)) is not None
}
for node in graph.nodes:
if (block := nodes_block.get(node.id)) is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
input_links = defaultdict(list)
for link in graph.links:
@@ -442,8 +440,8 @@ class GraphModel(Graph):
[sanitize(name) for name in node.input_default]
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
)
InputSchema = block.input_schema
for name in (required_fields := InputSchema.get_required_fields()):
input_schema = block.input_schema
for name in (required_fields := input_schema.get_required_fields()):
if (
name not in provided_inputs
# Webhook payload is passed in by ExecutionManager
@@ -453,7 +451,7 @@ class GraphModel(Graph):
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
)
# Checking availability of credentials is done by ExecutionManager
and name not in InputSchema.get_credentials_fields()
and name not in input_schema.get_credentials_fields()
# Validate only I/O nodes, or validate everything when executing
and (
for_run
@@ -480,43 +478,37 @@ class GraphModel(Graph):
)
# Get input schema properties and check dependencies
input_fields = InputSchema.model_fields
input_fields = input_schema.model_fields
def has_value(node: Node, name: str):
def has_value(name):
return (
name in node.input_default
node is not None
and name in node.input_default
and node.input_default[name] is not None
and str(node.input_default[name]).strip() != ""
) or (name in input_fields and input_fields[name].default is not None)
# Validate dependencies between fields
for field_name in input_fields.keys():
field_json_schema = InputSchema.get_field_schema(field_name)
dependencies: list[str] = []
# Check regular field dependencies (only pre graph execution)
if for_run:
dependencies.extend(field_json_schema.get("depends_on", []))
# Require presence of credentials discriminator (always).
# The `discriminator` is either the name of a sibling field (str),
# or an object that discriminates between possible types for this field:
# {"propertyName": prop_name, "mapping": {prop_value: sub_schema}}
if (
discriminator := field_json_schema.get("discriminator")
) and isinstance(discriminator, str):
dependencies.append(discriminator)
if not dependencies:
for field_name, field_info in input_fields.items():
# Apply input dependency validation only on run & field with depends_on
json_schema_extra = field_info.json_schema_extra or {}
if not (
for_run
and isinstance(json_schema_extra, dict)
and (
dependencies := cast(
list[str], json_schema_extra.get("depends_on", [])
)
)
):
continue
# Check if dependent field has value in input_default
field_has_value = has_value(node, field_name)
field_has_value = has_value(field_name)
field_is_required = field_name in required_fields
# Check for missing dependencies when dependent field is present
missing_deps = [dep for dep in dependencies if not has_value(node, dep)]
missing_deps = [dep for dep in dependencies if not has_value(dep)]
if missing_deps and (field_has_value or field_is_required):
raise ValueError(
f"Node {block.name} #{node.id}: Field `{field_name}` requires [{', '.join(missing_deps)}] to be set"
@@ -561,7 +553,7 @@ class GraphModel(Graph):
if block.block_type not in [BlockType.AGENT]
else vals.get("input_schema", {}).get("properties", {}).keys()
)
if sanitized_name not in fields and not is_tool_pin(name):
if sanitized_name not in fields and not name.startswith("tools_^_"):
fields_msg = f"Allowed fields: {fields}"
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
@@ -578,8 +570,6 @@ class GraphModel(Graph):
id=graph.id,
user_id=graph.userId if not for_export else "",
version=graph.version,
forked_from_id=graph.forkedFromId,
forked_from_version=graph.forkedFromVersion,
is_active=graph.isActive,
name=graph.name or "",
description=graph.description or "",
@@ -692,7 +682,6 @@ async def get_graph(
version: int | None = None,
user_id: str | None = None,
for_export: bool = False,
include_subgraphs: bool = False,
) -> GraphModel | None:
"""
Retrieves a graph from the DB.
@@ -729,58 +718,6 @@ async def get_graph(
):
return None
if include_subgraphs or for_export:
sub_graphs = await get_sub_graphs(graph)
return GraphModel.from_db(
graph=graph,
sub_graphs=sub_graphs,
for_export=for_export,
)
return GraphModel.from_db(graph, for_export)
async def get_graph_as_admin(
graph_id: str,
version: int | None = None,
user_id: str | None = None,
for_export: bool = False,
) -> GraphModel | None:
"""
Intentionally parallels the get_graph but should only be used for admin tasks, because can return any graph that's been submitted
Retrieves a graph from the DB.
Defaults to the version with `is_active` if `version` is not passed.
Returns `None` if the record is not found.
"""
logger.warning(f"Getting {graph_id=} {version=} as ADMIN {user_id=} {for_export=}")
where_clause: AgentGraphWhereInput = {
"id": graph_id,
}
if version is not None:
where_clause["version"] = version
graph = await AgentGraph.prisma().find_first(
where=where_clause,
include=AGENT_GRAPH_INCLUDE,
order={"version": "desc"},
)
# For access, the graph must be owned by the user or listed in the store
if graph is None or (
graph.userId != user_id
and not (
await StoreListingVersion.prisma().find_first(
where={
"agentGraphId": graph_id,
"agentGraphVersion": version or graph.version,
}
)
)
):
return None
if for_export:
sub_graphs = await get_sub_graphs(graph)
return GraphModel.from_db(
@@ -910,27 +847,6 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
async def fork_graph(graph_id: str, graph_version: int, user_id: str) -> GraphModel:
"""
Forks a graph by copying it and all its nodes and links to a new graph.
"""
async with transaction() as tx:
graph = await get_graph(graph_id, graph_version, user_id, True)
if not graph:
raise ValueError(f"Graph {graph_id} v{graph_version} not found")
# Set forked from ID and version as itself as it's about ot be copied
graph.forked_from_id = graph.id
graph.forked_from_version = graph.version
graph.name = f"{graph.name} (copy)"
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
graph.validate_graph(for_run=False)
await __create_graph(tx, graph, user_id)
return graph
async def __create_graph(tx, graph: Graph, user_id: str):
graphs = [graph] + graph.sub_graphs
@@ -943,8 +859,6 @@ async def __create_graph(tx, graph: Graph, user_id: str):
description=graph.description,
isActive=graph.is_active,
userId=user_id,
forkedFromId=graph.forked_from_id,
forkedFromVersion=graph.forked_from_version,
)
for graph in graphs
]

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import base64
import enum
import logging
from collections import defaultdict
from datetime import datetime, timezone
@@ -450,12 +449,6 @@ class ContributorDetails(BaseModel):
name: str = Field(title="Name", description="The name of the contributor.")
class TopUpType(enum.Enum):
AUTO = "AUTO"
MANUAL = "MANUAL"
UNCATEGORIZED = "UNCATEGORIZED"
class AutoTopUpConfig(BaseModel):
amount: int
"""Amount of credits to top up."""
@@ -468,18 +461,12 @@ class UserTransaction(BaseModel):
transaction_time: datetime = datetime.min.replace(tzinfo=timezone.utc)
transaction_type: CreditTransactionType = CreditTransactionType.USAGE
amount: int = 0
running_balance: int = 0
current_balance: int = 0
balance: int = 0
description: str | None = None
usage_graph_id: str | None = None
usage_execution_id: str | None = None
usage_node_count: int = 0
usage_start_time: datetime = datetime.max.replace(tzinfo=timezone.utc)
user_id: str
user_email: str | None = None
reason: str | None = None
admin_email: str | None = None
extra_data: str | None = None
class TransactionHistory(BaseModel):

View File

@@ -189,14 +189,26 @@ NotificationData = Annotated[
]
class BaseEventModel(BaseModel):
type: NotificationType
class NotificationEventDTO(BaseModel):
user_id: str
type: NotificationType
data: dict
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
retry_count: int = 0
class SummaryParamsEventDTO(BaseModel):
user_id: str
type: NotificationType
data: dict
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
class NotificationEventModel(BaseEventModel, Generic[NotificationDataType_co]):
class NotificationEventModel(BaseModel, Generic[NotificationDataType_co]):
user_id: str
type: NotificationType
data: NotificationDataType_co
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
@property
def strategy(self) -> QueueType:
@@ -213,8 +225,11 @@ class NotificationEventModel(BaseEventModel, Generic[NotificationDataType_co]):
return NotificationTypeOverride(self.type).template
class SummaryParamsEventModel(BaseEventModel, Generic[SummaryParamsType_co]):
class SummaryParamsEventModel(BaseModel, Generic[SummaryParamsType_co]):
user_id: str
type: NotificationType
data: SummaryParamsType_co
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
def get_notif_data_type(
@@ -369,7 +384,7 @@ class UserNotificationBatchDTO(BaseModel):
def get_batch_delay(notification_type: NotificationType) -> timedelta:
return {
NotificationType.AGENT_RUN: timedelta(days=1),
NotificationType.AGENT_RUN: timedelta(minutes=60),
NotificationType.ZERO_BALANCE: timedelta(minutes=60),
NotificationType.LOW_BALANCE: timedelta(minutes=60),
NotificationType.BLOCK_EXECUTION_FAILED: timedelta(minutes=60),

View File

@@ -1,10 +1,9 @@
from .database import DatabaseManager, DatabaseManagerClient
from .database import DatabaseManager
from .manager import ExecutionManager
from .scheduler import Scheduler
__all__ = [
"DatabaseManager",
"DatabaseManagerClient",
"ExecutionManager",
"Scheduler",
]

View File

@@ -1,15 +1,13 @@
import logging
from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
create_graph_execution,
get_graph_execution,
get_graph_execution_meta,
get_graph_executions,
get_incomplete_node_executions,
get_latest_node_execution,
get_node_executions,
get_node_execution_results,
update_graph_execution_start_time,
update_graph_execution_stats,
update_node_execution_stats,
@@ -41,14 +39,12 @@ from backend.data.user import (
update_user_integrations,
update_user_metadata,
)
from backend.util.service import AppService, AppServiceClient, endpoint_to_sync, expose
from backend.util.service import AppService, exposed_run_and_wait
from backend.util.settings import Config
config = Config()
_user_credit_model = get_user_credit_model()
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
async def _spend_credits(
@@ -57,10 +53,6 @@ async def _spend_credits(
return await _user_credit_model.spend_credits(user_id, cost, metadata)
async def _get_credits(user_id: str) -> int:
return await _user_credit_model.get_credits(user_id)
class DatabaseManager(AppService):
def run_service(self) -> None:
@@ -77,115 +69,58 @@ class DatabaseManager(AppService):
def get_port(cls) -> int:
return config.database_api_port
@staticmethod
def _(
f: Callable[P, R], name: str | None = None
) -> Callable[Concatenate[object, P], R]:
if name is not None:
f.__name__ = name
return cast(Callable[Concatenate[object, P], R], expose(f))
# Executions
get_graph_execution = _(get_graph_execution)
get_graph_executions = _(get_graph_executions)
get_graph_execution_meta = _(get_graph_execution_meta)
create_graph_execution = _(create_graph_execution)
get_node_executions = _(get_node_executions)
get_latest_node_execution = _(get_latest_node_execution)
update_node_execution_status = _(update_node_execution_status)
update_node_execution_status_batch = _(update_node_execution_status_batch)
update_graph_execution_start_time = _(update_graph_execution_start_time)
update_graph_execution_stats = _(update_graph_execution_stats)
update_node_execution_stats = _(update_node_execution_stats)
upsert_execution_input = _(upsert_execution_input)
upsert_execution_output = _(upsert_execution_output)
get_graph_execution = exposed_run_and_wait(get_graph_execution)
create_graph_execution = exposed_run_and_wait(create_graph_execution)
get_node_execution_results = exposed_run_and_wait(get_node_execution_results)
get_incomplete_node_executions = exposed_run_and_wait(
get_incomplete_node_executions
)
get_latest_node_execution = exposed_run_and_wait(get_latest_node_execution)
update_node_execution_status = exposed_run_and_wait(update_node_execution_status)
update_node_execution_status_batch = exposed_run_and_wait(
update_node_execution_status_batch
)
update_graph_execution_start_time = exposed_run_and_wait(
update_graph_execution_start_time
)
update_graph_execution_stats = exposed_run_and_wait(update_graph_execution_stats)
update_node_execution_stats = exposed_run_and_wait(update_node_execution_stats)
upsert_execution_input = exposed_run_and_wait(upsert_execution_input)
upsert_execution_output = exposed_run_and_wait(upsert_execution_output)
# Graphs
get_node = _(get_node)
get_graph = _(get_graph)
get_connected_output_nodes = _(get_connected_output_nodes)
get_graph_metadata = _(get_graph_metadata)
get_node = exposed_run_and_wait(get_node)
get_graph = exposed_run_and_wait(get_graph)
get_connected_output_nodes = exposed_run_and_wait(get_connected_output_nodes)
get_graph_metadata = exposed_run_and_wait(get_graph_metadata)
# Credits
spend_credits = _(_spend_credits, name="spend_credits")
get_credits = _(_get_credits, name="get_credits")
spend_credits = exposed_run_and_wait(_spend_credits)
# User + User Metadata + User Integrations
get_user_metadata = _(get_user_metadata)
update_user_metadata = _(update_user_metadata)
get_user_integrations = _(get_user_integrations)
update_user_integrations = _(update_user_integrations)
get_user_metadata = exposed_run_and_wait(get_user_metadata)
update_user_metadata = exposed_run_and_wait(update_user_metadata)
get_user_integrations = exposed_run_and_wait(get_user_integrations)
update_user_integrations = exposed_run_and_wait(update_user_integrations)
# User Comms - async
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
get_user_email_by_id = _(get_user_email_by_id)
get_user_email_verification = _(get_user_email_verification)
get_user_notification_preference = _(get_user_notification_preference)
get_active_user_ids_in_timerange = exposed_run_and_wait(
get_active_user_ids_in_timerange
)
get_user_email_by_id = exposed_run_and_wait(get_user_email_by_id)
get_user_email_verification = exposed_run_and_wait(get_user_email_verification)
get_user_notification_preference = exposed_run_and_wait(
get_user_notification_preference
)
# Notifications - async
create_or_add_to_user_notification_batch = _(
create_or_add_to_user_notification_batch = exposed_run_and_wait(
create_or_add_to_user_notification_batch
)
empty_user_notification_batch = _(empty_user_notification_batch)
get_all_batches_by_type = _(get_all_batches_by_type)
get_user_notification_batch = _(get_user_notification_batch)
get_user_notification_oldest_message_in_batch = _(
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
get_all_batches_by_type = exposed_run_and_wait(get_all_batches_by_type)
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)
get_user_notification_oldest_message_in_batch = exposed_run_and_wait(
get_user_notification_oldest_message_in_batch
)
class DatabaseManagerClient(AppServiceClient):
d = DatabaseManager
_ = endpoint_to_sync
@classmethod
def get_service_type(cls):
return DatabaseManager
# Executions
get_graph_execution = _(d.get_graph_execution)
get_graph_executions = _(d.get_graph_executions)
get_graph_execution_meta = _(d.get_graph_execution_meta)
create_graph_execution = _(d.create_graph_execution)
get_node_executions = _(d.get_node_executions)
get_latest_node_execution = _(d.get_latest_node_execution)
update_node_execution_status = _(d.update_node_execution_status)
update_node_execution_status_batch = _(d.update_node_execution_status_batch)
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
update_graph_execution_stats = _(d.update_graph_execution_stats)
update_node_execution_stats = _(d.update_node_execution_stats)
upsert_execution_input = _(d.upsert_execution_input)
upsert_execution_output = _(d.upsert_execution_output)
# Graphs
get_node = _(d.get_node)
get_graph = _(d.get_graph)
get_connected_output_nodes = _(d.get_connected_output_nodes)
get_graph_metadata = _(d.get_graph_metadata)
# Credits
spend_credits = _(d.spend_credits)
get_credits = _(d.get_credits)
# User + User Metadata + User Integrations
get_user_metadata = _(d.get_user_metadata)
update_user_metadata = _(d.update_user_metadata)
get_user_integrations = _(d.get_user_integrations)
update_user_integrations = _(d.update_user_integrations)
# User Comms - async
get_active_user_ids_in_timerange = _(d.get_active_user_ids_in_timerange)
get_user_email_by_id = _(d.get_user_email_by_id)
get_user_email_verification = _(d.get_user_email_verification)
get_user_notification_preference = _(d.get_user_notification_preference)
# Notifications - async
create_or_add_to_user_notification_batch = _(
d.create_or_add_to_user_notification_batch
)
empty_user_notification_batch = _(d.empty_user_notification_batch)
get_all_batches_by_type = _(d.get_all_batches_by_type)
get_user_notification_batch = _(d.get_user_notification_batch)
get_user_notification_oldest_message_in_batch = _(
d.get_user_notification_oldest_message_in_batch
)

View File

@@ -5,42 +5,35 @@ import os
import signal
import sys
import threading
import time
from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import contextmanager
from multiprocessing.pool import AsyncResult, Pool
from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
from pika.adapters.blocking_connection import BlockingChannel
from pika.spec import Basic, BasicProperties
from redis.lock import Lock as RedisLock
from backend.blocks.io import AgentOutputBlock
from backend.data.model import (
CredentialsMetaInput,
GraphExecutionStats,
NodeExecutionStats,
)
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.notifications import (
AgentRunData,
LowBalanceData,
NotificationEventModel,
NotificationEventDTO,
NotificationType,
)
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.utils import create_execution_queue_config
from backend.notifications.notifications import queue_notification
from backend.util.exceptions import InsufficientBalanceError
if TYPE_CHECKING:
from backend.executor import DatabaseManagerClient
from backend.executor import DatabaseManager
from backend.notifications.notifications import NotificationManager
from autogpt_libs.utils.cache import thread_cached
from prometheus_client import Gauge, start_http_server
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
from backend.blocks.agent import AgentExecutorBlock
from backend.data import redis
from backend.data.block import BlockData, BlockInput, BlockSchema, get_block
from backend.data.credit import UsageTransactionMetadata
from backend.data.execution import (
ExecutionQueue,
ExecutionStatus,
@@ -54,6 +47,7 @@ from backend.executor.utils import (
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
GRAPH_EXECUTION_QUEUE_NAME,
CancelExecutionEvent,
UsageTransactionMetadata,
block_usage_cost,
execution_usage_cost,
get_execution_event_bus,
@@ -67,24 +61,12 @@ from backend.util.decorator import error_logged, time_measured
from backend.util.file import clean_exec_files
from backend.util.logging import configure_logging
from backend.util.process import AppProcess, set_service_name
from backend.util.retry import func_retry
from backend.util.service import get_service_client
from backend.util.service import close_service_client, get_service_client
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
active_runs_gauge = Gauge(
"execution_manager_active_runs", "Number of active graph runs"
)
pool_size_gauge = Gauge(
"execution_manager_pool_size", "Maximum number of graph workers"
)
utilization_gauge = Gauge(
"execution_manager_utilization_ratio",
"Ratio of active graph runs to max graph workers",
)
class LogMetadata:
def __init__(
@@ -139,13 +121,10 @@ ExecutionStream = Generator[NodeExecutionEntry, None, None]
def execute_node(
db_client: "DatabaseManagerClient",
db_client: "DatabaseManager",
creds_manager: IntegrationCredentialsManager,
data: NodeExecutionEntry,
execution_stats: NodeExecutionStats | None = None,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
) -> ExecutionStream:
"""
Execute a node in the graph. This will trigger a block execution on a node,
@@ -193,7 +172,7 @@ def execute_node(
)
# Sanity check: validate the execution input.
input_data, error = validate_exec(node, data.inputs, resolve_input=False)
input_data, error = validate_exec(node, data.data, resolve_input=False)
if input_data is None:
log_metadata.error(f"Skip execution, input validation error: {error}")
push_output("error", error)
@@ -203,12 +182,8 @@ def execute_node(
# Re-shape the input data for agent block.
# AgentExecutorBlock specially separate the node input_data & its input_default.
if isinstance(node_block, AgentExecutorBlock):
_input_data = AgentExecutorBlock.Input(**node.input_default)
_input_data.inputs = input_data
if node_credentials_input_map:
_input_data.node_credentials_input_map = node_credentials_input_map
input_data = _input_data.model_dump()
data.inputs = input_data
input_data = {**node.input_default, "data": input_data}
data.data = input_data
# Execute the node
input_data_str = json.dumps(input_data)
@@ -255,7 +230,6 @@ def execute_node(
graph_exec_id=graph_exec_id,
graph_id=graph_id,
log_metadata=log_metadata,
node_credentials_input_map=node_credentials_input_map,
):
yield execution
@@ -274,14 +248,13 @@ def execute_node(
graph_exec_id=graph_exec_id,
graph_id=graph_id,
log_metadata=log_metadata,
node_credentials_input_map=node_credentials_input_map,
):
yield execution
raise e
finally:
# Ensure credentials are released even if execution fails
if creds_lock and creds_lock.locked() and creds_lock.owned():
if creds_lock and creds_lock.locked():
try:
creds_lock.release()
except Exception as e:
@@ -297,14 +270,13 @@ def execute_node(
def _enqueue_next_nodes(
db_client: "DatabaseManagerClient",
db_client: "DatabaseManager",
node: Node,
output: BlockData,
user_id: str,
graph_exec_id: str,
graph_id: str,
log_metadata: LogMetadata,
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]],
) -> list[NodeExecutionEntry]:
def add_enqueued_execution(
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
@@ -320,7 +292,7 @@ def _enqueue_next_nodes(
node_exec_id=node_exec_id,
node_id=node_id,
block_id=block_id,
inputs=data,
data=data,
)
def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
@@ -361,15 +333,6 @@ def _enqueue_next_nodes(
for name in static_link_names:
next_node_input[name] = latest_execution.input_data.get(name)
# Apply node credentials overrides
node_credentials = None
if node_credentials_input_map and (
node_credentials := node_credentials_input_map.get(next_node.id)
):
next_node_input.update(
{k: v.model_dump() for k, v in node_credentials.items()}
)
# Validate the input data for the next node.
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
suffix = f"{next_output_name}>{next_input_name}~{next_node_exec_id}:{validation_msg}"
@@ -396,10 +359,8 @@ def _enqueue_next_nodes(
# If link is static, there could be some incomplete executions waiting for it.
# Load and complete the input missing input data, and try to re-enqueue them.
for iexec in db_client.get_node_executions(
node_id=next_node_id,
graph_exec_id=graph_exec_id,
statuses=[ExecutionStatus.INCOMPLETE],
for iexec in db_client.get_incomplete_node_executions(
next_node_id, graph_exec_id
):
idata = iexec.input_data
ineid = iexec.node_exec_id
@@ -412,12 +373,6 @@ def _enqueue_next_nodes(
for input_name in static_link_names:
idata[input_name] = next_node_input[input_name]
# Apply node credentials overrides
if node_credentials:
idata.update(
{k: v.model_dump() for k, v in node_credentials.items()}
)
idata, msg = validate_exec(next_node, idata)
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
if not idata:
@@ -467,7 +422,6 @@ class Executor:
"""
@classmethod
@func_retry
def on_node_executor_start(cls):
configure_logging()
set_service_name("NodeExecutor")
@@ -478,28 +432,36 @@ class Executor:
# Set up shutdown handlers
cls.shutdown_lock = threading.Lock()
atexit.register(cls.on_node_executor_stop)
signal.signal(signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm())
signal.signal(signal.SIGINT, lambda _, __: cls.on_node_executor_sigterm())
atexit.register(cls.on_node_executor_stop) # handle regular shutdown
signal.signal( # handle termination
signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm()
)
@classmethod
def on_node_executor_stop(cls, log=logger.info):
def on_node_executor_stop(cls):
if not cls.shutdown_lock.acquire(blocking=False):
return # already shutting down
log(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
cls.creds_manager.release_all_locks()
log(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
redis.disconnect()
log(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB manager...")
cls.db_client.close()
log(f"[on_node_executor_stop {cls.pid}] ✅ Finished NodeExec cleanup")
sys.exit(0)
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB manager...")
close_service_client(cls.db_client)
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
@classmethod
def on_node_executor_sigterm(cls):
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ NodeExec SIGTERM received")
cls.on_node_executor_stop(log=llprint)
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ SIGTERM received")
if not cls.shutdown_lock.acquire(blocking=False):
return # already shutting down
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
cls.creds_manager.release_all_locks()
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
redis.disconnect()
llprint(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
sys.exit(0)
@classmethod
@error_logged
@@ -507,9 +469,6 @@ class Executor:
cls,
q: ExecutionQueue[NodeExecutionEntry],
node_exec: NodeExecutionEntry,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
) -> NodeExecutionStats:
log_metadata = LogMetadata(
user_id=node_exec.user_id,
@@ -522,7 +481,7 @@ class Executor:
execution_stats = NodeExecutionStats()
timing_info, _ = cls._on_node_execution(
q, node_exec, log_metadata, execution_stats, node_credentials_input_map
q, node_exec, log_metadata, execution_stats
)
execution_stats.walltime = timing_info.wall_time
execution_stats.cputime = timing_info.cpu_time
@@ -542,9 +501,6 @@ class Executor:
node_exec: NodeExecutionEntry,
log_metadata: LogMetadata,
stats: NodeExecutionStats | None = None,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
):
try:
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
@@ -553,7 +509,6 @@ class Executor:
creds_manager=cls.creds_manager,
data=node_exec,
execution_stats=stats,
node_credentials_input_map=node_credentials_input_map,
):
q.add(execution)
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
@@ -572,7 +527,6 @@ class Executor:
stats.error = e
@classmethod
@func_retry
def on_graph_executor_start(cls):
configure_logging()
set_service_name("GraphExecutor")
@@ -580,8 +534,23 @@ class Executor:
cls.db_client = get_db_client()
cls.pool_size = settings.config.num_node_workers
cls.pid = os.getpid()
cls.notification_service = get_notification_service()
cls._init_node_executor_pool()
logger.info(f"GraphExec {cls.pid} started with {cls.pool_size} node workers")
logger.info(
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
)
# Set up shutdown handler
atexit.register(cls.on_graph_executor_stop)
@classmethod
def on_graph_executor_stop(cls):
prefix = f"[on_graph_executor_stop {cls.pid}]"
logger.info(f"{prefix} ⏳ Terminating node executor pool...")
cls.executor.terminate()
logger.info(f"{prefix} ⏳ Disconnecting DB manager...")
close_service_client(cls.db_client)
logger.info(f"{prefix} ✅ Finished cleanup")
@classmethod
def _init_node_executor_pool(cls):
@@ -603,46 +572,22 @@ class Executor:
node_eid="*",
block_name="-",
)
exec_meta = cls.db_client.get_graph_execution_meta(
user_id=graph_exec.user_id,
execution_id=graph_exec.graph_exec_id,
exec_meta = cls.db_client.update_graph_execution_start_time(
graph_exec.graph_exec_id
)
if exec_meta is None:
log_metadata.warning(
f"Skipped graph execution #{graph_exec.graph_exec_id}, the graph execution is not found."
)
return
if exec_meta.status == ExecutionStatus.QUEUED:
log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}")
exec_meta.status = ExecutionStatus.RUNNING
send_execution_update(
cls.db_client.update_graph_execution_start_time(
graph_exec.graph_exec_id
)
)
elif exec_meta.status == ExecutionStatus.RUNNING:
log_metadata.info(
f"⚙️ Graph execution #{graph_exec.graph_exec_id} is already running, continuing where it left off."
)
else:
log_metadata.warning(
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution status is `{exec_meta.status}`."
logger.warning(
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution is not found or not currently in the QUEUED state."
)
return
send_execution_update(exec_meta)
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
graph_exec=graph_exec,
cancel=cancel,
log_metadata=log_metadata,
execution_stats=(
exec_meta.stats.to_db() if exec_meta.stats else GraphExecutionStats()
),
graph_exec, cancel, log_metadata
)
exec_stats.walltime += timing_info.wall_time
exec_stats.cputime += timing_info.cpu_time
exec_stats.error = str(error) if error else exec_stats.error
exec_stats.walltime = timing_info.wall_time
exec_stats.cputime = timing_info.cpu_time
exec_stats.error = str(error)
if graph_exec_result := cls.db_client.update_graph_execution_stats(
graph_exec_id=graph_exec.graph_exec_id,
@@ -659,15 +604,13 @@ class Executor:
node_exec: NodeExecutionEntry,
execution_count: int,
execution_stats: GraphExecutionStats,
):
) -> int:
block = get_block(node_exec.block_id)
if not block:
logger.error(f"Block {node_exec.block_id} not found.")
return
return execution_count
cost, matching_filter = block_usage_cost(
block=block, input_data=node_exec.inputs
)
cost, matching_filter = block_usage_cost(block=block, input_data=node_exec.data)
if cost > 0:
cls.db_client.spend_credits(
user_id=node_exec.user_id,
@@ -680,12 +623,11 @@ class Executor:
block_id=node_exec.block_id,
block=block.name,
input=matching_filter,
reason=f"Ran block {node_exec.block_id} {block.name}",
),
)
execution_stats.cost += cost
cost, usage_count = execution_usage_cost(execution_count)
cost, execution_count = execution_usage_cost(execution_count)
if cost > 0:
cls.db_client.spend_credits(
user_id=node_exec.user_id,
@@ -694,14 +636,15 @@ class Executor:
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
input={
"execution_count": usage_count,
"execution_count": execution_count,
"charge": "Execution Cost",
},
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
),
)
execution_stats.cost += cost
return execution_count
@classmethod
@time_measured
def _on_graph_execution(
@@ -709,7 +652,6 @@ class Executor:
graph_exec: GraphExecutionEntry,
cancel: threading.Event,
log_metadata: LogMetadata,
execution_stats: GraphExecutionStats,
) -> tuple[GraphExecutionStats, ExecutionStatus, Exception | None]:
"""
Returns:
@@ -717,6 +659,8 @@ class Executor:
ExecutionStatus: The final status of the graph execution.
Exception | None: The error that occurred during the execution, if any.
"""
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
execution_stats = GraphExecutionStats()
execution_status = ExecutionStatus.RUNNING
error = None
finished = False
@@ -737,21 +681,11 @@ class Executor:
cancel_thread.start()
try:
if cls.db_client.get_credits(graph_exec.user_id) <= 0:
raise InsufficientBalanceError(
user_id=graph_exec.user_id,
message="You have no credits left to run an agent.",
balance=0,
amount=1,
)
queue = ExecutionQueue[NodeExecutionEntry]()
for node_exec in cls.db_client.get_node_executions(
graph_exec.graph_exec_id,
statuses=[ExecutionStatus.RUNNING, ExecutionStatus.QUEUED],
):
queue.add(node_exec.to_node_execution_entry())
for node_exec in graph_exec.start_node_execs:
queue.add(node_exec)
exec_cost_counter = 0
running_executions: dict[str, AsyncResult] = {}
def make_exec_callback(exec_data: NodeExecutionEntry):
@@ -807,9 +741,9 @@ class Executor:
)
try:
cls._charge_usage(
exec_cost_counter = cls._charge_usage(
node_exec=queued_node_exec,
execution_count=increment_execution_count(graph_exec.user_id),
execution_count=exec_cost_counter + 1,
execution_stats=execution_stats,
)
except InsufficientBalanceError as error:
@@ -839,7 +773,7 @@ class Executor:
if (node_creds_map := graph_exec.node_credentials_input_map) and (
node_field_creds_map := node_creds_map.get(node_id)
):
queued_node_exec.inputs.update(
queued_node_exec.data.update(
{
field_name: creds_meta.model_dump()
for field_name, creds_meta in node_field_creds_map.items()
@@ -849,7 +783,7 @@ class Executor:
# Initiate node execution
running_executions[queued_node_exec.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(queue, queued_node_exec, node_creds_map),
(queue, queued_node_exec),
callback=make_exec_callback(queued_node_exec),
)
@@ -870,21 +804,24 @@ class Executor:
execution.wait(3)
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
execution_status = ExecutionStatus.COMPLETED
except Exception as e:
error = e
log_metadata.error(
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
)
execution_status = ExecutionStatus.FAILED
finally:
if error:
log_metadata.error(
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
)
execution_status = ExecutionStatus.FAILED
else:
execution_status = ExecutionStatus.COMPLETED
if not cancel.is_set():
finished = True
cancel.set()
cancel_thread.join()
clean_exec_files(graph_exec.graph_exec_id)
return execution_stats, execution_status, error
@classmethod
@@ -896,7 +833,7 @@ class Executor:
metadata = cls.db_client.get_graph_metadata(
graph_exec.graph_id, graph_exec.graph_version
)
outputs = cls.db_client.get_node_executions(
outputs = cls.db_client.get_node_execution_results(
graph_exec.graph_exec_id,
block_ids=[AgentOutputBlock().id],
)
@@ -909,21 +846,21 @@ class Executor:
for output in outputs
]
queue_notification(
NotificationEventModel(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=named_outputs,
agent_name=metadata.name if metadata else "Unknown Agent",
credits_used=exec_stats.cost,
execution_time=exec_stats.walltime,
graph_id=graph_exec.graph_id,
node_count=exec_stats.node_count,
),
)
event = NotificationEventDTO(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=named_outputs,
agent_name=metadata.name if metadata else "Unknown Agent",
credits_used=exec_stats.cost,
execution_time=exec_stats.walltime,
graph_id=graph_exec.graph_id,
node_count=exec_stats.node_count,
).model_dump(),
)
cls.notification_service.queue_notification(event)
@classmethod
def _handle_low_balance_notif(
cls,
@@ -937,8 +874,8 @@ class Executor:
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
queue_notification(
NotificationEventModel(
cls.notification_service.queue_notification(
NotificationEventDTO(
user_id=user_id,
type=NotificationType.LOW_BALANCE,
data=LowBalanceData(
@@ -946,7 +883,7 @@ class Executor:
billing_page_link=f"{base_url}/profile/credits",
shortfall=shortfall,
agent_name=metadata.name if metadata else "Unknown Agent",
),
).model_dump(),
)
)
@@ -957,23 +894,35 @@ class ExecutionManager(AppProcess):
self.pool_size = settings.config.num_graph_workers
self.running = True
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
atexit.register(self._on_cleanup)
signal.signal(signal.SIGTERM, lambda sig, frame: self._on_sigterm())
signal.signal(signal.SIGINT, lambda sig, frame: self._on_sigterm())
@classmethod
def get_port(cls) -> int:
return settings.config.execution_manager_port
def run(self):
pool_size_gauge.set(self.pool_size)
active_runs_gauge.set(0)
utilization_gauge.set(0)
retry_count_max = settings.config.execution_manager_loop_max_retry
retry_count = 0
self.metrics_server = threading.Thread(
target=start_http_server,
args=(settings.config.execution_manager_port,),
daemon=True,
)
self.metrics_server.start()
logger.info(f"[{self.service_name}] Starting execution manager...")
self._run()
for retry_count in range(retry_count_max):
try:
self._run()
except Exception as e:
if not self.running:
break
logger.exception(
f"[{self.service_name}] Error in execution manager: {e}"
)
if retry_count >= retry_count_max:
logger.error(
f"[{self.service_name}] Max retries reached ({retry_count_max}), exiting..."
)
break
else:
logger.info(
f"[{self.service_name}] Retrying execution loop in {retry_count} seconds..."
)
time.sleep(retry_count)
def _run(self):
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
@@ -985,33 +934,23 @@ class ExecutionManager(AppProcess):
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
redis.connect()
cancel_client = SyncRabbitMQ(create_execution_queue_config())
cancel_client.connect()
cancel_channel = cancel_client.get_channel()
logger.info(f"[{self.service_name}] ⏳ Starting cancel message consumer...")
threading.Thread(
target=lambda: (
cancel_channel.basic_consume(
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
on_message_callback=self._handle_cancel_message,
auto_ack=True,
),
cancel_channel.start_consuming(),
),
daemon=True,
).start()
run_client = SyncRabbitMQ(create_execution_queue_config())
run_client.connect()
run_channel = run_client.get_channel()
run_channel.basic_qos(prefetch_count=self.pool_size)
run_channel.basic_consume(
# Consume Cancel & Run execution requests.
clear_thread_cache(get_execution_queue)
channel = get_execution_queue().get_channel()
channel.basic_qos(prefetch_count=self.pool_size)
channel.basic_consume(
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
on_message_callback=self._handle_cancel_message,
auto_ack=True,
)
channel.basic_consume(
queue=GRAPH_EXECUTION_QUEUE_NAME,
on_message_callback=self._handle_run_message,
auto_ack=False,
)
logger.info(f"[{self.service_name}] ⏳ Starting to consume run messages...")
run_channel.start_consuming()
logger.info(f"[{self.service_name}] Ready to consume messages...")
channel.start_consuming()
def _handle_cancel_message(
self,
@@ -1081,15 +1020,11 @@ class ExecutionManager(AppProcess):
Executor.on_graph_execution, graph_exec_entry, cancel_event
)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
active_runs_gauge.set(len(self.active_graph_runs))
utilization_gauge.set(len(self.active_graph_runs) / self.pool_size)
def _on_run_done(f: Future):
logger.info(f"[{self.service_name}] Run completed for {graph_exec_id}")
try:
self.active_graph_runs.pop(graph_exec_id, None)
active_runs_gauge.set(len(self.active_graph_runs))
utilization_gauge.set(len(self.active_graph_runs) / self.pool_size)
if f.exception():
logger.error(
f"[{self.service_name}] Execution for {graph_exec_id} failed: {f.exception()}"
@@ -1108,44 +1043,42 @@ class ExecutionManager(AppProcess):
def cleanup(self):
super().cleanup()
self._on_cleanup()
def _on_sigterm(self):
llprint(f"[{self.service_name}] ⚠️ GraphExec SIGTERM received")
self._on_cleanup(log=llprint)
def _on_cleanup(self, log=logger.info):
prefix = f"[{self.service_name}][on_graph_executor_stop {os.getpid()}]"
log(f"{prefix} ⏳ Shutting down service loop...")
logger.info(f"[{self.service_name}] ⏳ Shutting down service loop...")
self.running = False
log(f"{prefix} ⏳ Shutting down RabbitMQ channel...")
logger.info(f"[{self.service_name}] ⏳ Shutting down RabbitMQ channel...")
get_execution_queue().get_channel().stop_consuming()
if hasattr(self, "executor"):
log(f"{prefix} ⏳ Shutting down GraphExec pool...")
self.executor.shutdown(cancel_futures=True, wait=False)
logger.info(f"[{self.service_name}] ⏳ Shutting down graph executor pool...")
self.executor.shutdown(cancel_futures=True)
log(f"{prefix} ⏳ Disconnecting Redis...")
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
redis.disconnect()
log(f"{prefix} ✅ Finished GraphExec cleanup")
@property
def db_client(self) -> "DatabaseManager":
return get_db_client()
# ------- UTILITIES ------- #
@thread_cached
def get_db_client() -> "DatabaseManagerClient":
from backend.executor import DatabaseManagerClient
def get_db_client() -> "DatabaseManager":
from backend.executor import DatabaseManager
# Disable health check for the service client to avoid breaking process initializer.
return get_service_client(DatabaseManagerClient, health_check=False)
return get_service_client(DatabaseManager)
def send_execution_update(entry: GraphExecution | NodeExecutionResult | None):
if entry is None:
return
@thread_cached
def get_notification_service() -> "NotificationManager":
from backend.notifications import NotificationManager
return get_service_client(NotificationManager)
def send_execution_update(entry: GraphExecution | NodeExecutionResult):
return get_execution_event_bus().publish(entry)
@@ -1156,26 +1089,14 @@ def synchronized(key: str, timeout: int = 60):
lock.acquire()
yield
finally:
if lock.locked() and lock.owned():
if lock.locked():
lock.release()
def increment_execution_count(user_id: str) -> int:
"""
Increment the execution count for a given user,
this will be used to charge the user for the execution cost.
"""
r = redis.get_redis()
k = f"uec:{user_id}" # User Execution Count global key
counter = cast(int, r.incr(k))
if counter == 1:
r.expire(k, settings.config.execution_counter_expiration_time)
return counter
def llprint(message: str):
"""
Low-level print/log helper function for use in signal handlers.
Regular log/print statements are not allowed in signal handlers.
"""
os.write(sys.stdout.fileno(), (message + "\n").encode())
if logger.getEffectiveLevel() == logging.DEBUG:
os.write(sys.stdout.fileno(), (message + "\n").encode())

View File

@@ -1,6 +1,5 @@
import logging
import os
from datetime import datetime, timedelta, timezone
from enum import Enum
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
@@ -13,21 +12,13 @@ from apscheduler.triggers.cron import CronTrigger
from autogpt_libs.utils.cache import thread_cached
from dotenv import load_dotenv
from prisma.enums import NotificationType
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
from sqlalchemy import MetaData, create_engine
from backend.data.block import BlockInput
from backend.data.execution import ExecutionStatus
from backend.executor import utils as execution_utils
from backend.notifications.notifications import NotificationManagerClient
from backend.util.metrics import sentry_capture_error
from backend.util.service import (
AppService,
AppServiceClient,
endpoint_to_async,
expose,
get_service_client,
)
from backend.notifications.notifications import NotificationManager
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Config
@@ -68,11 +59,13 @@ def job_listener(event):
@thread_cached
def get_notification_client():
return get_service_client(NotificationManagerClient)
from backend.notifications import NotificationManager
return get_service_client(NotificationManager)
def execute_graph(**kwargs):
args = GraphExecutionJobArgs(**kwargs)
args = ExecutionJobArgs(**kwargs)
try:
log(f"Executing recurring job for graph #{args.graph_id}")
execution_utils.add_graph_execution(
@@ -85,37 +78,6 @@ def execute_graph(**kwargs):
logger.exception(f"Error executing graph {args.graph_id}: {e}")
class LateExecutionException(Exception):
pass
def report_late_executions() -> str:
late_executions = execution_utils.get_db_client().get_graph_executions(
statuses=[ExecutionStatus.QUEUED],
created_time_gte=datetime.now(timezone.utc)
- timedelta(seconds=config.execution_late_notification_checkrange_secs),
created_time_lte=datetime.now(timezone.utc)
- timedelta(seconds=config.execution_late_notification_threshold_secs),
limit=1000,
)
if not late_executions:
return "No late executions detected."
num_late_executions = len(late_executions)
num_users = len(set([r.user_id for r in late_executions]))
error = LateExecutionException(
f"Late executions detected: {num_late_executions} late executions from {num_users} users "
f"in the last {config.execution_late_notification_checkrange_secs} seconds. "
f"Graph has been queued for more than {config.execution_late_notification_threshold_secs} seconds. "
"Please check the executor status."
)
msg = str(error)
sentry_capture_error(error)
get_notification_client().discord_system_alert(msg)
return msg
def process_existing_batches(**kwargs):
args = NotificationJobArgs(**kwargs)
try:
@@ -141,7 +103,7 @@ class Jobstores(Enum):
WEEKLY_NOTIFICATIONS = "weekly_notifications"
class GraphExecutionJobArgs(BaseModel):
class ExecutionJobArgs(BaseModel):
graph_id: str
input_data: BlockInput
user_id: str
@@ -149,16 +111,14 @@ class GraphExecutionJobArgs(BaseModel):
cron: str
class GraphExecutionJobInfo(GraphExecutionJobArgs):
class ExecutionJobInfo(ExecutionJobArgs):
id: str
name: str
next_run_time: str
@staticmethod
def from_db(
job_args: GraphExecutionJobArgs, job_obj: JobObj
) -> "GraphExecutionJobInfo":
return GraphExecutionJobInfo(
def from_db(job_args: ExecutionJobArgs, job_obj: JobObj) -> "ExecutionJobInfo":
return ExecutionJobInfo(
id=job_obj.id,
name=job_obj.name,
next_run_time=job_obj.next_run_time.isoformat(),
@@ -191,9 +151,6 @@ class NotificationJobInfo(NotificationJobArgs):
class Scheduler(AppService):
scheduler: BlockingScheduler
def __init__(self, register_system_tasks: bool = True):
self.register_system_tasks = register_system_tasks
@classmethod
def get_port(cls) -> int:
return config.execution_scheduler_port
@@ -202,6 +159,11 @@ class Scheduler(AppService):
def db_pool_size(cls) -> int:
return config.scheduler_db_pool_size
@property
@thread_cached
def notification_client(self) -> NotificationManager:
return get_service_client(NotificationManager)
def run_service(self):
load_dotenv()
db_schema, db_url = _extract_schema_from_url(os.getenv("DIRECT_URL"))
@@ -231,37 +193,6 @@ class Scheduler(AppService):
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
}
)
if self.register_system_tasks:
# Notification PROCESS WEEKLY SUMMARY
self.scheduler.add_job(
process_weekly_summary,
CronTrigger.from_crontab("0 * * * *"),
id="process_weekly_summary",
kwargs={},
replace_existing=True,
jobstore=Jobstores.WEEKLY_NOTIFICATIONS.value,
)
# Notification PROCESS EXISTING BATCHES
# self.scheduler.add_job(
# process_existing_batches,
# id="process_existing_batches",
# CronTrigger.from_crontab("0 12 * * 5"),
# replace_existing=True,
# jobstore=Jobstores.BATCHED_NOTIFICATIONS.value,
# )
# Notification LATE EXECUTIONS ALERT
self.scheduler.add_job(
report_late_executions,
id="report_late_executions",
trigger="interval",
replace_existing=True,
seconds=config.execution_late_notification_threshold_secs,
jobstore=Jobstores.EXECUTION.value,
)
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
self.scheduler.start()
@@ -272,15 +203,15 @@ class Scheduler(AppService):
self.scheduler.shutdown(wait=False)
@expose
def add_graph_execution_schedule(
def add_execution_schedule(
self,
graph_id: str,
graph_version: int,
cron: str,
input_data: BlockInput,
user_id: str,
) -> GraphExecutionJobInfo:
job_args = GraphExecutionJobArgs(
) -> ExecutionJobInfo:
job_args = ExecutionJobArgs(
graph_id=graph_id,
input_data=input_data,
user_id=user_id,
@@ -295,66 +226,77 @@ class Scheduler(AppService):
jobstore=Jobstores.EXECUTION.value,
)
log(f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}")
return GraphExecutionJobInfo.from_db(job_args, job)
return ExecutionJobInfo.from_db(job_args, job)
@expose
def delete_graph_execution_schedule(
self, schedule_id: str, user_id: str
) -> GraphExecutionJobInfo:
def delete_schedule(self, schedule_id: str, user_id: str) -> ExecutionJobInfo:
job = self.scheduler.get_job(schedule_id, jobstore=Jobstores.EXECUTION.value)
if not job:
log(f"Job {schedule_id} not found.")
raise ValueError(f"Job #{schedule_id} not found.")
job_args = GraphExecutionJobArgs(**job.kwargs)
job_args = ExecutionJobArgs(**job.kwargs)
if job_args.user_id != user_id:
raise ValueError("User ID does not match the job's user ID.")
log(f"Deleting job {schedule_id}")
job.remove()
return GraphExecutionJobInfo.from_db(job_args, job)
return ExecutionJobInfo.from_db(job_args, job)
@expose
def get_graph_execution_schedules(
def get_execution_schedules(
self, graph_id: str | None = None, user_id: str | None = None
) -> list[GraphExecutionJobInfo]:
jobs: list[JobObj] = self.scheduler.get_jobs(jobstore=Jobstores.EXECUTION.value)
) -> list[ExecutionJobInfo]:
schedules = []
for job in jobs:
logger.debug(
for job in self.scheduler.get_jobs(jobstore=Jobstores.EXECUTION.value):
logger.info(
f"Found job {job.id} with cron schedule {job.trigger} and args {job.kwargs}"
)
try:
job_args = GraphExecutionJobArgs.model_validate(job.kwargs)
except ValidationError:
continue
job_args = ExecutionJobArgs(**job.kwargs)
if (
job.next_run_time is not None
and (graph_id is None or job_args.graph_id == graph_id)
and (user_id is None or job_args.user_id == user_id)
):
schedules.append(GraphExecutionJobInfo.from_db(job_args, job))
schedules.append(ExecutionJobInfo.from_db(job_args, job))
return schedules
@expose
def execute_process_existing_batches(self, kwargs: dict):
process_existing_batches(**kwargs)
def add_batched_notification_schedule(
self,
notification_types: list[NotificationType],
data: dict,
cron: str,
) -> NotificationJobInfo:
job_args = NotificationJobArgs(
notification_types=notification_types,
cron=cron,
)
job = self.scheduler.add_job(
process_existing_batches,
CronTrigger.from_crontab(cron),
kwargs=job_args.model_dump(),
replace_existing=True,
jobstore=Jobstores.BATCHED_NOTIFICATIONS.value,
)
log(f"Added job {job.id} with cron schedule '{cron}' input data: {data}")
return NotificationJobInfo.from_db(job_args, job)
@expose
def execute_process_weekly_summary(self):
process_weekly_summary()
def add_weekly_notification_schedule(self, cron: str) -> NotificationJobInfo:
@expose
def execute_report_late_executions(self):
return report_late_executions()
class SchedulerClient(AppServiceClient):
@classmethod
def get_service_type(cls):
return Scheduler
add_execution_schedule = endpoint_to_async(Scheduler.add_graph_execution_schedule)
delete_schedule = endpoint_to_async(Scheduler.delete_graph_execution_schedule)
get_execution_schedules = endpoint_to_async(Scheduler.get_graph_execution_schedules)
job = self.scheduler.add_job(
process_weekly_summary,
CronTrigger.from_crontab(cron),
kwargs={},
replace_existing=True,
jobstore=Jobstores.WEEKLY_NOTIFICATIONS.value,
)
log(f"Added job {job.id} with cron schedule '{cron}'")
return NotificationJobInfo.from_db(
NotificationJobArgs(
cron=cron, notification_types=[NotificationType.WEEKLY_SUMMARY]
),
job,
)

View File

@@ -41,7 +41,7 @@ from backend.util.settings import Config
from backend.util.type import convert
if TYPE_CHECKING:
from backend.executor import DatabaseManagerClient
from backend.executor import DatabaseManager
from backend.integrations.credentials_store import IntegrationCredentialsStore
config = Config()
@@ -82,32 +82,40 @@ def get_integration_credentials_store() -> "IntegrationCredentialsStore":
@thread_cached
def get_db_client() -> "DatabaseManagerClient":
from backend.executor import DatabaseManagerClient
def get_db_client() -> "DatabaseManager":
from backend.executor import DatabaseManager
return get_service_client(DatabaseManagerClient)
return get_service_client(DatabaseManager)
# ============ Execution Cost Helpers ============ #
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
node_exec_id: str | None = None
block_id: str | None = None
block: str | None = None
input: BlockInput | None = None
def execution_usage_cost(execution_count: int) -> tuple[int, int]:
"""
Calculate the cost of executing a graph based on the current number of node executions.
Calculate the cost of executing a graph based on the number of executions.
Args:
execution_count: Number of node executions
execution_count: Number of executions
Returns:
Tuple of cost amount and the number of execution count that is included in the cost.
Tuple of cost amount and remaining execution count
"""
return (
(
config.execution_cost_per_threshold
if execution_count % config.execution_cost_count_threshold == 0
else 0
),
config.execution_cost_count_threshold,
execution_count
// config.execution_cost_count_threshold
* config.execution_cost_per_threshold,
execution_count % config.execution_cost_count_threshold,
)
@@ -258,7 +266,7 @@ def validate_exec(
If the data is valid, the first element will be the resolved input data, and
the second element will be the block name.
"""
node_block = get_block(node.block_id)
node_block: Block | None = get_block(node.block_id)
if not node_block:
return None, f"Block for {node.block_id} not found."
schema = node_block.input_schema
@@ -608,10 +616,7 @@ async def add_graph_execution_async(
ValueError: If the graph is not found or if there are validation errors.
""" # noqa
graph: GraphModel | None = await get_graph(
graph_id=graph_id,
user_id=user_id,
version=graph_version,
include_subgraphs=True,
graph_id=graph_id, user_id=user_id, version=graph_version
)
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
@@ -671,9 +676,6 @@ def add_graph_execution(
preset_id: Optional[str] = None,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
) -> GraphExecutionWithNodes:
"""
Adds a graph execution to the queue and returns the execution entry.
@@ -686,7 +688,6 @@ def add_graph_execution(
graph_version: The version of the graph to execute.
graph_credentials_inputs: Credentials inputs to use in the execution.
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
node_credentials_input_map: Credentials inputs to use in the execution, mapped to specific nodes.
Returns:
GraphExecutionEntry: The entry for the graph execution.
Raises:
@@ -694,15 +695,12 @@ def add_graph_execution(
"""
db = get_db_client()
graph: GraphModel | None = db.get_graph(
graph_id=graph_id,
user_id=user_id,
version=graph_version,
include_subgraphs=True,
graph_id=graph_id, user_id=user_id, version=graph_version
)
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
node_credentials_input_map = node_credentials_input_map or (
node_credentials_input_map = (
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else None

View File

@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional
from pydantic import SecretStr
if TYPE_CHECKING:
from backend.executor.database import DatabaseManagerClient
from backend.executor.database import DatabaseManager
from autogpt_libs.utils.cache import thread_cached
from autogpt_libs.utils.synchronize import RedisKeyedMutex
@@ -210,11 +210,11 @@ class IntegrationCredentialsStore:
@property
@thread_cached
def db_manager(self) -> "DatabaseManagerClient":
from backend.executor.database import DatabaseManagerClient
def db_manager(self) -> "DatabaseManager":
from backend.executor.database import DatabaseManager
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerClient)
return get_service_client(DatabaseManager)
def add_creds(self, user_id: str, credentials: Credentials) -> None:
with self.locked_user_integrations(user_id):

View File

@@ -93,7 +93,7 @@ class IntegrationCredentialsManager:
fresh_credentials = oauth_handler.refresh_tokens(credentials)
self.store.update_creds(user_id, fresh_credentials)
if _lock and _lock.locked() and _lock.owned():
if _lock and _lock.locked():
_lock.release()
credentials = fresh_credentials
@@ -145,7 +145,7 @@ class IntegrationCredentialsManager:
try:
yield
finally:
if lock.locked() and lock.owned():
if lock.locked():
lock.release()
def release_all_locks(self):

View File

@@ -16,7 +16,6 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
from .generic import GenericWebhooksManager
from .github import GithubWebhooksManager
from .slant3d import Slant3DWebhooksManager
from .exa import ExaWebhooksManager
_WEBHOOK_MANAGERS.update(
{
@@ -26,7 +25,6 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
GithubWebhooksManager,
Slant3DWebhooksManager,
GenericWebhooksManager,
ExaWebhooksManager,
]
}
)

View File

@@ -1,119 +0,0 @@
import logging
import requests
from fastapi import Request
from backend.data import integrations
from backend.data.model import APIKeyCredentials, Credentials
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks._base import BaseWebhooksManager
logger = logging.getLogger(__name__)
class ExaWebhooksManager(BaseWebhooksManager):
"""Manager for Exa webhooks"""
PROVIDER_NAME = ProviderName.EXA
BASE_URL = "https://api.exa.ai/websets/v0"
async def _register_webhook(
self,
credentials: Credentials,
webhook_type: str,
resource: str,
events: list[str],
ingress_url: str,
secret: str,
) -> tuple[str, dict]:
"""Register a new webhook with Exa"""
if not isinstance(credentials, APIKeyCredentials):
raise ValueError("API key is required to register a webhook")
headers = {
"x-api-key": credentials.api_key.get_secret_value(),
"Content-Type": "application/json",
}
payload = {
"events": events,
"url": ingress_url,
"metadata": {} # Optional metadata can be added here
}
response = requests.post(
f"{self.BASE_URL}/webhooks", headers=headers, json=payload
)
if not response.ok:
error = response.json().get("error", "Unknown error")
raise RuntimeError(f"Failed to register webhook: {error}")
response_data = response.json()
webhook_id = response_data.get("id", "")
webhook_config = {
"endpoint": ingress_url,
"provider": self.PROVIDER_NAME,
"events": events,
"type": webhook_type,
"webhook_id": webhook_id,
"secret": response_data.get("secret", "")
}
return webhook_id, webhook_config
@classmethod
async def validate_payload(
cls, webhook: integrations.Webhook, request: Request
) -> tuple[dict, str]:
"""Validate incoming webhook payload from Exa"""
payload = await request.json()
# Validate required fields from Exa API spec
required_fields = ["id", "object", "type", "data", "createdAt"]
missing_fields = [field for field in required_fields if field not in payload]
if missing_fields:
raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")
# Normalize payload structure
normalized_payload = {
"id": payload["id"],
"type": payload["type"],
"data": payload["data"],
"createdAt": payload["createdAt"]
}
# Extract event type from the payload
event_type = payload["type"]
return normalized_payload, event_type
async def _deregister_webhook(
self, webhook: integrations.Webhook, credentials: Credentials
) -> None:
"""Deregister a webhook with Exa"""
if not isinstance(credentials, APIKeyCredentials):
raise ValueError("API key is required to deregister a webhook")
webhook_id = webhook.config.get("webhook_id")
if not webhook_id:
logger.warning(f"No webhook ID found for webhook {webhook.id}, cannot deregister")
return
headers = {
"x-api-key": credentials.api_key.get_secret_value(),
}
response = requests.delete(
f"{self.BASE_URL}/webhooks/{webhook_id}", headers=headers
)
if not response.ok:
error = response.json().get("error", "Unknown error")
logger.error(f"Failed to deregister webhook {webhook_id}: {error}")
raise RuntimeError(f"Failed to deregister webhook: {error}")

View File

@@ -1,6 +1,5 @@
from .notifications import NotificationManager, NotificationManagerClient
from .notifications import NotificationManager
__all__ = [
"NotificationManager",
"NotificationManagerClient",
]

View File

@@ -1,6 +1,5 @@
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone
from typing import Callable
@@ -8,18 +7,20 @@ import aio_pika
from aio_pika.exceptions import QueueEmpty
from autogpt_libs.utils.cache import thread_cached
from prisma.enums import NotificationType
from pydantic import BaseModel
from backend.data import rabbitmq
from backend.data.notifications import (
BaseEventModel,
BaseSummaryData,
BaseSummaryParams,
DailySummaryData,
DailySummaryParams,
NotificationEventDTO,
NotificationEventModel,
NotificationResult,
NotificationTypeOverride,
QueueType,
SummaryParamsEventDTO,
SummaryParamsEventModel,
WeeklySummaryData,
WeeklySummaryParams,
@@ -27,178 +28,96 @@ from backend.data.notifications import (
get_notif_data_type,
get_summary_params_type,
)
from backend.data.rabbitmq import (
AsyncRabbitMQ,
Exchange,
ExchangeType,
Queue,
RabbitMQConfig,
SyncRabbitMQ,
)
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.user import generate_unsubscribe_link
from backend.notifications.email import EmailSender
from backend.util.metrics import discord_send_alert
from backend.util.service import (
AppService,
AppServiceClient,
expose,
get_service_client,
)
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
NOTIFICATION_EXCHANGE = Exchange(name="notifications", type=ExchangeType.TOPIC)
DEAD_LETTER_EXCHANGE = Exchange(name="dead_letter", type=ExchangeType.TOPIC)
EXCHANGES = [NOTIFICATION_EXCHANGE, DEAD_LETTER_EXCHANGE]
background_executor = ThreadPoolExecutor(max_workers=2)
class NotificationEvent(BaseModel):
event: NotificationEventDTO
model: NotificationEventModel
def create_notification_config() -> RabbitMQConfig:
"""Create RabbitMQ configuration for notifications"""
notification_exchange = Exchange(name="notifications", type=ExchangeType.TOPIC)
dead_letter_exchange = Exchange(name="dead_letter", type=ExchangeType.TOPIC)
queues = [
# Main notification queues
Queue(
name="immediate_notifications",
exchange=NOTIFICATION_EXCHANGE,
exchange=notification_exchange,
routing_key="notification.immediate.#",
arguments={
"x-dead-letter-exchange": DEAD_LETTER_EXCHANGE.name,
"x-dead-letter-exchange": dead_letter_exchange.name,
"x-dead-letter-routing-key": "failed.immediate",
},
),
Queue(
name="admin_notifications",
exchange=NOTIFICATION_EXCHANGE,
exchange=notification_exchange,
routing_key="notification.admin.#",
arguments={
"x-dead-letter-exchange": DEAD_LETTER_EXCHANGE.name,
"x-dead-letter-exchange": dead_letter_exchange.name,
"x-dead-letter-routing-key": "failed.admin",
},
),
# Summary notification queues
Queue(
name="summary_notifications",
exchange=NOTIFICATION_EXCHANGE,
exchange=notification_exchange,
routing_key="notification.summary.#",
arguments={
"x-dead-letter-exchange": DEAD_LETTER_EXCHANGE.name,
"x-dead-letter-exchange": dead_letter_exchange.name,
"x-dead-letter-routing-key": "failed.summary",
},
),
# Batch Queue
Queue(
name="batch_notifications",
exchange=NOTIFICATION_EXCHANGE,
exchange=notification_exchange,
routing_key="notification.batch.#",
arguments={
"x-dead-letter-exchange": DEAD_LETTER_EXCHANGE.name,
"x-dead-letter-exchange": dead_letter_exchange.name,
"x-dead-letter-routing-key": "failed.batch",
},
),
# Failed notifications queue
Queue(
name="failed_notifications",
exchange=DEAD_LETTER_EXCHANGE,
exchange=dead_letter_exchange,
routing_key="failed.#",
),
]
return RabbitMQConfig(
exchanges=EXCHANGES,
exchanges=[
notification_exchange,
dead_letter_exchange,
],
queues=queues,
)
@thread_cached
def get_scheduler():
from backend.executor import Scheduler
return get_service_client(Scheduler)
@thread_cached
def get_db():
from backend.executor.database import DatabaseManagerClient
from backend.executor.database import DatabaseManager
return get_service_client(DatabaseManagerClient)
@thread_cached
def get_notification_queue() -> SyncRabbitMQ:
client = SyncRabbitMQ(create_notification_config())
client.connect()
return client
@thread_cached
async def get_async_notification_queue() -> AsyncRabbitMQ:
client = AsyncRabbitMQ(create_notification_config())
await client.connect()
return client
def get_routing_key(event_type: NotificationType) -> str:
strategy = NotificationTypeOverride(event_type).strategy
"""Get the appropriate routing key for an event"""
if strategy == QueueType.IMMEDIATE:
return f"notification.immediate.{event_type.value}"
elif strategy == QueueType.BACKOFF:
return f"notification.backoff.{event_type.value}"
elif strategy == QueueType.ADMIN:
return f"notification.admin.{event_type.value}"
elif strategy == QueueType.BATCH:
return f"notification.batch.{event_type.value}"
elif strategy == QueueType.SUMMARY:
return f"notification.summary.{event_type.value}"
return f"notification.{event_type.value}"
def queue_notification(event: NotificationEventModel) -> NotificationResult:
"""Queue a notification - exposed method for other services to call"""
try:
logger.debug(f"Received Request to queue {event=}")
exchange = "notifications"
routing_key = get_routing_key(event.type)
queue = get_notification_queue()
queue.publish_message(
routing_key=routing_key,
message=event.model_dump_json(),
exchange=next(ex for ex in EXCHANGES if ex.name == exchange),
)
return NotificationResult(
success=True,
message=f"Notification queued with routing key: {routing_key}",
)
except Exception as e:
logger.exception(f"Error queueing notification: {e}")
return NotificationResult(success=False, message=str(e))
async def queue_notification_async(event: NotificationEventModel) -> NotificationResult:
"""Queue a notification - exposed method for other services to call"""
try:
logger.debug(f"Received Request to queue {event=}")
exchange = "notifications"
routing_key = get_routing_key(event.type)
queue = await get_async_notification_queue()
await queue.publish_message(
routing_key=routing_key,
message=event.model_dump_json(),
exchange=next(ex for ex in EXCHANGES if ex.name == exchange),
)
return NotificationResult(
success=True,
message=f"Notification queued with routing key: {routing_key}",
)
except Exception as e:
logger.exception(f"Error queueing notification: {e}")
return NotificationResult(success=False, message=str(e))
return get_service_client(DatabaseManager)
class NotificationManager(AppService):
@@ -228,11 +147,23 @@ class NotificationManager(AppService):
def get_port(cls) -> int:
return settings.config.notification_service_port
def get_routing_key(self, event_type: NotificationType) -> str:
strategy = NotificationTypeOverride(event_type).strategy
"""Get the appropriate routing key for an event"""
if strategy == QueueType.IMMEDIATE:
return f"notification.immediate.{event_type.value}"
elif strategy == QueueType.BACKOFF:
return f"notification.backoff.{event_type.value}"
elif strategy == QueueType.ADMIN:
return f"notification.admin.{event_type.value}"
elif strategy == QueueType.BATCH:
return f"notification.batch.{event_type.value}"
elif strategy == QueueType.SUMMARY:
return f"notification.summary.{event_type.value}"
return f"notification.{event_type.value}"
@expose
def queue_weekly_summary(self):
background_executor.submit(self._queue_weekly_summary)
def _queue_weekly_summary(self):
"""Process weekly summary for specified notification types"""
try:
logger.info("Processing weekly summary queuing operation")
@@ -246,13 +177,13 @@ class NotificationManager(AppService):
for user in users:
self._queue_scheduled_notification(
SummaryParamsEventModel(
SummaryParamsEventDTO(
user_id=user,
type=NotificationType.WEEKLY_SUMMARY,
data=WeeklySummaryParams(
start_date=start_time,
end_date=current_time,
),
).model_dump(),
),
)
processed_count += 1
@@ -264,9 +195,6 @@ class NotificationManager(AppService):
@expose
def process_existing_batches(self, notification_types: list[NotificationType]):
background_executor.submit(self._process_existing_batches, notification_types)
def _process_existing_batches(self, notification_types: list[NotificationType]):
"""Process existing batches for specified notification types"""
try:
processed_count = 0
@@ -386,23 +314,65 @@ class NotificationManager(AppService):
}
@expose
def discord_system_alert(self, content: str):
discord_send_alert(content)
def _queue_scheduled_notification(self, event: SummaryParamsEventModel):
"""Queue a scheduled notification - exposed method for other services to call"""
def queue_notification(self, event: NotificationEventDTO) -> NotificationResult:
"""Queue a notification - exposed method for other services to call"""
try:
logger.debug(f"Received Request to queue scheduled notification {event=}")
logger.info(f"Received Request to queue {event=}")
# Workaround for not being able to serialize generics over the expose bus
parsed_event = NotificationEventModel[
get_notif_data_type(event.type)
].model_validate(event.model_dump())
routing_key = self.get_routing_key(parsed_event.type)
message = parsed_event.model_dump_json()
logger.info(f"Received Request to queue {message=}")
exchange = "notifications"
routing_key = get_routing_key(event.type)
# Publish to RabbitMQ
self.run_and_wait(
self.rabbit.publish_message(
routing_key=routing_key,
message=event.model_dump_json(),
exchange=next(ex for ex in EXCHANGES if ex.name == exchange),
message=message,
exchange=next(
ex for ex in self.rabbit_config.exchanges if ex.name == exchange
),
)
)
return NotificationResult(
success=True,
message=f"Notification queued with routing key: {routing_key}",
)
except Exception as e:
logger.exception(f"Error queueing notification: {e}")
return NotificationResult(success=False, message=str(e))
def _queue_scheduled_notification(self, event: SummaryParamsEventDTO):
"""Queue a scheduled notification - exposed method for other services to call"""
try:
logger.info(f"Received Request to queue scheduled notification {event=}")
parsed_event = SummaryParamsEventModel[
get_summary_params_type(event.type)
].model_validate(event.model_dump())
routing_key = self.get_routing_key(event.type)
message = parsed_event.model_dump_json()
logger.info(f"Received Request to queue {message=}")
exchange = "notifications"
# Publish to RabbitMQ
self.run_and_wait(
self.rabbit.publish_message(
routing_key=routing_key,
message=message,
exchange=next(
ex for ex in self.rabbit_config.exchanges if ex.name == exchange
),
)
)
@@ -528,12 +498,13 @@ class NotificationManager(AppService):
)
return False
def _parse_message(self, message: str) -> NotificationEventModel | None:
def _parse_message(self, message: str) -> NotificationEvent | None:
try:
event = BaseEventModel.model_validate_json(message)
return NotificationEventModel[
event = NotificationEventDTO.model_validate_json(message)
model = NotificationEventModel[
get_notif_data_type(event.type)
].model_validate_json(message)
return NotificationEvent(event=event, model=model)
except Exception as e:
logger.error(f"Error parsing message due to non matching schema {e}")
return None
@@ -541,12 +512,14 @@ class NotificationManager(AppService):
def _process_admin_message(self, message: str) -> bool:
"""Process a single notification, sending to an admin, returning whether to put into the failed queue"""
try:
event = self._parse_message(message)
if not event:
parsed = self._parse_message(message)
if not parsed:
return False
logger.debug(f"Processing notification for admin: {event}")
event = parsed.event
model = parsed.model
logger.debug(f"Processing notification for admin: {model}")
recipient_email = settings.config.refund_notification_email
self.email_sender.send_templated(event.type, recipient_email, event)
self.email_sender.send_templated(event.type, recipient_email, model)
return True
except Exception as e:
logger.exception(f"Error processing notification for admin queue: {e}")
@@ -555,10 +528,12 @@ class NotificationManager(AppService):
def _process_immediate(self, message: str) -> bool:
"""Process a single notification immediately, returning whether to put into the failed queue"""
try:
event = self._parse_message(message)
if not event:
parsed = self._parse_message(message)
if not parsed:
return False
logger.debug(f"Processing immediate notification: {event}")
event = parsed.event
model = parsed.model
logger.debug(f"Processing immediate notification: {model}")
recipient_email = get_db().get_user_email_by_id(event.user_id)
if not recipient_email:
@@ -579,7 +554,7 @@ class NotificationManager(AppService):
self.email_sender.send_templated(
notification=event.type,
user_email=recipient_email,
data=event,
data=model,
user_unsub_link=unsub_link,
)
return True
@@ -590,10 +565,12 @@ class NotificationManager(AppService):
def _process_batch(self, message: str) -> bool:
"""Process a single notification with a batching strategy, returning whether to put into the failed queue"""
try:
event = self._parse_message(message)
if not event:
parsed = self._parse_message(message)
if not parsed:
return False
logger.info(f"Processing batch notification: {event}")
event = parsed.event
model = parsed.model
logger.info(f"Processing batch notification: {model}")
recipient_email = get_db().get_user_email_by_id(event.user_id)
if not recipient_email:
@@ -609,7 +586,7 @@ class NotificationManager(AppService):
)
return True
should_send = self._should_batch(event.user_id, event.type, event)
should_send = self._should_batch(event.user_id, event.type, model)
if not should_send:
logger.info("Batch not old enough to send")
@@ -651,7 +628,7 @@ class NotificationManager(AppService):
"""Process a single notification with a summary strategy, returning whether to put into the failed queue"""
try:
logger.info(f"Processing summary notification: {message}")
event = BaseEventModel.model_validate_json(message)
event = SummaryParamsEventDTO.model_validate_json(message)
model = SummaryParamsEventModel[
get_summary_params_type(event.type)
].model_validate_json(message)
@@ -732,6 +709,22 @@ class NotificationManager(AppService):
logger.info(f"[{self.service_name}] Started notification service")
# Set up scheduler for batch processing of all notification types
# this can be changed later to spawn different cleanups on different schedules
try:
get_scheduler().add_batched_notification_schedule(
notification_types=list(NotificationType),
data={},
cron="0 * * * *",
)
# get_scheduler().add_weekly_notification_schedule(
# # weekly on Friday at 12pm
# cron="0 12 * * 5",
# )
logger.info("Scheduled notification cleanup")
except Exception as e:
logger.error(f"Error scheduling notification cleanup: {e}")
# Set up queue consumers
channel = self.run_and_wait(self.rabbit.get_channel())
@@ -781,13 +774,3 @@ class NotificationManager(AppService):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
self.run_and_wait(self.rabbitmq_service.disconnect())
class NotificationManagerClient(AppServiceClient):
@classmethod
def get_service_type(cls):
return NotificationManager
process_existing_batches = NotificationManager.process_existing_batches
queue_weekly_summary = NotificationManager.queue_weekly_summary
discord_system_alert = NotificationManager.discord_system_alert

View File

@@ -1,5 +1,5 @@
from backend.app import run_processes
from backend.executor import DatabaseManager
from backend.executor import DatabaseManager, Scheduler
from backend.notifications.notifications import NotificationManager
from backend.server.rest_api import AgentServer
@@ -11,6 +11,7 @@ def main():
run_processes(
NotificationManager(),
DatabaseManager(),
Scheduler(),
AgentServer(),
)

View File

@@ -1,13 +0,0 @@
from backend.app import run_processes
from backend.executor.scheduler import Scheduler
def main():
"""
Run all the processes required for the AutoGPT-server Scheduling System.
"""
run_processes(Scheduler())
if __name__ == "__main__":
main()

View File

@@ -122,7 +122,7 @@ async def get_graph_execution_results(
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
results = await execution_db.get_node_executions(graph_exec_id)
results = await execution_db.get_node_execution_results(graph_exec_id)
last_result = results[-1] if results else None
execution_status = (
last_result.status if last_result else AgentExecutionStatus.INCOMPLETE

View File

@@ -19,7 +19,6 @@ import backend.data.graph
import backend.data.user
import backend.server.routers.postmark.postmark
import backend.server.routers.v1
import backend.server.v2.admin.credit_admin_routes
import backend.server.v2.admin.store_admin_routes
import backend.server.v2.library.db
import backend.server.v2.library.model
@@ -27,7 +26,6 @@ import backend.server.v2.library.routes
import backend.server.v2.otto.routes
import backend.server.v2.store.model
import backend.server.v2.store.routes
import backend.server.v2.turnstile.routes
import backend.util.service
import backend.util.settings
from backend.blocks.llm import LlmModel
@@ -109,20 +107,12 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/store",
)
app.include_router(
backend.server.v2.admin.credit_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/credits",
)
app.include_router(
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
)
app.include_router(
backend.server.v2.otto.routes.router, tags=["v2"], prefix="/api/otto"
)
app.include_router(
backend.server.v2.turnstile.routes.router, tags=["v2"], prefix="/api/turnstile"
)
app.include_router(
backend.server.routers.postmark.postmark.router,

View File

@@ -57,7 +57,7 @@ from backend.data.user import (
update_user_email,
update_user_notification_preference,
)
from backend.executor import scheduler
from backend.executor import Scheduler, scheduler
from backend.executor import utils as execution_utils
from backend.executor.utils import create_execution_queue_config
from backend.integrations.creds_manager import IntegrationCredentialsManager
@@ -83,8 +83,8 @@ if TYPE_CHECKING:
@thread_cached
def execution_scheduler_client() -> scheduler.SchedulerClient:
return get_service_client(scheduler.SchedulerClient)
def execution_scheduler_client() -> Scheduler:
return get_service_client(Scheduler)
@thread_cached
@@ -422,11 +422,7 @@ async def get_graph(
for_export: bool = False,
) -> graph_db.GraphModel:
graph = await graph_db.get_graph(
graph_id,
version,
user_id=user_id,
for_export=for_export,
include_subgraphs=True, # needed to construct full credentials input schema
graph_id, version, user_id=user_id, for_export=for_export
)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
@@ -667,7 +663,7 @@ async def _cancel_execution(graph_exec_id: str):
)
node_execs = [
node_exec.model_copy(update={"status": execution_db.ExecutionStatus.TERMINATED})
for node_exec in await execution_db.get_node_executions(
for node_exec in await execution_db.get_node_execution_results(
graph_exec_id=graph_exec_id,
statuses=[
execution_db.ExecutionStatus.QUEUED,
@@ -773,7 +769,7 @@ class ScheduleCreationRequest(pydantic.BaseModel):
async def create_schedule(
user_id: Annotated[str, Depends(get_user_id)],
schedule: ScheduleCreationRequest,
) -> scheduler.GraphExecutionJobInfo:
) -> scheduler.ExecutionJobInfo:
graph = await graph_db.get_graph(
schedule.graph_id, schedule.graph_version, user_id=user_id
)
@@ -783,12 +779,14 @@ async def create_schedule(
detail=f"Graph #{schedule.graph_id} v.{schedule.graph_version} not found.",
)
return await execution_scheduler_client().add_execution_schedule(
graph_id=schedule.graph_id,
graph_version=graph.version,
cron=schedule.cron,
input_data=schedule.input_data,
user_id=user_id,
return await asyncio.to_thread(
lambda: execution_scheduler_client().add_execution_schedule(
graph_id=schedule.graph_id,
graph_version=graph.version,
cron=schedule.cron,
input_data=schedule.input_data,
user_id=user_id,
)
)
@@ -797,11 +795,11 @@ async def create_schedule(
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
async def delete_schedule(
def delete_schedule(
schedule_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
return {"id": schedule_id}
@@ -810,11 +808,11 @@ async def delete_schedule(
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
async def get_execution_schedules(
def get_execution_schedules(
user_id: Annotated[str, Depends(get_user_id)],
graph_id: str | None = None,
) -> list[scheduler.GraphExecutionJobInfo]:
return await execution_scheduler_client().get_execution_schedules(
) -> list[scheduler.ExecutionJobInfo]:
return execution_scheduler_client().get_execution_schedules(
user_id=user_id,
graph_id=graph_id,
)

View File

@@ -1,77 +0,0 @@
import logging
import typing
from autogpt_libs.auth import requires_admin_user
from autogpt_libs.auth.depends import get_user_id
from fastapi import APIRouter, Body, Depends
from prisma import Json
from prisma.enums import CreditTransactionType
from backend.data.credit import admin_get_user_history, get_user_credit_model
from backend.server.v2.admin.model import AddUserCreditsResponse, UserHistoryResponse
logger = logging.getLogger(__name__)
_user_credit_model = get_user_credit_model()
router = APIRouter(
prefix="/admin",
tags=["credits", "admin"],
dependencies=[Depends(requires_admin_user)],
)
@router.post("/add_credits", response_model=AddUserCreditsResponse)
async def add_user_credits(
user_id: typing.Annotated[str, Body()],
amount: typing.Annotated[int, Body()],
comments: typing.Annotated[str, Body()],
admin_user: typing.Annotated[
str,
Depends(get_user_id),
],
):
""" """
logger.info(f"Admin user {admin_user} is adding {amount} credits to user {user_id}")
new_balance, transaction_key = await _user_credit_model._add_transaction(
user_id,
amount,
transaction_type=CreditTransactionType.GRANT,
metadata=Json({"admin_id": admin_user, "reason": comments}),
)
return {
"new_balance": new_balance,
"transaction_key": transaction_key,
}
@router.get(
"/users_history",
response_model=UserHistoryResponse,
)
async def admin_get_all_user_history(
admin_user: typing.Annotated[
str,
Depends(get_user_id),
],
search: typing.Optional[str] = None,
page: int = 1,
page_size: int = 20,
transaction_filter: typing.Optional[CreditTransactionType] = None,
):
""" """
logger.info(f"Admin user {admin_user} is getting grant history")
try:
resp = await admin_get_user_history(
page=page,
page_size=page_size,
search=search,
transaction_filter=transaction_filter,
)
logger.info(f"Admin user {admin_user} got {len(resp.history)} grant history")
return resp
except Exception as e:
logger.exception(f"Error getting grant history: {e}")
raise e

View File

@@ -1,16 +0,0 @@
from pydantic import BaseModel
from backend.data.model import UserTransaction
from backend.server.model import Pagination
class UserHistoryResponse(BaseModel):
"""Response model for listings with version history"""
history: list[UserTransaction]
pagination: Pagination
class AddUserCreditsResponse(BaseModel):
new_balance: int
transaction_key: str

View File

@@ -1,5 +1,4 @@
import logging
import tempfile
import typing
import autogpt_libs.auth.depends
@@ -10,7 +9,6 @@ import prisma.enums
import backend.server.v2.store.db
import backend.server.v2.store.exceptions
import backend.server.v2.store.model
import backend.util.json
logger = logging.getLogger(__name__)
@@ -100,47 +98,3 @@ async def review_submission(
status_code=500,
content={"detail": "An error occurred while reviewing the submission"},
)
@router.get(
"/submissions/download/{store_listing_version_id}",
tags=["store", "admin"],
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
)
async def admin_download_agent_file(
user: typing.Annotated[
autogpt_libs.auth.models.User,
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
],
store_listing_version_id: str = fastapi.Path(
..., description="The ID of the agent to download"
),
) -> fastapi.responses.FileResponse:
"""
Download the agent file by streaming its content.
Args:
store_listing_version_id (str): The ID of the agent to download
Returns:
StreamingResponse: A streaming response containing the agent's graph data.
Raises:
HTTPException: If the agent is not found or an unexpected error occurs.
"""
graph_data = await backend.server.v2.store.db.get_agent(
user_id=user.user_id,
store_listing_version_id=store_listing_version_id,
)
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
# Sending graph as a stream (similar to marketplace v1)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
tmp_file.write(backend.util.json.dumps(graph_data))
tmp_file.flush()
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)

View File

@@ -13,17 +13,12 @@ import backend.server.v2.library.model as library_model
import backend.server.v2.store.exceptions as store_exceptions
import backend.server.v2.store.image_gen as store_image_gen
import backend.server.v2.store.media as store_media
from backend.data import db
from backend.data import graph as graph_db
from backend.data.db import locked_transaction
from backend.data.includes import library_agent_include
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.settings import Config
logger = logging.getLogger(__name__)
config = Config()
integration_creds_manager = IntegrationCredentialsManager()
async def list_library_agents(
@@ -175,44 +170,6 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
async def get_library_agent_by_store_version_id(
store_listing_version_id: str,
user_id: str,
):
"""
Get the library agent metadata for a given store listing version ID and user ID.
"""
logger.debug(
f"Getting library agent for store listing ID: {store_listing_version_id}"
)
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id},
)
)
if not store_listing_version:
logger.warning(f"Store listing version not found: {store_listing_version_id}")
raise store_exceptions.AgentNotFoundError(
f"Store listing version {store_listing_version_id} not found or invalid"
)
# Check if user already has this agent
agent = await prisma.models.LibraryAgent.prisma().find_first(
where={
"userId": user_id,
"agentGraphId": store_listing_version.agentGraphId,
"agentGraphVersion": store_listing_version.agentGraphVersion,
"isDeleted": False,
},
include={"AgentGraph": True},
)
if agent:
return library_model.LibraryAgent.from_db(agent)
else:
return None
async def add_generated_agent_image(
graph: backend.data.graph.GraphModel,
library_agent_id: str,
@@ -249,7 +206,7 @@ async def add_generated_agent_image(
async def create_library_agent(
graph: backend.data.graph.GraphModel,
user_id: str,
) -> library_model.LibraryAgent:
) -> prisma.models.LibraryAgent:
"""
Adds an agent to the user's library (LibraryAgent table).
@@ -270,7 +227,7 @@ async def create_library_agent(
)
try:
agent = await prisma.models.LibraryAgent.prisma().create(
return await prisma.models.LibraryAgent.prisma().create(
data=prisma.types.LibraryAgentCreateInput(
isCreatedByUser=(user_id == graph.user_id),
useGraphIsActiveVersion=True,
@@ -281,10 +238,8 @@ async def create_library_agent(
"graphVersionId": {"id": graph.id, "version": graph.version}
}
},
),
include={"AgentGraph": True},
)
)
return library_model.LibraryAgent.from_db(agent)
except prisma.errors.PrismaError as e:
logger.error(f"Database error creating agent in library: {e}")
raise store_exceptions.DatabaseError("Failed to create agent in library") from e
@@ -435,6 +390,11 @@ async def add_store_agent_to_library(
)
graph = store_listing_version.AgentGraph
if graph.userId == user_id:
logger.warning(
f"User #{user_id} attempted to add their own agent to their library"
)
raise store_exceptions.DatabaseError("Cannot add own agent to library")
# Check if user already has this agent
existing_library_agent = (
@@ -444,7 +404,7 @@ async def add_store_agent_to_library(
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
},
include={"AgentGraph": True},
include=library_agent_include(user_id),
)
)
if existing_library_agent:
@@ -702,47 +662,3 @@ async def delete_preset(user_id: str, preset_id: str) -> None:
except prisma.errors.PrismaError as e:
logger.error(f"Database error deleting preset: {e}")
raise store_exceptions.DatabaseError("Failed to delete preset") from e
async def fork_library_agent(library_agent_id: str, user_id: str):
"""
Clones a library agent and its underyling graph and nodes (with new ids) for the given user.
Args:
library_agent_id: The ID of the library agent to fork.
user_id: The ID of the user who owns the library agent.
Returns:
The forked LibraryAgent.
Raises:
DatabaseError: If there's an error during the forking process.
"""
logger.debug(f"Forking library agent {library_agent_id} for user {user_id}")
try:
async with db.locked_transaction(f"usr_trx_{user_id}-fork_agent"):
# Fetch the original agent
original_agent = await get_library_agent(library_agent_id, user_id)
# Check if user owns the library agent
# TODO: once we have open/closed sourced agents this needs to be enabled ~kcze
# + update library/agents/[id]/page.tsx agent actions
# if not original_agent.can_access_graph:
# raise store_exceptions.DatabaseError(
# f"User {user_id} cannot access library agent graph {library_agent_id}"
# )
# Fork the underlying graph and nodes
new_graph = await graph_db.fork_graph(
original_agent.graph_id, original_agent.graph_version, user_id
)
new_graph = await on_graph_activate(
new_graph,
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
)
# Create a library agent for the new graph
return await create_library_agent(new_graph, user_id)
except prisma.errors.PrismaError as e:
logger.error(f"Database error cloning library agent: {e}")
raise store_exceptions.DatabaseError("Failed to fork library agent") from e

View File

@@ -22,13 +22,11 @@ async def test_agent_preset_from_db():
userId="test-user-123",
isDeleted=False,
InputPresets=[
prisma.models.AgentNodeExecutionInputOutput.model_validate(
{
"id": "input-123",
"time": datetime.datetime.now(),
"name": "input1",
"data": '{"type": "string", "value": "test value"}',
}
prisma.models.AgentNodeExecutionInputOutput(
id="input-123",
time=datetime.datetime.now(),
name="input1",
data=prisma.Json({"type": "string", "value": "test value"}),
)
],
)

View File

@@ -85,30 +85,6 @@ async def get_library_agent(
return await library_db.get_library_agent(id=library_agent_id, user_id=user_id)
@router.get(
"/marketplace/{store_listing_version_id}/",
tags=["store, library"],
response_model=library_model.LibraryAgent | None,
)
async def get_library_agent_by_store_listing_version_id(
store_listing_version_id: str,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
):
"""
Get Library Agent from Store Listing Version ID.
"""
try:
return await library_db.get_library_agent_by_store_version_id(
store_listing_version_id, user_id
)
except Exception as e:
logger.error(f"Could not fetch library agent from store version ID: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to add agent to library",
) from e
@router.post(
"",
status_code=status.HTTP_201_CREATED,
@@ -214,14 +190,3 @@ async def update_library_agent(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update library agent",
) from e
@router.post("/{library_agent_id}/fork")
async def fork_library_agent(
library_agent_id: str,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> library_model.LibraryAgent:
return await library_db.fork_library_agent(
library_agent_id=library_agent_id,
user_id=user_id,
)

View File

@@ -793,7 +793,6 @@ async def create_store_version(
changes_summary=changes_summary,
version=next_version,
)
except prisma.errors.PrismaError as e:
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to create new store version"
@@ -967,7 +966,7 @@ async def get_my_agents(
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
where=search_filter,
order=[{"updatedAt": "desc"}],
order=[{"agentGraphVersion": "desc"}],
skip=(page - 1) * page_size,
take=page_size,
include={"AgentGraph": True},
@@ -1362,31 +1361,3 @@ async def get_admin_listings_with_versions(
page_size=page_size,
),
)
async def get_agent_as_admin(
user_id: str | None,
store_listing_version_id: str,
) -> GraphModel:
"""Get agent using the version ID and store listing version ID."""
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}
)
)
if not store_listing_version:
raise ValueError(f"Store listing version {store_listing_version_id} not found")
graph = await backend.data.graph.get_graph_as_admin(
user_id=user_id,
graph_id=store_listing_version.agentGraphId,
version=store_listing_version.agentGraphVersion,
for_export=True,
)
if not graph:
raise ValueError(
f"Agent {store_listing_version.agentGraphId} v{store_listing_version.agentGraphVersion} not found"
)
return graph

View File

@@ -4,7 +4,20 @@ from typing import List
import prisma.enums
import pydantic
from backend.server.model import Pagination
class Pagination(pydantic.BaseModel):
total_items: int = pydantic.Field(
description="Total number of items.", examples=[42]
)
total_pages: int = pydantic.Field(
description="Total number of pages.", examples=[97]
)
current_page: int = pydantic.Field(
description="Current_page page number.", examples=[1]
)
page_size: int = pydantic.Field(
description="Number of items per page.", examples=[25]
)
class MyAgent(pydantic.BaseModel):

View File

@@ -1,30 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
class TurnstileVerifyRequest(BaseModel):
"""Request model for verifying a Turnstile token."""
token: str = Field(description="The Turnstile token to verify")
action: Optional[str] = Field(
default=None, description="The action that the user is attempting to perform"
)
class TurnstileVerifyResponse(BaseModel):
"""Response model for the Turnstile verification endpoint."""
success: bool = Field(description="Whether the token verification was successful")
error: Optional[str] = Field(
default=None, description="Error message if verification failed"
)
challenge_timestamp: Optional[str] = Field(
default=None, description="Timestamp of the challenge (ISO format)"
)
hostname: Optional[str] = Field(
default=None, description="Hostname of the site where the challenge was solved"
)
action: Optional[str] = Field(
default=None, description="The action associated with this verification"
)

View File

@@ -1,108 +0,0 @@
import logging
import aiohttp
from fastapi import APIRouter
from backend.util.settings import Settings
from .models import TurnstileVerifyRequest, TurnstileVerifyResponse
logger = logging.getLogger(__name__)
router = APIRouter()
settings = Settings()
@router.post("/verify", response_model=TurnstileVerifyResponse)
async def verify_turnstile_token(
request: TurnstileVerifyRequest,
) -> TurnstileVerifyResponse:
"""
Verify a Cloudflare Turnstile token.
This endpoint verifies a token returned by the Cloudflare Turnstile challenge
on the client side. It returns whether the verification was successful.
"""
logger.info(f"Verifying Turnstile token for action: {request.action}")
return await verify_token(request)
async def verify_token(request: TurnstileVerifyRequest) -> TurnstileVerifyResponse:
"""
Verify a Cloudflare Turnstile token by making a request to the Cloudflare API.
"""
# Get the secret key from settings
turnstile_secret_key = settings.secrets.turnstile_secret_key
turnstile_verify_url = settings.secrets.turnstile_verify_url
if not turnstile_secret_key:
logger.error("Turnstile secret key is not configured")
return TurnstileVerifyResponse(
success=False,
error="CONFIGURATION_ERROR",
challenge_timestamp=None,
hostname=None,
action=None,
)
try:
async with aiohttp.ClientSession() as session:
payload = {
"secret": turnstile_secret_key,
"response": request.token,
}
if request.action:
payload["action"] = request.action
logger.debug(f"Verifying Turnstile token with action: {request.action}")
async with session.post(
turnstile_verify_url,
data=payload,
timeout=aiohttp.ClientTimeout(total=10),
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Turnstile API error: {error_text}")
return TurnstileVerifyResponse(
success=False,
error=f"API_ERROR: {response.status}",
challenge_timestamp=None,
hostname=None,
action=None,
)
data = await response.json()
logger.debug(f"Turnstile API response: {data}")
# Parse the response and return a structured object
return TurnstileVerifyResponse(
success=data.get("success", False),
error=(
data.get("error-codes", None)[0]
if data.get("error-codes")
else None
),
challenge_timestamp=data.get("challenge_timestamp"),
hostname=data.get("hostname"),
action=data.get("action"),
)
except aiohttp.ClientError as e:
logger.error(f"Connection error to Turnstile API: {str(e)}")
return TurnstileVerifyResponse(
success=False,
error=f"CONNECTION_ERROR: {str(e)}",
challenge_timestamp=None,
hostname=None,
action=None,
)
except Exception as e:
logger.error(f"Unexpected error in Turnstile verification: {str(e)}")
return TurnstileVerifyResponse(
success=False,
error=f"UNEXPECTED_ERROR: {str(e)}",
challenge_timestamp=None,
hostname=None,
action=None,
)

View File

@@ -6,6 +6,7 @@ from typing import Protocol
import uvicorn
from autogpt_libs.auth import parse_jwt_token
from autogpt_libs.logging.utils import generate_uvicorn_config
from autogpt_libs.utils.cache import thread_cached
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
from starlette.middleware.cors import CORSMiddleware
@@ -18,7 +19,7 @@ from backend.server.model import (
WSSubscribeGraphExecutionRequest,
WSSubscribeGraphExecutionsRequest,
)
from backend.util.service import AppProcess
from backend.util.service import AppProcess, get_service_client
from backend.util.settings import AppEnvironment, Config, Settings
logger = logging.getLogger(__name__)
@@ -45,6 +46,13 @@ def get_connection_manager():
return _connection_manager
@thread_cached
def get_db_client():
from backend.executor import DatabaseManager
return get_service_client(DatabaseManager)
async def event_broadcaster(manager: ConnectionManager):
try:
event_queue = AsyncRedisExecutionEventBus()

View File

@@ -1,8 +1,6 @@
import asyncio
import logging
import sentry_sdk
from pydantic import SecretStr
from sentry_sdk.integrations.anthropic import AnthropicIntegration
from sentry_sdk.integrations.logging import LoggingIntegration
@@ -24,43 +22,3 @@ def sentry_init():
),
],
)
def sentry_capture_error(error: Exception):
sentry_sdk.capture_exception(error)
sentry_sdk.flush()
def discord_send_alert(content: str):
from backend.blocks.discord import SendDiscordMessageBlock
from backend.data.model import APIKeyCredentials, CredentialsMetaInput, ProviderName
from backend.util.settings import Settings
settings = Settings()
creds = APIKeyCredentials(
provider="discord",
api_key=SecretStr(settings.secrets.discord_bot_token),
title="Provide Discord Bot Token for the platform alert",
expires_at=None,
)
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return SendDiscordMessageBlock().run_once(
SendDiscordMessageBlock.Input(
credentials=CredentialsMetaInput(
id=creds.id,
title=creds.title,
type=creds.type,
provider=ProviderName.DISCORD,
),
message_content=content,
channel_name=settings.config.platform_alert_discord_channel,
),
"status",
credentials=creds,
)

View File

@@ -3,7 +3,7 @@ import os
import signal
import sys
from abc import ABC, abstractmethod
from multiprocessing import Process, get_all_start_methods, set_start_method
from multiprocessing import Process, set_start_method
from typing import Optional
from backend.util.logging import configure_logging
@@ -30,12 +30,7 @@ class AppProcess(ABC):
process: Optional[Process] = None
cleaned_up = False
if "forkserver" in get_all_start_methods():
set_start_method("forkserver", force=True)
else:
logger.warning("Forkserver start method is not available. Using spawn instead.")
set_start_method("spawn", force=True)
set_start_method("spawn", force=True)
configure_logging()
sentry_init()

View File

@@ -73,10 +73,3 @@ def conn_retry(
return async_wrapper if is_coroutine else sync_wrapper
return decorator
func_retry = retry(
reraise=False,
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=1, max=30),
)

View File

@@ -5,10 +5,8 @@ import os
import threading
import time
from abc import ABC, abstractmethod
from functools import cached_property, update_wrapper
from typing import (
Any,
Awaitable,
Callable,
Concatenate,
Coroutine,
@@ -44,15 +42,24 @@ api_call_timeout = config.rpc_client_call_timeout
P = ParamSpec("P")
R = TypeVar("R")
EXPOSED_FLAG = "__exposed__"
def expose(func: C) -> C:
func = getattr(func, "__func__", func)
setattr(func, EXPOSED_FLAG, True)
setattr(func, "__exposed__", True)
return func
def exposed_run_and_wait(
f: Callable[P, Coroutine[None, None, R]]
) -> Callable[Concatenate[object, P], R]:
# TODO:
# This function lies about its return type to make the DynamicClient
# call the function synchronously, fix this when DynamicClient can choose
# to call a function synchronously or asynchronously.
return expose(f) # type: ignore
# --------------------------------------------------
# AppService for IPC service based on HTTP request through FastAPI
# --------------------------------------------------
@@ -196,7 +203,7 @@ class AppService(BaseAppService, ABC):
# Register the exposed API routes.
for attr_name, attr in vars(type(self)).items():
if getattr(attr, EXPOSED_FLAG, False):
if getattr(attr, "__exposed__", False):
route_path = f"/{attr_name}"
self.fastapi_app.add_api_route(
route_path,
@@ -227,53 +234,31 @@ class AppService(BaseAppService, ABC):
AS = TypeVar("AS", bound=AppService)
class AppServiceClient(ABC):
@classmethod
@abstractmethod
def get_service_type(cls) -> Type[AppService]:
pass
def health_check(self):
pass
def close(self):
pass
def close_service_client(client: Any) -> None:
if hasattr(client, "close"):
client.close()
else:
logger.warning(f"Client {client} is not closable")
ASC = TypeVar("ASC", bound=AppServiceClient)
@conn_retry("AppService client", "Creating service client", max_retry=api_comm_retry)
@conn_retry("FastAPI client", "Creating service client", max_retry=api_comm_retry)
def get_service_client(
service_client_type: Type[ASC],
service_type: Type[AS],
call_timeout: int | None = api_call_timeout,
health_check: bool = True,
) -> ASC:
) -> AS:
class DynamicClient:
def __init__(self):
service_type = service_client_type.get_service_type()
host = service_type.get_host()
port = service_type.get_port()
self.base_url = f"http://{host}:{port}".rstrip("/")
@cached_property
def sync_client(self) -> httpx.Client:
return httpx.Client(
self.client = httpx.Client(
base_url=self.base_url,
timeout=call_timeout,
)
@cached_property
def async_client(self) -> httpx.AsyncClient:
return httpx.AsyncClient(
base_url=self.base_url,
timeout=call_timeout,
)
def _handle_call_method_response(
self, response: httpx.Response, method_name: str
) -> Any:
def _call_method(self, method_name: str, **kwargs) -> Any:
try:
response = self.client.post(method_name, json=to_dict(kwargs))
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
@@ -284,103 +269,36 @@ def get_service_client(
*(error.args or [str(e)])
)
def _call_method_sync(self, method_name: str, **kwargs) -> Any:
return self._handle_call_method_response(
method_name=method_name,
response=self.sync_client.post(method_name, json=to_dict(kwargs)),
)
async def _call_method_async(self, method_name: str, **kwargs) -> Any:
return self._handle_call_method_response(
method_name=method_name,
response=await self.async_client.post(
method_name, json=to_dict(kwargs)
),
)
async def aclose(self):
self.sync_client.close()
await self.async_client.aclose()
def close(self):
self.sync_client.close()
def _get_params(self, signature: inspect.Signature, *args, **kwargs) -> dict:
if args:
arg_names = list(signature.parameters.keys())
if arg_names[0] in ("self", "cls"):
arg_names = arg_names[1:]
kwargs.update(dict(zip(arg_names, args)))
return kwargs
def _get_return(self, expected_return: TypeAdapter | None, result: Any) -> Any:
if expected_return:
return expected_return.validate_python(result)
return result
self.client.close()
def __getattr__(self, name: str) -> Callable[..., Any]:
original_func = getattr(service_client_type, name, None)
if original_func is None:
raise AttributeError(
f"Method {name} not found in {service_client_type}"
)
else:
name = original_func.__name__
# Try to get the original function from the service type.
orig_func = getattr(service_type, name, None)
if orig_func is None:
raise AttributeError(f"Method {name} not found in {service_type}")
sig = inspect.signature(original_func)
sig = inspect.signature(orig_func)
ret_ann = sig.return_annotation
if ret_ann != inspect.Signature.empty:
expected_return = TypeAdapter(ret_ann)
else:
expected_return = None
if inspect.iscoroutinefunction(original_func):
def method(*args, **kwargs) -> Any:
if args:
arg_names = list(sig.parameters.keys())
if arg_names[0] in ("self", "cls"):
arg_names = arg_names[1:]
kwargs.update(dict(zip(arg_names, args)))
result = self._call_method(name, **kwargs)
if expected_return:
return expected_return.validate_python(result)
return result
async def async_method(*args, **kwargs) -> Any:
params = self._get_params(sig, *args, **kwargs)
result = await self._call_method_async(name, **params)
return self._get_return(expected_return, result)
return method
return async_method
else:
client = cast(AS, DynamicClient())
client.health_check()
def sync_method(*args, **kwargs) -> Any:
params = self._get_params(sig, *args, **kwargs)
result = self._call_method_sync(name, **params)
return self._get_return(expected_return, result)
return sync_method
client = cast(ASC, DynamicClient())
if health_check:
client.health_check()
return client
def endpoint_to_sync(
func: Callable[Concatenate[Any, P], Awaitable[R]],
) -> Callable[Concatenate[Any, P], R]:
"""
Produce a *typed* stub that **looks** synchronous to the typechecker.
"""
def _stub(*args: P.args, **kwargs: P.kwargs) -> R: # pragma: no cover
raise RuntimeError("should be intercepted by __getattr__")
update_wrapper(_stub, func)
return cast(Callable[Concatenate[Any, P], R], _stub)
def endpoint_to_async(
func: Callable[Concatenate[Any, P], R],
) -> Callable[Concatenate[Any, P], Awaitable[R]]:
"""
The async mirror of `to_sync`.
"""
async def _stub(*args: P.args, **kwargs: P.kwargs) -> R: # pragma: no cover
raise RuntimeError("should be intercepted by __getattr__")
update_wrapper(_stub, func)
return cast(Callable[Concatenate[Any, P], Awaitable[R]], _stub)
return cast(AS, client)

View File

@@ -117,18 +117,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default=1,
description="Cost per execution in cents after each threshold.",
)
execution_counter_expiration_time: int = Field(
default=60 * 60 * 24,
description="Time in seconds after which the execution counter is reset.",
)
execution_late_notification_threshold_secs: int = Field(
default=5 * 60,
description="Time in seconds after which the execution stuck on QUEUED status is considered late.",
)
execution_late_notification_checkrange_secs: int = Field(
default=60 * 60,
description="Time in seconds for how far back to check for the late executions.",
)
model_config = SettingsConfigDict(
env_file=".env",
@@ -149,6 +137,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default=8002,
description="The port for execution manager daemon to run on",
)
execution_manager_loop_max_retry: int = Field(
default=5,
description="The maximum number of retries for the execution manager loop",
)
execution_scheduler_port: int = Field(
default=8003,
@@ -239,10 +231,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default=True,
description="Whether to enable the agent input subtype blocks",
)
platform_alert_discord_channel: str = Field(
default="local-alerts",
description="The Discord channel for the platform",
)
@field_validator("platform_base_url", "frontend_base_url")
@classmethod
@@ -354,16 +342,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
description="The secret key to use for the unsubscribe user by token",
)
# Cloudflare Turnstile credentials
turnstile_secret_key: str = Field(
default="",
description="Cloudflare Turnstile backend secret key",
)
turnstile_verify_url: str = Field(
default="https://challenges.cloudflare.com/turnstile/v0/siteverify",
description="Cloudflare Turnstile verify URL",
)
# OAuth server credentials for integrations
# --8<-- [start:OAuthServerCredentialsExample]
github_client_id: str = Field(default="", description="GitHub OAuth client ID")

View File

@@ -25,7 +25,7 @@ class SpinTestServer:
self.db_api = DatabaseManager()
self.exec_manager = ExecutionManager()
self.agent_server = AgentServer()
self.scheduler = Scheduler(register_system_tasks=False)
self.scheduler = Scheduler()
self.notif_manager = NotificationManager()
@staticmethod

View File

@@ -1,7 +0,0 @@
-- AlterTable
ALTER TABLE "AgentGraph"
ADD COLUMN "forkedFromId" TEXT,
ADD COLUMN "forkedFromVersion" INTEGER;
-- AddForeignKey
ALTER TABLE "AgentGraph" ADD CONSTRAINT "AgentGraph_forkedFromId_forkedFromVersion_fkey" FOREIGN KEY ("forkedFromId", "forkedFromVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@@ -1,5 +0,0 @@
-- CreateIndex
CREATE INDEX "AgentGraphExecution_createdAt_idx" ON "AgentGraphExecution"("createdAt");
-- CreateIndex
CREATE INDEX "AgentNodeExecution_addedTime_idx" ON "AgentNodeExecution"("addedTime");

View File

@@ -1,9 +0,0 @@
-- Rename 'data' input to 'inputs' on all Agent Executor nodes
UPDATE "AgentNode" AS node
SET "constantInput" = jsonb_set(
"constantInput",
'{inputs}',
"constantInput"->'data'
) - 'data'
WHERE node."agentBlockId" = 'e189baac-8c20-45a1-94a7-55177ea42565'
AND node."constantInput" ? 'data';

View File

@@ -3568,21 +3568,6 @@ files = [
[package.dependencies]
tqdm = "*"
[[package]]
name = "prometheus-client"
version = "0.21.1"
description = "Python client for the Prometheus monitoring system."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "prometheus_client-0.21.1-py3-none-any.whl", hash = "sha256:594b45c410d6f4f8888940fe80b5cc2521b305a1fafe1c58609ef715a001f301"},
{file = "prometheus_client-0.21.1.tar.gz", hash = "sha256:252505a722ac04b0456be05c05f75f45d760c2911ffc45f2a06bcaed9f3ae3fb"},
]
[package.extras]
twisted = ["twisted"]
[[package]]
name = "propcache"
version = "0.2.1"
@@ -6325,4 +6310,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.13"
content-hash = "29ccee704d8296c57156daab98bb0cbbf5a43e83526b7f08a14c91fb7a4898f4"
content-hash = "781f77ec77cfce78b34fb57063dcc81df8e9c5a4be9a644033a0c197e0063730"

View File

@@ -64,7 +64,6 @@ websockets = "^14.2"
youtube-transcript-api = "^0.6.2"
zerobouncesdk = "^1.1.1"
# NOTE: please insert new dependencies in their alphabetical location
prometheus-client = "^0.21.1"
[tool.poetry.group.dev.dependencies]
aiohappyeyeballs = "^2.6.1"
@@ -88,7 +87,6 @@ build-backend = "poetry.core.masonry.api"
app = "backend.app:main"
rest = "backend.rest:main"
ws = "backend.ws:main"
scheduler = "backend.scheduler:main"
executor = "backend.exec:main"
cli = "backend.cli:main"
format = "linter:format"

View File

@@ -118,11 +118,6 @@ model AgentGraph {
// This allows us to delete user data with deleting the agent which maybe in use by other users
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
forkedFromId String?
forkedFromVersion Int?
forkedFrom AgentGraph? @relation("AgentGraphForks", fields: [forkedFromId, forkedFromVersion], references: [id, version])
forks AgentGraph[] @relation("AgentGraphForks")
Nodes AgentNode[]
Executions AgentGraphExecution[]
@@ -352,7 +347,6 @@ model AgentGraphExecution {
@@index([agentGraphId, agentGraphVersion])
@@index([userId])
@@index([createdAt])
}
// This model describes the execution of an AgentNode.
@@ -379,7 +373,6 @@ model AgentNodeExecution {
@@index([agentGraphExecutionId])
@@index([agentNodeId])
@@index([addedTime])
}
// This model describes the output of an AgentNodeExecution.

View File

@@ -6,10 +6,10 @@ from prisma.models import CreditTransaction
from backend.blocks.llm import AITextGeneratorBlock
from backend.data.block import get_block
from backend.data.credit import BetaUserCredit, UsageTransactionMetadata
from backend.data.credit import BetaUserCredit
from backend.data.execution import NodeExecutionEntry
from backend.data.user import DEFAULT_USER_ID
from backend.executor.utils import block_usage_cost
from backend.executor.utils import UsageTransactionMetadata, block_usage_cost
from backend.integrations.credentials_store import openai_credentials
from backend.util.test import SpinTestServer
@@ -34,7 +34,7 @@ async def spend_credits(entry: NodeExecutionEntry) -> int:
if not block:
raise RuntimeError(f"Block {entry.block_id} not found")
cost, matching_filter = block_usage_cost(block=block, input_data=entry.inputs)
cost, matching_filter = block_usage_cost(block=block, input_data=entry.data)
await user_credit.spend_credits(
entry.user_id,
cost,
@@ -46,7 +46,6 @@ async def spend_credits(entry: NodeExecutionEntry) -> int:
block_id=entry.block_id,
block=entry.block_id,
input=matching_filter,
reason=f"Ran block {entry.block_id} {block.name}",
),
)
@@ -67,7 +66,7 @@ async def test_block_credit_usage(server: SpinTestServer):
graph_exec_id="test_graph_exec",
node_exec_id="test_node_exec",
block_id=AITextGeneratorBlock().id,
inputs={
data={
"model": "gpt-4-turbo",
"credentials": {
"id": openai_credentials.id,
@@ -87,7 +86,7 @@ async def test_block_credit_usage(server: SpinTestServer):
graph_exec_id="test_graph_exec",
node_exec_id="test_node_exec",
block_id=AITextGeneratorBlock().id,
inputs={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
data={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
),
)
assert spending_amount_2 == 0

View File

@@ -1,7 +1,7 @@
import pytest
from backend.data import db
from backend.executor.scheduler import SchedulerClient
from backend.executor import Scheduler
from backend.server.model import CreateGraph
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.service import get_service_client
@@ -17,11 +17,11 @@ async def test_agent_schedule(server: SpinTestServer):
user_id=test_user.id,
)
scheduler = get_service_client(SchedulerClient)
schedules = await scheduler.get_execution_schedules(test_graph.id, test_user.id)
scheduler = get_service_client(Scheduler)
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
assert len(schedules) == 0
schedule = await scheduler.add_execution_schedule(
schedule = scheduler.add_execution_schedule(
graph_id=test_graph.id,
user_id=test_user.id,
graph_version=1,
@@ -30,12 +30,10 @@ async def test_agent_schedule(server: SpinTestServer):
)
assert schedule
schedules = await scheduler.get_execution_schedules(test_graph.id, test_user.id)
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
assert len(schedules) == 1
assert schedules[0].cron == "0 0 * * *"
await scheduler.delete_schedule(schedule.id, user_id=test_user.id)
schedules = await scheduler.get_execution_schedules(
test_graph.id, user_id=test_user.id
)
scheduler.delete_schedule(schedule.id, user_id=test_user.id)
schedules = scheduler.get_execution_schedules(test_graph.id, user_id=test_user.id)
assert len(schedules) == 0

View File

@@ -1,12 +1,6 @@
import pytest
from backend.util.service import (
AppService,
AppServiceClient,
endpoint_to_async,
expose,
get_service_client,
)
from backend.util.service import AppService, expose, get_service_client
TEST_SERVICE_PORT = 8765
@@ -38,25 +32,10 @@ class ServiceTest(AppService):
return self.run_and_wait(add_async(a, b))
class ServiceTestClient(AppServiceClient):
@classmethod
def get_service_type(cls):
return ServiceTest
add = ServiceTest.add
subtract = ServiceTest.subtract
fun_with_async = ServiceTest.fun_with_async
add_async = endpoint_to_async(ServiceTest.add)
subtract_async = endpoint_to_async(ServiceTest.subtract)
@pytest.mark.asyncio(loop_scope="session")
async def test_service_creation(server):
with ServiceTest():
client = get_service_client(ServiceTestClient)
client = get_service_client(ServiceTest)
assert client.add(5, 3) == 8
assert client.subtract(10, 4) == 6
assert client.fun_with_async(5, 3) == 8
assert await client.add_async(5, 3) == 8
assert await client.subtract_async(10, 4) == 6

View File

@@ -73,8 +73,6 @@ services:
condition: service_completed_successfully
rabbitmq:
condition: service_healthy
# scheduler_server:
# condition: service_healthy
environment:
- SUPABASE_URL=http://kong:8000
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
@@ -90,7 +88,7 @@ services:
- REDIS_PASSWORD=password
- ENABLE_AUTH=true
- PYRO_HOST=0.0.0.0
- SCHEDULER_HOST=scheduler_server
- EXECUTIONSCHEDULER_HOST=rest_server
- EXECUTIONMANAGER_HOST=executor
- NOTIFICATIONMANAGER_HOST=rest_server
- FRONTEND_BASE_URL=http://localhost:3000
@@ -100,6 +98,7 @@ services:
ports:
- "8006:8006"
- "8007:8007"
- "8003:8003" # execution scheduler
networks:
- app-network
@@ -143,7 +142,7 @@ services:
- NOTIFICATIONMANAGER_HOST=rest_server
- ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw= # DO NOT USE IN PRODUCTION!!
ports:
- "8002:8002"
- "8002:8000"
networks:
- app-network
@@ -188,61 +187,6 @@ services:
networks:
- app-network
scheduler_server:
build:
context: ../
dockerfile: autogpt_platform/backend/Dockerfile
target: server
command: ["python", "-m", "backend.scheduler"]
develop:
watch:
- path: ./
target: autogpt_platform/backend/
action: rebuild
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
rabbitmq:
condition: service_healthy
migrate:
condition: service_completed_successfully
# healthcheck:
# test:
# [
# "CMD",
# "curl",
# "-f",
# "-X",
# "POST",
# "http://localhost:8003/health_check",
# ]
# interval: 10s
# timeout: 10s
# retries: 5
environment:
- DATABASEMANAGER_HOST=rest_server
- NOTIFICATIONMANAGER_HOST=rest_server
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- REDIS_HOST=redis
- REDIS_PORT=6379
- REDIS_PASSWORD=password
- RABBITMQ_HOST=rabbitmq
- RABBITMQ_PORT=5672
- RABBITMQ_DEFAULT_USER=rabbitmq_user_default
- RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
- ENABLE_AUTH=true
- PYRO_HOST=0.0.0.0
- BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
ports:
- "8003:8003"
networks:
- app-network
# frontend:
# build:
# context: ../

View File

@@ -57,17 +57,11 @@ services:
file: ./docker-compose.platform.yml
service: websocket_server
scheduler_server:
<<: *agpt-services
extends:
file: ./docker-compose.platform.yml
service: scheduler_server
# frontend:
# <<: *agpt-services
# extends:
# file: ./docker-compose.platform.yml
# service: frontend
# frontend:
# <<: *agpt-services
# extends:
# file: ./docker-compose.platform.yml
# service: frontend
# Supabase services
studio:

View File

@@ -24,9 +24,3 @@ GA_MEASUREMENT_ID=G-FH2XK2W4GN
# When running locally, set NEXT_PUBLIC_BEHAVE_AS=CLOUD to use the a locally hosted marketplace (as is typical in development, and the cloud deployment), otherwise set it to LOCAL to have the marketplace open in a new tab
NEXT_PUBLIC_BEHAVE_AS=LOCAL
NEXT_PUBLIC_SHOW_BILLING_PAGE=false
## Cloudflare Turnstile (CAPTCHA) Configuration
## Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
## This is the frontend site key
NEXT_PUBLIC_CLOUDFLARE_TURNSTILE_SITE_KEY=

View File

@@ -1,8 +1,9 @@
import type { Preview } from "@storybook/react";
import { initialize, mswLoader } from "msw-storybook-addon";
import "../src/app/globals.css";
import { Providers } from "../src/app/providers";
// Initialize MSW
import React from "react";
initialize();
const preview: Preview = {
@@ -18,6 +19,17 @@ const preview: Preview = {
},
},
loaders: [mswLoader],
decorators: [
(Story, context) => {
const mockOptions = context.parameters.mockBackend || {};
return (
<Providers useMockBackend mockClientProps={mockOptions}>
<Story />
</Providers>
);
},
],
};
export default preview;

View File

@@ -163,7 +163,14 @@ export default function Page() {
</div>
<OnboardingFooter>
<OnboardingButton className="mb-2" href="/onboarding/4-agent">
<OnboardingButton
className="mb-2"
href="/onboarding/4-agent"
disabled={
state?.integrations.length === 0 &&
isEmptyOrWhitespace(state.otherIntegrations)
}
>
Next
</OnboardingButton>
</OnboardingFooter>

View File

@@ -59,7 +59,7 @@ export default function Page() {
<div className="my-12 flex items-center justify-between gap-5">
<OnboardingAgentCard
agent={agents[0]}
{...(agents[0] || {})}
selected={
agents[0] !== undefined
? state?.selectedStoreListingVersionId ==
@@ -74,7 +74,7 @@ export default function Page() {
}
/>
<OnboardingAgentCard
agent={agents[1]}
{...(agents[1] || {})}
selected={
agents[1] !== undefined
? state?.selectedStoreListingVersionId ==

View File

@@ -9,6 +9,7 @@ import StarRating from "@/components/onboarding/StarRating";
import { Play } from "lucide-react";
import { cn } from "@/lib/utils";
import { useCallback, useEffect, useState } from "react";
import Image from "next/image";
import { GraphMeta, StoreAgentDetails } from "@/lib/autogpt-server-api";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { useRouter } from "next/navigation";
@@ -16,7 +17,6 @@ import { useOnboarding } from "@/components/onboarding/onboarding-provider";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import SchemaTooltip from "@/components/SchemaTooltip";
import { TypeBasedInput } from "@/components/type-based-input";
import SmartImage from "@/components/agptui/SmartImage";
export default function Page() {
const { state, updateState, setStep } = useOnboarding(
@@ -99,7 +99,7 @@ export default function Page() {
}, [api, agent, router, state?.agentInput, storeAgent, updateState]);
const runYourAgent = (
<div className="ml-[104px] w-[481px] pl-5">
<div className="ml-[54px] w-[481px] pl-5">
<div className="flex flex-col">
<OnboardingText variant="header">Run your first agent</OnboardingText>
<span className="mt-9 text-base font-normal leading-normal text-zinc-600">
@@ -147,25 +147,32 @@ export default function Page() {
return (
<OnboardingStep dotted>
<OnboardingHeader backHref={"/onboarding/4-agent"} transparent />
{/* Agent card */}
<div className="fixed left-1/4 top-1/2 w-[481px] -translate-x-1/2 -translate-y-1/2">
<div className="h-[156px] w-[481px] rounded-xl bg-white px-6 pb-5 pt-4">
<span className="font-sans text-xs font-medium tracking-wide text-zinc-500">
SELECTED AGENT
</span>
{storeAgent ? (
<div
className={cn(
"flex w-full items-center justify-center",
showInput ? "mt-[32px]" : "mt-[192px]",
)}
>
{/* Left side */}
<div className="mr-[52px] w-[481px]">
<div className="h-[156px] w-[481px] rounded-xl bg-white px-6 pb-5 pt-4">
<span className="font-sans text-xs font-medium tracking-wide text-zinc-500">
SELECTED AGENT
</span>
<div className="mt-4 flex h-20 rounded-lg bg-violet-50 p-2">
{/* Left image */}
<SmartImage
src={storeAgent?.agent_image[0]}
alt="Agent cover"
imageContain
className="w-[350px] rounded-lg"
<Image
src={storeAgent?.agent_image[0] || ""}
alt="Description"
width={350}
height={196}
className="h-full w-auto rounded-lg object-contain"
/>
{/* Right content */}
<div className="ml-2 flex flex-1 flex-col">
<span className="w-[292px] truncate font-sans text-[14px] font-medium leading-normal text-zinc-800">
{storeAgent?.agent_name}
{agent?.name}
</span>
<span className="mt-[5px] w-[292px] truncate font-sans text-xs font-normal leading-tight text-zinc-600">
by {storeAgent?.creator}
@@ -182,19 +189,13 @@ export default function Page() {
</div>
</div>
</div>
) : (
<div className="mt-4 flex h-20 animate-pulse rounded-lg bg-gray-300 p-2" />
)}
</div>
</div>
</div>
<div className="flex min-h-[80vh] items-center justify-center">
{/* Left side */}
<div className="w-[481px]" />
{/* Right side */}
{!showInput ? (
runYourAgent
) : (
<div className="ml-[104px] w-[481px] pl-5">
<div className="ml-[54px] w-[481px] pl-5">
<div className="flex flex-col">
<OnboardingText variant="header">
Provide details for your agent

View File

@@ -6,13 +6,6 @@ import { redirect } from "next/navigation";
export async function finishOnboarding() {
const api = new BackendAPI();
const onboarding = await api.getUserOnboarding();
const listingId = onboarding?.selectedStoreListingVersionId;
if (listingId) {
const libraryAgent = await api.addMarketplaceAgentToLibrary(listingId);
revalidatePath(`/library/agents/${libraryAgent.id}`, "layout");
redirect(`/library/agents/${libraryAgent.id}`);
} else {
revalidatePath("/library", "layout");
redirect("/library");
}
revalidatePath("/library", "layout");
redirect("/library");
}

View File

@@ -13,7 +13,7 @@ export default async function OnboardingPage() {
// CONGRATS is the last step in intro onboarding
if (onboarding.completedSteps.includes("CONGRATS")) redirect("/marketplace");
else if (onboarding.completedSteps.includes("AGENT_INPUT"))
redirect("/onboarding/5-run");
redirect("/onboarding/6-congrats");
else if (onboarding.completedSteps.includes("AGENT_NEW_RUN"))
redirect("/onboarding/5-run");
else if (onboarding.completedSteps.includes("AGENT_CHOICE"))

View File

@@ -56,9 +56,3 @@ export async function getAdminListingsWithVersions(
const response = await api.getAdminListingsWithVersions(data);
return response;
}
export async function downloadAsAdmin(storeListingVersion: string) {
const api = new BackendApi();
const file = await api.downloadStoreAgentAdmin(storeListingVersion);
return file;
}

View File

@@ -1,11 +1,5 @@
"use client";
import React, {
useCallback,
useEffect,
useMemo,
useRef,
useState,
} from "react";
import React, { useCallback, useEffect, useMemo, useState } from "react";
import { useParams, useRouter } from "next/navigation";
import { exportAsJSONFile } from "@/lib/utils";
@@ -29,16 +23,6 @@ import AgentRunDetailsView from "@/components/agents/agent-run-details-view";
import AgentRunsSelectorList from "@/components/agents/agent-runs-selector-list";
import AgentScheduleDetailsView from "@/components/agents/agent-schedule-details-view";
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import { Button } from "@/components/ui/button";
import { useToast } from "@/components/ui/use-toast";
export default function AgentRunsPage(): React.ReactElement {
const { id: agentID }: { id: LibraryAgentID } = useParams();
@@ -47,7 +31,7 @@ export default function AgentRunsPage(): React.ReactElement {
// ============================ STATE =============================
const [graph, setGraph] = useState<Graph | null>(null); // Graph version corresponding to LibraryAgent
const [graph, setGraph] = useState<Graph | null>(null);
const [agent, setAgent] = useState<LibraryAgent | null>(null);
const [agentRuns, setAgentRuns] = useState<GraphExecutionMeta[]>([]);
const [schedules, setSchedules] = useState<Schedule[]>([]);
@@ -66,10 +50,7 @@ export default function AgentRunsPage(): React.ReactElement {
useState<boolean>(false);
const [confirmingDeleteAgentRun, setConfirmingDeleteAgentRun] =
useState<GraphExecutionMeta | null>(null);
const { state: onboardingState, updateState: updateOnboardingState } =
useOnboarding();
const [copyAgentDialogOpen, setCopyAgentDialogOpen] = useState(false);
const { toast } = useToast();
const { state, updateState } = useOnboarding();
const openRunDraftView = useCallback(() => {
selectView({ type: "run" });
@@ -84,43 +65,34 @@ export default function AgentRunsPage(): React.ReactElement {
setSelectedSchedule(schedule);
}, []);
const graphVersions = useRef<Record<number, Graph>>({});
const loadingGraphVersions = useRef<Record<number, Promise<Graph>>>({});
const [graphVersions, setGraphVersions] = useState<Record<number, Graph>>({});
const getGraphVersion = useCallback(
async (graphID: GraphID, version: number) => {
if (version in graphVersions.current)
return graphVersions.current[version];
if (version in loadingGraphVersions.current)
return loadingGraphVersions.current[version];
if (graphVersions[version]) return graphVersions[version];
const pendingGraph = api.getGraph(graphID, version).then((graph) => {
graphVersions.current[version] = graph;
return graph;
});
// Cache promise as well to avoid duplicate requests
loadingGraphVersions.current[version] = pendingGraph;
return pendingGraph;
const graphVersion = await api.getGraph(graphID, version);
setGraphVersions((prev) => ({
...prev,
[version]: graphVersion,
}));
return graphVersion;
},
[api, graphVersions, loadingGraphVersions],
[api, graphVersions],
);
// Reward user for viewing results of their onboarding agent
useEffect(() => {
if (
!onboardingState ||
!selectedRun ||
onboardingState.completedSteps.includes("GET_RESULTS")
)
if (!state || !selectedRun || state.completedSteps.includes("GET_RESULTS"))
return;
if (selectedRun.id === onboardingState.onboardingAgentExecutionId) {
updateOnboardingState({
completedSteps: [...onboardingState.completedSteps, "GET_RESULTS"],
if (selectedRun.id === state.onboardingAgentExecutionId) {
updateState({
completedSteps: [...state.completedSteps, "GET_RESULTS"],
});
}
}, [selectedRun, onboardingState, updateOnboardingState]);
}, [selectedRun, state]);
const refreshPageData = useCallback(() => {
const fetchAgents = useCallback(() => {
api.getLibraryAgent(agentID).then((agent) => {
setAgent(agent);
@@ -135,40 +107,38 @@ export default function AgentRunsPage(): React.ReactElement {
new Set(agentRuns.map((run) => run.graph_version)).forEach((version) =>
getGraphVersion(agent.graph_id, version),
);
if (!selectedView.id && isFirstLoad && agentRuns.length > 0) {
// only for first load or first execution
setIsFirstLoad(false);
const latestRun = agentRuns.reduce((latest, current) => {
if (latest.started_at && !current.started_at) return current;
else if (!latest.started_at) return latest;
return latest.started_at > current.started_at ? latest : current;
}, agentRuns[0]);
selectView({ type: "run", id: latestRun.id });
}
});
});
}, [api, agentID, getGraphVersion, graph]);
if (selectedView.type == "run" && selectedView.id && agent) {
api
.getGraphExecutionInfo(agent.graph_id, selectedView.id)
.then(setSelectedRun);
}
}, [api, agentID, getGraphVersion, graph, selectedView, isFirstLoad, agent]);
// On first load: select the latest run
useEffect(() => {
// Only for first load or first execution
if (selectedView.id || !isFirstLoad || agentRuns.length == 0) return;
setIsFirstLoad(false);
const latestRun = agentRuns.reduce((latest, current) => {
if (latest.started_at && !current.started_at) return current;
else if (!latest.started_at) return latest;
return latest.started_at > current.started_at ? latest : current;
}, agentRuns[0]);
selectView({ type: "run", id: latestRun.id });
}, [agentRuns, isFirstLoad, selectedView.id, selectView]);
// Initial load
useEffect(() => {
refreshPageData();
fetchAgents();
}, []);
// Subscribe to WebSocket updates for agent runs
// Subscribe to websocket updates for agent runs
useEffect(() => {
if (!agent?.graph_id) return;
if (!agent) return;
return api.onWebSocketConnect(() => {
refreshPageData(); // Sync up on (re)connect
// Subscribe to all executions for this agent
api.subscribeToGraphExecutions(agent.graph_id);
});
}, [api, agent?.graph_id, refreshPageData]);
// Subscribe to all executions for this agent
api.subscribeToGraphExecutions(agent.graph_id);
}, [api, agent]);
// Handle execution updates
useEffect(() => {
@@ -197,29 +167,24 @@ export default function AgentRunsPage(): React.ReactElement {
};
}, [api, agent?.graph_id, selectedView.id]);
// Pre-load selectedRun based on selectedView
// load selectedRun based on selectedView
useEffect(() => {
if (selectedView.type != "run" || !selectedView.id) return;
if (selectedView.type != "run" || !selectedView.id || !agent) return;
const newSelectedRun = agentRuns.find((run) => run.id == selectedView.id);
if (selectedView.id !== selectedRun?.id) {
// Pull partial data from "cache" while waiting for the rest to load
setSelectedRun(newSelectedRun ?? null);
// Ensure corresponding graph version is available before rendering I/O
api
.getGraphExecutionInfo(agent.graph_id, selectedView.id)
.then(async (run) => {
await getGraphVersion(run.graph_id, run.graph_version);
setSelectedRun(run);
});
}
}, [api, selectedView, agentRuns, selectedRun?.id]);
// Load selectedRun based on selectedView; refresh on agent refresh
useEffect(() => {
if (selectedView.type != "run" || !selectedView.id || !agent) return;
api
.getGraphExecutionInfo(agent.graph_id, selectedView.id)
.then(async (run) => {
// Ensure corresponding graph version is available before rendering I/O
await getGraphVersion(run.graph_id, run.graph_version);
setSelectedRun(run);
});
}, [api, selectedView, agent, getGraphVersion]);
}, [api, selectedView, agent, agentRuns, selectedRun?.id, getGraphVersion]);
const fetchSchedules = useCallback(async () => {
if (!agent) return;
@@ -272,41 +237,20 @@ export default function AgentRunsPage(): React.ReactElement {
[api, agent],
);
const copyAgent = useCallback(async () => {
setCopyAgentDialogOpen(false);
api
.forkLibraryAgent(agentID)
.then((newAgent) => {
router.push(`/library/agents/${newAgent.id}`);
})
.catch((error) => {
console.error("Error copying agent:", error);
toast({
title: "Error copying agent",
description: `An error occurred while copying the agent: ${error.message}`,
variant: "destructive",
});
});
}, [agentID, api, router, toast]);
const agentActions: ButtonAction[] = useMemo(
() => [
{
label: "Customize agent",
href: `/build?flowID=${agent?.graph_id}&flowVersion=${agent?.graph_version}`,
disabled: !agent?.can_access_graph,
},
{ label: "Export agent to file", callback: downloadGraph },
...(!agent?.can_access_graph
...(agent?.can_access_graph
? [
{
label: "Edit a copy",
callback: () => setCopyAgentDialogOpen(true),
label: "Open graph in builder",
href: `/build?flowID=${agent.graph_id}&flowVersion=${agent.graph_version}`,
},
{ label: "Export agent to file", callback: downloadGraph },
]
: []),
{
label: "Delete agent",
variant: "destructive",
callback: () => setAgentDeleteDialogOpen(true),
},
],
@@ -351,7 +295,7 @@ export default function AgentRunsPage(): React.ReactElement {
selectedRun && (
<AgentRunDetailsView
agent={agent}
graph={graphVersions.current[selectedRun.graph_version] ?? graph}
graph={graphVersions[selectedRun.graph_version] ?? graph}
run={selectedRun}
agentActions={agentActions}
onRun={(runID) => selectRun(runID)}
@@ -395,36 +339,6 @@ export default function AgentRunsPage(): React.ReactElement {
confirmingDeleteAgentRun && deleteRun(confirmingDeleteAgentRun)
}
/>
{/* Copy agent confirmation dialog */}
<Dialog
onOpenChange={setCopyAgentDialogOpen}
open={copyAgentDialogOpen}
>
<DialogContent>
<DialogHeader>
<DialogTitle>You&apos;re making an editable copy</DialogTitle>
<DialogDescription className="pt-2">
The original Marketplace agent stays the same and cannot be
edited. We&apos;ll save a new version of this agent to your
Library. From there, you can customize it however you&apos;d
like by clicking &quot;Customize agent&quot; this will open
the builder where you can see and modify the inner workings.
</DialogDescription>
</DialogHeader>
<DialogFooter className="justify-end">
<Button
type="button"
variant="outline"
onClick={() => setCopyAgentDialogOpen(false)}
>
Cancel
</Button>
<Button type="button" onClick={copyAgent}>
Continue
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</div>
</div>
);

View File

@@ -6,7 +6,6 @@ import * as Sentry from "@sentry/nextjs";
import getServerSupabase from "@/lib/supabase/getServerSupabase";
import BackendAPI from "@/lib/autogpt-server-api";
import { loginFormSchema, LoginProvider } from "@/types/auth";
import { verifyTurnstileToken } from "@/lib/turnstile";
export async function logout() {
return await Sentry.withServerActionInstrumentation(
@@ -40,10 +39,7 @@ async function shouldShowOnboarding() {
);
}
export async function login(
values: z.infer<typeof loginFormSchema>,
turnstileToken: string,
) {
export async function login(values: z.infer<typeof loginFormSchema>) {
return await Sentry.withServerActionInstrumentation("login", {}, async () => {
const supabase = getServerSupabase();
const api = new BackendAPI();
@@ -52,12 +48,6 @@ export async function login(
redirect("/error");
}
// Verify Turnstile token if provided
const success = await verifyTurnstileToken(turnstileToken, "login");
if (!success) {
return "CAPTCHA verification failed. Please try again.";
}
// We are sure that the values are of the correct type because zod validates the form
const { data, error } = await supabase.auth.signInWithPassword(values);

View File

@@ -24,11 +24,9 @@ import {
AuthFeedback,
AuthBottomText,
PasswordInput,
Turnstile,
} from "@/components/auth";
import { loginFormSchema } from "@/types/auth";
import { getBehaveAs } from "@/lib/utils";
import { useTurnstile } from "@/hooks/useTurnstile";
export default function LoginPage() {
const { supabase, user, isUserLoading } = useSupabase();
@@ -36,12 +34,6 @@ export default function LoginPage() {
const router = useRouter();
const [isLoading, setIsLoading] = useState(false);
const turnstile = useTurnstile({
action: "login",
autoVerify: false,
resetOnError: true,
});
const form = useForm<z.infer<typeof loginFormSchema>>({
resolver: zodResolver(loginFormSchema),
defaultValues: {
@@ -73,23 +65,15 @@ export default function LoginPage() {
return;
}
if (!turnstile.verified) {
setFeedback("Please complete the CAPTCHA challenge.");
setIsLoading(false);
return;
}
const error = await login(data, turnstile.token as string);
const error = await login(data);
setIsLoading(false);
if (error) {
setFeedback(error);
// Always reset the turnstile on any error
turnstile.reset();
return;
}
setFeedback(null);
},
[form, turnstile],
[form],
);
if (user) {
@@ -156,17 +140,6 @@ export default function LoginPage() {
</FormItem>
)}
/>
{/* Turnstile CAPTCHA Component */}
<Turnstile
siteKey={turnstile.siteKey}
onVerify={turnstile.handleVerify}
onExpire={turnstile.handleExpire}
onError={turnstile.handleError}
action="login"
shouldRender={turnstile.shouldRender}
/>
<AuthButton
onClick={() => onLogin(form.getValues())}
isLoading={isLoading}
@@ -176,7 +149,6 @@ export default function LoginPage() {
</AuthButton>
</form>
<AuthFeedback
type="login"
message={feedback}
isError={!!feedback}
behaveAs={getBehaveAs()}

View File

@@ -6,7 +6,6 @@ import { AgentsSection } from "@/components/agptui/composite/AgentsSection";
import { BecomeACreator } from "@/components/agptui/BecomeACreator";
import { Separator } from "@/components/ui/separator";
import { Metadata } from "next";
import getServerUser from "@/lib/supabase/getServerUser";
export async function generateMetadata({
params,
@@ -17,7 +16,7 @@ export async function generateMetadata({
const agent = await api.getStoreAgent(params.creator, params.slug);
return {
title: `${agent.agent_name} - AutoGPT Marketplace`,
title: `${agent.agent_name} - AutoGPT Store`,
description: agent.description,
};
}
@@ -37,7 +36,6 @@ export default async function Page({
params: { creator: string; slug: string };
}) {
const creator_lower = params.creator.toLowerCase();
const { user } = await getServerUser();
const api = new BackendAPI();
const agent = await api.getStoreAgent(creator_lower, params.slug);
const otherAgents = await api.getStoreAgents({ creator: creator_lower });
@@ -45,17 +43,9 @@ export default async function Page({
// We are using slug as we know its has been sanitized and is not null
search_query: agent.slug.replace(/-/g, " "),
});
const libraryAgent = user
? await api
.getLibraryAgentByStoreListingVersionID(agent.store_listing_version_id)
.catch((error) => {
console.error("Failed to fetch library agent:", error);
return null;
})
: null;
const breadcrumbs = [
{ name: "Marketplace", link: "/marketplace" },
{ name: "Store", link: "/marketplace" },
{
name: agent.creator,
link: `/marketplace/creator/${encodeURIComponent(agent.creator)}`,
@@ -71,10 +61,9 @@ export default async function Page({
<div className="mt-4 flex flex-col items-start gap-4 sm:mt-6 sm:gap-6 md:mt-8 md:flex-row md:gap-8">
<div className="w-full md:w-auto md:shrink-0">
<AgentInfo
user={user}
name={agent.agent_name}
creator={agent.creator}
shortDescription={agent.sub_heading}
shortDescription={agent.description}
longDescription={agent.description}
rating={agent.rating}
runs={agent.runs}
@@ -82,7 +71,6 @@ export default async function Page({
lastUpdated={agent.updated_at}
version={agent.versions[agent.versions.length - 1]}
storeListingVersionId={agent.store_listing_version_id}
libraryAgent={libraryAgent}
/>
</div>
<AgentImages

View File

@@ -298,9 +298,7 @@ export default function CreditsPage() {
>
<b>{formatCredits(transaction.amount)}</b>
</TableCell>
<TableCell>
{formatCredits(transaction.running_balance)}
</TableCell>
<TableCell>{formatCredits(transaction.balance)}</TableCell>
</TableRow>
))}
</TableBody>

View File

@@ -116,7 +116,6 @@ export default function PrivatePage() {
"544c62b5-1d0f-4156-8fb4-9525f11656eb", // Apollo
"3bcdbda3-84a3-46af-8fdb-bfd2472298b8", // SmartLead
"63a6e279-2dc2-448e-bf57-85776f7176dc", // ZeroBounce
"9aa1bde0-4947-4a70-a20c-84daa3850d52", // Google Maps
],
[],
);

View File

@@ -18,15 +18,11 @@ export default function Layout({ children }: { children: React.ReactNode }) {
href: "/profile/dashboard",
icon: <IconDashboardLayout className="h-6 w-6" />,
},
...(process.env.NEXT_PUBLIC_SHOW_BILLING_PAGE === "true"
? [
{
text: "Billing",
href: "/profile/credits",
icon: <IconCoin className="h-6 w-6" />,
},
]
: []),
{
text: "Billing",
href: "/profile/credits",
icon: <IconCoin className="h-6 w-6" />,
},
{
text: "Integrations",
href: "/profile/integrations",

View File

@@ -3,9 +3,8 @@ import getServerSupabase from "@/lib/supabase/getServerSupabase";
import { redirect } from "next/navigation";
import * as Sentry from "@sentry/nextjs";
import { headers } from "next/headers";
import { verifyTurnstileToken } from "@/lib/turnstile";
export async function sendResetEmail(email: string, turnstileToken: string) {
export async function sendResetEmail(email: string) {
return await Sentry.withServerActionInstrumentation(
"sendResetEmail",
{},
@@ -21,15 +20,6 @@ export async function sendResetEmail(email: string, turnstileToken: string) {
redirect("/error");
}
// Verify Turnstile token if provided
const success = await verifyTurnstileToken(
turnstileToken,
"reset_password",
);
if (!success) {
return "CAPTCHA verification failed. Please try again.";
}
const { error } = await supabase.auth.resetPasswordForEmail(email, {
redirectTo: `${origin}/reset_password`,
});
@@ -44,7 +34,7 @@ export async function sendResetEmail(email: string, turnstileToken: string) {
);
}
export async function changePassword(password: string, turnstileToken: string) {
export async function changePassword(password: string) {
return await Sentry.withServerActionInstrumentation(
"changePassword",
{},
@@ -55,15 +45,6 @@ export async function changePassword(password: string, turnstileToken: string) {
redirect("/error");
}
// Verify Turnstile token if provided
const success = await verifyTurnstileToken(
turnstileToken,
"change_password",
);
if (!success) {
return "CAPTCHA verification failed. Please try again.";
}
const { error } = await supabase.auth.updateUser({ password });
if (error) {

View File

@@ -5,7 +5,6 @@ import {
AuthButton,
AuthFeedback,
PasswordInput,
Turnstile,
} from "@/components/auth";
import {
Form,
@@ -26,7 +25,6 @@ import { z } from "zod";
import { changePassword, sendResetEmail } from "./actions";
import Spinner from "@/components/Spinner";
import { getBehaveAs } from "@/lib/utils";
import { useTurnstile } from "@/hooks/useTurnstile";
export default function ResetPasswordPage() {
const { supabase, user, isUserLoading } = useSupabase();
@@ -35,18 +33,6 @@ export default function ResetPasswordPage() {
const [isError, setIsError] = useState(false);
const [disabled, setDisabled] = useState(false);
const sendEmailTurnstile = useTurnstile({
action: "reset_password",
autoVerify: false,
resetOnError: true,
});
const changePasswordTurnstile = useTurnstile({
action: "change_password",
autoVerify: false,
resetOnError: true,
});
const sendEmailForm = useForm<z.infer<typeof sendEmailFormSchema>>({
resolver: zodResolver(sendEmailFormSchema),
defaultValues: {
@@ -72,22 +58,11 @@ export default function ResetPasswordPage() {
return;
}
if (!sendEmailTurnstile.verified) {
setFeedback("Please complete the CAPTCHA challenge.");
setIsError(true);
setIsLoading(false);
return;
}
const error = await sendResetEmail(
data.email,
sendEmailTurnstile.token as string,
);
const error = await sendResetEmail(data.email);
setIsLoading(false);
if (error) {
setFeedback(error);
setIsError(true);
sendEmailTurnstile.reset();
return;
}
setDisabled(true);
@@ -96,7 +71,7 @@ export default function ResetPasswordPage() {
);
setIsError(false);
},
[sendEmailForm, sendEmailTurnstile],
[sendEmailForm],
);
const onChangePassword = useCallback(
@@ -109,28 +84,17 @@ export default function ResetPasswordPage() {
return;
}
if (!changePasswordTurnstile.verified) {
setFeedback("Please complete the CAPTCHA challenge.");
setIsError(true);
setIsLoading(false);
return;
}
const error = await changePassword(
data.password,
changePasswordTurnstile.token as string,
);
const error = await changePassword(data.password);
setIsLoading(false);
if (error) {
setFeedback(error);
setIsError(true);
changePasswordTurnstile.reset();
return;
}
setFeedback("Password changed successfully. Redirecting to login.");
setIsError(false);
},
[changePasswordForm, changePasswordTurnstile],
[changePasswordForm],
);
if (isUserLoading) {
@@ -181,17 +145,6 @@ export default function ResetPasswordPage() {
</FormItem>
)}
/>
{/* Turnstile CAPTCHA Component for password change */}
<Turnstile
siteKey={changePasswordTurnstile.siteKey}
onVerify={changePasswordTurnstile.handleVerify}
onExpire={changePasswordTurnstile.handleExpire}
onError={changePasswordTurnstile.handleError}
action="change_password"
shouldRender={changePasswordTurnstile.shouldRender}
/>
<AuthButton
onClick={() => onChangePassword(changePasswordForm.getValues())}
isLoading={isLoading}
@@ -200,7 +153,6 @@ export default function ResetPasswordPage() {
Update password
</AuthButton>
<AuthFeedback
type="login"
message={feedback}
isError={isError}
behaveAs={getBehaveAs()}
@@ -223,17 +175,6 @@ export default function ResetPasswordPage() {
</FormItem>
)}
/>
{/* Turnstile CAPTCHA Component for reset email */}
<Turnstile
siteKey={sendEmailTurnstile.siteKey}
onVerify={sendEmailTurnstile.handleVerify}
onExpire={sendEmailTurnstile.handleExpire}
onError={sendEmailTurnstile.handleError}
action="reset_password"
shouldRender={sendEmailTurnstile.shouldRender}
/>
<AuthButton
onClick={() => onSendEmail(sendEmailForm.getValues())}
isLoading={isLoading}
@@ -243,7 +184,6 @@ export default function ResetPasswordPage() {
Send reset email
</AuthButton>
<AuthFeedback
type="login"
message={feedback}
isError={isError}
behaveAs={getBehaveAs()}

View File

@@ -6,12 +6,8 @@ import * as Sentry from "@sentry/nextjs";
import getServerSupabase from "@/lib/supabase/getServerSupabase";
import { signupFormSchema } from "@/types/auth";
import BackendAPI from "@/lib/autogpt-server-api";
import { verifyTurnstileToken } from "@/lib/turnstile";
export async function signup(
values: z.infer<typeof signupFormSchema>,
turnstileToken: string,
) {
export async function signup(values: z.infer<typeof signupFormSchema>) {
"use server";
return await Sentry.withServerActionInstrumentation(
"signup",
@@ -23,12 +19,6 @@ export async function signup(
redirect("/error");
}
// Verify Turnstile token if provided
const success = await verifyTurnstileToken(turnstileToken, "signup");
if (!success) {
return "CAPTCHA verification failed. Please try again.";
}
// We are sure that the values are of the correct type because zod validates the form
const { data, error } = await supabase.auth.signUp(values);

View File

@@ -25,12 +25,10 @@ import {
AuthButton,
AuthBottomText,
PasswordInput,
Turnstile,
} from "@/components/auth";
import AuthFeedback from "@/components/auth/AuthFeedback";
import { signupFormSchema } from "@/types/auth";
import { getBehaveAs } from "@/lib/utils";
import { useTurnstile } from "@/hooks/useTurnstile";
export default function SignupPage() {
const { supabase, user, isUserLoading } = useSupabase();
@@ -39,12 +37,6 @@ export default function SignupPage() {
const [isLoading, setIsLoading] = useState(false);
//TODO: Remove after closed beta
const turnstile = useTurnstile({
action: "signup",
autoVerify: false,
resetOnError: true,
});
const form = useForm<z.infer<typeof signupFormSchema>>({
resolver: zodResolver(signupFormSchema),
defaultValues: {
@@ -64,28 +56,20 @@ export default function SignupPage() {
return;
}
if (!turnstile.verified) {
setFeedback("Please complete the CAPTCHA challenge.");
setIsLoading(false);
return;
}
const error = await signup(data, turnstile.token as string);
const error = await signup(data);
setIsLoading(false);
if (error) {
if (error === "user_already_exists") {
setFeedback("User with this email already exists");
turnstile.reset();
return;
} else {
setFeedback(error);
turnstile.reset();
}
return;
}
setFeedback(null);
},
[form, turnstile],
[form],
);
if (user) {
@@ -106,7 +90,7 @@ export default function SignupPage() {
}
return (
<AuthCard className="mx-auto mt-12">
<AuthCard className="mx-auto">
<AuthHeader>Create a new account</AuthHeader>
<Form {...form}>
<form onSubmit={form.handleSubmit(onSignup)}>
@@ -157,17 +141,6 @@ export default function SignupPage() {
</FormItem>
)}
/>
{/* Turnstile CAPTCHA Component */}
<Turnstile
siteKey={turnstile.siteKey}
onVerify={turnstile.handleVerify}
onExpire={turnstile.handleExpire}
onError={turnstile.handleError}
action="signup"
shouldRender={turnstile.shouldRender}
/>
<AuthButton
onClick={() => onSignup(form.getValues())}
isLoading={isLoading}
@@ -215,7 +188,6 @@ export default function SignupPage() {
</form>
</Form>
<AuthFeedback
type="signup"
message={feedback}
isError={!!feedback}
behaveAs={getBehaveAs()}

View File

@@ -1,45 +0,0 @@
"use server";
import { revalidatePath } from "next/cache";
import BackendApi from "@/lib/autogpt-server-api";
import {
UsersBalanceHistoryResponse,
CreditTransactionType,
} from "@/lib/autogpt-server-api/types";
export async function addDollars(formData: FormData) {
const data = {
user_id: formData.get("id") as string,
amount: parseInt(formData.get("amount") as string),
comments: formData.get("comments") as string,
};
const api = new BackendApi();
const resp = await api.addUserCredits(
data.user_id,
data.amount,
data.comments,
);
console.log(resp);
revalidatePath("/admin/spending");
}
export async function getUsersTransactionHistory(
page: number = 1,
pageSize: number = 20,
search?: string,
transactionType?: CreditTransactionType,
): Promise<UsersBalanceHistoryResponse> {
const data: Record<string, any> = {
page,
page_size: pageSize,
};
if (search) {
data.search = search;
}
if (transactionType) {
data.transaction_filter = transactionType;
}
const api = new BackendApi();
const history = await api.getUsersHistory(data);
return history;
}

View File

@@ -1,58 +0,0 @@
import { AdminUserGrantHistory } from "@/components/admin/spending/admin-grant-history-data-table";
import type { CreditTransactionType } from "@/lib/autogpt-server-api";
import { withRoleAccess } from "@/lib/withRoleAccess";
import { Suspense } from "react";
function SpendingDashboard({
searchParams,
}: {
searchParams: {
page?: string;
status?: string;
search?: string;
};
}) {
const page = searchParams.page ? Number.parseInt(searchParams.page) : 1;
const search = searchParams.search;
const status = searchParams.status as CreditTransactionType | undefined;
return (
<div className="mx-auto p-6">
<div className="flex flex-col gap-4">
<div className="flex items-center justify-between">
<div>
<h1 className="text-3xl font-bold">User Spending</h1>
<p className="text-gray-500">Manage user spending balances</p>
</div>
</div>
<Suspense
fallback={
<div className="py-10 text-center">Loading submissions...</div>
}
>
<AdminUserGrantHistory
initialPage={page}
initialStatus={status}
initialSearch={search}
/>
</Suspense>
</div>
</div>
);
}
export default async function SpendingDashboardPage({
searchParams,
}: {
searchParams: {
page?: string;
status?: string;
search?: string;
};
}) {
"use server";
const withAdminAccess = await withRoleAccess(["admin"]);
const ProtectedSpendingDashboard = await withAdminAccess(SpendingDashboard);
return <ProtectedSpendingDashboard searchParams={searchParams} />;
}

View File

@@ -7,12 +7,27 @@ import { BackendAPIProvider } from "@/lib/autogpt-server-api/context";
import { TooltipProvider } from "@/components/ui/tooltip";
import CredentialsProvider from "@/components/integrations/credentials-provider";
import { LaunchDarklyProvider } from "@/components/feature-flag/feature-flag-provider";
import { MockClientProps } from "@/lib/autogpt-server-api/mock_client";
import OnboardingProvider from "@/components/onboarding/onboarding-provider";
export function Providers({ children, ...props }: ThemeProviderProps) {
export interface ProvidersProps extends ThemeProviderProps {
children: React.ReactNode;
useMockBackend?: boolean;
mockClientProps?: MockClientProps;
}
export function Providers({
children,
useMockBackend,
mockClientProps,
...props
}: ProvidersProps) {
return (
<NextThemesProvider {...props}>
<BackendAPIProvider>
<BackendAPIProvider
useMockBackend={useMockBackend}
mockClientProps={mockClientProps}
>
<CredentialsProvider>
<LaunchDarklyProvider>
<OnboardingProvider>

View File

@@ -1,58 +0,0 @@
"use client";
import { downloadAsAdmin } from "@/app/(platform)/admin/marketplace/actions";
import { Button } from "@/components/ui/button";
import { ExternalLink } from "lucide-react";
import { useState } from "react";
export function DownloadAgentAdminButton({
storeListingVersionId,
}: {
storeListingVersionId: string;
}) {
const [isLoading, setIsLoading] = useState(false);
const handleDownload = async () => {
try {
setIsLoading(true);
// Call the server action to get the data
const fileData = await downloadAsAdmin(storeListingVersionId);
// Client-side download logic
const jsonData = JSON.stringify(fileData, null, 2);
const blob = new Blob([jsonData], { type: "application/json" });
// Create a temporary URL for the Blob
const url = window.URL.createObjectURL(blob);
// Create a temporary anchor element
const a = document.createElement("a");
a.href = url;
a.download = `agent_${storeListingVersionId}.json`;
// Append the anchor to the body, click it, and remove it
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
// Revoke the temporary URL
window.URL.revokeObjectURL(url);
} catch (error) {
console.error("Download failed:", error);
} finally {
setIsLoading(false);
}
};
return (
<Button
size="sm"
variant="outline"
onClick={handleDownload}
disabled={isLoading}
>
<ExternalLink className="mr-2 h-4 w-4" />
{isLoading ? "Downloading..." : "Download"}
</Button>
);
}

View File

@@ -19,8 +19,6 @@ import {
SubmissionStatus,
} from "@/lib/autogpt-server-api/types";
import { ApproveRejectButtons } from "./approve-reject-buttons";
import { downloadAsAdmin } from "@/app/(platform)/admin/marketplace/actions";
import { DownloadAgentAdminButton } from "./download-agent-button";
// Moved the getStatusBadge function into the client component
const getStatusBadge = (status: SubmissionStatus) => {
@@ -79,11 +77,10 @@ export function ExpandableRow({
</TableCell>
<TableCell className="text-right">
<div className="flex justify-end gap-2">
{latestVersion?.store_listing_version_id && (
<DownloadAgentAdminButton
storeListingVersionId={latestVersion.store_listing_version_id}
/>
)}
<Button size="sm" variant="outline">
<ExternalLink className="mr-2 h-4 w-4" />
Builder
</Button>
{latestVersion?.status === SubmissionStatus.PENDING && (
<ApproveRejectButtons version={latestVersion} />
@@ -183,13 +180,17 @@ export function ExpandableRow({
{/* <TableCell>{version.categories.join(", ")}</TableCell> */}
<TableCell className="text-right">
<div className="flex justify-end gap-2">
{version.store_listing_version_id && (
<DownloadAgentAdminButton
storeListingVersionId={
version.store_listing_version_id
}
/>
)}
<Button
size="sm"
variant="outline"
onClick={() =>
(window.location.href = `/admin/agents/${version.store_listing_version_id}`)
}
>
<ExternalLink className="mr-2 h-4 w-4" />
Builder
</Button>
{version.status === SubmissionStatus.PENDING && (
<ApproveRejectButtons version={version} />
)}

View File

@@ -1,139 +0,0 @@
"use client";
import { useState } from "react";
import { Button } from "@/components/ui/button";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
DialogFooter,
} from "@/components/ui/dialog";
import { Label } from "@/components/ui/label";
import { Textarea } from "@/components/ui/textarea";
import { Input } from "@/components/ui/input";
import { useRouter } from "next/navigation";
import { addDollars } from "@/app/admin/spending/actions";
import useCredits from "@/hooks/useCredits";
export function AdminAddMoneyButton({
userId,
userEmail,
currentBalance,
defaultAmount,
defaultComments,
}: {
userId: string;
userEmail: string;
currentBalance: number;
defaultAmount?: number;
defaultComments?: string;
}) {
const router = useRouter();
const [isAddMoneyDialogOpen, setIsAddMoneyDialogOpen] = useState(false);
const [dollarAmount, setDollarAmount] = useState(
defaultAmount ? Math.abs(defaultAmount / 100).toFixed(2) : "1.00",
);
const { formatCredits } = useCredits();
const handleApproveSubmit = async (formData: FormData) => {
setIsAddMoneyDialogOpen(false);
try {
await addDollars(formData);
router.refresh(); // Refresh the current route
} catch (error) {
console.error("Error adding dollars:", error);
}
};
return (
<>
<Button
size="sm"
variant="default"
onClick={(e) => {
e.stopPropagation();
setIsAddMoneyDialogOpen(true);
}}
>
Add Dollars
</Button>
{/* Add $$$ Dialog */}
<Dialog
open={isAddMoneyDialogOpen}
onOpenChange={setIsAddMoneyDialogOpen}
>
<DialogContent>
<DialogHeader>
<DialogTitle>Add Dollars</DialogTitle>
<DialogDescription className="pt-2">
<div className="mb-2">
<span className="font-medium">User:</span> {userEmail}
</div>
<div>
<span className="font-medium">Current balance:</span> $
{(currentBalance / 100).toFixed(2)}
</div>
</DialogDescription>
</DialogHeader>
<form action={handleApproveSubmit}>
<input type="hidden" name="id" value={userId} />
<input
type="hidden"
name="amount"
value={Math.round(parseFloat(dollarAmount) * 100)}
/>
<div className="grid gap-4 py-4">
<div className="grid gap-2">
<Label htmlFor="dollarAmount">Amount (in dollars)</Label>
<div className="flex">
<div className="flex items-center justify-center rounded-l-md border border-r-0 bg-gray-50 px-3 text-gray-500">
$
</div>
<Input
id="dollarAmount"
type="number"
step="0.01"
min="0"
className="rounded-l-none"
value={dollarAmount}
onChange={(e) => setDollarAmount(e.target.value)}
placeholder="0.00"
/>
</div>
</div>
</div>
<div className="grid gap-4 py-4">
<div className="grid gap-2">
<Label htmlFor="comments">Comments (Optional)</Label>
<Textarea
id="comments"
name="comments"
placeholder="Why are you adding dollars?"
defaultValue={defaultComments || "We love you!"}
/>
</div>
</div>
<DialogFooter>
<Button
type="button"
variant="outline"
onClick={() => setIsAddMoneyDialogOpen(false)}
>
Cancel
</Button>
<Button type="submit">Add Dollars</Button>
</DialogFooter>
</form>
</DialogContent>
</Dialog>
</>
);
}

View File

@@ -1,183 +0,0 @@
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { PaginationControls } from "../../ui/pagination-controls";
import { SearchAndFilterAdminSpending } from "./search-filter-form";
import { getUsersTransactionHistory } from "@/app/admin/spending/actions";
import { AdminAddMoneyButton } from "./add-money-button";
import { CreditTransactionType } from "@/lib/autogpt-server-api";
export async function AdminUserGrantHistory({
initialPage = 1,
initialStatus,
initialSearch,
}: {
initialPage?: number;
initialStatus?: CreditTransactionType;
initialSearch?: string;
}) {
// Server-side data fetching
const { history, pagination } = await getUsersTransactionHistory(
initialPage,
15,
initialSearch,
initialStatus,
);
// Helper function to format the amount with color based on transaction type
const formatAmount = (amount: number, type: CreditTransactionType) => {
const isPositive = type === CreditTransactionType.GRANT;
const isNeutral = type === CreditTransactionType.TOP_UP;
const color = isPositive
? "text-green-600"
: isNeutral
? "text-blue-600"
: "text-red-600";
return <span className={color}>${Math.abs(amount / 100)}</span>;
};
// Helper function to format the transaction type with color
const formatType = (type: CreditTransactionType) => {
const isGrant = type === CreditTransactionType.GRANT;
const isPurchased = type === CreditTransactionType.TOP_UP;
const isSpent = type === CreditTransactionType.USAGE;
let displayText = type;
let bgColor = "";
if (isGrant) {
bgColor = "bg-green-100 text-green-800";
} else if (isPurchased) {
bgColor = "bg-blue-100 text-blue-800";
} else if (isSpent) {
bgColor = "bg-red-100 text-red-800";
}
return (
<span className={`rounded-full px-2 py-1 text-xs font-medium ${bgColor}`}>
{displayText.valueOf()}
</span>
);
};
// Helper function to format the date
const formatDate = (date: Date) => {
return new Intl.DateTimeFormat("en-US", {
month: "short",
day: "numeric",
year: "numeric",
hour: "numeric",
minute: "numeric",
hour12: true,
}).format(new Date(date));
};
return (
<div className="space-y-4">
<SearchAndFilterAdminSpending
initialStatus={initialStatus}
initialSearch={initialSearch}
/>
<div className="rounded-md border bg-white">
<Table>
<TableHeader className="bg-gray-50">
<TableRow>
<TableHead className="font-medium">User</TableHead>
<TableHead className="font-medium">Type</TableHead>
<TableHead className="font-medium">Date</TableHead>
<TableHead className="font-medium">Reason</TableHead>
<TableHead className="font-medium">Admin</TableHead>
<TableHead className="font-medium">Starting Balance</TableHead>
<TableHead className="font-medium">Amount</TableHead>
<TableHead className="font-medium">Ending Balance</TableHead>
{/* <TableHead className="font-medium">Current Balance</TableHead> */}
<TableHead className="text-right font-medium">Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{history.length === 0 ? (
<TableRow>
<TableCell
colSpan={8}
className="py-10 text-center text-gray-500"
>
No transactions found
</TableCell>
</TableRow>
) : (
history.map((transaction) => (
<TableRow
key={transaction.user_id}
className="hover:bg-gray-50"
>
<TableCell className="font-medium">
{transaction.user_email}
</TableCell>
<TableCell>
{formatType(transaction.transaction_type)}
</TableCell>
<TableCell className="text-gray-600">
{formatDate(transaction.transaction_time)}
</TableCell>
<TableCell>{transaction.reason}</TableCell>
<TableCell className="text-gray-600">
{transaction.admin_email}
</TableCell>
<TableCell className="font-medium text-green-600">
${(transaction.running_balance + -transaction.amount) / 100}
</TableCell>
<TableCell>
{formatAmount(
transaction.amount,
transaction.transaction_type,
)}
</TableCell>
<TableCell className="font-medium text-green-600">
${transaction.running_balance / 100}
</TableCell>
{/* <TableCell className="font-medium text-green-600">
${transaction.current_balance / 100}
</TableCell> */}
<TableCell className="text-right">
<AdminAddMoneyButton
userId={transaction.user_id}
userEmail={
transaction.user_email ?? "User Email wasn't attached"
}
currentBalance={transaction.current_balance}
defaultAmount={
transaction.transaction_type ===
CreditTransactionType.USAGE
? -transaction.amount
: undefined
}
defaultComments={
transaction.transaction_type ===
CreditTransactionType.USAGE
? "Refund for usage"
: undefined
}
/>
</TableCell>
</TableRow>
))
)}
</TableBody>
</Table>
</div>
<PaginationControls
currentPage={initialPage}
totalPages={pagination.total_pages}
/>
</div>
);
}

View File

@@ -1,105 +0,0 @@
"use client";
import { useState, useEffect } from "react";
import { useRouter, usePathname, useSearchParams } from "next/navigation";
import { Input } from "@/components/ui/input";
import { Button } from "@/components/ui/button";
import { Search } from "lucide-react";
import { CreditTransactionType } from "@/lib/autogpt-server-api";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
export function SearchAndFilterAdminSpending({
initialStatus,
initialSearch,
}: {
initialStatus?: CreditTransactionType;
initialSearch?: string;
}) {
const router = useRouter();
const pathname = usePathname();
const searchParams = useSearchParams();
// Initialize state from URL parameters
const [searchQuery, setSearchQuery] = useState(initialSearch || "");
const [selectedStatus, setSelectedStatus] = useState<string>(
searchParams.get("status") || "ALL",
);
// Update local state when URL parameters change
useEffect(() => {
const status = searchParams.get("status");
setSelectedStatus(status || "ALL");
setSearchQuery(searchParams.get("search") || "");
}, [searchParams]);
const handleSearch = () => {
const params = new URLSearchParams(searchParams.toString());
if (searchQuery) {
params.set("search", searchQuery);
} else {
params.delete("search");
}
if (selectedStatus !== "ALL") {
params.set("status", selectedStatus);
} else {
params.delete("status");
}
params.set("page", "1"); // Reset to first page on new search
router.push(`${pathname}?${params.toString()}`);
};
return (
<div className="flex items-center justify-between">
<div className="flex w-full items-center gap-2">
<Input
placeholder="Search users by Name or Email..."
value={searchQuery}
onChange={(e) => setSearchQuery(e.target.value)}
onKeyDown={(e) => e.key === "Enter" && handleSearch()}
/>
<Button variant="outline" onClick={handleSearch}>
<Search className="h-4 w-4" />
</Button>
</div>
<Select
value={selectedStatus}
onValueChange={(value) => {
setSelectedStatus(value);
const params = new URLSearchParams(searchParams.toString());
if (value === "ALL") {
params.delete("status");
} else {
params.set("status", value);
}
params.set("page", "1");
router.push(`${pathname}?${params.toString()}`);
}}
>
<SelectTrigger className="w-1/4">
<SelectValue placeholder="Select Status" />
</SelectTrigger>
<SelectContent>
<SelectItem value="ALL">All</SelectItem>
<SelectItem value={CreditTransactionType.TOP_UP}>Top Up</SelectItem>
<SelectItem value={CreditTransactionType.USAGE}>Usage</SelectItem>
<SelectItem value={CreditTransactionType.REFUND}>Refund</SelectItem>
<SelectItem value={CreditTransactionType.GRANT}>Grant</SelectItem>
<SelectItem value={CreditTransactionType.CARD_CHECK}>
Card Check
</SelectItem>
</SelectContent>
</Select>
</div>
);
}

View File

@@ -239,10 +239,7 @@ export default function AgentRunDetailsView({
{title || key}
</label>
{values.map((value, i) => (
<p
className="resize-none whitespace-pre-wrap break-words border-none text-sm text-neutral-700 disabled:cursor-not-allowed"
key={i}
>
<p className="text-sm text-neutral-700" key={i}>
{value}
</p>
))}

View File

@@ -27,8 +27,6 @@ type Story = StoryObj<typeof meta>;
export const Default: Story = {
args: {
user: null,
libraryAgent: null,
name: "AI Video Generator",
storeListingVersionId: "123",
creator: "Toran Richards",

View File

@@ -1,19 +1,17 @@
"use client";
import { StarRatingIcons } from "@/components/ui/icons";
import * as React from "react";
import { IconPlay, StarRatingIcons } from "@/components/ui/icons";
import { Separator } from "@/components/ui/separator";
import BackendAPI, { LibraryAgent } from "@/lib/autogpt-server-api";
import BackendAPI from "@/lib/autogpt-server-api";
import { useRouter } from "next/navigation";
import Link from "next/link";
import { useToast } from "@/components/ui/use-toast";
import useSupabase from "@/hooks/useSupabase";
import { DownloadIcon, LoaderIcon } from "lucide-react";
import { useOnboarding } from "../onboarding/onboarding-provider";
import { User } from "@supabase/supabase-js";
import { cn } from "@/lib/utils";
import { FC, useCallback, useMemo, useState } from "react";
interface AgentInfoProps {
user: User | null;
name: string;
creator: string;
shortDescription: string;
@@ -24,11 +22,9 @@ interface AgentInfoProps {
lastUpdated: string;
version: string;
storeListingVersionId: string;
libraryAgent: LibraryAgent | null;
}
export const AgentInfo: FC<AgentInfoProps> = ({
user,
export const AgentInfo: React.FC<AgentInfoProps> = ({
name,
creator,
shortDescription,
@@ -39,48 +35,28 @@ export const AgentInfo: FC<AgentInfoProps> = ({
lastUpdated,
version,
storeListingVersionId,
libraryAgent,
}) => {
const router = useRouter();
const api = useMemo(() => new BackendAPI(), []);
const api = React.useMemo(() => new BackendAPI(), []);
const { user } = useSupabase();
const { toast } = useToast();
const { completeStep } = useOnboarding();
const [adding, setAdding] = useState(false);
const [downloading, setDownloading] = useState(false);
const libraryAction = useCallback(async () => {
setAdding(true);
if (libraryAgent) {
toast({
description: "Redirecting to your library...",
duration: 2000,
});
// Redirect to the library agent page
router.push(`/library/agents/${libraryAgent.id}`);
return;
}
const [downloading, setDownloading] = React.useState(false);
const handleAddToLibrary = async () => {
try {
const newLibraryAgent = await api.addMarketplaceAgentToLibrary(
storeListingVersionId,
);
completeStep("MARKETPLACE_ADD_AGENT");
router.push(`/library/agents/${newLibraryAgent.id}`);
toast({
title: "Agent Added",
description: "Redirecting to your library...",
duration: 2000,
});
} catch (error) {
console.error("Failed to add agent to library:", error);
toast({
title: "Error",
description: "Failed to add agent to library. Please try again.",
variant: "destructive",
});
}
}, [toast, api, storeListingVersionId, completeStep, router]);
};
const handleDownload = useCallback(async () => {
const handleDownloadToLibrary = async () => {
const downloadAgent = async (): Promise<void> => {
setDownloading(true);
try {
@@ -113,16 +89,12 @@ export const AgentInfo: FC<AgentInfoProps> = ({
});
} catch (error) {
console.error(`Error downloading agent:`, error);
toast({
title: "Error",
description: "Failed to download agent. Please try again.",
variant: "destructive",
});
throw error;
}
};
await downloadAgent();
setDownloading(false);
}, [setDownloading, api, storeListingVersionId, toast]);
};
return (
<div className="w-full max-w-[396px] px-4 sm:px-6 lg:w-[396px] lg:px-0">
@@ -133,61 +105,65 @@ export const AgentInfo: FC<AgentInfoProps> = ({
{/* Creator */}
<div className="mb-3 flex w-full items-center gap-1.5 lg:mb-4">
<div className="font-sans text-base font-normal text-neutral-800 dark:text-neutral-200 sm:text-lg lg:text-xl">
<div className="font-geist text-base font-normal text-neutral-800 dark:text-neutral-200 sm:text-lg lg:text-xl">
by
</div>
<Link
href={`/marketplace/creator/${encodeURIComponent(creator)}`}
className="font-sans text-base font-medium text-neutral-800 hover:underline dark:text-neutral-200 sm:text-lg lg:text-xl"
className="font-geist text-base font-medium text-neutral-800 hover:underline dark:text-neutral-200 sm:text-lg lg:text-xl"
>
{creator}
</Link>
</div>
{/* Short Description */}
<div className="mb-4 line-clamp-2 w-full font-sans text-base font-normal leading-normal text-neutral-600 dark:text-neutral-300 sm:text-lg lg:mb-6 lg:text-xl lg:leading-7">
<div className="font-geist mb-4 line-clamp-2 w-full text-base font-normal leading-normal text-neutral-600 dark:text-neutral-300 sm:text-lg lg:mb-6 lg:text-xl lg:leading-7">
{shortDescription}
</div>
{/* Buttons */}
<div className="mb-4 flex w-full gap-3 lg:mb-[60px]">
{user && (
{/* Run Agent Button */}
<div className="mb-4 w-full lg:mb-[60px]">
{user ? (
<button
className={cn(
"inline-flex min-w-24 items-center justify-center rounded-full bg-violet-600 px-4 py-3",
"transition-colors duration-200 hover:bg-violet-500 disabled:bg-zinc-400",
)}
onClick={libraryAction}
disabled={adding}
onClick={handleAddToLibrary}
className="inline-flex w-full items-center justify-center gap-2 rounded-[38px] bg-violet-600 px-4 py-3 transition-colors hover:bg-violet-700 sm:w-auto sm:gap-2.5 sm:px-5 sm:py-3.5 lg:px-6 lg:py-4"
>
<span className="justify-start font-sans text-sm font-medium leading-snug text-primary-foreground">
{libraryAgent ? "See runs" : "Add to library"}
<IconPlay className="h-5 w-5 text-white sm:h-5 sm:w-5 lg:h-6 lg:w-6" />
<span className="font-poppins text-base font-medium text-neutral-50 sm:text-lg">
Add To Library
</span>
</button>
) : (
<button
onClick={handleDownloadToLibrary}
className={`inline-flex w-full items-center justify-center gap-2 rounded-[38px] px-4 py-3 transition-colors sm:w-auto sm:gap-2.5 sm:px-5 sm:py-3.5 lg:px-6 lg:py-4 ${
downloading
? "bg-neutral-400"
: "bg-violet-600 hover:bg-violet-700"
}`}
disabled={downloading}
>
{downloading ? (
<LoaderIcon className="h-5 w-5 animate-spin text-white sm:h-5 sm:w-5 lg:h-6 lg:w-6" />
) : (
<DownloadIcon className="h-5 w-5 text-white sm:h-5 sm:w-5 lg:h-6 lg:w-6" />
)}
<span className="font-poppins text-base font-medium text-neutral-50 sm:text-lg">
{downloading ? "Downloading..." : "Download Agent as File"}
</span>
</button>
)}
<button
className={cn(
"inline-flex min-w-24 items-center justify-center rounded-full bg-zinc-200 px-4 py-3",
"transition-colors duration-200 hover:bg-zinc-200/70 disabled:bg-zinc-200/40",
)}
onClick={handleDownload}
disabled={downloading}
>
<div className="justify-start text-center font-sans text-sm font-medium leading-snug text-zinc-800">
Download agent
</div>
</button>
</div>
{/* Rating and Runs */}
<div className="mb-4 flex w-full items-center justify-between lg:mb-[44px]">
<div className="flex items-center gap-1.5 sm:gap-2">
<span className="whitespace-nowrap font-sans text-base font-semibold text-neutral-800 dark:text-neutral-200 sm:text-lg">
<span className="font-geist whitespace-nowrap text-base font-semibold text-neutral-800 dark:text-neutral-200 sm:text-lg">
{rating.toFixed(1)}
</span>
<div className="flex gap-0.5">{StarRatingIcons(rating)}</div>
</div>
<div className="whitespace-nowrap font-sans text-base font-semibold text-neutral-800 dark:text-neutral-200 sm:text-lg">
<div className="font-geist whitespace-nowrap text-base font-semibold text-neutral-800 dark:text-neutral-200 sm:text-lg">
{runs.toLocaleString()} runs
</div>
</div>
@@ -207,14 +183,14 @@ export const AgentInfo: FC<AgentInfoProps> = ({
{/* Categories */}
<div className="mb-4 flex w-full flex-col gap-1.5 sm:gap-2 lg:mb-[36px]">
<div className="decoration-skip-ink-none mb-1.5 font-sans text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
<div className="font-geist decoration-skip-ink-none mb-1.5 text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
Categories
</div>
<div className="flex flex-wrap gap-1.5 sm:gap-2">
{categories.map((category, index) => (
<div
key={index}
className="decoration-skip-ink-none whitespace-nowrap rounded-full border border-neutral-600 bg-white px-2 py-0.5 font-sans text-base font-normal leading-6 text-neutral-800 underline-offset-[from-font] dark:border-neutral-700 dark:bg-neutral-800 dark:text-neutral-200 sm:px-[16px] sm:py-[10px]"
className="font-geist decoration-skip-ink-none whitespace-nowrap rounded-full border border-neutral-600 bg-white px-2 py-0.5 text-base font-normal leading-6 text-neutral-800 underline-offset-[from-font] dark:border-neutral-700 dark:bg-neutral-800 dark:text-neutral-200 sm:px-[16px] sm:py-[10px]"
>
{category}
</div>
@@ -224,10 +200,10 @@ export const AgentInfo: FC<AgentInfoProps> = ({
{/* Version History */}
<div className="flex w-full flex-col gap-0.5 sm:gap-1">
<div className="decoration-skip-ink-none mb-1.5 font-sans text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
<div className="font-geist decoration-skip-ink-none mb-1.5 text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
Version history
</div>
<div className="decoration-skip-ink-none font-sans text-base font-normal leading-6 text-neutral-600 underline-offset-[from-font] dark:text-neutral-400">
<div className="font-geist decoration-skip-ink-none text-base font-normal leading-6 text-neutral-600 underline-offset-[from-font] dark:text-neutral-400">
Last updated {lastUpdated}
</div>
<div className="text-xs text-neutral-600 dark:text-neutral-400 sm:text-sm">

View File

@@ -15,8 +15,8 @@ export interface AgentTableCardProps {
imageSrc: string[];
dateSubmitted: string;
status: StatusType;
runs: number;
rating: number;
runs?: number;
rating?: number;
id: number;
onEditSubmission: (submission: StoreSubmissionRequest) => void;
}
@@ -82,11 +82,11 @@ export const AgentTableCard: React.FC<AgentTableCardProps> = ({
{dateSubmitted}
</div>
<div className="text-sm text-neutral-600 dark:text-neutral-400">
{runs.toLocaleString()} runs
{runs ? runs.toLocaleString() : "N/A"} runs
</div>
<div className="flex items-center gap-1">
<span className="text-sm font-medium text-neutral-800 dark:text-neutral-200">
{rating.toFixed(1)}
{rating ? rating.toFixed(1) : "N/A"}
</span>
<IconStarFilled className="h-4 w-4 text-neutral-800 dark:text-neutral-200" />
</div>

View File

@@ -3,8 +3,6 @@ import { Navbar } from "./Navbar";
import { userEvent, within } from "@storybook/test";
import { IconType } from "../ui/icons";
import { ProfileDetails } from "@/lib/autogpt-server-api/types";
// You can't import this here, jest is not available in storybook and will crash it
// import { jest } from "@jest/globals";
// Mock the API responses
const mockProfileData: ProfileDetails = {
@@ -15,40 +13,6 @@ const mockProfileData: ProfileDetails = {
avatar_url: "https://avatars.githubusercontent.com/u/123456789?v=4",
};
const mockCreditData = {
credits: 1500,
};
// Mock the API module
// jest.mock("@/lib/autogpt-server-api", () => {
// return function () {
// return {
// getStoreProfile: () => Promise.resolve(mockProfileData),
// getUserCredit: () => Promise.resolve(mockCreditData),
// };
// };
// });
const meta = {
title: "AGPT UI/Navbar",
component: Navbar,
parameters: {
layout: "fullscreen",
},
tags: ["autodocs"],
argTypes: {
// isLoggedIn: { control: "boolean" },
// avatarSrc: { control: "text" },
links: { control: "object" },
// activeLink: { control: "text" },
menuItemGroups: { control: "object" },
// params: { control: { type: "object", defaultValue: { lang: "en" } } },
},
} satisfies Meta<typeof Navbar>;
export default meta;
type Story = StoryObj<typeof meta>;
const defaultMenuItemGroups = [
{
items: [
@@ -89,35 +53,83 @@ const defaultLinks = [
{ name: "Build", href: "/builder" },
];
const meta = {
title: "AGPT UI/Navbar",
component: Navbar,
parameters: {
layout: "fullscreen",
},
tags: ["autodocs"],
argTypes: {
links: { control: "object" },
menuItemGroups: { control: "object" },
mockUser: { control: "object" },
mockClientProps: { control: "object" },
},
} satisfies Meta<typeof Navbar>;
export default meta;
type Story = StoryObj<typeof meta>;
export const Default: Story = {
args: {
// params: { lang: "en" },
// isLoggedIn: true,
links: defaultLinks,
// activeLink: "/marketplace",
// avatarSrc: mockProfileData.avatar_url,
menuItemGroups: defaultMenuItemGroups,
mockUser: {
id: "123",
email: "test@test.com",
user_metadata: {
name: "Test User",
},
app_metadata: {
provider: "email",
},
aud: "test",
created_at: new Date().toISOString(),
},
mockClientProps: {
credits: 1500,
profile: mockProfileData,
},
},
parameters: {
mockBackend: {
credits: 1500,
profile: mockProfileData,
},
},
};
export const WithActiveLink: Story = {
export const WithCredits: Story = {
args: {
...Default.args,
// activeLink: "/library",
},
parameters: {
mockBackend: {
credits: 1500,
},
},
};
export const LongUserName: Story = {
export const WithLargeCredits: Story = {
args: {
...Default.args,
// avatarSrc: "https://avatars.githubusercontent.com/u/987654321?v=4",
},
parameters: {
mockBackend: {
credits: 999999,
},
},
};
export const NoAvatar: Story = {
export const WithZeroCredits: Story = {
args: {
...Default.args,
// avatarSrc: undefined,
},
parameters: {
mockBackend: {
credits: 0,
},
},
};
@@ -125,6 +137,12 @@ export const WithInteraction: Story = {
args: {
...Default.args,
},
parameters: {
mockBackend: {
credits: 1500,
profile: mockProfileData,
},
},
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
const profileTrigger = canvas.getByRole("button");
@@ -135,29 +153,3 @@ export const WithInteraction: Story = {
await canvas.findByText("Edit profile");
},
};
export const NotLoggedIn: Story = {
args: {
...Default.args,
// isLoggedIn: false,
// avatarSrc: undefined,
},
};
export const WithCredits: Story = {
args: {
...Default.args,
},
};
export const WithLargeCredits: Story = {
args: {
...Default.args,
},
};
export const WithZeroCredits: Story = {
args: {
...Default.args,
},
};

View File

@@ -9,6 +9,10 @@ import { ProfileDetails } from "@/lib/autogpt-server-api/types";
import { NavbarLink } from "./NavbarLink";
import getServerUser from "@/lib/supabase/getServerUser";
import BackendAPI from "@/lib/autogpt-server-api";
import { User } from "@supabase/supabase-js";
import MockClient, {
MockClientProps,
} from "@/lib/autogpt-server-api/mock_client";
// Disable theme toggle for now
// import { ThemeToggle } from "./ThemeToggle";
@@ -29,21 +33,33 @@ interface NavbarProps {
onClick?: () => void;
}[];
}[];
mockUser?: User;
mockClientProps?: MockClientProps;
}
async function getProfileData() {
async function getProfileData(mockClientProps?: MockClientProps) {
if (mockClientProps) {
const api = new MockClient(mockClientProps);
const profile = await Promise.resolve(api.getStoreProfile("navbar"));
return profile;
}
const api = new BackendAPI();
const profile = await Promise.resolve(api.getStoreProfile());
return profile;
}
export const Navbar = async ({ links, menuItemGroups }: NavbarProps) => {
const { user } = await getServerUser();
export const Navbar = async ({
links,
menuItemGroups,
mockUser,
mockClientProps,
}: NavbarProps) => {
const { user } = await getServerUser(mockUser);
const isLoggedIn = user !== null;
let profile: ProfileDetails | null = null;
if (isLoggedIn) {
profile = await getProfileData();
profile = await getProfileData(mockClientProps);
}
return (

View File

@@ -1,43 +0,0 @@
"use client";
import { cn } from "@/lib/utils";
import Image from "next/image";
import { useState } from "react";
interface SmartImageProps {
src?: string | null;
alt: string;
imageContain?: boolean;
className?: string;
}
export default function SmartImage({
src,
alt,
imageContain,
className,
}: SmartImageProps) {
const [isLoading, setIsLoading] = useState(true);
const shouldShowSkeleton = isLoading || !src;
return (
<div className={cn("relative overflow-hidden", className)}>
{src && (
<Image
src={src}
alt={alt}
fill
onLoad={() => setIsLoading(false)}
className={cn(
"rounded-inherit object-center transition-opacity duration-300",
isLoading ? "opacity-0" : "opacity-100",
imageContain ? "object-contain" : "object-cover",
)}
sizes="100%"
/>
)}
{shouldShowSkeleton && (
<div className="rounded-inherit absolute inset-0 animate-pulse bg-gray-300 dark:bg-gray-700" />
)}
</div>
);
}

View File

@@ -11,7 +11,7 @@ import { PopoverClose } from "@radix-ui/react-popover";
import { TaskGroups } from "../onboarding/WalletTaskGroups";
import { ScrollArea } from "../ui/scroll-area";
import { useOnboarding } from "../onboarding/onboarding-provider";
import { useCallback, useEffect, useRef, useState } from "react";
import { useCallback, useEffect, useRef } from "react";
import { cn } from "@/lib/utils";
import * as party from "party-js";
import WalletRefill from "./WalletRefill";
@@ -21,11 +21,6 @@ export default function Wallet() {
fetchInitialCredits: true,
});
const { state, updateState } = useOnboarding();
const [prevCredits, setPrevCredits] = useState<number | null>(credits);
const [flash, setFlash] = useState(false);
const [stepsLength, setStepsLength] = useState<number | null>(
state?.completedSteps?.length || null,
);
const walletRef = useRef<HTMLButtonElement | null>(null);
const onWalletOpen = useCallback(async () => {
@@ -42,81 +37,49 @@ export default function Wallet() {
.through("lifetime")
.build();
// Confetti effect on the wallet button
useEffect(() => {
if (!state?.completedSteps) {
return;
}
// If we haven't set the length yet, just set it and return
if (stepsLength === null) {
setStepsLength(state?.completedSteps?.length);
return;
}
// It's enough to compare array lengths,
// because the order of completed steps is not important
// If the length is the same, we don't need to do anything
if (state?.completedSteps?.length === stepsLength) {
return;
}
// Otherwise, we need to set the new length
setStepsLength(state?.completedSteps?.length);
// And make confetti
if (walletRef.current) {
setTimeout(() => {
fetchCredits();
party.confetti(walletRef.current!, {
count: 30,
spread: 120,
shapes: ["square", "circle"],
size: party.variation.range(1, 2),
speed: party.variation.range(200, 300),
modules: [fadeOut],
});
}, 800);
// Check if there are any completed tasks (state?.completedTasks) that
// are not in the state?.notified array and play confetti if so
const pending = state?.completedSteps
.filter((step) => !state?.notified.includes(step))
// Ignore steps that are not relevant for notifications
.filter(
(step) =>
step !== "WELCOME" &&
step !== "USAGE_REASON" &&
step !== "INTEGRATIONS" &&
step !== "AGENT_CHOICE" &&
step !== "AGENT_NEW_RUN" &&
step !== "AGENT_INPUT",
);
if ((pending?.length || 0) > 0 && walletRef.current) {
party.confetti(walletRef.current, {
count: 30,
spread: 120,
shapes: ["square", "circle"],
size: party.variation.range(1, 2),
speed: party.variation.range(200, 300),
modules: [fadeOut],
});
}
}, [state?.completedSteps, state?.notified]);
// Wallet flash on credits change
useEffect(() => {
if (credits === prevCredits) {
return;
}
setPrevCredits(credits);
if (prevCredits === null) {
return;
}
setFlash(true);
setTimeout(() => {
setFlash(false);
}, 300);
}, [credits]);
return (
<Popover>
<PopoverTrigger asChild>
<div className="relative inline-block">
<button
ref={walletRef}
className={cn(
"relative flex items-center gap-1 rounded-md bg-zinc-200 px-3 py-2 text-sm transition-colors duration-200 hover:bg-zinc-300",
)}
onClick={onWalletOpen}
>
Wallet{" "}
<span className="text-sm font-semibold">
{formatCredits(credits)}
</span>
{state?.notificationDot && (
<span className="absolute right-1 top-1 h-2 w-2 rounded-full bg-violet-600"></span>
)}
</button>
<div
className={cn(
"pointer-events-none absolute inset-0 rounded-md bg-violet-400 duration-2000 ease-in-out",
flash ? "opacity-50 duration-0" : "opacity-0",
)}
/>
</div>
<button
ref={walletRef}
className="relative flex items-center gap-1 rounded-md bg-zinc-200 px-3 py-2 text-sm transition-colors duration-200 hover:bg-zinc-300"
onClick={onWalletOpen}
>
Wallet{" "}
<span className="text-sm font-semibold">
{formatCredits(credits)}
</span>
{state?.notificationDot && (
<span className="absolute right-1 top-1 h-2 w-2 rounded-full bg-violet-600"></span>
)}
</button>
</PopoverTrigger>
<PopoverContent
className={cn(
@@ -141,9 +104,7 @@ export default function Wallet() {
</div>
<ScrollArea className="max-h-[85vh] overflow-y-auto">
{/* Top ups */}
{process.env.NEXT_PUBLIC_SHOW_BILLING_PAGE === "true" && (
<WalletRefill />
)}
<WalletRefill />
{/* Tasks */}
<p className="mx-1 mt-4 font-sans text-xs font-medium text-violet-700">
Onboarding tasks

View File

@@ -22,7 +22,6 @@ export default function ActionButtonGroup({
<Button
key={i}
variant={action.variant ?? "outline"}
disabled={action.disabled}
onClick={action.callback}
>
{action.label}
@@ -30,11 +29,7 @@ export default function ActionButtonGroup({
) : (
<Link
key={i}
className={cn(
buttonVariants({ variant: action.variant }),
action.disabled &&
"pointer-events-none border-zinc-400 text-zinc-400",
)}
className={buttonVariants({ variant: action.variant })}
href={action.href}
>
{action.label}

Some files were not shown because too many files have changed in this diff Show More