feat(backend): Support sub-agent on export/import agent feature (#9640)

Agents using Agent blocks should be seamlessly downloaded from the
marketplace to a file and imported from a file.

Requirements:
* A recursive export process that exports all the required agents to a
single file, no matter how many layers deep (taking care of potential
loops).
* An import process that expects and extracts several agents from a
single file into your library at once.

Considerations:
We need to ensure the reference IDs in the Agent Blocks match/are
updated to match the imported sub-agent ids to prevent broken
references.

### Changes 🏗️

* Add sub_graphs field on Graph model 
* Improve graph creation query to support inserting graph + subgraphs in
batch
* Deprecate graph template & remove its column
* Update on marketplace download agent (unified the used method, with
more secure cleanup & proper ownership check).
* Fix failing test cases

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Export graph with sub agents.
  - [x] Import the exported graph with sub agents.
This commit is contained in:
Zamil Majdy
2025-03-17 23:38:27 +07:00
parent 596b29f53a
commit 17f3a19bc3
14 changed files with 284 additions and 239 deletions

View File

@@ -1,4 +1,3 @@
import asyncio
import logging
import uuid
from collections import defaultdict
@@ -123,6 +122,12 @@ class NodeModel(Node):
stripped_node.input_default, self.block.input_schema.jsonschema()
)
if (
stripped_node.block.block_type == BlockType.INPUT
and "value" in stripped_node.input_default
):
stripped_node.input_default["value"] = ""
# Remove webhook info
stripped_node.webhook_id = None
stripped_node.webhook = None
@@ -249,10 +254,9 @@ class GraphExecution(GraphExecutionMeta):
)
class Graph(BaseDbModel):
class BaseGraph(BaseDbModel):
version: int = 1
is_active: bool = True
is_template: bool = False
name: str
description: str
nodes: list[Node] = []
@@ -315,6 +319,10 @@ class Graph(BaseDbModel):
}
class Graph(BaseGraph):
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs, only used in export
class GraphModel(Graph):
user_id: str
nodes: list[NodeModel] = [] # type: ignore
@@ -338,31 +346,55 @@ class GraphModel(Graph):
Reassigns all IDs in the graph to new UUIDs.
This method can be used before storing a new graph to the database.
"""
if reassign_graph_id:
graph_id_map = {
self.id: str(uuid.uuid4()),
**{sub_graph.id: str(uuid.uuid4()) for sub_graph in self.sub_graphs},
}
else:
graph_id_map = {}
self._reassign_ids(self, user_id, graph_id_map)
for sub_graph in self.sub_graphs:
self._reassign_ids(sub_graph, user_id, graph_id_map)
@staticmethod
def _reassign_ids(
graph: BaseGraph,
user_id: str,
graph_id_map: dict[str, str],
):
# Reassign Graph ID
id_map = {node.id: str(uuid.uuid4()) for node in self.nodes}
if reassign_graph_id:
self.id = str(uuid.uuid4())
if graph.id in graph_id_map:
graph.id = graph_id_map[graph.id]
# Reassign Node IDs
for node in self.nodes:
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
for node in graph.nodes:
node.id = id_map[node.id]
# Reassign Link IDs
for link in self.links:
for link in graph.links:
link.source_id = id_map[link.source_id]
link.sink_id = id_map[link.sink_id]
# Reassign User IDs for agent blocks
for node in self.nodes:
for node in graph.nodes:
if node.block_id != AgentExecutorBlock().id:
continue
node.input_default["user_id"] = user_id
node.input_default.setdefault("data", {})
self.validate_graph()
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
node.input_default["graph_id"] = graph_id_map[graph_id]
def validate_graph(self, for_run: bool = False):
self._validate_graph(self, for_run)
for sub_graph in self.sub_graphs:
self._validate_graph(sub_graph, for_run)
@staticmethod
def _validate_graph(graph: BaseGraph, for_run: bool = False):
def sanitize(name):
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
if sanitized_name.startswith("tools_^_"):
@@ -374,11 +406,11 @@ class GraphModel(Graph):
agent_nodes = set()
nodes_block = {
node.id: block
for node in self.nodes
for node in graph.nodes
if (block := get_block(node.block_id)) is not None
}
for node in self.nodes:
for node in graph.nodes:
if (block := nodes_block.get(node.id)) is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
@@ -391,11 +423,11 @@ class GraphModel(Graph):
input_links = defaultdict(list)
for link in self.links:
for link in graph.links:
input_links[link.sink_id].append(link)
# Nodes: required fields are filled or connected and dependencies are satisfied
for node in self.nodes:
for node in graph.nodes:
if (block := nodes_block.get(node.id)) is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
@@ -456,7 +488,7 @@ class GraphModel(Graph):
f"Node {block.name} #{node.id}: Field `{field_name}` requires [{', '.join(missing_deps)}] to be set"
)
node_map = {v.id: v for v in self.nodes}
node_map = {v.id: v for v in graph.nodes}
def is_static_output_block(nid: str) -> bool:
bid = node_map[nid].block_id
@@ -464,7 +496,7 @@ class GraphModel(Graph):
return b.static_output if b else False
# Links: links are connected and the connected pin data type are compatible.
for link in self.links:
for link in graph.links:
source = (link.source_id, link.source_name)
sink = (link.sink_id, link.sink_name)
prefix = f"Link {source} <-> {sink}"
@@ -505,13 +537,16 @@ class GraphModel(Graph):
link.is_static = True # Each value block output should be static.
@staticmethod
def from_db(graph: AgentGraph, for_export: bool = False):
def from_db(
graph: AgentGraph,
for_export: bool = False,
sub_graphs: list[AgentGraph] | None = None,
):
return GraphModel(
id=graph.id,
user_id=graph.userId,
user_id=graph.userId if not for_export else "",
version=graph.version,
is_active=graph.isActive,
is_template=graph.isTemplate,
name=graph.name or "",
description=graph.description or "",
nodes=[
@@ -524,28 +559,12 @@ class GraphModel(Graph):
for link in (node.Input or []) + (node.Output or [])
}
),
sub_graphs=[
GraphModel.from_db(sub_graph, for_export)
for sub_graph in sub_graphs or []
],
)
def clean_graph(self):
blocks = [block() for block in get_blocks().values()]
input_blocks = [
node
for node in self.nodes
if next(
(
b
for b in blocks
if b.id == node.block_id and b.block_type == BlockType.INPUT
),
None,
)
]
for node in self.nodes:
if any(input_block.id == node.id for input_block in input_blocks):
node.input_default["value"] = ""
# --------------------- CRUD functions --------------------- #
@@ -575,14 +594,14 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
async def get_graphs(
user_id: str,
filter_by: Literal["active", "template"] | None = "active",
filter_by: Literal["active"] | None = "active",
) -> list[GraphModel]:
"""
Retrieves graph metadata objects.
Default behaviour is to get all currently active graphs.
Args:
filter_by: An optional filter to either select templates or active graphs.
filter_by: An optional filter to either select graphs.
user_id: The ID of the user that owns the graph.
Returns:
@@ -592,8 +611,6 @@ async def get_graphs(
if filter_by == "active":
where_clause["isActive"] = True
elif filter_by == "template":
where_clause["isTemplate"] = True
graphs = await AgentGraph.prisma().find_many(
where=where_clause,
@@ -682,21 +699,18 @@ async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph
description=graph.description or "",
version=graph.version,
is_active=graph.isActive,
is_template=graph.isTemplate,
)
async def get_graph(
graph_id: str,
version: int | None = None,
template: bool = False, # note: currently not in use; TODO: remove from DB entirely
user_id: str | None = None,
for_export: bool = False,
) -> GraphModel | None:
"""
Retrieves a graph from the DB.
Defaults to the version with `is_active` if `version` is not passed,
or the latest version with `is_template` if `template=True`.
Defaults to the version with `is_active` if `version` is not passed.
Returns `None` if the record is not found.
"""
@@ -706,8 +720,6 @@ async def get_graph(
if version is not None:
where_clause["version"] = version
elif not template:
where_clause["isActive"] = True
graph = await AgentGraph.prisma().find_first(
where=where_clause,
@@ -731,9 +743,62 @@ async def get_graph(
):
return None
if for_export:
sub_graphs = await _get_sub_graphs(graph)
return GraphModel.from_db(
graph=graph,
sub_graphs=sub_graphs,
for_export=for_export,
)
return GraphModel.from_db(graph, for_export)
async def _get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
"""
Iteratively fetches all sub-graphs of a given graph, and flattens them into a list.
This call involves a DB fetch in batch, breadth-first, per-level of graph depth.
On each DB fetch we will only fetch the sub-graphs that are not already in the list.
"""
sub_graphs = {graph.id: graph}
search_graphs = [graph]
agent_block_id = AgentExecutorBlock().id
while search_graphs:
sub_graph_ids = [
(graph_id, graph_version)
for graph in search_graphs
for node in graph.AgentNodes or []
if (
node.AgentBlock
and node.AgentBlock.id == agent_block_id
and (graph_id := dict(node.constantInput).get("graph_id"))
and (graph_version := dict(node.constantInput).get("graph_version"))
)
]
if not sub_graph_ids:
break
graphs = await AgentGraph.prisma().find_many(
where={
"OR": [
{
"id": graph_id,
"version": graph_version,
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
}
for graph_id, graph_version in sub_graph_ids
] # type: ignore
},
include=AGENT_GRAPH_INCLUDE,
)
search_graphs = [graph for graph in graphs if graph.id not in sub_graphs]
sub_graphs.update({graph.id: graph for graph in search_graphs})
return [g for g in sub_graphs.values() if g.id != graph.id]
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
links = await AgentNodeLink.prisma().find_many(
where={"agentNodeSourceId": node_id},
@@ -797,50 +862,56 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
async with transaction() as tx:
await __create_graph(tx, graph, user_id)
if created_graph := await get_graph(
graph.id, graph.version, template=graph.is_template, user_id=user_id
):
if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
return created_graph
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
async def __create_graph(tx, graph: Graph, user_id: str):
await AgentGraph.prisma(tx).create(
data={
"id": graph.id,
"version": graph.version,
"name": graph.name,
"description": graph.description,
"isTemplate": graph.is_template,
"isActive": graph.is_active,
"userId": user_id,
"AgentNodes": {
"create": [
{
"id": node.id,
"agentBlockId": node.block_id,
"constantInput": Json(node.input_default),
"metadata": Json(node.metadata),
}
for node in graph.nodes
]
},
}
graphs = [graph] + graph.sub_graphs
await AgentGraph.prisma(tx).create_many(
data=[
{
"id": graph.id,
"version": graph.version,
"name": graph.name,
"description": graph.description,
"isActive": graph.is_active,
"userId": user_id,
}
for graph in graphs
]
)
await asyncio.gather(
*[
AgentNodeLink.prisma(tx).create(
{
"id": str(uuid.uuid4()),
"sourceName": link.source_name,
"sinkName": link.sink_name,
"agentNodeSourceId": link.source_id,
"agentNodeSinkId": link.sink_id,
"isStatic": link.is_static,
}
)
await AgentNode.prisma(tx).create_many(
data=[
{
"id": node.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"agentBlockId": node.block_id,
"constantInput": Json(node.input_default),
"metadata": Json(node.metadata),
"webhookId": node.webhook_id,
}
for graph in graphs
for node in graph.nodes
]
)
await AgentNodeLink.prisma(tx).create_many(
data=[
{
"id": str(uuid.uuid4()),
"sourceName": link.source_name,
"sinkName": link.sink_name,
"agentNodeSourceId": link.source_id,
"agentNodeSinkId": link.sink_id,
"isStatic": link.is_static,
}
for graph in graphs
for link in graph.links
]
)

View File

@@ -154,9 +154,10 @@ class AgentServer(backend.util.service.AppProcess):
graph_id: str,
graph_version: int,
user_id: str,
for_export: bool = False,
):
return await backend.server.routers.v1.get_graph(
graph_id, user_id, graph_version
graph_id, user_id, graph_version, for_export
)
@staticmethod

View File

@@ -396,10 +396,10 @@ async def get_graph(
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
version: int | None = None,
hide_credentials: bool = False,
for_export: bool = False,
) -> graph_db.GraphModel:
graph = await graph_db.get_graph(
graph_id, version, user_id=user_id, for_export=hide_credentials
graph_id, version, user_id=user_id, for_export=for_export
)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
@@ -429,6 +429,7 @@ async def create_new_graph(
) -> graph_db.GraphModel:
graph = graph_db.make_graph_model(create_graph.graph, user_id)
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
graph.validate_graph(for_run=False)
graph = await graph_db.create_graph(graph, user_id=user_id)
@@ -480,17 +481,10 @@ async def update_graph(
latest_version_number = max(g.version for g in existing_versions)
graph.version = latest_version_number + 1
latest_version_graph = next(
v for v in existing_versions if v.version == latest_version_number
)
current_active_version = next((v for v in existing_versions if v.is_active), None)
if latest_version_graph.is_template != graph.is_template:
raise HTTPException(
400, detail="Changing is_template on an existing graph is forbidden"
)
graph.is_active = not graph.is_template
graph = graph_db.make_graph_model(graph, user_id)
graph.reassign_ids(user_id=user_id)
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
graph.validate_graph(for_run=False)
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)

View File

@@ -3,20 +3,11 @@ from datetime import datetime
import prisma.errors
import prisma.models
import pytest
from prisma import Prisma
import backend.server.v2.library.db as db
import backend.server.v2.store.exceptions
@pytest.fixture(autouse=True)
async def setup_prisma():
# Don't register client if already registered
try:
Prisma()
except prisma.errors.ClientAlreadyRegisteredError:
pass
yield
from backend.data.db import connect
from backend.data.includes import library_agent_include
@pytest.mark.asyncio
@@ -31,7 +22,6 @@ async def test_get_library_agents(mocker):
userId="test-user",
isActive=True,
createdAt=datetime.now(),
isTemplate=False,
)
]
@@ -56,7 +46,6 @@ async def test_get_library_agents(mocker):
userId="other-user",
isActive=True,
createdAt=datetime.now(),
isTemplate=False,
),
)
]
@@ -91,10 +80,11 @@ async def test_get_library_agents(mocker):
assert result.pagination.page_size == 50
@pytest.mark.asyncio
@pytest.mark.asyncio(scope="session")
async def test_add_agent_to_library(mocker):
await connect()
# Mock data
mock_store_listing = prisma.models.StoreListingVersion(
mock_store_listing_data = prisma.models.StoreListingVersion(
id="version123",
version=1,
createdAt=datetime.now(),
@@ -119,21 +109,37 @@ async def test_add_agent_to_library(mocker):
userId="creator",
isActive=True,
createdAt=datetime.now(),
isTemplate=False,
),
)
mock_library_agent_data = prisma.models.LibraryAgent(
id="ua1",
userId="test-user",
agentId=mock_store_listing_data.agentId,
agentVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
createdAt=datetime.now(),
updatedAt=datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
Agent=mock_store_listing_data.Agent,
)
# Mock prisma calls
mock_store_listing_version = mocker.patch(
"prisma.models.StoreListingVersion.prisma"
)
mock_store_listing_version.return_value.find_unique = mocker.AsyncMock(
return_value=mock_store_listing
return_value=mock_store_listing_data
)
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock()
mock_library_agent.return_value.create = mocker.AsyncMock(
return_value=mock_library_agent_data
)
# Call function
await db.add_store_agent_to_library("version123", "test-user")
@@ -147,17 +153,20 @@ async def test_add_agent_to_library(mocker):
"userId": "test-user",
"agentId": "agent1",
"agentVersion": 1,
}
},
include=library_agent_include("test-user"),
)
mock_library_agent.return_value.create.assert_called_once_with(
data=prisma.types.LibraryAgentCreateInput(
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
)
),
include=library_agent_include("test-user"),
)
@pytest.mark.asyncio
@pytest.mark.asyncio(scope="session")
async def test_add_agent_to_library_not_found(mocker):
await connect()
# Mock prisma calls
mock_store_listing_version = mocker.patch(
"prisma.models.StoreListingVersion.prisma"

View File

@@ -2,11 +2,14 @@ import datetime
import prisma.fields
import prisma.models
import pytest
import backend.server.v2.library.model as library_model
from backend.util import json
def test_agent_preset_from_db():
@pytest.mark.asyncio
async def test_agent_preset_from_db():
# Create mock DB agent
db_agent = prisma.models.AgentPreset(
id="test-agent-123",
@@ -24,7 +27,7 @@ def test_agent_preset_from_db():
id="input-123",
time=datetime.datetime.now(),
name="input1",
data=prisma.fields.Json({"type": "string", "value": "test value"}),
data=json.dumps({"type": "string", "value": "test value"}), # type: ignore
)
],
)

