feat(backend): Register agent subgraphs as library entries during agent import (#10409)

Currently, we only create a library entry of the top-most graph when
importing the graph from an exported file.
This can cause some complications, as there is no way to remove the
library entry of it.

### Changes 🏗️

Create the library entry for all the subgraphs during the import
process.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Export an agent with subgraphs and import it back.
This commit is contained in:
Zamil Majdy
2025-07-21 19:54:42 +08:00
committed by GitHub
parent e28eec6ff9
commit 0c9b7334c1
5 changed files with 82 additions and 48 deletions

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Optional, cast, overload
from backend.data.block import BlockSchema
from backend.data.graph import set_node_webhook
@@ -9,7 +10,7 @@ from . import get_webhook_manager, supports_webhooks
from .utils import setup_webhook_for_block
if TYPE_CHECKING:
from backend.data.graph import GraphModel, NodeModel
from backend.data.graph import BaseGraph, GraphModel, Node, NodeModel
from backend.data.model import Credentials
from ._base import BaseWebhooksManager
@@ -18,13 +19,29 @@ logger = logging.getLogger(__name__)
credentials_manager = IntegrationCredentialsManager()
async def on_graph_activate(graph: "GraphModel", user_id: str):
async def on_graph_activate(graph: "GraphModel", user_id: str) -> "GraphModel":
"""
Hook to be called when a graph is activated/created.
⚠️ Assuming node entities are not re-used between graph versions, ⚠️
this hook calls `on_node_activate` on all nodes in this graph.
"""
graph = await _on_graph_activate(graph, user_id)
graph.sub_graphs = await asyncio.gather(
*(_on_graph_activate(sub_graph, user_id) for sub_graph in graph.sub_graphs)
)
return graph
@overload
async def _on_graph_activate(graph: "GraphModel", user_id: str) -> "GraphModel": ...
@overload
async def _on_graph_activate(graph: "BaseGraph", user_id: str) -> "BaseGraph": ...
async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str):
get_credentials = credentials_manager.cached_getter(user_id)
updated_nodes = []
for new_node in graph.nodes:
@@ -47,7 +64,7 @@ async def on_graph_activate(graph: "GraphModel", user_id: str):
)
updated_node = await on_node_activate(
graph.user_id, new_node, credentials=node_credentials
user_id, graph.id, new_node, credentials=node_credentials
)
updated_nodes.append(updated_node)
@@ -94,10 +111,11 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
async def on_node_activate(
user_id: str,
node: "NodeModel",
graph_id: str,
node: "Node",
*,
credentials: Optional["Credentials"] = None,
) -> "NodeModel":
) -> "Node":
"""Hook to be called when the node is activated/created"""
if node.block.webhook_config:
@@ -105,7 +123,7 @@ async def on_node_activate(
user_id=user_id,
trigger_block=node.block,
trigger_config=node.input_default,
for_graph_id=node.graph_id,
for_graph_id=graph_id,
)
if new_webhook:
node = await set_node_webhook(node.id, new_webhook.id)

View File

