mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
[URL Validation] Add the URL validation on http and search blocks to avoid SSRF attack
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
import ipaddress
|
||||
import requests
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from backend.util.settings import Config
|
||||
|
||||
class HttpMethod(Enum):
|
||||
GET = "GET"
|
||||
@@ -16,6 +18,31 @@ class HttpMethod(Enum):
|
||||
OPTIONS = "OPTIONS"
|
||||
HEAD = "HEAD"
|
||||
|
||||
def validate_url(url: str) -> str:
|
||||
"""
|
||||
To avoid SSRF attacks, the URL should not be a private IP address
|
||||
unless it is whitelisted in TRUST_ENDPOINTS_FOR_REQUESTS config.
|
||||
"""
|
||||
if any(url.startswith(origin) for origin in Config().trust_endpoints_for_requests):
|
||||
return url
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
hostname = parsed_url.hostname
|
||||
|
||||
if not hostname:
|
||||
raise ValueError(f"Invalid URL: Unable to determine hostname from {url}")
|
||||
|
||||
try:
|
||||
host = socket.gethostbyname_ex(hostname)
|
||||
for ip in host[2]:
|
||||
ip_addr = ipaddress.ip_address(ip)
|
||||
if ip_addr.is_global:
|
||||
return url
|
||||
raise ValueError(f"Access to private or untrusted IP address at {hostname} is not allowed.")
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid or unresolvable URL: {url}") from e
|
||||
|
||||
class SendWebRequestBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
@@ -54,11 +81,14 @@ class SendWebRequestBlock(Block):
|
||||
if isinstance(input_data.body, str):
|
||||
input_data.body = json.loads(input_data.body)
|
||||
|
||||
validated_url = validate_url(input_data.url)
|
||||
|
||||
response = requests.request(
|
||||
input_data.method.value,
|
||||
input_data.url,
|
||||
validated_url,
|
||||
headers=input_data.headers,
|
||||
json=input_data.body,
|
||||
json=input_data.body
|
||||
allow_redirects=False
|
||||
)
|
||||
if response.status_code // 100 == 2:
|
||||
yield "response", response.json()
|
||||
|
||||
@@ -1,20 +1,50 @@
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import quote
|
||||
|
||||
from urllib.parse import quote, urlparse
|
||||
import ipaddress
|
||||
import requests
|
||||
import socket
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput, SchemaField
|
||||
|
||||
from backend.util.settings import Config
|
||||
|
||||
class GetRequest:
|
||||
@classmethod
|
||||
def get_request(cls, url: str, json=False) -> Any:
|
||||
response = requests.get(url)
|
||||
validated_url = cls().validate_url(url)
|
||||
|
||||
response = requests.get(validated_url, allow_redirects=False)
|
||||
response.raise_for_status()
|
||||
return response.json() if json else response.text
|
||||
|
||||
@classmethod
|
||||
def validate_url(self, url: str) -> str:
|
||||
"""
|
||||
To avoid SSRF attacks, the URL should not be a private IP address
|
||||
unless it is whitelisted in TRUST_ENDPOINTS_FOR_REQUESTS config.
|
||||
"""
|
||||
if any(url.startswith(origin) for origin in Config().trust_endpoints_for_requests):
|
||||
return url
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
hostname = parsed_url.hostname
|
||||
|
||||
if not hostname:
|
||||
raise ValueError(f"Invalid URL: Unable to determine hostname from {url}")
|
||||
|
||||
try:
|
||||
host = socket.gethostbyname_ex(hostname)
|
||||
for ip in host[2]:
|
||||
ip_addr = ipaddress.ip_address(ip)
|
||||
if ip_addr.is_global:
|
||||
return url
|
||||
raise ValueError(f"Access to private or untrusted IP address at {hostname} is not allowed.")
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid or unresolvable URL: {url}") from e
|
||||
|
||||
|
||||
class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||
|
||||
@@ -153,8 +153,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Name of the event bus",
|
||||
)
|
||||
|
||||
backend_cors_allow_origins: List[str] = Field(default_factory=list)
|
||||
trust_endpoints_for_requests: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="A whitelist of trusted internal endpoints for the backend to make requests to.",
|
||||
)
|
||||
|
||||
backend_cors_allow_origins: List[str] = Field(default_factory=list)
|
||||
@field_validator("backend_cors_allow_origins")
|
||||
@classmethod
|
||||
def validate_cors_allow_origins(cls, v: List[str]) -> List[str]:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
{
|
||||
"num_graph_workers": 10,
|
||||
"num_node_workers": 5,
|
||||
"num_user_credits_refill": 1500
|
||||
"num_user_credits_refill": 1500,
|
||||
"trust_endpoints_for_requests": [""]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user