fix(mcp): Remove trusted_origins to prevent SSRF on user-provided URLs

User-provided MCP server URLs should not bypass SSRF IP-blocking
validation. Remove trusted_origins from all MCP code so that
private/internal IPs are properly blocked. Keep ThreadedResolver
in HostResolver fallback for DNS reliability in subprocess
environments.
This commit is contained in:
Zamil Majdy
2026-02-09 18:55:17 +04:00
parent 340520ba85
commit fe70b6929f
6 changed files with 5 additions and 28 deletions

View File

@@ -103,11 +103,7 @@ async def discover_tools(
logger.debug("Could not look up stored MCP credentials", exc_info=True)
try:
client = MCPClient(
request.server_url,
auth_token=auth_token,
trusted_origins=[request.server_url],
)
client = MCPClient(request.server_url, auth_token=auth_token)
init_result = await client.initialize()
tools = await client.list_tools()
@@ -177,7 +173,7 @@ async def mcp_oauth_login(
3. Performs Dynamic Client Registration (RFC 7591) if available
4. Returns the authorization URL for the frontend to open in a popup
"""
client = MCPClient(request.server_url, trusted_origins=[request.server_url])
client = MCPClient(request.server_url)
# Step 1: Discover protected-resource metadata (RFC 9728)
try:

View File

@@ -84,7 +84,6 @@ class TestDiscoverTools:
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="my-secret-token",
trusted_origins=["https://mcp.example.com/mcp"],
)
def test_discover_tools_auto_uses_stored_credential(self):
@@ -124,7 +123,6 @@ class TestDiscoverTools:
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="stored-token-123",
trusted_origins=["https://mcp.example.com/mcp"],
)
def test_discover_tools_mcp_error(self):

View File

@@ -148,12 +148,7 @@ class MCPToolBlock(Block):
auth_token: str | None = None,
) -> Any:
"""Call a tool on the MCP server. Extracted for easy mocking in tests."""
# Trust the user-configured server URL to allow internal/localhost servers
client = MCPClient(
server_url,
auth_token=auth_token,
trusted_origins=[server_url],
)
client = MCPClient(server_url, auth_token=auth_token)
await client.initialize()
result = await client.call_tool(tool_name, arguments)

View File

@@ -54,11 +54,9 @@ class MCPClient:
self,
server_url: str,
auth_token: str | None = None,
trusted_origins: list[str] | None = None,
):
self.server_url = server_url.rstrip("/")
self.auth_token = auth_token
self.trusted_origins = trusted_origins or []
self._request_id = 0
self._session_id: str | None = None
@@ -132,7 +130,6 @@ class MCPClient:
requests = Requests(
raise_for_status=True,
extra_headers=headers,
trusted_origins=self.trusted_origins,
)
response = await requests.post(self.server_url, json=payload)
@@ -171,7 +168,6 @@ class MCPClient:
requests = Requests(
raise_for_status=False,
extra_headers=headers,
trusted_origins=self.trusted_origins,
)
await requests.post(self.server_url, json=notification)
@@ -201,7 +197,6 @@ class MCPClient:
requests = Requests(
raise_for_status=False,
trusted_origins=self.trusted_origins,
)
for url in candidates:
try:
@@ -243,7 +238,6 @@ class MCPClient:
requests = Requests(
raise_for_status=False,
trusted_origins=self.trusted_origins,
)
for url in candidates:
try:

View File

@@ -85,8 +85,8 @@ def mcp_server_with_auth():
def _make_client(url: str, auth_token: str | None = None) -> MCPClient:
"""Create an MCPClient with localhost trusted for integration tests."""
return MCPClient(url, auth_token=auth_token, trusted_origins=[url])
"""Create an MCPClient for integration tests."""
return MCPClient(url, auth_token=auth_token)
def _make_fake_creds(api_key: str = "FAKE_API_KEY") -> APIKeyCredentials:

View File

@@ -467,12 +467,6 @@ class Requests:
resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses)
ssl_context = ssl.create_default_context()
connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context)
else:
# Use ThreadedResolver for trusted origins to avoid c-ares DNS issues
# in subprocess environments (e.g. ExecutionManager on macOS).
connector = aiohttp.TCPConnector(
resolver=aiohttp.ThreadedResolver()
)
session_kwargs: dict = {}
if connector:
session_kwargs["connector"] = connector