Merge branch 'ntindle/secrt-1079-add-ability-to-send-emails-from-notification-service' into ntindle/secrt-1077-add-email-service

This commit is contained in:
Nicholas Tindle
2025-02-14 00:17:13 -06:00
committed by GitHub
39 changed files with 1824 additions and 490 deletions

View File

@@ -137,9 +137,9 @@ jobs:
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
REDIS_HOST: "localhost"
REDIS_PORT: "6379"
REDIS_PASSWORD: "testpassword"
REDIS_HOST: 'localhost'
REDIS_PORT: '6379'
REDIS_PASSWORD: 'testpassword'
env:
CI: true
@@ -152,8 +152,8 @@ jobs:
# If you want to replace this, you can do so by making our entire system generate
# new credentials for each local user and update the environment variables in
# the backend service, docker composes, and examples
RABBITMQ_DEFAULT_USER: rabbitmq_user_default
RABBITMQ_DEFAULT_PASS: k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
RABBITMQ_DEFAULT_USER: 'rabbitmq_user_default'
RABBITMQ_DEFAULT_PASS: 'k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7'
# - name: Upload coverage reports to Codecov
# uses: codecov/codecov-action@v4

View File

@@ -77,8 +77,8 @@ jobs:
- name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@main
with:
large-packages: false # slow
docker-images: false # limited benefit
large-packages: false # slow
docker-images: false # limited benefit
- name: Copy default supabase .env
run: |
@@ -104,11 +104,12 @@ jobs:
run: yarn playwright install --with-deps ${{ matrix.browser }}
- name: Run tests
timeout-minutes: 20
run: |
yarn test --project=${{ matrix.browser }}
- name: Print Docker Compose logs in debug mode
if: runner.debug
- name: Print Final Docker Compose logs
if: always()
run: |
docker compose -f ../docker-compose.yml logs

View File

@@ -141,6 +141,7 @@ async def create_graph_execution(
graph_version: int,
nodes_input: list[tuple[str, BlockInput]],
user_id: str,
preset_id: str | None = None,
) -> tuple[str, list[ExecutionResult]]:
"""
Create a new AgentGraphExecution record.
@@ -168,6 +169,7 @@ async def create_graph_execution(
]
},
"userId": user_id,
"agentPresetId": preset_id,
},
include=GRAPH_EXECUTION_INCLUDE,
)

View File

@@ -70,11 +70,9 @@ class NodeModel(Node):
@staticmethod
def from_db(node: AgentNode):
if not node.AgentBlock:
raise ValueError(f"Invalid node {node.id}, invalid AgentBlock.")
obj = NodeModel(
id=node.id,
block_id=node.AgentBlock.id,
block_id=node.agentBlockId,
input_default=type.convert(node.constantInput, dict[str, Any]),
metadata=type.convert(node.metadata, dict[str, Any]),
graph_id=node.agentGraphId,
@@ -534,7 +532,7 @@ async def get_execution(user_id: str, execution_id: str) -> GraphExecution | Non
async def get_graph(
graph_id: str,
version: int | None = None,
template: bool = False,
template: bool = False, # note: currently not in use; TODO: remove from DB entirely
user_id: str | None = None,
for_export: bool = False,
) -> GraphModel | None:
@@ -716,11 +714,9 @@ async def fix_llm_provider_credentials():
store = IntegrationCredentialsStore()
broken_nodes = []
try:
broken_nodes = await prisma.get_client().query_raw(
"""
SELECT graph."userId" user_id,
broken_nodes = await prisma.get_client().query_raw(
"""
SELECT graph."userId" user_id,
node.id node_id,
node."constantInput" node_preset_input
FROM platform."AgentNode" node
@@ -729,10 +725,8 @@ async def fix_llm_provider_credentials():
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
ORDER BY graph."userId";
"""
)
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")
except Exception as e:
logger.error(f"Error fixing LLM credential inputs: {e}")
)
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")
user_id: str = ""
user_integrations = None

View File

@@ -1,7 +1,7 @@
import logging
from datetime import datetime, timedelta
from enum import Enum
from typing import Annotated, Generic, Optional, TypeVar, Union
from typing import Annotated, Any, Generic, Optional, TypeVar, Union
from prisma import Json
from prisma.enums import NotificationType
@@ -35,10 +35,10 @@ class BaseNotificationData(BaseModel):
class AgentRunData(BaseNotificationData):
agent_name: str
credits_used: float
# remaining_balance: float
execution_time: float
graph_id: str
node_count: int = Field(..., description="Number of nodes executed")
graph_id: str
outputs: dict[str, Any] = Field(..., description="Outputs of the agent")
class ZeroBalanceData(BaseNotificationData):
@@ -203,6 +203,19 @@ class NotificationTypeOverride:
NotificationType.MONTHLY_SUMMARY: "monthly_summary.html",
}[self.notification_type]
@property
def subject(self) -> str:
return {
NotificationType.AGENT_RUN: "Agent Run Report",
NotificationType.ZERO_BALANCE: "You're out of credits!",
NotificationType.LOW_BALANCE: "Low Balance Warning!",
NotificationType.BLOCK_EXECUTION_FAILED: "Uh oh! Block Execution Failed",
NotificationType.CONTINUOUS_AGENT_ERROR: "Shoot! Continuous Agent Error",
NotificationType.DAILY_SUMMARY: "Here's your daily summary!",
NotificationType.WEEKLY_SUMMARY: "Look at all the cool stuff you did last week!",
NotificationType.MONTHLY_SUMMARY: "We did a lot this month!",
}[self.notification_type]
class NotificationPreference(BaseModel):
user_id: str

View File

