feat: add two-phase library search for better sub-agent discovery

- Add TypedDict types for agent summaries (LibraryAgentSummary, MarketplaceAgentSummary, DecompositionResult)
- Add extract_search_terms_from_steps() to extract keywords from decomposed instructions
- Add enrich_library_agents_from_steps() for two-phase search after decomposition
- Integrate enrichment into create_agent.py flow
- Add comprehensive tests for new functionality
This commit is contained in:
Zamil Majdy
2026-01-29 18:51:07 -06:00
parent a3fe1ede55
commit c039a2e3ad
4 changed files with 578 additions and 46 deletions

View File

@@ -1,8 +1,15 @@
"""Agent generator package - Creates agents from natural language."""
from .core import (
from .core import ( # Types; Exceptions; Functions
AgentGeneratorNotConfiguredError,
AgentSummary,
DecompositionResult,
DecompositionStep,
LibraryAgentSummary,
MarketplaceAgentSummary,
decompose_goal,
enrich_library_agents_from_steps,
extract_search_terms_from_steps,
generate_agent,
generate_agent_patch,
get_agent_as_json,
@@ -17,6 +24,12 @@ from .service import health_check as check_external_service_health
from .service import is_external_service_configured
__all__ = [
# Types
"AgentSummary",
"DecompositionResult",
"DecompositionStep",
"LibraryAgentSummary",
"MarketplaceAgentSummary",
# Core functions
"decompose_goal",
"generate_agent",
@@ -26,6 +39,8 @@ __all__ = [
"get_library_agents_for_generation",
"get_all_relevant_agents_for_generation",
"search_marketplace_agents_for_generation",
"enrich_library_agents_from_steps",
"extract_search_terms_from_steps",
"json_to_graph",
# Exceptions
"AgentGeneratorNotConfiguredError",

View File

@@ -2,7 +2,7 @@
import logging
import uuid
from typing import Any
from typing import Any, TypedDict
from backend.api.features.library import db as library_db
from backend.data.graph import Graph, Link, Node, create_graph
@@ -17,6 +17,65 @@ from .service import (
logger = logging.getLogger(__name__)
# =============================================================================
# Type Definitions
# =============================================================================
class LibraryAgentSummary(TypedDict):
"""Summary of a library agent for sub-agent composition."""
graph_id: str
graph_version: int
name: str
description: str
input_schema: dict[str, Any]
output_schema: dict[str, Any]
class MarketplaceAgentSummary(TypedDict):
"""Summary of a marketplace agent for sub-agent composition."""
name: str
description: str
sub_heading: str
creator: str
is_marketplace_agent: bool
class DecompositionStep(TypedDict, total=False):
"""A single step in decomposed instructions."""
description: str
action: str
block_name: str
tool: str
name: str
class DecompositionResult(TypedDict, total=False):
"""Result from decompose_goal - can be instructions, questions, or error."""
type: str # "instructions", "clarifying_questions", "error", etc.
steps: list[DecompositionStep]
questions: list[dict[str, Any]]
error: str
error_type: str
# Type alias for agent summaries (can be either library or marketplace)
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
def _to_dict_list(
agents: list[AgentSummary] | list[dict[str, Any]] | None,
) -> list[dict[str, Any]] | None:
"""Convert typed agent summaries to plain dicts for external service calls."""
if agents is None:
return None
return [dict(a) for a in agents]
class AgentGeneratorNotConfiguredError(Exception):
"""Raised when the external Agent Generator service is not configured."""
@@ -41,7 +100,7 @@ async def get_library_agents_for_generation(
search_query: str | None = None,
exclude_graph_id: str | None = None,
max_results: int = 15,
) -> list[dict[str, Any]]:
) -> list[LibraryAgentSummary]:
"""Fetch user's library agents formatted for Agent Generator.
Uses search-based fetching to return relevant agents instead of all agents.
@@ -54,29 +113,33 @@ async def get_library_agents_for_generation(
max_results: Maximum number of agents to return (default 15)
Returns:
List of library agent dicts with schemas for sub-agent composition
List of LibraryAgentSummary with schemas for sub-agent composition
"""
try:
response = await library_db.list_library_agents(
user_id=user_id,
search_term=search_query, # Use search API
search_term=search_query,
page=1,
page_size=max_results,
)
return [
{
"graph_id": agent.graph_id,
"graph_version": agent.graph_version,
"name": agent.name,
"description": agent.description,
"input_schema": agent.input_schema,
"output_schema": agent.output_schema,
}
for agent in response.agents
results: list[LibraryAgentSummary] = []
for agent in response.agents:
# Exclude the agent being generated/edited to prevent circular references
if exclude_graph_id is None or agent.graph_id != exclude_graph_id
]
if exclude_graph_id is not None and agent.graph_id == exclude_graph_id:
continue
results.append(
LibraryAgentSummary(
graph_id=agent.graph_id,
graph_version=agent.graph_version,
name=agent.name,
description=agent.description,
input_schema=agent.input_schema,
output_schema=agent.output_schema,
)
)
return results
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
return []
@@ -85,7 +148,7 @@ async def get_library_agents_for_generation(
async def search_marketplace_agents_for_generation(
search_query: str,
max_results: int = 10,
) -> list[dict[str, Any]]:
) -> list[MarketplaceAgentSummary]:
"""Search marketplace agents formatted for Agent Generator.
Note: This returns basic agent info. Full input/output schemas would require
@@ -96,7 +159,7 @@ async def search_marketplace_agents_for_generation(
max_results: Maximum number of agents to return (default 10)
Returns:
List of marketplace agent dicts (without detailed schemas for now)
List of MarketplaceAgentSummary (without detailed schemas for now)
"""
from backend.api.features.store import db as store_db
@@ -107,18 +170,18 @@ async def search_marketplace_agents_for_generation(
page_size=max_results,
)
# Return basic info - full schemas would require fetching each agent's graph
return [
{
"name": agent.agent_name,
"description": agent.description,
"sub_heading": agent.sub_heading,
"creator": agent.creator,
"is_marketplace_agent": True,
# Note: graph_id and schemas not available without additional fetches
}
for agent in response.agents
]
results: list[MarketplaceAgentSummary] = []
for agent in response.agents:
results.append(
MarketplaceAgentSummary(
name=agent.agent_name,
description=agent.description,
sub_heading=agent.sub_heading,
creator=agent.creator,
is_marketplace_agent=True,
)
)
return results
except Exception as e:
logger.warning(f"Failed to search marketplace agents: {e}")
return []
@@ -131,7 +194,7 @@ async def get_all_relevant_agents_for_generation(
include_marketplace: bool = True,
max_library_results: int = 15,
max_marketplace_results: int = 10,
) -> list[dict[str, Any]]:
) -> list[AgentSummary]:
"""Fetch relevant agents from library and optionally marketplace.
Combines search results from user's library and public marketplace,
@@ -146,10 +209,10 @@ async def get_all_relevant_agents_for_generation(
max_marketplace_results: Max marketplace agents to return (default 10)
Returns:
List of agent dicts, library agents first (with full schemas),
List of AgentSummary, library agents first (with full schemas),
then marketplace agents (basic info only)
"""
agents: list[dict[str, Any]] = []
agents: list[AgentSummary] = []
# Get library agents (these have full schemas)
library_agents = await get_library_agents_for_generation(
@@ -167,20 +230,157 @@ async def get_all_relevant_agents_for_generation(
max_results=max_marketplace_results,
)
# Add marketplace agents that aren't already in library (by name)
library_names = {a["name"].lower() for a in library_agents if a.get("name")}
# LibraryAgentSummary always has 'name', so access directly
library_names = {a["name"].lower() for a in library_agents}
for agent in marketplace_agents:
agent_name = agent.get("name")
if agent_name and agent_name.lower() not in library_names:
# MarketplaceAgentSummary always has 'name'
if agent["name"].lower() not in library_names:
agents.append(agent)
return agents
def extract_search_terms_from_steps(
decomposition_result: DecompositionResult | dict[str, Any],
) -> list[str]:
"""Extract search terms from decomposed instruction steps.
Analyzes the decomposition result to extract relevant keywords
for additional library agent searches.
Args:
decomposition_result: Result from decompose_goal containing steps
Returns:
List of unique search terms extracted from steps
"""
search_terms: list[str] = []
# Handle instructions type (contains steps)
if decomposition_result.get("type") != "instructions":
return search_terms
steps = decomposition_result.get("steps", [])
if not steps:
return search_terms
# Keys that might contain useful search terms in DecompositionStep
step_keys: list[str] = ["description", "action", "block_name", "tool", "name"]
for step in steps:
# Extract text values from step dictionary
for key in step_keys:
value = step.get(key) # type: ignore[union-attr]
if value and len(value) > 3:
search_terms.append(value)
# Deduplicate while preserving order
seen: set[str] = set()
unique_terms: list[str] = []
for term in search_terms:
term_lower = term.lower()
if term_lower not in seen:
seen.add(term_lower)
unique_terms.append(term)
return unique_terms
async def enrich_library_agents_from_steps(
user_id: str,
decomposition_result: DecompositionResult | dict[str, Any],
existing_agents: list[AgentSummary] | list[dict[str, Any]],
exclude_graph_id: str | None = None,
include_marketplace: bool = True,
max_additional_results: int = 10,
) -> list[AgentSummary] | list[dict[str, Any]]:
"""Enrich library agents list with additional searches based on decomposed steps.
This implements two-phase search: after decomposition, we search for additional
relevant agents based on the specific steps identified.
Args:
user_id: The user ID
decomposition_result: Result from decompose_goal containing steps
existing_agents: Already fetched library agents from initial search
exclude_graph_id: Optional graph ID to exclude
include_marketplace: Whether to also search marketplace
max_additional_results: Max additional agents per search term (default 10)
Returns:
Combined list of library agents (existing + newly discovered)
"""
# Extract search terms from steps
search_terms = extract_search_terms_from_steps(decomposition_result)
if not search_terms:
return existing_agents
# Track existing agent IDs and names to avoid duplicates
# graph_id is only present in LibraryAgentSummary, not MarketplaceAgentSummary
existing_ids: set[str] = set()
existing_names: set[str] = set()
for agent in existing_agents:
# All agent summaries have 'name'
existing_names.add(agent["name"].lower())
# Only library agents have 'graph_id'
graph_id = agent.get("graph_id") # type: ignore[call-overload]
if graph_id:
existing_ids.add(graph_id)
all_agents: list[AgentSummary] | list[dict[str, Any]] = list(existing_agents)
# Search for additional agents using step-derived terms
# Limit to first 3 search terms to avoid too many API calls
for term in search_terms[:3]:
try:
additional_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
search_query=term,
exclude_graph_id=exclude_graph_id,
include_marketplace=include_marketplace,
max_library_results=max_additional_results,
max_marketplace_results=5, # Smaller limit for step-based searches
)
# Add only new agents (deduplicate by graph_id and name)
for agent in additional_agents:
agent_name = agent["name"].lower()
# Skip if already have this agent by name
if agent_name in existing_names:
continue
# Also check by graph_id for library agents
graph_id = agent.get("graph_id") # type: ignore[call-overload]
if graph_id and graph_id in existing_ids:
continue
all_agents.append(agent)
existing_names.add(agent_name)
if graph_id:
existing_ids.add(graph_id)
except Exception as e:
# Log but don't fail - continue with other search terms
logger.warning(
f"Failed to search for additional agents with term '{term}': {e}"
)
logger.debug(
f"Enriched library agents: {len(existing_agents)} initial + "
f"{len(all_agents) - len(existing_agents)} additional = {len(all_agents)} total"
)
return all_agents
async def decompose_goal(
description: str,
context: str = "",
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
library_agents: list[AgentSummary] | None = None,
) -> DecompositionResult | None:
"""Break down a goal into steps or return clarifying questions.
Args:
@@ -189,7 +389,7 @@ async def decompose_goal(
library_agents: User's library agents available for sub-agent composition
Returns:
Dict with either:
DecompositionResult with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
Or None on error
@@ -199,12 +399,17 @@ async def decompose_goal(
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for decompose_goal")
return await decompose_goal_external(description, context, library_agents)
# Convert typed dicts to plain dicts for external service
result = await decompose_goal_external(
description, context, _to_dict_list(library_agents)
)
# Cast the result to DecompositionResult (external service returns dict)
return result # type: ignore[return-value]
async def generate_agent(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
instructions: DecompositionResult | dict[str, Any],
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
@@ -220,7 +425,10 @@ async def generate_agent(
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(instructions, library_agents)
# Convert typed dicts to plain dicts for external service
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents)
)
if result:
# Check if it's an error response - pass through as-is
if isinstance(result, dict) and result.get("type") == "error":
@@ -407,7 +615,7 @@ async def get_agent_as_json(
async def generate_agent_patch(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
library_agents: list[AgentSummary] | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
@@ -430,6 +638,7 @@ async def generate_agent_patch(
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
# Convert typed dicts to plain dicts for external service
return await generate_agent_patch_external(
update_request, current_agent, library_agents
update_request, current_agent, _to_dict_list(library_agents)
)

View File

@@ -8,6 +8,7 @@ from backend.api.features.chat.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
decompose_goal,
enrich_library_agents_from_steps,
generate_agent,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
@@ -209,6 +210,23 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Step 1.5: Enrich library agents with step-based search (two-phase search)
# After decomposition, search for additional relevant agents based on the steps
if user_id and library_agents is not None:
try:
library_agents = await enrich_library_agents_from_steps(
user_id=user_id,
decomposition_result=decomposition_result,
existing_agents=library_agents,
include_marketplace=True,
)
logger.debug(
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
)
except Exception as e:
# Log but don't fail - continue with existing agents
logger.warning(f"Failed to enrich library agents from steps: {e}")
# Step 2: Generate agent JSON (external service handles fixing and validation)
try:
agent_json = await generate_agent(decomposition_result, library_agents)

View File

@@ -347,5 +347,295 @@ class TestGetAllRelevantAgentsForGeneration:
assert len(result) == 1
class TestExtractSearchTermsFromSteps:
"""Test extract_search_terms_from_steps function."""
def test_extracts_terms_from_instructions_type(self):
"""Test extraction from valid instructions decomposition result."""
decomposition_result = {
"type": "instructions",
"steps": [
{
"description": "Send an email notification",
"block_name": "GmailSendBlock",
},
{"description": "Fetch weather data", "action": "Get weather API"},
],
}
result = core.extract_search_terms_from_steps(decomposition_result)
assert "Send an email notification" in result
assert "GmailSendBlock" in result
assert "Fetch weather data" in result
assert "Get weather API" in result
def test_returns_empty_for_non_instructions_type(self):
"""Test that non-instructions types return empty list."""
decomposition_result = {
"type": "clarifying_questions",
"questions": [{"question": "What email?"}],
}
result = core.extract_search_terms_from_steps(decomposition_result)
assert result == []
def test_deduplicates_terms_case_insensitively(self):
"""Test that duplicate terms are removed (case-insensitive)."""
decomposition_result = {
"type": "instructions",
"steps": [
{"description": "Send Email", "name": "send email"},
{"description": "Other task"},
],
}
result = core.extract_search_terms_from_steps(decomposition_result)
# Should only have one "send email" variant
email_terms = [t for t in result if "email" in t.lower()]
assert len(email_terms) == 1
def test_filters_short_terms(self):
"""Test that terms with 3 or fewer characters are filtered out."""
decomposition_result = {
"type": "instructions",
"steps": [
{"description": "ab", "action": "xyz"}, # Both too short
{"description": "Valid term here"},
],
}
result = core.extract_search_terms_from_steps(decomposition_result)
assert "ab" not in result
assert "xyz" not in result
assert "Valid term here" in result
def test_handles_empty_steps(self):
"""Test handling of empty steps list."""
decomposition_result = {
"type": "instructions",
"steps": [],
}
result = core.extract_search_terms_from_steps(decomposition_result)
assert result == []
class TestEnrichLibraryAgentsFromSteps:
"""Test enrich_library_agents_from_steps function."""
@pytest.mark.asyncio
async def test_enriches_with_additional_agents(self):
"""Test that additional agents are found based on steps."""
existing_agents = [
{
"graph_id": "existing-123",
"graph_version": 1,
"name": "Existing Agent",
"description": "Already fetched",
"input_schema": {},
"output_schema": {},
}
]
additional_agents = [
{
"graph_id": "new-456",
"graph_version": 1,
"name": "Email Agent",
"description": "For sending emails",
"input_schema": {},
"output_schema": {},
}
]
decomposition_result = {
"type": "instructions",
"steps": [
{"description": "Send email notification"},
],
}
with patch.object(
core,
"get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=additional_agents,
):
result = await core.enrich_library_agents_from_steps(
user_id="user-123",
decomposition_result=decomposition_result,
existing_agents=existing_agents,
)
# Should have both existing and new agents
assert len(result) == 2
names = [a["name"] for a in result]
assert "Existing Agent" in names
assert "Email Agent" in names
@pytest.mark.asyncio
async def test_deduplicates_by_graph_id(self):
"""Test that agents with same graph_id are not duplicated."""
existing_agents = [
{
"graph_id": "agent-123",
"graph_version": 1,
"name": "Existing Agent",
"description": "Already fetched",
"input_schema": {},
"output_schema": {},
}
]
# Additional search returns same agent
additional_agents = [
{
"graph_id": "agent-123", # Same ID
"graph_version": 1,
"name": "Existing Agent Copy",
"description": "Same agent different name",
"input_schema": {},
"output_schema": {},
}
]
decomposition_result = {
"type": "instructions",
"steps": [{"description": "Some action"}],
}
with patch.object(
core,
"get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=additional_agents,
):
result = await core.enrich_library_agents_from_steps(
user_id="user-123",
decomposition_result=decomposition_result,
existing_agents=existing_agents,
)
# Should not duplicate
assert len(result) == 1
@pytest.mark.asyncio
async def test_deduplicates_by_name(self):
"""Test that agents with same name are not duplicated."""
existing_agents = [
{
"graph_id": "agent-123",
"graph_version": 1,
"name": "Email Agent",
"description": "Already fetched",
"input_schema": {},
"output_schema": {},
}
]
# Additional search returns agent with same name but different ID
additional_agents = [
{
"graph_id": "agent-456", # Different ID
"graph_version": 1,
"name": "Email Agent", # Same name
"description": "Different agent same name",
"input_schema": {},
"output_schema": {},
}
]
decomposition_result = {
"type": "instructions",
"steps": [{"description": "Send email"}],
}
with patch.object(
core,
"get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=additional_agents,
):
result = await core.enrich_library_agents_from_steps(
user_id="user-123",
decomposition_result=decomposition_result,
existing_agents=existing_agents,
)
# Should not duplicate by name
assert len(result) == 1
assert result[0].get("graph_id") == "agent-123" # Original kept
@pytest.mark.asyncio
async def test_returns_existing_when_no_steps(self):
"""Test that existing agents are returned when no search terms extracted."""
existing_agents = [
{
"graph_id": "existing-123",
"graph_version": 1,
"name": "Existing Agent",
"description": "Already fetched",
"input_schema": {},
"output_schema": {},
}
]
decomposition_result = {
"type": "clarifying_questions", # Not instructions type
"questions": [],
}
result = await core.enrich_library_agents_from_steps(
user_id="user-123",
decomposition_result=decomposition_result,
existing_agents=existing_agents,
)
# Should return existing unchanged
assert result == existing_agents
@pytest.mark.asyncio
async def test_limits_search_terms_to_three(self):
"""Test that only first 3 search terms are used."""
existing_agents = []
decomposition_result = {
"type": "instructions",
"steps": [
{"description": "First action"},
{"description": "Second action"},
{"description": "Third action"},
{"description": "Fourth action"},
{"description": "Fifth action"},
],
}
call_count = 0
async def mock_get_agents(*args, **kwargs):
nonlocal call_count
call_count += 1
return []
with patch.object(
core,
"get_all_relevant_agents_for_generation",
side_effect=mock_get_agents,
):
await core.enrich_library_agents_from_steps(
user_id="user-123",
decomposition_result=decomposition_result,
existing_agents=existing_agents,
)
# Should only make 3 calls (limited to first 3 terms)
assert call_count == 3
if __name__ == "__main__":
pytest.main([__file__, "-v"])