fix(backend): harden SQL query block against injection, SSRF bypass, and precision loss

- Replace regex-based SQL validation with sqlparse tokenizer to prevent
  multi-statement injection via quoted comment bypass (e.g. SET LOCAL
  statement_timeout = 0). Keywords in string literals no longer cause
  false positives.
- Replace urlparse with psycopg2.extensions.parse_dsn for SSRF protection,
  handling both URI and libpq DSN formats. Reject missing hostname and
  Unix socket paths.
- Use server-side named cursor to enforce max_rows at the database level
  instead of fetching entire result set into client memory.
- Serialize fractional Decimal values as str instead of float to preserve
  exact precision for analytics data.
- Add sqlparse dependency.
- Add tests for multi-statement injection, string literal keywords, and
  high-precision Decimal serialization.
This commit is contained in:
Zamil Majdy
2026-03-26 12:50:18 +07:00
parent 7ff096afd9
commit c5507415fd
4 changed files with 176 additions and 41 deletions

View File

@@ -2,10 +2,11 @@ import asyncio
import re
from decimal import Decimal
from typing import Any, Literal
from urllib.parse import urlparse
import psycopg2
import psycopg2.extensions
import psycopg2.extras
import sqlparse
from pydantic import SecretStr
from backend.blocks._base import (
@@ -55,18 +56,27 @@ def PostgresCredentialsField() -> PostgresCredentialsInput:
# Defense-in-depth: reject queries containing data-modifying keywords.
# NOTE: This regex matches keywords inside string literals and identifiers too
# (e.g. WHERE action = 'DELETE'). This is intentional — the DB-level readonly
# session is the primary safety net; this is a secondary check that favors
# safety over permissiveness. Ambiguous keywords that are harmless on a
# read-only connection (COMMENT, ANALYZE, LOCK, CLUSTER, REINDEX, VACUUM)
# are intentionally excluded to avoid false positives on column names.
_DISALLOWED_SQL_PATTERNS = re.compile(
r"\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|TRUNCATE|GRANT|REVOKE|"
r"COPY|EXECUTE|DO\s+\$|CALL|SET\s+(?!LOCAL\s+statement_timeout)"
r"|RESET|DISCARD|NOTIFY)\b",
re.IGNORECASE,
)
# These are checked against parsed SQL tokens (not raw text) so column names
# and string literals do not cause false positives.
_DISALLOWED_KEYWORDS = {
"INSERT",
"UPDATE",
"DELETE",
"DROP",
"ALTER",
"CREATE",
"TRUNCATE",
"GRANT",
"REVOKE",
"COPY",
"EXECUTE",
"CALL",
"SET",
"RESET",
"DISCARD",
"NOTIFY",
"DO",
}
def _sanitize_error(error_msg: str, connection_string: str) -> str:
@@ -77,30 +87,64 @@ def _sanitize_error(error_msg: str, connection_string: str) -> str:
return sanitized
def _extract_keyword_tokens(parsed: sqlparse.sql.Statement) -> list[str]:
"""Extract top-level keyword tokens from a parsed SQL statement.
Walks the token tree and collects Keyword and DML tokens, skipping
tokens that are inside string literals, identifiers, or parenthesized groups.
"""
keywords: list[str] = []
for token in parsed.flatten():
if token.ttype in (
sqlparse.tokens.Keyword,
sqlparse.tokens.Keyword.DML,
sqlparse.tokens.Keyword.DDL,
sqlparse.tokens.Keyword.DCL,
):
keywords.append(token.normalized.upper())
return keywords
def _validate_query_is_read_only(query: str) -> str | None:
"""Validate that a SQL query is read-only (SELECT/WITH only).
Uses sqlparse to properly tokenize the query, distinguishing keywords
from string literals, comments, and identifiers. This prevents bypass
via quoted comment injection or multi-statement attacks.
Returns an error message if the query is not read-only, None otherwise.
"""
# Strip SQL comments (-- and /* */).
# NOTE: This also strips comment-like patterns inside string literals
# (e.g. SELECT '--text' becomes SELECT ''). This is intentional —
# we prefer false positives over allowing bypass via string-embedded comments.
stripped = re.sub(r"--[^\n]*", "", query)
stripped = re.sub(r"/\*.*?\*/", "", stripped, flags=re.DOTALL)
stripped = stripped.strip().rstrip(";").strip()
stripped = query.strip().rstrip(";").strip()
if not stripped:
return "Query is empty."
# Check that the query starts with SELECT or WITH (for CTEs)
if not re.match(r"^\s*(SELECT|WITH)\b", stripped, re.IGNORECASE):
# Parse the SQL using sqlparse for proper tokenization
statements = sqlparse.parse(stripped)
# Filter out empty statements (e.g. from trailing semicolons)
statements = [s for s in statements if s.tokens and str(s).strip()]
if not statements:
return "Query is empty."
# Reject multiple statements — prevents injection via semicolons
if len(statements) > 1:
return "Only single statements are allowed."
stmt = statements[0]
stmt_type = stmt.get_type()
# sqlparse returns 'SELECT' for SELECT and WITH...SELECT queries
if stmt_type != "SELECT":
return "Only SELECT queries are allowed."
# Defense-in-depth: check for disallowed keywords
match = _DISALLOWED_SQL_PATTERNS.search(stripped)
if match:
return f"Disallowed SQL keyword: {match.group(0).upper()}"
# Defense-in-depth: check parsed keyword tokens for disallowed keywords
keywords = _extract_keyword_tokens(stmt)
for kw in keywords:
# Normalize multi-word tokens (e.g. "SET LOCAL" -> "SET")
base_kw = kw.split()[0] if " " in kw else kw
if base_kw in _DISALLOWED_KEYWORDS:
return f"Disallowed SQL keyword: {kw}"
return None
@@ -108,8 +152,11 @@ def _validate_query_is_read_only(query: str) -> str | None:
def _serialize_value(value: Any) -> Any:
"""Convert PostgreSQL-specific types to JSON-serializable Python types."""
if isinstance(value, Decimal):
# Use int if there's no fractional part, else float
return int(value) if value == value.to_integral_value() else float(value)
# Use int for whole numbers; use str for fractional to preserve exact
# precision (float would silently round high-precision analytics values).
if value == value.to_integral_value():
return int(value)
return str(value)
if hasattr(value, "isoformat"):
return value.isoformat()
if isinstance(value, memoryview):
@@ -184,15 +231,25 @@ class SQLQueryBlock(Block):
timeout: int,
max_rows: int,
) -> tuple[list[dict[str, Any]], list[str]]:
"""Execute a read-only SQL query and return (rows, columns)."""
"""Execute a read-only SQL query and return (rows, columns).
Uses a server-side named cursor so that only `max_rows` are fetched
from the database, avoiding client-side memory exhaustion for large
result sets.
"""
conn = psycopg2.connect(
connection_string,
connect_timeout=10,
options=f"-c statement_timeout={timeout * 1000}",
)
try:
conn.set_session(readonly=True, autocommit=True)
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
# Server-side cursors require a transaction (no autocommit).
conn.set_session(readonly=True, autocommit=False)
with conn.cursor(
name="sql_query_cursor",
cursor_factory=psycopg2.extras.RealDictCursor,
) as cur:
cur.itersize = max_rows
cur.execute(query)
columns = (
[desc[0] for desc in cur.description] if cur.description else []
@@ -220,11 +277,33 @@ class SQLQueryBlock(Block):
connection_string = credentials.api_key.get_secret_value()
# SSRF protection: validate the database host is not internal
parsed = urlparse(connection_string)
if parsed.hostname:
# SSRF protection: parse the connection string using psycopg2's DSN
# parser (handles both URI and libpq key=value formats) and validate
# that the host is not internal. Reject Unix socket paths entirely.
try:
dsn_params = psycopg2.extensions.parse_dsn(connection_string)
except psycopg2.ProgrammingError:
yield "error", "Invalid connection string format."
return
host = dsn_params.get("host", "")
hostaddr = dsn_params.get("hostaddr", "")
# Reject if no host is specified (would default to Unix socket)
if not host and not hostaddr:
yield "error", "Connection string must specify a database host."
return
# Reject Unix socket paths (host starting with '/')
if host.startswith("/"):
yield "error", "Unix socket connections are not allowed."
return
# Validate each specified host/hostaddr against SSRF blocklist
hosts_to_check = [h for h in [host, hostaddr] if h]
for h in hosts_to_check:
try:
await resolve_and_check_blocked(parsed.hostname)
await resolve_and_check_blocked(h)
except (ValueError, OSError) as e:
yield "error", f"Blocked host: {str(e).strip()}"
return

View File

@@ -86,6 +86,36 @@ class TestValidateQueryIsReadOnly:
result = _validate_query_is_read_only(query)
assert result is not None
# --- Quoted comment injection / multi-statement bypass ---
@pytest.mark.parametrize(
"query",
[
# Timeout bypass via quoted comment injection (CVE-like)
"SELECT '--'; SET LOCAL statement_timeout = 0; SELECT pg_sleep(1000)",
# Multi-statement with hidden SET
"SELECT 1; SET search_path TO public",
# RESET after SELECT
"SELECT 1; RESET ALL",
],
)
def test_multi_statement_injection_rejected(self, query: str):
result = _validate_query_is_read_only(query)
assert result is not None
# --- String literals containing keywords should NOT be rejected ---
@pytest.mark.parametrize(
"query",
[
"SELECT '--not-a-comment' AS label FROM t",
"SELECT * FROM users WHERE action = 'DELETE'",
"SELECT * FROM users WHERE status = 'GRANT'",
],
)
def test_keywords_in_string_literals_allowed(self, query: str):
assert _validate_query_is_read_only(query) is None
# --- Comment-wrapped attacks ---
@pytest.mark.parametrize(
@@ -114,7 +144,8 @@ class TestValidateQueryIsReadOnly:
assert _validate_query_is_read_only(" ") == "Query is empty."
def test_comment_only_query(self):
assert _validate_query_is_read_only("-- just a comment") == "Query is empty."
result = _validate_query_is_read_only("-- just a comment")
assert result is not None # Either "empty" or rejected
def test_semicolon_only_query(self):
assert _validate_query_is_read_only(";") == "Query is empty."
@@ -127,9 +158,17 @@ class TestSerializeValue:
assert _serialize_value(Decimal("42")) == 42
assert isinstance(_serialize_value(Decimal("42")), int)
def test_decimal_float(self):
assert _serialize_value(Decimal("3.14")) == 3.14
assert isinstance(_serialize_value(Decimal("3.14")), float)
def test_decimal_fractional(self):
# Fractional decimals are serialized as strings to preserve exact precision
assert _serialize_value(Decimal("3.14")) == "3.14"
assert isinstance(_serialize_value(Decimal("3.14")), str)
def test_decimal_high_precision(self):
# High-precision values must not lose precision via float conversion
val = Decimal("123456789.123456789012345678")
result = _serialize_value(val)
assert result == "123456789.123456789012345678"
assert isinstance(result, str)
def test_datetime(self):
dt = datetime(2024, 1, 1, 12, 0, 0)

View File

@@ -7009,6 +7009,22 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
pymysql = ["pymysql"]
sqlcipher = ["sqlcipher3_binary"]
[[package]]
name = "sqlparse"
version = "0.5.5"
description = "A non-validating SQL parser."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "sqlparse-0.5.5-py3-none-any.whl", hash = "sha256:12a08b3bf3eec877c519589833aed092e2444e68240a3577e8e26148acc7b1ba"},
{file = "sqlparse-0.5.5.tar.gz", hash = "sha256:e20d4a9b0b8585fdf63b10d30066c7c94c5d7a7ec47c889a2d83a3caa93ff28e"},
]
[package.extras]
dev = ["build"]
doc = ["sphinx"]
[[package]]
name = "sse-starlette"
version = "3.2.0"
@@ -8630,4 +8646,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.14"
content-hash = "1dd10577184ebff0d10997f4c6ba49484de79b7fa090946e8e5ce5c5bac3cdeb"
content-hash = "216d66f172caa1f08c33a893928fc4cf290a5fe20776dbe7300805b06ce5b01d"

View File

@@ -95,6 +95,7 @@ fpdf2 = "^2.8.6"
langsmith = "^0.7.7"
openpyxl = "^3.1.5"
pyarrow = "^23.0.0"
sqlparse = "^0.5.5"
[tool.poetry.group.dev.dependencies]
aiohappyeyeballs = "^2.6.1"