refactor(backend): split SQL query block into block + helpers module

- Extract validation, sanitization, serialization, and query execution
  into sql_query_helpers.py to meet ~300-line file guideline
- Fix duck typing in _serialize_value: replace hasattr(value, "isoformat")
  with explicit isinstance(value, (datetime, date, time))
- Extract _configure_session and _run_in_transaction helpers to bring
  execute_query under ~40-line function guideline
- Extract _validate_query, _resolve_host, _format_operational_error
  helpers to simplify the run method
- Add database name scrubbing to _sanitize_error
This commit is contained in:
Zamil Majdy
2026-04-01 08:18:15 +02:00
parent 8de9880f43
commit 991969612c
3 changed files with 410 additions and 381 deletions

View File

@@ -1,12 +1,7 @@
import asyncio
import re
from decimal import Decimal
from enum import Enum
from typing import Any, Literal
import sqlparse
from pydantic import SecretStr
from sqlalchemy import create_engine, text
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import DBAPIError, OperationalError, ProgrammingError
@@ -17,6 +12,15 @@ from backend.blocks._base import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.sql_query_helpers import (
_DATABASE_TYPE_DEFAULT_PORT,
_DATABASE_TYPE_TO_DRIVER,
DatabaseType,
_execute_query,
_sanitize_error,
_validate_query_is_read_only,
_validate_single_statement,
)
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
@@ -54,219 +58,6 @@ def DatabaseCredentialsField() -> DatabaseCredentialsInput:
)
class DatabaseType(str, Enum):
POSTGRES = "postgres"
MYSQL = "mysql"
MSSQL = "mssql"
# Defense-in-depth: reject queries containing data-modifying keywords.
# These are checked against parsed SQL tokens (not raw text) so column names
# and string literals do not cause false positives.
_DISALLOWED_KEYWORDS = {
"INSERT",
"UPDATE",
"DELETE",
"DROP",
"ALTER",
"CREATE",
"TRUNCATE",
"GRANT",
"REVOKE",
"COPY",
"EXECUTE",
"CALL",
"SET",
"RESET",
"DISCARD",
"NOTIFY",
"DO",
# SELECT ... INTO creates tables (PG/MSSQL) or writes files (MySQL OUTFILE/DUMPFILE)
"INTO",
"OUTFILE",
"DUMPFILE",
}
def _sanitize_error(
error_msg: str,
connection_string: str,
*,
host: str = "",
original_host: str = "",
username: str = "",
port: int = 0,
) -> str:
"""Remove connection string, credentials, and infrastructure details
from error messages so they are safe to expose to the LLM.
Scrubs:
- The full connection string
- URL-embedded credentials (``://user:pass@``)
- ``password=<value>`` key-value pairs
- The database hostname / IP used for the connection
- The original (pre-resolution) hostname provided by the user
- Any IPv4 addresses that appear in the message
- Any bracketed IPv6 addresses (e.g. ``[::1]``, ``[fe80::1%eth0]``)
- The database username
- The database port number
"""
sanitized = error_msg.replace(connection_string, "<connection_string>")
sanitized = re.sub(r"password=[^\s&]+", "password=***", sanitized)
sanitized = re.sub(r"://[^@]+@", "://***:***@", sanitized)
# Replace the known host (may be an IP already) before the generic IP pass.
# Also replace the original (pre-DNS-resolution) hostname if it differs.
if original_host and original_host != host:
sanitized = sanitized.replace(original_host, "<host>")
if host:
sanitized = sanitized.replace(host, "<host>")
# Replace any remaining IPv4 addresses (e.g. resolved IPs the driver logs)
sanitized = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", "<ip>", sanitized)
# Replace bracketed IPv6 addresses (e.g. "[::1]", "[fe80::1%eth0]")
sanitized = re.sub(r"\[[0-9a-fA-F:]+(?:%[^\]]+)?\]", "<ip>", sanitized)
# Replace the database username (handles double-quoted, single-quoted,
# and unquoted formats across PostgreSQL, MySQL, and MSSQL error messages).
if username:
sanitized = re.sub(
r"""for user ["']?""" + re.escape(username) + r"""["']?""",
"for user <user>",
sanitized,
)
# Catch remaining bare occurrences in various quote styles:
# - PostgreSQL: "FATAL: role "myuser" does not exist"
# - MySQL: "Access denied for user 'myuser'@'host'"
# - MSSQL: "Login failed for user 'myuser'"
sanitized = sanitized.replace(f'"{username}"', "<user>")
sanitized = sanitized.replace(f"'{username}'", "<user>")
# Replace the port number (handles "port 5432" and ":5432" formats)
if port:
port_str = re.escape(str(port))
sanitized = re.sub(
r"(?:port |:)" + port_str + r"(?![0-9])",
lambda m: ("port " if m.group().startswith("p") else ":") + "<port>",
sanitized,
)
return sanitized
def _extract_keyword_tokens(parsed: sqlparse.sql.Statement) -> list[str]:
"""Extract keyword tokens from a parsed SQL statement.
Uses sqlparse token type classification to collect Keyword/DML/DDL/DCL
tokens. String literals and identifiers have different token types, so
they are naturally excluded from the result.
"""
keywords: list[str] = []
for token in parsed.flatten():
if token.ttype in (
sqlparse.tokens.Keyword,
sqlparse.tokens.Keyword.DML,
sqlparse.tokens.Keyword.DDL,
sqlparse.tokens.Keyword.DCL,
):
keywords.append(token.normalized.upper())
return keywords
def _validate_query_is_read_only(stmt: sqlparse.sql.Statement) -> str | None:
"""Validate that a parsed SQL statement is read-only (SELECT/WITH only).
Accepts an already-parsed statement from ``_validate_single_statement``
to avoid re-parsing. Checks:
1. Statement type must be SELECT (sqlparse classifies WITH...SELECT as SELECT)
2. No disallowed keywords (INSERT, UPDATE, DELETE, DROP, INTO, etc.)
Returns an error message if the query is not read-only, None otherwise.
"""
# sqlparse returns 'SELECT' for SELECT and WITH...SELECT queries
if stmt.get_type() != "SELECT":
return "Only SELECT queries are allowed."
# Defense-in-depth: check parsed keyword tokens for disallowed keywords
keywords = _extract_keyword_tokens(stmt)
for kw in keywords:
# Normalize multi-word tokens (e.g. "SET LOCAL" -> "SET")
base_kw = kw.split()[0] if " " in kw else kw
if base_kw in _DISALLOWED_KEYWORDS:
return f"Disallowed SQL keyword: {kw}"
return None
def _validate_single_statement(
query: str,
) -> tuple[str | None, sqlparse.sql.Statement | None]:
"""Validate that the query contains exactly one non-empty SQL statement.
Returns (error_message, parsed_statement). If error_message is not None,
the query is invalid and parsed_statement will be None.
"""
stripped = query.strip().rstrip(";").strip()
if not stripped:
return "Query is empty.", None
# Parse the SQL using sqlparse for proper tokenization
statements = sqlparse.parse(stripped)
# Filter out empty statements and comment-only statements
statements = [
s
for s in statements
if s.tokens
and str(s).strip()
and not all(
t.is_whitespace or t.ttype in sqlparse.tokens.Comment for t in s.flatten()
)
]
if not statements:
return "Query is empty.", None
# Reject multiple statements -- prevents injection via semicolons
if len(statements) > 1:
return "Only single statements are allowed.", None
return None, statements[0]
def _serialize_value(value: Any) -> Any:
"""Convert database-specific types to JSON-serializable Python types."""
if isinstance(value, Decimal):
# Use int for whole numbers; use str for fractional to preserve exact
# precision (float would silently round high-precision analytics values).
if value == value.to_integral_value():
return int(value)
return str(value)
if hasattr(value, "isoformat"):
return value.isoformat()
if isinstance(value, (bytes, memoryview)):
if isinstance(value, memoryview):
return bytes(value).hex()
return value.hex()
return value
# Map DatabaseType enum values to the expected SQLAlchemy driver prefix.
_DATABASE_TYPE_TO_DRIVER = {
DatabaseType.POSTGRES: "postgresql",
DatabaseType.MYSQL: "mysql+pymysql",
DatabaseType.MSSQL: "mssql+pymssql",
}
# Default ports for each database type.
_DATABASE_TYPE_DEFAULT_PORT = {
DatabaseType.POSTGRES: 5432,
DatabaseType.MYSQL: 3306,
DatabaseType.MSSQL: 1433,
}
class SQLQueryBlock(Block):
class Input(BlockSchemaInput):
database_type: DatabaseType = SchemaField(
@@ -389,87 +180,17 @@ class SQLQueryBlock(Block):
) -> tuple[list[dict[str, Any]], list[str], int]:
"""Execute a SQL query and return (rows, columns, affected_rows).
Uses SQLAlchemy to connect to any supported database.
For SELECT queries, rows are limited to ``max_rows`` via DBAPI fetchmany.
For write queries, affected_rows contains the rowcount from the driver.
When ``read_only`` is True, the database session is set to read-only
mode and the transaction is always rolled back.
Delegates to ``_execute_query`` in ``sql_query_helpers``.
Extracted as a method so it can be mocked during block tests.
"""
# Determine driver-specific connection timeout argument.
# pymssql uses "login_timeout", while PostgreSQL/MySQL use "connect_timeout".
if database_type == DatabaseType.MSSQL:
connect_args = {"login_timeout": 10}
else:
connect_args = {"connect_timeout": 10}
engine = create_engine(
connection_url,
connect_args=connect_args,
return _execute_query(
connection_url=connection_url,
query=query,
timeout=timeout,
max_rows=max_rows,
read_only=read_only,
database_type=database_type,
)
try:
with engine.connect() as conn:
# Use AUTOCOMMIT so SET commands take effect immediately and
# apply to the explicit transaction we open below.
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
# Set session-level timeout (always) and read-only
# (when read_only=True) before starting the transaction.
# Compute timeout in milliseconds once. The value is
# Pydantic-validated (ge=1, le=120) so it is always a safe
# integer, but we use an explicit int() cast as
# defense-in-depth to guarantee no SQL injection even if
# upstream validation changes.
# NOTE: SET commands do not support bind parameters in most
# databases, so we use str(int(...)) for safe interpolation.
timeout_ms = str(int(timeout * 1000))
if engine.dialect.name == "postgresql":
conn.execute(text("SET statement_timeout = " + timeout_ms))
if read_only:
conn.execute(text("SET default_transaction_read_only = ON"))
elif engine.dialect.name == "mysql":
# NOTE: MAX_EXECUTION_TIME only applies to SELECT statements.
# Write queries (INSERT/UPDATE/DELETE) are not bounded by this
# setting; they rely on the database's wait_timeout instead.
conn.execute(text("SET SESSION MAX_EXECUTION_TIME = " + timeout_ms))
if read_only:
conn.execute(text("SET SESSION TRANSACTION READ ONLY"))
elif engine.dialect.name == "mssql":
# MSSQL: SET LOCK_TIMEOUT limits lock-wait time (ms).
# pymssql's connect_args "login_timeout" handles the connection
# timeout, but LOCK_TIMEOUT covers in-query lock waits.
conn.execute(text("SET LOCK_TIMEOUT " + timeout_ms))
# MSSQL lacks a session-level read-only mode like
# PostgreSQL/MySQL. Read-only enforcement is handled by
# the SQL validation layer (_validate_query_is_read_only)
# and the ROLLBACK in the finally block.
# Execute the user query inside an explicit transaction so
# the read-only setting (if enabled) applies to it.
# MSSQL uses T-SQL "BEGIN TRANSACTION"; others use "BEGIN".
if engine.dialect.name == "mssql":
conn.execute(text("BEGIN TRANSACTION"))
else:
conn.execute(text("BEGIN"))
try:
result = conn.execute(text(query))
affected = result.rowcount if not result.returns_rows else -1
columns = list(result.keys()) if result.returns_rows else []
rows = result.fetchmany(max_rows) if result.returns_rows else []
results = [
{col: _serialize_value(val) for col, val in zip(columns, row)}
for row in rows
]
finally:
if read_only:
conn.execute(text("ROLLBACK"))
else:
conn.execute(text("COMMIT"))
return results, columns, affected
finally:
engine.dispose()
async def run(
self,
@@ -478,63 +199,37 @@ class SQLQueryBlock(Block):
credentials: DatabaseCredentials,
**_kwargs: Any,
) -> BlockOutput:
# Multi-statement prevention applies in both modes (security against injection)
stmt_error, parsed_stmt = _validate_single_statement(input_data.query)
if stmt_error:
yield "error", stmt_error
# Validate query structure and read-only constraints.
error = self._validate_query(input_data)
if error:
yield "error", error
return
# When read_only (default), enforce SELECT-only queries.
# Reuse the already-parsed statement to avoid duplicate sqlparse.parse().
assert parsed_stmt is not None # guaranteed by stmt_error check above
if input_data.read_only:
ro_error = _validate_query_is_read_only(parsed_stmt)
if ro_error:
yield "error", ro_error
return
host = input_data.host.get_secret_value().strip()
if not host:
yield "error", "Database host is required."
# Validate host and resolve for SSRF protection.
host, pinned_host, error = await self._resolve_host(input_data)
if error:
yield "error", error
return
# Reject Unix socket paths (host starting with '/')
if host.startswith("/"):
yield "error", "Unix socket connections are not allowed."
return
# SSRF protection: validate that the host is not internal.
# We use the resolved IP address for the actual connection to prevent
# DNS rebinding (TOCTOU) attacks where the DNS record changes between
# the check and the connection.
try:
resolved_ips = await self.check_host_allowed(host)
except (ValueError, OSError) as e:
yield "error", f"Blocked host: {str(e).strip()}"
return
# Pin the connection to the first resolved IP to prevent DNS rebinding.
pinned_host = resolved_ips[0]
# Build the SQLAlchemy connection URL from discrete fields.
# URL.create() accepts the raw password without URL-encoding,
# so special characters like @, #, ! work correctly.
drivername = _DATABASE_TYPE_TO_DRIVER[input_data.database_type]
# Build connection URL and execute.
port = input_data.port or _DATABASE_TYPE_DEFAULT_PORT[input_data.database_type]
username = credentials.username.get_secret_value()
password = credentials.password.get_secret_value()
connection_url = URL.create(
drivername=drivername,
drivername=_DATABASE_TYPE_TO_DRIVER[input_data.database_type],
username=username,
password=password,
password=credentials.password.get_secret_value(),
host=pinned_host,
port=port,
database=input_data.database,
)
# Render the connection string for error sanitization only.
# The URL object is passed directly to create_engine() to prevent
# database name injection (e.g. "db?options=-c statement_timeout=0").
connection_string = connection_url.render_as_string(hide_password=True)
sanitize_kwargs = dict(
host=pinned_host,
original_host=host,
username=username,
port=port,
database=input_data.database,
)
conn_str = connection_url.render_as_string(hide_password=True)
try:
results, columns, affected = await asyncio.to_thread(
@@ -552,46 +247,59 @@ class SQLQueryBlock(Block):
if affected >= 0:
yield "affected_rows", affected
except OperationalError as e:
error_msg = _sanitize_error(
str(e).strip(),
connection_string,
host=pinned_host,
original_host=host,
username=username,
port=port,
yield "error", self._format_operational_error(
e, conn_str, input_data.timeout, **sanitize_kwargs
)
if "timeout" in error_msg.lower() or "cancel" in error_msg.lower():
yield "error", f"Query timed out after {input_data.timeout}s."
elif "connect" in error_msg.lower():
yield "error", f"Failed to connect to database: {error_msg}"
else:
yield "error", f"Database error: {error_msg}"
except ProgrammingError as e:
msg = _sanitize_error(
str(e).strip(),
connection_string,
host=pinned_host,
original_host=host,
username=username,
port=port,
)
msg = _sanitize_error(str(e).strip(), conn_str, **sanitize_kwargs)
yield "error", f"SQL error: {msg}"
except DBAPIError as e:
msg = _sanitize_error(
str(e).strip(),
connection_string,
host=pinned_host,
original_host=host,
username=username,
port=port,
)
msg = _sanitize_error(str(e).strip(), conn_str, **sanitize_kwargs)
yield "error", f"Database error: {msg}"
except ModuleNotFoundError:
yield (
"error",
(
f"Database driver not available for "
f"{input_data.database_type.value}. "
f"Please contact the platform administrator."
),
yield "error", (
f"Database driver not available for "
f"{input_data.database_type.value}. "
f"Please contact the platform administrator."
)
@staticmethod
def _validate_query(input_data: "SQLQueryBlock.Input") -> str | None:
"""Validate query structure and read-only constraints."""
stmt_error, parsed_stmt = _validate_single_statement(input_data.query)
if stmt_error:
return stmt_error
assert parsed_stmt is not None
if input_data.read_only:
return _validate_query_is_read_only(parsed_stmt)
return None
async def _resolve_host(
self, input_data: "SQLQueryBlock.Input"
) -> tuple[str, str, str | None]:
"""Validate and resolve the database host. Returns (host, pinned_ip, error)."""
host = input_data.host.get_secret_value().strip()
if not host:
return "", "", "Database host is required."
if host.startswith("/"):
return host, "", "Unix socket connections are not allowed."
try:
resolved_ips = await self.check_host_allowed(host)
except (ValueError, OSError) as e:
return host, "", f"Blocked host: {str(e).strip()}"
return host, resolved_ips[0], None
@staticmethod
def _format_operational_error(
e: OperationalError,
conn_str: str,
timeout: int,
**sanitize_kwargs: Any,
) -> str:
"""Sanitize and classify an OperationalError for user display."""
msg = _sanitize_error(str(e).strip(), conn_str, **sanitize_kwargs)
if "timeout" in msg.lower() or "cancel" in msg.lower():
return f"Query timed out after {timeout}s."
if "connect" in msg.lower():
return f"Failed to connect to database: {msg}"
return f"Database error: {msg}"

View File

@@ -10,10 +10,9 @@ import pytest
from pydantic import SecretStr
from sqlalchemy.exc import OperationalError
from backend.blocks.sql_query_block import (
from backend.blocks.sql_query_block import SQLQueryBlock, UserPasswordCredentials
from backend.blocks.sql_query_helpers import (
DatabaseType,
SQLQueryBlock,
UserPasswordCredentials,
_sanitize_error,
_serialize_value,
_validate_query_is_read_only,

View File

@@ -0,0 +1,322 @@
import re
from datetime import date, datetime, time
from decimal import Decimal
from enum import Enum
from typing import Any
import sqlparse
from sqlalchemy import create_engine, text
from sqlalchemy.engine.url import URL
class DatabaseType(str, Enum):
POSTGRES = "postgres"
MYSQL = "mysql"
MSSQL = "mssql"
# Defense-in-depth: reject queries containing data-modifying keywords.
# These are checked against parsed SQL tokens (not raw text) so column names
# and string literals do not cause false positives.
_DISALLOWED_KEYWORDS = {
"INSERT",
"UPDATE",
"DELETE",
"DROP",
"ALTER",
"CREATE",
"TRUNCATE",
"GRANT",
"REVOKE",
"COPY",
"EXECUTE",
"CALL",
"SET",
"RESET",
"DISCARD",
"NOTIFY",
"DO",
# SELECT ... INTO creates tables (PG/MSSQL) or writes files (MySQL OUTFILE/DUMPFILE)
"INTO",
"OUTFILE",
"DUMPFILE",
}
# Map DatabaseType enum values to the expected SQLAlchemy driver prefix.
_DATABASE_TYPE_TO_DRIVER = {
DatabaseType.POSTGRES: "postgresql",
DatabaseType.MYSQL: "mysql+pymysql",
DatabaseType.MSSQL: "mssql+pymssql",
}
# Default ports for each database type.
_DATABASE_TYPE_DEFAULT_PORT = {
DatabaseType.POSTGRES: 5432,
DatabaseType.MYSQL: 3306,
DatabaseType.MSSQL: 1433,
}
def _sanitize_error(
error_msg: str,
connection_string: str,
*,
host: str = "",
original_host: str = "",
username: str = "",
port: int = 0,
database: str = "",
) -> str:
"""Remove connection string, credentials, and infrastructure details
from error messages so they are safe to expose to the LLM.
Scrubs:
- The full connection string
- URL-embedded credentials (``://user:pass@``)
- ``password=<value>`` key-value pairs
- The database hostname / IP used for the connection
- The original (pre-resolution) hostname provided by the user
- Any IPv4 addresses that appear in the message
- Any bracketed IPv6 addresses (e.g. ``[::1]``, ``[fe80::1%eth0]``)
- The database username
- The database port number
- The database name
"""
sanitized = error_msg.replace(connection_string, "<connection_string>")
sanitized = re.sub(r"password=[^\s&]+", "password=***", sanitized)
sanitized = re.sub(r"://[^@]+@", "://***:***@", sanitized)
# Replace the known host (may be an IP already) before the generic IP pass.
# Also replace the original (pre-DNS-resolution) hostname if it differs.
if original_host and original_host != host:
sanitized = sanitized.replace(original_host, "<host>")
if host:
sanitized = sanitized.replace(host, "<host>")
# Replace any remaining IPv4 addresses (e.g. resolved IPs the driver logs)
sanitized = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", "<ip>", sanitized)
# Replace bracketed IPv6 addresses (e.g. "[::1]", "[fe80::1%eth0]")
sanitized = re.sub(r"\[[0-9a-fA-F:]+(?:%[^\]]+)?\]", "<ip>", sanitized)
# Replace the database username (handles double-quoted, single-quoted,
# and unquoted formats across PostgreSQL, MySQL, and MSSQL error messages).
if username:
sanitized = re.sub(
r"""for user ["']?""" + re.escape(username) + r"""["']?""",
"for user <user>",
sanitized,
)
# Catch remaining bare occurrences in various quote styles:
# - PostgreSQL: "FATAL: role "myuser" does not exist"
# - MySQL: "Access denied for user 'myuser'@'host'"
# - MSSQL: "Login failed for user 'myuser'"
sanitized = sanitized.replace(f'"{username}"', "<user>")
sanitized = sanitized.replace(f"'{username}'", "<user>")
# Replace the port number (handles "port 5432" and ":5432" formats)
if port:
port_str = re.escape(str(port))
sanitized = re.sub(
r"(?:port |:)" + port_str + r"(?![0-9])",
lambda m: ("port " if m.group().startswith("p") else ":") + "<port>",
sanitized,
)
# Replace the database name to avoid leaking internal infrastructure names.
if database:
sanitized = sanitized.replace(database, "<database>")
return sanitized
def _extract_keyword_tokens(parsed: sqlparse.sql.Statement) -> list[str]:
"""Extract keyword tokens from a parsed SQL statement.
Uses sqlparse token type classification to collect Keyword/DML/DDL/DCL
tokens. String literals and identifiers have different token types, so
they are naturally excluded from the result.
"""
return [
token.normalized.upper()
for token in parsed.flatten()
if token.ttype
in (
sqlparse.tokens.Keyword,
sqlparse.tokens.Keyword.DML,
sqlparse.tokens.Keyword.DDL,
sqlparse.tokens.Keyword.DCL,
)
]
def _validate_query_is_read_only(stmt: sqlparse.sql.Statement) -> str | None:
"""Validate that a parsed SQL statement is read-only (SELECT/WITH only).
Accepts an already-parsed statement from ``_validate_single_statement``
to avoid re-parsing. Checks:
1. Statement type must be SELECT (sqlparse classifies WITH...SELECT as SELECT)
2. No disallowed keywords (INSERT, UPDATE, DELETE, DROP, INTO, etc.)
Returns an error message if the query is not read-only, None otherwise.
"""
# sqlparse returns 'SELECT' for SELECT and WITH...SELECT queries
if stmt.get_type() != "SELECT":
return "Only SELECT queries are allowed."
# Defense-in-depth: check parsed keyword tokens for disallowed keywords
for kw in _extract_keyword_tokens(stmt):
# Normalize multi-word tokens (e.g. "SET LOCAL" -> "SET")
base_kw = kw.split()[0] if " " in kw else kw
if base_kw in _DISALLOWED_KEYWORDS:
return f"Disallowed SQL keyword: {kw}"
return None
def _validate_single_statement(
query: str,
) -> tuple[str | None, sqlparse.sql.Statement | None]:
"""Validate that the query contains exactly one non-empty SQL statement.
Returns (error_message, parsed_statement). If error_message is not None,
the query is invalid and parsed_statement will be None.
"""
stripped = query.strip().rstrip(";").strip()
if not stripped:
return "Query is empty.", None
# Parse the SQL using sqlparse for proper tokenization
statements = sqlparse.parse(stripped)
# Filter out empty statements and comment-only statements
statements = [
s
for s in statements
if s.tokens
and str(s).strip()
and not all(
t.is_whitespace or t.ttype in sqlparse.tokens.Comment for t in s.flatten()
)
]
if not statements:
return "Query is empty.", None
# Reject multiple statements -- prevents injection via semicolons
if len(statements) > 1:
return "Only single statements are allowed.", None
return None, statements[0]
def _serialize_value(value: Any) -> Any:
"""Convert database-specific types to JSON-serializable Python types."""
if isinstance(value, Decimal):
# Use int for whole numbers; use str for fractional to preserve exact
# precision (float would silently round high-precision analytics values).
if value == value.to_integral_value():
return int(value)
return str(value)
if isinstance(value, (datetime, date, time)):
return value.isoformat()
if isinstance(value, memoryview):
return bytes(value).hex()
if isinstance(value, bytes):
return value.hex()
return value
def _configure_session(
conn: Any,
dialect_name: str,
timeout_ms: str,
read_only: bool,
) -> None:
"""Set session-level timeout and read-only mode for the given dialect."""
if dialect_name == "postgresql":
conn.execute(text("SET statement_timeout = " + timeout_ms))
if read_only:
conn.execute(text("SET default_transaction_read_only = ON"))
elif dialect_name == "mysql":
# NOTE: MAX_EXECUTION_TIME only applies to SELECT statements.
# Write queries (INSERT/UPDATE/DELETE) are not bounded by this
# setting; they rely on the database's wait_timeout instead.
conn.execute(text("SET SESSION MAX_EXECUTION_TIME = " + timeout_ms))
if read_only:
conn.execute(text("SET SESSION TRANSACTION READ ONLY"))
elif dialect_name == "mssql":
# MSSQL: SET LOCK_TIMEOUT limits lock-wait time (ms).
# pymssql's connect_args "login_timeout" handles the connection
# timeout, but LOCK_TIMEOUT covers in-query lock waits.
conn.execute(text("SET LOCK_TIMEOUT " + timeout_ms))
# MSSQL lacks a session-level read-only mode like
# PostgreSQL/MySQL. Read-only enforcement is handled by
# the SQL validation layer (_validate_query_is_read_only)
# and the ROLLBACK in the finally block.
def _run_in_transaction(
conn: Any,
dialect_name: str,
query: str,
max_rows: int,
read_only: bool,
) -> tuple[list[dict[str, Any]], list[str], int]:
"""Execute a query inside an explicit transaction, returning results."""
# MSSQL uses T-SQL "BEGIN TRANSACTION"; others use "BEGIN".
begin_stmt = "BEGIN TRANSACTION" if dialect_name == "mssql" else "BEGIN"
conn.execute(text(begin_stmt))
try:
result = conn.execute(text(query))
affected = result.rowcount if not result.returns_rows else -1
columns = list(result.keys()) if result.returns_rows else []
rows = result.fetchmany(max_rows) if result.returns_rows else []
results = [
{col: _serialize_value(val) for col, val in zip(columns, row)}
for row in rows
]
finally:
conn.execute(text("ROLLBACK" if read_only else "COMMIT"))
return results, columns, affected
def _execute_query(
connection_url: URL | str,
query: str,
timeout: int,
max_rows: int,
read_only: bool = True,
database_type: DatabaseType = DatabaseType.POSTGRES,
) -> tuple[list[dict[str, Any]], list[str], int]:
"""Execute a SQL query and return (rows, columns, affected_rows).
Uses SQLAlchemy to connect to any supported database.
For SELECT queries, rows are limited to ``max_rows`` via DBAPI fetchmany.
For write queries, affected_rows contains the rowcount from the driver.
When ``read_only`` is True, the database session is set to read-only
mode and the transaction is always rolled back.
"""
# Determine driver-specific connection timeout argument.
# pymssql uses "login_timeout", while PostgreSQL/MySQL use "connect_timeout".
timeout_key = (
"login_timeout" if database_type == DatabaseType.MSSQL else "connect_timeout"
)
engine = create_engine(connection_url, connect_args={timeout_key: 10})
try:
with engine.connect() as conn:
# Use AUTOCOMMIT so SET commands take effect immediately.
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
# Compute timeout in milliseconds. The value is Pydantic-validated
# (ge=1, le=120), but we use int() as defense-in-depth.
# NOTE: SET commands do not support bind parameters in most
# databases, so we use str(int(...)) for safe interpolation.
timeout_ms = str(int(timeout * 1000))
_configure_session(conn, engine.dialect.name, timeout_ms, read_only)
return _run_in_transaction(
conn, engine.dialect.name, query, max_rows, read_only
)
finally:
engine.dispose()