diff --git a/autogpt_platform/backend/backend/blocks/agent.py b/autogpt_platform/backend/backend/blocks/agent.py index c2b600dd4c..b979721988 100644 --- a/autogpt_platform/backend/backend/blocks/agent.py +++ b/autogpt_platform/backend/backend/blocks/agent.py @@ -1,8 +1,6 @@ import logging from typing import Any -from autogpt_libs.utils.cache import thread_cached - from backend.data.block import ( Block, BlockCategory, @@ -19,21 +17,6 @@ from backend.util import json logger = logging.getLogger(__name__) -@thread_cached -def get_executor_manager_client(): - from backend.executor import ExecutionManager - from backend.util.service import get_service_client - - return get_service_client(ExecutionManager) - - -@thread_cached -def get_event_bus(): - from backend.data.execution import RedisExecutionEventBus - - return RedisExecutionEventBus() - - class AgentExecutorBlock(Block): class Input(BlockSchema): user_id: str = SchemaField(description="User ID") @@ -76,11 +59,11 @@ class AgentExecutorBlock(Block): def run(self, input_data: Input, **kwargs) -> BlockOutput: from backend.data.execution import ExecutionEventType + from backend.executor import utils as execution_utils - executor_manager = get_executor_manager_client() - event_bus = get_event_bus() + event_bus = execution_utils.get_execution_event_bus() - graph_exec = executor_manager.add_execution( + graph_exec = execution_utils.add_graph_execution( graph_id=input_data.graph_id, graph_version=input_data.graph_version, user_id=input_data.user_id, diff --git a/autogpt_platform/backend/backend/data/credit.py b/autogpt_platform/backend/backend/data/credit.py index f76c36a6c3..b106efc3ac 100644 --- a/autogpt_platform/backend/backend/data/credit.py +++ b/autogpt_platform/backend/backend/data/credit.py @@ -3,7 +3,6 @@ import logging from abc import ABC, abstractmethod from collections import defaultdict from datetime import datetime, timezone -from typing import cast import stripe from autogpt_libs.utils.cache import thread_cached @@ -12,6 +11,7 @@ from prisma.enums import ( CreditRefundRequestStatus, CreditTransactionType, NotificationType, + OnboardingStep, ) from prisma.errors import UniqueViolationError from prisma.models import CreditRefundRequest, CreditTransaction, User @@ -19,7 +19,6 @@ from prisma.types import ( CreditRefundRequestCreateInput, CreditTransactionCreateInput, CreditTransactionWhereInput, - IntFilter, ) from tenacity import retry, stop_after_attempt, wait_exponential @@ -123,6 +122,18 @@ class UserCreditBase(ABC): """ pass + @abstractmethod + async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep): + """ + Reward the user with credits for completing an onboarding step. + Won't reward if the user has already received credits for the step. + + Args: + user_id (str): The user ID. + step (OnboardingStep): The onboarding step. + """ + pass + @abstractmethod async def top_up_intent(self, user_id: str, amount: int) -> str: """ @@ -215,7 +226,7 @@ class UserCreditBase(ABC): "userId": user_id, "createdAt": {"lte": top_time}, "isActive": True, - "runningBalance": cast(IntFilter, {"not": None}), + "NOT": [{"runningBalance": None}], }, order={"createdAt": "desc"}, ) @@ -410,6 +421,24 @@ class UserCredit(UserCreditBase): async def top_up_credits(self, user_id: str, amount: int): await self._top_up_credits(user_id, amount) + async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep): + key = f"REWARD-{user_id}-{step.value}" + if not await CreditTransaction.prisma().find_first( + where={ + "userId": user_id, + "transactionKey": key, + } + ): + await self._add_transaction( + user_id=user_id, + amount=credits, + transaction_type=CreditTransactionType.GRANT, + transaction_key=key, + metadata=Json( + {"reason": f"Reward for completing {step.value} onboarding step."} + ), + ) + async def top_up_refund( self, user_id: str, transaction_key: str, metadata: dict[str, str] ) -> int: @@ -897,6 +926,9 @@ class DisabledUserCredit(UserCreditBase): async def top_up_credits(self, *args, **kwargs): pass + async def onboarding_reward(self, *args, **kwargs): + pass + async def top_up_intent(self, *args, **kwargs) -> str: return "" diff --git a/autogpt_platform/backend/backend/data/db.py b/autogpt_platform/backend/backend/data/db.py index 0702c47730..f962e9d393 100644 --- a/autogpt_platform/backend/backend/data/db.py +++ b/autogpt_platform/backend/backend/data/db.py @@ -89,7 +89,7 @@ async def transaction(): async def locked_transaction(key: str): lock_key = zlib.crc32(key.encode("utf-8")) async with transaction() as tx: - await tx.execute_raw(f"SELECT pg_advisory_xact_lock({lock_key})") + await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key) yield tx diff --git a/autogpt_platform/backend/backend/data/execution.py b/autogpt_platform/backend/backend/data/execution.py index 842c36c699..da49461888 100644 --- a/autogpt_platform/backend/backend/data/execution.py +++ b/autogpt_platform/backend/backend/data/execution.py @@ -34,11 +34,10 @@ from pydantic import BaseModel from pydantic.fields import Field from backend.server.v2.store.exceptions import DatabaseError -from backend.util import mock from backend.util import type as type_utils from backend.util.settings import Config -from .block import BlockData, BlockInput, BlockType, CompletedBlockOutput, get_block +from .block import BlockInput, BlockType, CompletedBlockOutput, get_block from .db import BaseDbModel from .includes import ( EXECUTION_RESULT_INCLUDE, @@ -203,6 +202,26 @@ class GraphExecutionWithNodes(GraphExecution): node_executions=node_executions, ) + def to_graph_execution_entry(self): + return GraphExecutionEntry( + user_id=self.user_id, + graph_id=self.graph_id, + graph_version=self.graph_version or 0, + graph_exec_id=self.id, + start_node_execs=[ + NodeExecutionEntry( + user_id=self.user_id, + graph_exec_id=node_exec.graph_exec_id, + graph_id=node_exec.graph_id, + node_exec_id=node_exec.node_exec_id, + node_id=node_exec.node_id, + block_id=node_exec.block_id, + data=node_exec.input_data, + ) + for node_exec in self.node_executions + ], + ) + class NodeExecutionResult(BaseModel): user_id: str @@ -469,19 +488,27 @@ async def upsert_execution_output( ) -async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecution: - res = await AgentGraphExecution.prisma().update( - where={"id": graph_exec_id}, +async def update_graph_execution_start_time( + graph_exec_id: str, +) -> GraphExecution | None: + count = await AgentGraphExecution.prisma().update_many( + where={ + "id": graph_exec_id, + "executionStatus": ExecutionStatus.QUEUED, + }, data={ "executionStatus": ExecutionStatus.RUNNING, "startedAt": datetime.now(tz=timezone.utc), }, + ) + if count == 0: + return None + + res = await AgentGraphExecution.prisma().find_unique( + where={"id": graph_exec_id}, include=GRAPH_EXECUTION_INCLUDE, ) - if not res: - raise ValueError(f"Graph execution #{graph_exec_id} not found") - - return GraphExecution.from_db(res) + return GraphExecution.from_db(res) if res else None async def update_graph_execution_stats( @@ -492,7 +519,8 @@ async def update_graph_execution_stats( data = stats.model_dump() if stats else {} if isinstance(data.get("error"), Exception): data["error"] = str(data["error"]) - res = await AgentGraphExecution.prisma().update( + + updated_count = await AgentGraphExecution.prisma().update_many( where={ "id": graph_exec_id, "OR": [ @@ -504,10 +532,15 @@ async def update_graph_execution_stats( "executionStatus": status, "stats": Json(data), }, + ) + if updated_count == 0: + return None + + graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise( + where={"id": graph_exec_id}, include=GRAPH_EXECUTION_INCLUDE, ) - - return GraphExecution.from_db(res) if res else None + return GraphExecution.from_db(graph_exec) async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats): @@ -643,7 +676,7 @@ async def get_latest_node_execution( where={ "agentNodeId": node_id, "agentGraphExecutionId": graph_eid, - "executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore + "NOT": [{"executionStatus": ExecutionStatus.INCOMPLETE}], }, order=[ {"queuedTime": "desc"}, @@ -711,144 +744,6 @@ class ExecutionQueue(Generic[T]): return self.queue.empty() -# ------------------- Execution Utilities -------------------- # - - -LIST_SPLIT = "_$_" -DICT_SPLIT = "_#_" -OBJC_SPLIT = "_@_" - - -def parse_execution_output(output: BlockData, name: str) -> Any | None: - """ - Extracts partial output data by name from a given BlockData. - - The function supports extracting data from lists, dictionaries, and objects - using specific naming conventions: - - For lists: _$_ - - For dictionaries: _#_ - - For objects: _@_ - - Args: - output (BlockData): A tuple containing the output name and data. - name (str): The name used to extract specific data from the output. - - Returns: - Any | None: The extracted data if found, otherwise None. - - Examples: - >>> output = ("result", [10, 20, 30]) - >>> parse_execution_output(output, "result_$_1") - 20 - - >>> output = ("config", {"key1": "value1", "key2": "value2"}) - >>> parse_execution_output(output, "config_#_key1") - 'value1' - - >>> class Sample: - ... attr1 = "value1" - ... attr2 = "value2" - >>> output = ("object", Sample()) - >>> parse_execution_output(output, "object_@_attr1") - 'value1' - """ - output_name, output_data = output - - if name == output_name: - return output_data - - if name.startswith(f"{output_name}{LIST_SPLIT}"): - index = int(name.split(LIST_SPLIT)[1]) - if not isinstance(output_data, list) or len(output_data) <= index: - return None - return output_data[int(name.split(LIST_SPLIT)[1])] - - if name.startswith(f"{output_name}{DICT_SPLIT}"): - index = name.split(DICT_SPLIT)[1] - if not isinstance(output_data, dict) or index not in output_data: - return None - return output_data[index] - - if name.startswith(f"{output_name}{OBJC_SPLIT}"): - index = name.split(OBJC_SPLIT)[1] - if isinstance(output_data, object) and hasattr(output_data, index): - return getattr(output_data, index) - return None - - return None - - -def merge_execution_input(data: BlockInput) -> BlockInput: - """ - Merges dynamic input pins into a single list, dictionary, or object based on naming patterns. - - This function processes input keys that follow specific patterns to merge them into a unified structure: - - `_$_` for list inputs. - - `_#_` for dictionary inputs. - - `_@_` for object inputs. - - Args: - data (BlockInput): A dictionary containing input keys and their corresponding values. - - Returns: - BlockInput: A dictionary with merged inputs. - - Raises: - ValueError: If a list index is not an integer. - - Examples: - >>> data = { - ... "list_$_0": "a", - ... "list_$_1": "b", - ... "dict_#_key1": "value1", - ... "dict_#_key2": "value2", - ... "object_@_attr1": "value1", - ... "object_@_attr2": "value2" - ... } - >>> merge_execution_input(data) - { - "list": ["a", "b"], - "dict": {"key1": "value1", "key2": "value2"}, - "object": - } - """ - - # Merge all input with _$_ into a single list. - items = list(data.items()) - - for key, value in items: - if LIST_SPLIT not in key: - continue - name, index = key.split(LIST_SPLIT) - if not index.isdigit(): - raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.") - - data[name] = data.get(name, []) - if int(index) >= len(data[name]): - # Pad list with empty string on missing indices. - data[name].extend([""] * (int(index) - len(data[name]) + 1)) - data[name][int(index)] = value - - # Merge all input with _#_ into a single dict. - for key, value in items: - if DICT_SPLIT not in key: - continue - name, index = key.split(DICT_SPLIT) - data[name] = data.get(name, {}) - data[name][index] = value - - # Merge all input with _@_ into a single object. - for key, value in items: - if OBJC_SPLIT not in key: - continue - name, index = key.split(OBJC_SPLIT) - if name not in data or not isinstance(data[name], object): - data[name] = mock.MockObject() - setattr(data[name], index, value) - - return data - - # --------------------- Event Bus --------------------- # diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index 98ef52e0b8..b0af47a263 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -10,9 +10,7 @@ from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVers from prisma.types import ( AgentGraphCreateInput, AgentGraphWhereInput, - AgentGraphWhereInputRecursive1, AgentNodeCreateInput, - AgentNodeIncludeFromAgentNodeRecursive1, AgentNodeLinkCreateInput, ) from pydantic.fields import computed_field @@ -655,14 +653,11 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]: graphs = await AgentGraph.prisma().find_many( where={ "OR": [ - type_utils.typed( - AgentGraphWhereInputRecursive1, - { - "id": graph_id, - "version": graph_version, - "userId": graph.userId, # Ensure the sub-graph is owned by the same user - }, - ) + { + "id": graph_id, + "version": graph_version, + "userId": graph.userId, # Ensure the sub-graph is owned by the same user + } for graph_id, graph_version in sub_graph_ids ] }, @@ -678,13 +673,7 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]: async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]: links = await AgentNodeLink.prisma().find_many( where={"agentNodeSourceId": node_id}, - include={ - "AgentNodeSink": { - "include": cast( - AgentNodeIncludeFromAgentNodeRecursive1, AGENT_NODE_INCLUDE - ) - } - }, + include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}}, ) return [ (Link.from_db(link), NodeModel.from_db(link.AgentNodeSink)) @@ -930,12 +919,19 @@ async def migrate_llm_models(migrate_to: LlmModel): # Convert enum values to a list of strings for the SQL query enum_values = [v.value for v in LlmModel.__members__.values()] + escaped_enum_values = repr(tuple(enum_values)) # hack but works query = f""" UPDATE "AgentNode" - SET "constantInput" = jsonb_set("constantInput", '{{{path}}}', '"{migrate_to.value}"', true) - WHERE "agentBlockId" = '{id}' - AND "constantInput" ? '{path}' - AND "constantInput"->>'{path}' NOT IN ({','.join(f"'{value}'" for value in enum_values)}) + SET "constantInput" = jsonb_set("constantInput", $1, $2, true) + WHERE "agentBlockId" = $3 + AND "constantInput" ? $4 + AND "constantInput"->>$4 NOT IN {escaped_enum_values} """ - await db.execute_raw(query) + await db.execute_raw( + query, # type: ignore - is supposed to be LiteralString + "{" + path + "}", + f'"{migrate_to.value}"', + id, + path, + ) diff --git a/autogpt_platform/backend/backend/data/includes.py b/autogpt_platform/backend/backend/data/includes.py index 8e8b15e121..f6b1ea2592 100644 --- a/autogpt_platform/backend/backend/data/includes.py +++ b/autogpt_platform/backend/backend/data/includes.py @@ -4,7 +4,6 @@ import prisma.enums import prisma.types from backend.blocks.io import IO_BLOCK_IDs -from backend.util.type import typed_cast AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = { "Input": True, @@ -14,13 +13,7 @@ AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = { } AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = { - "Nodes": { - "include": typed_cast( - prisma.types.AgentNodeIncludeFromAgentNodeRecursive1, - prisma.types.AgentNodeIncludeFromAgentNode, - AGENT_NODE_INCLUDE, - ) - } + "Nodes": {"include": AGENT_NODE_INCLUDE} } EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = { @@ -56,13 +49,7 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = { GRAPH_EXECUTION_INCLUDE_WITH_NODES["NodeExecutions"], ), "where": { - "Node": typed_cast( - prisma.types.AgentNodeRelationFilter, - prisma.types.AgentNodeWhereInput, - { - "AgentBlock": {"id": {"in": IO_BLOCK_IDs}}, - }, - ), + "Node": {"is": {"AgentBlock": {"is": {"id": {"in": IO_BLOCK_IDs}}}}}, "NOT": [{"executionStatus": prisma.enums.AgentExecutionStatus.INCOMPLETE}], }, } @@ -70,13 +57,7 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = { INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = { - "AgentNodes": { - "include": typed_cast( - prisma.types.AgentNodeIncludeFromAgentNodeRecursive1, - prisma.types.AgentNodeInclude, - AGENT_NODE_INCLUDE, - ) - } + "AgentNodes": {"include": AGENT_NODE_INCLUDE} } diff --git a/autogpt_platform/backend/backend/data/onboarding.py b/autogpt_platform/backend/backend/data/onboarding.py index 1f7503b5b6..c4ad817b5c 100644 --- a/autogpt_platform/backend/backend/data/onboarding.py +++ b/autogpt_platform/backend/backend/data/onboarding.py @@ -8,7 +8,9 @@ from prisma.enums import OnboardingStep from prisma.models import UserOnboarding from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput +from backend.data import db from backend.data.block import get_blocks +from backend.data.credit import get_user_credit_model from backend.data.graph import GraphModel from backend.data.model import CredentialsMetaInput from backend.server.v2.store.model import StoreAgentDetails @@ -24,14 +26,19 @@ REASON_MAPPING: dict[str, list[str]] = { POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding +user_credit = get_user_credit_model() + class UserOnboardingUpdate(pydantic.BaseModel): completedSteps: Optional[list[OnboardingStep]] = None + notificationDot: Optional[bool] = None + notified: Optional[list[OnboardingStep]] = None usageReason: Optional[str] = None integrations: Optional[list[str]] = None otherIntegrations: Optional[str] = None selectedStoreListingVersionId: Optional[str] = None agentInput: Optional[dict[str, Any]] = None + onboardingAgentExecutionId: Optional[str] = None async def get_user_onboarding(user_id: str): @@ -48,6 +55,20 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate): update: UserOnboardingUpdateInput = {} if data.completedSteps is not None: update["completedSteps"] = list(set(data.completedSteps)) + for step in ( + OnboardingStep.AGENT_NEW_RUN, + OnboardingStep.GET_RESULTS, + OnboardingStep.MARKETPLACE_ADD_AGENT, + OnboardingStep.MARKETPLACE_RUN_AGENT, + OnboardingStep.BUILDER_SAVE_AGENT, + OnboardingStep.BUILDER_RUN_AGENT, + ): + if step in data.completedSteps: + await reward_user(user_id, step) + if data.notificationDot is not None: + update["notificationDot"] = data.notificationDot + if data.notified is not None: + update["notified"] = list(set(data.notified)) if data.usageReason is not None: update["usageReason"] = data.usageReason if data.integrations is not None: @@ -58,16 +79,57 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate): update["selectedStoreListingVersionId"] = data.selectedStoreListingVersionId if data.agentInput is not None: update["agentInput"] = Json(data.agentInput) + if data.onboardingAgentExecutionId is not None: + update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId return await UserOnboarding.prisma().upsert( where={"userId": user_id}, data={ - "create": {"userId": user_id, **update}, # type: ignore + "create": {"userId": user_id, **update}, "update": update, }, ) +async def reward_user(user_id: str, step: OnboardingStep): + async with db.locked_transaction(f"usr_trx_{user_id}-reward"): + reward = 0 + match step: + # Reward user when they clicked New Run during onboarding + # This is because they need credits before scheduling a run (next step) + case OnboardingStep.AGENT_NEW_RUN: + reward = 300 + case OnboardingStep.GET_RESULTS: + reward = 300 + case OnboardingStep.MARKETPLACE_ADD_AGENT: + reward = 100 + case OnboardingStep.MARKETPLACE_RUN_AGENT: + reward = 100 + case OnboardingStep.BUILDER_SAVE_AGENT: + reward = 100 + case OnboardingStep.BUILDER_RUN_AGENT: + reward = 100 + + if reward == 0: + return + + onboarding = await get_user_onboarding(user_id) + + # Skip if already rewarded + if step in onboarding.rewardedFor: + return + + onboarding.rewardedFor.append(step) + await user_credit.onboarding_reward(user_id, reward, step) + await UserOnboarding.prisma().update( + where={"userId": user_id}, + data={ + "completedSteps": list(set(onboarding.completedSteps + [step])), + "rewardedFor": onboarding.rewardedFor, + }, + ) + + def clean_and_split(text: str) -> list[str]: """ Removes all special characters from a string, truncates it to 100 characters, diff --git a/autogpt_platform/backend/backend/data/user.py b/autogpt_platform/backend/backend/data/user.py index 739f3456dd..31e4c46b51 100644 --- a/autogpt_platform/backend/backend/data/user.py +++ b/autogpt_platform/backend/backend/data/user.py @@ -149,7 +149,7 @@ async def migrate_and_encrypt_user_integrations(): logger.info(f"Migrating integration credentials for {len(users)} users") for user in users: - raw_metadata = cast(UserMetadataRaw, user.metadata) + raw_metadata = cast(dict, user.metadata) metadata = UserMetadata.model_validate(raw_metadata) # Get existing integrations data @@ -165,7 +165,6 @@ async def migrate_and_encrypt_user_integrations(): await update_user_integrations(user_id=user.id, data=integrations) # Remove from metadata - raw_metadata = dict(raw_metadata) raw_metadata.pop("integration_credentials", None) raw_metadata.pop("integration_oauth_states", None) diff --git a/autogpt_platform/backend/backend/executor/database.py b/autogpt_platform/backend/backend/executor/database.py index 02645f12a9..5ad4791e29 100644 --- a/autogpt_platform/backend/backend/executor/database.py +++ b/autogpt_platform/backend/backend/executor/database.py @@ -1,11 +1,8 @@ import logging -from backend.data import db, redis +from backend.data import db from backend.data.credit import UsageTransactionMetadata, get_user_credit_model from backend.data.execution import ( - GraphExecution, - NodeExecutionResult, - RedisExecutionEventBus, create_graph_execution, get_graph_execution, get_incomplete_node_executions, @@ -42,7 +39,7 @@ from backend.data.user import ( update_user_integrations, update_user_metadata, ) -from backend.util.service import AppService, expose, exposed_run_and_wait +from backend.util.service import AppService, exposed_run_and_wait from backend.util.settings import Config config = Config() @@ -57,21 +54,14 @@ async def _spend_credits( class DatabaseManager(AppService): - def __init__(self): - super().__init__() - self.execution_event_bus = RedisExecutionEventBus() def run_service(self) -> None: logger.info(f"[{self.service_name}] ⏳ Connecting to Database...") self.run_and_wait(db.connect()) - logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...") - redis.connect() super().run_service() def cleanup(self): super().cleanup() - logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...") - redis.disconnect() logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...") self.run_and_wait(db.disconnect()) @@ -79,12 +69,6 @@ class DatabaseManager(AppService): def get_port(cls) -> int: return config.database_api_port - @expose - def send_execution_update( - self, execution_result: GraphExecution | NodeExecutionResult - ): - self.execution_event_bus.publish(execution_result) - # Executions get_graph_execution = exposed_run_and_wait(get_graph_execution) create_graph_execution = exposed_run_and_wait(create_graph_execution) diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index 9596434806..4b5cb97755 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -5,11 +5,14 @@ import os import signal import sys import threading +import time from concurrent.futures import Future, ProcessPoolExecutor from contextlib import contextmanager from multiprocessing.pool import AsyncResult, Pool -from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast +from pika.adapters.blocking_connection import BlockingChannel +from pika.spec import Basic from redis.lock import Lock as RedisLock from backend.blocks.io import AgentOutputBlock @@ -30,43 +33,36 @@ from autogpt_libs.utils.cache import thread_cached from backend.blocks.agent import AgentExecutorBlock from backend.data import redis -from backend.data.block import ( - Block, - BlockData, - BlockInput, - BlockSchema, - BlockType, - get_block, -) +from backend.data.block import BlockData, BlockInput, BlockSchema, get_block from backend.data.execution import ( ExecutionQueue, ExecutionStatus, + GraphExecution, GraphExecutionEntry, NodeExecutionEntry, NodeExecutionResult, - merge_execution_input, - parse_execution_output, ) -from backend.data.graph import GraphModel, Link, Node +from backend.data.graph import Link, Node from backend.executor.utils import ( + GRAPH_EXECUTION_CANCEL_QUEUE_NAME, + GRAPH_EXECUTION_QUEUE_NAME, + CancelExecutionEvent, UsageTransactionMetadata, block_usage_cost, execution_usage_cost, + get_execution_event_bus, + get_execution_queue, + parse_execution_output, + validate_exec, ) from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.util import json from backend.util.decorator import error_logged, time_measured from backend.util.file import clean_exec_files from backend.util.logging import configure_logging -from backend.util.process import set_service_name -from backend.util.service import ( - AppService, - close_service_client, - expose, - get_service_client, -) +from backend.util.process import AppProcess, set_service_name +from backend.util.service import close_service_client, get_service_client from backend.util.settings import Settings -from backend.util.type import convert logger = logging.getLogger(__name__) settings = Settings() @@ -152,7 +148,7 @@ def execute_node( def update_execution_status(status: ExecutionStatus) -> NodeExecutionResult: """Sets status and fetches+broadcasts the latest state of the node execution""" exec_update = db_client.update_node_execution_status(node_exec_id, status) - db_client.send_execution_update(exec_update) + send_execution_update(exec_update) return exec_update node = db_client.get_node(node_id) @@ -288,7 +284,7 @@ def _enqueue_next_nodes( exec_update = db_client.update_node_execution_status( node_exec_id, ExecutionStatus.QUEUED, data ) - db_client.send_execution_update(exec_update) + send_execution_update(exec_update) return NodeExecutionEntry( user_id=user_id, graph_exec_id=graph_exec_id, @@ -400,60 +396,6 @@ def _enqueue_next_nodes( ] -def validate_exec( - node: Node, - data: BlockInput, - resolve_input: bool = True, -) -> tuple[BlockInput | None, str]: - """ - Validate the input data for a node execution. - - Args: - node: The node to execute. - data: The input data for the node execution. - resolve_input: Whether to resolve dynamic pins into dict/list/object. - - Returns: - A tuple of the validated data and the block name. - If the data is invalid, the first element will be None, and the second element - will be an error message. - If the data is valid, the first element will be the resolved input data, and - the second element will be the block name. - """ - node_block: Block | None = get_block(node.block_id) - if not node_block: - return None, f"Block for {node.block_id} not found." - schema = node_block.input_schema - - # Convert non-matching data types to the expected input schema. - for name, data_type in schema.__annotations__.items(): - if (value := data.get(name)) and (type(value) is not data_type): - data[name] = convert(value, data_type) - - # Input data (without default values) should contain all required fields. - error_prefix = f"Input data missing or mismatch for `{node_block.name}`:" - if missing_links := schema.get_missing_links(data, node.input_links): - return None, f"{error_prefix} unpopulated links {missing_links}" - - # Merge input data with default values and resolve dynamic dict/list/object pins. - input_default = schema.get_input_defaults(node.input_default) - data = {**input_default, **data} - if resolve_input: - data = merge_execution_input(data) - - # Input data post-merge should contain all required fields from the schema. - if missing_input := schema.get_missing_input(data): - return None, f"{error_prefix} missing input {missing_input}" - - # Last validation: Validate the input values against the schema. - if error := schema.get_mismatch_error(data): - error_message = f"{error_prefix} {error}" - logger.error(error_message) - return None, error_message - - return data, node_block.name - - class Executor: """ This class contains event handlers for the process pool executor events. @@ -633,7 +575,13 @@ class Executor: exec_meta = cls.db_client.update_graph_execution_start_time( graph_exec.graph_exec_id ) - cls.db_client.send_execution_update(exec_meta) + if exec_meta is None: + logger.warning( + f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution is not found or not currently in the QUEUED state." + ) + return + + send_execution_update(exec_meta) timing_info, (exec_stats, status, error) = cls._on_graph_execution( graph_exec, cancel, log_metadata ) @@ -646,7 +594,7 @@ class Executor: status=status, stats=exec_stats, ): - cls.db_client.send_execution_update(graph_exec_result) + send_execution_update(graph_exec_result) cls._handle_agent_run_notif(graph_exec, exec_stats) @@ -759,7 +707,7 @@ class Executor: status=execution_status, stats=execution_stats, ): - cls.db_client.send_execution_update(_graph_exec) + send_execution_update(_graph_exec) else: logger.error( "Callback for " @@ -810,7 +758,7 @@ class Executor: exec_update = cls.db_client.update_node_execution_status( node_exec_id, execution_status ) - cls.db_client.send_execution_update(exec_update) + send_execution_update(exec_update) cls._handle_low_balance_notif( graph_exec.user_id, @@ -927,22 +875,25 @@ class Executor: ) -class ExecutionManager(AppService): +class ExecutionManager(AppProcess): def __init__(self): super().__init__() self.pool_size = settings.config.num_graph_workers - self.queue = ExecutionQueue[GraphExecutionEntry]() + self.running = True self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {} @classmethod def get_port(cls) -> int: return settings.config.execution_manager_port - def run_service(self): - from backend.integrations.credentials_store import IntegrationCredentialsStore - - self.credentials_store = IntegrationCredentialsStore() + def run(self): + while True: + try: + self._run() + except Exception: + logger.exception(f"[{self.service_name}] error in graph executor loop") + def _run(self): logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...") self.executor = ProcessPoolExecutor( max_workers=self.pool_size, @@ -952,25 +903,103 @@ class ExecutionManager(AppService): logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...") redis.connect() - sync_manager = multiprocessing.Manager() + logger.info(f"[{self.service_name}] Ready to consume messages...") while True: - graph_exec_data = self.queue.get() - graph_exec_id = graph_exec_data.graph_exec_id - logger.debug( - f"[ExecutionManager] Dispatching graph execution {graph_exec_id}" + channel = get_execution_queue().get_channel() + + # cancel graph execution requests + method_frame, _, body = channel.basic_get( + queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME, + auto_ack=True, ) - cancel_event = sync_manager.Event() - future = self.executor.submit( - Executor.on_graph_execution, graph_exec_data, cancel_event + if method_frame: + self._handle_cancel_message(body) + + # start graph execution requests + method_frame, _, body = channel.basic_get( + queue=GRAPH_EXECUTION_QUEUE_NAME, + auto_ack=False, ) - self.active_graph_runs[graph_exec_id] = (future, cancel_event) - future.add_done_callback( - lambda _: self.active_graph_runs.pop(graph_exec_id, None) + if method_frame: + self._handle_run_message(channel, method_frame, body) + else: + time.sleep(0.2) + + def _handle_cancel_message(self, body: bytes): + try: + request = CancelExecutionEvent.model_validate_json(body) + graph_exec_id = request.graph_exec_id + if not graph_exec_id: + logger.warning( + f"[{self.service_name}] Cancel message missing 'graph_exec_id'" + ) + return + if graph_exec_id not in self.active_graph_runs: + logger.debug( + f"[{self.service_name}] Cancel received for {graph_exec_id} but not active." + ) + return + + _, cancel_event = self.active_graph_runs[graph_exec_id] + logger.info(f"[{self.service_name}] Received cancel for {graph_exec_id}") + if not cancel_event.is_set(): + cancel_event.set() + else: + logger.debug( + f"[{self.service_name}] Cancel already set for {graph_exec_id}" + ) + + except Exception as e: + logger.exception(f"Error handling cancel message: {e}") + + def _handle_run_message( + self, channel: BlockingChannel, method_frame: Basic.GetOk, body: bytes + ): + delivery_tag = method_frame.delivery_tag + try: + graph_exec_entry = GraphExecutionEntry.model_validate_json(body) + except Exception as e: + logger.error(f"[{self.service_name}] Could not parse run message: {e}") + channel.basic_nack(delivery_tag, requeue=False) + return + + graph_exec_id = graph_exec_entry.graph_exec_id + logger.info( + f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}" + ) + if graph_exec_id in self.active_graph_runs: + logger.warning( + f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run." ) + channel.basic_nack(delivery_tag, requeue=False) + return + + cancel_event = multiprocessing.Manager().Event() + future = self.executor.submit( + Executor.on_graph_execution, graph_exec_entry, cancel_event + ) + self.active_graph_runs[graph_exec_id] = (future, cancel_event) + + def _on_run_done(f: Future): + logger.info(f"[{self.service_name}] Run completed for {graph_exec_id}") + try: + channel.basic_ack(delivery_tag) + self.active_graph_runs.pop(graph_exec_id, None) + if f.exception(): + logger.error( + f"[{self.service_name}] Execution for {graph_exec_id} failed: {f.exception()}" + ) + except Exception as e: + logger.error(f"[{self.service_name}] Error acknowledging message: {e}") + + future.add_done_callback(_on_run_done) def cleanup(self): super().cleanup() + logger.info(f"[{self.service_name}] ⏳ Shutting down service loop...") + self.running = False + logger.info(f"[{self.service_name}] ⏳ Shutting down graph executor pool...") self.executor.shutdown(cancel_futures=True) @@ -981,175 +1010,6 @@ class ExecutionManager(AppService): def db_client(self) -> "DatabaseManager": return get_db_client() - @expose - def add_execution( - self, - graph_id: str, - data: BlockInput, - user_id: str, - graph_version: Optional[int] = None, - preset_id: str | None = None, - ) -> GraphExecutionEntry: - graph: GraphModel | None = self.db_client.get_graph( - graph_id=graph_id, user_id=user_id, version=graph_version - ) - if not graph: - raise ValueError(f"Graph #{graph_id} not found.") - - graph.validate_graph(for_run=True) - self._validate_node_input_credentials(graph, user_id) - - nodes_input = [] - for node in graph.starting_nodes: - input_data = {} - block = node.block - - # Note block should never be executed. - if block.block_type == BlockType.NOTE: - continue - - # Extract request input data, and assign it to the input pin. - if block.block_type == BlockType.INPUT: - input_name = node.input_default.get("name") - if input_name and input_name in data: - input_data = {"value": data[input_name]} - - # Extract webhook payload, and assign it to the input pin - webhook_payload_key = f"webhook_{node.webhook_id}_payload" - if ( - block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL) - and node.webhook_id - ): - if webhook_payload_key not in data: - raise ValueError( - f"Node {block.name} #{node.id} webhook payload is missing" - ) - input_data = {"payload": data[webhook_payload_key]} - - input_data, error = validate_exec(node, input_data) - if input_data is None: - raise ValueError(error) - else: - nodes_input.append((node.id, input_data)) - - if not nodes_input: - raise ValueError( - "No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes." - ) - - graph_exec = self.db_client.create_graph_execution( - graph_id=graph_id, - graph_version=graph.version, - nodes_input=nodes_input, - user_id=user_id, - preset_id=preset_id, - ) - self.db_client.send_execution_update(graph_exec) - - graph_exec_entry = GraphExecutionEntry( - user_id=user_id, - graph_id=graph_id, - graph_version=graph_version or 0, - graph_exec_id=graph_exec.id, - start_node_execs=[ - NodeExecutionEntry( - user_id=user_id, - graph_exec_id=node_exec.graph_exec_id, - graph_id=node_exec.graph_id, - node_exec_id=node_exec.node_exec_id, - node_id=node_exec.node_id, - block_id=node_exec.block_id, - data=node_exec.input_data, - ) - for node_exec in graph_exec.node_executions - ], - ) - self.queue.add(graph_exec_entry) - - return graph_exec_entry - - @expose - def cancel_execution(self, graph_exec_id: str) -> None: - """ - Mechanism: - 1. Set the cancel event - 2. Graph executor's cancel handler thread detects the event, terminates workers, - reinitializes worker pool, and returns. - 3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`. - """ - if graph_exec_id not in self.active_graph_runs: - logger.warning( - f"Graph execution #{graph_exec_id} not active/running: " - "possibly already completed/cancelled." - ) - else: - future, cancel_event = self.active_graph_runs[graph_exec_id] - if not cancel_event.is_set(): - cancel_event.set() - future.result() - - # Update the status of the graph & node executions - self.db_client.update_graph_execution_stats( - graph_exec_id, - ExecutionStatus.TERMINATED, - ) - node_execs = self.db_client.get_node_execution_results( - graph_exec_id=graph_exec_id, - statuses=[ - ExecutionStatus.QUEUED, - ExecutionStatus.RUNNING, - ExecutionStatus.INCOMPLETE, - ], - ) - self.db_client.update_node_execution_status_batch( - [node_exec.node_exec_id for node_exec in node_execs], - ExecutionStatus.TERMINATED, - ) - for node_exec in node_execs: - node_exec.status = ExecutionStatus.TERMINATED - self.db_client.send_execution_update(node_exec) - - def _validate_node_input_credentials(self, graph: GraphModel, user_id: str): - """Checks all credentials for all nodes of the graph""" - - for node in graph.nodes: - block = node.block - - # Find any fields of type CredentialsMetaInput - credentials_fields = cast( - type[BlockSchema], block.input_schema - ).get_credentials_fields() - if not credentials_fields: - continue - - for field_name, credentials_meta_type in credentials_fields.items(): - credentials_meta = credentials_meta_type.model_validate( - node.input_default[field_name] - ) - # Fetch the corresponding Credentials and perform sanity checks - credentials = self.credentials_store.get_creds_by_id( - user_id, credentials_meta.id - ) - if not credentials: - raise ValueError( - f"Unknown credentials #{credentials_meta.id} " - f"for node #{node.id} input '{field_name}'" - ) - if ( - credentials.provider != credentials_meta.provider - or credentials.type != credentials_meta.type - ): - logger.warning( - f"Invalid credentials #{credentials.id} for node #{node.id}: " - "type/provider mismatch: " - f"{credentials_meta.type}<>{credentials.type};" - f"{credentials_meta.provider}<>{credentials.provider}" - ) - raise ValueError( - f"Invalid credentials #{credentials.id} for node #{node.id}: " - "type/provider mismatch" - ) - # ------- UTILITIES ------- # @@ -1168,6 +1028,10 @@ def get_notification_service() -> "NotificationManager": return get_service_client(NotificationManager) +def send_execution_update(entry: GraphExecution | NodeExecutionResult): + return get_execution_event_bus().publish(entry) + + @contextmanager def synchronized(key: str, timeout: int = 60): lock: RedisLock = redis.get_redis().lock(f"lock:{key}", timeout=timeout) diff --git a/autogpt_platform/backend/backend/executor/scheduler.py b/autogpt_platform/backend/backend/executor/scheduler.py index 3b6a74c360..09a10a56a2 100644 --- a/autogpt_platform/backend/backend/executor/scheduler.py +++ b/autogpt_platform/backend/backend/executor/scheduler.py @@ -16,7 +16,7 @@ from pydantic import BaseModel from sqlalchemy import MetaData, create_engine from backend.data.block import BlockInput -from backend.executor.manager import ExecutionManager +from backend.executor import utils as execution_utils from backend.notifications.notifications import NotificationManager from backend.util.service import AppService, expose, get_service_client from backend.util.settings import Config @@ -57,11 +57,6 @@ def job_listener(event): log(f"Job {event.job_id} completed successfully.") -@thread_cached -def get_execution_client() -> ExecutionManager: - return get_service_client(ExecutionManager) - - @thread_cached def get_notification_client(): from backend.notifications import NotificationManager @@ -73,7 +68,7 @@ def execute_graph(**kwargs): args = ExecutionJobArgs(**kwargs) try: log(f"Executing recurring job for graph #{args.graph_id}") - get_execution_client().add_execution( + execution_utils.add_graph_execution( graph_id=args.graph_id, data=args.input_data, user_id=args.user_id, @@ -164,11 +159,6 @@ class Scheduler(AppService): def db_pool_size(cls) -> int: return config.scheduler_db_pool_size - @property - @thread_cached - def execution_client(self) -> ExecutionManager: - return get_service_client(ExecutionManager) - @property @thread_cached def notification_client(self) -> NotificationManager: diff --git a/autogpt_platform/backend/backend/executor/utils.py b/autogpt_platform/backend/backend/executor/utils.py index 9952510690..0baee372b1 100644 --- a/autogpt_platform/backend/backend/executor/utils.py +++ b/autogpt_platform/backend/backend/executor/utils.py @@ -1,11 +1,70 @@ +import logging +from typing import TYPE_CHECKING, Any, cast + +from autogpt_libs.utils.cache import thread_cached from pydantic import BaseModel -from backend.data.block import Block, BlockInput +from backend.data.block import ( + Block, + BlockData, + BlockInput, + BlockSchema, + BlockType, + get_block, +) from backend.data.block_cost_config import BLOCK_COSTS from backend.data.cost import BlockCostType +from backend.data.execution import GraphExecutionEntry, RedisExecutionEventBus +from backend.data.graph import GraphModel, Node +from backend.data.rabbitmq import ( + Exchange, + ExchangeType, + Queue, + RabbitMQConfig, + SyncRabbitMQ, +) +from backend.util.mock import MockObject +from backend.util.service import get_service_client from backend.util.settings import Config +from backend.util.type import convert + +if TYPE_CHECKING: + from backend.executor import DatabaseManager + from backend.integrations.credentials_store import IntegrationCredentialsStore config = Config() +logger = logging.getLogger(__name__) + +# ============ Resource Helpers ============ # + + +@thread_cached +def get_execution_event_bus() -> RedisExecutionEventBus: + return RedisExecutionEventBus() + + +@thread_cached +def get_execution_queue() -> SyncRabbitMQ: + client = SyncRabbitMQ(create_execution_queue_config()) + client.connect() + return client + + +@thread_cached +def get_integration_credentials_store() -> "IntegrationCredentialsStore": + from backend.integrations.credentials_store import IntegrationCredentialsStore + + return IntegrationCredentialsStore() + + +@thread_cached +def get_db_client() -> "DatabaseManager": + from backend.executor import DatabaseManager + + return get_service_client(DatabaseManager) + + +# ============ Execution Cost Helpers ============ # class UsageTransactionMetadata(BaseModel): @@ -95,3 +154,398 @@ def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bo or (input_data.get(k) and _is_cost_filter_match(v, input_data[k])) for k, v in cost_filter.items() ) + + +# ============ Execution Input Helpers ============ # + +LIST_SPLIT = "_$_" +DICT_SPLIT = "_#_" +OBJC_SPLIT = "_@_" + + +def parse_execution_output(output: BlockData, name: str) -> Any | None: + """ + Extracts partial output data by name from a given BlockData. + + The function supports extracting data from lists, dictionaries, and objects + using specific naming conventions: + - For lists: _$_ + - For dictionaries: _#_ + - For objects: _@_ + + Args: + output (BlockData): A tuple containing the output name and data. + name (str): The name used to extract specific data from the output. + + Returns: + Any | None: The extracted data if found, otherwise None. + + Examples: + >>> output = ("result", [10, 20, 30]) + >>> parse_execution_output(output, "result_$_1") + 20 + + >>> output = ("config", {"key1": "value1", "key2": "value2"}) + >>> parse_execution_output(output, "config_#_key1") + 'value1' + + >>> class Sample: + ... attr1 = "value1" + ... attr2 = "value2" + >>> output = ("object", Sample()) + >>> parse_execution_output(output, "object_@_attr1") + 'value1' + """ + output_name, output_data = output + + if name == output_name: + return output_data + + if name.startswith(f"{output_name}{LIST_SPLIT}"): + index = int(name.split(LIST_SPLIT)[1]) + if not isinstance(output_data, list) or len(output_data) <= index: + return None + return output_data[int(name.split(LIST_SPLIT)[1])] + + if name.startswith(f"{output_name}{DICT_SPLIT}"): + index = name.split(DICT_SPLIT)[1] + if not isinstance(output_data, dict) or index not in output_data: + return None + return output_data[index] + + if name.startswith(f"{output_name}{OBJC_SPLIT}"): + index = name.split(OBJC_SPLIT)[1] + if isinstance(output_data, object) and hasattr(output_data, index): + return getattr(output_data, index) + return None + + return None + + +def validate_exec( + node: Node, + data: BlockInput, + resolve_input: bool = True, +) -> tuple[BlockInput | None, str]: + """ + Validate the input data for a node execution. + + Args: + node: The node to execute. + data: The input data for the node execution. + resolve_input: Whether to resolve dynamic pins into dict/list/object. + + Returns: + A tuple of the validated data and the block name. + If the data is invalid, the first element will be None, and the second element + will be an error message. + If the data is valid, the first element will be the resolved input data, and + the second element will be the block name. + """ + node_block: Block | None = get_block(node.block_id) + if not node_block: + return None, f"Block for {node.block_id} not found." + schema = node_block.input_schema + + # Convert non-matching data types to the expected input schema. + for name, data_type in schema.__annotations__.items(): + if (value := data.get(name)) and (type(value) is not data_type): + data[name] = convert(value, data_type) + + # Input data (without default values) should contain all required fields. + error_prefix = f"Input data missing or mismatch for `{node_block.name}`:" + if missing_links := schema.get_missing_links(data, node.input_links): + return None, f"{error_prefix} unpopulated links {missing_links}" + + # Merge input data with default values and resolve dynamic dict/list/object pins. + input_default = schema.get_input_defaults(node.input_default) + data = {**input_default, **data} + if resolve_input: + data = merge_execution_input(data) + + # Input data post-merge should contain all required fields from the schema. + if missing_input := schema.get_missing_input(data): + return None, f"{error_prefix} missing input {missing_input}" + + # Last validation: Validate the input values against the schema. + if error := schema.get_mismatch_error(data): + error_message = f"{error_prefix} {error}" + logger.error(error_message) + return None, error_message + + return data, node_block.name + + +def merge_execution_input(data: BlockInput) -> BlockInput: + """ + Merges dynamic input pins into a single list, dictionary, or object based on naming patterns. + + This function processes input keys that follow specific patterns to merge them into a unified structure: + - `_$_` for list inputs. + - `_#_` for dictionary inputs. + - `_@_` for object inputs. + + Args: + data (BlockInput): A dictionary containing input keys and their corresponding values. + + Returns: + BlockInput: A dictionary with merged inputs. + + Raises: + ValueError: If a list index is not an integer. + + Examples: + >>> data = { + ... "list_$_0": "a", + ... "list_$_1": "b", + ... "dict_#_key1": "value1", + ... "dict_#_key2": "value2", + ... "object_@_attr1": "value1", + ... "object_@_attr2": "value2" + ... } + >>> merge_execution_input(data) + { + "list": ["a", "b"], + "dict": {"key1": "value1", "key2": "value2"}, + "object": + } + """ + + # Merge all input with _$_ into a single list. + items = list(data.items()) + + for key, value in items: + if LIST_SPLIT not in key: + continue + name, index = key.split(LIST_SPLIT) + if not index.isdigit(): + raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.") + + data[name] = data.get(name, []) + if int(index) >= len(data[name]): + # Pad list with empty string on missing indices. + data[name].extend([""] * (int(index) - len(data[name]) + 1)) + data[name][int(index)] = value + + # Merge all input with _#_ into a single dict. + for key, value in items: + if DICT_SPLIT not in key: + continue + name, index = key.split(DICT_SPLIT) + data[name] = data.get(name, {}) + data[name][index] = value + + # Merge all input with _@_ into a single object. + for key, value in items: + if OBJC_SPLIT not in key: + continue + name, index = key.split(OBJC_SPLIT) + if name not in data or not isinstance(data[name], object): + data[name] = MockObject() + setattr(data[name], index, value) + + return data + + +def _validate_node_input_credentials(graph: GraphModel, user_id: str): + """Checks all credentials for all nodes of the graph""" + + for node in graph.nodes: + block = node.block + + # Find any fields of type CredentialsMetaInput + credentials_fields = cast( + type[BlockSchema], block.input_schema + ).get_credentials_fields() + if not credentials_fields: + continue + + for field_name, credentials_meta_type in credentials_fields.items(): + credentials_meta = credentials_meta_type.model_validate( + node.input_default[field_name] + ) + # Fetch the corresponding Credentials and perform sanity checks + credentials = get_integration_credentials_store().get_creds_by_id( + user_id, credentials_meta.id + ) + if not credentials: + raise ValueError( + f"Unknown credentials #{credentials_meta.id} " + f"for node #{node.id} input '{field_name}'" + ) + if ( + credentials.provider != credentials_meta.provider + or credentials.type != credentials_meta.type + ): + logger.warning( + f"Invalid credentials #{credentials.id} for node #{node.id}: " + "type/provider mismatch: " + f"{credentials_meta.type}<>{credentials.type};" + f"{credentials_meta.provider}<>{credentials.provider}" + ) + raise ValueError( + f"Invalid credentials #{credentials.id} for node #{node.id}: " + "type/provider mismatch" + ) + + +def construct_node_execution_input( + graph: GraphModel, + user_id: str, + data: BlockInput, +) -> list[tuple[str, BlockInput]]: + """ + Validates and prepares the input data for executing a graph. + This function checks the graph for starting nodes, validates the input data + against the schema, and resolves dynamic input pins into a single list, + dictionary, or object. + + Args: + graph (GraphModel): The graph model to execute. + user_id (str): The ID of the user executing the graph. + data (BlockInput): The input data for the graph execution. + + Returns: + list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and + the corresponding input data for that node. + """ + graph.validate_graph(for_run=True) + _validate_node_input_credentials(graph, user_id) + + nodes_input = [] + for node in graph.starting_nodes: + input_data = {} + block = node.block + + # Note block should never be executed. + if block.block_type == BlockType.NOTE: + continue + + # Extract request input data, and assign it to the input pin. + if block.block_type == BlockType.INPUT: + input_name = node.input_default.get("name") + if input_name and input_name in data: + input_data = {"value": data[input_name]} + + # Extract webhook payload, and assign it to the input pin + webhook_payload_key = f"webhook_{node.webhook_id}_payload" + if ( + block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL) + and node.webhook_id + ): + if webhook_payload_key not in data: + raise ValueError( + f"Node {block.name} #{node.id} webhook payload is missing" + ) + input_data = {"payload": data[webhook_payload_key]} + + input_data, error = validate_exec(node, input_data) + if input_data is None: + raise ValueError(error) + else: + nodes_input.append((node.id, input_data)) + + if not nodes_input: + raise ValueError( + "No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes." + ) + + return nodes_input + + +# ============ Execution Queue Helpers ============ # + + +class CancelExecutionEvent(BaseModel): + graph_exec_id: str + + +GRAPH_EXECUTION_EXCHANGE = Exchange( + name="graph_execution", + type=ExchangeType.DIRECT, + durable=True, + auto_delete=False, +) +GRAPH_EXECUTION_QUEUE_NAME = "graph_execution_queue" +GRAPH_EXECUTION_ROUTING_KEY = "graph_execution.run" + +GRAPH_EXECUTION_CANCEL_EXCHANGE = Exchange( + name="graph_execution_cancel", + type=ExchangeType.FANOUT, + durable=True, + auto_delete=True, +) +GRAPH_EXECUTION_CANCEL_QUEUE_NAME = "graph_execution_cancel_queue" + + +def create_execution_queue_config() -> RabbitMQConfig: + """ + Define two exchanges and queues: + - 'graph_execution' (DIRECT) for run tasks. + - 'graph_execution_cancel' (FANOUT) for cancel requests. + """ + run_queue = Queue( + name=GRAPH_EXECUTION_QUEUE_NAME, + exchange=GRAPH_EXECUTION_EXCHANGE, + routing_key=GRAPH_EXECUTION_ROUTING_KEY, + durable=True, + auto_delete=False, + ) + cancel_queue = Queue( + name=GRAPH_EXECUTION_CANCEL_QUEUE_NAME, + exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE, + routing_key="", # not used for FANOUT + durable=True, + auto_delete=False, + ) + return RabbitMQConfig( + vhost="/", + exchanges=[GRAPH_EXECUTION_EXCHANGE, GRAPH_EXECUTION_CANCEL_EXCHANGE], + queues=[run_queue, cancel_queue], + ) + + +def add_graph_execution( + graph_id: str, + data: BlockInput, + user_id: str, + graph_version: int | None = None, + preset_id: str | None = None, +) -> GraphExecutionEntry: + """ + Adds a graph execution to the queue and returns the execution entry. + + Args: + graph_id (str): The ID of the graph to execute. + data (BlockInput): The input data for the graph execution. + user_id (str): The ID of the user executing the graph. + graph_version (int | None): The version of the graph to execute. Defaults to None. + preset_id (str | None): The ID of the preset to use. Defaults to None. + Returns: + GraphExecutionEntry: The entry for the graph execution. + Raises: + ValueError: If the graph is not found or if there are validation errors. + """ + graph: GraphModel | None = get_db_client().get_graph( + graph_id=graph_id, user_id=user_id, version=graph_version + ) + if not graph: + raise ValueError(f"Graph #{graph_id} not found.") + + graph_exec = get_db_client().create_graph_execution( + graph_id=graph_id, + graph_version=graph.version, + nodes_input=construct_node_execution_input(graph, user_id, data), + user_id=user_id, + preset_id=preset_id, + ) + get_execution_event_bus().publish(graph_exec) + + graph_exec_entry = graph_exec.to_graph_execution_entry() + get_execution_queue().publish_message( + routing_key=GRAPH_EXECUTION_ROUTING_KEY, + message=graph_exec_entry.model_dump_json(), + exchange=GRAPH_EXECUTION_EXCHANGE, + ) + + return graph_exec_entry diff --git a/autogpt_platform/backend/backend/server/external/routes/v1.py b/autogpt_platform/backend/backend/server/external/routes/v1.py index 68760b1a54..6da2a13123 100644 --- a/autogpt_platform/backend/backend/server/external/routes/v1.py +++ b/autogpt_platform/backend/backend/server/external/routes/v1.py @@ -2,7 +2,6 @@ import logging from collections import defaultdict from typing import Annotated, Any, Dict, List, Optional, Sequence -from autogpt_libs.utils.cache import thread_cached from fastapi import APIRouter, Body, Depends, HTTPException from prisma.enums import AgentExecutionStatus, APIKeyPermission from typing_extensions import TypedDict @@ -13,17 +12,10 @@ from backend.data import graph as graph_db from backend.data.api_key import APIKey from backend.data.block import BlockInput, CompletedBlockOutput from backend.data.execution import NodeExecutionResult -from backend.executor import ExecutionManager from backend.server.external.middleware import require_permission -from backend.util.service import get_service_client +from backend.server.routers import v1 as internal_api_routes from backend.util.settings import Settings - -@thread_cached -def execution_manager_client() -> ExecutionManager: - return get_service_client(ExecutionManager) - - settings = Settings() logger = logging.getLogger(__name__) @@ -98,18 +90,18 @@ def execute_graph_block( path="/graphs/{graph_id}/execute/{graph_version}", tags=["graphs"], ) -def execute_graph( +async def execute_graph( graph_id: str, graph_version: int, node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)], api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)), ) -> dict[str, Any]: try: - graph_exec = execution_manager_client().add_execution( - graph_id, - graph_version=graph_version, - data=node_input, + graph_exec = await internal_api_routes.execute_graph( + graph_id=graph_id, + node_input=node_input, user_id=api_key.user_id, + graph_version=graph_version, ) return {"id": graph_exec.graph_exec_id} except Exception as e: diff --git a/autogpt_platform/backend/backend/server/integrations/router.py b/autogpt_platform/backend/backend/server/integrations/router.py index 00e3cedd7d..8efe7673fd 100644 --- a/autogpt_platform/backend/backend/server/integrations/router.py +++ b/autogpt_platform/backend/backend/server/integrations/router.py @@ -1,3 +1,4 @@ +import asyncio import logging from typing import TYPE_CHECKING, Annotated, Literal @@ -14,13 +15,12 @@ from backend.data.integrations import ( wait_for_webhook_event, ) from backend.data.model import Credentials, CredentialsType, OAuth2Credentials -from backend.executor.manager import ExecutionManager from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.oauth import HANDLERS_BY_NAME from backend.integrations.providers import ProviderName from backend.integrations.webhooks import get_webhook_manager +from backend.server.routers import v1 as internal_api_routes from backend.util.exceptions import NeedConfirmation, NotFoundError -from backend.util.service import get_service_client from backend.util.settings import Settings if TYPE_CHECKING: @@ -309,19 +309,22 @@ async def webhook_ingress_generic( if not webhook.attached_nodes: return - executor = get_service_client(ExecutionManager) + executions = [] for node in webhook.attached_nodes: logger.debug(f"Webhook-attached node: {node}") if not node.is_triggered_by_event_type(event_type): logger.debug(f"Node #{node.id} doesn't trigger on event {event_type}") continue logger.debug(f"Executing graph #{node.graph_id} node #{node.id}") - executor.add_execution( - graph_id=node.graph_id, - graph_version=node.graph_version, - data={f"webhook_{webhook_id}_payload": payload}, - user_id=webhook.user_id, + executions.append( + internal_api_routes.execute_graph( + graph_id=node.graph_id, + graph_version=node.graph_version, + node_input={f"webhook_{webhook_id}_payload": payload}, + user_id=webhook.user_id, + ) ) + asyncio.gather(*executions) @router.post("/webhooks/{webhook_id}/ping") diff --git a/autogpt_platform/backend/backend/server/rest_api.py b/autogpt_platform/backend/backend/server/rest_api.py index 4c211420ed..4ca11d3703 100644 --- a/autogpt_platform/backend/backend/server/rest_api.py +++ b/autogpt_platform/backend/backend/server/rest_api.py @@ -17,7 +17,6 @@ import backend.data.block import backend.data.db import backend.data.graph import backend.data.user -import backend.server.integrations.router import backend.server.routers.postmark.postmark import backend.server.routers.v1 import backend.server.v2.admin.store_admin_routes @@ -156,7 +155,7 @@ class AgentServer(backend.util.service.AppProcess): graph_version: Optional[int] = None, node_input: Optional[dict[str, Any]] = None, ): - return backend.server.routers.v1.execute_graph( + return await backend.server.routers.v1.execute_graph( user_id=user_id, graph_id=graph_id, graph_version=graph_version, @@ -275,7 +274,9 @@ class AgentServer(backend.util.service.AppProcess): provider: ProviderName, credentials: Credentials, ) -> Credentials: - return backend.server.integrations.router.create_credentials( + from backend.server.integrations.router import create_credentials + + return create_credentials( user_id=user_id, provider=provider, credentials=credentials ) diff --git a/autogpt_platform/backend/backend/server/routers/v1.py b/autogpt_platform/backend/backend/server/routers/v1.py index 7dc0efefdd..2bcc522396 100644 --- a/autogpt_platform/backend/backend/server/routers/v1.py +++ b/autogpt_platform/backend/backend/server/routers/v1.py @@ -2,7 +2,7 @@ import asyncio import logging from collections import defaultdict from datetime import datetime -from typing import TYPE_CHECKING, Annotated, Any, Sequence +from typing import TYPE_CHECKING, Annotated, Any, Coroutine, Sequence import pydantic import stripe @@ -13,7 +13,6 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND from typing_extensions import Optional, TypedDict -import backend.data.block import backend.server.integrations.router import backend.server.routers.analytics import backend.server.v2.library.db as library_db @@ -31,7 +30,7 @@ from backend.data.api_key import ( suspend_api_key, update_api_key_permissions, ) -from backend.data.block import BlockInput, CompletedBlockOutput +from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks from backend.data.credit import ( AutoTopUpConfig, RefundRequest, @@ -41,6 +40,7 @@ from backend.data.credit import ( get_user_credit_model, set_auto_top_up, ) +from backend.data.execution import AsyncRedisExecutionEventBus from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO from backend.data.onboarding import ( UserOnboardingUpdate, @@ -49,13 +49,16 @@ from backend.data.onboarding import ( onboarding_enabled, update_user_onboarding, ) +from backend.data.rabbitmq import AsyncRabbitMQ from backend.data.user import ( get_or_create_user, get_user_notification_preference, update_user_email, update_user_notification_preference, ) -from backend.executor import ExecutionManager, Scheduler, scheduler +from backend.executor import Scheduler, scheduler +from backend.executor import utils as execution_utils +from backend.executor.utils import create_execution_queue_config from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.webhooks.graph_lifecycle_hooks import ( on_graph_activate, @@ -79,13 +82,23 @@ if TYPE_CHECKING: @thread_cached -def execution_manager_client() -> ExecutionManager: - return get_service_client(ExecutionManager) +def execution_scheduler_client() -> Scheduler: + return get_service_client(Scheduler) @thread_cached -def execution_scheduler_client() -> Scheduler: - return get_service_client(Scheduler) +def execution_queue_client() -> Coroutine[None, None, AsyncRabbitMQ]: + async def f() -> AsyncRabbitMQ: + client = AsyncRabbitMQ(create_execution_queue_config()) + await client.connect() + return client + + return f() + + +@thread_cached +def execution_event_bus() -> AsyncRedisExecutionEventBus: + return AsyncRedisExecutionEventBus() settings = Settings() @@ -206,7 +219,7 @@ async def is_onboarding_enabled(): @v1_router.get(path="/blocks", tags=["blocks"], dependencies=[Depends(auth_middleware)]) def get_graph_blocks() -> Sequence[dict[Any, Any]]: - blocks = [block() for block in backend.data.block.get_blocks().values()] + blocks = [block() for block in get_blocks().values()] costs = get_block_costs() return [ {**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks if not b.disabled @@ -219,7 +232,7 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]: dependencies=[Depends(auth_middleware)], ) def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput: - obj = backend.data.block.get_block(block_id) + obj = get_block(block_id) if not obj: raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.") @@ -308,7 +321,7 @@ async def configure_user_auto_top_up( dependencies=[Depends(auth_middleware)], ) async def get_user_auto_top_up( - user_id: Annotated[str, Depends(get_user_id)] + user_id: Annotated[str, Depends(get_user_id)], ) -> AutoTopUpConfig: return await get_auto_top_up(user_id) @@ -375,7 +388,7 @@ async def get_credit_history( @v1_router.get(path="/credits/refunds", dependencies=[Depends(auth_middleware)]) async def get_refund_requests( - user_id: Annotated[str, Depends(get_user_id)] + user_id: Annotated[str, Depends(get_user_id)], ) -> list[RefundRequest]: return await _user_credit_model.get_refund_requests(user_id) @@ -391,7 +404,7 @@ class DeleteGraphResponse(TypedDict): @v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)]) async def get_graphs( - user_id: Annotated[str, Depends(get_user_id)] + user_id: Annotated[str, Depends(get_user_id)], ) -> Sequence[graph_db.GraphModel]: return await graph_db.get_graphs(filter_by="active", user_id=user_id) @@ -580,16 +593,35 @@ async def set_graph_active_version( tags=["graphs"], dependencies=[Depends(auth_middleware)], ) -def execute_graph( +async def execute_graph( graph_id: str, node_input: Annotated[dict[str, Any], Body(..., default_factory=dict)], user_id: Annotated[str, Depends(get_user_id)], graph_version: Optional[int] = None, + preset_id: Optional[str] = None, ) -> ExecuteGraphResponse: - graph_exec = execution_manager_client().add_execution( - graph_id, node_input, user_id=user_id, graph_version=graph_version + graph: graph_db.GraphModel | None = await graph_db.get_graph( + graph_id=graph_id, user_id=user_id, version=graph_version ) - return ExecuteGraphResponse(graph_exec_id=graph_exec.graph_exec_id) + if not graph: + raise ValueError(f"Graph #{graph_id} not found.") + + graph_exec = await execution_db.create_graph_execution( + graph_id=graph_id, + graph_version=graph.version, + nodes_input=execution_utils.construct_node_execution_input( + graph, user_id, node_input + ), + user_id=user_id, + preset_id=preset_id, + ) + execution_utils.get_execution_event_bus().publish(graph_exec) + execution_utils.get_execution_queue().publish_message( + routing_key=execution_utils.GRAPH_EXECUTION_ROUTING_KEY, + message=graph_exec.to_graph_execution_entry().model_dump_json(), + exchange=execution_utils.GRAPH_EXECUTION_EXCHANGE, + ) + return ExecuteGraphResponse(graph_exec_id=graph_exec.id) @v1_router.post( @@ -605,9 +637,7 @@ async def stop_graph_run( ): raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found") - await asyncio.to_thread( - lambda: execution_manager_client().cancel_execution(graph_exec_id) - ) + await _cancel_execution(graph_exec_id) # Retrieve & return canceled graph execution in its final state result = await execution_db.get_graph_execution( @@ -621,6 +651,49 @@ async def stop_graph_run( return result +async def _cancel_execution(graph_exec_id: str): + """ + Mechanism: + 1. Set the cancel event + 2. Graph executor's cancel handler thread detects the event, terminates workers, + reinitializes worker pool, and returns. + 3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`. + """ + queue_client = await execution_queue_client() + await queue_client.publish_message( + routing_key="", + message=execution_utils.CancelExecutionEvent( + graph_exec_id=graph_exec_id + ).model_dump_json(), + exchange=execution_utils.GRAPH_EXECUTION_CANCEL_EXCHANGE, + ) + + # Update the status of the graph & node executions + await execution_db.update_graph_execution_stats( + graph_exec_id, + execution_db.ExecutionStatus.TERMINATED, + ) + node_execs = [ + node_exec.model_copy(update={"status": execution_db.ExecutionStatus.TERMINATED}) + for node_exec in await execution_db.get_node_execution_results( + graph_exec_id=graph_exec_id, + statuses=[ + execution_db.ExecutionStatus.QUEUED, + execution_db.ExecutionStatus.RUNNING, + execution_db.ExecutionStatus.INCOMPLETE, + ], + ) + ] + + await execution_db.update_node_execution_status_batch( + [node_exec.node_exec_id for node_exec in node_execs], + execution_db.ExecutionStatus.TERMINATED, + ) + await asyncio.gather( + *[execution_event_bus().publish(node_exec) for node_exec in node_execs] + ) + + @v1_router.get( path="/executions", tags=["graphs"], @@ -792,7 +865,7 @@ async def create_api_key( dependencies=[Depends(auth_middleware)], ) async def get_api_keys( - user_id: Annotated[str, Depends(get_user_id)] + user_id: Annotated[str, Depends(get_user_id)], ) -> list[APIKeyWithoutHash]: """List all API keys for the user""" try: diff --git a/autogpt_platform/backend/backend/server/v2/library/routes/presets.py b/autogpt_platform/backend/backend/server/v2/library/routes/presets.py index 73d94cab46..0983a2ff94 100644 --- a/autogpt_platform/backend/backend/server/v2/library/routes/presets.py +++ b/autogpt_platform/backend/backend/server/v2/library/routes/presets.py @@ -2,25 +2,16 @@ import logging from typing import Annotated, Any import autogpt_libs.auth as autogpt_auth_lib -import autogpt_libs.utils.cache from fastapi import APIRouter, Body, Depends, HTTPException, status -import backend.executor import backend.server.v2.library.db as db import backend.server.v2.library.model as models -import backend.util.service logger = logging.getLogger(__name__) router = APIRouter() -@autogpt_libs.utils.cache.thread_cached -def execution_manager_client() -> backend.executor.ExecutionManager: - """Return a cached instance of ExecutionManager client.""" - return backend.util.service.get_service_client(backend.executor.ExecutionManager) - - @router.get( "/presets", summary="List presets", @@ -216,6 +207,8 @@ async def execute_preset( HTTPException: If the preset is not found or an error occurs while executing the preset. """ try: + from backend.server.routers import v1 as internal_api_routes + preset = await db.get_preset(user_id, preset_id) if not preset: raise HTTPException( @@ -226,10 +219,10 @@ async def execute_preset( # Merge input overrides with preset inputs merged_node_input = preset.inputs | node_input - execution = execution_manager_client().add_execution( + execution = await internal_api_routes.execute_graph( graph_id=graph_id, + node_input=merged_node_input, graph_version=graph_version, - data=merged_node_input, user_id=user_id, preset_id=preset_id, ) diff --git a/autogpt_platform/backend/backend/server/v2/store/db.py b/autogpt_platform/backend/backend/server/v2/store/db.py index 3538b8f777..233ed05421 100644 --- a/autogpt_platform/backend/backend/server/v2/store/db.py +++ b/autogpt_platform/backend/backend/server/v2/store/db.py @@ -12,7 +12,6 @@ import backend.server.v2.store.exceptions import backend.server.v2.store.model from backend.data.graph import GraphModel, get_sub_graphs from backend.data.includes import AGENT_GRAPH_INCLUDE -from backend.util.type import typed_cast logger = logging.getLogger(__name__) @@ -960,7 +959,7 @@ async def get_my_agents( try: search_filter: prisma.types.LibraryAgentWhereInput = { "userId": user_id, - "AgentGraph": {"is": {"StoreListing": {"none": {"isDeleted": False}}}}, + "AgentGraph": {"is": {"StoreListings": {"none": {"isDeleted": False}}}}, "isArchived": False, "isDeleted": False, } @@ -1088,13 +1087,7 @@ async def review_store_submission( where={"id": store_listing_version_id}, include={ "StoreListing": True, - "AgentGraph": { - "include": typed_cast( - prisma.types.AgentGraphIncludeFromAgentGraphRecursive1, - prisma.types.AgentGraphInclude, - AGENT_GRAPH_INCLUDE, - ) - }, + "AgentGraph": {"include": AGENT_GRAPH_INCLUDE}, }, ) ) diff --git a/autogpt_platform/backend/backend/util/json.py b/autogpt_platform/backend/backend/util/json.py index 106d79b99d..d71b61a246 100644 --- a/autogpt_platform/backend/backend/util/json.py +++ b/autogpt_platform/backend/backend/util/json.py @@ -15,21 +15,25 @@ def to_dict(data) -> dict: def dumps(data) -> str: - return json.dumps(jsonable_encoder(data)) + return json.dumps(to_dict(data)) T = TypeVar("T") @overload -def loads(data: str, *args, target_type: Type[T], **kwargs) -> T: ... +def loads(data: str | bytes, *args, target_type: Type[T], **kwargs) -> T: ... @overload -def loads(data: str, *args, **kwargs) -> Any: ... +def loads(data: str | bytes, *args, **kwargs) -> Any: ... -def loads(data: str, *args, target_type: Type[T] | None = None, **kwargs) -> Any: +def loads( + data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs +) -> Any: + if isinstance(data, bytes): + data = data.decode("utf-8") parsed = json.loads(data, *args, **kwargs) if target_type: return type_match(parsed, target_type) diff --git a/autogpt_platform/backend/backend/util/metrics.py b/autogpt_platform/backend/backend/util/metrics.py index f6a2ab5676..06b0ce776e 100644 --- a/autogpt_platform/backend/backend/util/metrics.py +++ b/autogpt_platform/backend/backend/util/metrics.py @@ -14,9 +14,7 @@ def sentry_init(): traces_sample_rate=1.0, profiles_sample_rate=1.0, environment=f"app:{Settings().config.app_env.value}-behave:{Settings().config.behave_as.value}", - _experiments={ - "enable_logs": True, - }, + _experiments={"enable_logs": True}, integrations=[ LoggingIntegration(sentry_logs_level=logging.INFO), AnthropicIntegration( diff --git a/autogpt_platform/backend/backend/util/type.py b/autogpt_platform/backend/backend/util/type.py index c47b9a7b99..b49da4f590 100644 --- a/autogpt_platform/backend/backend/util/type.py +++ b/autogpt_platform/backend/backend/util/type.py @@ -198,18 +198,6 @@ def convert(value: Any, target_type: Type[T]) -> T: raise ConversionError(f"Failed to convert {value} to {target_type}") from e -def typed(type: type[T], value: T) -> T: - """ - Add an explicit type to a value. Useful in nested statements, e.g. dict literals. - """ - return value - - -def typed_cast(to_type: type[TT], from_type: type[T], value: T) -> TT: - """Strict cast to preserve type checking abilities.""" - return cast(TT, value) - - class FormattedStringType(str): string_format: str diff --git a/autogpt_platform/backend/migrations/20250411130000_update_onboarding_step/migration.sql b/autogpt_platform/backend/migrations/20250411130000_update_onboarding_step/migration.sql new file mode 100644 index 0000000000..d4b80091e0 --- /dev/null +++ b/autogpt_platform/backend/migrations/20250411130000_update_onboarding_step/migration.sql @@ -0,0 +1,16 @@ +-- Modify the OnboardingStep enum +ALTER TYPE "OnboardingStep" ADD VALUE 'GET_RESULTS'; +ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_VISIT'; +ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_ADD_AGENT'; +ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_RUN_AGENT'; +ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_OPEN'; +ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_SAVE_AGENT'; +ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_RUN_AGENT'; + +-- Modify the UserOnboarding table +ALTER TABLE "UserOnboarding" + ADD COLUMN "updatedAt" TIMESTAMP(3), + ADD COLUMN "notificationDot" BOOLEAN NOT NULL DEFAULT true, + ADD COLUMN "notified" "OnboardingStep"[] DEFAULT '{}', + ADD COLUMN "rewardedFor" "OnboardingStep"[] DEFAULT '{}', + ADD COLUMN "onboardingAgentExecutionId" TEXT diff --git a/autogpt_platform/backend/schema.prisma b/autogpt_platform/backend/schema.prisma index 339d9d2df4..71ff441e1b 100644 --- a/autogpt_platform/backend/schema.prisma +++ b/autogpt_platform/backend/schema.prisma @@ -5,7 +5,7 @@ datasource db { generator client { provider = "prisma-client-py" - recursive_type_depth = 5 + recursive_type_depth = -1 interface = "asyncio" previewFeatures = ["views"] } @@ -58,6 +58,7 @@ model User { } enum OnboardingStep { + // Introductory onboarding (Library) WELCOME USAGE_REASON INTEGRATIONS @@ -65,18 +66,32 @@ enum OnboardingStep { AGENT_NEW_RUN AGENT_INPUT CONGRATS + GET_RESULTS + // Marketplace + MARKETPLACE_VISIT + MARKETPLACE_ADD_AGENT + MARKETPLACE_RUN_AGENT + // Builder + BUILDER_OPEN + BUILDER_SAVE_AGENT + BUILDER_RUN_AGENT } model UserOnboarding { id String @id @default(uuid()) createdAt DateTime @default(now()) + updatedAt DateTime? @updatedAt completedSteps OnboardingStep[] @default([]) + notificationDot Boolean @default(true) + notified OnboardingStep[] @default([]) + rewardedFor OnboardingStep[] @default([]) usageReason String? integrations String[] @default([]) otherIntegrations String? selectedStoreListingVersionId String? agentInput Json? + onboardingAgentExecutionId String? userId String @unique User User @relation(fields: [userId], references: [id], onDelete: Cascade) diff --git a/autogpt_platform/backend/test/executor/test_execution_functions.py b/autogpt_platform/backend/test/executor/test_execution_functions.py index 4fce8f49a8..9cdf8430e4 100644 --- a/autogpt_platform/backend/test/executor/test_execution_functions.py +++ b/autogpt_platform/backend/test/executor/test_execution_functions.py @@ -1,4 +1,4 @@ -from backend.data.execution import merge_execution_input, parse_execution_output +from backend.executor.utils import merge_execution_input, parse_execution_output def test_parse_execution_output(): diff --git a/autogpt_platform/backend/test/test_data_creator.py b/autogpt_platform/backend/test/test_data_creator.py index 512e977765..a61c836caf 100644 --- a/autogpt_platform/backend/test/test_data_creator.py +++ b/autogpt_platform/backend/test/test_data_creator.py @@ -10,16 +10,12 @@ from prisma.types import ( AgentGraphCreateInput, AgentNodeCreateInput, AgentNodeLinkCreateInput, - AgentPresetCreateInput, AnalyticsDetailsCreateInput, AnalyticsMetricsCreateInput, APIKeyCreateInput, CreditTransactionCreateInput, - LibraryAgentCreateInput, ProfileCreateInput, - StoreListingCreateInput, StoreListingReviewCreateInput, - StoreListingVersionCreateInput, UserCreateInput, ) @@ -140,14 +136,14 @@ async def main(): for _ in range(num_presets): # Create 1 AgentPreset per user graph = random.choice(agent_graphs) preset = await db.agentpreset.create( - data=AgentPresetCreateInput( - name=faker.sentence(nb_words=3), - description=faker.text(max_nb_chars=200), - userId=user.id, - agentId=graph.id, - agentVersion=graph.version, - isActive=True, - ) + data={ + "name": faker.sentence(nb_words=3), + "description": faker.text(max_nb_chars=200), + "userId": user.id, + "agentGraphId": graph.id, + "agentGraphVersion": graph.version, + "isActive": True, + } ) agent_presets.append(preset) @@ -160,16 +156,15 @@ async def main(): graph = random.choice(agent_graphs) preset = random.choice(agent_presets) user_agent = await db.libraryagent.create( - data=LibraryAgentCreateInput( - userId=user.id, - agentId=graph.id, - agentVersion=graph.version, - agentPresetId=preset.id, - isFavorite=random.choice([True, False]), - isCreatedByUser=random.choice([True, False]), - isArchived=random.choice([True, False]), - isDeleted=random.choice([True, False]), - ) + data={ + "userId": user.id, + "agentGraphId": graph.id, + "agentGraphVersion": graph.version, + "isFavorite": random.choice([True, False]), + "isCreatedByUser": random.choice([True, False]), + "isArchived": random.choice([True, False]), + "isDeleted": random.choice([True, False]), + } ) user_agents.append(user_agent) @@ -346,13 +341,13 @@ async def main(): user = random.choice(users) slug = faker.slug() listing = await db.storelisting.create( - data=StoreListingCreateInput( - agentId=graph.id, - agentVersion=graph.version, - owningUserId=user.id, - hasApprovedVersion=random.choice([True, False]), - slug=slug, - ) + data={ + "agentGraphId": graph.id, + "agentGraphVersion": graph.version, + "owningUserId": user.id, + "hasApprovedVersion": random.choice([True, False]), + "slug": slug, + } ) store_listings.append(listing) @@ -362,26 +357,26 @@ async def main(): for listing in store_listings: graph = [g for g in agent_graphs if g.id == listing.agentId][0] version = await db.storelistingversion.create( - data=StoreListingVersionCreateInput( - agentId=graph.id, - agentVersion=graph.version, - name=graph.name or faker.sentence(nb_words=3), - subHeading=faker.sentence(), - videoUrl=faker.url(), - imageUrls=[get_image() for _ in range(3)], - description=faker.text(), - categories=[faker.word() for _ in range(3)], - isFeatured=random.choice([True, False]), - isAvailable=True, - storeListingId=listing.id, - submissionStatus=random.choice( + data={ + "agentGraphId": graph.id, + "agentGraphVersion": graph.version, + "name": graph.name or faker.sentence(nb_words=3), + "subHeading": faker.sentence(), + "videoUrl": faker.url(), + "imageUrls": [get_image() for _ in range(3)], + "description": faker.text(), + "categories": [faker.word() for _ in range(3)], + "isFeatured": random.choice([True, False]), + "isAvailable": True, + "storeListingId": listing.id, + "submissionStatus": random.choice( [ prisma.enums.SubmissionStatus.PENDING, prisma.enums.SubmissionStatus.APPROVED, prisma.enums.SubmissionStatus.REJECTED, ] ), - ) + } ) store_listing_versions.append(version) @@ -422,23 +417,12 @@ async def main(): ) await db.storelistingversion.update( where={"id": version.id}, - data=StoreListingVersionCreateInput( - submissionStatus=status, - Reviewer={"connect": {"id": reviewer.id}}, - reviewComments=faker.text(), - reviewedAt=datetime.now(), - agentId=version.agentId, # preserving existing fields - agentVersion=version.agentVersion, - name=version.name, - subHeading=version.subHeading, - videoUrl=version.videoUrl, - imageUrls=version.imageUrls, - description=version.description, - categories=version.categories, - isFeatured=version.isFeatured, - isAvailable=version.isAvailable, - storeListingId=version.storeListingId, - ), + data={ + "submissionStatus": status, + "Reviewer": {"connect": {"id": reviewer.id}}, + "reviewComments": faker.text(), + "reviewedAt": datetime.now(), + }, ) # Insert APIKeys diff --git a/autogpt_platform/frontend/package.json b/autogpt_platform/frontend/package.json index 5747b3e55f..be7c5b88c8 100644 --- a/autogpt_platform/frontend/package.json +++ b/autogpt_platform/frontend/package.json @@ -52,7 +52,6 @@ "@xyflow/react": "12.4.2", "ajv": "^8.17.1", "boring-avatars": "^1.11.2", - "canvas-confetti": "^1.9.3", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", "cmdk": "1.0.4", @@ -70,6 +69,7 @@ "moment": "^2.30.1", "next": "^14.2.26", "next-themes": "^0.4.5", + "party-js": "^2.2.0", "react": "^18", "react-day-picker": "^9.6.1", "react-dom": "^18", diff --git a/autogpt_platform/frontend/public/onboarding/builder-open.mp4 b/autogpt_platform/frontend/public/onboarding/builder-open.mp4 new file mode 100644 index 0000000000..23444cc8a9 Binary files /dev/null and b/autogpt_platform/frontend/public/onboarding/builder-open.mp4 differ diff --git a/autogpt_platform/frontend/public/onboarding/builder-run.mp4 b/autogpt_platform/frontend/public/onboarding/builder-run.mp4 new file mode 100644 index 0000000000..c635a27aa5 Binary files /dev/null and b/autogpt_platform/frontend/public/onboarding/builder-run.mp4 differ diff --git a/autogpt_platform/frontend/public/onboarding/builder-save.mp4 b/autogpt_platform/frontend/public/onboarding/builder-save.mp4 new file mode 100644 index 0000000000..4bcf12e71b Binary files /dev/null and b/autogpt_platform/frontend/public/onboarding/builder-save.mp4 differ diff --git a/autogpt_platform/frontend/public/onboarding/get-results.mp4 b/autogpt_platform/frontend/public/onboarding/get-results.mp4 new file mode 100644 index 0000000000..787782f2f0 Binary files /dev/null and b/autogpt_platform/frontend/public/onboarding/get-results.mp4 differ diff --git a/autogpt_platform/frontend/public/onboarding/marketplace-add.mp4 b/autogpt_platform/frontend/public/onboarding/marketplace-add.mp4 new file mode 100644 index 0000000000..24df26e74f Binary files /dev/null and b/autogpt_platform/frontend/public/onboarding/marketplace-add.mp4 differ diff --git a/autogpt_platform/frontend/public/onboarding/marketplace-run.mp4 b/autogpt_platform/frontend/public/onboarding/marketplace-run.mp4 new file mode 100644 index 0000000000..99e270a2eb Binary files /dev/null and b/autogpt_platform/frontend/public/onboarding/marketplace-run.mp4 differ diff --git a/autogpt_platform/frontend/public/onboarding/marketplace-visit.mp4 b/autogpt_platform/frontend/public/onboarding/marketplace-visit.mp4 new file mode 100644 index 0000000000..33735a19d6 Binary files /dev/null and b/autogpt_platform/frontend/public/onboarding/marketplace-visit.mp4 differ diff --git a/autogpt_platform/frontend/src/app/build/page.tsx b/autogpt_platform/frontend/src/app/build/page.tsx index b9de310e41..3ba2ca35c8 100644 --- a/autogpt_platform/frontend/src/app/build/page.tsx +++ b/autogpt_platform/frontend/src/app/build/page.tsx @@ -3,9 +3,16 @@ import { useSearchParams } from "next/navigation"; import { GraphID } from "@/lib/autogpt-server-api/types"; import FlowEditor from "@/components/Flow"; +import { useOnboarding } from "@/components/onboarding/onboarding-provider"; +import { useEffect } from "react"; export default function Home() { const query = useSearchParams(); + const { completeStep } = useOnboarding(); + + useEffect(() => { + completeStep("BUILDER_OPEN"); + }, []); return ( (false); const [confirmingDeleteAgentRun, setConfirmingDeleteAgentRun] = useState(null); + const { state, updateState } = useOnboarding(); const openRunDraftView = useCallback(() => { selectView({ type: "run" }); @@ -78,6 +80,18 @@ export default function AgentRunsPage(): React.ReactElement { [api, graphVersions], ); + // Reward user for viewing results of their onboarding agent + useEffect(() => { + if (!state || !selectedRun || state.completedSteps.includes("GET_RESULTS")) + return; + + if (selectedRun.id === state.onboardingAgentExecutionId) { + updateState({ + completedSteps: [...state.completedSteps, "GET_RESULTS"], + }); + } + }, [selectedRun, state]); + const fetchAgents = useCallback(() => { api.getLibraryAgent(agentID).then((agent) => { setAgent(agent); diff --git a/autogpt_platform/frontend/src/app/marketplace/creator/[creator]/page.tsx b/autogpt_platform/frontend/src/app/marketplace/creator/[creator]/page.tsx index 7dec9660ac..fe70af6968 100644 --- a/autogpt_platform/frontend/src/app/marketplace/creator/[creator]/page.tsx +++ b/autogpt_platform/frontend/src/app/marketplace/creator/[creator]/page.tsx @@ -8,6 +8,7 @@ import { BreadCrumbs } from "@/components/agptui/BreadCrumbs"; import { Metadata } from "next"; import { CreatorInfoCard } from "@/components/agptui/CreatorInfoCard"; import { CreatorLinks } from "@/components/agptui/CreatorLinks"; +import { Separator } from "@/components/ui/separator"; export async function generateMetadata({ params, @@ -78,7 +79,7 @@ export default async function Page({
-
+ {/* 100px margin because our featured sections button are placed 40px below the container */} - + { + updateState({ + onboardingAgentExecutionId: graph_exec_id, + }); + router.push("/onboarding/6-congrats"); + }); }, [api, agent, router, state?.agentInput]); const runYourAgent = ( diff --git a/autogpt_platform/frontend/src/app/onboarding/6-congrats/actions.ts b/autogpt_platform/frontend/src/app/onboarding/6-congrats/actions.ts index 95a7a3abaa..58cadc0bad 100644 --- a/autogpt_platform/frontend/src/app/onboarding/6-congrats/actions.ts +++ b/autogpt_platform/frontend/src/app/onboarding/6-congrats/actions.ts @@ -6,9 +6,6 @@ import { redirect } from "next/navigation"; export async function finishOnboarding() { const api = new BackendAPI(); const onboarding = await api.getUserOnboarding(); - await api.updateUserOnboarding({ - completedSteps: [...onboarding.completedSteps, "CONGRATS"], - }); revalidatePath("/library", "layout"); redirect("/library"); } diff --git a/autogpt_platform/frontend/src/app/onboarding/6-congrats/page.tsx b/autogpt_platform/frontend/src/app/onboarding/6-congrats/page.tsx index 1c985b96fa..90333916e3 100644 --- a/autogpt_platform/frontend/src/app/onboarding/6-congrats/page.tsx +++ b/autogpt_platform/frontend/src/app/onboarding/6-congrats/page.tsx @@ -1,24 +1,26 @@ "use client"; -import { useEffect, useState } from "react"; +import { useEffect, useRef, useState } from "react"; import { cn } from "@/lib/utils"; import { finishOnboarding } from "./actions"; -import confetti from "canvas-confetti"; import { useOnboarding } from "@/components/onboarding/onboarding-provider"; +import * as party from "party-js"; export default function Page() { - useOnboarding(7, "AGENT_INPUT"); + const { state, updateState } = useOnboarding(7, "AGENT_INPUT"); const [showText, setShowText] = useState(false); const [showSubtext, setShowSubtext] = useState(false); + const divRef = useRef(null); useEffect(() => { - confetti({ - particleCount: 120, - spread: 360, - shapes: ["square", "circle"], - scalar: 2, - decay: 0.93, - origin: { y: 0.38, x: 0.51 }, - }); + if (divRef.current) { + party.confetti(divRef.current, { + count: 100, + spread: 180, + shapes: ["square", "circle"], + size: party.variation.range(2, 2), // scalar: 2 + speed: party.variation.range(300, 1000), + }); + } const timer0 = setTimeout(() => { setShowText(true); @@ -29,6 +31,9 @@ export default function Page() { }, 500); const timer2 = setTimeout(() => { + updateState({ + completedSteps: [...(state?.completedSteps || []), "CONGRATS"], + }); finishOnboarding(); }, 3000); @@ -42,6 +47,7 @@ export default function Page() { return (
- You earned 15$ for running your first agent + You earned 3$ for running your first agent

); diff --git a/autogpt_platform/frontend/src/app/onboarding/reset/page.ts b/autogpt_platform/frontend/src/app/onboarding/reset/page.ts index 1c7904b06d..65f36a14f2 100644 --- a/autogpt_platform/frontend/src/app/onboarding/reset/page.ts +++ b/autogpt_platform/frontend/src/app/onboarding/reset/page.ts @@ -5,11 +5,14 @@ export default async function OnboardingResetPage() { const api = new BackendAPI(); await api.updateUserOnboarding({ completedSteps: [], + notificationDot: true, + notified: [], usageReason: null, integrations: [], otherIntegrations: "", selectedStoreListingVersionId: null, agentInput: {}, + onboardingAgentExecutionId: null, }); redirect("/onboarding/1-welcome"); } diff --git a/autogpt_platform/frontend/src/components/agents/agent-run-draft-view.tsx b/autogpt_platform/frontend/src/components/agents/agent-run-draft-view.tsx index f73fa4978f..db87661193 100644 --- a/autogpt_platform/frontend/src/components/agents/agent-run-draft-view.tsx +++ b/autogpt_platform/frontend/src/components/agents/agent-run-draft-view.tsx @@ -11,6 +11,7 @@ import { useToastOnFail } from "@/components/ui/use-toast"; import ActionButtonGroup from "@/components/agptui/action-button-group"; import SchemaTooltip from "@/components/SchemaTooltip"; import { IconPlay } from "@/components/ui/icons"; +import { useOnboarding } from "../onboarding/onboarding-provider"; export default function AgentRunDraftView({ graph, @@ -26,15 +27,18 @@ export default function AgentRunDraftView({ const agentInputs = graph.input_schema.properties; const [inputValues, setInputValues] = useState>({}); + const { state, completeStep } = useOnboarding(); - const doRun = useCallback( - () => - api - .executeGraph(graph.id, graph.version, inputValues) - .then((newRun) => onRun(newRun.graph_exec_id)) - .catch(toastOnFail("execute agent")), - [api, graph, inputValues, onRun, toastOnFail], - ); + const doRun = useCallback(() => { + api + .executeGraph(graph.id, graph.version, inputValues) + .then((newRun) => onRun(newRun.graph_exec_id)) + .catch(toastOnFail("execute agent")); + // Mark run agent onboarding step as completed + if (state?.completedSteps.includes("MARKETPLACE_ADD_AGENT")) { + completeStep("MARKETPLACE_RUN_AGENT"); + } + }, [api, graph, inputValues, onRun, state]); const runActions: ButtonAction[] = useMemo( () => [ diff --git a/autogpt_platform/frontend/src/components/agptui/AgentInfo.tsx b/autogpt_platform/frontend/src/components/agptui/AgentInfo.tsx index a9531221f1..ecf165ed7e 100644 --- a/autogpt_platform/frontend/src/components/agptui/AgentInfo.tsx +++ b/autogpt_platform/frontend/src/components/agptui/AgentInfo.tsx @@ -10,6 +10,7 @@ import { useToast } from "@/components/ui/use-toast"; import useSupabase from "@/hooks/useSupabase"; import { DownloadIcon, LoaderIcon } from "lucide-react"; +import { useOnboarding } from "../onboarding/onboarding-provider"; interface AgentInfoProps { name: string; creator: string; @@ -39,6 +40,7 @@ export const AgentInfo: React.FC = ({ const api = React.useMemo(() => new BackendAPI(), []); const { user } = useSupabase(); const { toast } = useToast(); + const { completeStep } = useOnboarding(); const [downloading, setDownloading] = React.useState(false); @@ -47,6 +49,7 @@ export const AgentInfo: React.FC = ({ const newLibraryAgent = await api.addMarketplaceAgentToLibrary( storeListingVersionId, ); + completeStep("MARKETPLACE_ADD_AGENT"); router.push(`/library/agents/${newLibraryAgent.id}`); } catch (error) { console.error("Failed to add agent to library:", error); @@ -170,10 +173,10 @@ export const AgentInfo: React.FC = ({ {/* Description Section */}
-
+
Description
-
+
{longDescription}
diff --git a/autogpt_platform/frontend/src/components/agptui/BreadCrumbs.tsx b/autogpt_platform/frontend/src/components/agptui/BreadCrumbs.tsx index 21ed54deb4..5be5505150 100644 --- a/autogpt_platform/frontend/src/components/agptui/BreadCrumbs.tsx +++ b/autogpt_platform/frontend/src/components/agptui/BreadCrumbs.tsx @@ -22,7 +22,7 @@ export const BreadCrumbs: React.FC = ({ items }) => { */} -
+
{items.map((item, index) => ( diff --git a/autogpt_platform/frontend/src/components/agptui/CreditsCard.stories.tsx b/autogpt_platform/frontend/src/components/agptui/CreditsCard.stories.tsx deleted file mode 100644 index cc9028eeaa..0000000000 --- a/autogpt_platform/frontend/src/components/agptui/CreditsCard.stories.tsx +++ /dev/null @@ -1,43 +0,0 @@ -import type { Meta, StoryObj } from "@storybook/react"; -import CreditsCard from "./CreditsCard"; -import { userEvent, within } from "@storybook/test"; - -const meta: Meta = { - title: "AGPT UI/Credits Card", - component: CreditsCard, - tags: ["autodocs"], -}; - -export default meta; -type Story = StoryObj; - -export const Default: Story = { - args: { - credits: 0, - }, -}; - -export const SmallNumber: Story = { - args: { - credits: 10, - }, -}; - -export const LargeNumber: Story = { - args: { - credits: 1000000, - }, -}; - -export const InteractionTest: Story = { - args: { - credits: 100, - }, - play: async ({ canvasElement }) => { - const canvas = within(canvasElement); - const refreshButton = canvas.getByRole("button", { - name: /refresh credits/i, - }); - await userEvent.click(refreshButton); - }, -}; diff --git a/autogpt_platform/frontend/src/components/agptui/CreditsCard.tsx b/autogpt_platform/frontend/src/components/agptui/CreditsCard.tsx deleted file mode 100644 index f44cca19f4..0000000000 --- a/autogpt_platform/frontend/src/components/agptui/CreditsCard.tsx +++ /dev/null @@ -1,48 +0,0 @@ -"use client"; - -import { IconRefresh } from "@/components/ui/icons"; -import { useState } from "react"; -import { - Tooltip, - TooltipContent, - TooltipTrigger, -} from "@/components/ui/tooltip"; -import { useBackendAPI } from "@/lib/autogpt-server-api/context"; -import useCredits from "@/hooks/useCredits"; - -const CreditsCard = () => { - const { credits, formatCredits, fetchCredits } = useCredits({ - fetchInitialCredits: true, - }); - const api = useBackendAPI(); - - const onRefresh = async () => { - fetchCredits(); - }; - - return ( -
-
- - Balance: {formatCredits(credits)} - -
- - - - - -

Refresh credits

-
-
-
- ); -}; - -export default CreditsCard; diff --git a/autogpt_platform/frontend/src/components/agptui/FeaturedAgentCard.tsx b/autogpt_platform/frontend/src/components/agptui/FeaturedAgentCard.tsx index b1a6571138..8cbaad4b65 100644 --- a/autogpt_platform/frontend/src/components/agptui/FeaturedAgentCard.tsx +++ b/autogpt_platform/frontend/src/components/agptui/FeaturedAgentCard.tsx @@ -27,7 +27,7 @@ export const FeaturedAgentCard: React.FC = ({ data-testid="featured-store-card" onMouseEnter={() => setIsHovered(true)} onMouseLeave={() => setIsHovered(false)} - className={`flex h-full flex-col ${backgroundColor}`} + className={`flex h-full flex-col ${backgroundColor} rounded-[1.5rem] border-none`} > diff --git a/autogpt_platform/frontend/src/components/agptui/Navbar.tsx b/autogpt_platform/frontend/src/components/agptui/Navbar.tsx index 92f7f94f86..61bcce4836 100644 --- a/autogpt_platform/frontend/src/components/agptui/Navbar.tsx +++ b/autogpt_platform/frontend/src/components/agptui/Navbar.tsx @@ -4,7 +4,7 @@ import { ProfilePopoutMenu } from "./ProfilePopoutMenu"; import { IconType, IconLogIn, IconAutoGPTLogo } from "@/components/ui/icons"; import { MobileNavBar } from "./MobileNavBar"; import { Button } from "./Button"; -import CreditsCard from "./CreditsCard"; +import Wallet from "./Wallet"; import { ProfileDetails } from "@/lib/autogpt-server-api/types"; import { NavbarLink } from "./NavbarLink"; import getServerUser from "@/lib/supabase/getServerUser"; @@ -61,7 +61,7 @@ export const Navbar = async ({ links, menuItemGroups }: NavbarProps) => {
{isLoggedIn ? (
- {profile && } + {profile && } (null); + + const onWalletOpen = useCallback(async () => { + if (state?.notificationDot) { + updateState({ notificationDot: false }); + } + // Refresh credits when the wallet is opened + fetchCredits(); + }, [state?.notificationDot, updateState, fetchCredits]); + + const fadeOut = new party.ModuleBuilder() + .drive("opacity") + .by((t) => 1 - t) + .through("lifetime") + .build(); + + useEffect(() => { + // Check if there are any completed tasks (state?.completedTasks) that + // are not in the state?.notified array and play confetti if so + const pending = state?.completedSteps + .filter((step) => !state?.notified.includes(step)) + // Ignore steps that are not relevant for notifications + .filter( + (step) => + step !== "WELCOME" && + step !== "USAGE_REASON" && + step !== "INTEGRATIONS" && + step !== "AGENT_CHOICE" && + step !== "AGENT_NEW_RUN" && + step !== "AGENT_INPUT", + ); + if ((pending?.length || 0) > 0 && walletRef.current) { + party.confetti(walletRef.current, { + count: 30, + spread: 120, + shapes: ["square", "circle"], + size: party.variation.range(1, 2), + speed: party.variation.range(200, 300), + modules: [fadeOut], + }); + } + }, [state?.completedSteps, state?.notified]); + + return ( + + + + + +
+
+ + Your wallet + +
+
+ Wallet{" "} + {formatCredits(credits)} +
+ + + +
+
+

+ Complete the following tasks to earn more credits! +

+
+ + + +
+
+ ); +} diff --git a/autogpt_platform/frontend/src/components/agptui/composite/AgentsSection.tsx b/autogpt_platform/frontend/src/components/agptui/composite/AgentsSection.tsx index f40bf6b357..2f827811fb 100644 --- a/autogpt_platform/frontend/src/components/agptui/composite/AgentsSection.tsx +++ b/autogpt_platform/frontend/src/components/agptui/composite/AgentsSection.tsx @@ -49,7 +49,7 @@ export const AgentsSection: React.FC = ({
{sectionTitle}
diff --git a/autogpt_platform/frontend/src/components/agptui/composite/FeaturedCreators.tsx b/autogpt_platform/frontend/src/components/agptui/composite/FeaturedCreators.tsx index ffaa6bff51..9accb662d6 100644 --- a/autogpt_platform/frontend/src/components/agptui/composite/FeaturedCreators.tsx +++ b/autogpt_platform/frontend/src/components/agptui/composite/FeaturedCreators.tsx @@ -33,7 +33,7 @@ export const FeaturedCreators: React.FC = ({ return (
-

+

{title}

diff --git a/autogpt_platform/frontend/src/components/agptui/composite/FeaturedSection.tsx b/autogpt_platform/frontend/src/components/agptui/composite/FeaturedSection.tsx index 5b930238e9..84c61295e5 100644 --- a/autogpt_platform/frontend/src/components/agptui/composite/FeaturedSection.tsx +++ b/autogpt_platform/frontend/src/components/agptui/composite/FeaturedSection.tsx @@ -46,7 +46,7 @@ export const FeaturedSection: React.FC = ({ }; return ( -
+

Featured agents

diff --git a/autogpt_platform/frontend/src/components/agptui/composite/HeroSection.tsx b/autogpt_platform/frontend/src/components/agptui/composite/HeroSection.tsx index 7fae3b88fb..e666d9f267 100644 --- a/autogpt_platform/frontend/src/components/agptui/composite/HeroSection.tsx +++ b/autogpt_platform/frontend/src/components/agptui/composite/HeroSection.tsx @@ -4,9 +4,16 @@ import * as React from "react"; import { SearchBar } from "@/components/agptui/SearchBar"; import { FilterChips } from "@/components/agptui/FilterChips"; import { useRouter } from "next/navigation"; +import { useOnboarding } from "@/components/onboarding/onboarding-provider"; export const HeroSection: React.FC = () => { const router = useRouter(); + const { completeStep } = useOnboarding(); + + // Mark marketplace visit task as completed + React.useEffect(() => { + completeStep("MARKETPLACE_VISIT"); + }, [completeStep]); function onFilterChange(selectedFilters: string[]) { const encodedTerm = encodeURIComponent(selectedFilters.join(", ")); diff --git a/autogpt_platform/frontend/src/components/library/library-search-bar.tsx b/autogpt_platform/frontend/src/components/library/library-search-bar.tsx index 89b580edcc..0c5343bf82 100644 --- a/autogpt_platform/frontend/src/components/library/library-search-bar.tsx +++ b/autogpt_platform/frontend/src/components/library/library-search-bar.tsx @@ -49,7 +49,7 @@ export default function LibrarySearchBar(): React.ReactNode { onFocus={() => setIsFocused(true)} onBlur={() => !inputRef.current?.value && setIsFocused(false)} onChange={handleSearchInput} - className="flex-1 border-none font-sans text-[16px] font-normal leading-7 shadow-none focus:shadow-none" + className="flex-1 border-none font-sans text-[16px] font-normal leading-7 shadow-none focus:shadow-none focus:ring-0" type="text" placeholder="Search agents" /> diff --git a/autogpt_platform/frontend/src/components/onboarding/WalletTaskGroups.tsx b/autogpt_platform/frontend/src/components/onboarding/WalletTaskGroups.tsx new file mode 100644 index 0000000000..c63712fc9e --- /dev/null +++ b/autogpt_platform/frontend/src/components/onboarding/WalletTaskGroups.tsx @@ -0,0 +1,331 @@ +import { useCallback, useEffect, useRef, useState } from "react"; +import { ChevronDown, Check } from "lucide-react"; +import { OnboardingStep } from "@/lib/autogpt-server-api"; +import { useOnboarding } from "./onboarding-provider"; +import { cn } from "@/lib/utils"; +import * as party from "party-js"; + +interface Task { + id: OnboardingStep; + name: string; + amount: number; + details: string; + video?: string; +} + +interface TaskGroup { + name: string; + tasks: Task[]; + isOpen: boolean; +} + +export function TaskGroups() { + const [groups, setGroups] = useState([ + { + name: "Run your first agent", + isOpen: false, + tasks: [ + { + id: "CONGRATS", + name: "Finish onboarding", + amount: 3, + details: "Go through our step by step tutorial", + }, + { + id: "GET_RESULTS", + name: "Get results from first agent", + amount: 3, + details: + "Sit back and relax - your agent is running and will finish soon! See the results in the Library once it's done", + video: "/onboarding/get-results.mp4", + }, + ], + }, + { + name: "Explore the Marketplace", + isOpen: false, + tasks: [ + { + id: "MARKETPLACE_VISIT", + name: "Go to Marketplace", + amount: 0, + details: "Click Marketplace in the top navigation", + video: "/onboarding/marketplace-visit.mp4", + }, + { + id: "MARKETPLACE_ADD_AGENT", + name: "Find an agent", + amount: 1, + details: + "Search for an agent in the Marketplace, like a code generator or research assistant and add it to your Library", + video: "/onboarding/marketplace-add.mp4", + }, + { + id: "MARKETPLACE_RUN_AGENT", + name: "Try out your agent", + amount: 1, + details: + "Run the agent you found in the Marketplace from the Library - whether it's a writing assistant, data analyzer, or something else", + video: "/onboarding/marketplace-run.mp4", + }, + ], + }, + { + name: "Build your own agent", + isOpen: false, + tasks: [ + { + id: "BUILDER_OPEN", + name: "Open the Builder", + amount: 0, + details: "Click Builder in the top navigation", + video: "/onboarding/builder-open.mp4", + }, + { + id: "BUILDER_SAVE_AGENT", + name: "Place your first blocks and save your agent", + amount: 1, + details: + "Open block library on the left and add a block to the canvas then save your agent", + video: "/onboarding/builder-save.mp4", + }, + { + id: "BUILDER_RUN_AGENT", + name: "Run your agent", + amount: 1, + details: "Run your agent from the Builder", + video: "/onboarding/builder-run.mp4", + }, + ], + }, + ]); + const { state, updateState } = useOnboarding(); + const refs = useRef>({}); + + const toggleGroup = useCallback((name: string) => { + setGroups((prevGroups) => + prevGroups.map((group) => + group.name === name ? { ...group, isOpen: !group.isOpen } : group, + ), + ); + }, []); + + const isTaskCompleted = useCallback( + (task: Task) => { + return state?.completedSteps?.includes(task.id) || false; + }, + [state?.completedSteps], + ); + + const getCompletedCount = useCallback( + (tasks: Task[]) => { + return tasks.filter((task) => isTaskCompleted(task)).length; + }, + [isTaskCompleted], + ); + + const isGroupCompleted = useCallback( + (group: TaskGroup) => { + return group.tasks.every((task) => isTaskCompleted(task)); + }, + [isTaskCompleted], + ); + + const setRef = (name: string) => (el: HTMLDivElement | null) => { + if (el) { + refs.current[name] = el; + } + }; + + useEffect(() => { + groups.forEach((group) => { + const groupCompleted = isGroupCompleted(group); + // Check if the last task in the group is completed + const alreadyCelebrated = state?.notified.includes( + group.tasks[group.tasks.length - 1].id, + ); + + if (groupCompleted) { + const el = refs.current[group.name]; + if (el && !alreadyCelebrated) { + party.confetti(el, { + count: 50, + spread: 120, + shapes: ["square", "circle"], + size: party.variation.range(1, 2), + speed: party.variation.range(200, 300), + }); + // Update the state to include all group tasks as notified + // This ensures that the confetti effect isn't perpetually triggered on Wallet + const notifiedTasks = group.tasks.map((task) => task.id); + updateState({ + notified: [...(state?.notified || []), ...notifiedTasks], + }); + } + return; + } + + group.tasks.forEach((task) => { + const el = refs.current[task.id]; + if (el && isTaskCompleted(task) && !state?.notified.includes(task.id)) { + party.confetti(el, { + count: 40, + spread: 120, + shapes: ["square", "circle"], + size: party.variation.range(1, 1.5), + speed: party.variation.range(200, 300), + }); + // Update the state to include the task as notified + updateState({ notified: [...(state?.notified || []), task.id] }); + } + }); + }); + }, [state?.completedSteps]); + + return ( +
+ {groups.map((group) => ( +
+ {/* Group Header - unchanged */} +
toggleGroup(group.name)} + > + {/* Name and completed count */} +
+
+ {group.name} +
+
+ {getCompletedCount(group.tasks)} of {group.tasks.length}{" "} + completed +
+
+ {/* Reward and chevron */} +
+
+ $ + {group.tasks + .reduce((sum, task) => sum + task.amount, 0) + .toFixed(2)} +
+ +
+
+ + {/* Tasks */} +
+ {group.tasks.map((task) => ( +
+
+ {/* Checkmark and name */} +
+
+ {isTaskCompleted(task) && ( + + )} +
+ + {task.name} + +
+ {/* Reward */} + {task.amount > 0 && ( + + ${task.amount.toFixed(2)} + + )} +
+ + {/* Details section */} +
+ {task.details} +
+ {task.video && ( +
+ +
+ )} +
+ ))} +
+
+ ))} +
+ ); +} diff --git a/autogpt_platform/frontend/src/components/onboarding/onboarding-provider.tsx b/autogpt_platform/frontend/src/components/onboarding/onboarding-provider.tsx index 62472dc421..36519fd3dc 100644 --- a/autogpt_platform/frontend/src/components/onboarding/onboarding-provider.tsx +++ b/autogpt_platform/frontend/src/components/onboarding/onboarding-provider.tsx @@ -14,9 +14,12 @@ import { const OnboardingContext = createContext< | { state: UserOnboarding | null; - updateState: (state: Partial) => void; + updateState: ( + state: Omit, "rewardedFor">, + ) => void; step: number; setStep: (step: number) => void; + completeStep: (step: OnboardingStep) => void; } | undefined >(undefined); @@ -84,19 +87,23 @@ export default function OnboardingProvider({ }, [api, pathname, router]); const updateState = useCallback( - (newState: Partial) => { + (newState: Omit, "rewardedFor">) => { setState((prev) => { - api.updateUserOnboarding({ ...prev, ...newState }); + api.updateUserOnboarding(newState); if (!prev) { // Handle initial state return { completedSteps: [], + notificationDot: false, + notified: [], + rewardedFor: [], usageReason: null, integrations: [], otherIntegrations: null, selectedStoreListingVersionId: null, agentInput: null, + onboardingAgentExecutionId: null, ...newState, }; } @@ -106,8 +113,21 @@ export default function OnboardingProvider({ [api, setState], ); + const completeStep = useCallback( + (step: OnboardingStep) => { + if (!state || state.completedSteps.includes(step)) return; + + updateState({ + completedSteps: [...state.completedSteps, step], + }); + }, + [api, state], + ); + return ( - + {children} ); diff --git a/autogpt_platform/frontend/src/components/ui/carousel.tsx b/autogpt_platform/frontend/src/components/ui/carousel.tsx index 8a391b5b39..500723b6cf 100644 --- a/autogpt_platform/frontend/src/components/ui/carousel.tsx +++ b/autogpt_platform/frontend/src/components/ui/carousel.tsx @@ -213,7 +213,7 @@ const CarouselPrevious = React.forwardRef< className={cn( "absolute h-[52px] w-[52px] rounded-full", orientation === "horizontal" - ? "right-24 top-0" + ? "right-20 top-0" : "-top-12 left-1/2 -translate-x-1/2 rotate-90", className, )} diff --git a/autogpt_platform/frontend/src/hooks/useAgentGraph.tsx b/autogpt_platform/frontend/src/hooks/useAgentGraph.tsx index 4b5187cbfa..94d3729a6e 100644 --- a/autogpt_platform/frontend/src/hooks/useAgentGraph.tsx +++ b/autogpt_platform/frontend/src/hooks/useAgentGraph.tsx @@ -24,6 +24,7 @@ import { useToast } from "@/components/ui/use-toast"; import { InputItem } from "@/components/RunnerUIWrapper"; import { GraphMeta } from "@/lib/autogpt-server-api"; import { default as NextLink } from "next/link"; +import { useOnboarding } from "@/components/onboarding/onboarding-provider"; const ajv = new Ajv({ strict: false, allErrors: true }); @@ -77,6 +78,7 @@ export default function useAgentGraph( useState(false); const [nodes, setNodes] = useState([]); const [edges, setEdges] = useState([]); + const { state, completeStep } = useOnboarding(); const api = useMemo( () => new BackendAPI(process.env.NEXT_PUBLIC_AGPT_SERVER_URL!), @@ -576,6 +578,9 @@ export default function useAgentGraph( path.set("flowVersion", savedAgent.version.toString()); path.set("flowExecutionID", graphExecution.graph_exec_id); router.push(`${pathname}?${path.toString()}`); + if (state?.completedSteps.includes("BUILDER_SAVE_AGENT")) { + completeStep("BUILDER_RUN_AGENT"); + } }) .catch((error) => { const errorMessage = @@ -966,6 +971,7 @@ export default function useAgentGraph( const saveAgent = useCallback(async () => { try { await _saveAgent(); + completeStep("BUILDER_SAVE_AGENT"); } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error); diff --git a/autogpt_platform/frontend/src/lib/autogpt-server-api/client.ts b/autogpt_platform/frontend/src/lib/autogpt-server-api/client.ts index e3c9d1b73f..60293c44e8 100644 --- a/autogpt_platform/frontend/src/lib/autogpt-server-api/client.ts +++ b/autogpt_platform/frontend/src/lib/autogpt-server-api/client.ts @@ -180,7 +180,9 @@ export default class BackendAPI { return this._get("/onboarding"); } - updateUserOnboarding(onboarding: Partial): Promise { + updateUserOnboarding( + onboarding: Omit, "rewardedFor">, + ): Promise { return this._request("PATCH", "/onboarding", onboarding); } diff --git a/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts b/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts index 92db06841e..ea8f63380f 100644 --- a/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts +++ b/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts @@ -801,15 +801,26 @@ export type OnboardingStep = | "AGENT_CHOICE" | "AGENT_NEW_RUN" | "AGENT_INPUT" - | "CONGRATS"; + | "CONGRATS" + | "GET_RESULTS" + | "MARKETPLACE_VISIT" + | "MARKETPLACE_ADD_AGENT" + | "MARKETPLACE_RUN_AGENT" + | "BUILDER_OPEN" + | "BUILDER_SAVE_AGENT" + | "BUILDER_RUN_AGENT"; export interface UserOnboarding { completedSteps: OnboardingStep[]; + notificationDot: boolean; + notified: OnboardingStep[]; + rewardedFor: OnboardingStep[]; usageReason: string | null; integrations: string[]; otherIntegrations: string | null; selectedStoreListingVersionId: string | null; - agentInput: { [key: string]: string } | null; + agentInput: { [key: string]: string | number } | null; + onboardingAgentExecutionId: GraphExecutionID | null; } /* *** UTILITIES *** */ diff --git a/autogpt_platform/frontend/yarn.lock b/autogpt_platform/frontend/yarn.lock index 74fc88b276..249f0bce62 100644 --- a/autogpt_platform/frontend/yarn.lock +++ b/autogpt_platform/frontend/yarn.lock @@ -5026,11 +5026,6 @@ caniuse-lite@^1.0.30001579, caniuse-lite@^1.0.30001688: resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001697.tgz#040bbbb54463c4b4b3377c716b34a322d16e6fc7" integrity sha512-GwNPlWJin8E+d7Gxq96jxM6w0w+VFeyyXRsjU58emtkYqnbwHqXm5uT2uCmO0RQE9htWknOP4xtBlLmM/gWxvQ== -canvas-confetti@^1.9.3: - version "1.9.3" - resolved "https://registry.yarnpkg.com/canvas-confetti/-/canvas-confetti-1.9.3.tgz#ef4c857420ad8045ab4abe8547261c8cdf229845" - integrity sha512-rFfTURMvmVEX1gyXFgn5QMn81bYk70qa0HLzcIOSVEyl57n6o9ItHeBtUSWdvKAPY0xlvBHno4/v3QPrT83q9g== - case-sensitive-paths-webpack-plugin@^2.4.0: version "2.4.0" resolved "https://registry.yarnpkg.com/case-sensitive-paths-webpack-plugin/-/case-sensitive-paths-webpack-plugin-2.4.0.tgz#db64066c6422eed2e08cc14b986ca43796dbc6d4" @@ -9590,6 +9585,11 @@ parse-passwd@^1.0.0: resolved "https://registry.yarnpkg.com/parse-passwd/-/parse-passwd-1.0.0.tgz#6d5b934a456993b23d37f40a382d6f1666a8e5c6" integrity sha512-1Y1A//QUXEZK7YKz+rD9WydcE1+EuPr6ZBgKecAB8tmoW6UFv0NREVJe1p+jRxtThkcbbKkfwIbWJe/IeE6m2Q== +party-js@^2.2.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/party-js/-/party-js-2.2.0.tgz#3340026971c9e62fd34db102daaa645fbc9130b8" + integrity sha512-50hGuALCpvDTrQLPQ1fgUgxKIWAH28ShVkmeK/3zhO0YJyCqkhrZhQEkWPxDYLvbFJ7YAXyROmFEu35gKpZLtQ== + pascal-case@^3.1.2: version "3.1.2" resolved "https://registry.yarnpkg.com/pascal-case/-/pascal-case-3.1.2.tgz#b48e0ef2b98e205e7c1dae747d0b1508237660eb"