mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Merge commit from fork
* Fix SSRF via user-controlled ollama_host field Validate ollama_host against BLOCKED_IP_NETWORKS before passing to ollama.AsyncClient(). The server-configured default (env: OLLAMA_HOST) is allowed without validation; user-supplied values that differ are checked for private/internal IP resolution. Fixes GHSA-6jx2-4h7q-3fx3 * Generalize validate_ollama_host to validate_host; fix description line length * Rename to validate_untrusted_host with whitelist parameter * Apply PR suggestion: include whitelist in error message; run formatting * Move whitelist check after URL normalization; match on netloc * revert unrelated formatting changes * Dedup validate_url and validate_untrusted_host; normalize whitelist * Move _resolve_and_check_blocked after calling functions * dedup and clean up * make trusted_hostnames truly optional --------- Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
This commit is contained in:
@@ -24,7 +24,7 @@ from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import HTTPClientError, Requests, validate_url
|
||||
from backend.util.request import HTTPClientError, Requests, validate_url_host
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -80,7 +80,7 @@ async def discover_tools(
|
||||
"""
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
await validate_url_host(request.server_url)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
@@ -167,7 +167,7 @@ async def mcp_oauth_login(
|
||||
"""
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
await validate_url_host(request.server_url)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
@@ -187,7 +187,7 @@ async def mcp_oauth_login(
|
||||
|
||||
# Validate the auth server URL from metadata to prevent SSRF.
|
||||
try:
|
||||
await validate_url(auth_server_url, trusted_origins=[])
|
||||
await validate_url_host(auth_server_url)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
@@ -234,7 +234,7 @@ async def mcp_oauth_login(
|
||||
if registration_endpoint:
|
||||
# Validate the registration endpoint to prevent SSRF via metadata.
|
||||
try:
|
||||
await validate_url(registration_endpoint, trusted_origins=[])
|
||||
await validate_url_host(registration_endpoint)
|
||||
except ValueError:
|
||||
pass # Skip registration, fall back to default client_id
|
||||
else:
|
||||
@@ -429,7 +429,7 @@ async def mcp_store_token(
|
||||
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
await validate_url_host(request.server_url)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
|
||||
from backend.blocks.search import GetRequest
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
|
||||
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
|
||||
|
||||
|
||||
class SearchTheWebBlock(Block, GetRequest):
|
||||
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
|
||||
) -> BlockOutput:
|
||||
if input_data.raw_content:
|
||||
try:
|
||||
parsed_url, _, _ = await validate_url(input_data.url, [])
|
||||
parsed_url, _, _ = await validate_url_host(input_data.url)
|
||||
url = parsed_url.geturl()
|
||||
except ValueError as e:
|
||||
yield "error", f"Invalid URL: {e}"
|
||||
|
||||
@@ -34,8 +34,11 @@ from backend.util import json
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.prompt import compress_context, estimate_token_count
|
||||
from backend.util.request import validate_url_host
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
settings = Settings()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
fmt = TextFormatter(autoescape=False)
|
||||
|
||||
@@ -805,6 +808,11 @@ async def llm_call(
|
||||
if tools:
|
||||
raise ValueError("Ollama does not support tools.")
|
||||
|
||||
# Validate user-provided Ollama host to prevent SSRF etc.
|
||||
await validate_url_host(
|
||||
ollama_host, trusted_hostnames=[settings.config.ollama_host]
|
||||
)
|
||||
|
||||
client = ollama.AsyncClient(host=ollama_host)
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
|
||||
@@ -33,7 +33,7 @@ import tempfile
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.util.request import validate_url
|
||||
from backend.util.request import validate_url_host
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
@@ -235,7 +235,7 @@ async def _restore_browser_state(
|
||||
if url:
|
||||
# Validate the saved URL to prevent SSRF via stored redirect targets.
|
||||
try:
|
||||
await validate_url(url, trusted_origins=[])
|
||||
await validate_url_host(url)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"[browser] State restore: blocked SSRF URL %s", url[:200]
|
||||
@@ -473,7 +473,7 @@ class BrowserNavigateTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
await validate_url(url, trusted_origins=[])
|
||||
await validate_url_host(url)
|
||||
except ValueError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
|
||||
@@ -14,7 +14,7 @@ from backend.blocks.mcp.helpers import (
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.utils import build_missing_credentials_from_field_info
|
||||
from backend.util.request import HTTPClientError, validate_url
|
||||
from backend.util.request import HTTPClientError, validate_url_host
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
@@ -144,7 +144,7 @@ class RunMCPToolTool(BaseTool):
|
||||
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges
|
||||
try:
|
||||
await validate_url(server_url, trusted_origins=[])
|
||||
await validate_url_host(server_url)
|
||||
except ValueError as e:
|
||||
msg = str(e)
|
||||
if "Unable to resolve" in msg or "No IP addresses" in msg:
|
||||
|
||||
@@ -144,76 +144,106 @@ async def _resolve_host(hostname: str) -> list[str]:
|
||||
return ip_addresses
|
||||
|
||||
|
||||
async def validate_url(
|
||||
url: str, trusted_origins: list[str]
|
||||
async def validate_url_host(
|
||||
url: str, trusted_hostnames: Optional[list[str]] = None
|
||||
) -> tuple[URL, bool, list[str]]:
|
||||
"""
|
||||
Validates the URL to prevent SSRF attacks by ensuring it does not point
|
||||
to a private, link-local, or otherwise blocked IP address — unless
|
||||
Validates a (URL's) host string to prevent SSRF attacks by ensuring it does not
|
||||
point to a private, link-local, or otherwise blocked IP address — unless
|
||||
the hostname is explicitly trusted.
|
||||
|
||||
Hosts in `trusted_hostnames` are permitted without checks.
|
||||
All other hosts are resolved and checked against `BLOCKED_IP_NETWORKS`.
|
||||
|
||||
Params:
|
||||
url: A hostname, netloc, or URL to validate.
|
||||
If no scheme is included, `http://` is assumed.
|
||||
trusted_hostnames: A list of hostnames that don't require validation.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
- if the URL has a disallowed URL scheme
|
||||
- if the URL/host string can't be parsed
|
||||
- if the hostname contains invalid or unsupported (non-ASCII) characters
|
||||
- if the host resolves to a blocked IP
|
||||
|
||||
Returns:
|
||||
str: The validated, canonicalized, parsed URL
|
||||
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
||||
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
||||
1. The validated, canonicalized, parsed host/URL,
|
||||
with hostname ASCII-safe encoding
|
||||
2. Whether the host is trusted (based on the passed `trusted_hostnames`).
|
||||
3. List of resolved IP addresses for the host; empty if the host is trusted.
|
||||
"""
|
||||
parsed = parse_url(url)
|
||||
|
||||
# Check scheme
|
||||
if parsed.scheme not in ALLOWED_SCHEMES:
|
||||
raise ValueError(
|
||||
f"Scheme '{parsed.scheme}' is not allowed. Only HTTP/HTTPS are supported."
|
||||
f"URL scheme '{parsed.scheme}' is not allowed; allowed schemes: "
|
||||
f"{', '.join(ALLOWED_SCHEMES)}"
|
||||
)
|
||||
|
||||
# Validate and IDNA encode hostname
|
||||
if not parsed.hostname:
|
||||
raise ValueError("Invalid URL: No hostname found.")
|
||||
raise ValueError(f"Invalid host/URL; no host in parse result: {url}")
|
||||
|
||||
# IDNA encode to prevent Unicode domain attacks
|
||||
try:
|
||||
ascii_hostname = idna.encode(parsed.hostname).decode("ascii")
|
||||
except idna.IDNAError:
|
||||
raise ValueError("Invalid hostname with unsupported characters.")
|
||||
raise ValueError(f"Hostname '{parsed.hostname}' has unsupported characters")
|
||||
|
||||
# Check hostname characters
|
||||
if not HOSTNAME_REGEX.match(ascii_hostname):
|
||||
raise ValueError("Hostname contains invalid characters.")
|
||||
raise ValueError(f"Hostname '{parsed.hostname}' has unsupported characters")
|
||||
|
||||
# Check if hostname is trusted
|
||||
is_trusted = ascii_hostname in trusted_origins
|
||||
|
||||
# If not trusted, validate IP addresses
|
||||
ip_addresses: list[str] = []
|
||||
if not is_trusted:
|
||||
# Resolve all IP addresses for the hostname
|
||||
ip_addresses = await _resolve_host(ascii_hostname)
|
||||
|
||||
# Block any IP address that belongs to a blocked range
|
||||
for ip_str in ip_addresses:
|
||||
if _is_ip_blocked(ip_str):
|
||||
raise ValueError(
|
||||
f"Access to blocked or private IP address {ip_str} "
|
||||
f"for hostname {ascii_hostname} is not allowed."
|
||||
)
|
||||
|
||||
# Reconstruct the netloc with IDNA-encoded hostname and preserve port
|
||||
netloc = ascii_hostname
|
||||
if parsed.port:
|
||||
netloc = f"{ascii_hostname}:{parsed.port}"
|
||||
|
||||
return (
|
||||
URL(
|
||||
parsed.scheme,
|
||||
netloc,
|
||||
quote(parsed.path, safe="/%:@"),
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment,
|
||||
),
|
||||
is_trusted,
|
||||
ip_addresses,
|
||||
# Re-create parsed URL object with IDNA-encoded hostname
|
||||
parsed = URL(
|
||||
parsed.scheme,
|
||||
(ascii_hostname if parsed.port is None else f"{ascii_hostname}:{parsed.port}"),
|
||||
quote(parsed.path, safe="/%:@"),
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment,
|
||||
)
|
||||
|
||||
is_trusted = trusted_hostnames and any(
|
||||
matches_allowed_host(parsed, allowed)
|
||||
for allowed in (
|
||||
# Normalize + parse allowlist entries the same way for consistent matching
|
||||
parse_url(w)
|
||||
for w in trusted_hostnames
|
||||
)
|
||||
)
|
||||
|
||||
if is_trusted:
|
||||
return parsed, True, []
|
||||
|
||||
# If not allowlisted, go ahead with host resolution and IP target check
|
||||
return parsed, False, await _resolve_and_check_blocked(ascii_hostname)
|
||||
|
||||
|
||||
def matches_allowed_host(url: URL, allowed: URL) -> bool:
|
||||
if url.hostname != allowed.hostname:
|
||||
return False
|
||||
|
||||
# Allow any port if not explicitly specified in the allowlist
|
||||
if allowed.port is None:
|
||||
return True
|
||||
|
||||
return url.port == allowed.port
|
||||
|
||||
|
||||
async def _resolve_and_check_blocked(hostname: str) -> list[str]:
|
||||
"""
|
||||
Resolves hostname to IPs and raises ValueError if any resolve to
|
||||
a blocked network. Returns the list of resolved IP addresses.
|
||||
"""
|
||||
ip_addresses = await _resolve_host(hostname)
|
||||
for ip_str in ip_addresses:
|
||||
if _is_ip_blocked(ip_str):
|
||||
raise ValueError(
|
||||
f"Access to blocked or private IP address {ip_str} "
|
||||
f"for hostname {hostname} is not allowed."
|
||||
)
|
||||
return ip_addresses
|
||||
|
||||
|
||||
def parse_url(url: str) -> URL:
|
||||
"""Canonicalizes and parses a URL string."""
|
||||
@@ -352,7 +382,7 @@ class Requests:
|
||||
):
|
||||
self.trusted_origins = []
|
||||
for url in trusted_origins or []:
|
||||
hostname = urlparse(url).hostname
|
||||
hostname = parse_url(url).netloc # {host}[:{port}]
|
||||
if not hostname:
|
||||
raise ValueError(f"Invalid URL: Unable to determine hostname of {url}")
|
||||
self.trusted_origins.append(hostname)
|
||||
@@ -450,7 +480,7 @@ class Requests:
|
||||
data = form
|
||||
|
||||
# Validate URL and get trust status
|
||||
parsed_url, is_trusted, ip_addresses = await validate_url(
|
||||
parsed_url, is_trusted, ip_addresses = await validate_url_host(
|
||||
url, self.trusted_origins
|
||||
)
|
||||
|
||||
@@ -503,7 +533,6 @@ class Requests:
|
||||
json=json,
|
||||
**kwargs,
|
||||
) as response:
|
||||
|
||||
if self.raise_for_status:
|
||||
try:
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from backend.util.request import pin_url, validate_url
|
||||
from backend.util.request import pin_url, validate_url_host
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -60,9 +60,9 @@ async def test_validate_url_no_dns_rebinding(
|
||||
):
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError):
|
||||
await validate_url(raw_url, trusted_origins)
|
||||
await validate_url_host(raw_url, trusted_origins)
|
||||
else:
|
||||
validated_url, _, _ = await validate_url(raw_url, trusted_origins)
|
||||
validated_url, _, _ = await validate_url_host(raw_url, trusted_origins)
|
||||
assert validated_url.geturl() == expected_value
|
||||
|
||||
|
||||
@@ -101,10 +101,10 @@ async def test_dns_rebinding_fix(
|
||||
if expect_error:
|
||||
# If any IP is blocked, we expect a ValueError
|
||||
with pytest.raises(ValueError):
|
||||
url, _, ip_addresses = await validate_url(hostname, [])
|
||||
url, _, ip_addresses = await validate_url_host(hostname)
|
||||
pin_url(url, ip_addresses)
|
||||
else:
|
||||
url, _, ip_addresses = await validate_url(hostname, [])
|
||||
url, _, ip_addresses = await validate_url_host(hostname)
|
||||
pinned_url = pin_url(url, ip_addresses).geturl()
|
||||
# The pinned_url should contain the first valid IP
|
||||
assert pinned_url.startswith("http://") or pinned_url.startswith("https://")
|
||||
|
||||
@@ -89,6 +89,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
le=500,
|
||||
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
|
||||
)
|
||||
ollama_host: str = Field(
|
||||
default="localhost:11434",
|
||||
description="Default Ollama host; exempted from SSRF checks.",
|
||||
)
|
||||
pyro_host: str = Field(
|
||||
default="localhost",
|
||||
description="The default hostname of the Pyro server.",
|
||||
|
||||
Reference in New Issue
Block a user