Merge branch 'dev' into abhi-9708/add-better-skeleton-on-agent-run-page

This commit is contained in:
Abhimanyu Yadav
2025-04-21 09:21:07 +05:30
committed by GitHub
120 changed files with 3506 additions and 1546 deletions

View File

@@ -34,6 +34,7 @@ jobs:
python -m prisma migrate deploy
env:
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
trigger:

View File

@@ -36,6 +36,7 @@ jobs:
python -m prisma migrate deploy
env:
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
trigger:
needs: migrate

View File

@@ -135,6 +135,7 @@ jobs:
run: poetry run prisma migrate dev --name updates
env:
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
- id: lint
name: Run Linter
@@ -151,12 +152,13 @@ jobs:
env:
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
REDIS_HOST: 'localhost'
REDIS_PORT: '6379'
REDIS_PASSWORD: 'testpassword'
REDIS_HOST: "localhost"
REDIS_PORT: "6379"
REDIS_PASSWORD: "testpassword"
env:
CI: true
@@ -169,8 +171,8 @@ jobs:
# If you want to replace this, you can do so by making our entire system generate
# new credentials for each local user and update the environment variables in
# the backend service, docker composes, and examples
RABBITMQ_DEFAULT_USER: 'rabbitmq_user_default'
RABBITMQ_DEFAULT_PASS: 'k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7'
RABBITMQ_DEFAULT_USER: "rabbitmq_user_default"
RABBITMQ_DEFAULT_PASS: "k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7"
# - name: Upload coverage reports to Codecov
# uses: codecov/codecov-action@v4

View File

@@ -1,20 +1,59 @@
import inspect
import threading
from typing import Callable, ParamSpec, TypeVar
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
P = ParamSpec("P")
R = TypeVar("R")
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
@overload
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
@overload
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
def thread_cached(
func: Callable[P, R] | Callable[P, Awaitable[R]],
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
thread_local = threading.local()
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
def _clear():
if hasattr(thread_local, "cache"):
del thread_local.cache
return wrapper
if inspect.iscoroutinefunction(func):
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
*args, **kwargs
)
return cache[key]
setattr(async_wrapper, "clear_cache", _clear)
return async_wrapper
else:
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
setattr(sync_wrapper, "clear_cache", _clear)
return sync_wrapper
def clear_thread_cache(func: Callable) -> None:
if clear := getattr(func, "clear_cache", None):
clear()

View File

@@ -8,6 +8,7 @@ DB_CONNECT_TIMEOUT=60
DB_POOL_TIMEOUT=300
DB_SCHEMA=platform
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
PRISMA_SCHEMA="postgres/schema.prisma"
# EXECUTOR

View File

@@ -73,7 +73,6 @@ FROM server_dependencies AS server
COPY autogpt_platform/backend /app/autogpt_platform/backend
RUN poetry install --no-ansi --only-root
ENV DATABASE_URL=""
ENV PORT=8000
CMD ["poetry", "run", "rest"]

View File

@@ -1,8 +1,6 @@
import logging
from typing import Any
from autogpt_libs.utils.cache import thread_cached
from backend.data.block import (
Block,
BlockCategory,
@@ -19,21 +17,6 @@ from backend.util import json
logger = logging.getLogger(__name__)
@thread_cached
def get_executor_manager_client():
from backend.executor import ExecutionManager
from backend.util.service import get_service_client
return get_service_client(ExecutionManager)
@thread_cached
def get_event_bus():
from backend.data.execution import RedisExecutionEventBus
return RedisExecutionEventBus()
class AgentExecutorBlock(Block):
class Input(BlockSchema):
user_id: str = SchemaField(description="User ID")
@@ -76,23 +59,23 @@ class AgentExecutorBlock(Block):
def run(self, input_data: Input, **kwargs) -> BlockOutput:
from backend.data.execution import ExecutionEventType
from backend.executor import utils as execution_utils
executor_manager = get_executor_manager_client()
event_bus = get_event_bus()
event_bus = execution_utils.get_execution_event_bus()
graph_exec = executor_manager.add_execution(
graph_exec = execution_utils.add_graph_execution(
graph_id=input_data.graph_id,
graph_version=input_data.graph_version,
user_id=input_data.user_id,
data=input_data.data,
inputs=input_data.data,
)
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.graph_exec_id}"
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.id}"
logger.info(f"Starting execution of {log_id}")
for event in event_bus.listen(
user_id=graph_exec.user_id,
graph_id=graph_exec.graph_id,
graph_exec_id=graph_exec.graph_exec_id,
graph_exec_id=graph_exec.id,
):
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
if event.status in [
@@ -105,7 +88,7 @@ class AgentExecutorBlock(Block):
else:
continue
logger.info(
logger.debug(
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
)
@@ -123,5 +106,7 @@ class AgentExecutorBlock(Block):
continue
for output_data in event.output_data.get("output", []):
logger.info(f"Execution {log_id} produced {output_name}: {output_data}")
logger.debug(
f"Execution {log_id} produced {output_name}: {output_data}"
)
yield output_name, output_data

View File

@@ -9,7 +9,6 @@ from typing import Any, Iterable, List, Literal, NamedTuple, Optional
import anthropic
import ollama
import openai
from anthropic import NotGiven
from anthropic.types import ToolParam
from groq import Groq
from pydantic import BaseModel, SecretStr
@@ -249,7 +248,7 @@ class LLMResponse(BaseModel):
def convert_openai_tool_fmt_to_anthropic(
openai_tools: list[dict] | None = None,
) -> Iterable[ToolParam] | NotGiven:
) -> Iterable[ToolParam] | anthropic.NotGiven:
"""
Convert OpenAI tool format to Anthropic tool format.
"""
@@ -287,6 +286,7 @@ def llm_call(
max_tokens: int | None,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
parallel_tool_calls: bool | None = None,
) -> LLMResponse:
"""
Make a call to a language model.
@@ -332,6 +332,9 @@ def llm_call(
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=(
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
),
)
if response.choices[0].message.tool_calls:
@@ -487,6 +490,9 @@ def llm_call(
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=(
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
),
)
# If there's no response, raise an error

View File

@@ -491,6 +491,7 @@ class SmartDecisionMakerBlock(Block):
max_tokens=input_data.max_tokens,
tools=tool_functions,
ollama_host=input_data.ollama_host,
parallel_tool_calls=False,
)
if not response.tool_calls:

View File

@@ -28,6 +28,7 @@ from backend.util.settings import Config
from .model import (
ContributorDetails,
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
is_credentials_field_name,
)
@@ -203,6 +204,15 @@ class BlockSchema(BaseModel):
)
}
@classmethod
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
return {
field_name: CredentialsFieldInfo.model_validate(
cls.get_field_schema(field_name), by_alias=True
)
for field_name in cls.get_credentials_fields().keys()
}
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data # Return as is, by default.
@@ -509,6 +519,7 @@ async def initialize_blocks() -> None:
)
def get_block(block_id: str) -> Block | None:
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
cls = get_blocks().get(block_id)
return cls() if cls else None

View File

@@ -11,6 +11,7 @@ from prisma.enums import (
CreditRefundRequestStatus,
CreditTransactionType,
NotificationType,
OnboardingStep,
)
from prisma.errors import UniqueViolationError
from prisma.models import CreditRefundRequest, CreditTransaction, User
@@ -121,6 +122,18 @@ class UserCreditBase(ABC):
"""
pass
@abstractmethod
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
"""
Reward the user with credits for completing an onboarding step.
Won't reward if the user has already received credits for the step.
Args:
user_id (str): The user ID.
step (OnboardingStep): The onboarding step.
"""
pass
@abstractmethod
async def top_up_intent(self, user_id: str, amount: int) -> str:
"""
@@ -213,7 +226,7 @@ class UserCreditBase(ABC):
"userId": user_id,
"createdAt": {"lte": top_time},
"isActive": True,
"runningBalance": {"not": None}, # type: ignore
"NOT": [{"runningBalance": None}],
},
order={"createdAt": "desc"},
)
@@ -408,6 +421,24 @@ class UserCredit(UserCreditBase):
async def top_up_credits(self, user_id: str, amount: int):
await self._top_up_credits(user_id, amount)
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
key = f"REWARD-{user_id}-{step.value}"
if not await CreditTransaction.prisma().find_first(
where={
"userId": user_id,
"transactionKey": key,
}
):
await self._add_transaction(
user_id=user_id,
amount=credits,
transaction_type=CreditTransactionType.GRANT,
transaction_key=key,
metadata=Json(
{"reason": f"Reward for completing {step.value} onboarding step."}
),
)
async def top_up_refund(
self, user_id: str, transaction_key: str, metadata: dict[str, str]
) -> int:
@@ -895,6 +926,9 @@ class DisabledUserCredit(UserCreditBase):
async def top_up_credits(self, *args, **kwargs):
pass
async def onboarding_reward(self, *args, **kwargs):
pass
async def top_up_intent(self, *args, **kwargs) -> str:
return ""

View File

@@ -62,10 +62,10 @@ async def connect():
# Connection acquired from a pool like Supabase somehow still possibly allows
# the db client obtains a connection but still reject query connection afterward.
try:
await prisma.execute_raw("SELECT 1")
except Exception as e:
raise ConnectionError("Failed to connect to Prisma.") from e
# try:
# await prisma.execute_raw("SELECT 1")
# except Exception as e:
# raise ConnectionError("Failed to connect to Prisma.") from e
@conn_retry("Prisma", "Releasing connection")
@@ -89,7 +89,7 @@ async def transaction():
async def locked_transaction(key: str):
lock_key = zlib.crc32(key.encode("utf-8"))
async with transaction() as tx:
await tx.execute_raw(f"SELECT pg_advisory_xact_lock({lock_key})")
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
yield tx

View File

@@ -23,6 +23,7 @@ from prisma.models import (
AgentNodeExecutionInputOutput,
)
from prisma.types import (
AgentGraphExecutionCreateInput,
AgentGraphExecutionWhereInput,
AgentNodeExecutionCreateInput,
AgentNodeExecutionInputOutputCreateInput,
@@ -33,18 +34,17 @@ from pydantic import BaseModel
from pydantic.fields import Field
from backend.server.v2.store.exceptions import DatabaseError
from backend.util import mock
from backend.util import type as type_utils
from backend.util.settings import Config
from .block import BlockData, BlockInput, BlockType, CompletedBlockOutput, get_block
from .block import BlockInput, BlockType, CompletedBlockOutput, get_block
from .db import BaseDbModel
from .includes import (
EXECUTION_RESULT_INCLUDE,
GRAPH_EXECUTION_INCLUDE,
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
)
from .model import GraphExecutionStats, NodeExecutionStats
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
from .queue import AsyncRedisEventBus, RedisEventBus
T = TypeVar("T")
@@ -121,7 +121,7 @@ class GraphExecution(GraphExecutionMeta):
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
if _graph_exec.AgentNodeExecutions is None:
if _graph_exec.NodeExecutions is None:
raise ValueError("Node executions must be included in query")
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
@@ -129,7 +129,7 @@ class GraphExecution(GraphExecutionMeta):
complete_node_executions = sorted(
[
NodeExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.AgentNodeExecutions
for ne in _graph_exec.NodeExecutions
if ne.executionStatus != ExecutionStatus.INCOMPLETE
],
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
@@ -181,7 +181,7 @@ class GraphExecutionWithNodes(GraphExecution):
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
if _graph_exec.AgentNodeExecutions is None:
if _graph_exec.NodeExecutions is None:
raise ValueError("Node executions must be included in query")
graph_exec_with_io = GraphExecution.from_db(_graph_exec)
@@ -189,7 +189,7 @@ class GraphExecutionWithNodes(GraphExecution):
node_executions = sorted(
[
NodeExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.AgentNodeExecutions
for ne in _graph_exec.NodeExecutions
],
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
)
@@ -202,6 +202,27 @@ class GraphExecutionWithNodes(GraphExecution):
node_executions=node_executions,
)
def to_graph_execution_entry(self):
return GraphExecutionEntry(
user_id=self.user_id,
graph_id=self.graph_id,
graph_version=self.graph_version or 0,
graph_exec_id=self.id,
start_node_execs=[
NodeExecutionEntry(
user_id=self.user_id,
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
data=node_exec.input_data,
)
for node_exec in self.node_executions
],
node_credentials_input_map={}, # FIXME
)
class NodeExecutionResult(BaseModel):
user_id: str
@@ -220,21 +241,21 @@ class NodeExecutionResult(BaseModel):
end_time: datetime | None
@staticmethod
def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None):
if execution.executionData:
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
if _node_exec.executionData:
# Execution that has been queued for execution will persist its data.
input_data = type_utils.convert(execution.executionData, dict[str, Any])
input_data = type_utils.convert(_node_exec.executionData, dict[str, Any])
else:
# For incomplete execution, executionData will not be yet available.
input_data: BlockInput = defaultdict()
for data in execution.Input or []:
for data in _node_exec.Input or []:
input_data[data.name] = type_utils.convert(data.data, type[Any])
output_data: CompletedBlockOutput = defaultdict(list)
for data in execution.Output or []:
for data in _node_exec.Output or []:
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution
graph_execution: AgentGraphExecution | None = _node_exec.GraphExecution
if graph_execution:
user_id = graph_execution.userId
elif not user_id:
@@ -246,17 +267,17 @@ class NodeExecutionResult(BaseModel):
user_id=user_id,
graph_id=graph_execution.agentGraphId if graph_execution else "",
graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
graph_exec_id=execution.agentGraphExecutionId,
block_id=execution.AgentNode.agentBlockId if execution.AgentNode else "",
node_exec_id=execution.id,
node_id=execution.agentNodeId,
status=execution.executionStatus,
graph_exec_id=_node_exec.agentGraphExecutionId,
block_id=_node_exec.Node.agentBlockId if _node_exec.Node else "",
node_exec_id=_node_exec.id,
node_id=_node_exec.agentNodeId,
status=_node_exec.executionStatus,
input_data=input_data,
output_data=output_data,
add_time=execution.addedTime,
queue_time=execution.queuedTime,
start_time=execution.startedTime,
end_time=execution.endedTime,
add_time=_node_exec.addedTime,
queue_time=_node_exec.queuedTime,
start_time=_node_exec.startedTime,
end_time=_node_exec.endedTime,
)
@@ -341,7 +362,7 @@ async def get_graph_execution(
async def create_graph_execution(
graph_id: str,
graph_version: int,
nodes_input: list[tuple[str, BlockInput]],
starting_nodes_input: list[tuple[str, BlockInput]],
user_id: str,
preset_id: str | None = None,
) -> GraphExecutionWithNodes:
@@ -351,29 +372,29 @@ async def create_graph_execution(
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
"""
result = await AgentGraphExecution.prisma().create(
data={
"agentGraphId": graph_id,
"agentGraphVersion": graph_version,
"executionStatus": ExecutionStatus.QUEUED,
"AgentNodeExecutions": {
"create": [ # type: ignore
{
"agentNodeId": node_id,
"executionStatus": ExecutionStatus.QUEUED,
"queuedTime": datetime.now(tz=timezone.utc),
"Input": {
data=AgentGraphExecutionCreateInput(
agentGraphId=graph_id,
agentGraphVersion=graph_version,
executionStatus=ExecutionStatus.QUEUED,
NodeExecutions={
"create": [
AgentNodeExecutionCreateInput(
agentNodeId=node_id,
executionStatus=ExecutionStatus.QUEUED,
queuedTime=datetime.now(tz=timezone.utc),
Input={
"create": [
{"name": name, "data": Json(data)}
for name, data in node_input.items()
]
},
}
for node_id, node_input in nodes_input
)
for node_id, node_input in starting_nodes_input
]
},
"userId": user_id,
"agentPresetId": preset_id,
},
userId=user_id,
agentPresetId=preset_id,
),
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
)
@@ -468,19 +489,27 @@ async def upsert_execution_output(
)
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecution:
res = await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
async def update_graph_execution_start_time(
graph_exec_id: str,
) -> GraphExecution | None:
count = await AgentGraphExecution.prisma().update_many(
where={
"id": graph_exec_id,
"executionStatus": ExecutionStatus.QUEUED,
},
data={
"executionStatus": ExecutionStatus.RUNNING,
"startedAt": datetime.now(tz=timezone.utc),
},
)
if count == 0:
return None
res = await AgentGraphExecution.prisma().find_unique(
where={"id": graph_exec_id},
include=GRAPH_EXECUTION_INCLUDE,
)
if not res:
raise ValueError(f"Graph execution #{graph_exec_id} not found")
return GraphExecution.from_db(res)
return GraphExecution.from_db(res) if res else None
async def update_graph_execution_stats(
@@ -491,7 +520,8 @@ async def update_graph_execution_stats(
data = stats.model_dump() if stats else {}
if isinstance(data.get("error"), Exception):
data["error"] = str(data["error"])
res = await AgentGraphExecution.prisma().update(
updated_count = await AgentGraphExecution.prisma().update_many(
where={
"id": graph_exec_id,
"OR": [
@@ -503,10 +533,15 @@ async def update_graph_execution_stats(
"executionStatus": status,
"stats": Json(data),
},
)
if updated_count == 0:
return None
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
where={"id": graph_exec_id},
include=GRAPH_EXECUTION_INCLUDE,
)
return GraphExecution.from_db(res) if res else None
return GraphExecution.from_db(graph_exec)
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
@@ -600,7 +635,7 @@ async def get_node_execution_results(
"agentGraphExecutionId": graph_exec_id,
}
if block_ids:
where_clause["AgentNode"] = {"is": {"agentBlockId": {"in": block_ids}}}
where_clause["Node"] = {"is": {"agentBlockId": {"in": block_ids}}}
if statuses:
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
@@ -642,7 +677,7 @@ async def get_latest_node_execution(
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
"NOT": [{"executionStatus": ExecutionStatus.INCOMPLETE}],
},
order=[
{"queuedTime": "desc"},
@@ -678,6 +713,7 @@ class GraphExecutionEntry(BaseModel):
graph_id: str
graph_version: int
start_node_execs: list["NodeExecutionEntry"]
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]]
class NodeExecutionEntry(BaseModel):
@@ -710,144 +746,6 @@ class ExecutionQueue(Generic[T]):
return self.queue.empty()
# ------------------- Execution Utilities -------------------- #
LIST_SPLIT = "_$_"
DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
def parse_execution_output(output: BlockData, name: str) -> Any | None:
"""
Extracts partial output data by name from a given BlockData.
The function supports extracting data from lists, dictionaries, and objects
using specific naming conventions:
- For lists: <output_name>_$_<index>
- For dictionaries: <output_name>_#_<key>
- For objects: <output_name>_@_<attribute>
Args:
output (BlockData): A tuple containing the output name and data.
name (str): The name used to extract specific data from the output.
Returns:
Any | None: The extracted data if found, otherwise None.
Examples:
>>> output = ("result", [10, 20, 30])
>>> parse_execution_output(output, "result_$_1")
20
>>> output = ("config", {"key1": "value1", "key2": "value2"})
>>> parse_execution_output(output, "config_#_key1")
'value1'
>>> class Sample:
... attr1 = "value1"
... attr2 = "value2"
>>> output = ("object", Sample())
>>> parse_execution_output(output, "object_@_attr1")
'value1'
"""
output_name, output_data = output
if name == output_name:
return output_data
if name.startswith(f"{output_name}{LIST_SPLIT}"):
index = int(name.split(LIST_SPLIT)[1])
if not isinstance(output_data, list) or len(output_data) <= index:
return None
return output_data[int(name.split(LIST_SPLIT)[1])]
if name.startswith(f"{output_name}{DICT_SPLIT}"):
index = name.split(DICT_SPLIT)[1]
if not isinstance(output_data, dict) or index not in output_data:
return None
return output_data[index]
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
index = name.split(OBJC_SPLIT)[1]
if isinstance(output_data, object) and hasattr(output_data, index):
return getattr(output_data, index)
return None
return None
def merge_execution_input(data: BlockInput) -> BlockInput:
"""
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
This function processes input keys that follow specific patterns to merge them into a unified structure:
- `<input_name>_$_<index>` for list inputs.
- `<input_name>_#_<index>` for dictionary inputs.
- `<input_name>_@_<index>` for object inputs.
Args:
data (BlockInput): A dictionary containing input keys and their corresponding values.
Returns:
BlockInput: A dictionary with merged inputs.
Raises:
ValueError: If a list index is not an integer.
Examples:
>>> data = {
... "list_$_0": "a",
... "list_$_1": "b",
... "dict_#_key1": "value1",
... "dict_#_key2": "value2",
... "object_@_attr1": "value1",
... "object_@_attr2": "value2"
... }
>>> merge_execution_input(data)
{
"list": ["a", "b"],
"dict": {"key1": "value1", "key2": "value2"},
"object": <MockObject attr1="value1" attr2="value2">
}
"""
# Merge all input with <input_name>_$_<index> into a single list.
items = list(data.items())
for key, value in items:
if LIST_SPLIT not in key:
continue
name, index = key.split(LIST_SPLIT)
if not index.isdigit():
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
data[name] = data.get(name, [])
if int(index) >= len(data[name]):
# Pad list with empty string on missing indices.
data[name].extend([""] * (int(index) - len(data[name]) + 1))
data[name][int(index)] = value
# Merge all input with <input_name>_#_<index> into a single dict.
for key, value in items:
if DICT_SPLIT not in key:
continue
name, index = key.split(DICT_SPLIT)
data[name] = data.get(name, {})
data[name][index] = value
# Merge all input with <input_name>_@_<index> into a single object.
for key, value in items:
if OBJC_SPLIT not in key:
continue
name, index = key.split(OBJC_SPLIT)
if name not in data or not isinstance(data[name], object):
data[name] = mock.MockObject()
setattr(data[name], index, value)
return data
# --------------------- Event Bus --------------------- #

View File

@@ -1,7 +1,7 @@
import logging
import uuid
from collections import defaultdict
from typing import Any, Literal, Optional, Type
from typing import Any, Literal, Optional, cast
import prisma
from prisma import Json
@@ -13,12 +13,19 @@ from prisma.types import (
AgentNodeCreateInput,
AgentNodeLinkCreateInput,
)
from pydantic import create_model
from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.blocks.llm import LlmModel
from backend.data.db import prisma as db
from backend.data.model import (
CredentialsField,
CredentialsFieldInfo,
CredentialsMetaInput,
is_credentials_field_name,
)
from backend.util import type as type_utils
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
@@ -190,14 +197,19 @@ class BaseGraph(BaseDbModel):
)
)
@computed_field
@property
def credentials_input_schema(self) -> dict[str, Any]:
return self._credentials_input_schema.jsonschema()
@staticmethod
def _generate_schema(
*props: tuple[Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input], dict],
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
) -> dict[str, Any]:
schema = []
schema_fields: list[AgentInputBlock.Input | AgentOutputBlock.Input] = []
for type_class, input_default in props:
try:
schema.append(type_class(**input_default))
schema_fields.append(type_class(**input_default))
except Exception as e:
logger.warning(f"Invalid {type_class}: {input_default}, {e}")
@@ -217,9 +229,93 @@ class BaseGraph(BaseDbModel):
**({"description": p.description} if p.description else {}),
**({"default": p.value} if p.value is not None else {}),
}
for p in schema
for p in schema_fields
},
"required": [p.name for p in schema if p.value is None],
"required": [p.name for p in schema_fields if p.value is None],
}
@property
def _credentials_input_schema(self) -> type[BlockSchema]:
graph_credentials_inputs = self.aggregate_credentials_inputs()
logger.debug(
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
f"{graph_credentials_inputs}"
)
# Warn if same-provider credentials inputs can't be combined (= bad UX)
graph_cred_fields = list(graph_credentials_inputs.values())
for i, (field, keys) in enumerate(graph_cred_fields):
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
if field.provider != other_field.provider:
continue
# If this happens, that means a block implementation probably needs
# to be updated.
logger.warning(
"Multiple combined credentials fields "
f"for provider {field.provider} "
f"on graph #{self.id} ({self.name}); "
f"fields: {field} <> {other_field};"
f"keys: {keys} <> {other_keys}."
)
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
agg_field_key: (
CredentialsMetaInput[
Literal[tuple(field_info.provider)], # type: ignore
Literal[tuple(field_info.supported_types)], # type: ignore
],
CredentialsField(
required_scopes=set(field_info.required_scopes or []),
discriminator=field_info.discriminator,
discriminator_mapping=field_info.discriminator_mapping,
),
)
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
}
return create_model(
self.name.replace(" ", "") + "CredentialsInputSchema",
__base__=BlockSchema,
**fields, # type: ignore
)
def aggregate_credentials_inputs(
self,
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
"""
Returns:
dict[aggregated_field_key, tuple(
CredentialsFieldInfo: A spec for one aggregated credentials field
set[(node_id, field_name)]: Node credentials fields that are
compatible with this aggregated field spec
)]
"""
return {
"_".join(sorted(agg_field_info.provider))
+ "_"
+ "_".join(sorted(agg_field_info.supported_types))
+ "_credentials": (agg_field_info, node_fields)
for agg_field_info, node_fields in CredentialsFieldInfo.combine(
*(
(
# Apply discrimination before aggregating credentials inputs
(
field_info.discriminate(
node.input_default[field_info.discriminator]
)
if (
field_info.discriminator
and node.input_default.get(field_info.discriminator)
)
else field_info
),
(node.id, field_name),
)
for node in self.nodes
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
)
)
}
@@ -320,8 +416,6 @@ class GraphModel(Graph):
return sanitized_name
# Validate smart decision maker nodes
smart_decision_maker_nodes = set()
agent_nodes = set()
nodes_block = {
node.id: block
for node in graph.nodes
@@ -332,13 +426,6 @@ class GraphModel(Graph):
if (block := nodes_block.get(node.id)) is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
# Smart decision maker nodes
if block.block_type == BlockType.AI:
smart_decision_maker_nodes.add(node.id)
# Agent nodes
elif block.block_type == BlockType.AGENT:
agent_nodes.add(node.id)
input_links = defaultdict(list)
for link in graph.links:
@@ -353,16 +440,21 @@ class GraphModel(Graph):
[sanitize(name) for name in node.input_default]
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
)
for name in block.input_schema.get_required_fields():
input_schema = block.input_schema
for name in (required_fields := input_schema.get_required_fields()):
if (
name not in provided_inputs
# Webhook payload is passed in by ExecutionManager
and not (
name == "payload"
and block.block_type
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
)
# Checking availability of credentials is done by ExecutionManager
and name not in input_schema.get_credentials_fields()
# Validate only I/O nodes, or validate everything when executing
and (
for_run # Skip input completion validation, unless when executing.
for_run
or block.block_type
in [
BlockType.INPUT,
@@ -375,9 +467,18 @@ class GraphModel(Graph):
f"Node {block.name} #{node.id} required input missing: `{name}`"
)
if (
block.block_type == BlockType.INPUT
and (input_key := node.input_default.get("name"))
and is_credentials_field_name(input_key)
):
raise ValueError(
f"Agent input node uses reserved name '{input_key}'; "
"'credentials' and `*_credentials` are reserved input names"
)
# Get input schema properties and check dependencies
input_schema = block.input_schema.model_fields
required_fields = block.input_schema.get_required_fields()
input_fields = input_schema.model_fields
def has_value(name):
return (
@@ -385,14 +486,21 @@ class GraphModel(Graph):
and name in node.input_default
and node.input_default[name] is not None
and str(node.input_default[name]).strip() != ""
) or (name in input_schema and input_schema[name].default is not None)
) or (name in input_fields and input_fields[name].default is not None)
# Validate dependencies between fields
for field_name, field_info in input_schema.items():
for field_name, field_info in input_fields.items():
# Apply input dependency validation only on run & field with depends_on
json_schema_extra = field_info.json_schema_extra or {}
dependencies = json_schema_extra.get("depends_on", [])
if not for_run or not dependencies:
if not (
for_run
and isinstance(json_schema_extra, dict)
and (
dependencies := cast(
list[str], json_schema_extra.get("depends_on", [])
)
)
):
continue
# Check if dependent field has value in input_default
@@ -465,13 +573,11 @@ class GraphModel(Graph):
is_active=graph.isActive,
name=graph.name or "",
description=graph.description or "",
nodes=[
NodeModel.from_db(node, for_export) for node in graph.AgentNodes or []
],
nodes=[NodeModel.from_db(node, for_export) for node in graph.Nodes or []],
links=list(
{
Link.from_db(link)
for node in graph.AgentNodes or []
for node in graph.Nodes or []
for link in (node.Input or []) + (node.Output or [])
}
),
@@ -602,8 +708,8 @@ async def get_graph(
and not (
await StoreListingVersion.prisma().find_first(
where={
"agentId": graph_id,
"agentVersion": version or graph.version,
"agentGraphId": graph_id,
"agentGraphVersion": version or graph.version,
"isDeleted": False,
"submissionStatus": SubmissionStatus.APPROVED,
}
@@ -637,12 +743,16 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
sub_graph_ids = [
(graph_id, graph_version)
for graph in search_graphs
for node in graph.AgentNodes or []
for node in graph.Nodes or []
if (
node.AgentBlock
and node.AgentBlock.id == agent_block_id
and (graph_id := dict(node.constantInput).get("graph_id"))
and (graph_version := dict(node.constantInput).get("graph_version"))
and (graph_id := cast(str, dict(node.constantInput).get("graph_id")))
and (
graph_version := cast(
int, dict(node.constantInput).get("graph_version")
)
)
)
]
if not sub_graph_ids:
@@ -657,7 +767,7 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
}
for graph_id, graph_version in sub_graph_ids
] # type: ignore
]
},
include=AGENT_GRAPH_INCLUDE,
)
@@ -671,7 +781,7 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
links = await AgentNodeLink.prisma().find_many(
where={"agentNodeSourceId": node_id},
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}}, # type: ignore
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}},
)
return [
(Link.from_db(link), NodeModel.from_db(link.AgentNodeSink))
@@ -829,12 +939,12 @@ async def fix_llm_provider_credentials():
SELECT graph."userId" user_id,
node.id node_id,
node."constantInput" node_preset_input
FROM platform."AgentNode" node
LEFT JOIN platform."AgentGraph" graph
ON node."agentGraphId" = graph.id
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
ORDER BY graph."userId";
"""
FROM platform."AgentNode" node
LEFT JOIN platform."AgentGraph" graph
ON node."agentGraphId" = graph.id
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
ORDER BY graph."userId";
"""
)
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")
except Exception as e:
@@ -912,17 +1022,24 @@ async def migrate_llm_models(migrate_to: LlmModel):
if field.annotation == LlmModel:
llm_model_fields[block.id] = field_name
# Convert enum values to a list of strings for the SQL query
enum_values = [v.value for v in LlmModel]
escaped_enum_values = repr(tuple(enum_values)) # hack but works
# Update each block
for id, path in llm_model_fields.items():
# Convert enum values to a list of strings for the SQL query
enum_values = [v.value for v in LlmModel.__members__.values()]
query = f"""
UPDATE "AgentNode"
SET "constantInput" = jsonb_set("constantInput", '{{{path}}}', '"{migrate_to.value}"', true)
WHERE "agentBlockId" = '{id}'
AND "constantInput" ? '{path}'
AND "constantInput"->>'{path}' NOT IN ({','.join(f"'{value}'" for value in enum_values)})
UPDATE platform."AgentNode"
SET "constantInput" = jsonb_set("constantInput", $1, to_jsonb($2), true)
WHERE "agentBlockId" = $3
AND "constantInput" ? ($4)::text
AND "constantInput"->>($4)::text NOT IN {escaped_enum_values}
"""
await db.execute_raw(query)
await db.execute_raw(
query, # type: ignore - is supposed to be LiteralString
[path],
migrate_to.value,
id,
path,
)