View File

@@ -1,7 +1,6 @@
import datetime
import autogpt_libs.auth as autogpt_auth_lib
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
@@ -30,49 +29,48 @@ app.dependency_overrides[autogpt_auth_lib.auth_middleware] = override_auth_middl
app.dependency_overrides[autogpt_auth_lib.depends.get_user_id] = override_get_user_id
def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
mocked_value = [
library_model.LibraryAgentResponse(
agents=[
library_model.LibraryAgent(
id="test-agent-1",
agent_id="test-agent-1",
agent_version=1,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=True,
is_latest_version=True,
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
),
library_model.LibraryAgent(
id="test-agent-2",
agent_id="test-agent-2",
agent_version=1,
name="Test Agent 2",
description="Test Description 2",
image_url=None,
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=False,
is_latest_version=True,
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
),
],
pagination=server_model.Pagination(
total_items=2, total_pages=1, current_page=1, page_size=50
@pytest.mark.asyncio
async def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
mocked_value = library_model.LibraryAgentResponse(
agents=[
library_model.LibraryAgent(
id="test-agent-1",
agent_id="test-agent-1",
agent_version=1,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=True,
is_latest_version=True,
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
),
library_model.LibraryAgent(
id="test-agent-2",
agent_id="test-agent-2",
agent_version=1,
name="Test Agent 2",
description="Test Description 2",
image_url=None,
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=False,
is_latest_version=True,
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
),
],
pagination=server_model.Pagination(
total_items=2, total_pages=1, current_page=1, page_size=50
),
]
mock_db_call = mocker.patch("backend.server.v2.library.db.get_library_agents")
)
mock_db_call = mocker.patch("backend.server.v2.library.db.list_library_agents")
mock_db_call.return_value = mocked_value
response = client.get("/agents?search_term=test")
@@ -94,7 +92,7 @@ def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
def test_get_library_agents_error(mocker: pytest_mock.MockFixture):
mock_db_call = mocker.patch("backend.server.v2.library.db.get_library_agents")
mock_db_call = mocker.patch("backend.server.v2.library.db.list_library_agents")
mock_db_call.side_effect = Exception("Test error")
response = client.get("/agents?search_term=test")

View File

@@ -1,6 +1,5 @@
import logging
from datetime import datetime
from typing import Optional
import fastapi
import prisma.enums
@@ -754,47 +753,31 @@ async def get_my_agents(
async def get_agent(
store_listing_version_id: str, version_id: Optional[int]
user_id: str,
store_listing_version_id: str,
) -> GraphModel:
"""Get agent using the version ID and store listing version ID."""
try:
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}, include={"Agent": True}
)
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}
)
)
if not store_listing_version:
raise ValueError(f"Store listing version {store_listing_version_id} not found")
graph = await backend.data.graph.get_graph(
user_id=user_id,
graph_id=store_listing_version.agentId,
version=store_listing_version.agentVersion,
for_export=True,
)
if not graph:
raise ValueError(
f"Agent {store_listing_version.agentId} v{store_listing_version.agentVersion} not found"
)
if not store_listing_version or not store_listing_version.Agent:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing version {store_listing_version_id} not found",
)
graph_id = store_listing_version.agentId
graph_version = store_listing_version.agentVersion
graph = await backend.data.graph.get_graph(graph_id, graph_version)
if not graph:
raise fastapi.HTTPException(
status_code=404,
detail=(
f"Agent #{graph_id} not found "
f"for store listing version #{store_listing_version_id}"
),
)
graph.version = 1
graph.is_template = False
graph.is_active = True
delattr(graph, "user_id")
return graph
except Exception as e:
logger.error(f"Error getting agent: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch agent"
) from e
return graph
async def review_store_submission(

View File

@@ -146,7 +146,6 @@ async def test_create_store_submission(mocker):
userId="user-id",
createdAt=datetime.now(),
isActive=True,
isTemplate=False,
)
mock_listing = prisma.models.StoreListing(

View File

@@ -1,4 +1,3 @@
import json
import logging
import tempfile
import typing
@@ -8,7 +7,6 @@ import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware
import fastapi
import fastapi.responses
from fastapi.encoders import jsonable_encoder
import backend.data.block
import backend.data.graph
@@ -16,6 +14,7 @@ import backend.server.v2.store.db
import backend.server.v2.store.image_gen
import backend.server.v2.store.media
import backend.server.v2.store.model
import backend.util.json
logger = logging.getLogger(__name__)
@@ -591,19 +590,18 @@ async def generate_image(
tags=["store", "public"],
)
async def download_agent_file(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
store_listing_version_id: str = fastapi.Path(
..., description="The ID of the agent to download"
),
version: typing.Optional[int] = fastapi.Query(
None, description="Specific version of the agent"
),
) -> fastapi.responses.FileResponse:
"""
Download the agent file by streaming its content.
Args:
agent_id (str): The ID of the agent to download.
version (Optional[int]): Specific version of the agent to download.
store_listing_version_id (str): The ID of the agent to download
Returns:
StreamingResponse: A streaming response containing the agent's graph data.
@@ -613,35 +611,16 @@ async def download_agent_file(
"""
graph_data = await backend.server.v2.store.db.get_agent(
store_listing_version_id=store_listing_version_id, version_id=version
user_id=user_id,
store_listing_version_id=store_listing_version_id,
)
graph_data.clean_graph()
graph_date_dict = jsonable_encoder(graph_data)
def remove_credentials(obj):
if obj and isinstance(obj, dict):
if "credentials" in obj:
del obj["credentials"]
if "creds" in obj:
del obj["creds"]
for value in obj.values():
remove_credentials(value)
elif isinstance(obj, list):
for item in obj:
remove_credentials(item)
return obj
graph_date_dict = remove_credentials(graph_date_dict)
file_name = f"agent_{store_listing_version_id}_v{version or 'latest'}.json"
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
# Sending graph as a stream (similar to marketplace v1)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
tmp_file.write(json.dumps(graph_date_dict))
tmp_file.write(backend.util.json.dumps(graph_data))
tmp_file.flush()
return fastapi.responses.FileResponse(

View File

@@ -0,0 +1,8 @@
/*
Warnings:
- You are about to drop the column `isTemplate` on the `AgentGraph` table. All the data in the column will be lost.
*/
-- AlterTable
ALTER TABLE "AgentGraph" DROP COLUMN "isTemplate";

View File

@@ -87,7 +87,6 @@ model AgentGraph {
description String?
isActive Boolean @default(true)
isTemplate Boolean @default(false)
// Link to User model
userId String

View File

@@ -199,7 +199,9 @@ async def test_clean_graph(server: SpinTestServer):
)
# Clean the graph
created_graph.clean_graph()
created_graph = await server.agent_server.test_get_graph(
created_graph.id, created_graph.version, DEFAULT_USER_ID, for_export=True
)
# # Verify input block value is cleared
input_node = next(

View File

@@ -91,7 +91,6 @@ async def main():
"description": faker.text(max_nb_chars=200),
"userId": user.id,
"isActive": True,
"isTemplate": False,
}
)
agent_graphs.append(graph)

View File

@@ -197,14 +197,14 @@ export default class BackendAPI {
getGraph(
id: GraphID,
version?: number,
hide_credentials?: boolean,
for_export?: boolean,
): Promise<Graph> {
let query: Record<string, any> = {};
if (version !== undefined) {
query["version"] = version;
}
if (hide_credentials !== undefined) {
query["hide_credentials"] = hide_credentials;
if (for_export !== undefined) {
query["for_export"] = for_export;
}
return this._get(`/graphs/${id}`, query);
}