mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
refactor(backend/copilot): use BaseGraph type for graph field
Use BaseGraph instead of Graph to get typed nodes+links without causing the Pydantic OpenAPI schema split. BaseGraph-Input/Output already exists on dev so no frontend imports break. Fetches via graph_db().get_graph().
This commit is contained in:
@@ -10,10 +10,10 @@ if TYPE_CHECKING:
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.api.features.store.model import StoreAgent, StoreAgentDetails
|
||||
|
||||
from backend.data.db_accessors import graph_db as get_graph_db
|
||||
from backend.data.db_accessors import library_db, store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .agent_generator import get_agent_as_json
|
||||
from .models import (
|
||||
AgentInfo,
|
||||
AgentsFoundResponse,
|
||||
@@ -222,12 +222,12 @@ async def _enrich_agents_with_graph(agents: list[AgentInfo], user_id: str) -> No
|
||||
|
||||
async def _fetch(agent: AgentInfo) -> None:
|
||||
try:
|
||||
graph_json = await get_agent_as_json(
|
||||
agent.graph_id, user_id # type: ignore[arg-type]
|
||||
graph = await get_graph_db().get_graph(
|
||||
agent.graph_id, version=None, user_id=user_id # type: ignore[arg-type]
|
||||
)
|
||||
if graph_json is None:
|
||||
if graph is None:
|
||||
logger.warning(f"Graph not found for agent {agent.graph_id}")
|
||||
agent.graph = graph_json
|
||||
agent.graph = graph
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch graph for agent {agent.graph_id}: {e}")
|
||||
|
||||
|
||||
@@ -177,20 +177,18 @@ class TestLibraryUUIDLookup:
|
||||
assert response.agents[0].name == "My Library Agent"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_include_graph_fetches_nodes_and_links(self):
|
||||
"""include_graph=True attaches full graph JSON to agent results."""
|
||||
async def test_include_graph_fetches_graph(self):
|
||||
"""include_graph=True attaches BaseGraph to agent results."""
|
||||
from backend.data.graph import BaseGraph
|
||||
|
||||
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
mock_agent = self._make_mock_library_agent(agent_id)
|
||||
|
||||
mock_lib_db = MagicMock()
|
||||
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
|
||||
fake_graph = {
|
||||
"id": agent_id,
|
||||
"name": "My Library Agent",
|
||||
"nodes": [{"id": "node-1", "block_id": "block-1"}],
|
||||
"links": [{"id": "link-1", "source_id": "node-1", "sink_id": "node-2"}],
|
||||
}
|
||||
fake_graph = BaseGraph(id=agent_id, name="My Library Agent", description="test")
|
||||
mock_graph_db = MagicMock()
|
||||
mock_graph_db.get_graph = AsyncMock(return_value=fake_graph)
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -198,116 +196,118 @@ class TestLibraryUUIDLookup:
|
||||
return_value=mock_lib_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.get_agent_as_json",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_graph,
|
||||
) as mock_get_json,
|
||||
"backend.copilot.tools.agent_search.get_graph_db",
|
||||
return_value=mock_graph_db,
|
||||
),
|
||||
):
|
||||
response = await search_agents(
|
||||
query=agent_id,
|
||||
source="library",
|
||||
session_id="test-session",
|
||||
session_id="s",
|
||||
user_id=_TEST_USER_ID,
|
||||
include_graph=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.agents[0].graph is not None
|
||||
assert response.agents[0].graph["nodes"] == fake_graph["nodes"]
|
||||
assert response.agents[0].graph["links"] == fake_graph["links"]
|
||||
mock_get_json.assert_awaited_once_with(agent_id, _TEST_USER_ID)
|
||||
assert response.agents[0].graph.id == agent_id
|
||||
mock_graph_db.get_graph.assert_awaited_once_with(
|
||||
agent_id, version=None, user_id=_TEST_USER_ID
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_include_graph_false_does_not_fetch(self):
|
||||
async def test_include_graph_false_skips_fetch(self):
|
||||
"""include_graph=False (default) does not fetch graph data."""
|
||||
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
mock_agent = self._make_mock_library_agent(agent_id)
|
||||
|
||||
mock_lib_db = MagicMock()
|
||||
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
|
||||
mock_graph_db = MagicMock()
|
||||
mock_graph_db.get_graph = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.library_db",
|
||||
return_value=mock_lib_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.get_agent_as_json",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_json,
|
||||
"backend.copilot.tools.agent_search.get_graph_db",
|
||||
return_value=mock_graph_db,
|
||||
),
|
||||
):
|
||||
response = await search_agents(
|
||||
query=agent_id,
|
||||
source="library",
|
||||
session_id="test-session",
|
||||
session_id="s",
|
||||
user_id=_TEST_USER_ID,
|
||||
include_graph=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.agents[0].graph is None
|
||||
mock_get_json.assert_not_awaited()
|
||||
mock_graph_db.get_graph.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_include_graph_handles_fetch_failure(self):
|
||||
"""include_graph=True still returns agents when graph fetch fails."""
|
||||
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
mock_agent = self._make_mock_library_agent(agent_id)
|
||||
|
||||
mock_lib_db = MagicMock()
|
||||
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
|
||||
mock_graph_db = MagicMock()
|
||||
mock_graph_db.get_graph = AsyncMock(side_effect=Exception("DB down"))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.library_db",
|
||||
return_value=mock_lib_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.get_agent_as_json",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection failed"),
|
||||
"backend.copilot.tools.agent_search.get_graph_db",
|
||||
return_value=mock_graph_db,
|
||||
),
|
||||
):
|
||||
response = await search_agents(
|
||||
query=agent_id,
|
||||
source="library",
|
||||
session_id="test-session",
|
||||
session_id="s",
|
||||
user_id=_TEST_USER_ID,
|
||||
include_graph=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.count == 1
|
||||
assert response.agents[0].graph is None
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_include_graph_handles_none_return(self):
|
||||
"""include_graph=True handles get_agent_as_json returning None."""
|
||||
"""include_graph=True handles get_graph returning None."""
|
||||
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
mock_agent = self._make_mock_library_agent(agent_id)
|
||||
|
||||
mock_lib_db = MagicMock()
|
||||
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
|
||||
mock_graph_db = MagicMock()
|
||||
mock_graph_db.get_graph = AsyncMock(return_value=None)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.library_db",
|
||||
return_value=mock_lib_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.get_agent_as_json",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
"backend.copilot.tools.agent_search.get_graph_db",
|
||||
return_value=mock_graph_db,
|
||||
),
|
||||
):
|
||||
response = await search_agents(
|
||||
query=agent_id,
|
||||
source="library",
|
||||
session_id="test-session",
|
||||
session_id="s",
|
||||
user_id=_TEST_USER_ID,
|
||||
include_graph=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.count == 1
|
||||
assert response.agents[0].graph is None
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.graph import BaseGraph
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
|
||||
|
||||
@@ -122,7 +123,7 @@ class AgentInfo(BaseModel):
|
||||
default=None,
|
||||
description="Input schema for the agent, including field names, types, and defaults",
|
||||
)
|
||||
graph: dict[str, Any] | None = Field(
|
||||
graph: BaseGraph | None = Field(
|
||||
default=None,
|
||||
description="Full graph structure (nodes + links) when include_graph is requested",
|
||||
)
|
||||
|
||||
@@ -7463,10 +7463,9 @@
|
||||
},
|
||||
"graph": {
|
||||
"anyOf": [
|
||||
{ "additionalProperties": true, "type": "object" },
|
||||
{ "$ref": "#/components/schemas/BaseGraph-Output" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Graph",
|
||||
"description": "Full graph structure (nodes + links) when include_graph is requested"
|
||||
}
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user