mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend): Further enhance sanitization of SQL raw queries (#11279)
### Changes 🏗️ Enhanced SQL query security in the store search functionality by implementing proper parameterization to prevent SQL injection vulnerabilities. **Security Improvements:** - Replaced string interpolation with PostgreSQL positional parameters (`$1`, `$2`, etc.) for all user inputs - Added ORDER BY whitelist validation to prevent injection via `sorted_by` parameter - Parameterized search term, creators array, category, and pagination values - Fixed variable naming conflict (`sql_where_clause` vs `where_clause`) **Testing:** - Added 4 comprehensive tests validating SQL injection prevention across different attack vectors - Tests verify that malicious input in search queries, filters, sorting, and categories are safely handled - All 10 tests in db_test.py pass successfully ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] All existing tests pass (10/10 tests passing) - [x] New security tests validate SQL injection prevention - [x] Verified parameterized queries handle malicious input safely - [x] Code formatting passes (`poetry run format`) #### For configuration changes: - [x] `.env.default` is updated or already compatible with my changes - [x] `docker-compose.yml` is updated or already compatible with my changes - [x] I have included a list of my configuration changes in the PR description (under **Changes**) *Note: No configuration changes required for this security fix*
This commit is contained in:
@@ -38,25 +38,6 @@ DEFAULT_ADMIN_NAME = "AutoGPT Admin"
|
||||
DEFAULT_ADMIN_EMAIL = "admin@autogpt.co"
|
||||
|
||||
|
||||
def sanitize_query(query: str | None) -> str | None:
|
||||
if query is None:
|
||||
return query
|
||||
query = query.strip()[:100]
|
||||
return (
|
||||
query.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
.replace("[", "\\[")
|
||||
.replace("]", "\\]")
|
||||
.replace("'", "\\'")
|
||||
.replace('"', '\\"')
|
||||
.replace(";", "\\;")
|
||||
.replace("--", "\\--")
|
||||
.replace("/*", "\\/*")
|
||||
.replace("*/", "\\*/")
|
||||
)
|
||||
|
||||
|
||||
async def get_store_agents(
|
||||
featured: bool = False,
|
||||
creators: list[str] | None = None,
|
||||
@@ -73,60 +54,55 @@ async def get_store_agents(
|
||||
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||
)
|
||||
|
||||
sanitized_creators = []
|
||||
if creators:
|
||||
for c in creators:
|
||||
sanitized_creators.append(sanitize_query(c))
|
||||
|
||||
sanitized_category = None
|
||||
if category:
|
||||
sanitized_category = sanitize_query(category)
|
||||
|
||||
try:
|
||||
# If search_query is provided, use full-text search
|
||||
if search_query:
|
||||
search_term = sanitize_query(search_query)
|
||||
if not search_term:
|
||||
# Return empty results for invalid search query
|
||||
return backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=page,
|
||||
total_items=0,
|
||||
total_pages=0,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Build filter conditions
|
||||
filter_conditions = []
|
||||
filter_conditions.append("is_available = true")
|
||||
# Whitelist allowed order_by columns
|
||||
ALLOWED_ORDER_BY = {
|
||||
"rating": "rating DESC, rank DESC",
|
||||
"runs": "runs DESC, rank DESC",
|
||||
"name": "agent_name ASC, rank DESC",
|
||||
"updated_at": "updated_at DESC, rank DESC",
|
||||
}
|
||||
|
||||
# Validate and get order clause
|
||||
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
|
||||
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
|
||||
else:
|
||||
order_by_clause = "updated_at DESC, rank DESC"
|
||||
|
||||
# Build WHERE conditions and parameters list
|
||||
where_parts: list[str] = []
|
||||
params: list[typing.Any] = [search_query] # $1 - search term
|
||||
param_index = 2 # Start at $2 for next parameter
|
||||
|
||||
# Always filter for available agents
|
||||
where_parts.append("is_available = true")
|
||||
|
||||
if featured:
|
||||
filter_conditions.append("featured = true")
|
||||
if creators:
|
||||
creator_list = "','".join(sanitized_creators)
|
||||
filter_conditions.append(f"creator_username IN ('{creator_list}')")
|
||||
if category:
|
||||
filter_conditions.append(f"'{sanitized_category}' = ANY(categories)")
|
||||
where_parts.append("featured = true")
|
||||
|
||||
where_filter = (
|
||||
" AND ".join(filter_conditions) if filter_conditions else "1=1"
|
||||
)
|
||||
if creators and creators:
|
||||
# Use ANY with array parameter
|
||||
where_parts.append(f"creator_username = ANY(${param_index})")
|
||||
params.append(creators)
|
||||
param_index += 1
|
||||
|
||||
# Build ORDER BY clause
|
||||
if sorted_by == "rating":
|
||||
order_by_clause = "rating DESC, rank DESC"
|
||||
elif sorted_by == "runs":
|
||||
order_by_clause = "runs DESC, rank DESC"
|
||||
elif sorted_by == "name":
|
||||
order_by_clause = "agent_name ASC, rank DESC"
|
||||
else:
|
||||
order_by_clause = "rank DESC, updated_at DESC"
|
||||
if category and category:
|
||||
where_parts.append(f"${param_index} = ANY(categories)")
|
||||
params.append(category)
|
||||
param_index += 1
|
||||
|
||||
# Execute full-text search query
|
||||
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
limit_param = f"${param_index}"
|
||||
offset_param = f"${param_index + 1}"
|
||||
|
||||
# Execute full-text search query with parameterized values
|
||||
sql_query = f"""
|
||||
SELECT
|
||||
slug,
|
||||
@@ -144,29 +120,31 @@ async def get_store_agents(
|
||||
updated_at,
|
||||
ts_rank_cd(search, query) AS rank
|
||||
FROM "StoreAgent",
|
||||
plainto_tsquery('english', '{search_term}') AS query
|
||||
WHERE {where_filter}
|
||||
plainto_tsquery('english', $1) AS query
|
||||
WHERE {sql_where_clause}
|
||||
AND search @@ query
|
||||
ORDER BY rank DESC, {order_by_clause}
|
||||
LIMIT {page_size} OFFSET {offset}
|
||||
ORDER BY {order_by_clause}
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
# Count query for pagination
|
||||
# Count query for pagination - only uses search term parameter
|
||||
count_query = f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "StoreAgent",
|
||||
plainto_tsquery('english', '{search_term}') AS query
|
||||
WHERE {where_filter}
|
||||
plainto_tsquery('english', $1) AS query
|
||||
WHERE {sql_where_clause}
|
||||
AND search @@ query
|
||||
"""
|
||||
|
||||
# Execute both queries
|
||||
# Execute both queries with parameters
|
||||
agents = await prisma.client.get_client().query_raw(
|
||||
query=typing.cast(typing.LiteralString, sql_query)
|
||||
typing.cast(typing.LiteralString, sql_query), *params
|
||||
)
|
||||
|
||||
# For count, use params without pagination (last 2 params)
|
||||
count_params = params[:-2]
|
||||
count_result = await prisma.client.get_client().query_raw(
|
||||
query=typing.cast(typing.LiteralString, count_query)
|
||||
typing.cast(typing.LiteralString, count_query), *count_params
|
||||
)
|
||||
|
||||
total = count_result[0]["count"] if count_result else 0
|
||||
@@ -200,9 +178,9 @@ async def get_store_agents(
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
where_clause["creator_username"] = {"in": sanitized_creators}
|
||||
if sanitized_category:
|
||||
where_clause["categories"] = {"has": sanitized_category}
|
||||
where_clause["creator_username"] = {"in": creators}
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
order_by = []
|
||||
if sorted_by == "rating":
|
||||
@@ -1757,22 +1735,21 @@ async def get_admin_listings_with_versions(
|
||||
if status:
|
||||
where_dict["Versions"] = {"some": {"submissionStatus": status}}
|
||||
|
||||
sanitized_query = sanitize_query(search_query)
|
||||
if sanitized_query:
|
||||
if search_query:
|
||||
# Find users with matching email
|
||||
matching_users = await prisma.models.User.prisma().find_many(
|
||||
where={"email": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
where={"email": {"contains": search_query, "mode": "insensitive"}},
|
||||
)
|
||||
|
||||
user_ids = [user.id for user in matching_users]
|
||||
|
||||
# Set up OR conditions
|
||||
where_dict["OR"] = [
|
||||
{"slug": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
{"slug": {"contains": search_query, "mode": "insensitive"}},
|
||||
{
|
||||
"Versions": {
|
||||
"some": {
|
||||
"name": {"contains": sanitized_query, "mode": "insensitive"}
|
||||
"name": {"contains": search_query, "mode": "insensitive"}
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -1780,7 +1757,7 @@ async def get_admin_listings_with_versions(
|
||||
"Versions": {
|
||||
"some": {
|
||||
"description": {
|
||||
"contains": sanitized_query,
|
||||
"contains": search_query,
|
||||
"mode": "insensitive",
|
||||
}
|
||||
}
|
||||
@@ -1790,7 +1767,7 @@ async def get_admin_listings_with_versions(
|
||||
"Versions": {
|
||||
"some": {
|
||||
"subHeading": {
|
||||
"contains": sanitized_query,
|
||||
"contains": search_query,
|
||||
"mode": "insensitive",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ async def setup_prisma():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents(mocker):
|
||||
# Mock data
|
||||
mock_agents = [
|
||||
@@ -64,7 +64,7 @@ async def test_get_store_agents(mocker):
|
||||
mock_store_agent.return_value.count.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agent_details(mocker):
|
||||
# Mock data
|
||||
mock_agent = prisma.models.StoreAgent(
|
||||
@@ -173,7 +173,7 @@ async def test_get_store_agent_details(mocker):
|
||||
mock_store_listing_db.return_value.find_first.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_creator_details(mocker):
|
||||
# Mock data
|
||||
mock_creator_data = prisma.models.Creator(
|
||||
@@ -210,7 +210,7 @@ async def test_get_store_creator_details(mocker):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_create_store_submission(mocker):
|
||||
# Mock data
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
@@ -282,7 +282,7 @@ async def test_create_store_submission(mocker):
|
||||
mock_store_listing.return_value.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_update_profile(mocker):
|
||||
# Mock data
|
||||
mock_profile = prisma.models.Profile(
|
||||
@@ -327,7 +327,7 @@ async def test_update_profile(mocker):
|
||||
mock_profile_db.return_value.update.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_user_profile(mocker):
|
||||
# Mock data
|
||||
mock_profile = prisma.models.Profile(
|
||||
@@ -359,3 +359,63 @@ async def test_get_user_profile(mocker):
|
||||
assert result.description == "Test description"
|
||||
assert result.links == ["link1", "link2"]
|
||||
assert result.avatar_url == "avatar.jpg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_with_search_parameterized(mocker):
|
||||
"""Test that search query uses parameterized SQL - validates the fix works"""
|
||||
|
||||
# Call function with search query containing potential SQL injection
|
||||
malicious_search = "test'; DROP TABLE StoreAgent; --"
|
||||
result = await db.get_store_agents(search_query=malicious_search)
|
||||
|
||||
# Verify query executed safely
|
||||
assert isinstance(result.agents, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_with_search_and_filters_parameterized():
|
||||
"""Test parameterized SQL with multiple filters"""
|
||||
|
||||
# Call with multiple filters including potential injection attempts
|
||||
result = await db.get_store_agents(
|
||||
search_query="test",
|
||||
creators=["creator1'; DROP TABLE Users; --", "creator2"],
|
||||
category="AI'; DELETE FROM StoreAgent; --",
|
||||
featured=True,
|
||||
sorted_by="rating",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify the query executed without error
|
||||
assert isinstance(result.agents, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_search_with_invalid_sort_by():
|
||||
"""Test that invalid sorted_by value doesn't cause SQL injection""" # Try to inject SQL via sorted_by parameter
|
||||
malicious_sort = "rating; DROP TABLE Users; --"
|
||||
result = await db.get_store_agents(
|
||||
search_query="test",
|
||||
sorted_by=malicious_sort,
|
||||
)
|
||||
|
||||
# Verify the query executed without error
|
||||
# Invalid sort_by should fall back to default, not cause SQL injection
|
||||
assert isinstance(result.agents, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_search_category_array_injection():
|
||||
"""Test that category parameter is safely passed as a parameter"""
|
||||
# Try SQL injection via category
|
||||
malicious_category = "AI'; DROP TABLE StoreAgent; --"
|
||||
result = await db.get_store_agents(
|
||||
search_query="test",
|
||||
category=malicious_category,
|
||||
)
|
||||
|
||||
# Verify the query executed without error
|
||||
# Category should be parameterized, preventing SQL injection
|
||||
assert isinstance(result.agents, list)
|
||||
|
||||
Reference in New Issue
Block a user