Compare commits

...

2 Commits

Author SHA1 Message Date
Swifty
b119bc788c fmt 2025-11-03 09:15:10 +01:00
Swifty
e66b51d256 revert pg full text search 2025-11-03 09:11:55 +01:00

View File

@@ -2,7 +2,6 @@ import asyncio
import logging import logging
import typing import typing
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Literal
import fastapi import fastapi
import prisma.enums import prisma.enums
@@ -39,10 +38,29 @@ DEFAULT_ADMIN_NAME = "AutoGPT Admin"
DEFAULT_ADMIN_EMAIL = "admin@autogpt.co" DEFAULT_ADMIN_EMAIL = "admin@autogpt.co"
def sanitize_query(query: str | None) -> str | None:
if query is None:
return query
query = query.strip()[:100]
return (
query.replace("\\", "\\\\")
.replace("%", "\\%")
.replace("_", "\\_")
.replace("[", "\\[")
.replace("]", "\\]")
.replace("'", "\\'")
.replace('"', '\\"')
.replace(";", "\\;")
.replace("--", "\\--")
.replace("/*", "\\/*")
.replace("*/", "\\*/")
)
async def get_store_agents( async def get_store_agents(
featured: bool = False, featured: bool = False,
creators: list[str] | None = None, creators: list[str] | None = None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None, sorted_by: str | None = None,
search_query: str | None = None, search_query: str | None = None,
category: str | None = None, category: str | None = None,
page: int = 1, page: int = 1,
@@ -58,52 +76,48 @@ async def get_store_agents(
try: try:
# If search_query is provided, use full-text search # If search_query is provided, use full-text search
if search_query: if search_query:
search_term = sanitize_query(search_query)
if not search_term:
# Return empty results for invalid search query
return backend.server.v2.store.model.StoreAgentsResponse(
agents=[],
pagination=backend.server.v2.store.model.Pagination(
current_page=page,
total_items=0,
total_pages=0,
page_size=page_size,
),
)
offset = (page - 1) * page_size offset = (page - 1) * page_size
# Whitelist allowed order_by columns # Build filter conditions
ALLOWED_ORDER_BY = { filter_conditions = []
"rating": "rating DESC, rank DESC", filter_conditions.append("is_available = true")
"runs": "runs DESC, rank DESC",
"name": "agent_name ASC, rank ASC",
"updated_at": "updated_at DESC, rank DESC",
}
# Validate and get order clause
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
else:
order_by_clause = "updated_at DESC, rank DESC"
# Build WHERE conditions and parameters list
where_parts: list[str] = []
params: list[typing.Any] = [search_query] # $1 - search term
param_index = 2 # Start at $2 for next parameter
# Always filter for available agents
where_parts.append("is_available = true")
if featured: if featured:
where_parts.append("featured = true") filter_conditions.append("featured = true")
if creators:
creator_list = "','".join(creators)
filter_conditions.append(f"creator_username IN ('{creator_list}')")
if category:
filter_conditions.append(f"'{category}' = ANY(categories)")
if creators and creators: where_filter = (
# Use ANY with array parameter " AND ".join(filter_conditions) if filter_conditions else "1=1"
where_parts.append(f"creator_username = ANY(${param_index})") )
params.append(creators)
param_index += 1
if category and category: # Build ORDER BY clause
where_parts.append(f"${param_index} = ANY(categories)") if sorted_by == "rating":
params.append(category) order_by_clause = "rating DESC, rank DESC"
param_index += 1 elif sorted_by == "runs":
order_by_clause = "runs DESC, rank DESC"
elif sorted_by == "name":
order_by_clause = "agent_name ASC, rank DESC"
else:
order_by_clause = "rank DESC, updated_at DESC"
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1" # Execute full-text search query
# Add pagination params
params.extend([page_size, offset])
limit_param = f"${param_index}"
offset_param = f"${param_index + 1}"
# Execute full-text search query with parameterized values
sql_query = f""" sql_query = f"""
SELECT SELECT
slug, slug,
@@ -121,31 +135,29 @@ async def get_store_agents(
updated_at, updated_at,
ts_rank_cd(search, query) AS rank ts_rank_cd(search, query) AS rank
FROM "StoreAgent", FROM "StoreAgent",
plainto_tsquery('english', $1) AS query plainto_tsquery('english', '{search_term}') AS query
WHERE {sql_where_clause} WHERE {where_filter}
AND search @@ query AND search @@ query
ORDER BY {order_by_clause} ORDER BY rank DESC, {order_by_clause}
LIMIT {limit_param} OFFSET {offset_param} LIMIT {page_size} OFFSET {offset}
""" """
# Count query for pagination - only uses search term parameter # Count query for pagination
count_query = f""" count_query = f"""
SELECT COUNT(*) as count SELECT COUNT(*) as count
FROM "StoreAgent", FROM "StoreAgent",
plainto_tsquery('english', $1) AS query plainto_tsquery('english', '{search_term}') AS query
WHERE {sql_where_clause} WHERE {where_filter}
AND search @@ query AND search @@ query
""" """
# Execute both queries with parameters # Execute both queries
agents = await prisma.client.get_client().query_raw( agents = await prisma.client.get_client().query_raw(
typing.cast(typing.LiteralString, sql_query), *params query=typing.cast(typing.LiteralString, sql_query)
) )
# For count, use params without pagination (last 2 params)
count_params = params[:-2]
count_result = await prisma.client.get_client().query_raw( count_result = await prisma.client.get_client().query_raw(
typing.cast(typing.LiteralString, count_query), *count_params query=typing.cast(typing.LiteralString, count_query)
) )
total = count_result[0]["count"] if count_result else 0 total = count_result[0]["count"] if count_result else 0
@@ -422,7 +434,7 @@ async def get_store_agent_by_version_id(
async def get_store_creators( async def get_store_creators(
featured: bool = False, featured: bool = False,
search_query: str | None = None, search_query: str | None = None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None, sorted_by: str | None = None,
page: int = 1, page: int = 1,
page_size: int = 20, page_size: int = 20,
) -> backend.server.v2.store.model.CreatorsResponse: ) -> backend.server.v2.store.model.CreatorsResponse:
@@ -1736,21 +1748,22 @@ async def get_admin_listings_with_versions(
if status: if status:
where_dict["Versions"] = {"some": {"submissionStatus": status}} where_dict["Versions"] = {"some": {"submissionStatus": status}}
if search_query: sanitized_query = sanitize_query(search_query)
if sanitized_query:
# Find users with matching email # Find users with matching email
matching_users = await prisma.models.User.prisma().find_many( matching_users = await prisma.models.User.prisma().find_many(
where={"email": {"contains": search_query, "mode": "insensitive"}}, where={"email": {"contains": sanitized_query, "mode": "insensitive"}},
) )
user_ids = [user.id for user in matching_users] user_ids = [user.id for user in matching_users]
# Set up OR conditions # Set up OR conditions
where_dict["OR"] = [ where_dict["OR"] = [
{"slug": {"contains": search_query, "mode": "insensitive"}}, {"slug": {"contains": sanitized_query, "mode": "insensitive"}},
{ {
"Versions": { "Versions": {
"some": { "some": {
"name": {"contains": search_query, "mode": "insensitive"} "name": {"contains": sanitized_query, "mode": "insensitive"}
} }
} }
}, },
@@ -1758,7 +1771,7 @@ async def get_admin_listings_with_versions(
"Versions": { "Versions": {
"some": { "some": {
"description": { "description": {
"contains": search_query, "contains": sanitized_query,
"mode": "insensitive", "mode": "insensitive",
} }
} }
@@ -1768,7 +1781,7 @@ async def get_admin_listings_with_versions(
"Versions": { "Versions": {
"some": { "some": {
"subHeading": { "subHeading": {
"contains": search_query, "contains": sanitized_query,
"mode": "insensitive", "mode": "insensitive",
} }
} }