Apply secure request to all blocks

This commit is contained in:
Zamil Majdy
2024-11-06 11:33:14 +07:00
parent 6ac4132e64
commit 4908f4633d
14 changed files with 174 additions and 130 deletions

View File

@@ -3,12 +3,12 @@ import time
from enum import Enum
from typing import Literal
import requests
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, CredentialsMetaInput, SchemaField
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",

View File

@@ -1,68 +1,50 @@
from urllib.parse import urlparse
import requests
from ._auth import GithubCredentials
from backend.blocks.github._auth import GithubCredentials
from backend.util.request import Requests
class GitHubAPI:
def __init__(self, credentials: GithubCredentials):
self.credentials = credentials
def _validate_github_url(url: str) -> None:
parsed_url = urlparse(url)
if parsed_url.netloc != "github.com":
raise ValueError("The input URL must be a valid GitHub URL.")
@staticmethod
def _validate_github_url(url: str) -> None:
parsed_url = urlparse(url)
if parsed_url.netloc != "github.com":
raise ValueError("The input URL must be a valid GitHub URL.")
@staticmethod
def _convert_to_api_url(url: str) -> str:
"""
Converts a standard GitHub URL to the corresponding GitHub API URL.
Handles repository URLs, issue URLs, pull request URLs, and more.
"""
GitHubAPI._validate_github_url(url)
parsed_url = urlparse(url)
path_parts = parsed_url.path.strip("/").split("/")
def _convert_to_api_url(url: str) -> str:
"""
Converts a standard GitHub URL to the corresponding GitHub API URL.
Handles repository URLs, issue URLs, pull request URLs, and more.
"""
_validate_github_url(url)
parsed_url = urlparse(url)
path_parts = parsed_url.path.strip("/").split("/")
if len(path_parts) >= 2:
owner, repo = path_parts[0], path_parts[1]
api_base = f"https://api.github.com/repos/{owner}/{repo}"
if len(path_parts) >= 2:
owner, repo = path_parts[0], path_parts[1]
api_base = f"https://api.github.com/repos/{owner}/{repo}"
if len(path_parts) > 2:
additional_path = "/".join(path_parts[2:])
api_url = f"{api_base}/{additional_path}"
else:
# Repository base URL
api_url = api_base
if len(path_parts) > 2:
additional_path = "/".join(path_parts[2:])
api_url = f"{api_base}/{additional_path}"
else:
raise ValueError("Invalid GitHub URL format.")
# Repository base URL
api_url = api_base
else:
raise ValueError("Invalid GitHub URL format.")
return api_url
return api_url
def _get_headers(self) -> dict:
return {
"Authorization": self.credentials.bearer(),
"Accept": "application/vnd.github.v3+json",
}
def get(self, url: str, **kwargs) -> requests.Response:
api_url = self._convert_to_api_url(url)
headers = self._get_headers()
response = requests.get(api_url, headers=headers, **kwargs)
response.raise_for_status()
return response
def _get_headers(credentials: GithubCredentials) -> dict[str, str]:
return {
"Authorization": credentials.bearer(),
"Accept": "application/vnd.github.v3+json",
}
def post(self, url: str, json=None, **kwargs) -> requests.Response:
api_url = self._convert_to_api_url(url)
headers = self._get_headers()
response = requests.post(api_url, headers=headers, json=json, **kwargs)
response.raise_for_status()
return response
def delete(self, url: str, json=None, **kwargs) -> requests.Response:
api_url = self._convert_to_api_url(url)
headers = self._get_headers()
response = requests.delete(api_url, headers=headers, json=json, **kwargs)
response.raise_for_status()
return response
def get_api(credentials: GithubCredentials) -> Requests:
return Requests(
trusted_origins=["https://api.github.com", "https://github.com"],
extra_url_validator=_convert_to_api_url,
extra_headers=_get_headers(credentials),
)

View File

