diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py index 5dd8ffe787..42887d9ff4 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py @@ -10,10 +10,12 @@ from .core import ( decompose_goal, enrich_library_agents_from_steps, extract_search_terms_from_steps, + extract_uuids_from_text, generate_agent, generate_agent_patch, get_agent_as_json, get_all_relevant_agents_for_generation, + get_library_agent_by_graph_id, get_library_agents_for_generation, json_to_graph, save_agent_to_library, @@ -34,10 +36,12 @@ __all__ = [ "decompose_goal", "enrich_library_agents_from_steps", "extract_search_terms_from_steps", + "extract_uuids_from_text", "generate_agent", "generate_agent_patch", "get_agent_as_json", "get_all_relevant_agents_for_generation", + "get_library_agent_by_graph_id", "get_library_agents_for_generation", "get_user_message_for_error", "is_external_service_configured", diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py index 8e1db2ad90..950aa37924 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py @@ -1,6 +1,7 @@ """Core agent generation functions.""" import logging +import re import uuid from typing import Any, TypedDict @@ -104,6 +105,56 @@ def _check_service_configured() -> None: ) +# UUID v4 pattern for extracting agent IDs from text +_UUID_PATTERN = re.compile( + r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}", + re.IGNORECASE, +) + + +def extract_uuids_from_text(text: str) -> list[str]: + """Extract all UUID v4 strings from text. + + Args: + text: Text that may contain UUIDs (e.g., user's goal description) + + Returns: + List of unique UUIDs found in the text (lowercase) + """ + matches = _UUID_PATTERN.findall(text) + # Deduplicate and normalize to lowercase + return list({m.lower() for m in matches}) + + +async def get_library_agent_by_graph_id( + user_id: str, graph_id: str +) -> LibraryAgentSummary | None: + """Fetch a specific library agent by its graph_id. + + Args: + user_id: The user ID + graph_id: The graph ID to look up + + Returns: + LibraryAgentSummary if found, None otherwise + """ + try: + agent = await library_db.get_library_agent_by_graph_id(user_id, graph_id) + if not agent: + return None + return LibraryAgentSummary( + graph_id=agent.graph_id, + graph_version=agent.graph_version, + name=agent.name, + description=agent.description, + input_schema=agent.input_schema, + output_schema=agent.output_schema, + ) + except Exception as e: + logger.debug(f"Could not fetch library agent by graph_id {graph_id}: {e}") + return None + + async def get_library_agents_for_generation( user_id: str, search_query: str | None = None, @@ -207,6 +258,9 @@ async def get_all_relevant_agents_for_generation( Combines search results from user's library and public marketplace, with library agents taking priority (they have full schemas). + Also extracts UUIDs from the search_query and fetches those agents + directly to ensure explicitly referenced agents are included. + Args: user_id: The user ID search_query: Search term to find relevant agents (user's goal/description) @@ -220,15 +274,32 @@ async def get_all_relevant_agents_for_generation( then marketplace agents (basic info only) """ agents: list[AgentSummary] = [] + seen_graph_ids: set[str] = set() - # Get library agents (these have full schemas) + # Extract UUIDs from search_query and fetch those agents directly + # This ensures explicitly referenced agents are always included + if search_query: + mentioned_uuids = extract_uuids_from_text(search_query) + for graph_id in mentioned_uuids: + if graph_id == exclude_graph_id: + continue + agent = await get_library_agent_by_graph_id(user_id, graph_id) + if agent and agent["graph_id"] not in seen_graph_ids: + agents.append(agent) + seen_graph_ids.add(agent["graph_id"]) + logger.debug(f"Found explicitly mentioned agent: {agent['name']}") + + # Get library agents via search (these have full schemas) library_agents = await get_library_agents_for_generation( user_id=user_id, search_query=search_query, exclude_graph_id=exclude_graph_id, max_results=max_library_results, ) - agents.extend(library_agents) + for agent in library_agents: + if agent["graph_id"] not in seen_graph_ids: + agents.append(agent) + seen_graph_ids.add(agent["graph_id"]) # Optionally add marketplace agents if include_marketplace and search_query: @@ -237,10 +308,8 @@ async def get_all_relevant_agents_for_generation( max_results=max_marketplace_results, ) # Add marketplace agents that aren't already in library (by name) - # LibraryAgentSummary always has 'name', so access directly - library_names = {a["name"].lower() for a in library_agents} + library_names = {a["name"].lower() for a in agents if "name" in a} for agent in marketplace_agents: - # MarketplaceAgentSummary always has 'name' if agent["name"].lower() not in library_names: agents.append(agent) diff --git a/autogpt_platform/backend/test/agent_generator/test_library_agents.py b/autogpt_platform/backend/test/agent_generator/test_library_agents.py index 7fff60870d..ef179c6d55 100644 --- a/autogpt_platform/backend/test/agent_generator/test_library_agents.py +++ b/autogpt_platform/backend/test/agent_generator/test_library_agents.py @@ -637,5 +637,150 @@ class TestEnrichLibraryAgentsFromSteps: assert call_count == 3 +class TestExtractUuidsFromText: + """Test extract_uuids_from_text function.""" + + def test_extracts_single_uuid(self): + """Test extraction of a single UUID from text.""" + text = "Use my agent 46631191-e8a8-486f-ad90-84f89738321d for this task" + result = core.extract_uuids_from_text(text) + assert len(result) == 1 + assert "46631191-e8a8-486f-ad90-84f89738321d" in result + + def test_extracts_multiple_uuids(self): + """Test extraction of multiple UUIDs from text.""" + text = ( + "Combine agents 11111111-1111-4111-8111-111111111111 " + "and 22222222-2222-4222-9222-222222222222" + ) + result = core.extract_uuids_from_text(text) + assert len(result) == 2 + assert "11111111-1111-4111-8111-111111111111" in result + assert "22222222-2222-4222-9222-222222222222" in result + + def test_deduplicates_uuids(self): + """Test that duplicate UUIDs are deduplicated.""" + text = ( + "Use 46631191-e8a8-486f-ad90-84f89738321d twice: " + "46631191-e8a8-486f-ad90-84f89738321d" + ) + result = core.extract_uuids_from_text(text) + assert len(result) == 1 + + def test_normalizes_to_lowercase(self): + """Test that UUIDs are normalized to lowercase.""" + text = "Use 46631191-E8A8-486F-AD90-84F89738321D" + result = core.extract_uuids_from_text(text) + assert result[0] == "46631191-e8a8-486f-ad90-84f89738321d" + + def test_returns_empty_for_no_uuids(self): + """Test that empty list is returned when no UUIDs found.""" + text = "Create an email agent that sends notifications" + result = core.extract_uuids_from_text(text) + assert result == [] + + def test_ignores_invalid_uuids(self): + """Test that invalid UUID-like strings are ignored.""" + text = "Not a valid UUID: 12345678-1234-1234-1234-123456789abc" + result = core.extract_uuids_from_text(text) + # UUID v4 requires specific patterns (4 in third group, 8/9/a/b in fourth) + assert len(result) == 0 + + +class TestGetLibraryAgentByGraphId: + """Test get_library_agent_by_graph_id function.""" + + @pytest.mark.asyncio + async def test_returns_agent_when_found(self): + """Test that agent is returned when found by graph_id.""" + mock_agent = MagicMock() + mock_agent.graph_id = "agent-123" + mock_agent.graph_version = 1 + mock_agent.name = "Test Agent" + mock_agent.description = "Test description" + mock_agent.input_schema = {"properties": {}} + mock_agent.output_schema = {"properties": {}} + + with patch.object( + core.library_db, + "get_library_agent_by_graph_id", + new_callable=AsyncMock, + return_value=mock_agent, + ): + result = await core.get_library_agent_by_graph_id("user-123", "agent-123") + + assert result is not None + assert result["graph_id"] == "agent-123" + assert result["name"] == "Test Agent" + + @pytest.mark.asyncio + async def test_returns_none_when_not_found(self): + """Test that None is returned when agent not found.""" + with patch.object( + core.library_db, + "get_library_agent_by_graph_id", + new_callable=AsyncMock, + return_value=None, + ): + result = await core.get_library_agent_by_graph_id("user-123", "nonexistent") + + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_on_exception(self): + """Test that None is returned when exception occurs.""" + with patch.object( + core.library_db, + "get_library_agent_by_graph_id", + new_callable=AsyncMock, + side_effect=Exception("Database error"), + ): + result = await core.get_library_agent_by_graph_id("user-123", "agent-123") + + assert result is None + + +class TestGetAllRelevantAgentsWithUuids: + """Test UUID extraction in get_all_relevant_agents_for_generation.""" + + @pytest.mark.asyncio + async def test_fetches_explicitly_mentioned_agents(self): + """Test that agents mentioned by UUID are fetched directly.""" + mock_agent = MagicMock() + mock_agent.graph_id = "46631191-e8a8-486f-ad90-84f89738321d" + mock_agent.graph_version = 1 + mock_agent.name = "Mentioned Agent" + mock_agent.description = "Explicitly mentioned" + mock_agent.input_schema = {} + mock_agent.output_schema = {} + + mock_response = MagicMock() + mock_response.agents = [] + + with ( + patch.object( + core.library_db, + "get_library_agent_by_graph_id", + new_callable=AsyncMock, + return_value=mock_agent, + ), + patch.object( + core.library_db, + "list_library_agents", + new_callable=AsyncMock, + return_value=mock_response, + ), + ): + result = await core.get_all_relevant_agents_for_generation( + user_id="user-123", + search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d", + include_marketplace=False, + ) + + assert len(result) == 1 + assert result[0].get("graph_id") == "46631191-e8a8-486f-ad90-84f89738321d" + assert result[0].get("name") == "Mentioned Agent" + + if __name__ == "__main__": pytest.main([__file__, "-v"])