Compare commits

...

2 Commits

Author SHA1 Message Date
Swifty
b119bc788c fmt 2025-11-03 09:15:10 +01:00
Swifty
e66b51d256 revert pg full text search 2025-11-03 09:11:55 +01:00

View File

@@ -2,7 +2,6 @@ import asyncio
import logging
import typing
from datetime import datetime, timezone
from typing import Literal
import fastapi
import prisma.enums
@@ -39,10 +38,29 @@ DEFAULT_ADMIN_NAME = "AutoGPT Admin"
DEFAULT_ADMIN_EMAIL = "admin@autogpt.co"
def sanitize_query(query: str | None) -> str | None:
if query is None:
return query
query = query.strip()[:100]
return (
query.replace("\\", "\\\\")
.replace("%", "\\%")
.replace("_", "\\_")
.replace("[", "\\[")
.replace("]", "\\]")
.replace("'", "\\'")
.replace('"', '\\"')
.replace(";", "\\;")
.replace("--", "\\--")
.replace("/*", "\\/*")
.replace("*/", "\\*/")
)
async def get_store_agents(
featured: bool = False,
creators: list[str] | None = None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
sorted_by: str | None = None,
search_query: str | None = None,
category: str | None = None,
page: int = 1,
@@ -58,52 +76,48 @@ async def get_store_agents(
try:
# If search_query is provided, use full-text search
if search_query:
search_term = sanitize_query(search_query)
if not search_term:
# Return empty results for invalid search query
return backend.server.v2.store.model.StoreAgentsResponse(
agents=[],
pagination=backend.server.v2.store.model.Pagination(
current_page=page,
total_items=0,
total_pages=0,
page_size=page_size,
),
)
offset = (page - 1) * page_size
# Whitelist allowed order_by columns
ALLOWED_ORDER_BY = {
"rating": "rating DESC, rank DESC",
"runs": "runs DESC, rank DESC",
"name": "agent_name ASC, rank ASC",
"updated_at": "updated_at DESC, rank DESC",
}
# Validate and get order clause
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
else:
order_by_clause = "updated_at DESC, rank DESC"
# Build WHERE conditions and parameters list
where_parts: list[str] = []
params: list[typing.Any] = [search_query] # $1 - search term
param_index = 2 # Start at $2 for next parameter
# Always filter for available agents
where_parts.append("is_available = true")
# Build filter conditions
filter_conditions = []
filter_conditions.append("is_available = true")
if featured:
where_parts.append("featured = true")
filter_conditions.append("featured = true")
if creators:
creator_list = "','".join(creators)
filter_conditions.append(f"creator_username IN ('{creator_list}')")
if category:
filter_conditions.append(f"'{category}' = ANY(categories)")
if creators and creators:
# Use ANY with array parameter
where_parts.append(f"creator_username = ANY(${param_index})")
params.append(creators)
param_index += 1
where_filter = (
" AND ".join(filter_conditions) if filter_conditions else "1=1"
)
if category and category:
where_parts.append(f"${param_index} = ANY(categories)")
params.append(category)
param_index += 1
# Build ORDER BY clause
if sorted_by == "rating":
order_by_clause = "rating DESC, rank DESC"
elif sorted_by == "runs":
order_by_clause = "runs DESC, rank DESC"
elif sorted_by == "name":
order_by_clause = "agent_name ASC, rank DESC"
else:
order_by_clause = "rank DESC, updated_at DESC"
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
# Add pagination params
params.extend([page_size, offset])
limit_param = f"${param_index}"
offset_param = f"${param_index + 1}"
# Execute full-text search query with parameterized values
# Execute full-text search query
sql_query = f"""
SELECT
slug,
@@ -121,31 +135,29 @@ async def get_store_agents(
updated_at,
ts_rank_cd(search, query) AS rank
FROM "StoreAgent",
plainto_tsquery('english', $1) AS query
WHERE {sql_where_clause}
plainto_tsquery('english', '{search_term}') AS query
WHERE {where_filter}
AND search @@ query
ORDER BY {order_by_clause}
LIMIT {limit_param} OFFSET {offset_param}
ORDER BY rank DESC, {order_by_clause}
LIMIT {page_size} OFFSET {offset}
"""
# Count query for pagination - only uses search term parameter
# Count query for pagination
count_query = f"""
SELECT COUNT(*) as count
FROM "StoreAgent",
plainto_tsquery('english', $1) AS query
WHERE {sql_where_clause}
plainto_tsquery('english', '{search_term}') AS query
WHERE {where_filter}
AND search @@ query
"""
# Execute both queries with parameters
# Execute both queries
agents = await prisma.client.get_client().query_raw(
typing.cast(typing.LiteralString, sql_query), *params
query=typing.cast(typing.LiteralString, sql_query)
)
# For count, use params without pagination (last 2 params)
count_params = params[:-2]
count_result = await prisma.client.get_client().query_raw(
typing.cast(typing.LiteralString, count_query), *count_params
query=typing.cast(typing.LiteralString, count_query)
)
total = count_result[0]["count"] if count_result else 0
@@ -422,7 +434,7 @@ async def get_store_agent_by_version_id(
async def get_store_creators(
featured: bool = False,
search_query: str | None = None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
sorted_by: str | None = None,
page: int = 1,
page_size: int = 20,
) -> backend.server.v2.store.model.CreatorsResponse:
@@ -1736,21 +1748,22 @@ async def get_admin_listings_with_versions(
if status:
where_dict["Versions"] = {"some": {"submissionStatus": status}}
if search_query:
sanitized_query = sanitize_query(search_query)
if sanitized_query:
# Find users with matching email
matching_users = await prisma.models.User.prisma().find_many(
where={"email": {"contains": search_query, "mode": "insensitive"}},
where={"email": {"contains": sanitized_query, "mode": "insensitive"}},
)
user_ids = [user.id for user in matching_users]
# Set up OR conditions
where_dict["OR"] = [
{"slug": {"contains": search_query, "mode": "insensitive"}},
{"slug": {"contains": sanitized_query, "mode": "insensitive"}},
{
"Versions": {
"some": {
"name": {"contains": search_query, "mode": "insensitive"}
"name": {"contains": sanitized_query, "mode": "insensitive"}
}
}
},
@@ -1758,7 +1771,7 @@ async def get_admin_listings_with_versions(
"Versions": {
"some": {
"description": {
"contains": search_query,
"contains": sanitized_query,
"mode": "insensitive",
}
}
@@ -1768,7 +1781,7 @@ async def get_admin_listings_with_versions(
"Versions": {
"some": {
"subHeading": {
"contains": search_query,
"contains": sanitized_query,
"mode": "insensitive",
}
}