View File

@@ -1,3 +1,5 @@
from typing import cast
import prisma.enums
import prisma.types
@@ -11,25 +13,25 @@ AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
}
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
"Nodes": {"include": AGENT_NODE_INCLUDE}
}
EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
"Input": True,
"Output": True,
"AgentNode": True,
"AgentGraphExecution": True,
"Node": True,
"GraphExecution": True,
}
MAX_NODE_EXECUTIONS_FETCH = 1000
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
"AgentNodeExecutions": {
"NodeExecutions": {
"include": {
"Input": True,
"Output": True,
"AgentNode": True,
"AgentGraphExecution": True,
"Node": True,
"GraphExecution": True,
},
"order_by": [
{"queuedTime": "desc"},
@@ -41,31 +43,30 @@ GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
}
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
"AgentNodeExecutions": {
**GRAPH_EXECUTION_INCLUDE_WITH_NODES["AgentNodeExecutions"], # type: ignore
"NodeExecutions": {
**cast(
prisma.types.FindManyAgentNodeExecutionArgsFromAgentGraphExecution,
GRAPH_EXECUTION_INCLUDE_WITH_NODES["NodeExecutions"],
),
"where": {
"AgentNode": {
"AgentBlock": {"id": {"in": IO_BLOCK_IDs}}, # type: ignore
},
"NOT": {
"executionStatus": prisma.enums.AgentExecutionStatus.INCOMPLETE,
},
"Node": {"is": {"AgentBlock": {"is": {"id": {"in": IO_BLOCK_IDs}}}}},
"NOT": [{"executionStatus": prisma.enums.AgentExecutionStatus.INCOMPLETE}],
},
}
}
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
"AgentNodes": {"include": AGENT_NODE_INCLUDE}
}
def library_agent_include(user_id: str) -> prisma.types.LibraryAgentInclude:
return {
"Agent": {
"AgentGraph": {
"include": {
**AGENT_GRAPH_INCLUDE,
"AgentGraphExecution": {"where": {"userId": user_id}},
"Executions": {"where": {"userId": user_id}},
}
},
"Creator": True,

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import base64
import logging
from collections import defaultdict
from datetime import datetime, timezone
from typing import (
TYPE_CHECKING,
@@ -12,6 +13,7 @@ from typing import (
Generic,
Literal,
Optional,
Sequence,
TypedDict,
TypeVar,
get_args,
@@ -300,9 +302,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
)
field_schema = model.jsonschema()["properties"][field_name]
try:
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
field_schema
)
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
except ValidationError as e:
if "Field required [type=missing" not in str(e):
raise
@@ -328,14 +328,90 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
)
class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
credentials_provider: list[CP]
credentials_scopes: Optional[list[str]] = None
credentials_types: list[CT]
provider: frozenset[CP] = Field(..., alias="credentials_provider")
supported_types: frozenset[CT] = Field(..., alias="credentials_types")
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
discriminator: Optional[str] = None
discriminator_mapping: Optional[dict[str, CP]] = None
@classmethod
def combine(
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
"""
Combines multiple CredentialsFieldInfo objects into as few as possible.
Rules:
- Items can only be combined if they have the same supported credentials types
and the same supported providers.
- When combining items, the `required_scopes` of the result is a join
of the `required_scopes` of the original items.
Params:
*fields: (CredentialsFieldInfo, key) objects to group and combine
Returns:
A sequence of tuples containing combined CredentialsFieldInfo objects and
the set of keys of the respective original items that were grouped together.
"""
if not fields:
return []
# Group fields by their provider and supported_types
grouped_fields: defaultdict[
tuple[frozenset[CP], frozenset[CT]],
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
] = defaultdict(list)
for field, key in fields:
group_key = (frozenset(field.provider), frozenset(field.supported_types))
grouped_fields[group_key].append((key, field))
# Combine fields within each group
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
for group in grouped_fields.values():
# Start with the first field in the group
_, combined = group[0]
# Track the keys that were combined
combined_keys = {key for key, _ in group}
# Combine required_scopes from all fields in the group
all_scopes = set()
for _, field in group:
if field.required_scopes:
all_scopes.update(field.required_scopes)
# Create a new combined field
result.append(
(
CredentialsFieldInfo[CP, CT](
credentials_provider=combined.provider,
credentials_types=combined.supported_types,
credentials_scopes=frozenset(all_scopes) or None,
discriminator=combined.discriminator,
discriminator_mapping=combined.discriminator_mapping,
),
combined_keys,
)
)
return result
def discriminate(self, discriminator_value: Any) -> CredentialsFieldInfo:
if not (self.discriminator and self.discriminator_mapping):
return self
discriminator_value = self.discriminator_mapping[discriminator_value]
return CredentialsFieldInfo(
credentials_provider=frozenset([discriminator_value]),
credentials_types=self.supported_types,
credentials_scopes=self.required_scopes,
)
def CredentialsField(
required_scopes: set[str] = set(),

View File

@@ -6,9 +6,11 @@ import pydantic
from prisma import Json
from prisma.enums import OnboardingStep
from prisma.models import UserOnboarding
from prisma.types import UserOnboardingUpdateInput
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
from backend.data import db
from backend.data.block import get_blocks
from backend.data.credit import get_user_credit_model
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.server.v2.store.model import StoreAgentDetails
@@ -24,21 +26,26 @@ REASON_MAPPING: dict[str, list[str]] = {
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
user_credit = get_user_credit_model()
class UserOnboardingUpdate(pydantic.BaseModel):
completedSteps: Optional[list[OnboardingStep]] = None
notificationDot: Optional[bool] = None
notified: Optional[list[OnboardingStep]] = None
usageReason: Optional[str] = None
integrations: Optional[list[str]] = None
otherIntegrations: Optional[str] = None
selectedStoreListingVersionId: Optional[str] = None
agentInput: Optional[dict[str, Any]] = None
onboardingAgentExecutionId: Optional[str] = None
async def get_user_onboarding(user_id: str):
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id}, # type: ignore
"create": UserOnboardingCreateInput(userId=user_id),
"update": {},
},
)
@@ -48,6 +55,20 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update: UserOnboardingUpdateInput = {}
if data.completedSteps is not None:
update["completedSteps"] = list(set(data.completedSteps))
for step in (
OnboardingStep.AGENT_NEW_RUN,
OnboardingStep.GET_RESULTS,
OnboardingStep.MARKETPLACE_ADD_AGENT,
OnboardingStep.MARKETPLACE_RUN_AGENT,
OnboardingStep.BUILDER_SAVE_AGENT,
OnboardingStep.BUILDER_RUN_AGENT,
):
if step in data.completedSteps:
await reward_user(user_id, step)
if data.notificationDot is not None:
update["notificationDot"] = data.notificationDot
if data.notified is not None:
update["notified"] = list(set(data.notified))
if data.usageReason is not None:
update["usageReason"] = data.usageReason
if data.integrations is not None:
@@ -58,16 +79,57 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update["selectedStoreListingVersionId"] = data.selectedStoreListingVersionId
if data.agentInput is not None:
update["agentInput"] = Json(data.agentInput)
if data.onboardingAgentExecutionId is not None:
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, **update}, # type: ignore
"create": {"userId": user_id, **update},
"update": update,
},
)
async def reward_user(user_id: str, step: OnboardingStep):
async with db.locked_transaction(f"usr_trx_{user_id}-reward"):
reward = 0
match step:
# Reward user when they clicked New Run during onboarding
# This is because they need credits before scheduling a run (next step)
case OnboardingStep.AGENT_NEW_RUN:
reward = 300
case OnboardingStep.GET_RESULTS:
reward = 300
case OnboardingStep.MARKETPLACE_ADD_AGENT:
reward = 100
case OnboardingStep.MARKETPLACE_RUN_AGENT:
reward = 100
case OnboardingStep.BUILDER_SAVE_AGENT:
reward = 100
case OnboardingStep.BUILDER_RUN_AGENT:
reward = 100
if reward == 0:
return
onboarding = await get_user_onboarding(user_id)
# Skip if already rewarded
if step in onboarding.rewardedFor:
return
onboarding.rewardedFor.append(step)
await user_credit.onboarding_reward(user_id, reward, step)
await UserOnboarding.prisma().update(
where={"userId": user_id},
data={
"completedSteps": list(set(onboarding.completedSteps + [step])),
"rewardedFor": onboarding.rewardedFor,
},
)
def clean_and_split(text: str) -> list[str]:
"""
Removes all special characters from a string, truncates it to 100 characters,
@@ -186,11 +248,11 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
where={
"id": {"in": [agent.storeListingVersionId for agent in storeAgents]},
},
include={"Agent": True},
include={"AgentGraph": True},
)
for listing in agentListings:
agent = listing.Agent
agent = listing.AgentGraph
if agent is None:
continue
graph = GraphModel.from_db(agent)

View File

@@ -4,10 +4,18 @@ from enum import Enum
from typing import Awaitable, Optional
import aio_pika
import aio_pika.exceptions as aio_ex
import pika
import pika.adapters.blocking_connection
from pika.exceptions import AMQPError
from pika.spec import BasicProperties
from pydantic import BaseModel
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from backend.util.retry import conn_retry
from backend.util.settings import Settings
@@ -161,6 +169,12 @@ class SyncRabbitMQ(RabbitMQBase):
routing_key=queue.routing_key or queue.name,
)
@retry(
retry=retry_if_exception_type((AMQPError, ConnectionError)),
wait=wait_random_exponential(multiplier=1, max=5),
stop=stop_after_attempt(5),
reraise=True,
)
def publish_message(
self,
routing_key: str,
@@ -258,6 +272,12 @@ class AsyncRabbitMQ(RabbitMQBase):
exchange, routing_key=queue.routing_key or queue.name
)
@retry(
retry=retry_if_exception_type((aio_ex.AMQPError, ConnectionError)),
wait=wait_random_exponential(multiplier=1, max=5),
stop=stop_after_attempt(5),
reraise=True,
)
async def publish_message(
self,
routing_key: str,

View File

@@ -11,7 +11,7 @@ from fastapi import HTTPException
from prisma import Json
from prisma.enums import NotificationType
from prisma.models import User
from prisma.types import UserCreateInput, UserUpdateInput
from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
from backend.data.db import prisma
from backend.data.model import UserIntegrations, UserMetadata, UserMetadataRaw
@@ -135,16 +135,21 @@ async def migrate_and_encrypt_user_integrations():
"""Migrate integration credentials and OAuth states from metadata to integrations column."""
users = await User.prisma().find_many(
where={
"metadata": {
"path": ["integration_credentials"],
"not": Json({"a": "yolo"}), # bogus value works to check if key exists
} # type: ignore
"metadata": cast(
JsonFilter,
{
"path": ["integration_credentials"],
"not": Json(
{"a": "yolo"}
), # bogus value works to check if key exists
},
)
}
)
logger.info(f"Migrating integration credentials for {len(users)} users")
for user in users:
raw_metadata = cast(UserMetadataRaw, user.metadata)
raw_metadata = cast(dict, user.metadata)
metadata = UserMetadata.model_validate(raw_metadata)
# Get existing integrations data
@@ -160,7 +165,6 @@ async def migrate_and_encrypt_user_integrations():
await update_user_integrations(user_id=user.id, data=integrations)
# Remove from metadata
raw_metadata = dict(raw_metadata)
raw_metadata.pop("integration_credentials", None)
raw_metadata.pop("integration_oauth_states", None)

View File

@@ -1,11 +1,8 @@
import logging
from backend.data import db, redis
from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
GraphExecution,
NodeExecutionResult,
RedisExecutionEventBus,
create_graph_execution,
get_graph_execution,
get_incomplete_node_executions,
@@ -42,7 +39,7 @@ from backend.data.user import (
update_user_integrations,
update_user_metadata,
)
from backend.util.service import AppService, expose, exposed_run_and_wait
from backend.util.service import AppService, exposed_run_and_wait
from backend.util.settings import Config
config = Config()
@@ -57,21 +54,14 @@ async def _spend_credits(
class DatabaseManager(AppService):
def __init__(self):
super().__init__()
self.execution_event_bus = RedisExecutionEventBus()
def run_service(self) -> None:
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
self.run_and_wait(db.connect())
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
redis.connect()
super().run_service()
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
redis.disconnect()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
self.run_and_wait(db.disconnect())
@@ -79,12 +69,6 @@ class DatabaseManager(AppService):
def get_port(cls) -> int:
return config.database_api_port
@expose
def send_execution_update(
self, execution_result: GraphExecution | NodeExecutionResult
):
self.execution_event_bus.publish(execution_result)
# Executions
get_graph_execution = exposed_run_and_wait(get_graph_execution)
create_graph_execution = exposed_run_and_wait(create_graph_execution)

View File

@@ -5,11 +5,14 @@ import os
import signal
import sys
import threading
import time
from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import contextmanager
from multiprocessing.pool import AsyncResult, Pool
from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
from pika.adapters.blocking_connection import BlockingChannel
from pika.spec import Basic, BasicProperties
from redis.lock import Lock as RedisLock
from backend.blocks.io import AgentOutputBlock
@@ -26,47 +29,40 @@ if TYPE_CHECKING:
from backend.executor import DatabaseManager
from backend.notifications.notifications import NotificationManager
from autogpt_libs.utils.cache import thread_cached
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
from backend.blocks.agent import AgentExecutorBlock
from backend.data import redis
from backend.data.block import (
Block,
BlockData,
BlockInput,
BlockSchema,
BlockType,
get_block,
)
from backend.data.block import BlockData, BlockInput, BlockSchema, get_block
from backend.data.execution import (
ExecutionQueue,
ExecutionStatus,
GraphExecution,
GraphExecutionEntry,
NodeExecutionEntry,
NodeExecutionResult,
merge_execution_input,
parse_execution_output,
)
from backend.data.graph import GraphModel, Link, Node
from backend.data.graph import Link, Node
from backend.executor.utils import (
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
GRAPH_EXECUTION_QUEUE_NAME,
CancelExecutionEvent,
UsageTransactionMetadata,
block_usage_cost,
execution_usage_cost,
get_execution_event_bus,
get_execution_queue,
parse_execution_output,
validate_exec,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util import json
from backend.util.decorator import error_logged, time_measured
from backend.util.file import clean_exec_files
from backend.util.logging import configure_logging
from backend.util.process import set_service_name
from backend.util.service import (
AppService,
close_service_client,
expose,
get_service_client,
)
from backend.util.process import AppProcess, set_service_name
from backend.util.service import close_service_client, get_service_client
from backend.util.settings import Settings
from backend.util.type import convert
logger = logging.getLogger(__name__)
settings = Settings()
@@ -91,7 +87,7 @@ class LogMetadata:
"node_id": node_id,
"block_name": block_name,
}
self.prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|nid:{node_eid}|{block_name}]"
self.prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]"
def info(self, msg: str, **extra):
msg = self._wrap(msg, **extra)
@@ -152,7 +148,7 @@ def execute_node(
def update_execution_status(status: ExecutionStatus) -> NodeExecutionResult:
"""Sets status and fetches+broadcasts the latest state of the node execution"""
exec_update = db_client.update_node_execution_status(node_exec_id, status)
db_client.send_execution_update(exec_update)
send_execution_update(exec_update)
return exec_update
node = db_client.get_node(node_id)
@@ -192,7 +188,7 @@ def execute_node(
# Execute the node
input_data_str = json.dumps(input_data)
input_size = len(input_data_str)
log_metadata.info("Executed node with input", input=input_data_str)
log_metadata.debug("Executed node with input", input=input_data_str)
update_execution_status(ExecutionStatus.RUNNING)
# Inject extra execution arguments for the blocks via kwargs
@@ -223,7 +219,7 @@ def execute_node(
):
output_data = json.convert_pydantic_to_json(output_data)
output_size += len(json.dumps(output_data))
log_metadata.info("Node produced output", **{output_name: output_data})
log_metadata.debug("Node produced output", **{output_name: output_data})
push_output(output_name, output_data)
outputs[output_name] = output_data
for execution in _enqueue_next_nodes(
@@ -288,7 +284,7 @@ def _enqueue_next_nodes(
exec_update = db_client.update_node_execution_status(
node_exec_id, ExecutionStatus.QUEUED, data
)
db_client.send_execution_update(exec_update)
send_execution_update(exec_update)
return NodeExecutionEntry(
user_id=user_id,
graph_exec_id=graph_exec_id,
@@ -400,60 +396,6 @@ def _enqueue_next_nodes(
]
def validate_exec(
node: Node,
data: BlockInput,
resolve_input: bool = True,
) -> tuple[BlockInput | None, str]:
"""
Validate the input data for a node execution.
Args:
node: The node to execute.
data: The input data for the node execution.
resolve_input: Whether to resolve dynamic pins into dict/list/object.
Returns:
A tuple of the validated data and the block name.
If the data is invalid, the first element will be None, and the second element
will be an error message.
If the data is valid, the first element will be the resolved input data, and
the second element will be the block name.
"""
node_block: Block | None = get_block(node.block_id)
if not node_block:
return None, f"Block for {node.block_id} not found."
schema = node_block.input_schema
# Convert non-matching data types to the expected input schema.
for name, data_type in schema.__annotations__.items():
if (value := data.get(name)) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Input data (without default values) should contain all required fields.
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
if missing_links := schema.get_missing_links(data, node.input_links):
return None, f"{error_prefix} unpopulated links {missing_links}"
# Merge input data with default values and resolve dynamic dict/list/object pins.
input_default = schema.get_input_defaults(node.input_default)
data = {**input_default, **data}
if resolve_input:
data = merge_execution_input(data)
# Input data post-merge should contain all required fields from the schema.
if missing_input := schema.get_missing_input(data):
return None, f"{error_prefix} missing input {missing_input}"
# Last validation: Validate the input values against the schema.
if error := schema.get_mismatch_error(data):
error_message = f"{error_prefix} {error}"
logger.error(error_message)
return None, error_message
return data, node_block.name
class Executor:
"""
This class contains event handlers for the process pool executor events.
@@ -633,7 +575,13 @@ class Executor:
exec_meta = cls.db_client.update_graph_execution_start_time(
graph_exec.graph_exec_id
)
cls.db_client.send_execution_update(exec_meta)
if exec_meta is None:
logger.warning(
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution is not found or not currently in the QUEUED state."
)
return
send_execution_update(exec_meta)
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
graph_exec, cancel, log_metadata
)
@@ -646,7 +594,7 @@ class Executor:
status=status,
stats=exec_stats,
):
cls.db_client.send_execution_update(graph_exec_result)
send_execution_update(graph_exec_result)
cls._handle_agent_run_notif(graph_exec, exec_stats)
@@ -759,7 +707,7 @@ class Executor:
status=execution_status,
stats=execution_stats,
):
cls.db_client.send_execution_update(_graph_exec)
send_execution_update(_graph_exec)
else:
logger.error(
"Callback for "
@@ -776,10 +724,10 @@ class Executor:
execution_status = ExecutionStatus.TERMINATED
return execution_stats, execution_status, error
exec_data = queue.get()
queued_node_exec = queue.get()
# Avoid parallel execution of the same node.
execution = running_executions.get(exec_data.node_id)
execution = running_executions.get(queued_node_exec.node_id)
if execution and not execution.ready():
# TODO (performance improvement):
# Wait for the completion of the same node execution is blocking.
@@ -788,18 +736,18 @@ class Executor:
execution.wait()
log_metadata.debug(
f"Dispatching node execution {exec_data.node_exec_id} "
f"for node {exec_data.node_id}",
f"Dispatching node execution {queued_node_exec.node_exec_id} "
f"for node {queued_node_exec.node_id}",
)
try:
exec_cost_counter = cls._charge_usage(
node_exec=exec_data,
node_exec=queued_node_exec,
execution_count=exec_cost_counter + 1,
execution_stats=execution_stats,
)
except InsufficientBalanceError as error:
node_exec_id = exec_data.node_exec_id
node_exec_id = queued_node_exec.node_exec_id
cls.db_client.upsert_execution_output(
node_exec_id=node_exec_id,
output_name="error",
@@ -810,7 +758,7 @@ class Executor:
exec_update = cls.db_client.update_node_execution_status(
node_exec_id, execution_status
)
cls.db_client.send_execution_update(exec_update)
send_execution_update(exec_update)
cls._handle_low_balance_notif(
graph_exec.user_id,
@@ -820,10 +768,23 @@ class Executor:
)
raise
running_executions[exec_data.node_id] = cls.executor.apply_async(
# Add credentials input overrides
node_id = queued_node_exec.node_id
if (node_creds_map := graph_exec.node_credentials_input_map) and (
node_field_creds_map := node_creds_map.get(node_id)
):
queued_node_exec.data.update(
{
field_name: creds_meta.model_dump()
for field_name, creds_meta in node_field_creds_map.items()
}
)
# Initiate node execution
running_executions[queued_node_exec.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(queue, exec_data),
callback=make_exec_callback(exec_data),
(queue, queued_node_exec),
callback=make_exec_callback(queued_node_exec),
)
# Avoid terminating graph execution when some nodes are still running.
@@ -927,22 +888,43 @@ class Executor:
)
class ExecutionManager(AppService):
class ExecutionManager(AppProcess):
def __init__(self):
super().__init__()
self.pool_size = settings.config.num_graph_workers
self.queue = ExecutionQueue[GraphExecutionEntry]()
self.running = True
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
@classmethod
def get_port(cls) -> int:
return settings.config.execution_manager_port
def run_service(self):
from backend.integrations.credentials_store import IntegrationCredentialsStore
def run(self):
retry_count_max = settings.config.execution_manager_loop_max_retry
retry_count = 0
self.credentials_store = IntegrationCredentialsStore()
for retry_count in range(retry_count_max):
try:
self._run()
except Exception as e:
if not self.running:
break
logger.exception(
f"[{self.service_name}] Error in execution manager: {e}"
)
if retry_count >= retry_count_max:
logger.error(
f"[{self.service_name}] Max retries reached ({retry_count_max}), exiting..."
)
break
else:
logger.info(
f"[{self.service_name}] Retrying execution loop in {retry_count} seconds..."
)
time.sleep(retry_count)
def _run(self):
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
self.executor = ProcessPoolExecutor(
max_workers=self.pool_size,
@@ -952,25 +934,122 @@ class ExecutionManager(AppService):
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
redis.connect()
sync_manager = multiprocessing.Manager()
while True:
graph_exec_data = self.queue.get()
graph_exec_id = graph_exec_data.graph_exec_id
logger.debug(
f"[ExecutionManager] Dispatching graph execution {graph_exec_id}"
)
cancel_event = sync_manager.Event()
future = self.executor.submit(
Executor.on_graph_execution, graph_exec_data, cancel_event
)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
future.add_done_callback(
lambda _: self.active_graph_runs.pop(graph_exec_id, None)
# Consume Cancel & Run execution requests.
clear_thread_cache(get_execution_queue)
channel = get_execution_queue().get_channel()
channel.basic_qos(prefetch_count=self.pool_size)
channel.basic_consume(
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
on_message_callback=self._handle_cancel_message,
auto_ack=True,
)
channel.basic_consume(
queue=GRAPH_EXECUTION_QUEUE_NAME,
on_message_callback=self._handle_run_message,
auto_ack=False,
)
logger.info(f"[{self.service_name}] Ready to consume messages...")
channel.start_consuming()
def _handle_cancel_message(
self,
channel: BlockingChannel,
method: Basic.Deliver,
properties: BasicProperties,
body: bytes,
):
"""
Called whenever we receive a CANCEL message from the queue.
(With auto_ack=True, message is considered 'acked' automatically.)
"""
try:
request = CancelExecutionEvent.model_validate_json(body)
graph_exec_id = request.graph_exec_id
if not graph_exec_id:
logger.warning(
f"[{self.service_name}] Cancel message missing 'graph_exec_id'"
)
return
if graph_exec_id not in self.active_graph_runs:
logger.debug(
f"[{self.service_name}] Cancel received for {graph_exec_id} but not active."
)
return
_, cancel_event = self.active_graph_runs[graph_exec_id]
logger.info(f"[{self.service_name}] Received cancel for {graph_exec_id}")
if not cancel_event.is_set():
cancel_event.set()
else:
logger.debug(
f"[{self.service_name}] Cancel already set for {graph_exec_id}"
)
except Exception as e:
logger.exception(f"Error handling cancel message: {e}")
def _handle_run_message(
self,
channel: BlockingChannel,
method: Basic.Deliver,
properties: BasicProperties,
body: bytes,
):
delivery_tag = method.delivery_tag
try:
graph_exec_entry = GraphExecutionEntry.model_validate_json(body)
except Exception as e:
logger.error(f"[{self.service_name}] Could not parse run message: {e}")
channel.basic_nack(delivery_tag, requeue=False)
return
graph_exec_id = graph_exec_entry.graph_exec_id
logger.info(
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
)
if graph_exec_id in self.active_graph_runs:
logger.warning(
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
)
channel.basic_nack(delivery_tag, requeue=False)
return
cancel_event = multiprocessing.Manager().Event()
future = self.executor.submit(
Executor.on_graph_execution, graph_exec_entry, cancel_event
)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
def _on_run_done(f: Future):
logger.info(f"[{self.service_name}] Run completed for {graph_exec_id}")
try:
self.active_graph_runs.pop(graph_exec_id, None)
if f.exception():
logger.error(
f"[{self.service_name}] Execution for {graph_exec_id} failed: {f.exception()}"
)
channel.connection.add_callback_threadsafe(
lambda: channel.basic_nack(delivery_tag, requeue=False)
)
else:
channel.connection.add_callback_threadsafe(
lambda: channel.basic_ack(delivery_tag)
)
except Exception as e:
logger.error(f"[{self.service_name}] Error acknowledging message: {e}")
future.add_done_callback(_on_run_done)
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down service loop...")
self.running = False
logger.info(f"[{self.service_name}] ⏳ Shutting down RabbitMQ channel...")
get_execution_queue().get_channel().stop_consuming()
logger.info(f"[{self.service_name}] ⏳ Shutting down graph executor pool...")
self.executor.shutdown(cancel_futures=True)
@@ -981,175 +1060,6 @@ class ExecutionManager(AppService):
def db_client(self) -> "DatabaseManager":
return get_db_client()
@expose
def add_execution(
self,
graph_id: str,
data: BlockInput,
user_id: str,
graph_version: Optional[int] = None,
preset_id: str | None = None,
) -> GraphExecutionEntry:
graph: GraphModel | None = self.db_client.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
if not graph:
raise ValueError(f"Graph #{graph_id} not found.")
graph.validate_graph(for_run=True)
self._validate_node_input_credentials(graph, user_id)
nodes_input = []
for node in graph.starting_nodes:
input_data = {}
block = node.block
# Note block should never be executed.
if block.block_type == BlockType.NOTE:
continue
# Extract request input data, and assign it to the input pin.
if block.block_type == BlockType.INPUT:
input_name = node.input_default.get("name")
if input_name and input_name in data:
input_data = {"value": data[input_name]}
# Extract webhook payload, and assign it to the input pin
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
if (
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
and node.webhook_id
):
if webhook_payload_key not in data:
raise ValueError(
f"Node {block.name} #{node.id} webhook payload is missing"
)
input_data = {"payload": data[webhook_payload_key]}
input_data, error = validate_exec(node, input_data)
if input_data is None:
raise ValueError(error)
else:
nodes_input.append((node.id, input_data))
if not nodes_input:
raise ValueError(
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
)
graph_exec = self.db_client.create_graph_execution(
graph_id=graph_id,
graph_version=graph.version,
nodes_input=nodes_input,
user_id=user_id,
preset_id=preset_id,
)
self.db_client.send_execution_update(graph_exec)
graph_exec_entry = GraphExecutionEntry(
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version or 0,
graph_exec_id=graph_exec.id,
start_node_execs=[
NodeExecutionEntry(
user_id=user_id,
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
data=node_exec.input_data,
)
for node_exec in graph_exec.node_executions
],
)
self.queue.add(graph_exec_entry)
return graph_exec_entry
@expose
def cancel_execution(self, graph_exec_id: str) -> None:
"""
Mechanism:
1. Set the cancel event
2. Graph executor's cancel handler thread detects the event, terminates workers,
reinitializes worker pool, and returns.
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
if graph_exec_id not in self.active_graph_runs:
logger.warning(
f"Graph execution #{graph_exec_id} not active/running: "
"possibly already completed/cancelled."
)
else:
future, cancel_event = self.active_graph_runs[graph_exec_id]
if not cancel_event.is_set():
cancel_event.set()
future.result()
# Update the status of the graph & node executions
self.db_client.update_graph_execution_stats(
graph_exec_id,
ExecutionStatus.TERMINATED,
)
node_execs = self.db_client.get_node_execution_results(
graph_exec_id=graph_exec_id,
statuses=[
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
ExecutionStatus.INCOMPLETE,
],
)
self.db_client.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in node_execs],
ExecutionStatus.TERMINATED,
)
for node_exec in node_execs:
node_exec.status = ExecutionStatus.TERMINATED
self.db_client.send_execution_update(node_exec)
def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
"""Checks all credentials for all nodes of the graph"""
for node in graph.nodes:
block = node.block
# Find any fields of type CredentialsMetaInput
credentials_fields = cast(
type[BlockSchema], block.input_schema
).get_credentials_fields()
if not credentials_fields:
continue
for field_name, credentials_meta_type in credentials_fields.items():
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
# Fetch the corresponding Credentials and perform sanity checks
credentials = self.credentials_store.get_creds_by_id(
user_id, credentials_meta.id
)
if not credentials:
raise ValueError(
f"Unknown credentials #{credentials_meta.id} "
f"for node #{node.id} input '{field_name}'"
)
if (
credentials.provider != credentials_meta.provider
or credentials.type != credentials_meta.type
):
logger.warning(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch: "
f"{credentials_meta.type}<>{credentials.type};"
f"{credentials_meta.provider}<>{credentials.provider}"
)
raise ValueError(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch"
)
# ------- UTILITIES ------- #
@@ -1168,6 +1078,10 @@ def get_notification_service() -> "NotificationManager":
return get_service_client(NotificationManager)
def send_execution_update(entry: GraphExecution | NodeExecutionResult):
return get_execution_event_bus().publish(entry)
@contextmanager
def synchronized(key: str, timeout: int = 60):
lock: RedisLock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)

View File

@@ -16,7 +16,7 @@ from pydantic import BaseModel
from sqlalchemy import MetaData, create_engine
from backend.data.block import BlockInput
from backend.executor.manager import ExecutionManager
from backend.executor import utils as execution_utils
from backend.notifications.notifications import NotificationManager
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Config
@@ -57,11 +57,6 @@ def job_listener(event):
log(f"Job {event.job_id} completed successfully.")
@thread_cached
def get_execution_client() -> ExecutionManager:
return get_service_client(ExecutionManager)
@thread_cached
def get_notification_client():
from backend.notifications import NotificationManager
@@ -73,9 +68,9 @@ def execute_graph(**kwargs):
args = ExecutionJobArgs(**kwargs)
try:
log(f"Executing recurring job for graph #{args.graph_id}")
get_execution_client().add_execution(
execution_utils.add_graph_execution(
graph_id=args.graph_id,
data=args.input_data,
inputs=args.input_data,
user_id=args.user_id,
graph_version=args.graph_version,
)
@@ -164,11 +159,6 @@ class Scheduler(AppService):
def db_pool_size(cls) -> int:
return config.scheduler_db_pool_size
@property
@thread_cached
def execution_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager)
@property
@thread_cached
def notification_client(self) -> NotificationManager:
@@ -176,7 +166,7 @@ class Scheduler(AppService):
def run_service(self):
load_dotenv()
db_schema, db_url = _extract_schema_from_url(os.getenv("DATABASE_URL"))
db_schema, db_url = _extract_schema_from_url(os.getenv("DIRECT_URL"))
self.scheduler = BlockingScheduler(
jobstores={
Jobstores.EXECUTION.value: SQLAlchemyJobStore(

View File

@@ -1,11 +1,94 @@
import logging
from typing import TYPE_CHECKING, Any, Optional, cast
from autogpt_libs.utils.cache import thread_cached
from pydantic import BaseModel
from backend.data.block import Block, BlockInput
from backend.data.block import (
Block,
BlockData,
BlockInput,
BlockSchema,
BlockType,
get_block,
)
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCostType
from backend.data.execution import (
AsyncRedisExecutionEventBus,
ExecutionStatus,
GraphExecutionStats,
GraphExecutionWithNodes,
RedisExecutionEventBus,
create_graph_execution,
update_graph_execution_stats,
update_node_execution_status_batch,
)
from backend.data.graph import GraphModel, Node, get_graph
from backend.data.model import CredentialsMetaInput
from backend.data.rabbitmq import (
AsyncRabbitMQ,
Exchange,
ExchangeType,
Queue,
RabbitMQConfig,
SyncRabbitMQ,
)
from backend.util.exceptions import NotFoundError
from backend.util.mock import MockObject
from backend.util.service import get_service_client
from backend.util.settings import Config
from backend.util.type import convert
if TYPE_CHECKING:
from backend.executor import DatabaseManager
from backend.integrations.credentials_store import IntegrationCredentialsStore
config = Config()
logger = logging.getLogger(__name__)
# ============ Resource Helpers ============ #
@thread_cached
def get_execution_event_bus() -> RedisExecutionEventBus:
return RedisExecutionEventBus()
@thread_cached
def get_async_execution_event_bus() -> AsyncRedisExecutionEventBus:
return AsyncRedisExecutionEventBus()
@thread_cached
def get_execution_queue() -> SyncRabbitMQ:
client = SyncRabbitMQ(create_execution_queue_config())
client.connect()
return client
@thread_cached
async def get_async_execution_queue() -> AsyncRabbitMQ:
client = AsyncRabbitMQ(create_execution_queue_config())
await client.connect()
return client
@thread_cached
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
from backend.integrations.credentials_store import IntegrationCredentialsStore
return IntegrationCredentialsStore()
@thread_cached
def get_db_client() -> "DatabaseManager":
from backend.executor import DatabaseManager
return get_service_client(DatabaseManager)
# ============ Execution Cost Helpers ============ #
class UsageTransactionMetadata(BaseModel):
@@ -95,3 +178,571 @@ def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bo
or (input_data.get(k) and _is_cost_filter_match(v, input_data[k]))
for k, v in cost_filter.items()
)
# ============ Execution Input Helpers ============ #
LIST_SPLIT = "_$_"
DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
def parse_execution_output(output: BlockData, name: str) -> Any | None:
"""
Extracts partial output data by name from a given BlockData.
The function supports extracting data from lists, dictionaries, and objects
using specific naming conventions:
- For lists: <output_name>_$_<index>
- For dictionaries: <output_name>_#_<key>
- For objects: <output_name>_@_<attribute>
Args:
output (BlockData): A tuple containing the output name and data.
name (str): The name used to extract specific data from the output.
Returns:
Any | None: The extracted data if found, otherwise None.
Examples:
>>> output = ("result", [10, 20, 30])
>>> parse_execution_output(output, "result_$_1")
20
>>> output = ("config", {"key1": "value1", "key2": "value2"})
>>> parse_execution_output(output, "config_#_key1")
'value1'
>>> class Sample:
... attr1 = "value1"
... attr2 = "value2"
>>> output = ("object", Sample())
>>> parse_execution_output(output, "object_@_attr1")
'value1'
"""
output_name, output_data = output
if name == output_name:
return output_data
if name.startswith(f"{output_name}{LIST_SPLIT}"):
index = int(name.split(LIST_SPLIT)[1])
if not isinstance(output_data, list) or len(output_data) <= index:
return None
return output_data[int(name.split(LIST_SPLIT)[1])]
if name.startswith(f"{output_name}{DICT_SPLIT}"):
index = name.split(DICT_SPLIT)[1]
if not isinstance(output_data, dict) or index not in output_data:
return None
return output_data[index]
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
index = name.split(OBJC_SPLIT)[1]
if isinstance(output_data, object) and hasattr(output_data, index):
return getattr(output_data, index)
return None
return None
def validate_exec(
node: Node,
data: BlockInput,
resolve_input: bool = True,
) -> tuple[BlockInput | None, str]:
"""
Validate the input data for a node execution.
Args:
node: The node to execute.
data: The input data for the node execution.
resolve_input: Whether to resolve dynamic pins into dict/list/object.
Returns:
A tuple of the validated data and the block name.
If the data is invalid, the first element will be None, and the second element
will be an error message.
If the data is valid, the first element will be the resolved input data, and
the second element will be the block name.
"""
node_block: Block | None = get_block(node.block_id)
if not node_block:
return None, f"Block for {node.block_id} not found."
schema = node_block.input_schema
# Convert non-matching data types to the expected input schema.
for name, data_type in schema.__annotations__.items():
if (value := data.get(name)) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Input data (without default values) should contain all required fields.
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
if missing_links := schema.get_missing_links(data, node.input_links):
return None, f"{error_prefix} unpopulated links {missing_links}"
# Merge input data with default values and resolve dynamic dict/list/object pins.
input_default = schema.get_input_defaults(node.input_default)
data = {**input_default, **data}
if resolve_input:
data = merge_execution_input(data)
# Input data post-merge should contain all required fields from the schema.
if missing_input := schema.get_missing_input(data):
return None, f"{error_prefix} missing input {missing_input}"
# Last validation: Validate the input values against the schema.
if error := schema.get_mismatch_error(data):
error_message = f"{error_prefix} {error}"
logger.error(error_message)
return None, error_message
return data, node_block.name
def merge_execution_input(data: BlockInput) -> BlockInput:
"""
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
This function processes input keys that follow specific patterns to merge them into a unified structure:
- `<input_name>_$_<index>` for list inputs.
- `<input_name>_#_<index>` for dictionary inputs.
- `<input_name>_@_<index>` for object inputs.
Args:
data (BlockInput): A dictionary containing input keys and their corresponding values.
Returns:
BlockInput: A dictionary with merged inputs.
Raises:
ValueError: If a list index is not an integer.
Examples:
>>> data = {
... "list_$_0": "a",
... "list_$_1": "b",
... "dict_#_key1": "value1",
... "dict_#_key2": "value2",
... "object_@_attr1": "value1",
... "object_@_attr2": "value2"
... }
>>> merge_execution_input(data)
{
"list": ["a", "b"],
"dict": {"key1": "value1", "key2": "value2"},
"object": <MockObject attr1="value1" attr2="value2">
}
"""
# Merge all input with <input_name>_$_<index> into a single list.
items = list(data.items())
for key, value in items:
if LIST_SPLIT not in key:
continue
name, index = key.split(LIST_SPLIT)
if not index.isdigit():
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
data[name] = data.get(name, [])
if int(index) >= len(data[name]):
# Pad list with empty string on missing indices.
data[name].extend([""] * (int(index) - len(data[name]) + 1))
data[name][int(index)] = value
# Merge all input with <input_name>_#_<index> into a single dict.
for key, value in items:
if DICT_SPLIT not in key:
continue
name, index = key.split(DICT_SPLIT)
data[name] = data.get(name, {})
data[name][index] = value
# Merge all input with <input_name>_@_<index> into a single object.
for key, value in items:
if OBJC_SPLIT not in key:
continue
name, index = key.split(OBJC_SPLIT)
if name not in data or not isinstance(data[name], object):
data[name] = MockObject()
setattr(data[name], index, value)
return data
def _validate_node_input_credentials(
graph: GraphModel,
user_id: str,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
):
"""Checks all credentials for all nodes of the graph"""
for node in graph.nodes:
block = node.block
# Find any fields of type CredentialsMetaInput
credentials_fields = cast(
type[BlockSchema], block.input_schema
).get_credentials_fields()
if not credentials_fields:
continue
for field_name, credentials_meta_type in credentials_fields.items():
if (
node_credentials_input_map
and (node_credentials_inputs := node_credentials_input_map.get(node.id))
and field_name in node_credentials_inputs
):
credentials_meta = node_credentials_input_map[node.id][field_name]
elif field_name in node.input_default:
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
else:
raise ValueError(
f"Credentials absent for {block.name} node #{node.id} "
f"input '{field_name}'"
)
# Fetch the corresponding Credentials and perform sanity checks
credentials = get_integration_credentials_store().get_creds_by_id(
user_id, credentials_meta.id
)
if not credentials:
raise ValueError(
f"Unknown credentials #{credentials_meta.id} "
f"for node #{node.id} input '{field_name}'"
)
if (
credentials.provider != credentials_meta.provider
or credentials.type != credentials_meta.type
):
logger.warning(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch: "
f"{credentials_meta.type}<>{credentials.type};"
f"{credentials_meta.provider}<>{credentials.provider}"
)
raise ValueError(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch"
)
def make_node_credentials_input_map(
graph: GraphModel,
graph_credentials_input: dict[str, CredentialsMetaInput],
) -> dict[str, dict[str, CredentialsMetaInput]]:
"""
Maps credentials for an execution to the correct nodes.
Params:
graph: The graph to be executed.
graph_credentials_input: A (graph_input_name, credentials_meta) map.
Returns:
dict[node_id, dict[field_name, CredentialsMetaInput]]: Node credentials input map.
"""
result: dict[str, dict[str, CredentialsMetaInput]] = {}
# Get aggregated credentials fields for the graph
graph_cred_inputs = graph.aggregate_credentials_inputs()
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
# Best-effort map: skip missing items
if graph_input_name not in graph_credentials_input:
continue
# Use passed-in credentials for all compatible node input fields
for node_id, node_field_name in compatible_node_fields:
if node_id not in result:
result[node_id] = {}
result[node_id][node_field_name] = graph_credentials_input[graph_input_name]
return result
def construct_node_execution_input(
graph: GraphModel,
user_id: str,
graph_inputs: BlockInput,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
) -> list[tuple[str, BlockInput]]:
"""
Validates and prepares the input data for executing a graph.
This function checks the graph for starting nodes, validates the input data
against the schema, and resolves dynamic input pins into a single list,
dictionary, or object.
Args:
graph (GraphModel): The graph model to execute.
user_id (str): The ID of the user executing the graph.
data (BlockInput): The input data for the graph execution.
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
Returns:
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
the corresponding input data for that node.
"""
graph.validate_graph(for_run=True)
_validate_node_input_credentials(graph, user_id, node_credentials_input_map)
nodes_input = []
for node in graph.starting_nodes:
input_data = {}
block = node.block
# Note block should never be executed.
if block.block_type == BlockType.NOTE:
continue
# Extract request input data, and assign it to the input pin.
if block.block_type == BlockType.INPUT:
input_name = node.input_default.get("name")
if input_name and input_name in graph_inputs:
input_data = {"value": graph_inputs[input_name]}
# Extract webhook payload, and assign it to the input pin
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
if (
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
and node.webhook_id
):
if webhook_payload_key not in graph_inputs:
raise ValueError(
f"Node {block.name} #{node.id} webhook payload is missing"
)
input_data = {"payload": graph_inputs[webhook_payload_key]}
# Apply node credentials overrides
if node_credentials_input_map and (
node_credentials := node_credentials_input_map.get(node.id)
):
input_data.update({k: v.model_dump() for k, v in node_credentials.items()})
input_data, error = validate_exec(node, input_data)
if input_data is None:
raise ValueError(error)
else:
nodes_input.append((node.id, input_data))
if not nodes_input:
raise ValueError(
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
)
return nodes_input
# ============ Execution Queue Helpers ============ #
class CancelExecutionEvent(BaseModel):
graph_exec_id: str
GRAPH_EXECUTION_EXCHANGE = Exchange(
name="graph_execution",
type=ExchangeType.DIRECT,
durable=True,
auto_delete=False,
)
GRAPH_EXECUTION_QUEUE_NAME = "graph_execution_queue"
GRAPH_EXECUTION_ROUTING_KEY = "graph_execution.run"
GRAPH_EXECUTION_CANCEL_EXCHANGE = Exchange(
name="graph_execution_cancel",
type=ExchangeType.FANOUT,
durable=True,
auto_delete=True,
)
GRAPH_EXECUTION_CANCEL_QUEUE_NAME = "graph_execution_cancel_queue"
def create_execution_queue_config() -> RabbitMQConfig:
"""
Define two exchanges and queues:
- 'graph_execution' (DIRECT) for run tasks.
- 'graph_execution_cancel' (FANOUT) for cancel requests.
"""
run_queue = Queue(
name=GRAPH_EXECUTION_QUEUE_NAME,
exchange=GRAPH_EXECUTION_EXCHANGE,
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
durable=True,
auto_delete=False,
)
cancel_queue = Queue(
name=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
routing_key="", # not used for FANOUT
durable=True,
auto_delete=False,
)
return RabbitMQConfig(
vhost="/",
exchanges=[GRAPH_EXECUTION_EXCHANGE, GRAPH_EXECUTION_CANCEL_EXCHANGE],
queues=[run_queue, cancel_queue],
)
async def add_graph_execution_async(
graph_id: str,
user_id: str,
inputs: BlockInput,
preset_id: Optional[str] = None,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
) -> GraphExecutionWithNodes:
"""
Adds a graph execution to the queue and returns the execution entry.
Args:
graph_id: The ID of the graph to execute.
user_id: The ID of the user executing the graph.
inputs: The input data for the graph execution.
preset_id: The ID of the preset to use.
graph_version: The version of the graph to execute.
graph_credentials_inputs: Credentials inputs to use in the execution.
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
Returns:
GraphExecutionEntry: The entry for the graph execution.
Raises:
ValueError: If the graph is not found or if there are validation errors.
""" # noqa
graph: GraphModel | None = await get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
node_credentials_input_map = (
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else None
)
graph_exec = await create_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
starting_nodes_input=construct_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=inputs,
node_credentials_input_map=node_credentials_input_map,
),
preset_id=preset_id,
)
try:
queue = await get_async_execution_queue()
graph_exec_entry = graph_exec.to_graph_execution_entry()
if node_credentials_input_map:
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
await queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec_entry.model_dump_json(),
exchange=GRAPH_EXECUTION_EXCHANGE,
)
bus = get_async_execution_event_bus()
await bus.publish(graph_exec)
return graph_exec
except Exception as e:
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
await update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
ExecutionStatus.FAILED,
)
await update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.FAILED,
stats=GraphExecutionStats(error=str(e)),
)
raise
def add_graph_execution(
graph_id: str,
user_id: str,
inputs: BlockInput,
preset_id: Optional[str] = None,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
) -> GraphExecutionWithNodes:
"""
Adds a graph execution to the queue and returns the execution entry.
Args:
graph_id: The ID of the graph to execute.
user_id: The ID of the user executing the graph.
inputs: The input data for the graph execution.
preset_id: The ID of the preset to use.
graph_version: The version of the graph to execute.
graph_credentials_inputs: Credentials inputs to use in the execution.
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
Returns:
GraphExecutionEntry: The entry for the graph execution.
Raises:
ValueError: If the graph is not found or if there are validation errors.
"""
db = get_db_client()
graph: GraphModel | None = db.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
node_credentials_input_map = (
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else None
)
graph_exec = db.create_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
starting_nodes_input=construct_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=inputs,
node_credentials_input_map=node_credentials_input_map,
),
preset_id=preset_id,
)
try:
queue = get_execution_queue()
graph_exec_entry = graph_exec.to_graph_execution_entry()
if node_credentials_input_map:
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec_entry.model_dump_json(),
exchange=GRAPH_EXECUTION_EXCHANGE,
)
bus = get_execution_event_bus()
bus.publish(graph_exec)
return graph_exec
except Exception as e:
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
db.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
ExecutionStatus.FAILED,
)
db.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.FAILED,
stats=GraphExecutionStats(error=str(e)),
)
raise

View File

@@ -2,7 +2,6 @@ import logging
from collections import defaultdict
from typing import Annotated, Any, Dict, List, Optional, Sequence
from autogpt_libs.utils.cache import thread_cached
from fastapi import APIRouter, Body, Depends, HTTPException
from prisma.enums import AgentExecutionStatus, APIKeyPermission
from typing_extensions import TypedDict
@@ -13,17 +12,10 @@ from backend.data import graph as graph_db
from backend.data.api_key import APIKey
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.execution import NodeExecutionResult
from backend.executor import ExecutionManager
from backend.executor.utils import add_graph_execution_async
from backend.server.external.middleware import require_permission
from backend.util.service import get_service_client
from backend.util.settings import Settings
@thread_cached
def execution_manager_client() -> ExecutionManager:
return get_service_client(ExecutionManager)
settings = Settings()
logger = logging.getLogger(__name__)
@@ -98,20 +90,20 @@ def execute_graph_block(
path="/graphs/{graph_id}/execute/{graph_version}",
tags=["graphs"],
)
def execute_graph(
async def execute_graph(
graph_id: str,
graph_version: int,
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
) -> dict[str, Any]:
try:
graph_exec = execution_manager_client().add_execution(
graph_id,
graph_version=graph_version,
data=node_input,
graph_exec = await add_graph_execution_async(
graph_id=graph_id,
user_id=api_key.user_id,
inputs=node_input,
graph_version=graph_version,
)
return {"id": graph_exec.graph_exec_id}
return {"id": graph_exec.id}
except Exception as e:
msg = str(e).encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
from typing import TYPE_CHECKING, Annotated, Literal
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from pydantic import BaseModel, Field
@@ -14,13 +15,12 @@ from backend.data.integrations import (
wait_for_webhook_event,
)
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
from backend.executor.manager import ExecutionManager
from backend.executor.utils import add_graph_execution_async
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
from backend.util.exceptions import NeedConfirmation, NotFoundError
from backend.util.service import get_service_client
from backend.util.settings import Settings
if TYPE_CHECKING:
@@ -309,19 +309,22 @@ async def webhook_ingress_generic(
if not webhook.attached_nodes:
return
executor = get_service_client(ExecutionManager)
executions: list[Awaitable] = []
for node in webhook.attached_nodes:
logger.debug(f"Webhook-attached node: {node}")
if not node.is_triggered_by_event_type(event_type):
logger.debug(f"Node #{node.id} doesn't trigger on event {event_type}")
continue
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
executor.add_execution(
graph_id=node.graph_id,
graph_version=node.graph_version,
data={f"webhook_{webhook_id}_payload": payload},
user_id=webhook.user_id,
executions.append(
add_graph_execution_async(
user_id=webhook.user_id,
graph_id=node.graph_id,
graph_version=node.graph_version,
inputs={f"webhook_{webhook_id}_payload": payload},
)
)
asyncio.gather(*executions)
@router.post("/webhooks/{webhook_id}/ping")

View File

@@ -17,7 +17,6 @@ import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.user
import backend.server.integrations.router
import backend.server.routers.postmark.postmark
import backend.server.routers.v1
import backend.server.v2.admin.store_admin_routes
@@ -29,6 +28,7 @@ import backend.server.v2.store.model
import backend.server.v2.store.routes
import backend.util.service
import backend.util.settings
from backend.blocks.llm import LlmModel
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.server.external.api import external_app
@@ -57,8 +57,7 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
# FIXME ERROR: operator does not exist: text ? unknown
# await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
with launch_darkly_context():
yield
await backend.data.db.disconnect()
@@ -156,11 +155,12 @@ class AgentServer(backend.util.service.AppProcess):
graph_version: Optional[int] = None,
node_input: Optional[dict[str, Any]] = None,
):
return backend.server.routers.v1.execute_graph(
return await backend.server.routers.v1.execute_graph(
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version,
node_input=node_input or {},
inputs=node_input or {},
credentials_inputs={},
)
@staticmethod
@@ -275,7 +275,9 @@ class AgentServer(backend.util.service.AppProcess):
provider: ProviderName,
credentials: Credentials,
) -> Credentials:
return backend.server.integrations.router.create_credentials(
from backend.server.integrations.router import create_credentials
return create_credentials(
user_id=user_id, provider=provider, credentials=credentials
)

View File

@@ -13,7 +13,6 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
import backend.data.block
import backend.server.integrations.router
import backend.server.routers.analytics
import backend.server.v2.library.db as library_db
@@ -31,7 +30,7 @@ from backend.data.api_key import (
suspend_api_key,
update_api_key_permissions,
)
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
from backend.data.credit import (
AutoTopUpConfig,
RefundRequest,
@@ -41,6 +40,8 @@ from backend.data.credit import (
get_user_credit_model,
set_auto_top_up,
)
from backend.data.execution import AsyncRedisExecutionEventBus
from backend.data.model import CredentialsMetaInput
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.data.onboarding import (
UserOnboardingUpdate,
@@ -49,13 +50,16 @@ from backend.data.onboarding import (
onboarding_enabled,
update_user_onboarding,
)
from backend.data.rabbitmq import AsyncRabbitMQ
from backend.data.user import (
get_or_create_user,
get_user_notification_preference,
update_user_email,
update_user_notification_preference,
)
from backend.executor import ExecutionManager, Scheduler, scheduler
from backend.executor import Scheduler, scheduler
from backend.executor import utils as execution_utils
from backend.executor.utils import create_execution_queue_config
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import (
on_graph_activate,
@@ -79,13 +83,20 @@ if TYPE_CHECKING:
@thread_cached
def execution_manager_client() -> ExecutionManager:
return get_service_client(ExecutionManager)
def execution_scheduler_client() -> Scheduler:
return get_service_client(Scheduler)
@thread_cached
def execution_scheduler_client() -> Scheduler:
return get_service_client(Scheduler)
async def execution_queue_client() -> AsyncRabbitMQ:
client = AsyncRabbitMQ(create_execution_queue_config())
await client.connect()
return client
@thread_cached
def execution_event_bus() -> AsyncRedisExecutionEventBus:
return AsyncRedisExecutionEventBus()
settings = Settings()
@@ -206,7 +217,7 @@ async def is_onboarding_enabled():
@v1_router.get(path="/blocks", tags=["blocks"], dependencies=[Depends(auth_middleware)])
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.data.block.get_blocks().values()]
blocks = [block() for block in get_blocks().values()]
costs = get_block_costs()
return [
{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks if not b.disabled
@@ -219,7 +230,7 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
dependencies=[Depends(auth_middleware)],
)
def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput:
obj = backend.data.block.get_block(block_id)
obj = get_block(block_id)
if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
@@ -308,7 +319,7 @@ async def configure_user_auto_top_up(
dependencies=[Depends(auth_middleware)],
)
async def get_user_auto_top_up(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> AutoTopUpConfig:
return await get_auto_top_up(user_id)
@@ -375,7 +386,7 @@ async def get_credit_history(
@v1_router.get(path="/credits/refunds", dependencies=[Depends(auth_middleware)])
async def get_refund_requests(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> list[RefundRequest]:
return await _user_credit_model.get_refund_requests(user_id)
@@ -391,7 +402,7 @@ class DeleteGraphResponse(TypedDict):
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
async def get_graphs(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> Sequence[graph_db.GraphModel]:
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
@@ -580,16 +591,25 @@ async def set_graph_active_version(
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
def execute_graph(
async def execute_graph(
graph_id: str,
node_input: Annotated[dict[str, Any], Body(..., default_factory=dict)],
user_id: Annotated[str, Depends(get_user_id)],
inputs: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
credentials_inputs: Annotated[
dict[str, CredentialsMetaInput], Body(..., embed=True, default_factory=dict)
],
graph_version: Optional[int] = None,
preset_id: Optional[str] = None,
) -> ExecuteGraphResponse:
graph_exec = execution_manager_client().add_execution(
graph_id, node_input, user_id=user_id, graph_version=graph_version
graph_exec = await execution_utils.add_graph_execution_async(
graph_id=graph_id,
user_id=user_id,
inputs=inputs,
preset_id=preset_id,
graph_version=graph_version,
graph_credentials_inputs=credentials_inputs,
)
return ExecuteGraphResponse(graph_exec_id=graph_exec.graph_exec_id)
return ExecuteGraphResponse(graph_exec_id=graph_exec.id)
@v1_router.post(
@@ -605,9 +625,7 @@ async def stop_graph_run(
):
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
await asyncio.to_thread(
lambda: execution_manager_client().cancel_execution(graph_exec_id)
)
await _cancel_execution(graph_exec_id)
# Retrieve & return canceled graph execution in its final state
result = await execution_db.get_graph_execution(
@@ -621,6 +639,49 @@ async def stop_graph_run(
return result
async def _cancel_execution(graph_exec_id: str):
"""
Mechanism:
1. Set the cancel event
2. Graph executor's cancel handler thread detects the event, terminates workers,
reinitializes worker pool, and returns.
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
queue_client = await execution_queue_client()
await queue_client.publish_message(
routing_key="",
message=execution_utils.CancelExecutionEvent(
graph_exec_id=graph_exec_id
).model_dump_json(),
exchange=execution_utils.GRAPH_EXECUTION_CANCEL_EXCHANGE,
)
# Update the status of the graph & node executions
await execution_db.update_graph_execution_stats(
graph_exec_id,
execution_db.ExecutionStatus.TERMINATED,
)
node_execs = [
node_exec.model_copy(update={"status": execution_db.ExecutionStatus.TERMINATED})
for node_exec in await execution_db.get_node_execution_results(
graph_exec_id=graph_exec_id,
statuses=[
execution_db.ExecutionStatus.QUEUED,
execution_db.ExecutionStatus.RUNNING,
execution_db.ExecutionStatus.INCOMPLETE,
],
)
]
await execution_db.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in node_execs],
execution_db.ExecutionStatus.TERMINATED,
)
await asyncio.gather(
*[execution_event_bus().publish(node_exec) for node_exec in node_execs]
)
@v1_router.get(
path="/executions",
tags=["graphs"],
@@ -792,7 +853,7 @@ async def create_api_key(
dependencies=[Depends(auth_middleware)],
)
async def get_api_keys(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> list[APIKeyWithoutHash]:
"""List all API keys for the user"""
try:

View File

@@ -6,7 +6,6 @@ import prisma.errors
import prisma.fields
import prisma.models
import prisma.types
from prisma.types import AgentPresetCreateInput
import backend.data.graph
import backend.server.model
@@ -69,12 +68,12 @@ async def list_library_agents(
if search_term:
where_clause["OR"] = [
{
"Agent": {
"AgentGraph": {
"is": {"name": {"contains": search_term, "mode": "insensitive"}}
}
},
{
"Agent": {
"AgentGraph": {
"is": {
"description": {"contains": search_term, "mode": "insensitive"}
}
@@ -233,7 +232,8 @@ async def create_library_agent(
isCreatedByUser=(user_id == graph.user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
Agent={
# Creator={"connect": {"id": agent.userId}},
AgentGraph={
"connect": {
"graphVersionId": {"id": graph.id, "version": graph.version}
}
@@ -247,38 +247,41 @@ async def create_library_agent(
async def update_agent_version_in_library(
user_id: str,
agent_id: str,
agent_version: int,
agent_graph_id: str,
agent_graph_version: int,
) -> None:
"""
Updates the agent version in the library if useGraphIsActiveVersion is True.
Args:
user_id: Owner of the LibraryAgent.
agent_id: The agent's ID to update.
agent_version: The new version of the agent.
agent_graph_id: The agent graph's ID to update.
agent_graph_version: The new version of the agent graph.
Raises:
DatabaseError: If there's an error with the update.
"""
logger.debug(
f"Updating agent version in library for user #{user_id}, "
f"agent #{agent_id} v{agent_version}"
f"agent #{agent_graph_id} v{agent_graph_version}"
)
try:
library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise(
where={
"userId": user_id,
"agentId": agent_id,
"agentGraphId": agent_graph_id,
"useGraphIsActiveVersion": True,
},
)
await prisma.models.LibraryAgent.prisma().update(
where={"id": library_agent.id},
data={
"Agent": {
"AgentGraph": {
"connect": {
"graphVersionId": {"id": agent_id, "version": agent_version}
"graphVersionId": {
"id": agent_graph_id,
"version": agent_graph_version,
}
},
},
},
@@ -342,7 +345,7 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
"""
try:
await prisma.models.LibraryAgent.prisma().delete_many(
where={"agentId": graph_id, "userId": user_id}
where={"agentGraphId": graph_id, "userId": user_id}
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error deleting library agent: {e}")
@@ -375,10 +378,10 @@ async def add_store_agent_to_library(
async with locked_transaction(f"add_agent_trx_{user_id}"):
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}, include={"Agent": True}
where={"id": store_listing_version_id}, include={"AgentGraph": True}
)
)
if not store_listing_version or not store_listing_version.Agent:
if not store_listing_version or not store_listing_version.AgentGraph:
logger.warning(
f"Store listing version not found: {store_listing_version_id}"
)
@@ -386,7 +389,7 @@ async def add_store_agent_to_library(
f"Store listing version {store_listing_version_id} not found or invalid"
)
graph = store_listing_version.Agent
graph = store_listing_version.AgentGraph
if graph.userId == user_id:
logger.warning(
f"User #{user_id} attempted to add their own agent to their library"
@@ -398,8 +401,8 @@ async def add_store_agent_to_library(
await prisma.models.LibraryAgent.prisma().find_first(
where={
"userId": user_id,
"agentId": graph.id,
"agentVersion": graph.version,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
},
include=library_agent_include(user_id),
)
@@ -421,15 +424,15 @@ async def add_store_agent_to_library(
added_agent = await prisma.models.LibraryAgent.prisma().create(
data=prisma.types.LibraryAgentCreateInput(
userId=user_id,
agentId=graph.id,
agentVersion=graph.version,
agentGraphId=graph.id,
agentGraphVersion=graph.version,
isCreatedByUser=False,
),
include=library_agent_include(user_id),
)
logger.debug(
f"Added graph #{graph.id} "
f"for store listing #{store_listing_version.id} "
f"Added graph #{graph.id} v{graph.version}"
f"for store listing version #{store_listing_version.id} "
f"to library for user #{user_id}"
)
return library_model.LibraryAgent.from_db(added_agent)
@@ -468,8 +471,8 @@ async def set_is_deleted_for_library_agent(
count = await prisma.models.LibraryAgent.prisma().update_many(
where={
"userId": user_id,
"agentId": agent_id,
"agentVersion": agent_version,
"agentGraphId": agent_id,
"agentGraphVersion": agent_version,
},
data={"isDeleted": is_deleted},
)
@@ -598,21 +601,22 @@ async def upsert_preset(
f"Upserting preset #{preset_id} ({repr(preset.name)}) for user #{user_id}",
)
try:
inputs = [
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput(
name=name, data=prisma.fields.Json(data)
)
for name, data in preset.inputs.items()
]
if preset_id:
# Update existing preset
updated = await prisma.models.AgentPreset.prisma().update(
where={"id": preset_id},
data=AgentPresetCreateInput(
name=preset.name,
description=preset.description,
isActive=preset.is_active,
InputPresets={
"create": [
{"name": name, "data": prisma.fields.Json(data)}
for name, data in preset.inputs.items()
]
},
),
data={
"name": preset.name,
"description": preset.description,
"isActive": preset.is_active,
"InputPresets": {"create": inputs},
},
include={"InputPresets": True},
)
if not updated:
@@ -625,15 +629,10 @@ async def upsert_preset(
userId=user_id,
name=preset.name,
description=preset.description,
agentId=preset.agent_id,
agentVersion=preset.agent_version,
agentGraphId=preset.graph_id,
agentGraphVersion=preset.graph_version,
isActive=preset.is_active,
InputPresets={
"create": [
{"name": name, "data": prisma.fields.Json(data)}
for name, data in preset.inputs.items()
]
},
InputPresets={"create": inputs},
),
include={"InputPresets": True},
)

View File

@@ -30,8 +30,8 @@ async def test_get_library_agents(mocker):
prisma.models.LibraryAgent(
id="ua1",
userId="test-user",
agentId="agent2",
agentVersion=1,
agentGraphId="agent2",
agentGraphVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
@@ -39,7 +39,7 @@ async def test_get_library_agents(mocker):
updatedAt=datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
Agent=prisma.models.AgentGraph(
AgentGraph=prisma.models.AgentGraph(
id="agent2",
version=1,
name="Test Agent 2",
@@ -71,8 +71,8 @@ async def test_get_library_agents(mocker):
assert result.agents[0].id == "ua1"
assert result.agents[0].name == "Test Agent 2"
assert result.agents[0].description == "Test Description 2"
assert result.agents[0].agent_id == "agent2"
assert result.agents[0].agent_version == 1
assert result.agents[0].graph_id == "agent2"
assert result.agents[0].graph_version == 1
assert result.agents[0].can_access_graph is False
assert result.agents[0].is_latest_version is True
assert result.pagination.total_items == 1
@@ -90,8 +90,8 @@ async def test_add_agent_to_library(mocker):
version=1,
createdAt=datetime.now(),
updatedAt=datetime.now(),
agentId="agent1",
agentVersion=1,
agentGraphId="agent1",
agentGraphVersion=1,
name="Test Agent",
subHeading="Test Agent Subheading",
imageUrls=["https://example.com/image.jpg"],
@@ -102,7 +102,7 @@ async def test_add_agent_to_library(mocker):
isAvailable=True,
storeListingId="listing123",
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
Agent=prisma.models.AgentGraph(
AgentGraph=prisma.models.AgentGraph(
id="agent1",
version=1,
name="Test Agent",
@@ -116,8 +116,8 @@ async def test_add_agent_to_library(mocker):
mock_library_agent_data = prisma.models.LibraryAgent(
id="ua1",
userId="test-user",
agentId=mock_store_listing_data.agentId,
agentVersion=1,
agentGraphId=mock_store_listing_data.agentGraphId,
agentGraphVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
@@ -125,7 +125,7 @@ async def test_add_agent_to_library(mocker):
updatedAt=datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
Agent=mock_store_listing_data.Agent,
AgentGraph=mock_store_listing_data.AgentGraph,
)
# Mock prisma calls
@@ -147,19 +147,22 @@ async def test_add_agent_to_library(mocker):
# Verify mocks called correctly
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"Agent": True}
where={"id": "version123"}, include={"AgentGraph": True}
)
mock_library_agent.return_value.find_first.assert_called_once_with(
where={
"userId": "test-user",
"agentId": "agent1",
"agentVersion": 1,
"agentGraphId": "agent1",
"agentGraphVersion": 1,
},
include=library_agent_include("test-user"),
)
mock_library_agent.return_value.create.assert_called_once_with(
data=prisma.types.LibraryAgentCreateInput(
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
userId="test-user",
agentGraphId="agent1",
agentGraphVersion=1,
isCreatedByUser=False,
),
include=library_agent_include("test-user"),
)
@@ -182,5 +185,5 @@ async def test_add_agent_to_library_not_found(mocker):
# Verify mock called correctly
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"Agent": True}
where={"id": "version123"}, include={"AgentGraph": True}
)

View File

@@ -25,8 +25,8 @@ class LibraryAgent(pydantic.BaseModel):
"""
id: str
agent_id: str
agent_version: int
graph_id: str
graph_version: int
image_url: str | None
@@ -58,12 +58,12 @@ class LibraryAgent(pydantic.BaseModel):
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
model instance.
"""
if not agent.Agent:
if not agent.AgentGraph:
raise ValueError("Associated Agent record is required.")
graph = graph_model.GraphModel.from_db(agent.Agent)
graph = graph_model.GraphModel.from_db(agent.AgentGraph)
agent_updated_at = agent.Agent.updatedAt
agent_updated_at = agent.AgentGraph.updatedAt
lib_agent_updated_at = agent.updatedAt
# Compute updated_at as the latest between library agent and graph
@@ -83,21 +83,21 @@ class LibraryAgent(pydantic.BaseModel):
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
days=7
)
executions = agent.Agent.AgentGraphExecution or []
executions = agent.AgentGraph.Executions or []
status_result = _calculate_agent_status(executions, week_ago)
status = status_result.status
new_output = status_result.new_output
# Check if user can access the graph
can_access_graph = agent.Agent.userId == agent.userId
can_access_graph = agent.AgentGraph.userId == agent.userId
# Hard-coded to True until a method to check is implemented
is_latest_version = True
return LibraryAgent(
id=agent.id,
agent_id=agent.agentId,
agent_version=agent.agentVersion,
graph_id=agent.agentGraphId,
graph_version=agent.agentGraphVersion,
image_url=agent.imageUrl,
creator_name=creator_name,
creator_image_url=creator_image_url,
@@ -174,8 +174,8 @@ class LibraryAgentPreset(pydantic.BaseModel):
id: str
updated_at: datetime.datetime
agent_id: str
agent_version: int
graph_id: str
graph_version: int
name: str
description: str
@@ -194,8 +194,8 @@ class LibraryAgentPreset(pydantic.BaseModel):
return cls(
id=preset.id,
updated_at=preset.updatedAt,
agent_id=preset.agentId,
agent_version=preset.agentVersion,
graph_id=preset.agentGraphId,
graph_version=preset.agentGraphVersion,
name=preset.name,
description=preset.description,
is_active=preset.isActive,
@@ -218,8 +218,8 @@ class CreateLibraryAgentPresetRequest(pydantic.BaseModel):
name: str
description: str
inputs: block_model.BlockInput
agent_id: str
agent_version: int
graph_id: str
graph_version: int
is_active: bool

View File

@@ -5,7 +5,6 @@ import prisma.models
import pytest
import backend.server.v2.library.model as library_model
from backend.util import json
@pytest.mark.asyncio
@@ -15,8 +14,8 @@ async def test_agent_preset_from_db():
id="test-agent-123",
createdAt=datetime.datetime.now(),
updatedAt=datetime.datetime.now(),
agentId="agent-123",
agentVersion=1,
agentGraphId="agent-123",
agentGraphVersion=1,
name="Test Agent",
description="Test agent description",
isActive=True,
@@ -27,7 +26,7 @@ async def test_agent_preset_from_db():
id="input-123",
time=datetime.datetime.now(),
name="input1",
data=json.dumps({"type": "string", "value": "test value"}), # type: ignore
data=prisma.Json({"type": "string", "value": "test value"}),
)
],
)
@@ -36,7 +35,7 @@ async def test_agent_preset_from_db():
agent = library_model.LibraryAgentPreset.from_db(db_agent)
assert agent.id == "test-agent-123"
assert agent.agent_version == 1
assert agent.graph_version == 1
assert agent.is_active is True
assert agent.name == "Test Agent"
assert agent.description == "Test agent description"

View File

@@ -2,25 +2,17 @@ import logging
from typing import Annotated, Any
import autogpt_libs.auth as autogpt_auth_lib
import autogpt_libs.utils.cache
from fastapi import APIRouter, Body, Depends, HTTPException, status
import backend.executor
import backend.server.v2.library.db as db
import backend.server.v2.library.model as models
import backend.util.service
from backend.executor.utils import add_graph_execution_async
logger = logging.getLogger(__name__)
router = APIRouter()
@autogpt_libs.utils.cache.thread_cached
def execution_manager_client() -> backend.executor.ExecutionManager:
"""Return a cached instance of ExecutionManager client."""
return backend.util.service.get_service_client(backend.executor.ExecutionManager)
@router.get(
"/presets",
summary="List presets",
@@ -226,17 +218,17 @@ async def execute_preset(
# Merge input overrides with preset inputs
merged_node_input = preset.inputs | node_input
execution = execution_manager_client().add_execution(
execution = await add_graph_execution_async(
graph_id=graph_id,
graph_version=graph_version,
data=merged_node_input,
user_id=user_id,
inputs=merged_node_input,
preset_id=preset_id,
graph_version=graph_version,
)
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
return {"id": execution.graph_exec_id}
return {"id": execution.id}
except HTTPException:
raise
except Exception as e:

View File

@@ -35,8 +35,8 @@ async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
agents=[
library_model.LibraryAgent(
id="test-agent-1",
agent_id="test-agent-1",
agent_version=1,
graph_id="test-agent-1",
graph_version=1,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
@@ -51,8 +51,8 @@ async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
),
library_model.LibraryAgent(
id="test-agent-2",
agent_id="test-agent-2",
agent_version=1,
graph_id="test-agent-2",
graph_version=1,
name="Test Agent 2",
description="Test Description 2",
image_url=None,
@@ -78,9 +78,9 @@ async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
data = library_model.LibraryAgentResponse.model_validate(response.json())
assert len(data.agents) == 2
assert data.agents[0].agent_id == "test-agent-1"
assert data.agents[0].graph_id == "test-agent-1"
assert data.agents[0].can_access_graph is True
assert data.agents[1].agent_id == "test-agent-2"
assert data.agents[1].graph_id == "test-agent-2"
assert data.agents[1].can_access_graph is False
mock_db_call.assert_called_once_with(
user_id="test-user-id",

View File

@@ -200,17 +200,17 @@ async def get_available_graph(
"isAvailable": True,
"isDeleted": False,
},
include={"Agent": {"include": {"AgentNodes": True}}},
include={"AgentGraph": {"include": {"Nodes": True}}},
)
)
if not store_listing_version or not store_listing_version.Agent:
if not store_listing_version or not store_listing_version.AgentGraph:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing version {store_listing_version_id} not found",
)
graph = GraphModel.from_db(store_listing_version.Agent)
graph = GraphModel.from_db(store_listing_version.AgentGraph)
# We return graph meta, without nodes, they cannot be just removed
# because then input_schema would be empty
return {
@@ -516,7 +516,7 @@ async def delete_store_submission(
try:
# Verify the submission belongs to this user
submission = await prisma.models.StoreListing.prisma().find_first(
where={"agentId": submission_id, "owningUserId": user_id}
where={"agentGraphId": submission_id, "owningUserId": user_id}
)
if not submission:
@@ -598,7 +598,7 @@ async def create_store_submission(
# Check if listing already exists for this agent
existing_listing = await prisma.models.StoreListing.prisma().find_first(
where=prisma.types.StoreListingWhereInput(
agentId=agent_id, owningUserId=user_id
agentGraphId=agent_id, owningUserId=user_id
)
)
@@ -625,15 +625,15 @@ async def create_store_submission(
# If no existing listing, create a new one
data = prisma.types.StoreListingCreateInput(
slug=slug,
agentId=agent_id,
agentVersion=agent_version,
agentGraphId=agent_id,
agentGraphVersion=agent_version,
owningUserId=user_id,
createdAt=datetime.now(tz=timezone.utc),
Versions={
"create": [
prisma.types.StoreListingVersionCreateInput(
agentId=agent_id,
agentVersion=agent_version,
agentGraphId=agent_id,
agentGraphVersion=agent_version,
name=name,
videoUrl=video_url,
imageUrls=image_urls,
@@ -758,8 +758,8 @@ async def create_store_version(
new_version = await prisma.models.StoreListingVersion.prisma().create(
data=prisma.types.StoreListingVersionCreateInput(
version=next_version,
agentId=agent_id,
agentVersion=agent_version,
agentGraphId=agent_id,
agentGraphVersion=agent_version,
name=name,
videoUrl=video_url,
imageUrls=image_urls,
@@ -959,17 +959,17 @@ async def get_my_agents(
try:
search_filter: prisma.types.LibraryAgentWhereInput = {
"userId": user_id,
"Agent": {"is": {"StoreListing": {"none": {"isDeleted": False}}}},
"AgentGraph": {"is": {"StoreListings": {"none": {"isDeleted": False}}}},
"isArchived": False,
"isDeleted": False,
}
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
where=search_filter,
order=[{"agentVersion": "desc"}],
order=[{"agentGraphVersion": "desc"}],
skip=(page - 1) * page_size,
take=page_size,
include={"Agent": True},
include={"AgentGraph": True},
)
total = await prisma.models.LibraryAgent.prisma().count(where=search_filter)
@@ -985,7 +985,7 @@ async def get_my_agents(
agent_image=library_agent.imageUrl,
)
for library_agent in library_agents
if (graph := library_agent.Agent)
if (graph := library_agent.AgentGraph)
]
return backend.server.v2.store.model.MyAgentsResponse(
@@ -1020,13 +1020,13 @@ async def get_agent(
graph = await backend.data.graph.get_graph(
user_id=user_id,
graph_id=store_listing_version.agentId,
version=store_listing_version.agentVersion,
graph_id=store_listing_version.agentGraphId,
version=store_listing_version.agentGraphVersion,
for_export=True,
)
if not graph:
raise ValueError(
f"Agent {store_listing_version.agentId} v{store_listing_version.agentVersion} not found"
f"Agent {store_listing_version.agentGraphId} v{store_listing_version.agentGraphVersion} not found"
)
return graph
@@ -1050,11 +1050,14 @@ async def _get_missing_sub_store_listing(
# Fetch all the sub-graphs that are listed, and return the ones missing.
store_listed_sub_graphs = {
(listing.agentId, listing.agentVersion)
(listing.agentGraphId, listing.agentGraphVersion)
for listing in await prisma.models.StoreListingVersion.prisma().find_many(
where={
"OR": [
{"agentId": sub_graph.id, "agentVersion": sub_graph.version}
{
"agentGraphId": sub_graph.id,
"agentGraphVersion": sub_graph.version,
}
for sub_graph in sub_graphs
],
"submissionStatus": prisma.enums.SubmissionStatus.APPROVED,
@@ -1084,7 +1087,7 @@ async def review_store_submission(
where={"id": store_listing_version_id},
include={
"StoreListing": True,
"Agent": {"include": AGENT_GRAPH_INCLUDE}, # type: ignore
"AgentGraph": {"include": AGENT_GRAPH_INCLUDE},
},
)
)
@@ -1096,23 +1099,23 @@ async def review_store_submission(
)
# If approving, update the listing to indicate it has an approved version
if is_approved and store_listing_version.Agent:
heading = f"Sub-graph of {store_listing_version.name}v{store_listing_version.agentVersion}"
if is_approved and store_listing_version.AgentGraph:
heading = f"Sub-graph of {store_listing_version.name}v{store_listing_version.agentGraphVersion}"
sub_store_listing_versions = [
prisma.types.StoreListingVersionCreateWithoutRelationsInput(
agentId=sub_graph.id,
agentVersion=sub_graph.version,
agentGraphId=sub_graph.id,
agentGraphVersion=sub_graph.version,
name=sub_graph.name or heading,
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
subHeading=heading,
description=f"{heading}: {sub_graph.description}",
changesSummary=f"This listing is added as a {heading} / #{store_listing_version.agentId}.",
changesSummary=f"This listing is added as a {heading} / #{store_listing_version.agentGraphId}.",
isAvailable=False, # Hide sub-graphs from the store by default.
submittedAt=datetime.now(tz=timezone.utc),
)
for sub_graph in await _get_missing_sub_store_listing(
store_listing_version.Agent
store_listing_version.AgentGraph
)
]
@@ -1155,8 +1158,8 @@ async def review_store_submission(
# Convert to Pydantic model for consistency
return backend.server.v2.store.model.StoreSubmission(
agent_id=submission.agentId,
agent_version=submission.agentVersion,
agent_id=submission.agentGraphId,
agent_version=submission.agentGraphVersion,
name=submission.name,
sub_heading=submission.subHeading,
slug=(
@@ -1294,8 +1297,8 @@ async def get_admin_listings_with_versions(
# If we have versions, turn them into StoreSubmission models
for version in listing.Versions or []:
version_model = backend.server.v2.store.model.StoreSubmission(
agent_id=version.agentId,
agent_version=version.agentVersion,
agent_id=version.agentGraphId,
agent_version=version.agentGraphVersion,
name=version.name,
sub_heading=version.subHeading,
slug=listing.slug,
@@ -1324,8 +1327,8 @@ async def get_admin_listings_with_versions(
backend.server.v2.store.model.StoreListingWithVersions(
listing_id=listing.id,
slug=listing.slug,
agent_id=listing.agentId,
agent_version=listing.agentVersion,
agent_id=listing.agentGraphId,
agent_version=listing.agentGraphVersion,
active_version_id=listing.activeVersionId,
has_approved_version=listing.hasApprovedVersion,
creator_email=creator_email,

View File

@@ -170,14 +170,14 @@ async def test_create_store_submission(mocker):
isDeleted=False,
hasApprovedVersion=False,
slug="test-agent",
agentId="agent-id",
agentVersion=1,
agentGraphId="agent-id",
agentGraphVersion=1,
owningUserId="user-id",
Versions=[
prisma.models.StoreListingVersion(
id="version-id",
agentId="agent-id",
agentVersion=1,
agentGraphId="agent-id",
agentGraphVersion=1,
name="Test Agent",
description="Test description",
createdAt=datetime.now(),

View File

@@ -15,21 +15,25 @@ def to_dict(data) -> dict:
def dumps(data) -> str:
return json.dumps(jsonable_encoder(data))
return json.dumps(to_dict(data))
T = TypeVar("T")
@overload
def loads(data: str, *args, target_type: Type[T], **kwargs) -> T: ...
def loads(data: str | bytes, *args, target_type: Type[T], **kwargs) -> T: ...
@overload
def loads(data: str, *args, **kwargs) -> Any: ...
def loads(data: str | bytes, *args, **kwargs) -> Any: ...
def loads(data: str, *args, target_type: Type[T] | None = None, **kwargs) -> Any:
def loads(
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
) -> Any:
if isinstance(data, bytes):
data = data.decode("utf-8")
parsed = json.loads(data, *args, **kwargs)
if target_type:
return type_match(parsed, target_type)

View File

@@ -14,9 +14,7 @@ def sentry_init():
traces_sample_rate=1.0,
profiles_sample_rate=1.0,
environment=f"app:{Settings().config.app_env.value}-behave:{Settings().config.behave_as.value}",
_experiments={
"enable_logs": True,
},
_experiments={"enable_logs": True},
integrations=[
LoggingIntegration(sentry_logs_level=logging.INFO),
AnthropicIntegration(

View File

@@ -25,6 +25,7 @@ from pydantic import BaseModel, TypeAdapter, create_model
from backend.util.exceptions import InsufficientBalanceError
from backend.util.json import to_dict
from backend.util.metrics import sentry_init
from backend.util.process import AppProcess, get_service_name
from backend.util.retry import conn_retry
from backend.util.settings import Config
@@ -196,6 +197,7 @@ class AppService(BaseAppService, ABC):
self.shared_event_loop.run_until_complete(server.serve())
def run(self):
sentry_init()
super().run()
self.fastapi_app = FastAPI()

View File

@@ -137,6 +137,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default=8002,
description="The port for execution manager daemon to run on",
)
execution_manager_loop_max_retry: int = Field(
default=5,
description="The maximum number of retries for the execution manager loop",
)
execution_scheduler_port: int = Field(
default=8003,

View File

@@ -182,6 +182,7 @@ def _try_convert(value: Any, target_type: Type, raise_on_mismatch: bool) -> Any:
T = TypeVar("T")
TT = TypeVar("TT")
def type_match(value: Any, target_type: Type[T]) -> T:

View File

@@ -0,0 +1,50 @@
/*
Warnings:
- The relation LibraryAgent:AgentPreset was REMOVED
- A unique constraint covering the columns `[userId,agentGraphId,agentGraphVersion]` on the table `LibraryAgent` will be added. If there are existing duplicate values, this will fail.
- The foreign key constraints on AgentPreset and LibraryAgent are being changed from CASCADE to RESTRICT for AgentGraph deletion, which means you cannot delete AgentGraphs that have associated LibraryAgents or AgentPresets.
Use the following query to check whether these conditions are satisfied:
-- Check for duplicate LibraryAgent userId + agentGraphId + agentGraphVersion combinations that would violate the new unique constraint
SELECT la."userId",
la."agentId" as graph_id,
la."agentVersion" as graph_version,
COUNT(*) as multiplicity
FROM "LibraryAgent" la
GROUP BY la."userId",
la."agentId",
la."agentVersion"
HAVING COUNT(*) > 1;
*/
-- Drop foreign key constraints on columns we're about to rename
ALTER TABLE "AgentPreset" DROP CONSTRAINT "AgentPreset_agentId_agentVersion_fkey";
ALTER TABLE "LibraryAgent" DROP CONSTRAINT "LibraryAgent_agentId_agentVersion_fkey";
ALTER TABLE "LibraryAgent" DROP CONSTRAINT "LibraryAgent_agentPresetId_fkey";
-- Rename columns in AgentPreset
ALTER TABLE "AgentPreset" RENAME COLUMN "agentId" TO "agentGraphId";
ALTER TABLE "AgentPreset" RENAME COLUMN "agentVersion" TO "agentGraphVersion";
-- Rename columns in LibraryAgent
ALTER TABLE "LibraryAgent" RENAME COLUMN "agentId" TO "agentGraphId";
ALTER TABLE "LibraryAgent" RENAME COLUMN "agentVersion" TO "agentGraphVersion";
-- Drop LibraryAgent.agentPresetId column
ALTER TABLE "LibraryAgent" DROP COLUMN "agentPresetId";
-- Replace userId index with unique index on userId + agentGraphId + agentGraphVersion
DROP INDEX "LibraryAgent_userId_idx";
CREATE UNIQUE INDEX "LibraryAgent_userId_agentGraphId_agentGraphVersion_key" ON "LibraryAgent"("userId", "agentGraphId", "agentGraphVersion");
-- Re-add the foreign key constraints with new column names
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_agentGraphId_agentGraphVersion_fkey"
FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph"("id", "version")
ON DELETE RESTRICT -- Disallow deleting AgentGraph when still referenced by existing LibraryAgents
ON UPDATE CASCADE;
ALTER TABLE "AgentPreset" ADD CONSTRAINT "AgentPreset_agentGraphId_agentGraphVersion_fkey"
FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph"("id", "version")
ON DELETE RESTRICT -- Disallow deleting AgentGraph when still referenced by existing AgentPresets
ON UPDATE CASCADE;

View File

@@ -0,0 +1,35 @@
/*
- Rename column StoreListing.agentId to agentGraphId
- Rename column StoreListing.agentVersion to agentGraphVersion
- Rename column StoreListingVersion.agentId to agentGraphId
- Rename column StoreListingVersion.agentVersion to agentGraphVersion
*/
-- Drop foreign key constraints on columns we're about to rename
ALTER TABLE "StoreListing" DROP CONSTRAINT "StoreListing_agentId_agentVersion_fkey";
ALTER TABLE "StoreListingVersion" DROP CONSTRAINT "StoreListingVersion_agentId_agentVersion_fkey";
-- Drop indices on columns we're about to rename
DROP INDEX "StoreListing_agentId_key";
DROP INDEX "StoreListingVersion_agentId_agentVersion_idx";
-- Rename columns
ALTER TABLE "StoreListing" RENAME COLUMN "agentId" TO "agentGraphId";
ALTER TABLE "StoreListing" RENAME COLUMN "agentVersion" TO "agentGraphVersion";
ALTER TABLE "StoreListingVersion" RENAME COLUMN "agentId" TO "agentGraphId";
ALTER TABLE "StoreListingVersion" RENAME COLUMN "agentVersion" TO "agentGraphVersion";
-- Re-create indices with updated name on renamed columns
CREATE UNIQUE INDEX "StoreListing_agentGraphId_key" ON "StoreListing"("agentGraphId");
CREATE INDEX "StoreListingVersion_agentGraphId_agentGraphVersion_idx" ON "StoreListingVersion"("agentGraphId", "agentGraphVersion");
-- Re-create foreign key constraints with updated name on renamed columns
ALTER TABLE "StoreListing" ADD CONSTRAINT "StoreListing_agentGraphId_agentGraphVersion_fkey"
FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph"("id", "version")
ON DELETE CASCADE
ON UPDATE CASCADE;
ALTER TABLE "StoreListingVersion" ADD CONSTRAINT "StoreListingVersion_agentGraphId_agentGraphVersion_fkey"
FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph"("id", "version")
ON DELETE RESTRICT
ON UPDATE CASCADE;

View File

@@ -0,0 +1,16 @@
-- Modify the OnboardingStep enum
ALTER TYPE "OnboardingStep" ADD VALUE 'GET_RESULTS';
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_VISIT';
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_ADD_AGENT';
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_RUN_AGENT';
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_OPEN';
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_SAVE_AGENT';
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_RUN_AGENT';
-- Modify the UserOnboarding table
ALTER TABLE "UserOnboarding"
ADD COLUMN "updatedAt" TIMESTAMP(3),
ADD COLUMN "notificationDot" BOOLEAN NOT NULL DEFAULT true,
ADD COLUMN "notified" "OnboardingStep"[] DEFAULT '{}',
ADD COLUMN "rewardedFor" "OnboardingStep"[] DEFAULT '{}',
ADD COLUMN "onboardingAgentExecutionId" TEXT

View File

@@ -0,0 +1,56 @@
-- Backfill nulls with empty arrays
UPDATE "UserOnboarding"
SET "integrations" = ARRAY[]::TEXT[]
WHERE "integrations" IS NULL;
UPDATE "UserOnboarding"
SET "completedSteps" = '{}'
WHERE "completedSteps" IS NULL;
UPDATE "UserOnboarding"
SET "notified" = '{}'
WHERE "notified" IS NULL;
UPDATE "UserOnboarding"
SET "rewardedFor" = '{}'
WHERE "rewardedFor" IS NULL;
UPDATE "IntegrationWebhook"
SET "events" = ARRAY[]::TEXT[]
WHERE "events" IS NULL;
UPDATE "APIKey"
SET "permissions" = '{}'
WHERE "permissions" IS NULL;
UPDATE "Profile"
SET "links" = ARRAY[]::TEXT[]
WHERE "links" IS NULL;
UPDATE "StoreListingVersion"
SET "imageUrls" = ARRAY[]::TEXT[]
WHERE "imageUrls" IS NULL;
UPDATE "StoreListingVersion"
SET "categories" = ARRAY[]::TEXT[]
WHERE "categories" IS NULL;
-- Enforce NOT NULL constraints
ALTER TABLE "UserOnboarding"
ALTER COLUMN "integrations" SET NOT NULL,
ALTER COLUMN "completedSteps" SET NOT NULL,
ALTER COLUMN "notified" SET NOT NULL,
ALTER COLUMN "rewardedFor" SET NOT NULL;
ALTER TABLE "IntegrationWebhook"
ALTER COLUMN "events" SET NOT NULL;
ALTER TABLE "APIKey"
ALTER COLUMN "permissions" SET NOT NULL;
ALTER TABLE "Profile"
ALTER COLUMN "links" SET NOT NULL;
ALTER TABLE "StoreListingVersion"
ALTER COLUMN "imageUrls" SET NOT NULL,
ALTER COLUMN "categories" SET NOT NULL;

View File

@@ -1,11 +1,12 @@
datasource db {
provider = "postgresql"
url = env("DATABASE_URL")
provider = "postgresql"
url = env("DATABASE_URL")
directUrl = env("DIRECT_URL")
}
generator client {
provider = "prisma-client-py"
recursive_type_depth = 5
recursive_type_depth = -1
interface = "asyncio"
previewFeatures = ["views"]
}
@@ -39,25 +40,26 @@ model User {
AgentGraphExecutions AgentGraphExecution[]
AnalyticsDetails AnalyticsDetails[]
AnalyticsMetrics AnalyticsMetrics[]
CreditTransaction CreditTransaction[]
CreditTransactions CreditTransaction[]
AgentPreset AgentPreset[]
LibraryAgent LibraryAgent[]
AgentPresets AgentPreset[]
LibraryAgents LibraryAgent[]
Profile Profile[]
UserOnboarding UserOnboarding?
StoreListing StoreListing[]
StoreListingReview StoreListingReview[]
StoreListings StoreListing[]
StoreListingReviews StoreListingReview[]
StoreVersionsReviewed StoreListingVersion[]
APIKeys APIKey[]
IntegrationWebhooks IntegrationWebhook[]
UserNotificationBatch UserNotificationBatch[]
NotificationBatches UserNotificationBatch[]
@@index([id])
@@index([email])
}
enum OnboardingStep {
// Introductory onboarding (Library)
WELCOME
USAGE_REASON
INTEGRATIONS
@@ -65,18 +67,32 @@ enum OnboardingStep {
AGENT_NEW_RUN
AGENT_INPUT
CONGRATS
GET_RESULTS
// Marketplace
MARKETPLACE_VISIT
MARKETPLACE_ADD_AGENT
MARKETPLACE_RUN_AGENT
// Builder
BUILDER_OPEN
BUILDER_SAVE_AGENT
BUILDER_RUN_AGENT
}
model UserOnboarding {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime? @updatedAt
completedSteps OnboardingStep[] @default([])
notificationDot Boolean @default(true)
notified OnboardingStep[] @default([])
rewardedFor OnboardingStep[] @default([])
usageReason String?
integrations String[] @default([])
otherIntegrations String?
selectedStoreListingVersionId String?
agentInput Json?
onboardingAgentExecutionId String?
userId String @unique
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@ -102,13 +118,13 @@ model AgentGraph {
// This allows us to delete user data with deleting the agent which maybe in use by other users
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
AgentNodes AgentNode[]
AgentGraphExecution AgentGraphExecution[]
Nodes AgentNode[]
Executions AgentGraphExecution[]
AgentPreset AgentPreset[]
LibraryAgent LibraryAgent[]
StoreListing StoreListing[]
StoreListingVersion StoreListingVersion[]
Presets AgentPreset[]
LibraryAgents LibraryAgent[]
StoreListings StoreListing[]
StoreListingVersions StoreListingVersion[]
@@id(name: "graphVersionId", [id, version])
@@index([userId, isActive])
@@ -139,13 +155,12 @@ model AgentPreset {
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
agentId String
agentVersion Int
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
agentGraphId String
agentGraphVersion Int
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version], onDelete: Restrict)
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
LibraryAgents LibraryAgent[]
AgentExecution AgentGraphExecution[]
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
Executions AgentGraphExecution[]
isDeleted Boolean @default(false)
@@ -194,7 +209,7 @@ model UserNotificationBatch {
}
// For the library page
// It is a user controlled list of agents, that they will see in there library
// It is a user controlled list of agents, that they will see in their library
model LibraryAgent {
id String @id @default(uuid())
createdAt DateTime @default(now())
@@ -205,12 +220,9 @@ model LibraryAgent {
imageUrl String?
agentId String
agentVersion Int
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version])
agentPresetId String?
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
agentGraphId String
agentGraphVersion Int
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version], onDelete: Restrict)
creatorId String?
Creator Profile? @relation(fields: [creatorId], references: [id])
@@ -222,7 +234,7 @@ model LibraryAgent {
isArchived Boolean @default(false)
isDeleted Boolean @default(false)
@@index([userId])
@@unique([userId, agentGraphId, agentGraphVersion])
}
////////////////////////////////////////////////////////////
@@ -256,7 +268,7 @@ model AgentNode {
metadata Json @default("{}")
ExecutionHistory AgentNodeExecution[]
Executions AgentNodeExecution[]
@@index([agentGraphId, agentGraphVersion])
@@index([agentBlockId])
@@ -323,15 +335,15 @@ model AgentGraphExecution {
agentGraphVersion Int @default(1)
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version], onDelete: Cascade)
AgentNodeExecutions AgentNodeExecution[]
NodeExecutions AgentNodeExecution[]
// Link to User model -- Executed by this user
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
stats Json?
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
agentPresetId String?
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
@@index([agentGraphId, agentGraphVersion])
@@index([userId])
@@ -342,10 +354,10 @@ model AgentNodeExecution {
id String @id @default(uuid())
agentGraphExecutionId String
AgentGraphExecution AgentGraphExecution @relation(fields: [agentGraphExecutionId], references: [id], onDelete: Cascade)
GraphExecution AgentGraphExecution @relation(fields: [agentGraphExecutionId], references: [id], onDelete: Cascade)
agentNodeId String
AgentNode AgentNode @relation(fields: [agentNodeId], references: [id], onDelete: Cascade)
Node AgentNode @relation(fields: [agentNodeId], references: [id], onDelete: Cascade)
Input AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionInput")
Output AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionOutput")
@@ -627,9 +639,9 @@ model StoreListing {
ActiveVersion StoreListingVersion? @relation("ActiveVersion", fields: [activeVersionId], references: [id])
// The agent link here is only so we can do lookup on agentId
agentId String
agentVersion Int
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
agentGraphId String
agentGraphVersion Int
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version], onDelete: Cascade)
owningUserId String
OwningUser User @relation(fields: [owningUserId], references: [id])
@@ -638,7 +650,7 @@ model StoreListing {
Versions StoreListingVersion[] @relation("ListingVersions")
// Unique index on agentId to ensure only one listing per agent, regardless of number of versions the agent has.
@@unique([agentId])
@@unique([agentGraphId])
@@unique([owningUserId, slug])
// Used in the view query
@@index([isDeleted, hasApprovedVersion])
@@ -651,9 +663,9 @@ model StoreListingVersion {
updatedAt DateTime @default(now()) @updatedAt
// The agent and version to be listed on the store
agentId String
agentVersion Int
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version])
agentGraphId String
agentGraphVersion Int
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
// Content fields
name String
@@ -697,7 +709,7 @@ model StoreListingVersion {
@@index([storeListingId, submissionStatus, isAvailable])
@@index([submissionStatus])
@@index([reviewerId])
@@index([agentId, agentVersion]) // Non-unique index for efficient lookups
@@index([agentGraphId, agentGraphVersion]) // Non-unique index for efficient lookups
}
model StoreListingReview {

View File

@@ -1,4 +1,4 @@
from backend.data.execution import merge_execution_input, parse_execution_output
from backend.executor.utils import merge_execution_input, parse_execution_output
def test_parse_execution_output():

View File

@@ -360,8 +360,8 @@ async def test_execute_preset(server: SpinTestServer):
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
name="Test Preset With Clash",
description="Test preset with clashing input values",
agent_id=test_graph.id,
agent_version=test_graph.version,
graph_id=test_graph.id,
graph_version=test_graph.version,
inputs={
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
@@ -449,8 +449,8 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
name="Test Preset With Clash",
description="Test preset with clashing input values",
agent_id=test_graph.id,
agent_version=test_graph.version,
graph_id=test_graph.id,
graph_version=test_graph.version,
inputs={
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",

View File

@@ -10,16 +10,12 @@ from prisma.types import (
AgentGraphCreateInput,
AgentNodeCreateInput,
AgentNodeLinkCreateInput,
AgentPresetCreateInput,
AnalyticsDetailsCreateInput,
AnalyticsMetricsCreateInput,
APIKeyCreateInput,
CreditTransactionCreateInput,
LibraryAgentCreateInput,
ProfileCreateInput,
StoreListingCreateInput,
StoreListingReviewCreateInput,
StoreListingVersionCreateInput,
UserCreateInput,
)
@@ -140,14 +136,14 @@ async def main():
for _ in range(num_presets): # Create 1 AgentPreset per user
graph = random.choice(agent_graphs)
preset = await db.agentpreset.create(
data=AgentPresetCreateInput(
name=faker.sentence(nb_words=3),
description=faker.text(max_nb_chars=200),
userId=user.id,
agentId=graph.id,
agentVersion=graph.version,
isActive=True,
)
data={
"name": faker.sentence(nb_words=3),
"description": faker.text(max_nb_chars=200),
"userId": user.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"isActive": True,
}
)
agent_presets.append(preset)
@@ -160,16 +156,15 @@ async def main():
graph = random.choice(agent_graphs)
preset = random.choice(agent_presets)
user_agent = await db.libraryagent.create(
data=LibraryAgentCreateInput(
userId=user.id,
agentId=graph.id,
agentVersion=graph.version,
agentPresetId=preset.id,
isFavorite=random.choice([True, False]),
isCreatedByUser=random.choice([True, False]),
isArchived=random.choice([True, False]),
isDeleted=random.choice([True, False]),
)
data={
"userId": user.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"isFavorite": random.choice([True, False]),
"isCreatedByUser": random.choice([True, False]),
"isArchived": random.choice([True, False]),
"isDeleted": random.choice([True, False]),
}
)
user_agents.append(user_agent)
@@ -346,13 +341,13 @@ async def main():
user = random.choice(users)
slug = faker.slug()
listing = await db.storelisting.create(
data=StoreListingCreateInput(
agentId=graph.id,
agentVersion=graph.version,
owningUserId=user.id,
hasApprovedVersion=random.choice([True, False]),
slug=slug,
)
data={
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"owningUserId": user.id,
"hasApprovedVersion": random.choice([True, False]),
"slug": slug,
}
)
store_listings.append(listing)
@@ -362,26 +357,26 @@ async def main():
for listing in store_listings:
graph = [g for g in agent_graphs if g.id == listing.agentId][0]
version = await db.storelistingversion.create(
data=StoreListingVersionCreateInput(
agentId=graph.id,
agentVersion=graph.version,
name=graph.name or faker.sentence(nb_words=3),
subHeading=faker.sentence(),
videoUrl=faker.url(),
imageUrls=[get_image() for _ in range(3)],
description=faker.text(),
categories=[faker.word() for _ in range(3)],
isFeatured=random.choice([True, False]),
isAvailable=True,
storeListingId=listing.id,
submissionStatus=random.choice(
data={
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"name": graph.name or faker.sentence(nb_words=3),
"subHeading": faker.sentence(),
"videoUrl": faker.url(),
"imageUrls": [get_image() for _ in range(3)],
"description": faker.text(),
"categories": [faker.word() for _ in range(3)],
"isFeatured": random.choice([True, False]),
"isAvailable": True,
"storeListingId": listing.id,
"submissionStatus": random.choice(
[
prisma.enums.SubmissionStatus.PENDING,
prisma.enums.SubmissionStatus.APPROVED,
prisma.enums.SubmissionStatus.REJECTED,
]
),
)
}
)
store_listing_versions.append(version)
@@ -422,23 +417,12 @@ async def main():
)
await db.storelistingversion.update(
where={"id": version.id},
data=StoreListingVersionCreateInput(
submissionStatus=status,
Reviewer={"connect": {"id": reviewer.id}},
reviewComments=faker.text(),
reviewedAt=datetime.now(),
agentId=version.agentId, # preserving existing fields
agentVersion=version.agentVersion,
name=version.name,
subHeading=version.subHeading,
videoUrl=version.videoUrl,
imageUrls=version.imageUrls,
description=version.description,
categories=version.categories,
isFeatured=version.isFeatured,
isAvailable=version.isAvailable,
storeListingId=version.storeListingId,
),
data={
"submissionStatus": status,
"Reviewer": {"connect": {"id": reviewer.id}},
"reviewComments": faker.text(),
"reviewedAt": datetime.now(),
},
)
# Insert APIKeys

View File

@@ -15,6 +15,7 @@ services:
condition: service_healthy
environment:
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
networks:
- app-network
restart: on-failure
@@ -77,6 +78,7 @@ services:
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
- SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- REDIS_HOST=redis
- REDIS_PORT=6379
- RABBITMQ_HOST=rabbitmq
@@ -126,6 +128,7 @@ services:
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
- SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- REDIS_HOST=redis
- REDIS_PORT=6379
- REDIS_PASSWORD=password
@@ -167,6 +170,7 @@ services:
- DATABASEMANAGER_HOST=rest_server
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- REDIS_HOST=redis
- REDIS_PORT=6379
- REDIS_PASSWORD=password
@@ -201,6 +205,7 @@ services:
# - NEXT_PUBLIC_SUPABASE_URL=http://kong:8000
# - NEXT_PUBLIC_SUPABASE_ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
# - DATABASE_URL=postgresql://agpt_user:pass123@postgres:5432/postgres?connect_timeout=60&schema=platform
# - DIRECT_URL=postgresql://agpt_user:pass123@postgres:5432/postgres?connect_timeout=60&schema=platform
# - NEXT_PUBLIC_AGPT_SERVER_URL=http://localhost:8006/api
# - NEXT_PUBLIC_AGPT_WS_SERVER_URL=ws://localhost:8001/ws
# - NEXT_PUBLIC_AGPT_MARKETPLACE_URL=http://localhost:8015/api/v1/market

View File

@@ -4,7 +4,7 @@ NEXT_PUBLIC_AGPT_WS_SERVER_URL=ws://localhost:8001/ws
NEXT_PUBLIC_AGPT_MARKETPLACE_URL=http://localhost:8015/api/v1/market
NEXT_PUBLIC_LAUNCHDARKLY_ENABLED=false
NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID=
NEXT_PUBLIC_APP_ENV=dev
NEXT_PUBLIC_APP_ENV=local
## Locale settings

View File

@@ -2,12 +2,14 @@
// The config you add here will be used whenever a users loads a page in their browser.
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
import { getEnvironmentStr } from "@/lib/utils";
import * as Sentry from "@sentry/nextjs";
Sentry.init({
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
enabled: process.env.DISABLE_SENTRY !== "true",
environment: getEnvironmentStr(),
// Add optional integrations for additional features
integrations: [
@@ -28,7 +30,9 @@ Sentry.init({
// Set `tracePropagationTargets` to control for which URLs trace propagation should be enabled
tracePropagationTargets: [
"localhost",
"localhost:8006",
/^https:\/\/dev\-builder\.agpt\.co\/api/,
/^https:\/\/.*\.agpt\.co\/api/,
],
// Define how likely Replay events are sampled.
@@ -48,4 +52,8 @@ Sentry.init({
// For example, a tracesSampleRate of 0.5 and profilesSampleRate of 0.5 would
// result in 25% of transactions being profiled (0.5*0.5=0.25)
profilesSampleRate: 1.0,
_experiments: {
// Enable logs to be sent to Sentry.
enableLogs: true,
},
});

View File

@@ -42,6 +42,7 @@
"@radix-ui/react-separator": "^1.1.0",
"@radix-ui/react-slot": "^1.1.0",
"@radix-ui/react-switch": "^1.1.1",
"@radix-ui/react-tabs": "^1.1.4",
"@radix-ui/react-toast": "^1.2.5",
"@radix-ui/react-tooltip": "^1.1.7",
"@sentry/nextjs": "^9",
@@ -52,7 +53,6 @@
"@xyflow/react": "12.4.2",
"ajv": "^8.17.1",
"boring-avatars": "^1.11.2",
"canvas-confetti": "^1.9.3",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"cmdk": "1.0.4",
@@ -70,6 +70,7 @@
"moment": "^2.30.1",
"next": "^14.2.26",
"next-themes": "^0.4.5",
"party-js": "^2.2.0",
"react": "^18",
"react-day-picker": "^9.6.1",
"react-dom": "^18",

View File

@@ -4,15 +4,28 @@
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
import * as Sentry from "@sentry/nextjs";
import { getEnvironmentStr } from "./src/lib/utils";
Sentry.init({
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
enabled: process.env.NODE_ENV !== "development",
environment: getEnvironmentStr(),
// Define how likely traces are sampled. Adjust this value in production, or use tracesSampler for greater control.
tracesSampleRate: 1,
tracePropagationTargets: [
"localhost",
"localhost:8006",
/^https:\/\/dev\-builder\.agpt\.co\/api/,
/^https:\/\/.*\.agpt\.co\/api/,
],
// Setting this option to true will print useful information to the console while you're setting up Sentry.
debug: false,
_experiments: {
// Enable logs to be sent to Sentry.
enableLogs: true,
},
});

View File

@@ -2,6 +2,7 @@
// The config you add here will be used whenever the server handles a request.
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
import { getEnvironmentStr } from "@/lib/utils";
import * as Sentry from "@sentry/nextjs";
// import { NodeProfilingIntegration } from "@sentry/profiling-node";
@@ -9,9 +10,16 @@ Sentry.init({
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
enabled: process.env.NODE_ENV !== "development",
environment: getEnvironmentStr(),
// Define how likely traces are sampled. Adjust this value in production, or use tracesSampler for greater control.
tracesSampleRate: 1,
tracePropagationTargets: [
"localhost",
"localhost:8006",
/^https:\/\/dev\-builder\.agpt\.co\/api/,
/^https:\/\/.*\.agpt\.co\/api/,
],
// Setting this option to true will print useful information to the console while you're setting up Sentry.
debug: false,
@@ -22,4 +30,9 @@ Sentry.init({
// NodeProfilingIntegration,
// Sentry.fsIntegration(),
],
_experiments: {
// Enable logs to be sent to Sentry.
enableLogs: true,
},
});

View File

@@ -3,9 +3,16 @@
import { useSearchParams } from "next/navigation";
import { GraphID } from "@/lib/autogpt-server-api/types";
import FlowEditor from "@/components/Flow";
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
import { useEffect } from "react";
export default function Home() {
const query = useSearchParams();
const { completeStep } = useOnboarding();
useEffect(() => {
completeStep("BUILDER_OPEN");
}, []);
return (
<FlowEditor

View File

@@ -144,3 +144,13 @@
text-wrap: balance;
}
}
input[type="number"]::-webkit-outer-spin-button,
input[type="number"]::-webkit-inner-spin-button {
-webkit-appearance: none;
margin: 0;
}
input[type="number"] {
-moz-appearance: textfield;
}

View File

@@ -23,6 +23,7 @@ import AgentRunDetailsView from "@/components/agents/agent-run-details-view";
import AgentRunsSelectorList from "@/components/agents/agent-runs-selector-list";
import AgentScheduleDetailsView from "@/components/agents/agent-schedule-details-view";
import LibraryRunLoadingSkeleton from "./loading";
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
export default function AgentRunsPage(): React.ReactElement {
const { id: agentID }: { id: LibraryAgentID } = useParams();
@@ -50,6 +51,7 @@ export default function AgentRunsPage(): React.ReactElement {
useState<boolean>(false);
const [confirmingDeleteAgentRun, setConfirmingDeleteAgentRun] =
useState<GraphExecutionMeta | null>(null);
const { state, updateState } = useOnboarding();
const openRunDraftView = useCallback(() => {
selectView({ type: "run" });
@@ -79,20 +81,32 @@ export default function AgentRunsPage(): React.ReactElement {
[api, graphVersions],
);
// Reward user for viewing results of their onboarding agent
useEffect(() => {
if (!state || !selectedRun || state.completedSteps.includes("GET_RESULTS"))
return;
if (selectedRun.id === state.onboardingAgentExecutionId) {
updateState({
completedSteps: [...state.completedSteps, "GET_RESULTS"],
});
}
}, [selectedRun, state]);
const fetchAgents = useCallback(() => {
api.getLibraryAgent(agentID).then((agent) => {
setAgent(agent);
getGraphVersion(agent.agent_id, agent.agent_version).then(
getGraphVersion(agent.graph_id, agent.graph_version).then(
(_graph) =>
(graph && graph.version == _graph.version) || setGraph(_graph),
);
api.getGraphExecutions(agent.agent_id).then((agentRuns) => {
api.getGraphExecutions(agent.graph_id).then((agentRuns) => {
setAgentRuns(agentRuns);
// Preload the corresponding graph versions
new Set(agentRuns.map((run) => run.graph_version)).forEach((version) =>
getGraphVersion(agent.agent_id, version),
getGraphVersion(agent.graph_id, version),
);
if (!selectedView.id && isFirstLoad && agentRuns.length > 0) {
@@ -110,7 +124,7 @@ export default function AgentRunsPage(): React.ReactElement {
});
if (selectedView.type == "run" && selectedView.id && agent) {
api
.getGraphExecutionInfo(agent.agent_id, selectedView.id)
.getGraphExecutionInfo(agent.graph_id, selectedView.id)
.then(setSelectedRun);
}
}, [api, agentID, getGraphVersion, graph, selectedView, isFirstLoad, agent]);
@@ -124,7 +138,7 @@ export default function AgentRunsPage(): React.ReactElement {
if (!agent) return;
// Subscribe to all executions for this agent
api.subscribeToGraphExecutions(agent.agent_id);
api.subscribeToGraphExecutions(agent.graph_id);
}, [api, agent]);
// Handle execution updates
@@ -132,6 +146,8 @@ export default function AgentRunsPage(): React.ReactElement {
const detachExecUpdateHandler = api.onWebSocketMessage(
"graph_execution_event",
(data) => {
if (data.graph_id != agent?.graph_id) return;
setAgentRuns((prev) => {
const index = prev.findIndex((run) => run.id === data.id);
if (index === -1) {
@@ -150,7 +166,7 @@ export default function AgentRunsPage(): React.ReactElement {
return () => {
detachExecUpdateHandler();
};
}, [api, selectedView.id]);
}, [api, agent?.graph_id, selectedView.id]);
// load selectedRun based on selectedView
useEffect(() => {
@@ -163,7 +179,7 @@ export default function AgentRunsPage(): React.ReactElement {
// Ensure corresponding graph version is available before rendering I/O
api
.getGraphExecutionInfo(agent.agent_id, selectedView.id)
.getGraphExecutionInfo(agent.graph_id, selectedView.id)
.then(async (run) => {
await getGraphVersion(run.graph_id, run.graph_version);
setSelectedRun(run);
@@ -176,7 +192,7 @@ export default function AgentRunsPage(): React.ReactElement {
// TODO: filter in backend - https://github.com/Significant-Gravitas/AutoGPT/issues/9183
setSchedules(
(await api.listSchedules()).filter((s) => s.graph_id == agent.agent_id),
(await api.listSchedules()).filter((s) => s.graph_id == agent.graph_id),
);
}, [api, agent]);
@@ -215,7 +231,7 @@ export default function AgentRunsPage(): React.ReactElement {
agent &&
// Export sanitized graph from backend
api
.getGraph(agent.agent_id, agent.agent_version, true)
.getGraph(agent.graph_id, agent.graph_version, true)
.then((graph) =>
exportAsJSONFile(graph, `${graph.name}_v${graph.version}.json`),
),
@@ -228,7 +244,7 @@ export default function AgentRunsPage(): React.ReactElement {
? [
{
label: "Open graph in builder",
href: `/build?flowID=${agent.agent_id}&flowVersion=${agent.agent_version}`,
href: `/build?flowID=${agent.graph_id}&flowVersion=${agent.graph_version}`,
},
{ label: "Export agent to file", callback: downloadGraph },
]
@@ -243,7 +259,6 @@ export default function AgentRunsPage(): React.ReactElement {
);
if (!agent || !graph) {
/* TODO: implement loading indicators / skeleton page */
return <LibraryRunLoadingSkeleton />;
}

View File

@@ -81,16 +81,19 @@ export default async function Page({
}
/>
</div>
<Separator className="mb-[25px] mt-6" />
<Separator className="mb-[25px] mt-[60px]" />
<AgentsSection
margin="32px"
agents={otherAgents.agents}
sectionTitle={`Other agents by ${agent.creator}`}
/>
<Separator className="mb-[25px] mt-6" />
<Separator className="mb-[25px] mt-[60px]" />
<AgentsSection
margin="32px"
agents={similarAgents.agents}
sectionTitle="Similar agents"
/>
<Separator className="mb-[25px] mt-[60px]" />
<BecomeACreator
title="Become a Creator"
description="Join our ever-growing community of hackers and tinkerers"

View File

@@ -8,6 +8,7 @@ import { BreadCrumbs } from "@/components/agptui/BreadCrumbs";
import { Metadata } from "next";
import { CreatorInfoCard } from "@/components/agptui/CreatorInfoCard";
import { CreatorLinks } from "@/components/agptui/CreatorLinks";
import { Separator } from "@/components/ui/separator";
export async function generateMetadata({
params,
@@ -78,7 +79,7 @@ export default async function Page({
</div>
</div>
<div className="mt-8 sm:mt-12 md:mt-16 lg:pb-[58px]">
<hr className="w-full bg-neutral-700" />
<Separator className="mb-6 bg-gray-200" />
<AgentsSection
agents={creatorAgents.agents}
hideAvatars={true}

View File

@@ -153,16 +153,17 @@ export default async function Page({}: {}) {
<main className="px-4">
<HeroSection />
<FeaturedSection featuredAgents={featuredAgents.agents} />
<Separator />
{/* 100px margin because our featured sections button are placed 40px below the container */}
<Separator className="mb-6 mt-24" />
<AgentsSection
sectionTitle="Top Agents"
agents={topAgents.agents as Agent[]}
/>
<Separator />
<Separator className="mb-[25px] mt-[60px]" />
<FeaturedCreators
featuredCreators={featuredCreators.creators as FeaturedCreator[]}
/>
<Separator />
<Separator className="mb-[25px] mt-[60px]" />
<BecomeACreator
title="Become a Creator"
description="Join our ever-growing community of hackers and tinkerers"

View File

@@ -98,7 +98,7 @@ const Monitor = () => {
flows={flows}
executions={[
...(selectedFlow
? executions.filter((v) => v.graph_id == selectedFlow.agent_id)
? executions.filter((v) => v.graph_id == selectedFlow.graph_id)
: executions),
].sort((a, b) => b.started_at.getTime() - a.started_at.getTime())}
selectedRun={selectedRun}
@@ -108,7 +108,7 @@ const Monitor = () => {
<FlowRunInfo
agent={
selectedFlow ||
flows.find((f) => f.agent_id == selectedRun.graph_id)!
flows.find((f) => f.graph_id == selectedRun.graph_id)!
}
execution={selectedRun}
className={column3}
@@ -118,7 +118,7 @@ const Monitor = () => {
<FlowInfo
flow={selectedFlow}
executions={executions.filter(
(e) => e.graph_id == selectedFlow.agent_id,
(e) => e.graph_id == selectedFlow.graph_id,
)}
className={column3}
refresh={() => {

View File

@@ -80,12 +80,23 @@ export default function Page() {
if (!agent) {
return;
}
api.addMarketplaceAgentToLibrary(
storeAgent?.store_listing_version_id || "",
);
api.executeGraph(agent.id, agent.version, state?.agentInput || {});
router.push("/onboarding/6-congrats");
}, [api, agent, router, state?.agentInput]);
api
.addMarketplaceAgentToLibrary(storeAgent?.store_listing_version_id || "")
.then((libraryAgent) => {
api
.executeGraph(
libraryAgent.graph_id,
libraryAgent.graph_version,
state?.agentInput || {},
)
.then(({ graph_exec_id }) => {
updateState({
onboardingAgentExecutionId: graph_exec_id,
});
router.push("/onboarding/6-congrats");
});
});
}, [api, agent, router, state?.agentInput, storeAgent, updateState]);
const runYourAgent = (
<div className="ml-[54px] w-[481px] pl-5">

View File

@@ -6,9 +6,6 @@ import { redirect } from "next/navigation";
export async function finishOnboarding() {
const api = new BackendAPI();
const onboarding = await api.getUserOnboarding();
await api.updateUserOnboarding({
completedSteps: [...onboarding.completedSteps, "CONGRATS"],
});
revalidatePath("/library", "layout");
redirect("/library");
}

View File

@@ -1,24 +1,26 @@
"use client";
import { useEffect, useState } from "react";
import { useEffect, useRef, useState } from "react";
import { cn } from "@/lib/utils";
import { finishOnboarding } from "./actions";
import confetti from "canvas-confetti";
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
import * as party from "party-js";
export default function Page() {
useOnboarding(7, "AGENT_INPUT");
const { state, updateState } = useOnboarding(7, "AGENT_INPUT");
const [showText, setShowText] = useState(false);
const [showSubtext, setShowSubtext] = useState(false);
const divRef = useRef(null);
useEffect(() => {
confetti({
particleCount: 120,
spread: 360,
shapes: ["square", "circle"],
scalar: 2,
decay: 0.93,
origin: { y: 0.38, x: 0.51 },
});
if (divRef.current) {
party.confetti(divRef.current, {
count: 100,
spread: 180,
shapes: ["square", "circle"],
size: party.variation.range(2, 2), // scalar: 2
speed: party.variation.range(300, 1000),
});
}
const timer0 = setTimeout(() => {
setShowText(true);
@@ -29,6 +31,9 @@ export default function Page() {
}, 500);
const timer2 = setTimeout(() => {
updateState({
completedSteps: [...(state?.completedSteps || []), "CONGRATS"],
});
finishOnboarding();
}, 3000);
@@ -42,6 +47,7 @@ export default function Page() {
return (
<div className="flex h-screen w-screen flex-col items-center justify-center bg-violet-100">
<div
ref={divRef}
className={cn(
"z-10 -mb-16 text-9xl duration-500",
showText ? "opacity-100" : "opacity-0",
@@ -63,7 +69,7 @@ export default function Page() {
showSubtext ? "opacity-100" : "opacity-0",
)}
>
You earned 15$ for running your first agent
You earned 3$ for running your first agent
</p>
</div>
);

View File

@@ -5,11 +5,14 @@ export default async function OnboardingResetPage() {
const api = new BackendAPI();
await api.updateUserOnboarding({
completedSteps: [],
notificationDot: true,
notified: [],
usageReason: null,
integrations: [],
otherIntegrations: "",
selectedStoreListingVersionId: null,
agentInput: {},
onboardingAgentExecutionId: null,
});
redirect("/onboarding/1-welcome");
}

View File

@@ -118,8 +118,7 @@ export default function CreditsPage() {
{topupStatus === "success" && (
<span className="text-green-500">
Your payment was successful. Your credits will be updated
shortly. You can click the refresh icon 🔄 in case it is not
updated.
shortly. Try refreshing the page in case it is not updated.
</span>
)}
{topupStatus === "cancel" && (

View File

@@ -131,11 +131,7 @@ export default function PrivatePage() {
const allCredentials = providers
? Object.values(providers).flatMap((provider) =>
[
...provider.savedOAuthCredentials,
...provider.savedApiKeys,
...provider.savedUserPasswordCredentials,
]
provider.savedCredentials
.filter((cred) => !hiddenCredentials.includes(cred.id))
.map((credentials) => ({
...credentials,

View File

@@ -178,18 +178,24 @@ export const CustomNode = React.memo(
return obj;
}, []);
const setHardcodedValues = (values: any) => {
updateNodeData(id, { hardcodedValues: values });
};
const setHardcodedValues = useCallback(
(values: any) => {
updateNodeData(id, { hardcodedValues: values });
},
[id, updateNodeData],
);
useEffect(() => {
isInitialSetup.current = false;
setHardcodedValues(fillDefaults(data.hardcodedValues, data.inputSchema));
}, []);
const setErrors = (errors: { [key: string]: string }) => {
updateNodeData(id, { errors });
};
const setErrors = useCallback(
(errors: { [key: string]: string }) => {
updateNodeData(id, { errors });
},
[id, updateNodeData],
);
const toggleOutput = (checked: boolean) => {
setIsOutputOpen(checked);
@@ -340,46 +346,49 @@ export const CustomNode = React.memo(
});
}
};
const handleInputChange = (path: string, value: any) => {
const keys = parseKeys(path);
const newValues = JSON.parse(JSON.stringify(data.hardcodedValues));
let current = newValues;
const handleInputChange = useCallback(
(path: string, value: any) => {
const keys = parseKeys(path);
const newValues = JSON.parse(JSON.stringify(data.hardcodedValues));
let current = newValues;
for (let i = 0; i < keys.length - 1; i++) {
const { key: currentKey, index } = keys[i];
if (index !== undefined) {
if (!current[currentKey]) current[currentKey] = [];
if (!current[currentKey][index]) current[currentKey][index] = {};
current = current[currentKey][index];
} else {
if (!current[currentKey]) current[currentKey] = {};
current = current[currentKey];
for (let i = 0; i < keys.length - 1; i++) {
const { key: currentKey, index } = keys[i];
if (index !== undefined) {
if (!current[currentKey]) current[currentKey] = [];
if (!current[currentKey][index]) current[currentKey][index] = {};
current = current[currentKey][index];
} else {
if (!current[currentKey]) current[currentKey] = {};
current = current[currentKey];
}
}
}
const lastKey = keys[keys.length - 1];
if (lastKey.index !== undefined) {
if (!current[lastKey.key]) current[lastKey.key] = [];
current[lastKey.key][lastKey.index] = value;
} else {
current[lastKey.key] = value;
}
const lastKey = keys[keys.length - 1];
if (lastKey.index !== undefined) {
if (!current[lastKey.key]) current[lastKey.key] = [];
current[lastKey.key][lastKey.index] = value;
} else {
current[lastKey.key] = value;
}
if (!isInitialSetup.current) {
history.push({
type: "UPDATE_INPUT",
payload: { nodeId: id, oldValues: data.hardcodedValues, newValues },
undo: () => setHardcodedValues(data.hardcodedValues),
redo: () => setHardcodedValues(newValues),
});
}
if (!isInitialSetup.current) {
history.push({
type: "UPDATE_INPUT",
payload: { nodeId: id, oldValues: data.hardcodedValues, newValues },
undo: () => setHardcodedValues(data.hardcodedValues),
redo: () => setHardcodedValues(newValues),
});
}
setHardcodedValues(newValues);
const errors = data.errors || {};
// Remove error with the same key
setNestedProperty(errors, path, null);
setErrors({ ...errors });
};
setHardcodedValues(newValues);
const errors = data.errors || {};
// Remove error with the same key
setNestedProperty(errors, path, null);
setErrors({ ...errors });
},
[data.hardcodedValues, id, setHardcodedValues, data.errors, setErrors],
);
const isInputHandleConnected = (key: string) => {
return (
@@ -407,28 +416,34 @@ export const CustomNode = React.memo(
);
};
const handleInputClick = (key: string) => {
console.debug(`Opening modal for key: ${key}`);
setActiveKey(key);
const value = getValue(key, data.hardcodedValues);
setInputModalValue(
typeof value === "object" ? JSON.stringify(value, null, 2) : value,
);
setIsModalOpen(true);
};
const handleInputClick = useCallback(
(key: string) => {
console.debug(`Opening modal for key: ${key}`);
setActiveKey(key);
const value = getValue(key, data.hardcodedValues);
setInputModalValue(
typeof value === "object" ? JSON.stringify(value, null, 2) : value,
);
setIsModalOpen(true);
},
[data.hardcodedValues],
);
const handleModalSave = (value: string) => {
if (activeKey) {
try {
const parsedValue = JSON.parse(value);
handleInputChange(activeKey, parsedValue);
} catch (error) {
handleInputChange(activeKey, value);
const handleModalSave = useCallback(
(value: string) => {
if (activeKey) {
try {
const parsedValue = JSON.parse(value);
handleInputChange(activeKey, parsedValue);
} catch (error) {
handleInputChange(activeKey, value);
}
}
}
setIsModalOpen(false);
setActiveKey(null);
};
setIsModalOpen(false);
setActiveKey(null);
},
[activeKey, handleInputChange],
);
const handleOutputClick = () => {
setIsOutputModalOpen(true);

View File

@@ -1,5 +1,6 @@
"use client";
import React, { useCallback, useMemo } from "react";
import { isEmpty } from "lodash";
import moment from "moment";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
@@ -164,7 +165,8 @@ export default function AgentRunDetailsView({
] satisfies ButtonAction[])
: []),
...(["success", "failed", "stopped"].includes(runStatus) &&
!graph.has_webhook_trigger
!graph.has_webhook_trigger &&
isEmpty(graph.credentials_input_schema.required) // TODO: enable re-run with credentials - https://linear.app/autogpt/issue/SECRT-1243
? [
{
label: (
@@ -193,6 +195,7 @@ export default function AgentRunDetailsView({
stopRun,
deleteRun,
graph.has_webhook_trigger,
graph.credentials_input_schema.properties,
agent.can_access_graph,
run.graph_id,
run.graph_version,
@@ -258,11 +261,7 @@ export default function AgentRunDetailsView({
Object.entries(agentRunInputs).map(([key, { title, value }]) => (
<div key={key} className="flex flex-col gap-1.5">
<label className="text-sm font-medium">{title || key}</label>
<Input
defaultValue={value}
className="rounded-full"
disabled
/>
<Input value={value} className="rounded-full" disabled />
</div>
))
) : (

View File

@@ -6,11 +6,13 @@ import { GraphExecutionID, GraphMeta } from "@/lib/autogpt-server-api";
import type { ButtonAction } from "@/components/agptui/types";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { CredentialsInput } from "@/components/integrations/credentials-input";
import { TypeBasedInput } from "@/components/type-based-input";
import { useToastOnFail } from "@/components/ui/use-toast";
import ActionButtonGroup from "@/components/agptui/action-button-group";
import SchemaTooltip from "@/components/SchemaTooltip";
import { IconPlay } from "@/components/ui/icons";
import { useOnboarding } from "../onboarding/onboarding-provider";
export default function AgentRunDraftView({
graph,
@@ -25,16 +27,32 @@ export default function AgentRunDraftView({
const toastOnFail = useToastOnFail();
const agentInputs = graph.input_schema.properties;
const agentCredentialsInputs = graph.credentials_input_schema.properties;
const [inputValues, setInputValues] = useState<Record<string, any>>({});
const doRun = useCallback(
() =>
api
.executeGraph(graph.id, graph.version, inputValues)
.then((newRun) => onRun(newRun.graph_exec_id))
.catch(toastOnFail("execute agent")),
[api, graph, inputValues, onRun, toastOnFail],
const [inputCredentials, setInputCredentials] = useState<Record<string, any>>(
{},
);
const { state, completeStep } = useOnboarding();
const doRun = useCallback(() => {
api
.executeGraph(graph.id, graph.version, inputValues, inputCredentials)
.then((newRun) => onRun(newRun.graph_exec_id))
.catch(toastOnFail("execute agent"));
// Mark run agent onboarding step as completed
if (state?.completedSteps.includes("MARKETPLACE_ADD_AGENT")) {
completeStep("MARKETPLACE_RUN_AGENT");
}
}, [
api,
graph,
inputValues,
inputCredentials,
onRun,
toastOnFail,
state,
completeStep,
]);
const runActions: ButtonAction[] = useMemo(
() => [
@@ -60,6 +78,26 @@ export default function AgentRunDraftView({
<CardTitle className="font-poppins text-lg">Input</CardTitle>
</CardHeader>
<CardContent className="flex flex-col gap-4">
{/* Credentials inputs */}
{Object.entries(agentCredentialsInputs).map(
([key, inputSubSchema]) => (
<CredentialsInput
key={key}
schema={{ ...inputSubSchema, discriminator: undefined }}
selectedCredentials={
inputCredentials[key] ?? inputSubSchema.default
}
onSelectCredentials={(value) =>
setInputCredentials((obj) => ({
...obj,
[key]: value,
}))
}
/>
),
)}
{/* Regular inputs */}
{Object.entries(agentInputs).map(([key, inputSubSchema]) => (
<div key={key} className="flex flex-col space-y-2">
<label className="flex items-center gap-1 text-sm font-medium">

View File

@@ -84,7 +84,7 @@ export default function AgentRunsSelectorList({
>
<span>Scheduled</span>
<span className="text-neutral-600">
{schedules.filter((s) => s.graph_id === agent.agent_id).length}
{schedules.filter((s) => s.graph_id === agent.graph_id).length}
</span>
</Badge>
</div>
@@ -127,7 +127,7 @@ export default function AgentRunsSelectorList({
/>
))
: schedules
.filter((schedule) => schedule.graph_id === agent.agent_id)
.filter((schedule) => schedule.graph_id === agent.graph_id)
.map((schedule) => (
<AgentRunSummaryCard
className="h-28 w-72 lg:h-32 xl:w-80"

View File

@@ -109,11 +109,7 @@ export default function AgentScheduleDetailsView({
Object.entries(agentRunInputs).map(([key, { title, value }]) => (
<div key={key} className="flex flex-col gap-1.5">
<label className="text-sm font-medium">{title || key}</label>
<Input
defaultValue={value}
className="rounded-full"
disabled
/>
<Input value={value} className="rounded-full" disabled />
</div>
))
) : (

View File

@@ -10,6 +10,7 @@ import { useToast } from "@/components/ui/use-toast";
import useSupabase from "@/hooks/useSupabase";
import { DownloadIcon, LoaderIcon } from "lucide-react";
import { useOnboarding } from "../onboarding/onboarding-provider";
interface AgentInfoProps {
name: string;
creator: string;
@@ -39,6 +40,7 @@ export const AgentInfo: React.FC<AgentInfoProps> = ({
const api = React.useMemo(() => new BackendAPI(), []);
const { user } = useSupabase();
const { toast } = useToast();
const { completeStep } = useOnboarding();
const [downloading, setDownloading] = React.useState(false);
@@ -47,6 +49,7 @@ export const AgentInfo: React.FC<AgentInfoProps> = ({
const newLibraryAgent = await api.addMarketplaceAgentToLibrary(
storeListingVersionId,
);
completeStep("MARKETPLACE_ADD_AGENT");
router.push(`/library/agents/${newLibraryAgent.id}`);
} catch (error) {
console.error("Failed to add agent to library:", error);
@@ -170,10 +173,10 @@ export const AgentInfo: React.FC<AgentInfoProps> = ({
{/* Description Section */}
<div className="mb-4 w-full lg:mb-[36px]">
<div className="font-geist decoration-skip-ink-none mb-1.5 text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
<div className="mb-1.5 font-sans text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
Description
</div>
<div className="font-geist decoration-skip-ink-none text-base font-normal leading-6 text-neutral-600 underline-offset-[from-font] dark:text-neutral-400">
<div className="whitespace-pre-line font-sans text-base font-normal leading-6 text-neutral-600 dark:text-neutral-400">
{longDescription}
</div>
</div>

View File

@@ -21,17 +21,14 @@ export const BecomeACreator: React.FC<BecomeACreatorProps> = ({
return (
<div className="relative mx-auto h-auto min-h-[300px] w-full max-w-[1360px] md:min-h-[400px] lg:h-[459px]">
{/* Top border */}
<div className="left-0 top-0 h-px w-full bg-gray-200 dark:bg-gray-700" />
{/* Title */}
<h2 className="mb-8 text-left font-poppins text-[18px] font-[600] leading-9 text-neutral-800 dark:text-neutral-200">
<h2 className="mb-[77px] font-poppins text-[18px] font-semibold leading-[28px] text-neutral-800 dark:text-neutral-200">
{title}
</h2>
{/* Content Container */}
<div className="m-auto w-full max-w-[900px] px-4 py-16 text-center md:px-6 lg:px-0">
<h2 className="underline-from-font decoration-skip-ink-none mb-6 text-center font-poppins text-[48px] font-semibold leading-[54px] tracking-[-0.012em] text-neutral-950 dark:text-neutral-50 md:mb-8 lg:mb-12">
<div className="mx-auto w-full max-w-[900px] px-4 text-center md:px-6 lg:px-0">
<h2 className="mb-6 text-center font-poppins text-[48px] font-semibold leading-[54px] tracking-[-0.012em] text-neutral-950 dark:text-neutral-50 md:mb-8 lg:mb-12">
Build AI agents and share
<br />
<span className="text-violet-600 dark:text-violet-400">

View File

@@ -22,7 +22,7 @@ export const BreadCrumbs: React.FC<BreadCrumbsProps> = ({ items }) => {
<button className="flex h-12 w-12 items-center justify-center rounded-full border border-neutral-200 transition-colors hover:bg-neutral-50 dark:border-neutral-700 dark:hover:bg-neutral-800">
<IconRightArrow className="h-5 w-5 text-neutral-900 dark:text-neutral-100" />
</button> */}
<div className="flex h-auto flex-wrap items-center justify-start gap-4 rounded-[5rem] bg-white dark:bg-transparent">
<div className="flex h-auto flex-wrap items-center justify-start gap-4 rounded-[5rem] dark:bg-transparent">
{items.map((item, index) => (
<React.Fragment key={index}>
<Link href={item.link}>

View File

@@ -33,7 +33,10 @@ export const CreatorInfoCard: React.FC<CreatorInfoCardProps> = ({
src={avatarSrc}
alt={`${username}'s avatar`}
/>
<AvatarFallback className="h-[100px] w-[100px] sm:h-[130px] sm:w-[130px]">
<AvatarFallback
size={130}
className="h-[100px] w-[100px] sm:h-[130px] sm:w-[130px]"
>
{username.charAt(0)}
</AvatarFallback>
</Avatar>

View File

@@ -1,43 +0,0 @@
import type { Meta, StoryObj } from "@storybook/react";
import CreditsCard from "./CreditsCard";
import { userEvent, within } from "@storybook/test";
const meta: Meta<typeof CreditsCard> = {
title: "AGPT UI/Credits Card",
component: CreditsCard,
tags: ["autodocs"],
};
export default meta;
type Story = StoryObj<typeof CreditsCard>;
export const Default: Story = {
args: {
credits: 0,
},
};
export const SmallNumber: Story = {
args: {
credits: 10,
},
};
export const LargeNumber: Story = {
args: {
credits: 1000000,
},
};
export const InteractionTest: Story = {
args: {
credits: 100,
},
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
const refreshButton = canvas.getByRole("button", {
name: /refresh credits/i,
});
await userEvent.click(refreshButton);
},
};

View File

@@ -1,48 +0,0 @@
"use client";
import { IconRefresh } from "@/components/ui/icons";
import { useState } from "react";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import useCredits from "@/hooks/useCredits";
const CreditsCard = () => {
const { credits, formatCredits, fetchCredits } = useCredits({
fetchInitialCredits: true,
});
const api = useBackendAPI();
const onRefresh = async () => {
fetchCredits();
};
return (
<div className="inline-flex h-[48px] items-center gap-2.5 rounded-2xl bg-neutral-200 p-4 dark:bg-neutral-800">
<div className="flex items-center gap-0.5">
<span className="p-ui-semibold text-base leading-7 text-neutral-900 dark:text-neutral-50">
Balance: {formatCredits(credits)}
</span>
</div>
<Tooltip key="RefreshCredits" delayDuration={500}>
<TooltipTrigger asChild>
<button
onClick={onRefresh}
className="h-6 w-6 transition-colors hover:text-neutral-700 dark:hover:text-neutral-300"
aria-label="Refresh credits"
>
<IconRefresh className="h-6 w-6" />
</button>
</TooltipTrigger>
<TooltipContent>
<p>Refresh credits</p>
</TooltipContent>
</Tooltip>
</div>
);
};
export default CreditsCard;

View File

@@ -27,7 +27,7 @@ export const FeaturedAgentCard: React.FC<FeaturedStoreCardProps> = ({
data-testid="featured-store-card"
onMouseEnter={() => setIsHovered(true)}
onMouseLeave={() => setIsHovered(false)}
className={`flex h-full flex-col ${backgroundColor}`}
className={`flex h-full flex-col ${backgroundColor} rounded-[1.5rem] border-none`}
>
<CardHeader>
<CardTitle className="line-clamp-2 text-base sm:text-xl">

View File

@@ -4,7 +4,7 @@ import { ProfilePopoutMenu } from "./ProfilePopoutMenu";
import { IconType, IconLogIn, IconAutoGPTLogo } from "@/components/ui/icons";
import { MobileNavBar } from "./MobileNavBar";
import { Button } from "./Button";
import CreditsCard from "./CreditsCard";
import Wallet from "./Wallet";
import { ProfileDetails } from "@/lib/autogpt-server-api/types";
import { NavbarLink } from "./NavbarLink";
import getServerUser from "@/lib/supabase/getServerUser";
@@ -61,7 +61,7 @@ export const Navbar = async ({ links, menuItemGroups }: NavbarProps) => {
<div className="flex items-center gap-4">
{isLoggedIn ? (
<div className="flex items-center gap-4">
{profile && <CreditsCard />}
{profile && <Wallet />}
<ProfilePopoutMenu
menuItemGroups={menuItemGroups}
userName={profile?.username}

View File

@@ -32,7 +32,7 @@ export const StoreCard: React.FC<StoreCardProps> = ({
return (
<div
className="inline-flex w-full max-w-[434px] cursor-pointer flex-col items-start justify-start gap-2.5 rounded-[26px] bg-white transition-all duration-300 hover:shadow-lg dark:bg-transparent dark:hover:shadow-gray-700"
className="flex h-[27rem] w-full max-w-md cursor-pointer flex-col items-start rounded-3xl bg-white transition-all duration-300 hover:shadow-lg dark:bg-transparent dark:hover:shadow-gray-700"
onClick={handleClick}
data-testid="store-card"
role="button"
@@ -44,8 +44,8 @@ export const StoreCard: React.FC<StoreCardProps> = ({
}
}}
>
{/* Header Image Section with Avatar */}
<div className="relative h-[200px] w-full overflow-hidden rounded-[20px]">
{/* First Section: Image with Avatar */}
<div className="relative aspect-[2/1.2] w-full overflow-hidden rounded-3xl md:aspect-[2.17/1]">
{agentImage && (
<Image
src={agentImage}
@@ -57,14 +57,14 @@ export const StoreCard: React.FC<StoreCardProps> = ({
)}
{!hideAvatar && (
<div className="absolute bottom-4 left-4">
<Avatar className="h-16 w-16 border-2 border-white dark:border-gray-800">
<Avatar className="h-16 w-16">
{avatarSrc && (
<AvatarImage
src={avatarSrc}
alt={`${creatorName || agentName} creator avatar`}
/>
)}
<AvatarFallback>
<AvatarFallback size={64}>
{(creatorName || agentName).charAt(0)}
</AvatarFallback>
</Avatar>
@@ -72,37 +72,46 @@ export const StoreCard: React.FC<StoreCardProps> = ({
)}
</div>
{/* Content Section */}
<div className="w-full px-2 py-4">
{/* Title and Creator */}
<h3 className="mb-0.5 font-poppins text-2xl font-semibold text-[#272727] dark:text-neutral-100">
{agentName}
</h3>
{!hideAvatar && creatorName && (
<p className="mb-2.5 font-sans text-xl font-normal text-neutral-600 dark:text-neutral-400">
by {creatorName}
</p>
)}
{/* Description */}
<p className="mb-4 font-sans text-base font-normal leading-normal text-neutral-600 dark:text-neutral-400">
{description}
</p>
<div className="mt-3 flex w-full flex-1 flex-col px-4">
{/* Second Section: Agent Name and Creator Name */}
<div className="flex w-full flex-col">
<h3 className="line-clamp-2 font-poppins text-2xl font-semibold text-[#272727] dark:text-neutral-100">
{agentName}
</h3>
{!hideAvatar && creatorName && (
<p className="mt-3 truncate font-sans text-xl font-normal text-neutral-600 dark:text-neutral-400">
by {creatorName}
</p>
)}
</div>
{/* Stats Row */}
<div className="flex items-center justify-between">
<div className="font-sans text-lg font-semibold text-neutral-800 dark:text-neutral-200">
{runs.toLocaleString()} runs
</div>
<div className="flex items-center gap-2">
<span className="font-geist text-lg font-semibold text-neutral-800 dark:text-neutral-200">
{rating.toFixed(1)}
</span>
<div
className="inline-flex items-center"
role="img"
aria-label={`Rating: ${rating.toFixed(1)} out of 5 stars`}
>
{StarRatingIcons(rating)}
{/* Third Section: Description */}
<div className="mt-2.5 flex w-full flex-col">
<p className="line-clamp-3 font-sans text-base font-normal leading-normal text-neutral-600 dark:text-neutral-400">
{description}
</p>
</div>
<div className="flex-grow" />
{/* Spacer to push stats to bottom */}
{/* Fourth Section: Stats Row - aligned to bottom */}
<div className="mt-5 w-full">
<div className="flex items-center justify-between">
<div className="font-sans text-lg font-semibold text-neutral-800 dark:text-neutral-200">
{runs.toLocaleString()} runs
</div>
<div className="flex items-center gap-2">
<span className="font-sans text-lg font-semibold text-neutral-800 dark:text-neutral-200">
{rating.toFixed(1)}
</span>
<div
className="inline-flex items-center"
role="img"
aria-label={`Rating: ${rating.toFixed(1)} out of 5 stars`}
>
{StarRatingIcons(rating)}
</div>
</div>
</div>
</div>

View File

@@ -0,0 +1,120 @@
"use client";
import useCredits from "@/hooks/useCredits";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/ui/popover";
import { X } from "lucide-react";
import { PopoverClose } from "@radix-ui/react-popover";
import { TaskGroups } from "../onboarding/WalletTaskGroups";
import { ScrollArea } from "../ui/scroll-area";
import { useOnboarding } from "../onboarding/onboarding-provider";
import { useCallback, useEffect, useRef } from "react";
import { cn } from "@/lib/utils";
import * as party from "party-js";
import WalletRefill from "./WalletRefill";
export default function Wallet() {
const { credits, formatCredits, fetchCredits } = useCredits({
fetchInitialCredits: true,
});
const { state, updateState } = useOnboarding();
const walletRef = useRef<HTMLButtonElement | null>(null);
const onWalletOpen = useCallback(async () => {
if (state?.notificationDot) {
updateState({ notificationDot: false });
}
// Refresh credits when the wallet is opened
fetchCredits();
}, [state?.notificationDot, updateState, fetchCredits]);
const fadeOut = new party.ModuleBuilder()
.drive("opacity")
.by((t) => 1 - t)
.through("lifetime")
.build();
useEffect(() => {
// Check if there are any completed tasks (state?.completedTasks) that
// are not in the state?.notified array and play confetti if so
const pending = state?.completedSteps
.filter((step) => !state?.notified.includes(step))
// Ignore steps that are not relevant for notifications
.filter(
(step) =>
step !== "WELCOME" &&
step !== "USAGE_REASON" &&
step !== "INTEGRATIONS" &&
step !== "AGENT_CHOICE" &&
step !== "AGENT_NEW_RUN" &&
step !== "AGENT_INPUT",
);
if ((pending?.length || 0) > 0 && walletRef.current) {
party.confetti(walletRef.current, {
count: 30,
spread: 120,
shapes: ["square", "circle"],
size: party.variation.range(1, 2),
speed: party.variation.range(200, 300),
modules: [fadeOut],
});
}
}, [state?.completedSteps, state?.notified]);
return (
<Popover>
<PopoverTrigger asChild>
<button
ref={walletRef}
className="relative flex items-center gap-1 rounded-md bg-zinc-200 px-3 py-2 text-sm transition-colors duration-200 hover:bg-zinc-300"
onClick={onWalletOpen}
>
Wallet{" "}
<span className="text-sm font-semibold">
{formatCredits(credits)}
</span>
{state?.notificationDot && (
<span className="absolute right-1 top-1 h-2 w-2 rounded-full bg-violet-600"></span>
)}
</button>
</PopoverTrigger>
<PopoverContent
className={cn(
"absolute -right-[7.9rem] -top-[3.2rem] z-50 w-[28.5rem] px-[0.625rem] py-2",
"rounded-xl border-zinc-200 bg-zinc-50 shadow-[0_3px_3px] shadow-zinc-300",
)}
>
{/* Header */}
<div className="mx-1 flex items-center justify-between border-b border-zinc-300 pb-2">
<span className="font-poppins font-medium text-zinc-900">
Your wallet
</span>
<div className="flex items-center font-inter text-sm font-semibold text-violet-700">
<div className="rounded-lg bg-violet-100 px-3 py-2">
Wallet{" "}
<span className="font-semibold">{formatCredits(credits)}</span>
</div>
<PopoverClose>
<X className="ml-[2.8rem] h-5 w-5 text-zinc-800 hover:text-foreground" />
</PopoverClose>
</div>
</div>
<ScrollArea className="max-h-[85vh] overflow-y-auto">
{/* Top ups */}
<WalletRefill />
{/* Tasks */}
<p className="mx-1 mt-4 font-sans text-xs font-medium text-violet-700">
Onboarding tasks
</p>
<p className="mx-1 my-1 font-sans text-xs font-normal text-zinc-500">
Complete the following tasks to earn more credits!
</p>
<TaskGroups />
</ScrollArea>
</PopoverContent>
</Popover>
);
}

View File

@@ -0,0 +1,265 @@
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { cn } from "@/lib/utils";
import { zodResolver } from "@hookform/resolvers/zod";
import { useForm } from "react-hook-form";
import { z } from "zod";
import {
Form,
FormControl,
FormField,
FormItem,
FormLabel,
FormMessage,
} from "@/components/ui/form";
import { Input } from "../ui/input";
import Link from "next/link";
import { useToast, useToastOnFail } from "../ui/use-toast";
import useCredits from "@/hooks/useCredits";
import { useCallback, useEffect, useState } from "react";
const topUpSchema = z.object({
amount: z
.number({ coerce: true, invalid_type_error: "Enter top-up amount" })
.min(5, "Top-ups start at $5. Please enter a higher amount."),
});
const autoRefillSchema = z
.object({
threshold: z
.number({ coerce: true, invalid_type_error: "Enter min. balance" })
.min(
5,
"Looks like your balance is too low for auto-refill. Try $5 or more.",
),
refillAmount: z
.number({ coerce: true, invalid_type_error: "Enter top-up amount" })
.min(5, "Top-ups start at $5. Please enter a higher amount."),
})
.refine((data) => data.refillAmount >= data.threshold, {
message:
"Your refill amount must be equal to or greater than the balance you entered above.",
path: ["refillAmount"],
});
export default function WalletRefill() {
const { toast } = useToast();
const toastOnFail = useToastOnFail();
const { requestTopUp, autoTopUpConfig, updateAutoTopUpConfig } = useCredits({
fetchInitialAutoTopUpConfig: true,
});
const [isLoading, setIsLoading] = useState(false);
const topUpForm = useForm<z.infer<typeof topUpSchema>>({
resolver: zodResolver(topUpSchema),
});
const autoRefillForm = useForm<z.infer<typeof autoRefillSchema>>({
resolver: zodResolver(autoRefillSchema),
});
console.log("autoRefillForm");
// Pre-fill the auto-refill form with existing values
useEffect(() => {
const values = autoRefillForm.getValues();
if (
autoTopUpConfig &&
autoTopUpConfig.amount > 0 &&
autoTopUpConfig.threshold > 0 &&
!autoRefillForm.getFieldState("threshold").isTouched &&
!autoRefillForm.getFieldState("refillAmount").isTouched
) {
autoRefillForm.setValue("threshold", autoTopUpConfig.threshold / 100);
autoRefillForm.setValue("refillAmount", autoTopUpConfig.amount / 100);
}
}, [autoTopUpConfig, autoRefillForm]);
const submitTopUp = useCallback(
async (data: z.infer<typeof topUpSchema>) => {
setIsLoading(true);
await requestTopUp(data.amount * 100).catch(
toastOnFail("request top-up"),
);
setIsLoading(false);
},
[requestTopUp, toastOnFail],
);
const submitAutoTopUpConfig = useCallback(
async (data: z.infer<typeof autoRefillSchema>) => {
setIsLoading(true);
await updateAutoTopUpConfig(data.refillAmount * 100, data.threshold * 100)
.then(() => {
toast({ title: "Auto top-up config updated! 🎉" });
})
.catch(toastOnFail("update auto top-up config"));
setIsLoading(false);
},
[updateAutoTopUpConfig, toast, toastOnFail],
);
return (
<div className="mx-1 border-b border-zinc-300">
<p className="mx-0 mt-4 font-sans text-xs font-medium text-violet-700">
Add credits to your balance
</p>
<p className="mx-0 my-1 font-sans text-xs font-normal text-zinc-500">
Choose a one-time top-up or set up automatic refills
</p>
<Tabs
defaultValue="top-up"
className="mb-6 mt-4 flex w-full flex-col items-center"
>
<TabsList className="mx-auto">
<TabsTrigger value="top-up">One-time top up</TabsTrigger>
<TabsTrigger value="auto-refill">Auto-refill</TabsTrigger>
</TabsList>
<div className="mt-4 w-full rounded-lg px-5 outline outline-1 outline-offset-2 outline-zinc-200">
<TabsContent value="top-up" className="flex flex-col">
<div className="mt-2 justify-start font-sans text-sm font-medium leading-snug text-zinc-900">
One-time top-up
</div>
<div className="mt-1 justify-start font-sans text-xs font-normal leading-tight text-zinc-500">
Enter an amount (min. $5) and add credits instantly.
</div>
<Form {...topUpForm}>
<form onSubmit={topUpForm.handleSubmit(submitTopUp)}>
<FormField
control={topUpForm.control}
name="amount"
render={({ field }) => (
<FormItem className="mb-6 mt-4">
<FormLabel className="font-sans text-sm font-medium leading-snug text-zinc-800">
Amount
</FormLabel>
<FormControl>
<>
<Input
className={cn(
"mt-2 rounded-3xl border-0 bg-white py-2 pl-6 pr-4 font-sans outline outline-1 outline-zinc-300",
"focus:outline-2 focus:outline-offset-0 focus:outline-violet-700",
)}
type="number"
step="1"
{...field}
/>
<span className="absolute left-10 -translate-y-9 text-sm text-zinc-500">
$
</span>
</>
</FormControl>
<FormMessage className="mt-2 font-sans text-xs font-normal leading-tight" />
</FormItem>
)}
/>
<button
className={cn(
"mb-2 inline-flex h-10 w-24 items-center justify-center rounded-3xl bg-zinc-800 px-4 py-2",
"font-sans text-sm font-medium leading-snug text-white",
"transition-colors duration-200 hover:bg-zinc-700 disabled:bg-zinc-500",
)}
type="submit"
disabled={isLoading}
>
Top up
</button>
</form>
</Form>
</TabsContent>
<TabsContent value="auto-refill" className="flex flex-col">
<div className="justify-start font-sans text-sm font-medium leading-snug text-zinc-900">
Auto-refill
</div>
<div className="mt-1 justify-start font-sans text-xs font-normal leading-tight text-zinc-500">
Choose a one-time top-up or set up automatic refills.
</div>
<Form {...autoRefillForm}>
<form
onSubmit={autoRefillForm.handleSubmit(submitAutoTopUpConfig)}
>
<FormField
control={autoRefillForm.control}
name="threshold"
render={({ field }) => (
<FormItem className="mb-6 mt-4">
<FormLabel className="font-sans text-sm font-medium leading-snug text-zinc-800">
Refill when balance drops below:
</FormLabel>
<FormControl>
<>
<Input
className={cn(
"mt-2 rounded-3xl border-0 bg-white py-2 pl-6 pr-4 font-sans outline outline-1 outline-zinc-300",
"focus:outline-2 focus:outline-offset-0 focus:outline-violet-700",
)}
type="number"
step="1"
{...field}
/>
<span className="absolute left-10 -translate-y-9 text-sm text-zinc-500">
$
</span>
</>
</FormControl>
<FormMessage className="mt-2 font-sans text-xs font-normal leading-tight" />
</FormItem>
)}
/>
<FormField
control={autoRefillForm.control}
name="refillAmount"
render={({ field }) => (
<FormItem className="mb-6">
<FormLabel className="font-sans text-sm font-medium leading-snug text-zinc-800">
Add this amount:
</FormLabel>
<FormControl>
<>
<Input
className={cn(
"mt-2 rounded-3xl border-0 bg-white py-2 pl-6 pr-4 font-sans outline outline-1 outline-zinc-300",
"focus:outline-2 focus:outline-offset-0 focus:outline-violet-700",
)}
type="number"
step="1"
{...field}
/>
<span className="absolute left-10 -translate-y-9 text-sm text-zinc-500">
$
</span>
</>
</FormControl>
<FormMessage className="mt-2 font-sans text-xs font-normal leading-tight" />
</FormItem>
)}
/>
<button
className={cn(
"mb-4 inline-flex h-10 w-40 items-center justify-center rounded-3xl bg-zinc-800 px-4 py-2",
"font-sans text-sm font-medium leading-snug text-white",
"transition-colors duration-200 hover:bg-zinc-700 disabled:bg-zinc-500",
)}
type="submit"
disabled={isLoading}
>
Enable Auto-refill
</button>
</form>
</Form>
</TabsContent>
<div className="mb-3 justify-start font-sans text-xs font-normal leading-tight">
<span className="text-zinc-500">
To update your billing details, head to{" "}
</span>
<Link
href="/profile/credits"
className="cursor-pointer text-zinc-800 underline"
>
Billing settings
</Link>
</div>
</div>
</Tabs>
</div>
);
}

View File

@@ -25,12 +25,14 @@ interface AgentsSectionProps {
sectionTitle: string;
agents: Agent[];
hideAvatars?: boolean;
margin?: string;
}
export const AgentsSection: React.FC<AgentsSectionProps> = ({
sectionTitle,
agents: allAgents,
hideAvatars = false,
margin = "37px",
}) => {
const router = useRouter();
@@ -44,9 +46,11 @@ export const AgentsSection: React.FC<AgentsSectionProps> = ({
};
return (
<div className="flex flex-col items-center justify-center pb-4 lg:pb-8">
<div className="flex flex-col items-center justify-center">
<div className="w-full max-w-[1360px]">
<div className="mb-8 font-poppins text-[18px] font-[600] leading-9 text-neutral-800 dark:text-neutral-200">
<div
className={`mb-[${margin}] font-poppins text-lg font-semibold text-[#282828] dark:text-neutral-200`}
>
{sectionTitle}
</div>
{!displayedAgents || displayedAgents.length === 0 ? (
@@ -81,7 +85,7 @@ export const AgentsSection: React.FC<AgentsSectionProps> = ({
</CarouselContent>
</Carousel>
<div className="hidden grid-cols-1 place-items-center gap-6 md:grid md:grid-cols-2 lg:grid-cols-4">
<div className="hidden grid-cols-1 place-items-center gap-6 md:grid md:grid-cols-2 lg:grid-cols-3 2xl:grid-cols-4">
{displayedAgents.map((agent, index) => (
<StoreCard
key={index}

View File

@@ -31,9 +31,9 @@ export const FeaturedCreators: React.FC<FeaturedCreatorsProps> = ({
const displayedCreators = featuredCreators.slice(0, 4);
return (
<div className="flex w-full flex-col items-center justify-center py-16">
<div className="flex w-full flex-col items-center justify-center">
<div className="w-full max-w-[1360px]">
<h2 className="mb-8 text-left font-poppins text-[18px] font-[600] leading-9 text-neutral-800 dark:text-neutral-200">
<h2 className="mb-9 font-poppins text-lg font-semibold text-neutral-800 dark:text-neutral-200">
{title}
</h2>

View File

@@ -46,8 +46,8 @@ export const FeaturedSection: React.FC<FeaturedSectionProps> = ({
};
return (
<section className="mx-auto w-full max-w-7xl px-4 pb-16">
<h2 className="mb-8 text-left font-poppins text-[18px] font-[600] leading-9 text-neutral-800 dark:text-neutral-200">
<section className="w-full">
<h2 className="mb-8 font-poppins text-2xl font-semibold leading-7 text-neutral-800 dark:text-neutral-200">
Featured agents
</h2>

View File

@@ -4,9 +4,16 @@ import * as React from "react";
import { SearchBar } from "@/components/agptui/SearchBar";
import { FilterChips } from "@/components/agptui/FilterChips";
import { useRouter } from "next/navigation";
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
export const HeroSection: React.FC = () => {
const router = useRouter();
const { completeStep } = useOnboarding();
// Mark marketplace visit task as completed
React.useEffect(() => {
completeStep("MARKETPLACE_VISIT");
}, [completeStep]);
function onFilterChange(selectedFilters: string[]) {
const encodedTerm = encodeURIComponent(selectedFilters.join(", "));

View File

@@ -147,13 +147,3 @@
.custom-switch {
padding-left: 2px;
}
input[type="number"]::-webkit-outer-spin-button,
input[type="number"]::-webkit-inner-spin-button {
-webkit-appearance: none;
margin: 0;
}
input[type="number"] {
-moz-appearance: textfield;
}

View File

@@ -1,5 +1,6 @@
import { FC, useEffect, useMemo, useState } from "react";
import { z } from "zod";
import { beautifyString, cn } from "@/lib/utils";
import { cn } from "@/lib/utils";
import { useForm } from "react-hook-form";
import { Input } from "@/components/ui/input";
import { Button } from "@/components/ui/button";
@@ -16,8 +17,8 @@ import {
FaKey,
FaHubspot,
} from "react-icons/fa";
import { FC, useMemo, useState } from "react";
import {
BlockIOCredentialsSubSchema,
CredentialsMetaInput,
CredentialsProviderName,
} from "@/lib/autogpt-server-api/types";
@@ -106,13 +107,18 @@ export type OAuthPopupResultMessage = { message_type: "oauth_popup_result" } & (
);
export const CredentialsInput: FC<{
selfKey: string;
schema: BlockIOCredentialsSubSchema;
className?: string;
selectedCredentials?: CredentialsMetaInput;
onSelectCredentials: (newValue?: CredentialsMetaInput) => void;
}> = ({ selfKey, className, selectedCredentials, onSelectCredentials }) => {
const api = useBackendAPI();
const credentials = useCredentials(selfKey);
siblingInputs?: Record<string, any>;
}> = ({
schema,
className,
selectedCredentials,
onSelectCredentials,
siblingInputs,
}) => {
const [isAPICredentialsModalOpen, setAPICredentialsModalOpen] =
useState(false);
const [
@@ -124,20 +130,47 @@ export const CredentialsInput: FC<{
useState<AbortController | null>(null);
const [oAuthError, setOAuthError] = useState<string | null>(null);
if (!credentials || credentials.isLoading) {
const api = useBackendAPI();
const credentials = useCredentials(schema, siblingInputs);
// Deselect credentials if they do not exist (e.g. provider was changed)
useEffect(() => {
if (!credentials || !("savedCredentials" in credentials)) return;
if (
selectedCredentials &&
!credentials.savedCredentials.some((c) => c.id === selectedCredentials.id)
) {
onSelectCredentials(undefined);
}
}, [credentials, selectedCredentials, onSelectCredentials]);
const singleCredential = useMemo(() => {
if (!credentials || !("savedCredentials" in credentials)) return null;
if (credentials.savedCredentials.length === 1)
return credentials.savedCredentials[0];
return null;
}, [credentials]);
// If only 1 credential is available, auto-select it and hide this input
useEffect(() => {
if (singleCredential && !selectedCredentials) {
onSelectCredentials(singleCredential);
}
}, [singleCredential, selectedCredentials, onSelectCredentials]);
if (!credentials || credentials.isLoading || singleCredential) {
return null;
}
const {
schema,
provider,
providerName,
supportsApiKey,
supportsOAuth2,
supportsUserPassword,
savedApiKeys,
savedOAuthCredentials,
savedUserPasswordCredentials,
savedCredentials,
oAuthCallback,
} = credentials;
@@ -235,13 +268,14 @@ export const CredentialsInput: FC<{
<>
{supportsApiKey && (
<APIKeyCredentialsModal
credentialsFieldName={selfKey}
schema={schema}
open={isAPICredentialsModalOpen}
onClose={() => setAPICredentialsModalOpen(false)}
onCredentialsCreate={(credsMeta) => {
onSelectCredentials(credsMeta);
setAPICredentialsModalOpen(false);
}}
siblingInputs={siblingInputs}
/>
)}
{supportsOAuth2 && (
@@ -253,43 +287,34 @@ export const CredentialsInput: FC<{
)}
{supportsUserPassword && (
<UserPasswordCredentialsModal
credentialsFieldName={selfKey}
schema={schema}
open={isUserPasswordCredentialsModalOpen}
onClose={() => setUserPasswordCredentialsModalOpen(false)}
onCredentialsCreate={(creds) => {
onSelectCredentials(creds);
setUserPasswordCredentialsModalOpen(false);
}}
siblingInputs={siblingInputs}
/>
)}
</>
);
// Deselect credentials if they do not exist (e.g. provider was changed)
if (
selectedCredentials &&
!savedApiKeys
.concat(savedOAuthCredentials)
.concat(savedUserPasswordCredentials)
.some((c) => c.id === selectedCredentials.id)
) {
onSelectCredentials(undefined);
}
const fieldHeader = (
<div className="mb-2 flex gap-1">
<span className="text-m green text-gray-900">
{providerName} Credentials
</span>
<SchemaTooltip description={schema.description} />
</div>
);
// No saved credentials yet
if (
savedApiKeys.length === 0 &&
savedOAuthCredentials.length === 0 &&
savedUserPasswordCredentials.length === 0
) {
if (savedCredentials.length === 0) {
return (
<>
<div className="mb-2 flex gap-1">
<span className="text-m green text-gray-900">
{providerName} Credentials
</span>
<SchemaTooltip description={schema.description} />
</div>
<div>
{fieldHeader}
<div className={cn("flex flex-row space-x-2", className)}>
{supportsOAuth2 && (
<Button onClick={handleOAuthLogin}>
@@ -314,46 +339,10 @@ export const CredentialsInput: FC<{
{oAuthError && (
<div className="mt-2 text-red-500">Error: {oAuthError}</div>
)}
</>
</div>
);
}
const getCredentialCounts = () => ({
apiKeys: savedApiKeys.length,
oauth: savedOAuthCredentials.length,
userPass: savedUserPasswordCredentials.length,
});
const getSingleCredential = () => {
const counts = getCredentialCounts();
const totalCredentials = Object.values(counts).reduce(
(sum, count) => sum + count,
0,
);
if (totalCredentials !== 1) return null;
if (counts.apiKeys === 1) return savedApiKeys[0];
if (counts.oauth === 1) return savedOAuthCredentials[0];
if (counts.userPass === 1) return savedUserPasswordCredentials[0];
return null;
};
const singleCredential = getSingleCredential();
if (singleCredential) {
if (!selectedCredentials) {
onSelectCredentials({
id: singleCredential.id,
type: singleCredential.type,
provider,
title: singleCredential.title,
});
}
return null;
}
function handleValueChange(newValue: string) {
if (newValue === "sign-in") {
// Trigger OAuth2 sign in flow
@@ -362,10 +351,7 @@ export const CredentialsInput: FC<{
// Open API key dialog
setAPICredentialsModalOpen(true);
} else {
const selectedCreds = savedApiKeys
.concat(savedOAuthCredentials)
.concat(savedUserPasswordCredentials)
.find((c) => c.id == newValue)!;
const selectedCreds = savedCredentials.find((c) => c.id == newValue)!;
onSelectCredentials({
id: selectedCreds.id,
@@ -378,38 +364,40 @@ export const CredentialsInput: FC<{
// Saved credentials exist
return (
<>
<div className="flex gap-1">
<span className="text-m green mb-0 text-gray-900">
{providerName} Credentials
</span>
<SchemaTooltip description={schema.description} />
</div>
<div>
{fieldHeader}
<Select value={selectedCredentials?.id} onValueChange={handleValueChange}>
<SelectTrigger>
<SelectValue placeholder={schema.placeholder} />
</SelectTrigger>
<SelectContent className="nodrag">
{savedOAuthCredentials.map((credentials, index) => (
<SelectItem key={index} value={credentials.id}>
<ProviderIcon className="mr-2 inline h-4 w-4" />
{credentials.username}
</SelectItem>
))}
{savedApiKeys.map((credentials, index) => (
<SelectItem key={index} value={credentials.id}>
<ProviderIcon className="mr-2 inline h-4 w-4" />
<IconKey className="mr-1.5 inline" />
{credentials.title}
</SelectItem>
))}
{savedUserPasswordCredentials.map((credentials, index) => (
<SelectItem key={index} value={credentials.id}>
<ProviderIcon className="mr-2 inline h-4 w-4" />
<IconUserPlus className="mr-1.5 inline" />
{credentials.title}
</SelectItem>
))}
{savedCredentials
.filter((c) => c.type == "oauth2")
.map((credentials, index) => (
<SelectItem key={index} value={credentials.id}>
<ProviderIcon className="mr-2 inline h-4 w-4" />
{credentials.username}
</SelectItem>
))}
{savedCredentials
.filter((c) => c.type == "api_key")
.map((credentials, index) => (
<SelectItem key={index} value={credentials.id}>
<ProviderIcon className="mr-2 inline h-4 w-4" />
<IconKey className="mr-1.5 inline" />
{credentials.title}
</SelectItem>
))}
{savedCredentials
.filter((c) => c.type == "user_password")
.map((credentials, index) => (
<SelectItem key={index} value={credentials.id}>
<ProviderIcon className="mr-2 inline h-4 w-4" />
<IconUserPlus className="mr-1.5 inline" />
{credentials.title}
</SelectItem>
))}
<SelectSeparator />
{supportsOAuth2 && (
<SelectItem value="sign-in">
@@ -435,17 +423,18 @@ export const CredentialsInput: FC<{
{oAuthError && (
<div className="mt-2 text-red-500">Error: {oAuthError}</div>
)}
</>
</div>
);
};
export const APIKeyCredentialsModal: FC<{
credentialsFieldName: string;
schema: BlockIOCredentialsSubSchema;
open: boolean;
onClose: () => void;
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
}> = ({ credentialsFieldName, open, onClose, onCredentialsCreate }) => {
const credentials = useCredentials(credentialsFieldName);
siblingInputs?: Record<string, any>;
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
const credentials = useCredentials(schema, siblingInputs);
const formSchema = z.object({
apiKey: z.string().min(1, "API Key is required"),
@@ -466,8 +455,7 @@ export const APIKeyCredentialsModal: FC<{
return null;
}
const { schema, provider, providerName, createAPIKeyCredentials } =
credentials;
const { provider, providerName, createAPIKeyCredentials } = credentials;
async function onSubmit(values: z.infer<typeof formSchema>) {
const expiresAt = values.expiresAt
@@ -576,12 +564,13 @@ export const APIKeyCredentialsModal: FC<{
};
export const UserPasswordCredentialsModal: FC<{
credentialsFieldName: string;
schema: BlockIOCredentialsSubSchema;
open: boolean;
onClose: () => void;
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
}> = ({ credentialsFieldName, open, onClose, onCredentialsCreate }) => {
const credentials = useCredentials(credentialsFieldName);
siblingInputs?: Record<string, any>;
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
const credentials = useCredentials(schema, siblingInputs);
const formSchema = z.object({
username: z.string().min(1, "Username is required"),
@@ -606,8 +595,7 @@ export const UserPasswordCredentialsModal: FC<{
return null;
}
const { schema, provider, providerName, createUserPasswordCredentials } =
credentials;
const { provider, providerName, createUserPasswordCredentials } = credentials;
async function onSubmit(values: z.infer<typeof formSchema>) {
const newCredentials = await createUserPasswordCredentials({

View File

@@ -68,9 +68,7 @@ type UserPasswordCredentialsCreatable = Omit<
export type CredentialsProviderData = {
provider: CredentialsProviderName;
providerName: string;
savedApiKeys: CredentialsMetaResponse[];
savedOAuthCredentials: CredentialsMetaResponse[];
savedUserPasswordCredentials: CredentialsMetaResponse[];
savedCredentials: CredentialsMetaResponse[];
oAuthCallback: (
code: string,
state_token: string,
@@ -113,28 +111,12 @@ export default function CredentialsProvider({
setProviders((prev) => {
if (!prev || !prev[provider]) return prev;
const updatedProvider = { ...prev[provider] };
if (credentials.type === "api_key") {
updatedProvider.savedApiKeys = [
...updatedProvider.savedApiKeys,
credentials,
];
} else if (credentials.type === "oauth2") {
updatedProvider.savedOAuthCredentials = [
...updatedProvider.savedOAuthCredentials,
credentials,
];
} else if (credentials.type === "user_password") {
updatedProvider.savedUserPasswordCredentials = [
...updatedProvider.savedUserPasswordCredentials,
credentials,
];
}
return {
...prev,
[provider]: updatedProvider,
[provider]: {
...prev[provider],
savedCredentials: [...prev[provider].savedCredentials, credentials],
},
};
});
},
@@ -203,21 +185,14 @@ export default function CredentialsProvider({
setProviders((prev) => {
if (!prev || !prev[provider]) return prev;
const updatedProvider = { ...prev[provider] };
updatedProvider.savedApiKeys = updatedProvider.savedApiKeys.filter(
(cred) => cred.id !== id,
);
updatedProvider.savedOAuthCredentials =
updatedProvider.savedOAuthCredentials.filter(
(cred) => cred.id !== id,
);
updatedProvider.savedUserPasswordCredentials =
updatedProvider.savedUserPasswordCredentials.filter(
(cred) => cred.id !== id,
);
return {
...prev,
[provider]: updatedProvider,
[provider]: {
...prev[provider],
savedCredentials: prev[provider].savedCredentials.filter(
(cred) => cred.id !== id,
),
},
};
});
return result;
@@ -233,29 +208,12 @@ export default function CredentialsProvider({
const credentialsByProvider = response.reduce(
(acc, cred) => {
if (!acc[cred.provider]) {
acc[cred.provider] = {
oauthCreds: [],
apiKeys: [],
userPasswordCreds: [],
};
}
if (cred.type === "oauth2") {
acc[cred.provider].oauthCreds.push(cred);
} else if (cred.type === "api_key") {
acc[cred.provider].apiKeys.push(cred);
} else if (cred.type === "user_password") {
acc[cred.provider].userPasswordCreds.push(cred);
acc[cred.provider] = [];
}
acc[cred.provider].push(cred);
return acc;
},
{} as Record<
CredentialsProviderName,
{
oauthCreds: CredentialsMetaResponse[];
apiKeys: CredentialsMetaResponse[];
userPasswordCreds: CredentialsMetaResponse[];
}
>,
{} as Record<CredentialsProviderName, CredentialsMetaResponse[]>,
);
setProviders((prev) => ({
@@ -265,40 +223,19 @@ export default function CredentialsProvider({
provider,
{
provider,
providerName:
providerDisplayNames[provider as CredentialsProviderName],
savedApiKeys: credentialsByProvider[provider]?.apiKeys ?? [],
savedOAuthCredentials:
credentialsByProvider[provider]?.oauthCreds ?? [],
savedUserPasswordCredentials:
credentialsByProvider[provider]?.userPasswordCreds ?? [],
providerName: providerDisplayNames[provider],
savedCredentials: credentialsByProvider[provider] ?? [],
oAuthCallback: (code: string, state_token: string) =>
oAuthCallback(
provider as CredentialsProviderName,
code,
state_token,
),
oAuthCallback(provider, code, state_token),
createAPIKeyCredentials: (
credentials: APIKeyCredentialsCreatable,
) =>
createAPIKeyCredentials(
provider as CredentialsProviderName,
credentials,
),
) => createAPIKeyCredentials(provider, credentials),
createUserPasswordCredentials: (
credentials: UserPasswordCredentialsCreatable,
) =>
createUserPasswordCredentials(
provider as CredentialsProviderName,
credentials,
),
) => createUserPasswordCredentials(provider, credentials),
deleteCredentials: (id: string, force: boolean = false) =>
deleteCredentials(
provider as CredentialsProviderName,
id,
force,
),
},
deleteCredentials(provider, id, force),
} satisfies CredentialsProviderData,
]),
),
}));

View File

@@ -8,7 +8,7 @@ export default function LibraryAgentCard({
id,
name,
description,
agent_id,
graph_id: agent_id,
can_access_graph,
creator_image_url,
image_url,
@@ -48,7 +48,7 @@ export default function LibraryAgentCard({
/>
)}
<div className="absolute bottom-4 left-4">
<Avatar className="h-16 w-16 border-2 border-white dark:border-gray-800">
<Avatar className="h-16 w-16">
<AvatarImage
src={
creator_image_url
@@ -57,7 +57,7 @@ export default function LibraryAgentCard({
}
alt={`${name} creator avatar`}
/>
<AvatarFallback>{name.charAt(0)}</AvatarFallback>
<AvatarFallback size={64}>{name.charAt(0)}</AvatarFallback>
</Avatar>
</div>
</Link>

View File

@@ -49,7 +49,7 @@ export default function LibrarySearchBar(): React.ReactNode {
onFocus={() => setIsFocused(true)}
onBlur={() => !inputRef.current?.value && setIsFocused(false)}
onChange={handleSearchInput}
className="flex-1 border-none font-sans text-[16px] font-normal leading-7 shadow-none focus:shadow-none"
className="flex-1 border-none font-sans text-[16px] font-normal leading-7 shadow-none focus:shadow-none focus:ring-0"
type="text"
placeholder="Search agents"
/>

View File

@@ -109,7 +109,7 @@ export const AgentFlowList = ({
lastRun: GraphExecutionMeta | null = null;
if (executions) {
const _flowRuns = executions.filter(
(r) => r.graph_id == flow.agent_id,
(r) => r.graph_id == flow.graph_id,
);
runCount = _flowRuns.length;
lastRun =

Some files were not shown because too many files have changed in this diff Show More