Merge commit from fork

* Fix SSRF via user-controlled ollama_host field

Validate ollama_host against BLOCKED_IP_NETWORKS before passing to
ollama.AsyncClient(). The server-configured default (env: OLLAMA_HOST)
is allowed without validation; user-supplied values that differ are
checked for private/internal IP resolution.

Fixes GHSA-6jx2-4h7q-3fx3

* Generalize validate_ollama_host to validate_host; fix description line length

* Rename to validate_untrusted_host with whitelist parameter

* Apply PR suggestion: include whitelist in error message; run formatting

* Move whitelist check after URL normalization; match on netloc

* revert unrelated formatting changes

* Dedup validate_url and validate_untrusted_host; normalize whitelist

* Move _resolve_and_check_blocked after calling functions

* dedup and clean up

* make trusted_hostnames truly optional

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
This commit is contained in:
Otto
2026-03-10 14:51:58 +00:00
committed by GitHub
parent 684845d946
commit bfb843a56e
8 changed files with 109 additions and 68 deletions

View File

@@ -24,7 +24,7 @@ from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests, validate_url
from backend.util.request import HTTPClientError, Requests, validate_url_host
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -80,7 +80,7 @@ async def discover_tools(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -167,7 +167,7 @@ async def mcp_oauth_login(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -187,7 +187,7 @@ async def mcp_oauth_login(
# Validate the auth server URL from metadata to prevent SSRF.
try:
await validate_url(auth_server_url, trusted_origins=[])
await validate_url_host(auth_server_url)
except ValueError as e:
raise fastapi.HTTPException(
status_code=400,
@@ -234,7 +234,7 @@ async def mcp_oauth_login(
if registration_endpoint:
# Validate the registration endpoint to prevent SSRF via metadata.
try:
await validate_url(registration_endpoint, trusted_origins=[])
await validate_url_host(registration_endpoint)
except ValueError:
pass # Skip registration, fall back to default client_id
else:
@@ -429,7 +429,7 @@ async def mcp_store_token(
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")

View File

@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
from backend.blocks.search import GetRequest
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
class SearchTheWebBlock(Block, GetRequest):
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
) -> BlockOutput:
if input_data.raw_content:
try:
parsed_url, _, _ = await validate_url(input_data.url, [])
parsed_url, _, _ = await validate_url_host(input_data.url)
url = parsed_url.geturl()
except ValueError as e:
yield "error", f"Invalid URL: {e}"

View File

@@ -34,8 +34,11 @@ from backend.util import json
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_context, estimate_token_count
from backend.util.request import validate_url_host
from backend.util.settings import Settings
from backend.util.text import TextFormatter
settings = Settings()
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
fmt = TextFormatter(autoescape=False)
@@ -805,6 +808,11 @@ async def llm_call(
if tools:
raise ValueError("Ollama does not support tools.")
# Validate user-provided Ollama host to prevent SSRF etc.
await validate_url_host(
ollama_host, trusted_hostnames=[settings.config.ollama_host]
)
client = ollama.AsyncClient(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]

View File

@@ -33,7 +33,7 @@ import tempfile
from typing import Any
from backend.copilot.model import ChatSession
from backend.util.request import validate_url
from backend.util.request import validate_url_host
from .base import BaseTool
from .models import (
@@ -235,7 +235,7 @@ async def _restore_browser_state(
if url:
# Validate the saved URL to prevent SSRF via stored redirect targets.
try:
await validate_url(url, trusted_origins=[])
await validate_url_host(url)
except ValueError:
logger.warning(
"[browser] State restore: blocked SSRF URL %s", url[:200]
@@ -473,7 +473,7 @@ class BrowserNavigateTool(BaseTool):
)
try:
await validate_url(url, trusted_origins=[])
await validate_url_host(url)
except ValueError as e:
return ErrorResponse(
message=str(e),

View File

@@ -14,7 +14,7 @@ from backend.blocks.mcp.helpers import (
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.utils import build_missing_credentials_from_field_info
from backend.util.request import HTTPClientError, validate_url
from backend.util.request import HTTPClientError, validate_url_host
from .base import BaseTool
from .models import (
@@ -144,7 +144,7 @@ class RunMCPToolTool(BaseTool):
# Validate URL to prevent SSRF — blocks loopback and private IP ranges
try:
await validate_url(server_url, trusted_origins=[])
await validate_url_host(server_url)
except ValueError as e:
msg = str(e)
if "Unable to resolve" in msg or "No IP addresses" in msg:

View File

@@ -144,76 +144,106 @@ async def _resolve_host(hostname: str) -> list[str]:
return ip_addresses
async def validate_url(
url: str, trusted_origins: list[str]
async def validate_url_host(
url: str, trusted_hostnames: Optional[list[str]] = None
) -> tuple[URL, bool, list[str]]:
"""
Validates the URL to prevent SSRF attacks by ensuring it does not point
to a private, link-local, or otherwise blocked IP address — unless
Validates a (URL's) host string to prevent SSRF attacks by ensuring it does not
point to a private, link-local, or otherwise blocked IP address — unless
the hostname is explicitly trusted.
Hosts in `trusted_hostnames` are permitted without checks.
All other hosts are resolved and checked against `BLOCKED_IP_NETWORKS`.
Params:
url: A hostname, netloc, or URL to validate.
If no scheme is included, `http://` is assumed.
trusted_hostnames: A list of hostnames that don't require validation.
Raises:
ValueError:
- if the URL has a disallowed URL scheme
- if the URL/host string can't be parsed
- if the hostname contains invalid or unsupported (non-ASCII) characters
- if the host resolves to a blocked IP
Returns:
str: The validated, canonicalized, parsed URL
is_trusted: Boolean indicating if the hostname is in trusted_origins
ip_addresses: List of IP addresses for the host; empty if the host is trusted
1. The validated, canonicalized, parsed host/URL,
with hostname ASCII-safe encoding
2. Whether the host is trusted (based on the passed `trusted_hostnames`).
3. List of resolved IP addresses for the host; empty if the host is trusted.
"""
parsed = parse_url(url)
# Check scheme
if parsed.scheme not in ALLOWED_SCHEMES:
raise ValueError(
f"Scheme '{parsed.scheme}' is not allowed. Only HTTP/HTTPS are supported."
f"URL scheme '{parsed.scheme}' is not allowed; allowed schemes: "
f"{', '.join(ALLOWED_SCHEMES)}"
)
# Validate and IDNA encode hostname
if not parsed.hostname:
raise ValueError("Invalid URL: No hostname found.")
raise ValueError(f"Invalid host/URL; no host in parse result: {url}")
# IDNA encode to prevent Unicode domain attacks
try:
ascii_hostname = idna.encode(parsed.hostname).decode("ascii")
except idna.IDNAError:
raise ValueError("Invalid hostname with unsupported characters.")
raise ValueError(f"Hostname '{parsed.hostname}' has unsupported characters")
# Check hostname characters
if not HOSTNAME_REGEX.match(ascii_hostname):
raise ValueError("Hostname contains invalid characters.")
raise ValueError(f"Hostname '{parsed.hostname}' has unsupported characters")
# Check if hostname is trusted
is_trusted = ascii_hostname in trusted_origins
# If not trusted, validate IP addresses
ip_addresses: list[str] = []
if not is_trusted:
# Resolve all IP addresses for the hostname
ip_addresses = await _resolve_host(ascii_hostname)
# Block any IP address that belongs to a blocked range
for ip_str in ip_addresses:
if _is_ip_blocked(ip_str):
raise ValueError(
f"Access to blocked or private IP address {ip_str} "
f"for hostname {ascii_hostname} is not allowed."
)
# Reconstruct the netloc with IDNA-encoded hostname and preserve port
netloc = ascii_hostname
if parsed.port:
netloc = f"{ascii_hostname}:{parsed.port}"
return (
URL(
parsed.scheme,
netloc,
quote(parsed.path, safe="/%:@"),
parsed.params,
parsed.query,
parsed.fragment,
),
is_trusted,
ip_addresses,
# Re-create parsed URL object with IDNA-encoded hostname
parsed = URL(
parsed.scheme,
(ascii_hostname if parsed.port is None else f"{ascii_hostname}:{parsed.port}"),
quote(parsed.path, safe="/%:@"),
parsed.params,
parsed.query,
parsed.fragment,
)
is_trusted = trusted_hostnames and any(
matches_allowed_host(parsed, allowed)
for allowed in (
# Normalize + parse allowlist entries the same way for consistent matching
parse_url(w)
for w in trusted_hostnames
)
)
if is_trusted:
return parsed, True, []
# If not allowlisted, go ahead with host resolution and IP target check
return parsed, False, await _resolve_and_check_blocked(ascii_hostname)
def matches_allowed_host(url: URL, allowed: URL) -> bool:
if url.hostname != allowed.hostname:
return False
# Allow any port if not explicitly specified in the allowlist
if allowed.port is None:
return True
return url.port == allowed.port
async def _resolve_and_check_blocked(hostname: str) -> list[str]:
"""
Resolves hostname to IPs and raises ValueError if any resolve to
a blocked network. Returns the list of resolved IP addresses.
"""
ip_addresses = await _resolve_host(hostname)
for ip_str in ip_addresses:
if _is_ip_blocked(ip_str):
raise ValueError(
f"Access to blocked or private IP address {ip_str} "
f"for hostname {hostname} is not allowed."
)
return ip_addresses
def parse_url(url: str) -> URL:
"""Canonicalizes and parses a URL string."""
@@ -352,7 +382,7 @@ class Requests:
):
self.trusted_origins = []
for url in trusted_origins or []:
hostname = urlparse(url).hostname
hostname = parse_url(url).netloc # {host}[:{port}]
if not hostname:
raise ValueError(f"Invalid URL: Unable to determine hostname of {url}")
self.trusted_origins.append(hostname)
@@ -450,7 +480,7 @@ class Requests:
data = form
# Validate URL and get trust status
parsed_url, is_trusted, ip_addresses = await validate_url(
parsed_url, is_trusted, ip_addresses = await validate_url_host(
url, self.trusted_origins
)
@@ -503,7 +533,6 @@ class Requests:
json=json,
**kwargs,
) as response:
if self.raise_for_status:
try:
response.raise_for_status()

View File

@@ -1,7 +1,7 @@
import pytest
from aiohttp import web
from backend.util.request import pin_url, validate_url
from backend.util.request import pin_url, validate_url_host
@pytest.mark.parametrize(
@@ -60,9 +60,9 @@ async def test_validate_url_no_dns_rebinding(
):
if should_raise:
with pytest.raises(ValueError):
await validate_url(raw_url, trusted_origins)
await validate_url_host(raw_url, trusted_origins)
else:
validated_url, _, _ = await validate_url(raw_url, trusted_origins)
validated_url, _, _ = await validate_url_host(raw_url, trusted_origins)
assert validated_url.geturl() == expected_value
@@ -101,10 +101,10 @@ async def test_dns_rebinding_fix(
if expect_error:
# If any IP is blocked, we expect a ValueError
with pytest.raises(ValueError):
url, _, ip_addresses = await validate_url(hostname, [])
url, _, ip_addresses = await validate_url_host(hostname)
pin_url(url, ip_addresses)
else:
url, _, ip_addresses = await validate_url(hostname, [])
url, _, ip_addresses = await validate_url_host(hostname)
pinned_url = pin_url(url, ip_addresses).geturl()
# The pinned_url should contain the first valid IP
assert pinned_url.startswith("http://") or pinned_url.startswith("https://")

View File

@@ -89,6 +89,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
le=500,
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
)
ollama_host: str = Field(
default="localhost:11434",
description="Default Ollama host; exempted from SSRF checks.",
)
pyro_host: str = Field(
default="localhost",
description="The default hostname of the Pyro server.",