feat(market): better search

This commit is contained in:
Nicholas Tindle
2024-07-31 14:41:39 -05:00
parent a0e4735951
commit f024a8e6ec
3 changed files with 62 additions and 46 deletions

View File

@@ -1,4 +1,4 @@
from typing import Literal
from typing import List, Literal
import prisma.models
import prisma.types
@@ -145,9 +145,9 @@ async def search_db(
query: str,
page: int = 1,
page_size: int = 10,
category: str | None = None,
categories: List[str] | None = None,
description_threshold: int = 60,
sort_by: str = "createdAt",
sort_by: str = "rank",
sort_order: Literal["desc"] | Literal["asc"] = "desc",
):
"""Perform a search for agents based on the provided query string.
@@ -156,9 +156,9 @@ async def search_db(
query (str): the search string
page (int, optional): page for searching. Defaults to 1.
page_size (int, optional): the number of results to return. Defaults to 10.
category (str | None, optional): categorization filters. Defaults to None.
categories (List[str] | None, optional): list of category filters. Defaults to None.
description_threshold (int, optional): number of characters to return. Defaults to 60.
sort_by (str, optional): sort by option. Defaults to "createdAt".
sort_by (str, optional): sort by option. Defaults to "rank".
sort_order ("asc" | "desc", optional): the sort order. Defaults to "desc".
Raises:
@@ -166,41 +166,54 @@ async def search_db(
AgentQueryError: Raises if an unexpected error occurs.
Returns:
_type_: _description_
List[AgentsWithRank]: List of agents matching the search criteria.
"""
try:
# This can all be replaced with a one line full text search when it's supported :')
a = await prisma.client.get_client().query_raw(
query=f"""
WITH query AS (
SELECT to_tsquery(string_agg(lexeme || ':*', ' & ' ORDER BY positions)) AS q
FROM unnest(to_tsvector('${query}'))
)
SELECT
subq.*,
ts_rank(subq.search_text::tsvector, query.q) AS rank
FROM (
SELECT
id,
"createdAt",
"updatedAt",
version,
name,
description,
author,
keywords,
categories,
CAST(search AS TEXT) AS search_text,
graph
FROM "Agents"
) subq, query
ORDER BY rank DESC
LIMIT {page_size};
""",
offset = (page - 1) * page_size
category_filter = ""
if categories:
category_conditions = [f"'{cat}' = ANY(categories)" for cat in categories]
category_filter = "AND (" + " OR ".join(category_conditions) + ")"
# Construct the ORDER BY clause based on the sort_by parameter
if sort_by in ["createdAt", "updatedAt"]:
order_by_clause = f'"{sort_by}" {sort_order.upper()}, rank DESC'
elif sort_by == "name":
order_by_clause = f"name {sort_order.upper()}, rank DESC"
else:
order_by_clause = 'rank DESC, "createdAt" DESC'
sql_query = f"""
WITH query AS (
SELECT to_tsquery(string_agg(lexeme || ':*', ' & ' ORDER BY positions)) AS q
FROM unnest(to_tsvector('{query}'))
)
SELECT
id,
"createdAt",
"updatedAt",
version,
name,
LEFT(description, {description_threshold}) AS description,
author,
keywords,
categories,
graph,
ts_rank(CAST(search AS tsvector), query.q) AS rank
FROM "Agents", query
WHERE 1=1 {category_filter}
ORDER BY {order_by_clause}
LIMIT {page_size}
OFFSET {offset};
"""
results = await prisma.client.get_client().query_raw(
query=sql_query,
model=AgentsWithRank,
)
return a
return results
except PrismaError as e:
raise AgentQueryError(f"Database query failed: {str(e)}")

View File

@@ -1,6 +1,6 @@
from typing import List, Literal
from fastapi import APIRouter
from fastapi import APIRouter, Query
from market.db import search_db
from market.utils.extension_types import AgentsWithRank
@@ -11,14 +11,18 @@ router = APIRouter()
@router.get("/search")
async def search(
query: str,
page: int = 1,
page_size: int = 10,
category: str | None = None,
description_threshold: int = 60,
sort_by: str = "createdAt",
sort_order: Literal["desc"] | Literal["asc"] = "desc",
page: int = Query(1, description="The pagination page to start on"),
page_size: int = Query(10, description="The number of items to return per page"),
categories: List[str] = Query(None, description="The categories to filter by"),
description_threshold: int = Query(
60, description="The number of characters to return from the description"
),
sort_by: str = Query("rank", description="Sorting by column"),
sort_order: Literal["desc", "asc"] = Query(
"desc", description="The sort order based on sort_by"
),
) -> List[AgentsWithRank]:
"""searches endpoint for agnets
"""searches endpoint for agents
Args:
query (str): the search query
@@ -26,14 +30,14 @@ async def search(
page_size (int, optional): the number of items to return per page. Defaults to 10.
category (str | None, optional): the agent category to filter by. None is no filter. Defaults to None.
description_threshold (int, optional): the number of characters to return from the description. Defaults to 60.
sort_by (str, optional): Sorting options. Defaults to "createdAt".
sort_by (str, optional): Sorting by column. Defaults to "rank".
sort_order ('asc' | 'desc', optional): the sort order based on sort_by. Defaults to "desc".
"""
return await search_db(
query=query,
page=page,
page_size=page_size,
category=category,
categories=categories,
description_threshold=description_threshold,
sort_by=sort_by,
sort_order=sort_order,

View File

@@ -3,4 +3,3 @@ from prisma.models import Agents
class AgentsWithRank(Agents):
rank: float
search_text: str