mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-08 13:55:06 -05:00
Merge branch 'dev' into zamilmajdy/secrt-1222-move-scheduler-into-a-singleton
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -19,21 +17,6 @@ from backend.util import json
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_executor_manager_client():
|
||||
from backend.executor import ExecutionManager
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_event_bus():
|
||||
from backend.data.execution import RedisExecutionEventBus
|
||||
|
||||
return RedisExecutionEventBus()
|
||||
|
||||
|
||||
class AgentExecutorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
user_id: str = SchemaField(description="User ID")
|
||||
@@ -76,11 +59,11 @@ class AgentExecutorBlock(Block):
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
from backend.data.execution import ExecutionEventType
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
executor_manager = get_executor_manager_client()
|
||||
event_bus = get_event_bus()
|
||||
event_bus = execution_utils.get_execution_event_bus()
|
||||
|
||||
graph_exec = executor_manager.add_execution(
|
||||
graph_exec = execution_utils.add_graph_execution(
|
||||
graph_id=input_data.graph_id,
|
||||
graph_version=input_data.graph_version,
|
||||
user_id=input_data.user_id,
|
||||
|
||||
@@ -3,7 +3,6 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import cast
|
||||
|
||||
import stripe
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
@@ -12,6 +11,7 @@ from prisma.enums import (
|
||||
CreditRefundRequestStatus,
|
||||
CreditTransactionType,
|
||||
NotificationType,
|
||||
OnboardingStep,
|
||||
)
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User
|
||||
@@ -19,7 +19,6 @@ from prisma.types import (
|
||||
CreditRefundRequestCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
CreditTransactionWhereInput,
|
||||
IntFilter,
|
||||
)
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
@@ -123,6 +122,18 @@ class UserCreditBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
|
||||
"""
|
||||
Reward the user with credits for completing an onboarding step.
|
||||
Won't reward if the user has already received credits for the step.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
step (OnboardingStep): The onboarding step.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def top_up_intent(self, user_id: str, amount: int) -> str:
|
||||
"""
|
||||
@@ -215,7 +226,7 @@ class UserCreditBase(ABC):
|
||||
"userId": user_id,
|
||||
"createdAt": {"lte": top_time},
|
||||
"isActive": True,
|
||||
"runningBalance": cast(IntFilter, {"not": None}),
|
||||
"NOT": [{"runningBalance": None}],
|
||||
},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
@@ -410,6 +421,24 @@ class UserCredit(UserCreditBase):
|
||||
async def top_up_credits(self, user_id: str, amount: int):
|
||||
await self._top_up_credits(user_id, amount)
|
||||
|
||||
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
|
||||
key = f"REWARD-{user_id}-{step.value}"
|
||||
if not await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"transactionKey": key,
|
||||
}
|
||||
):
|
||||
await self._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=credits,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
transaction_key=key,
|
||||
metadata=Json(
|
||||
{"reason": f"Reward for completing {step.value} onboarding step."}
|
||||
),
|
||||
)
|
||||
|
||||
async def top_up_refund(
|
||||
self, user_id: str, transaction_key: str, metadata: dict[str, str]
|
||||
) -> int:
|
||||
@@ -897,6 +926,9 @@ class DisabledUserCredit(UserCreditBase):
|
||||
async def top_up_credits(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def onboarding_reward(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def top_up_intent(self, *args, **kwargs) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ async def transaction():
|
||||
async def locked_transaction(key: str):
|
||||
lock_key = zlib.crc32(key.encode("utf-8"))
|
||||
async with transaction() as tx:
|
||||
await tx.execute_raw(f"SELECT pg_advisory_xact_lock({lock_key})")
|
||||
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
|
||||
yield tx
|
||||
|
||||
|
||||
|
||||
@@ -34,11 +34,10 @@ from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import mock
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .block import BlockData, BlockInput, BlockType, CompletedBlockOutput, get_block
|
||||
from .block import BlockInput, BlockType, CompletedBlockOutput, get_block
|
||||
from .db import BaseDbModel
|
||||
from .includes import (
|
||||
EXECUTION_RESULT_INCLUDE,
|
||||
@@ -203,6 +202,26 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
node_executions=node_executions,
|
||||
)
|
||||
|
||||
def to_graph_execution_entry(self):
|
||||
return GraphExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_id=self.graph_id,
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
start_node_execs=[
|
||||
NodeExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
data=node_exec.input_data,
|
||||
)
|
||||
for node_exec in self.node_executions
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class NodeExecutionResult(BaseModel):
|
||||
user_id: str
|
||||
@@ -469,19 +488,27 @@ async def upsert_execution_output(
|
||||
)
|
||||
|
||||
|
||||
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecution:
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
async def update_graph_execution_start_time(
|
||||
graph_exec_id: str,
|
||||
) -> GraphExecution | None:
|
||||
count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
},
|
||||
data={
|
||||
"executionStatus": ExecutionStatus.RUNNING,
|
||||
"startedAt": datetime.now(tz=timezone.utc),
|
||||
},
|
||||
)
|
||||
if count == 0:
|
||||
return None
|
||||
|
||||
res = await AgentGraphExecution.prisma().find_unique(
|
||||
where={"id": graph_exec_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
if not res:
|
||||
raise ValueError(f"Graph execution #{graph_exec_id} not found")
|
||||
|
||||
return GraphExecution.from_db(res)
|
||||
return GraphExecution.from_db(res) if res else None
|
||||
|
||||
|
||||
async def update_graph_execution_stats(
|
||||
@@ -492,7 +519,8 @@ async def update_graph_execution_stats(
|
||||
data = stats.model_dump() if stats else {}
|
||||
if isinstance(data.get("error"), Exception):
|
||||
data["error"] = str(data["error"])
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
|
||||
updated_count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"OR": [
|
||||
@@ -504,10 +532,15 @@ async def update_graph_execution_stats(
|
||||
"executionStatus": status,
|
||||
"stats": Json(data),
|
||||
},
|
||||
)
|
||||
if updated_count == 0:
|
||||
return None
|
||||
|
||||
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
|
||||
where={"id": graph_exec_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
|
||||
return GraphExecution.from_db(res) if res else None
|
||||
return GraphExecution.from_db(graph_exec)
|
||||
|
||||
|
||||
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
|
||||
@@ -643,7 +676,7 @@ async def get_latest_node_execution(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
|
||||
"NOT": [{"executionStatus": ExecutionStatus.INCOMPLETE}],
|
||||
},
|
||||
order=[
|
||||
{"queuedTime": "desc"},
|
||||
@@ -711,144 +744,6 @@ class ExecutionQueue(Generic[T]):
|
||||
return self.queue.empty()
|
||||
|
||||
|
||||
# ------------------- Execution Utilities -------------------- #
|
||||
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
"""
|
||||
Extracts partial output data by name from a given BlockData.
|
||||
|
||||
The function supports extracting data from lists, dictionaries, and objects
|
||||
using specific naming conventions:
|
||||
- For lists: <output_name>_$_<index>
|
||||
- For dictionaries: <output_name>_#_<key>
|
||||
- For objects: <output_name>_@_<attribute>
|
||||
|
||||
Args:
|
||||
output (BlockData): A tuple containing the output name and data.
|
||||
name (str): The name used to extract specific data from the output.
|
||||
|
||||
Returns:
|
||||
Any | None: The extracted data if found, otherwise None.
|
||||
|
||||
Examples:
|
||||
>>> output = ("result", [10, 20, 30])
|
||||
>>> parse_execution_output(output, "result_$_1")
|
||||
20
|
||||
|
||||
>>> output = ("config", {"key1": "value1", "key2": "value2"})
|
||||
>>> parse_execution_output(output, "config_#_key1")
|
||||
'value1'
|
||||
|
||||
>>> class Sample:
|
||||
... attr1 = "value1"
|
||||
... attr2 = "value2"
|
||||
>>> output = ("object", Sample())
|
||||
>>> parse_execution_output(output, "object_@_attr1")
|
||||
'value1'
|
||||
"""
|
||||
output_name, output_data = output
|
||||
|
||||
if name == output_name:
|
||||
return output_data
|
||||
|
||||
if name.startswith(f"{output_name}{LIST_SPLIT}"):
|
||||
index = int(name.split(LIST_SPLIT)[1])
|
||||
if not isinstance(output_data, list) or len(output_data) <= index:
|
||||
return None
|
||||
return output_data[int(name.split(LIST_SPLIT)[1])]
|
||||
|
||||
if name.startswith(f"{output_name}{DICT_SPLIT}"):
|
||||
index = name.split(DICT_SPLIT)[1]
|
||||
if not isinstance(output_data, dict) or index not in output_data:
|
||||
return None
|
||||
return output_data[index]
|
||||
|
||||
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
|
||||
index = name.split(OBJC_SPLIT)[1]
|
||||
if isinstance(output_data, object) and hasattr(output_data, index):
|
||||
return getattr(output_data, index)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
"""
|
||||
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
|
||||
|
||||
This function processes input keys that follow specific patterns to merge them into a unified structure:
|
||||
- `<input_name>_$_<index>` for list inputs.
|
||||
- `<input_name>_#_<index>` for dictionary inputs.
|
||||
- `<input_name>_@_<index>` for object inputs.
|
||||
|
||||
Args:
|
||||
data (BlockInput): A dictionary containing input keys and their corresponding values.
|
||||
|
||||
Returns:
|
||||
BlockInput: A dictionary with merged inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If a list index is not an integer.
|
||||
|
||||
Examples:
|
||||
>>> data = {
|
||||
... "list_$_0": "a",
|
||||
... "list_$_1": "b",
|
||||
... "dict_#_key1": "value1",
|
||||
... "dict_#_key2": "value2",
|
||||
... "object_@_attr1": "value1",
|
||||
... "object_@_attr2": "value2"
|
||||
... }
|
||||
>>> merge_execution_input(data)
|
||||
{
|
||||
"list": ["a", "b"],
|
||||
"dict": {"key1": "value1", "key2": "value2"},
|
||||
"object": <MockObject attr1="value1" attr2="value2">
|
||||
}
|
||||
"""
|
||||
|
||||
# Merge all input with <input_name>_$_<index> into a single list.
|
||||
items = list(data.items())
|
||||
|
||||
for key, value in items:
|
||||
if LIST_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(LIST_SPLIT)
|
||||
if not index.isdigit():
|
||||
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
|
||||
|
||||
data[name] = data.get(name, [])
|
||||
if int(index) >= len(data[name]):
|
||||
# Pad list with empty string on missing indices.
|
||||
data[name].extend([""] * (int(index) - len(data[name]) + 1))
|
||||
data[name][int(index)] = value
|
||||
|
||||
# Merge all input with <input_name>_#_<index> into a single dict.
|
||||
for key, value in items:
|
||||
if DICT_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(DICT_SPLIT)
|
||||
data[name] = data.get(name, {})
|
||||
data[name][index] = value
|
||||
|
||||
# Merge all input with <input_name>_@_<index> into a single object.
|
||||
for key, value in items:
|
||||
if OBJC_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(OBJC_SPLIT)
|
||||
if name not in data or not isinstance(data[name], object):
|
||||
data[name] = mock.MockObject()
|
||||
setattr(data[name], index, value)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# --------------------- Event Bus --------------------- #
|
||||
|
||||
|
||||
|
||||
@@ -10,9 +10,7 @@ from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVers
|
||||
from prisma.types import (
|
||||
AgentGraphCreateInput,
|
||||
AgentGraphWhereInput,
|
||||
AgentGraphWhereInputRecursive1,
|
||||
AgentNodeCreateInput,
|
||||
AgentNodeIncludeFromAgentNodeRecursive1,
|
||||
AgentNodeLinkCreateInput,
|
||||
)
|
||||
from pydantic.fields import computed_field
|
||||
@@ -655,14 +653,11 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
graphs = await AgentGraph.prisma().find_many(
|
||||
where={
|
||||
"OR": [
|
||||
type_utils.typed(
|
||||
AgentGraphWhereInputRecursive1,
|
||||
{
|
||||
"id": graph_id,
|
||||
"version": graph_version,
|
||||
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
|
||||
},
|
||||
)
|
||||
{
|
||||
"id": graph_id,
|
||||
"version": graph_version,
|
||||
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
|
||||
}
|
||||
for graph_id, graph_version in sub_graph_ids
|
||||
]
|
||||
},
|
||||
@@ -678,13 +673,7 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
|
||||
links = await AgentNodeLink.prisma().find_many(
|
||||
where={"agentNodeSourceId": node_id},
|
||||
include={
|
||||
"AgentNodeSink": {
|
||||
"include": cast(
|
||||
AgentNodeIncludeFromAgentNodeRecursive1, AGENT_NODE_INCLUDE
|
||||
)
|
||||
}
|
||||
},
|
||||
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}},
|
||||
)
|
||||
return [
|
||||
(Link.from_db(link), NodeModel.from_db(link.AgentNodeSink))
|
||||
@@ -930,12 +919,19 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
# Convert enum values to a list of strings for the SQL query
|
||||
enum_values = [v.value for v in LlmModel.__members__.values()]
|
||||
|
||||
escaped_enum_values = repr(tuple(enum_values)) # hack but works
|
||||
query = f"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = jsonb_set("constantInput", '{{{path}}}', '"{migrate_to.value}"', true)
|
||||
WHERE "agentBlockId" = '{id}'
|
||||
AND "constantInput" ? '{path}'
|
||||
AND "constantInput"->>'{path}' NOT IN ({','.join(f"'{value}'" for value in enum_values)})
|
||||
SET "constantInput" = jsonb_set("constantInput", $1, $2, true)
|
||||
WHERE "agentBlockId" = $3
|
||||
AND "constantInput" ? $4
|
||||
AND "constantInput"->>$4 NOT IN {escaped_enum_values}
|
||||
"""
|
||||
|
||||
await db.execute_raw(query)
|
||||
await db.execute_raw(
|
||||
query, # type: ignore - is supposed to be LiteralString
|
||||
"{" + path + "}",
|
||||
f'"{migrate_to.value}"',
|
||||
id,
|
||||
path,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ import prisma.enums
|
||||
import prisma.types
|
||||
|
||||
from backend.blocks.io import IO_BLOCK_IDs
|
||||
from backend.util.type import typed_cast
|
||||
|
||||
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
"Input": True,
|
||||
@@ -14,13 +13,7 @@ AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
}
|
||||
|
||||
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
|
||||
"Nodes": {
|
||||
"include": typed_cast(
|
||||
prisma.types.AgentNodeIncludeFromAgentNodeRecursive1,
|
||||
prisma.types.AgentNodeIncludeFromAgentNode,
|
||||
AGENT_NODE_INCLUDE,
|
||||
)
|
||||
}
|
||||
"Nodes": {"include": AGENT_NODE_INCLUDE}
|
||||
}
|
||||
|
||||
EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
@@ -56,13 +49,7 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES["NodeExecutions"],
|
||||
),
|
||||
"where": {
|
||||
"Node": typed_cast(
|
||||
prisma.types.AgentNodeRelationFilter,
|
||||
prisma.types.AgentNodeWhereInput,
|
||||
{
|
||||
"AgentBlock": {"id": {"in": IO_BLOCK_IDs}},
|
||||
},
|
||||
),
|
||||
"Node": {"is": {"AgentBlock": {"is": {"id": {"in": IO_BLOCK_IDs}}}}},
|
||||
"NOT": [{"executionStatus": prisma.enums.AgentExecutionStatus.INCOMPLETE}],
|
||||
},
|
||||
}
|
||||
@@ -70,13 +57,7 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {
|
||||
"include": typed_cast(
|
||||
prisma.types.AgentNodeIncludeFromAgentNodeRecursive1,
|
||||
prisma.types.AgentNodeInclude,
|
||||
AGENT_NODE_INCLUDE,
|
||||
)
|
||||
}
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,9 @@ from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block import get_blocks
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
@@ -24,14 +26,19 @@ REASON_MAPPING: dict[str, list[str]] = {
|
||||
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
|
||||
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
|
||||
|
||||
user_credit = get_user_credit_model()
|
||||
|
||||
|
||||
class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
completedSteps: Optional[list[OnboardingStep]] = None
|
||||
notificationDot: Optional[bool] = None
|
||||
notified: Optional[list[OnboardingStep]] = None
|
||||
usageReason: Optional[str] = None
|
||||
integrations: Optional[list[str]] = None
|
||||
otherIntegrations: Optional[str] = None
|
||||
selectedStoreListingVersionId: Optional[str] = None
|
||||
agentInput: Optional[dict[str, Any]] = None
|
||||
onboardingAgentExecutionId: Optional[str] = None
|
||||
|
||||
|
||||
async def get_user_onboarding(user_id: str):
|
||||
@@ -48,6 +55,20 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update: UserOnboardingUpdateInput = {}
|
||||
if data.completedSteps is not None:
|
||||
update["completedSteps"] = list(set(data.completedSteps))
|
||||
for step in (
|
||||
OnboardingStep.AGENT_NEW_RUN,
|
||||
OnboardingStep.GET_RESULTS,
|
||||
OnboardingStep.MARKETPLACE_ADD_AGENT,
|
||||
OnboardingStep.MARKETPLACE_RUN_AGENT,
|
||||
OnboardingStep.BUILDER_SAVE_AGENT,
|
||||
OnboardingStep.BUILDER_RUN_AGENT,
|
||||
):
|
||||
if step in data.completedSteps:
|
||||
await reward_user(user_id, step)
|
||||
if data.notificationDot is not None:
|
||||
update["notificationDot"] = data.notificationDot
|
||||
if data.notified is not None:
|
||||
update["notified"] = list(set(data.notified))
|
||||
if data.usageReason is not None:
|
||||
update["usageReason"] = data.usageReason
|
||||
if data.integrations is not None:
|
||||
@@ -58,16 +79,57 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update["selectedStoreListingVersionId"] = data.selectedStoreListingVersionId
|
||||
if data.agentInput is not None:
|
||||
update["agentInput"] = Json(data.agentInput)
|
||||
if data.onboardingAgentExecutionId is not None:
|
||||
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
|
||||
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, **update}, # type: ignore
|
||||
"create": {"userId": user_id, **update},
|
||||
"update": update,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def reward_user(user_id: str, step: OnboardingStep):
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}-reward"):
|
||||
reward = 0
|
||||
match step:
|
||||
# Reward user when they clicked New Run during onboarding
|
||||
# This is because they need credits before scheduling a run (next step)
|
||||
case OnboardingStep.AGENT_NEW_RUN:
|
||||
reward = 300
|
||||
case OnboardingStep.GET_RESULTS:
|
||||
reward = 300
|
||||
case OnboardingStep.MARKETPLACE_ADD_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.MARKETPLACE_RUN_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_SAVE_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_RUN_AGENT:
|
||||
reward = 100
|
||||
|
||||
if reward == 0:
|
||||
return
|
||||
|
||||
onboarding = await get_user_onboarding(user_id)
|
||||
|
||||
# Skip if already rewarded
|
||||
if step in onboarding.rewardedFor:
|
||||
return
|
||||
|
||||
onboarding.rewardedFor.append(step)
|
||||
await user_credit.onboarding_reward(user_id, reward, step)
|
||||
await UserOnboarding.prisma().update(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"completedSteps": list(set(onboarding.completedSteps + [step])),
|
||||
"rewardedFor": onboarding.rewardedFor,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def clean_and_split(text: str) -> list[str]:
|
||||
"""
|
||||
Removes all special characters from a string, truncates it to 100 characters,
|
||||
|
||||
@@ -149,7 +149,7 @@ async def migrate_and_encrypt_user_integrations():
|
||||
logger.info(f"Migrating integration credentials for {len(users)} users")
|
||||
|
||||
for user in users:
|
||||
raw_metadata = cast(UserMetadataRaw, user.metadata)
|
||||
raw_metadata = cast(dict, user.metadata)
|
||||
metadata = UserMetadata.model_validate(raw_metadata)
|
||||
|
||||
# Get existing integrations data
|
||||
@@ -165,7 +165,6 @@ async def migrate_and_encrypt_user_integrations():
|
||||
await update_user_integrations(user_id=user.id, data=integrations)
|
||||
|
||||
# Remove from metadata
|
||||
raw_metadata = dict(raw_metadata)
|
||||
raw_metadata.pop("integration_credentials", None)
|
||||
raw_metadata.pop("integration_oauth_states", None)
|
||||
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import logging
|
||||
|
||||
from backend.data import db, redis
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
GraphExecution,
|
||||
NodeExecutionResult,
|
||||
RedisExecutionEventBus,
|
||||
create_graph_execution,
|
||||
get_graph_execution,
|
||||
get_incomplete_node_executions,
|
||||
@@ -42,7 +39,7 @@ from backend.data.user import (
|
||||
update_user_integrations,
|
||||
update_user_metadata,
|
||||
)
|
||||
from backend.util.service import AppService, expose, exposed_run_and_wait
|
||||
from backend.util.service import AppService, exposed_run_and_wait
|
||||
from backend.util.settings import Config
|
||||
|
||||
config = Config()
|
||||
@@ -57,21 +54,14 @@ async def _spend_credits(
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.execution_event_bus = RedisExecutionEventBus()
|
||||
|
||||
def run_service(self) -> None:
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
self.run_and_wait(db.connect())
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
|
||||
redis.connect()
|
||||
super().run_service()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
|
||||
@@ -79,12 +69,6 @@ class DatabaseManager(AppService):
|
||||
def get_port(cls) -> int:
|
||||
return config.database_api_port
|
||||
|
||||
@expose
|
||||
def send_execution_update(
|
||||
self, execution_result: GraphExecution | NodeExecutionResult
|
||||
):
|
||||
self.execution_event_bus.publish(execution_result)
|
||||
|
||||
# Executions
|
||||
get_graph_execution = exposed_run_and_wait(get_graph_execution)
|
||||
create_graph_execution = exposed_run_and_wait(create_graph_execution)
|
||||
|
||||
@@ -5,11 +5,14 @@ import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
@@ -30,43 +33,36 @@ from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockData,
|
||||
BlockInput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.block import BlockData, BlockInput, BlockSchema, get_block
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
merge_execution_input,
|
||||
parse_execution_output,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Link, Node
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.executor.utils import (
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
CancelExecutionEvent,
|
||||
UsageTransactionMetadata,
|
||||
block_usage_cost,
|
||||
execution_usage_cost,
|
||||
get_execution_event_bus,
|
||||
get_execution_queue,
|
||||
parse_execution_output,
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util import json
|
||||
from backend.util.decorator import error_logged, time_measured
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import configure_logging
|
||||
from backend.util.process import set_service_name
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
close_service_client,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.service import close_service_client, get_service_client
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.type import convert
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -152,7 +148,7 @@ def execute_node(
|
||||
def update_execution_status(status: ExecutionStatus) -> NodeExecutionResult:
|
||||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||||
exec_update = db_client.update_node_execution_status(node_exec_id, status)
|
||||
db_client.send_execution_update(exec_update)
|
||||
send_execution_update(exec_update)
|
||||
return exec_update
|
||||
|
||||
node = db_client.get_node(node_id)
|
||||
@@ -288,7 +284,7 @@ def _enqueue_next_nodes(
|
||||
exec_update = db_client.update_node_execution_status(
|
||||
node_exec_id, ExecutionStatus.QUEUED, data
|
||||
)
|
||||
db_client.send_execution_update(exec_update)
|
||||
send_execution_update(exec_update)
|
||||
return NodeExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
@@ -400,60 +396,6 @@ def _enqueue_next_nodes(
|
||||
]
|
||||
|
||||
|
||||
def validate_exec(
|
||||
node: Node,
|
||||
data: BlockInput,
|
||||
resolve_input: bool = True,
|
||||
) -> tuple[BlockInput | None, str]:
|
||||
"""
|
||||
Validate the input data for a node execution.
|
||||
|
||||
Args:
|
||||
node: The node to execute.
|
||||
data: The input data for the node execution.
|
||||
resolve_input: Whether to resolve dynamic pins into dict/list/object.
|
||||
|
||||
Returns:
|
||||
A tuple of the validated data and the block name.
|
||||
If the data is invalid, the first element will be None, and the second element
|
||||
will be an error message.
|
||||
If the data is valid, the first element will be the resolved input data, and
|
||||
the second element will be the block name.
|
||||
"""
|
||||
node_block: Block | None = get_block(node.block_id)
|
||||
if not node_block:
|
||||
return None, f"Block for {node.block_id} not found."
|
||||
schema = node_block.input_schema
|
||||
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
if (value := data.get(name)) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data (without default values) should contain all required fields.
|
||||
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
|
||||
if missing_links := schema.get_missing_links(data, node.input_links):
|
||||
return None, f"{error_prefix} unpopulated links {missing_links}"
|
||||
|
||||
# Merge input data with default values and resolve dynamic dict/list/object pins.
|
||||
input_default = schema.get_input_defaults(node.input_default)
|
||||
data = {**input_default, **data}
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
if missing_input := schema.get_missing_input(data):
|
||||
return None, f"{error_prefix} missing input {missing_input}"
|
||||
|
||||
# Last validation: Validate the input values against the schema.
|
||||
if error := schema.get_mismatch_error(data):
|
||||
error_message = f"{error_prefix} {error}"
|
||||
logger.error(error_message)
|
||||
return None, error_message
|
||||
|
||||
return data, node_block.name
|
||||
|
||||
|
||||
class Executor:
|
||||
"""
|
||||
This class contains event handlers for the process pool executor events.
|
||||
@@ -633,7 +575,13 @@ class Executor:
|
||||
exec_meta = cls.db_client.update_graph_execution_start_time(
|
||||
graph_exec.graph_exec_id
|
||||
)
|
||||
cls.db_client.send_execution_update(exec_meta)
|
||||
if exec_meta is None:
|
||||
logger.warning(
|
||||
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution is not found or not currently in the QUEUED state."
|
||||
)
|
||||
return
|
||||
|
||||
send_execution_update(exec_meta)
|
||||
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
|
||||
graph_exec, cancel, log_metadata
|
||||
)
|
||||
@@ -646,7 +594,7 @@ class Executor:
|
||||
status=status,
|
||||
stats=exec_stats,
|
||||
):
|
||||
cls.db_client.send_execution_update(graph_exec_result)
|
||||
send_execution_update(graph_exec_result)
|
||||
|
||||
cls._handle_agent_run_notif(graph_exec, exec_stats)
|
||||
|
||||
@@ -759,7 +707,7 @@ class Executor:
|
||||
status=execution_status,
|
||||
stats=execution_stats,
|
||||
):
|
||||
cls.db_client.send_execution_update(_graph_exec)
|
||||
send_execution_update(_graph_exec)
|
||||
else:
|
||||
logger.error(
|
||||
"Callback for "
|
||||
@@ -810,7 +758,7 @@ class Executor:
|
||||
exec_update = cls.db_client.update_node_execution_status(
|
||||
node_exec_id, execution_status
|
||||
)
|
||||
cls.db_client.send_execution_update(exec_update)
|
||||
send_execution_update(exec_update)
|
||||
|
||||
cls._handle_low_balance_notif(
|
||||
graph_exec.user_id,
|
||||
@@ -927,22 +875,25 @@ class Executor:
|
||||
)
|
||||
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
class ExecutionManager(AppProcess):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecutionEntry]()
|
||||
self.running = True
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return settings.config.execution_manager_port
|
||||
|
||||
def run_service(self):
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
self.credentials_store = IntegrationCredentialsStore()
|
||||
def run(self):
|
||||
while True:
|
||||
try:
|
||||
self._run()
|
||||
except Exception:
|
||||
logger.exception(f"[{self.service_name}] error in graph executor loop")
|
||||
|
||||
def _run(self):
|
||||
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
|
||||
self.executor = ProcessPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
@@ -952,25 +903,103 @@ class ExecutionManager(AppService):
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
|
||||
redis.connect()
|
||||
|
||||
sync_manager = multiprocessing.Manager()
|
||||
logger.info(f"[{self.service_name}] Ready to consume messages...")
|
||||
while True:
|
||||
graph_exec_data = self.queue.get()
|
||||
graph_exec_id = graph_exec_data.graph_exec_id
|
||||
logger.debug(
|
||||
f"[ExecutionManager] Dispatching graph execution {graph_exec_id}"
|
||||
channel = get_execution_queue().get_channel()
|
||||
|
||||
# cancel graph execution requests
|
||||
method_frame, _, body = channel.basic_get(
|
||||
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
auto_ack=True,
|
||||
)
|
||||
cancel_event = sync_manager.Event()
|
||||
future = self.executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_data, cancel_event
|
||||
if method_frame:
|
||||
self._handle_cancel_message(body)
|
||||
|
||||
# start graph execution requests
|
||||
method_frame, _, body = channel.basic_get(
|
||||
queue=GRAPH_EXECUTION_QUEUE_NAME,
|
||||
auto_ack=False,
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
future.add_done_callback(
|
||||
lambda _: self.active_graph_runs.pop(graph_exec_id, None)
|
||||
if method_frame:
|
||||
self._handle_run_message(channel, method_frame, body)
|
||||
else:
|
||||
time.sleep(0.2)
|
||||
|
||||
def _handle_cancel_message(self, body: bytes):
|
||||
try:
|
||||
request = CancelExecutionEvent.model_validate_json(body)
|
||||
graph_exec_id = request.graph_exec_id
|
||||
if not graph_exec_id:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Cancel message missing 'graph_exec_id'"
|
||||
)
|
||||
return
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
logger.debug(
|
||||
f"[{self.service_name}] Cancel received for {graph_exec_id} but not active."
|
||||
)
|
||||
return
|
||||
|
||||
_, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
logger.info(f"[{self.service_name}] Received cancel for {graph_exec_id}")
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
else:
|
||||
logger.debug(
|
||||
f"[{self.service_name}] Cancel already set for {graph_exec_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling cancel message: {e}")
|
||||
|
||||
def _handle_run_message(
|
||||
self, channel: BlockingChannel, method_frame: Basic.GetOk, body: bytes
|
||||
):
|
||||
delivery_tag = method_frame.delivery_tag
|
||||
try:
|
||||
graph_exec_entry = GraphExecutionEntry.model_validate_json(body)
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.service_name}] Could not parse run message: {e}")
|
||||
channel.basic_nack(delivery_tag, requeue=False)
|
||||
return
|
||||
|
||||
graph_exec_id = graph_exec_entry.graph_exec_id
|
||||
logger.info(
|
||||
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
|
||||
)
|
||||
if graph_exec_id in self.active_graph_runs:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
|
||||
)
|
||||
channel.basic_nack(delivery_tag, requeue=False)
|
||||
return
|
||||
|
||||
cancel_event = multiprocessing.Manager().Event()
|
||||
future = self.executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_entry, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
|
||||
def _on_run_done(f: Future):
|
||||
logger.info(f"[{self.service_name}] Run completed for {graph_exec_id}")
|
||||
try:
|
||||
channel.basic_ack(delivery_tag)
|
||||
self.active_graph_runs.pop(graph_exec_id, None)
|
||||
if f.exception():
|
||||
logger.error(
|
||||
f"[{self.service_name}] Execution for {graph_exec_id} failed: {f.exception()}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.service_name}] Error acknowledging message: {e}")
|
||||
|
||||
future.add_done_callback(_on_run_done)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down service loop...")
|
||||
self.running = False
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down graph executor pool...")
|
||||
self.executor.shutdown(cancel_futures=True)
|
||||
|
||||
@@ -981,175 +1010,6 @@ class ExecutionManager(AppService):
|
||||
def db_client(self) -> "DatabaseManager":
|
||||
return get_db_client()
|
||||
|
||||
@expose
|
||||
def add_execution(
|
||||
self,
|
||||
graph_id: str,
|
||||
data: BlockInput,
|
||||
user_id: str,
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: str | None = None,
|
||||
) -> GraphExecutionEntry:
|
||||
graph: GraphModel | None = self.db_client.get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise ValueError(f"Graph #{graph_id} not found.")
|
||||
|
||||
graph.validate_graph(for_run=True)
|
||||
self._validate_node_input_credentials(graph, user_id)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
input_data = {}
|
||||
block = node.block
|
||||
|
||||
# Note block should never be executed.
|
||||
if block.block_type == BlockType.NOTE:
|
||||
continue
|
||||
|
||||
# Extract request input data, and assign it to the input pin.
|
||||
if block.block_type == BlockType.INPUT:
|
||||
input_name = node.input_default.get("name")
|
||||
if input_name and input_name in data:
|
||||
input_data = {"value": data[input_name]}
|
||||
|
||||
# Extract webhook payload, and assign it to the input pin
|
||||
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
||||
if (
|
||||
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
and node.webhook_id
|
||||
):
|
||||
if webhook_payload_key not in data:
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} webhook payload is missing"
|
||||
)
|
||||
input_data = {"payload": data[webhook_payload_key]}
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
raise ValueError(error)
|
||||
else:
|
||||
nodes_input.append((node.id, input_data))
|
||||
|
||||
if not nodes_input:
|
||||
raise ValueError(
|
||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||
)
|
||||
|
||||
graph_exec = self.db_client.create_graph_execution(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
nodes_input=nodes_input,
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
self.db_client.send_execution_update(graph_exec)
|
||||
|
||||
graph_exec_entry = GraphExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version or 0,
|
||||
graph_exec_id=graph_exec.id,
|
||||
start_node_execs=[
|
||||
NodeExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
data=node_exec.input_data,
|
||||
)
|
||||
for node_exec in graph_exec.node_executions
|
||||
],
|
||||
)
|
||||
self.queue.add(graph_exec_entry)
|
||||
|
||||
return graph_exec_entry
|
||||
|
||||
@expose
|
||||
def cancel_execution(self, graph_exec_id: str) -> None:
|
||||
"""
|
||||
Mechanism:
|
||||
1. Set the cancel event
|
||||
2. Graph executor's cancel handler thread detects the event, terminates workers,
|
||||
reinitializes worker pool, and returns.
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
logger.warning(
|
||||
f"Graph execution #{graph_exec_id} not active/running: "
|
||||
"possibly already completed/cancelled."
|
||||
)
|
||||
else:
|
||||
future, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
future.result()
|
||||
|
||||
# Update the status of the graph & node executions
|
||||
self.db_client.update_graph_execution_stats(
|
||||
graph_exec_id,
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
node_execs = self.db_client.get_node_execution_results(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
self.db_client.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
for node_exec in node_execs:
|
||||
node_exec.status = ExecutionStatus.TERMINATED
|
||||
self.db_client.send_execution_update(node_exec)
|
||||
|
||||
def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
for node in graph.nodes:
|
||||
block = node.block
|
||||
|
||||
# Find any fields of type CredentialsMetaInput
|
||||
credentials_fields = cast(
|
||||
type[BlockSchema], block.input_schema
|
||||
).get_credentials_fields()
|
||||
if not credentials_fields:
|
||||
continue
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = self.credentials_store.get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
f"Unknown credentials #{credentials_meta.id} "
|
||||
f"for node #{node.id} input '{field_name}'"
|
||||
)
|
||||
if (
|
||||
credentials.provider != credentials_meta.provider
|
||||
or credentials.type != credentials_meta.type
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch: "
|
||||
f"{credentials_meta.type}<>{credentials.type};"
|
||||
f"{credentials_meta.provider}<>{credentials.provider}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch"
|
||||
)
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
@@ -1168,6 +1028,10 @@ def get_notification_service() -> "NotificationManager":
|
||||
return get_service_client(NotificationManager)
|
||||
|
||||
|
||||
def send_execution_update(entry: GraphExecution | NodeExecutionResult):
|
||||
return get_execution_event_bus().publish(entry)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def synchronized(key: str, timeout: int = 60):
|
||||
lock: RedisLock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)
|
||||
|
||||
@@ -16,7 +16,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
from backend.executor.manager import ExecutionManager
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import Config
|
||||
@@ -57,11 +57,6 @@ def job_listener(event):
|
||||
log(f"Job {event.job_id} completed successfully.")
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_client() -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_client():
|
||||
from backend.notifications import NotificationManager
|
||||
@@ -73,7 +68,7 @@ def execute_graph(**kwargs):
|
||||
args = ExecutionJobArgs(**kwargs)
|
||||
try:
|
||||
log(f"Executing recurring job for graph #{args.graph_id}")
|
||||
get_execution_client().add_execution(
|
||||
execution_utils.add_graph_execution(
|
||||
graph_id=args.graph_id,
|
||||
data=args.input_data,
|
||||
user_id=args.user_id,
|
||||
@@ -164,11 +159,6 @@ class Scheduler(AppService):
|
||||
def db_pool_size(cls) -> int:
|
||||
return config.scheduler_db_pool_size
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def execution_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def notification_client(self) -> NotificationManager:
|
||||
|
||||
@@ -1,11 +1,70 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockInput
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockData,
|
||||
BlockInput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCostType
|
||||
from backend.data.execution import GraphExecutionEntry, RedisExecutionEventBus
|
||||
from backend.data.graph import GraphModel, Node
|
||||
from backend.data.rabbitmq import (
|
||||
Exchange,
|
||||
ExchangeType,
|
||||
Queue,
|
||||
RabbitMQConfig,
|
||||
SyncRabbitMQ,
|
||||
)
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
config = Config()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ============ Resource Helpers ============ #
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_event_bus() -> RedisExecutionEventBus:
|
||||
return RedisExecutionEventBus()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_queue() -> SyncRabbitMQ:
|
||||
client = SyncRabbitMQ(create_execution_queue_config())
|
||||
client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
return IntegrationCredentialsStore()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client() -> "DatabaseManager":
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
# ============ Execution Cost Helpers ============ #
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
@@ -95,3 +154,398 @@ def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bo
|
||||
or (input_data.get(k) and _is_cost_filter_match(v, input_data[k]))
|
||||
for k, v in cost_filter.items()
|
||||
)
|
||||
|
||||
|
||||
# ============ Execution Input Helpers ============ #
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
"""
|
||||
Extracts partial output data by name from a given BlockData.
|
||||
|
||||
The function supports extracting data from lists, dictionaries, and objects
|
||||
using specific naming conventions:
|
||||
- For lists: <output_name>_$_<index>
|
||||
- For dictionaries: <output_name>_#_<key>
|
||||
- For objects: <output_name>_@_<attribute>
|
||||
|
||||
Args:
|
||||
output (BlockData): A tuple containing the output name and data.
|
||||
name (str): The name used to extract specific data from the output.
|
||||
|
||||
Returns:
|
||||
Any | None: The extracted data if found, otherwise None.
|
||||
|
||||
Examples:
|
||||
>>> output = ("result", [10, 20, 30])
|
||||
>>> parse_execution_output(output, "result_$_1")
|
||||
20
|
||||
|
||||
>>> output = ("config", {"key1": "value1", "key2": "value2"})
|
||||
>>> parse_execution_output(output, "config_#_key1")
|
||||
'value1'
|
||||
|
||||
>>> class Sample:
|
||||
... attr1 = "value1"
|
||||
... attr2 = "value2"
|
||||
>>> output = ("object", Sample())
|
||||
>>> parse_execution_output(output, "object_@_attr1")
|
||||
'value1'
|
||||
"""
|
||||
output_name, output_data = output
|
||||
|
||||
if name == output_name:
|
||||
return output_data
|
||||
|
||||
if name.startswith(f"{output_name}{LIST_SPLIT}"):
|
||||
index = int(name.split(LIST_SPLIT)[1])
|
||||
if not isinstance(output_data, list) or len(output_data) <= index:
|
||||
return None
|
||||
return output_data[int(name.split(LIST_SPLIT)[1])]
|
||||
|
||||
if name.startswith(f"{output_name}{DICT_SPLIT}"):
|
||||
index = name.split(DICT_SPLIT)[1]
|
||||
if not isinstance(output_data, dict) or index not in output_data:
|
||||
return None
|
||||
return output_data[index]
|
||||
|
||||
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
|
||||
index = name.split(OBJC_SPLIT)[1]
|
||||
if isinstance(output_data, object) and hasattr(output_data, index):
|
||||
return getattr(output_data, index)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_exec(
|
||||
node: Node,
|
||||
data: BlockInput,
|
||||
resolve_input: bool = True,
|
||||
) -> tuple[BlockInput | None, str]:
|
||||
"""
|
||||
Validate the input data for a node execution.
|
||||
|
||||
Args:
|
||||
node: The node to execute.
|
||||
data: The input data for the node execution.
|
||||
resolve_input: Whether to resolve dynamic pins into dict/list/object.
|
||||
|
||||
Returns:
|
||||
A tuple of the validated data and the block name.
|
||||
If the data is invalid, the first element will be None, and the second element
|
||||
will be an error message.
|
||||
If the data is valid, the first element will be the resolved input data, and
|
||||
the second element will be the block name.
|
||||
"""
|
||||
node_block: Block | None = get_block(node.block_id)
|
||||
if not node_block:
|
||||
return None, f"Block for {node.block_id} not found."
|
||||
schema = node_block.input_schema
|
||||
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
if (value := data.get(name)) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data (without default values) should contain all required fields.
|
||||
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
|
||||
if missing_links := schema.get_missing_links(data, node.input_links):
|
||||
return None, f"{error_prefix} unpopulated links {missing_links}"
|
||||
|
||||
# Merge input data with default values and resolve dynamic dict/list/object pins.
|
||||
input_default = schema.get_input_defaults(node.input_default)
|
||||
data = {**input_default, **data}
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
if missing_input := schema.get_missing_input(data):
|
||||
return None, f"{error_prefix} missing input {missing_input}"
|
||||
|
||||
# Last validation: Validate the input values against the schema.
|
||||
if error := schema.get_mismatch_error(data):
|
||||
error_message = f"{error_prefix} {error}"
|
||||
logger.error(error_message)
|
||||
return None, error_message
|
||||
|
||||
return data, node_block.name
|
||||
|
||||
|
||||
def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
"""
|
||||
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
|
||||
|
||||
This function processes input keys that follow specific patterns to merge them into a unified structure:
|
||||
- `<input_name>_$_<index>` for list inputs.
|
||||
- `<input_name>_#_<index>` for dictionary inputs.
|
||||
- `<input_name>_@_<index>` for object inputs.
|
||||
|
||||
Args:
|
||||
data (BlockInput): A dictionary containing input keys and their corresponding values.
|
||||
|
||||
Returns:
|
||||
BlockInput: A dictionary with merged inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If a list index is not an integer.
|
||||
|
||||
Examples:
|
||||
>>> data = {
|
||||
... "list_$_0": "a",
|
||||
... "list_$_1": "b",
|
||||
... "dict_#_key1": "value1",
|
||||
... "dict_#_key2": "value2",
|
||||
... "object_@_attr1": "value1",
|
||||
... "object_@_attr2": "value2"
|
||||
... }
|
||||
>>> merge_execution_input(data)
|
||||
{
|
||||
"list": ["a", "b"],
|
||||
"dict": {"key1": "value1", "key2": "value2"},
|
||||
"object": <MockObject attr1="value1" attr2="value2">
|
||||
}
|
||||
"""
|
||||
|
||||
# Merge all input with <input_name>_$_<index> into a single list.
|
||||
items = list(data.items())
|
||||
|
||||
for key, value in items:
|
||||
if LIST_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(LIST_SPLIT)
|
||||
if not index.isdigit():
|
||||
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
|
||||
|
||||
data[name] = data.get(name, [])
|
||||
if int(index) >= len(data[name]):
|
||||
# Pad list with empty string on missing indices.
|
||||
data[name].extend([""] * (int(index) - len(data[name]) + 1))
|
||||
data[name][int(index)] = value
|
||||
|
||||
# Merge all input with <input_name>_#_<index> into a single dict.
|
||||
for key, value in items:
|
||||
if DICT_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(DICT_SPLIT)
|
||||
data[name] = data.get(name, {})
|
||||
data[name][index] = value
|
||||
|
||||
# Merge all input with <input_name>_@_<index> into a single object.
|
||||
for key, value in items:
|
||||
if OBJC_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(OBJC_SPLIT)
|
||||
if name not in data or not isinstance(data[name], object):
|
||||
data[name] = MockObject()
|
||||
setattr(data[name], index, value)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _validate_node_input_credentials(graph: GraphModel, user_id: str):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
for node in graph.nodes:
|
||||
block = node.block
|
||||
|
||||
# Find any fields of type CredentialsMetaInput
|
||||
credentials_fields = cast(
|
||||
type[BlockSchema], block.input_schema
|
||||
).get_credentials_fields()
|
||||
if not credentials_fields:
|
||||
continue
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = get_integration_credentials_store().get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
f"Unknown credentials #{credentials_meta.id} "
|
||||
f"for node #{node.id} input '{field_name}'"
|
||||
)
|
||||
if (
|
||||
credentials.provider != credentials_meta.provider
|
||||
or credentials.type != credentials_meta.type
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch: "
|
||||
f"{credentials_meta.type}<>{credentials.type};"
|
||||
f"{credentials_meta.provider}<>{credentials.provider}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch"
|
||||
)
|
||||
|
||||
|
||||
def construct_node_execution_input(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
data: BlockInput,
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
This function checks the graph for starting nodes, validates the input data
|
||||
against the schema, and resolves dynamic input pins into a single list,
|
||||
dictionary, or object.
|
||||
|
||||
Args:
|
||||
graph (GraphModel): The graph model to execute.
|
||||
user_id (str): The ID of the user executing the graph.
|
||||
data (BlockInput): The input data for the graph execution.
|
||||
|
||||
Returns:
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||
the corresponding input data for that node.
|
||||
"""
|
||||
graph.validate_graph(for_run=True)
|
||||
_validate_node_input_credentials(graph, user_id)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
input_data = {}
|
||||
block = node.block
|
||||
|
||||
# Note block should never be executed.
|
||||
if block.block_type == BlockType.NOTE:
|
||||
continue
|
||||
|
||||
# Extract request input data, and assign it to the input pin.
|
||||
if block.block_type == BlockType.INPUT:
|
||||
input_name = node.input_default.get("name")
|
||||
if input_name and input_name in data:
|
||||
input_data = {"value": data[input_name]}
|
||||
|
||||
# Extract webhook payload, and assign it to the input pin
|
||||
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
||||
if (
|
||||
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
and node.webhook_id
|
||||
):
|
||||
if webhook_payload_key not in data:
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} webhook payload is missing"
|
||||
)
|
||||
input_data = {"payload": data[webhook_payload_key]}
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
raise ValueError(error)
|
||||
else:
|
||||
nodes_input.append((node.id, input_data))
|
||||
|
||||
if not nodes_input:
|
||||
raise ValueError(
|
||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||
)
|
||||
|
||||
return nodes_input
|
||||
|
||||
|
||||
# ============ Execution Queue Helpers ============ #
|
||||
|
||||
|
||||
class CancelExecutionEvent(BaseModel):
|
||||
graph_exec_id: str
|
||||
|
||||
|
||||
GRAPH_EXECUTION_EXCHANGE = Exchange(
|
||||
name="graph_execution",
|
||||
type=ExchangeType.DIRECT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
GRAPH_EXECUTION_QUEUE_NAME = "graph_execution_queue"
|
||||
GRAPH_EXECUTION_ROUTING_KEY = "graph_execution.run"
|
||||
|
||||
GRAPH_EXECUTION_CANCEL_EXCHANGE = Exchange(
|
||||
name="graph_execution_cancel",
|
||||
type=ExchangeType.FANOUT,
|
||||
durable=True,
|
||||
auto_delete=True,
|
||||
)
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME = "graph_execution_cancel_queue"
|
||||
|
||||
|
||||
def create_execution_queue_config() -> RabbitMQConfig:
|
||||
"""
|
||||
Define two exchanges and queues:
|
||||
- 'graph_execution' (DIRECT) for run tasks.
|
||||
- 'graph_execution_cancel' (FANOUT) for cancel requests.
|
||||
"""
|
||||
run_queue = Queue(
|
||||
name=GRAPH_EXECUTION_QUEUE_NAME,
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
cancel_queue = Queue(
|
||||
name=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
|
||||
routing_key="", # not used for FANOUT
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
return RabbitMQConfig(
|
||||
vhost="/",
|
||||
exchanges=[GRAPH_EXECUTION_EXCHANGE, GRAPH_EXECUTION_CANCEL_EXCHANGE],
|
||||
queues=[run_queue, cancel_queue],
|
||||
)
|
||||
|
||||
|
||||
def add_graph_execution(
|
||||
graph_id: str,
|
||||
data: BlockInput,
|
||||
user_id: str,
|
||||
graph_version: int | None = None,
|
||||
preset_id: str | None = None,
|
||||
) -> GraphExecutionEntry:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
|
||||
Args:
|
||||
graph_id (str): The ID of the graph to execute.
|
||||
data (BlockInput): The input data for the graph execution.
|
||||
user_id (str): The ID of the user executing the graph.
|
||||
graph_version (int | None): The version of the graph to execute. Defaults to None.
|
||||
preset_id (str | None): The ID of the preset to use. Defaults to None.
|
||||
Returns:
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
ValueError: If the graph is not found or if there are validation errors.
|
||||
"""
|
||||
graph: GraphModel | None = get_db_client().get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise ValueError(f"Graph #{graph_id} not found.")
|
||||
|
||||
graph_exec = get_db_client().create_graph_execution(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
nodes_input=construct_node_execution_input(graph, user_id, data),
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
get_execution_event_bus().publish(graph_exec)
|
||||
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
get_execution_queue().publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
|
||||
return graph_exec_entry
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Dict, List, Optional, Sequence
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
from typing_extensions import TypedDict
|
||||
@@ -13,17 +12,10 @@ from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKey
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.execution import NodeExecutionResult
|
||||
from backend.executor import ExecutionManager
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.util.service import get_service_client
|
||||
from backend.server.routers import v1 as internal_api_routes
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_manager_client() -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -98,18 +90,18 @@ def execute_graph_block(
|
||||
path="/graphs/{graph_id}/execute/{graph_version}",
|
||||
tags=["graphs"],
|
||||
)
|
||||
def execute_graph(
|
||||
async def execute_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
graph_exec = execution_manager_client().add_execution(
|
||||
graph_id,
|
||||
graph_version=graph_version,
|
||||
data=node_input,
|
||||
graph_exec = await internal_api_routes.execute_graph(
|
||||
graph_id=graph_id,
|
||||
node_input=node_input,
|
||||
user_id=api_key.user_id,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
return {"id": graph_exec.graph_exec_id}
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Literal
|
||||
|
||||
@@ -14,13 +15,12 @@ from backend.data.integrations import (
|
||||
wait_for_webhook_event,
|
||||
)
|
||||
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
|
||||
from backend.executor.manager import ExecutionManager
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.server.routers import v1 as internal_api_routes
|
||||
from backend.util.exceptions import NeedConfirmation, NotFoundError
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -309,19 +309,22 @@ async def webhook_ingress_generic(
|
||||
if not webhook.attached_nodes:
|
||||
return
|
||||
|
||||
executor = get_service_client(ExecutionManager)
|
||||
executions = []
|
||||
for node in webhook.attached_nodes:
|
||||
logger.debug(f"Webhook-attached node: {node}")
|
||||
if not node.is_triggered_by_event_type(event_type):
|
||||
logger.debug(f"Node #{node.id} doesn't trigger on event {event_type}")
|
||||
continue
|
||||
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
|
||||
executor.add_execution(
|
||||
graph_id=node.graph_id,
|
||||
graph_version=node.graph_version,
|
||||
data={f"webhook_{webhook_id}_payload": payload},
|
||||
user_id=webhook.user_id,
|
||||
executions.append(
|
||||
internal_api_routes.execute_graph(
|
||||
graph_id=node.graph_id,
|
||||
graph_version=node.graph_version,
|
||||
node_input={f"webhook_{webhook_id}_payload": payload},
|
||||
user_id=webhook.user_id,
|
||||
)
|
||||
)
|
||||
asyncio.gather(*executions)
|
||||
|
||||
|
||||
@router.post("/webhooks/{webhook_id}/ping")
|
||||
|
||||
@@ -17,7 +17,6 @@ import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.postmark.postmark
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.store_admin_routes
|
||||
@@ -156,7 +155,7 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
graph_version: Optional[int] = None,
|
||||
node_input: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
return backend.server.routers.v1.execute_graph(
|
||||
return await backend.server.routers.v1.execute_graph(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
@@ -275,7 +274,9 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
return backend.server.integrations.router.create_credentials(
|
||||
from backend.server.integrations.router import create_credentials
|
||||
|
||||
return create_credentials(
|
||||
user_id=user_id, provider=provider, credentials=credentials
|
||||
)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Sequence
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Coroutine, Sequence
|
||||
|
||||
import pydantic
|
||||
import stripe
|
||||
@@ -13,7 +13,6 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
import backend.data.block
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.v2.library.db as library_db
|
||||
@@ -31,7 +30,7 @@ from backend.data.api_key import (
|
||||
suspend_api_key,
|
||||
update_api_key_permissions,
|
||||
)
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -41,6 +40,7 @@ from backend.data.credit import (
|
||||
get_user_credit_model,
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
UserOnboardingUpdate,
|
||||
@@ -49,13 +49,16 @@ from backend.data.onboarding import (
|
||||
onboarding_enabled,
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
update_user_notification_preference,
|
||||
)
|
||||
from backend.executor import ExecutionManager, Scheduler, scheduler
|
||||
from backend.executor import Scheduler, scheduler
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||
on_graph_activate,
|
||||
@@ -79,13 +82,23 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_manager_client() -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
def execution_scheduler_client() -> Scheduler:
|
||||
return get_service_client(Scheduler)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_scheduler_client() -> Scheduler:
|
||||
return get_service_client(Scheduler)
|
||||
def execution_queue_client() -> Coroutine[None, None, AsyncRabbitMQ]:
|
||||
async def f() -> AsyncRabbitMQ:
|
||||
client = AsyncRabbitMQ(create_execution_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
return f()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_event_bus() -> AsyncRedisExecutionEventBus:
|
||||
return AsyncRedisExecutionEventBus()
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -206,7 +219,7 @@ async def is_onboarding_enabled():
|
||||
|
||||
@v1_router.get(path="/blocks", tags=["blocks"], dependencies=[Depends(auth_middleware)])
|
||||
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||
blocks = [block() for block in get_blocks().values()]
|
||||
costs = get_block_costs()
|
||||
return [
|
||||
{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks if not b.disabled
|
||||
@@ -219,7 +232,7 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput:
|
||||
obj = backend.data.block.get_block(block_id)
|
||||
obj = get_block(block_id)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||
|
||||
@@ -308,7 +321,7 @@ async def configure_user_auto_top_up(
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_user_auto_top_up(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> AutoTopUpConfig:
|
||||
return await get_auto_top_up(user_id)
|
||||
|
||||
@@ -375,7 +388,7 @@ async def get_credit_history(
|
||||
|
||||
@v1_router.get(path="/credits/refunds", dependencies=[Depends(auth_middleware)])
|
||||
async def get_refund_requests(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[RefundRequest]:
|
||||
return await _user_credit_model.get_refund_requests(user_id)
|
||||
|
||||
@@ -391,7 +404,7 @@ class DeleteGraphResponse(TypedDict):
|
||||
|
||||
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
|
||||
async def get_graphs(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
|
||||
|
||||
@@ -580,16 +593,35 @@ async def set_graph_active_version(
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
def execute_graph(
|
||||
async def execute_graph(
|
||||
graph_id: str,
|
||||
node_input: Annotated[dict[str, Any], Body(..., default_factory=dict)],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> ExecuteGraphResponse:
|
||||
graph_exec = execution_manager_client().add_execution(
|
||||
graph_id, node_input, user_id=user_id, graph_version=graph_version
|
||||
graph: graph_db.GraphModel | None = await graph_db.get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
return ExecuteGraphResponse(graph_exec_id=graph_exec.graph_exec_id)
|
||||
if not graph:
|
||||
raise ValueError(f"Graph #{graph_id} not found.")
|
||||
|
||||
graph_exec = await execution_db.create_graph_execution(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
nodes_input=execution_utils.construct_node_execution_input(
|
||||
graph, user_id, node_input
|
||||
),
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
execution_utils.get_execution_event_bus().publish(graph_exec)
|
||||
execution_utils.get_execution_queue().publish_message(
|
||||
routing_key=execution_utils.GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec.to_graph_execution_entry().model_dump_json(),
|
||||
exchange=execution_utils.GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
return ExecuteGraphResponse(graph_exec_id=graph_exec.id)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -605,9 +637,7 @@ async def stop_graph_run(
|
||||
):
|
||||
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
|
||||
|
||||
await asyncio.to_thread(
|
||||
lambda: execution_manager_client().cancel_execution(graph_exec_id)
|
||||
)
|
||||
await _cancel_execution(graph_exec_id)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
result = await execution_db.get_graph_execution(
|
||||
@@ -621,6 +651,49 @@ async def stop_graph_run(
|
||||
return result
|
||||
|
||||
|
||||
async def _cancel_execution(graph_exec_id: str):
|
||||
"""
|
||||
Mechanism:
|
||||
1. Set the cancel event
|
||||
2. Graph executor's cancel handler thread detects the event, terminates workers,
|
||||
reinitializes worker pool, and returns.
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
queue_client = await execution_queue_client()
|
||||
await queue_client.publish_message(
|
||||
routing_key="",
|
||||
message=execution_utils.CancelExecutionEvent(
|
||||
graph_exec_id=graph_exec_id
|
||||
).model_dump_json(),
|
||||
exchange=execution_utils.GRAPH_EXECUTION_CANCEL_EXCHANGE,
|
||||
)
|
||||
|
||||
# Update the status of the graph & node executions
|
||||
await execution_db.update_graph_execution_stats(
|
||||
graph_exec_id,
|
||||
execution_db.ExecutionStatus.TERMINATED,
|
||||
)
|
||||
node_execs = [
|
||||
node_exec.model_copy(update={"status": execution_db.ExecutionStatus.TERMINATED})
|
||||
for node_exec in await execution_db.get_node_execution_results(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
execution_db.ExecutionStatus.QUEUED,
|
||||
execution_db.ExecutionStatus.RUNNING,
|
||||
execution_db.ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
await execution_db.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
execution_db.ExecutionStatus.TERMINATED,
|
||||
)
|
||||
await asyncio.gather(
|
||||
*[execution_event_bus().publish(node_exec) for node_exec in node_execs]
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/executions",
|
||||
tags=["graphs"],
|
||||
@@ -792,7 +865,7 @@ async def create_api_key(
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_api_keys(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[APIKeyWithoutHash]:
|
||||
"""List all API keys for the user"""
|
||||
try:
|
||||
|
||||
@@ -2,25 +2,16 @@ import logging
|
||||
from typing import Annotated, Any
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
import autogpt_libs.utils.cache
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, status
|
||||
|
||||
import backend.executor
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
import backend.util.service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@autogpt_libs.utils.cache.thread_cached
|
||||
def execution_manager_client() -> backend.executor.ExecutionManager:
|
||||
"""Return a cached instance of ExecutionManager client."""
|
||||
return backend.util.service.get_service_client(backend.executor.ExecutionManager)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/presets",
|
||||
summary="List presets",
|
||||
@@ -216,6 +207,8 @@ async def execute_preset(
|
||||
HTTPException: If the preset is not found or an error occurs while executing the preset.
|
||||
"""
|
||||
try:
|
||||
from backend.server.routers import v1 as internal_api_routes
|
||||
|
||||
preset = await db.get_preset(user_id, preset_id)
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
@@ -226,10 +219,10 @@ async def execute_preset(
|
||||
# Merge input overrides with preset inputs
|
||||
merged_node_input = preset.inputs | node_input
|
||||
|
||||
execution = execution_manager_client().add_execution(
|
||||
execution = await internal_api_routes.execute_graph(
|
||||
graph_id=graph_id,
|
||||
node_input=merged_node_input,
|
||||
graph_version=graph_version,
|
||||
data=merged_node_input,
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,6 @@ import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
from backend.data.graph import GraphModel, get_sub_graphs
|
||||
from backend.data.includes import AGENT_GRAPH_INCLUDE
|
||||
from backend.util.type import typed_cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -960,7 +959,7 @@ async def get_my_agents(
|
||||
try:
|
||||
search_filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
"AgentGraph": {"is": {"StoreListing": {"none": {"isDeleted": False}}}},
|
||||
"AgentGraph": {"is": {"StoreListings": {"none": {"isDeleted": False}}}},
|
||||
"isArchived": False,
|
||||
"isDeleted": False,
|
||||
}
|
||||
@@ -1088,13 +1087,7 @@ async def review_store_submission(
|
||||
where={"id": store_listing_version_id},
|
||||
include={
|
||||
"StoreListing": True,
|
||||
"AgentGraph": {
|
||||
"include": typed_cast(
|
||||
prisma.types.AgentGraphIncludeFromAgentGraphRecursive1,
|
||||
prisma.types.AgentGraphInclude,
|
||||
AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
},
|
||||
"AgentGraph": {"include": AGENT_GRAPH_INCLUDE},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -15,21 +15,25 @@ def to_dict(data) -> dict:
|
||||
|
||||
|
||||
def dumps(data) -> str:
|
||||
return json.dumps(jsonable_encoder(data))
|
||||
return json.dumps(to_dict(data))
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str, *args, target_type: Type[T], **kwargs) -> T: ...
|
||||
def loads(data: str | bytes, *args, target_type: Type[T], **kwargs) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str, *args, **kwargs) -> Any: ...
|
||||
def loads(data: str | bytes, *args, **kwargs) -> Any: ...
|
||||
|
||||
|
||||
def loads(data: str, *args, target_type: Type[T] | None = None, **kwargs) -> Any:
|
||||
def loads(
|
||||
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
|
||||
) -> Any:
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
parsed = json.loads(data, *args, **kwargs)
|
||||
if target_type:
|
||||
return type_match(parsed, target_type)
|
||||
|
||||
@@ -14,9 +14,7 @@ def sentry_init():
|
||||
traces_sample_rate=1.0,
|
||||
profiles_sample_rate=1.0,
|
||||
environment=f"app:{Settings().config.app_env.value}-behave:{Settings().config.behave_as.value}",
|
||||
_experiments={
|
||||
"enable_logs": True,
|
||||
},
|
||||
_experiments={"enable_logs": True},
|
||||
integrations=[
|
||||
LoggingIntegration(sentry_logs_level=logging.INFO),
|
||||
AnthropicIntegration(
|
||||
|
||||
@@ -198,18 +198,6 @@ def convert(value: Any, target_type: Type[T]) -> T:
|
||||
raise ConversionError(f"Failed to convert {value} to {target_type}") from e
|
||||
|
||||
|
||||
def typed(type: type[T], value: T) -> T:
|
||||
"""
|
||||
Add an explicit type to a value. Useful in nested statements, e.g. dict literals.
|
||||
"""
|
||||
return value
|
||||
|
||||
|
||||
def typed_cast(to_type: type[TT], from_type: type[T], value: T) -> TT:
|
||||
"""Strict cast to preserve type checking abilities."""
|
||||
return cast(TT, value)
|
||||
|
||||
|
||||
class FormattedStringType(str):
|
||||
string_format: str
|
||||
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
-- Modify the OnboardingStep enum
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'GET_RESULTS';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_VISIT';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_ADD_AGENT';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_RUN_AGENT';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_OPEN';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_SAVE_AGENT';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_RUN_AGENT';
|
||||
|
||||
-- Modify the UserOnboarding table
|
||||
ALTER TABLE "UserOnboarding"
|
||||
ADD COLUMN "updatedAt" TIMESTAMP(3),
|
||||
ADD COLUMN "notificationDot" BOOLEAN NOT NULL DEFAULT true,
|
||||
ADD COLUMN "notified" "OnboardingStep"[] DEFAULT '{}',
|
||||
ADD COLUMN "rewardedFor" "OnboardingStep"[] DEFAULT '{}',
|
||||
ADD COLUMN "onboardingAgentExecutionId" TEXT
|
||||
@@ -5,7 +5,7 @@ datasource db {
|
||||
|
||||
generator client {
|
||||
provider = "prisma-client-py"
|
||||
recursive_type_depth = 5
|
||||
recursive_type_depth = -1
|
||||
interface = "asyncio"
|
||||
previewFeatures = ["views"]
|
||||
}
|
||||
@@ -58,6 +58,7 @@ model User {
|
||||
}
|
||||
|
||||
enum OnboardingStep {
|
||||
// Introductory onboarding (Library)
|
||||
WELCOME
|
||||
USAGE_REASON
|
||||
INTEGRATIONS
|
||||
@@ -65,18 +66,32 @@ enum OnboardingStep {
|
||||
AGENT_NEW_RUN
|
||||
AGENT_INPUT
|
||||
CONGRATS
|
||||
GET_RESULTS
|
||||
// Marketplace
|
||||
MARKETPLACE_VISIT
|
||||
MARKETPLACE_ADD_AGENT
|
||||
MARKETPLACE_RUN_AGENT
|
||||
// Builder
|
||||
BUILDER_OPEN
|
||||
BUILDER_SAVE_AGENT
|
||||
BUILDER_RUN_AGENT
|
||||
}
|
||||
|
||||
model UserOnboarding {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime? @updatedAt
|
||||
|
||||
completedSteps OnboardingStep[] @default([])
|
||||
notificationDot Boolean @default(true)
|
||||
notified OnboardingStep[] @default([])
|
||||
rewardedFor OnboardingStep[] @default([])
|
||||
usageReason String?
|
||||
integrations String[] @default([])
|
||||
otherIntegrations String?
|
||||
selectedStoreListingVersionId String?
|
||||
agentInput Json?
|
||||
onboardingAgentExecutionId String?
|
||||
|
||||
userId String @unique
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from backend.data.execution import merge_execution_input, parse_execution_output
|
||||
from backend.executor.utils import merge_execution_input, parse_execution_output
|
||||
|
||||
|
||||
def test_parse_execution_output():
|
||||
|
||||
@@ -10,16 +10,12 @@ from prisma.types import (
|
||||
AgentGraphCreateInput,
|
||||
AgentNodeCreateInput,
|
||||
AgentNodeLinkCreateInput,
|
||||
AgentPresetCreateInput,
|
||||
AnalyticsDetailsCreateInput,
|
||||
AnalyticsMetricsCreateInput,
|
||||
APIKeyCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
LibraryAgentCreateInput,
|
||||
ProfileCreateInput,
|
||||
StoreListingCreateInput,
|
||||
StoreListingReviewCreateInput,
|
||||
StoreListingVersionCreateInput,
|
||||
UserCreateInput,
|
||||
)
|
||||
|
||||
@@ -140,14 +136,14 @@ async def main():
|
||||
for _ in range(num_presets): # Create 1 AgentPreset per user
|
||||
graph = random.choice(agent_graphs)
|
||||
preset = await db.agentpreset.create(
|
||||
data=AgentPresetCreateInput(
|
||||
name=faker.sentence(nb_words=3),
|
||||
description=faker.text(max_nb_chars=200),
|
||||
userId=user.id,
|
||||
agentId=graph.id,
|
||||
agentVersion=graph.version,
|
||||
isActive=True,
|
||||
)
|
||||
data={
|
||||
"name": faker.sentence(nb_words=3),
|
||||
"description": faker.text(max_nb_chars=200),
|
||||
"userId": user.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"isActive": True,
|
||||
}
|
||||
)
|
||||
agent_presets.append(preset)
|
||||
|
||||
@@ -160,16 +156,15 @@ async def main():
|
||||
graph = random.choice(agent_graphs)
|
||||
preset = random.choice(agent_presets)
|
||||
user_agent = await db.libraryagent.create(
|
||||
data=LibraryAgentCreateInput(
|
||||
userId=user.id,
|
||||
agentId=graph.id,
|
||||
agentVersion=graph.version,
|
||||
agentPresetId=preset.id,
|
||||
isFavorite=random.choice([True, False]),
|
||||
isCreatedByUser=random.choice([True, False]),
|
||||
isArchived=random.choice([True, False]),
|
||||
isDeleted=random.choice([True, False]),
|
||||
)
|
||||
data={
|
||||
"userId": user.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"isFavorite": random.choice([True, False]),
|
||||
"isCreatedByUser": random.choice([True, False]),
|
||||
"isArchived": random.choice([True, False]),
|
||||
"isDeleted": random.choice([True, False]),
|
||||
}
|
||||
)
|
||||
user_agents.append(user_agent)
|
||||
|
||||
@@ -346,13 +341,13 @@ async def main():
|
||||
user = random.choice(users)
|
||||
slug = faker.slug()
|
||||
listing = await db.storelisting.create(
|
||||
data=StoreListingCreateInput(
|
||||
agentId=graph.id,
|
||||
agentVersion=graph.version,
|
||||
owningUserId=user.id,
|
||||
hasApprovedVersion=random.choice([True, False]),
|
||||
slug=slug,
|
||||
)
|
||||
data={
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"owningUserId": user.id,
|
||||
"hasApprovedVersion": random.choice([True, False]),
|
||||
"slug": slug,
|
||||
}
|
||||
)
|
||||
store_listings.append(listing)
|
||||
|
||||
@@ -362,26 +357,26 @@ async def main():
|
||||
for listing in store_listings:
|
||||
graph = [g for g in agent_graphs if g.id == listing.agentId][0]
|
||||
version = await db.storelistingversion.create(
|
||||
data=StoreListingVersionCreateInput(
|
||||
agentId=graph.id,
|
||||
agentVersion=graph.version,
|
||||
name=graph.name or faker.sentence(nb_words=3),
|
||||
subHeading=faker.sentence(),
|
||||
videoUrl=faker.url(),
|
||||
imageUrls=[get_image() for _ in range(3)],
|
||||
description=faker.text(),
|
||||
categories=[faker.word() for _ in range(3)],
|
||||
isFeatured=random.choice([True, False]),
|
||||
isAvailable=True,
|
||||
storeListingId=listing.id,
|
||||
submissionStatus=random.choice(
|
||||
data={
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"name": graph.name or faker.sentence(nb_words=3),
|
||||
"subHeading": faker.sentence(),
|
||||
"videoUrl": faker.url(),
|
||||
"imageUrls": [get_image() for _ in range(3)],
|
||||
"description": faker.text(),
|
||||
"categories": [faker.word() for _ in range(3)],
|
||||
"isFeatured": random.choice([True, False]),
|
||||
"isAvailable": True,
|
||||
"storeListingId": listing.id,
|
||||
"submissionStatus": random.choice(
|
||||
[
|
||||
prisma.enums.SubmissionStatus.PENDING,
|
||||
prisma.enums.SubmissionStatus.APPROVED,
|
||||
prisma.enums.SubmissionStatus.REJECTED,
|
||||
]
|
||||
),
|
||||
)
|
||||
}
|
||||
)
|
||||
store_listing_versions.append(version)
|
||||
|
||||
@@ -422,23 +417,12 @@ async def main():
|
||||
)
|
||||
await db.storelistingversion.update(
|
||||
where={"id": version.id},
|
||||
data=StoreListingVersionCreateInput(
|
||||
submissionStatus=status,
|
||||
Reviewer={"connect": {"id": reviewer.id}},
|
||||
reviewComments=faker.text(),
|
||||
reviewedAt=datetime.now(),
|
||||
agentId=version.agentId, # preserving existing fields
|
||||
agentVersion=version.agentVersion,
|
||||
name=version.name,
|
||||
subHeading=version.subHeading,
|
||||
videoUrl=version.videoUrl,
|
||||
imageUrls=version.imageUrls,
|
||||
description=version.description,
|
||||
categories=version.categories,
|
||||
isFeatured=version.isFeatured,
|
||||
isAvailable=version.isAvailable,
|
||||
storeListingId=version.storeListingId,
|
||||
),
|
||||
data={
|
||||
"submissionStatus": status,
|
||||
"Reviewer": {"connect": {"id": reviewer.id}},
|
||||
"reviewComments": faker.text(),
|
||||
"reviewedAt": datetime.now(),
|
||||
},
|
||||
)
|
||||
|
||||
# Insert APIKeys
|
||||
|
||||
@@ -52,7 +52,6 @@
|
||||
"@xyflow/react": "12.4.2",
|
||||
"ajv": "^8.17.1",
|
||||
"boring-avatars": "^1.11.2",
|
||||
"canvas-confetti": "^1.9.3",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "1.0.4",
|
||||
@@ -70,6 +69,7 @@
|
||||
"moment": "^2.30.1",
|
||||
"next": "^14.2.26",
|
||||
"next-themes": "^0.4.5",
|
||||
"party-js": "^2.2.0",
|
||||
"react": "^18",
|
||||
"react-day-picker": "^9.6.1",
|
||||
"react-dom": "^18",
|
||||
|
||||
BIN
autogpt_platform/frontend/public/onboarding/builder-open.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/builder-open.mp4
Normal file
Binary file not shown.
BIN
autogpt_platform/frontend/public/onboarding/builder-run.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/builder-run.mp4
Normal file
Binary file not shown.
BIN
autogpt_platform/frontend/public/onboarding/builder-save.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/builder-save.mp4
Normal file
Binary file not shown.
BIN
autogpt_platform/frontend/public/onboarding/get-results.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/get-results.mp4
Normal file
Binary file not shown.
BIN
autogpt_platform/frontend/public/onboarding/marketplace-add.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/marketplace-add.mp4
Normal file
Binary file not shown.
BIN
autogpt_platform/frontend/public/onboarding/marketplace-run.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/marketplace-run.mp4
Normal file
Binary file not shown.
Binary file not shown.
@@ -3,9 +3,16 @@
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { GraphID } from "@/lib/autogpt-server-api/types";
|
||||
import FlowEditor from "@/components/Flow";
|
||||
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
import { useEffect } from "react";
|
||||
|
||||
export default function Home() {
|
||||
const query = useSearchParams();
|
||||
const { completeStep } = useOnboarding();
|
||||
|
||||
useEffect(() => {
|
||||
completeStep("BUILDER_OPEN");
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<FlowEditor
|
||||
|
||||
@@ -22,6 +22,7 @@ import AgentRunDraftView from "@/components/agents/agent-run-draft-view";
|
||||
import AgentRunDetailsView from "@/components/agents/agent-run-details-view";
|
||||
import AgentRunsSelectorList from "@/components/agents/agent-runs-selector-list";
|
||||
import AgentScheduleDetailsView from "@/components/agents/agent-schedule-details-view";
|
||||
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
|
||||
export default function AgentRunsPage(): React.ReactElement {
|
||||
const { id: agentID }: { id: LibraryAgentID } = useParams();
|
||||
@@ -49,6 +50,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
useState<boolean>(false);
|
||||
const [confirmingDeleteAgentRun, setConfirmingDeleteAgentRun] =
|
||||
useState<GraphExecutionMeta | null>(null);
|
||||
const { state, updateState } = useOnboarding();
|
||||
|
||||
const openRunDraftView = useCallback(() => {
|
||||
selectView({ type: "run" });
|
||||
@@ -78,6 +80,18 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
[api, graphVersions],
|
||||
);
|
||||
|
||||
// Reward user for viewing results of their onboarding agent
|
||||
useEffect(() => {
|
||||
if (!state || !selectedRun || state.completedSteps.includes("GET_RESULTS"))
|
||||
return;
|
||||
|
||||
if (selectedRun.id === state.onboardingAgentExecutionId) {
|
||||
updateState({
|
||||
completedSteps: [...state.completedSteps, "GET_RESULTS"],
|
||||
});
|
||||
}
|
||||
}, [selectedRun, state]);
|
||||
|
||||
const fetchAgents = useCallback(() => {
|
||||
api.getLibraryAgent(agentID).then((agent) => {
|
||||
setAgent(agent);
|
||||
|
||||
@@ -8,6 +8,7 @@ import { BreadCrumbs } from "@/components/agptui/BreadCrumbs";
|
||||
import { Metadata } from "next";
|
||||
import { CreatorInfoCard } from "@/components/agptui/CreatorInfoCard";
|
||||
import { CreatorLinks } from "@/components/agptui/CreatorLinks";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
|
||||
export async function generateMetadata({
|
||||
params,
|
||||
@@ -78,7 +79,7 @@ export default async function Page({
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-8 sm:mt-12 md:mt-16 lg:pb-[58px]">
|
||||
<hr className="w-full bg-neutral-700" />
|
||||
<Separator className="mb-6 bg-gray-200" />
|
||||
<AgentsSection
|
||||
agents={creatorAgents.agents}
|
||||
hideAvatars={true}
|
||||
|
||||
@@ -154,7 +154,7 @@ export default async function Page({}: {}) {
|
||||
<HeroSection />
|
||||
<FeaturedSection featuredAgents={featuredAgents.agents} />
|
||||
{/* 100px margin because our featured sections button are placed 40px below the container */}
|
||||
<Separator className="mb-[25px] mt-[100px]" />
|
||||
<Separator className="mb-6 mt-24" />
|
||||
<AgentsSection
|
||||
sectionTitle="Top Agents"
|
||||
agents={topAgents.agents as Agent[]}
|
||||
|
||||
@@ -83,8 +83,14 @@ export default function Page() {
|
||||
api.addMarketplaceAgentToLibrary(
|
||||
storeAgent?.store_listing_version_id || "",
|
||||
);
|
||||
api.executeGraph(agent.id, agent.version, state?.agentInput || {});
|
||||
router.push("/onboarding/6-congrats");
|
||||
api
|
||||
.executeGraph(agent.id, agent.version, state?.agentInput || {})
|
||||
.then(({ graph_exec_id }) => {
|
||||
updateState({
|
||||
onboardingAgentExecutionId: graph_exec_id,
|
||||
});
|
||||
router.push("/onboarding/6-congrats");
|
||||
});
|
||||
}, [api, agent, router, state?.agentInput]);
|
||||
|
||||
const runYourAgent = (
|
||||
|
||||
@@ -6,9 +6,6 @@ import { redirect } from "next/navigation";
|
||||
export async function finishOnboarding() {
|
||||
const api = new BackendAPI();
|
||||
const onboarding = await api.getUserOnboarding();
|
||||
await api.updateUserOnboarding({
|
||||
completedSteps: [...onboarding.completedSteps, "CONGRATS"],
|
||||
});
|
||||
revalidatePath("/library", "layout");
|
||||
redirect("/library");
|
||||
}
|
||||
|
||||
@@ -1,24 +1,26 @@
|
||||
"use client";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { finishOnboarding } from "./actions";
|
||||
import confetti from "canvas-confetti";
|
||||
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
import * as party from "party-js";
|
||||
|
||||
export default function Page() {
|
||||
useOnboarding(7, "AGENT_INPUT");
|
||||
const { state, updateState } = useOnboarding(7, "AGENT_INPUT");
|
||||
const [showText, setShowText] = useState(false);
|
||||
const [showSubtext, setShowSubtext] = useState(false);
|
||||
const divRef = useRef(null);
|
||||
|
||||
useEffect(() => {
|
||||
confetti({
|
||||
particleCount: 120,
|
||||
spread: 360,
|
||||
shapes: ["square", "circle"],
|
||||
scalar: 2,
|
||||
decay: 0.93,
|
||||
origin: { y: 0.38, x: 0.51 },
|
||||
});
|
||||
if (divRef.current) {
|
||||
party.confetti(divRef.current, {
|
||||
count: 100,
|
||||
spread: 180,
|
||||
shapes: ["square", "circle"],
|
||||
size: party.variation.range(2, 2), // scalar: 2
|
||||
speed: party.variation.range(300, 1000),
|
||||
});
|
||||
}
|
||||
|
||||
const timer0 = setTimeout(() => {
|
||||
setShowText(true);
|
||||
@@ -29,6 +31,9 @@ export default function Page() {
|
||||
}, 500);
|
||||
|
||||
const timer2 = setTimeout(() => {
|
||||
updateState({
|
||||
completedSteps: [...(state?.completedSteps || []), "CONGRATS"],
|
||||
});
|
||||
finishOnboarding();
|
||||
}, 3000);
|
||||
|
||||
@@ -42,6 +47,7 @@ export default function Page() {
|
||||
return (
|
||||
<div className="flex h-screen w-screen flex-col items-center justify-center bg-violet-100">
|
||||
<div
|
||||
ref={divRef}
|
||||
className={cn(
|
||||
"z-10 -mb-16 text-9xl duration-500",
|
||||
showText ? "opacity-100" : "opacity-0",
|
||||
@@ -63,7 +69,7 @@ export default function Page() {
|
||||
showSubtext ? "opacity-100" : "opacity-0",
|
||||
)}
|
||||
>
|
||||
You earned 15$ for running your first agent
|
||||
You earned 3$ for running your first agent
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -5,11 +5,14 @@ export default async function OnboardingResetPage() {
|
||||
const api = new BackendAPI();
|
||||
await api.updateUserOnboarding({
|
||||
completedSteps: [],
|
||||
notificationDot: true,
|
||||
notified: [],
|
||||
usageReason: null,
|
||||
integrations: [],
|
||||
otherIntegrations: "",
|
||||
selectedStoreListingVersionId: null,
|
||||
agentInput: {},
|
||||
onboardingAgentExecutionId: null,
|
||||
});
|
||||
redirect("/onboarding/1-welcome");
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import { useToastOnFail } from "@/components/ui/use-toast";
|
||||
import ActionButtonGroup from "@/components/agptui/action-button-group";
|
||||
import SchemaTooltip from "@/components/SchemaTooltip";
|
||||
import { IconPlay } from "@/components/ui/icons";
|
||||
import { useOnboarding } from "../onboarding/onboarding-provider";
|
||||
|
||||
export default function AgentRunDraftView({
|
||||
graph,
|
||||
@@ -26,15 +27,18 @@ export default function AgentRunDraftView({
|
||||
|
||||
const agentInputs = graph.input_schema.properties;
|
||||
const [inputValues, setInputValues] = useState<Record<string, any>>({});
|
||||
const { state, completeStep } = useOnboarding();
|
||||
|
||||
const doRun = useCallback(
|
||||
() =>
|
||||
api
|
||||
.executeGraph(graph.id, graph.version, inputValues)
|
||||
.then((newRun) => onRun(newRun.graph_exec_id))
|
||||
.catch(toastOnFail("execute agent")),
|
||||
[api, graph, inputValues, onRun, toastOnFail],
|
||||
);
|
||||
const doRun = useCallback(() => {
|
||||
api
|
||||
.executeGraph(graph.id, graph.version, inputValues)
|
||||
.then((newRun) => onRun(newRun.graph_exec_id))
|
||||
.catch(toastOnFail("execute agent"));
|
||||
// Mark run agent onboarding step as completed
|
||||
if (state?.completedSteps.includes("MARKETPLACE_ADD_AGENT")) {
|
||||
completeStep("MARKETPLACE_RUN_AGENT");
|
||||
}
|
||||
}, [api, graph, inputValues, onRun, state]);
|
||||
|
||||
const runActions: ButtonAction[] = useMemo(
|
||||
() => [
|
||||
|
||||
@@ -10,6 +10,7 @@ import { useToast } from "@/components/ui/use-toast";
|
||||
|
||||
import useSupabase from "@/hooks/useSupabase";
|
||||
import { DownloadIcon, LoaderIcon } from "lucide-react";
|
||||
import { useOnboarding } from "../onboarding/onboarding-provider";
|
||||
interface AgentInfoProps {
|
||||
name: string;
|
||||
creator: string;
|
||||
@@ -39,6 +40,7 @@ export const AgentInfo: React.FC<AgentInfoProps> = ({
|
||||
const api = React.useMemo(() => new BackendAPI(), []);
|
||||
const { user } = useSupabase();
|
||||
const { toast } = useToast();
|
||||
const { completeStep } = useOnboarding();
|
||||
|
||||
const [downloading, setDownloading] = React.useState(false);
|
||||
|
||||
@@ -47,6 +49,7 @@ export const AgentInfo: React.FC<AgentInfoProps> = ({
|
||||
const newLibraryAgent = await api.addMarketplaceAgentToLibrary(
|
||||
storeListingVersionId,
|
||||
);
|
||||
completeStep("MARKETPLACE_ADD_AGENT");
|
||||
router.push(`/library/agents/${newLibraryAgent.id}`);
|
||||
} catch (error) {
|
||||
console.error("Failed to add agent to library:", error);
|
||||
@@ -170,10 +173,10 @@ export const AgentInfo: React.FC<AgentInfoProps> = ({
|
||||
|
||||
{/* Description Section */}
|
||||
<div className="mb-4 w-full lg:mb-[36px]">
|
||||
<div className="font-geist decoration-skip-ink-none mb-1.5 text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
|
||||
<div className="mb-1.5 font-sans text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
|
||||
Description
|
||||
</div>
|
||||
<div className="font-geist decoration-skip-ink-none text-base font-normal leading-6 text-neutral-600 underline-offset-[from-font] dark:text-neutral-400">
|
||||
<div className="whitespace-pre-line font-sans text-base font-normal leading-6 text-neutral-600 dark:text-neutral-400">
|
||||
{longDescription}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -22,7 +22,7 @@ export const BreadCrumbs: React.FC<BreadCrumbsProps> = ({ items }) => {
|
||||
<button className="flex h-12 w-12 items-center justify-center rounded-full border border-neutral-200 transition-colors hover:bg-neutral-50 dark:border-neutral-700 dark:hover:bg-neutral-800">
|
||||
<IconRightArrow className="h-5 w-5 text-neutral-900 dark:text-neutral-100" />
|
||||
</button> */}
|
||||
<div className="flex h-auto flex-wrap items-center justify-start gap-4 rounded-[5rem] bg-white dark:bg-transparent">
|
||||
<div className="flex h-auto flex-wrap items-center justify-start gap-4 rounded-[5rem] dark:bg-transparent">
|
||||
{items.map((item, index) => (
|
||||
<React.Fragment key={index}>
|
||||
<Link href={item.link}>
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import CreditsCard from "./CreditsCard";
|
||||
import { userEvent, within } from "@storybook/test";
|
||||
|
||||
const meta: Meta<typeof CreditsCard> = {
|
||||
title: "AGPT UI/Credits Card",
|
||||
component: CreditsCard,
|
||||
tags: ["autodocs"],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof CreditsCard>;
|
||||
|
||||
export const Default: Story = {
|
||||
args: {
|
||||
credits: 0,
|
||||
},
|
||||
};
|
||||
|
||||
export const SmallNumber: Story = {
|
||||
args: {
|
||||
credits: 10,
|
||||
},
|
||||
};
|
||||
|
||||
export const LargeNumber: Story = {
|
||||
args: {
|
||||
credits: 1000000,
|
||||
},
|
||||
};
|
||||
|
||||
export const InteractionTest: Story = {
|
||||
args: {
|
||||
credits: 100,
|
||||
},
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
const refreshButton = canvas.getByRole("button", {
|
||||
name: /refresh credits/i,
|
||||
});
|
||||
await userEvent.click(refreshButton);
|
||||
},
|
||||
};
|
||||
@@ -1,48 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { IconRefresh } from "@/components/ui/icons";
|
||||
import { useState } from "react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import useCredits from "@/hooks/useCredits";
|
||||
|
||||
const CreditsCard = () => {
|
||||
const { credits, formatCredits, fetchCredits } = useCredits({
|
||||
fetchInitialCredits: true,
|
||||
});
|
||||
const api = useBackendAPI();
|
||||
|
||||
const onRefresh = async () => {
|
||||
fetchCredits();
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="inline-flex h-[48px] items-center gap-2.5 rounded-2xl bg-neutral-200 p-4 dark:bg-neutral-800">
|
||||
<div className="flex items-center gap-0.5">
|
||||
<span className="p-ui-semibold text-base leading-7 text-neutral-900 dark:text-neutral-50">
|
||||
Balance: {formatCredits(credits)}
|
||||
</span>
|
||||
</div>
|
||||
<Tooltip key="RefreshCredits" delayDuration={500}>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
onClick={onRefresh}
|
||||
className="h-6 w-6 transition-colors hover:text-neutral-700 dark:hover:text-neutral-300"
|
||||
aria-label="Refresh credits"
|
||||
>
|
||||
<IconRefresh className="h-6 w-6" />
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
<p>Refresh credits</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default CreditsCard;
|
||||
@@ -27,7 +27,7 @@ export const FeaturedAgentCard: React.FC<FeaturedStoreCardProps> = ({
|
||||
data-testid="featured-store-card"
|
||||
onMouseEnter={() => setIsHovered(true)}
|
||||
onMouseLeave={() => setIsHovered(false)}
|
||||
className={`flex h-full flex-col ${backgroundColor}`}
|
||||
className={`flex h-full flex-col ${backgroundColor} rounded-[1.5rem] border-none`}
|
||||
>
|
||||
<CardHeader>
|
||||
<CardTitle className="line-clamp-2 text-base sm:text-xl">
|
||||
|
||||
@@ -4,7 +4,7 @@ import { ProfilePopoutMenu } from "./ProfilePopoutMenu";
|
||||
import { IconType, IconLogIn, IconAutoGPTLogo } from "@/components/ui/icons";
|
||||
import { MobileNavBar } from "./MobileNavBar";
|
||||
import { Button } from "./Button";
|
||||
import CreditsCard from "./CreditsCard";
|
||||
import Wallet from "./Wallet";
|
||||
import { ProfileDetails } from "@/lib/autogpt-server-api/types";
|
||||
import { NavbarLink } from "./NavbarLink";
|
||||
import getServerUser from "@/lib/supabase/getServerUser";
|
||||
@@ -61,7 +61,7 @@ export const Navbar = async ({ links, menuItemGroups }: NavbarProps) => {
|
||||
<div className="flex items-center gap-4">
|
||||
{isLoggedIn ? (
|
||||
<div className="flex items-center gap-4">
|
||||
{profile && <CreditsCard />}
|
||||
{profile && <Wallet />}
|
||||
<ProfilePopoutMenu
|
||||
menuItemGroups={menuItemGroups}
|
||||
userName={profile?.username}
|
||||
|
||||
114
autogpt_platform/frontend/src/components/agptui/Wallet.tsx
Normal file
114
autogpt_platform/frontend/src/components/agptui/Wallet.tsx
Normal file
@@ -0,0 +1,114 @@
|
||||
"use client";
|
||||
|
||||
import useCredits from "@/hooks/useCredits";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/ui/popover";
|
||||
import { X } from "lucide-react";
|
||||
import { PopoverClose } from "@radix-ui/react-popover";
|
||||
import { TaskGroups } from "../onboarding/WalletTaskGroups";
|
||||
import { ScrollArea } from "../ui/scroll-area";
|
||||
import { useOnboarding } from "../onboarding/onboarding-provider";
|
||||
import { useCallback, useEffect, useRef } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import * as party from "party-js";
|
||||
|
||||
export default function Wallet() {
|
||||
const { credits, formatCredits, fetchCredits } = useCredits({
|
||||
fetchInitialCredits: true,
|
||||
});
|
||||
const { state, updateState } = useOnboarding();
|
||||
const walletRef = useRef<HTMLButtonElement | null>(null);
|
||||
|
||||
const onWalletOpen = useCallback(async () => {
|
||||
if (state?.notificationDot) {
|
||||
updateState({ notificationDot: false });
|
||||
}
|
||||
// Refresh credits when the wallet is opened
|
||||
fetchCredits();
|
||||
}, [state?.notificationDot, updateState, fetchCredits]);
|
||||
|
||||
const fadeOut = new party.ModuleBuilder()
|
||||
.drive("opacity")
|
||||
.by((t) => 1 - t)
|
||||
.through("lifetime")
|
||||
.build();
|
||||
|
||||
useEffect(() => {
|
||||
// Check if there are any completed tasks (state?.completedTasks) that
|
||||
// are not in the state?.notified array and play confetti if so
|
||||
const pending = state?.completedSteps
|
||||
.filter((step) => !state?.notified.includes(step))
|
||||
// Ignore steps that are not relevant for notifications
|
||||
.filter(
|
||||
(step) =>
|
||||
step !== "WELCOME" &&
|
||||
step !== "USAGE_REASON" &&
|
||||
step !== "INTEGRATIONS" &&
|
||||
step !== "AGENT_CHOICE" &&
|
||||
step !== "AGENT_NEW_RUN" &&
|
||||
step !== "AGENT_INPUT",
|
||||
);
|
||||
if ((pending?.length || 0) > 0 && walletRef.current) {
|
||||
party.confetti(walletRef.current, {
|
||||
count: 30,
|
||||
spread: 120,
|
||||
shapes: ["square", "circle"],
|
||||
size: party.variation.range(1, 2),
|
||||
speed: party.variation.range(200, 300),
|
||||
modules: [fadeOut],
|
||||
});
|
||||
}
|
||||
}, [state?.completedSteps, state?.notified]);
|
||||
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<button
|
||||
ref={walletRef}
|
||||
className="relative flex items-center gap-1 rounded-md bg-zinc-200 px-3 py-2 text-sm transition-colors duration-200 hover:bg-zinc-300"
|
||||
onClick={onWalletOpen}
|
||||
>
|
||||
Wallet{" "}
|
||||
<span className="text-sm font-semibold">
|
||||
{formatCredits(credits)}
|
||||
</span>
|
||||
{state?.notificationDot && (
|
||||
<span className="absolute right-1 top-1 h-2 w-2 rounded-full bg-violet-600"></span>
|
||||
)}
|
||||
</button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
className={cn(
|
||||
"absolute -right-[7.9rem] -top-[3.2rem] z-50 w-[28.5rem] px-[0.625rem] py-2",
|
||||
"rounded-xl border-zinc-200 bg-zinc-50 shadow-[0_3px_3px] shadow-zinc-300",
|
||||
)}
|
||||
>
|
||||
<div>
|
||||
<div className="mx-1 flex items-center justify-between border-b border-zinc-300 pb-2">
|
||||
<span className="font-poppins font-medium text-zinc-900">
|
||||
Your wallet
|
||||
</span>
|
||||
<div className="flex items-center font-inter text-sm font-semibold text-violet-700">
|
||||
<div className="rounded-lg bg-violet-100 px-3 py-2">
|
||||
Wallet{" "}
|
||||
<span className="font-semibold">{formatCredits(credits)}</span>
|
||||
</div>
|
||||
<PopoverClose>
|
||||
<X className="ml-[2.8rem] h-5 w-5 text-zinc-800 hover:text-foreground" />
|
||||
</PopoverClose>
|
||||
</div>
|
||||
</div>
|
||||
<p className="mx-1 mt-3 font-inter text-xs text-muted-foreground text-zinc-400">
|
||||
Complete the following tasks to earn more credits!
|
||||
</p>
|
||||
</div>
|
||||
<ScrollArea className="max-h-[80vh] overflow-y-auto">
|
||||
<TaskGroups />
|
||||
</ScrollArea>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
@@ -49,7 +49,7 @@ export const AgentsSection: React.FC<AgentsSectionProps> = ({
|
||||
<div className="flex flex-col items-center justify-center">
|
||||
<div className="w-full max-w-[1360px]">
|
||||
<div
|
||||
className={`mb-[${margin}] font-poppins text-[18px] font-[600] leading-7 text-[#282828] dark:text-neutral-200`}
|
||||
className={`mb-[${margin}] font-poppins text-lg font-semibold text-[#282828] dark:text-neutral-200`}
|
||||
>
|
||||
{sectionTitle}
|
||||
</div>
|
||||
|
||||
@@ -33,7 +33,7 @@ export const FeaturedCreators: React.FC<FeaturedCreatorsProps> = ({
|
||||
return (
|
||||
<div className="flex w-full flex-col items-center justify-center">
|
||||
<div className="w-full max-w-[1360px]">
|
||||
<h2 className="mb-[37px] font-poppins text-2xl font-semibold leading-7 text-neutral-800 dark:text-neutral-200">
|
||||
<h2 className="mb-9 font-poppins text-lg font-semibold text-neutral-800 dark:text-neutral-200">
|
||||
{title}
|
||||
</h2>
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ export const FeaturedSection: React.FC<FeaturedSectionProps> = ({
|
||||
};
|
||||
|
||||
return (
|
||||
<section className="w-full pb-16">
|
||||
<section className="w-full">
|
||||
<h2 className="mb-8 font-poppins text-2xl font-semibold leading-7 text-neutral-800 dark:text-neutral-200">
|
||||
Featured agents
|
||||
</h2>
|
||||
|
||||
@@ -4,9 +4,16 @@ import * as React from "react";
|
||||
import { SearchBar } from "@/components/agptui/SearchBar";
|
||||
import { FilterChips } from "@/components/agptui/FilterChips";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
|
||||
export const HeroSection: React.FC = () => {
|
||||
const router = useRouter();
|
||||
const { completeStep } = useOnboarding();
|
||||
|
||||
// Mark marketplace visit task as completed
|
||||
React.useEffect(() => {
|
||||
completeStep("MARKETPLACE_VISIT");
|
||||
}, [completeStep]);
|
||||
|
||||
function onFilterChange(selectedFilters: string[]) {
|
||||
const encodedTerm = encodeURIComponent(selectedFilters.join(", "));
|
||||
|
||||
@@ -49,7 +49,7 @@ export default function LibrarySearchBar(): React.ReactNode {
|
||||
onFocus={() => setIsFocused(true)}
|
||||
onBlur={() => !inputRef.current?.value && setIsFocused(false)}
|
||||
onChange={handleSearchInput}
|
||||
className="flex-1 border-none font-sans text-[16px] font-normal leading-7 shadow-none focus:shadow-none"
|
||||
className="flex-1 border-none font-sans text-[16px] font-normal leading-7 shadow-none focus:shadow-none focus:ring-0"
|
||||
type="text"
|
||||
placeholder="Search agents"
|
||||
/>
|
||||
|
||||
@@ -0,0 +1,331 @@
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { ChevronDown, Check } from "lucide-react";
|
||||
import { OnboardingStep } from "@/lib/autogpt-server-api";
|
||||
import { useOnboarding } from "./onboarding-provider";
|
||||
import { cn } from "@/lib/utils";
|
||||
import * as party from "party-js";
|
||||
|
||||
interface Task {
|
||||
id: OnboardingStep;
|
||||
name: string;
|
||||
amount: number;
|
||||
details: string;
|
||||
video?: string;
|
||||
}
|
||||
|
||||
interface TaskGroup {
|
||||
name: string;
|
||||
tasks: Task[];
|
||||
isOpen: boolean;
|
||||
}
|
||||
|
||||
export function TaskGroups() {
|
||||
const [groups, setGroups] = useState<TaskGroup[]>([
|
||||
{
|
||||
name: "Run your first agent",
|
||||
isOpen: false,
|
||||
tasks: [
|
||||
{
|
||||
id: "CONGRATS",
|
||||
name: "Finish onboarding",
|
||||
amount: 3,
|
||||
details: "Go through our step by step tutorial",
|
||||
},
|
||||
{
|
||||
id: "GET_RESULTS",
|
||||
name: "Get results from first agent",
|
||||
amount: 3,
|
||||
details:
|
||||
"Sit back and relax - your agent is running and will finish soon! See the results in the Library once it's done",
|
||||
video: "/onboarding/get-results.mp4",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: "Explore the Marketplace",
|
||||
isOpen: false,
|
||||
tasks: [
|
||||
{
|
||||
id: "MARKETPLACE_VISIT",
|
||||
name: "Go to Marketplace",
|
||||
amount: 0,
|
||||
details: "Click Marketplace in the top navigation",
|
||||
video: "/onboarding/marketplace-visit.mp4",
|
||||
},
|
||||
{
|
||||
id: "MARKETPLACE_ADD_AGENT",
|
||||
name: "Find an agent",
|
||||
amount: 1,
|
||||
details:
|
||||
"Search for an agent in the Marketplace, like a code generator or research assistant and add it to your Library",
|
||||
video: "/onboarding/marketplace-add.mp4",
|
||||
},
|
||||
{
|
||||
id: "MARKETPLACE_RUN_AGENT",
|
||||
name: "Try out your agent",
|
||||
amount: 1,
|
||||
details:
|
||||
"Run the agent you found in the Marketplace from the Library - whether it's a writing assistant, data analyzer, or something else",
|
||||
video: "/onboarding/marketplace-run.mp4",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: "Build your own agent",
|
||||
isOpen: false,
|
||||
tasks: [
|
||||
{
|
||||
id: "BUILDER_OPEN",
|
||||
name: "Open the Builder",
|
||||
amount: 0,
|
||||
details: "Click Builder in the top navigation",
|
||||
video: "/onboarding/builder-open.mp4",
|
||||
},
|
||||
{
|
||||
id: "BUILDER_SAVE_AGENT",
|
||||
name: "Place your first blocks and save your agent",
|
||||
amount: 1,
|
||||
details:
|
||||
"Open block library on the left and add a block to the canvas then save your agent",
|
||||
video: "/onboarding/builder-save.mp4",
|
||||
},
|
||||
{
|
||||
id: "BUILDER_RUN_AGENT",
|
||||
name: "Run your agent",
|
||||
amount: 1,
|
||||
details: "Run your agent from the Builder",
|
||||
video: "/onboarding/builder-run.mp4",
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
const { state, updateState } = useOnboarding();
|
||||
const refs = useRef<Record<string, HTMLDivElement | null>>({});
|
||||
|
||||
const toggleGroup = useCallback((name: string) => {
|
||||
setGroups((prevGroups) =>
|
||||
prevGroups.map((group) =>
|
||||
group.name === name ? { ...group, isOpen: !group.isOpen } : group,
|
||||
),
|
||||
);
|
||||
}, []);
|
||||
|
||||
const isTaskCompleted = useCallback(
|
||||
(task: Task) => {
|
||||
return state?.completedSteps?.includes(task.id) || false;
|
||||
},
|
||||
[state?.completedSteps],
|
||||
);
|
||||
|
||||
const getCompletedCount = useCallback(
|
||||
(tasks: Task[]) => {
|
||||
return tasks.filter((task) => isTaskCompleted(task)).length;
|
||||
},
|
||||
[isTaskCompleted],
|
||||
);
|
||||
|
||||
const isGroupCompleted = useCallback(
|
||||
(group: TaskGroup) => {
|
||||
return group.tasks.every((task) => isTaskCompleted(task));
|
||||
},
|
||||
[isTaskCompleted],
|
||||
);
|
||||
|
||||
const setRef = (name: string) => (el: HTMLDivElement | null) => {
|
||||
if (el) {
|
||||
refs.current[name] = el;
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
groups.forEach((group) => {
|
||||
const groupCompleted = isGroupCompleted(group);
|
||||
// Check if the last task in the group is completed
|
||||
const alreadyCelebrated = state?.notified.includes(
|
||||
group.tasks[group.tasks.length - 1].id,
|
||||
);
|
||||
|
||||
if (groupCompleted) {
|
||||
const el = refs.current[group.name];
|
||||
if (el && !alreadyCelebrated) {
|
||||
party.confetti(el, {
|
||||
count: 50,
|
||||
spread: 120,
|
||||
shapes: ["square", "circle"],
|
||||
size: party.variation.range(1, 2),
|
||||
speed: party.variation.range(200, 300),
|
||||
});
|
||||
// Update the state to include all group tasks as notified
|
||||
// This ensures that the confetti effect isn't perpetually triggered on Wallet
|
||||
const notifiedTasks = group.tasks.map((task) => task.id);
|
||||
updateState({
|
||||
notified: [...(state?.notified || []), ...notifiedTasks],
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
group.tasks.forEach((task) => {
|
||||
const el = refs.current[task.id];
|
||||
if (el && isTaskCompleted(task) && !state?.notified.includes(task.id)) {
|
||||
party.confetti(el, {
|
||||
count: 40,
|
||||
spread: 120,
|
||||
shapes: ["square", "circle"],
|
||||
size: party.variation.range(1, 1.5),
|
||||
speed: party.variation.range(200, 300),
|
||||
});
|
||||
// Update the state to include the task as notified
|
||||
updateState({ notified: [...(state?.notified || []), task.id] });
|
||||
}
|
||||
});
|
||||
});
|
||||
}, [state?.completedSteps]);
|
||||
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
{groups.map((group) => (
|
||||
<div
|
||||
key={group.name}
|
||||
ref={setRef(group.name)}
|
||||
className="mt-3 overflow-hidden rounded-lg border border-zinc-200 bg-zinc-100"
|
||||
>
|
||||
{/* Group Header - unchanged */}
|
||||
<div
|
||||
className="flex cursor-pointer items-center justify-between p-3"
|
||||
onClick={() => toggleGroup(group.name)}
|
||||
>
|
||||
{/* Name and completed count */}
|
||||
<div className="flex-1">
|
||||
<div
|
||||
className={cn(
|
||||
"text-sm font-medium text-zinc-900",
|
||||
isGroupCompleted(group) ? "text-zinc-600 line-through" : "",
|
||||
)}
|
||||
>
|
||||
{group.name}
|
||||
</div>
|
||||
<div
|
||||
className={cn(
|
||||
"mt-1 text-xs font-normal leading-tight text-zinc-500",
|
||||
isGroupCompleted(group) ? "line-through" : "",
|
||||
)}
|
||||
>
|
||||
{getCompletedCount(group.tasks)} of {group.tasks.length}{" "}
|
||||
completed
|
||||
</div>
|
||||
</div>
|
||||
{/* Reward and chevron */}
|
||||
<div className="flex items-center gap-2">
|
||||
<div
|
||||
className={cn(
|
||||
"text-xs font-medium leading-tight text-violet-600",
|
||||
isGroupCompleted(group) ? "line-through" : "",
|
||||
)}
|
||||
>
|
||||
$
|
||||
{group.tasks
|
||||
.reduce((sum, task) => sum + task.amount, 0)
|
||||
.toFixed(2)}
|
||||
</div>
|
||||
<ChevronDown
|
||||
className={`h-5 w-5 text-slate-950 transition-transform duration-300 ease-in-out ${
|
||||
group.isOpen ? "rotate-180" : ""
|
||||
}`}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Tasks */}
|
||||
<div
|
||||
className={cn(
|
||||
"overflow-hidden transition-all duration-300 ease-in-out",
|
||||
group.isOpen || !isGroupCompleted(group)
|
||||
? "max-h-[1000px] opacity-100"
|
||||
: "max-h-0 opacity-0",
|
||||
)}
|
||||
>
|
||||
{group.tasks.map((task) => (
|
||||
<div
|
||||
key={task.id}
|
||||
ref={setRef(task.id)}
|
||||
className="mx-3 border-t border-zinc-300 px-1 pb-1 pt-3"
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
{/* Checkmark and name */}
|
||||
<div className="flex items-center gap-2">
|
||||
<div
|
||||
className={cn(
|
||||
"flex h-4 w-4 items-center justify-center rounded-full border",
|
||||
isTaskCompleted(task)
|
||||
? "border-emerald-600"
|
||||
: "border-zinc-600",
|
||||
)}
|
||||
>
|
||||
{isTaskCompleted(task) && (
|
||||
<Check className="h-3 w-3 text-emerald-600" />
|
||||
)}
|
||||
</div>
|
||||
<span
|
||||
className={cn(
|
||||
"text-sm font-normal",
|
||||
isTaskCompleted(task)
|
||||
? "text-zinc-500 line-through"
|
||||
: "text-zinc-800",
|
||||
)}
|
||||
>
|
||||
{task.name}
|
||||
</span>
|
||||
</div>
|
||||
{/* Reward */}
|
||||
{task.amount > 0 && (
|
||||
<span
|
||||
className={cn(
|
||||
"text-xs font-normal text-zinc-500",
|
||||
isTaskCompleted(task) ? "line-through" : "",
|
||||
)}
|
||||
>
|
||||
${task.amount.toFixed(2)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Details section */}
|
||||
<div
|
||||
className={cn(
|
||||
"mt-2 overflow-hidden pl-6 text-xs font-normal text-zinc-500 transition-all duration-300 ease-in-out",
|
||||
isTaskCompleted(task) && "line-through",
|
||||
group.isOpen
|
||||
? "max-h-[100px] opacity-100"
|
||||
: "max-h-0 opacity-0",
|
||||
)}
|
||||
>
|
||||
{task.details}
|
||||
</div>
|
||||
{task.video && (
|
||||
<div
|
||||
className={cn(
|
||||
"relative mx-6 aspect-video overflow-hidden rounded-lg transition-all duration-300 ease-in-out",
|
||||
group.isOpen
|
||||
? "my-2 max-h-[200px] opacity-100"
|
||||
: "max-h-0 opacity-0",
|
||||
)}
|
||||
>
|
||||
<video
|
||||
src={task.video}
|
||||
autoPlay
|
||||
loop
|
||||
muted
|
||||
playsInline
|
||||
className={cn("h-full w-full object-cover object-center")}
|
||||
></video>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -14,9 +14,12 @@ import {
|
||||
const OnboardingContext = createContext<
|
||||
| {
|
||||
state: UserOnboarding | null;
|
||||
updateState: (state: Partial<UserOnboarding>) => void;
|
||||
updateState: (
|
||||
state: Omit<Partial<UserOnboarding>, "rewardedFor">,
|
||||
) => void;
|
||||
step: number;
|
||||
setStep: (step: number) => void;
|
||||
completeStep: (step: OnboardingStep) => void;
|
||||
}
|
||||
| undefined
|
||||
>(undefined);
|
||||
@@ -84,19 +87,23 @@ export default function OnboardingProvider({
|
||||
}, [api, pathname, router]);
|
||||
|
||||
const updateState = useCallback(
|
||||
(newState: Partial<UserOnboarding>) => {
|
||||
(newState: Omit<Partial<UserOnboarding>, "rewardedFor">) => {
|
||||
setState((prev) => {
|
||||
api.updateUserOnboarding({ ...prev, ...newState });
|
||||
api.updateUserOnboarding(newState);
|
||||
|
||||
if (!prev) {
|
||||
// Handle initial state
|
||||
return {
|
||||
completedSteps: [],
|
||||
notificationDot: false,
|
||||
notified: [],
|
||||
rewardedFor: [],
|
||||
usageReason: null,
|
||||
integrations: [],
|
||||
otherIntegrations: null,
|
||||
selectedStoreListingVersionId: null,
|
||||
agentInput: null,
|
||||
onboardingAgentExecutionId: null,
|
||||
...newState,
|
||||
};
|
||||
}
|
||||
@@ -106,8 +113,21 @@ export default function OnboardingProvider({
|
||||
[api, setState],
|
||||
);
|
||||
|
||||
const completeStep = useCallback(
|
||||
(step: OnboardingStep) => {
|
||||
if (!state || state.completedSteps.includes(step)) return;
|
||||
|
||||
updateState({
|
||||
completedSteps: [...state.completedSteps, step],
|
||||
});
|
||||
},
|
||||
[api, state],
|
||||
);
|
||||
|
||||
return (
|
||||
<OnboardingContext.Provider value={{ state, updateState, step, setStep }}>
|
||||
<OnboardingContext.Provider
|
||||
value={{ state, updateState, step, setStep, completeStep }}
|
||||
>
|
||||
{children}
|
||||
</OnboardingContext.Provider>
|
||||
);
|
||||
|
||||
@@ -213,7 +213,7 @@ const CarouselPrevious = React.forwardRef<
|
||||
className={cn(
|
||||
"absolute h-[52px] w-[52px] rounded-full",
|
||||
orientation === "horizontal"
|
||||
? "right-24 top-0"
|
||||
? "right-20 top-0"
|
||||
: "-top-12 left-1/2 -translate-x-1/2 rotate-90",
|
||||
className,
|
||||
)}
|
||||
|
||||
@@ -24,6 +24,7 @@ import { useToast } from "@/components/ui/use-toast";
|
||||
import { InputItem } from "@/components/RunnerUIWrapper";
|
||||
import { GraphMeta } from "@/lib/autogpt-server-api";
|
||||
import { default as NextLink } from "next/link";
|
||||
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
|
||||
const ajv = new Ajv({ strict: false, allErrors: true });
|
||||
|
||||
@@ -77,6 +78,7 @@ export default function useAgentGraph(
|
||||
useState(false);
|
||||
const [nodes, setNodes] = useState<CustomNode[]>([]);
|
||||
const [edges, setEdges] = useState<CustomEdge[]>([]);
|
||||
const { state, completeStep } = useOnboarding();
|
||||
|
||||
const api = useMemo(
|
||||
() => new BackendAPI(process.env.NEXT_PUBLIC_AGPT_SERVER_URL!),
|
||||
@@ -576,6 +578,9 @@ export default function useAgentGraph(
|
||||
path.set("flowVersion", savedAgent.version.toString());
|
||||
path.set("flowExecutionID", graphExecution.graph_exec_id);
|
||||
router.push(`${pathname}?${path.toString()}`);
|
||||
if (state?.completedSteps.includes("BUILDER_SAVE_AGENT")) {
|
||||
completeStep("BUILDER_RUN_AGENT");
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
const errorMessage =
|
||||
@@ -966,6 +971,7 @@ export default function useAgentGraph(
|
||||
const saveAgent = useCallback(async () => {
|
||||
try {
|
||||
await _saveAgent();
|
||||
completeStep("BUILDER_SAVE_AGENT");
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
|
||||
@@ -180,7 +180,9 @@ export default class BackendAPI {
|
||||
return this._get("/onboarding");
|
||||
}
|
||||
|
||||
updateUserOnboarding(onboarding: Partial<UserOnboarding>): Promise<void> {
|
||||
updateUserOnboarding(
|
||||
onboarding: Omit<Partial<UserOnboarding>, "rewardedFor">,
|
||||
): Promise<void> {
|
||||
return this._request("PATCH", "/onboarding", onboarding);
|
||||
}
|
||||
|
||||
|
||||
@@ -801,15 +801,26 @@ export type OnboardingStep =
|
||||
| "AGENT_CHOICE"
|
||||
| "AGENT_NEW_RUN"
|
||||
| "AGENT_INPUT"
|
||||
| "CONGRATS";
|
||||
| "CONGRATS"
|
||||
| "GET_RESULTS"
|
||||
| "MARKETPLACE_VISIT"
|
||||
| "MARKETPLACE_ADD_AGENT"
|
||||
| "MARKETPLACE_RUN_AGENT"
|
||||
| "BUILDER_OPEN"
|
||||
| "BUILDER_SAVE_AGENT"
|
||||
| "BUILDER_RUN_AGENT";
|
||||
|
||||
export interface UserOnboarding {
|
||||
completedSteps: OnboardingStep[];
|
||||
notificationDot: boolean;
|
||||
notified: OnboardingStep[];
|
||||
rewardedFor: OnboardingStep[];
|
||||
usageReason: string | null;
|
||||
integrations: string[];
|
||||
otherIntegrations: string | null;
|
||||
selectedStoreListingVersionId: string | null;
|
||||
agentInput: { [key: string]: string } | null;
|
||||
agentInput: { [key: string]: string | number } | null;
|
||||
onboardingAgentExecutionId: GraphExecutionID | null;
|
||||
}
|
||||
|
||||
/* *** UTILITIES *** */
|
||||
|
||||
@@ -5026,11 +5026,6 @@ caniuse-lite@^1.0.30001579, caniuse-lite@^1.0.30001688:
|
||||
resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001697.tgz#040bbbb54463c4b4b3377c716b34a322d16e6fc7"
|
||||
integrity sha512-GwNPlWJin8E+d7Gxq96jxM6w0w+VFeyyXRsjU58emtkYqnbwHqXm5uT2uCmO0RQE9htWknOP4xtBlLmM/gWxvQ==
|
||||
|
||||
canvas-confetti@^1.9.3:
|
||||
version "1.9.3"
|
||||
resolved "https://registry.yarnpkg.com/canvas-confetti/-/canvas-confetti-1.9.3.tgz#ef4c857420ad8045ab4abe8547261c8cdf229845"
|
||||
integrity sha512-rFfTURMvmVEX1gyXFgn5QMn81bYk70qa0HLzcIOSVEyl57n6o9ItHeBtUSWdvKAPY0xlvBHno4/v3QPrT83q9g==
|
||||
|
||||
case-sensitive-paths-webpack-plugin@^2.4.0:
|
||||
version "2.4.0"
|
||||
resolved "https://registry.yarnpkg.com/case-sensitive-paths-webpack-plugin/-/case-sensitive-paths-webpack-plugin-2.4.0.tgz#db64066c6422eed2e08cc14b986ca43796dbc6d4"
|
||||
@@ -9590,6 +9585,11 @@ parse-passwd@^1.0.0:
|
||||
resolved "https://registry.yarnpkg.com/parse-passwd/-/parse-passwd-1.0.0.tgz#6d5b934a456993b23d37f40a382d6f1666a8e5c6"
|
||||
integrity sha512-1Y1A//QUXEZK7YKz+rD9WydcE1+EuPr6ZBgKecAB8tmoW6UFv0NREVJe1p+jRxtThkcbbKkfwIbWJe/IeE6m2Q==
|
||||
|
||||
party-js@^2.2.0:
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/party-js/-/party-js-2.2.0.tgz#3340026971c9e62fd34db102daaa645fbc9130b8"
|
||||
integrity sha512-50hGuALCpvDTrQLPQ1fgUgxKIWAH28ShVkmeK/3zhO0YJyCqkhrZhQEkWPxDYLvbFJ7YAXyROmFEu35gKpZLtQ==
|
||||
|
||||
pascal-case@^3.1.2:
|
||||
version "3.1.2"
|
||||
resolved "https://registry.yarnpkg.com/pascal-case/-/pascal-case-3.1.2.tgz#b48e0ef2b98e205e7c1dae747d0b1508237660eb"
|
||||
|
||||
Reference in New Issue
Block a user