mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-30 09:28:19 -05:00
Compare commits
4 Commits
dev
...
pwuts/secr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8045a67869 | ||
|
|
571947322d | ||
|
|
3310db5495 | ||
|
|
b59feffe25 |
@@ -8,7 +8,13 @@ from backend.api.features.library import model as library_model
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.data.model import (
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
_extract_host_from_url,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -273,7 +279,14 @@ async def match_user_credentials_to_graph(
|
||||
for cred in available_creds
|
||||
if cred.provider in credential_requirements.provider
|
||||
and cred.type in credential_requirements.supported_types
|
||||
and _credential_has_required_scopes(cred, credential_requirements)
|
||||
and (
|
||||
cred.type != "oauth2"
|
||||
or _credential_has_required_scopes(cred, credential_requirements)
|
||||
)
|
||||
and (
|
||||
cred.type != "host_scoped"
|
||||
or _credential_is_for_host(cred, credential_requirements)
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -318,19 +331,10 @@ async def match_user_credentials_to_graph(
|
||||
|
||||
|
||||
def _credential_has_required_scopes(
|
||||
credential: Credentials,
|
||||
credential: OAuth2Credentials,
|
||||
requirements: CredentialsFieldInfo,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a credential has all the scopes required by the block.
|
||||
|
||||
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||
of the required scopes. For other credential types, returns True (no scope check).
|
||||
"""
|
||||
# Only OAuth2 credentials have scopes to check
|
||||
if credential.type != "oauth2":
|
||||
return True
|
||||
|
||||
"""Check if an OAuth2 credential has all the scopes required by the input."""
|
||||
# If no scopes are required, any credential matches
|
||||
if not requirements.required_scopes:
|
||||
return True
|
||||
@@ -339,6 +343,23 @@ def _credential_has_required_scopes(
|
||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||
|
||||
|
||||
def _credential_is_for_host(
|
||||
credential: HostScopedCredentials,
|
||||
requirements: CredentialsFieldInfo,
|
||||
) -> bool:
|
||||
"""Check if a host-scoped credential matches the host required by the input."""
|
||||
# We need to know the host to match host-scoped credentials to.
|
||||
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
|
||||
# to discriminator_values. No discriminator_values -> no host to match against.
|
||||
if not requirements.discriminator_values:
|
||||
return True
|
||||
|
||||
host = _extract_host_from_url(requirements.discriminator_values.pop())
|
||||
|
||||
# Check that credential host matches required host
|
||||
return credential.host == host
|
||||
|
||||
|
||||
async def check_user_has_required_credentials(
|
||||
user_id: str,
|
||||
required_credentials: list[CredentialsMetaInput],
|
||||
|
||||
@@ -42,6 +42,7 @@ from typing_extensions import TypedDict
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.json import loads as json_loads
|
||||
from backend.util.request import parse_url
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
# Type alias for any provider name (including custom ones)
|
||||
@@ -554,8 +555,8 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
def _extract_host_from_url(url: str) -> str:
|
||||
"""Extract host from URL for grouping host-scoped credentials."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return parsed.hostname or url
|
||||
parsed = parse_url(url)
|
||||
return parsed.netloc or url
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
@@ -157,12 +157,7 @@ async def validate_url(
|
||||
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
||||
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
||||
"""
|
||||
# Canonicalize URL
|
||||
url = url.strip("/ ").replace("\\", "/")
|
||||
parsed = urlparse(url)
|
||||
if not parsed.scheme:
|
||||
url = f"http://{url}"
|
||||
parsed = urlparse(url)
|
||||
parsed = parse_url(url)
|
||||
|
||||
# Check scheme
|
||||
if parsed.scheme not in ALLOWED_SCHEMES:
|
||||
@@ -220,6 +215,17 @@ async def validate_url(
|
||||
)
|
||||
|
||||
|
||||
def parse_url(url: str) -> URL:
|
||||
"""Canonicalizes and parses a URL string."""
|
||||
url = url.strip("/ ").replace("\\", "/")
|
||||
parsed = urlparse(url)
|
||||
if not parsed.scheme:
|
||||
url = f"http://{url}"
|
||||
parsed = urlparse(url)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
|
||||
"""
|
||||
Pins a URL to a specific IP address to prevent DNS rebinding attacks.
|
||||
|
||||
@@ -41,7 +41,17 @@ export function HostScopedCredentialsModal({
|
||||
const currentHost = currentUrl ? getHostFromUrl(currentUrl) : "";
|
||||
|
||||
const formSchema = z.object({
|
||||
host: z.string().min(1, "Host is required"),
|
||||
host: z
|
||||
.string()
|
||||
.min(1, "Host is required")
|
||||
.refine((val) => !/^[a-zA-Z][a-zA-Z\d+\-.]*:\/\//.test(val), {
|
||||
message: "Enter only the host (e.g. api.example.com), not a full URL",
|
||||
})
|
||||
.refine((val) => !val.includes("/"), {
|
||||
message:
|
||||
"Enter only the host (e.g. api.example.com), without a trailing path. " +
|
||||
"You may specify a port (e.g. api.example.com:8080) if needed.",
|
||||
}),
|
||||
title: z.string().optional(),
|
||||
headers: z.record(z.string()).optional(),
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user