mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(market): better search
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Literal
|
||||
from typing import List, Literal
|
||||
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
@@ -145,9 +145,9 @@ async def search_db(
|
||||
query: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
category: str | None = None,
|
||||
categories: List[str] | None = None,
|
||||
description_threshold: int = 60,
|
||||
sort_by: str = "createdAt",
|
||||
sort_by: str = "rank",
|
||||
sort_order: Literal["desc"] | Literal["asc"] = "desc",
|
||||
):
|
||||
"""Perform a search for agents based on the provided query string.
|
||||
@@ -156,9 +156,9 @@ async def search_db(
|
||||
query (str): the search string
|
||||
page (int, optional): page for searching. Defaults to 1.
|
||||
page_size (int, optional): the number of results to return. Defaults to 10.
|
||||
category (str | None, optional): categorization filters. Defaults to None.
|
||||
categories (List[str] | None, optional): list of category filters. Defaults to None.
|
||||
description_threshold (int, optional): number of characters to return. Defaults to 60.
|
||||
sort_by (str, optional): sort by option. Defaults to "createdAt".
|
||||
sort_by (str, optional): sort by option. Defaults to "rank".
|
||||
sort_order ("asc" | "desc", optional): the sort order. Defaults to "desc".
|
||||
|
||||
Raises:
|
||||
@@ -166,41 +166,54 @@ async def search_db(
|
||||
AgentQueryError: Raises if an unexpected error occurs.
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
List[AgentsWithRank]: List of agents matching the search criteria.
|
||||
"""
|
||||
try:
|
||||
# This can all be replaced with a one line full text search when it's supported :')
|
||||
a = await prisma.client.get_client().query_raw(
|
||||
query=f"""
|
||||
WITH query AS (
|
||||
SELECT to_tsquery(string_agg(lexeme || ':*', ' & ' ORDER BY positions)) AS q
|
||||
FROM unnest(to_tsvector('${query}'))
|
||||
)
|
||||
SELECT
|
||||
subq.*,
|
||||
ts_rank(subq.search_text::tsvector, query.q) AS rank
|
||||
FROM (
|
||||
SELECT
|
||||
id,
|
||||
"createdAt",
|
||||
"updatedAt",
|
||||
version,
|
||||
name,
|
||||
description,
|
||||
author,
|
||||
keywords,
|
||||
categories,
|
||||
CAST(search AS TEXT) AS search_text,
|
||||
graph
|
||||
FROM "Agents"
|
||||
) subq, query
|
||||
ORDER BY rank DESC
|
||||
LIMIT {page_size};
|
||||
""",
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
category_filter = ""
|
||||
if categories:
|
||||
category_conditions = [f"'{cat}' = ANY(categories)" for cat in categories]
|
||||
category_filter = "AND (" + " OR ".join(category_conditions) + ")"
|
||||
|
||||
# Construct the ORDER BY clause based on the sort_by parameter
|
||||
if sort_by in ["createdAt", "updatedAt"]:
|
||||
order_by_clause = f'"{sort_by}" {sort_order.upper()}, rank DESC'
|
||||
elif sort_by == "name":
|
||||
order_by_clause = f"name {sort_order.upper()}, rank DESC"
|
||||
else:
|
||||
order_by_clause = 'rank DESC, "createdAt" DESC'
|
||||
|
||||
sql_query = f"""
|
||||
WITH query AS (
|
||||
SELECT to_tsquery(string_agg(lexeme || ':*', ' & ' ORDER BY positions)) AS q
|
||||
FROM unnest(to_tsvector('{query}'))
|
||||
)
|
||||
SELECT
|
||||
id,
|
||||
"createdAt",
|
||||
"updatedAt",
|
||||
version,
|
||||
name,
|
||||
LEFT(description, {description_threshold}) AS description,
|
||||
author,
|
||||
keywords,
|
||||
categories,
|
||||
graph,
|
||||
ts_rank(CAST(search AS tsvector), query.q) AS rank
|
||||
FROM "Agents", query
|
||||
WHERE 1=1 {category_filter}
|
||||
ORDER BY {order_by_clause}
|
||||
LIMIT {page_size}
|
||||
OFFSET {offset};
|
||||
"""
|
||||
|
||||
results = await prisma.client.get_client().query_raw(
|
||||
query=sql_query,
|
||||
model=AgentsWithRank,
|
||||
)
|
||||
|
||||
return a
|
||||
return results
|
||||
|
||||
except PrismaError as e:
|
||||
raise AgentQueryError(f"Database query failed: {str(e)}")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List, Literal
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Query
|
||||
|
||||
from market.db import search_db
|
||||
from market.utils.extension_types import AgentsWithRank
|
||||
@@ -11,14 +11,18 @@ router = APIRouter()
|
||||
@router.get("/search")
|
||||
async def search(
|
||||
query: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
category: str | None = None,
|
||||
description_threshold: int = 60,
|
||||
sort_by: str = "createdAt",
|
||||
sort_order: Literal["desc"] | Literal["asc"] = "desc",
|
||||
page: int = Query(1, description="The pagination page to start on"),
|
||||
page_size: int = Query(10, description="The number of items to return per page"),
|
||||
categories: List[str] = Query(None, description="The categories to filter by"),
|
||||
description_threshold: int = Query(
|
||||
60, description="The number of characters to return from the description"
|
||||
),
|
||||
sort_by: str = Query("rank", description="Sorting by column"),
|
||||
sort_order: Literal["desc", "asc"] = Query(
|
||||
"desc", description="The sort order based on sort_by"
|
||||
),
|
||||
) -> List[AgentsWithRank]:
|
||||
"""searches endpoint for agnets
|
||||
"""searches endpoint for agents
|
||||
|
||||
Args:
|
||||
query (str): the search query
|
||||
@@ -26,14 +30,14 @@ async def search(
|
||||
page_size (int, optional): the number of items to return per page. Defaults to 10.
|
||||
category (str | None, optional): the agent category to filter by. None is no filter. Defaults to None.
|
||||
description_threshold (int, optional): the number of characters to return from the description. Defaults to 60.
|
||||
sort_by (str, optional): Sorting options. Defaults to "createdAt".
|
||||
sort_by (str, optional): Sorting by column. Defaults to "rank".
|
||||
sort_order ('asc' | 'desc', optional): the sort order based on sort_by. Defaults to "desc".
|
||||
"""
|
||||
return await search_db(
|
||||
query=query,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
category=category,
|
||||
categories=categories,
|
||||
description_threshold=description_threshold,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
|
||||
@@ -3,4 +3,3 @@ from prisma.models import Agents
|
||||
|
||||
class AgentsWithRank(Agents):
|
||||
rank: float
|
||||
search_text: str
|
||||
|
||||
Reference in New Issue
Block a user