mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-13 09:08:02 -05:00
Compare commits
2 Commits
dev
...
fix/sql-in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
224411abd3 | ||
|
|
6b241af79e |
@@ -100,33 +100,50 @@ async def get_store_agents(
|
|||||||
|
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
# Build filter conditions
|
# Whitelist allowed order_by columns
|
||||||
filter_conditions = []
|
ALLOWED_ORDER_BY = {
|
||||||
filter_conditions.append("is_available = true")
|
"rating": "rating DESC, rank DESC",
|
||||||
|
"runs": "runs DESC, rank DESC",
|
||||||
|
"name": "agent_name ASC, rank DESC",
|
||||||
|
"updated_at": "updated_at DESC, rank DESC",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Validate and get order clause
|
||||||
|
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
|
||||||
|
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
|
||||||
|
else:
|
||||||
|
order_by_clause = "updated_at DESC, rank DESC",
|
||||||
|
|
||||||
|
# Build WHERE conditions and parameters list
|
||||||
|
where_parts: list[str] = []
|
||||||
|
params: list[typing.Any] = [search_term] # $1 - search term
|
||||||
|
param_index = 2 # Start at $2 for next parameter
|
||||||
|
|
||||||
|
# Always filter for available agents
|
||||||
|
where_parts.append("is_available = true")
|
||||||
|
|
||||||
if featured:
|
if featured:
|
||||||
filter_conditions.append("featured = true")
|
where_parts.append("featured = true")
|
||||||
if creators:
|
|
||||||
creator_list = "','".join(sanitized_creators)
|
|
||||||
filter_conditions.append(f"creator_username IN ('{creator_list}')")
|
|
||||||
if category:
|
|
||||||
filter_conditions.append(f"'{sanitized_category}' = ANY(categories)")
|
|
||||||
|
|
||||||
where_filter = (
|
if creators and sanitized_creators:
|
||||||
" AND ".join(filter_conditions) if filter_conditions else "1=1"
|
# Use ANY with array parameter
|
||||||
)
|
where_parts.append(f"creator_username = ANY(${param_index})")
|
||||||
|
params.append(sanitized_creators)
|
||||||
|
param_index += 1
|
||||||
|
|
||||||
# Build ORDER BY clause
|
if category and sanitized_category:
|
||||||
if sorted_by == "rating":
|
where_parts.append(f"${param_index} = ANY(categories)")
|
||||||
order_by_clause = "rating DESC, rank DESC"
|
params.append(sanitized_category)
|
||||||
elif sorted_by == "runs":
|
param_index += 1
|
||||||
order_by_clause = "runs DESC, rank DESC"
|
|
||||||
elif sorted_by == "name":
|
|
||||||
order_by_clause = "agent_name ASC, rank DESC"
|
|
||||||
else:
|
|
||||||
order_by_clause = "rank DESC, updated_at DESC"
|
|
||||||
|
|
||||||
# Execute full-text search query
|
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
|
||||||
|
|
||||||
|
# Add pagination params
|
||||||
|
params.extend([page_size, offset])
|
||||||
|
limit_param = f"${param_index}"
|
||||||
|
offset_param = f"${param_index + 1}"
|
||||||
|
|
||||||
|
# Execute full-text search query with parameterized values
|
||||||
sql_query = f"""
|
sql_query = f"""
|
||||||
SELECT
|
SELECT
|
||||||
slug,
|
slug,
|
||||||
@@ -144,29 +161,31 @@ async def get_store_agents(
|
|||||||
updated_at,
|
updated_at,
|
||||||
ts_rank_cd(search, query) AS rank
|
ts_rank_cd(search, query) AS rank
|
||||||
FROM "StoreAgent",
|
FROM "StoreAgent",
|
||||||
plainto_tsquery('english', '{search_term}') AS query
|
plainto_tsquery('english', $1) AS query
|
||||||
WHERE {where_filter}
|
WHERE {sql_where_clause}
|
||||||
AND search @@ query
|
AND search @@ query
|
||||||
ORDER BY rank DESC, {order_by_clause}
|
ORDER BY {order_by_clause}
|
||||||
LIMIT {page_size} OFFSET {offset}
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Count query for pagination
|
# Count query for pagination - only uses search term parameter
|
||||||
count_query = f"""
|
count_query = f"""
|
||||||
SELECT COUNT(*) as count
|
SELECT COUNT(*) as count
|
||||||
FROM "StoreAgent",
|
FROM "StoreAgent",
|
||||||
plainto_tsquery('english', '{search_term}') AS query
|
plainto_tsquery('english', $1) AS query
|
||||||
WHERE {where_filter}
|
WHERE {sql_where_clause}
|
||||||
AND search @@ query
|
AND search @@ query
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Execute both queries
|
# Execute both queries with parameters
|
||||||
agents = await prisma.client.get_client().query_raw(
|
agents = await prisma.client.get_client().query_raw(
|
||||||
query=typing.cast(typing.LiteralString, sql_query)
|
typing.cast(typing.LiteralString, sql_query), *params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# For count, use params without pagination (last 2 params)
|
||||||
|
count_params = params[:-2]
|
||||||
count_result = await prisma.client.get_client().query_raw(
|
count_result = await prisma.client.get_client().query_raw(
|
||||||
query=typing.cast(typing.LiteralString, count_query)
|
typing.cast(typing.LiteralString, count_query), *count_params
|
||||||
)
|
)
|
||||||
|
|
||||||
total = count_result[0]["count"] if count_result else 0
|
total = count_result[0]["count"] if count_result else 0
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ async def setup_prisma():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_get_store_agents(mocker):
|
async def test_get_store_agents(mocker):
|
||||||
# Mock data
|
# Mock data
|
||||||
mock_agents = [
|
mock_agents = [
|
||||||
@@ -64,7 +64,7 @@ async def test_get_store_agents(mocker):
|
|||||||
mock_store_agent.return_value.count.assert_called_once()
|
mock_store_agent.return_value.count.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_get_store_agent_details(mocker):
|
async def test_get_store_agent_details(mocker):
|
||||||
# Mock data
|
# Mock data
|
||||||
mock_agent = prisma.models.StoreAgent(
|
mock_agent = prisma.models.StoreAgent(
|
||||||
@@ -173,7 +173,7 @@ async def test_get_store_agent_details(mocker):
|
|||||||
mock_store_listing_db.return_value.find_first.assert_called_once()
|
mock_store_listing_db.return_value.find_first.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_get_store_creator_details(mocker):
|
async def test_get_store_creator_details(mocker):
|
||||||
# Mock data
|
# Mock data
|
||||||
mock_creator_data = prisma.models.Creator(
|
mock_creator_data = prisma.models.Creator(
|
||||||
@@ -210,7 +210,7 @@ async def test_get_store_creator_details(mocker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_create_store_submission(mocker):
|
async def test_create_store_submission(mocker):
|
||||||
# Mock data
|
# Mock data
|
||||||
mock_agent = prisma.models.AgentGraph(
|
mock_agent = prisma.models.AgentGraph(
|
||||||
@@ -282,7 +282,7 @@ async def test_create_store_submission(mocker):
|
|||||||
mock_store_listing.return_value.create.assert_called_once()
|
mock_store_listing.return_value.create.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_update_profile(mocker):
|
async def test_update_profile(mocker):
|
||||||
# Mock data
|
# Mock data
|
||||||
mock_profile = prisma.models.Profile(
|
mock_profile = prisma.models.Profile(
|
||||||
@@ -327,7 +327,7 @@ async def test_update_profile(mocker):
|
|||||||
mock_profile_db.return_value.update.assert_called_once()
|
mock_profile_db.return_value.update.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_get_user_profile(mocker):
|
async def test_get_user_profile(mocker):
|
||||||
# Mock data
|
# Mock data
|
||||||
mock_profile = prisma.models.Profile(
|
mock_profile = prisma.models.Profile(
|
||||||
@@ -359,3 +359,63 @@ async def test_get_user_profile(mocker):
|
|||||||
assert result.description == "Test description"
|
assert result.description == "Test description"
|
||||||
assert result.links == ["link1", "link2"]
|
assert result.links == ["link1", "link2"]
|
||||||
assert result.avatar_url == "avatar.jpg"
|
assert result.avatar_url == "avatar.jpg"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_get_store_agents_with_search_parameterized(mocker):
|
||||||
|
"""Test that search query uses parameterized SQL - validates the fix works"""
|
||||||
|
|
||||||
|
# Call function with search query containing potential SQL injection
|
||||||
|
malicious_search = "test'; DROP TABLE StoreAgent; --"
|
||||||
|
result = await db.get_store_agents(search_query=malicious_search)
|
||||||
|
|
||||||
|
# Verify query executed safely
|
||||||
|
assert isinstance(result.agents, list)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_get_store_agents_with_search_and_filters_parameterized():
|
||||||
|
"""Test parameterized SQL with multiple filters"""
|
||||||
|
|
||||||
|
# Call with multiple filters including potential injection attempts
|
||||||
|
result = await db.get_store_agents(
|
||||||
|
search_query="test",
|
||||||
|
creators=["creator1'; DROP TABLE Users; --", "creator2"],
|
||||||
|
category="AI'; DELETE FROM StoreAgent; --",
|
||||||
|
featured=True,
|
||||||
|
sorted_by="rating",
|
||||||
|
page=1,
|
||||||
|
page_size=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the query executed without error
|
||||||
|
assert isinstance(result.agents, list)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_get_store_agents_search_with_invalid_sort_by():
|
||||||
|
"""Test that invalid sorted_by value doesn't cause SQL injection""" # Try to inject SQL via sorted_by parameter
|
||||||
|
malicious_sort = "rating; DROP TABLE Users; --"
|
||||||
|
result = await db.get_store_agents(
|
||||||
|
search_query="test",
|
||||||
|
sorted_by=malicious_sort,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the query executed without error
|
||||||
|
# Invalid sort_by should fall back to default, not cause SQL injection
|
||||||
|
assert isinstance(result.agents, list)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_get_store_agents_search_category_array_injection():
|
||||||
|
"""Test that category parameter is safely passed as a parameter"""
|
||||||
|
# Try SQL injection via category
|
||||||
|
malicious_category = "AI'; DROP TABLE StoreAgent; --"
|
||||||
|
result = await db.get_store_agents(
|
||||||
|
search_query="test",
|
||||||
|
category=malicious_category,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the query executed without error
|
||||||
|
# Category should be parameterized, preventing SQL injection
|
||||||
|
assert isinstance(result.agents, list)
|
||||||
|
|||||||
Reference in New Issue
Block a user