mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(backend): Register agent subgraphs as library entries during agent import (#10409)
Currently, we only create a library entry of the top-most graph when importing the graph from an exported file. This can cause some complications, as there is no way to remove the library entry of it. ### Changes 🏗️ Create the library entry for all the subgraphs during the import process. ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [x] Export an agent with subgraphs and import it back.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
from typing import TYPE_CHECKING, Optional, cast, overload
|
||||
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.graph import set_node_webhook
|
||||
@@ -9,7 +10,7 @@ from . import get_webhook_manager, supports_webhooks
|
||||
from .utils import setup_webhook_for_block
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import GraphModel, NodeModel
|
||||
from backend.data.graph import BaseGraph, GraphModel, Node, NodeModel
|
||||
from backend.data.model import Credentials
|
||||
|
||||
from ._base import BaseWebhooksManager
|
||||
@@ -18,13 +19,29 @@ logger = logging.getLogger(__name__)
|
||||
credentials_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def on_graph_activate(graph: "GraphModel", user_id: str):
|
||||
async def on_graph_activate(graph: "GraphModel", user_id: str) -> "GraphModel":
|
||||
"""
|
||||
Hook to be called when a graph is activated/created.
|
||||
|
||||
⚠️ Assuming node entities are not re-used between graph versions, ⚠️
|
||||
this hook calls `on_node_activate` on all nodes in this graph.
|
||||
"""
|
||||
graph = await _on_graph_activate(graph, user_id)
|
||||
graph.sub_graphs = await asyncio.gather(
|
||||
*(_on_graph_activate(sub_graph, user_id) for sub_graph in graph.sub_graphs)
|
||||
)
|
||||
return graph
|
||||
|
||||
|
||||
@overload
|
||||
async def _on_graph_activate(graph: "GraphModel", user_id: str) -> "GraphModel": ...
|
||||
|
||||
|
||||
@overload
|
||||
async def _on_graph_activate(graph: "BaseGraph", user_id: str) -> "BaseGraph": ...
|
||||
|
||||
|
||||
async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str):
|
||||
get_credentials = credentials_manager.cached_getter(user_id)
|
||||
updated_nodes = []
|
||||
for new_node in graph.nodes:
|
||||
@@ -47,7 +64,7 @@ async def on_graph_activate(graph: "GraphModel", user_id: str):
|
||||
)
|
||||
|
||||
updated_node = await on_node_activate(
|
||||
graph.user_id, new_node, credentials=node_credentials
|
||||
user_id, graph.id, new_node, credentials=node_credentials
|
||||
)
|
||||
updated_nodes.append(updated_node)
|
||||
|
||||
@@ -94,10 +111,11 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
|
||||
|
||||
async def on_node_activate(
|
||||
user_id: str,
|
||||
node: "NodeModel",
|
||||
graph_id: str,
|
||||
node: "Node",
|
||||
*,
|
||||
credentials: Optional["Credentials"] = None,
|
||||
) -> "NodeModel":
|
||||
) -> "Node":
|
||||
"""Hook to be called when the node is activated/created"""
|
||||
|
||||
if node.block.webhook_config:
|
||||
@@ -105,7 +123,7 @@ async def on_node_activate(
|
||||
user_id=user_id,
|
||||
trigger_block=node.block,
|
||||
trigger_config=node.input_default,
|
||||
for_graph_id=node.graph_id,
|
||||
for_graph_id=graph_id,
|
||||
)
|
||||
if new_webhook:
|
||||
node = await set_node_webhook(node.id, new_webhook.id)
|
||||
|
||||
@@ -621,16 +621,11 @@ async def create_new_graph(
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||
graph.validate_graph(for_run=False)
|
||||
|
||||
graph = await graph_db.create_graph(graph, user_id=user_id)
|
||||
|
||||
# Create a library agent for the new graph
|
||||
library_agent = await library_db.create_library_agent(graph, user_id)
|
||||
_ = asyncio.create_task(
|
||||
library_db.add_generated_agent_image(graph, library_agent.id)
|
||||
)
|
||||
|
||||
graph = await on_graph_activate(graph, user_id=user_id)
|
||||
return graph
|
||||
# The return value of the create graph & library function is intentionally not used here,
|
||||
# as the graph already valid and no sub-graphs are returned back.
|
||||
await graph_db.create_graph(graph, user_id=user_id)
|
||||
await library_db.create_library_agent(graph, user_id=user_id)
|
||||
return await on_graph_activate(graph, user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Literal, Optional
|
||||
|
||||
@@ -246,13 +247,13 @@ async def get_library_agent_by_graph_id(
|
||||
|
||||
|
||||
async def add_generated_agent_image(
|
||||
graph: graph_db.GraphModel,
|
||||
graph: graph_db.BaseGraph,
|
||||
user_id: str,
|
||||
library_agent_id: str,
|
||||
) -> Optional[prisma.models.LibraryAgent]:
|
||||
"""
|
||||
Generates an image for the specified LibraryAgent and updates its record.
|
||||
"""
|
||||
user_id = graph.user_id
|
||||
graph_id = graph.id
|
||||
|
||||
# Use .jpeg here since we are generating JPEG images
|
||||
@@ -281,16 +282,19 @@ async def add_generated_agent_image(
|
||||
async def create_library_agent(
|
||||
graph: graph_db.GraphModel,
|
||||
user_id: str,
|
||||
) -> library_model.LibraryAgent:
|
||||
create_library_agents_for_sub_graphs: bool = True,
|
||||
) -> list[library_model.LibraryAgent]:
|
||||
"""
|
||||
Adds an agent to the user's library (LibraryAgent table).
|
||||
|
||||
Args:
|
||||
agent: The agent/Graph to add to the library.
|
||||
user_id: The user to whom the agent will be added.
|
||||
create_library_agents_for_sub_graphs: If True, creates LibraryAgent records for sub-graphs as well.
|
||||
|
||||
Returns:
|
||||
The newly created LibraryAgent record.
|
||||
The newly created LibraryAgent records.
|
||||
If the graph has sub-graphs, the parent graph will always be the first entry in the list.
|
||||
|
||||
Raises:
|
||||
AgentNotFoundError: If the specified agent does not exist.
|
||||
@@ -300,26 +304,39 @@ async def create_library_agent(
|
||||
f"Creating library agent for graph #{graph.id} v{graph.version}; "
|
||||
f"user #{user_id}"
|
||||
)
|
||||
graph_entries = (
|
||||
[graph, *graph.sub_graphs] if create_library_agents_for_sub_graphs else [graph]
|
||||
)
|
||||
|
||||
try:
|
||||
agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
isCreatedByUser=(user_id == graph.user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
# Creator={"connect": {"id": graph.user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
),
|
||||
include={"AgentGraph": True},
|
||||
async with transaction() as tx:
|
||||
library_agents = await asyncio.gather(
|
||||
*(
|
||||
prisma.models.LibraryAgent.prisma(tx).create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
isCreatedByUser=(user_id == user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
# Creator={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_entry.id,
|
||||
"version": graph_entry.version,
|
||||
}
|
||||
}
|
||||
},
|
||||
),
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
for graph_entry in graph_entries
|
||||
)
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(agent)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating agent in library: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to create agent in library") from e
|
||||
|
||||
# Generate images for the main graph and sub-graphs
|
||||
for agent, graph in zip(library_agents, graph_entries):
|
||||
asyncio.create_task(add_generated_agent_image(graph, user_id, agent.id))
|
||||
|
||||
return [library_model.LibraryAgent.from_db(agent) for agent in library_agents]
|
||||
|
||||
|
||||
async def update_agent_version_in_library(
|
||||
@@ -872,7 +889,9 @@ async def delete_preset(user_id: str, preset_id: str) -> None:
|
||||
raise store_exceptions.DatabaseError("Failed to delete preset") from e
|
||||
|
||||
|
||||
async def fork_library_agent(library_agent_id: str, user_id: str):
|
||||
async def fork_library_agent(
|
||||
library_agent_id: str, user_id: str
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Clones a library agent and its underyling graph and nodes (with new ids) for the given user.
|
||||
|
||||
@@ -881,7 +900,7 @@ async def fork_library_agent(library_agent_id: str, user_id: str):
|
||||
user_id: The ID of the user who owns the library agent.
|
||||
|
||||
Returns:
|
||||
The forked LibraryAgent.
|
||||
The forked parent (if it has sub-graphs) LibraryAgent.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's an error during the forking process.
|
||||
@@ -907,7 +926,7 @@ async def fork_library_agent(library_agent_id: str, user_id: str):
|
||||
new_graph = await on_graph_activate(new_graph, user_id=user_id)
|
||||
|
||||
# Create a library agent for the new graph
|
||||
return await create_library_agent(new_graph, user_id)
|
||||
return (await create_library_agent(new_graph, user_id))[0]
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error cloning library agent: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to fork library agent") from e
|
||||
|
||||
@@ -16,7 +16,7 @@ from backend.blocks.ideogram import (
|
||||
StyleType,
|
||||
UpscaleOption,
|
||||
)
|
||||
from backend.data.graph import Graph
|
||||
from backend.data.graph import BaseGraph
|
||||
from backend.data.model import CredentialsMetaInput, ProviderName
|
||||
from backend.integrations.credentials_store import ideogram_credentials
|
||||
from backend.util.request import Requests
|
||||
@@ -34,14 +34,14 @@ class ImageStyle(str, Enum):
|
||||
DIGITAL_ART = "digital art"
|
||||
|
||||
|
||||
async def generate_agent_image(agent: Graph | AgentGraph) -> io.BytesIO:
|
||||
async def generate_agent_image(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||
if settings.config.use_agent_image_generation_v2:
|
||||
return await generate_agent_image_v2(graph=agent)
|
||||
else:
|
||||
return await generate_agent_image_v1(agent=agent)
|
||||
|
||||
|
||||
async def generate_agent_image_v2(graph: Graph | AgentGraph) -> io.BytesIO:
|
||||
async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||
"""
|
||||
Generate an image for an agent using Ideogram model.
|
||||
Returns:
|
||||
@@ -99,7 +99,7 @@ async def generate_agent_image_v2(graph: Graph | AgentGraph) -> io.BytesIO:
|
||||
return io.BytesIO(response.content)
|
||||
|
||||
|
||||
async def generate_agent_image_v1(agent: Graph | AgentGraph) -> io.BytesIO:
|
||||
async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||
"""
|
||||
Generate an image for an agent using Flux model via Replicate API.
|
||||
|
||||
|
||||
@@ -386,8 +386,10 @@ class TestDataCreator:
|
||||
)
|
||||
if graph:
|
||||
# Use the API function to create library agent
|
||||
library_agent = await create_library_agent(graph, user["id"])
|
||||
library_agents.append(library_agent.model_dump())
|
||||
library_agents.extend(
|
||||
v.model_dump()
|
||||
for v in await create_library_agent(graph, user["id"])
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error creating library agent: {e}")
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user