Merge branch 'dev' into zamilmajdy/secrt-1222-move-scheduler-into-a-singleton

This commit is contained in:
Swifty
2025-04-14 17:16:38 +02:00
committed by GitHub
61 changed files with 1549 additions and 808 deletions

View File

@@ -1,8 +1,6 @@
import logging
from typing import Any
from autogpt_libs.utils.cache import thread_cached
from backend.data.block import (
Block,
BlockCategory,
@@ -19,21 +17,6 @@ from backend.util import json
logger = logging.getLogger(__name__)
@thread_cached
def get_executor_manager_client():
from backend.executor import ExecutionManager
from backend.util.service import get_service_client
return get_service_client(ExecutionManager)
@thread_cached
def get_event_bus():
from backend.data.execution import RedisExecutionEventBus
return RedisExecutionEventBus()
class AgentExecutorBlock(Block):
class Input(BlockSchema):
user_id: str = SchemaField(description="User ID")
@@ -76,11 +59,11 @@ class AgentExecutorBlock(Block):
def run(self, input_data: Input, **kwargs) -> BlockOutput:
from backend.data.execution import ExecutionEventType
from backend.executor import utils as execution_utils
executor_manager = get_executor_manager_client()
event_bus = get_event_bus()
event_bus = execution_utils.get_execution_event_bus()
graph_exec = executor_manager.add_execution(
graph_exec = execution_utils.add_graph_execution(
graph_id=input_data.graph_id,
graph_version=input_data.graph_version,
user_id=input_data.user_id,

View File

@@ -3,7 +3,6 @@ import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timezone
from typing import cast
import stripe
from autogpt_libs.utils.cache import thread_cached
@@ -12,6 +11,7 @@ from prisma.enums import (
CreditRefundRequestStatus,
CreditTransactionType,
NotificationType,
OnboardingStep,
)
from prisma.errors import UniqueViolationError
from prisma.models import CreditRefundRequest, CreditTransaction, User
@@ -19,7 +19,6 @@ from prisma.types import (
CreditRefundRequestCreateInput,
CreditTransactionCreateInput,
CreditTransactionWhereInput,
IntFilter,
)
from tenacity import retry, stop_after_attempt, wait_exponential
@@ -123,6 +122,18 @@ class UserCreditBase(ABC):
"""
pass
@abstractmethod
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
"""
Reward the user with credits for completing an onboarding step.
Won't reward if the user has already received credits for the step.
Args:
user_id (str): The user ID.
step (OnboardingStep): The onboarding step.
"""
pass
@abstractmethod
async def top_up_intent(self, user_id: str, amount: int) -> str:
"""
@@ -215,7 +226,7 @@ class UserCreditBase(ABC):
"userId": user_id,
"createdAt": {"lte": top_time},
"isActive": True,
"runningBalance": cast(IntFilter, {"not": None}),
"NOT": [{"runningBalance": None}],
},
order={"createdAt": "desc"},
)
@@ -410,6 +421,24 @@ class UserCredit(UserCreditBase):
async def top_up_credits(self, user_id: str, amount: int):
await self._top_up_credits(user_id, amount)
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
key = f"REWARD-{user_id}-{step.value}"
if not await CreditTransaction.prisma().find_first(
where={
"userId": user_id,
"transactionKey": key,
}
):
await self._add_transaction(
user_id=user_id,
amount=credits,
transaction_type=CreditTransactionType.GRANT,
transaction_key=key,
metadata=Json(
{"reason": f"Reward for completing {step.value} onboarding step."}
),
)
async def top_up_refund(
self, user_id: str, transaction_key: str, metadata: dict[str, str]
) -> int:
@@ -897,6 +926,9 @@ class DisabledUserCredit(UserCreditBase):
async def top_up_credits(self, *args, **kwargs):
pass
async def onboarding_reward(self, *args, **kwargs):
pass
async def top_up_intent(self, *args, **kwargs) -> str:
return ""

View File

@@ -89,7 +89,7 @@ async def transaction():
async def locked_transaction(key: str):
lock_key = zlib.crc32(key.encode("utf-8"))
async with transaction() as tx:
await tx.execute_raw(f"SELECT pg_advisory_xact_lock({lock_key})")
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
yield tx

View File

@@ -34,11 +34,10 @@ from pydantic import BaseModel
from pydantic.fields import Field
from backend.server.v2.store.exceptions import DatabaseError
from backend.util import mock
from backend.util import type as type_utils
from backend.util.settings import Config
from .block import BlockData, BlockInput, BlockType, CompletedBlockOutput, get_block
from .block import BlockInput, BlockType, CompletedBlockOutput, get_block
from .db import BaseDbModel
from .includes import (
EXECUTION_RESULT_INCLUDE,
@@ -203,6 +202,26 @@ class GraphExecutionWithNodes(GraphExecution):
node_executions=node_executions,
)
def to_graph_execution_entry(self):
return GraphExecutionEntry(
user_id=self.user_id,
graph_id=self.graph_id,
graph_version=self.graph_version or 0,
graph_exec_id=self.id,
start_node_execs=[
NodeExecutionEntry(
user_id=self.user_id,
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
data=node_exec.input_data,
)
for node_exec in self.node_executions
],
)
class NodeExecutionResult(BaseModel):
user_id: str
@@ -469,19 +488,27 @@ async def upsert_execution_output(
)
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecution:
res = await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
async def update_graph_execution_start_time(
graph_exec_id: str,
) -> GraphExecution | None:
count = await AgentGraphExecution.prisma().update_many(
where={
"id": graph_exec_id,
"executionStatus": ExecutionStatus.QUEUED,
},
data={
"executionStatus": ExecutionStatus.RUNNING,
"startedAt": datetime.now(tz=timezone.utc),
},
)
if count == 0:
return None
res = await AgentGraphExecution.prisma().find_unique(
where={"id": graph_exec_id},
include=GRAPH_EXECUTION_INCLUDE,
)
if not res:
raise ValueError(f"Graph execution #{graph_exec_id} not found")
return GraphExecution.from_db(res)
return GraphExecution.from_db(res) if res else None
async def update_graph_execution_stats(
@@ -492,7 +519,8 @@ async def update_graph_execution_stats(
data = stats.model_dump() if stats else {}
if isinstance(data.get("error"), Exception):
data["error"] = str(data["error"])
res = await AgentGraphExecution.prisma().update(
updated_count = await AgentGraphExecution.prisma().update_many(
where={
"id": graph_exec_id,
"OR": [
@@ -504,10 +532,15 @@ async def update_graph_execution_stats(
"executionStatus": status,
"stats": Json(data),
},
)
if updated_count == 0:
return None
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
where={"id": graph_exec_id},
include=GRAPH_EXECUTION_INCLUDE,
)
return GraphExecution.from_db(res) if res else None
return GraphExecution.from_db(graph_exec)
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
@@ -643,7 +676,7 @@ async def get_latest_node_execution(
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
"NOT": [{"executionStatus": ExecutionStatus.INCOMPLETE}],
},
order=[
{"queuedTime": "desc"},
@@ -711,144 +744,6 @@ class ExecutionQueue(Generic[T]):
return self.queue.empty()
# ------------------- Execution Utilities -------------------- #
LIST_SPLIT = "_$_"
DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
def parse_execution_output(output: BlockData, name: str) -> Any | None:
"""
Extracts partial output data by name from a given BlockData.
The function supports extracting data from lists, dictionaries, and objects
using specific naming conventions:
- For lists: <output_name>_$_<index>
- For dictionaries: <output_name>_#_<key>
- For objects: <output_name>_@_<attribute>
Args:
output (BlockData): A tuple containing the output name and data.
name (str): The name used to extract specific data from the output.
Returns:
Any | None: The extracted data if found, otherwise None.
Examples:
>>> output = ("result", [10, 20, 30])
>>> parse_execution_output(output, "result_$_1")
20
>>> output = ("config", {"key1": "value1", "key2": "value2"})
>>> parse_execution_output(output, "config_#_key1")
'value1'
>>> class Sample:
... attr1 = "value1"
... attr2 = "value2"
>>> output = ("object", Sample())
>>> parse_execution_output(output, "object_@_attr1")
'value1'
"""
output_name, output_data = output
if name == output_name:
return output_data
if name.startswith(f"{output_name}{LIST_SPLIT}"):
index = int(name.split(LIST_SPLIT)[1])
if not isinstance(output_data, list) or len(output_data) <= index:
return None
return output_data[int(name.split(LIST_SPLIT)[1])]
if name.startswith(f"{output_name}{DICT_SPLIT}"):
index = name.split(DICT_SPLIT)[1]
if not isinstance(output_data, dict) or index not in output_data:
return None
return output_data[index]
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
index = name.split(OBJC_SPLIT)[1]
if isinstance(output_data, object) and hasattr(output_data, index):
return getattr(output_data, index)
return None
return None
def merge_execution_input(data: BlockInput) -> BlockInput:
"""
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
This function processes input keys that follow specific patterns to merge them into a unified structure:
- `<input_name>_$_<index>` for list inputs.
- `<input_name>_#_<index>` for dictionary inputs.
- `<input_name>_@_<index>` for object inputs.
Args:
data (BlockInput): A dictionary containing input keys and their corresponding values.
Returns:
BlockInput: A dictionary with merged inputs.
Raises:
ValueError: If a list index is not an integer.
Examples:
>>> data = {
... "list_$_0": "a",
... "list_$_1": "b",
... "dict_#_key1": "value1",
... "dict_#_key2": "value2",
... "object_@_attr1": "value1",
... "object_@_attr2": "value2"
... }
>>> merge_execution_input(data)
{
"list": ["a", "b"],
"dict": {"key1": "value1", "key2": "value2"},
"object": <MockObject attr1="value1" attr2="value2">
}
"""
# Merge all input with <input_name>_$_<index> into a single list.
items = list(data.items())
for key, value in items:
if LIST_SPLIT not in key:
continue
name, index = key.split(LIST_SPLIT)
if not index.isdigit():
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
data[name] = data.get(name, [])
if int(index) >= len(data[name]):
# Pad list with empty string on missing indices.
data[name].extend([""] * (int(index) - len(data[name]) + 1))
data[name][int(index)] = value
# Merge all input with <input_name>_#_<index> into a single dict.
for key, value in items:
if DICT_SPLIT not in key:
continue
name, index = key.split(DICT_SPLIT)
data[name] = data.get(name, {})
data[name][index] = value
# Merge all input with <input_name>_@_<index> into a single object.
for key, value in items:
if OBJC_SPLIT not in key:
continue
name, index = key.split(OBJC_SPLIT)
if name not in data or not isinstance(data[name], object):
data[name] = mock.MockObject()
setattr(data[name], index, value)
return data
# --------------------- Event Bus --------------------- #

View File

@@ -10,9 +10,7 @@ from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVers
from prisma.types import (
AgentGraphCreateInput,
AgentGraphWhereInput,
AgentGraphWhereInputRecursive1,
AgentNodeCreateInput,
AgentNodeIncludeFromAgentNodeRecursive1,
AgentNodeLinkCreateInput,
)
from pydantic.fields import computed_field
@@ -655,14 +653,11 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
graphs = await AgentGraph.prisma().find_many(
where={
"OR": [
type_utils.typed(
AgentGraphWhereInputRecursive1,
{
"id": graph_id,
"version": graph_version,
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
},
)
{
"id": graph_id,
"version": graph_version,
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
}
for graph_id, graph_version in sub_graph_ids
]
},
@@ -678,13 +673,7 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
links = await AgentNodeLink.prisma().find_many(
where={"agentNodeSourceId": node_id},
include={
"AgentNodeSink": {
"include": cast(
AgentNodeIncludeFromAgentNodeRecursive1, AGENT_NODE_INCLUDE
)
}
},
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}},
)
return [
(Link.from_db(link), NodeModel.from_db(link.AgentNodeSink))
@@ -930,12 +919,19 @@ async def migrate_llm_models(migrate_to: LlmModel):
# Convert enum values to a list of strings for the SQL query
enum_values = [v.value for v in LlmModel.__members__.values()]
escaped_enum_values = repr(tuple(enum_values)) # hack but works
query = f"""
UPDATE "AgentNode"
SET "constantInput" = jsonb_set("constantInput", '{{{path}}}', '"{migrate_to.value}"', true)
WHERE "agentBlockId" = '{id}'
AND "constantInput" ? '{path}'
AND "constantInput"->>'{path}' NOT IN ({','.join(f"'{value}'" for value in enum_values)})
SET "constantInput" = jsonb_set("constantInput", $1, $2, true)
WHERE "agentBlockId" = $3
AND "constantInput" ? $4
AND "constantInput"->>$4 NOT IN {escaped_enum_values}
"""
await db.execute_raw(query)
await db.execute_raw(
query, # type: ignore - is supposed to be LiteralString
"{" + path + "}",
f'"{migrate_to.value}"',
id,
path,
)

View File

@@ -4,7 +4,6 @@ import prisma.enums
import prisma.types
from backend.blocks.io import IO_BLOCK_IDs
from backend.util.type import typed_cast
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
"Input": True,
@@ -14,13 +13,7 @@ AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
}
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
"Nodes": {
"include": typed_cast(
prisma.types.AgentNodeIncludeFromAgentNodeRecursive1,
prisma.types.AgentNodeIncludeFromAgentNode,
AGENT_NODE_INCLUDE,
)
}
"Nodes": {"include": AGENT_NODE_INCLUDE}
}
EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
@@ -56,13 +49,7 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
GRAPH_EXECUTION_INCLUDE_WITH_NODES["NodeExecutions"],
),
"where": {
"Node": typed_cast(
prisma.types.AgentNodeRelationFilter,
prisma.types.AgentNodeWhereInput,
{
"AgentBlock": {"id": {"in": IO_BLOCK_IDs}},
},
),
"Node": {"is": {"AgentBlock": {"is": {"id": {"in": IO_BLOCK_IDs}}}}},
"NOT": [{"executionStatus": prisma.enums.AgentExecutionStatus.INCOMPLETE}],
},
}
@@ -70,13 +57,7 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
"AgentNodes": {
"include": typed_cast(
prisma.types.AgentNodeIncludeFromAgentNodeRecursive1,
prisma.types.AgentNodeInclude,
AGENT_NODE_INCLUDE,
)
}
"AgentNodes": {"include": AGENT_NODE_INCLUDE}
}

View File

@@ -8,7 +8,9 @@ from prisma.enums import OnboardingStep
from prisma.models import UserOnboarding
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
from backend.data import db
from backend.data.block import get_blocks
from backend.data.credit import get_user_credit_model
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.server.v2.store.model import StoreAgentDetails
@@ -24,14 +26,19 @@ REASON_MAPPING: dict[str, list[str]] = {
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
user_credit = get_user_credit_model()
class UserOnboardingUpdate(pydantic.BaseModel):
completedSteps: Optional[list[OnboardingStep]] = None
notificationDot: Optional[bool] = None
notified: Optional[list[OnboardingStep]] = None
usageReason: Optional[str] = None
integrations: Optional[list[str]] = None
otherIntegrations: Optional[str] = None
selectedStoreListingVersionId: Optional[str] = None
agentInput: Optional[dict[str, Any]] = None
onboardingAgentExecutionId: Optional[str] = None
async def get_user_onboarding(user_id: str):
@@ -48,6 +55,20 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update: UserOnboardingUpdateInput = {}
if data.completedSteps is not None:
update["completedSteps"] = list(set(data.completedSteps))
for step in (
OnboardingStep.AGENT_NEW_RUN,
OnboardingStep.GET_RESULTS,
OnboardingStep.MARKETPLACE_ADD_AGENT,
OnboardingStep.MARKETPLACE_RUN_AGENT,
OnboardingStep.BUILDER_SAVE_AGENT,
OnboardingStep.BUILDER_RUN_AGENT,
):
if step in data.completedSteps:
await reward_user(user_id, step)
if data.notificationDot is not None:
update["notificationDot"] = data.notificationDot
if data.notified is not None:
update["notified"] = list(set(data.notified))
if data.usageReason is not None:
update["usageReason"] = data.usageReason
if data.integrations is not None:
@@ -58,16 +79,57 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update["selectedStoreListingVersionId"] = data.selectedStoreListingVersionId
if data.agentInput is not None:
update["agentInput"] = Json(data.agentInput)
if data.onboardingAgentExecutionId is not None:
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, **update}, # type: ignore
"create": {"userId": user_id, **update},
"update": update,
},
)
async def reward_user(user_id: str, step: OnboardingStep):
async with db.locked_transaction(f"usr_trx_{user_id}-reward"):
reward = 0
match step:
# Reward user when they clicked New Run during onboarding
# This is because they need credits before scheduling a run (next step)
case OnboardingStep.AGENT_NEW_RUN:
reward = 300
case OnboardingStep.GET_RESULTS:
reward = 300
case OnboardingStep.MARKETPLACE_ADD_AGENT:
reward = 100
case OnboardingStep.MARKETPLACE_RUN_AGENT:
reward = 100
case OnboardingStep.BUILDER_SAVE_AGENT:
reward = 100
case OnboardingStep.BUILDER_RUN_AGENT:
reward = 100
if reward == 0:
return
onboarding = await get_user_onboarding(user_id)
# Skip if already rewarded
if step in onboarding.rewardedFor:
return
onboarding.rewardedFor.append(step)
await user_credit.onboarding_reward(user_id, reward, step)
await UserOnboarding.prisma().update(
where={"userId": user_id},
data={
"completedSteps": list(set(onboarding.completedSteps + [step])),
"rewardedFor": onboarding.rewardedFor,
},
)
def clean_and_split(text: str) -> list[str]:
"""
Removes all special characters from a string, truncates it to 100 characters,

View File

@@ -149,7 +149,7 @@ async def migrate_and_encrypt_user_integrations():
logger.info(f"Migrating integration credentials for {len(users)} users")
for user in users:
raw_metadata = cast(UserMetadataRaw, user.metadata)
raw_metadata = cast(dict, user.metadata)
metadata = UserMetadata.model_validate(raw_metadata)
# Get existing integrations data
@@ -165,7 +165,6 @@ async def migrate_and_encrypt_user_integrations():
await update_user_integrations(user_id=user.id, data=integrations)
# Remove from metadata
raw_metadata = dict(raw_metadata)
raw_metadata.pop("integration_credentials", None)
raw_metadata.pop("integration_oauth_states", None)

View File

@@ -1,11 +1,8 @@
import logging
from backend.data import db, redis
from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
GraphExecution,
NodeExecutionResult,
RedisExecutionEventBus,
create_graph_execution,
get_graph_execution,
get_incomplete_node_executions,
@@ -42,7 +39,7 @@ from backend.data.user import (
update_user_integrations,
update_user_metadata,
)
from backend.util.service import AppService, expose, exposed_run_and_wait
from backend.util.service import AppService, exposed_run_and_wait
from backend.util.settings import Config
config = Config()
@@ -57,21 +54,14 @@ async def _spend_credits(
class DatabaseManager(AppService):
def __init__(self):
super().__init__()
self.execution_event_bus = RedisExecutionEventBus()
def run_service(self) -> None:
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
self.run_and_wait(db.connect())
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
redis.connect()
super().run_service()
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
redis.disconnect()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
self.run_and_wait(db.disconnect())
@@ -79,12 +69,6 @@ class DatabaseManager(AppService):
def get_port(cls) -> int:
return config.database_api_port
@expose
def send_execution_update(
self, execution_result: GraphExecution | NodeExecutionResult
):
self.execution_event_bus.publish(execution_result)
# Executions
get_graph_execution = exposed_run_and_wait(get_graph_execution)
create_graph_execution = exposed_run_and_wait(create_graph_execution)

View File

@@ -5,11 +5,14 @@ import os
import signal
import sys
import threading
import time
from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import contextmanager
from multiprocessing.pool import AsyncResult, Pool
from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
from pika.adapters.blocking_connection import BlockingChannel
from pika.spec import Basic
from redis.lock import Lock as RedisLock
from backend.blocks.io import AgentOutputBlock
@@ -30,43 +33,36 @@ from autogpt_libs.utils.cache import thread_cached
from backend.blocks.agent import AgentExecutorBlock
from backend.data import redis
from backend.data.block import (
Block,
BlockData,
BlockInput,
BlockSchema,
BlockType,
get_block,
)
from backend.data.block import BlockData, BlockInput, BlockSchema, get_block
from backend.data.execution import (
ExecutionQueue,
ExecutionStatus,
GraphExecution,
GraphExecutionEntry,
NodeExecutionEntry,
NodeExecutionResult,
merge_execution_input,
parse_execution_output,
)
from backend.data.graph import GraphModel, Link, Node
from backend.data.graph import Link, Node
from backend.executor.utils import (
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
GRAPH_EXECUTION_QUEUE_NAME,
CancelExecutionEvent,
UsageTransactionMetadata,
block_usage_cost,
execution_usage_cost,
get_execution_event_bus,
get_execution_queue,
parse_execution_output,
validate_exec,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util import json
from backend.util.decorator import error_logged, time_measured
from backend.util.file import clean_exec_files
from backend.util.logging import configure_logging
from backend.util.process import set_service_name
from backend.util.service import (
AppService,
close_service_client,
expose,
get_service_client,
)
from backend.util.process import AppProcess, set_service_name
from backend.util.service import close_service_client, get_service_client
from backend.util.settings import Settings
from backend.util.type import convert
logger = logging.getLogger(__name__)
settings = Settings()
@@ -152,7 +148,7 @@ def execute_node(
def update_execution_status(status: ExecutionStatus) -> NodeExecutionResult:
"""Sets status and fetches+broadcasts the latest state of the node execution"""
exec_update = db_client.update_node_execution_status(node_exec_id, status)
db_client.send_execution_update(exec_update)
send_execution_update(exec_update)
return exec_update
node = db_client.get_node(node_id)
@@ -288,7 +284,7 @@ def _enqueue_next_nodes(
exec_update = db_client.update_node_execution_status(
node_exec_id, ExecutionStatus.QUEUED, data
)
db_client.send_execution_update(exec_update)
send_execution_update(exec_update)
return NodeExecutionEntry(
user_id=user_id,
graph_exec_id=graph_exec_id,
@@ -400,60 +396,6 @@ def _enqueue_next_nodes(
]
def validate_exec(
node: Node,
data: BlockInput,
resolve_input: bool = True,
) -> tuple[BlockInput | None, str]:
"""
Validate the input data for a node execution.
Args:
node: The node to execute.
data: The input data for the node execution.
resolve_input: Whether to resolve dynamic pins into dict/list/object.
Returns:
A tuple of the validated data and the block name.
If the data is invalid, the first element will be None, and the second element
will be an error message.
If the data is valid, the first element will be the resolved input data, and
the second element will be the block name.
"""
node_block: Block | None = get_block(node.block_id)
if not node_block:
return None, f"Block for {node.block_id} not found."
schema = node_block.input_schema
# Convert non-matching data types to the expected input schema.
for name, data_type in schema.__annotations__.items():
if (value := data.get(name)) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Input data (without default values) should contain all required fields.
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
if missing_links := schema.get_missing_links(data, node.input_links):
return None, f"{error_prefix} unpopulated links {missing_links}"
# Merge input data with default values and resolve dynamic dict/list/object pins.
input_default = schema.get_input_defaults(node.input_default)
data = {**input_default, **data}
if resolve_input:
data = merge_execution_input(data)
# Input data post-merge should contain all required fields from the schema.
if missing_input := schema.get_missing_input(data):
return None, f"{error_prefix} missing input {missing_input}"
# Last validation: Validate the input values against the schema.
if error := schema.get_mismatch_error(data):
error_message = f"{error_prefix} {error}"
logger.error(error_message)
return None, error_message
return data, node_block.name
class Executor:
"""
This class contains event handlers for the process pool executor events.
@@ -633,7 +575,13 @@ class Executor:
exec_meta = cls.db_client.update_graph_execution_start_time(
graph_exec.graph_exec_id
)
cls.db_client.send_execution_update(exec_meta)
if exec_meta is None:
logger.warning(
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution is not found or not currently in the QUEUED state."
)
return
send_execution_update(exec_meta)
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
graph_exec, cancel, log_metadata
)
@@ -646,7 +594,7 @@ class Executor:
status=status,
stats=exec_stats,
):
cls.db_client.send_execution_update(graph_exec_result)
send_execution_update(graph_exec_result)
cls._handle_agent_run_notif(graph_exec, exec_stats)
@@ -759,7 +707,7 @@ class Executor:
status=execution_status,
stats=execution_stats,
):
cls.db_client.send_execution_update(_graph_exec)
send_execution_update(_graph_exec)
else:
logger.error(
"Callback for "
@@ -810,7 +758,7 @@ class Executor:
exec_update = cls.db_client.update_node_execution_status(
node_exec_id, execution_status
)
cls.db_client.send_execution_update(exec_update)
send_execution_update(exec_update)
cls._handle_low_balance_notif(
graph_exec.user_id,
@@ -927,22 +875,25 @@ class Executor:
)
class ExecutionManager(AppService):
class ExecutionManager(AppProcess):
def __init__(self):
super().__init__()
self.pool_size = settings.config.num_graph_workers
self.queue = ExecutionQueue[GraphExecutionEntry]()
self.running = True
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
@classmethod
def get_port(cls) -> int:
return settings.config.execution_manager_port
def run_service(self):
from backend.integrations.credentials_store import IntegrationCredentialsStore
self.credentials_store = IntegrationCredentialsStore()
def run(self):
while True:
try:
self._run()
except Exception:
logger.exception(f"[{self.service_name}] error in graph executor loop")
def _run(self):
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
self.executor = ProcessPoolExecutor(
max_workers=self.pool_size,
@@ -952,25 +903,103 @@ class ExecutionManager(AppService):
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
redis.connect()
sync_manager = multiprocessing.Manager()
logger.info(f"[{self.service_name}] Ready to consume messages...")
while True:
graph_exec_data = self.queue.get()
graph_exec_id = graph_exec_data.graph_exec_id
logger.debug(
f"[ExecutionManager] Dispatching graph execution {graph_exec_id}"
channel = get_execution_queue().get_channel()
# cancel graph execution requests
method_frame, _, body = channel.basic_get(
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
auto_ack=True,
)
cancel_event = sync_manager.Event()
future = self.executor.submit(
Executor.on_graph_execution, graph_exec_data, cancel_event
if method_frame:
self._handle_cancel_message(body)
# start graph execution requests
method_frame, _, body = channel.basic_get(
queue=GRAPH_EXECUTION_QUEUE_NAME,
auto_ack=False,
)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
future.add_done_callback(
lambda _: self.active_graph_runs.pop(graph_exec_id, None)
if method_frame:
self._handle_run_message(channel, method_frame, body)
else:
time.sleep(0.2)
def _handle_cancel_message(self, body: bytes):
try:
request = CancelExecutionEvent.model_validate_json(body)
graph_exec_id = request.graph_exec_id
if not graph_exec_id:
logger.warning(
f"[{self.service_name}] Cancel message missing 'graph_exec_id'"
)
return
if graph_exec_id not in self.active_graph_runs:
logger.debug(
f"[{self.service_name}] Cancel received for {graph_exec_id} but not active."
)
return
_, cancel_event = self.active_graph_runs[graph_exec_id]
logger.info(f"[{self.service_name}] Received cancel for {graph_exec_id}")
if not cancel_event.is_set():
cancel_event.set()
else:
logger.debug(
f"[{self.service_name}] Cancel already set for {graph_exec_id}"
)
except Exception as e:
logger.exception(f"Error handling cancel message: {e}")
def _handle_run_message(
self, channel: BlockingChannel, method_frame: Basic.GetOk, body: bytes
):
delivery_tag = method_frame.delivery_tag
try:
graph_exec_entry = GraphExecutionEntry.model_validate_json(body)
except Exception as e:
logger.error(f"[{self.service_name}] Could not parse run message: {e}")
channel.basic_nack(delivery_tag, requeue=False)
return
graph_exec_id = graph_exec_entry.graph_exec_id
logger.info(
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
)
if graph_exec_id in self.active_graph_runs:
logger.warning(
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
)
channel.basic_nack(delivery_tag, requeue=False)
return
cancel_event = multiprocessing.Manager().Event()
future = self.executor.submit(
Executor.on_graph_execution, graph_exec_entry, cancel_event
)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
def _on_run_done(f: Future):
logger.info(f"[{self.service_name}] Run completed for {graph_exec_id}")
try:
channel.basic_ack(delivery_tag)
self.active_graph_runs.pop(graph_exec_id, None)
if f.exception():
logger.error(
f"[{self.service_name}] Execution for {graph_exec_id} failed: {f.exception()}"
)
except Exception as e:
logger.error(f"[{self.service_name}] Error acknowledging message: {e}")
future.add_done_callback(_on_run_done)
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down service loop...")
self.running = False
logger.info(f"[{self.service_name}] ⏳ Shutting down graph executor pool...")
self.executor.shutdown(cancel_futures=True)
@@ -981,175 +1010,6 @@ class ExecutionManager(AppService):
def db_client(self) -> "DatabaseManager":
return get_db_client()
@expose
def add_execution(
self,
graph_id: str,
data: BlockInput,
user_id: str,
graph_version: Optional[int] = None,
preset_id: str | None = None,
) -> GraphExecutionEntry:
graph: GraphModel | None = self.db_client.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
if not graph:
raise ValueError(f"Graph #{graph_id} not found.")
graph.validate_graph(for_run=True)
self._validate_node_input_credentials(graph, user_id)
nodes_input = []
for node in graph.starting_nodes:
input_data = {}
block = node.block
# Note block should never be executed.
if block.block_type == BlockType.NOTE:
continue
# Extract request input data, and assign it to the input pin.
if block.block_type == BlockType.INPUT:
input_name = node.input_default.get("name")
if input_name and input_name in data:
input_data = {"value": data[input_name]}
# Extract webhook payload, and assign it to the input pin
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
if (
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
and node.webhook_id
):
if webhook_payload_key not in data:
raise ValueError(
f"Node {block.name} #{node.id} webhook payload is missing"
)
input_data = {"payload": data[webhook_payload_key]}
input_data, error = validate_exec(node, input_data)
if input_data is None:
raise ValueError(error)
else:
nodes_input.append((node.id, input_data))
if not nodes_input:
raise ValueError(
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
)
graph_exec = self.db_client.create_graph_execution(
graph_id=graph_id,
graph_version=graph.version,
nodes_input=nodes_input,
user_id=user_id,
preset_id=preset_id,
)
self.db_client.send_execution_update(graph_exec)
graph_exec_entry = GraphExecutionEntry(
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version or 0,
graph_exec_id=graph_exec.id,
start_node_execs=[
NodeExecutionEntry(
user_id=user_id,
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
data=node_exec.input_data,
)
for node_exec in graph_exec.node_executions
],
)
self.queue.add(graph_exec_entry)
return graph_exec_entry
@expose
def cancel_execution(self, graph_exec_id: str) -> None:
"""
Mechanism:
1. Set the cancel event
2. Graph executor's cancel handler thread detects the event, terminates workers,
reinitializes worker pool, and returns.
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
if graph_exec_id not in self.active_graph_runs:
logger.warning(
f"Graph execution #{graph_exec_id} not active/running: "
"possibly already completed/cancelled."
)
else:
future, cancel_event = self.active_graph_runs[graph_exec_id]
if not cancel_event.is_set():
cancel_event.set()
future.result()
# Update the status of the graph & node executions
self.db_client.update_graph_execution_stats(
graph_exec_id,
ExecutionStatus.TERMINATED,
)
node_execs = self.db_client.get_node_execution_results(
graph_exec_id=graph_exec_id,
statuses=[
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
ExecutionStatus.INCOMPLETE,
],
)
self.db_client.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in node_execs],
ExecutionStatus.TERMINATED,
)
for node_exec in node_execs:
node_exec.status = ExecutionStatus.TERMINATED
self.db_client.send_execution_update(node_exec)
def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
"""Checks all credentials for all nodes of the graph"""
for node in graph.nodes:
block = node.block
# Find any fields of type CredentialsMetaInput
credentials_fields = cast(
type[BlockSchema], block.input_schema
).get_credentials_fields()
if not credentials_fields:
continue
for field_name, credentials_meta_type in credentials_fields.items():
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
# Fetch the corresponding Credentials and perform sanity checks
credentials = self.credentials_store.get_creds_by_id(
user_id, credentials_meta.id
)
if not credentials:
raise ValueError(
f"Unknown credentials #{credentials_meta.id} "
f"for node #{node.id} input '{field_name}'"
)
if (
credentials.provider != credentials_meta.provider
or credentials.type != credentials_meta.type
):
logger.warning(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch: "
f"{credentials_meta.type}<>{credentials.type};"
f"{credentials_meta.provider}<>{credentials.provider}"
)
raise ValueError(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch"
)
# ------- UTILITIES ------- #
@@ -1168,6 +1028,10 @@ def get_notification_service() -> "NotificationManager":
return get_service_client(NotificationManager)
def send_execution_update(entry: GraphExecution | NodeExecutionResult):
return get_execution_event_bus().publish(entry)
@contextmanager
def synchronized(key: str, timeout: int = 60):
lock: RedisLock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)

View File

@@ -16,7 +16,7 @@ from pydantic import BaseModel
from sqlalchemy import MetaData, create_engine
from backend.data.block import BlockInput
from backend.executor.manager import ExecutionManager
from backend.executor import utils as execution_utils
from backend.notifications.notifications import NotificationManager
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Config
@@ -57,11 +57,6 @@ def job_listener(event):
log(f"Job {event.job_id} completed successfully.")
@thread_cached
def get_execution_client() -> ExecutionManager:
return get_service_client(ExecutionManager)
@thread_cached
def get_notification_client():
from backend.notifications import NotificationManager
@@ -73,7 +68,7 @@ def execute_graph(**kwargs):
args = ExecutionJobArgs(**kwargs)
try:
log(f"Executing recurring job for graph #{args.graph_id}")
get_execution_client().add_execution(
execution_utils.add_graph_execution(
graph_id=args.graph_id,
data=args.input_data,
user_id=args.user_id,
@@ -164,11 +159,6 @@ class Scheduler(AppService):
def db_pool_size(cls) -> int:
return config.scheduler_db_pool_size
@property
@thread_cached
def execution_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager)
@property
@thread_cached
def notification_client(self) -> NotificationManager:

View File

@@ -1,11 +1,70 @@
import logging
from typing import TYPE_CHECKING, Any, cast
from autogpt_libs.utils.cache import thread_cached
from pydantic import BaseModel
from backend.data.block import Block, BlockInput
from backend.data.block import (
Block,
BlockData,
BlockInput,
BlockSchema,
BlockType,
get_block,
)
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCostType
from backend.data.execution import GraphExecutionEntry, RedisExecutionEventBus
from backend.data.graph import GraphModel, Node
from backend.data.rabbitmq import (
Exchange,
ExchangeType,
Queue,
RabbitMQConfig,
SyncRabbitMQ,
)
from backend.util.mock import MockObject
from backend.util.service import get_service_client
from backend.util.settings import Config
from backend.util.type import convert
if TYPE_CHECKING:
from backend.executor import DatabaseManager
from backend.integrations.credentials_store import IntegrationCredentialsStore
config = Config()
logger = logging.getLogger(__name__)
# ============ Resource Helpers ============ #
@thread_cached
def get_execution_event_bus() -> RedisExecutionEventBus:
return RedisExecutionEventBus()
@thread_cached
def get_execution_queue() -> SyncRabbitMQ:
client = SyncRabbitMQ(create_execution_queue_config())
client.connect()
return client
@thread_cached
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
from backend.integrations.credentials_store import IntegrationCredentialsStore
return IntegrationCredentialsStore()
@thread_cached
def get_db_client() -> "DatabaseManager":
from backend.executor import DatabaseManager
return get_service_client(DatabaseManager)
# ============ Execution Cost Helpers ============ #
class UsageTransactionMetadata(BaseModel):
@@ -95,3 +154,398 @@ def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bo
or (input_data.get(k) and _is_cost_filter_match(v, input_data[k]))
for k, v in cost_filter.items()
)
# ============ Execution Input Helpers ============ #
LIST_SPLIT = "_$_"
DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
def parse_execution_output(output: BlockData, name: str) -> Any | None:
"""
Extracts partial output data by name from a given BlockData.
The function supports extracting data from lists, dictionaries, and objects
using specific naming conventions:
- For lists: <output_name>_$_<index>
- For dictionaries: <output_name>_#_<key>
- For objects: <output_name>_@_<attribute>
Args:
output (BlockData): A tuple containing the output name and data.
name (str): The name used to extract specific data from the output.
Returns:
Any | None: The extracted data if found, otherwise None.
Examples:
>>> output = ("result", [10, 20, 30])
>>> parse_execution_output(output, "result_$_1")
20
>>> output = ("config", {"key1": "value1", "key2": "value2"})
>>> parse_execution_output(output, "config_#_key1")
'value1'
>>> class Sample:
... attr1 = "value1"
... attr2 = "value2"
>>> output = ("object", Sample())
>>> parse_execution_output(output, "object_@_attr1")
'value1'
"""
output_name, output_data = output
if name == output_name:
return output_data
if name.startswith(f"{output_name}{LIST_SPLIT}"):
index = int(name.split(LIST_SPLIT)[1])
if not isinstance(output_data, list) or len(output_data) <= index:
return None
return output_data[int(name.split(LIST_SPLIT)[1])]
if name.startswith(f"{output_name}{DICT_SPLIT}"):
index = name.split(DICT_SPLIT)[1]
if not isinstance(output_data, dict) or index not in output_data:
return None
return output_data[index]
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
index = name.split(OBJC_SPLIT)[1]
if isinstance(output_data, object) and hasattr(output_data, index):
return getattr(output_data, index)
return None
return None
def validate_exec(
node: Node,
data: BlockInput,
resolve_input: bool = True,
) -> tuple[BlockInput | None, str]:
"""
Validate the input data for a node execution.
Args:
node: The node to execute.
data: The input data for the node execution.
resolve_input: Whether to resolve dynamic pins into dict/list/object.
Returns:
A tuple of the validated data and the block name.
If the data is invalid, the first element will be None, and the second element
will be an error message.
If the data is valid, the first element will be the resolved input data, and
the second element will be the block name.
"""
node_block: Block | None = get_block(node.block_id)
if not node_block:
return None, f"Block for {node.block_id} not found."
schema = node_block.input_schema
# Convert non-matching data types to the expected input schema.
for name, data_type in schema.__annotations__.items():
if (value := data.get(name)) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Input data (without default values) should contain all required fields.
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
if missing_links := schema.get_missing_links(data, node.input_links):
return None, f"{error_prefix} unpopulated links {missing_links}"
# Merge input data with default values and resolve dynamic dict/list/object pins.
input_default = schema.get_input_defaults(node.input_default)
data = {**input_default, **data}
if resolve_input:
data = merge_execution_input(data)
# Input data post-merge should contain all required fields from the schema.
if missing_input := schema.get_missing_input(data):
return None, f"{error_prefix} missing input {missing_input}"
# Last validation: Validate the input values against the schema.
if error := schema.get_mismatch_error(data):
error_message = f"{error_prefix} {error}"
logger.error(error_message)
return None, error_message
return data, node_block.name
def merge_execution_input(data: BlockInput) -> BlockInput:
"""
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
This function processes input keys that follow specific patterns to merge them into a unified structure:
- `<input_name>_$_<index>` for list inputs.
- `<input_name>_#_<index>` for dictionary inputs.
- `<input_name>_@_<index>` for object inputs.
Args:
data (BlockInput): A dictionary containing input keys and their corresponding values.
Returns:
BlockInput: A dictionary with merged inputs.
Raises:
ValueError: If a list index is not an integer.
Examples:
>>> data = {
... "list_$_0": "a",
... "list_$_1": "b",
... "dict_#_key1": "value1",
... "dict_#_key2": "value2",
... "object_@_attr1": "value1",
... "object_@_attr2": "value2"
... }
>>> merge_execution_input(data)
{
"list": ["a", "b"],
"dict": {"key1": "value1", "key2": "value2"},
"object": <MockObject attr1="value1" attr2="value2">
}
"""
# Merge all input with <input_name>_$_<index> into a single list.
items = list(data.items())
for key, value in items:
if LIST_SPLIT not in key:
continue
name, index = key.split(LIST_SPLIT)
if not index.isdigit():
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
data[name] = data.get(name, [])
if int(index) >= len(data[name]):
# Pad list with empty string on missing indices.
data[name].extend([""] * (int(index) - len(data[name]) + 1))
data[name][int(index)] = value
# Merge all input with <input_name>_#_<index> into a single dict.
for key, value in items:
if DICT_SPLIT not in key:
continue
name, index = key.split(DICT_SPLIT)
data[name] = data.get(name, {})
data[name][index] = value
# Merge all input with <input_name>_@_<index> into a single object.
for key, value in items:
if OBJC_SPLIT not in key:
continue
name, index = key.split(OBJC_SPLIT)
if name not in data or not isinstance(data[name], object):
data[name] = MockObject()
setattr(data[name], index, value)
return data
def _validate_node_input_credentials(graph: GraphModel, user_id: str):
"""Checks all credentials for all nodes of the graph"""
for node in graph.nodes:
block = node.block
# Find any fields of type CredentialsMetaInput
credentials_fields = cast(
type[BlockSchema], block.input_schema
).get_credentials_fields()
if not credentials_fields:
continue
for field_name, credentials_meta_type in credentials_fields.items():
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
# Fetch the corresponding Credentials and perform sanity checks
credentials = get_integration_credentials_store().get_creds_by_id(
user_id, credentials_meta.id
)
if not credentials:
raise ValueError(
f"Unknown credentials #{credentials_meta.id} "
f"for node #{node.id} input '{field_name}'"
)
if (
credentials.provider != credentials_meta.provider
or credentials.type != credentials_meta.type
):
logger.warning(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch: "
f"{credentials_meta.type}<>{credentials.type};"
f"{credentials_meta.provider}<>{credentials.provider}"
)
raise ValueError(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch"
)
def construct_node_execution_input(
graph: GraphModel,
user_id: str,
data: BlockInput,
) -> list[tuple[str, BlockInput]]:
"""
Validates and prepares the input data for executing a graph.
This function checks the graph for starting nodes, validates the input data
against the schema, and resolves dynamic input pins into a single list,
dictionary, or object.
Args:
graph (GraphModel): The graph model to execute.
user_id (str): The ID of the user executing the graph.
data (BlockInput): The input data for the graph execution.
Returns:
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
the corresponding input data for that node.
"""
graph.validate_graph(for_run=True)
_validate_node_input_credentials(graph, user_id)
nodes_input = []
for node in graph.starting_nodes:
input_data = {}
block = node.block
# Note block should never be executed.
if block.block_type == BlockType.NOTE:
continue
# Extract request input data, and assign it to the input pin.
if block.block_type == BlockType.INPUT:
input_name = node.input_default.get("name")
if input_name and input_name in data:
input_data = {"value": data[input_name]}
# Extract webhook payload, and assign it to the input pin
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
if (
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
and node.webhook_id
):
if webhook_payload_key not in data:
raise ValueError(
f"Node {block.name} #{node.id} webhook payload is missing"
)
input_data = {"payload": data[webhook_payload_key]}
input_data, error = validate_exec(node, input_data)
if input_data is None:
raise ValueError(error)
else:
nodes_input.append((node.id, input_data))
if not nodes_input:
raise ValueError(
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
)
return nodes_input
# ============ Execution Queue Helpers ============ #
class CancelExecutionEvent(BaseModel):
graph_exec_id: str
GRAPH_EXECUTION_EXCHANGE = Exchange(
name="graph_execution",
type=ExchangeType.DIRECT,
durable=True,
auto_delete=False,
)
GRAPH_EXECUTION_QUEUE_NAME = "graph_execution_queue"
GRAPH_EXECUTION_ROUTING_KEY = "graph_execution.run"
GRAPH_EXECUTION_CANCEL_EXCHANGE = Exchange(
name="graph_execution_cancel",
type=ExchangeType.FANOUT,
durable=True,
auto_delete=True,
)
GRAPH_EXECUTION_CANCEL_QUEUE_NAME = "graph_execution_cancel_queue"
def create_execution_queue_config() -> RabbitMQConfig:
"""
Define two exchanges and queues:
- 'graph_execution' (DIRECT) for run tasks.
- 'graph_execution_cancel' (FANOUT) for cancel requests.
"""
run_queue = Queue(
name=GRAPH_EXECUTION_QUEUE_NAME,
exchange=GRAPH_EXECUTION_EXCHANGE,
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
durable=True,
auto_delete=False,
)
cancel_queue = Queue(
name=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
routing_key="", # not used for FANOUT
durable=True,
auto_delete=False,
)
return RabbitMQConfig(
vhost="/",
exchanges=[GRAPH_EXECUTION_EXCHANGE, GRAPH_EXECUTION_CANCEL_EXCHANGE],
queues=[run_queue, cancel_queue],
)
def add_graph_execution(
graph_id: str,
data: BlockInput,
user_id: str,
graph_version: int | None = None,
preset_id: str | None = None,
) -> GraphExecutionEntry:
"""
Adds a graph execution to the queue and returns the execution entry.
Args:
graph_id (str): The ID of the graph to execute.
data (BlockInput): The input data for the graph execution.
user_id (str): The ID of the user executing the graph.
graph_version (int | None): The version of the graph to execute. Defaults to None.
preset_id (str | None): The ID of the preset to use. Defaults to None.
Returns:
GraphExecutionEntry: The entry for the graph execution.
Raises:
ValueError: If the graph is not found or if there are validation errors.
"""
graph: GraphModel | None = get_db_client().get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
if not graph:
raise ValueError(f"Graph #{graph_id} not found.")
graph_exec = get_db_client().create_graph_execution(
graph_id=graph_id,
graph_version=graph.version,
nodes_input=construct_node_execution_input(graph, user_id, data),
user_id=user_id,
preset_id=preset_id,
)
get_execution_event_bus().publish(graph_exec)
graph_exec_entry = graph_exec.to_graph_execution_entry()
get_execution_queue().publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec_entry.model_dump_json(),
exchange=GRAPH_EXECUTION_EXCHANGE,
)
return graph_exec_entry

View File

@@ -2,7 +2,6 @@ import logging
from collections import defaultdict
from typing import Annotated, Any, Dict, List, Optional, Sequence
from autogpt_libs.utils.cache import thread_cached
from fastapi import APIRouter, Body, Depends, HTTPException
from prisma.enums import AgentExecutionStatus, APIKeyPermission
from typing_extensions import TypedDict
@@ -13,17 +12,10 @@ from backend.data import graph as graph_db
from backend.data.api_key import APIKey
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.execution import NodeExecutionResult
from backend.executor import ExecutionManager
from backend.server.external.middleware import require_permission
from backend.util.service import get_service_client
from backend.server.routers import v1 as internal_api_routes
from backend.util.settings import Settings
@thread_cached
def execution_manager_client() -> ExecutionManager:
return get_service_client(ExecutionManager)
settings = Settings()
logger = logging.getLogger(__name__)
@@ -98,18 +90,18 @@ def execute_graph_block(
path="/graphs/{graph_id}/execute/{graph_version}",
tags=["graphs"],
)
def execute_graph(
async def execute_graph(
graph_id: str,
graph_version: int,
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
) -> dict[str, Any]:
try:
graph_exec = execution_manager_client().add_execution(
graph_id,
graph_version=graph_version,
data=node_input,
graph_exec = await internal_api_routes.execute_graph(
graph_id=graph_id,
node_input=node_input,
user_id=api_key.user_id,
graph_version=graph_version,
)
return {"id": graph_exec.graph_exec_id}
except Exception as e:

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import TYPE_CHECKING, Annotated, Literal
@@ -14,13 +15,12 @@ from backend.data.integrations import (
wait_for_webhook_event,
)
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
from backend.executor.manager import ExecutionManager
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
from backend.server.routers import v1 as internal_api_routes
from backend.util.exceptions import NeedConfirmation, NotFoundError
from backend.util.service import get_service_client
from backend.util.settings import Settings
if TYPE_CHECKING:
@@ -309,19 +309,22 @@ async def webhook_ingress_generic(
if not webhook.attached_nodes:
return
executor = get_service_client(ExecutionManager)
executions = []
for node in webhook.attached_nodes:
logger.debug(f"Webhook-attached node: {node}")
if not node.is_triggered_by_event_type(event_type):
logger.debug(f"Node #{node.id} doesn't trigger on event {event_type}")
continue
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
executor.add_execution(
graph_id=node.graph_id,
graph_version=node.graph_version,
data={f"webhook_{webhook_id}_payload": payload},
user_id=webhook.user_id,
executions.append(
internal_api_routes.execute_graph(
graph_id=node.graph_id,
graph_version=node.graph_version,
node_input={f"webhook_{webhook_id}_payload": payload},
user_id=webhook.user_id,
)
)
asyncio.gather(*executions)
@router.post("/webhooks/{webhook_id}/ping")

View File

@@ -17,7 +17,6 @@ import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.user
import backend.server.integrations.router
import backend.server.routers.postmark.postmark
import backend.server.routers.v1
import backend.server.v2.admin.store_admin_routes
@@ -156,7 +155,7 @@ class AgentServer(backend.util.service.AppProcess):
graph_version: Optional[int] = None,
node_input: Optional[dict[str, Any]] = None,
):
return backend.server.routers.v1.execute_graph(
return await backend.server.routers.v1.execute_graph(
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version,
@@ -275,7 +274,9 @@ class AgentServer(backend.util.service.AppProcess):
provider: ProviderName,
credentials: Credentials,
) -> Credentials:
return backend.server.integrations.router.create_credentials(
from backend.server.integrations.router import create_credentials
return create_credentials(
user_id=user_id, provider=provider, credentials=credentials
)

View File

@@ -2,7 +2,7 @@ import asyncio
import logging
from collections import defaultdict
from datetime import datetime
from typing import TYPE_CHECKING, Annotated, Any, Sequence
from typing import TYPE_CHECKING, Annotated, Any, Coroutine, Sequence
import pydantic
import stripe
@@ -13,7 +13,6 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
import backend.data.block
import backend.server.integrations.router
import backend.server.routers.analytics
import backend.server.v2.library.db as library_db
@@ -31,7 +30,7 @@ from backend.data.api_key import (
suspend_api_key,
update_api_key_permissions,
)
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
from backend.data.credit import (
AutoTopUpConfig,
RefundRequest,
@@ -41,6 +40,7 @@ from backend.data.credit import (
get_user_credit_model,
set_auto_top_up,
)
from backend.data.execution import AsyncRedisExecutionEventBus
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.data.onboarding import (
UserOnboardingUpdate,
@@ -49,13 +49,16 @@ from backend.data.onboarding import (
onboarding_enabled,
update_user_onboarding,
)
from backend.data.rabbitmq import AsyncRabbitMQ
from backend.data.user import (
get_or_create_user,
get_user_notification_preference,
update_user_email,
update_user_notification_preference,
)
from backend.executor import ExecutionManager, Scheduler, scheduler
from backend.executor import Scheduler, scheduler
from backend.executor import utils as execution_utils
from backend.executor.utils import create_execution_queue_config
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import (
on_graph_activate,
@@ -79,13 +82,23 @@ if TYPE_CHECKING:
@thread_cached
def execution_manager_client() -> ExecutionManager:
return get_service_client(ExecutionManager)
def execution_scheduler_client() -> Scheduler:
return get_service_client(Scheduler)
@thread_cached
def execution_scheduler_client() -> Scheduler:
return get_service_client(Scheduler)
def execution_queue_client() -> Coroutine[None, None, AsyncRabbitMQ]:
async def f() -> AsyncRabbitMQ:
client = AsyncRabbitMQ(create_execution_queue_config())
await client.connect()
return client
return f()
@thread_cached
def execution_event_bus() -> AsyncRedisExecutionEventBus:
return AsyncRedisExecutionEventBus()
settings = Settings()
@@ -206,7 +219,7 @@ async def is_onboarding_enabled():
@v1_router.get(path="/blocks", tags=["blocks"], dependencies=[Depends(auth_middleware)])
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.data.block.get_blocks().values()]
blocks = [block() for block in get_blocks().values()]
costs = get_block_costs()
return [
{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks if not b.disabled
@@ -219,7 +232,7 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
dependencies=[Depends(auth_middleware)],
)
def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput:
obj = backend.data.block.get_block(block_id)
obj = get_block(block_id)
if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
@@ -308,7 +321,7 @@ async def configure_user_auto_top_up(
dependencies=[Depends(auth_middleware)],
)
async def get_user_auto_top_up(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> AutoTopUpConfig:
return await get_auto_top_up(user_id)
@@ -375,7 +388,7 @@ async def get_credit_history(
@v1_router.get(path="/credits/refunds", dependencies=[Depends(auth_middleware)])
async def get_refund_requests(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> list[RefundRequest]:
return await _user_credit_model.get_refund_requests(user_id)
@@ -391,7 +404,7 @@ class DeleteGraphResponse(TypedDict):
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
async def get_graphs(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> Sequence[graph_db.GraphModel]:
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
@@ -580,16 +593,35 @@ async def set_graph_active_version(
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
def execute_graph(
async def execute_graph(
graph_id: str,
node_input: Annotated[dict[str, Any], Body(..., default_factory=dict)],
user_id: Annotated[str, Depends(get_user_id)],
graph_version: Optional[int] = None,
preset_id: Optional[str] = None,
) -> ExecuteGraphResponse:
graph_exec = execution_manager_client().add_execution(
graph_id, node_input, user_id=user_id, graph_version=graph_version
graph: graph_db.GraphModel | None = await graph_db.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
return ExecuteGraphResponse(graph_exec_id=graph_exec.graph_exec_id)
if not graph:
raise ValueError(f"Graph #{graph_id} not found.")
graph_exec = await execution_db.create_graph_execution(
graph_id=graph_id,
graph_version=graph.version,
nodes_input=execution_utils.construct_node_execution_input(
graph, user_id, node_input
),
user_id=user_id,
preset_id=preset_id,
)
execution_utils.get_execution_event_bus().publish(graph_exec)
execution_utils.get_execution_queue().publish_message(
routing_key=execution_utils.GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec.to_graph_execution_entry().model_dump_json(),
exchange=execution_utils.GRAPH_EXECUTION_EXCHANGE,
)
return ExecuteGraphResponse(graph_exec_id=graph_exec.id)
@v1_router.post(
@@ -605,9 +637,7 @@ async def stop_graph_run(
):
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
await asyncio.to_thread(
lambda: execution_manager_client().cancel_execution(graph_exec_id)
)
await _cancel_execution(graph_exec_id)
# Retrieve & return canceled graph execution in its final state
result = await execution_db.get_graph_execution(
@@ -621,6 +651,49 @@ async def stop_graph_run(
return result
async def _cancel_execution(graph_exec_id: str):
"""
Mechanism:
1. Set the cancel event
2. Graph executor's cancel handler thread detects the event, terminates workers,
reinitializes worker pool, and returns.
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
queue_client = await execution_queue_client()
await queue_client.publish_message(
routing_key="",
message=execution_utils.CancelExecutionEvent(
graph_exec_id=graph_exec_id
).model_dump_json(),
exchange=execution_utils.GRAPH_EXECUTION_CANCEL_EXCHANGE,
)
# Update the status of the graph & node executions
await execution_db.update_graph_execution_stats(
graph_exec_id,
execution_db.ExecutionStatus.TERMINATED,
)
node_execs = [
node_exec.model_copy(update={"status": execution_db.ExecutionStatus.TERMINATED})
for node_exec in await execution_db.get_node_execution_results(
graph_exec_id=graph_exec_id,
statuses=[
execution_db.ExecutionStatus.QUEUED,
execution_db.ExecutionStatus.RUNNING,
execution_db.ExecutionStatus.INCOMPLETE,
],
)
]
await execution_db.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in node_execs],
execution_db.ExecutionStatus.TERMINATED,
)
await asyncio.gather(
*[execution_event_bus().publish(node_exec) for node_exec in node_execs]
)
@v1_router.get(
path="/executions",
tags=["graphs"],
@@ -792,7 +865,7 @@ async def create_api_key(
dependencies=[Depends(auth_middleware)],
)
async def get_api_keys(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> list[APIKeyWithoutHash]:
"""List all API keys for the user"""
try:

View File

@@ -2,25 +2,16 @@ import logging
from typing import Annotated, Any
import autogpt_libs.auth as autogpt_auth_lib
import autogpt_libs.utils.cache
from fastapi import APIRouter, Body, Depends, HTTPException, status
import backend.executor
import backend.server.v2.library.db as db
import backend.server.v2.library.model as models
import backend.util.service
logger = logging.getLogger(__name__)
router = APIRouter()
@autogpt_libs.utils.cache.thread_cached
def execution_manager_client() -> backend.executor.ExecutionManager:
"""Return a cached instance of ExecutionManager client."""
return backend.util.service.get_service_client(backend.executor.ExecutionManager)
@router.get(
"/presets",
summary="List presets",
@@ -216,6 +207,8 @@ async def execute_preset(
HTTPException: If the preset is not found or an error occurs while executing the preset.
"""
try:
from backend.server.routers import v1 as internal_api_routes
preset = await db.get_preset(user_id, preset_id)
if not preset:
raise HTTPException(
@@ -226,10 +219,10 @@ async def execute_preset(
# Merge input overrides with preset inputs
merged_node_input = preset.inputs | node_input
execution = execution_manager_client().add_execution(
execution = await internal_api_routes.execute_graph(
graph_id=graph_id,
node_input=merged_node_input,
graph_version=graph_version,
data=merged_node_input,
user_id=user_id,
preset_id=preset_id,
)

View File

@@ -12,7 +12,6 @@ import backend.server.v2.store.exceptions
import backend.server.v2.store.model
from backend.data.graph import GraphModel, get_sub_graphs
from backend.data.includes import AGENT_GRAPH_INCLUDE
from backend.util.type import typed_cast
logger = logging.getLogger(__name__)
@@ -960,7 +959,7 @@ async def get_my_agents(
try:
search_filter: prisma.types.LibraryAgentWhereInput = {
"userId": user_id,
"AgentGraph": {"is": {"StoreListing": {"none": {"isDeleted": False}}}},
"AgentGraph": {"is": {"StoreListings": {"none": {"isDeleted": False}}}},
"isArchived": False,
"isDeleted": False,
}
@@ -1088,13 +1087,7 @@ async def review_store_submission(
where={"id": store_listing_version_id},
include={
"StoreListing": True,
"AgentGraph": {
"include": typed_cast(
prisma.types.AgentGraphIncludeFromAgentGraphRecursive1,
prisma.types.AgentGraphInclude,
AGENT_GRAPH_INCLUDE,
)
},
"AgentGraph": {"include": AGENT_GRAPH_INCLUDE},
},
)
)

View File

@@ -15,21 +15,25 @@ def to_dict(data) -> dict:
def dumps(data) -> str:
return json.dumps(jsonable_encoder(data))
return json.dumps(to_dict(data))
T = TypeVar("T")
@overload
def loads(data: str, *args, target_type: Type[T], **kwargs) -> T: ...
def loads(data: str | bytes, *args, target_type: Type[T], **kwargs) -> T: ...
@overload
def loads(data: str, *args, **kwargs) -> Any: ...
def loads(data: str | bytes, *args, **kwargs) -> Any: ...
def loads(data: str, *args, target_type: Type[T] | None = None, **kwargs) -> Any:
def loads(
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
) -> Any:
if isinstance(data, bytes):
data = data.decode("utf-8")
parsed = json.loads(data, *args, **kwargs)
if target_type:
return type_match(parsed, target_type)

View File

@@ -14,9 +14,7 @@ def sentry_init():
traces_sample_rate=1.0,
profiles_sample_rate=1.0,
environment=f"app:{Settings().config.app_env.value}-behave:{Settings().config.behave_as.value}",
_experiments={
"enable_logs": True,
},
_experiments={"enable_logs": True},
integrations=[
LoggingIntegration(sentry_logs_level=logging.INFO),
AnthropicIntegration(

View File

@@ -198,18 +198,6 @@ def convert(value: Any, target_type: Type[T]) -> T:
raise ConversionError(f"Failed to convert {value} to {target_type}") from e
def typed(type: type[T], value: T) -> T:
"""
Add an explicit type to a value. Useful in nested statements, e.g. dict literals.
"""
return value
def typed_cast(to_type: type[TT], from_type: type[T], value: T) -> TT:
"""Strict cast to preserve type checking abilities."""
return cast(TT, value)
class FormattedStringType(str):
string_format: str

View File

@@ -0,0 +1,16 @@
-- Modify the OnboardingStep enum
ALTER TYPE "OnboardingStep" ADD VALUE 'GET_RESULTS';
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_VISIT';
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_ADD_AGENT';
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_RUN_AGENT';
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_OPEN';
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_SAVE_AGENT';
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_RUN_AGENT';
-- Modify the UserOnboarding table
ALTER TABLE "UserOnboarding"
ADD COLUMN "updatedAt" TIMESTAMP(3),
ADD COLUMN "notificationDot" BOOLEAN NOT NULL DEFAULT true,
ADD COLUMN "notified" "OnboardingStep"[] DEFAULT '{}',
ADD COLUMN "rewardedFor" "OnboardingStep"[] DEFAULT '{}',
ADD COLUMN "onboardingAgentExecutionId" TEXT

View File

@@ -5,7 +5,7 @@ datasource db {
generator client {
provider = "prisma-client-py"
recursive_type_depth = 5
recursive_type_depth = -1
interface = "asyncio"
previewFeatures = ["views"]
}
@@ -58,6 +58,7 @@ model User {
}
enum OnboardingStep {
// Introductory onboarding (Library)
WELCOME
USAGE_REASON
INTEGRATIONS
@@ -65,18 +66,32 @@ enum OnboardingStep {
AGENT_NEW_RUN
AGENT_INPUT
CONGRATS
GET_RESULTS
// Marketplace
MARKETPLACE_VISIT
MARKETPLACE_ADD_AGENT
MARKETPLACE_RUN_AGENT
// Builder
BUILDER_OPEN
BUILDER_SAVE_AGENT
BUILDER_RUN_AGENT
}
model UserOnboarding {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime? @updatedAt
completedSteps OnboardingStep[] @default([])
notificationDot Boolean @default(true)
notified OnboardingStep[] @default([])
rewardedFor OnboardingStep[] @default([])
usageReason String?
integrations String[] @default([])
otherIntegrations String?
selectedStoreListingVersionId String?
agentInput Json?
onboardingAgentExecutionId String?
userId String @unique
User User @relation(fields: [userId], references: [id], onDelete: Cascade)

View File

@@ -1,4 +1,4 @@
from backend.data.execution import merge_execution_input, parse_execution_output
from backend.executor.utils import merge_execution_input, parse_execution_output
def test_parse_execution_output():

View File

@@ -10,16 +10,12 @@ from prisma.types import (
AgentGraphCreateInput,
AgentNodeCreateInput,
AgentNodeLinkCreateInput,
AgentPresetCreateInput,
AnalyticsDetailsCreateInput,
AnalyticsMetricsCreateInput,
APIKeyCreateInput,
CreditTransactionCreateInput,
LibraryAgentCreateInput,
ProfileCreateInput,
StoreListingCreateInput,
StoreListingReviewCreateInput,
StoreListingVersionCreateInput,
UserCreateInput,
)
@@ -140,14 +136,14 @@ async def main():
for _ in range(num_presets): # Create 1 AgentPreset per user
graph = random.choice(agent_graphs)
preset = await db.agentpreset.create(
data=AgentPresetCreateInput(
name=faker.sentence(nb_words=3),
description=faker.text(max_nb_chars=200),
userId=user.id,
agentId=graph.id,
agentVersion=graph.version,
isActive=True,
)
data={
"name": faker.sentence(nb_words=3),
"description": faker.text(max_nb_chars=200),
"userId": user.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"isActive": True,
}
)
agent_presets.append(preset)
@@ -160,16 +156,15 @@ async def main():
graph = random.choice(agent_graphs)
preset = random.choice(agent_presets)
user_agent = await db.libraryagent.create(
data=LibraryAgentCreateInput(
userId=user.id,
agentId=graph.id,
agentVersion=graph.version,
agentPresetId=preset.id,
isFavorite=random.choice([True, False]),
isCreatedByUser=random.choice([True, False]),
isArchived=random.choice([True, False]),
isDeleted=random.choice([True, False]),
)
data={
"userId": user.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"isFavorite": random.choice([True, False]),
"isCreatedByUser": random.choice([True, False]),
"isArchived": random.choice([True, False]),
"isDeleted": random.choice([True, False]),
}
)
user_agents.append(user_agent)
@@ -346,13 +341,13 @@ async def main():
user = random.choice(users)
slug = faker.slug()
listing = await db.storelisting.create(
data=StoreListingCreateInput(
agentId=graph.id,
agentVersion=graph.version,
owningUserId=user.id,
hasApprovedVersion=random.choice([True, False]),
slug=slug,
)
data={
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"owningUserId": user.id,
"hasApprovedVersion": random.choice([True, False]),
"slug": slug,
}
)
store_listings.append(listing)
@@ -362,26 +357,26 @@ async def main():
for listing in store_listings:
graph = [g for g in agent_graphs if g.id == listing.agentId][0]
version = await db.storelistingversion.create(
data=StoreListingVersionCreateInput(
agentId=graph.id,
agentVersion=graph.version,
name=graph.name or faker.sentence(nb_words=3),
subHeading=faker.sentence(),
videoUrl=faker.url(),
imageUrls=[get_image() for _ in range(3)],
description=faker.text(),
categories=[faker.word() for _ in range(3)],
isFeatured=random.choice([True, False]),
isAvailable=True,
storeListingId=listing.id,
submissionStatus=random.choice(
data={
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"name": graph.name or faker.sentence(nb_words=3),
"subHeading": faker.sentence(),
"videoUrl": faker.url(),
"imageUrls": [get_image() for _ in range(3)],
"description": faker.text(),
"categories": [faker.word() for _ in range(3)],
"isFeatured": random.choice([True, False]),
"isAvailable": True,
"storeListingId": listing.id,
"submissionStatus": random.choice(
[
prisma.enums.SubmissionStatus.PENDING,
prisma.enums.SubmissionStatus.APPROVED,
prisma.enums.SubmissionStatus.REJECTED,
]
),
)
}
)
store_listing_versions.append(version)
@@ -422,23 +417,12 @@ async def main():
)
await db.storelistingversion.update(
where={"id": version.id},
data=StoreListingVersionCreateInput(
submissionStatus=status,
Reviewer={"connect": {"id": reviewer.id}},
reviewComments=faker.text(),
reviewedAt=datetime.now(),
agentId=version.agentId, # preserving existing fields
agentVersion=version.agentVersion,
name=version.name,
subHeading=version.subHeading,
videoUrl=version.videoUrl,
imageUrls=version.imageUrls,
description=version.description,
categories=version.categories,
isFeatured=version.isFeatured,
isAvailable=version.isAvailable,
storeListingId=version.storeListingId,
),
data={
"submissionStatus": status,
"Reviewer": {"connect": {"id": reviewer.id}},
"reviewComments": faker.text(),
"reviewedAt": datetime.now(),
},
)
# Insert APIKeys

View File

@@ -52,7 +52,6 @@
"@xyflow/react": "12.4.2",
"ajv": "^8.17.1",
"boring-avatars": "^1.11.2",
"canvas-confetti": "^1.9.3",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"cmdk": "1.0.4",
@@ -70,6 +69,7 @@
"moment": "^2.30.1",
"next": "^14.2.26",
"next-themes": "^0.4.5",
"party-js": "^2.2.0",
"react": "^18",
"react-day-picker": "^9.6.1",
"react-dom": "^18",

View File

@@ -3,9 +3,16 @@
import { useSearchParams } from "next/navigation";
import { GraphID } from "@/lib/autogpt-server-api/types";
import FlowEditor from "@/components/Flow";
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
import { useEffect } from "react";
export default function Home() {
const query = useSearchParams();
const { completeStep } = useOnboarding();
useEffect(() => {
completeStep("BUILDER_OPEN");
}, []);
return (
<FlowEditor

View File

@@ -22,6 +22,7 @@ import AgentRunDraftView from "@/components/agents/agent-run-draft-view";
import AgentRunDetailsView from "@/components/agents/agent-run-details-view";
import AgentRunsSelectorList from "@/components/agents/agent-runs-selector-list";
import AgentScheduleDetailsView from "@/components/agents/agent-schedule-details-view";
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
export default function AgentRunsPage(): React.ReactElement {
const { id: agentID }: { id: LibraryAgentID } = useParams();
@@ -49,6 +50,7 @@ export default function AgentRunsPage(): React.ReactElement {
useState<boolean>(false);
const [confirmingDeleteAgentRun, setConfirmingDeleteAgentRun] =
useState<GraphExecutionMeta | null>(null);
const { state, updateState } = useOnboarding();
const openRunDraftView = useCallback(() => {
selectView({ type: "run" });
@@ -78,6 +80,18 @@ export default function AgentRunsPage(): React.ReactElement {
[api, graphVersions],
);
// Reward user for viewing results of their onboarding agent
useEffect(() => {
if (!state || !selectedRun || state.completedSteps.includes("GET_RESULTS"))
return;
if (selectedRun.id === state.onboardingAgentExecutionId) {
updateState({
completedSteps: [...state.completedSteps, "GET_RESULTS"],
});
}
}, [selectedRun, state]);
const fetchAgents = useCallback(() => {
api.getLibraryAgent(agentID).then((agent) => {
setAgent(agent);

View File

@@ -8,6 +8,7 @@ import { BreadCrumbs } from "@/components/agptui/BreadCrumbs";
import { Metadata } from "next";
import { CreatorInfoCard } from "@/components/agptui/CreatorInfoCard";
import { CreatorLinks } from "@/components/agptui/CreatorLinks";
import { Separator } from "@/components/ui/separator";
export async function generateMetadata({
params,
@@ -78,7 +79,7 @@ export default async function Page({
</div>
</div>
<div className="mt-8 sm:mt-12 md:mt-16 lg:pb-[58px]">
<hr className="w-full bg-neutral-700" />
<Separator className="mb-6 bg-gray-200" />
<AgentsSection
agents={creatorAgents.agents}
hideAvatars={true}

View File

@@ -154,7 +154,7 @@ export default async function Page({}: {}) {
<HeroSection />
<FeaturedSection featuredAgents={featuredAgents.agents} />
{/* 100px margin because our featured sections button are placed 40px below the container */}
<Separator className="mb-[25px] mt-[100px]" />
<Separator className="mb-6 mt-24" />
<AgentsSection
sectionTitle="Top Agents"
agents={topAgents.agents as Agent[]}

View File

@@ -83,8 +83,14 @@ export default function Page() {
api.addMarketplaceAgentToLibrary(
storeAgent?.store_listing_version_id || "",
);
api.executeGraph(agent.id, agent.version, state?.agentInput || {});
router.push("/onboarding/6-congrats");
api
.executeGraph(agent.id, agent.version, state?.agentInput || {})
.then(({ graph_exec_id }) => {
updateState({
onboardingAgentExecutionId: graph_exec_id,
});
router.push("/onboarding/6-congrats");
});
}, [api, agent, router, state?.agentInput]);
const runYourAgent = (

View File

@@ -6,9 +6,6 @@ import { redirect } from "next/navigation";
export async function finishOnboarding() {
const api = new BackendAPI();
const onboarding = await api.getUserOnboarding();
await api.updateUserOnboarding({
completedSteps: [...onboarding.completedSteps, "CONGRATS"],
});
revalidatePath("/library", "layout");
redirect("/library");
}

View File

@@ -1,24 +1,26 @@
"use client";
import { useEffect, useState } from "react";
import { useEffect, useRef, useState } from "react";
import { cn } from "@/lib/utils";
import { finishOnboarding } from "./actions";
import confetti from "canvas-confetti";
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
import * as party from "party-js";
export default function Page() {
useOnboarding(7, "AGENT_INPUT");
const { state, updateState } = useOnboarding(7, "AGENT_INPUT");
const [showText, setShowText] = useState(false);
const [showSubtext, setShowSubtext] = useState(false);
const divRef = useRef(null);
useEffect(() => {
confetti({
particleCount: 120,
spread: 360,
shapes: ["square", "circle"],
scalar: 2,
decay: 0.93,
origin: { y: 0.38, x: 0.51 },
});
if (divRef.current) {
party.confetti(divRef.current, {
count: 100,
spread: 180,
shapes: ["square", "circle"],
size: party.variation.range(2, 2), // scalar: 2
speed: party.variation.range(300, 1000),
});
}
const timer0 = setTimeout(() => {
setShowText(true);
@@ -29,6 +31,9 @@ export default function Page() {
}, 500);
const timer2 = setTimeout(() => {
updateState({
completedSteps: [...(state?.completedSteps || []), "CONGRATS"],
});
finishOnboarding();
}, 3000);
@@ -42,6 +47,7 @@ export default function Page() {
return (
<div className="flex h-screen w-screen flex-col items-center justify-center bg-violet-100">
<div
ref={divRef}
className={cn(
"z-10 -mb-16 text-9xl duration-500",
showText ? "opacity-100" : "opacity-0",
@@ -63,7 +69,7 @@ export default function Page() {
showSubtext ? "opacity-100" : "opacity-0",
)}
>
You earned 15$ for running your first agent
You earned 3$ for running your first agent
</p>
</div>
);

View File

@@ -5,11 +5,14 @@ export default async function OnboardingResetPage() {
const api = new BackendAPI();
await api.updateUserOnboarding({
completedSteps: [],
notificationDot: true,
notified: [],
usageReason: null,
integrations: [],
otherIntegrations: "",
selectedStoreListingVersionId: null,
agentInput: {},
onboardingAgentExecutionId: null,
});
redirect("/onboarding/1-welcome");
}

View File

@@ -11,6 +11,7 @@ import { useToastOnFail } from "@/components/ui/use-toast";
import ActionButtonGroup from "@/components/agptui/action-button-group";
import SchemaTooltip from "@/components/SchemaTooltip";
import { IconPlay } from "@/components/ui/icons";
import { useOnboarding } from "../onboarding/onboarding-provider";
export default function AgentRunDraftView({
graph,
@@ -26,15 +27,18 @@ export default function AgentRunDraftView({
const agentInputs = graph.input_schema.properties;
const [inputValues, setInputValues] = useState<Record<string, any>>({});
const { state, completeStep } = useOnboarding();
const doRun = useCallback(
() =>
api
.executeGraph(graph.id, graph.version, inputValues)
.then((newRun) => onRun(newRun.graph_exec_id))
.catch(toastOnFail("execute agent")),
[api, graph, inputValues, onRun, toastOnFail],
);
const doRun = useCallback(() => {
api
.executeGraph(graph.id, graph.version, inputValues)
.then((newRun) => onRun(newRun.graph_exec_id))
.catch(toastOnFail("execute agent"));
// Mark run agent onboarding step as completed
if (state?.completedSteps.includes("MARKETPLACE_ADD_AGENT")) {
completeStep("MARKETPLACE_RUN_AGENT");
}
}, [api, graph, inputValues, onRun, state]);
const runActions: ButtonAction[] = useMemo(
() => [

View File

@@ -10,6 +10,7 @@ import { useToast } from "@/components/ui/use-toast";
import useSupabase from "@/hooks/useSupabase";
import { DownloadIcon, LoaderIcon } from "lucide-react";
import { useOnboarding } from "../onboarding/onboarding-provider";
interface AgentInfoProps {
name: string;
creator: string;
@@ -39,6 +40,7 @@ export const AgentInfo: React.FC<AgentInfoProps> = ({
const api = React.useMemo(() => new BackendAPI(), []);
const { user } = useSupabase();
const { toast } = useToast();
const { completeStep } = useOnboarding();
const [downloading, setDownloading] = React.useState(false);
@@ -47,6 +49,7 @@ export const AgentInfo: React.FC<AgentInfoProps> = ({
const newLibraryAgent = await api.addMarketplaceAgentToLibrary(
storeListingVersionId,
);
completeStep("MARKETPLACE_ADD_AGENT");
router.push(`/library/agents/${newLibraryAgent.id}`);
} catch (error) {
console.error("Failed to add agent to library:", error);
@@ -170,10 +173,10 @@ export const AgentInfo: React.FC<AgentInfoProps> = ({
{/* Description Section */}
<div className="mb-4 w-full lg:mb-[36px]">
<div className="font-geist decoration-skip-ink-none mb-1.5 text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
<div className="mb-1.5 font-sans text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
Description
</div>
<div className="font-geist decoration-skip-ink-none text-base font-normal leading-6 text-neutral-600 underline-offset-[from-font] dark:text-neutral-400">
<div className="whitespace-pre-line font-sans text-base font-normal leading-6 text-neutral-600 dark:text-neutral-400">
{longDescription}
</div>
</div>

View File

@@ -22,7 +22,7 @@ export const BreadCrumbs: React.FC<BreadCrumbsProps> = ({ items }) => {
<button className="flex h-12 w-12 items-center justify-center rounded-full border border-neutral-200 transition-colors hover:bg-neutral-50 dark:border-neutral-700 dark:hover:bg-neutral-800">
<IconRightArrow className="h-5 w-5 text-neutral-900 dark:text-neutral-100" />
</button> */}
<div className="flex h-auto flex-wrap items-center justify-start gap-4 rounded-[5rem] bg-white dark:bg-transparent">
<div className="flex h-auto flex-wrap items-center justify-start gap-4 rounded-[5rem] dark:bg-transparent">
{items.map((item, index) => (
<React.Fragment key={index}>
<Link href={item.link}>

View File

@@ -1,43 +0,0 @@
import type { Meta, StoryObj } from "@storybook/react";
import CreditsCard from "./CreditsCard";
import { userEvent, within } from "@storybook/test";
const meta: Meta<typeof CreditsCard> = {
title: "AGPT UI/Credits Card",
component: CreditsCard,
tags: ["autodocs"],
};
export default meta;
type Story = StoryObj<typeof CreditsCard>;
export const Default: Story = {
args: {
credits: 0,
},
};
export const SmallNumber: Story = {
args: {
credits: 10,
},
};
export const LargeNumber: Story = {
args: {
credits: 1000000,
},
};
export const InteractionTest: Story = {
args: {
credits: 100,
},
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
const refreshButton = canvas.getByRole("button", {
name: /refresh credits/i,
});
await userEvent.click(refreshButton);
},
};

View File

@@ -1,48 +0,0 @@
"use client";
import { IconRefresh } from "@/components/ui/icons";
import { useState } from "react";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import useCredits from "@/hooks/useCredits";
const CreditsCard = () => {
const { credits, formatCredits, fetchCredits } = useCredits({
fetchInitialCredits: true,
});
const api = useBackendAPI();
const onRefresh = async () => {
fetchCredits();
};
return (
<div className="inline-flex h-[48px] items-center gap-2.5 rounded-2xl bg-neutral-200 p-4 dark:bg-neutral-800">
<div className="flex items-center gap-0.5">
<span className="p-ui-semibold text-base leading-7 text-neutral-900 dark:text-neutral-50">
Balance: {formatCredits(credits)}
</span>
</div>
<Tooltip key="RefreshCredits" delayDuration={500}>
<TooltipTrigger asChild>
<button
onClick={onRefresh}
className="h-6 w-6 transition-colors hover:text-neutral-700 dark:hover:text-neutral-300"
aria-label="Refresh credits"
>
<IconRefresh className="h-6 w-6" />
</button>
</TooltipTrigger>
<TooltipContent>
<p>Refresh credits</p>
</TooltipContent>
</Tooltip>
</div>
);
};
export default CreditsCard;

View File

@@ -27,7 +27,7 @@ export const FeaturedAgentCard: React.FC<FeaturedStoreCardProps> = ({
data-testid="featured-store-card"
onMouseEnter={() => setIsHovered(true)}
onMouseLeave={() => setIsHovered(false)}
className={`flex h-full flex-col ${backgroundColor}`}
className={`flex h-full flex-col ${backgroundColor} rounded-[1.5rem] border-none`}
>
<CardHeader>
<CardTitle className="line-clamp-2 text-base sm:text-xl">

View File

@@ -4,7 +4,7 @@ import { ProfilePopoutMenu } from "./ProfilePopoutMenu";
import { IconType, IconLogIn, IconAutoGPTLogo } from "@/components/ui/icons";
import { MobileNavBar } from "./MobileNavBar";
import { Button } from "./Button";
import CreditsCard from "./CreditsCard";
import Wallet from "./Wallet";
import { ProfileDetails } from "@/lib/autogpt-server-api/types";
import { NavbarLink } from "./NavbarLink";
import getServerUser from "@/lib/supabase/getServerUser";
@@ -61,7 +61,7 @@ export const Navbar = async ({ links, menuItemGroups }: NavbarProps) => {
<div className="flex items-center gap-4">
{isLoggedIn ? (
<div className="flex items-center gap-4">
{profile && <CreditsCard />}
{profile && <Wallet />}
<ProfilePopoutMenu
menuItemGroups={menuItemGroups}
userName={profile?.username}

View File

@@ -0,0 +1,114 @@
"use client";
import useCredits from "@/hooks/useCredits";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/ui/popover";
import { X } from "lucide-react";
import { PopoverClose } from "@radix-ui/react-popover";
import { TaskGroups } from "../onboarding/WalletTaskGroups";
import { ScrollArea } from "../ui/scroll-area";
import { useOnboarding } from "../onboarding/onboarding-provider";
import { useCallback, useEffect, useRef } from "react";
import { cn } from "@/lib/utils";
import * as party from "party-js";
export default function Wallet() {
const { credits, formatCredits, fetchCredits } = useCredits({
fetchInitialCredits: true,
});
const { state, updateState } = useOnboarding();
const walletRef = useRef<HTMLButtonElement | null>(null);
const onWalletOpen = useCallback(async () => {
if (state?.notificationDot) {
updateState({ notificationDot: false });
}
// Refresh credits when the wallet is opened
fetchCredits();
}, [state?.notificationDot, updateState, fetchCredits]);
const fadeOut = new party.ModuleBuilder()
.drive("opacity")
.by((t) => 1 - t)
.through("lifetime")
.build();
useEffect(() => {
// Check if there are any completed tasks (state?.completedTasks) that
// are not in the state?.notified array and play confetti if so
const pending = state?.completedSteps
.filter((step) => !state?.notified.includes(step))
// Ignore steps that are not relevant for notifications
.filter(
(step) =>
step !== "WELCOME" &&
step !== "USAGE_REASON" &&
step !== "INTEGRATIONS" &&
step !== "AGENT_CHOICE" &&
step !== "AGENT_NEW_RUN" &&
step !== "AGENT_INPUT",
);
if ((pending?.length || 0) > 0 && walletRef.current) {
party.confetti(walletRef.current, {
count: 30,
spread: 120,
shapes: ["square", "circle"],
size: party.variation.range(1, 2),
speed: party.variation.range(200, 300),
modules: [fadeOut],
});
}
}, [state?.completedSteps, state?.notified]);
return (
<Popover>
<PopoverTrigger asChild>
<button
ref={walletRef}
className="relative flex items-center gap-1 rounded-md bg-zinc-200 px-3 py-2 text-sm transition-colors duration-200 hover:bg-zinc-300"
onClick={onWalletOpen}
>
Wallet{" "}
<span className="text-sm font-semibold">
{formatCredits(credits)}
</span>
{state?.notificationDot && (
<span className="absolute right-1 top-1 h-2 w-2 rounded-full bg-violet-600"></span>
)}
</button>
</PopoverTrigger>
<PopoverContent
className={cn(
"absolute -right-[7.9rem] -top-[3.2rem] z-50 w-[28.5rem] px-[0.625rem] py-2",
"rounded-xl border-zinc-200 bg-zinc-50 shadow-[0_3px_3px] shadow-zinc-300",
)}
>
<div>
<div className="mx-1 flex items-center justify-between border-b border-zinc-300 pb-2">
<span className="font-poppins font-medium text-zinc-900">
Your wallet
</span>
<div className="flex items-center font-inter text-sm font-semibold text-violet-700">
<div className="rounded-lg bg-violet-100 px-3 py-2">
Wallet{" "}
<span className="font-semibold">{formatCredits(credits)}</span>
</div>
<PopoverClose>
<X className="ml-[2.8rem] h-5 w-5 text-zinc-800 hover:text-foreground" />
</PopoverClose>
</div>
</div>
<p className="mx-1 mt-3 font-inter text-xs text-muted-foreground text-zinc-400">
Complete the following tasks to earn more credits!
</p>
</div>
<ScrollArea className="max-h-[80vh] overflow-y-auto">
<TaskGroups />
</ScrollArea>
</PopoverContent>
</Popover>
);
}

View File

@@ -49,7 +49,7 @@ export const AgentsSection: React.FC<AgentsSectionProps> = ({
<div className="flex flex-col items-center justify-center">
<div className="w-full max-w-[1360px]">
<div
className={`mb-[${margin}] font-poppins text-[18px] font-[600] leading-7 text-[#282828] dark:text-neutral-200`}
className={`mb-[${margin}] font-poppins text-lg font-semibold text-[#282828] dark:text-neutral-200`}
>
{sectionTitle}
</div>

View File

@@ -33,7 +33,7 @@ export const FeaturedCreators: React.FC<FeaturedCreatorsProps> = ({
return (
<div className="flex w-full flex-col items-center justify-center">
<div className="w-full max-w-[1360px]">
<h2 className="mb-[37px] font-poppins text-2xl font-semibold leading-7 text-neutral-800 dark:text-neutral-200">
<h2 className="mb-9 font-poppins text-lg font-semibold text-neutral-800 dark:text-neutral-200">
{title}
</h2>

View File

@@ -46,7 +46,7 @@ export const FeaturedSection: React.FC<FeaturedSectionProps> = ({
};
return (
<section className="w-full pb-16">
<section className="w-full">
<h2 className="mb-8 font-poppins text-2xl font-semibold leading-7 text-neutral-800 dark:text-neutral-200">
Featured agents
</h2>

View File

@@ -4,9 +4,16 @@ import * as React from "react";
import { SearchBar } from "@/components/agptui/SearchBar";
import { FilterChips } from "@/components/agptui/FilterChips";
import { useRouter } from "next/navigation";
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
export const HeroSection: React.FC = () => {
const router = useRouter();
const { completeStep } = useOnboarding();
// Mark marketplace visit task as completed
React.useEffect(() => {
completeStep("MARKETPLACE_VISIT");
}, [completeStep]);
function onFilterChange(selectedFilters: string[]) {
const encodedTerm = encodeURIComponent(selectedFilters.join(", "));

View File

@@ -49,7 +49,7 @@ export default function LibrarySearchBar(): React.ReactNode {
onFocus={() => setIsFocused(true)}
onBlur={() => !inputRef.current?.value && setIsFocused(false)}
onChange={handleSearchInput}
className="flex-1 border-none font-sans text-[16px] font-normal leading-7 shadow-none focus:shadow-none"
className="flex-1 border-none font-sans text-[16px] font-normal leading-7 shadow-none focus:shadow-none focus:ring-0"
type="text"
placeholder="Search agents"
/>

View File

@@ -0,0 +1,331 @@
import { useCallback, useEffect, useRef, useState } from "react";
import { ChevronDown, Check } from "lucide-react";
import { OnboardingStep } from "@/lib/autogpt-server-api";
import { useOnboarding } from "./onboarding-provider";
import { cn } from "@/lib/utils";
import * as party from "party-js";
interface Task {
id: OnboardingStep;
name: string;
amount: number;
details: string;
video?: string;
}
interface TaskGroup {
name: string;
tasks: Task[];
isOpen: boolean;
}
export function TaskGroups() {
const [groups, setGroups] = useState<TaskGroup[]>([
{
name: "Run your first agent",
isOpen: false,
tasks: [
{
id: "CONGRATS",
name: "Finish onboarding",
amount: 3,
details: "Go through our step by step tutorial",
},
{
id: "GET_RESULTS",
name: "Get results from first agent",
amount: 3,
details:
"Sit back and relax - your agent is running and will finish soon! See the results in the Library once it's done",
video: "/onboarding/get-results.mp4",
},
],
},
{
name: "Explore the Marketplace",
isOpen: false,
tasks: [
{
id: "MARKETPLACE_VISIT",
name: "Go to Marketplace",
amount: 0,
details: "Click Marketplace in the top navigation",
video: "/onboarding/marketplace-visit.mp4",
},
{
id: "MARKETPLACE_ADD_AGENT",
name: "Find an agent",
amount: 1,
details:
"Search for an agent in the Marketplace, like a code generator or research assistant and add it to your Library",
video: "/onboarding/marketplace-add.mp4",
},
{
id: "MARKETPLACE_RUN_AGENT",
name: "Try out your agent",
amount: 1,
details:
"Run the agent you found in the Marketplace from the Library - whether it's a writing assistant, data analyzer, or something else",
video: "/onboarding/marketplace-run.mp4",
},
],
},
{
name: "Build your own agent",
isOpen: false,
tasks: [
{
id: "BUILDER_OPEN",
name: "Open the Builder",
amount: 0,
details: "Click Builder in the top navigation",
video: "/onboarding/builder-open.mp4",
},
{
id: "BUILDER_SAVE_AGENT",
name: "Place your first blocks and save your agent",
amount: 1,
details:
"Open block library on the left and add a block to the canvas then save your agent",
video: "/onboarding/builder-save.mp4",
},
{
id: "BUILDER_RUN_AGENT",
name: "Run your agent",
amount: 1,
details: "Run your agent from the Builder",
video: "/onboarding/builder-run.mp4",
},
],
},
]);
const { state, updateState } = useOnboarding();
const refs = useRef<Record<string, HTMLDivElement | null>>({});
const toggleGroup = useCallback((name: string) => {
setGroups((prevGroups) =>
prevGroups.map((group) =>
group.name === name ? { ...group, isOpen: !group.isOpen } : group,
),
);
}, []);
const isTaskCompleted = useCallback(
(task: Task) => {
return state?.completedSteps?.includes(task.id) || false;
},
[state?.completedSteps],
);
const getCompletedCount = useCallback(
(tasks: Task[]) => {
return tasks.filter((task) => isTaskCompleted(task)).length;
},
[isTaskCompleted],
);
const isGroupCompleted = useCallback(
(group: TaskGroup) => {
return group.tasks.every((task) => isTaskCompleted(task));
},
[isTaskCompleted],
);
const setRef = (name: string) => (el: HTMLDivElement | null) => {
if (el) {
refs.current[name] = el;
}
};
useEffect(() => {
groups.forEach((group) => {
const groupCompleted = isGroupCompleted(group);
// Check if the last task in the group is completed
const alreadyCelebrated = state?.notified.includes(
group.tasks[group.tasks.length - 1].id,
);
if (groupCompleted) {
const el = refs.current[group.name];
if (el && !alreadyCelebrated) {
party.confetti(el, {
count: 50,
spread: 120,
shapes: ["square", "circle"],
size: party.variation.range(1, 2),
speed: party.variation.range(200, 300),
});
// Update the state to include all group tasks as notified
// This ensures that the confetti effect isn't perpetually triggered on Wallet
const notifiedTasks = group.tasks.map((task) => task.id);
updateState({
notified: [...(state?.notified || []), ...notifiedTasks],
});
}
return;
}
group.tasks.forEach((task) => {
const el = refs.current[task.id];
if (el && isTaskCompleted(task) && !state?.notified.includes(task.id)) {
party.confetti(el, {
count: 40,
spread: 120,
shapes: ["square", "circle"],
size: party.variation.range(1, 1.5),
speed: party.variation.range(200, 300),
});
// Update the state to include the task as notified
updateState({ notified: [...(state?.notified || []), task.id] });
}
});
});
}, [state?.completedSteps]);
return (
<div className="space-y-2">
{groups.map((group) => (
<div
key={group.name}
ref={setRef(group.name)}
className="mt-3 overflow-hidden rounded-lg border border-zinc-200 bg-zinc-100"
>
{/* Group Header - unchanged */}
<div
className="flex cursor-pointer items-center justify-between p-3"
onClick={() => toggleGroup(group.name)}
>
{/* Name and completed count */}
<div className="flex-1">
<div
className={cn(
"text-sm font-medium text-zinc-900",
isGroupCompleted(group) ? "text-zinc-600 line-through" : "",
)}
>
{group.name}
</div>
<div
className={cn(
"mt-1 text-xs font-normal leading-tight text-zinc-500",
isGroupCompleted(group) ? "line-through" : "",
)}
>
{getCompletedCount(group.tasks)} of {group.tasks.length}{" "}
completed
</div>
</div>
{/* Reward and chevron */}
<div className="flex items-center gap-2">
<div
className={cn(
"text-xs font-medium leading-tight text-violet-600",
isGroupCompleted(group) ? "line-through" : "",
)}
>
$
{group.tasks
.reduce((sum, task) => sum + task.amount, 0)
.toFixed(2)}
</div>
<ChevronDown
className={`h-5 w-5 text-slate-950 transition-transform duration-300 ease-in-out ${
group.isOpen ? "rotate-180" : ""
}`}
/>
</div>
</div>
{/* Tasks */}
<div
className={cn(
"overflow-hidden transition-all duration-300 ease-in-out",
group.isOpen || !isGroupCompleted(group)
? "max-h-[1000px] opacity-100"
: "max-h-0 opacity-0",
)}
>
{group.tasks.map((task) => (
<div
key={task.id}
ref={setRef(task.id)}
className="mx-3 border-t border-zinc-300 px-1 pb-1 pt-3"
>
<div className="flex items-center justify-between">
{/* Checkmark and name */}
<div className="flex items-center gap-2">
<div
className={cn(
"flex h-4 w-4 items-center justify-center rounded-full border",
isTaskCompleted(task)
? "border-emerald-600"
: "border-zinc-600",
)}
>
{isTaskCompleted(task) && (
<Check className="h-3 w-3 text-emerald-600" />
)}
</div>
<span
className={cn(
"text-sm font-normal",
isTaskCompleted(task)
? "text-zinc-500 line-through"
: "text-zinc-800",
)}
>
{task.name}
</span>
</div>
{/* Reward */}
{task.amount > 0 && (
<span
className={cn(
"text-xs font-normal text-zinc-500",
isTaskCompleted(task) ? "line-through" : "",
)}
>
${task.amount.toFixed(2)}
</span>
)}
</div>
{/* Details section */}
<div
className={cn(
"mt-2 overflow-hidden pl-6 text-xs font-normal text-zinc-500 transition-all duration-300 ease-in-out",
isTaskCompleted(task) && "line-through",
group.isOpen
? "max-h-[100px] opacity-100"
: "max-h-0 opacity-0",
)}
>
{task.details}
</div>
{task.video && (
<div
className={cn(
"relative mx-6 aspect-video overflow-hidden rounded-lg transition-all duration-300 ease-in-out",
group.isOpen
? "my-2 max-h-[200px] opacity-100"
: "max-h-0 opacity-0",
)}
>
<video
src={task.video}
autoPlay
loop
muted
playsInline
className={cn("h-full w-full object-cover object-center")}
></video>
</div>
)}
</div>
))}
</div>
</div>
))}
</div>
);
}

View File

@@ -14,9 +14,12 @@ import {
const OnboardingContext = createContext<
| {
state: UserOnboarding | null;
updateState: (state: Partial<UserOnboarding>) => void;
updateState: (
state: Omit<Partial<UserOnboarding>, "rewardedFor">,
) => void;
step: number;
setStep: (step: number) => void;
completeStep: (step: OnboardingStep) => void;
}
| undefined
>(undefined);
@@ -84,19 +87,23 @@ export default function OnboardingProvider({
}, [api, pathname, router]);
const updateState = useCallback(
(newState: Partial<UserOnboarding>) => {
(newState: Omit<Partial<UserOnboarding>, "rewardedFor">) => {
setState((prev) => {
api.updateUserOnboarding({ ...prev, ...newState });
api.updateUserOnboarding(newState);
if (!prev) {
// Handle initial state
return {
completedSteps: [],
notificationDot: false,
notified: [],
rewardedFor: [],
usageReason: null,
integrations: [],
otherIntegrations: null,
selectedStoreListingVersionId: null,
agentInput: null,
onboardingAgentExecutionId: null,
...newState,
};
}
@@ -106,8 +113,21 @@ export default function OnboardingProvider({
[api, setState],
);
const completeStep = useCallback(
(step: OnboardingStep) => {
if (!state || state.completedSteps.includes(step)) return;
updateState({
completedSteps: [...state.completedSteps, step],
});
},
[api, state],
);
return (
<OnboardingContext.Provider value={{ state, updateState, step, setStep }}>
<OnboardingContext.Provider
value={{ state, updateState, step, setStep, completeStep }}
>
{children}
</OnboardingContext.Provider>
);

View File

@@ -213,7 +213,7 @@ const CarouselPrevious = React.forwardRef<
className={cn(
"absolute h-[52px] w-[52px] rounded-full",
orientation === "horizontal"
? "right-24 top-0"
? "right-20 top-0"
: "-top-12 left-1/2 -translate-x-1/2 rotate-90",
className,
)}

View File

@@ -24,6 +24,7 @@ import { useToast } from "@/components/ui/use-toast";
import { InputItem } from "@/components/RunnerUIWrapper";
import { GraphMeta } from "@/lib/autogpt-server-api";
import { default as NextLink } from "next/link";
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
const ajv = new Ajv({ strict: false, allErrors: true });
@@ -77,6 +78,7 @@ export default function useAgentGraph(
useState(false);
const [nodes, setNodes] = useState<CustomNode[]>([]);
const [edges, setEdges] = useState<CustomEdge[]>([]);
const { state, completeStep } = useOnboarding();
const api = useMemo(
() => new BackendAPI(process.env.NEXT_PUBLIC_AGPT_SERVER_URL!),
@@ -576,6 +578,9 @@ export default function useAgentGraph(
path.set("flowVersion", savedAgent.version.toString());
path.set("flowExecutionID", graphExecution.graph_exec_id);
router.push(`${pathname}?${path.toString()}`);
if (state?.completedSteps.includes("BUILDER_SAVE_AGENT")) {
completeStep("BUILDER_RUN_AGENT");
}
})
.catch((error) => {
const errorMessage =
@@ -966,6 +971,7 @@ export default function useAgentGraph(
const saveAgent = useCallback(async () => {
try {
await _saveAgent();
completeStep("BUILDER_SAVE_AGENT");
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);

View File

@@ -180,7 +180,9 @@ export default class BackendAPI {
return this._get("/onboarding");
}
updateUserOnboarding(onboarding: Partial<UserOnboarding>): Promise<void> {
updateUserOnboarding(
onboarding: Omit<Partial<UserOnboarding>, "rewardedFor">,
): Promise<void> {
return this._request("PATCH", "/onboarding", onboarding);
}

View File

@@ -801,15 +801,26 @@ export type OnboardingStep =
| "AGENT_CHOICE"
| "AGENT_NEW_RUN"
| "AGENT_INPUT"
| "CONGRATS";
| "CONGRATS"
| "GET_RESULTS"
| "MARKETPLACE_VISIT"
| "MARKETPLACE_ADD_AGENT"
| "MARKETPLACE_RUN_AGENT"
| "BUILDER_OPEN"
| "BUILDER_SAVE_AGENT"
| "BUILDER_RUN_AGENT";
export interface UserOnboarding {
completedSteps: OnboardingStep[];
notificationDot: boolean;
notified: OnboardingStep[];
rewardedFor: OnboardingStep[];
usageReason: string | null;
integrations: string[];
otherIntegrations: string | null;
selectedStoreListingVersionId: string | null;
agentInput: { [key: string]: string } | null;
agentInput: { [key: string]: string | number } | null;
onboardingAgentExecutionId: GraphExecutionID | null;
}
/* *** UTILITIES *** */

View File

@@ -5026,11 +5026,6 @@ caniuse-lite@^1.0.30001579, caniuse-lite@^1.0.30001688:
resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001697.tgz#040bbbb54463c4b4b3377c716b34a322d16e6fc7"
integrity sha512-GwNPlWJin8E+d7Gxq96jxM6w0w+VFeyyXRsjU58emtkYqnbwHqXm5uT2uCmO0RQE9htWknOP4xtBlLmM/gWxvQ==
canvas-confetti@^1.9.3:
version "1.9.3"
resolved "https://registry.yarnpkg.com/canvas-confetti/-/canvas-confetti-1.9.3.tgz#ef4c857420ad8045ab4abe8547261c8cdf229845"
integrity sha512-rFfTURMvmVEX1gyXFgn5QMn81bYk70qa0HLzcIOSVEyl57n6o9ItHeBtUSWdvKAPY0xlvBHno4/v3QPrT83q9g==
case-sensitive-paths-webpack-plugin@^2.4.0:
version "2.4.0"
resolved "https://registry.yarnpkg.com/case-sensitive-paths-webpack-plugin/-/case-sensitive-paths-webpack-plugin-2.4.0.tgz#db64066c6422eed2e08cc14b986ca43796dbc6d4"
@@ -9590,6 +9585,11 @@ parse-passwd@^1.0.0:
resolved "https://registry.yarnpkg.com/parse-passwd/-/parse-passwd-1.0.0.tgz#6d5b934a456993b23d37f40a382d6f1666a8e5c6"
integrity sha512-1Y1A//QUXEZK7YKz+rD9WydcE1+EuPr6ZBgKecAB8tmoW6UFv0NREVJe1p+jRxtThkcbbKkfwIbWJe/IeE6m2Q==
party-js@^2.2.0:
version "2.2.0"
resolved "https://registry.yarnpkg.com/party-js/-/party-js-2.2.0.tgz#3340026971c9e62fd34db102daaa645fbc9130b8"
integrity sha512-50hGuALCpvDTrQLPQ1fgUgxKIWAH28ShVkmeK/3zhO0YJyCqkhrZhQEkWPxDYLvbFJ7YAXyROmFEu35gKpZLtQ==
pascal-case@^3.1.2:
version "3.1.2"
resolved "https://registry.yarnpkg.com/pascal-case/-/pascal-case-3.1.2.tgz#b48e0ef2b98e205e7c1dae747d0b1508237660eb"