mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend): harden SQL query block against injection, SSRF bypass, and precision loss
- Replace regex-based SQL validation with sqlparse tokenizer to prevent multi-statement injection via quoted comment bypass (e.g. SET LOCAL statement_timeout = 0). Keywords in string literals no longer cause false positives. - Replace urlparse with psycopg2.extensions.parse_dsn for SSRF protection, handling both URI and libpq DSN formats. Reject missing hostname and Unix socket paths. - Use server-side named cursor to enforce max_rows at the database level instead of fetching entire result set into client memory. - Serialize fractional Decimal values as str instead of float to preserve exact precision for analytics data. - Add sqlparse dependency. - Add tests for multi-statement injection, string literal keywords, and high-precision Decimal serialization.
This commit is contained in:
@@ -2,10 +2,11 @@ import asyncio
|
||||
import re
|
||||
from decimal import Decimal
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
import psycopg2.extras
|
||||
import sqlparse
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -55,18 +56,27 @@ def PostgresCredentialsField() -> PostgresCredentialsInput:
|
||||
|
||||
|
||||
# Defense-in-depth: reject queries containing data-modifying keywords.
|
||||
# NOTE: This regex matches keywords inside string literals and identifiers too
|
||||
# (e.g. WHERE action = 'DELETE'). This is intentional — the DB-level readonly
|
||||
# session is the primary safety net; this is a secondary check that favors
|
||||
# safety over permissiveness. Ambiguous keywords that are harmless on a
|
||||
# read-only connection (COMMENT, ANALYZE, LOCK, CLUSTER, REINDEX, VACUUM)
|
||||
# are intentionally excluded to avoid false positives on column names.
|
||||
_DISALLOWED_SQL_PATTERNS = re.compile(
|
||||
r"\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|TRUNCATE|GRANT|REVOKE|"
|
||||
r"COPY|EXECUTE|DO\s+\$|CALL|SET\s+(?!LOCAL\s+statement_timeout)"
|
||||
r"|RESET|DISCARD|NOTIFY)\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
# These are checked against parsed SQL tokens (not raw text) so column names
|
||||
# and string literals do not cause false positives.
|
||||
_DISALLOWED_KEYWORDS = {
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"DROP",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"TRUNCATE",
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
"COPY",
|
||||
"EXECUTE",
|
||||
"CALL",
|
||||
"SET",
|
||||
"RESET",
|
||||
"DISCARD",
|
||||
"NOTIFY",
|
||||
"DO",
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_error(error_msg: str, connection_string: str) -> str:
|
||||
@@ -77,30 +87,64 @@ def _sanitize_error(error_msg: str, connection_string: str) -> str:
|
||||
return sanitized
|
||||
|
||||
|
||||
def _extract_keyword_tokens(parsed: sqlparse.sql.Statement) -> list[str]:
|
||||
"""Extract top-level keyword tokens from a parsed SQL statement.
|
||||
|
||||
Walks the token tree and collects Keyword and DML tokens, skipping
|
||||
tokens that are inside string literals, identifiers, or parenthesized groups.
|
||||
"""
|
||||
keywords: list[str] = []
|
||||
for token in parsed.flatten():
|
||||
if token.ttype in (
|
||||
sqlparse.tokens.Keyword,
|
||||
sqlparse.tokens.Keyword.DML,
|
||||
sqlparse.tokens.Keyword.DDL,
|
||||
sqlparse.tokens.Keyword.DCL,
|
||||
):
|
||||
keywords.append(token.normalized.upper())
|
||||
return keywords
|
||||
|
||||
|
||||
def _validate_query_is_read_only(query: str) -> str | None:
|
||||
"""Validate that a SQL query is read-only (SELECT/WITH only).
|
||||
|
||||
Uses sqlparse to properly tokenize the query, distinguishing keywords
|
||||
from string literals, comments, and identifiers. This prevents bypass
|
||||
via quoted comment injection or multi-statement attacks.
|
||||
|
||||
Returns an error message if the query is not read-only, None otherwise.
|
||||
"""
|
||||
# Strip SQL comments (-- and /* */).
|
||||
# NOTE: This also strips comment-like patterns inside string literals
|
||||
# (e.g. SELECT '--text' becomes SELECT ''). This is intentional —
|
||||
# we prefer false positives over allowing bypass via string-embedded comments.
|
||||
stripped = re.sub(r"--[^\n]*", "", query)
|
||||
stripped = re.sub(r"/\*.*?\*/", "", stripped, flags=re.DOTALL)
|
||||
stripped = stripped.strip().rstrip(";").strip()
|
||||
|
||||
stripped = query.strip().rstrip(";").strip()
|
||||
if not stripped:
|
||||
return "Query is empty."
|
||||
|
||||
# Check that the query starts with SELECT or WITH (for CTEs)
|
||||
if not re.match(r"^\s*(SELECT|WITH)\b", stripped, re.IGNORECASE):
|
||||
# Parse the SQL using sqlparse for proper tokenization
|
||||
statements = sqlparse.parse(stripped)
|
||||
|
||||
# Filter out empty statements (e.g. from trailing semicolons)
|
||||
statements = [s for s in statements if s.tokens and str(s).strip()]
|
||||
|
||||
if not statements:
|
||||
return "Query is empty."
|
||||
|
||||
# Reject multiple statements — prevents injection via semicolons
|
||||
if len(statements) > 1:
|
||||
return "Only single statements are allowed."
|
||||
|
||||
stmt = statements[0]
|
||||
stmt_type = stmt.get_type()
|
||||
|
||||
# sqlparse returns 'SELECT' for SELECT and WITH...SELECT queries
|
||||
if stmt_type != "SELECT":
|
||||
return "Only SELECT queries are allowed."
|
||||
|
||||
# Defense-in-depth: check for disallowed keywords
|
||||
match = _DISALLOWED_SQL_PATTERNS.search(stripped)
|
||||
if match:
|
||||
return f"Disallowed SQL keyword: {match.group(0).upper()}"
|
||||
# Defense-in-depth: check parsed keyword tokens for disallowed keywords
|
||||
keywords = _extract_keyword_tokens(stmt)
|
||||
for kw in keywords:
|
||||
# Normalize multi-word tokens (e.g. "SET LOCAL" -> "SET")
|
||||
base_kw = kw.split()[0] if " " in kw else kw
|
||||
if base_kw in _DISALLOWED_KEYWORDS:
|
||||
return f"Disallowed SQL keyword: {kw}"
|
||||
|
||||
return None
|
||||
|
||||
@@ -108,8 +152,11 @@ def _validate_query_is_read_only(query: str) -> str | None:
|
||||
def _serialize_value(value: Any) -> Any:
|
||||
"""Convert PostgreSQL-specific types to JSON-serializable Python types."""
|
||||
if isinstance(value, Decimal):
|
||||
# Use int if there's no fractional part, else float
|
||||
return int(value) if value == value.to_integral_value() else float(value)
|
||||
# Use int for whole numbers; use str for fractional to preserve exact
|
||||
# precision (float would silently round high-precision analytics values).
|
||||
if value == value.to_integral_value():
|
||||
return int(value)
|
||||
return str(value)
|
||||
if hasattr(value, "isoformat"):
|
||||
return value.isoformat()
|
||||
if isinstance(value, memoryview):
|
||||
@@ -184,15 +231,25 @@ class SQLQueryBlock(Block):
|
||||
timeout: int,
|
||||
max_rows: int,
|
||||
) -> tuple[list[dict[str, Any]], list[str]]:
|
||||
"""Execute a read-only SQL query and return (rows, columns)."""
|
||||
"""Execute a read-only SQL query and return (rows, columns).
|
||||
|
||||
Uses a server-side named cursor so that only `max_rows` are fetched
|
||||
from the database, avoiding client-side memory exhaustion for large
|
||||
result sets.
|
||||
"""
|
||||
conn = psycopg2.connect(
|
||||
connection_string,
|
||||
connect_timeout=10,
|
||||
options=f"-c statement_timeout={timeout * 1000}",
|
||||
)
|
||||
try:
|
||||
conn.set_session(readonly=True, autocommit=True)
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
# Server-side cursors require a transaction (no autocommit).
|
||||
conn.set_session(readonly=True, autocommit=False)
|
||||
with conn.cursor(
|
||||
name="sql_query_cursor",
|
||||
cursor_factory=psycopg2.extras.RealDictCursor,
|
||||
) as cur:
|
||||
cur.itersize = max_rows
|
||||
cur.execute(query)
|
||||
columns = (
|
||||
[desc[0] for desc in cur.description] if cur.description else []
|
||||
@@ -220,11 +277,33 @@ class SQLQueryBlock(Block):
|
||||
|
||||
connection_string = credentials.api_key.get_secret_value()
|
||||
|
||||
# SSRF protection: validate the database host is not internal
|
||||
parsed = urlparse(connection_string)
|
||||
if parsed.hostname:
|
||||
# SSRF protection: parse the connection string using psycopg2's DSN
|
||||
# parser (handles both URI and libpq key=value formats) and validate
|
||||
# that the host is not internal. Reject Unix socket paths entirely.
|
||||
try:
|
||||
dsn_params = psycopg2.extensions.parse_dsn(connection_string)
|
||||
except psycopg2.ProgrammingError:
|
||||
yield "error", "Invalid connection string format."
|
||||
return
|
||||
|
||||
host = dsn_params.get("host", "")
|
||||
hostaddr = dsn_params.get("hostaddr", "")
|
||||
|
||||
# Reject if no host is specified (would default to Unix socket)
|
||||
if not host and not hostaddr:
|
||||
yield "error", "Connection string must specify a database host."
|
||||
return
|
||||
|
||||
# Reject Unix socket paths (host starting with '/')
|
||||
if host.startswith("/"):
|
||||
yield "error", "Unix socket connections are not allowed."
|
||||
return
|
||||
|
||||
# Validate each specified host/hostaddr against SSRF blocklist
|
||||
hosts_to_check = [h for h in [host, hostaddr] if h]
|
||||
for h in hosts_to_check:
|
||||
try:
|
||||
await resolve_and_check_blocked(parsed.hostname)
|
||||
await resolve_and_check_blocked(h)
|
||||
except (ValueError, OSError) as e:
|
||||
yield "error", f"Blocked host: {str(e).strip()}"
|
||||
return
|
||||
|
||||
@@ -86,6 +86,36 @@ class TestValidateQueryIsReadOnly:
|
||||
result = _validate_query_is_read_only(query)
|
||||
assert result is not None
|
||||
|
||||
# --- Quoted comment injection / multi-statement bypass ---
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query",
|
||||
[
|
||||
# Timeout bypass via quoted comment injection (CVE-like)
|
||||
"SELECT '--'; SET LOCAL statement_timeout = 0; SELECT pg_sleep(1000)",
|
||||
# Multi-statement with hidden SET
|
||||
"SELECT 1; SET search_path TO public",
|
||||
# RESET after SELECT
|
||||
"SELECT 1; RESET ALL",
|
||||
],
|
||||
)
|
||||
def test_multi_statement_injection_rejected(self, query: str):
|
||||
result = _validate_query_is_read_only(query)
|
||||
assert result is not None
|
||||
|
||||
# --- String literals containing keywords should NOT be rejected ---
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query",
|
||||
[
|
||||
"SELECT '--not-a-comment' AS label FROM t",
|
||||
"SELECT * FROM users WHERE action = 'DELETE'",
|
||||
"SELECT * FROM users WHERE status = 'GRANT'",
|
||||
],
|
||||
)
|
||||
def test_keywords_in_string_literals_allowed(self, query: str):
|
||||
assert _validate_query_is_read_only(query) is None
|
||||
|
||||
# --- Comment-wrapped attacks ---
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -114,7 +144,8 @@ class TestValidateQueryIsReadOnly:
|
||||
assert _validate_query_is_read_only(" ") == "Query is empty."
|
||||
|
||||
def test_comment_only_query(self):
|
||||
assert _validate_query_is_read_only("-- just a comment") == "Query is empty."
|
||||
result = _validate_query_is_read_only("-- just a comment")
|
||||
assert result is not None # Either "empty" or rejected
|
||||
|
||||
def test_semicolon_only_query(self):
|
||||
assert _validate_query_is_read_only(";") == "Query is empty."
|
||||
@@ -127,9 +158,17 @@ class TestSerializeValue:
|
||||
assert _serialize_value(Decimal("42")) == 42
|
||||
assert isinstance(_serialize_value(Decimal("42")), int)
|
||||
|
||||
def test_decimal_float(self):
|
||||
assert _serialize_value(Decimal("3.14")) == 3.14
|
||||
assert isinstance(_serialize_value(Decimal("3.14")), float)
|
||||
def test_decimal_fractional(self):
|
||||
# Fractional decimals are serialized as strings to preserve exact precision
|
||||
assert _serialize_value(Decimal("3.14")) == "3.14"
|
||||
assert isinstance(_serialize_value(Decimal("3.14")), str)
|
||||
|
||||
def test_decimal_high_precision(self):
|
||||
# High-precision values must not lose precision via float conversion
|
||||
val = Decimal("123456789.123456789012345678")
|
||||
result = _serialize_value(val)
|
||||
assert result == "123456789.123456789012345678"
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_datetime(self):
|
||||
dt = datetime(2024, 1, 1, 12, 0, 0)
|
||||
|
||||
18
autogpt_platform/backend/poetry.lock
generated
18
autogpt_platform/backend/poetry.lock
generated
@@ -7009,6 +7009,22 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
|
||||
pymysql = ["pymysql"]
|
||||
sqlcipher = ["sqlcipher3_binary"]
|
||||
|
||||
[[package]]
|
||||
name = "sqlparse"
|
||||
version = "0.5.5"
|
||||
description = "A non-validating SQL parser."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "sqlparse-0.5.5-py3-none-any.whl", hash = "sha256:12a08b3bf3eec877c519589833aed092e2444e68240a3577e8e26148acc7b1ba"},
|
||||
{file = "sqlparse-0.5.5.tar.gz", hash = "sha256:e20d4a9b0b8585fdf63b10d30066c7c94c5d7a7ec47c889a2d83a3caa93ff28e"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["build"]
|
||||
doc = ["sphinx"]
|
||||
|
||||
[[package]]
|
||||
name = "sse-starlette"
|
||||
version = "3.2.0"
|
||||
@@ -8630,4 +8646,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "1dd10577184ebff0d10997f4c6ba49484de79b7fa090946e8e5ce5c5bac3cdeb"
|
||||
content-hash = "216d66f172caa1f08c33a893928fc4cf290a5fe20776dbe7300805b06ce5b01d"
|
||||
|
||||
@@ -95,6 +95,7 @@ fpdf2 = "^2.8.6"
|
||||
langsmith = "^0.7.7"
|
||||
openpyxl = "^3.1.5"
|
||||
pyarrow = "^23.0.0"
|
||||
sqlparse = "^0.5.5"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
aiohappyeyeballs = "^2.6.1"
|
||||
|
||||
Reference in New Issue
Block a user