refactor(backend/copilot): use BaseGraph type for graph field

Use BaseGraph instead of Graph to get typed nodes+links without causing
the Pydantic OpenAPI schema split. BaseGraph-Input/Output already exists
on dev so no frontend imports break. Fetches via graph_db().get_graph().
This commit is contained in:
Zamil Majdy
2026-03-31 17:34:20 +02:00
parent 77fd8648a7
commit c47fcc1925
4 changed files with 45 additions and 45 deletions

View File

@@ -10,10 +10,10 @@ if TYPE_CHECKING:
from backend.api.features.library.model import LibraryAgent
from backend.api.features.store.model import StoreAgent, StoreAgentDetails
from backend.data.db_accessors import graph_db as get_graph_db
from backend.data.db_accessors import library_db, store_db
from backend.util.exceptions import DatabaseError, NotFoundError
from .agent_generator import get_agent_as_json
from .models import (
AgentInfo,
AgentsFoundResponse,
@@ -222,12 +222,12 @@ async def _enrich_agents_with_graph(agents: list[AgentInfo], user_id: str) -> No
async def _fetch(agent: AgentInfo) -> None:
try:
graph_json = await get_agent_as_json(
agent.graph_id, user_id # type: ignore[arg-type]
graph = await get_graph_db().get_graph(
agent.graph_id, version=None, user_id=user_id # type: ignore[arg-type]
)
if graph_json is None:
if graph is None:
logger.warning(f"Graph not found for agent {agent.graph_id}")
agent.graph = graph_json
agent.graph = graph
except Exception as e:
logger.warning(f"Failed to fetch graph for agent {agent.graph_id}: {e}")

View File

@@ -177,20 +177,18 @@ class TestLibraryUUIDLookup:
assert response.agents[0].name == "My Library Agent"
@pytest.mark.asyncio(loop_scope="session")
async def test_include_graph_fetches_nodes_and_links(self):
"""include_graph=True attaches full graph JSON to agent results."""
async def test_include_graph_fetches_graph(self):
"""include_graph=True attaches BaseGraph to agent results."""
from backend.data.graph import BaseGraph
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
mock_agent = self._make_mock_library_agent(agent_id)
mock_lib_db = MagicMock()
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
fake_graph = {
"id": agent_id,
"name": "My Library Agent",
"nodes": [{"id": "node-1", "block_id": "block-1"}],
"links": [{"id": "link-1", "source_id": "node-1", "sink_id": "node-2"}],
}
fake_graph = BaseGraph(id=agent_id, name="My Library Agent", description="test")
mock_graph_db = MagicMock()
mock_graph_db.get_graph = AsyncMock(return_value=fake_graph)
with (
patch(
@@ -198,116 +196,118 @@ class TestLibraryUUIDLookup:
return_value=mock_lib_db,
),
patch(
"backend.copilot.tools.agent_search.get_agent_as_json",
new_callable=AsyncMock,
return_value=fake_graph,
) as mock_get_json,
"backend.copilot.tools.agent_search.get_graph_db",
return_value=mock_graph_db,
),
):
response = await search_agents(
query=agent_id,
source="library",
session_id="test-session",
session_id="s",
user_id=_TEST_USER_ID,
include_graph=True,
)
assert isinstance(response, AgentsFoundResponse)
assert response.agents[0].graph is not None
assert response.agents[0].graph["nodes"] == fake_graph["nodes"]
assert response.agents[0].graph["links"] == fake_graph["links"]
mock_get_json.assert_awaited_once_with(agent_id, _TEST_USER_ID)
assert response.agents[0].graph.id == agent_id
mock_graph_db.get_graph.assert_awaited_once_with(
agent_id, version=None, user_id=_TEST_USER_ID
)
@pytest.mark.asyncio(loop_scope="session")
async def test_include_graph_false_does_not_fetch(self):
async def test_include_graph_false_skips_fetch(self):
"""include_graph=False (default) does not fetch graph data."""
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
mock_agent = self._make_mock_library_agent(agent_id)
mock_lib_db = MagicMock()
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
mock_graph_db = MagicMock()
mock_graph_db.get_graph = AsyncMock()
with (
patch(
"backend.copilot.tools.agent_search.library_db",
return_value=mock_lib_db,
),
patch(
"backend.copilot.tools.agent_search.get_agent_as_json",
new_callable=AsyncMock,
) as mock_get_json,
"backend.copilot.tools.agent_search.get_graph_db",
return_value=mock_graph_db,
),
):
response = await search_agents(
query=agent_id,
source="library",
session_id="test-session",
session_id="s",
user_id=_TEST_USER_ID,
include_graph=False,
)
assert isinstance(response, AgentsFoundResponse)
assert response.agents[0].graph is None
mock_get_json.assert_not_awaited()
mock_graph_db.get_graph.assert_not_awaited()
@pytest.mark.asyncio(loop_scope="session")
async def test_include_graph_handles_fetch_failure(self):
"""include_graph=True still returns agents when graph fetch fails."""
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
mock_agent = self._make_mock_library_agent(agent_id)
mock_lib_db = MagicMock()
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
mock_graph_db = MagicMock()
mock_graph_db.get_graph = AsyncMock(side_effect=Exception("DB down"))
with (
patch(
"backend.copilot.tools.agent_search.library_db",
return_value=mock_lib_db,
),
patch(
"backend.copilot.tools.agent_search.get_agent_as_json",
new_callable=AsyncMock,
side_effect=Exception("DB connection failed"),
"backend.copilot.tools.agent_search.get_graph_db",
return_value=mock_graph_db,
),
):
response = await search_agents(
query=agent_id,
source="library",
session_id="test-session",
session_id="s",
user_id=_TEST_USER_ID,
include_graph=True,
)
assert isinstance(response, AgentsFoundResponse)
assert response.count == 1
assert response.agents[0].graph is None
@pytest.mark.asyncio(loop_scope="session")
async def test_include_graph_handles_none_return(self):
"""include_graph=True handles get_agent_as_json returning None."""
"""include_graph=True handles get_graph returning None."""
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
mock_agent = self._make_mock_library_agent(agent_id)
mock_lib_db = MagicMock()
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
mock_graph_db = MagicMock()
mock_graph_db.get_graph = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.tools.agent_search.library_db",
return_value=mock_lib_db,
),
patch(
"backend.copilot.tools.agent_search.get_agent_as_json",
new_callable=AsyncMock,
return_value=None,
"backend.copilot.tools.agent_search.get_graph_db",
return_value=mock_graph_db,
),
):
response = await search_agents(
query=agent_id,
source="library",
session_id="test-session",
session_id="s",
user_id=_TEST_USER_ID,
include_graph=True,
)
assert isinstance(response, AgentsFoundResponse)
assert response.count == 1
assert response.agents[0].graph is None

View File

@@ -6,6 +6,7 @@ from typing import Any, Literal
from pydantic import BaseModel, Field
from backend.data.graph import BaseGraph
from backend.data.model import CredentialsMetaInput
@@ -122,7 +123,7 @@ class AgentInfo(BaseModel):
default=None,
description="Input schema for the agent, including field names, types, and defaults",
)
graph: dict[str, Any] | None = Field(
graph: BaseGraph | None = Field(
default=None,
description="Full graph structure (nodes + links) when include_graph is requested",
)

View File

@@ -7463,10 +7463,9 @@
},
"graph": {
"anyOf": [
{ "additionalProperties": true, "type": "object" },
{ "$ref": "#/components/schemas/BaseGraph-Output" },
{ "type": "null" }
],
"title": "Graph",
"description": "Full graph structure (nodes + links) when include_graph is requested"
}
},