[URL Validation] Add the URL validation on http and search blocks to avoid SSRF attack

This commit is contained in:
jackfromeast
2024-11-03 15:03:58 -05:00
parent 22b4e8b8c0
commit 51f9336e9e
4 changed files with 75 additions and 10 deletions

View File

@@ -1,11 +1,13 @@
import json
from enum import Enum
import ipaddress
import requests
import socket
from urllib.parse import urlparse
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.settings import Config
class HttpMethod(Enum):
GET = "GET"
@@ -16,6 +18,31 @@ class HttpMethod(Enum):
OPTIONS = "OPTIONS"
HEAD = "HEAD"
def validate_url(url: str) -> str:
"""
To avoid SSRF attacks, the URL should not be a private IP address
unless it is whitelisted in TRUST_ENDPOINTS_FOR_REQUESTS config.
"""
if any(url.startswith(origin) for origin in Config().trust_endpoints_for_requests):
return url
parsed_url = urlparse(url)
hostname = parsed_url.hostname
if not hostname:
raise ValueError(f"Invalid URL: Unable to determine hostname from {url}")
try:
host = socket.gethostbyname_ex(hostname)
for ip in host[2]:
ip_addr = ipaddress.ip_address(ip)
if ip_addr.is_global:
return url
raise ValueError(f"Access to private or untrusted IP address at {hostname} is not allowed.")
except ValueError:
raise
except Exception as e:
raise ValueError(f"Invalid or unresolvable URL: {url}") from e
class SendWebRequestBlock(Block):
class Input(BlockSchema):
@@ -54,11 +81,14 @@ class SendWebRequestBlock(Block):
if isinstance(input_data.body, str):
input_data.body = json.loads(input_data.body)
validated_url = validate_url(input_data.url)
response = requests.request(
input_data.method.value,
input_data.url,
validated_url,
headers=input_data.headers,
json=input_data.body,
json=input_data.body
allow_redirects=False
)
if response.status_code // 100 == 2:
yield "response", response.json()

View File

@@ -1,20 +1,50 @@
from typing import Any, Literal
from urllib.parse import quote
from urllib.parse import quote, urlparse
import ipaddress
import requests
import socket
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, CredentialsMetaInput, SchemaField
from backend.util.settings import Config
class GetRequest:
@classmethod
def get_request(cls, url: str, json=False) -> Any:
response = requests.get(url)
validated_url = cls().validate_url(url)
response = requests.get(validated_url, allow_redirects=False)
response.raise_for_status()
return response.json() if json else response.text
@classmethod
def validate_url(self, url: str) -> str:
"""
To avoid SSRF attacks, the URL should not be a private IP address
unless it is whitelisted in TRUST_ENDPOINTS_FOR_REQUESTS config.
"""
if any(url.startswith(origin) for origin in Config().trust_endpoints_for_requests):
return url
parsed_url = urlparse(url)
hostname = parsed_url.hostname
if not hostname:
raise ValueError(f"Invalid URL: Unable to determine hostname from {url}")
try:
host = socket.gethostbyname_ex(hostname)
for ip in host[2]:
ip_addr = ipaddress.ip_address(ip)
if ip_addr.is_global:
return url
raise ValueError(f"Access to private or untrusted IP address at {hostname} is not allowed.")
except ValueError:
raise
except Exception as e:
raise ValueError(f"Invalid or unresolvable URL: {url}") from e
class GetWikipediaSummaryBlock(Block, GetRequest):

View File

@@ -153,8 +153,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Name of the event bus",
)
backend_cors_allow_origins: List[str] = Field(default_factory=list)
trust_endpoints_for_requests: List[str] = Field(
default_factory=list,
description="A whitelist of trusted internal endpoints for the backend to make requests to.",
)
backend_cors_allow_origins: List[str] = Field(default_factory=list)
@field_validator("backend_cors_allow_origins")
@classmethod
def validate_cors_allow_origins(cls, v: List[str]) -> List[str]:

View File

@@ -1,5 +1,6 @@
{
"num_graph_workers": 10,
"num_node_workers": 5,
"num_user_credits_refill": 1500
"num_user_credits_refill": 1500,
"trust_endpoints_for_requests": [""]
}