@@ -206,13 +206,14 @@ def execute_node(
# This is fine because for now, there is no block that is charged by time.
cost = db_client.spend_credits(data, input_size + output_size, 0)
outputs: dict[str, Any] = {}
for output_name, output_data in node_block.execute(
input_data, **extra_exec_kwargs
):
output_size += len(json.dumps(output_data))
log_metadata.info("Node produced output", **{output_name: output_data})
db_client.upsert_execution_output(node_exec_id, output_name, output_data)
outputs[output_name] = output_data
for execution in _enqueue_next_nodes(
db_client=db_client,
node=node,
@@ -230,6 +231,7 @@ def execute_node(
user_id=user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=outputs,
agent_name=node_block.name,
credits_used=cost,
execution_time=0,
@@ -831,6 +833,7 @@ class ExecutionManager(AppService):
data: BlockInput,
user_id: str,
graph_version: Optional[int] = None,
preset_id: str | None = None,
) -> GraphExecutionEntry:
graph: GraphModel | None = self.db_client.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
@@ -852,9 +855,9 @@ class ExecutionManager(AppService):
# Extract request input data, and assign it to the input pin.
if block.block_type == BlockType.INPUT:
name = node.input_default.get("name")
if name in data.get("node_input", {}):
input_data = {"value": data["node_input"][name]}
input_name = node.input_default.get("name")
if input_name and input_name in data:
input_data = {"value": data[input_name]}
# Extract webhook payload, and assign it to the input pin
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
@@ -879,6 +882,7 @@ class ExecutionManager(AppService):
graph_version=graph.version,
nodes_input=nodes_input,
user_id=user_id,
preset_id=preset_id,
)
starting_node_execs = []

View File

@@ -4,6 +4,7 @@ import pathlib
from postmarker.core import PostmarkClient
from postmarker.models.emails import EmailManager
from prisma.enums import NotificationType
from pydantic import BaseModel
from backend.data.notifications import (
NotificationEventModel,
@@ -24,6 +25,12 @@ class TypedPostmarkClient(PostmarkClient):
emails: EmailManager
class Template(BaseModel):
subject: str
body: str
base_template: str
class EmailSender:
def __init__(self):
if settings.secrets.postmark_server_api_token:
@@ -42,33 +49,48 @@ class EmailSender:
user_email: str,
data: NotificationEventModel[T_co] | list[NotificationEventModel[T_co]],
):
"""Send an email to a user using a template pulled from the notification type"""
if not self.postmark:
logger.warning("Postmark client not initialized, email not sent")
return
body = self._get_template(notification)
# use the jinja2 library to render the template
body = self.formatter.format_string(body, data)
logger.info(
f"Sending email to {user_email} with subject {"subject"} and body {body}"
)
self._send_email(user_email, "subject", body)
template = self._get_template(notification)
try:
subject, full_message = self.formatter.format_email(
base_template=template.base_template,
subject_template=template.subject,
content_template=template.body,
data=data,
unsubscribe_link="https://autogpt.com/unsubscribe",
)
except Exception as e:
logger.error(f"Error formatting full message: {e}")
raise e
self._send_email(user_email, subject, full_message)
def _get_template(self, notification: NotificationType):
# convert the notification type to a notification type override
notification_type_override = NotificationTypeOverride(notification)
# find the template in templates/name.html (the .template returns with the .html)
template_path = f"templates/{notification_type_override.template}.jinja2"
logger.info(
logger.debug(
f"Template full path: {pathlib.Path(__file__).parent / template_path}"
)
base_template_path = "templates/base.html.jinja2"
with open(pathlib.Path(__file__).parent / base_template_path, "r") as file:
base_template = file.read()
with open(pathlib.Path(__file__).parent / template_path, "r") as file:
template = file.read()
return template
return Template(
subject=notification_type_override.subject,
body=template,
base_template=base_template,
)
def _send_email(self, user_email: str, subject: str, body: str):
logger.info(
f"Sending email to {user_email} with subject {subject} and body {body}"
)
logger.debug(f"Sending email to {user_email} with subject {subject}")
self.postmark.emails.send(
From=settings.config.postmark_sender_email,
To=user_email,

View File

@@ -131,9 +131,8 @@ class NotificationManager(AppService):
def __init__(self):
super().__init__()
self.use_db = True
self.use_async = False # Use async RabbitMQ client
self.use_rabbitmq = create_notification_config()
self.summary_manager = SummaryManager()
self.rabbitmq_config = create_notification_config()
self.running = True
self.email_sender = EmailSender()
@@ -296,7 +295,7 @@ class NotificationManager(AppService):
return datetime(now.year, now.month - 1, 1)
async def _process_immediate(self, message: str) -> bool:
"""Process a single notification immediately"""
"""Process a single notification immediately, returning whether to put into the failed queue"""
try:
event = NotificationEventDTO.model_validate_json(message)
parsed_event = NotificationEventModel[

View File

@@ -1,6 +1,75 @@
AGENT RUN
{{data.name}}
{{data.node_count}}
{{data.execution_time}}
{{data.graph_id}}
{# Agent Run #}
{# Template variables:
data.name: the name of the agent
data.credits_used: the number of credits used by the agent
data.node_count: the number of nodes the agent ran on
data.execution_time: the time it took to run the agent
data.graph_id: the id of the graph the agent ran on
data.outputs: the dict[str, Any] of outputs of the agent
#}
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 10px;
">
Hi,
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 10px;
">
We've run your agent {{ data.name }} and it took {{ data.execution_time }} seconds to complete.
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 10px;
">
It ran on {{ data.node_count }} nodes and used {{ data.credits_used }} credits.
</p>
<ul style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 10px;
">
It output the following:
{# jinja2 list iteration thorugh data.outputs #}
{% for key, value in data.outputs.items() %}
<li>{{ key }}: {{ value }}</li>
{% endfor %}
</ul>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 10px;
">
Your feedback has been instrumental in shaping AutoGPT, and we couldn't have
done it without you. We look forward to continuing this journey together as we
bring AI-powered automation to the world.
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 0;
">
Thank you again for your time and support.
</p>

View File

@@ -0,0 +1,352 @@
{# Base Template #}
{# Template variables:
data.message: the message to display in the email
data.title: the title of the email
data.unsubscribe_link: the link to unsubscribe from the email
#}
<!doctype html>
<html lang="ltr" xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=yes">
<meta name="format-detection" content="telephone=no, date=no, address=no, email=no, url=no">
<meta name="x-apple-disable-message-reformatting">
<!--[if !mso]>
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<![endif]-->
<!--[if mso]>
<style>
* { font-family: sans-serif !important; }
</style>
<noscript>
<xml>
<o:OfficeDocumentSettings>
<o:PixelsPerInch>96</o:PixelsPerInch>
</o:OfficeDocumentSettings>
</xml>
</noscript>
<![endif]-->
<style type="text/css">
/* RESET STYLES */
html,
body {
margin: 0 !important;
padding: 0 !important;
width: 100% !important;
height: 100% !important;
}
body {
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
text-rendering: optimizeLegibility;
}
.document {
margin: 0 !important;
padding: 0 !important;
width: 100% !important;
}
img {
border: 0;
outline: none;
text-decoration: none;
-ms-interpolation-mode: bicubic;
}
table {
border-collapse: collapse;
}
table,
td {
mso-table-lspace: 0pt;
mso-table-rspace: 0pt;
}
body,
table,
td,
a {
-webkit-text-size-adjust: 100%;
-ms-text-size-adjust: 100%;
}
h1,
h2,
h3,
h4,
h5,
p {
margin: 0;
word-break: break-word;
}
/* iOS BLUE LINKS */
a[x-apple-data-detectors] {
color: inherit !important;
text-decoration: none !important;
font-size: inherit !important;
font-family: inherit !important;
font-weight: inherit !important;
line-height: inherit !important;
}
/* ANDROID CENTER FIX */
div[style*="margin: 16px 0;"] {
margin: 0 !important;
}
/* MEDIA QUERIES */
@media all and (max-width:639px) {
.wrapper {
width: 100% !important;
}
.container {
width: 100% !important;
min-width: 100% !important;
padding: 0 !important;
}
.row {
padding-left: 20px !important;
padding-right: 20px !important;
}
.col-mobile {
width: 20px !important;
}
.col {
display: block !important;
width: 100% !important;
}
.mobile-center {
text-align: center !important;
float: none !important;
}
.mobile-mx-auto {
margin: 0 auto !important;
float: none !important;
}
.mobile-left {
text-align: center !important;
float: left !important;
}
.mobile-hide {
display: none !important;
}
.img {
width: 100% !important;
height: auto !important;
}
.ml-btn {
width: 100% !important;
max-width: 100% !important;
}
.ml-btn-container {
width: 100% !important;
max-width: 100% !important;
}
}
</style>
<style type="text/css">
@import url("https://assets.mlcdn.com/fonts-v2.css?version=1729862");
</style>
<style type="text/css">
@media screen {
body {
font-family: 'Poppins', sans-serif;
}
}
</style>
<title>{{data.title}}</title>
</head>
<body style="margin: 0 !important; padding: 0 !important; background-color:#070629;">
<div class="document" role="article" aria-roledescription="email" aria-label lang dir="ltr"
style="background-color:#070629; line-height: 100%; font-size:medium; font-size:max(16px, 1rem);">
<!-- Main Content -->
<table width="100%" align="center" cellspacing="0" cellpadding="0" border="0">
<tr>
<td class="background" bgcolor="#070629" align="center" valign="top" style="padding: 0 8px;">
<!-- Email Content -->
<table class="container" align="center" width="640" cellpadding="0" cellspacing="0" border="0"
style="max-width: 640px;">
<tr>
<td align="center">
<!-- Logo Section -->
<table class="container ml-4 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
cellspacing="0" cellpadding="0" style="width: 640px; min-width: 640px;">
<tr>
<td class="ml-default-border container" height="40" style="line-height: 40px; min-width: 640px;">
</td>
</tr>
<tr>
<td>
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
<tr>
<td class="row" align="center" style="padding: 0 50px;">
<img
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
border="0" alt="" width="120" class="logo"
style="max-width: 120px; display: inline-block;">
</td>
</tr>
</table>
</td>
</tr>
</table>
<!-- Main Content Section -->
<table class="container ml-6 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
cellspacing="0" cellpadding="0" style="color: #070629; width: 640px; min-width: 640px;">
<tr>
<td class="row" style="padding: 0 50px;">
{{data.message|safe}}
</td>
</tr>
</table>
<!-- Signature Section -->
<table class="container ml-8 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
cellspacing="0" cellpadding="0" style="color: #070629; width: 640px; min-width: 640px;">
<tr>
<td class="row mobile-center" align="left" style="padding: 0 50px;">
<table class="ml-8 wrapper" border="0" cellspacing="0" cellpadding="0"
style="color: #070629; text-align: left;">
<tr>
<td class="col mobile-center" align="center" width="80">
<img
src="https://storage.mlcdn.com/account_image/597379/68W8w94Zwl52yQyrKdFERRquu2CivAcn17ST22HF.jpg"
border="0" alt="" width="80" class="avatar"
style="display: inline-block; max-width: 80px; border-radius: 80px;">
</td>
<td class="col" width="30" height="30" style="line-height: 30px;"></td>
<td class="col center mobile-center" align>
<p
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-top: 0; margin-bottom: 0;">
John Ababseh<br>Product Manager<br>
<a href="mailto:john.ababseh@agpt.co" target="_blank"
style="color: #4285F4; font-weight: normal; font-style: normal; text-decoration: underline;">john.ababseh@agpt.co</a>
</p>
</td>
</tr>
</table>
</td>
</tr>
</table>
<!-- Footer Section -->
<table class="container ml-10 ml-default-border" width="640" bgcolor="#ffffff" align="center" border="0"
cellspacing="0" cellpadding="0" style="width: 640px; min-width: 640px;">
<tr>
<td class="row" style="padding: 0 50px;">
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
<tr>
<td>
<!-- Footer Content -->
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
<tr>
<td class="col" align="left" valign="middle" width="120">
<img
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
border="0" alt="" width="120" class="logo"
style="max-width: 120px; display: inline-block;">
</td>
<td class="col" width="40" height="30" style="line-height: 30px;"></td>
<td class="col mobile-left" align="right" valign="middle" width="250">
<table role="presentation" cellpadding="0" cellspacing="0" border="0">
<tr>
<td align="center" valign="middle" width="18" style="padding: 0 5px 0 0;">
<a href="https://x.com/auto_gpt" target="blank" style="text-decoration: none;">
<img
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/x.png"
width="18" alt="x">
</a>
</td>
<td align="center" valign="middle" width="18" style="padding: 0 5px;">
<a href="https://discord.gg/autogpt" target="blank"
style="text-decoration: none;">
<img
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/discord.png"
width="18" alt="discord">
</a>
</td>
<td align="center" valign="middle" width="18" style="padding: 0 0 0 5px;">
<a href="https://agpt.co/" target="blank" style="text-decoration: none;">
<img
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/website.png"
width="18" alt="website">
</a>
</td>
</tr>
</table>
</td>
</tr>
</table>
</td>
</tr>
<tr>
<td align="center" style="text-align: left!important;">
<h5
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 15px; line-height: 125%; font-weight: bold; font-style: normal; text-decoration: none; margin-bottom: 6px;">
AutoGPT
</h5>
</td>
</tr>
<tr>
<td align="center" style="text-align: left!important;">
<p
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 14px; line-height: 150%; display: inline-block; margin-bottom: 0;">
3rd Floor 1 Ashley Road, Cheshire, United Kingdom, WA14 2DT, Altrincham<br>United Kingdom
</p>
</td>
</tr>
<tr>
<td height="8" style="line-height: 8px;"></td>
</tr>
<tr>
<td align="left" style="text-align: left!important;">
<p
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 14px; line-height: 150%; display: inline-block; margin-bottom: 0;">
You received this email because you signed up on our website.</p>
</td>
</tr>
<tr>
<td height="1" style="line-height: 12px;"></td>
</tr>
<tr>
<td align="left">
<p
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 14px; line-height: 150%; display: inline-block; margin-bottom: 0;">
<a href="{{data.unsubscribe_link}}"
style="color: #4285F4; font-weight: normal; font-style: normal; text-decoration: underline;">Unsubscribe</a>
</p>
</td>
</tr>
</table>
</td>
</tr>
</table>
</td>
</tr>
</table>
</td>
</tr>
</table>
</div>
</body>
</html>

View File

@@ -1,9 +1,9 @@
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence
from typing import Annotated, Any, Dict, List, Optional, Sequence
from autogpt_libs.utils.cache import thread_cached
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Body, Depends, HTTPException
from prisma.enums import AgentExecutionStatus, APIKeyPermission
from typing_extensions import TypedDict
@@ -101,7 +101,7 @@ def execute_graph_block(
def execute_graph(
graph_id: str,
graph_version: int,
node_input: dict[Any, Any],
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
) -> dict[str, Any]:
try:
@@ -113,7 +113,7 @@ def execute_graph(
)
return {"id": graph_exec.graph_exec_id}
except Exception as e:
msg = e.__str__().encode().decode("unicode_escape")
msg = str(e).encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)

View File

@@ -33,9 +33,7 @@ class ExecuteGraphResponse(pydantic.BaseModel):
class CreateGraph(pydantic.BaseModel):
template_id: str | None = None
template_version: int | None = None
graph: backend.data.graph.Graph | None = None
graph: backend.data.graph.Graph
class CreateAPIKeyRequest(pydantic.BaseModel):
@@ -57,5 +55,20 @@ class UpdatePermissionsRequest(pydantic.BaseModel):
permissions: List[APIKeyPermission]
class Pagination(pydantic.BaseModel):
total_items: int = pydantic.Field(
description="Total number of items.", examples=[42]
)
total_pages: int = pydantic.Field(
description="Total number of pages.", examples=[2]
)
current_page: int = pydantic.Field(
description="Current_page page number.", examples=[1]
)
page_size: int = pydantic.Field(
description="Number of items per page.", examples=[25]
)
class RequestTopUp(pydantic.BaseModel):
credit_amount: int

View File

@@ -17,6 +17,8 @@ import backend.data.db
import backend.data.graph
import backend.data.user
import backend.server.routers.v1
import backend.server.v2.library.db
import backend.server.v2.library.model
import backend.server.v2.library.routes
import backend.server.v2.store.model
import backend.server.v2.store.routes
@@ -123,15 +125,15 @@ class AgentServer(backend.util.service.AppProcess):
@staticmethod
async def test_execute_graph(
graph_id: str,
node_input: dict[str, Any],
user_id: str,
graph_version: Optional[int] = None,
node_input: Optional[dict[str, Any]] = None,
):
return backend.server.routers.v1.execute_graph(
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version,
node_input=node_input,
node_input=node_input or {},
)
@staticmethod
@@ -170,8 +172,64 @@ class AgentServer(backend.util.service.AppProcess):
@staticmethod
async def test_delete_graph(graph_id: str, user_id: str):
await backend.server.v2.library.db.delete_library_agent_by_graph_id(
graph_id=graph_id, user_id=user_id
)
return await backend.server.routers.v1.delete_graph(graph_id, user_id)
@staticmethod
async def test_get_presets(user_id: str, page: int = 1, page_size: int = 10):
return await backend.server.v2.library.routes.presets.get_presets(
user_id=user_id, page=page, page_size=page_size
)
@staticmethod
async def test_get_preset(preset_id: str, user_id: str):
return await backend.server.v2.library.routes.presets.get_preset(
preset_id=preset_id, user_id=user_id
)
@staticmethod
async def test_create_preset(
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
user_id: str,
):
return await backend.server.v2.library.routes.presets.create_preset(
preset=preset, user_id=user_id
)
@staticmethod
async def test_update_preset(
preset_id: str,
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
user_id: str,
):
return await backend.server.v2.library.routes.presets.update_preset(
preset_id=preset_id, preset=preset, user_id=user_id
)
@staticmethod
async def test_delete_preset(preset_id: str, user_id: str):
return await backend.server.v2.library.routes.presets.delete_preset(
preset_id=preset_id, user_id=user_id
)
@staticmethod
async def test_execute_preset(
graph_id: str,
graph_version: int,
preset_id: str,
user_id: str,
node_input: Optional[dict[str, Any]] = None,
):
return await backend.server.v2.library.routes.presets.execute_preset(
graph_id=graph_id,
graph_version=graph_version,
preset_id=preset_id,
node_input=node_input or {},
user_id=user_id,
)
@staticmethod
async def test_create_store_listing(
request: backend.server.v2.store.model.StoreSubmissionRequest, user_id: str

View File

@@ -9,12 +9,13 @@ import stripe
from autogpt_libs.auth.middleware import auth_middleware
from autogpt_libs.feature_flag.client import feature_flag
from autogpt_libs.utils.cache import thread_cached
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
from typing_extensions import Optional, TypedDict
import backend.data.block
import backend.server.integrations.router
import backend.server.routers.analytics
import backend.server.v2.library.db as library_db
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.api_key import (
@@ -310,11 +311,6 @@ async def get_graph(
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
@v1_router.get(
path="/templates/{graph_id}/versions",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_graph_all_versions(
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> Sequence[graph_db.GraphModel]:
@@ -330,41 +326,18 @@ async def get_graph_all_versions(
async def create_new_graph(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.GraphModel:
return await do_create_graph(create_graph, is_template=False, user_id=user_id)
async def do_create_graph(
create_graph: CreateGraph,
is_template: bool,
# user_id doesn't have to be annotated like on other endpoints,
# because create_graph isn't used directly as an endpoint
user_id: str,
) -> graph_db.GraphModel:
if create_graph.graph:
graph = graph_db.make_graph_model(create_graph.graph, user_id)
elif create_graph.template_id:
# Create a new graph from a template
graph = await graph_db.get_graph(
create_graph.template_id,
create_graph.template_version,
template=True,
user_id=user_id,
)
if not graph:
raise HTTPException(
400, detail=f"Template #{create_graph.template_id} not found"
)
graph.version = 1
else:
raise HTTPException(
status_code=400, detail="Either graph or template_id must be provided."
)
graph.is_template = is_template
graph.is_active = not is_template
graph = graph_db.make_graph_model(create_graph.graph, user_id)
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
graph = await graph_db.create_graph(graph, user_id=user_id)
# Create a library agent for the new graph
await library_db.create_library_agent(
graph.id,
graph.version,
user_id,
)
graph = await on_graph_activate(
graph,
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
@@ -391,11 +364,6 @@ async def delete_graph(
@v1_router.put(
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
)
@v1_router.put(
path="/templates/{graph_id}",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def update_graph(
graph_id: str,
graph: graph_db.Graph,
@@ -427,6 +395,10 @@ async def update_graph(
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
if new_graph_version.is_active:
# Keep the library agent up to date with the new active version
await library_db.update_agent_version_in_library(
user_id, graph.id, graph.version
)
def get_credentials(credentials_id: str) -> "Credentials | None":
return integration_creds_manager.get(user_id, credentials_id)
@@ -483,6 +455,12 @@ async def set_graph_active_version(
version=new_active_version,
user_id=user_id,
)
# Keep the library agent up to date with the new active version
await library_db.update_agent_version_in_library(
user_id, new_active_graph.id, new_active_graph.version
)
if current_active_graph and current_active_graph.version != new_active_version:
# Handle deactivation of the previously active version
await on_graph_deactivate(
@@ -498,7 +476,7 @@ async def set_graph_active_version(
)
def execute_graph(
graph_id: str,
node_input: dict[Any, Any],
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
user_id: Annotated[str, Depends(get_user_id)],
graph_version: Optional[int] = None,
) -> ExecuteGraphResponse:
@@ -508,7 +486,7 @@ def execute_graph(
)
return ExecuteGraphResponse(graph_exec_id=graph_exec.graph_exec_id)
except Exception as e:
msg = e.__str__().encode().decode("unicode_escape")
msg = str(e).encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)
@@ -559,47 +537,6 @@ async def get_graph_run_node_execution_results(
return await execution_db.get_execution_results(graph_exec_id)
########################################################
##################### Templates ########################
########################################################
@v1_router.get(
path="/templates",
tags=["graphs", "templates"],
dependencies=[Depends(auth_middleware)],
)
async def get_templates(
user_id: Annotated[str, Depends(get_user_id)]
) -> Sequence[graph_db.GraphModel]:
return await graph_db.get_graphs(filter_by="template", user_id=user_id)
@v1_router.get(
path="/templates/{graph_id}",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_template(
graph_id: str, version: int | None = None
) -> graph_db.GraphModel:
graph = await graph_db.get_graph(graph_id, version, template=True)
if not graph:
raise HTTPException(status_code=404, detail=f"Template #{graph_id} not found.")
return graph
@v1_router.post(
path="/templates",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def create_new_template(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.GraphModel:
return await do_create_graph(create_graph, is_template=True, user_id=user_id)
########################################################
##################### Schedules ########################
########################################################

View File

@@ -1,103 +1,166 @@
import logging
from typing import List
import prisma.errors
import prisma.fields
import prisma.models
import prisma.types
import backend.data.graph
import backend.data.includes
import backend.server.v2.library.model
import backend.server.v2.store.exceptions
import backend.server.model
import backend.server.v2.library.model as library_model
import backend.server.v2.store.exceptions as store_exceptions
logger = logging.getLogger(__name__)
async def get_library_agents(
user_id: str,
) -> List[backend.server.v2.library.model.LibraryAgent]:
"""
Returns all agents (AgentGraph) that belong to the user and all agents in their library (UserAgent table)
"""
logger.debug(f"Getting library agents for user {user_id}")
user_id: str, search_query: str | None = None
) -> list[library_model.LibraryAgent]:
logger.debug(
f"Fetching library agents for user_id={user_id} search_query={search_query}"
)
try:
# Get agents created by user with nodes and links
user_created = await prisma.models.AgentGraph.prisma().find_many(
where=prisma.types.AgentGraphWhereInput(userId=user_id, isActive=True),
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
)
if search_query and len(search_query.strip()) > 100:
logger.warning(f"Search query too long: {search_query}")
raise store_exceptions.DatabaseError("Search query is too long.")
# Get agents in user's library with nodes and links
library_agents = await prisma.models.UserAgent.prisma().find_many(
where=prisma.types.UserAgentWhereInput(
userId=user_id, isDeleted=False, isArchived=False
),
include={
where_clause: prisma.types.LibraryAgentWhereInput = {
"userId": user_id,
"isDeleted": False,
"isArchived": False,
}
if search_query:
where_clause["OR"] = [
{
"Agent": {
"include": {
"AgentNodes": {
"include": {
"Input": True,
"Output": True,
"Webhook": True,
"AgentBlock": True,
}
}
"is": {"name": {"contains": search_query, "mode": "insensitive"}}
}
},
{
"Agent": {
"is": {
"description": {"contains": search_query, "mode": "insensitive"}
}
}
},
]
try:
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
where=where_clause,
include={
"Agent": {
"include": {
"AgentNodes": {"include": {"Input": True, "Output": True}}
}
}
},
order=[{"updatedAt": "desc"}],
)
# Convert to Graph models first
graphs = []
# Add user created agents
for agent in user_created:
try:
graphs.append(backend.data.graph.GraphModel.from_db(agent))
except Exception as e:
logger.error(f"Error processing user created agent {agent.id}: {e}")
continue
# Add library agents
for agent in library_agents:
if agent.Agent:
try:
graphs.append(backend.data.graph.GraphModel.from_db(agent.Agent))
except Exception as e:
logger.error(f"Error processing library agent {agent.agentId}: {e}")
continue
# Convert Graph models to LibraryAgent models
result = []
for graph in graphs:
result.append(
backend.server.v2.library.model.LibraryAgent(
id=graph.id,
version=graph.version,
is_active=graph.is_active,
name=graph.name,
description=graph.description,
isCreatedByUser=any(a.id == graph.id for a in user_created),
input_schema=graph.input_schema,
output_schema=graph.output_schema,
)
)
logger.debug(f"Found {len(result)} library agents")
return result
logger.debug(f"Retrieved {len(library_agents)} agents for user_id={user_id}.")
return [library_model.LibraryAgent.from_db(agent) for agent in library_agents]
except prisma.errors.PrismaError as e:
logger.error(f"Database error getting library agents: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch library agents"
logger.error(f"Database error fetching library agents: {e}")
raise store_exceptions.DatabaseError("Unable to fetch library agents.")
async def create_library_agent(
agent_id: str, agent_version: int, user_id: str
) -> prisma.models.LibraryAgent:
"""
Adds an agent to the user's library (LibraryAgent table)
"""
try:
return await prisma.models.LibraryAgent.prisma().create(
data={
"userId": user_id,
"agentId": agent_id,
"agentVersion": agent_version,
"isCreatedByUser": False,
"useGraphIsActiveVersion": True,
}
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error creating agent to library: {str(e)}")
raise store_exceptions.DatabaseError("Failed to create agent to library") from e
async def update_agent_version_in_library(
user_id: str, agent_id: str, agent_version: int
) -> None:
"""
Updates the agent version in the library
"""
try:
library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise(
where={
"userId": user_id,
"agentId": agent_id,
"useGraphIsActiveVersion": True,
},
)
await prisma.models.LibraryAgent.prisma().update(
where={"id": library_agent.id},
data={
"Agent": {
"connect": {
"graphVersionId": {"id": agent_id, "version": agent_version}
},
},
},
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating agent version in library: {str(e)}")
raise store_exceptions.DatabaseError(
"Failed to update agent version in library"
) from e
async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> None:
async def update_library_agent(
library_agent_id: str,
user_id: str,
auto_update_version: bool = False,
is_favorite: bool = False,
is_archived: bool = False,
is_deleted: bool = False,
) -> None:
"""
Finds the agent from the store listing version and adds it to the user's library (UserAgent table)
Updates the library agent with the given fields
"""
try:
await prisma.models.LibraryAgent.prisma().update_many(
where={"id": library_agent_id, "userId": user_id},
data={
"useGraphIsActiveVersion": auto_update_version,
"isFavorite": is_favorite,
"isArchived": is_archived,
"isDeleted": is_deleted,
},
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating library agent: {str(e)}")
raise store_exceptions.DatabaseError("Failed to update library agent") from e
async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
"""
Deletes a library agent for the given user
"""
try:
await prisma.models.LibraryAgent.prisma().delete_many(
where={"agentId": graph_id, "userId": user_id}
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error deleting library agent: {str(e)}")
raise store_exceptions.DatabaseError("Failed to delete library agent") from e
async def add_store_agent_to_library(
store_listing_version_id: str, user_id: str
) -> None:
"""
Finds the agent from the store listing version and adds it to the user's library (LibraryAgent table)
if they don't already have it
"""
logger.debug(
@@ -116,7 +179,7 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
logger.warning(
f"Store listing version not found: {store_listing_version_id}"
)
raise backend.server.v2.store.exceptions.AgentNotFoundError(
raise store_exceptions.AgentNotFoundError(
f"Store listing version {store_listing_version_id} not found"
)
@@ -126,12 +189,10 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
logger.warning(
f"User {user_id} cannot add their own agent to their library"
)
raise backend.server.v2.store.exceptions.DatabaseError(
"Cannot add own agent to library"
)
raise store_exceptions.DatabaseError("Cannot add own agent to library")
# Check if user already has this agent
existing_user_agent = await prisma.models.UserAgent.prisma().find_first(
existing_user_agent = await prisma.models.LibraryAgent.prisma().find_first(
where={
"userId": user_id,
"agentId": agent.id,
@@ -145,21 +206,134 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
)
return
# Create UserAgent entry
await prisma.models.UserAgent.prisma().create(
data=prisma.types.UserAgentCreateInput(
userId=user_id,
agentId=agent.id,
agentVersion=agent.version,
isCreatedByUser=False,
)
# Create LibraryAgent entry
await prisma.models.LibraryAgent.prisma().create(
data={
"userId": user_id,
"agentId": agent.id,
"agentVersion": agent.version,
"isCreatedByUser": False,
}
)
logger.debug(f"Added agent {agent.id} to library for user {user_id}")
except backend.server.v2.store.exceptions.AgentNotFoundError:
except store_exceptions.AgentNotFoundError:
raise
except prisma.errors.PrismaError as e:
logger.error(f"Database error adding agent to library: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to add agent to library"
) from e
raise store_exceptions.DatabaseError("Failed to add agent to library") from e
##############################################
########### Presets DB Functions #############
##############################################
async def get_presets(
user_id: str, page: int, page_size: int
) -> library_model.LibraryAgentPresetResponse:
try:
presets = await prisma.models.AgentPreset.prisma().find_many(
where={"userId": user_id},
skip=page * page_size,
take=page_size,
)
total_items = await prisma.models.AgentPreset.prisma().count(
where={"userId": user_id},
)
total_pages = (total_items + page_size - 1) // page_size
presets = [
library_model.LibraryAgentPreset.from_db(preset) for preset in presets
]
return library_model.LibraryAgentPresetResponse(
presets=presets,
pagination=backend.server.model.Pagination(
total_items=total_items,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error getting presets: {str(e)}")
raise store_exceptions.DatabaseError("Failed to fetch presets") from e
async def get_preset(
user_id: str, preset_id: str
) -> library_model.LibraryAgentPreset | None:
try:
preset = await prisma.models.AgentPreset.prisma().find_unique(
where={"id": preset_id}, include={"InputPresets": True}
)
if not preset or preset.userId != user_id:
return None
return library_model.LibraryAgentPreset.from_db(preset)
except prisma.errors.PrismaError as e:
logger.error(f"Database error getting preset: {str(e)}")
raise store_exceptions.DatabaseError("Failed to fetch preset") from e
async def upsert_preset(
user_id: str,
preset: library_model.CreateLibraryAgentPresetRequest,
preset_id: str | None = None,
) -> library_model.LibraryAgentPreset:
try:
if preset_id:
# Update existing preset
new_preset = await prisma.models.AgentPreset.prisma().update(
where={"id": preset_id},
data={
"name": preset.name,
"description": preset.description,
"isActive": preset.is_active,
"InputPresets": {
"create": [
{"name": name, "data": prisma.fields.Json(data)}
for name, data in preset.inputs.items()
]
},
},
include={"InputPresets": True},
)
if not new_preset:
raise ValueError(f"AgentPreset #{preset_id} not found")
else:
# Create new preset
new_preset = await prisma.models.AgentPreset.prisma().create(
data={
"userId": user_id,
"name": preset.name,
"description": preset.description,
"agentId": preset.agent_id,
"agentVersion": preset.agent_version,
"isActive": preset.is_active,
"InputPresets": {
"create": [
{"name": name, "data": prisma.fields.Json(data)}
for name, data in preset.inputs.items()
]
},
},
include={"InputPresets": True},
)
return library_model.LibraryAgentPreset.from_db(new_preset)
except prisma.errors.PrismaError as e:
logger.error(f"Database error creating preset: {str(e)}")
raise store_exceptions.DatabaseError("Failed to create preset") from e
async def delete_preset(user_id: str, preset_id: str) -> None:
try:
await prisma.models.AgentPreset.prisma().update_many(
where={"id": preset_id, "userId": user_id},
data={"isDeleted": True},
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error deleting preset: {str(e)}")
raise store_exceptions.DatabaseError("Failed to delete preset") from e

View File

@@ -37,7 +37,7 @@ async def test_get_library_agents(mocker):
]
mock_library_agents = [
prisma.models.UserAgent(
prisma.models.LibraryAgent(
id="ua1",
userId="test-user",
agentId="agent2",
@@ -48,6 +48,7 @@ async def test_get_library_agents(mocker):
createdAt=datetime.now(),
updatedAt=datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
Agent=prisma.models.AgentGraph(
id="agent2",
version=1,
@@ -67,8 +68,8 @@ async def test_get_library_agents(mocker):
return_value=mock_user_created
)
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma")
mock_user_agent.return_value.find_many = mocker.AsyncMock(
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_many = mocker.AsyncMock(
return_value=mock_library_agents
)
@@ -76,40 +77,16 @@ async def test_get_library_agents(mocker):
result = await db.get_library_agents("test-user")
# Verify results
assert len(result) == 2
assert result[0].id == "agent1"
assert result[0].name == "Test Agent 1"
assert result[0].description == "Test Description 1"
assert result[0].isCreatedByUser is True
assert result[1].id == "agent2"
assert result[1].name == "Test Agent 2"
assert result[1].description == "Test Description 2"
assert result[1].isCreatedByUser is False
# Verify mocks called correctly
mock_agent_graph.return_value.find_many.assert_called_once_with(
where=prisma.types.AgentGraphWhereInput(userId="test-user", isActive=True),
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
)
mock_user_agent.return_value.find_many.assert_called_once_with(
where=prisma.types.UserAgentWhereInput(
userId="test-user", isDeleted=False, isArchived=False
),
include={
"Agent": {
"include": {
"AgentNodes": {
"include": {
"Input": True,
"Output": True,
"Webhook": True,
"AgentBlock": True,
}
}
}
}
},
)
assert len(result) == 1
assert result[0].id == "ua1"
assert result[0].name == "Test Agent 2"
assert result[0].description == "Test Description 2"
assert result[0].is_created_by_user is False
assert result[0].is_latest_version is True
assert result[0].is_favorite is False
assert result[0].agent_id == "agent2"
assert result[0].agent_version == 1
assert result[0].preset_id is None
@pytest.mark.asyncio
@@ -152,26 +129,26 @@ async def test_add_agent_to_library(mocker):
return_value=mock_store_listing
)
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma")
mock_user_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_user_agent.return_value.create = mocker.AsyncMock()
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock()
# Call function
await db.add_agent_to_library("version123", "test-user")
await db.add_store_agent_to_library("version123", "test-user")
# Verify mocks called correctly
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"Agent": True}
)
mock_user_agent.return_value.find_first.assert_called_once_with(
mock_library_agent.return_value.find_first.assert_called_once_with(
where={
"userId": "test-user",
"agentId": "agent1",
"agentVersion": 1,
}
)
mock_user_agent.return_value.create.assert_called_once_with(
data=prisma.types.UserAgentCreateInput(
mock_library_agent.return_value.create.assert_called_once_with(
data=prisma.types.LibraryAgentCreateInput(
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
)
)
@@ -189,7 +166,7 @@ async def test_add_agent_to_library_not_found(mocker):
# Call function and verify exception
with pytest.raises(backend.server.v2.store.exceptions.AgentNotFoundError):
await db.add_agent_to_library("version123", "test-user")
await db.add_store_agent_to_library("version123", "test-user")
# Verify mock called correctly
mock_store_listing_version.return_value.find_unique.assert_called_once_with(

View File

@@ -1,16 +1,111 @@
import typing
import datetime
from typing import Any
import prisma.models
import pydantic
import backend.data.block as block_model
import backend.data.graph as graph_model
import backend.server.model as server_model
class LibraryAgent(pydantic.BaseModel):
id: str # Changed from agent_id to match GraphMeta
version: int # Changed from agent_version to match GraphMeta
is_active: bool # Added to match GraphMeta
agent_id: str
agent_version: int # Changed from agent_version to match GraphMeta
preset_id: str | None
updated_at: datetime.datetime
name: str
description: str
isCreatedByUser: bool
# Made input_schema and output_schema match GraphMeta's type
input_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
output_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
output_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
is_favorite: bool
is_created_by_user: bool
is_latest_version: bool
@staticmethod
def from_db(agent: prisma.models.LibraryAgent):
if not agent.Agent:
raise ValueError("AgentGraph is required")
graph = graph_model.GraphModel.from_db(agent.Agent)
agent_updated_at = agent.Agent.updatedAt
lib_agent_updated_at = agent.updatedAt
# Take the latest updated_at timestamp either when the graph was updated or the library agent was updated
updated_at = (
max(agent_updated_at, lib_agent_updated_at)
if agent_updated_at
else lib_agent_updated_at
)
return LibraryAgent(
id=agent.id,
agent_id=agent.agentId,
agent_version=agent.agentVersion,
updated_at=updated_at,
name=graph.name,
description=graph.description,
input_schema=graph.input_schema,
output_schema=graph.output_schema,
is_favorite=agent.isFavorite,
is_created_by_user=agent.isCreatedByUser,
is_latest_version=graph.is_active,
preset_id=agent.AgentPreset.id if agent.AgentPreset else None,
)
class LibraryAgentPreset(pydantic.BaseModel):
id: str
updated_at: datetime.datetime
agent_id: str
agent_version: int
name: str
description: str
is_active: bool
inputs: block_model.BlockInput
@staticmethod
def from_db(preset: prisma.models.AgentPreset):
input_data: block_model.BlockInput = {}
for preset_input in preset.InputPresets or []:
input_data[preset_input.name] = preset_input.data
return LibraryAgentPreset(
id=preset.id,
updated_at=preset.updatedAt,
agent_id=preset.agentId,
agent_version=preset.agentVersion,
name=preset.name,
description=preset.description,
is_active=preset.isActive,
inputs=input_data,
)
class LibraryAgentPresetResponse(pydantic.BaseModel):
presets: list[LibraryAgentPreset]
pagination: server_model.Pagination
class CreateLibraryAgentPresetRequest(pydantic.BaseModel):
name: str
description: str
inputs: block_model.BlockInput
agent_id: str
agent_version: int
is_active: bool

View File

@@ -1,23 +1,36 @@
import datetime
import prisma.fields
import prisma.models
import backend.data.block
import backend.server.model
import backend.server.v2.library.model
def test_library_agent():
agent = backend.server.v2.library.model.LibraryAgent(
id="test-agent-123",
version=1,
is_active=True,
agent_id="agent-123",
agent_version=1,
preset_id=None,
updated_at=datetime.datetime.now(),
name="Test Agent",
description="Test description",
isCreatedByUser=False,
input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}},
is_favorite=False,
is_created_by_user=False,
is_latest_version=True,
)
assert agent.id == "test-agent-123"
assert agent.version == 1
assert agent.is_active is True
assert agent.agent_id == "agent-123"
assert agent.agent_version == 1
assert agent.name == "Test Agent"
assert agent.description == "Test description"
assert agent.isCreatedByUser is False
assert agent.is_favorite is False
assert agent.is_created_by_user is False
assert agent.is_latest_version is True
assert agent.input_schema == {"type": "object", "properties": {}}
assert agent.output_schema == {"type": "object", "properties": {}}
@@ -25,19 +38,140 @@ def test_library_agent():
def test_library_agent_with_user_created():
agent = backend.server.v2.library.model.LibraryAgent(
id="user-agent-456",
version=2,
is_active=True,
agent_id="agent-456",
agent_version=2,
preset_id=None,
updated_at=datetime.datetime.now(),
name="User Created Agent",
description="An agent created by the user",
isCreatedByUser=True,
input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}},
is_favorite=False,
is_created_by_user=True,
is_latest_version=True,
)
assert agent.id == "user-agent-456"
assert agent.version == 2
assert agent.is_active is True
assert agent.agent_id == "agent-456"
assert agent.agent_version == 2
assert agent.name == "User Created Agent"
assert agent.description == "An agent created by the user"
assert agent.isCreatedByUser is True
assert agent.is_favorite is False
assert agent.is_created_by_user is True
assert agent.is_latest_version is True
assert agent.input_schema == {"type": "object", "properties": {}}
assert agent.output_schema == {"type": "object", "properties": {}}
def test_library_agent_preset():
preset = backend.server.v2.library.model.LibraryAgentPreset(
id="preset-123",
name="Test Preset",
description="Test preset description",
agent_id="test-agent-123",
agent_version=1,
is_active=True,
inputs={
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
},
updated_at=datetime.datetime.now(),
)
assert preset.id == "preset-123"
assert preset.name == "Test Preset"
assert preset.description == "Test preset description"
assert preset.agent_id == "test-agent-123"
assert preset.agent_version == 1
assert preset.is_active is True
assert preset.inputs == {
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
}
def test_library_agent_preset_response():
preset = backend.server.v2.library.model.LibraryAgentPreset(
id="preset-123",
name="Test Preset",
description="Test preset description",
agent_id="test-agent-123",
agent_version=1,
is_active=True,
inputs={
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
},
updated_at=datetime.datetime.now(),
)
pagination = backend.server.model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=10
)
response = backend.server.v2.library.model.LibraryAgentPresetResponse(
presets=[preset], pagination=pagination
)
assert len(response.presets) == 1
assert response.presets[0].id == "preset-123"
assert response.pagination.total_items == 1
assert response.pagination.total_pages == 1
assert response.pagination.current_page == 1
assert response.pagination.page_size == 10
def test_create_library_agent_preset_request():
request = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
name="New Preset",
description="New preset description",
agent_id="agent-123",
agent_version=1,
is_active=True,
inputs={
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
},
)
assert request.name == "New Preset"
assert request.description == "New preset description"
assert request.agent_id == "agent-123"
assert request.agent_version == 1
assert request.is_active is True
assert request.inputs == {
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
}
def test_library_agent_from_db():
# Create mock DB agent
db_agent = prisma.models.AgentPreset(
id="test-agent-123",
createdAt=datetime.datetime.now(),
updatedAt=datetime.datetime.now(),
agentId="agent-123",
agentVersion=1,
name="Test Agent",
description="Test agent description",
isActive=True,
userId="test-user-123",
isDeleted=False,
InputPresets=[
prisma.models.AgentNodeExecutionInputOutput(
id="input-123",
time=datetime.datetime.now(),
name="input1",
data=prisma.fields.Json({"type": "string", "value": "test value"}),
)
],
)
# Convert to LibraryAgentPreset
agent = backend.server.v2.library.model.LibraryAgentPreset.from_db(db_agent)
assert agent.id == "test-agent-123"
assert agent.agent_version == 1
assert agent.is_active is True
assert agent.name == "Test Agent"
assert agent.description == "Test agent description"
assert agent.inputs == {"input1": {"type": "string", "value": "test value"}}

View File

@@ -1,123 +0,0 @@
import logging
import typing
import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware
import fastapi
import prisma
import backend.data.graph
import backend.integrations.creds_manager
import backend.integrations.webhooks.graph_lifecycle_hooks
import backend.server.v2.library.db
import backend.server.v2.library.model
logger = logging.getLogger(__name__)
router = fastapi.APIRouter()
integration_creds_manager = (
backend.integrations.creds_manager.IntegrationCredentialsManager()
)
@router.get(
"/agents",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
async def get_library_agents(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
]
) -> typing.Sequence[backend.server.v2.library.model.LibraryAgent]:
"""
Get all agents in the user's library, including both created and saved agents.
"""
try:
agents = await backend.server.v2.library.db.get_library_agents(user_id)
return agents
except Exception:
logger.exception("Exception occurred whilst getting library agents")
raise fastapi.HTTPException(
status_code=500, detail="Failed to get library agents"
)
@router.post(
"/agents/{store_listing_version_id}",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
status_code=201,
)
async def add_agent_to_library(
store_listing_version_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> fastapi.Response:
"""
Add an agent from the store to the user's library.
Args:
store_listing_version_id (str): ID of the store listing version to add
user_id (str): ID of the authenticated user
Returns:
fastapi.Response: 201 status code on success
Raises:
HTTPException: If there is an error adding the agent to the library
"""
try:
# Get the graph from the store listing
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}, include={"Agent": True}
)
)
if not store_listing_version or not store_listing_version.Agent:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing version {store_listing_version_id} not found",
)
agent = store_listing_version.Agent
if agent.userId == user_id:
raise fastapi.HTTPException(
status_code=400, detail="Cannot add own agent to library"
)
# Create a new graph from the template
graph = await backend.data.graph.get_graph(
agent.id, agent.version, user_id=user_id
)
if not graph:
raise fastapi.HTTPException(
status_code=404, detail=f"Agent {agent.id} not found"
)
# Create a deep copy with new IDs
graph.version = 1
graph.is_template = False
graph.is_active = True
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
# Save the new graph
graph = await backend.data.graph.create_graph(graph, user_id=user_id)
graph = (
await backend.integrations.webhooks.graph_lifecycle_hooks.on_graph_activate(
graph,
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
)
)
return fastapi.Response(status_code=201)
except Exception:
logger.exception("Exception occurred whilst adding agent to library")
raise fastapi.HTTPException(
status_code=500, detail="Failed to add agent to library"
)

View File

@@ -0,0 +1,9 @@
import fastapi
from .agents import router as agents_router
from .presets import router as presets_router
router = fastapi.APIRouter()
router.include_router(presets_router)
router.include_router(agents_router)

View File

@@ -0,0 +1,138 @@
import logging
from typing import Annotated, Sequence
import autogpt_libs.auth as autogpt_auth_lib
import fastapi
import backend.server.v2.library.db as library_db
import backend.server.v2.library.model as library_model
import backend.server.v2.store.exceptions as store_exceptions
logger = logging.getLogger(__name__)
router = fastapi.APIRouter()
@router.get(
"/agents",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_auth_lib.auth_middleware)],
)
async def get_library_agents(
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)]
) -> Sequence[library_model.LibraryAgent]:
"""
Get all agents in the user's library, including both created and saved agents.
"""
try:
agents = await library_db.get_library_agents(user_id)
return agents
except Exception as e:
logger.exception(f"Exception occurred whilst getting library agents: {e}")
raise fastapi.HTTPException(
status_code=500, detail="Failed to get library agents"
)
@router.post(
"/agents/{store_listing_version_id}",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_auth_lib.auth_middleware)],
status_code=201,
)
async def add_agent_to_library(
store_listing_version_id: str,
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)],
) -> fastapi.Response:
"""
Add an agent from the store to the user's library.
Args:
store_listing_version_id (str): ID of the store listing version to add
user_id (str): ID of the authenticated user
Returns:
fastapi.Response: 201 status code on success
Raises:
HTTPException: If there is an error adding the agent to the library
"""
try:
# Use the database function to add the agent to the library
await library_db.add_store_agent_to_library(store_listing_version_id, user_id)
return fastapi.Response(status_code=201)
except store_exceptions.AgentNotFoundError:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing version {store_listing_version_id} not found",
)
except store_exceptions.DatabaseError as e:
logger.exception(f"Database error occurred whilst adding agent to library: {e}")
raise fastapi.HTTPException(
status_code=500, detail="Failed to add agent to library"
)
except Exception as e:
logger.exception(
f"Unexpected exception occurred whilst adding agent to library: {e}"
)
raise fastapi.HTTPException(
status_code=500, detail="Failed to add agent to library"
)
@router.put(
"/agents/{library_agent_id}",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_auth_lib.auth_middleware)],
status_code=204,
)
async def update_library_agent(
library_agent_id: str,
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)],
auto_update_version: bool = False,
is_favorite: bool = False,
is_archived: bool = False,
is_deleted: bool = False,
) -> fastapi.Response:
"""
Update the library agent with the given fields.
Args:
library_agent_id (str): ID of the library agent to update
user_id (str): ID of the authenticated user
auto_update_version (bool): Whether to auto-update the agent version
is_favorite (bool): Whether the agent is marked as favorite
is_archived (bool): Whether the agent is archived
is_deleted (bool): Whether the agent is deleted
Returns:
fastapi.Response: 204 status code on success
Raises:
HTTPException: If there is an error updating the library agent
"""
try:
# Use the database function to update the library agent
await library_db.update_library_agent(
library_agent_id,
user_id,
auto_update_version,
is_favorite,
is_archived,
is_deleted,
)
return fastapi.Response(status_code=204)
except store_exceptions.DatabaseError as e:
logger.exception(f"Database error occurred whilst updating library agent: {e}")
raise fastapi.HTTPException(
status_code=500, detail="Failed to update library agent"
)
except Exception as e:
logger.exception(
f"Unexpected exception occurred whilst updating library agent: {e}"
)
raise fastapi.HTTPException(
status_code=500, detail="Failed to update library agent"
)

View File

@@ -0,0 +1,130 @@
import logging
from typing import Annotated, Any
import autogpt_libs.auth as autogpt_auth_lib
import autogpt_libs.utils.cache
import fastapi
import backend.executor
import backend.server.v2.library.db as library_db
import backend.server.v2.library.model as library_model
import backend.util.service
logger = logging.getLogger(__name__)
router = fastapi.APIRouter()
@autogpt_libs.utils.cache.thread_cached
def execution_manager_client() -> backend.executor.ExecutionManager:
return backend.util.service.get_service_client(backend.executor.ExecutionManager)
@router.get("/presets")
async def get_presets(
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)],
page: int = 1,
page_size: int = 10,
) -> library_model.LibraryAgentPresetResponse:
try:
presets = await library_db.get_presets(user_id, page, page_size)
return presets
except Exception as e:
logger.exception(f"Exception occurred whilst getting presets: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to get presets")
@router.get("/presets/{preset_id}")
async def get_preset(
preset_id: str,
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)],
) -> library_model.LibraryAgentPreset:
try:
preset = await library_db.get_preset(user_id, preset_id)
if not preset:
raise fastapi.HTTPException(
status_code=404,
detail=f"Preset {preset_id} not found",
)
return preset
except Exception as e:
logger.exception(f"Exception occurred whilst getting preset: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to get preset")
@router.post("/presets")
async def create_preset(
preset: library_model.CreateLibraryAgentPresetRequest,
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)],
) -> library_model.LibraryAgentPreset:
try:
return await library_db.upsert_preset(user_id, preset)
except Exception as e:
logger.exception(f"Exception occurred whilst creating preset: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to create preset")
@router.put("/presets/{preset_id}")
async def update_preset(
preset_id: str,
preset: library_model.CreateLibraryAgentPresetRequest,
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)],
) -> library_model.LibraryAgentPreset:
try:
return await library_db.upsert_preset(user_id, preset, preset_id)
except Exception as e:
logger.exception(f"Exception occurred whilst updating preset: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to update preset")
@router.delete("/presets/{preset_id}")
async def delete_preset(
preset_id: str,
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)],
):
try:
await library_db.delete_preset(user_id, preset_id)
return fastapi.Response(status_code=204)
except Exception as e:
logger.exception(f"Exception occurred whilst deleting preset: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to delete preset")
@router.post(
path="/presets/{preset_id}/execute",
tags=["presets"],
dependencies=[fastapi.Depends(autogpt_auth_lib.auth_middleware)],
)
async def execute_preset(
graph_id: str,
graph_version: int,
preset_id: str,
node_input: Annotated[
dict[str, Any], fastapi.Body(..., embed=True, default_factory=dict)
],
user_id: Annotated[str, fastapi.Depends(autogpt_auth_lib.depends.get_user_id)],
) -> dict[str, Any]: # FIXME: add proper return type
try:
preset = await library_db.get_preset(user_id, preset_id)
if not preset:
raise fastapi.HTTPException(status_code=404, detail="Preset not found")
logger.debug(f"Preset inputs: {preset.inputs}")
# Merge input overrides with preset inputs
merged_node_input = preset.inputs | node_input
execution = execution_manager_client().add_execution(
graph_id=graph_id,
graph_version=graph_version,
data=merged_node_input,
user_id=user_id,
preset_id=preset_id,
)
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
return {"id": execution.graph_exec_id}
except Exception as e:
msg = str(e).encode().decode("unicode_escape")
raise fastapi.HTTPException(status_code=400, detail=msg)

View File

@@ -1,16 +1,16 @@
import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware
import datetime
import autogpt_libs.auth as autogpt_auth_lib
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
import backend.server.v2.library.db
import backend.server.v2.library.model
import backend.server.v2.library.routes
import backend.server.v2.library.model as library_model
from backend.server.v2.library.routes import router as library_router
app = fastapi.FastAPI()
app.include_router(backend.server.v2.library.routes.router)
app.include_router(library_router)
client = fastapi.testclient.TestClient(app)
@@ -25,31 +25,37 @@ def override_get_user_id():
return "test-user-id"
app.dependency_overrides[autogpt_libs.auth.middleware.auth_middleware] = (
override_auth_middleware
)
app.dependency_overrides[autogpt_libs.auth.depends.get_user_id] = override_get_user_id
app.dependency_overrides[autogpt_auth_lib.auth_middleware] = override_auth_middleware
app.dependency_overrides[autogpt_auth_lib.depends.get_user_id] = override_get_user_id
def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
mocked_value = [
backend.server.v2.library.model.LibraryAgent(
library_model.LibraryAgent(
id="test-agent-1",
version=1,
is_active=True,
agent_id="test-agent-1",
agent_version=1,
preset_id="preset-1",
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
is_favorite=False,
is_created_by_user=True,
is_latest_version=True,
name="Test Agent 1",
description="Test Description 1",
isCreatedByUser=True,
input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}},
),
backend.server.v2.library.model.LibraryAgent(
library_model.LibraryAgent(
id="test-agent-2",
version=1,
is_active=True,
agent_id="test-agent-2",
agent_version=1,
preset_id="preset-2",
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
is_favorite=False,
is_created_by_user=False,
is_latest_version=True,
name="Test Agent 2",
description="Test Description 2",
isCreatedByUser=False,
input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}},
),
@@ -61,14 +67,13 @@ def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
assert response.status_code == 200
data = [
backend.server.v2.library.model.LibraryAgent.model_validate(agent)
for agent in response.json()
library_model.LibraryAgent.model_validate(agent) for agent in response.json()
]
assert len(data) == 2
assert data[0].id == "test-agent-1"
assert data[0].isCreatedByUser is True
assert data[1].id == "test-agent-2"
assert data[1].isCreatedByUser is False
assert data[0].agent_id == "test-agent-1"
assert data[0].is_created_by_user is True
assert data[1].agent_id == "test-agent-2"
assert data[1].is_created_by_user is False
mock_db_call.assert_called_once_with("test-user-id")

View File

@@ -253,7 +253,9 @@ async def block_autogen_agent():
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"input": "Write me a block that writes a string into a file."}
response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
graph_id=test_graph.id,
user_id=test_user.id,
node_input=input_data,
)
print(response)
result = await wait_execution(

View File

@@ -157,7 +157,9 @@ async def reddit_marketing_agent():
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"subreddit": "AutoGPT"}
response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
graph_id=test_graph.id,
user_id=test_user.id,
node_input=input_data,
)
print(response)
result = await wait_execution(

View File

@@ -86,7 +86,9 @@ async def sample_agent():
test_graph = await create_graph(create_test_graph(), test_user.id)
input_data = {"input_1": "Hello", "input_2": "World"}
response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
graph_id=test_graph.id,
user_id=test_user.id,
node_input=input_data,
)
print(response)
result = await wait_execution(

View File

@@ -62,7 +62,7 @@ def expose(func: C) -> C:
try:
return func(*args, **kwargs)
except Exception as e:
msg = f"Error in {func.__name__}: {e.__str__()}"
msg = f"Error in {func.__name__}: {e}"
if isinstance(e, ValueError):
logger.warning(msg)
else:
@@ -80,7 +80,7 @@ def register_pydantic_serializers(func: Callable):
try:
pydantic_types = _pydantic_models_from_type_annotation(annotation)
except Exception as e:
raise TypeError(f"Error while exposing {func.__name__}: {e.__str__()}")
raise TypeError(f"Error while exposing {func.__name__}: {e}")
for model in pydantic_types:
logger.debug(
@@ -117,7 +117,7 @@ class AppService(AppProcess, ABC):
shared_event_loop: asyncio.AbstractEventLoop
use_db: bool = False
use_redis: bool = False
use_rabbitmq: Optional[rabbitmq.RabbitMQConfig] = None
rabbitmq_config: Optional[rabbitmq.RabbitMQConfig] = None
rabbitmq_service: Optional[rabbitmq.AsyncRabbitMQ] = None
use_supabase: bool = False
@@ -143,9 +143,9 @@ class AppService(AppProcess, ABC):
@property
def rabbit_config(self) -> rabbitmq.RabbitMQConfig:
"""Access the RabbitMQ config. Will raise if not configured."""
if not self.use_rabbitmq:
if not self.rabbitmq_config:
raise RuntimeError("RabbitMQ not configured for this service")
return self.use_rabbitmq
return self.rabbitmq_config
def run_service(self) -> None:
while True:
@@ -164,13 +164,13 @@ class AppService(AppProcess, ABC):
self.shared_event_loop.run_until_complete(db.connect())
if self.use_redis:
redis.connect()
if self.use_rabbitmq:
if self.rabbitmq_config:
logger.info(f"[{self.__class__.__name__}] ⏳ Configuring RabbitMQ...")
# if self.use_async:
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.use_rabbitmq)
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
self.shared_event_loop.run_until_complete(self.rabbitmq_service.connect())
# else:
# self.rabbitmq_service = rabbitmq.SyncRabbitMQ(self.use_rabbitmq)
# self.rabbitmq_service = rabbitmq.SyncRabbitMQ(self.rabbitmq_config)
# self.rabbitmq_service.connect()
if self.use_supabase:
from supabase import create_client
@@ -200,7 +200,7 @@ class AppService(AppProcess, ABC):
if self.use_redis:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting Redis...")
redis.disconnect()
if self.use_rabbitmq:
if self.rabbitmq_config:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting RabbitMQ...")
@conn_retry("Pyro", "Starting Pyro Service")

View File

@@ -22,7 +22,7 @@ class SpinTestServer:
self.exec_manager = ExecutionManager()
self.agent_server = AgentServer()
self.scheduler = ExecutionScheduler()
self.notifications = NotificationManager()
self.notif_manager = NotificationManager()
@staticmethod
def test_get_user_id():
@@ -34,7 +34,7 @@ class SpinTestServer:
self.agent_server.__enter__()
self.exec_manager.__enter__()
self.scheduler.__enter__()
self.notifications.__enter__()
self.notif_manager.__enter__()
await db.connect()
await initialize_blocks()
@@ -49,7 +49,7 @@ class SpinTestServer:
self.exec_manager.__exit__(exc_type, exc_val, exc_tb)
self.agent_server.__exit__(exc_type, exc_val, exc_tb)
self.db_api.__exit__(exc_type, exc_val, exc_tb)
self.notifications.__exit__(exc_type, exc_val, exc_tb)
self.notif_manager.__exit__(exc_type, exc_val, exc_tb)
def setup_dependency_overrides(self):
# Override get_user_id for testing

View File

@@ -1,17 +1,69 @@
import logging
import bleach
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
from markupsafe import Markup
logger = logging.getLogger(__name__)
class TextFormatter:
def __init__(self):
# Create a sandboxed environment
self.env = SandboxedEnvironment(loader=BaseLoader(), autoescape=True)
# Clear any registered filters, tests, and globals to minimize attack surface
self.env.filters.clear()
self.env.tests.clear()
self.env.globals.clear()
self.allowed_tags = ["p", "b", "i", "u", "ul", "li", "br", "strong", "em"]
self.allowed_attributes = {"*": ["style", "class"]}
def format_string(self, template_str: str, values=None, **kwargs) -> str:
"""Regular template rendering with escaping"""
template = self.env.from_string(template_str)
return template.render(values or {}, **kwargs)
def format_email(
self,
subject_template: str,
base_template: str,
content_template: str,
data=None,
**kwargs,
) -> tuple[str, str]:
"""
Special handling for email templates where content needs to be rendered as HTML
"""
# First render the content template
content = self.format_string(content_template, data, **kwargs)
# Clean the HTML but don't escape it
clean_content = bleach.clean(
content,
tags=self.allowed_tags,
attributes=self.allowed_attributes,
strip=True,
)
# Mark the cleaned HTML as safe using Markup
safe_content = Markup(clean_content)
rendered_subject_template = self.format_string(subject_template, data, **kwargs)
# Create new env just for HTML template
html_env = SandboxedEnvironment(loader=BaseLoader(), autoescape=True)
html_env.filters["safe"] = lambda x: (
x if isinstance(x, Markup) else Markup(str(x))
)
# Render base template with the safe content
template = html_env.from_string(base_template)
rendered_base_template = template.render(
data={
"message": safe_content,
"title": rendered_subject_template,
"unsubscribe_link": kwargs.get("unsubscribe_link", ""),
}
)
return rendered_subject_template, rendered_base_template

View File

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "AgentPreset" ADD COLUMN "isDeleted" BOOLEAN NOT NULL DEFAULT false;

View File

@@ -0,0 +1,46 @@
/*
Warnings:
- You are about to drop the `UserAgent` table. If the table is not empty, all the data it contains will be lost.
*/
-- DropForeignKey
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_agentId_agentVersion_fkey";
-- DropForeignKey
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_agentPresetId_fkey";
-- DropForeignKey
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_userId_fkey";
-- DropTable
DROP TABLE "UserAgent";
-- CreateTable
CREATE TABLE "LibraryAgent" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"userId" TEXT NOT NULL,
"agentId" TEXT NOT NULL,
"agentVersion" INTEGER NOT NULL,
"agentPresetId" TEXT,
"isFavorite" BOOLEAN NOT NULL DEFAULT false,
"isCreatedByUser" BOOLEAN NOT NULL DEFAULT false,
"isArchived" BOOLEAN NOT NULL DEFAULT false,
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
CONSTRAINT "LibraryAgent_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "LibraryAgent_userId_idx" ON "LibraryAgent"("userId");
-- AddForeignKey
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_agentId_agentVersion_fkey" FOREIGN KEY ("agentId", "agentVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_agentPresetId_fkey" FOREIGN KEY ("agentPresetId") REFERENCES "AgentPreset"("id") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "LibraryAgent" ADD COLUMN "useGraphIsActiveVersion" BOOLEAN NOT NULL DEFAULT false;

View File

@@ -365,6 +365,24 @@ d = ["aiohttp (>=3.10)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"]
[[package]]
name = "bleach"
version = "6.2.0"
description = "An easy safelist-based HTML-sanitizing tool."
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "bleach-6.2.0-py3-none-any.whl", hash = "sha256:117d9c6097a7c3d22fd578fcd8d35ff1e125df6736f554da4e432fdd63f31e5e"},
{file = "bleach-6.2.0.tar.gz", hash = "sha256:123e894118b8a599fd80d3ec1a6d4cc7ce4e5882b1317a7e1ba69b56e95f991f"},
]
[package.dependencies]
webencodings = "*"
[package.extras]
css = ["tinycss2 (>=1.1.0,<1.5)"]
[[package]]
name = "cachetools"
version = "5.5.1"
@@ -4799,6 +4817,18 @@ files = [
[package.dependencies]
anyio = ">=3.0.0"
[[package]]
name = "webencodings"
version = "0.5.1"
description = "Character encoding aliases for legacy web content"
optional = false
python-versions = "*"
groups = ["main"]
files = [
{file = "webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78"},
{file = "webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923"},
]
[[package]]
name = "websocket-client"
version = "1.8.0"
@@ -5137,4 +5167,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.13"
content-hash = "e06603a4c5a82e6e40e8b76c7c4a5dc0b60b31fbe5e47cce3b89d66c53bb704c"
content-hash = "fd396e770353328ac3fe71aaba5b73b5a147dd445e4866e5328d6173a13bf7a3"

View File

@@ -62,6 +62,7 @@ uvicorn = { extras = ["standard"], version = "^0.34.0" }
websockets = "^13.1"
youtube-transcript-api = "^0.6.2"
# NOTE: please insert new dependencies in their alphabetical location
bleach = "^6.2.0"
[tool.poetry.group.dev.dependencies]
aiohappyeyeballs = "^2.4.4"

View File

@@ -41,8 +41,8 @@ model User {
AnalyticsMetrics AnalyticsMetrics[]
CreditTransaction CreditTransaction[]
AgentPreset AgentPreset[]
UserAgent UserAgent[]
AgentPreset AgentPreset[]
LibraryAgent LibraryAgent[]
Profile Profile[]
StoreListing StoreListing[]
@@ -78,7 +78,7 @@ model AgentGraph {
AgentGraphExecution AgentGraphExecution[]
AgentPreset AgentPreset[]
UserAgent UserAgent[]
LibraryAgent LibraryAgent[]
StoreListing StoreListing[]
StoreListingVersion StoreListingVersion?
@@ -116,9 +116,11 @@ model AgentPreset {
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
UserAgents UserAgent[]
LibraryAgents LibraryAgent[]
AgentExecution AgentGraphExecution[]
isDeleted Boolean @default(false)
@@index([userId])
}
@@ -163,7 +165,7 @@ model UserNotificationBatch {
// For the library page
// It is a user controlled list of agents, that they will see in there library
model UserAgent {
model LibraryAgent {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
@@ -178,6 +180,8 @@ model UserAgent {
agentPresetId String?
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
useGraphIsActiveVersion Boolean @default(false)
isFavorite Boolean @default(false)
isCreatedByUser Boolean @default(false)
isArchived Boolean @default(false)

View File

@@ -5,8 +5,9 @@ import fastapi.responses
import pytest
from prisma.models import User
import backend.server.v2.library.model
import backend.server.v2.store.model
from backend.blocks.basic import FindInDictionaryBlock, StoreValueBlock
from backend.blocks.basic import AgentInputBlock, FindInDictionaryBlock, StoreValueBlock
from backend.blocks.maths import CalculatorBlock, Operation
from backend.data import execution, graph
from backend.server.model import CreateGraph
@@ -131,7 +132,7 @@ async def test_agent_execution(server: SpinTestServer):
logger.info("Starting test_agent_execution")
test_user = await create_test_user()
test_graph = await create_graph(server, create_test_graph(), test_user)
data = {"node_input": {"input_1": "Hello", "input_2": "World"}}
data = {"input_1": "Hello", "input_2": "World"}
graph_exec_id = await execute_graph(
server.agent_server,
test_graph,
@@ -295,6 +296,192 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
logger.info("Completed test_static_input_link_on_graph")
@pytest.mark.asyncio(scope="session")
async def test_execute_preset(server: SpinTestServer):
"""
Test executing a preset.
This test ensures that:
1. A preset can be successfully executed
2. The execution results are correct
Args:
server (SpinTestServer): The test server instance.
"""
# Create test graph and user
nodes = [
graph.Node( # 0
block_id=AgentInputBlock().id,
input_default={"name": "dictionary"},
),
graph.Node( # 1
block_id=AgentInputBlock().id,
input_default={"name": "selected_value"},
),
graph.Node( # 2
block_id=StoreValueBlock().id,
input_default={"input": {"key1": "Hi", "key2": "Everyone"}},
),
graph.Node( # 3
block_id=FindInDictionaryBlock().id,
input_default={"key": "", "input": {}},
),
]
links = [
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[2].id,
source_name="result",
sink_name="input",
),
graph.Link(
source_id=nodes[1].id,
sink_id=nodes[3].id,
source_name="result",
sink_name="key",
),
graph.Link(
source_id=nodes[2].id,
sink_id=nodes[3].id,
source_name="output",
sink_name="input",
),
]
test_graph = graph.Graph(
name="TestGraph",
description="Test graph",
nodes=nodes,
links=links,
)
test_user = await create_test_user()
test_graph = await create_graph(server, test_graph, test_user)
# Create preset with initial values
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
name="Test Preset With Clash",
description="Test preset with clashing input values",
agent_id=test_graph.id,
agent_version=test_graph.version,
inputs={
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
},
is_active=True,
)
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
# Execute preset with overriding values
result = await server.agent_server.test_execute_preset(
graph_id=test_graph.id,
graph_version=test_graph.version,
preset_id=created_preset.id,
user_id=test_user.id,
)
# Verify execution
assert result is not None
graph_exec_id = result["id"]
# Wait for execution to complete
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
assert len(executions) == 4
# FindInDictionaryBlock should wait for the input pin to be provided,
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
assert executions[3].status == execution.ExecutionStatus.COMPLETED
assert executions[3].output_data == {"output": ["World"]}
@pytest.mark.asyncio(scope="session")
async def test_execute_preset_with_clash(server: SpinTestServer):
"""
Test executing a preset with clashing input data.
"""
# Create test graph and user
nodes = [
graph.Node( # 0
block_id=AgentInputBlock().id,
input_default={"name": "dictionary"},
),
graph.Node( # 1
block_id=AgentInputBlock().id,
input_default={"name": "selected_value"},
),
graph.Node( # 2
block_id=StoreValueBlock().id,
input_default={"input": {"key1": "Hi", "key2": "Everyone"}},
),
graph.Node( # 3
block_id=FindInDictionaryBlock().id,
input_default={"key": "", "input": {}},
),
]
links = [
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[2].id,
source_name="result",
sink_name="input",
),
graph.Link(
source_id=nodes[1].id,
sink_id=nodes[3].id,
source_name="result",
sink_name="key",
),
graph.Link(
source_id=nodes[2].id,
sink_id=nodes[3].id,
source_name="output",
sink_name="input",
),
]
test_graph = graph.Graph(
name="TestGraph",
description="Test graph",
nodes=nodes,
links=links,
)
test_user = await create_test_user()
test_graph = await create_graph(server, test_graph, test_user)
# Create preset with initial values
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
name="Test Preset With Clash",
description="Test preset with clashing input values",
agent_id=test_graph.id,
agent_version=test_graph.version,
inputs={
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
},
is_active=True,
)
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
# Execute preset with overriding values
result = await server.agent_server.test_execute_preset(
graph_id=test_graph.id,
graph_version=test_graph.version,
preset_id=created_preset.id,
node_input={"selected_value": "key1"},
user_id=test_user.id,
)
# Verify execution
assert result is not None
graph_exec_id = result["id"]
# Wait for execution to complete
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
assert len(executions) == 4
# FindInDictionaryBlock should wait for the input pin to be provided,
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
assert executions[3].status == execution.ExecutionStatus.COMPLETED
assert executions[3].output_data == {"output": ["Hello"]}
@pytest.mark.asyncio(scope="session")
async def test_store_listing_graph(server: SpinTestServer):
logger.info("Starting test_agent_execution")
@@ -344,7 +531,7 @@ async def test_store_listing_graph(server: SpinTestServer):
)
alt_test_user = admin_user
data = {"node_input": {"input_1": "Hello", "input_2": "World"}}
data = {"input_1": "Hello", "input_2": "World"}
graph_exec_id = await execute_graph(
server.agent_server,
test_graph,

View File

@@ -140,10 +140,10 @@ async def main():
print(f"Inserting {NUM_USERS * MAX_AGENTS_PER_USER} user agents")
for user in users:
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
for _ in range(num_agents): # Create 1 UserAgent per user
for _ in range(num_agents): # Create 1 LibraryAgent per user
graph = random.choice(agent_graphs)
preset = random.choice(agent_presets)
user_agent = await db.useragent.create(
user_agent = await db.libraryagent.create(
data={
"userId": user.id,
"agentId": graph.id,

View File

@@ -81,18 +81,20 @@ services:
- REDIS_PORT=6379
- RABBITMQ_HOST=rabbitmq
- RABBITMQ_PORT=5672
- RABBITMQ_USER=rabbitmq_user_default
- RABBITMQ_PASSWORD=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
- RABBITMQ_DEFAULT_USER=rabbitmq_user_default
- RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
- REDIS_PASSWORD=password
- ENABLE_AUTH=true
- PYRO_HOST=0.0.0.0
- EXECUTIONSCHEDULER_HOST=rest_server
- EXECUTIONMANAGER_HOST=executor
- NOTIFICATIONMANAGER_HOST=rest_server
- FRONTEND_BASE_URL=http://localhost:3000
- BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
- ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw= # DO NOT USE IN PRODUCTION!!
ports:
- "8006:8006"
- "8007:8007"
- "8003:8003" # execution scheduler
networks:
- app-network
@@ -127,11 +129,12 @@ services:
- REDIS_PASSWORD=password
- RABBITMQ_HOST=rabbitmq
- RABBITMQ_PORT=5672
- RABBITMQ_USER=rabbitmq_user_default
- RABBITMQ_PASSWORD=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
- RABBITMQ_DEFAULT_USER=rabbitmq_user_default
- RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
- ENABLE_AUTH=true
- PYRO_HOST=0.0.0.0
- AGENTSERVER_HOST=rest_server
- NOTIFICATIONMANAGER_HOST=rest_server
- ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw= # DO NOT USE IN PRODUCTION!!
ports:
- "8002:8000"
@@ -157,17 +160,17 @@ services:
# rabbitmq:
# condition: service_healthy
migrate:
condition: service_completed_successfully
condition: service_completed_successfully
environment:
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- REDIS_HOST=redis
- REDIS_PORT=6379
- REDIS_PASSWORD=password
# - RABBITMQ_HOST=rabbitmq # TODO: Uncomment this when we have a need for it in websocket (like nofifying when stuff went down)
# - RABBITMQ_HOST=rabbitmq
# - RABBITMQ_PORT=5672
# - RABBITMQ_USER=rabbitmq_user_default
# - RABBITMQ_PASSWORD=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
# - RABBITMQ_DEFAULT_USER=rabbitmq_user_default
# - RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
- ENABLE_AUTH=true
- PYRO_HOST=0.0.0.0
- BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]

View File

@@ -160,10 +160,8 @@ export default class BackendAPI {
return this._get(`/graphs/${id}/versions`);
}
createGraph(graphCreateBody: GraphCreatable): Promise<Graph>;
createGraph(graphID: GraphCreatable | string): Promise<Graph> {
let requestBody = { graph: graphID } as GraphCreateRequestBody;
createGraph(graph: GraphCreatable): Promise<Graph> {
let requestBody = { graph } as GraphCreateRequestBody;
return this._request("POST", "/graphs", requestBody);
}