@@ -5,7 +5,7 @@ from typing_extensions import TypedDict
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._api import GitHubAPI
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
@@ -68,7 +68,7 @@ class GithubCommentBlock(Block):
def post_comment(
credentials: GithubCredentials, issue_url: str, body_text: str
) -> tuple[int, str]:
api = GitHubAPI(credentials)
api = get_api(credentials)
data = {"body": body_text}
comments_url = issue_url + "/comments"
response = api.post(comments_url, json=data)
@@ -145,7 +145,7 @@ class GithubMakeIssueBlock(Block):
def create_issue(
credentials: GithubCredentials, repo_url: str, title: str, body: str
) -> tuple[int, str]:
api = GitHubAPI(credentials)
api = get_api(credentials)
data = {"title": title, "body": body}
issues_url = repo_url + "/issues"
response = api.post(issues_url, json=data)
@@ -215,7 +215,7 @@ class GithubReadIssueBlock(Block):
def read_issue(
credentials: GithubCredentials, issue_url: str
) -> tuple[str, str, str]:
api = GitHubAPI(credentials)
api = get_api(credentials)
response = api.get(issue_url)
data = response.json()
title = data.get("title", "No title found")
@@ -292,7 +292,7 @@ class GithubListIssuesBlock(Block):
def list_issues(
credentials: GithubCredentials, repo_url: str
) -> list[Output.IssueItem]:
api = GitHubAPI(credentials)
api = get_api(credentials)
issues_url = repo_url + "/issues"
response = api.get(issues_url)
data = response.json()
@@ -352,7 +352,7 @@ class GithubAddLabelBlock(Block):
@staticmethod
def add_label(credentials: GithubCredentials, issue_url: str, label: str) -> str:
api = GitHubAPI(credentials)
api = get_api(credentials)
data = {"labels": [label]}
labels_url = issue_url + "/labels"
response = api.post(labels_url, json=data)
@@ -413,7 +413,7 @@ class GithubRemoveLabelBlock(Block):
@staticmethod
def remove_label(credentials: GithubCredentials, issue_url: str, label: str) -> str:
api = GitHubAPI(credentials)
api = get_api(credentials)
label_url = issue_url + f"/labels/{label}"
response = api.delete(label_url)
response.raise_for_status()
@@ -479,7 +479,7 @@ class GithubAssignIssueBlock(Block):
issue_url: str,
assignee: str,
) -> str:
api = GitHubAPI(credentials)
api = get_api(credentials)
assignees_url = issue_url + "/assignees"
data = {"assignees": [assignee]}
response = api.post(assignees_url, json=data)
@@ -546,7 +546,7 @@ class GithubUnassignIssueBlock(Block):
issue_url: str,
assignee: str,
) -> str:
api = GitHubAPI(credentials)
api = get_api(credentials)
assignees_url = issue_url + "/assignees"
data = {"assignees": [assignee]}
response = api.delete(assignees_url, json=data)

View File

@@ -3,7 +3,7 @@ from typing_extensions import TypedDict
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._api import GitHubAPI
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
@@ -64,7 +64,7 @@ class GithubListPullRequestsBlock(Block):
@staticmethod
def list_prs(credentials: GithubCredentials, repo_url: str) -> list[Output.PRItem]:
api = GitHubAPI(credentials)
api = get_api(credentials)
pulls_url = repo_url + "/pulls"
response = api.get(pulls_url)
data = response.json()
@@ -159,7 +159,7 @@ class GithubMakePullRequestBlock(Block):
head: str,
base: str,
) -> tuple[int, str]:
api = GitHubAPI(credentials)
api = get_api(credentials)
pulls_url = repo_url + "/pulls"
data = {"title": title, "body": body, "head": head, "base": base}
response = api.post(pulls_url, json=data)
@@ -240,7 +240,7 @@ class GithubReadPullRequestBlock(Block):
@staticmethod
def read_pr(credentials: GithubCredentials, pr_url: str) -> tuple[str, str, str]:
api = GitHubAPI(credentials)
api = get_api(credentials)
# Adjust the URL to access the issue endpoint for PR metadata
issue_url = pr_url.replace("/pull/", "/issues/")
response = api.get(issue_url)
@@ -252,7 +252,7 @@ class GithubReadPullRequestBlock(Block):
@staticmethod
def read_pr_changes(credentials: GithubCredentials, pr_url: str) -> str:
api = GitHubAPI(credentials)
api = get_api(credentials)
files_url = pr_url + "/files"
response = api.get(files_url)
files = response.json()
@@ -330,7 +330,7 @@ class GithubAssignPRReviewerBlock(Block):
def assign_reviewer(
credentials: GithubCredentials, pr_url: str, reviewer: str
) -> str:
api = GitHubAPI(credentials)
api = get_api(credentials)
reviewers_url = pr_url + "/requested_reviewers"
data = {"reviewers": [reviewer]}
api.post(reviewers_url, json=data)
@@ -397,7 +397,7 @@ class GithubUnassignPRReviewerBlock(Block):
def unassign_reviewer(
credentials: GithubCredentials, pr_url: str, reviewer: str
) -> str:
api = GitHubAPI(credentials)
api = get_api(credentials)
reviewers_url = pr_url + "/requested_reviewers"
data = {"reviewers": [reviewer]}
api.delete(reviewers_url, json=data)
@@ -477,7 +477,7 @@ class GithubListPRReviewersBlock(Block):
def list_reviewers(
credentials: GithubCredentials, pr_url: str
) -> list[Output.ReviewerItem]:
api = GitHubAPI(credentials)
api = get_api(credentials)
reviewers_url = pr_url + "/requested_reviewers"
response = api.get(reviewers_url)
data = response.json()

