mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend): address review round — remove SQLite, hide password, cleanup dead code
- Remove DatabaseType.SQLITE from enum (rejected at runtime, confusing UX) - Remove all SQLite dead code paths (driver map, connect_args, runtime check) - Change render_as_string(hide_password=False) to hide_password=True to avoid materializing plaintext credentials in local variable - Simplify pinned_host assignment (remove unreachable fallback branch) - Remove SQLite-related test cases - Add doc comment to _make_input noting read_only default deviation
This commit is contained in:
@@ -57,7 +57,6 @@ def DatabaseCredentialsField() -> DatabaseCredentialsInput:
|
||||
class DatabaseType(str, Enum):
|
||||
POSTGRES = "postgres"
|
||||
MYSQL = "mysql"
|
||||
SQLITE = "sqlite"
|
||||
MSSQL = "mssql"
|
||||
|
||||
|
||||
@@ -257,7 +256,6 @@ def _serialize_value(value: Any) -> Any:
|
||||
_DATABASE_TYPE_TO_DRIVER = {
|
||||
DatabaseType.POSTGRES: "postgresql",
|
||||
DatabaseType.MYSQL: "mysql+pymysql",
|
||||
DatabaseType.SQLITE: "sqlite",
|
||||
DatabaseType.MSSQL: "mssql+pymssql",
|
||||
}
|
||||
|
||||
@@ -397,9 +395,7 @@ class SQLQueryBlock(Block):
|
||||
"""
|
||||
# Determine driver-specific connection timeout argument.
|
||||
# pyodbc (MSSQL) uses "timeout", while PostgreSQL/MySQL use "connect_timeout".
|
||||
if database_type == DatabaseType.SQLITE:
|
||||
connect_args: dict[str, Any] = {}
|
||||
elif database_type == DatabaseType.MSSQL:
|
||||
if database_type == DatabaseType.MSSQL:
|
||||
connect_args = {"login_timeout": 10}
|
||||
else:
|
||||
connect_args = {"connect_timeout": 10}
|
||||
@@ -488,11 +484,6 @@ class SQLQueryBlock(Block):
|
||||
yield "error", ro_error
|
||||
return
|
||||
|
||||
# SQLite is not supported for remote execution
|
||||
if input_data.database_type == DatabaseType.SQLITE:
|
||||
yield "error", "SQLite is not supported for remote execution."
|
||||
return
|
||||
|
||||
host = input_data.host.get_secret_value().strip()
|
||||
if not host:
|
||||
yield "error", "Database host is required."
|
||||
@@ -514,7 +505,7 @@ class SQLQueryBlock(Block):
|
||||
return
|
||||
|
||||
# Pin the connection to the first resolved IP to prevent DNS rebinding.
|
||||
pinned_host = resolved_ips[0] if resolved_ips else host
|
||||
pinned_host = resolved_ips[0]
|
||||
|
||||
# Build the SQLAlchemy connection URL from discrete fields.
|
||||
# URL.create() accepts the raw password without URL-encoding,
|
||||
@@ -536,7 +527,7 @@ class SQLQueryBlock(Block):
|
||||
# Render the connection string for error sanitization only.
|
||||
# The URL object is passed directly to create_engine() to prevent
|
||||
# database name injection (e.g. "db?options=-c statement_timeout=0").
|
||||
connection_string = connection_url.render_as_string(hide_password=False)
|
||||
connection_string = connection_url.render_as_string(hide_password=True)
|
||||
|
||||
try:
|
||||
results, columns, affected = await asyncio.to_thread(
|
||||
|
||||
@@ -324,6 +324,9 @@ def _make_input(
|
||||
timeout: int = 30,
|
||||
max_rows: int = 100,
|
||||
) -> SQLQueryBlock.Input:
|
||||
"""Build a test input. Note: ``read_only`` defaults to ``False`` here (unlike the
|
||||
block's default of ``True``) so write-mode tests don't need to specify it.
|
||||
Set ``read_only=True`` explicitly when testing read-only behaviour."""
|
||||
return SQLQueryBlock.Input(
|
||||
query=query,
|
||||
database_type=database_type,
|
||||
@@ -355,7 +358,7 @@ async def _collect_outputs(
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests: SQLQueryBlock.run() -- SSRF, SQLite, error handling
|
||||
# Integration tests: SQLQueryBlock.run() -- SSRF, error handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -443,20 +446,6 @@ class TestSQLQueryBlockRunSSRF:
|
||||
assert "Blocked host" in outputs["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSQLQueryBlockRunSQLite:
|
||||
"""SQLite must be explicitly disabled (scenario 13)."""
|
||||
|
||||
async def test_sqlite_disabled_with_clear_error(self):
|
||||
block = SQLQueryBlock()
|
||||
creds = _make_credentials()
|
||||
input_data = _make_input(creds, database_type=DatabaseType.SQLITE)
|
||||
outputs = await _collect_outputs(block, input_data, creds)
|
||||
assert "error" in outputs
|
||||
assert "SQLite" in outputs["error"]
|
||||
assert "not supported" in outputs["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSQLQueryBlockRunErrorHandling:
|
||||
"""Error handling: connection failures and timeouts."""
|
||||
@@ -681,20 +670,6 @@ class TestSQLQueryBlockWriteMode:
|
||||
assert "error" in outputs
|
||||
assert "Blocked host" in outputs["error"]
|
||||
|
||||
async def test_sqlite_blocked_in_write_mode(self):
|
||||
"""SQLite remains disabled regardless of read_only setting."""
|
||||
block = SQLQueryBlock()
|
||||
creds = _make_credentials()
|
||||
input_data = _make_input(
|
||||
creds,
|
||||
query="INSERT INTO t VALUES (1)",
|
||||
database_type=DatabaseType.SQLITE,
|
||||
read_only=False,
|
||||
)
|
||||
outputs = await _collect_outputs(block, input_data, creds)
|
||||
assert "error" in outputs
|
||||
assert "SQLite" in outputs["error"]
|
||||
|
||||
async def test_affected_rows_returned_for_write(self):
|
||||
"""Write queries return affected_rows count."""
|
||||
block = SQLQueryBlock()
|
||||
|
||||
Reference in New Issue
Block a user