mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend): Increase block request security; Prevent DNS rebinding & open redirect attack (#9688)
The current block web requests utility has a logic to avoid the system firing into blocklisted IPs. However, the current logic is still prone to a few security issues: * DNS rebinding attack: due to the lack of guarantee on the used IP not being changed during the IP checking and firing step. * Open redirect: due to the request sensitive request headers are still being propagated throughout the web redirect. ### Changes 🏗️ * Uses IP pinning to request the web. * Strip `Authorization`, `Proxy-Authorization`, `Cookie` upon web redirects. ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Test the web request block, add more tests with different validation scenarios.
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
import ipaddress
|
||||
import re
|
||||
import socket
|
||||
import ssl
|
||||
from typing import Callable
|
||||
from urllib.parse import urljoin, urlparse, urlunparse
|
||||
from urllib.parse import quote, urljoin, urlparse, urlunparse
|
||||
|
||||
import idna
|
||||
import requests as req
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3 import PoolManager
|
||||
|
||||
from backend.util.settings import Config
|
||||
|
||||
@@ -41,11 +44,56 @@ def _is_ip_blocked(ip: str) -> bool:
|
||||
return any(ip_addr in network for network in BLOCKED_IP_NETWORKS)
|
||||
|
||||
|
||||
def validate_url(url: str, trusted_origins: list[str]) -> str:
|
||||
def _remove_insecure_headers(headers: dict, old_url: str, new_url: str) -> dict:
|
||||
"""
|
||||
Removes sensitive headers (Authorization, Proxy-Authorization, Cookie)
|
||||
if the scheme/host/port of new_url differ from old_url.
|
||||
"""
|
||||
old_parsed = urlparse(old_url)
|
||||
new_parsed = urlparse(new_url)
|
||||
if (
|
||||
(old_parsed.scheme != new_parsed.scheme)
|
||||
or (old_parsed.hostname != new_parsed.hostname)
|
||||
or (old_parsed.port != new_parsed.port)
|
||||
):
|
||||
headers.pop("Authorization", None)
|
||||
headers.pop("Proxy-Authorization", None)
|
||||
headers.pop("Cookie", None)
|
||||
return headers
|
||||
|
||||
|
||||
class HostSSLAdapter(HTTPAdapter):
|
||||
"""
|
||||
A custom adapter that connects to an IP address but still
|
||||
sets the TLS SNI to the original host name so the cert can match.
|
||||
"""
|
||||
|
||||
def __init__(self, ssl_hostname, *args, **kwargs):
|
||||
self.ssl_hostname = ssl_hostname
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def init_poolmanager(self, *args, **kwargs):
|
||||
self.poolmanager = PoolManager(
|
||||
*args,
|
||||
ssl_context=ssl.create_default_context(),
|
||||
server_hostname=self.ssl_hostname, # This works for urllib3>=2
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def validate_url(
|
||||
url: str,
|
||||
trusted_origins: list[str],
|
||||
enable_dns_rebinding: bool = True,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Validates the URL to prevent SSRF attacks by ensuring it does not point
|
||||
to a private, link-local, or otherwise blocked IP address — unless
|
||||
the hostname is explicitly trusted.
|
||||
|
||||
Returns a tuple of:
|
||||
- pinned_url: a URL that has the netloc replaced with the validated IP
|
||||
- ascii_hostname: the original ASCII hostname (IDNA-decoded) for use in the Host header
|
||||
"""
|
||||
# Canonicalize URL
|
||||
url = url.strip("/ ").replace("\\", "/")
|
||||
@@ -74,17 +122,30 @@ def validate_url(url: str, trusted_origins: list[str]) -> str:
|
||||
if not HOSTNAME_REGEX.match(ascii_hostname):
|
||||
raise ValueError("Hostname contains invalid characters.")
|
||||
|
||||
# Rebuild URL with IDNA-encoded hostname
|
||||
parsed = parsed._replace(netloc=ascii_hostname)
|
||||
url = str(urlunparse(parsed))
|
||||
|
||||
# If hostname is trusted, skip IP-based checks
|
||||
# If hostname is trusted, skip IP-based checks but still return pinned URL
|
||||
if ascii_hostname in trusted_origins:
|
||||
return url
|
||||
pinned_netloc = ascii_hostname
|
||||
if parsed.port:
|
||||
pinned_netloc += f":{parsed.port}"
|
||||
|
||||
pinned_url = urlunparse(
|
||||
(
|
||||
parsed.scheme,
|
||||
pinned_netloc,
|
||||
quote(parsed.path, safe="/%:@"),
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment,
|
||||
)
|
||||
)
|
||||
return pinned_url, ascii_hostname
|
||||
|
||||
# Resolve all IP addresses for the hostname
|
||||
try:
|
||||
ip_addresses = {res[4][0] for res in socket.getaddrinfo(ascii_hostname, None)}
|
||||
ip_list = [res[4][0] for res in socket.getaddrinfo(ascii_hostname, None)]
|
||||
ipv4 = [ip for ip in ip_list if ":" not in ip]
|
||||
ipv6 = [ip for ip in ip_list if ":" in ip]
|
||||
ip_addresses = ipv4 + ipv6 # Prefer IPv4 over IPv6
|
||||
except socket.gaierror:
|
||||
raise ValueError(f"Unable to resolve IP address for hostname {ascii_hostname}")
|
||||
|
||||
@@ -99,7 +160,33 @@ def validate_url(url: str, trusted_origins: list[str]) -> str:
|
||||
f"for hostname {ascii_hostname} is not allowed."
|
||||
)
|
||||
|
||||
return url
|
||||
# Pin to the first valid IP (for SSRF defense).
|
||||
pinned_ip = ip_addresses[0]
|
||||
|
||||
# If it's IPv6, bracket it
|
||||
if ":" in pinned_ip:
|
||||
pinned_netloc = f"[{pinned_ip}]"
|
||||
else:
|
||||
pinned_netloc = pinned_ip
|
||||
|
||||
if parsed.port:
|
||||
pinned_netloc += f":{parsed.port}"
|
||||
|
||||
if not enable_dns_rebinding:
|
||||
pinned_netloc = ascii_hostname
|
||||
|
||||
pinned_url = urlunparse(
|
||||
(
|
||||
parsed.scheme,
|
||||
pinned_netloc,
|
||||
quote(parsed.path, safe="/%:@"),
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment,
|
||||
)
|
||||
)
|
||||
|
||||
return pinned_url, ascii_hostname # (pinned_url, original_hostname)
|
||||
|
||||
|
||||
class Requests:
|
||||
@@ -137,28 +224,49 @@ class Requests:
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> req.Response:
|
||||
# Merge any extra headers
|
||||
if self.extra_headers is not None:
|
||||
headers = {**(headers or {}), **self.extra_headers}
|
||||
# Validate URL and get pinned URL + original hostname
|
||||
pinned_url, original_hostname = validate_url(url, self.trusted_origins)
|
||||
|
||||
# Validate the URL (with optional extra validator)
|
||||
url = validate_url(url, self.trusted_origins)
|
||||
# Apply any extra user-defined validation/transformation
|
||||
if self.extra_url_validator is not None:
|
||||
url = self.extra_url_validator(url)
|
||||
pinned_url = self.extra_url_validator(pinned_url)
|
||||
|
||||
# Merge any extra headers
|
||||
headers = dict(headers) if headers else {}
|
||||
if self.extra_headers is not None:
|
||||
headers.update(self.extra_headers)
|
||||
|
||||
# Force the Host header to the original hostname
|
||||
headers["Host"] = original_hostname
|
||||
|
||||
# Create a fresh session & mount our HostSSLAdapter if pinned to IP
|
||||
session = req.Session()
|
||||
pinned_parsed = urlparse(pinned_url)
|
||||
|
||||
# If pinned_url netloc is an IP (not in trusted_origins),
|
||||
# then we attach the custom SNI adapter:
|
||||
if pinned_parsed.hostname and pinned_parsed.hostname != original_hostname:
|
||||
# That means we definitely pinned to an IP
|
||||
mount_prefix = f"{pinned_parsed.scheme}://{pinned_parsed.hostname}"
|
||||
if pinned_parsed.port:
|
||||
mount_prefix += f":{pinned_parsed.port}"
|
||||
adapter = HostSSLAdapter(ssl_hostname=original_hostname)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
# Perform the request with redirects disabled for manual handling
|
||||
response = req.request(
|
||||
response = session.request(
|
||||
method,
|
||||
url,
|
||||
pinned_url,
|
||||
headers=headers,
|
||||
allow_redirects=False,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.raise_for_status:
|
||||
response.raise_for_status()
|
||||
|
||||
# If allowed and a redirect is received, follow the redirect
|
||||
# If allowed and a redirect is received, follow the redirect manually
|
||||
if allow_redirects and response.is_redirect:
|
||||
if max_redirects <= 0:
|
||||
raise Exception("Too many redirects.")
|
||||
@@ -167,14 +275,16 @@ class Requests:
|
||||
if not location:
|
||||
return response
|
||||
|
||||
new_url = validate_url(urljoin(url, location), self.trusted_origins)
|
||||
if self.extra_url_validator is not None:
|
||||
new_url = self.extra_url_validator(new_url)
|
||||
# The base URL is the pinned_url we just used
|
||||
# so that relative redirects resolve correctly.
|
||||
new_url = urljoin(pinned_url, location)
|
||||
# Carry forward the same headers but update Host
|
||||
new_headers = _remove_insecure_headers(dict(headers), url, new_url)
|
||||
|
||||
return self.request(
|
||||
method,
|
||||
new_url,
|
||||
headers=headers,
|
||||
headers=new_headers,
|
||||
allow_redirects=allow_redirects,
|
||||
max_redirects=max_redirects - 1,
|
||||
*args,
|
||||
|
||||
@@ -3,77 +3,104 @@ import pytest
|
||||
from backend.util.request import validate_url
|
||||
|
||||
|
||||
def test_validate_url():
|
||||
# Rejected IP ranges
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("localhost", [])
|
||||
@pytest.mark.parametrize(
|
||||
"url, trusted_origins, expected_value, should_raise",
|
||||
[
|
||||
# Rejected IP ranges
|
||||
("localhost", [], None, True),
|
||||
("192.168.1.1", [], None, True),
|
||||
("127.0.0.1", [], None, True),
|
||||
("0.0.0.0", [], None, True),
|
||||
# Normal URLs (should default to http:// if no scheme provided)
|
||||
("google.com/a?b=c", [], "http://google.com/a?b=c", False),
|
||||
("github.com?key=!@!@", [], "http://github.com?key=!@!@", False),
|
||||
# Scheme Enforcement
|
||||
("ftp://example.com", [], None, True),
|
||||
("file://example.com", [], None, True),
|
||||
# International domain converting to punycode (allowed if public)
|
||||
("http://xn--exmple-cua.com", [], "http://xn--exmple-cua.com", False),
|
||||
# Invalid domain (IDNA failure)
|
||||
("http://exa◌mple.com", [], None, True),
|
||||
# IPv6 addresses (loopback/blocked)
|
||||
("::1", [], None, True),
|
||||
("http://[::1]", [], None, True),
|
||||
# Suspicious Characters in Hostname
|
||||
("http://example_underscore.com", [], None, True),
|
||||
("http://exa mple.com", [], None, True),
|
||||
# Malformed URLs
|
||||
("http://", [], None, True), # No hostname
|
||||
("://missing-scheme", [], None, True), # Missing proper scheme
|
||||
# Trusted Origins
|
||||
(
|
||||
"internal-api.company.com",
|
||||
["internal-api.company.com", "10.0.0.5"],
|
||||
"http://internal-api.company.com",
|
||||
False,
|
||||
),
|
||||
("10.0.0.5", ["10.0.0.5"], "http://10.0.0.5", False),
|
||||
# Special Characters in Path
|
||||
(
|
||||
"example.com/path%20with%20spaces",
|
||||
[],
|
||||
"http://example.com/path%20with%20spaces",
|
||||
False,
|
||||
),
|
||||
# Backslashes should be replaced with forward slashes
|
||||
("http://example.com\\backslash", [], "http://example.com/backslash", False),
|
||||
# Check default-scheme behavior for valid domains
|
||||
("example.com", [], "http://example.com", False),
|
||||
("https://secure.com", [], "https://secure.com", False),
|
||||
# Non-ASCII Characters in Query/Fragment
|
||||
("example.com?param=äöü", [], "http://example.com?param=äöü", False),
|
||||
],
|
||||
)
|
||||
def test_validate_url_no_dns_rebinding(
|
||||
url, trusted_origins, expected_value, should_raise
|
||||
):
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError):
|
||||
validate_url(url, trusted_origins, enable_dns_rebinding=False)
|
||||
else:
|
||||
url, host = validate_url(url, trusted_origins, enable_dns_rebinding=False)
|
||||
assert url == expected_value
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("192.168.1.1", [])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("127.0.0.1", [])
|
||||
@pytest.mark.parametrize(
|
||||
"hostname, resolved_ips, expect_error, expected_ip",
|
||||
[
|
||||
# Multiple public IPs, none blocked
|
||||
("public-example.com", ["8.8.8.8", "9.9.9.9"], False, "8.8.8.8"),
|
||||
# Includes a blocked IP (e.g. link-local 169.254.x.x) => should raise
|
||||
("rebinding.com", ["1.2.3.4", "169.254.169.254"], True, None),
|
||||
# Single public IP
|
||||
("single-public.com", ["8.8.8.8"], False, "8.8.8.8"),
|
||||
# Single blocked IP
|
||||
("blocked.com", ["127.0.0.1"], True, None),
|
||||
],
|
||||
)
|
||||
def test_dns_rebinding_fix(
|
||||
monkeypatch, hostname, resolved_ips, expect_error, expected_ip
|
||||
):
|
||||
"""
|
||||
Tests that validate_url pins the first valid public IP address, and rejects
|
||||
the domain if any of the resolved IPs are blocked (i.e., DNS Rebinding scenario).
|
||||
"""
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("0.0.0.0", [])
|
||||
def mock_getaddrinfo(host, port, *args, **kwargs):
|
||||
# Simulate multiple IPs returned for the given hostname
|
||||
return [(None, None, None, None, (ip, port)) for ip in resolved_ips]
|
||||
|
||||
# Normal URLs
|
||||
assert validate_url("google.com/a?b=c", []) == "http://google.com/a?b=c"
|
||||
assert validate_url("github.com?key=!@!@", []) == "http://github.com?key=!@!@"
|
||||
# Patch socket.getaddrinfo so we control the DNS resolution in the test
|
||||
monkeypatch.setattr("socket.getaddrinfo", mock_getaddrinfo)
|
||||
|
||||
# Scheme Enforcement
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("ftp://example.com", [])
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("file://example.com", [])
|
||||
|
||||
# International domain that converts to punycode - should be allowed if public
|
||||
assert validate_url("http://xn--exmple-cua.com", []) == "http://xn--exmple-cua.com"
|
||||
# If the domain fails IDNA encoding or is invalid, it should raise an error
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("http://exa◌mple.com", [])
|
||||
|
||||
# IPv6 Addresses
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("::1", []) # IPv6 loopback should be blocked
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("http://[::1]", []) # IPv6 loopback in URL form
|
||||
|
||||
# Suspicious Characters in Hostname
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("http://example_underscore.com", [])
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("http://exa mple.com", []) # Space in hostname
|
||||
|
||||
# Malformed URLs
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("http://", []) # No hostname
|
||||
with pytest.raises(ValueError):
|
||||
validate_url("://missing-scheme", []) # Missing proper scheme
|
||||
|
||||
# Trusted Origins
|
||||
trusted = ["internal-api.company.com", "10.0.0.5"]
|
||||
assert (
|
||||
validate_url("internal-api.company.com", trusted)
|
||||
== "http://internal-api.company.com"
|
||||
)
|
||||
assert validate_url("10.0.0.5", ["10.0.0.5"]) == "http://10.0.0.5"
|
||||
|
||||
# Special Characters in Path or Query
|
||||
assert (
|
||||
validate_url("example.com/path%20with%20spaces", [])
|
||||
== "http://example.com/path%20with%20spaces"
|
||||
)
|
||||
|
||||
# Backslashes should be replaced with forward slashes
|
||||
assert (
|
||||
validate_url("http://example.com\\backslash", [])
|
||||
== "http://example.com/backslash"
|
||||
)
|
||||
|
||||
# Check defaulting scheme behavior for valid domains
|
||||
assert validate_url("example.com", []) == "http://example.com"
|
||||
assert validate_url("https://secure.com", []) == "https://secure.com"
|
||||
|
||||
# Non-ASCII Characters in Query/Fragment
|
||||
assert validate_url("example.com?param=äöü", []) == "http://example.com?param=äöü"
|
||||
if expect_error:
|
||||
# If any IP is blocked, we expect a ValueError
|
||||
with pytest.raises(ValueError):
|
||||
validate_url(hostname, [])
|
||||
else:
|
||||
pinned_url, ascii_hostname = validate_url(hostname, [])
|
||||
# The pinned_url should contain the first valid IP
|
||||
assert pinned_url.startswith("http://") or pinned_url.startswith("https://")
|
||||
assert expected_ip in pinned_url
|
||||
# The ascii_hostname should match our original hostname after IDNA encoding
|
||||
assert ascii_hostname == hostname
|
||||
|
||||
Reference in New Issue
Block a user