mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge branch 'dev' into abhi-9708/add-better-skeleton-on-agent-run-page
This commit is contained in:
@@ -34,6 +34,7 @@ jobs:
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
|
||||
trigger:
|
||||
|
||||
@@ -36,6 +36,7 @@ jobs:
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
trigger:
|
||||
needs: migrate
|
||||
|
||||
12
.github/workflows/platform-backend-ci.yml
vendored
12
.github/workflows/platform-backend-ci.yml
vendored
@@ -135,6 +135,7 @@ jobs:
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
@@ -151,12 +152,13 @@ jobs:
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: 'localhost'
|
||||
REDIS_PORT: '6379'
|
||||
REDIS_PASSWORD: 'testpassword'
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
|
||||
env:
|
||||
CI: true
|
||||
@@ -169,8 +171,8 @@ jobs:
|
||||
# If you want to replace this, you can do so by making our entire system generate
|
||||
# new credentials for each local user and update the environment variables in
|
||||
# the backend service, docker composes, and examples
|
||||
RABBITMQ_DEFAULT_USER: 'rabbitmq_user_default'
|
||||
RABBITMQ_DEFAULT_PASS: 'k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7'
|
||||
RABBITMQ_DEFAULT_USER: "rabbitmq_user_default"
|
||||
RABBITMQ_DEFAULT_PASS: "k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7"
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
|
||||
@@ -1,20 +1,59 @@
|
||||
import inspect
|
||||
import threading
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
def thread_cached(
|
||||
func: Callable[P, R] | Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
|
||||
thread_local = threading.local()
|
||||
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
return wrapper
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
|
||||
@@ -8,6 +8,7 @@ DB_CONNECT_TIMEOUT=60
|
||||
DB_POOL_TIMEOUT=300
|
||||
DB_SCHEMA=platform
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
|
||||
# EXECUTOR
|
||||
|
||||
@@ -73,7 +73,6 @@ FROM server_dependencies AS server
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
RUN poetry install --no-ansi --only-root
|
||||
|
||||
ENV DATABASE_URL=""
|
||||
ENV PORT=8000
|
||||
|
||||
CMD ["poetry", "run", "rest"]
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -19,21 +17,6 @@ from backend.util import json
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_executor_manager_client():
|
||||
from backend.executor import ExecutionManager
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_event_bus():
|
||||
from backend.data.execution import RedisExecutionEventBus
|
||||
|
||||
return RedisExecutionEventBus()
|
||||
|
||||
|
||||
class AgentExecutorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
user_id: str = SchemaField(description="User ID")
|
||||
@@ -76,23 +59,23 @@ class AgentExecutorBlock(Block):
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
from backend.data.execution import ExecutionEventType
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
executor_manager = get_executor_manager_client()
|
||||
event_bus = get_event_bus()
|
||||
event_bus = execution_utils.get_execution_event_bus()
|
||||
|
||||
graph_exec = executor_manager.add_execution(
|
||||
graph_exec = execution_utils.add_graph_execution(
|
||||
graph_id=input_data.graph_id,
|
||||
graph_version=input_data.graph_version,
|
||||
user_id=input_data.user_id,
|
||||
data=input_data.data,
|
||||
inputs=input_data.data,
|
||||
)
|
||||
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.graph_exec_id}"
|
||||
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.id}"
|
||||
logger.info(f"Starting execution of {log_id}")
|
||||
|
||||
for event in event_bus.listen(
|
||||
user_id=graph_exec.user_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
graph_exec_id=graph_exec.id,
|
||||
):
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
if event.status in [
|
||||
@@ -105,7 +88,7 @@ class AgentExecutorBlock(Block):
|
||||
else:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
|
||||
)
|
||||
|
||||
@@ -123,5 +106,7 @@ class AgentExecutorBlock(Block):
|
||||
continue
|
||||
|
||||
for output_data in event.output_data.get("output", []):
|
||||
logger.info(f"Execution {log_id} produced {output_name}: {output_data}")
|
||||
logger.debug(
|
||||
f"Execution {log_id} produced {output_name}: {output_data}"
|
||||
)
|
||||
yield output_name, output_data
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Any, Iterable, List, Literal, NamedTuple, Optional
|
||||
import anthropic
|
||||
import ollama
|
||||
import openai
|
||||
from anthropic import NotGiven
|
||||
from anthropic.types import ToolParam
|
||||
from groq import Groq
|
||||
from pydantic import BaseModel, SecretStr
|
||||
@@ -249,7 +248,7 @@ class LLMResponse(BaseModel):
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
openai_tools: list[dict] | None = None,
|
||||
) -> Iterable[ToolParam] | NotGiven:
|
||||
) -> Iterable[ToolParam] | anthropic.NotGiven:
|
||||
"""
|
||||
Convert OpenAI tool format to Anthropic tool format.
|
||||
"""
|
||||
@@ -287,6 +286,7 @@ def llm_call(
|
||||
max_tokens: int | None,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
parallel_tool_calls: bool | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Make a call to a language model.
|
||||
@@ -332,6 +332,9 @@ def llm_call(
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=(
|
||||
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
|
||||
),
|
||||
)
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
@@ -487,6 +490,9 @@ def llm_call(
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=(
|
||||
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
|
||||
),
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
|
||||
@@ -491,6 +491,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=False,
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
|
||||
@@ -28,6 +28,7 @@ from backend.util.settings import Config
|
||||
from .model import (
|
||||
ContributorDetails,
|
||||
Credentials,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
@@ -203,6 +204,15 @@ class BlockSchema(BaseModel):
|
||||
)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
|
||||
return {
|
||||
field_name: CredentialsFieldInfo.model_validate(
|
||||
cls.get_field_schema(field_name), by_alias=True
|
||||
)
|
||||
for field_name in cls.get_credentials_fields().keys()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
return data # Return as is, by default.
|
||||
@@ -509,6 +519,7 @@ async def initialize_blocks() -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_block(block_id: str) -> Block | None:
|
||||
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||
def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
|
||||
cls = get_blocks().get(block_id)
|
||||
return cls() if cls else None
|
||||
|
||||
@@ -11,6 +11,7 @@ from prisma.enums import (
|
||||
CreditRefundRequestStatus,
|
||||
CreditTransactionType,
|
||||
NotificationType,
|
||||
OnboardingStep,
|
||||
)
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User
|
||||
@@ -121,6 +122,18 @@ class UserCreditBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
|
||||
"""
|
||||
Reward the user with credits for completing an onboarding step.
|
||||
Won't reward if the user has already received credits for the step.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
step (OnboardingStep): The onboarding step.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def top_up_intent(self, user_id: str, amount: int) -> str:
|
||||
"""
|
||||
@@ -213,7 +226,7 @@ class UserCreditBase(ABC):
|
||||
"userId": user_id,
|
||||
"createdAt": {"lte": top_time},
|
||||
"isActive": True,
|
||||
"runningBalance": {"not": None}, # type: ignore
|
||||
"NOT": [{"runningBalance": None}],
|
||||
},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
@@ -408,6 +421,24 @@ class UserCredit(UserCreditBase):
|
||||
async def top_up_credits(self, user_id: str, amount: int):
|
||||
await self._top_up_credits(user_id, amount)
|
||||
|
||||
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
|
||||
key = f"REWARD-{user_id}-{step.value}"
|
||||
if not await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"transactionKey": key,
|
||||
}
|
||||
):
|
||||
await self._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=credits,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
transaction_key=key,
|
||||
metadata=Json(
|
||||
{"reason": f"Reward for completing {step.value} onboarding step."}
|
||||
),
|
||||
)
|
||||
|
||||
async def top_up_refund(
|
||||
self, user_id: str, transaction_key: str, metadata: dict[str, str]
|
||||
) -> int:
|
||||
@@ -895,6 +926,9 @@ class DisabledUserCredit(UserCreditBase):
|
||||
async def top_up_credits(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def onboarding_reward(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def top_up_intent(self, *args, **kwargs) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
@@ -62,10 +62,10 @@ async def connect():
|
||||
|
||||
# Connection acquired from a pool like Supabase somehow still possibly allows
|
||||
# the db client obtains a connection but still reject query connection afterward.
|
||||
try:
|
||||
await prisma.execute_raw("SELECT 1")
|
||||
except Exception as e:
|
||||
raise ConnectionError("Failed to connect to Prisma.") from e
|
||||
# try:
|
||||
# await prisma.execute_raw("SELECT 1")
|
||||
# except Exception as e:
|
||||
# raise ConnectionError("Failed to connect to Prisma.") from e
|
||||
|
||||
|
||||
@conn_retry("Prisma", "Releasing connection")
|
||||
@@ -89,7 +89,7 @@ async def transaction():
|
||||
async def locked_transaction(key: str):
|
||||
lock_key = zlib.crc32(key.encode("utf-8"))
|
||||
async with transaction() as tx:
|
||||
await tx.execute_raw(f"SELECT pg_advisory_xact_lock({lock_key})")
|
||||
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
|
||||
yield tx
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from prisma.models import (
|
||||
AgentNodeExecutionInputOutput,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionCreateInput,
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionCreateInput,
|
||||
AgentNodeExecutionInputOutputCreateInput,
|
||||
@@ -33,18 +34,17 @@ from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import mock
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .block import BlockData, BlockInput, BlockType, CompletedBlockOutput, get_block
|
||||
from .block import BlockInput, BlockType, CompletedBlockOutput, get_block
|
||||
from .db import BaseDbModel
|
||||
from .includes import (
|
||||
EXECUTION_RESULT_INCLUDE,
|
||||
GRAPH_EXECUTION_INCLUDE,
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
||||
from .queue import AsyncRedisEventBus, RedisEventBus
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -121,7 +121,7 @@ class GraphExecution(GraphExecutionMeta):
|
||||
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
if _graph_exec.AgentNodeExecutions is None:
|
||||
if _graph_exec.NodeExecutions is None:
|
||||
raise ValueError("Node executions must be included in query")
|
||||
|
||||
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
|
||||
@@ -129,7 +129,7 @@ class GraphExecution(GraphExecutionMeta):
|
||||
complete_node_executions = sorted(
|
||||
[
|
||||
NodeExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
for ne in _graph_exec.NodeExecutions
|
||||
if ne.executionStatus != ExecutionStatus.INCOMPLETE
|
||||
],
|
||||
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
|
||||
@@ -181,7 +181,7 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
if _graph_exec.AgentNodeExecutions is None:
|
||||
if _graph_exec.NodeExecutions is None:
|
||||
raise ValueError("Node executions must be included in query")
|
||||
|
||||
graph_exec_with_io = GraphExecution.from_db(_graph_exec)
|
||||
@@ -189,7 +189,7 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
node_executions = sorted(
|
||||
[
|
||||
NodeExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
for ne in _graph_exec.NodeExecutions
|
||||
],
|
||||
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
|
||||
)
|
||||
@@ -202,6 +202,27 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
node_executions=node_executions,
|
||||
)
|
||||
|
||||
def to_graph_execution_entry(self):
|
||||
return GraphExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_id=self.graph_id,
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
start_node_execs=[
|
||||
NodeExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
data=node_exec.input_data,
|
||||
)
|
||||
for node_exec in self.node_executions
|
||||
],
|
||||
node_credentials_input_map={}, # FIXME
|
||||
)
|
||||
|
||||
|
||||
class NodeExecutionResult(BaseModel):
|
||||
user_id: str
|
||||
@@ -220,21 +241,21 @@ class NodeExecutionResult(BaseModel):
|
||||
end_time: datetime | None
|
||||
|
||||
@staticmethod
|
||||
def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None):
|
||||
if execution.executionData:
|
||||
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
|
||||
if _node_exec.executionData:
|
||||
# Execution that has been queued for execution will persist its data.
|
||||
input_data = type_utils.convert(execution.executionData, dict[str, Any])
|
||||
input_data = type_utils.convert(_node_exec.executionData, dict[str, Any])
|
||||
else:
|
||||
# For incomplete execution, executionData will not be yet available.
|
||||
input_data: BlockInput = defaultdict()
|
||||
for data in execution.Input or []:
|
||||
for data in _node_exec.Input or []:
|
||||
input_data[data.name] = type_utils.convert(data.data, type[Any])
|
||||
|
||||
output_data: CompletedBlockOutput = defaultdict(list)
|
||||
for data in execution.Output or []:
|
||||
for data in _node_exec.Output or []:
|
||||
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
|
||||
|
||||
graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution
|
||||
graph_execution: AgentGraphExecution | None = _node_exec.GraphExecution
|
||||
if graph_execution:
|
||||
user_id = graph_execution.userId
|
||||
elif not user_id:
|
||||
@@ -246,17 +267,17 @@ class NodeExecutionResult(BaseModel):
|
||||
user_id=user_id,
|
||||
graph_id=graph_execution.agentGraphId if graph_execution else "",
|
||||
graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
|
||||
graph_exec_id=execution.agentGraphExecutionId,
|
||||
block_id=execution.AgentNode.agentBlockId if execution.AgentNode else "",
|
||||
node_exec_id=execution.id,
|
||||
node_id=execution.agentNodeId,
|
||||
status=execution.executionStatus,
|
||||
graph_exec_id=_node_exec.agentGraphExecutionId,
|
||||
block_id=_node_exec.Node.agentBlockId if _node_exec.Node else "",
|
||||
node_exec_id=_node_exec.id,
|
||||
node_id=_node_exec.agentNodeId,
|
||||
status=_node_exec.executionStatus,
|
||||
input_data=input_data,
|
||||
output_data=output_data,
|
||||
add_time=execution.addedTime,
|
||||
queue_time=execution.queuedTime,
|
||||
start_time=execution.startedTime,
|
||||
end_time=execution.endedTime,
|
||||
add_time=_node_exec.addedTime,
|
||||
queue_time=_node_exec.queuedTime,
|
||||
start_time=_node_exec.startedTime,
|
||||
end_time=_node_exec.endedTime,
|
||||
)
|
||||
|
||||
|
||||
@@ -341,7 +362,7 @@ async def get_graph_execution(
|
||||
async def create_graph_execution(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
nodes_input: list[tuple[str, BlockInput]],
|
||||
starting_nodes_input: list[tuple[str, BlockInput]],
|
||||
user_id: str,
|
||||
preset_id: str | None = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
@@ -351,29 +372,29 @@ async def create_graph_execution(
|
||||
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
||||
"""
|
||||
result = await AgentGraphExecution.prisma().create(
|
||||
data={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
"AgentNodeExecutions": {
|
||||
"create": [ # type: ignore
|
||||
{
|
||||
"agentNodeId": node_id,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
"queuedTime": datetime.now(tz=timezone.utc),
|
||||
"Input": {
|
||||
data=AgentGraphExecutionCreateInput(
|
||||
agentGraphId=graph_id,
|
||||
agentGraphVersion=graph_version,
|
||||
executionStatus=ExecutionStatus.QUEUED,
|
||||
NodeExecutions={
|
||||
"create": [
|
||||
AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
executionStatus=ExecutionStatus.QUEUED,
|
||||
queuedTime=datetime.now(tz=timezone.utc),
|
||||
Input={
|
||||
"create": [
|
||||
{"name": name, "data": Json(data)}
|
||||
for name, data in node_input.items()
|
||||
]
|
||||
},
|
||||
}
|
||||
for node_id, node_input in nodes_input
|
||||
)
|
||||
for node_id, node_input in starting_nodes_input
|
||||
]
|
||||
},
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
},
|
||||
userId=user_id,
|
||||
agentPresetId=preset_id,
|
||||
),
|
||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
|
||||
@@ -468,19 +489,27 @@ async def upsert_execution_output(
|
||||
)
|
||||
|
||||
|
||||
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecution:
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
async def update_graph_execution_start_time(
|
||||
graph_exec_id: str,
|
||||
) -> GraphExecution | None:
|
||||
count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
},
|
||||
data={
|
||||
"executionStatus": ExecutionStatus.RUNNING,
|
||||
"startedAt": datetime.now(tz=timezone.utc),
|
||||
},
|
||||
)
|
||||
if count == 0:
|
||||
return None
|
||||
|
||||
res = await AgentGraphExecution.prisma().find_unique(
|
||||
where={"id": graph_exec_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
if not res:
|
||||
raise ValueError(f"Graph execution #{graph_exec_id} not found")
|
||||
|
||||
return GraphExecution.from_db(res)
|
||||
return GraphExecution.from_db(res) if res else None
|
||||
|
||||
|
||||
async def update_graph_execution_stats(
|
||||
@@ -491,7 +520,8 @@ async def update_graph_execution_stats(
|
||||
data = stats.model_dump() if stats else {}
|
||||
if isinstance(data.get("error"), Exception):
|
||||
data["error"] = str(data["error"])
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
|
||||
updated_count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"OR": [
|
||||
@@ -503,10 +533,15 @@ async def update_graph_execution_stats(
|
||||
"executionStatus": status,
|
||||
"stats": Json(data),
|
||||
},
|
||||
)
|
||||
if updated_count == 0:
|
||||
return None
|
||||
|
||||
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
|
||||
where={"id": graph_exec_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
|
||||
return GraphExecution.from_db(res) if res else None
|
||||
return GraphExecution.from_db(graph_exec)
|
||||
|
||||
|
||||
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
|
||||
@@ -600,7 +635,7 @@ async def get_node_execution_results(
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
}
|
||||
if block_ids:
|
||||
where_clause["AgentNode"] = {"is": {"agentBlockId": {"in": block_ids}}}
|
||||
where_clause["Node"] = {"is": {"agentBlockId": {"in": block_ids}}}
|
||||
if statuses:
|
||||
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
@@ -642,7 +677,7 @@ async def get_latest_node_execution(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
|
||||
"NOT": [{"executionStatus": ExecutionStatus.INCOMPLETE}],
|
||||
},
|
||||
order=[
|
||||
{"queuedTime": "desc"},
|
||||
@@ -678,6 +713,7 @@ class GraphExecutionEntry(BaseModel):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
start_node_execs: list["NodeExecutionEntry"]
|
||||
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]]
|
||||
|
||||
|
||||
class NodeExecutionEntry(BaseModel):
|
||||
@@ -710,144 +746,6 @@ class ExecutionQueue(Generic[T]):
|
||||
return self.queue.empty()
|
||||
|
||||
|
||||
# ------------------- Execution Utilities -------------------- #
|
||||
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
"""
|
||||
Extracts partial output data by name from a given BlockData.
|
||||
|
||||
The function supports extracting data from lists, dictionaries, and objects
|
||||
using specific naming conventions:
|
||||
- For lists: <output_name>_$_<index>
|
||||
- For dictionaries: <output_name>_#_<key>
|
||||
- For objects: <output_name>_@_<attribute>
|
||||
|
||||
Args:
|
||||
output (BlockData): A tuple containing the output name and data.
|
||||
name (str): The name used to extract specific data from the output.
|
||||
|
||||
Returns:
|
||||
Any | None: The extracted data if found, otherwise None.
|
||||
|
||||
Examples:
|
||||
>>> output = ("result", [10, 20, 30])
|
||||
>>> parse_execution_output(output, "result_$_1")
|
||||
20
|
||||
|
||||
>>> output = ("config", {"key1": "value1", "key2": "value2"})
|
||||
>>> parse_execution_output(output, "config_#_key1")
|
||||
'value1'
|
||||
|
||||
>>> class Sample:
|
||||
... attr1 = "value1"
|
||||
... attr2 = "value2"
|
||||
>>> output = ("object", Sample())
|
||||
>>> parse_execution_output(output, "object_@_attr1")
|
||||
'value1'
|
||||
"""
|
||||
output_name, output_data = output
|
||||
|
||||
if name == output_name:
|
||||
return output_data
|
||||
|
||||
if name.startswith(f"{output_name}{LIST_SPLIT}"):
|
||||
index = int(name.split(LIST_SPLIT)[1])
|
||||
if not isinstance(output_data, list) or len(output_data) <= index:
|
||||
return None
|
||||
return output_data[int(name.split(LIST_SPLIT)[1])]
|
||||
|
||||
if name.startswith(f"{output_name}{DICT_SPLIT}"):
|
||||
index = name.split(DICT_SPLIT)[1]
|
||||
if not isinstance(output_data, dict) or index not in output_data:
|
||||
return None
|
||||
return output_data[index]
|
||||
|
||||
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
|
||||
index = name.split(OBJC_SPLIT)[1]
|
||||
if isinstance(output_data, object) and hasattr(output_data, index):
|
||||
return getattr(output_data, index)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
"""
|
||||
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
|
||||
|
||||
This function processes input keys that follow specific patterns to merge them into a unified structure:
|
||||
- `<input_name>_$_<index>` for list inputs.
|
||||
- `<input_name>_#_<index>` for dictionary inputs.
|
||||
- `<input_name>_@_<index>` for object inputs.
|
||||
|
||||
Args:
|
||||
data (BlockInput): A dictionary containing input keys and their corresponding values.
|
||||
|
||||
Returns:
|
||||
BlockInput: A dictionary with merged inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If a list index is not an integer.
|
||||
|
||||
Examples:
|
||||
>>> data = {
|
||||
... "list_$_0": "a",
|
||||
... "list_$_1": "b",
|
||||
... "dict_#_key1": "value1",
|
||||
... "dict_#_key2": "value2",
|
||||
... "object_@_attr1": "value1",
|
||||
... "object_@_attr2": "value2"
|
||||
... }
|
||||
>>> merge_execution_input(data)
|
||||
{
|
||||
"list": ["a", "b"],
|
||||
"dict": {"key1": "value1", "key2": "value2"},
|
||||
"object": <MockObject attr1="value1" attr2="value2">
|
||||
}
|
||||
"""
|
||||
|
||||
# Merge all input with <input_name>_$_<index> into a single list.
|
||||
items = list(data.items())
|
||||
|
||||
for key, value in items:
|
||||
if LIST_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(LIST_SPLIT)
|
||||
if not index.isdigit():
|
||||
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
|
||||
|
||||
data[name] = data.get(name, [])
|
||||
if int(index) >= len(data[name]):
|
||||
# Pad list with empty string on missing indices.
|
||||
data[name].extend([""] * (int(index) - len(data[name]) + 1))
|
||||
data[name][int(index)] = value
|
||||
|
||||
# Merge all input with <input_name>_#_<index> into a single dict.
|
||||
for key, value in items:
|
||||
if DICT_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(DICT_SPLIT)
|
||||
data[name] = data.get(name, {})
|
||||
data[name][index] = value
|
||||
|
||||
# Merge all input with <input_name>_@_<index> into a single object.
|
||||
for key, value in items:
|
||||
if OBJC_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(OBJC_SPLIT)
|
||||
if name not in data or not isinstance(data[name], object):
|
||||
data[name] = mock.MockObject()
|
||||
setattr(data[name], index, value)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# --------------------- Event Bus --------------------- #
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any, Literal, Optional, Type
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
import prisma
|
||||
from prisma import Json
|
||||
@@ -13,12 +13,19 @@ from prisma.types import (
|
||||
AgentNodeCreateInput,
|
||||
AgentNodeLinkCreateInput,
|
||||
)
|
||||
from pydantic import create_model
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import prisma as db
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
from backend.util import type as type_utils
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
@@ -190,14 +197,19 @@ class BaseGraph(BaseDbModel):
|
||||
)
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
return self._credentials_input_schema.jsonschema()
|
||||
|
||||
@staticmethod
|
||||
def _generate_schema(
|
||||
*props: tuple[Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input], dict],
|
||||
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
|
||||
) -> dict[str, Any]:
|
||||
schema = []
|
||||
schema_fields: list[AgentInputBlock.Input | AgentOutputBlock.Input] = []
|
||||
for type_class, input_default in props:
|
||||
try:
|
||||
schema.append(type_class(**input_default))
|
||||
schema_fields.append(type_class(**input_default))
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid {type_class}: {input_default}, {e}")
|
||||
|
||||
@@ -217,9 +229,93 @@ class BaseGraph(BaseDbModel):
|
||||
**({"description": p.description} if p.description else {}),
|
||||
**({"default": p.value} if p.value is not None else {}),
|
||||
}
|
||||
for p in schema
|
||||
for p in schema_fields
|
||||
},
|
||||
"required": [p.name for p in schema if p.value is None],
|
||||
"required": [p.name for p in schema_fields if p.value is None],
|
||||
}
|
||||
|
||||
@property
|
||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
logger.debug(
|
||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||
f"{graph_credentials_inputs}"
|
||||
)
|
||||
|
||||
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
||||
graph_cred_fields = list(graph_credentials_inputs.values())
|
||||
for i, (field, keys) in enumerate(graph_cred_fields):
|
||||
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
||||
if field.provider != other_field.provider:
|
||||
continue
|
||||
|
||||
# If this happens, that means a block implementation probably needs
|
||||
# to be updated.
|
||||
logger.warning(
|
||||
"Multiple combined credentials fields "
|
||||
f"for provider {field.provider} "
|
||||
f"on graph #{self.id} ({self.name}); "
|
||||
f"fields: {field} <> {other_field};"
|
||||
f"keys: {keys} <> {other_keys}."
|
||||
)
|
||||
|
||||
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
|
||||
agg_field_key: (
|
||||
CredentialsMetaInput[
|
||||
Literal[tuple(field_info.provider)], # type: ignore
|
||||
Literal[tuple(field_info.supported_types)], # type: ignore
|
||||
],
|
||||
CredentialsField(
|
||||
required_scopes=set(field_info.required_scopes or []),
|
||||
discriminator=field_info.discriminator,
|
||||
discriminator_mapping=field_info.discriminator_mapping,
|
||||
),
|
||||
)
|
||||
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
||||
}
|
||||
|
||||
return create_model(
|
||||
self.name.replace(" ", "") + "CredentialsInputSchema",
|
||||
__base__=BlockSchema,
|
||||
**fields, # type: ignore
|
||||
)
|
||||
|
||||
def aggregate_credentials_inputs(
|
||||
self,
|
||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
|
||||
"""
|
||||
Returns:
|
||||
dict[aggregated_field_key, tuple(
|
||||
CredentialsFieldInfo: A spec for one aggregated credentials field
|
||||
set[(node_id, field_name)]: Node credentials fields that are
|
||||
compatible with this aggregated field spec
|
||||
)]
|
||||
"""
|
||||
return {
|
||||
"_".join(sorted(agg_field_info.provider))
|
||||
+ "_"
|
||||
+ "_".join(sorted(agg_field_info.supported_types))
|
||||
+ "_credentials": (agg_field_info, node_fields)
|
||||
for agg_field_info, node_fields in CredentialsFieldInfo.combine(
|
||||
*(
|
||||
(
|
||||
# Apply discrimination before aggregating credentials inputs
|
||||
(
|
||||
field_info.discriminate(
|
||||
node.input_default[field_info.discriminator]
|
||||
)
|
||||
if (
|
||||
field_info.discriminator
|
||||
and node.input_default.get(field_info.discriminator)
|
||||
)
|
||||
else field_info
|
||||
),
|
||||
(node.id, field_name),
|
||||
)
|
||||
for node in self.nodes
|
||||
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -320,8 +416,6 @@ class GraphModel(Graph):
|
||||
return sanitized_name
|
||||
|
||||
# Validate smart decision maker nodes
|
||||
smart_decision_maker_nodes = set()
|
||||
agent_nodes = set()
|
||||
nodes_block = {
|
||||
node.id: block
|
||||
for node in graph.nodes
|
||||
@@ -332,13 +426,6 @@ class GraphModel(Graph):
|
||||
if (block := nodes_block.get(node.id)) is None:
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
# Smart decision maker nodes
|
||||
if block.block_type == BlockType.AI:
|
||||
smart_decision_maker_nodes.add(node.id)
|
||||
# Agent nodes
|
||||
elif block.block_type == BlockType.AGENT:
|
||||
agent_nodes.add(node.id)
|
||||
|
||||
input_links = defaultdict(list)
|
||||
|
||||
for link in graph.links:
|
||||
@@ -353,16 +440,21 @@ class GraphModel(Graph):
|
||||
[sanitize(name) for name in node.input_default]
|
||||
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
|
||||
)
|
||||
for name in block.input_schema.get_required_fields():
|
||||
input_schema = block.input_schema
|
||||
for name in (required_fields := input_schema.get_required_fields()):
|
||||
if (
|
||||
name not in provided_inputs
|
||||
# Webhook payload is passed in by ExecutionManager
|
||||
and not (
|
||||
name == "payload"
|
||||
and block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
)
|
||||
# Checking availability of credentials is done by ExecutionManager
|
||||
and name not in input_schema.get_credentials_fields()
|
||||
# Validate only I/O nodes, or validate everything when executing
|
||||
and (
|
||||
for_run # Skip input completion validation, unless when executing.
|
||||
for_run
|
||||
or block.block_type
|
||||
in [
|
||||
BlockType.INPUT,
|
||||
@@ -375,9 +467,18 @@ class GraphModel(Graph):
|
||||
f"Node {block.name} #{node.id} required input missing: `{name}`"
|
||||
)
|
||||
|
||||
if (
|
||||
block.block_type == BlockType.INPUT
|
||||
and (input_key := node.input_default.get("name"))
|
||||
and is_credentials_field_name(input_key)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Agent input node uses reserved name '{input_key}'; "
|
||||
"'credentials' and `*_credentials` are reserved input names"
|
||||
)
|
||||
|
||||
# Get input schema properties and check dependencies
|
||||
input_schema = block.input_schema.model_fields
|
||||
required_fields = block.input_schema.get_required_fields()
|
||||
input_fields = input_schema.model_fields
|
||||
|
||||
def has_value(name):
|
||||
return (
|
||||
@@ -385,14 +486,21 @@ class GraphModel(Graph):
|
||||
and name in node.input_default
|
||||
and node.input_default[name] is not None
|
||||
and str(node.input_default[name]).strip() != ""
|
||||
) or (name in input_schema and input_schema[name].default is not None)
|
||||
) or (name in input_fields and input_fields[name].default is not None)
|
||||
|
||||
# Validate dependencies between fields
|
||||
for field_name, field_info in input_schema.items():
|
||||
for field_name, field_info in input_fields.items():
|
||||
# Apply input dependency validation only on run & field with depends_on
|
||||
json_schema_extra = field_info.json_schema_extra or {}
|
||||
dependencies = json_schema_extra.get("depends_on", [])
|
||||
if not for_run or not dependencies:
|
||||
if not (
|
||||
for_run
|
||||
and isinstance(json_schema_extra, dict)
|
||||
and (
|
||||
dependencies := cast(
|
||||
list[str], json_schema_extra.get("depends_on", [])
|
||||
)
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
# Check if dependent field has value in input_default
|
||||
@@ -465,13 +573,11 @@ class GraphModel(Graph):
|
||||
is_active=graph.isActive,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
nodes=[
|
||||
NodeModel.from_db(node, for_export) for node in graph.AgentNodes or []
|
||||
],
|
||||
nodes=[NodeModel.from_db(node, for_export) for node in graph.Nodes or []],
|
||||
links=list(
|
||||
{
|
||||
Link.from_db(link)
|
||||
for node in graph.AgentNodes or []
|
||||
for node in graph.Nodes or []
|
||||
for link in (node.Input or []) + (node.Output or [])
|
||||
}
|
||||
),
|
||||
@@ -602,8 +708,8 @@ async def get_graph(
|
||||
and not (
|
||||
await StoreListingVersion.prisma().find_first(
|
||||
where={
|
||||
"agentId": graph_id,
|
||||
"agentVersion": version or graph.version,
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": version or graph.version,
|
||||
"isDeleted": False,
|
||||
"submissionStatus": SubmissionStatus.APPROVED,
|
||||
}
|
||||
@@ -637,12 +743,16 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
sub_graph_ids = [
|
||||
(graph_id, graph_version)
|
||||
for graph in search_graphs
|
||||
for node in graph.AgentNodes or []
|
||||
for node in graph.Nodes or []
|
||||
if (
|
||||
node.AgentBlock
|
||||
and node.AgentBlock.id == agent_block_id
|
||||
and (graph_id := dict(node.constantInput).get("graph_id"))
|
||||
and (graph_version := dict(node.constantInput).get("graph_version"))
|
||||
and (graph_id := cast(str, dict(node.constantInput).get("graph_id")))
|
||||
and (
|
||||
graph_version := cast(
|
||||
int, dict(node.constantInput).get("graph_version")
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
if not sub_graph_ids:
|
||||
@@ -657,7 +767,7 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
|
||||
}
|
||||
for graph_id, graph_version in sub_graph_ids
|
||||
] # type: ignore
|
||||
]
|
||||
},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
@@ -671,7 +781,7 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
|
||||
links = await AgentNodeLink.prisma().find_many(
|
||||
where={"agentNodeSourceId": node_id},
|
||||
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}}, # type: ignore
|
||||
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}},
|
||||
)
|
||||
return [
|
||||
(Link.from_db(link), NodeModel.from_db(link.AgentNodeSink))
|
||||
@@ -829,12 +939,12 @@ async def fix_llm_provider_credentials():
|
||||
SELECT graph."userId" user_id,
|
||||
node.id node_id,
|
||||
node."constantInput" node_preset_input
|
||||
FROM platform."AgentNode" node
|
||||
LEFT JOIN platform."AgentGraph" graph
|
||||
ON node."agentGraphId" = graph.id
|
||||
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
|
||||
ORDER BY graph."userId";
|
||||
"""
|
||||
FROM platform."AgentNode" node
|
||||
LEFT JOIN platform."AgentGraph" graph
|
||||
ON node."agentGraphId" = graph.id
|
||||
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
|
||||
ORDER BY graph."userId";
|
||||
"""
|
||||
)
|
||||
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")
|
||||
except Exception as e:
|
||||
@@ -912,17 +1022,24 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
if field.annotation == LlmModel:
|
||||
llm_model_fields[block.id] = field_name
|
||||
|
||||
# Convert enum values to a list of strings for the SQL query
|
||||
enum_values = [v.value for v in LlmModel]
|
||||
escaped_enum_values = repr(tuple(enum_values)) # hack but works
|
||||
|
||||
# Update each block
|
||||
for id, path in llm_model_fields.items():
|
||||
# Convert enum values to a list of strings for the SQL query
|
||||
enum_values = [v.value for v in LlmModel.__members__.values()]
|
||||
|
||||
query = f"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = jsonb_set("constantInput", '{{{path}}}', '"{migrate_to.value}"', true)
|
||||
WHERE "agentBlockId" = '{id}'
|
||||
AND "constantInput" ? '{path}'
|
||||
AND "constantInput"->>'{path}' NOT IN ({','.join(f"'{value}'" for value in enum_values)})
|
||||
UPDATE platform."AgentNode"
|
||||
SET "constantInput" = jsonb_set("constantInput", $1, to_jsonb($2), true)
|
||||
WHERE "agentBlockId" = $3
|
||||
AND "constantInput" ? ($4)::text
|
||||
AND "constantInput"->>($4)::text NOT IN {escaped_enum_values}
|
||||
"""
|
||||
|
||||
await db.execute_raw(query)
|
||||
await db.execute_raw(
|
||||
query, # type: ignore - is supposed to be LiteralString
|
||||
[path],
|
||||
migrate_to.value,
|
||||
id,
|
||||
path,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import cast
|
||||
|
||||
import prisma.enums
|
||||
import prisma.types
|
||||
|
||||
@@ -11,25 +13,25 @@ AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
}
|
||||
|
||||
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
|
||||
"Nodes": {"include": AGENT_NODE_INCLUDE}
|
||||
}
|
||||
|
||||
EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
"Node": True,
|
||||
"GraphExecution": True,
|
||||
}
|
||||
|
||||
MAX_NODE_EXECUTIONS_FETCH = 1000
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
|
||||
"AgentNodeExecutions": {
|
||||
"NodeExecutions": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
"Node": True,
|
||||
"GraphExecution": True,
|
||||
},
|
||||
"order_by": [
|
||||
{"queuedTime": "desc"},
|
||||
@@ -41,31 +43,30 @@ GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
|
||||
}
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
"AgentNodeExecutions": {
|
||||
**GRAPH_EXECUTION_INCLUDE_WITH_NODES["AgentNodeExecutions"], # type: ignore
|
||||
"NodeExecutions": {
|
||||
**cast(
|
||||
prisma.types.FindManyAgentNodeExecutionArgsFromAgentGraphExecution,
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES["NodeExecutions"],
|
||||
),
|
||||
"where": {
|
||||
"AgentNode": {
|
||||
"AgentBlock": {"id": {"in": IO_BLOCK_IDs}}, # type: ignore
|
||||
},
|
||||
"NOT": {
|
||||
"executionStatus": prisma.enums.AgentExecutionStatus.INCOMPLETE,
|
||||
},
|
||||
"Node": {"is": {"AgentBlock": {"is": {"id": {"in": IO_BLOCK_IDs}}}}},
|
||||
"NOT": [{"executionStatus": prisma.enums.AgentExecutionStatus.INCOMPLETE}],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE}
|
||||
}
|
||||
|
||||
|
||||
def library_agent_include(user_id: str) -> prisma.types.LibraryAgentInclude:
|
||||
return {
|
||||
"Agent": {
|
||||
"AgentGraph": {
|
||||
"include": {
|
||||
**AGENT_GRAPH_INCLUDE,
|
||||
"AgentGraphExecution": {"where": {"userId": user_id}},
|
||||
"Executions": {"where": {"userId": user_id}},
|
||||
}
|
||||
},
|
||||
"Creator": True,
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -12,6 +13,7 @@ from typing import (
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
get_args,
|
||||
@@ -300,9 +302,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
field_schema = model.jsonschema()["properties"][field_name]
|
||||
try:
|
||||
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
|
||||
field_schema
|
||||
)
|
||||
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
||||
except ValidationError as e:
|
||||
if "Field required [type=missing" not in str(e):
|
||||
raise
|
||||
@@ -328,14 +328,90 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
|
||||
|
||||
class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
|
||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
|
||||
credentials_provider: list[CP]
|
||||
credentials_scopes: Optional[list[str]] = None
|
||||
credentials_types: list[CT]
|
||||
provider: frozenset[CP] = Field(..., alias="credentials_provider")
|
||||
supported_types: frozenset[CT] = Field(..., alias="credentials_types")
|
||||
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
|
||||
discriminator: Optional[str] = None
|
||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||
|
||||
@classmethod
|
||||
def combine(
|
||||
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
|
||||
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
|
||||
"""
|
||||
Combines multiple CredentialsFieldInfo objects into as few as possible.
|
||||
|
||||
Rules:
|
||||
- Items can only be combined if they have the same supported credentials types
|
||||
and the same supported providers.
|
||||
- When combining items, the `required_scopes` of the result is a join
|
||||
of the `required_scopes` of the original items.
|
||||
|
||||
Params:
|
||||
*fields: (CredentialsFieldInfo, key) objects to group and combine
|
||||
|
||||
Returns:
|
||||
A sequence of tuples containing combined CredentialsFieldInfo objects and
|
||||
the set of keys of the respective original items that were grouped together.
|
||||
"""
|
||||
if not fields:
|
||||
return []
|
||||
|
||||
# Group fields by their provider and supported_types
|
||||
grouped_fields: defaultdict[
|
||||
tuple[frozenset[CP], frozenset[CT]],
|
||||
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
|
||||
] = defaultdict(list)
|
||||
|
||||
for field, key in fields:
|
||||
group_key = (frozenset(field.provider), frozenset(field.supported_types))
|
||||
grouped_fields[group_key].append((key, field))
|
||||
|
||||
# Combine fields within each group
|
||||
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
|
||||
|
||||
for group in grouped_fields.values():
|
||||
# Start with the first field in the group
|
||||
_, combined = group[0]
|
||||
|
||||
# Track the keys that were combined
|
||||
combined_keys = {key for key, _ in group}
|
||||
|
||||
# Combine required_scopes from all fields in the group
|
||||
all_scopes = set()
|
||||
for _, field in group:
|
||||
if field.required_scopes:
|
||||
all_scopes.update(field.required_scopes)
|
||||
|
||||
# Create a new combined field
|
||||
result.append(
|
||||
(
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
credentials_types=combined.supported_types,
|
||||
credentials_scopes=frozenset(all_scopes) or None,
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
),
|
||||
combined_keys,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def discriminate(self, discriminator_value: Any) -> CredentialsFieldInfo:
|
||||
if not (self.discriminator and self.discriminator_mapping):
|
||||
return self
|
||||
|
||||
discriminator_value = self.discriminator_mapping[discriminator_value]
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([discriminator_value]),
|
||||
credentials_types=self.supported_types,
|
||||
credentials_scopes=self.required_scopes,
|
||||
)
|
||||
|
||||
|
||||
def CredentialsField(
|
||||
required_scopes: set[str] = set(),
|
||||
|
||||
@@ -6,9 +6,11 @@ import pydantic
|
||||
from prisma import Json
|
||||
from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingUpdateInput
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block import get_blocks
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
@@ -24,21 +26,26 @@ REASON_MAPPING: dict[str, list[str]] = {
|
||||
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
|
||||
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
|
||||
|
||||
user_credit = get_user_credit_model()
|
||||
|
||||
|
||||
class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
completedSteps: Optional[list[OnboardingStep]] = None
|
||||
notificationDot: Optional[bool] = None
|
||||
notified: Optional[list[OnboardingStep]] = None
|
||||
usageReason: Optional[str] = None
|
||||
integrations: Optional[list[str]] = None
|
||||
otherIntegrations: Optional[str] = None
|
||||
selectedStoreListingVersionId: Optional[str] = None
|
||||
agentInput: Optional[dict[str, Any]] = None
|
||||
onboardingAgentExecutionId: Optional[str] = None
|
||||
|
||||
|
||||
async def get_user_onboarding(user_id: str):
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id}, # type: ignore
|
||||
"create": UserOnboardingCreateInput(userId=user_id),
|
||||
"update": {},
|
||||
},
|
||||
)
|
||||
@@ -48,6 +55,20 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update: UserOnboardingUpdateInput = {}
|
||||
if data.completedSteps is not None:
|
||||
update["completedSteps"] = list(set(data.completedSteps))
|
||||
for step in (
|
||||
OnboardingStep.AGENT_NEW_RUN,
|
||||
OnboardingStep.GET_RESULTS,
|
||||
OnboardingStep.MARKETPLACE_ADD_AGENT,
|
||||
OnboardingStep.MARKETPLACE_RUN_AGENT,
|
||||
OnboardingStep.BUILDER_SAVE_AGENT,
|
||||
OnboardingStep.BUILDER_RUN_AGENT,
|
||||
):
|
||||
if step in data.completedSteps:
|
||||
await reward_user(user_id, step)
|
||||
if data.notificationDot is not None:
|
||||
update["notificationDot"] = data.notificationDot
|
||||
if data.notified is not None:
|
||||
update["notified"] = list(set(data.notified))
|
||||
if data.usageReason is not None:
|
||||
update["usageReason"] = data.usageReason
|
||||
if data.integrations is not None:
|
||||
@@ -58,16 +79,57 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update["selectedStoreListingVersionId"] = data.selectedStoreListingVersionId
|
||||
if data.agentInput is not None:
|
||||
update["agentInput"] = Json(data.agentInput)
|
||||
if data.onboardingAgentExecutionId is not None:
|
||||
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
|
||||
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, **update}, # type: ignore
|
||||
"create": {"userId": user_id, **update},
|
||||
"update": update,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def reward_user(user_id: str, step: OnboardingStep):
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}-reward"):
|
||||
reward = 0
|
||||
match step:
|
||||
# Reward user when they clicked New Run during onboarding
|
||||
# This is because they need credits before scheduling a run (next step)
|
||||
case OnboardingStep.AGENT_NEW_RUN:
|
||||
reward = 300
|
||||
case OnboardingStep.GET_RESULTS:
|
||||
reward = 300
|
||||
case OnboardingStep.MARKETPLACE_ADD_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.MARKETPLACE_RUN_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_SAVE_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_RUN_AGENT:
|
||||
reward = 100
|
||||
|
||||
if reward == 0:
|
||||
return
|
||||
|
||||
onboarding = await get_user_onboarding(user_id)
|
||||
|
||||
# Skip if already rewarded
|
||||
if step in onboarding.rewardedFor:
|
||||
return
|
||||
|
||||
onboarding.rewardedFor.append(step)
|
||||
await user_credit.onboarding_reward(user_id, reward, step)
|
||||
await UserOnboarding.prisma().update(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"completedSteps": list(set(onboarding.completedSteps + [step])),
|
||||
"rewardedFor": onboarding.rewardedFor,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def clean_and_split(text: str) -> list[str]:
|
||||
"""
|
||||
Removes all special characters from a string, truncates it to 100 characters,
|
||||
@@ -186,11 +248,11 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
where={
|
||||
"id": {"in": [agent.storeListingVersionId for agent in storeAgents]},
|
||||
},
|
||||
include={"Agent": True},
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
|
||||
for listing in agentListings:
|
||||
agent = listing.Agent
|
||||
agent = listing.AgentGraph
|
||||
if agent is None:
|
||||
continue
|
||||
graph = GraphModel.from_db(agent)
|
||||
|
||||
@@ -4,10 +4,18 @@ from enum import Enum
|
||||
from typing import Awaitable, Optional
|
||||
|
||||
import aio_pika
|
||||
import aio_pika.exceptions as aio_ex
|
||||
import pika
|
||||
import pika.adapters.blocking_connection
|
||||
from pika.exceptions import AMQPError
|
||||
from pika.spec import BasicProperties
|
||||
from pydantic import BaseModel
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Settings
|
||||
@@ -161,6 +169,12 @@ class SyncRabbitMQ(RabbitMQBase):
|
||||
routing_key=queue.routing_key or queue.name,
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type((AMQPError, ConnectionError)),
|
||||
wait=wait_random_exponential(multiplier=1, max=5),
|
||||
stop=stop_after_attempt(5),
|
||||
reraise=True,
|
||||
)
|
||||
def publish_message(
|
||||
self,
|
||||
routing_key: str,
|
||||
@@ -258,6 +272,12 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
exchange, routing_key=queue.routing_key or queue.name
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type((aio_ex.AMQPError, ConnectionError)),
|
||||
wait=wait_random_exponential(multiplier=1, max=5),
|
||||
stop=stop_after_attempt(5),
|
||||
reraise=True,
|
||||
)
|
||||
async def publish_message(
|
||||
self,
|
||||
routing_key: str,
|
||||
|
||||
@@ -11,7 +11,7 @@ from fastapi import HTTPException
|
||||
from prisma import Json
|
||||
from prisma.enums import NotificationType
|
||||
from prisma.models import User
|
||||
from prisma.types import UserCreateInput, UserUpdateInput
|
||||
from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
|
||||
|
||||
from backend.data.db import prisma
|
||||
from backend.data.model import UserIntegrations, UserMetadata, UserMetadataRaw
|
||||
@@ -135,16 +135,21 @@ async def migrate_and_encrypt_user_integrations():
|
||||
"""Migrate integration credentials and OAuth states from metadata to integrations column."""
|
||||
users = await User.prisma().find_many(
|
||||
where={
|
||||
"metadata": {
|
||||
"path": ["integration_credentials"],
|
||||
"not": Json({"a": "yolo"}), # bogus value works to check if key exists
|
||||
} # type: ignore
|
||||
"metadata": cast(
|
||||
JsonFilter,
|
||||
{
|
||||
"path": ["integration_credentials"],
|
||||
"not": Json(
|
||||
{"a": "yolo"}
|
||||
), # bogus value works to check if key exists
|
||||
},
|
||||
)
|
||||
}
|
||||
)
|
||||
logger.info(f"Migrating integration credentials for {len(users)} users")
|
||||
|
||||
for user in users:
|
||||
raw_metadata = cast(UserMetadataRaw, user.metadata)
|
||||
raw_metadata = cast(dict, user.metadata)
|
||||
metadata = UserMetadata.model_validate(raw_metadata)
|
||||
|
||||
# Get existing integrations data
|
||||
@@ -160,7 +165,6 @@ async def migrate_and_encrypt_user_integrations():
|
||||
await update_user_integrations(user_id=user.id, data=integrations)
|
||||
|
||||
# Remove from metadata
|
||||
raw_metadata = dict(raw_metadata)
|
||||
raw_metadata.pop("integration_credentials", None)
|
||||
raw_metadata.pop("integration_oauth_states", None)
|
||||
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import logging
|
||||
|
||||
from backend.data import db, redis
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
GraphExecution,
|
||||
NodeExecutionResult,
|
||||
RedisExecutionEventBus,
|
||||
create_graph_execution,
|
||||
get_graph_execution,
|
||||
get_incomplete_node_executions,
|
||||
@@ -42,7 +39,7 @@ from backend.data.user import (
|
||||
update_user_integrations,
|
||||
update_user_metadata,
|
||||
)
|
||||
from backend.util.service import AppService, expose, exposed_run_and_wait
|
||||
from backend.util.service import AppService, exposed_run_and_wait
|
||||
from backend.util.settings import Config
|
||||
|
||||
config = Config()
|
||||
@@ -57,21 +54,14 @@ async def _spend_credits(
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.execution_event_bus = RedisExecutionEventBus()
|
||||
|
||||
def run_service(self) -> None:
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
self.run_and_wait(db.connect())
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
|
||||
redis.connect()
|
||||
super().run_service()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
|
||||
@@ -79,12 +69,6 @@ class DatabaseManager(AppService):
|
||||
def get_port(cls) -> int:
|
||||
return config.database_api_port
|
||||
|
||||
@expose
|
||||
def send_execution_update(
|
||||
self, execution_result: GraphExecution | NodeExecutionResult
|
||||
):
|
||||
self.execution_event_bus.publish(execution_result)
|
||||
|
||||
# Executions
|
||||
get_graph_execution = exposed_run_and_wait(get_graph_execution)
|
||||
create_graph_execution = exposed_run_and_wait(create_graph_execution)
|
||||
|
||||
@@ -5,11 +5,14 @@ import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
@@ -26,47 +29,40 @@ if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockData,
|
||||
BlockInput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.block import BlockData, BlockInput, BlockSchema, get_block
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
NodeExecutionResult,
|
||||
merge_execution_input,
|
||||
parse_execution_output,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Link, Node
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.executor.utils import (
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
CancelExecutionEvent,
|
||||
UsageTransactionMetadata,
|
||||
block_usage_cost,
|
||||
execution_usage_cost,
|
||||
get_execution_event_bus,
|
||||
get_execution_queue,
|
||||
parse_execution_output,
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util import json
|
||||
from backend.util.decorator import error_logged, time_measured
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import configure_logging
|
||||
from backend.util.process import set_service_name
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
close_service_client,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.service import close_service_client, get_service_client
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.type import convert
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -91,7 +87,7 @@ class LogMetadata:
|
||||
"node_id": node_id,
|
||||
"block_name": block_name,
|
||||
}
|
||||
self.prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|nid:{node_eid}|{block_name}]"
|
||||
self.prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]"
|
||||
|
||||
def info(self, msg: str, **extra):
|
||||
msg = self._wrap(msg, **extra)
|
||||
@@ -152,7 +148,7 @@ def execute_node(
|
||||
def update_execution_status(status: ExecutionStatus) -> NodeExecutionResult:
|
||||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||||
exec_update = db_client.update_node_execution_status(node_exec_id, status)
|
||||
db_client.send_execution_update(exec_update)
|
||||
send_execution_update(exec_update)
|
||||
return exec_update
|
||||
|
||||
node = db_client.get_node(node_id)
|
||||
@@ -192,7 +188,7 @@ def execute_node(
|
||||
# Execute the node
|
||||
input_data_str = json.dumps(input_data)
|
||||
input_size = len(input_data_str)
|
||||
log_metadata.info("Executed node with input", input=input_data_str)
|
||||
log_metadata.debug("Executed node with input", input=input_data_str)
|
||||
update_execution_status(ExecutionStatus.RUNNING)
|
||||
|
||||
# Inject extra execution arguments for the blocks via kwargs
|
||||
@@ -223,7 +219,7 @@ def execute_node(
|
||||
):
|
||||
output_data = json.convert_pydantic_to_json(output_data)
|
||||
output_size += len(json.dumps(output_data))
|
||||
log_metadata.info("Node produced output", **{output_name: output_data})
|
||||
log_metadata.debug("Node produced output", **{output_name: output_data})
|
||||
push_output(output_name, output_data)
|
||||
outputs[output_name] = output_data
|
||||
for execution in _enqueue_next_nodes(
|
||||
@@ -288,7 +284,7 @@ def _enqueue_next_nodes(
|
||||
exec_update = db_client.update_node_execution_status(
|
||||
node_exec_id, ExecutionStatus.QUEUED, data
|
||||
)
|
||||
db_client.send_execution_update(exec_update)
|
||||
send_execution_update(exec_update)
|
||||
return NodeExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
@@ -400,60 +396,6 @@ def _enqueue_next_nodes(
|
||||
]
|
||||
|
||||
|
||||
def validate_exec(
|
||||
node: Node,
|
||||
data: BlockInput,
|
||||
resolve_input: bool = True,
|
||||
) -> tuple[BlockInput | None, str]:
|
||||
"""
|
||||
Validate the input data for a node execution.
|
||||
|
||||
Args:
|
||||
node: The node to execute.
|
||||
data: The input data for the node execution.
|
||||
resolve_input: Whether to resolve dynamic pins into dict/list/object.
|
||||
|
||||
Returns:
|
||||
A tuple of the validated data and the block name.
|
||||
If the data is invalid, the first element will be None, and the second element
|
||||
will be an error message.
|
||||
If the data is valid, the first element will be the resolved input data, and
|
||||
the second element will be the block name.
|
||||
"""
|
||||
node_block: Block | None = get_block(node.block_id)
|
||||
if not node_block:
|
||||
return None, f"Block for {node.block_id} not found."
|
||||
schema = node_block.input_schema
|
||||
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
if (value := data.get(name)) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data (without default values) should contain all required fields.
|
||||
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
|
||||
if missing_links := schema.get_missing_links(data, node.input_links):
|
||||
return None, f"{error_prefix} unpopulated links {missing_links}"
|
||||
|
||||
# Merge input data with default values and resolve dynamic dict/list/object pins.
|
||||
input_default = schema.get_input_defaults(node.input_default)
|
||||
data = {**input_default, **data}
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
if missing_input := schema.get_missing_input(data):
|
||||
return None, f"{error_prefix} missing input {missing_input}"
|
||||
|
||||
# Last validation: Validate the input values against the schema.
|
||||
if error := schema.get_mismatch_error(data):
|
||||
error_message = f"{error_prefix} {error}"
|
||||
logger.error(error_message)
|
||||
return None, error_message
|
||||
|
||||
return data, node_block.name
|
||||
|
||||
|
||||
class Executor:
|
||||
"""
|
||||
This class contains event handlers for the process pool executor events.
|
||||
@@ -633,7 +575,13 @@ class Executor:
|
||||
exec_meta = cls.db_client.update_graph_execution_start_time(
|
||||
graph_exec.graph_exec_id
|
||||
)
|
||||
cls.db_client.send_execution_update(exec_meta)
|
||||
if exec_meta is None:
|
||||
logger.warning(
|
||||
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution is not found or not currently in the QUEUED state."
|
||||
)
|
||||
return
|
||||
|
||||
send_execution_update(exec_meta)
|
||||
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
|
||||
graph_exec, cancel, log_metadata
|
||||
)
|
||||
@@ -646,7 +594,7 @@ class Executor:
|
||||
status=status,
|
||||
stats=exec_stats,
|
||||
):
|
||||
cls.db_client.send_execution_update(graph_exec_result)
|
||||
send_execution_update(graph_exec_result)
|
||||
|
||||
cls._handle_agent_run_notif(graph_exec, exec_stats)
|
||||
|
||||
@@ -759,7 +707,7 @@ class Executor:
|
||||
status=execution_status,
|
||||
stats=execution_stats,
|
||||
):
|
||||
cls.db_client.send_execution_update(_graph_exec)
|
||||
send_execution_update(_graph_exec)
|
||||
else:
|
||||
logger.error(
|
||||
"Callback for "
|
||||
@@ -776,10 +724,10 @@ class Executor:
|
||||
execution_status = ExecutionStatus.TERMINATED
|
||||
return execution_stats, execution_status, error
|
||||
|
||||
exec_data = queue.get()
|
||||
queued_node_exec = queue.get()
|
||||
|
||||
# Avoid parallel execution of the same node.
|
||||
execution = running_executions.get(exec_data.node_id)
|
||||
execution = running_executions.get(queued_node_exec.node_id)
|
||||
if execution and not execution.ready():
|
||||
# TODO (performance improvement):
|
||||
# Wait for the completion of the same node execution is blocking.
|
||||
@@ -788,18 +736,18 @@ class Executor:
|
||||
execution.wait()
|
||||
|
||||
log_metadata.debug(
|
||||
f"Dispatching node execution {exec_data.node_exec_id} "
|
||||
f"for node {exec_data.node_id}",
|
||||
f"Dispatching node execution {queued_node_exec.node_exec_id} "
|
||||
f"for node {queued_node_exec.node_id}",
|
||||
)
|
||||
|
||||
try:
|
||||
exec_cost_counter = cls._charge_usage(
|
||||
node_exec=exec_data,
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=exec_cost_counter + 1,
|
||||
execution_stats=execution_stats,
|
||||
)
|
||||
except InsufficientBalanceError as error:
|
||||
node_exec_id = exec_data.node_exec_id
|
||||
node_exec_id = queued_node_exec.node_exec_id
|
||||
cls.db_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name="error",
|
||||
@@ -810,7 +758,7 @@ class Executor:
|
||||
exec_update = cls.db_client.update_node_execution_status(
|
||||
node_exec_id, execution_status
|
||||
)
|
||||
cls.db_client.send_execution_update(exec_update)
|
||||
send_execution_update(exec_update)
|
||||
|
||||
cls._handle_low_balance_notif(
|
||||
graph_exec.user_id,
|
||||
@@ -820,10 +768,23 @@ class Executor:
|
||||
)
|
||||
raise
|
||||
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
# Add credentials input overrides
|
||||
node_id = queued_node_exec.node_id
|
||||
if (node_creds_map := graph_exec.node_credentials_input_map) and (
|
||||
node_field_creds_map := node_creds_map.get(node_id)
|
||||
):
|
||||
queued_node_exec.data.update(
|
||||
{
|
||||
field_name: creds_meta.model_dump()
|
||||
for field_name, creds_meta in node_field_creds_map.items()
|
||||
}
|
||||
)
|
||||
|
||||
# Initiate node execution
|
||||
running_executions[queued_node_exec.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(queue, exec_data),
|
||||
callback=make_exec_callback(exec_data),
|
||||
(queue, queued_node_exec),
|
||||
callback=make_exec_callback(queued_node_exec),
|
||||
)
|
||||
|
||||
# Avoid terminating graph execution when some nodes are still running.
|
||||
@@ -927,22 +888,43 @@ class Executor:
|
||||
)
|
||||
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
class ExecutionManager(AppProcess):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecutionEntry]()
|
||||
self.running = True
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return settings.config.execution_manager_port
|
||||
|
||||
def run_service(self):
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
def run(self):
|
||||
retry_count_max = settings.config.execution_manager_loop_max_retry
|
||||
retry_count = 0
|
||||
|
||||
self.credentials_store = IntegrationCredentialsStore()
|
||||
for retry_count in range(retry_count_max):
|
||||
try:
|
||||
self._run()
|
||||
except Exception as e:
|
||||
if not self.running:
|
||||
break
|
||||
logger.exception(
|
||||
f"[{self.service_name}] Error in execution manager: {e}"
|
||||
)
|
||||
|
||||
if retry_count >= retry_count_max:
|
||||
logger.error(
|
||||
f"[{self.service_name}] Max retries reached ({retry_count_max}), exiting..."
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.info(
|
||||
f"[{self.service_name}] Retrying execution loop in {retry_count} seconds..."
|
||||
)
|
||||
time.sleep(retry_count)
|
||||
|
||||
def _run(self):
|
||||
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
|
||||
self.executor = ProcessPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
@@ -952,25 +934,122 @@ class ExecutionManager(AppService):
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
|
||||
redis.connect()
|
||||
|
||||
sync_manager = multiprocessing.Manager()
|
||||
while True:
|
||||
graph_exec_data = self.queue.get()
|
||||
graph_exec_id = graph_exec_data.graph_exec_id
|
||||
logger.debug(
|
||||
f"[ExecutionManager] Dispatching graph execution {graph_exec_id}"
|
||||
)
|
||||
cancel_event = sync_manager.Event()
|
||||
future = self.executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_data, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
future.add_done_callback(
|
||||
lambda _: self.active_graph_runs.pop(graph_exec_id, None)
|
||||
# Consume Cancel & Run execution requests.
|
||||
clear_thread_cache(get_execution_queue)
|
||||
channel = get_execution_queue().get_channel()
|
||||
channel.basic_qos(prefetch_count=self.pool_size)
|
||||
channel.basic_consume(
|
||||
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
on_message_callback=self._handle_cancel_message,
|
||||
auto_ack=True,
|
||||
)
|
||||
channel.basic_consume(
|
||||
queue=GRAPH_EXECUTION_QUEUE_NAME,
|
||||
on_message_callback=self._handle_run_message,
|
||||
auto_ack=False,
|
||||
)
|
||||
|
||||
logger.info(f"[{self.service_name}] Ready to consume messages...")
|
||||
channel.start_consuming()
|
||||
|
||||
def _handle_cancel_message(
|
||||
self,
|
||||
channel: BlockingChannel,
|
||||
method: Basic.Deliver,
|
||||
properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
"""
|
||||
Called whenever we receive a CANCEL message from the queue.
|
||||
(With auto_ack=True, message is considered 'acked' automatically.)
|
||||
"""
|
||||
try:
|
||||
request = CancelExecutionEvent.model_validate_json(body)
|
||||
graph_exec_id = request.graph_exec_id
|
||||
if not graph_exec_id:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Cancel message missing 'graph_exec_id'"
|
||||
)
|
||||
return
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
logger.debug(
|
||||
f"[{self.service_name}] Cancel received for {graph_exec_id} but not active."
|
||||
)
|
||||
return
|
||||
|
||||
_, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
logger.info(f"[{self.service_name}] Received cancel for {graph_exec_id}")
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
else:
|
||||
logger.debug(
|
||||
f"[{self.service_name}] Cancel already set for {graph_exec_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling cancel message: {e}")
|
||||
|
||||
def _handle_run_message(
|
||||
self,
|
||||
channel: BlockingChannel,
|
||||
method: Basic.Deliver,
|
||||
properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
delivery_tag = method.delivery_tag
|
||||
try:
|
||||
graph_exec_entry = GraphExecutionEntry.model_validate_json(body)
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.service_name}] Could not parse run message: {e}")
|
||||
channel.basic_nack(delivery_tag, requeue=False)
|
||||
return
|
||||
|
||||
graph_exec_id = graph_exec_entry.graph_exec_id
|
||||
logger.info(
|
||||
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
|
||||
)
|
||||
if graph_exec_id in self.active_graph_runs:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
|
||||
)
|
||||
channel.basic_nack(delivery_tag, requeue=False)
|
||||
return
|
||||
|
||||
cancel_event = multiprocessing.Manager().Event()
|
||||
future = self.executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_entry, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
|
||||
def _on_run_done(f: Future):
|
||||
logger.info(f"[{self.service_name}] Run completed for {graph_exec_id}")
|
||||
try:
|
||||
self.active_graph_runs.pop(graph_exec_id, None)
|
||||
if f.exception():
|
||||
logger.error(
|
||||
f"[{self.service_name}] Execution for {graph_exec_id} failed: {f.exception()}"
|
||||
)
|
||||
channel.connection.add_callback_threadsafe(
|
||||
lambda: channel.basic_nack(delivery_tag, requeue=False)
|
||||
)
|
||||
else:
|
||||
channel.connection.add_callback_threadsafe(
|
||||
lambda: channel.basic_ack(delivery_tag)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.service_name}] Error acknowledging message: {e}")
|
||||
|
||||
future.add_done_callback(_on_run_done)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down service loop...")
|
||||
self.running = False
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down RabbitMQ channel...")
|
||||
get_execution_queue().get_channel().stop_consuming()
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down graph executor pool...")
|
||||
self.executor.shutdown(cancel_futures=True)
|
||||
|
||||
@@ -981,175 +1060,6 @@ class ExecutionManager(AppService):
|
||||
def db_client(self) -> "DatabaseManager":
|
||||
return get_db_client()
|
||||
|
||||
@expose
|
||||
def add_execution(
|
||||
self,
|
||||
graph_id: str,
|
||||
data: BlockInput,
|
||||
user_id: str,
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: str | None = None,
|
||||
) -> GraphExecutionEntry:
|
||||
graph: GraphModel | None = self.db_client.get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise ValueError(f"Graph #{graph_id} not found.")
|
||||
|
||||
graph.validate_graph(for_run=True)
|
||||
self._validate_node_input_credentials(graph, user_id)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
input_data = {}
|
||||
block = node.block
|
||||
|
||||
# Note block should never be executed.
|
||||
if block.block_type == BlockType.NOTE:
|
||||
continue
|
||||
|
||||
# Extract request input data, and assign it to the input pin.
|
||||
if block.block_type == BlockType.INPUT:
|
||||
input_name = node.input_default.get("name")
|
||||
if input_name and input_name in data:
|
||||
input_data = {"value": data[input_name]}
|
||||
|
||||
# Extract webhook payload, and assign it to the input pin
|
||||
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
||||
if (
|
||||
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
and node.webhook_id
|
||||
):
|
||||
if webhook_payload_key not in data:
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} webhook payload is missing"
|
||||
)
|
||||
input_data = {"payload": data[webhook_payload_key]}
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
raise ValueError(error)
|
||||
else:
|
||||
nodes_input.append((node.id, input_data))
|
||||
|
||||
if not nodes_input:
|
||||
raise ValueError(
|
||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||
)
|
||||
|
||||
graph_exec = self.db_client.create_graph_execution(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
nodes_input=nodes_input,
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
self.db_client.send_execution_update(graph_exec)
|
||||
|
||||
graph_exec_entry = GraphExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version or 0,
|
||||
graph_exec_id=graph_exec.id,
|
||||
start_node_execs=[
|
||||
NodeExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
data=node_exec.input_data,
|
||||
)
|
||||
for node_exec in graph_exec.node_executions
|
||||
],
|
||||
)
|
||||
self.queue.add(graph_exec_entry)
|
||||
|
||||
return graph_exec_entry
|
||||
|
||||
@expose
|
||||
def cancel_execution(self, graph_exec_id: str) -> None:
|
||||
"""
|
||||
Mechanism:
|
||||
1. Set the cancel event
|
||||
2. Graph executor's cancel handler thread detects the event, terminates workers,
|
||||
reinitializes worker pool, and returns.
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
logger.warning(
|
||||
f"Graph execution #{graph_exec_id} not active/running: "
|
||||
"possibly already completed/cancelled."
|
||||
)
|
||||
else:
|
||||
future, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
future.result()
|
||||
|
||||
# Update the status of the graph & node executions
|
||||
self.db_client.update_graph_execution_stats(
|
||||
graph_exec_id,
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
node_execs = self.db_client.get_node_execution_results(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
self.db_client.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
for node_exec in node_execs:
|
||||
node_exec.status = ExecutionStatus.TERMINATED
|
||||
self.db_client.send_execution_update(node_exec)
|
||||
|
||||
def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
for node in graph.nodes:
|
||||
block = node.block
|
||||
|
||||
# Find any fields of type CredentialsMetaInput
|
||||
credentials_fields = cast(
|
||||
type[BlockSchema], block.input_schema
|
||||
).get_credentials_fields()
|
||||
if not credentials_fields:
|
||||
continue
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = self.credentials_store.get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
f"Unknown credentials #{credentials_meta.id} "
|
||||
f"for node #{node.id} input '{field_name}'"
|
||||
)
|
||||
if (
|
||||
credentials.provider != credentials_meta.provider
|
||||
or credentials.type != credentials_meta.type
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch: "
|
||||
f"{credentials_meta.type}<>{credentials.type};"
|
||||
f"{credentials_meta.provider}<>{credentials.provider}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch"
|
||||
)
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
@@ -1168,6 +1078,10 @@ def get_notification_service() -> "NotificationManager":
|
||||
return get_service_client(NotificationManager)
|
||||
|
||||
|
||||
def send_execution_update(entry: GraphExecution | NodeExecutionResult):
|
||||
return get_execution_event_bus().publish(entry)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def synchronized(key: str, timeout: int = 60):
|
||||
lock: RedisLock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)
|
||||
|
||||
@@ -16,7 +16,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
from backend.executor.manager import ExecutionManager
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import Config
|
||||
@@ -57,11 +57,6 @@ def job_listener(event):
|
||||
log(f"Job {event.job_id} completed successfully.")
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_client() -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_client():
|
||||
from backend.notifications import NotificationManager
|
||||
@@ -73,9 +68,9 @@ def execute_graph(**kwargs):
|
||||
args = ExecutionJobArgs(**kwargs)
|
||||
try:
|
||||
log(f"Executing recurring job for graph #{args.graph_id}")
|
||||
get_execution_client().add_execution(
|
||||
execution_utils.add_graph_execution(
|
||||
graph_id=args.graph_id,
|
||||
data=args.input_data,
|
||||
inputs=args.input_data,
|
||||
user_id=args.user_id,
|
||||
graph_version=args.graph_version,
|
||||
)
|
||||
@@ -164,11 +159,6 @@ class Scheduler(AppService):
|
||||
def db_pool_size(cls) -> int:
|
||||
return config.scheduler_db_pool_size
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def execution_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def notification_client(self) -> NotificationManager:
|
||||
@@ -176,7 +166,7 @@ class Scheduler(AppService):
|
||||
|
||||
def run_service(self):
|
||||
load_dotenv()
|
||||
db_schema, db_url = _extract_schema_from_url(os.getenv("DATABASE_URL"))
|
||||
db_schema, db_url = _extract_schema_from_url(os.getenv("DIRECT_URL"))
|
||||
self.scheduler = BlockingScheduler(
|
||||
jobstores={
|
||||
Jobstores.EXECUTION.value: SQLAlchemyJobStore(
|
||||
|
||||
@@ -1,11 +1,94 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockInput
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockData,
|
||||
BlockInput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCostType
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
ExecutionStatus,
|
||||
GraphExecutionStats,
|
||||
GraphExecutionWithNodes,
|
||||
RedisExecutionEventBus,
|
||||
create_graph_execution,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_status_batch,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Node, get_graph
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.rabbitmq import (
|
||||
AsyncRabbitMQ,
|
||||
Exchange,
|
||||
ExchangeType,
|
||||
Queue,
|
||||
RabbitMQConfig,
|
||||
SyncRabbitMQ,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
config = Config()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ============ Resource Helpers ============ #
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_event_bus() -> RedisExecutionEventBus:
|
||||
return RedisExecutionEventBus()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_async_execution_event_bus() -> AsyncRedisExecutionEventBus:
|
||||
return AsyncRedisExecutionEventBus()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_queue() -> SyncRabbitMQ:
|
||||
client = SyncRabbitMQ(create_execution_queue_config())
|
||||
client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def get_async_execution_queue() -> AsyncRabbitMQ:
|
||||
client = AsyncRabbitMQ(create_execution_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
return IntegrationCredentialsStore()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client() -> "DatabaseManager":
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
# ============ Execution Cost Helpers ============ #
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
@@ -95,3 +178,571 @@ def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bo
|
||||
or (input_data.get(k) and _is_cost_filter_match(v, input_data[k]))
|
||||
for k, v in cost_filter.items()
|
||||
)
|
||||
|
||||
|
||||
# ============ Execution Input Helpers ============ #
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
"""
|
||||
Extracts partial output data by name from a given BlockData.
|
||||
|
||||
The function supports extracting data from lists, dictionaries, and objects
|
||||
using specific naming conventions:
|
||||
- For lists: <output_name>_$_<index>
|
||||
- For dictionaries: <output_name>_#_<key>
|
||||
- For objects: <output_name>_@_<attribute>
|
||||
|
||||
Args:
|
||||
output (BlockData): A tuple containing the output name and data.
|
||||
name (str): The name used to extract specific data from the output.
|
||||
|
||||
Returns:
|
||||
Any | None: The extracted data if found, otherwise None.
|
||||
|
||||
Examples:
|
||||
>>> output = ("result", [10, 20, 30])
|
||||
>>> parse_execution_output(output, "result_$_1")
|
||||
20
|
||||
|
||||
>>> output = ("config", {"key1": "value1", "key2": "value2"})
|
||||
>>> parse_execution_output(output, "config_#_key1")
|
||||
'value1'
|
||||
|
||||
>>> class Sample:
|
||||
... attr1 = "value1"
|
||||
... attr2 = "value2"
|
||||
>>> output = ("object", Sample())
|
||||
>>> parse_execution_output(output, "object_@_attr1")
|
||||
'value1'
|
||||
"""
|
||||
output_name, output_data = output
|
||||
|
||||
if name == output_name:
|
||||
return output_data
|
||||
|
||||
if name.startswith(f"{output_name}{LIST_SPLIT}"):
|
||||
index = int(name.split(LIST_SPLIT)[1])
|
||||
if not isinstance(output_data, list) or len(output_data) <= index:
|
||||
return None
|
||||
return output_data[int(name.split(LIST_SPLIT)[1])]
|
||||
|
||||
if name.startswith(f"{output_name}{DICT_SPLIT}"):
|
||||
index = name.split(DICT_SPLIT)[1]
|
||||
if not isinstance(output_data, dict) or index not in output_data:
|
||||
return None
|
||||
return output_data[index]
|
||||
|
||||
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
|
||||
index = name.split(OBJC_SPLIT)[1]
|
||||
if isinstance(output_data, object) and hasattr(output_data, index):
|
||||
return getattr(output_data, index)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_exec(
|
||||
node: Node,
|
||||
data: BlockInput,
|
||||
resolve_input: bool = True,
|
||||
) -> tuple[BlockInput | None, str]:
|
||||
"""
|
||||
Validate the input data for a node execution.
|
||||
|
||||
Args:
|
||||
node: The node to execute.
|
||||
data: The input data for the node execution.
|
||||
resolve_input: Whether to resolve dynamic pins into dict/list/object.
|
||||
|
||||
Returns:
|
||||
A tuple of the validated data and the block name.
|
||||
If the data is invalid, the first element will be None, and the second element
|
||||
will be an error message.
|
||||
If the data is valid, the first element will be the resolved input data, and
|
||||
the second element will be the block name.
|
||||
"""
|
||||
node_block: Block | None = get_block(node.block_id)
|
||||
if not node_block:
|
||||
return None, f"Block for {node.block_id} not found."
|
||||
schema = node_block.input_schema
|
||||
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
if (value := data.get(name)) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data (without default values) should contain all required fields.
|
||||
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
|
||||
if missing_links := schema.get_missing_links(data, node.input_links):
|
||||
return None, f"{error_prefix} unpopulated links {missing_links}"
|
||||
|
||||
# Merge input data with default values and resolve dynamic dict/list/object pins.
|
||||
input_default = schema.get_input_defaults(node.input_default)
|
||||
data = {**input_default, **data}
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
if missing_input := schema.get_missing_input(data):
|
||||
return None, f"{error_prefix} missing input {missing_input}"
|
||||
|
||||
# Last validation: Validate the input values against the schema.
|
||||
if error := schema.get_mismatch_error(data):
|
||||
error_message = f"{error_prefix} {error}"
|
||||
logger.error(error_message)
|
||||
return None, error_message
|
||||
|
||||
return data, node_block.name
|
||||
|
||||
|
||||
def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
"""
|
||||
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
|
||||
|
||||
This function processes input keys that follow specific patterns to merge them into a unified structure:
|
||||
- `<input_name>_$_<index>` for list inputs.
|
||||
- `<input_name>_#_<index>` for dictionary inputs.
|
||||
- `<input_name>_@_<index>` for object inputs.
|
||||
|
||||
Args:
|
||||
data (BlockInput): A dictionary containing input keys and their corresponding values.
|
||||
|
||||
Returns:
|
||||
BlockInput: A dictionary with merged inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If a list index is not an integer.
|
||||
|
||||
Examples:
|
||||
>>> data = {
|
||||
... "list_$_0": "a",
|
||||
... "list_$_1": "b",
|
||||
... "dict_#_key1": "value1",
|
||||
... "dict_#_key2": "value2",
|
||||
... "object_@_attr1": "value1",
|
||||
... "object_@_attr2": "value2"
|
||||
... }
|
||||
>>> merge_execution_input(data)
|
||||
{
|
||||
"list": ["a", "b"],
|
||||
"dict": {"key1": "value1", "key2": "value2"},
|
||||
"object": <MockObject attr1="value1" attr2="value2">
|
||||
}
|
||||
"""
|
||||
|
||||
# Merge all input with <input_name>_$_<index> into a single list.
|
||||
items = list(data.items())
|
||||
|
||||
for key, value in items:
|
||||
if LIST_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(LIST_SPLIT)
|
||||
if not index.isdigit():
|
||||
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
|
||||
|
||||
data[name] = data.get(name, [])
|
||||
if int(index) >= len(data[name]):
|
||||
# Pad list with empty string on missing indices.
|
||||
data[name].extend([""] * (int(index) - len(data[name]) + 1))
|
||||
data[name][int(index)] = value
|
||||
|
||||
# Merge all input with <input_name>_#_<index> into a single dict.
|
||||
for key, value in items:
|
||||
if DICT_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(DICT_SPLIT)
|
||||
data[name] = data.get(name, {})
|
||||
data[name][index] = value
|
||||
|
||||
# Merge all input with <input_name>_@_<index> into a single object.
|
||||
for key, value in items:
|
||||
if OBJC_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(OBJC_SPLIT)
|
||||
if name not in data or not isinstance(data[name], object):
|
||||
data[name] = MockObject()
|
||||
setattr(data[name], index, value)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
for node in graph.nodes:
|
||||
block = node.block
|
||||
|
||||
# Find any fields of type CredentialsMetaInput
|
||||
credentials_fields = cast(
|
||||
type[BlockSchema], block.input_schema
|
||||
).get_credentials_fields()
|
||||
if not credentials_fields:
|
||||
continue
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
if (
|
||||
node_credentials_input_map
|
||||
and (node_credentials_inputs := node_credentials_input_map.get(node.id))
|
||||
and field_name in node_credentials_inputs
|
||||
):
|
||||
credentials_meta = node_credentials_input_map[node.id][field_name]
|
||||
elif field_name in node.input_default:
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Credentials absent for {block.name} node #{node.id} "
|
||||
f"input '{field_name}'"
|
||||
)
|
||||
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = get_integration_credentials_store().get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
f"Unknown credentials #{credentials_meta.id} "
|
||||
f"for node #{node.id} input '{field_name}'"
|
||||
)
|
||||
if (
|
||||
credentials.provider != credentials_meta.provider
|
||||
or credentials.type != credentials_meta.type
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch: "
|
||||
f"{credentials_meta.type}<>{credentials.type};"
|
||||
f"{credentials_meta.provider}<>{credentials.provider}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch"
|
||||
)
|
||||
|
||||
|
||||
def make_node_credentials_input_map(
|
||||
graph: GraphModel,
|
||||
graph_credentials_input: dict[str, CredentialsMetaInput],
|
||||
) -> dict[str, dict[str, CredentialsMetaInput]]:
|
||||
"""
|
||||
Maps credentials for an execution to the correct nodes.
|
||||
|
||||
Params:
|
||||
graph: The graph to be executed.
|
||||
graph_credentials_input: A (graph_input_name, credentials_meta) map.
|
||||
|
||||
Returns:
|
||||
dict[node_id, dict[field_name, CredentialsMetaInput]]: Node credentials input map.
|
||||
"""
|
||||
result: dict[str, dict[str, CredentialsMetaInput]] = {}
|
||||
|
||||
# Get aggregated credentials fields for the graph
|
||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||
|
||||
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
|
||||
# Best-effort map: skip missing items
|
||||
if graph_input_name not in graph_credentials_input:
|
||||
continue
|
||||
|
||||
# Use passed-in credentials for all compatible node input fields
|
||||
for node_id, node_field_name in compatible_node_fields:
|
||||
if node_id not in result:
|
||||
result[node_id] = {}
|
||||
result[node_id][node_field_name] = graph_credentials_input[graph_input_name]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def construct_node_execution_input(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
This function checks the graph for starting nodes, validates the input data
|
||||
against the schema, and resolves dynamic input pins into a single list,
|
||||
dictionary, or object.
|
||||
|
||||
Args:
|
||||
graph (GraphModel): The graph model to execute.
|
||||
user_id (str): The ID of the user executing the graph.
|
||||
data (BlockInput): The input data for the graph execution.
|
||||
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
|
||||
|
||||
Returns:
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||
the corresponding input data for that node.
|
||||
"""
|
||||
graph.validate_graph(for_run=True)
|
||||
_validate_node_input_credentials(graph, user_id, node_credentials_input_map)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
input_data = {}
|
||||
block = node.block
|
||||
|
||||
# Note block should never be executed.
|
||||
if block.block_type == BlockType.NOTE:
|
||||
continue
|
||||
|
||||
# Extract request input data, and assign it to the input pin.
|
||||
if block.block_type == BlockType.INPUT:
|
||||
input_name = node.input_default.get("name")
|
||||
if input_name and input_name in graph_inputs:
|
||||
input_data = {"value": graph_inputs[input_name]}
|
||||
|
||||
# Extract webhook payload, and assign it to the input pin
|
||||
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
||||
if (
|
||||
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
and node.webhook_id
|
||||
):
|
||||
if webhook_payload_key not in graph_inputs:
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} webhook payload is missing"
|
||||
)
|
||||
input_data = {"payload": graph_inputs[webhook_payload_key]}
|
||||
|
||||
# Apply node credentials overrides
|
||||
if node_credentials_input_map and (
|
||||
node_credentials := node_credentials_input_map.get(node.id)
|
||||
):
|
||||
input_data.update({k: v.model_dump() for k, v in node_credentials.items()})
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
raise ValueError(error)
|
||||
else:
|
||||
nodes_input.append((node.id, input_data))
|
||||
|
||||
if not nodes_input:
|
||||
raise ValueError(
|
||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||
)
|
||||
|
||||
return nodes_input
|
||||
|
||||
|
||||
# ============ Execution Queue Helpers ============ #
|
||||
|
||||
|
||||
class CancelExecutionEvent(BaseModel):
|
||||
graph_exec_id: str
|
||||
|
||||
|
||||
GRAPH_EXECUTION_EXCHANGE = Exchange(
|
||||
name="graph_execution",
|
||||
type=ExchangeType.DIRECT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
GRAPH_EXECUTION_QUEUE_NAME = "graph_execution_queue"
|
||||
GRAPH_EXECUTION_ROUTING_KEY = "graph_execution.run"
|
||||
|
||||
GRAPH_EXECUTION_CANCEL_EXCHANGE = Exchange(
|
||||
name="graph_execution_cancel",
|
||||
type=ExchangeType.FANOUT,
|
||||
durable=True,
|
||||
auto_delete=True,
|
||||
)
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME = "graph_execution_cancel_queue"
|
||||
|
||||
|
||||
def create_execution_queue_config() -> RabbitMQConfig:
|
||||
"""
|
||||
Define two exchanges and queues:
|
||||
- 'graph_execution' (DIRECT) for run tasks.
|
||||
- 'graph_execution_cancel' (FANOUT) for cancel requests.
|
||||
"""
|
||||
run_queue = Queue(
|
||||
name=GRAPH_EXECUTION_QUEUE_NAME,
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
cancel_queue = Queue(
|
||||
name=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
|
||||
routing_key="", # not used for FANOUT
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
return RabbitMQConfig(
|
||||
vhost="/",
|
||||
exchanges=[GRAPH_EXECUTION_EXCHANGE, GRAPH_EXECUTION_CANCEL_EXCHANGE],
|
||||
queues=[run_queue, cancel_queue],
|
||||
)
|
||||
|
||||
|
||||
async def add_graph_execution_async(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
inputs: BlockInput,
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
|
||||
Args:
|
||||
graph_id: The ID of the graph to execute.
|
||||
user_id: The ID of the user executing the graph.
|
||||
inputs: The input data for the graph execution.
|
||||
preset_id: The ID of the preset to use.
|
||||
graph_version: The version of the graph to execute.
|
||||
graph_credentials_inputs: Credentials inputs to use in the execution.
|
||||
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
|
||||
Returns:
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
ValueError: If the graph is not found or if there are validation errors.
|
||||
""" # noqa
|
||||
graph: GraphModel | None = await get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
node_credentials_input_map = (
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else None
|
||||
)
|
||||
|
||||
graph_exec = await create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
),
|
||||
preset_id=preset_id,
|
||||
)
|
||||
try:
|
||||
queue = await get_async_execution_queue()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
if node_credentials_input_map:
|
||||
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
|
||||
await queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
|
||||
bus = get_async_execution_event_bus()
|
||||
await bus.publish(graph_exec)
|
||||
|
||||
return graph_exec
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
|
||||
|
||||
await update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
|
||||
ExecutionStatus.FAILED,
|
||||
)
|
||||
await update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
stats=GraphExecutionStats(error=str(e)),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def add_graph_execution(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
inputs: BlockInput,
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
|
||||
Args:
|
||||
graph_id: The ID of the graph to execute.
|
||||
user_id: The ID of the user executing the graph.
|
||||
inputs: The input data for the graph execution.
|
||||
preset_id: The ID of the preset to use.
|
||||
graph_version: The version of the graph to execute.
|
||||
graph_credentials_inputs: Credentials inputs to use in the execution.
|
||||
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
|
||||
Returns:
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
ValueError: If the graph is not found or if there are validation errors.
|
||||
"""
|
||||
db = get_db_client()
|
||||
graph: GraphModel | None = db.get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
node_credentials_input_map = (
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else None
|
||||
)
|
||||
|
||||
graph_exec = db.create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
),
|
||||
preset_id=preset_id,
|
||||
)
|
||||
try:
|
||||
queue = get_execution_queue()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
if node_credentials_input_map:
|
||||
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
|
||||
queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
|
||||
bus = get_execution_event_bus()
|
||||
bus.publish(graph_exec)
|
||||
|
||||
return graph_exec
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
|
||||
|
||||
db.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
|
||||
ExecutionStatus.FAILED,
|
||||
)
|
||||
db.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
stats=GraphExecutionStats(error=str(e)),
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Dict, List, Optional, Sequence
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
from typing_extensions import TypedDict
|
||||
@@ -13,17 +12,10 @@ from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKey
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.execution import NodeExecutionResult
|
||||
from backend.executor import ExecutionManager
|
||||
from backend.executor.utils import add_graph_execution_async
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_manager_client() -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -98,20 +90,20 @@ def execute_graph_block(
|
||||
path="/graphs/{graph_id}/execute/{graph_version}",
|
||||
tags=["graphs"],
|
||||
)
|
||||
def execute_graph(
|
||||
async def execute_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
graph_exec = execution_manager_client().add_execution(
|
||||
graph_id,
|
||||
graph_version=graph_version,
|
||||
data=node_input,
|
||||
graph_exec = await add_graph_execution_async(
|
||||
graph_id=graph_id,
|
||||
user_id=api_key.user_id,
|
||||
inputs=node_input,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
return {"id": graph_exec.graph_exec_id}
|
||||
return {"id": graph_exec.id}
|
||||
except Exception as e:
|
||||
msg = str(e).encode().decode("unicode_escape")
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -14,13 +15,12 @@ from backend.data.integrations import (
|
||||
wait_for_webhook_event,
|
||||
)
|
||||
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
|
||||
from backend.executor.manager import ExecutionManager
|
||||
from backend.executor.utils import add_graph_execution_async
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.util.exceptions import NeedConfirmation, NotFoundError
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -309,19 +309,22 @@ async def webhook_ingress_generic(
|
||||
if not webhook.attached_nodes:
|
||||
return
|
||||
|
||||
executor = get_service_client(ExecutionManager)
|
||||
executions: list[Awaitable] = []
|
||||
for node in webhook.attached_nodes:
|
||||
logger.debug(f"Webhook-attached node: {node}")
|
||||
if not node.is_triggered_by_event_type(event_type):
|
||||
logger.debug(f"Node #{node.id} doesn't trigger on event {event_type}")
|
||||
continue
|
||||
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
|
||||
executor.add_execution(
|
||||
graph_id=node.graph_id,
|
||||
graph_version=node.graph_version,
|
||||
data={f"webhook_{webhook_id}_payload": payload},
|
||||
user_id=webhook.user_id,
|
||||
executions.append(
|
||||
add_graph_execution_async(
|
||||
user_id=webhook.user_id,
|
||||
graph_id=node.graph_id,
|
||||
graph_version=node.graph_version,
|
||||
inputs={f"webhook_{webhook_id}_payload": payload},
|
||||
)
|
||||
)
|
||||
asyncio.gather(*executions)
|
||||
|
||||
|
||||
@router.post("/webhooks/{webhook_id}/ping")
|
||||
|
||||
@@ -17,7 +17,6 @@ import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.postmark.postmark
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.store_admin_routes
|
||||
@@ -29,6 +28,7 @@ import backend.server.v2.store.model
|
||||
import backend.server.v2.store.routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.external.api import external_app
|
||||
@@ -57,8 +57,7 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.block.initialize_blocks()
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
# FIXME ERROR: operator does not exist: text ? unknown
|
||||
# await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
await backend.data.db.disconnect()
|
||||
@@ -156,11 +155,12 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
graph_version: Optional[int] = None,
|
||||
node_input: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
return backend.server.routers.v1.execute_graph(
|
||||
return await backend.server.routers.v1.execute_graph(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
node_input=node_input or {},
|
||||
inputs=node_input or {},
|
||||
credentials_inputs={},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -275,7 +275,9 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
return backend.server.integrations.router.create_credentials(
|
||||
from backend.server.integrations.router import create_credentials
|
||||
|
||||
return create_credentials(
|
||||
user_id=user_id, provider=provider, credentials=credentials
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
import backend.data.block
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.v2.library.db as library_db
|
||||
@@ -31,7 +30,7 @@ from backend.data.api_key import (
|
||||
suspend_api_key,
|
||||
update_api_key_permissions,
|
||||
)
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -41,6 +40,8 @@ from backend.data.credit import (
|
||||
get_user_credit_model,
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
UserOnboardingUpdate,
|
||||
@@ -49,13 +50,16 @@ from backend.data.onboarding import (
|
||||
onboarding_enabled,
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
update_user_notification_preference,
|
||||
)
|
||||
from backend.executor import ExecutionManager, Scheduler, scheduler
|
||||
from backend.executor import Scheduler, scheduler
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||
on_graph_activate,
|
||||
@@ -79,13 +83,20 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_manager_client() -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
def execution_scheduler_client() -> Scheduler:
|
||||
return get_service_client(Scheduler)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_scheduler_client() -> Scheduler:
|
||||
return get_service_client(Scheduler)
|
||||
async def execution_queue_client() -> AsyncRabbitMQ:
|
||||
client = AsyncRabbitMQ(create_execution_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_event_bus() -> AsyncRedisExecutionEventBus:
|
||||
return AsyncRedisExecutionEventBus()
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -206,7 +217,7 @@ async def is_onboarding_enabled():
|
||||
|
||||
@v1_router.get(path="/blocks", tags=["blocks"], dependencies=[Depends(auth_middleware)])
|
||||
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||
blocks = [block() for block in get_blocks().values()]
|
||||
costs = get_block_costs()
|
||||
return [
|
||||
{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks if not b.disabled
|
||||
@@ -219,7 +230,7 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput:
|
||||
obj = backend.data.block.get_block(block_id)
|
||||
obj = get_block(block_id)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||
|
||||
@@ -308,7 +319,7 @@ async def configure_user_auto_top_up(
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_user_auto_top_up(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> AutoTopUpConfig:
|
||||
return await get_auto_top_up(user_id)
|
||||
|
||||
@@ -375,7 +386,7 @@ async def get_credit_history(
|
||||
|
||||
@v1_router.get(path="/credits/refunds", dependencies=[Depends(auth_middleware)])
|
||||
async def get_refund_requests(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[RefundRequest]:
|
||||
return await _user_credit_model.get_refund_requests(user_id)
|
||||
|
||||
@@ -391,7 +402,7 @@ class DeleteGraphResponse(TypedDict):
|
||||
|
||||
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
|
||||
async def get_graphs(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
|
||||
|
||||
@@ -580,16 +591,25 @@ async def set_graph_active_version(
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
def execute_graph(
|
||||
async def execute_graph(
|
||||
graph_id: str,
|
||||
node_input: Annotated[dict[str, Any], Body(..., default_factory=dict)],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
inputs: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
credentials_inputs: Annotated[
|
||||
dict[str, CredentialsMetaInput], Body(..., embed=True, default_factory=dict)
|
||||
],
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> ExecuteGraphResponse:
|
||||
graph_exec = execution_manager_client().add_execution(
|
||||
graph_id, node_input, user_id=user_id, graph_version=graph_version
|
||||
graph_exec = await execution_utils.add_graph_execution_async(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=credentials_inputs,
|
||||
)
|
||||
return ExecuteGraphResponse(graph_exec_id=graph_exec.graph_exec_id)
|
||||
return ExecuteGraphResponse(graph_exec_id=graph_exec.id)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -605,9 +625,7 @@ async def stop_graph_run(
|
||||
):
|
||||
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
|
||||
|
||||
await asyncio.to_thread(
|
||||
lambda: execution_manager_client().cancel_execution(graph_exec_id)
|
||||
)
|
||||
await _cancel_execution(graph_exec_id)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
result = await execution_db.get_graph_execution(
|
||||
@@ -621,6 +639,49 @@ async def stop_graph_run(
|
||||
return result
|
||||
|
||||
|
||||
async def _cancel_execution(graph_exec_id: str):
|
||||
"""
|
||||
Mechanism:
|
||||
1. Set the cancel event
|
||||
2. Graph executor's cancel handler thread detects the event, terminates workers,
|
||||
reinitializes worker pool, and returns.
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
queue_client = await execution_queue_client()
|
||||
await queue_client.publish_message(
|
||||
routing_key="",
|
||||
message=execution_utils.CancelExecutionEvent(
|
||||
graph_exec_id=graph_exec_id
|
||||
).model_dump_json(),
|
||||
exchange=execution_utils.GRAPH_EXECUTION_CANCEL_EXCHANGE,
|
||||
)
|
||||
|
||||
# Update the status of the graph & node executions
|
||||
await execution_db.update_graph_execution_stats(
|
||||
graph_exec_id,
|
||||
execution_db.ExecutionStatus.TERMINATED,
|
||||
)
|
||||
node_execs = [
|
||||
node_exec.model_copy(update={"status": execution_db.ExecutionStatus.TERMINATED})
|
||||
for node_exec in await execution_db.get_node_execution_results(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
execution_db.ExecutionStatus.QUEUED,
|
||||
execution_db.ExecutionStatus.RUNNING,
|
||||
execution_db.ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
await execution_db.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
execution_db.ExecutionStatus.TERMINATED,
|
||||
)
|
||||
await asyncio.gather(
|
||||
*[execution_event_bus().publish(node_exec) for node_exec in node_execs]
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/executions",
|
||||
tags=["graphs"],
|
||||
@@ -792,7 +853,7 @@ async def create_api_key(
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_api_keys(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[APIKeyWithoutHash]:
|
||||
"""List all API keys for the user"""
|
||||
try:
|
||||
|
||||
@@ -6,7 +6,6 @@ import prisma.errors
|
||||
import prisma.fields
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
from prisma.types import AgentPresetCreateInput
|
||||
|
||||
import backend.data.graph
|
||||
import backend.server.model
|
||||
@@ -69,12 +68,12 @@ async def list_library_agents(
|
||||
if search_term:
|
||||
where_clause["OR"] = [
|
||||
{
|
||||
"Agent": {
|
||||
"AgentGraph": {
|
||||
"is": {"name": {"contains": search_term, "mode": "insensitive"}}
|
||||
}
|
||||
},
|
||||
{
|
||||
"Agent": {
|
||||
"AgentGraph": {
|
||||
"is": {
|
||||
"description": {"contains": search_term, "mode": "insensitive"}
|
||||
}
|
||||
@@ -233,7 +232,8 @@ async def create_library_agent(
|
||||
isCreatedByUser=(user_id == graph.user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
Agent={
|
||||
# Creator={"connect": {"id": agent.userId}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
@@ -247,38 +247,41 @@ async def create_library_agent(
|
||||
|
||||
async def update_agent_version_in_library(
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
agent_version: int,
|
||||
agent_graph_id: str,
|
||||
agent_graph_version: int,
|
||||
) -> None:
|
||||
"""
|
||||
Updates the agent version in the library if useGraphIsActiveVersion is True.
|
||||
|
||||
Args:
|
||||
user_id: Owner of the LibraryAgent.
|
||||
agent_id: The agent's ID to update.
|
||||
agent_version: The new version of the agent.
|
||||
agent_graph_id: The agent graph's ID to update.
|
||||
agent_graph_version: The new version of the agent graph.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's an error with the update.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Updating agent version in library for user #{user_id}, "
|
||||
f"agent #{agent_id} v{agent_version}"
|
||||
f"agent #{agent_graph_id} v{agent_graph_version}"
|
||||
)
|
||||
try:
|
||||
library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentId": agent_id,
|
||||
"agentGraphId": agent_graph_id,
|
||||
"useGraphIsActiveVersion": True,
|
||||
},
|
||||
)
|
||||
await prisma.models.LibraryAgent.prisma().update(
|
||||
where={"id": library_agent.id},
|
||||
data={
|
||||
"Agent": {
|
||||
"AgentGraph": {
|
||||
"connect": {
|
||||
"graphVersionId": {"id": agent_id, "version": agent_version}
|
||||
"graphVersionId": {
|
||||
"id": agent_graph_id,
|
||||
"version": agent_graph_version,
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -342,7 +345,7 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
|
||||
"""
|
||||
try:
|
||||
await prisma.models.LibraryAgent.prisma().delete_many(
|
||||
where={"agentId": graph_id, "userId": user_id}
|
||||
where={"agentGraphId": graph_id, "userId": user_id}
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error deleting library agent: {e}")
|
||||
@@ -375,10 +378,10 @@ async def add_store_agent_to_library(
|
||||
async with locked_transaction(f"add_agent_trx_{user_id}"):
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"Agent": True}
|
||||
where={"id": store_listing_version_id}, include={"AgentGraph": True}
|
||||
)
|
||||
)
|
||||
if not store_listing_version or not store_listing_version.Agent:
|
||||
if not store_listing_version or not store_listing_version.AgentGraph:
|
||||
logger.warning(
|
||||
f"Store listing version not found: {store_listing_version_id}"
|
||||
)
|
||||
@@ -386,7 +389,7 @@ async def add_store_agent_to_library(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
graph = store_listing_version.Agent
|
||||
graph = store_listing_version.AgentGraph
|
||||
if graph.userId == user_id:
|
||||
logger.warning(
|
||||
f"User #{user_id} attempted to add their own agent to their library"
|
||||
@@ -398,8 +401,8 @@ async def add_store_agent_to_library(
|
||||
await prisma.models.LibraryAgent.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentId": graph.id,
|
||||
"agentVersion": graph.version,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
},
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
@@ -421,15 +424,15 @@ async def add_store_agent_to_library(
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId=user_id,
|
||||
agentId=graph.id,
|
||||
agentVersion=graph.version,
|
||||
agentGraphId=graph.id,
|
||||
agentGraphVersion=graph.version,
|
||||
isCreatedByUser=False,
|
||||
),
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
logger.debug(
|
||||
f"Added graph #{graph.id} "
|
||||
f"for store listing #{store_listing_version.id} "
|
||||
f"Added graph #{graph.id} v{graph.version}"
|
||||
f"for store listing version #{store_listing_version.id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
@@ -468,8 +471,8 @@ async def set_is_deleted_for_library_agent(
|
||||
count = await prisma.models.LibraryAgent.prisma().update_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentId": agent_id,
|
||||
"agentVersion": agent_version,
|
||||
"agentGraphId": agent_id,
|
||||
"agentGraphVersion": agent_version,
|
||||
},
|
||||
data={"isDeleted": is_deleted},
|
||||
)
|
||||
@@ -598,21 +601,22 @@ async def upsert_preset(
|
||||
f"Upserting preset #{preset_id} ({repr(preset.name)}) for user #{user_id}",
|
||||
)
|
||||
try:
|
||||
inputs = [
|
||||
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput(
|
||||
name=name, data=prisma.fields.Json(data)
|
||||
)
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
if preset_id:
|
||||
# Update existing preset
|
||||
updated = await prisma.models.AgentPreset.prisma().update(
|
||||
where={"id": preset_id},
|
||||
data=AgentPresetCreateInput(
|
||||
name=preset.name,
|
||||
description=preset.description,
|
||||
isActive=preset.is_active,
|
||||
InputPresets={
|
||||
"create": [
|
||||
{"name": name, "data": prisma.fields.Json(data)}
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
},
|
||||
),
|
||||
data={
|
||||
"name": preset.name,
|
||||
"description": preset.description,
|
||||
"isActive": preset.is_active,
|
||||
"InputPresets": {"create": inputs},
|
||||
},
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not updated:
|
||||
@@ -625,15 +629,10 @@ async def upsert_preset(
|
||||
userId=user_id,
|
||||
name=preset.name,
|
||||
description=preset.description,
|
||||
agentId=preset.agent_id,
|
||||
agentVersion=preset.agent_version,
|
||||
agentGraphId=preset.graph_id,
|
||||
agentGraphVersion=preset.graph_version,
|
||||
isActive=preset.is_active,
|
||||
InputPresets={
|
||||
"create": [
|
||||
{"name": name, "data": prisma.fields.Json(data)}
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
},
|
||||
InputPresets={"create": inputs},
|
||||
),
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
|
||||
@@ -30,8 +30,8 @@ async def test_get_library_agents(mocker):
|
||||
prisma.models.LibraryAgent(
|
||||
id="ua1",
|
||||
userId="test-user",
|
||||
agentId="agent2",
|
||||
agentVersion=1,
|
||||
agentGraphId="agent2",
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
@@ -39,7 +39,7 @@ async def test_get_library_agents(mocker):
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
Agent=prisma.models.AgentGraph(
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id="agent2",
|
||||
version=1,
|
||||
name="Test Agent 2",
|
||||
@@ -71,8 +71,8 @@ async def test_get_library_agents(mocker):
|
||||
assert result.agents[0].id == "ua1"
|
||||
assert result.agents[0].name == "Test Agent 2"
|
||||
assert result.agents[0].description == "Test Description 2"
|
||||
assert result.agents[0].agent_id == "agent2"
|
||||
assert result.agents[0].agent_version == 1
|
||||
assert result.agents[0].graph_id == "agent2"
|
||||
assert result.agents[0].graph_version == 1
|
||||
assert result.agents[0].can_access_graph is False
|
||||
assert result.agents[0].is_latest_version is True
|
||||
assert result.pagination.total_items == 1
|
||||
@@ -90,8 +90,8 @@ async def test_add_agent_to_library(mocker):
|
||||
version=1,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
agentId="agent1",
|
||||
agentVersion=1,
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
name="Test Agent",
|
||||
subHeading="Test Agent Subheading",
|
||||
imageUrls=["https://example.com/image.jpg"],
|
||||
@@ -102,7 +102,7 @@ async def test_add_agent_to_library(mocker):
|
||||
isAvailable=True,
|
||||
storeListingId="listing123",
|
||||
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
|
||||
Agent=prisma.models.AgentGraph(
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id="agent1",
|
||||
version=1,
|
||||
name="Test Agent",
|
||||
@@ -116,8 +116,8 @@ async def test_add_agent_to_library(mocker):
|
||||
mock_library_agent_data = prisma.models.LibraryAgent(
|
||||
id="ua1",
|
||||
userId="test-user",
|
||||
agentId=mock_store_listing_data.agentId,
|
||||
agentVersion=1,
|
||||
agentGraphId=mock_store_listing_data.agentGraphId,
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
@@ -125,7 +125,7 @@ async def test_add_agent_to_library(mocker):
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
Agent=mock_store_listing_data.Agent,
|
||||
AgentGraph=mock_store_listing_data.AgentGraph,
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
@@ -147,19 +147,22 @@ async def test_add_agent_to_library(mocker):
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"Agent": True}
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
mock_library_agent.return_value.find_first.assert_called_once_with(
|
||||
where={
|
||||
"userId": "test-user",
|
||||
"agentId": "agent1",
|
||||
"agentVersion": 1,
|
||||
"agentGraphId": "agent1",
|
||||
"agentGraphVersion": 1,
|
||||
},
|
||||
include=library_agent_include("test-user"),
|
||||
)
|
||||
mock_library_agent.return_value.create.assert_called_once_with(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
|
||||
userId="test-user",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
),
|
||||
include=library_agent_include("test-user"),
|
||||
)
|
||||
@@ -182,5 +185,5 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
|
||||
# Verify mock called correctly
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"Agent": True}
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
|
||||
@@ -25,8 +25,8 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
"""
|
||||
|
||||
id: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -58,12 +58,12 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
|
||||
model instance.
|
||||
"""
|
||||
if not agent.Agent:
|
||||
if not agent.AgentGraph:
|
||||
raise ValueError("Associated Agent record is required.")
|
||||
|
||||
graph = graph_model.GraphModel.from_db(agent.Agent)
|
||||
graph = graph_model.GraphModel.from_db(agent.AgentGraph)
|
||||
|
||||
agent_updated_at = agent.Agent.updatedAt
|
||||
agent_updated_at = agent.AgentGraph.updatedAt
|
||||
lib_agent_updated_at = agent.updatedAt
|
||||
|
||||
# Compute updated_at as the latest between library agent and graph
|
||||
@@ -83,21 +83,21 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||
days=7
|
||||
)
|
||||
executions = agent.Agent.AgentGraphExecution or []
|
||||
executions = agent.AgentGraph.Executions or []
|
||||
status_result = _calculate_agent_status(executions, week_ago)
|
||||
status = status_result.status
|
||||
new_output = status_result.new_output
|
||||
|
||||
# Check if user can access the graph
|
||||
can_access_graph = agent.Agent.userId == agent.userId
|
||||
can_access_graph = agent.AgentGraph.userId == agent.userId
|
||||
|
||||
# Hard-coded to True until a method to check is implemented
|
||||
is_latest_version = True
|
||||
|
||||
return LibraryAgent(
|
||||
id=agent.id,
|
||||
agent_id=agent.agentId,
|
||||
agent_version=agent.agentVersion,
|
||||
graph_id=agent.agentGraphId,
|
||||
graph_version=agent.agentGraphVersion,
|
||||
image_url=agent.imageUrl,
|
||||
creator_name=creator_name,
|
||||
creator_image_url=creator_image_url,
|
||||
@@ -174,8 +174,8 @@ class LibraryAgentPreset(pydantic.BaseModel):
|
||||
id: str
|
||||
updated_at: datetime.datetime
|
||||
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
|
||||
name: str
|
||||
description: str
|
||||
@@ -194,8 +194,8 @@ class LibraryAgentPreset(pydantic.BaseModel):
|
||||
return cls(
|
||||
id=preset.id,
|
||||
updated_at=preset.updatedAt,
|
||||
agent_id=preset.agentId,
|
||||
agent_version=preset.agentVersion,
|
||||
graph_id=preset.agentGraphId,
|
||||
graph_version=preset.agentGraphVersion,
|
||||
name=preset.name,
|
||||
description=preset.description,
|
||||
is_active=preset.isActive,
|
||||
@@ -218,8 +218,8 @@ class CreateLibraryAgentPresetRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
inputs: block_model.BlockInput
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
is_active: bool
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import prisma.models
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
from backend.util import json
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -15,8 +14,8 @@ async def test_agent_preset_from_db():
|
||||
id="test-agent-123",
|
||||
createdAt=datetime.datetime.now(),
|
||||
updatedAt=datetime.datetime.now(),
|
||||
agentId="agent-123",
|
||||
agentVersion=1,
|
||||
agentGraphId="agent-123",
|
||||
agentGraphVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test agent description",
|
||||
isActive=True,
|
||||
@@ -27,7 +26,7 @@ async def test_agent_preset_from_db():
|
||||
id="input-123",
|
||||
time=datetime.datetime.now(),
|
||||
name="input1",
|
||||
data=json.dumps({"type": "string", "value": "test value"}), # type: ignore
|
||||
data=prisma.Json({"type": "string", "value": "test value"}),
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -36,7 +35,7 @@ async def test_agent_preset_from_db():
|
||||
agent = library_model.LibraryAgentPreset.from_db(db_agent)
|
||||
|
||||
assert agent.id == "test-agent-123"
|
||||
assert agent.agent_version == 1
|
||||
assert agent.graph_version == 1
|
||||
assert agent.is_active is True
|
||||
assert agent.name == "Test Agent"
|
||||
assert agent.description == "Test agent description"
|
||||
|
||||
@@ -2,25 +2,17 @@ import logging
|
||||
from typing import Annotated, Any
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
import autogpt_libs.utils.cache
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, status
|
||||
|
||||
import backend.executor
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
import backend.util.service
|
||||
from backend.executor.utils import add_graph_execution_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@autogpt_libs.utils.cache.thread_cached
|
||||
def execution_manager_client() -> backend.executor.ExecutionManager:
|
||||
"""Return a cached instance of ExecutionManager client."""
|
||||
return backend.util.service.get_service_client(backend.executor.ExecutionManager)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/presets",
|
||||
summary="List presets",
|
||||
@@ -226,17 +218,17 @@ async def execute_preset(
|
||||
# Merge input overrides with preset inputs
|
||||
merged_node_input = preset.inputs | node_input
|
||||
|
||||
execution = execution_manager_client().add_execution(
|
||||
execution = await add_graph_execution_async(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
data=merged_node_input,
|
||||
user_id=user_id,
|
||||
inputs=merged_node_input,
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
|
||||
|
||||
return {"id": execution.graph_exec_id}
|
||||
return {"id": execution.id}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -35,8 +35,8 @@ async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||
agents=[
|
||||
library_model.LibraryAgent(
|
||||
id="test-agent-1",
|
||||
agent_id="test-agent-1",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -51,8 +51,8 @@ async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||
),
|
||||
library_model.LibraryAgent(
|
||||
id="test-agent-2",
|
||||
agent_id="test-agent-2",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-2",
|
||||
graph_version=1,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
@@ -78,9 +78,9 @@ async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||
|
||||
data = library_model.LibraryAgentResponse.model_validate(response.json())
|
||||
assert len(data.agents) == 2
|
||||
assert data.agents[0].agent_id == "test-agent-1"
|
||||
assert data.agents[0].graph_id == "test-agent-1"
|
||||
assert data.agents[0].can_access_graph is True
|
||||
assert data.agents[1].agent_id == "test-agent-2"
|
||||
assert data.agents[1].graph_id == "test-agent-2"
|
||||
assert data.agents[1].can_access_graph is False
|
||||
mock_db_call.assert_called_once_with(
|
||||
user_id="test-user-id",
|
||||
|
||||
@@ -200,17 +200,17 @@ async def get_available_graph(
|
||||
"isAvailable": True,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include={"Agent": {"include": {"AgentNodes": True}}},
|
||||
include={"AgentGraph": {"include": {"Nodes": True}}},
|
||||
)
|
||||
)
|
||||
|
||||
if not store_listing_version or not store_listing_version.Agent:
|
||||
if not store_listing_version or not store_listing_version.AgentGraph:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Store listing version {store_listing_version_id} not found",
|
||||
)
|
||||
|
||||
graph = GraphModel.from_db(store_listing_version.Agent)
|
||||
graph = GraphModel.from_db(store_listing_version.AgentGraph)
|
||||
# We return graph meta, without nodes, they cannot be just removed
|
||||
# because then input_schema would be empty
|
||||
return {
|
||||
@@ -516,7 +516,7 @@ async def delete_store_submission(
|
||||
try:
|
||||
# Verify the submission belongs to this user
|
||||
submission = await prisma.models.StoreListing.prisma().find_first(
|
||||
where={"agentId": submission_id, "owningUserId": user_id}
|
||||
where={"agentGraphId": submission_id, "owningUserId": user_id}
|
||||
)
|
||||
|
||||
if not submission:
|
||||
@@ -598,7 +598,7 @@ async def create_store_submission(
|
||||
# Check if listing already exists for this agent
|
||||
existing_listing = await prisma.models.StoreListing.prisma().find_first(
|
||||
where=prisma.types.StoreListingWhereInput(
|
||||
agentId=agent_id, owningUserId=user_id
|
||||
agentGraphId=agent_id, owningUserId=user_id
|
||||
)
|
||||
)
|
||||
|
||||
@@ -625,15 +625,15 @@ async def create_store_submission(
|
||||
# If no existing listing, create a new one
|
||||
data = prisma.types.StoreListingCreateInput(
|
||||
slug=slug,
|
||||
agentId=agent_id,
|
||||
agentVersion=agent_version,
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
owningUserId=user_id,
|
||||
createdAt=datetime.now(tz=timezone.utc),
|
||||
Versions={
|
||||
"create": [
|
||||
prisma.types.StoreListingVersionCreateInput(
|
||||
agentId=agent_id,
|
||||
agentVersion=agent_version,
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
imageUrls=image_urls,
|
||||
@@ -758,8 +758,8 @@ async def create_store_version(
|
||||
new_version = await prisma.models.StoreListingVersion.prisma().create(
|
||||
data=prisma.types.StoreListingVersionCreateInput(
|
||||
version=next_version,
|
||||
agentId=agent_id,
|
||||
agentVersion=agent_version,
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
imageUrls=image_urls,
|
||||
@@ -959,17 +959,17 @@ async def get_my_agents(
|
||||
try:
|
||||
search_filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
"Agent": {"is": {"StoreListing": {"none": {"isDeleted": False}}}},
|
||||
"AgentGraph": {"is": {"StoreListings": {"none": {"isDeleted": False}}}},
|
||||
"isArchived": False,
|
||||
"isDeleted": False,
|
||||
}
|
||||
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=search_filter,
|
||||
order=[{"agentVersion": "desc"}],
|
||||
order=[{"agentGraphVersion": "desc"}],
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
include={"Agent": True},
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
|
||||
total = await prisma.models.LibraryAgent.prisma().count(where=search_filter)
|
||||
@@ -985,7 +985,7 @@ async def get_my_agents(
|
||||
agent_image=library_agent.imageUrl,
|
||||
)
|
||||
for library_agent in library_agents
|
||||
if (graph := library_agent.Agent)
|
||||
if (graph := library_agent.AgentGraph)
|
||||
]
|
||||
|
||||
return backend.server.v2.store.model.MyAgentsResponse(
|
||||
@@ -1020,13 +1020,13 @@ async def get_agent(
|
||||
|
||||
graph = await backend.data.graph.get_graph(
|
||||
user_id=user_id,
|
||||
graph_id=store_listing_version.agentId,
|
||||
version=store_listing_version.agentVersion,
|
||||
graph_id=store_listing_version.agentGraphId,
|
||||
version=store_listing_version.agentGraphVersion,
|
||||
for_export=True,
|
||||
)
|
||||
if not graph:
|
||||
raise ValueError(
|
||||
f"Agent {store_listing_version.agentId} v{store_listing_version.agentVersion} not found"
|
||||
f"Agent {store_listing_version.agentGraphId} v{store_listing_version.agentGraphVersion} not found"
|
||||
)
|
||||
|
||||
return graph
|
||||
@@ -1050,11 +1050,14 @@ async def _get_missing_sub_store_listing(
|
||||
|
||||
# Fetch all the sub-graphs that are listed, and return the ones missing.
|
||||
store_listed_sub_graphs = {
|
||||
(listing.agentId, listing.agentVersion)
|
||||
(listing.agentGraphId, listing.agentGraphVersion)
|
||||
for listing in await prisma.models.StoreListingVersion.prisma().find_many(
|
||||
where={
|
||||
"OR": [
|
||||
{"agentId": sub_graph.id, "agentVersion": sub_graph.version}
|
||||
{
|
||||
"agentGraphId": sub_graph.id,
|
||||
"agentGraphVersion": sub_graph.version,
|
||||
}
|
||||
for sub_graph in sub_graphs
|
||||
],
|
||||
"submissionStatus": prisma.enums.SubmissionStatus.APPROVED,
|
||||
@@ -1084,7 +1087,7 @@ async def review_store_submission(
|
||||
where={"id": store_listing_version_id},
|
||||
include={
|
||||
"StoreListing": True,
|
||||
"Agent": {"include": AGENT_GRAPH_INCLUDE}, # type: ignore
|
||||
"AgentGraph": {"include": AGENT_GRAPH_INCLUDE},
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -1096,23 +1099,23 @@ async def review_store_submission(
|
||||
)
|
||||
|
||||
# If approving, update the listing to indicate it has an approved version
|
||||
if is_approved and store_listing_version.Agent:
|
||||
heading = f"Sub-graph of {store_listing_version.name}v{store_listing_version.agentVersion}"
|
||||
if is_approved and store_listing_version.AgentGraph:
|
||||
heading = f"Sub-graph of {store_listing_version.name}v{store_listing_version.agentGraphVersion}"
|
||||
|
||||
sub_store_listing_versions = [
|
||||
prisma.types.StoreListingVersionCreateWithoutRelationsInput(
|
||||
agentId=sub_graph.id,
|
||||
agentVersion=sub_graph.version,
|
||||
agentGraphId=sub_graph.id,
|
||||
agentGraphVersion=sub_graph.version,
|
||||
name=sub_graph.name or heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.APPROVED,
|
||||
subHeading=heading,
|
||||
description=f"{heading}: {sub_graph.description}",
|
||||
changesSummary=f"This listing is added as a {heading} / #{store_listing_version.agentId}.",
|
||||
changesSummary=f"This listing is added as a {heading} / #{store_listing_version.agentGraphId}.",
|
||||
isAvailable=False, # Hide sub-graphs from the store by default.
|
||||
submittedAt=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
for sub_graph in await _get_missing_sub_store_listing(
|
||||
store_listing_version.Agent
|
||||
store_listing_version.AgentGraph
|
||||
)
|
||||
]
|
||||
|
||||
@@ -1155,8 +1158,8 @@ async def review_store_submission(
|
||||
|
||||
# Convert to Pydantic model for consistency
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id=submission.agentId,
|
||||
agent_version=submission.agentVersion,
|
||||
agent_id=submission.agentGraphId,
|
||||
agent_version=submission.agentGraphVersion,
|
||||
name=submission.name,
|
||||
sub_heading=submission.subHeading,
|
||||
slug=(
|
||||
@@ -1294,8 +1297,8 @@ async def get_admin_listings_with_versions(
|
||||
# If we have versions, turn them into StoreSubmission models
|
||||
for version in listing.Versions or []:
|
||||
version_model = backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id=version.agentId,
|
||||
agent_version=version.agentVersion,
|
||||
agent_id=version.agentGraphId,
|
||||
agent_version=version.agentGraphVersion,
|
||||
name=version.name,
|
||||
sub_heading=version.subHeading,
|
||||
slug=listing.slug,
|
||||
@@ -1324,8 +1327,8 @@ async def get_admin_listings_with_versions(
|
||||
backend.server.v2.store.model.StoreListingWithVersions(
|
||||
listing_id=listing.id,
|
||||
slug=listing.slug,
|
||||
agent_id=listing.agentId,
|
||||
agent_version=listing.agentVersion,
|
||||
agent_id=listing.agentGraphId,
|
||||
agent_version=listing.agentGraphVersion,
|
||||
active_version_id=listing.activeVersionId,
|
||||
has_approved_version=listing.hasApprovedVersion,
|
||||
creator_email=creator_email,
|
||||
|
||||
@@ -170,14 +170,14 @@ async def test_create_store_submission(mocker):
|
||||
isDeleted=False,
|
||||
hasApprovedVersion=False,
|
||||
slug="test-agent",
|
||||
agentId="agent-id",
|
||||
agentVersion=1,
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
owningUserId="user-id",
|
||||
Versions=[
|
||||
prisma.models.StoreListingVersion(
|
||||
id="version-id",
|
||||
agentId="agent-id",
|
||||
agentVersion=1,
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
createdAt=datetime.now(),
|
||||
|
||||
@@ -15,21 +15,25 @@ def to_dict(data) -> dict:
|
||||
|
||||
|
||||
def dumps(data) -> str:
|
||||
return json.dumps(jsonable_encoder(data))
|
||||
return json.dumps(to_dict(data))
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str, *args, target_type: Type[T], **kwargs) -> T: ...
|
||||
def loads(data: str | bytes, *args, target_type: Type[T], **kwargs) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str, *args, **kwargs) -> Any: ...
|
||||
def loads(data: str | bytes, *args, **kwargs) -> Any: ...
|
||||
|
||||
|
||||
def loads(data: str, *args, target_type: Type[T] | None = None, **kwargs) -> Any:
|
||||
def loads(
|
||||
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
|
||||
) -> Any:
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
parsed = json.loads(data, *args, **kwargs)
|
||||
if target_type:
|
||||
return type_match(parsed, target_type)
|
||||
|
||||
@@ -14,9 +14,7 @@ def sentry_init():
|
||||
traces_sample_rate=1.0,
|
||||
profiles_sample_rate=1.0,
|
||||
environment=f"app:{Settings().config.app_env.value}-behave:{Settings().config.behave_as.value}",
|
||||
_experiments={
|
||||
"enable_logs": True,
|
||||
},
|
||||
_experiments={"enable_logs": True},
|
||||
integrations=[
|
||||
LoggingIntegration(sentry_logs_level=logging.INFO),
|
||||
AnthropicIntegration(
|
||||
|
||||
@@ -25,6 +25,7 @@ from pydantic import BaseModel, TypeAdapter, create_model
|
||||
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.json import to_dict
|
||||
from backend.util.metrics import sentry_init
|
||||
from backend.util.process import AppProcess, get_service_name
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Config
|
||||
@@ -196,6 +197,7 @@ class AppService(BaseAppService, ABC):
|
||||
self.shared_event_loop.run_until_complete(server.serve())
|
||||
|
||||
def run(self):
|
||||
sentry_init()
|
||||
super().run()
|
||||
self.fastapi_app = FastAPI()
|
||||
|
||||
|
||||
@@ -137,6 +137,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=8002,
|
||||
description="The port for execution manager daemon to run on",
|
||||
)
|
||||
execution_manager_loop_max_retry: int = Field(
|
||||
default=5,
|
||||
description="The maximum number of retries for the execution manager loop",
|
||||
)
|
||||
|
||||
execution_scheduler_port: int = Field(
|
||||
default=8003,
|
||||
|
||||
@@ -182,6 +182,7 @@ def _try_convert(value: Any, target_type: Type, raise_on_mismatch: bool) -> Any:
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
TT = TypeVar("TT")
|
||||
|
||||
|
||||
def type_match(value: Any, target_type: Type[T]) -> T:
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
/*
|
||||
Warnings:
|
||||
- The relation LibraryAgent:AgentPreset was REMOVED
|
||||
- A unique constraint covering the columns `[userId,agentGraphId,agentGraphVersion]` on the table `LibraryAgent` will be added. If there are existing duplicate values, this will fail.
|
||||
- The foreign key constraints on AgentPreset and LibraryAgent are being changed from CASCADE to RESTRICT for AgentGraph deletion, which means you cannot delete AgentGraphs that have associated LibraryAgents or AgentPresets.
|
||||
|
||||
Use the following query to check whether these conditions are satisfied:
|
||||
|
||||
-- Check for duplicate LibraryAgent userId + agentGraphId + agentGraphVersion combinations that would violate the new unique constraint
|
||||
SELECT la."userId",
|
||||
la."agentId" as graph_id,
|
||||
la."agentVersion" as graph_version,
|
||||
COUNT(*) as multiplicity
|
||||
FROM "LibraryAgent" la
|
||||
GROUP BY la."userId",
|
||||
la."agentId",
|
||||
la."agentVersion"
|
||||
HAVING COUNT(*) > 1;
|
||||
*/
|
||||
|
||||
-- Drop foreign key constraints on columns we're about to rename
|
||||
ALTER TABLE "AgentPreset" DROP CONSTRAINT "AgentPreset_agentId_agentVersion_fkey";
|
||||
ALTER TABLE "LibraryAgent" DROP CONSTRAINT "LibraryAgent_agentId_agentVersion_fkey";
|
||||
ALTER TABLE "LibraryAgent" DROP CONSTRAINT "LibraryAgent_agentPresetId_fkey";
|
||||
|
||||
-- Rename columns in AgentPreset
|
||||
ALTER TABLE "AgentPreset" RENAME COLUMN "agentId" TO "agentGraphId";
|
||||
ALTER TABLE "AgentPreset" RENAME COLUMN "agentVersion" TO "agentGraphVersion";
|
||||
|
||||
-- Rename columns in LibraryAgent
|
||||
ALTER TABLE "LibraryAgent" RENAME COLUMN "agentId" TO "agentGraphId";
|
||||
ALTER TABLE "LibraryAgent" RENAME COLUMN "agentVersion" TO "agentGraphVersion";
|
||||
|
||||
-- Drop LibraryAgent.agentPresetId column
|
||||
ALTER TABLE "LibraryAgent" DROP COLUMN "agentPresetId";
|
||||
|
||||
-- Replace userId index with unique index on userId + agentGraphId + agentGraphVersion
|
||||
DROP INDEX "LibraryAgent_userId_idx";
|
||||
CREATE UNIQUE INDEX "LibraryAgent_userId_agentGraphId_agentGraphVersion_key" ON "LibraryAgent"("userId", "agentGraphId", "agentGraphVersion");
|
||||
|
||||
-- Re-add the foreign key constraints with new column names
|
||||
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_agentGraphId_agentGraphVersion_fkey"
|
||||
FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph"("id", "version")
|
||||
ON DELETE RESTRICT -- Disallow deleting AgentGraph when still referenced by existing LibraryAgents
|
||||
ON UPDATE CASCADE;
|
||||
|
||||
ALTER TABLE "AgentPreset" ADD CONSTRAINT "AgentPreset_agentGraphId_agentGraphVersion_fkey"
|
||||
FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph"("id", "version")
|
||||
ON DELETE RESTRICT -- Disallow deleting AgentGraph when still referenced by existing AgentPresets
|
||||
ON UPDATE CASCADE;
|
||||
@@ -0,0 +1,35 @@
|
||||
/*
|
||||
- Rename column StoreListing.agentId to agentGraphId
|
||||
- Rename column StoreListing.agentVersion to agentGraphVersion
|
||||
- Rename column StoreListingVersion.agentId to agentGraphId
|
||||
- Rename column StoreListingVersion.agentVersion to agentGraphVersion
|
||||
*/
|
||||
|
||||
-- Drop foreign key constraints on columns we're about to rename
|
||||
ALTER TABLE "StoreListing" DROP CONSTRAINT "StoreListing_agentId_agentVersion_fkey";
|
||||
ALTER TABLE "StoreListingVersion" DROP CONSTRAINT "StoreListingVersion_agentId_agentVersion_fkey";
|
||||
|
||||
-- Drop indices on columns we're about to rename
|
||||
DROP INDEX "StoreListing_agentId_key";
|
||||
DROP INDEX "StoreListingVersion_agentId_agentVersion_idx";
|
||||
|
||||
-- Rename columns
|
||||
ALTER TABLE "StoreListing" RENAME COLUMN "agentId" TO "agentGraphId";
|
||||
ALTER TABLE "StoreListing" RENAME COLUMN "agentVersion" TO "agentGraphVersion";
|
||||
ALTER TABLE "StoreListingVersion" RENAME COLUMN "agentId" TO "agentGraphId";
|
||||
ALTER TABLE "StoreListingVersion" RENAME COLUMN "agentVersion" TO "agentGraphVersion";
|
||||
|
||||
-- Re-create indices with updated name on renamed columns
|
||||
CREATE UNIQUE INDEX "StoreListing_agentGraphId_key" ON "StoreListing"("agentGraphId");
|
||||
CREATE INDEX "StoreListingVersion_agentGraphId_agentGraphVersion_idx" ON "StoreListingVersion"("agentGraphId", "agentGraphVersion");
|
||||
|
||||
-- Re-create foreign key constraints with updated name on renamed columns
|
||||
ALTER TABLE "StoreListing" ADD CONSTRAINT "StoreListing_agentGraphId_agentGraphVersion_fkey"
|
||||
FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph"("id", "version")
|
||||
ON DELETE CASCADE
|
||||
ON UPDATE CASCADE;
|
||||
|
||||
ALTER TABLE "StoreListingVersion" ADD CONSTRAINT "StoreListingVersion_agentGraphId_agentGraphVersion_fkey"
|
||||
FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph"("id", "version")
|
||||
ON DELETE RESTRICT
|
||||
ON UPDATE CASCADE;
|
||||
@@ -0,0 +1,16 @@
|
||||
-- Modify the OnboardingStep enum
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'GET_RESULTS';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_VISIT';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_ADD_AGENT';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_RUN_AGENT';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_OPEN';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_SAVE_AGENT';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_RUN_AGENT';
|
||||
|
||||
-- Modify the UserOnboarding table
|
||||
ALTER TABLE "UserOnboarding"
|
||||
ADD COLUMN "updatedAt" TIMESTAMP(3),
|
||||
ADD COLUMN "notificationDot" BOOLEAN NOT NULL DEFAULT true,
|
||||
ADD COLUMN "notified" "OnboardingStep"[] DEFAULT '{}',
|
||||
ADD COLUMN "rewardedFor" "OnboardingStep"[] DEFAULT '{}',
|
||||
ADD COLUMN "onboardingAgentExecutionId" TEXT
|
||||
@@ -0,0 +1,56 @@
|
||||
-- Backfill nulls with empty arrays
|
||||
UPDATE "UserOnboarding"
|
||||
SET "integrations" = ARRAY[]::TEXT[]
|
||||
WHERE "integrations" IS NULL;
|
||||
|
||||
UPDATE "UserOnboarding"
|
||||
SET "completedSteps" = '{}'
|
||||
WHERE "completedSteps" IS NULL;
|
||||
|
||||
UPDATE "UserOnboarding"
|
||||
SET "notified" = '{}'
|
||||
WHERE "notified" IS NULL;
|
||||
|
||||
UPDATE "UserOnboarding"
|
||||
SET "rewardedFor" = '{}'
|
||||
WHERE "rewardedFor" IS NULL;
|
||||
|
||||
UPDATE "IntegrationWebhook"
|
||||
SET "events" = ARRAY[]::TEXT[]
|
||||
WHERE "events" IS NULL;
|
||||
|
||||
UPDATE "APIKey"
|
||||
SET "permissions" = '{}'
|
||||
WHERE "permissions" IS NULL;
|
||||
|
||||
UPDATE "Profile"
|
||||
SET "links" = ARRAY[]::TEXT[]
|
||||
WHERE "links" IS NULL;
|
||||
|
||||
UPDATE "StoreListingVersion"
|
||||
SET "imageUrls" = ARRAY[]::TEXT[]
|
||||
WHERE "imageUrls" IS NULL;
|
||||
|
||||
UPDATE "StoreListingVersion"
|
||||
SET "categories" = ARRAY[]::TEXT[]
|
||||
WHERE "categories" IS NULL;
|
||||
|
||||
-- Enforce NOT NULL constraints
|
||||
ALTER TABLE "UserOnboarding"
|
||||
ALTER COLUMN "integrations" SET NOT NULL,
|
||||
ALTER COLUMN "completedSteps" SET NOT NULL,
|
||||
ALTER COLUMN "notified" SET NOT NULL,
|
||||
ALTER COLUMN "rewardedFor" SET NOT NULL;
|
||||
|
||||
ALTER TABLE "IntegrationWebhook"
|
||||
ALTER COLUMN "events" SET NOT NULL;
|
||||
|
||||
ALTER TABLE "APIKey"
|
||||
ALTER COLUMN "permissions" SET NOT NULL;
|
||||
|
||||
ALTER TABLE "Profile"
|
||||
ALTER COLUMN "links" SET NOT NULL;
|
||||
|
||||
ALTER TABLE "StoreListingVersion"
|
||||
ALTER COLUMN "imageUrls" SET NOT NULL,
|
||||
ALTER COLUMN "categories" SET NOT NULL;
|
||||
@@ -1,11 +1,12 @@
|
||||
datasource db {
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
directUrl = env("DIRECT_URL")
|
||||
}
|
||||
|
||||
generator client {
|
||||
provider = "prisma-client-py"
|
||||
recursive_type_depth = 5
|
||||
recursive_type_depth = -1
|
||||
interface = "asyncio"
|
||||
previewFeatures = ["views"]
|
||||
}
|
||||
@@ -39,25 +40,26 @@ model User {
|
||||
AgentGraphExecutions AgentGraphExecution[]
|
||||
AnalyticsDetails AnalyticsDetails[]
|
||||
AnalyticsMetrics AnalyticsMetrics[]
|
||||
CreditTransaction CreditTransaction[]
|
||||
CreditTransactions CreditTransaction[]
|
||||
|
||||
AgentPreset AgentPreset[]
|
||||
LibraryAgent LibraryAgent[]
|
||||
AgentPresets AgentPreset[]
|
||||
LibraryAgents LibraryAgent[]
|
||||
|
||||
Profile Profile[]
|
||||
UserOnboarding UserOnboarding?
|
||||
StoreListing StoreListing[]
|
||||
StoreListingReview StoreListingReview[]
|
||||
StoreListings StoreListing[]
|
||||
StoreListingReviews StoreListingReview[]
|
||||
StoreVersionsReviewed StoreListingVersion[]
|
||||
APIKeys APIKey[]
|
||||
IntegrationWebhooks IntegrationWebhook[]
|
||||
UserNotificationBatch UserNotificationBatch[]
|
||||
NotificationBatches UserNotificationBatch[]
|
||||
|
||||
@@index([id])
|
||||
@@index([email])
|
||||
}
|
||||
|
||||
enum OnboardingStep {
|
||||
// Introductory onboarding (Library)
|
||||
WELCOME
|
||||
USAGE_REASON
|
||||
INTEGRATIONS
|
||||
@@ -65,18 +67,32 @@ enum OnboardingStep {
|
||||
AGENT_NEW_RUN
|
||||
AGENT_INPUT
|
||||
CONGRATS
|
||||
GET_RESULTS
|
||||
// Marketplace
|
||||
MARKETPLACE_VISIT
|
||||
MARKETPLACE_ADD_AGENT
|
||||
MARKETPLACE_RUN_AGENT
|
||||
// Builder
|
||||
BUILDER_OPEN
|
||||
BUILDER_SAVE_AGENT
|
||||
BUILDER_RUN_AGENT
|
||||
}
|
||||
|
||||
model UserOnboarding {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime? @updatedAt
|
||||
|
||||
completedSteps OnboardingStep[] @default([])
|
||||
notificationDot Boolean @default(true)
|
||||
notified OnboardingStep[] @default([])
|
||||
rewardedFor OnboardingStep[] @default([])
|
||||
usageReason String?
|
||||
integrations String[] @default([])
|
||||
otherIntegrations String?
|
||||
selectedStoreListingVersionId String?
|
||||
agentInput Json?
|
||||
onboardingAgentExecutionId String?
|
||||
|
||||
userId String @unique
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
@@ -102,13 +118,13 @@ model AgentGraph {
|
||||
// This allows us to delete user data with deleting the agent which maybe in use by other users
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
AgentNodes AgentNode[]
|
||||
AgentGraphExecution AgentGraphExecution[]
|
||||
Nodes AgentNode[]
|
||||
Executions AgentGraphExecution[]
|
||||
|
||||
AgentPreset AgentPreset[]
|
||||
LibraryAgent LibraryAgent[]
|
||||
StoreListing StoreListing[]
|
||||
StoreListingVersion StoreListingVersion[]
|
||||
Presets AgentPreset[]
|
||||
LibraryAgents LibraryAgent[]
|
||||
StoreListings StoreListing[]
|
||||
StoreListingVersions StoreListingVersion[]
|
||||
|
||||
@@id(name: "graphVersionId", [id, version])
|
||||
@@index([userId, isActive])
|
||||
@@ -139,13 +155,12 @@ model AgentPreset {
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
agentId String
|
||||
agentVersion Int
|
||||
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
|
||||
agentGraphId String
|
||||
agentGraphVersion Int
|
||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version], onDelete: Restrict)
|
||||
|
||||
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
|
||||
LibraryAgents LibraryAgent[]
|
||||
AgentExecution AgentGraphExecution[]
|
||||
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
|
||||
Executions AgentGraphExecution[]
|
||||
|
||||
isDeleted Boolean @default(false)
|
||||
|
||||
@@ -194,7 +209,7 @@ model UserNotificationBatch {
|
||||
}
|
||||
|
||||
// For the library page
|
||||
// It is a user controlled list of agents, that they will see in there library
|
||||
// It is a user controlled list of agents, that they will see in their library
|
||||
model LibraryAgent {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
@@ -205,12 +220,9 @@ model LibraryAgent {
|
||||
|
||||
imageUrl String?
|
||||
|
||||
agentId String
|
||||
agentVersion Int
|
||||
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version])
|
||||
|
||||
agentPresetId String?
|
||||
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
|
||||
agentGraphId String
|
||||
agentGraphVersion Int
|
||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version], onDelete: Restrict)
|
||||
|
||||
creatorId String?
|
||||
Creator Profile? @relation(fields: [creatorId], references: [id])
|
||||
@@ -222,7 +234,7 @@ model LibraryAgent {
|
||||
isArchived Boolean @default(false)
|
||||
isDeleted Boolean @default(false)
|
||||
|
||||
@@index([userId])
|
||||
@@unique([userId, agentGraphId, agentGraphVersion])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
@@ -256,7 +268,7 @@ model AgentNode {
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
ExecutionHistory AgentNodeExecution[]
|
||||
Executions AgentNodeExecution[]
|
||||
|
||||
@@index([agentGraphId, agentGraphVersion])
|
||||
@@index([agentBlockId])
|
||||
@@ -323,15 +335,15 @@ model AgentGraphExecution {
|
||||
agentGraphVersion Int @default(1)
|
||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version], onDelete: Cascade)
|
||||
|
||||
AgentNodeExecutions AgentNodeExecution[]
|
||||
NodeExecutions AgentNodeExecution[]
|
||||
|
||||
// Link to User model -- Executed by this user
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
stats Json?
|
||||
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
|
||||
agentPresetId String?
|
||||
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
|
||||
|
||||
@@index([agentGraphId, agentGraphVersion])
|
||||
@@index([userId])
|
||||
@@ -342,10 +354,10 @@ model AgentNodeExecution {
|
||||
id String @id @default(uuid())
|
||||
|
||||
agentGraphExecutionId String
|
||||
AgentGraphExecution AgentGraphExecution @relation(fields: [agentGraphExecutionId], references: [id], onDelete: Cascade)
|
||||
GraphExecution AgentGraphExecution @relation(fields: [agentGraphExecutionId], references: [id], onDelete: Cascade)
|
||||
|
||||
agentNodeId String
|
||||
AgentNode AgentNode @relation(fields: [agentNodeId], references: [id], onDelete: Cascade)
|
||||
Node AgentNode @relation(fields: [agentNodeId], references: [id], onDelete: Cascade)
|
||||
|
||||
Input AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionInput")
|
||||
Output AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionOutput")
|
||||
@@ -627,9 +639,9 @@ model StoreListing {
|
||||
ActiveVersion StoreListingVersion? @relation("ActiveVersion", fields: [activeVersionId], references: [id])
|
||||
|
||||
// The agent link here is only so we can do lookup on agentId
|
||||
agentId String
|
||||
agentVersion Int
|
||||
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
|
||||
agentGraphId String
|
||||
agentGraphVersion Int
|
||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version], onDelete: Cascade)
|
||||
|
||||
owningUserId String
|
||||
OwningUser User @relation(fields: [owningUserId], references: [id])
|
||||
@@ -638,7 +650,7 @@ model StoreListing {
|
||||
Versions StoreListingVersion[] @relation("ListingVersions")
|
||||
|
||||
// Unique index on agentId to ensure only one listing per agent, regardless of number of versions the agent has.
|
||||
@@unique([agentId])
|
||||
@@unique([agentGraphId])
|
||||
@@unique([owningUserId, slug])
|
||||
// Used in the view query
|
||||
@@index([isDeleted, hasApprovedVersion])
|
||||
@@ -651,9 +663,9 @@ model StoreListingVersion {
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
// The agent and version to be listed on the store
|
||||
agentId String
|
||||
agentVersion Int
|
||||
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version])
|
||||
agentGraphId String
|
||||
agentGraphVersion Int
|
||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
|
||||
|
||||
// Content fields
|
||||
name String
|
||||
@@ -697,7 +709,7 @@ model StoreListingVersion {
|
||||
@@index([storeListingId, submissionStatus, isAvailable])
|
||||
@@index([submissionStatus])
|
||||
@@index([reviewerId])
|
||||
@@index([agentId, agentVersion]) // Non-unique index for efficient lookups
|
||||
@@index([agentGraphId, agentGraphVersion]) // Non-unique index for efficient lookups
|
||||
}
|
||||
|
||||
model StoreListingReview {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from backend.data.execution import merge_execution_input, parse_execution_output
|
||||
from backend.executor.utils import merge_execution_input, parse_execution_output
|
||||
|
||||
|
||||
def test_parse_execution_output():
|
||||
|
||||
@@ -360,8 +360,8 @@ async def test_execute_preset(server: SpinTestServer):
|
||||
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
agent_id=test_graph.id,
|
||||
agent_version=test_graph.version,
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
inputs={
|
||||
"dictionary": {"key1": "Hello", "key2": "World"},
|
||||
"selected_value": "key2",
|
||||
@@ -449,8 +449,8 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
agent_id=test_graph.id,
|
||||
agent_version=test_graph.version,
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
inputs={
|
||||
"dictionary": {"key1": "Hello", "key2": "World"},
|
||||
"selected_value": "key2",
|
||||
|
||||
@@ -10,16 +10,12 @@ from prisma.types import (
|
||||
AgentGraphCreateInput,
|
||||
AgentNodeCreateInput,
|
||||
AgentNodeLinkCreateInput,
|
||||
AgentPresetCreateInput,
|
||||
AnalyticsDetailsCreateInput,
|
||||
AnalyticsMetricsCreateInput,
|
||||
APIKeyCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
LibraryAgentCreateInput,
|
||||
ProfileCreateInput,
|
||||
StoreListingCreateInput,
|
||||
StoreListingReviewCreateInput,
|
||||
StoreListingVersionCreateInput,
|
||||
UserCreateInput,
|
||||
)
|
||||
|
||||
@@ -140,14 +136,14 @@ async def main():
|
||||
for _ in range(num_presets): # Create 1 AgentPreset per user
|
||||
graph = random.choice(agent_graphs)
|
||||
preset = await db.agentpreset.create(
|
||||
data=AgentPresetCreateInput(
|
||||
name=faker.sentence(nb_words=3),
|
||||
description=faker.text(max_nb_chars=200),
|
||||
userId=user.id,
|
||||
agentId=graph.id,
|
||||
agentVersion=graph.version,
|
||||
isActive=True,
|
||||
)
|
||||
data={
|
||||
"name": faker.sentence(nb_words=3),
|
||||
"description": faker.text(max_nb_chars=200),
|
||||
"userId": user.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"isActive": True,
|
||||
}
|
||||
)
|
||||
agent_presets.append(preset)
|
||||
|
||||
@@ -160,16 +156,15 @@ async def main():
|
||||
graph = random.choice(agent_graphs)
|
||||
preset = random.choice(agent_presets)
|
||||
user_agent = await db.libraryagent.create(
|
||||
data=LibraryAgentCreateInput(
|
||||
userId=user.id,
|
||||
agentId=graph.id,
|
||||
agentVersion=graph.version,
|
||||
agentPresetId=preset.id,
|
||||
isFavorite=random.choice([True, False]),
|
||||
isCreatedByUser=random.choice([True, False]),
|
||||
isArchived=random.choice([True, False]),
|
||||
isDeleted=random.choice([True, False]),
|
||||
)
|
||||
data={
|
||||
"userId": user.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"isFavorite": random.choice([True, False]),
|
||||
"isCreatedByUser": random.choice([True, False]),
|
||||
"isArchived": random.choice([True, False]),
|
||||
"isDeleted": random.choice([True, False]),
|
||||
}
|
||||
)
|
||||
user_agents.append(user_agent)
|
||||
|
||||
@@ -346,13 +341,13 @@ async def main():
|
||||
user = random.choice(users)
|
||||
slug = faker.slug()
|
||||
listing = await db.storelisting.create(
|
||||
data=StoreListingCreateInput(
|
||||
agentId=graph.id,
|
||||
agentVersion=graph.version,
|
||||
owningUserId=user.id,
|
||||
hasApprovedVersion=random.choice([True, False]),
|
||||
slug=slug,
|
||||
)
|
||||
data={
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"owningUserId": user.id,
|
||||
"hasApprovedVersion": random.choice([True, False]),
|
||||
"slug": slug,
|
||||
}
|
||||
)
|
||||
store_listings.append(listing)
|
||||
|
||||
@@ -362,26 +357,26 @@ async def main():
|
||||
for listing in store_listings:
|
||||
graph = [g for g in agent_graphs if g.id == listing.agentId][0]
|
||||
version = await db.storelistingversion.create(
|
||||
data=StoreListingVersionCreateInput(
|
||||
agentId=graph.id,
|
||||
agentVersion=graph.version,
|
||||
name=graph.name or faker.sentence(nb_words=3),
|
||||
subHeading=faker.sentence(),
|
||||
videoUrl=faker.url(),
|
||||
imageUrls=[get_image() for _ in range(3)],
|
||||
description=faker.text(),
|
||||
categories=[faker.word() for _ in range(3)],
|
||||
isFeatured=random.choice([True, False]),
|
||||
isAvailable=True,
|
||||
storeListingId=listing.id,
|
||||
submissionStatus=random.choice(
|
||||
data={
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"name": graph.name or faker.sentence(nb_words=3),
|
||||
"subHeading": faker.sentence(),
|
||||
"videoUrl": faker.url(),
|
||||
"imageUrls": [get_image() for _ in range(3)],
|
||||
"description": faker.text(),
|
||||
"categories": [faker.word() for _ in range(3)],
|
||||
"isFeatured": random.choice([True, False]),
|
||||
"isAvailable": True,
|
||||
"storeListingId": listing.id,
|
||||
"submissionStatus": random.choice(
|
||||
[
|
||||
prisma.enums.SubmissionStatus.PENDING,
|
||||
prisma.enums.SubmissionStatus.APPROVED,
|
||||
prisma.enums.SubmissionStatus.REJECTED,
|
||||
]
|
||||
),
|
||||
)
|
||||
}
|
||||
)
|
||||
store_listing_versions.append(version)
|
||||
|
||||
@@ -422,23 +417,12 @@ async def main():
|
||||
)
|
||||
await db.storelistingversion.update(
|
||||
where={"id": version.id},
|
||||
data=StoreListingVersionCreateInput(
|
||||
submissionStatus=status,
|
||||
Reviewer={"connect": {"id": reviewer.id}},
|
||||
reviewComments=faker.text(),
|
||||
reviewedAt=datetime.now(),
|
||||
agentId=version.agentId, # preserving existing fields
|
||||
agentVersion=version.agentVersion,
|
||||
name=version.name,
|
||||
subHeading=version.subHeading,
|
||||
videoUrl=version.videoUrl,
|
||||
imageUrls=version.imageUrls,
|
||||
description=version.description,
|
||||
categories=version.categories,
|
||||
isFeatured=version.isFeatured,
|
||||
isAvailable=version.isAvailable,
|
||||
storeListingId=version.storeListingId,
|
||||
),
|
||||
data={
|
||||
"submissionStatus": status,
|
||||
"Reviewer": {"connect": {"id": reviewer.id}},
|
||||
"reviewComments": faker.text(),
|
||||
"reviewedAt": datetime.now(),
|
||||
},
|
||||
)
|
||||
|
||||
# Insert APIKeys
|
||||
|
||||
@@ -15,6 +15,7 @@ services:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
networks:
|
||||
- app-network
|
||||
restart: on-failure
|
||||
@@ -77,6 +78,7 @@ services:
|
||||
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
- SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- REDIS_HOST=redis
|
||||
- REDIS_PORT=6379
|
||||
- RABBITMQ_HOST=rabbitmq
|
||||
@@ -126,6 +128,7 @@ services:
|
||||
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
- SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- REDIS_HOST=redis
|
||||
- REDIS_PORT=6379
|
||||
- REDIS_PASSWORD=password
|
||||
@@ -167,6 +170,7 @@ services:
|
||||
- DATABASEMANAGER_HOST=rest_server
|
||||
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- REDIS_HOST=redis
|
||||
- REDIS_PORT=6379
|
||||
- REDIS_PASSWORD=password
|
||||
@@ -201,6 +205,7 @@ services:
|
||||
# - NEXT_PUBLIC_SUPABASE_URL=http://kong:8000
|
||||
# - NEXT_PUBLIC_SUPABASE_ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
|
||||
# - DATABASE_URL=postgresql://agpt_user:pass123@postgres:5432/postgres?connect_timeout=60&schema=platform
|
||||
# - DIRECT_URL=postgresql://agpt_user:pass123@postgres:5432/postgres?connect_timeout=60&schema=platform
|
||||
# - NEXT_PUBLIC_AGPT_SERVER_URL=http://localhost:8006/api
|
||||
# - NEXT_PUBLIC_AGPT_WS_SERVER_URL=ws://localhost:8001/ws
|
||||
# - NEXT_PUBLIC_AGPT_MARKETPLACE_URL=http://localhost:8015/api/v1/market
|
||||
|
||||
@@ -4,7 +4,7 @@ NEXT_PUBLIC_AGPT_WS_SERVER_URL=ws://localhost:8001/ws
|
||||
NEXT_PUBLIC_AGPT_MARKETPLACE_URL=http://localhost:8015/api/v1/market
|
||||
NEXT_PUBLIC_LAUNCHDARKLY_ENABLED=false
|
||||
NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID=
|
||||
NEXT_PUBLIC_APP_ENV=dev
|
||||
NEXT_PUBLIC_APP_ENV=local
|
||||
|
||||
## Locale settings
|
||||
|
||||
|
||||
@@ -2,12 +2,14 @@
|
||||
// The config you add here will be used whenever a users loads a page in their browser.
|
||||
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
|
||||
|
||||
import { getEnvironmentStr } from "@/lib/utils";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
|
||||
Sentry.init({
|
||||
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
|
||||
|
||||
enabled: process.env.DISABLE_SENTRY !== "true",
|
||||
environment: getEnvironmentStr(),
|
||||
|
||||
// Add optional integrations for additional features
|
||||
integrations: [
|
||||
@@ -28,7 +30,9 @@ Sentry.init({
|
||||
// Set `tracePropagationTargets` to control for which URLs trace propagation should be enabled
|
||||
tracePropagationTargets: [
|
||||
"localhost",
|
||||
"localhost:8006",
|
||||
/^https:\/\/dev\-builder\.agpt\.co\/api/,
|
||||
/^https:\/\/.*\.agpt\.co\/api/,
|
||||
],
|
||||
|
||||
// Define how likely Replay events are sampled.
|
||||
@@ -48,4 +52,8 @@ Sentry.init({
|
||||
// For example, a tracesSampleRate of 0.5 and profilesSampleRate of 0.5 would
|
||||
// result in 25% of transactions being profiled (0.5*0.5=0.25)
|
||||
profilesSampleRate: 1.0,
|
||||
_experiments: {
|
||||
// Enable logs to be sent to Sentry.
|
||||
enableLogs: true,
|
||||
},
|
||||
});
|
||||
@@ -42,6 +42,7 @@
|
||||
"@radix-ui/react-separator": "^1.1.0",
|
||||
"@radix-ui/react-slot": "^1.1.0",
|
||||
"@radix-ui/react-switch": "^1.1.1",
|
||||
"@radix-ui/react-tabs": "^1.1.4",
|
||||
"@radix-ui/react-toast": "^1.2.5",
|
||||
"@radix-ui/react-tooltip": "^1.1.7",
|
||||
"@sentry/nextjs": "^9",
|
||||
@@ -52,7 +53,6 @@
|
||||
"@xyflow/react": "12.4.2",
|
||||
"ajv": "^8.17.1",
|
||||
"boring-avatars": "^1.11.2",
|
||||
"canvas-confetti": "^1.9.3",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "1.0.4",
|
||||
@@ -70,6 +70,7 @@
|
||||
"moment": "^2.30.1",
|
||||
"next": "^14.2.26",
|
||||
"next-themes": "^0.4.5",
|
||||
"party-js": "^2.2.0",
|
||||
"react": "^18",
|
||||
"react-day-picker": "^9.6.1",
|
||||
"react-dom": "^18",
|
||||
|
||||
BIN
autogpt_platform/frontend/public/onboarding/builder-open.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/builder-open.mp4
Normal file
Binary file not shown.
BIN
autogpt_platform/frontend/public/onboarding/builder-run.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/builder-run.mp4
Normal file
Binary file not shown.
BIN
autogpt_platform/frontend/public/onboarding/builder-save.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/builder-save.mp4
Normal file
Binary file not shown.
BIN
autogpt_platform/frontend/public/onboarding/get-results.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/get-results.mp4
Normal file
Binary file not shown.
BIN
autogpt_platform/frontend/public/onboarding/marketplace-add.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/marketplace-add.mp4
Normal file
Binary file not shown.
BIN
autogpt_platform/frontend/public/onboarding/marketplace-run.mp4
Normal file
BIN
autogpt_platform/frontend/public/onboarding/marketplace-run.mp4
Normal file
Binary file not shown.
Binary file not shown.
@@ -4,15 +4,28 @@
|
||||
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
|
||||
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { getEnvironmentStr } from "./src/lib/utils";
|
||||
|
||||
Sentry.init({
|
||||
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
|
||||
|
||||
enabled: process.env.NODE_ENV !== "development",
|
||||
environment: getEnvironmentStr(),
|
||||
|
||||
// Define how likely traces are sampled. Adjust this value in production, or use tracesSampler for greater control.
|
||||
tracesSampleRate: 1,
|
||||
tracePropagationTargets: [
|
||||
"localhost",
|
||||
"localhost:8006",
|
||||
/^https:\/\/dev\-builder\.agpt\.co\/api/,
|
||||
/^https:\/\/.*\.agpt\.co\/api/,
|
||||
],
|
||||
|
||||
// Setting this option to true will print useful information to the console while you're setting up Sentry.
|
||||
debug: false,
|
||||
|
||||
_experiments: {
|
||||
// Enable logs to be sent to Sentry.
|
||||
enableLogs: true,
|
||||
},
|
||||
});
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// The config you add here will be used whenever the server handles a request.
|
||||
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
|
||||
|
||||
import { getEnvironmentStr } from "@/lib/utils";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
// import { NodeProfilingIntegration } from "@sentry/profiling-node";
|
||||
|
||||
@@ -9,9 +10,16 @@ Sentry.init({
|
||||
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
|
||||
|
||||
enabled: process.env.NODE_ENV !== "development",
|
||||
environment: getEnvironmentStr(),
|
||||
|
||||
// Define how likely traces are sampled. Adjust this value in production, or use tracesSampler for greater control.
|
||||
tracesSampleRate: 1,
|
||||
tracePropagationTargets: [
|
||||
"localhost",
|
||||
"localhost:8006",
|
||||
/^https:\/\/dev\-builder\.agpt\.co\/api/,
|
||||
/^https:\/\/.*\.agpt\.co\/api/,
|
||||
],
|
||||
|
||||
// Setting this option to true will print useful information to the console while you're setting up Sentry.
|
||||
debug: false,
|
||||
@@ -22,4 +30,9 @@ Sentry.init({
|
||||
// NodeProfilingIntegration,
|
||||
// Sentry.fsIntegration(),
|
||||
],
|
||||
|
||||
_experiments: {
|
||||
// Enable logs to be sent to Sentry.
|
||||
enableLogs: true,
|
||||
},
|
||||
});
|
||||
|
||||
@@ -3,9 +3,16 @@
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { GraphID } from "@/lib/autogpt-server-api/types";
|
||||
import FlowEditor from "@/components/Flow";
|
||||
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
import { useEffect } from "react";
|
||||
|
||||
export default function Home() {
|
||||
const query = useSearchParams();
|
||||
const { completeStep } = useOnboarding();
|
||||
|
||||
useEffect(() => {
|
||||
completeStep("BUILDER_OPEN");
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<FlowEditor
|
||||
|
||||
@@ -144,3 +144,13 @@
|
||||
text-wrap: balance;
|
||||
}
|
||||
}
|
||||
|
||||
input[type="number"]::-webkit-outer-spin-button,
|
||||
input[type="number"]::-webkit-inner-spin-button {
|
||||
-webkit-appearance: none;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
input[type="number"] {
|
||||
-moz-appearance: textfield;
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import AgentRunDetailsView from "@/components/agents/agent-run-details-view";
|
||||
import AgentRunsSelectorList from "@/components/agents/agent-runs-selector-list";
|
||||
import AgentScheduleDetailsView from "@/components/agents/agent-schedule-details-view";
|
||||
import LibraryRunLoadingSkeleton from "./loading";
|
||||
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
|
||||
export default function AgentRunsPage(): React.ReactElement {
|
||||
const { id: agentID }: { id: LibraryAgentID } = useParams();
|
||||
@@ -50,6 +51,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
useState<boolean>(false);
|
||||
const [confirmingDeleteAgentRun, setConfirmingDeleteAgentRun] =
|
||||
useState<GraphExecutionMeta | null>(null);
|
||||
const { state, updateState } = useOnboarding();
|
||||
|
||||
const openRunDraftView = useCallback(() => {
|
||||
selectView({ type: "run" });
|
||||
@@ -79,20 +81,32 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
[api, graphVersions],
|
||||
);
|
||||
|
||||
// Reward user for viewing results of their onboarding agent
|
||||
useEffect(() => {
|
||||
if (!state || !selectedRun || state.completedSteps.includes("GET_RESULTS"))
|
||||
return;
|
||||
|
||||
if (selectedRun.id === state.onboardingAgentExecutionId) {
|
||||
updateState({
|
||||
completedSteps: [...state.completedSteps, "GET_RESULTS"],
|
||||
});
|
||||
}
|
||||
}, [selectedRun, state]);
|
||||
|
||||
const fetchAgents = useCallback(() => {
|
||||
api.getLibraryAgent(agentID).then((agent) => {
|
||||
setAgent(agent);
|
||||
|
||||
getGraphVersion(agent.agent_id, agent.agent_version).then(
|
||||
getGraphVersion(agent.graph_id, agent.graph_version).then(
|
||||
(_graph) =>
|
||||
(graph && graph.version == _graph.version) || setGraph(_graph),
|
||||
);
|
||||
api.getGraphExecutions(agent.agent_id).then((agentRuns) => {
|
||||
api.getGraphExecutions(agent.graph_id).then((agentRuns) => {
|
||||
setAgentRuns(agentRuns);
|
||||
|
||||
// Preload the corresponding graph versions
|
||||
new Set(agentRuns.map((run) => run.graph_version)).forEach((version) =>
|
||||
getGraphVersion(agent.agent_id, version),
|
||||
getGraphVersion(agent.graph_id, version),
|
||||
);
|
||||
|
||||
if (!selectedView.id && isFirstLoad && agentRuns.length > 0) {
|
||||
@@ -110,7 +124,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
});
|
||||
if (selectedView.type == "run" && selectedView.id && agent) {
|
||||
api
|
||||
.getGraphExecutionInfo(agent.agent_id, selectedView.id)
|
||||
.getGraphExecutionInfo(agent.graph_id, selectedView.id)
|
||||
.then(setSelectedRun);
|
||||
}
|
||||
}, [api, agentID, getGraphVersion, graph, selectedView, isFirstLoad, agent]);
|
||||
@@ -124,7 +138,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
if (!agent) return;
|
||||
|
||||
// Subscribe to all executions for this agent
|
||||
api.subscribeToGraphExecutions(agent.agent_id);
|
||||
api.subscribeToGraphExecutions(agent.graph_id);
|
||||
}, [api, agent]);
|
||||
|
||||
// Handle execution updates
|
||||
@@ -132,6 +146,8 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
const detachExecUpdateHandler = api.onWebSocketMessage(
|
||||
"graph_execution_event",
|
||||
(data) => {
|
||||
if (data.graph_id != agent?.graph_id) return;
|
||||
|
||||
setAgentRuns((prev) => {
|
||||
const index = prev.findIndex((run) => run.id === data.id);
|
||||
if (index === -1) {
|
||||
@@ -150,7 +166,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
return () => {
|
||||
detachExecUpdateHandler();
|
||||
};
|
||||
}, [api, selectedView.id]);
|
||||
}, [api, agent?.graph_id, selectedView.id]);
|
||||
|
||||
// load selectedRun based on selectedView
|
||||
useEffect(() => {
|
||||
@@ -163,7 +179,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
|
||||
// Ensure corresponding graph version is available before rendering I/O
|
||||
api
|
||||
.getGraphExecutionInfo(agent.agent_id, selectedView.id)
|
||||
.getGraphExecutionInfo(agent.graph_id, selectedView.id)
|
||||
.then(async (run) => {
|
||||
await getGraphVersion(run.graph_id, run.graph_version);
|
||||
setSelectedRun(run);
|
||||
@@ -176,7 +192,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
|
||||
// TODO: filter in backend - https://github.com/Significant-Gravitas/AutoGPT/issues/9183
|
||||
setSchedules(
|
||||
(await api.listSchedules()).filter((s) => s.graph_id == agent.agent_id),
|
||||
(await api.listSchedules()).filter((s) => s.graph_id == agent.graph_id),
|
||||
);
|
||||
}, [api, agent]);
|
||||
|
||||
@@ -215,7 +231,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
agent &&
|
||||
// Export sanitized graph from backend
|
||||
api
|
||||
.getGraph(agent.agent_id, agent.agent_version, true)
|
||||
.getGraph(agent.graph_id, agent.graph_version, true)
|
||||
.then((graph) =>
|
||||
exportAsJSONFile(graph, `${graph.name}_v${graph.version}.json`),
|
||||
),
|
||||
@@ -228,7 +244,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
? [
|
||||
{
|
||||
label: "Open graph in builder",
|
||||
href: `/build?flowID=${agent.agent_id}&flowVersion=${agent.agent_version}`,
|
||||
href: `/build?flowID=${agent.graph_id}&flowVersion=${agent.graph_version}`,
|
||||
},
|
||||
{ label: "Export agent to file", callback: downloadGraph },
|
||||
]
|
||||
@@ -243,7 +259,6 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
);
|
||||
|
||||
if (!agent || !graph) {
|
||||
/* TODO: implement loading indicators / skeleton page */
|
||||
return <LibraryRunLoadingSkeleton />;
|
||||
}
|
||||
|
||||
|
||||
@@ -81,16 +81,19 @@ export default async function Page({
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<Separator className="mb-[25px] mt-6" />
|
||||
<Separator className="mb-[25px] mt-[60px]" />
|
||||
<AgentsSection
|
||||
margin="32px"
|
||||
agents={otherAgents.agents}
|
||||
sectionTitle={`Other agents by ${agent.creator}`}
|
||||
/>
|
||||
<Separator className="mb-[25px] mt-6" />
|
||||
<Separator className="mb-[25px] mt-[60px]" />
|
||||
<AgentsSection
|
||||
margin="32px"
|
||||
agents={similarAgents.agents}
|
||||
sectionTitle="Similar agents"
|
||||
/>
|
||||
<Separator className="mb-[25px] mt-[60px]" />
|
||||
<BecomeACreator
|
||||
title="Become a Creator"
|
||||
description="Join our ever-growing community of hackers and tinkerers"
|
||||
|
||||
@@ -8,6 +8,7 @@ import { BreadCrumbs } from "@/components/agptui/BreadCrumbs";
|
||||
import { Metadata } from "next";
|
||||
import { CreatorInfoCard } from "@/components/agptui/CreatorInfoCard";
|
||||
import { CreatorLinks } from "@/components/agptui/CreatorLinks";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
|
||||
export async function generateMetadata({
|
||||
params,
|
||||
@@ -78,7 +79,7 @@ export default async function Page({
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-8 sm:mt-12 md:mt-16 lg:pb-[58px]">
|
||||
<hr className="w-full bg-neutral-700" />
|
||||
<Separator className="mb-6 bg-gray-200" />
|
||||
<AgentsSection
|
||||
agents={creatorAgents.agents}
|
||||
hideAvatars={true}
|
||||
|
||||
@@ -153,16 +153,17 @@ export default async function Page({}: {}) {
|
||||
<main className="px-4">
|
||||
<HeroSection />
|
||||
<FeaturedSection featuredAgents={featuredAgents.agents} />
|
||||
<Separator />
|
||||
{/* 100px margin because our featured sections button are placed 40px below the container */}
|
||||
<Separator className="mb-6 mt-24" />
|
||||
<AgentsSection
|
||||
sectionTitle="Top Agents"
|
||||
agents={topAgents.agents as Agent[]}
|
||||
/>
|
||||
<Separator />
|
||||
<Separator className="mb-[25px] mt-[60px]" />
|
||||
<FeaturedCreators
|
||||
featuredCreators={featuredCreators.creators as FeaturedCreator[]}
|
||||
/>
|
||||
<Separator />
|
||||
<Separator className="mb-[25px] mt-[60px]" />
|
||||
<BecomeACreator
|
||||
title="Become a Creator"
|
||||
description="Join our ever-growing community of hackers and tinkerers"
|
||||
|
||||
@@ -98,7 +98,7 @@ const Monitor = () => {
|
||||
flows={flows}
|
||||
executions={[
|
||||
...(selectedFlow
|
||||
? executions.filter((v) => v.graph_id == selectedFlow.agent_id)
|
||||
? executions.filter((v) => v.graph_id == selectedFlow.graph_id)
|
||||
: executions),
|
||||
].sort((a, b) => b.started_at.getTime() - a.started_at.getTime())}
|
||||
selectedRun={selectedRun}
|
||||
@@ -108,7 +108,7 @@ const Monitor = () => {
|
||||
<FlowRunInfo
|
||||
agent={
|
||||
selectedFlow ||
|
||||
flows.find((f) => f.agent_id == selectedRun.graph_id)!
|
||||
flows.find((f) => f.graph_id == selectedRun.graph_id)!
|
||||
}
|
||||
execution={selectedRun}
|
||||
className={column3}
|
||||
@@ -118,7 +118,7 @@ const Monitor = () => {
|
||||
<FlowInfo
|
||||
flow={selectedFlow}
|
||||
executions={executions.filter(
|
||||
(e) => e.graph_id == selectedFlow.agent_id,
|
||||
(e) => e.graph_id == selectedFlow.graph_id,
|
||||
)}
|
||||
className={column3}
|
||||
refresh={() => {
|
||||
|
||||
@@ -80,12 +80,23 @@ export default function Page() {
|
||||
if (!agent) {
|
||||
return;
|
||||
}
|
||||
api.addMarketplaceAgentToLibrary(
|
||||
storeAgent?.store_listing_version_id || "",
|
||||
);
|
||||
api.executeGraph(agent.id, agent.version, state?.agentInput || {});
|
||||
router.push("/onboarding/6-congrats");
|
||||
}, [api, agent, router, state?.agentInput]);
|
||||
api
|
||||
.addMarketplaceAgentToLibrary(storeAgent?.store_listing_version_id || "")
|
||||
.then((libraryAgent) => {
|
||||
api
|
||||
.executeGraph(
|
||||
libraryAgent.graph_id,
|
||||
libraryAgent.graph_version,
|
||||
state?.agentInput || {},
|
||||
)
|
||||
.then(({ graph_exec_id }) => {
|
||||
updateState({
|
||||
onboardingAgentExecutionId: graph_exec_id,
|
||||
});
|
||||
router.push("/onboarding/6-congrats");
|
||||
});
|
||||
});
|
||||
}, [api, agent, router, state?.agentInput, storeAgent, updateState]);
|
||||
|
||||
const runYourAgent = (
|
||||
<div className="ml-[54px] w-[481px] pl-5">
|
||||
|
||||
@@ -6,9 +6,6 @@ import { redirect } from "next/navigation";
|
||||
export async function finishOnboarding() {
|
||||
const api = new BackendAPI();
|
||||
const onboarding = await api.getUserOnboarding();
|
||||
await api.updateUserOnboarding({
|
||||
completedSteps: [...onboarding.completedSteps, "CONGRATS"],
|
||||
});
|
||||
revalidatePath("/library", "layout");
|
||||
redirect("/library");
|
||||
}
|
||||
|
||||
@@ -1,24 +1,26 @@
|
||||
"use client";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { finishOnboarding } from "./actions";
|
||||
import confetti from "canvas-confetti";
|
||||
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
import * as party from "party-js";
|
||||
|
||||
export default function Page() {
|
||||
useOnboarding(7, "AGENT_INPUT");
|
||||
const { state, updateState } = useOnboarding(7, "AGENT_INPUT");
|
||||
const [showText, setShowText] = useState(false);
|
||||
const [showSubtext, setShowSubtext] = useState(false);
|
||||
const divRef = useRef(null);
|
||||
|
||||
useEffect(() => {
|
||||
confetti({
|
||||
particleCount: 120,
|
||||
spread: 360,
|
||||
shapes: ["square", "circle"],
|
||||
scalar: 2,
|
||||
decay: 0.93,
|
||||
origin: { y: 0.38, x: 0.51 },
|
||||
});
|
||||
if (divRef.current) {
|
||||
party.confetti(divRef.current, {
|
||||
count: 100,
|
||||
spread: 180,
|
||||
shapes: ["square", "circle"],
|
||||
size: party.variation.range(2, 2), // scalar: 2
|
||||
speed: party.variation.range(300, 1000),
|
||||
});
|
||||
}
|
||||
|
||||
const timer0 = setTimeout(() => {
|
||||
setShowText(true);
|
||||
@@ -29,6 +31,9 @@ export default function Page() {
|
||||
}, 500);
|
||||
|
||||
const timer2 = setTimeout(() => {
|
||||
updateState({
|
||||
completedSteps: [...(state?.completedSteps || []), "CONGRATS"],
|
||||
});
|
||||
finishOnboarding();
|
||||
}, 3000);
|
||||
|
||||
@@ -42,6 +47,7 @@ export default function Page() {
|
||||
return (
|
||||
<div className="flex h-screen w-screen flex-col items-center justify-center bg-violet-100">
|
||||
<div
|
||||
ref={divRef}
|
||||
className={cn(
|
||||
"z-10 -mb-16 text-9xl duration-500",
|
||||
showText ? "opacity-100" : "opacity-0",
|
||||
@@ -63,7 +69,7 @@ export default function Page() {
|
||||
showSubtext ? "opacity-100" : "opacity-0",
|
||||
)}
|
||||
>
|
||||
You earned 15$ for running your first agent
|
||||
You earned 3$ for running your first agent
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -5,11 +5,14 @@ export default async function OnboardingResetPage() {
|
||||
const api = new BackendAPI();
|
||||
await api.updateUserOnboarding({
|
||||
completedSteps: [],
|
||||
notificationDot: true,
|
||||
notified: [],
|
||||
usageReason: null,
|
||||
integrations: [],
|
||||
otherIntegrations: "",
|
||||
selectedStoreListingVersionId: null,
|
||||
agentInput: {},
|
||||
onboardingAgentExecutionId: null,
|
||||
});
|
||||
redirect("/onboarding/1-welcome");
|
||||
}
|
||||
|
||||
@@ -118,8 +118,7 @@ export default function CreditsPage() {
|
||||
{topupStatus === "success" && (
|
||||
<span className="text-green-500">
|
||||
Your payment was successful. Your credits will be updated
|
||||
shortly. You can click the refresh icon 🔄 in case it is not
|
||||
updated.
|
||||
shortly. Try refreshing the page in case it is not updated.
|
||||
</span>
|
||||
)}
|
||||
{topupStatus === "cancel" && (
|
||||
|
||||
@@ -131,11 +131,7 @@ export default function PrivatePage() {
|
||||
|
||||
const allCredentials = providers
|
||||
? Object.values(providers).flatMap((provider) =>
|
||||
[
|
||||
...provider.savedOAuthCredentials,
|
||||
...provider.savedApiKeys,
|
||||
...provider.savedUserPasswordCredentials,
|
||||
]
|
||||
provider.savedCredentials
|
||||
.filter((cred) => !hiddenCredentials.includes(cred.id))
|
||||
.map((credentials) => ({
|
||||
...credentials,
|
||||
|
||||
@@ -178,18 +178,24 @@ export const CustomNode = React.memo(
|
||||
return obj;
|
||||
}, []);
|
||||
|
||||
const setHardcodedValues = (values: any) => {
|
||||
updateNodeData(id, { hardcodedValues: values });
|
||||
};
|
||||
const setHardcodedValues = useCallback(
|
||||
(values: any) => {
|
||||
updateNodeData(id, { hardcodedValues: values });
|
||||
},
|
||||
[id, updateNodeData],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
isInitialSetup.current = false;
|
||||
setHardcodedValues(fillDefaults(data.hardcodedValues, data.inputSchema));
|
||||
}, []);
|
||||
|
||||
const setErrors = (errors: { [key: string]: string }) => {
|
||||
updateNodeData(id, { errors });
|
||||
};
|
||||
const setErrors = useCallback(
|
||||
(errors: { [key: string]: string }) => {
|
||||
updateNodeData(id, { errors });
|
||||
},
|
||||
[id, updateNodeData],
|
||||
);
|
||||
|
||||
const toggleOutput = (checked: boolean) => {
|
||||
setIsOutputOpen(checked);
|
||||
@@ -340,46 +346,49 @@ export const CustomNode = React.memo(
|
||||
});
|
||||
}
|
||||
};
|
||||
const handleInputChange = (path: string, value: any) => {
|
||||
const keys = parseKeys(path);
|
||||
const newValues = JSON.parse(JSON.stringify(data.hardcodedValues));
|
||||
let current = newValues;
|
||||
const handleInputChange = useCallback(
|
||||
(path: string, value: any) => {
|
||||
const keys = parseKeys(path);
|
||||
const newValues = JSON.parse(JSON.stringify(data.hardcodedValues));
|
||||
let current = newValues;
|
||||
|
||||
for (let i = 0; i < keys.length - 1; i++) {
|
||||
const { key: currentKey, index } = keys[i];
|
||||
if (index !== undefined) {
|
||||
if (!current[currentKey]) current[currentKey] = [];
|
||||
if (!current[currentKey][index]) current[currentKey][index] = {};
|
||||
current = current[currentKey][index];
|
||||
} else {
|
||||
if (!current[currentKey]) current[currentKey] = {};
|
||||
current = current[currentKey];
|
||||
for (let i = 0; i < keys.length - 1; i++) {
|
||||
const { key: currentKey, index } = keys[i];
|
||||
if (index !== undefined) {
|
||||
if (!current[currentKey]) current[currentKey] = [];
|
||||
if (!current[currentKey][index]) current[currentKey][index] = {};
|
||||
current = current[currentKey][index];
|
||||
} else {
|
||||
if (!current[currentKey]) current[currentKey] = {};
|
||||
current = current[currentKey];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const lastKey = keys[keys.length - 1];
|
||||
if (lastKey.index !== undefined) {
|
||||
if (!current[lastKey.key]) current[lastKey.key] = [];
|
||||
current[lastKey.key][lastKey.index] = value;
|
||||
} else {
|
||||
current[lastKey.key] = value;
|
||||
}
|
||||
const lastKey = keys[keys.length - 1];
|
||||
if (lastKey.index !== undefined) {
|
||||
if (!current[lastKey.key]) current[lastKey.key] = [];
|
||||
current[lastKey.key][lastKey.index] = value;
|
||||
} else {
|
||||
current[lastKey.key] = value;
|
||||
}
|
||||
|
||||
if (!isInitialSetup.current) {
|
||||
history.push({
|
||||
type: "UPDATE_INPUT",
|
||||
payload: { nodeId: id, oldValues: data.hardcodedValues, newValues },
|
||||
undo: () => setHardcodedValues(data.hardcodedValues),
|
||||
redo: () => setHardcodedValues(newValues),
|
||||
});
|
||||
}
|
||||
if (!isInitialSetup.current) {
|
||||
history.push({
|
||||
type: "UPDATE_INPUT",
|
||||
payload: { nodeId: id, oldValues: data.hardcodedValues, newValues },
|
||||
undo: () => setHardcodedValues(data.hardcodedValues),
|
||||
redo: () => setHardcodedValues(newValues),
|
||||
});
|
||||
}
|
||||
|
||||
setHardcodedValues(newValues);
|
||||
const errors = data.errors || {};
|
||||
// Remove error with the same key
|
||||
setNestedProperty(errors, path, null);
|
||||
setErrors({ ...errors });
|
||||
};
|
||||
setHardcodedValues(newValues);
|
||||
const errors = data.errors || {};
|
||||
// Remove error with the same key
|
||||
setNestedProperty(errors, path, null);
|
||||
setErrors({ ...errors });
|
||||
},
|
||||
[data.hardcodedValues, id, setHardcodedValues, data.errors, setErrors],
|
||||
);
|
||||
|
||||
const isInputHandleConnected = (key: string) => {
|
||||
return (
|
||||
@@ -407,28 +416,34 @@ export const CustomNode = React.memo(
|
||||
);
|
||||
};
|
||||
|
||||
const handleInputClick = (key: string) => {
|
||||
console.debug(`Opening modal for key: ${key}`);
|
||||
setActiveKey(key);
|
||||
const value = getValue(key, data.hardcodedValues);
|
||||
setInputModalValue(
|
||||
typeof value === "object" ? JSON.stringify(value, null, 2) : value,
|
||||
);
|
||||
setIsModalOpen(true);
|
||||
};
|
||||
const handleInputClick = useCallback(
|
||||
(key: string) => {
|
||||
console.debug(`Opening modal for key: ${key}`);
|
||||
setActiveKey(key);
|
||||
const value = getValue(key, data.hardcodedValues);
|
||||
setInputModalValue(
|
||||
typeof value === "object" ? JSON.stringify(value, null, 2) : value,
|
||||
);
|
||||
setIsModalOpen(true);
|
||||
},
|
||||
[data.hardcodedValues],
|
||||
);
|
||||
|
||||
const handleModalSave = (value: string) => {
|
||||
if (activeKey) {
|
||||
try {
|
||||
const parsedValue = JSON.parse(value);
|
||||
handleInputChange(activeKey, parsedValue);
|
||||
} catch (error) {
|
||||
handleInputChange(activeKey, value);
|
||||
const handleModalSave = useCallback(
|
||||
(value: string) => {
|
||||
if (activeKey) {
|
||||
try {
|
||||
const parsedValue = JSON.parse(value);
|
||||
handleInputChange(activeKey, parsedValue);
|
||||
} catch (error) {
|
||||
handleInputChange(activeKey, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
setIsModalOpen(false);
|
||||
setActiveKey(null);
|
||||
};
|
||||
setIsModalOpen(false);
|
||||
setActiveKey(null);
|
||||
},
|
||||
[activeKey, handleInputChange],
|
||||
);
|
||||
|
||||
const handleOutputClick = () => {
|
||||
setIsOutputModalOpen(true);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"use client";
|
||||
import React, { useCallback, useMemo } from "react";
|
||||
import { isEmpty } from "lodash";
|
||||
import moment from "moment";
|
||||
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
@@ -164,7 +165,8 @@ export default function AgentRunDetailsView({
|
||||
] satisfies ButtonAction[])
|
||||
: []),
|
||||
...(["success", "failed", "stopped"].includes(runStatus) &&
|
||||
!graph.has_webhook_trigger
|
||||
!graph.has_webhook_trigger &&
|
||||
isEmpty(graph.credentials_input_schema.required) // TODO: enable re-run with credentials - https://linear.app/autogpt/issue/SECRT-1243
|
||||
? [
|
||||
{
|
||||
label: (
|
||||
@@ -193,6 +195,7 @@ export default function AgentRunDetailsView({
|
||||
stopRun,
|
||||
deleteRun,
|
||||
graph.has_webhook_trigger,
|
||||
graph.credentials_input_schema.properties,
|
||||
agent.can_access_graph,
|
||||
run.graph_id,
|
||||
run.graph_version,
|
||||
@@ -258,11 +261,7 @@ export default function AgentRunDetailsView({
|
||||
Object.entries(agentRunInputs).map(([key, { title, value }]) => (
|
||||
<div key={key} className="flex flex-col gap-1.5">
|
||||
<label className="text-sm font-medium">{title || key}</label>
|
||||
<Input
|
||||
defaultValue={value}
|
||||
className="rounded-full"
|
||||
disabled
|
||||
/>
|
||||
<Input value={value} className="rounded-full" disabled />
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
|
||||
@@ -6,11 +6,13 @@ import { GraphExecutionID, GraphMeta } from "@/lib/autogpt-server-api";
|
||||
|
||||
import type { ButtonAction } from "@/components/agptui/types";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { CredentialsInput } from "@/components/integrations/credentials-input";
|
||||
import { TypeBasedInput } from "@/components/type-based-input";
|
||||
import { useToastOnFail } from "@/components/ui/use-toast";
|
||||
import ActionButtonGroup from "@/components/agptui/action-button-group";
|
||||
import SchemaTooltip from "@/components/SchemaTooltip";
|
||||
import { IconPlay } from "@/components/ui/icons";
|
||||
import { useOnboarding } from "../onboarding/onboarding-provider";
|
||||
|
||||
export default function AgentRunDraftView({
|
||||
graph,
|
||||
@@ -25,16 +27,32 @@ export default function AgentRunDraftView({
|
||||
const toastOnFail = useToastOnFail();
|
||||
|
||||
const agentInputs = graph.input_schema.properties;
|
||||
const agentCredentialsInputs = graph.credentials_input_schema.properties;
|
||||
const [inputValues, setInputValues] = useState<Record<string, any>>({});
|
||||
|
||||
const doRun = useCallback(
|
||||
() =>
|
||||
api
|
||||
.executeGraph(graph.id, graph.version, inputValues)
|
||||
.then((newRun) => onRun(newRun.graph_exec_id))
|
||||
.catch(toastOnFail("execute agent")),
|
||||
[api, graph, inputValues, onRun, toastOnFail],
|
||||
const [inputCredentials, setInputCredentials] = useState<Record<string, any>>(
|
||||
{},
|
||||
);
|
||||
const { state, completeStep } = useOnboarding();
|
||||
|
||||
const doRun = useCallback(() => {
|
||||
api
|
||||
.executeGraph(graph.id, graph.version, inputValues, inputCredentials)
|
||||
.then((newRun) => onRun(newRun.graph_exec_id))
|
||||
.catch(toastOnFail("execute agent"));
|
||||
// Mark run agent onboarding step as completed
|
||||
if (state?.completedSteps.includes("MARKETPLACE_ADD_AGENT")) {
|
||||
completeStep("MARKETPLACE_RUN_AGENT");
|
||||
}
|
||||
}, [
|
||||
api,
|
||||
graph,
|
||||
inputValues,
|
||||
inputCredentials,
|
||||
onRun,
|
||||
toastOnFail,
|
||||
state,
|
||||
completeStep,
|
||||
]);
|
||||
|
||||
const runActions: ButtonAction[] = useMemo(
|
||||
() => [
|
||||
@@ -60,6 +78,26 @@ export default function AgentRunDraftView({
|
||||
<CardTitle className="font-poppins text-lg">Input</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent className="flex flex-col gap-4">
|
||||
{/* Credentials inputs */}
|
||||
{Object.entries(agentCredentialsInputs).map(
|
||||
([key, inputSubSchema]) => (
|
||||
<CredentialsInput
|
||||
key={key}
|
||||
schema={{ ...inputSubSchema, discriminator: undefined }}
|
||||
selectedCredentials={
|
||||
inputCredentials[key] ?? inputSubSchema.default
|
||||
}
|
||||
onSelectCredentials={(value) =>
|
||||
setInputCredentials((obj) => ({
|
||||
...obj,
|
||||
[key]: value,
|
||||
}))
|
||||
}
|
||||
/>
|
||||
),
|
||||
)}
|
||||
|
||||
{/* Regular inputs */}
|
||||
{Object.entries(agentInputs).map(([key, inputSubSchema]) => (
|
||||
<div key={key} className="flex flex-col space-y-2">
|
||||
<label className="flex items-center gap-1 text-sm font-medium">
|
||||
|
||||
@@ -84,7 +84,7 @@ export default function AgentRunsSelectorList({
|
||||
>
|
||||
<span>Scheduled</span>
|
||||
<span className="text-neutral-600">
|
||||
{schedules.filter((s) => s.graph_id === agent.agent_id).length}
|
||||
{schedules.filter((s) => s.graph_id === agent.graph_id).length}
|
||||
</span>
|
||||
</Badge>
|
||||
</div>
|
||||
@@ -127,7 +127,7 @@ export default function AgentRunsSelectorList({
|
||||
/>
|
||||
))
|
||||
: schedules
|
||||
.filter((schedule) => schedule.graph_id === agent.agent_id)
|
||||
.filter((schedule) => schedule.graph_id === agent.graph_id)
|
||||
.map((schedule) => (
|
||||
<AgentRunSummaryCard
|
||||
className="h-28 w-72 lg:h-32 xl:w-80"
|
||||
|
||||
@@ -109,11 +109,7 @@ export default function AgentScheduleDetailsView({
|
||||
Object.entries(agentRunInputs).map(([key, { title, value }]) => (
|
||||
<div key={key} className="flex flex-col gap-1.5">
|
||||
<label className="text-sm font-medium">{title || key}</label>
|
||||
<Input
|
||||
defaultValue={value}
|
||||
className="rounded-full"
|
||||
disabled
|
||||
/>
|
||||
<Input value={value} className="rounded-full" disabled />
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
|
||||
@@ -10,6 +10,7 @@ import { useToast } from "@/components/ui/use-toast";
|
||||
|
||||
import useSupabase from "@/hooks/useSupabase";
|
||||
import { DownloadIcon, LoaderIcon } from "lucide-react";
|
||||
import { useOnboarding } from "../onboarding/onboarding-provider";
|
||||
interface AgentInfoProps {
|
||||
name: string;
|
||||
creator: string;
|
||||
@@ -39,6 +40,7 @@ export const AgentInfo: React.FC<AgentInfoProps> = ({
|
||||
const api = React.useMemo(() => new BackendAPI(), []);
|
||||
const { user } = useSupabase();
|
||||
const { toast } = useToast();
|
||||
const { completeStep } = useOnboarding();
|
||||
|
||||
const [downloading, setDownloading] = React.useState(false);
|
||||
|
||||
@@ -47,6 +49,7 @@ export const AgentInfo: React.FC<AgentInfoProps> = ({
|
||||
const newLibraryAgent = await api.addMarketplaceAgentToLibrary(
|
||||
storeListingVersionId,
|
||||
);
|
||||
completeStep("MARKETPLACE_ADD_AGENT");
|
||||
router.push(`/library/agents/${newLibraryAgent.id}`);
|
||||
} catch (error) {
|
||||
console.error("Failed to add agent to library:", error);
|
||||
@@ -170,10 +173,10 @@ export const AgentInfo: React.FC<AgentInfoProps> = ({
|
||||
|
||||
{/* Description Section */}
|
||||
<div className="mb-4 w-full lg:mb-[36px]">
|
||||
<div className="font-geist decoration-skip-ink-none mb-1.5 text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
|
||||
<div className="mb-1.5 font-sans text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
|
||||
Description
|
||||
</div>
|
||||
<div className="font-geist decoration-skip-ink-none text-base font-normal leading-6 text-neutral-600 underline-offset-[from-font] dark:text-neutral-400">
|
||||
<div className="whitespace-pre-line font-sans text-base font-normal leading-6 text-neutral-600 dark:text-neutral-400">
|
||||
{longDescription}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -21,17 +21,14 @@ export const BecomeACreator: React.FC<BecomeACreatorProps> = ({
|
||||
|
||||
return (
|
||||
<div className="relative mx-auto h-auto min-h-[300px] w-full max-w-[1360px] md:min-h-[400px] lg:h-[459px]">
|
||||
{/* Top border */}
|
||||
<div className="left-0 top-0 h-px w-full bg-gray-200 dark:bg-gray-700" />
|
||||
|
||||
{/* Title */}
|
||||
<h2 className="mb-8 text-left font-poppins text-[18px] font-[600] leading-9 text-neutral-800 dark:text-neutral-200">
|
||||
<h2 className="mb-[77px] font-poppins text-[18px] font-semibold leading-[28px] text-neutral-800 dark:text-neutral-200">
|
||||
{title}
|
||||
</h2>
|
||||
|
||||
{/* Content Container */}
|
||||
<div className="m-auto w-full max-w-[900px] px-4 py-16 text-center md:px-6 lg:px-0">
|
||||
<h2 className="underline-from-font decoration-skip-ink-none mb-6 text-center font-poppins text-[48px] font-semibold leading-[54px] tracking-[-0.012em] text-neutral-950 dark:text-neutral-50 md:mb-8 lg:mb-12">
|
||||
<div className="mx-auto w-full max-w-[900px] px-4 text-center md:px-6 lg:px-0">
|
||||
<h2 className="mb-6 text-center font-poppins text-[48px] font-semibold leading-[54px] tracking-[-0.012em] text-neutral-950 dark:text-neutral-50 md:mb-8 lg:mb-12">
|
||||
Build AI agents and share
|
||||
<br />
|
||||
<span className="text-violet-600 dark:text-violet-400">
|
||||
|
||||
@@ -22,7 +22,7 @@ export const BreadCrumbs: React.FC<BreadCrumbsProps> = ({ items }) => {
|
||||
<button className="flex h-12 w-12 items-center justify-center rounded-full border border-neutral-200 transition-colors hover:bg-neutral-50 dark:border-neutral-700 dark:hover:bg-neutral-800">
|
||||
<IconRightArrow className="h-5 w-5 text-neutral-900 dark:text-neutral-100" />
|
||||
</button> */}
|
||||
<div className="flex h-auto flex-wrap items-center justify-start gap-4 rounded-[5rem] bg-white dark:bg-transparent">
|
||||
<div className="flex h-auto flex-wrap items-center justify-start gap-4 rounded-[5rem] dark:bg-transparent">
|
||||
{items.map((item, index) => (
|
||||
<React.Fragment key={index}>
|
||||
<Link href={item.link}>
|
||||
|
||||
@@ -33,7 +33,10 @@ export const CreatorInfoCard: React.FC<CreatorInfoCardProps> = ({
|
||||
src={avatarSrc}
|
||||
alt={`${username}'s avatar`}
|
||||
/>
|
||||
<AvatarFallback className="h-[100px] w-[100px] sm:h-[130px] sm:w-[130px]">
|
||||
<AvatarFallback
|
||||
size={130}
|
||||
className="h-[100px] w-[100px] sm:h-[130px] sm:w-[130px]"
|
||||
>
|
||||
{username.charAt(0)}
|
||||
</AvatarFallback>
|
||||
</Avatar>
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import CreditsCard from "./CreditsCard";
|
||||
import { userEvent, within } from "@storybook/test";
|
||||
|
||||
const meta: Meta<typeof CreditsCard> = {
|
||||
title: "AGPT UI/Credits Card",
|
||||
component: CreditsCard,
|
||||
tags: ["autodocs"],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof CreditsCard>;
|
||||
|
||||
export const Default: Story = {
|
||||
args: {
|
||||
credits: 0,
|
||||
},
|
||||
};
|
||||
|
||||
export const SmallNumber: Story = {
|
||||
args: {
|
||||
credits: 10,
|
||||
},
|
||||
};
|
||||
|
||||
export const LargeNumber: Story = {
|
||||
args: {
|
||||
credits: 1000000,
|
||||
},
|
||||
};
|
||||
|
||||
export const InteractionTest: Story = {
|
||||
args: {
|
||||
credits: 100,
|
||||
},
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
const refreshButton = canvas.getByRole("button", {
|
||||
name: /refresh credits/i,
|
||||
});
|
||||
await userEvent.click(refreshButton);
|
||||
},
|
||||
};
|
||||
@@ -1,48 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { IconRefresh } from "@/components/ui/icons";
|
||||
import { useState } from "react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import useCredits from "@/hooks/useCredits";
|
||||
|
||||
const CreditsCard = () => {
|
||||
const { credits, formatCredits, fetchCredits } = useCredits({
|
||||
fetchInitialCredits: true,
|
||||
});
|
||||
const api = useBackendAPI();
|
||||
|
||||
const onRefresh = async () => {
|
||||
fetchCredits();
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="inline-flex h-[48px] items-center gap-2.5 rounded-2xl bg-neutral-200 p-4 dark:bg-neutral-800">
|
||||
<div className="flex items-center gap-0.5">
|
||||
<span className="p-ui-semibold text-base leading-7 text-neutral-900 dark:text-neutral-50">
|
||||
Balance: {formatCredits(credits)}
|
||||
</span>
|
||||
</div>
|
||||
<Tooltip key="RefreshCredits" delayDuration={500}>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
onClick={onRefresh}
|
||||
className="h-6 w-6 transition-colors hover:text-neutral-700 dark:hover:text-neutral-300"
|
||||
aria-label="Refresh credits"
|
||||
>
|
||||
<IconRefresh className="h-6 w-6" />
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
<p>Refresh credits</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default CreditsCard;
|
||||
@@ -27,7 +27,7 @@ export const FeaturedAgentCard: React.FC<FeaturedStoreCardProps> = ({
|
||||
data-testid="featured-store-card"
|
||||
onMouseEnter={() => setIsHovered(true)}
|
||||
onMouseLeave={() => setIsHovered(false)}
|
||||
className={`flex h-full flex-col ${backgroundColor}`}
|
||||
className={`flex h-full flex-col ${backgroundColor} rounded-[1.5rem] border-none`}
|
||||
>
|
||||
<CardHeader>
|
||||
<CardTitle className="line-clamp-2 text-base sm:text-xl">
|
||||
|
||||
@@ -4,7 +4,7 @@ import { ProfilePopoutMenu } from "./ProfilePopoutMenu";
|
||||
import { IconType, IconLogIn, IconAutoGPTLogo } from "@/components/ui/icons";
|
||||
import { MobileNavBar } from "./MobileNavBar";
|
||||
import { Button } from "./Button";
|
||||
import CreditsCard from "./CreditsCard";
|
||||
import Wallet from "./Wallet";
|
||||
import { ProfileDetails } from "@/lib/autogpt-server-api/types";
|
||||
import { NavbarLink } from "./NavbarLink";
|
||||
import getServerUser from "@/lib/supabase/getServerUser";
|
||||
@@ -61,7 +61,7 @@ export const Navbar = async ({ links, menuItemGroups }: NavbarProps) => {
|
||||
<div className="flex items-center gap-4">
|
||||
{isLoggedIn ? (
|
||||
<div className="flex items-center gap-4">
|
||||
{profile && <CreditsCard />}
|
||||
{profile && <Wallet />}
|
||||
<ProfilePopoutMenu
|
||||
menuItemGroups={menuItemGroups}
|
||||
userName={profile?.username}
|
||||
|
||||
@@ -32,7 +32,7 @@ export const StoreCard: React.FC<StoreCardProps> = ({
|
||||
|
||||
return (
|
||||
<div
|
||||
className="inline-flex w-full max-w-[434px] cursor-pointer flex-col items-start justify-start gap-2.5 rounded-[26px] bg-white transition-all duration-300 hover:shadow-lg dark:bg-transparent dark:hover:shadow-gray-700"
|
||||
className="flex h-[27rem] w-full max-w-md cursor-pointer flex-col items-start rounded-3xl bg-white transition-all duration-300 hover:shadow-lg dark:bg-transparent dark:hover:shadow-gray-700"
|
||||
onClick={handleClick}
|
||||
data-testid="store-card"
|
||||
role="button"
|
||||
@@ -44,8 +44,8 @@ export const StoreCard: React.FC<StoreCardProps> = ({
|
||||
}
|
||||
}}
|
||||
>
|
||||
{/* Header Image Section with Avatar */}
|
||||
<div className="relative h-[200px] w-full overflow-hidden rounded-[20px]">
|
||||
{/* First Section: Image with Avatar */}
|
||||
<div className="relative aspect-[2/1.2] w-full overflow-hidden rounded-3xl md:aspect-[2.17/1]">
|
||||
{agentImage && (
|
||||
<Image
|
||||
src={agentImage}
|
||||
@@ -57,14 +57,14 @@ export const StoreCard: React.FC<StoreCardProps> = ({
|
||||
)}
|
||||
{!hideAvatar && (
|
||||
<div className="absolute bottom-4 left-4">
|
||||
<Avatar className="h-16 w-16 border-2 border-white dark:border-gray-800">
|
||||
<Avatar className="h-16 w-16">
|
||||
{avatarSrc && (
|
||||
<AvatarImage
|
||||
src={avatarSrc}
|
||||
alt={`${creatorName || agentName} creator avatar`}
|
||||
/>
|
||||
)}
|
||||
<AvatarFallback>
|
||||
<AvatarFallback size={64}>
|
||||
{(creatorName || agentName).charAt(0)}
|
||||
</AvatarFallback>
|
||||
</Avatar>
|
||||
@@ -72,37 +72,46 @@ export const StoreCard: React.FC<StoreCardProps> = ({
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Content Section */}
|
||||
<div className="w-full px-2 py-4">
|
||||
{/* Title and Creator */}
|
||||
<h3 className="mb-0.5 font-poppins text-2xl font-semibold text-[#272727] dark:text-neutral-100">
|
||||
{agentName}
|
||||
</h3>
|
||||
{!hideAvatar && creatorName && (
|
||||
<p className="mb-2.5 font-sans text-xl font-normal text-neutral-600 dark:text-neutral-400">
|
||||
by {creatorName}
|
||||
</p>
|
||||
)}
|
||||
{/* Description */}
|
||||
<p className="mb-4 font-sans text-base font-normal leading-normal text-neutral-600 dark:text-neutral-400">
|
||||
{description}
|
||||
</p>
|
||||
<div className="mt-3 flex w-full flex-1 flex-col px-4">
|
||||
{/* Second Section: Agent Name and Creator Name */}
|
||||
<div className="flex w-full flex-col">
|
||||
<h3 className="line-clamp-2 font-poppins text-2xl font-semibold text-[#272727] dark:text-neutral-100">
|
||||
{agentName}
|
||||
</h3>
|
||||
{!hideAvatar && creatorName && (
|
||||
<p className="mt-3 truncate font-sans text-xl font-normal text-neutral-600 dark:text-neutral-400">
|
||||
by {creatorName}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Stats Row */}
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="font-sans text-lg font-semibold text-neutral-800 dark:text-neutral-200">
|
||||
{runs.toLocaleString()} runs
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-geist text-lg font-semibold text-neutral-800 dark:text-neutral-200">
|
||||
{rating.toFixed(1)}
|
||||
</span>
|
||||
<div
|
||||
className="inline-flex items-center"
|
||||
role="img"
|
||||
aria-label={`Rating: ${rating.toFixed(1)} out of 5 stars`}
|
||||
>
|
||||
{StarRatingIcons(rating)}
|
||||
{/* Third Section: Description */}
|
||||
<div className="mt-2.5 flex w-full flex-col">
|
||||
<p className="line-clamp-3 font-sans text-base font-normal leading-normal text-neutral-600 dark:text-neutral-400">
|
||||
{description}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="flex-grow" />
|
||||
{/* Spacer to push stats to bottom */}
|
||||
|
||||
{/* Fourth Section: Stats Row - aligned to bottom */}
|
||||
<div className="mt-5 w-full">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="font-sans text-lg font-semibold text-neutral-800 dark:text-neutral-200">
|
||||
{runs.toLocaleString()} runs
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-sans text-lg font-semibold text-neutral-800 dark:text-neutral-200">
|
||||
{rating.toFixed(1)}
|
||||
</span>
|
||||
<div
|
||||
className="inline-flex items-center"
|
||||
role="img"
|
||||
aria-label={`Rating: ${rating.toFixed(1)} out of 5 stars`}
|
||||
>
|
||||
{StarRatingIcons(rating)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
120
autogpt_platform/frontend/src/components/agptui/Wallet.tsx
Normal file
120
autogpt_platform/frontend/src/components/agptui/Wallet.tsx
Normal file
@@ -0,0 +1,120 @@
|
||||
"use client";
|
||||
|
||||
import useCredits from "@/hooks/useCredits";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/ui/popover";
|
||||
import { X } from "lucide-react";
|
||||
import { PopoverClose } from "@radix-ui/react-popover";
|
||||
import { TaskGroups } from "../onboarding/WalletTaskGroups";
|
||||
import { ScrollArea } from "../ui/scroll-area";
|
||||
import { useOnboarding } from "../onboarding/onboarding-provider";
|
||||
import { useCallback, useEffect, useRef } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import * as party from "party-js";
|
||||
import WalletRefill from "./WalletRefill";
|
||||
|
||||
export default function Wallet() {
|
||||
const { credits, formatCredits, fetchCredits } = useCredits({
|
||||
fetchInitialCredits: true,
|
||||
});
|
||||
const { state, updateState } = useOnboarding();
|
||||
const walletRef = useRef<HTMLButtonElement | null>(null);
|
||||
|
||||
const onWalletOpen = useCallback(async () => {
|
||||
if (state?.notificationDot) {
|
||||
updateState({ notificationDot: false });
|
||||
}
|
||||
// Refresh credits when the wallet is opened
|
||||
fetchCredits();
|
||||
}, [state?.notificationDot, updateState, fetchCredits]);
|
||||
|
||||
const fadeOut = new party.ModuleBuilder()
|
||||
.drive("opacity")
|
||||
.by((t) => 1 - t)
|
||||
.through("lifetime")
|
||||
.build();
|
||||
|
||||
useEffect(() => {
|
||||
// Check if there are any completed tasks (state?.completedTasks) that
|
||||
// are not in the state?.notified array and play confetti if so
|
||||
const pending = state?.completedSteps
|
||||
.filter((step) => !state?.notified.includes(step))
|
||||
// Ignore steps that are not relevant for notifications
|
||||
.filter(
|
||||
(step) =>
|
||||
step !== "WELCOME" &&
|
||||
step !== "USAGE_REASON" &&
|
||||
step !== "INTEGRATIONS" &&
|
||||
step !== "AGENT_CHOICE" &&
|
||||
step !== "AGENT_NEW_RUN" &&
|
||||
step !== "AGENT_INPUT",
|
||||
);
|
||||
if ((pending?.length || 0) > 0 && walletRef.current) {
|
||||
party.confetti(walletRef.current, {
|
||||
count: 30,
|
||||
spread: 120,
|
||||
shapes: ["square", "circle"],
|
||||
size: party.variation.range(1, 2),
|
||||
speed: party.variation.range(200, 300),
|
||||
modules: [fadeOut],
|
||||
});
|
||||
}
|
||||
}, [state?.completedSteps, state?.notified]);
|
||||
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<button
|
||||
ref={walletRef}
|
||||
className="relative flex items-center gap-1 rounded-md bg-zinc-200 px-3 py-2 text-sm transition-colors duration-200 hover:bg-zinc-300"
|
||||
onClick={onWalletOpen}
|
||||
>
|
||||
Wallet{" "}
|
||||
<span className="text-sm font-semibold">
|
||||
{formatCredits(credits)}
|
||||
</span>
|
||||
{state?.notificationDot && (
|
||||
<span className="absolute right-1 top-1 h-2 w-2 rounded-full bg-violet-600"></span>
|
||||
)}
|
||||
</button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
className={cn(
|
||||
"absolute -right-[7.9rem] -top-[3.2rem] z-50 w-[28.5rem] px-[0.625rem] py-2",
|
||||
"rounded-xl border-zinc-200 bg-zinc-50 shadow-[0_3px_3px] shadow-zinc-300",
|
||||
)}
|
||||
>
|
||||
{/* Header */}
|
||||
<div className="mx-1 flex items-center justify-between border-b border-zinc-300 pb-2">
|
||||
<span className="font-poppins font-medium text-zinc-900">
|
||||
Your wallet
|
||||
</span>
|
||||
<div className="flex items-center font-inter text-sm font-semibold text-violet-700">
|
||||
<div className="rounded-lg bg-violet-100 px-3 py-2">
|
||||
Wallet{" "}
|
||||
<span className="font-semibold">{formatCredits(credits)}</span>
|
||||
</div>
|
||||
<PopoverClose>
|
||||
<X className="ml-[2.8rem] h-5 w-5 text-zinc-800 hover:text-foreground" />
|
||||
</PopoverClose>
|
||||
</div>
|
||||
</div>
|
||||
<ScrollArea className="max-h-[85vh] overflow-y-auto">
|
||||
{/* Top ups */}
|
||||
<WalletRefill />
|
||||
{/* Tasks */}
|
||||
<p className="mx-1 mt-4 font-sans text-xs font-medium text-violet-700">
|
||||
Onboarding tasks
|
||||
</p>
|
||||
<p className="mx-1 my-1 font-sans text-xs font-normal text-zinc-500">
|
||||
Complete the following tasks to earn more credits!
|
||||
</p>
|
||||
<TaskGroups />
|
||||
</ScrollArea>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
265
autogpt_platform/frontend/src/components/agptui/WalletRefill.tsx
Normal file
265
autogpt_platform/frontend/src/components/agptui/WalletRefill.tsx
Normal file
@@ -0,0 +1,265 @@
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { z } from "zod";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import { Input } from "../ui/input";
|
||||
import Link from "next/link";
|
||||
import { useToast, useToastOnFail } from "../ui/use-toast";
|
||||
import useCredits from "@/hooks/useCredits";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
const topUpSchema = z.object({
|
||||
amount: z
|
||||
.number({ coerce: true, invalid_type_error: "Enter top-up amount" })
|
||||
.min(5, "Top-ups start at $5. Please enter a higher amount."),
|
||||
});
|
||||
|
||||
const autoRefillSchema = z
|
||||
.object({
|
||||
threshold: z
|
||||
.number({ coerce: true, invalid_type_error: "Enter min. balance" })
|
||||
.min(
|
||||
5,
|
||||
"Looks like your balance is too low for auto-refill. Try $5 or more.",
|
||||
),
|
||||
refillAmount: z
|
||||
.number({ coerce: true, invalid_type_error: "Enter top-up amount" })
|
||||
.min(5, "Top-ups start at $5. Please enter a higher amount."),
|
||||
})
|
||||
.refine((data) => data.refillAmount >= data.threshold, {
|
||||
message:
|
||||
"Your refill amount must be equal to or greater than the balance you entered above.",
|
||||
path: ["refillAmount"],
|
||||
});
|
||||
|
||||
export default function WalletRefill() {
|
||||
const { toast } = useToast();
|
||||
const toastOnFail = useToastOnFail();
|
||||
const { requestTopUp, autoTopUpConfig, updateAutoTopUpConfig } = useCredits({
|
||||
fetchInitialAutoTopUpConfig: true,
|
||||
});
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
const topUpForm = useForm<z.infer<typeof topUpSchema>>({
|
||||
resolver: zodResolver(topUpSchema),
|
||||
});
|
||||
const autoRefillForm = useForm<z.infer<typeof autoRefillSchema>>({
|
||||
resolver: zodResolver(autoRefillSchema),
|
||||
});
|
||||
|
||||
console.log("autoRefillForm");
|
||||
|
||||
// Pre-fill the auto-refill form with existing values
|
||||
useEffect(() => {
|
||||
const values = autoRefillForm.getValues();
|
||||
if (
|
||||
autoTopUpConfig &&
|
||||
autoTopUpConfig.amount > 0 &&
|
||||
autoTopUpConfig.threshold > 0 &&
|
||||
!autoRefillForm.getFieldState("threshold").isTouched &&
|
||||
!autoRefillForm.getFieldState("refillAmount").isTouched
|
||||
) {
|
||||
autoRefillForm.setValue("threshold", autoTopUpConfig.threshold / 100);
|
||||
autoRefillForm.setValue("refillAmount", autoTopUpConfig.amount / 100);
|
||||
}
|
||||
}, [autoTopUpConfig, autoRefillForm]);
|
||||
|
||||
const submitTopUp = useCallback(
|
||||
async (data: z.infer<typeof topUpSchema>) => {
|
||||
setIsLoading(true);
|
||||
await requestTopUp(data.amount * 100).catch(
|
||||
toastOnFail("request top-up"),
|
||||
);
|
||||
setIsLoading(false);
|
||||
},
|
||||
[requestTopUp, toastOnFail],
|
||||
);
|
||||
|
||||
const submitAutoTopUpConfig = useCallback(
|
||||
async (data: z.infer<typeof autoRefillSchema>) => {
|
||||
setIsLoading(true);
|
||||
await updateAutoTopUpConfig(data.refillAmount * 100, data.threshold * 100)
|
||||
.then(() => {
|
||||
toast({ title: "Auto top-up config updated! 🎉" });
|
||||
})
|
||||
.catch(toastOnFail("update auto top-up config"));
|
||||
setIsLoading(false);
|
||||
},
|
||||
[updateAutoTopUpConfig, toast, toastOnFail],
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="mx-1 border-b border-zinc-300">
|
||||
<p className="mx-0 mt-4 font-sans text-xs font-medium text-violet-700">
|
||||
Add credits to your balance
|
||||
</p>
|
||||
<p className="mx-0 my-1 font-sans text-xs font-normal text-zinc-500">
|
||||
Choose a one-time top-up or set up automatic refills
|
||||
</p>
|
||||
<Tabs
|
||||
defaultValue="top-up"
|
||||
className="mb-6 mt-4 flex w-full flex-col items-center"
|
||||
>
|
||||
<TabsList className="mx-auto">
|
||||
<TabsTrigger value="top-up">One-time top up</TabsTrigger>
|
||||
<TabsTrigger value="auto-refill">Auto-refill</TabsTrigger>
|
||||
</TabsList>
|
||||
<div className="mt-4 w-full rounded-lg px-5 outline outline-1 outline-offset-2 outline-zinc-200">
|
||||
<TabsContent value="top-up" className="flex flex-col">
|
||||
<div className="mt-2 justify-start font-sans text-sm font-medium leading-snug text-zinc-900">
|
||||
One-time top-up
|
||||
</div>
|
||||
<div className="mt-1 justify-start font-sans text-xs font-normal leading-tight text-zinc-500">
|
||||
Enter an amount (min. $5) and add credits instantly.
|
||||
</div>
|
||||
<Form {...topUpForm}>
|
||||
<form onSubmit={topUpForm.handleSubmit(submitTopUp)}>
|
||||
<FormField
|
||||
control={topUpForm.control}
|
||||
name="amount"
|
||||
render={({ field }) => (
|
||||
<FormItem className="mb-6 mt-4">
|
||||
<FormLabel className="font-sans text-sm font-medium leading-snug text-zinc-800">
|
||||
Amount
|
||||
</FormLabel>
|
||||
<FormControl>
|
||||
<>
|
||||
<Input
|
||||
className={cn(
|
||||
"mt-2 rounded-3xl border-0 bg-white py-2 pl-6 pr-4 font-sans outline outline-1 outline-zinc-300",
|
||||
"focus:outline-2 focus:outline-offset-0 focus:outline-violet-700",
|
||||
)}
|
||||
type="number"
|
||||
step="1"
|
||||
{...field}
|
||||
/>
|
||||
<span className="absolute left-10 -translate-y-9 text-sm text-zinc-500">
|
||||
$
|
||||
</span>
|
||||
</>
|
||||
</FormControl>
|
||||
<FormMessage className="mt-2 font-sans text-xs font-normal leading-tight" />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<button
|
||||
className={cn(
|
||||
"mb-2 inline-flex h-10 w-24 items-center justify-center rounded-3xl bg-zinc-800 px-4 py-2",
|
||||
"font-sans text-sm font-medium leading-snug text-white",
|
||||
"transition-colors duration-200 hover:bg-zinc-700 disabled:bg-zinc-500",
|
||||
)}
|
||||
type="submit"
|
||||
disabled={isLoading}
|
||||
>
|
||||
Top up
|
||||
</button>
|
||||
</form>
|
||||
</Form>
|
||||
</TabsContent>
|
||||
<TabsContent value="auto-refill" className="flex flex-col">
|
||||
<div className="justify-start font-sans text-sm font-medium leading-snug text-zinc-900">
|
||||
Auto-refill
|
||||
</div>
|
||||
<div className="mt-1 justify-start font-sans text-xs font-normal leading-tight text-zinc-500">
|
||||
Choose a one-time top-up or set up automatic refills.
|
||||
</div>
|
||||
|
||||
<Form {...autoRefillForm}>
|
||||
<form
|
||||
onSubmit={autoRefillForm.handleSubmit(submitAutoTopUpConfig)}
|
||||
>
|
||||
<FormField
|
||||
control={autoRefillForm.control}
|
||||
name="threshold"
|
||||
render={({ field }) => (
|
||||
<FormItem className="mb-6 mt-4">
|
||||
<FormLabel className="font-sans text-sm font-medium leading-snug text-zinc-800">
|
||||
Refill when balance drops below:
|
||||
</FormLabel>
|
||||
<FormControl>
|
||||
<>
|
||||
<Input
|
||||
className={cn(
|
||||
"mt-2 rounded-3xl border-0 bg-white py-2 pl-6 pr-4 font-sans outline outline-1 outline-zinc-300",
|
||||
"focus:outline-2 focus:outline-offset-0 focus:outline-violet-700",
|
||||
)}
|
||||
type="number"
|
||||
step="1"
|
||||
{...field}
|
||||
/>
|
||||
<span className="absolute left-10 -translate-y-9 text-sm text-zinc-500">
|
||||
$
|
||||
</span>
|
||||
</>
|
||||
</FormControl>
|
||||
<FormMessage className="mt-2 font-sans text-xs font-normal leading-tight" />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={autoRefillForm.control}
|
||||
name="refillAmount"
|
||||
render={({ field }) => (
|
||||
<FormItem className="mb-6">
|
||||
<FormLabel className="font-sans text-sm font-medium leading-snug text-zinc-800">
|
||||
Add this amount:
|
||||
</FormLabel>
|
||||
<FormControl>
|
||||
<>
|
||||
<Input
|
||||
className={cn(
|
||||
"mt-2 rounded-3xl border-0 bg-white py-2 pl-6 pr-4 font-sans outline outline-1 outline-zinc-300",
|
||||
"focus:outline-2 focus:outline-offset-0 focus:outline-violet-700",
|
||||
)}
|
||||
type="number"
|
||||
step="1"
|
||||
{...field}
|
||||
/>
|
||||
<span className="absolute left-10 -translate-y-9 text-sm text-zinc-500">
|
||||
$
|
||||
</span>
|
||||
</>
|
||||
</FormControl>
|
||||
<FormMessage className="mt-2 font-sans text-xs font-normal leading-tight" />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<button
|
||||
className={cn(
|
||||
"mb-4 inline-flex h-10 w-40 items-center justify-center rounded-3xl bg-zinc-800 px-4 py-2",
|
||||
"font-sans text-sm font-medium leading-snug text-white",
|
||||
"transition-colors duration-200 hover:bg-zinc-700 disabled:bg-zinc-500",
|
||||
)}
|
||||
type="submit"
|
||||
disabled={isLoading}
|
||||
>
|
||||
Enable Auto-refill
|
||||
</button>
|
||||
</form>
|
||||
</Form>
|
||||
</TabsContent>
|
||||
<div className="mb-3 justify-start font-sans text-xs font-normal leading-tight">
|
||||
<span className="text-zinc-500">
|
||||
To update your billing details, head to{" "}
|
||||
</span>
|
||||
<Link
|
||||
href="/profile/credits"
|
||||
className="cursor-pointer text-zinc-800 underline"
|
||||
>
|
||||
Billing settings
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
</Tabs>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -25,12 +25,14 @@ interface AgentsSectionProps {
|
||||
sectionTitle: string;
|
||||
agents: Agent[];
|
||||
hideAvatars?: boolean;
|
||||
margin?: string;
|
||||
}
|
||||
|
||||
export const AgentsSection: React.FC<AgentsSectionProps> = ({
|
||||
sectionTitle,
|
||||
agents: allAgents,
|
||||
hideAvatars = false,
|
||||
margin = "37px",
|
||||
}) => {
|
||||
const router = useRouter();
|
||||
|
||||
@@ -44,9 +46,11 @@ export const AgentsSection: React.FC<AgentsSectionProps> = ({
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center pb-4 lg:pb-8">
|
||||
<div className="flex flex-col items-center justify-center">
|
||||
<div className="w-full max-w-[1360px]">
|
||||
<div className="mb-8 font-poppins text-[18px] font-[600] leading-9 text-neutral-800 dark:text-neutral-200">
|
||||
<div
|
||||
className={`mb-[${margin}] font-poppins text-lg font-semibold text-[#282828] dark:text-neutral-200`}
|
||||
>
|
||||
{sectionTitle}
|
||||
</div>
|
||||
{!displayedAgents || displayedAgents.length === 0 ? (
|
||||
@@ -81,7 +85,7 @@ export const AgentsSection: React.FC<AgentsSectionProps> = ({
|
||||
</CarouselContent>
|
||||
</Carousel>
|
||||
|
||||
<div className="hidden grid-cols-1 place-items-center gap-6 md:grid md:grid-cols-2 lg:grid-cols-4">
|
||||
<div className="hidden grid-cols-1 place-items-center gap-6 md:grid md:grid-cols-2 lg:grid-cols-3 2xl:grid-cols-4">
|
||||
{displayedAgents.map((agent, index) => (
|
||||
<StoreCard
|
||||
key={index}
|
||||
|
||||
@@ -31,9 +31,9 @@ export const FeaturedCreators: React.FC<FeaturedCreatorsProps> = ({
|
||||
const displayedCreators = featuredCreators.slice(0, 4);
|
||||
|
||||
return (
|
||||
<div className="flex w-full flex-col items-center justify-center py-16">
|
||||
<div className="flex w-full flex-col items-center justify-center">
|
||||
<div className="w-full max-w-[1360px]">
|
||||
<h2 className="mb-8 text-left font-poppins text-[18px] font-[600] leading-9 text-neutral-800 dark:text-neutral-200">
|
||||
<h2 className="mb-9 font-poppins text-lg font-semibold text-neutral-800 dark:text-neutral-200">
|
||||
{title}
|
||||
</h2>
|
||||
|
||||
|
||||
@@ -46,8 +46,8 @@ export const FeaturedSection: React.FC<FeaturedSectionProps> = ({
|
||||
};
|
||||
|
||||
return (
|
||||
<section className="mx-auto w-full max-w-7xl px-4 pb-16">
|
||||
<h2 className="mb-8 text-left font-poppins text-[18px] font-[600] leading-9 text-neutral-800 dark:text-neutral-200">
|
||||
<section className="w-full">
|
||||
<h2 className="mb-8 font-poppins text-2xl font-semibold leading-7 text-neutral-800 dark:text-neutral-200">
|
||||
Featured agents
|
||||
</h2>
|
||||
|
||||
|
||||
@@ -4,9 +4,16 @@ import * as React from "react";
|
||||
import { SearchBar } from "@/components/agptui/SearchBar";
|
||||
import { FilterChips } from "@/components/agptui/FilterChips";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
|
||||
export const HeroSection: React.FC = () => {
|
||||
const router = useRouter();
|
||||
const { completeStep } = useOnboarding();
|
||||
|
||||
// Mark marketplace visit task as completed
|
||||
React.useEffect(() => {
|
||||
completeStep("MARKETPLACE_VISIT");
|
||||
}, [completeStep]);
|
||||
|
||||
function onFilterChange(selectedFilters: string[]) {
|
||||
const encodedTerm = encodeURIComponent(selectedFilters.join(", "));
|
||||
|
||||
@@ -147,13 +147,3 @@
|
||||
.custom-switch {
|
||||
padding-left: 2px;
|
||||
}
|
||||
|
||||
input[type="number"]::-webkit-outer-spin-button,
|
||||
input[type="number"]::-webkit-inner-spin-button {
|
||||
-webkit-appearance: none;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
input[type="number"] {
|
||||
-moz-appearance: textfield;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { FC, useEffect, useMemo, useState } from "react";
|
||||
import { z } from "zod";
|
||||
import { beautifyString, cn } from "@/lib/utils";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
@@ -16,8 +17,8 @@ import {
|
||||
FaKey,
|
||||
FaHubspot,
|
||||
} from "react-icons/fa";
|
||||
import { FC, useMemo, useState } from "react";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
CredentialsProviderName,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
@@ -106,13 +107,18 @@ export type OAuthPopupResultMessage = { message_type: "oauth_popup_result" } & (
|
||||
);
|
||||
|
||||
export const CredentialsInput: FC<{
|
||||
selfKey: string;
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
className?: string;
|
||||
selectedCredentials?: CredentialsMetaInput;
|
||||
onSelectCredentials: (newValue?: CredentialsMetaInput) => void;
|
||||
}> = ({ selfKey, className, selectedCredentials, onSelectCredentials }) => {
|
||||
const api = useBackendAPI();
|
||||
const credentials = useCredentials(selfKey);
|
||||
siblingInputs?: Record<string, any>;
|
||||
}> = ({
|
||||
schema,
|
||||
className,
|
||||
selectedCredentials,
|
||||
onSelectCredentials,
|
||||
siblingInputs,
|
||||
}) => {
|
||||
const [isAPICredentialsModalOpen, setAPICredentialsModalOpen] =
|
||||
useState(false);
|
||||
const [
|
||||
@@ -124,20 +130,47 @@ export const CredentialsInput: FC<{
|
||||
useState<AbortController | null>(null);
|
||||
const [oAuthError, setOAuthError] = useState<string | null>(null);
|
||||
|
||||
if (!credentials || credentials.isLoading) {
|
||||
const api = useBackendAPI();
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
// Deselect credentials if they do not exist (e.g. provider was changed)
|
||||
useEffect(() => {
|
||||
if (!credentials || !("savedCredentials" in credentials)) return;
|
||||
if (
|
||||
selectedCredentials &&
|
||||
!credentials.savedCredentials.some((c) => c.id === selectedCredentials.id)
|
||||
) {
|
||||
onSelectCredentials(undefined);
|
||||
}
|
||||
}, [credentials, selectedCredentials, onSelectCredentials]);
|
||||
|
||||
const singleCredential = useMemo(() => {
|
||||
if (!credentials || !("savedCredentials" in credentials)) return null;
|
||||
|
||||
if (credentials.savedCredentials.length === 1)
|
||||
return credentials.savedCredentials[0];
|
||||
|
||||
return null;
|
||||
}, [credentials]);
|
||||
|
||||
// If only 1 credential is available, auto-select it and hide this input
|
||||
useEffect(() => {
|
||||
if (singleCredential && !selectedCredentials) {
|
||||
onSelectCredentials(singleCredential);
|
||||
}
|
||||
}, [singleCredential, selectedCredentials, onSelectCredentials]);
|
||||
|
||||
if (!credentials || credentials.isLoading || singleCredential) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const {
|
||||
schema,
|
||||
provider,
|
||||
providerName,
|
||||
supportsApiKey,
|
||||
supportsOAuth2,
|
||||
supportsUserPassword,
|
||||
savedApiKeys,
|
||||
savedOAuthCredentials,
|
||||
savedUserPasswordCredentials,
|
||||
savedCredentials,
|
||||
oAuthCallback,
|
||||
} = credentials;
|
||||
|
||||
@@ -235,13 +268,14 @@ export const CredentialsInput: FC<{
|
||||
<>
|
||||
{supportsApiKey && (
|
||||
<APIKeyCredentialsModal
|
||||
credentialsFieldName={selfKey}
|
||||
schema={schema}
|
||||
open={isAPICredentialsModalOpen}
|
||||
onClose={() => setAPICredentialsModalOpen(false)}
|
||||
onCredentialsCreate={(credsMeta) => {
|
||||
onSelectCredentials(credsMeta);
|
||||
setAPICredentialsModalOpen(false);
|
||||
}}
|
||||
siblingInputs={siblingInputs}
|
||||
/>
|
||||
)}
|
||||
{supportsOAuth2 && (
|
||||
@@ -253,43 +287,34 @@ export const CredentialsInput: FC<{
|
||||
)}
|
||||
{supportsUserPassword && (
|
||||
<UserPasswordCredentialsModal
|
||||
credentialsFieldName={selfKey}
|
||||
schema={schema}
|
||||
open={isUserPasswordCredentialsModalOpen}
|
||||
onClose={() => setUserPasswordCredentialsModalOpen(false)}
|
||||
onCredentialsCreate={(creds) => {
|
||||
onSelectCredentials(creds);
|
||||
setUserPasswordCredentialsModalOpen(false);
|
||||
}}
|
||||
siblingInputs={siblingInputs}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
|
||||
// Deselect credentials if they do not exist (e.g. provider was changed)
|
||||
if (
|
||||
selectedCredentials &&
|
||||
!savedApiKeys
|
||||
.concat(savedOAuthCredentials)
|
||||
.concat(savedUserPasswordCredentials)
|
||||
.some((c) => c.id === selectedCredentials.id)
|
||||
) {
|
||||
onSelectCredentials(undefined);
|
||||
}
|
||||
const fieldHeader = (
|
||||
<div className="mb-2 flex gap-1">
|
||||
<span className="text-m green text-gray-900">
|
||||
{providerName} Credentials
|
||||
</span>
|
||||
<SchemaTooltip description={schema.description} />
|
||||
</div>
|
||||
);
|
||||
|
||||
// No saved credentials yet
|
||||
if (
|
||||
savedApiKeys.length === 0 &&
|
||||
savedOAuthCredentials.length === 0 &&
|
||||
savedUserPasswordCredentials.length === 0
|
||||
) {
|
||||
if (savedCredentials.length === 0) {
|
||||
return (
|
||||
<>
|
||||
<div className="mb-2 flex gap-1">
|
||||
<span className="text-m green text-gray-900">
|
||||
{providerName} Credentials
|
||||
</span>
|
||||
<SchemaTooltip description={schema.description} />
|
||||
</div>
|
||||
<div>
|
||||
{fieldHeader}
|
||||
|
||||
<div className={cn("flex flex-row space-x-2", className)}>
|
||||
{supportsOAuth2 && (
|
||||
<Button onClick={handleOAuthLogin}>
|
||||
@@ -314,46 +339,10 @@ export const CredentialsInput: FC<{
|
||||
{oAuthError && (
|
||||
<div className="mt-2 text-red-500">Error: {oAuthError}</div>
|
||||
)}
|
||||
</>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const getCredentialCounts = () => ({
|
||||
apiKeys: savedApiKeys.length,
|
||||
oauth: savedOAuthCredentials.length,
|
||||
userPass: savedUserPasswordCredentials.length,
|
||||
});
|
||||
|
||||
const getSingleCredential = () => {
|
||||
const counts = getCredentialCounts();
|
||||
const totalCredentials = Object.values(counts).reduce(
|
||||
(sum, count) => sum + count,
|
||||
0,
|
||||
);
|
||||
|
||||
if (totalCredentials !== 1) return null;
|
||||
|
||||
if (counts.apiKeys === 1) return savedApiKeys[0];
|
||||
if (counts.oauth === 1) return savedOAuthCredentials[0];
|
||||
if (counts.userPass === 1) return savedUserPasswordCredentials[0];
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
const singleCredential = getSingleCredential();
|
||||
|
||||
if (singleCredential) {
|
||||
if (!selectedCredentials) {
|
||||
onSelectCredentials({
|
||||
id: singleCredential.id,
|
||||
type: singleCredential.type,
|
||||
provider,
|
||||
title: singleCredential.title,
|
||||
});
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function handleValueChange(newValue: string) {
|
||||
if (newValue === "sign-in") {
|
||||
// Trigger OAuth2 sign in flow
|
||||
@@ -362,10 +351,7 @@ export const CredentialsInput: FC<{
|
||||
// Open API key dialog
|
||||
setAPICredentialsModalOpen(true);
|
||||
} else {
|
||||
const selectedCreds = savedApiKeys
|
||||
.concat(savedOAuthCredentials)
|
||||
.concat(savedUserPasswordCredentials)
|
||||
.find((c) => c.id == newValue)!;
|
||||
const selectedCreds = savedCredentials.find((c) => c.id == newValue)!;
|
||||
|
||||
onSelectCredentials({
|
||||
id: selectedCreds.id,
|
||||
@@ -378,38 +364,40 @@ export const CredentialsInput: FC<{
|
||||
|
||||
// Saved credentials exist
|
||||
return (
|
||||
<>
|
||||
<div className="flex gap-1">
|
||||
<span className="text-m green mb-0 text-gray-900">
|
||||
{providerName} Credentials
|
||||
</span>
|
||||
<SchemaTooltip description={schema.description} />
|
||||
</div>
|
||||
<div>
|
||||
{fieldHeader}
|
||||
|
||||
<Select value={selectedCredentials?.id} onValueChange={handleValueChange}>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder={schema.placeholder} />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="nodrag">
|
||||
{savedOAuthCredentials.map((credentials, index) => (
|
||||
<SelectItem key={index} value={credentials.id}>
|
||||
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||
{credentials.username}
|
||||
</SelectItem>
|
||||
))}
|
||||
{savedApiKeys.map((credentials, index) => (
|
||||
<SelectItem key={index} value={credentials.id}>
|
||||
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||
<IconKey className="mr-1.5 inline" />
|
||||
{credentials.title}
|
||||
</SelectItem>
|
||||
))}
|
||||
{savedUserPasswordCredentials.map((credentials, index) => (
|
||||
<SelectItem key={index} value={credentials.id}>
|
||||
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||
<IconUserPlus className="mr-1.5 inline" />
|
||||
{credentials.title}
|
||||
</SelectItem>
|
||||
))}
|
||||
{savedCredentials
|
||||
.filter((c) => c.type == "oauth2")
|
||||
.map((credentials, index) => (
|
||||
<SelectItem key={index} value={credentials.id}>
|
||||
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||
{credentials.username}
|
||||
</SelectItem>
|
||||
))}
|
||||
{savedCredentials
|
||||
.filter((c) => c.type == "api_key")
|
||||
.map((credentials, index) => (
|
||||
<SelectItem key={index} value={credentials.id}>
|
||||
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||
<IconKey className="mr-1.5 inline" />
|
||||
{credentials.title}
|
||||
</SelectItem>
|
||||
))}
|
||||
{savedCredentials
|
||||
.filter((c) => c.type == "user_password")
|
||||
.map((credentials, index) => (
|
||||
<SelectItem key={index} value={credentials.id}>
|
||||
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||
<IconUserPlus className="mr-1.5 inline" />
|
||||
{credentials.title}
|
||||
</SelectItem>
|
||||
))}
|
||||
<SelectSeparator />
|
||||
{supportsOAuth2 && (
|
||||
<SelectItem value="sign-in">
|
||||
@@ -435,17 +423,18 @@ export const CredentialsInput: FC<{
|
||||
{oAuthError && (
|
||||
<div className="mt-2 text-red-500">Error: {oAuthError}</div>
|
||||
)}
|
||||
</>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const APIKeyCredentialsModal: FC<{
|
||||
credentialsFieldName: string;
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
}> = ({ credentialsFieldName, open, onClose, onCredentialsCreate }) => {
|
||||
const credentials = useCredentials(credentialsFieldName);
|
||||
siblingInputs?: Record<string, any>;
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
const formSchema = z.object({
|
||||
apiKey: z.string().min(1, "API Key is required"),
|
||||
@@ -466,8 +455,7 @@ export const APIKeyCredentialsModal: FC<{
|
||||
return null;
|
||||
}
|
||||
|
||||
const { schema, provider, providerName, createAPIKeyCredentials } =
|
||||
credentials;
|
||||
const { provider, providerName, createAPIKeyCredentials } = credentials;
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
const expiresAt = values.expiresAt
|
||||
@@ -576,12 +564,13 @@ export const APIKeyCredentialsModal: FC<{
|
||||
};
|
||||
|
||||
export const UserPasswordCredentialsModal: FC<{
|
||||
credentialsFieldName: string;
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
}> = ({ credentialsFieldName, open, onClose, onCredentialsCreate }) => {
|
||||
const credentials = useCredentials(credentialsFieldName);
|
||||
siblingInputs?: Record<string, any>;
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
const formSchema = z.object({
|
||||
username: z.string().min(1, "Username is required"),
|
||||
@@ -606,8 +595,7 @@ export const UserPasswordCredentialsModal: FC<{
|
||||
return null;
|
||||
}
|
||||
|
||||
const { schema, provider, providerName, createUserPasswordCredentials } =
|
||||
credentials;
|
||||
const { provider, providerName, createUserPasswordCredentials } = credentials;
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
const newCredentials = await createUserPasswordCredentials({
|
||||
|
||||
@@ -68,9 +68,7 @@ type UserPasswordCredentialsCreatable = Omit<
|
||||
export type CredentialsProviderData = {
|
||||
provider: CredentialsProviderName;
|
||||
providerName: string;
|
||||
savedApiKeys: CredentialsMetaResponse[];
|
||||
savedOAuthCredentials: CredentialsMetaResponse[];
|
||||
savedUserPasswordCredentials: CredentialsMetaResponse[];
|
||||
savedCredentials: CredentialsMetaResponse[];
|
||||
oAuthCallback: (
|
||||
code: string,
|
||||
state_token: string,
|
||||
@@ -113,28 +111,12 @@ export default function CredentialsProvider({
|
||||
setProviders((prev) => {
|
||||
if (!prev || !prev[provider]) return prev;
|
||||
|
||||
const updatedProvider = { ...prev[provider] };
|
||||
|
||||
if (credentials.type === "api_key") {
|
||||
updatedProvider.savedApiKeys = [
|
||||
...updatedProvider.savedApiKeys,
|
||||
credentials,
|
||||
];
|
||||
} else if (credentials.type === "oauth2") {
|
||||
updatedProvider.savedOAuthCredentials = [
|
||||
...updatedProvider.savedOAuthCredentials,
|
||||
credentials,
|
||||
];
|
||||
} else if (credentials.type === "user_password") {
|
||||
updatedProvider.savedUserPasswordCredentials = [
|
||||
...updatedProvider.savedUserPasswordCredentials,
|
||||
credentials,
|
||||
];
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
[provider]: updatedProvider,
|
||||
[provider]: {
|
||||
...prev[provider],
|
||||
savedCredentials: [...prev[provider].savedCredentials, credentials],
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
@@ -203,21 +185,14 @@ export default function CredentialsProvider({
|
||||
setProviders((prev) => {
|
||||
if (!prev || !prev[provider]) return prev;
|
||||
|
||||
const updatedProvider = { ...prev[provider] };
|
||||
updatedProvider.savedApiKeys = updatedProvider.savedApiKeys.filter(
|
||||
(cred) => cred.id !== id,
|
||||
);
|
||||
updatedProvider.savedOAuthCredentials =
|
||||
updatedProvider.savedOAuthCredentials.filter(
|
||||
(cred) => cred.id !== id,
|
||||
);
|
||||
updatedProvider.savedUserPasswordCredentials =
|
||||
updatedProvider.savedUserPasswordCredentials.filter(
|
||||
(cred) => cred.id !== id,
|
||||
);
|
||||
return {
|
||||
...prev,
|
||||
[provider]: updatedProvider,
|
||||
[provider]: {
|
||||
...prev[provider],
|
||||
savedCredentials: prev[provider].savedCredentials.filter(
|
||||
(cred) => cred.id !== id,
|
||||
),
|
||||
},
|
||||
};
|
||||
});
|
||||
return result;
|
||||
@@ -233,29 +208,12 @@ export default function CredentialsProvider({
|
||||
const credentialsByProvider = response.reduce(
|
||||
(acc, cred) => {
|
||||
if (!acc[cred.provider]) {
|
||||
acc[cred.provider] = {
|
||||
oauthCreds: [],
|
||||
apiKeys: [],
|
||||
userPasswordCreds: [],
|
||||
};
|
||||
}
|
||||
if (cred.type === "oauth2") {
|
||||
acc[cred.provider].oauthCreds.push(cred);
|
||||
} else if (cred.type === "api_key") {
|
||||
acc[cred.provider].apiKeys.push(cred);
|
||||
} else if (cred.type === "user_password") {
|
||||
acc[cred.provider].userPasswordCreds.push(cred);
|
||||
acc[cred.provider] = [];
|
||||
}
|
||||
acc[cred.provider].push(cred);
|
||||
return acc;
|
||||
},
|
||||
{} as Record<
|
||||
CredentialsProviderName,
|
||||
{
|
||||
oauthCreds: CredentialsMetaResponse[];
|
||||
apiKeys: CredentialsMetaResponse[];
|
||||
userPasswordCreds: CredentialsMetaResponse[];
|
||||
}
|
||||
>,
|
||||
{} as Record<CredentialsProviderName, CredentialsMetaResponse[]>,
|
||||
);
|
||||
|
||||
setProviders((prev) => ({
|
||||
@@ -265,40 +223,19 @@ export default function CredentialsProvider({
|
||||
provider,
|
||||
{
|
||||
provider,
|
||||
providerName:
|
||||
providerDisplayNames[provider as CredentialsProviderName],
|
||||
savedApiKeys: credentialsByProvider[provider]?.apiKeys ?? [],
|
||||
savedOAuthCredentials:
|
||||
credentialsByProvider[provider]?.oauthCreds ?? [],
|
||||
savedUserPasswordCredentials:
|
||||
credentialsByProvider[provider]?.userPasswordCreds ?? [],
|
||||
providerName: providerDisplayNames[provider],
|
||||
savedCredentials: credentialsByProvider[provider] ?? [],
|
||||
oAuthCallback: (code: string, state_token: string) =>
|
||||
oAuthCallback(
|
||||
provider as CredentialsProviderName,
|
||||
code,
|
||||
state_token,
|
||||
),
|
||||
oAuthCallback(provider, code, state_token),
|
||||
createAPIKeyCredentials: (
|
||||
credentials: APIKeyCredentialsCreatable,
|
||||
) =>
|
||||
createAPIKeyCredentials(
|
||||
provider as CredentialsProviderName,
|
||||
credentials,
|
||||
),
|
||||
) => createAPIKeyCredentials(provider, credentials),
|
||||
createUserPasswordCredentials: (
|
||||
credentials: UserPasswordCredentialsCreatable,
|
||||
) =>
|
||||
createUserPasswordCredentials(
|
||||
provider as CredentialsProviderName,
|
||||
credentials,
|
||||
),
|
||||
) => createUserPasswordCredentials(provider, credentials),
|
||||
deleteCredentials: (id: string, force: boolean = false) =>
|
||||
deleteCredentials(
|
||||
provider as CredentialsProviderName,
|
||||
id,
|
||||
force,
|
||||
),
|
||||
},
|
||||
deleteCredentials(provider, id, force),
|
||||
} satisfies CredentialsProviderData,
|
||||
]),
|
||||
),
|
||||
}));
|
||||
|
||||
@@ -8,7 +8,7 @@ export default function LibraryAgentCard({
|
||||
id,
|
||||
name,
|
||||
description,
|
||||
agent_id,
|
||||
graph_id: agent_id,
|
||||
can_access_graph,
|
||||
creator_image_url,
|
||||
image_url,
|
||||
@@ -48,7 +48,7 @@ export default function LibraryAgentCard({
|
||||
/>
|
||||
)}
|
||||
<div className="absolute bottom-4 left-4">
|
||||
<Avatar className="h-16 w-16 border-2 border-white dark:border-gray-800">
|
||||
<Avatar className="h-16 w-16">
|
||||
<AvatarImage
|
||||
src={
|
||||
creator_image_url
|
||||
@@ -57,7 +57,7 @@ export default function LibraryAgentCard({
|
||||
}
|
||||
alt={`${name} creator avatar`}
|
||||
/>
|
||||
<AvatarFallback>{name.charAt(0)}</AvatarFallback>
|
||||
<AvatarFallback size={64}>{name.charAt(0)}</AvatarFallback>
|
||||
</Avatar>
|
||||
</div>
|
||||
</Link>
|
||||
|
||||
@@ -49,7 +49,7 @@ export default function LibrarySearchBar(): React.ReactNode {
|
||||
onFocus={() => setIsFocused(true)}
|
||||
onBlur={() => !inputRef.current?.value && setIsFocused(false)}
|
||||
onChange={handleSearchInput}
|
||||
className="flex-1 border-none font-sans text-[16px] font-normal leading-7 shadow-none focus:shadow-none"
|
||||
className="flex-1 border-none font-sans text-[16px] font-normal leading-7 shadow-none focus:shadow-none focus:ring-0"
|
||||
type="text"
|
||||
placeholder="Search agents"
|
||||
/>
|
||||
|
||||
@@ -109,7 +109,7 @@ export const AgentFlowList = ({
|
||||
lastRun: GraphExecutionMeta | null = null;
|
||||
if (executions) {
|
||||
const _flowRuns = executions.filter(
|
||||
(r) => r.graph_id == flow.agent_id,
|
||||
(r) => r.graph_id == flow.graph_id,
|
||||
);
|
||||
runCount = _flowRuns.length;
|
||||
lastRun =
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user