View File

@@ -5,7 +5,7 @@ from typing_extensions import TypedDict
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._api import GitHubAPI
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
@@ -68,7 +68,7 @@ class GithubListTagsBlock(Block):
def list_tags(
credentials: GithubCredentials, repo_url: str
) -> list[Output.TagItem]:
api = GitHubAPI(credentials)
api = get_api(credentials)
tags_url = repo_url + "/tags"
response = api.get(tags_url)
data = response.json()
@@ -150,7 +150,7 @@ class GithubListBranchesBlock(Block):
def list_branches(
credentials: GithubCredentials, repo_url: str
) -> list[Output.BranchItem]:
api = GitHubAPI(credentials)
api = get_api(credentials)
branches_url = repo_url + "/branches"
response = api.get(branches_url)
data = response.json()
@@ -237,7 +237,7 @@ class GithubListDiscussionsBlock(Block):
def list_discussions(
credentials: GithubCredentials, repo_url: str, num_discussions: int
) -> list[Output.DiscussionItem]:
api = GitHubAPI(credentials)
api = get_api(credentials)
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
repo_path = repo_url.replace("https://github.com/", "")
owner, repo = repo_path.split("/")
@@ -332,7 +332,7 @@ class GithubListReleasesBlock(Block):
def list_releases(
credentials: GithubCredentials, repo_url: str
) -> list[Output.ReleaseItem]:
api = GitHubAPI(credentials)
api = get_api(credentials)
releases_url = repo_url + "/releases"
response = api.get(releases_url)
data = response.json()
@@ -408,7 +408,7 @@ class GithubReadFileBlock(Block):
def read_file(
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
) -> tuple[str, int]:
api = GitHubAPI(credentials)
api = get_api(credentials)
content_url = repo_url + f"/contents/{file_path}?ref={branch}"
response = api.get(content_url)
content = response.json()
@@ -518,7 +518,7 @@ class GithubReadFolderBlock(Block):
def read_folder(
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
api = GitHubAPI(credentials)
api = get_api(credentials)
contents_url = repo_url + f"/contents/{folder_path}?ref={branch}"
response = api.get(contents_url)
content = response.json()
@@ -612,7 +612,7 @@ class GithubMakeBranchBlock(Block):
new_branch: str,
source_branch: str,
) -> str:
api = GitHubAPI(credentials)
api = get_api(credentials)
# Get the SHA of the source branch
ref_url = repo_url + f"/git/refs/heads/{source_branch}"
response = api.get(ref_url)
@@ -681,7 +681,7 @@ class GithubDeleteBranchBlock(Block):
def delete_branch(
credentials: GithubCredentials, repo_url: str, branch: str
) -> str:
api = GitHubAPI(credentials)
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{branch}"
api.delete(ref_url)
return "Branch deleted successfully"

View File

