mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-12 15:55:03 -05:00
Compare commits
2 Commits
fix/claude
...
hotfix/rev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b119bc788c | ||
|
|
e66b51d256 |
@@ -2,7 +2,6 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import typing
|
import typing
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
@@ -39,10 +38,29 @@ DEFAULT_ADMIN_NAME = "AutoGPT Admin"
|
|||||||
DEFAULT_ADMIN_EMAIL = "admin@autogpt.co"
|
DEFAULT_ADMIN_EMAIL = "admin@autogpt.co"
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_query(query: str | None) -> str | None:
|
||||||
|
if query is None:
|
||||||
|
return query
|
||||||
|
query = query.strip()[:100]
|
||||||
|
return (
|
||||||
|
query.replace("\\", "\\\\")
|
||||||
|
.replace("%", "\\%")
|
||||||
|
.replace("_", "\\_")
|
||||||
|
.replace("[", "\\[")
|
||||||
|
.replace("]", "\\]")
|
||||||
|
.replace("'", "\\'")
|
||||||
|
.replace('"', '\\"')
|
||||||
|
.replace(";", "\\;")
|
||||||
|
.replace("--", "\\--")
|
||||||
|
.replace("/*", "\\/*")
|
||||||
|
.replace("*/", "\\*/")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_store_agents(
|
async def get_store_agents(
|
||||||
featured: bool = False,
|
featured: bool = False,
|
||||||
creators: list[str] | None = None,
|
creators: list[str] | None = None,
|
||||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
sorted_by: str | None = None,
|
||||||
search_query: str | None = None,
|
search_query: str | None = None,
|
||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
@@ -58,52 +76,48 @@ async def get_store_agents(
|
|||||||
try:
|
try:
|
||||||
# If search_query is provided, use full-text search
|
# If search_query is provided, use full-text search
|
||||||
if search_query:
|
if search_query:
|
||||||
|
search_term = sanitize_query(search_query)
|
||||||
|
if not search_term:
|
||||||
|
# Return empty results for invalid search query
|
||||||
|
return backend.server.v2.store.model.StoreAgentsResponse(
|
||||||
|
agents=[],
|
||||||
|
pagination=backend.server.v2.store.model.Pagination(
|
||||||
|
current_page=page,
|
||||||
|
total_items=0,
|
||||||
|
total_pages=0,
|
||||||
|
page_size=page_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
# Whitelist allowed order_by columns
|
# Build filter conditions
|
||||||
ALLOWED_ORDER_BY = {
|
filter_conditions = []
|
||||||
"rating": "rating DESC, rank DESC",
|
filter_conditions.append("is_available = true")
|
||||||
"runs": "runs DESC, rank DESC",
|
|
||||||
"name": "agent_name ASC, rank ASC",
|
|
||||||
"updated_at": "updated_at DESC, rank DESC",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Validate and get order clause
|
|
||||||
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
|
|
||||||
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
|
|
||||||
else:
|
|
||||||
order_by_clause = "updated_at DESC, rank DESC"
|
|
||||||
|
|
||||||
# Build WHERE conditions and parameters list
|
|
||||||
where_parts: list[str] = []
|
|
||||||
params: list[typing.Any] = [search_query] # $1 - search term
|
|
||||||
param_index = 2 # Start at $2 for next parameter
|
|
||||||
|
|
||||||
# Always filter for available agents
|
|
||||||
where_parts.append("is_available = true")
|
|
||||||
|
|
||||||
if featured:
|
if featured:
|
||||||
where_parts.append("featured = true")
|
filter_conditions.append("featured = true")
|
||||||
|
if creators:
|
||||||
|
creator_list = "','".join(creators)
|
||||||
|
filter_conditions.append(f"creator_username IN ('{creator_list}')")
|
||||||
|
if category:
|
||||||
|
filter_conditions.append(f"'{category}' = ANY(categories)")
|
||||||
|
|
||||||
if creators and creators:
|
where_filter = (
|
||||||
# Use ANY with array parameter
|
" AND ".join(filter_conditions) if filter_conditions else "1=1"
|
||||||
where_parts.append(f"creator_username = ANY(${param_index})")
|
)
|
||||||
params.append(creators)
|
|
||||||
param_index += 1
|
|
||||||
|
|
||||||
if category and category:
|
# Build ORDER BY clause
|
||||||
where_parts.append(f"${param_index} = ANY(categories)")
|
if sorted_by == "rating":
|
||||||
params.append(category)
|
order_by_clause = "rating DESC, rank DESC"
|
||||||
param_index += 1
|
elif sorted_by == "runs":
|
||||||
|
order_by_clause = "runs DESC, rank DESC"
|
||||||
|
elif sorted_by == "name":
|
||||||
|
order_by_clause = "agent_name ASC, rank DESC"
|
||||||
|
else:
|
||||||
|
order_by_clause = "rank DESC, updated_at DESC"
|
||||||
|
|
||||||
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
|
# Execute full-text search query
|
||||||
|
|
||||||
# Add pagination params
|
|
||||||
params.extend([page_size, offset])
|
|
||||||
limit_param = f"${param_index}"
|
|
||||||
offset_param = f"${param_index + 1}"
|
|
||||||
|
|
||||||
# Execute full-text search query with parameterized values
|
|
||||||
sql_query = f"""
|
sql_query = f"""
|
||||||
SELECT
|
SELECT
|
||||||
slug,
|
slug,
|
||||||
@@ -121,31 +135,29 @@ async def get_store_agents(
|
|||||||
updated_at,
|
updated_at,
|
||||||
ts_rank_cd(search, query) AS rank
|
ts_rank_cd(search, query) AS rank
|
||||||
FROM "StoreAgent",
|
FROM "StoreAgent",
|
||||||
plainto_tsquery('english', $1) AS query
|
plainto_tsquery('english', '{search_term}') AS query
|
||||||
WHERE {sql_where_clause}
|
WHERE {where_filter}
|
||||||
AND search @@ query
|
AND search @@ query
|
||||||
ORDER BY {order_by_clause}
|
ORDER BY rank DESC, {order_by_clause}
|
||||||
LIMIT {limit_param} OFFSET {offset_param}
|
LIMIT {page_size} OFFSET {offset}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Count query for pagination - only uses search term parameter
|
# Count query for pagination
|
||||||
count_query = f"""
|
count_query = f"""
|
||||||
SELECT COUNT(*) as count
|
SELECT COUNT(*) as count
|
||||||
FROM "StoreAgent",
|
FROM "StoreAgent",
|
||||||
plainto_tsquery('english', $1) AS query
|
plainto_tsquery('english', '{search_term}') AS query
|
||||||
WHERE {sql_where_clause}
|
WHERE {where_filter}
|
||||||
AND search @@ query
|
AND search @@ query
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Execute both queries with parameters
|
# Execute both queries
|
||||||
agents = await prisma.client.get_client().query_raw(
|
agents = await prisma.client.get_client().query_raw(
|
||||||
typing.cast(typing.LiteralString, sql_query), *params
|
query=typing.cast(typing.LiteralString, sql_query)
|
||||||
)
|
)
|
||||||
|
|
||||||
# For count, use params without pagination (last 2 params)
|
|
||||||
count_params = params[:-2]
|
|
||||||
count_result = await prisma.client.get_client().query_raw(
|
count_result = await prisma.client.get_client().query_raw(
|
||||||
typing.cast(typing.LiteralString, count_query), *count_params
|
query=typing.cast(typing.LiteralString, count_query)
|
||||||
)
|
)
|
||||||
|
|
||||||
total = count_result[0]["count"] if count_result else 0
|
total = count_result[0]["count"] if count_result else 0
|
||||||
@@ -422,7 +434,7 @@ async def get_store_agent_by_version_id(
|
|||||||
async def get_store_creators(
|
async def get_store_creators(
|
||||||
featured: bool = False,
|
featured: bool = False,
|
||||||
search_query: str | None = None,
|
search_query: str | None = None,
|
||||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
sorted_by: str | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
) -> backend.server.v2.store.model.CreatorsResponse:
|
) -> backend.server.v2.store.model.CreatorsResponse:
|
||||||
@@ -1736,21 +1748,22 @@ async def get_admin_listings_with_versions(
|
|||||||
if status:
|
if status:
|
||||||
where_dict["Versions"] = {"some": {"submissionStatus": status}}
|
where_dict["Versions"] = {"some": {"submissionStatus": status}}
|
||||||
|
|
||||||
if search_query:
|
sanitized_query = sanitize_query(search_query)
|
||||||
|
if sanitized_query:
|
||||||
# Find users with matching email
|
# Find users with matching email
|
||||||
matching_users = await prisma.models.User.prisma().find_many(
|
matching_users = await prisma.models.User.prisma().find_many(
|
||||||
where={"email": {"contains": search_query, "mode": "insensitive"}},
|
where={"email": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||||
)
|
)
|
||||||
|
|
||||||
user_ids = [user.id for user in matching_users]
|
user_ids = [user.id for user in matching_users]
|
||||||
|
|
||||||
# Set up OR conditions
|
# Set up OR conditions
|
||||||
where_dict["OR"] = [
|
where_dict["OR"] = [
|
||||||
{"slug": {"contains": search_query, "mode": "insensitive"}},
|
{"slug": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||||
{
|
{
|
||||||
"Versions": {
|
"Versions": {
|
||||||
"some": {
|
"some": {
|
||||||
"name": {"contains": search_query, "mode": "insensitive"}
|
"name": {"contains": sanitized_query, "mode": "insensitive"}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -1758,7 +1771,7 @@ async def get_admin_listings_with_versions(
|
|||||||
"Versions": {
|
"Versions": {
|
||||||
"some": {
|
"some": {
|
||||||
"description": {
|
"description": {
|
||||||
"contains": search_query,
|
"contains": sanitized_query,
|
||||||
"mode": "insensitive",
|
"mode": "insensitive",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1768,7 +1781,7 @@ async def get_admin_listings_with_versions(
|
|||||||
"Versions": {
|
"Versions": {
|
||||||
"some": {
|
"some": {
|
||||||
"subHeading": {
|
"subHeading": {
|
||||||
"contains": search_query,
|
"contains": sanitized_query,
|
||||||
"mode": "insensitive",
|
"mode": "insensitive",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user