diff --git a/autogpt_platform/backend/backend/api/features/mcp/routes.py b/autogpt_platform/backend/backend/api/features/mcp/routes.py index cdf827feed..87dbb64b77 100644 --- a/autogpt_platform/backend/backend/api/features/mcp/routes.py +++ b/autogpt_platform/backend/backend/api/features/mcp/routes.py @@ -24,7 +24,7 @@ from backend.blocks.mcp.oauth import MCPOAuthHandler from backend.data.model import OAuth2Credentials from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.providers import ProviderName -from backend.util.request import HTTPClientError, Requests, validate_url +from backend.util.request import HTTPClientError, Requests, validate_url_host from backend.util.settings import Settings logger = logging.getLogger(__name__) @@ -80,7 +80,7 @@ async def discover_tools( """ # Validate URL to prevent SSRF — blocks loopback and private IP ranges. try: - await validate_url(request.server_url, trusted_origins=[]) + await validate_url_host(request.server_url) except ValueError as e: raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}") @@ -167,7 +167,7 @@ async def mcp_oauth_login( """ # Validate URL to prevent SSRF — blocks loopback and private IP ranges. try: - await validate_url(request.server_url, trusted_origins=[]) + await validate_url_host(request.server_url) except ValueError as e: raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}") @@ -187,7 +187,7 @@ async def mcp_oauth_login( # Validate the auth server URL from metadata to prevent SSRF. try: - await validate_url(auth_server_url, trusted_origins=[]) + await validate_url_host(auth_server_url) except ValueError as e: raise fastapi.HTTPException( status_code=400, @@ -234,7 +234,7 @@ async def mcp_oauth_login( if registration_endpoint: # Validate the registration endpoint to prevent SSRF via metadata. try: - await validate_url(registration_endpoint, trusted_origins=[]) + await validate_url_host(registration_endpoint) except ValueError: pass # Skip registration, fall back to default client_id else: @@ -429,7 +429,7 @@ async def mcp_store_token( # Validate URL to prevent SSRF — blocks loopback and private IP ranges. try: - await validate_url(request.server_url, trusted_origins=[]) + await validate_url_host(request.server_url) except ValueError as e: raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}") diff --git a/autogpt_platform/backend/backend/blocks/jina/search.py b/autogpt_platform/backend/backend/blocks/jina/search.py index 5e58ddcab4..007dd5bc12 100644 --- a/autogpt_platform/backend/backend/blocks/jina/search.py +++ b/autogpt_platform/backend/backend/blocks/jina/search.py @@ -17,7 +17,7 @@ from backend.blocks.jina._auth import ( from backend.blocks.search import GetRequest from backend.data.model import SchemaField from backend.util.exceptions import BlockExecutionError -from backend.util.request import HTTPClientError, HTTPServerError, validate_url +from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host class SearchTheWebBlock(Block, GetRequest): @@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest): ) -> BlockOutput: if input_data.raw_content: try: - parsed_url, _, _ = await validate_url(input_data.url, []) + parsed_url, _, _ = await validate_url_host(input_data.url) url = parsed_url.geturl() except ValueError as e: yield "error", f"Invalid URL: {e}" diff --git a/autogpt_platform/backend/backend/blocks/llm.py b/autogpt_platform/backend/backend/blocks/llm.py index c71aec8235..366ab5cca9 100644 --- a/autogpt_platform/backend/backend/blocks/llm.py +++ b/autogpt_platform/backend/backend/blocks/llm.py @@ -34,8 +34,11 @@ from backend.util import json from backend.util.clients import OPENROUTER_BASE_URL from backend.util.logging import TruncatedLogger from backend.util.prompt import compress_context, estimate_token_count +from backend.util.request import validate_url_host +from backend.util.settings import Settings from backend.util.text import TextFormatter +settings = Settings() logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]") fmt = TextFormatter(autoescape=False) @@ -805,6 +808,11 @@ async def llm_call( if tools: raise ValueError("Ollama does not support tools.") + # Validate user-provided Ollama host to prevent SSRF etc. + await validate_url_host( + ollama_host, trusted_hostnames=[settings.config.ollama_host] + ) + client = ollama.AsyncClient(host=ollama_host) sys_messages = [p["content"] for p in prompt if p["role"] == "system"] usr_messages = [p["content"] for p in prompt if p["role"] != "system"] diff --git a/autogpt_platform/backend/backend/copilot/tools/agent_browser.py b/autogpt_platform/backend/backend/copilot/tools/agent_browser.py index 6e76e1c62a..8bec85dd3f 100644 --- a/autogpt_platform/backend/backend/copilot/tools/agent_browser.py +++ b/autogpt_platform/backend/backend/copilot/tools/agent_browser.py @@ -33,7 +33,7 @@ import tempfile from typing import Any from backend.copilot.model import ChatSession -from backend.util.request import validate_url +from backend.util.request import validate_url_host from .base import BaseTool from .models import ( @@ -235,7 +235,7 @@ async def _restore_browser_state( if url: # Validate the saved URL to prevent SSRF via stored redirect targets. try: - await validate_url(url, trusted_origins=[]) + await validate_url_host(url) except ValueError: logger.warning( "[browser] State restore: blocked SSRF URL %s", url[:200] @@ -473,7 +473,7 @@ class BrowserNavigateTool(BaseTool): ) try: - await validate_url(url, trusted_origins=[]) + await validate_url_host(url) except ValueError as e: return ErrorResponse( message=str(e), diff --git a/autogpt_platform/backend/backend/copilot/tools/run_mcp_tool.py b/autogpt_platform/backend/backend/copilot/tools/run_mcp_tool.py index e8f9d3842c..3aead12f63 100644 --- a/autogpt_platform/backend/backend/copilot/tools/run_mcp_tool.py +++ b/autogpt_platform/backend/backend/copilot/tools/run_mcp_tool.py @@ -14,7 +14,7 @@ from backend.blocks.mcp.helpers import ( ) from backend.copilot.model import ChatSession from backend.copilot.tools.utils import build_missing_credentials_from_field_info -from backend.util.request import HTTPClientError, validate_url +from backend.util.request import HTTPClientError, validate_url_host from .base import BaseTool from .models import ( @@ -144,7 +144,7 @@ class RunMCPToolTool(BaseTool): # Validate URL to prevent SSRF — blocks loopback and private IP ranges try: - await validate_url(server_url, trusted_origins=[]) + await validate_url_host(server_url) except ValueError as e: msg = str(e) if "Unable to resolve" in msg or "No IP addresses" in msg: diff --git a/autogpt_platform/backend/backend/util/request.py b/autogpt_platform/backend/backend/util/request.py index 924200ad29..fde69f7d48 100644 --- a/autogpt_platform/backend/backend/util/request.py +++ b/autogpt_platform/backend/backend/util/request.py @@ -144,76 +144,106 @@ async def _resolve_host(hostname: str) -> list[str]: return ip_addresses -async def validate_url( - url: str, trusted_origins: list[str] +async def validate_url_host( + url: str, trusted_hostnames: Optional[list[str]] = None ) -> tuple[URL, bool, list[str]]: """ - Validates the URL to prevent SSRF attacks by ensuring it does not point - to a private, link-local, or otherwise blocked IP address — unless + Validates a (URL's) host string to prevent SSRF attacks by ensuring it does not + point to a private, link-local, or otherwise blocked IP address — unless the hostname is explicitly trusted. + Hosts in `trusted_hostnames` are permitted without checks. + All other hosts are resolved and checked against `BLOCKED_IP_NETWORKS`. + + Params: + url: A hostname, netloc, or URL to validate. + If no scheme is included, `http://` is assumed. + trusted_hostnames: A list of hostnames that don't require validation. + + Raises: + ValueError: + - if the URL has a disallowed URL scheme + - if the URL/host string can't be parsed + - if the hostname contains invalid or unsupported (non-ASCII) characters + - if the host resolves to a blocked IP + Returns: - str: The validated, canonicalized, parsed URL - is_trusted: Boolean indicating if the hostname is in trusted_origins - ip_addresses: List of IP addresses for the host; empty if the host is trusted + 1. The validated, canonicalized, parsed host/URL, + with hostname ASCII-safe encoding + 2. Whether the host is trusted (based on the passed `trusted_hostnames`). + 3. List of resolved IP addresses for the host; empty if the host is trusted. """ parsed = parse_url(url) - # Check scheme if parsed.scheme not in ALLOWED_SCHEMES: raise ValueError( - f"Scheme '{parsed.scheme}' is not allowed. Only HTTP/HTTPS are supported." + f"URL scheme '{parsed.scheme}' is not allowed; allowed schemes: " + f"{', '.join(ALLOWED_SCHEMES)}" ) - # Validate and IDNA encode hostname if not parsed.hostname: - raise ValueError("Invalid URL: No hostname found.") + raise ValueError(f"Invalid host/URL; no host in parse result: {url}") # IDNA encode to prevent Unicode domain attacks try: ascii_hostname = idna.encode(parsed.hostname).decode("ascii") except idna.IDNAError: - raise ValueError("Invalid hostname with unsupported characters.") + raise ValueError(f"Hostname '{parsed.hostname}' has unsupported characters") - # Check hostname characters if not HOSTNAME_REGEX.match(ascii_hostname): - raise ValueError("Hostname contains invalid characters.") + raise ValueError(f"Hostname '{parsed.hostname}' has unsupported characters") - # Check if hostname is trusted - is_trusted = ascii_hostname in trusted_origins - - # If not trusted, validate IP addresses - ip_addresses: list[str] = [] - if not is_trusted: - # Resolve all IP addresses for the hostname - ip_addresses = await _resolve_host(ascii_hostname) - - # Block any IP address that belongs to a blocked range - for ip_str in ip_addresses: - if _is_ip_blocked(ip_str): - raise ValueError( - f"Access to blocked or private IP address {ip_str} " - f"for hostname {ascii_hostname} is not allowed." - ) - - # Reconstruct the netloc with IDNA-encoded hostname and preserve port - netloc = ascii_hostname - if parsed.port: - netloc = f"{ascii_hostname}:{parsed.port}" - - return ( - URL( - parsed.scheme, - netloc, - quote(parsed.path, safe="/%:@"), - parsed.params, - parsed.query, - parsed.fragment, - ), - is_trusted, - ip_addresses, + # Re-create parsed URL object with IDNA-encoded hostname + parsed = URL( + parsed.scheme, + (ascii_hostname if parsed.port is None else f"{ascii_hostname}:{parsed.port}"), + quote(parsed.path, safe="/%:@"), + parsed.params, + parsed.query, + parsed.fragment, ) + is_trusted = trusted_hostnames and any( + matches_allowed_host(parsed, allowed) + for allowed in ( + # Normalize + parse allowlist entries the same way for consistent matching + parse_url(w) + for w in trusted_hostnames + ) + ) + + if is_trusted: + return parsed, True, [] + + # If not allowlisted, go ahead with host resolution and IP target check + return parsed, False, await _resolve_and_check_blocked(ascii_hostname) + + +def matches_allowed_host(url: URL, allowed: URL) -> bool: + if url.hostname != allowed.hostname: + return False + + # Allow any port if not explicitly specified in the allowlist + if allowed.port is None: + return True + + return url.port == allowed.port + + +async def _resolve_and_check_blocked(hostname: str) -> list[str]: + """ + Resolves hostname to IPs and raises ValueError if any resolve to + a blocked network. Returns the list of resolved IP addresses. + """ + ip_addresses = await _resolve_host(hostname) + for ip_str in ip_addresses: + if _is_ip_blocked(ip_str): + raise ValueError( + f"Access to blocked or private IP address {ip_str} " + f"for hostname {hostname} is not allowed." + ) + return ip_addresses + def parse_url(url: str) -> URL: """Canonicalizes and parses a URL string.""" @@ -352,7 +382,7 @@ class Requests: ): self.trusted_origins = [] for url in trusted_origins or []: - hostname = urlparse(url).hostname + hostname = parse_url(url).netloc # {host}[:{port}] if not hostname: raise ValueError(f"Invalid URL: Unable to determine hostname of {url}") self.trusted_origins.append(hostname) @@ -450,7 +480,7 @@ class Requests: data = form # Validate URL and get trust status - parsed_url, is_trusted, ip_addresses = await validate_url( + parsed_url, is_trusted, ip_addresses = await validate_url_host( url, self.trusted_origins ) @@ -503,7 +533,6 @@ class Requests: json=json, **kwargs, ) as response: - if self.raise_for_status: try: response.raise_for_status() diff --git a/autogpt_platform/backend/backend/util/request_test.py b/autogpt_platform/backend/backend/util/request_test.py index eaabbf4e09..aa4f5649a0 100644 --- a/autogpt_platform/backend/backend/util/request_test.py +++ b/autogpt_platform/backend/backend/util/request_test.py @@ -1,7 +1,7 @@ import pytest from aiohttp import web -from backend.util.request import pin_url, validate_url +from backend.util.request import pin_url, validate_url_host @pytest.mark.parametrize( @@ -60,9 +60,9 @@ async def test_validate_url_no_dns_rebinding( ): if should_raise: with pytest.raises(ValueError): - await validate_url(raw_url, trusted_origins) + await validate_url_host(raw_url, trusted_origins) else: - validated_url, _, _ = await validate_url(raw_url, trusted_origins) + validated_url, _, _ = await validate_url_host(raw_url, trusted_origins) assert validated_url.geturl() == expected_value @@ -101,10 +101,10 @@ async def test_dns_rebinding_fix( if expect_error: # If any IP is blocked, we expect a ValueError with pytest.raises(ValueError): - url, _, ip_addresses = await validate_url(hostname, []) + url, _, ip_addresses = await validate_url_host(hostname) pin_url(url, ip_addresses) else: - url, _, ip_addresses = await validate_url(hostname, []) + url, _, ip_addresses = await validate_url_host(hostname) pinned_url = pin_url(url, ip_addresses).geturl() # The pinned_url should contain the first valid IP assert pinned_url.startswith("http://") or pinned_url.startswith("https://") diff --git a/autogpt_platform/backend/backend/util/settings.py b/autogpt_platform/backend/backend/util/settings.py index 39ec1a0bc8..8f12679d14 100644 --- a/autogpt_platform/backend/backend/util/settings.py +++ b/autogpt_platform/backend/backend/util/settings.py @@ -89,6 +89,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings): le=500, description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.", ) + ollama_host: str = Field( + default="localhost:11434", + description="Default Ollama host; exempted from SSRF checks.", + ) pyro_host: str = Field( default="localhost", description="The default hostname of the Pyro server.",