test(backend): add integration tests for SQLQueryBlock SSRF, SQLite, and error handling

Add run()-level tests covering SSRF private IP rejection (127.0.0.1,
10.x, 172.16.x, 192.168.x), Unix socket blocking, missing hostname
rejection, SQLite disabled error, credential sanitization on connection
failure, query timeout clean error, URL type mismatch rejection, happy
path, and EXECUTE keyword rejection. Also adds time serialization test.
This commit is contained in:
Zamil Majdy
2026-03-26 14:05:09 +07:00
parent 292be77b86
commit c636b6f310

View File

@@ -1,12 +1,18 @@
"""Tests for SQLQueryBlock query validation, URL validation, and error sanitization."""
from datetime import date, datetime
from datetime import date, datetime, time
from decimal import Decimal
from typing import Any
from unittest.mock import AsyncMock
import pytest
from pydantic import SecretStr
from sqlalchemy.exc import OperationalError
from backend.blocks.sql_query_block import (
APIKeyCredentials,
DatabaseType,
SQLQueryBlock,
_sanitize_error,
_serialize_value,
_validate_connection_url,
@@ -66,6 +72,7 @@ class TestValidateQueryIsReadOnly:
"GRANT SELECT ON users TO public",
"REVOKE SELECT ON users FROM public",
"COPY users TO '/tmp/out.csv'",
"EXECUTE my_prepared_statement",
],
)
def test_disallowed_statements_rejected(self, query: str):
@@ -240,6 +247,10 @@ class TestSerializeValue:
d = date(2024, 6, 15)
assert _serialize_value(d) == "2024-06-15"
def test_time(self):
t = time(14, 30, 45)
assert _serialize_value(t) == "14:30:45"
def test_memoryview(self):
mv = memoryview(b"\xde\xad\xbe\xef")
assert _serialize_value(mv) == "deadbeef"
@@ -296,3 +307,232 @@ class TestSanitizeError:
result = _sanitize_error(error, conn)
assert "s3cret" not in result
assert "<connection_string>" in result
# ---------------------------------------------------------------------------
# Helpers for run()-level integration tests
# ---------------------------------------------------------------------------
def _make_credentials(url: str) -> APIKeyCredentials:
return APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="database",
api_key=SecretStr(url),
title="test creds",
)
def _make_input(
creds: APIKeyCredentials,
query: str = "SELECT 1",
database_type: DatabaseType = DatabaseType.POSTGRES,
timeout: int = 30,
max_rows: int = 100,
) -> SQLQueryBlock.Input:
return SQLQueryBlock.Input(
query=query,
database_type=database_type,
timeout=timeout,
max_rows=max_rows,
credentials={ # type: ignore[arg-type]
"provider": "database",
"id": creds.id,
"type": "api_key",
"title": "t",
},
)
async def _collect_outputs(
block: SQLQueryBlock,
input_data: SQLQueryBlock.Input,
credentials: APIKeyCredentials,
) -> dict[str, Any]:
"""Run the block and collect all yielded (name, value) pairs."""
outputs: dict[str, Any] = {}
async for name, value in block.run(input_data, credentials=credentials):
outputs[name] = value
return outputs
# ---------------------------------------------------------------------------
# Integration tests: SQLQueryBlock.run() — SSRF, SQLite, error handling
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestSQLQueryBlockRunSSRF:
"""SSRF protection tests exercised through the block's run() method."""
async def test_private_ip_127_rejected(self):
"""Scenario 6: loopback 127.0.0.1 must be blocked."""
block = SQLQueryBlock()
creds = _make_credentials("postgresql://user:pass@127.0.0.1:5432/db")
input_data = _make_input(creds)
# Mock check_host_allowed to simulate the real SSRF check raising
block.check_host_allowed = AsyncMock( # type: ignore[assignment]
side_effect=ValueError(
"Access to blocked or private IP address 127.0.0.1 "
"for hostname 127.0.0.1 is not allowed."
)
)
outputs = await _collect_outputs(block, input_data, creds)
assert "error" in outputs
assert "Blocked host" in outputs["error"]
async def test_private_ip_10_rejected(self):
"""Scenario 7: internal 10.x.x.x must be blocked."""
block = SQLQueryBlock()
creds = _make_credentials("postgresql://user:pass@10.0.0.1:5432/db")
input_data = _make_input(creds)
block.check_host_allowed = AsyncMock( # type: ignore[assignment]
side_effect=ValueError(
"Access to blocked or private IP address 10.0.0.1 "
"for hostname 10.0.0.1 is not allowed."
)
)
outputs = await _collect_outputs(block, input_data, creds)
assert "error" in outputs
assert "Blocked host" in outputs["error"]
async def test_unix_socket_rejected(self):
"""Scenario 8: Unix socket paths in host must be blocked."""
block = SQLQueryBlock()
creds = _make_credentials("postgresql://user:pass@/var/run/postgresql/db")
input_data = _make_input(creds)
outputs = await _collect_outputs(block, input_data, creds)
assert "error" in outputs
# Should be either "Unix socket" or "must specify a database host"
assert "Unix socket" in outputs["error"] or "must specify" in outputs["error"]
async def test_missing_hostname_rejected(self):
"""Scenario 9: Connection string without a hostname must be rejected."""
block = SQLQueryBlock()
creds = _make_credentials("postgresql:///db")
input_data = _make_input(creds)
outputs = await _collect_outputs(block, input_data, creds)
assert "error" in outputs
assert "must specify" in outputs["error"] or "host" in outputs["error"].lower()
async def test_private_ip_172_rejected(self):
"""172.16.x.x private range must be blocked."""
block = SQLQueryBlock()
creds = _make_credentials("postgresql://user:pass@172.16.0.1:5432/db")
input_data = _make_input(creds)
block.check_host_allowed = AsyncMock( # type: ignore[assignment]
side_effect=ValueError(
"Access to blocked or private IP address 172.16.0.1 "
"for hostname 172.16.0.1 is not allowed."
)
)
outputs = await _collect_outputs(block, input_data, creds)
assert "error" in outputs
assert "Blocked host" in outputs["error"]
async def test_private_ip_192_168_rejected(self):
"""192.168.x.x private range must be blocked."""
block = SQLQueryBlock()
creds = _make_credentials("postgresql://user:pass@192.168.1.1:5432/db")
input_data = _make_input(creds)
block.check_host_allowed = AsyncMock( # type: ignore[assignment]
side_effect=ValueError(
"Access to blocked or private IP address 192.168.1.1 "
"for hostname 192.168.1.1 is not allowed."
)
)
outputs = await _collect_outputs(block, input_data, creds)
assert "error" in outputs
assert "Blocked host" in outputs["error"]
@pytest.mark.asyncio
class TestSQLQueryBlockRunSQLite:
"""SQLite must be explicitly disabled (scenario 13)."""
async def test_sqlite_disabled_with_clear_error(self):
block = SQLQueryBlock()
creds = _make_credentials("sqlite:///path/to/db.sqlite")
input_data = _make_input(creds, database_type=DatabaseType.SQLITE)
outputs = await _collect_outputs(block, input_data, creds)
assert "error" in outputs
assert "SQLite" in outputs["error"]
assert "not supported" in outputs["error"]
@pytest.mark.asyncio
class TestSQLQueryBlockRunErrorHandling:
"""Error handling: connection failures and timeouts (scenarios 20-21)."""
async def test_connection_failure_sanitized_no_credentials(self):
"""Scenario 20: Connection error must not leak credentials."""
block = SQLQueryBlock()
conn_url = "postgresql://admin:supersecret@db.example.com:5432/prod"
creds = _make_credentials(conn_url)
input_data = _make_input(creds)
# Mock SSRF check to allow the host
block.check_host_allowed = AsyncMock(return_value=None) # type: ignore[assignment]
# Mock execute_query to raise an OperationalError with credentials in msg
block.execute_query = lambda **kwargs: (_ for _ in ()).throw( # type: ignore[assignment]
OperationalError(
f"could not connect to server: Connection refused\n"
f'\tIs the server running on host "{conn_url}"?',
params=None,
orig=Exception("connection refused"),
)
)
outputs = await _collect_outputs(block, input_data, creds)
assert "error" in outputs
assert "supersecret" not in outputs["error"]
assert "connect" in str(outputs["error"]).lower()
async def test_query_timeout_clean_error(self):
"""Scenario 21: Timeout yields clean user-facing message."""
block = SQLQueryBlock()
conn_url = "postgresql://user:pass@db.example.com:5432/db"
creds = _make_credentials(conn_url)
input_data = _make_input(creds, query="SELECT pg_sleep(1000)", timeout=5)
block.check_host_allowed = AsyncMock(return_value=None) # type: ignore[assignment]
block.execute_query = lambda **kwargs: (_ for _ in ()).throw( # type: ignore[assignment]
OperationalError(
"canceling statement due to statement timeout",
params=None,
orig=Exception("timeout"),
)
)
outputs = await _collect_outputs(block, input_data, creds)
assert "error" in outputs
assert "timed out" in str(outputs["error"]).lower()
assert "5s" in str(outputs["error"])
async def test_read_only_query_rejected_at_run_level(self):
"""INSERT is rejected before any connection attempt."""
block = SQLQueryBlock()
creds = _make_credentials("postgresql://user:pass@db.example.com:5432/db")
input_data = _make_input(creds, query="INSERT INTO users VALUES (1, 'test')")
outputs = await _collect_outputs(block, input_data, creds)
assert "error" in outputs
assert "SELECT" in str(outputs["error"])
async def test_successful_query_returns_results(self):
"""Happy path: mocked query returns correct structure."""
block = SQLQueryBlock()
creds = _make_credentials("postgresql://user:pass@db.example.com:5432/db")
input_data = _make_input(creds, query="SELECT id, name FROM users LIMIT 2")
block.check_host_allowed = AsyncMock(return_value=None) # type: ignore[assignment]
mock_rows = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]
mock_cols = ["id", "name"]
block.execute_query = lambda **kwargs: (mock_rows, mock_cols) # type: ignore[assignment]
outputs = await _collect_outputs(block, input_data, creds)
assert outputs["results"] == mock_rows
assert outputs["columns"] == mock_cols
assert outputs["row_count"] == 2
async def test_url_mismatch_rejected(self):
"""MySQL URL with Postgres type is rejected at run level."""
block = SQLQueryBlock()
creds = _make_credentials("mysql://user:pass@db.example.com:3306/db")
input_data = _make_input(creds)
outputs = await _collect_outputs(block, input_data, creds)
assert "error" in outputs
assert "does not match" in str(outputs["error"])