@@ -1,14 +1,9 @@
import ipaddress
import json
import socket
from enum import Enum
from urllib.parse import urlparse
import requests
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.settings import Config
from backend.util.request import requests
class HttpMethod(Enum):
@@ -21,35 +16,6 @@ class HttpMethod(Enum):
HEAD = "HEAD"
def validate_url(url: str) -> str:
"""
To avoid SSRF attacks, the URL should not be a private IP address
unless it is whitelisted in TRUST_ENDPOINTS_FOR_REQUESTS config.
"""
if any(url.startswith(origin) for origin in Config().trust_endpoints_for_requests):
return url
parsed_url = urlparse(url)
hostname = parsed_url.hostname
if not hostname:
raise ValueError(f"Invalid URL: Unable to determine hostname from {url}")
try:
host = socket.gethostbyname_ex(hostname)
for ip in host[2]:
ip_addr = ipaddress.ip_address(ip)
if ip_addr.is_global:
return url
raise ValueError(
f"Access to private or untrusted IP address at {hostname} is not allowed."
)
except ValueError:
raise
except Exception as e:
raise ValueError(f"Invalid or unresolvable URL: {url}") from e
class SendWebRequestBlock(Block):
class Input(BlockSchema):
url: str = SchemaField(
@@ -87,11 +53,9 @@ class SendWebRequestBlock(Block):
if isinstance(input_data.body, str):
input_data.body = json.loads(input_data.body)
validated_url = validate_url(input_data.url)
response = requests.request(
input_data.method.value,
validated_url,
input_data.url,
headers=input_data.headers,
json=input_data.body,
allow_redirects=False,

View File

@@ -1,12 +1,13 @@
from enum import Enum
from typing import Any, Dict, Literal, Optional
import requests
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
from pydantic import SecretStr
from requests.exceptions import RequestException
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, CredentialsMetaInput, SchemaField
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -244,7 +245,7 @@ class IdeogramModelBlock(Block):
response = requests.post(url, json=data, headers=headers)
response.raise_for_status()
return response.json()["data"][0]["url"]
except requests.exceptions.RequestException as e:
except RequestException as e:
raise Exception(f"Failed to fetch image: {str(e)}")
def upscale_image(self, api_key: SecretStr, image_url: str):
@@ -275,5 +276,5 @@ class IdeogramModelBlock(Block):
response.raise_for_status()
return response.json()["data"][0]["url"]
except requests.exceptions.RequestException as e:
except RequestException as e:
raise Exception(f"Failed to upscale image: {str(e)}")

View File

@@ -1,5 +1,3 @@
import requests
from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
@@ -7,6 +5,7 @@ from backend.blocks.jina._auth import (
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import requests
class JinaChunkingBlock(Block):

View File

@@ -1,5 +1,3 @@
import requests
from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
@@ -7,6 +5,7 @@ from backend.blocks.jina._auth import (
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import requests
class JinaEmbeddingBlock(Block):

View File

@@ -1,7 +1,6 @@
from enum import Enum
from typing import List, Literal
import requests
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
from pydantic import SecretStr
@@ -13,6 +12,7 @@ from backend.data.model import (
SchemaField,
SecretField,
)
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",

View File

@@ -3,12 +3,12 @@ import socket
from typing import Any, Literal
from urllib.parse import quote, urlparse
import requests
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, CredentialsMetaInput, SchemaField
from backend.util.request import requests
from backend.util.settings import Config

View File

@@ -1,12 +1,12 @@
import time
from typing import Literal
import requests
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, CredentialsMetaInput, SchemaField
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",

View File

@@ -1,11 +1,11 @@
from typing import Any, Literal
import requests
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, CredentialsMetaInput, SchemaField
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",

View File

@@ -0,0 +1,99 @@
import ipaddress
import socket
from typing import Callable
from urllib.parse import urlparse
import requests as req
from backend.util.settings import Config
def is_ip_allowed(ip: str) -> bool:
"""
Checks if the IP address is allowed (i.e., it's a global IP address).
"""
ip_addr = ipaddress.ip_address(ip)
return ip_addr.is_global
def validate_url(url: str, trusted_origins: list[str]) -> str:
"""
Validates the URL to prevent SSRF attacks by ensuring it does not point to a private
or untrusted IP address, unless whitelisted.
"""
if any(url.startswith(origin) for origin in trusted_origins):
return url
parsed_url = urlparse(url)
hostname = parsed_url.hostname
if not hostname:
raise ValueError(f"Invalid URL: Unable to determine hostname from {url}")
try:
# Resolve all IP addresses for the hostname
ip_addresses = {result[4][0] for result in socket.getaddrinfo(hostname, None)}
# Check if all IP addresses are global
if all(is_ip_allowed(ip) for ip in ip_addresses):
return url
else:
raise ValueError(
f"Access to private or untrusted IP address at {hostname} is not allowed."
)
except Exception as e:
raise ValueError(f"Invalid or unresolvable URL: {url}") from e
class Requests:
"""
A wrapper around the requests library that validates URLs before making requests.
"""
def __init__(
self,
trusted_origins: list[str],
raise_for_status: bool = True,
extra_url_validator: Callable[[str], str] | None = None,
extra_headers: dict[str, str] | None = None,
):
self.trusted_origins = trusted_origins
self.raise_for_status = raise_for_status
self.extra_url_validator = extra_url_validator
self.extra_headers = extra_headers
def request(self, method, url, headers=None, *args, **kwargs) -> req.Response:
if self.extra_headers is not None:
headers = {**(headers or {}), **self.extra_headers}
if self.extra_url_validator is not None:
url = self.extra_url_validator(url)
url = validate_url(url, self.trusted_origins)
response = req.request(method, url, headers=headers, *args, **kwargs)
if self.raise_for_status:
response.raise_for_status()
return response
def get(self, url, *args, **kwargs) -> req.Response:
return self.request("GET", url, *args, **kwargs)
def post(self, url, *args, **kwargs) -> req.Response:
return self.request("POST", url, *args, **kwargs)
def put(self, url, *args, **kwargs) -> req.Response:
return self.request("PUT", url, *args, **kwargs)
def delete(self, url, *args, **kwargs) -> req.Response:
return self.request("DELETE", url, *args, **kwargs)
def head(self, url, *args, **kwargs) -> req.Response:
return self.request("HEAD", url, *args, **kwargs)
def options(self, url, *args, **kwargs) -> req.Response:
return self.request("OPTIONS", url, *args, **kwargs)
def patch(self, url, *args, **kwargs) -> req.Response:
return self.request("PATCH", url, *args, **kwargs)
requests = Requests(Config().trust_endpoints_for_requests)