mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend): address round 2 review — port validation, comment fix, dead fallback
- Add ge=1, le=65535 port validation to Input schema - Fix inaccurate comment: pymssql not pyodbc - Replace _DATABASE_TYPE_DEFAULT_PORT.get() with direct dict access (all types have entries after SQLite removal) - Update default port tests to use port=None instead of port=0
This commit is contained in:
@@ -285,6 +285,8 @@ class SQLQueryBlock(Block):
|
||||
"Database port (leave empty for default: "
|
||||
"PostgreSQL: 5432, MySQL: 3306, MSSQL: 1433)"
|
||||
),
|
||||
ge=1,
|
||||
le=65535,
|
||||
)
|
||||
database: str = SchemaField(
|
||||
description="Name of the database to connect to",
|
||||
@@ -394,7 +396,7 @@ class SQLQueryBlock(Block):
|
||||
mode and the transaction is always rolled back.
|
||||
"""
|
||||
# Determine driver-specific connection timeout argument.
|
||||
# pyodbc (MSSQL) uses "timeout", while PostgreSQL/MySQL use "connect_timeout".
|
||||
# pymssql uses "login_timeout", while PostgreSQL/MySQL use "connect_timeout".
|
||||
if database_type == DatabaseType.MSSQL:
|
||||
connect_args = {"login_timeout": 10}
|
||||
else:
|
||||
@@ -511,9 +513,7 @@ class SQLQueryBlock(Block):
|
||||
# URL.create() accepts the raw password without URL-encoding,
|
||||
# so special characters like @, #, ! work correctly.
|
||||
drivername = _DATABASE_TYPE_TO_DRIVER[input_data.database_type]
|
||||
port = input_data.port or _DATABASE_TYPE_DEFAULT_PORT.get(
|
||||
input_data.database_type, 5432
|
||||
)
|
||||
port = input_data.port or _DATABASE_TYPE_DEFAULT_PORT[input_data.database_type]
|
||||
username = credentials.username.get_secret_value()
|
||||
password = credentials.password.get_secret_value()
|
||||
connection_url = URL.create(
|
||||
|
||||
@@ -775,14 +775,14 @@ class TestSQLQueryBlockDefaultPort:
|
||||
"""Port should default based on the selected database_type."""
|
||||
|
||||
async def test_mysql_default_port_3306(self):
|
||||
"""MySQL should default to port 3306 when port is 0."""
|
||||
"""MySQL should default to port 3306 when port is None."""
|
||||
block = SQLQueryBlock()
|
||||
creds = _make_credentials()
|
||||
input_data = _make_input(
|
||||
creds,
|
||||
query="SELECT 1",
|
||||
database_type=DatabaseType.MYSQL,
|
||||
port=0,
|
||||
port=None,
|
||||
read_only=False,
|
||||
)
|
||||
block.check_host_allowed = AsyncMock(return_value=["1.2.3.4"]) # type: ignore[assignment]
|
||||
@@ -798,14 +798,14 @@ class TestSQLQueryBlockDefaultPort:
|
||||
assert ":3306/" in captured_conn_str["value"]
|
||||
|
||||
async def test_mssql_default_port_1433(self):
|
||||
"""MSSQL should default to port 1433 when port is 0."""
|
||||
"""MSSQL should default to port 1433 when port is None."""
|
||||
block = SQLQueryBlock()
|
||||
creds = _make_credentials()
|
||||
input_data = _make_input(
|
||||
creds,
|
||||
query="SELECT 1",
|
||||
database_type=DatabaseType.MSSQL,
|
||||
port=0,
|
||||
port=None,
|
||||
read_only=False,
|
||||
)
|
||||
block.check_host_allowed = AsyncMock(return_value=["1.2.3.4"]) # type: ignore[assignment]
|
||||
@@ -821,14 +821,14 @@ class TestSQLQueryBlockDefaultPort:
|
||||
assert ":1433/" in captured_conn_str["value"]
|
||||
|
||||
async def test_postgres_default_port_5432(self):
|
||||
"""PostgreSQL should default to port 5432 when port is 0."""
|
||||
"""PostgreSQL should default to port 5432 when port is None."""
|
||||
block = SQLQueryBlock()
|
||||
creds = _make_credentials()
|
||||
input_data = _make_input(
|
||||
creds,
|
||||
query="SELECT 1",
|
||||
database_type=DatabaseType.POSTGRES,
|
||||
port=0,
|
||||
port=None,
|
||||
read_only=False,
|
||||
)
|
||||
block.check_host_allowed = AsyncMock(return_value=["1.2.3.4"]) # type: ignore[assignment]
|
||||
|
||||
Reference in New Issue
Block a user