mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 23:28:07 -05:00
feat(backend): implement hybrid search with BM25, vector, and RRF ranking
Implement hybrid search for the store combining: - BM25 full-text search (PostgreSQL tsvector with ts_rank_cd) - Vector semantic similarity (pgvector cosine distance) - Popularity signal (run counts as PageRank proxy) Results are ranked using Reciprocal Rank Fusion (RRF) formula. Key changes: - Add migration for BM25 trigger with weighted fields and GIN index - Add SearchFilterMode enum (strict/permissive/combined) - Update get_store_agents() with hybrid search SQL using CTEs - Add filter_mode parameter to API endpoint (default: permissive) - Add RRF score threshold (0.02) to filter irrelevant results Thresholds: - Vector similarity: >= 0.4 - BM25 relevance: >= 0.05 - RRF score: >= 0.02 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -27,8 +27,9 @@ async def _get_cached_store_agents(
|
||||
category: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
filter_mode: Literal["strict", "permissive", "combined"] = "permissive",
|
||||
):
|
||||
"""Cached helper to get store agents."""
|
||||
"""Cached helper to get store agents with hybrid search support."""
|
||||
return await backend.server.v2.store.db.get_store_agents(
|
||||
featured=featured,
|
||||
creators=[creator] if creator else None,
|
||||
@@ -37,6 +38,7 @@ async def _get_cached_store_agents(
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
filter_mode=filter_mode,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,9 +41,23 @@ DEFAULT_ADMIN_EMAIL = "admin@autogpt.co"
|
||||
|
||||
# Minimum similarity threshold for vector search results
|
||||
# Cosine similarity ranges from -1 to 1, where 1 is identical
|
||||
# 0.4 filters out loosely related or unrelated results
|
||||
# 0.4 filters loosely related results while keeping semantically relevant ones
|
||||
VECTOR_SEARCH_SIMILARITY_THRESHOLD = 0.4
|
||||
|
||||
# Minimum relevance threshold for BM25 full-text search results
|
||||
# ts_rank_cd returns values typically in range 0-1 (can exceed 1 for exact matches)
|
||||
# 0.05 allows partial keyword matches
|
||||
BM25_RELEVANCE_THRESHOLD = 0.05
|
||||
|
||||
# RRF constant (k) - standard value that balances influence of top vs lower ranks
|
||||
# Higher k values reduce the influence of high-ranking items
|
||||
RRF_K = 60
|
||||
|
||||
# Minimum RRF score threshold for combined mode
|
||||
# Filters out results that rank poorly across all signals
|
||||
# For reference: rank #1 in all = ~0.041, rank #100 in all = ~0.016
|
||||
RRF_SCORE_THRESHOLD = 0.02
|
||||
|
||||
|
||||
async def get_store_agents(
|
||||
featured: bool = False,
|
||||
@@ -53,79 +67,189 @@ async def get_store_agents(
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
filter_mode: Literal["strict", "permissive", "combined"] = "permissive",
|
||||
) -> backend.server.v2.store.model.StoreAgentsResponse:
|
||||
"""
|
||||
Get PUBLIC store agents from the StoreAgent view
|
||||
Get PUBLIC store agents from the StoreAgent view.
|
||||
|
||||
When search_query is provided, uses hybrid search combining:
|
||||
- BM25 full-text search (lexical matching via PostgreSQL tsvector)
|
||||
- Vector semantic similarity (meaning-based matching via pgvector)
|
||||
- Popularity signal (run counts as PageRank proxy)
|
||||
|
||||
Results are ranked using Reciprocal Rank Fusion (RRF).
|
||||
|
||||
Args:
|
||||
featured: Filter to only show featured agents.
|
||||
creators: Filter agents by creator usernames.
|
||||
sorted_by: Sort agents by "runs", "rating", "name", or "updated_at".
|
||||
search_query: Search query for hybrid search.
|
||||
category: Filter agents by category.
|
||||
page: Page number for pagination.
|
||||
page_size: Number of agents per page.
|
||||
filter_mode: Controls how results are filtered when searching:
|
||||
- "strict": Must match BOTH BM25 AND vector thresholds
|
||||
- "permissive": Must match EITHER BM25 OR vector threshold
|
||||
- "combined": No threshold filtering, rely on RRF score (default)
|
||||
|
||||
Returns:
|
||||
StoreAgentsResponse with paginated list of agents.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||
f"Getting store agents. featured={featured}, creators={creators}, "
|
||||
f"sorted_by={sorted_by}, search={search_query}, category={category}, "
|
||||
f"page={page}, filter_mode={filter_mode}"
|
||||
)
|
||||
|
||||
try:
|
||||
# If search_query is provided, use vector similarity search
|
||||
# If search_query is provided, use hybrid search (BM25 + vector + popularity)
|
||||
if search_query:
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Generate embedding for search query
|
||||
# Generate embedding for vector search
|
||||
embedding_service = get_embedding_service()
|
||||
query_embedding = await embedding_service.generate_embedding(search_query)
|
||||
# Convert embedding to PostgreSQL array format
|
||||
embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
|
||||
|
||||
# Whitelist allowed order_by columns
|
||||
# For vector search, we use similarity instead of rank
|
||||
ALLOWED_ORDER_BY = {
|
||||
"rating": "rating DESC, similarity DESC",
|
||||
"runs": "runs DESC, similarity DESC",
|
||||
"name": "agent_name ASC, similarity DESC",
|
||||
"updated_at": "updated_at DESC, similarity DESC",
|
||||
}
|
||||
|
||||
# Validate and get order clause
|
||||
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
|
||||
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
|
||||
else:
|
||||
# Default: order by vector similarity (most similar first)
|
||||
order_by_clause = "similarity DESC, updated_at DESC"
|
||||
|
||||
# Build WHERE conditions and parameters list
|
||||
where_parts: list[str] = []
|
||||
params: list[typing.Any] = [embedding_str] # $1 - query embedding
|
||||
param_index = 2 # Start at $2 for next parameter
|
||||
|
||||
# Always filter for available agents and agents with embeddings
|
||||
# Always filter for available agents
|
||||
where_parts.append("is_available = true")
|
||||
where_parts.append("embedding IS NOT NULL")
|
||||
# Filter out results below similarity threshold
|
||||
where_parts.append(
|
||||
f"1 - (embedding <=> $1::vector) >= {VECTOR_SEARCH_SIMILARITY_THRESHOLD}"
|
||||
)
|
||||
|
||||
# Require at least one search signal to be present
|
||||
if filter_mode == "strict":
|
||||
# Strict mode: require both embedding AND search to be available
|
||||
where_parts.append("embedding IS NOT NULL")
|
||||
where_parts.append("search IS NOT NULL")
|
||||
else:
|
||||
# Permissive/combined: require at least one signal
|
||||
where_parts.append("(embedding IS NOT NULL OR search IS NOT NULL)")
|
||||
|
||||
if featured:
|
||||
where_parts.append("featured = true")
|
||||
|
||||
if creators and creators:
|
||||
if creators:
|
||||
# Use ANY with array parameter
|
||||
where_parts.append(f"creator_username = ANY(${param_index})")
|
||||
params.append(creators)
|
||||
param_index += 1
|
||||
|
||||
if category and category:
|
||||
if category:
|
||||
where_parts.append(f"${param_index} = ANY(categories)")
|
||||
params.append(category)
|
||||
param_index += 1
|
||||
|
||||
# Add search query for BM25
|
||||
params.append(search_query)
|
||||
bm25_query_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
|
||||
|
||||
# Build score filter based on filter_mode
|
||||
# This filter is applied BEFORE RRF ranking in the filtered_agents CTE
|
||||
if filter_mode == "strict":
|
||||
score_filter = f"""
|
||||
bm25_score >= {BM25_RELEVANCE_THRESHOLD}
|
||||
AND vector_score >= {VECTOR_SEARCH_SIMILARITY_THRESHOLD}
|
||||
"""
|
||||
elif filter_mode == "permissive":
|
||||
score_filter = f"""
|
||||
bm25_score >= {BM25_RELEVANCE_THRESHOLD}
|
||||
OR vector_score >= {VECTOR_SEARCH_SIMILARITY_THRESHOLD}
|
||||
"""
|
||||
else: # combined - no pre-filtering on individual scores
|
||||
score_filter = "1=1"
|
||||
|
||||
# RRF score filter is applied AFTER ranking to filter irrelevant results
|
||||
rrf_score_filter = f"rrf_score >= {RRF_SCORE_THRESHOLD}"
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
limit_param = f"${param_index}"
|
||||
offset_param = f"${param_index + 1}"
|
||||
|
||||
# Vector similarity search query using cosine distance
|
||||
# The <=> operator returns cosine distance (0 = identical, 2 = opposite)
|
||||
# We convert to similarity: 1 - distance/2 gives range [0, 1]
|
||||
# Hybrid search SQL with Reciprocal Rank Fusion (RRF)
|
||||
# CTEs: scored_agents -> filtered_agents -> ranked_agents -> rrf_scored
|
||||
sql_query = f"""
|
||||
WITH scored_agents AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
-- BM25 score using ts_rank_cd (covers density normalization)
|
||||
COALESCE(
|
||||
ts_rank_cd(
|
||||
search,
|
||||
plainto_tsquery('english', {bm25_query_param}),
|
||||
32 -- normalization: divide by document length
|
||||
),
|
||||
0
|
||||
) AS bm25_score,
|
||||
-- Vector similarity score (cosine: 1 - distance)
|
||||
CASE
|
||||
WHEN embedding IS NOT NULL
|
||||
THEN 1 - (embedding <=> $1::vector)
|
||||
ELSE 0
|
||||
END AS vector_score,
|
||||
-- Popularity score (log-normalized run count)
|
||||
CASE
|
||||
WHEN runs > 0
|
||||
THEN LN(runs + 1)
|
||||
ELSE 0
|
||||
END AS popularity_score
|
||||
FROM {{schema_prefix}}"StoreAgent"
|
||||
WHERE {sql_where_clause}
|
||||
),
|
||||
max_popularity AS (
|
||||
SELECT GREATEST(MAX(popularity_score), 1) AS max_pop
|
||||
FROM scored_agents
|
||||
),
|
||||
normalized_agents AS (
|
||||
SELECT
|
||||
sa.*,
|
||||
-- Normalize popularity to [0, 1] range
|
||||
sa.popularity_score / mp.max_pop AS norm_popularity_score
|
||||
FROM scored_agents sa
|
||||
CROSS JOIN max_popularity mp
|
||||
),
|
||||
filtered_agents AS (
|
||||
SELECT *
|
||||
FROM normalized_agents
|
||||
WHERE {score_filter}
|
||||
),
|
||||
ranked_agents AS (
|
||||
SELECT
|
||||
*,
|
||||
ROW_NUMBER() OVER (ORDER BY bm25_score DESC NULLS LAST) AS bm25_rank,
|
||||
ROW_NUMBER() OVER (ORDER BY vector_score DESC NULLS LAST) AS vector_rank,
|
||||
ROW_NUMBER() OVER (ORDER BY norm_popularity_score DESC NULLS LAST) AS popularity_rank
|
||||
FROM filtered_agents
|
||||
),
|
||||
rrf_scored AS (
|
||||
SELECT
|
||||
*,
|
||||
-- RRF formula with weighted contributions
|
||||
-- BM25 and vector get full weight, popularity gets 0.5x weight
|
||||
(1.0 / ({RRF_K} + bm25_rank)) +
|
||||
(1.0 / ({RRF_K} + vector_rank)) +
|
||||
(0.5 / ({RRF_K} + popularity_rank)) AS rrf_score
|
||||
FROM ranked_agents
|
||||
)
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
@@ -140,21 +264,77 @@ async def get_store_agents(
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
1 - (embedding <=> $1::vector) AS similarity
|
||||
FROM {{schema_prefix}}"StoreAgent"
|
||||
WHERE {sql_where_clause}
|
||||
ORDER BY {order_by_clause}
|
||||
rrf_score
|
||||
FROM rrf_scored
|
||||
WHERE {rrf_score_filter}
|
||||
ORDER BY rrf_score DESC, updated_at DESC
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
# Count query for pagination
|
||||
# Count query (without pagination) - needs same CTEs for filtering
|
||||
# Must compute RRF scores to filter by rrf_score_filter
|
||||
count_query = f"""
|
||||
WITH scored_agents AS (
|
||||
SELECT
|
||||
runs,
|
||||
COALESCE(
|
||||
ts_rank_cd(
|
||||
search,
|
||||
plainto_tsquery('english', {bm25_query_param}),
|
||||
32
|
||||
),
|
||||
0
|
||||
) AS bm25_score,
|
||||
CASE
|
||||
WHEN embedding IS NOT NULL
|
||||
THEN 1 - (embedding <=> $1::vector)
|
||||
ELSE 0
|
||||
END AS vector_score,
|
||||
CASE
|
||||
WHEN runs > 0
|
||||
THEN LN(runs + 1)
|
||||
ELSE 0
|
||||
END AS popularity_score
|
||||
FROM {{schema_prefix}}"StoreAgent"
|
||||
WHERE {sql_where_clause}
|
||||
),
|
||||
max_popularity AS (
|
||||
SELECT GREATEST(MAX(popularity_score), 1) AS max_pop
|
||||
FROM scored_agents
|
||||
),
|
||||
normalized_agents AS (
|
||||
SELECT
|
||||
sa.*,
|
||||
sa.popularity_score / mp.max_pop AS norm_popularity_score
|
||||
FROM scored_agents sa
|
||||
CROSS JOIN max_popularity mp
|
||||
),
|
||||
filtered_agents AS (
|
||||
SELECT *
|
||||
FROM normalized_agents
|
||||
WHERE {score_filter}
|
||||
),
|
||||
ranked_agents AS (
|
||||
SELECT
|
||||
*,
|
||||
ROW_NUMBER() OVER (ORDER BY bm25_score DESC NULLS LAST) AS bm25_rank,
|
||||
ROW_NUMBER() OVER (ORDER BY vector_score DESC NULLS LAST) AS vector_rank,
|
||||
ROW_NUMBER() OVER (ORDER BY norm_popularity_score DESC NULLS LAST) AS popularity_rank
|
||||
FROM filtered_agents
|
||||
),
|
||||
rrf_scored AS (
|
||||
SELECT
|
||||
(1.0 / ({RRF_K} + bm25_rank)) +
|
||||
(1.0 / ({RRF_K} + vector_rank)) +
|
||||
(0.5 / ({RRF_K} + popularity_rank)) AS rrf_score
|
||||
FROM ranked_agents
|
||||
)
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"StoreAgent"
|
||||
WHERE {sql_where_clause}
|
||||
FROM rrf_scored
|
||||
WHERE {rrf_score_filter}
|
||||
"""
|
||||
|
||||
# Execute both queries with parameters
|
||||
# Execute queries
|
||||
agents = await query_raw_with_schema(sql_query, *params)
|
||||
|
||||
# For count, use params without pagination (last 2 params)
|
||||
|
||||
@@ -407,12 +407,12 @@ async def test_get_store_agents_search_category_array_injection():
|
||||
assert isinstance(result.agents, list)
|
||||
|
||||
|
||||
# Vector search tests
|
||||
# Hybrid search tests (BM25 + vector + popularity with RRF ranking)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_vector_search_mocked(mocker):
|
||||
"""Test vector search uses embedding service and executes query safely."""
|
||||
async def test_get_store_agents_hybrid_search_mocked(mocker):
|
||||
"""Test hybrid search uses embedding service and executes query safely."""
|
||||
from backend.integrations.embeddings import EMBEDDING_DIMENSIONS
|
||||
|
||||
# Mock embedding service
|
||||
@@ -444,8 +444,8 @@ async def test_get_store_agents_vector_search_mocked(mocker):
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_vector_search_with_results(mocker):
|
||||
"""Test vector search returns properly formatted results."""
|
||||
async def test_get_store_agents_hybrid_search_with_results(mocker):
|
||||
"""Test hybrid search returns properly formatted results with RRF scoring."""
|
||||
from backend.integrations.embeddings import EMBEDDING_DIMENSIONS
|
||||
|
||||
# Mock embedding service
|
||||
@@ -459,7 +459,7 @@ async def test_get_store_agents_vector_search_with_results(mocker):
|
||||
mocker.MagicMock(return_value=mock_embedding_service),
|
||||
)
|
||||
|
||||
# Mock query results
|
||||
# Mock query results (hybrid search returns rrf_score instead of similarity)
|
||||
mock_agents = [
|
||||
{
|
||||
"slug": "test-agent",
|
||||
@@ -475,7 +475,7 @@ async def test_get_store_agents_vector_search_with_results(mocker):
|
||||
"featured": False,
|
||||
"is_available": True,
|
||||
"updated_at": datetime.now(),
|
||||
"similarity": 0.95,
|
||||
"rrf_score": 0.048, # RRF score from combined rankings
|
||||
}
|
||||
]
|
||||
mock_count = [{"count": 1}]
|
||||
@@ -496,8 +496,8 @@ async def test_get_store_agents_vector_search_with_results(mocker):
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_vector_search_with_filters(mocker):
|
||||
"""Test vector search works correctly with additional filters."""
|
||||
async def test_get_store_agents_hybrid_search_with_filters(mocker):
|
||||
"""Test hybrid search works correctly with additional filters."""
|
||||
from backend.integrations.embeddings import EMBEDDING_DIMENSIONS
|
||||
|
||||
# Mock embedding service
|
||||
@@ -523,7 +523,6 @@ async def test_get_store_agents_vector_search_with_filters(mocker):
|
||||
featured=True,
|
||||
creators=["creator1", "creator2"],
|
||||
category="AI",
|
||||
sorted_by="rating",
|
||||
)
|
||||
|
||||
# Verify query was called with parameterized values
|
||||
@@ -534,13 +533,124 @@ async def test_get_store_agents_vector_search_with_filters(mocker):
|
||||
first_call_args = mock_query.call_args_list[0]
|
||||
sql_query = first_call_args[0][0]
|
||||
|
||||
# Verify key elements of the query
|
||||
assert "embedding <=> $1::vector" in sql_query
|
||||
# Verify key elements of hybrid search query
|
||||
assert "embedding <=> $1::vector" in sql_query # Vector search
|
||||
assert "ts_rank_cd" in sql_query # BM25 search
|
||||
assert "rrf_score" in sql_query # RRF ranking
|
||||
assert "featured = true" in sql_query
|
||||
assert "creator_username = ANY($" in sql_query
|
||||
assert "= ANY(categories)" in sql_query
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_hybrid_search_strict_filter_mode(mocker):
|
||||
"""Test hybrid search with strict filter mode requires both BM25 and vector matches."""
|
||||
from backend.integrations.embeddings import EMBEDDING_DIMENSIONS
|
||||
|
||||
# Mock embedding service
|
||||
mock_embedding = [0.1] * EMBEDDING_DIMENSIONS
|
||||
mock_embedding_service = mocker.MagicMock()
|
||||
mock_embedding_service.generate_embedding = mocker.AsyncMock(
|
||||
return_value=mock_embedding
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.store.db.get_embedding_service",
|
||||
mocker.MagicMock(return_value=mock_embedding_service),
|
||||
)
|
||||
|
||||
# Mock query_raw_with_schema
|
||||
mock_query = mocker.patch(
|
||||
"backend.server.v2.store.db.query_raw_with_schema",
|
||||
mocker.AsyncMock(side_effect=[[], [{"count": 0}]]),
|
||||
)
|
||||
|
||||
# Call function with strict filter mode
|
||||
await db.get_store_agents(search_query="test query", filter_mode="strict")
|
||||
|
||||
# Check that the SQL query includes strict filtering conditions
|
||||
first_call_args = mock_query.call_args_list[0]
|
||||
sql_query = first_call_args[0][0]
|
||||
|
||||
# Strict mode requires both embedding AND search to be present
|
||||
assert "embedding IS NOT NULL" in sql_query
|
||||
assert "search IS NOT NULL" in sql_query
|
||||
# Strict score filter requires both thresholds to be met
|
||||
assert "bm25_score >=" in sql_query
|
||||
assert "AND vector_score >=" in sql_query
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_hybrid_search_permissive_filter_mode(mocker):
|
||||
"""Test hybrid search with permissive filter mode requires either BM25 or vector match."""
|
||||
from backend.integrations.embeddings import EMBEDDING_DIMENSIONS
|
||||
|
||||
# Mock embedding service
|
||||
mock_embedding = [0.1] * EMBEDDING_DIMENSIONS
|
||||
mock_embedding_service = mocker.MagicMock()
|
||||
mock_embedding_service.generate_embedding = mocker.AsyncMock(
|
||||
return_value=mock_embedding
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.store.db.get_embedding_service",
|
||||
mocker.MagicMock(return_value=mock_embedding_service),
|
||||
)
|
||||
|
||||
# Mock query_raw_with_schema
|
||||
mock_query = mocker.patch(
|
||||
"backend.server.v2.store.db.query_raw_with_schema",
|
||||
mocker.AsyncMock(side_effect=[[], [{"count": 0}]]),
|
||||
)
|
||||
|
||||
# Call function with permissive filter mode
|
||||
await db.get_store_agents(search_query="test query", filter_mode="permissive")
|
||||
|
||||
# Check that the SQL query includes permissive filtering conditions
|
||||
first_call_args = mock_query.call_args_list[0]
|
||||
sql_query = first_call_args[0][0]
|
||||
|
||||
# Permissive mode requires at least one signal
|
||||
assert "(embedding IS NOT NULL OR search IS NOT NULL)" in sql_query
|
||||
# Permissive score filter requires either threshold to be met
|
||||
assert "OR vector_score >=" in sql_query
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_hybrid_search_combined_filter_mode(mocker):
|
||||
"""Test hybrid search with combined filter mode (default) filters by RRF score."""
|
||||
from backend.integrations.embeddings import EMBEDDING_DIMENSIONS
|
||||
|
||||
# Mock embedding service
|
||||
mock_embedding = [0.1] * EMBEDDING_DIMENSIONS
|
||||
mock_embedding_service = mocker.MagicMock()
|
||||
mock_embedding_service.generate_embedding = mocker.AsyncMock(
|
||||
return_value=mock_embedding
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.store.db.get_embedding_service",
|
||||
mocker.MagicMock(return_value=mock_embedding_service),
|
||||
)
|
||||
|
||||
# Mock query_raw_with_schema
|
||||
mock_query = mocker.patch(
|
||||
"backend.server.v2.store.db.query_raw_with_schema",
|
||||
mocker.AsyncMock(side_effect=[[], [{"count": 0}]]),
|
||||
)
|
||||
|
||||
# Call function with combined filter mode (default)
|
||||
await db.get_store_agents(search_query="test query", filter_mode="combined")
|
||||
|
||||
# Check that the SQL query includes combined filtering
|
||||
first_call_args = mock_query.call_args_list[0]
|
||||
sql_query = first_call_args[0][0]
|
||||
|
||||
# Combined mode requires at least one signal
|
||||
assert "(embedding IS NOT NULL OR search IS NOT NULL)" in sql_query
|
||||
# Combined mode uses "1=1" as pre-filter (no individual score filtering)
|
||||
# But applies RRF score threshold to filter irrelevant results
|
||||
assert "rrf_score" in sql_query
|
||||
assert "rrf_score >=" in sql_query # RRF threshold filter applied
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_and_store_embedding_success(mocker):
|
||||
"""Test that embedding generation and storage works correctly."""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
import prisma.enums
|
||||
@@ -7,6 +8,19 @@ import pydantic
|
||||
from backend.util.models import Pagination
|
||||
|
||||
|
||||
class SearchFilterMode(str, Enum):
|
||||
"""How to combine BM25 and vector search results for filtering.
|
||||
|
||||
- STRICT: Must pass BOTH BM25 AND vector similarity thresholds
|
||||
- PERMISSIVE: Must pass EITHER BM25 OR vector similarity threshold
|
||||
- COMBINED: No pre-filtering, only the combined RRF score matters (default)
|
||||
"""
|
||||
|
||||
STRICT = "strict"
|
||||
PERMISSIVE = "permissive"
|
||||
COMBINED = "combined"
|
||||
|
||||
|
||||
class MyAgent(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
|
||||
@@ -99,18 +99,30 @@ async def get_agents(
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
filter_mode: Literal["strict", "permissive", "combined"] = "permissive",
|
||||
):
|
||||
"""
|
||||
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||
|
||||
When search_query is provided, uses hybrid search combining:
|
||||
- BM25 full-text search (lexical matching)
|
||||
- Vector semantic similarity (meaning-based matching)
|
||||
- Popularity signal (run counts)
|
||||
|
||||
Results are ranked using Reciprocal Rank Fusion (RRF).
|
||||
|
||||
Args:
|
||||
featured (bool, optional): Filter to only show featured agents. Defaults to False.
|
||||
creator (str | None, optional): Filter agents by creator username. Defaults to None.
|
||||
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
|
||||
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
|
||||
search_query (str | None, optional): Search agents by name, subheading and description.
|
||||
category (str | None, optional): Filter agents by category. Defaults to None.
|
||||
page (int, optional): Page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of agents per page. Defaults to 20.
|
||||
filter_mode (str, optional): Controls result filtering when searching:
|
||||
- "strict": Must match BOTH BM25 AND vector thresholds
|
||||
- "permissive": Must match EITHER BM25 OR vector threshold
|
||||
- "combined": No threshold filtering, rely on RRF score (default)
|
||||
|
||||
Returns:
|
||||
StoreAgentsResponse: Paginated list of agents matching the filters
|
||||
@@ -144,6 +156,7 @@ async def get_agents(
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
filter_mode=filter_mode,
|
||||
)
|
||||
return agents
|
||||
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
-- Migration: Add hybrid search infrastructure (BM25 + vector + popularity)
|
||||
-- This migration:
|
||||
-- 1. Creates/updates the tsvector trigger with weighted fields
|
||||
-- 2. Adds GIN index for full-text search performance
|
||||
-- 3. Backfills existing records with tsvector data
|
||||
|
||||
-- Create or replace the trigger function with WEIGHTED tsvector
|
||||
-- Weight A = name (highest priority), B = subHeading, C = description
|
||||
CREATE OR REPLACE FUNCTION update_tsvector_column() RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.search := setweight(to_tsvector('english', COALESCE(NEW.name, '')), 'A') ||
|
||||
setweight(to_tsvector('english', COALESCE(NEW."subHeading", '')), 'B') ||
|
||||
setweight(to_tsvector('english', COALESCE(NEW.description, '')), 'C');
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
-- Drop and recreate trigger to ensure it's active with the updated function
|
||||
DROP TRIGGER IF EXISTS "update_tsvector" ON "StoreListingVersion";
|
||||
CREATE TRIGGER "update_tsvector"
|
||||
BEFORE INSERT OR UPDATE OF name, "subHeading", description ON "StoreListingVersion"
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_tsvector_column();
|
||||
|
||||
-- Create GIN index for full-text search performance
|
||||
CREATE INDEX IF NOT EXISTS idx_store_listing_version_search_gin
|
||||
ON "StoreListingVersion" USING GIN (search);
|
||||
|
||||
-- Backfill existing records with weighted tsvector
|
||||
UPDATE "StoreListingVersion"
|
||||
SET search = setweight(to_tsvector('english', COALESCE(name, '')), 'A') ||
|
||||
setweight(to_tsvector('english', COALESCE("subHeading", '')), 'B') ||
|
||||
setweight(to_tsvector('english', COALESCE(description, '')), 'C')
|
||||
WHERE search IS NULL
|
||||
OR search = ''::tsvector;
|
||||
Reference in New Issue
Block a user