mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Apply secure request to all blocks
This commit is contained in:
@@ -3,12 +3,12 @@ import time
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
import requests
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput, SchemaField
|
||||
from backend.util.request import requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
|
||||
@@ -1,68 +1,50 @@
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from ._auth import GithubCredentials
|
||||
from backend.blocks.github._auth import GithubCredentials
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
class GitHubAPI:
|
||||
def __init__(self, credentials: GithubCredentials):
|
||||
self.credentials = credentials
|
||||
def _validate_github_url(url: str) -> None:
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.netloc != "github.com":
|
||||
raise ValueError("The input URL must be a valid GitHub URL.")
|
||||
|
||||
@staticmethod
|
||||
def _validate_github_url(url: str) -> None:
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.netloc != "github.com":
|
||||
raise ValueError("The input URL must be a valid GitHub URL.")
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_api_url(url: str) -> str:
|
||||
"""
|
||||
Converts a standard GitHub URL to the corresponding GitHub API URL.
|
||||
Handles repository URLs, issue URLs, pull request URLs, and more.
|
||||
"""
|
||||
GitHubAPI._validate_github_url(url)
|
||||
parsed_url = urlparse(url)
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
def _convert_to_api_url(url: str) -> str:
|
||||
"""
|
||||
Converts a standard GitHub URL to the corresponding GitHub API URL.
|
||||
Handles repository URLs, issue URLs, pull request URLs, and more.
|
||||
"""
|
||||
_validate_github_url(url)
|
||||
parsed_url = urlparse(url)
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
|
||||
if len(path_parts) >= 2:
|
||||
owner, repo = path_parts[0], path_parts[1]
|
||||
api_base = f"https://api.github.com/repos/{owner}/{repo}"
|
||||
if len(path_parts) >= 2:
|
||||
owner, repo = path_parts[0], path_parts[1]
|
||||
api_base = f"https://api.github.com/repos/{owner}/{repo}"
|
||||
|
||||
if len(path_parts) > 2:
|
||||
additional_path = "/".join(path_parts[2:])
|
||||
api_url = f"{api_base}/{additional_path}"
|
||||
else:
|
||||
# Repository base URL
|
||||
api_url = api_base
|
||||
if len(path_parts) > 2:
|
||||
additional_path = "/".join(path_parts[2:])
|
||||
api_url = f"{api_base}/{additional_path}"
|
||||
else:
|
||||
raise ValueError("Invalid GitHub URL format.")
|
||||
# Repository base URL
|
||||
api_url = api_base
|
||||
else:
|
||||
raise ValueError("Invalid GitHub URL format.")
|
||||
|
||||
return api_url
|
||||
return api_url
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
return {
|
||||
"Authorization": self.credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
def get(self, url: str, **kwargs) -> requests.Response:
|
||||
api_url = self._convert_to_api_url(url)
|
||||
headers = self._get_headers()
|
||||
response = requests.get(api_url, headers=headers, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
def _get_headers(credentials: GithubCredentials) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
def post(self, url: str, json=None, **kwargs) -> requests.Response:
|
||||
api_url = self._convert_to_api_url(url)
|
||||
headers = self._get_headers()
|
||||
response = requests.post(api_url, headers=headers, json=json, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def delete(self, url: str, json=None, **kwargs) -> requests.Response:
|
||||
api_url = self._convert_to_api_url(url)
|
||||
headers = self._get_headers()
|
||||
response = requests.delete(api_url, headers=headers, json=json, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
def get_api(credentials: GithubCredentials) -> Requests:
|
||||
return Requests(
|
||||
trusted_origins=["https://api.github.com", "https://github.com"],
|
||||
extra_url_validator=_convert_to_api_url,
|
||||
extra_headers=_get_headers(credentials),
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing_extensions import TypedDict
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import GitHubAPI
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
@@ -68,7 +68,7 @@ class GithubCommentBlock(Block):
|
||||
def post_comment(
|
||||
credentials: GithubCredentials, issue_url: str, body_text: str
|
||||
) -> tuple[int, str]:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
data = {"body": body_text}
|
||||
comments_url = issue_url + "/comments"
|
||||
response = api.post(comments_url, json=data)
|
||||
@@ -145,7 +145,7 @@ class GithubMakeIssueBlock(Block):
|
||||
def create_issue(
|
||||
credentials: GithubCredentials, repo_url: str, title: str, body: str
|
||||
) -> tuple[int, str]:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
data = {"title": title, "body": body}
|
||||
issues_url = repo_url + "/issues"
|
||||
response = api.post(issues_url, json=data)
|
||||
@@ -215,7 +215,7 @@ class GithubReadIssueBlock(Block):
|
||||
def read_issue(
|
||||
credentials: GithubCredentials, issue_url: str
|
||||
) -> tuple[str, str, str]:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
response = api.get(issue_url)
|
||||
data = response.json()
|
||||
title = data.get("title", "No title found")
|
||||
@@ -292,7 +292,7 @@ class GithubListIssuesBlock(Block):
|
||||
def list_issues(
|
||||
credentials: GithubCredentials, repo_url: str
|
||||
) -> list[Output.IssueItem]:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
issues_url = repo_url + "/issues"
|
||||
response = api.get(issues_url)
|
||||
data = response.json()
|
||||
@@ -352,7 +352,7 @@ class GithubAddLabelBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
def add_label(credentials: GithubCredentials, issue_url: str, label: str) -> str:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
data = {"labels": [label]}
|
||||
labels_url = issue_url + "/labels"
|
||||
response = api.post(labels_url, json=data)
|
||||
@@ -413,7 +413,7 @@ class GithubRemoveLabelBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
def remove_label(credentials: GithubCredentials, issue_url: str, label: str) -> str:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
label_url = issue_url + f"/labels/{label}"
|
||||
response = api.delete(label_url)
|
||||
response.raise_for_status()
|
||||
@@ -479,7 +479,7 @@ class GithubAssignIssueBlock(Block):
|
||||
issue_url: str,
|
||||
assignee: str,
|
||||
) -> str:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
assignees_url = issue_url + "/assignees"
|
||||
data = {"assignees": [assignee]}
|
||||
response = api.post(assignees_url, json=data)
|
||||
@@ -546,7 +546,7 @@ class GithubUnassignIssueBlock(Block):
|
||||
issue_url: str,
|
||||
assignee: str,
|
||||
) -> str:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
assignees_url = issue_url + "/assignees"
|
||||
data = {"assignees": [assignee]}
|
||||
response = api.delete(assignees_url, json=data)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing_extensions import TypedDict
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import GitHubAPI
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
@@ -64,7 +64,7 @@ class GithubListPullRequestsBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
def list_prs(credentials: GithubCredentials, repo_url: str) -> list[Output.PRItem]:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
pulls_url = repo_url + "/pulls"
|
||||
response = api.get(pulls_url)
|
||||
data = response.json()
|
||||
@@ -159,7 +159,7 @@ class GithubMakePullRequestBlock(Block):
|
||||
head: str,
|
||||
base: str,
|
||||
) -> tuple[int, str]:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
pulls_url = repo_url + "/pulls"
|
||||
data = {"title": title, "body": body, "head": head, "base": base}
|
||||
response = api.post(pulls_url, json=data)
|
||||
@@ -240,7 +240,7 @@ class GithubReadPullRequestBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
def read_pr(credentials: GithubCredentials, pr_url: str) -> tuple[str, str, str]:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
# Adjust the URL to access the issue endpoint for PR metadata
|
||||
issue_url = pr_url.replace("/pull/", "/issues/")
|
||||
response = api.get(issue_url)
|
||||
@@ -252,7 +252,7 @@ class GithubReadPullRequestBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
def read_pr_changes(credentials: GithubCredentials, pr_url: str) -> str:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
files_url = pr_url + "/files"
|
||||
response = api.get(files_url)
|
||||
files = response.json()
|
||||
@@ -330,7 +330,7 @@ class GithubAssignPRReviewerBlock(Block):
|
||||
def assign_reviewer(
|
||||
credentials: GithubCredentials, pr_url: str, reviewer: str
|
||||
) -> str:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
reviewers_url = pr_url + "/requested_reviewers"
|
||||
data = {"reviewers": [reviewer]}
|
||||
api.post(reviewers_url, json=data)
|
||||
@@ -397,7 +397,7 @@ class GithubUnassignPRReviewerBlock(Block):
|
||||
def unassign_reviewer(
|
||||
credentials: GithubCredentials, pr_url: str, reviewer: str
|
||||
) -> str:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
reviewers_url = pr_url + "/requested_reviewers"
|
||||
data = {"reviewers": [reviewer]}
|
||||
api.delete(reviewers_url, json=data)
|
||||
@@ -477,7 +477,7 @@ class GithubListPRReviewersBlock(Block):
|
||||
def list_reviewers(
|
||||
credentials: GithubCredentials, pr_url: str
|
||||
) -> list[Output.ReviewerItem]:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
reviewers_url = pr_url + "/requested_reviewers"
|
||||
response = api.get(reviewers_url)
|
||||
data = response.json()
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing_extensions import TypedDict
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import GitHubAPI
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
@@ -68,7 +68,7 @@ class GithubListTagsBlock(Block):
|
||||
def list_tags(
|
||||
credentials: GithubCredentials, repo_url: str
|
||||
) -> list[Output.TagItem]:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
tags_url = repo_url + "/tags"
|
||||
response = api.get(tags_url)
|
||||
data = response.json()
|
||||
@@ -150,7 +150,7 @@ class GithubListBranchesBlock(Block):
|
||||
def list_branches(
|
||||
credentials: GithubCredentials, repo_url: str
|
||||
) -> list[Output.BranchItem]:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
branches_url = repo_url + "/branches"
|
||||
response = api.get(branches_url)
|
||||
data = response.json()
|
||||
@@ -237,7 +237,7 @@ class GithubListDiscussionsBlock(Block):
|
||||
def list_discussions(
|
||||
credentials: GithubCredentials, repo_url: str, num_discussions: int
|
||||
) -> list[Output.DiscussionItem]:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
owner, repo = repo_path.split("/")
|
||||
@@ -332,7 +332,7 @@ class GithubListReleasesBlock(Block):
|
||||
def list_releases(
|
||||
credentials: GithubCredentials, repo_url: str
|
||||
) -> list[Output.ReleaseItem]:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
releases_url = repo_url + "/releases"
|
||||
response = api.get(releases_url)
|
||||
data = response.json()
|
||||
@@ -408,7 +408,7 @@ class GithubReadFileBlock(Block):
|
||||
def read_file(
|
||||
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
||||
) -> tuple[str, int]:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
content_url = repo_url + f"/contents/{file_path}?ref={branch}"
|
||||
response = api.get(content_url)
|
||||
content = response.json()
|
||||
@@ -518,7 +518,7 @@ class GithubReadFolderBlock(Block):
|
||||
def read_folder(
|
||||
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
||||
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{folder_path}?ref={branch}"
|
||||
response = api.get(contents_url)
|
||||
content = response.json()
|
||||
@@ -612,7 +612,7 @@ class GithubMakeBranchBlock(Block):
|
||||
new_branch: str,
|
||||
source_branch: str,
|
||||
) -> str:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
# Get the SHA of the source branch
|
||||
ref_url = repo_url + f"/git/refs/heads/{source_branch}"
|
||||
response = api.get(ref_url)
|
||||
@@ -681,7 +681,7 @@ class GithubDeleteBranchBlock(Block):
|
||||
def delete_branch(
|
||||
credentials: GithubCredentials, repo_url: str, branch: str
|
||||
) -> str:
|
||||
api = GitHubAPI(credentials)
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{branch}"
|
||||
api.delete(ref_url)
|
||||
return "Branch deleted successfully"
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
import ipaddress
|
||||
import json
|
||||
import socket
|
||||
from enum import Enum
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.settings import Config
|
||||
from backend.util.request import requests
|
||||
|
||||
|
||||
class HttpMethod(Enum):
|
||||
@@ -21,35 +16,6 @@ class HttpMethod(Enum):
|
||||
HEAD = "HEAD"
|
||||
|
||||
|
||||
def validate_url(url: str) -> str:
|
||||
"""
|
||||
To avoid SSRF attacks, the URL should not be a private IP address
|
||||
unless it is whitelisted in TRUST_ENDPOINTS_FOR_REQUESTS config.
|
||||
"""
|
||||
if any(url.startswith(origin) for origin in Config().trust_endpoints_for_requests):
|
||||
return url
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
hostname = parsed_url.hostname
|
||||
|
||||
if not hostname:
|
||||
raise ValueError(f"Invalid URL: Unable to determine hostname from {url}")
|
||||
|
||||
try:
|
||||
host = socket.gethostbyname_ex(hostname)
|
||||
for ip in host[2]:
|
||||
ip_addr = ipaddress.ip_address(ip)
|
||||
if ip_addr.is_global:
|
||||
return url
|
||||
raise ValueError(
|
||||
f"Access to private or untrusted IP address at {hostname} is not allowed."
|
||||
)
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid or unresolvable URL: {url}") from e
|
||||
|
||||
|
||||
class SendWebRequestBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
url: str = SchemaField(
|
||||
@@ -87,11 +53,9 @@ class SendWebRequestBlock(Block):
|
||||
if isinstance(input_data.body, str):
|
||||
input_data.body = json.loads(input_data.body)
|
||||
|
||||
validated_url = validate_url(input_data.url)
|
||||
|
||||
response = requests.request(
|
||||
input_data.method.value,
|
||||
validated_url,
|
||||
input_data.url,
|
||||
headers=input_data.headers,
|
||||
json=input_data.body,
|
||||
allow_redirects=False,
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
import requests
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
|
||||
from pydantic import SecretStr
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput, SchemaField
|
||||
from backend.util.request import requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -244,7 +245,7 @@ class IdeogramModelBlock(Block):
|
||||
response = requests.post(url, json=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()["data"][0]["url"]
|
||||
except requests.exceptions.RequestException as e:
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to fetch image: {str(e)}")
|
||||
|
||||
def upscale_image(self, api_key: SecretStr, image_url: str):
|
||||
@@ -275,5 +276,5 @@ class IdeogramModelBlock(Block):
|
||||
response.raise_for_status()
|
||||
return response.json()["data"][0]["url"]
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to upscale image: {str(e)}")
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import requests
|
||||
|
||||
from backend.blocks.jina._auth import (
|
||||
JinaCredentials,
|
||||
JinaCredentialsField,
|
||||
@@ -7,6 +5,7 @@ from backend.blocks.jina._auth import (
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import requests
|
||||
|
||||
|
||||
class JinaChunkingBlock(Block):
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import requests
|
||||
|
||||
from backend.blocks.jina._auth import (
|
||||
JinaCredentials,
|
||||
JinaCredentialsField,
|
||||
@@ -7,6 +5,7 @@ from backend.blocks.jina._auth import (
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import requests
|
||||
|
||||
|
||||
class JinaEmbeddingBlock(Block):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from enum import Enum
|
||||
from typing import List, Literal
|
||||
|
||||
import requests
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -13,6 +12,7 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
SecretField,
|
||||
)
|
||||
from backend.util.request import requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
|
||||
@@ -3,12 +3,12 @@ import socket
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import quote, urlparse
|
||||
|
||||
import requests
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput, SchemaField
|
||||
from backend.util.request import requests
|
||||
from backend.util.settings import Config
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
import requests
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput, SchemaField
|
||||
from backend.util.request import requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
import requests
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput, SchemaField
|
||||
from backend.util.request import requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
|
||||
99
autogpt_platform/backend/backend/util/request.py
Normal file
99
autogpt_platform/backend/backend/util/request.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import ipaddress
|
||||
import socket
|
||||
from typing import Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests as req
|
||||
|
||||
from backend.util.settings import Config
|
||||
|
||||
|
||||
def is_ip_allowed(ip: str) -> bool:
|
||||
"""
|
||||
Checks if the IP address is allowed (i.e., it's a global IP address).
|
||||
"""
|
||||
ip_addr = ipaddress.ip_address(ip)
|
||||
return ip_addr.is_global
|
||||
|
||||
|
||||
def validate_url(url: str, trusted_origins: list[str]) -> str:
|
||||
"""
|
||||
Validates the URL to prevent SSRF attacks by ensuring it does not point to a private
|
||||
or untrusted IP address, unless whitelisted.
|
||||
"""
|
||||
if any(url.startswith(origin) for origin in trusted_origins):
|
||||
return url
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
hostname = parsed_url.hostname
|
||||
if not hostname:
|
||||
raise ValueError(f"Invalid URL: Unable to determine hostname from {url}")
|
||||
|
||||
try:
|
||||
# Resolve all IP addresses for the hostname
|
||||
ip_addresses = {result[4][0] for result in socket.getaddrinfo(hostname, None)}
|
||||
# Check if all IP addresses are global
|
||||
if all(is_ip_allowed(ip) for ip in ip_addresses):
|
||||
return url
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Access to private or untrusted IP address at {hostname} is not allowed."
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid or unresolvable URL: {url}") from e
|
||||
|
||||
|
||||
class Requests:
|
||||
"""
|
||||
A wrapper around the requests library that validates URLs before making requests.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trusted_origins: list[str],
|
||||
raise_for_status: bool = True,
|
||||
extra_url_validator: Callable[[str], str] | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
):
|
||||
self.trusted_origins = trusted_origins
|
||||
self.raise_for_status = raise_for_status
|
||||
self.extra_url_validator = extra_url_validator
|
||||
self.extra_headers = extra_headers
|
||||
|
||||
def request(self, method, url, headers=None, *args, **kwargs) -> req.Response:
|
||||
if self.extra_headers is not None:
|
||||
headers = {**(headers or {}), **self.extra_headers}
|
||||
|
||||
if self.extra_url_validator is not None:
|
||||
url = self.extra_url_validator(url)
|
||||
url = validate_url(url, self.trusted_origins)
|
||||
|
||||
response = req.request(method, url, headers=headers, *args, **kwargs)
|
||||
if self.raise_for_status:
|
||||
response.raise_for_status()
|
||||
|
||||
return response
|
||||
|
||||
def get(self, url, *args, **kwargs) -> req.Response:
|
||||
return self.request("GET", url, *args, **kwargs)
|
||||
|
||||
def post(self, url, *args, **kwargs) -> req.Response:
|
||||
return self.request("POST", url, *args, **kwargs)
|
||||
|
||||
def put(self, url, *args, **kwargs) -> req.Response:
|
||||
return self.request("PUT", url, *args, **kwargs)
|
||||
|
||||
def delete(self, url, *args, **kwargs) -> req.Response:
|
||||
return self.request("DELETE", url, *args, **kwargs)
|
||||
|
||||
def head(self, url, *args, **kwargs) -> req.Response:
|
||||
return self.request("HEAD", url, *args, **kwargs)
|
||||
|
||||
def options(self, url, *args, **kwargs) -> req.Response:
|
||||
return self.request("OPTIONS", url, *args, **kwargs)
|
||||
|
||||
def patch(self, url, *args, **kwargs) -> req.Response:
|
||||
return self.request("PATCH", url, *args, **kwargs)
|
||||
|
||||
|
||||
requests = Requests(Config().trust_endpoints_for_requests)
|
||||
Reference in New Issue
Block a user