@@ -621,16 +621,11 @@ async def create_new_graph(
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
graph.validate_graph(for_run=False)
graph = await graph_db.create_graph(graph, user_id=user_id)
# Create a library agent for the new graph
library_agent = await library_db.create_library_agent(graph, user_id)
_ = asyncio.create_task(
library_db.add_generated_agent_image(graph, library_agent.id)
)
graph = await on_graph_activate(graph, user_id=user_id)
return graph
# The return value of the create graph & library function is intentionally not used here,
# as the graph already valid and no sub-graphs are returned back.
await graph_db.create_graph(graph, user_id=user_id)
await library_db.create_library_agent(graph, user_id=user_id)
return await on_graph_activate(graph, user_id=user_id)
@v1_router.delete(

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import Literal, Optional
@@ -246,13 +247,13 @@ async def get_library_agent_by_graph_id(
async def add_generated_agent_image(
graph: graph_db.GraphModel,
graph: graph_db.BaseGraph,
user_id: str,
library_agent_id: str,
) -> Optional[prisma.models.LibraryAgent]:
"""
Generates an image for the specified LibraryAgent and updates its record.
"""
user_id = graph.user_id
graph_id = graph.id
# Use .jpeg here since we are generating JPEG images
@@ -281,16 +282,19 @@ async def add_generated_agent_image(
async def create_library_agent(
graph: graph_db.GraphModel,
user_id: str,
) -> library_model.LibraryAgent:
create_library_agents_for_sub_graphs: bool = True,
) -> list[library_model.LibraryAgent]:
"""
Adds an agent to the user's library (LibraryAgent table).
Args:
agent: The agent/Graph to add to the library.
user_id: The user to whom the agent will be added.
create_library_agents_for_sub_graphs: If True, creates LibraryAgent records for sub-graphs as well.
Returns:
The newly created LibraryAgent record.
The newly created LibraryAgent records.
If the graph has sub-graphs, the parent graph will always be the first entry in the list.
Raises:
AgentNotFoundError: If the specified agent does not exist.
@@ -300,26 +304,39 @@ async def create_library_agent(
f"Creating library agent for graph #{graph.id} v{graph.version}; "
f"user #{user_id}"
)
graph_entries = (
[graph, *graph.sub_graphs] if create_library_agents_for_sub_graphs else [graph]
)
try:
agent = await prisma.models.LibraryAgent.prisma().create(
data=prisma.types.LibraryAgentCreateInput(
isCreatedByUser=(user_id == graph.user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
# Creator={"connect": {"id": graph.user_id}},
AgentGraph={
"connect": {
"graphVersionId": {"id": graph.id, "version": graph.version}
}
},
),
include={"AgentGraph": True},
async with transaction() as tx:
library_agents = await asyncio.gather(
*(
prisma.models.LibraryAgent.prisma(tx).create(
data=prisma.types.LibraryAgentCreateInput(
isCreatedByUser=(user_id == user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
# Creator={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {
"id": graph_entry.id,
"version": graph_entry.version,
}
}
},
),
include=library_agent_include(user_id),
)
for graph_entry in graph_entries
)
)
return library_model.LibraryAgent.from_db(agent)
except prisma.errors.PrismaError as e:
logger.error(f"Database error creating agent in library: {e}")
raise store_exceptions.DatabaseError("Failed to create agent in library") from e
# Generate images for the main graph and sub-graphs
for agent, graph in zip(library_agents, graph_entries):
asyncio.create_task(add_generated_agent_image(graph, user_id, agent.id))
return [library_model.LibraryAgent.from_db(agent) for agent in library_agents]
async def update_agent_version_in_library(
@@ -872,7 +889,9 @@ async def delete_preset(user_id: str, preset_id: str) -> None:
raise store_exceptions.DatabaseError("Failed to delete preset") from e
async def fork_library_agent(library_agent_id: str, user_id: str):
async def fork_library_agent(
library_agent_id: str, user_id: str
) -> library_model.LibraryAgent:
"""
Clones a library agent and its underyling graph and nodes (with new ids) for the given user.
@@ -881,7 +900,7 @@ async def fork_library_agent(library_agent_id: str, user_id: str):
user_id: The ID of the user who owns the library agent.
Returns:
The forked LibraryAgent.
The forked parent (if it has sub-graphs) LibraryAgent.
Raises:
DatabaseError: If there's an error during the forking process.
@@ -907,7 +926,7 @@ async def fork_library_agent(library_agent_id: str, user_id: str):
new_graph = await on_graph_activate(new_graph, user_id=user_id)
# Create a library agent for the new graph
return await create_library_agent(new_graph, user_id)
return (await create_library_agent(new_graph, user_id))[0]
except prisma.errors.PrismaError as e:
logger.error(f"Database error cloning library agent: {e}")
raise store_exceptions.DatabaseError("Failed to fork library agent") from e

View File

@@ -16,7 +16,7 @@ from backend.blocks.ideogram import (
StyleType,
UpscaleOption,
)
from backend.data.graph import Graph
from backend.data.graph import BaseGraph
from backend.data.model import CredentialsMetaInput, ProviderName
from backend.integrations.credentials_store import ideogram_credentials
from backend.util.request import Requests
@@ -34,14 +34,14 @@ class ImageStyle(str, Enum):
DIGITAL_ART = "digital art"
async def generate_agent_image(agent: Graph | AgentGraph) -> io.BytesIO:
async def generate_agent_image(agent: BaseGraph | AgentGraph) -> io.BytesIO:
if settings.config.use_agent_image_generation_v2:
return await generate_agent_image_v2(graph=agent)
else:
return await generate_agent_image_v1(agent=agent)
async def generate_agent_image_v2(graph: Graph | AgentGraph) -> io.BytesIO:
async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
"""
Generate an image for an agent using Ideogram model.
Returns:
@@ -99,7 +99,7 @@ async def generate_agent_image_v2(graph: Graph | AgentGraph) -> io.BytesIO:
return io.BytesIO(response.content)
async def generate_agent_image_v1(agent: Graph | AgentGraph) -> io.BytesIO:
async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
"""
Generate an image for an agent using Flux model via Replicate API.

View File

@@ -386,8 +386,10 @@ class TestDataCreator:
)
if graph:
# Use the API function to create library agent
library_agent = await create_library_agent(graph, user["id"])
library_agents.append(library_agent.model_dump())
library_agents.extend(
v.model_dump()
for v in await create_library_agent(graph, user["id"])
)
except Exception as e:
print(f"Error creating library agent: {e}")
continue