Compare commits

...

4 Commits

4 changed files with 60 additions and 22 deletions

View File

@@ -8,7 +8,13 @@ from backend.api.features.library import model as library_model
from backend.api.features.store import db as store_db
from backend.data import graph as graph_db
from backend.data.graph import GraphModel
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
from backend.data.model import (
CredentialsFieldInfo,
CredentialsMetaInput,
HostScopedCredentials,
OAuth2Credentials,
_extract_host_from_url,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import NotFoundError
@@ -273,7 +279,14 @@ async def match_user_credentials_to_graph(
for cred in available_creds
if cred.provider in credential_requirements.provider
and cred.type in credential_requirements.supported_types
and _credential_has_required_scopes(cred, credential_requirements)
and (
cred.type != "oauth2"
or _credential_has_required_scopes(cred, credential_requirements)
)
and (
cred.type != "host_scoped"
or _credential_is_for_host(cred, credential_requirements)
)
),
None,
)
@@ -318,19 +331,10 @@ async def match_user_credentials_to_graph(
def _credential_has_required_scopes(
credential: Credentials,
credential: OAuth2Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""
Check if a credential has all the scopes required by the block.
For OAuth2 credentials, verifies that the credential's scopes are a superset
of the required scopes. For other credential types, returns True (no scope check).
"""
# Only OAuth2 credentials have scopes to check
if credential.type != "oauth2":
return True
"""Check if an OAuth2 credential has all the scopes required by the input."""
# If no scopes are required, any credential matches
if not requirements.required_scopes:
return True
@@ -339,6 +343,23 @@ def _credential_has_required_scopes(
return set(credential.scopes).issuperset(requirements.required_scopes)
def _credential_is_for_host(
credential: HostScopedCredentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if a host-scoped credential matches the host required by the input."""
# We need to know the host to match host-scoped credentials to.
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
# to discriminator_values. No discriminator_values -> no host to match against.
if not requirements.discriminator_values:
return True
host = _extract_host_from_url(requirements.discriminator_values.pop())
# Check that credential host matches required host
return credential.host == host
async def check_user_has_required_credentials(
user_id: str,
required_credentials: list[CredentialsMetaInput],

View File

@@ -42,6 +42,7 @@ from typing_extensions import TypedDict
from backend.integrations.providers import ProviderName
from backend.util.json import loads as json_loads
from backend.util.request import parse_url
from backend.util.settings import Secrets
# Type alias for any provider name (including custom ones)
@@ -554,8 +555,8 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
def _extract_host_from_url(url: str) -> str:
"""Extract host from URL for grouping host-scoped credentials."""
try:
parsed = urlparse(url)
return parsed.hostname or url
parsed = parse_url(url)
return parsed.netloc or url
except Exception:
return ""

View File

@@ -157,12 +157,7 @@ async def validate_url(
is_trusted: Boolean indicating if the hostname is in trusted_origins
ip_addresses: List of IP addresses for the host; empty if the host is trusted
"""
# Canonicalize URL
url = url.strip("/ ").replace("\\", "/")
parsed = urlparse(url)
if not parsed.scheme:
url = f"http://{url}"
parsed = urlparse(url)
parsed = parse_url(url)
# Check scheme
if parsed.scheme not in ALLOWED_SCHEMES:
@@ -220,6 +215,17 @@ async def validate_url(
)
def parse_url(url: str) -> URL:
"""Canonicalizes and parses a URL string."""
url = url.strip("/ ").replace("\\", "/")
parsed = urlparse(url)
if not parsed.scheme:
url = f"http://{url}"
parsed = urlparse(url)
return parsed
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
"""
Pins a URL to a specific IP address to prevent DNS rebinding attacks.

View File

@@ -41,7 +41,17 @@ export function HostScopedCredentialsModal({
const currentHost = currentUrl ? getHostFromUrl(currentUrl) : "";
const formSchema = z.object({
host: z.string().min(1, "Host is required"),
host: z
.string()
.min(1, "Host is required")
.refine((val) => !/^[a-zA-Z][a-zA-Z\d+\-.]*:\/\//.test(val), {
message: "Enter only the host (e.g. api.example.com), not a full URL",
})
.refine((val) => !val.includes("/"), {
message:
"Enter only the host (e.g. api.example.com), without a trailing path. " +
"You may specify a port (e.g. api.example.com:8080) if needed.",
}),
title: z.string().optional(),
headers: z.record(z.string()).optional(),
});