mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-10 07:38:04 -05:00
feat(platform): Add Host-scoped credentials support for blocks HTTP requests (#10215)
Currently, we don't have a secure way to pass Authorization headers when calling the `SendWebRequestBlock`. This hinders the integration of third-party applications that do not yet have native block support. ### Changes 🏗️ Add Host-scoped credentials support for the newly introduced SendAuthenticatedWebRequestBlock. <img width="1000" alt="image" src="https://github.com/user-attachments/assets/0d3d577a-2b9b-4f0f-9377-0e00a069ba37" /> <img width="1000" alt="image" src="https://github.com/user-attachments/assets/a59b9f16-c89c-453d-a628-1df0dfd60fb5" /> ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [x] Uses `https://api.openai.com/v1/images/edits` through SendWebRequestBlock by passing the api-key through host-scoped credentials.
This commit is contained in:
5
.github/workflows/platform-backend-ci.yml
vendored
5
.github/workflows/platform-backend-ci.yml
vendored
@@ -190,9 +190,9 @@ jobs:
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
else
|
||||
poetry run pytest -s -vv test
|
||||
poetry run pytest -s -vv
|
||||
fi
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
@@ -205,6 +205,7 @@ jobs:
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
env:
|
||||
CI: true
|
||||
|
||||
@@ -20,7 +20,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
modules = [
|
||||
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
|
||||
for f in current_dir.rglob("*.py")
|
||||
if f.is_file() and f.name != "__init__.py"
|
||||
if f.is_file() and f.name != "__init__.py" and not f.name.startswith("test_")
|
||||
]
|
||||
for module in modules:
|
||||
if not re.match("^[a-z0-9_.]+$", module):
|
||||
|
||||
@@ -11,7 +11,7 @@ from backend.blocks.apollo.models import (
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
@@ -65,7 +65,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
le=50000,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: ApolloCredentialsInput = SchemaField(
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from backend.blocks.apollo.models import (
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
@@ -97,7 +97,7 @@ class SearchPeopleBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
credentials: ApolloCredentialsInput = SchemaField(
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
|
||||
@@ -3,11 +3,19 @@ import logging
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import aiofiles
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
HostScopedCredentials,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import (
|
||||
MediaFileType,
|
||||
get_exec_file_path,
|
||||
@@ -19,6 +27,30 @@ from backend.util.request import Requests
|
||||
logger = logging.getLogger(name=__name__)
|
||||
|
||||
|
||||
# Host-scoped credentials for HTTP requests
|
||||
HttpCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.HTTP], Literal["host_scoped"]
|
||||
]
|
||||
|
||||
|
||||
TEST_CREDENTIALS = HostScopedCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer test-token"),
|
||||
"X-API-Key": SecretStr("test-api-key"),
|
||||
},
|
||||
title="Mock HTTP Host-Scoped Credentials",
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class HttpMethod(Enum):
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
@@ -169,3 +201,62 @@ class SendWebRequestBlock(Block):
|
||||
yield "client_error", result
|
||||
else:
|
||||
yield "server_error", result
|
||||
|
||||
|
||||
class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
||||
class Input(SendWebRequestBlock.Input):
|
||||
credentials: HttpCredentials = CredentialsField(
|
||||
description="HTTP host-scoped credentials for automatic header injection",
|
||||
discriminator="url",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
Block.__init__(
|
||||
self,
|
||||
id="fff86bcd-e001-4bad-a7f6-2eae4720c8dc",
|
||||
description="Make an authenticated HTTP request with host-scoped credentials (JSON / form / multipart).",
|
||||
categories={BlockCategory.OUTPUT},
|
||||
input_schema=SendAuthenticatedWebRequestBlock.Input,
|
||||
output_schema=SendWebRequestBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run( # type: ignore[override]
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
credentials: HostScopedCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
||||
base_input = SendWebRequestBlock.Input(
|
||||
url=input_data.url,
|
||||
method=input_data.method,
|
||||
headers=input_data.headers,
|
||||
json_format=input_data.json_format,
|
||||
body=input_data.body,
|
||||
files_name=input_data.files_name,
|
||||
files=input_data.files,
|
||||
)
|
||||
|
||||
# Apply host-scoped credentials to headers
|
||||
extra_headers = {}
|
||||
if credentials.matches_url(input_data.url):
|
||||
logger.debug(
|
||||
f"Applying host-scoped credentials {credentials.id} for URL {input_data.url}"
|
||||
)
|
||||
extra_headers.update(credentials.get_headers_dict())
|
||||
else:
|
||||
logger.warning(
|
||||
f"Host-scoped credentials {credentials.id} do not match URL {input_data.url}"
|
||||
)
|
||||
|
||||
# Merge with user-provided headers (user headers take precedence)
|
||||
base_input.headers = {**extra_headers, **input_data.headers}
|
||||
|
||||
# Use parent class run method
|
||||
async for output_name, output_data in super().run(
|
||||
base_input, graph_exec_id=graph_exec_id, **kwargs
|
||||
):
|
||||
yield output_name, output_data
|
||||
|
||||
@@ -40,7 +40,7 @@ LLMProviderName = Literal[
|
||||
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
id="769f6af7-820b-4d5d-9b7a-ab82bbc165f",
|
||||
provider="openai",
|
||||
api_key=SecretStr("mock-openai-api-key"),
|
||||
title="Mock OpenAI API key",
|
||||
|
||||
@@ -13,7 +13,7 @@ from backend.data.model import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
id="8cc8b2c5-d3e4-4b1c-84ad-e1e9fe2a0122",
|
||||
provider="mem0",
|
||||
api_key=SecretStr("mock-mem0-api-key"),
|
||||
title="Mock Mem0 API key",
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.smartlead.models import (
|
||||
Sequence,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class CreateCampaignBlock(Block):
|
||||
@@ -27,7 +27,7 @@ class CreateCampaignBlock(Block):
|
||||
name: str = SchemaField(
|
||||
description="The name of the campaign",
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
credentials: SmartLeadCredentialsInput = CredentialsField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
@@ -119,7 +119,7 @@ class AddLeadToCampaignBlock(Block):
|
||||
description="Settings for lead upload",
|
||||
default=LeadUploadSettings(),
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
credentials: SmartLeadCredentialsInput = CredentialsField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
@@ -251,7 +251,7 @@ class SaveCampaignSequencesBlock(Block):
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
credentials: SmartLeadCredentialsInput = CredentialsField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
|
||||
485
autogpt_platform/backend/backend/blocks/test/test_http.py
Normal file
485
autogpt_platform/backend/backend/blocks/test/test_http.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""Comprehensive tests for HTTP block with HostScopedCredentials functionality."""
|
||||
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.http import (
|
||||
HttpCredentials,
|
||||
HttpMethod,
|
||||
SendAuthenticatedWebRequestBlock,
|
||||
)
|
||||
from backend.data.model import HostScopedCredentials
|
||||
from backend.util.request import Response
|
||||
|
||||
|
||||
class TestHttpBlockWithHostScopedCredentials:
|
||||
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
||||
|
||||
@pytest.fixture
|
||||
def http_block(self):
|
||||
"""Create an HTTP block instance."""
|
||||
return SendAuthenticatedWebRequestBlock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Mock a successful HTTP response."""
|
||||
response = MagicMock(spec=Response)
|
||||
response.status = 200
|
||||
response.headers = {"content-type": "application/json"}
|
||||
response.json.return_value = {"success": True, "data": "test"}
|
||||
return response
|
||||
|
||||
@pytest.fixture
|
||||
def exact_match_credentials(self):
|
||||
"""Create host-scoped credentials for exact domain matching."""
|
||||
return HostScopedCredentials(
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer exact-match-token"),
|
||||
"X-API-Key": SecretStr("api-key-123"),
|
||||
},
|
||||
title="Exact Match API Credentials",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def wildcard_credentials(self):
|
||||
"""Create host-scoped credentials with wildcard pattern."""
|
||||
return HostScopedCredentials(
|
||||
provider="http",
|
||||
host="*.github.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("token ghp_wildcard123"),
|
||||
},
|
||||
title="GitHub Wildcard Credentials",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def non_matching_credentials(self):
|
||||
"""Create credentials that don't match test URLs."""
|
||||
return HostScopedCredentials(
|
||||
provider="http",
|
||||
host="different.api.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer non-matching-token"),
|
||||
},
|
||||
title="Non-matching Credentials",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_http_block_with_exact_host_match(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
exact_match_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test HTTP block with exact host matching credentials."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Prepare input data
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": exact_match_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": exact_match_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with credentials provided by execution manager
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=exact_match_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify request headers include both credential and user headers
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {
|
||||
"Authorization": "Bearer exact-match-token",
|
||||
"X-API-Key": "api-key-123",
|
||||
"User-Agent": "test-agent",
|
||||
}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
# Verify response handling
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "response"
|
||||
assert result[0][1] == {"success": True, "data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_http_block_with_wildcard_host_match(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
wildcard_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test HTTP block with wildcard host pattern matching."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with subdomain that should match *.github.com
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.github.com/user",
|
||||
method=HttpMethod.GET,
|
||||
headers={},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": wildcard_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": wildcard_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with wildcard credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=wildcard_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify wildcard matching works
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {"Authorization": "token ghp_wildcard123"}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_http_block_with_non_matching_credentials(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
non_matching_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test HTTP block when credentials don't match the target URL."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with URL that doesn't match the credentials
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": non_matching_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": non_matching_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with non-matching credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=non_matching_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify only user headers are included (no credential headers)
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {"User-Agent": "test-agent"}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_user_headers_override_credential_headers(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
exact_match_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test that user-provided headers take precedence over credential headers."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with user header that conflicts with credential header
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.POST,
|
||||
headers={
|
||||
"Authorization": "Bearer user-override-token", # Should override
|
||||
"Content-Type": "application/json", # Additional user header
|
||||
},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": exact_match_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": exact_match_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with conflicting headers
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=exact_match_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify user headers take precedence
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {
|
||||
"X-API-Key": "api-key-123", # From credentials
|
||||
"Authorization": "Bearer user-override-token", # User override
|
||||
"Content-Type": "application/json", # User header
|
||||
}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_auto_discovered_credentials_flow(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
mock_response,
|
||||
):
|
||||
"""Test the auto-discovery flow where execution manager provides matching credentials."""
|
||||
# Create auto-discovered credentials
|
||||
auto_discovered_creds = HostScopedCredentials(
|
||||
provider="http",
|
||||
host="*.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer auto-discovered-token"),
|
||||
},
|
||||
title="Auto-discovered Credentials",
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with empty credentials field (triggers auto-discovery)
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": "", # Empty ID triggers auto-discovery in execution manager
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": "",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with auto-discovered credentials provided by execution manager
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=auto_discovered_creds, # Execution manager found these
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify auto-discovered credentials were applied
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {"Authorization": "Bearer auto-discovered-token"}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
# Verify response handling
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "response"
|
||||
assert result[0][1] == {"success": True, "data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_multiple_header_credentials(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
mock_response,
|
||||
):
|
||||
"""Test credentials with multiple headers are all applied."""
|
||||
# Create credentials with multiple headers
|
||||
multi_header_creds = HostScopedCredentials(
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer multi-token"),
|
||||
"X-API-Key": SecretStr("api-key-456"),
|
||||
"X-Client-ID": SecretStr("client-789"),
|
||||
"X-Custom-Header": SecretStr("custom-value"),
|
||||
},
|
||||
title="Multi-Header Credentials",
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with credentials containing multiple headers
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": multi_header_creds.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": multi_header_creds.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with multi-header credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=multi_header_creds,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify all headers are included
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {
|
||||
"Authorization": "Bearer multi-token",
|
||||
"X-API-Key": "api-key-456",
|
||||
"X-Client-ID": "client-789",
|
||||
"X-Custom-Header": "custom-value",
|
||||
"User-Agent": "test-agent",
|
||||
}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_credentials_with_complex_url_patterns(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
mock_response,
|
||||
):
|
||||
"""Test credentials matching various URL patterns."""
|
||||
# Test cases for different URL patterns
|
||||
test_cases = [
|
||||
{
|
||||
"host_pattern": "api.example.com",
|
||||
"test_url": "https://api.example.com/v1/users",
|
||||
"should_match": True,
|
||||
},
|
||||
{
|
||||
"host_pattern": "*.example.com",
|
||||
"test_url": "https://api.example.com/v1/users",
|
||||
"should_match": True,
|
||||
},
|
||||
{
|
||||
"host_pattern": "*.example.com",
|
||||
"test_url": "https://subdomain.example.com/data",
|
||||
"should_match": True,
|
||||
},
|
||||
{
|
||||
"host_pattern": "api.example.com",
|
||||
"test_url": "https://api.different.com/data",
|
||||
"should_match": False,
|
||||
},
|
||||
]
|
||||
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
for case in test_cases:
|
||||
# Reset mock for each test case
|
||||
mock_requests.reset_mock()
|
||||
|
||||
# Create credentials for this test case
|
||||
test_creds = HostScopedCredentials(
|
||||
provider="http",
|
||||
host=case["host_pattern"],
|
||||
headers={
|
||||
"Authorization": SecretStr(f"Bearer {case['host_pattern']}-token"),
|
||||
},
|
||||
title=f"Credentials for {case['host_pattern']}",
|
||||
)
|
||||
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url=case["test_url"],
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": test_creds.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": test_creds.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with test credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=test_creds,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify headers based on whether pattern should match
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
headers = call_args.kwargs["headers"]
|
||||
|
||||
if case["should_match"]:
|
||||
# Should include both user and credential headers
|
||||
expected_auth = f"Bearer {case['host_pattern']}-token"
|
||||
assert headers["Authorization"] == expected_auth
|
||||
assert headers["User-Agent"] == "test-agent"
|
||||
else:
|
||||
# Should only include user headers
|
||||
assert "Authorization" not in headers
|
||||
assert headers["User-Agent"] == "test-agent"
|
||||
@@ -25,12 +25,7 @@ async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Grap
|
||||
async def create_credentials(s: SpinTestServer, u: User):
|
||||
provider = ProviderName.OPENAI
|
||||
credentials = llm.TEST_CREDENTIALS
|
||||
try:
|
||||
await s.agent_server.test_create_credentials(u.id, provider, credentials)
|
||||
except Exception:
|
||||
# ValueErrors is raised trying to recreate the same credentials
|
||||
# so hidding the error
|
||||
pass
|
||||
return await s.agent_server.test_create_credentials(u.id, provider, credentials)
|
||||
|
||||
|
||||
async def execute_graph(
|
||||
@@ -60,19 +55,18 @@ async def execute_graph(
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
await create_credentials(server, test_user)
|
||||
creds = await create_credentials(server, test_user)
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
"credentials": creds,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
@@ -110,80 +104,18 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_validation_with_tool_nodes_raises_error(server: SpinTestServer):
|
||||
|
||||
test_user = await create_test_user()
|
||||
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
await create_credentials(server, test_user)
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=AgentExecutorBlock().id,
|
||||
input_default={
|
||||
"graph_id": test_tool_graph.id,
|
||||
"graph_version": test_tool_graph.version,
|
||||
"input_schema": test_tool_graph.input_schema,
|
||||
"output_schema": test_tool_graph.output_schema,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=StoreValueBlock().id,
|
||||
),
|
||||
]
|
||||
|
||||
links = [
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[1].id,
|
||||
source_name="tools_^_sample_tool_input_1",
|
||||
sink_name="input_1",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[1].id,
|
||||
source_name="tools_^_sample_tool_input_2",
|
||||
sink_name="input_2",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="tools_^_store_value_input",
|
||||
sink_name="input",
|
||||
),
|
||||
]
|
||||
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
await create_credentials(server, test_user)
|
||||
creds = await create_credentials(server, test_user)
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
"credentials": creds,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
@@ -15,7 +15,7 @@ from backend.blocks.zerobounce._auth import (
|
||||
ZeroBounceCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
@@ -90,7 +90,7 @@ class ValidateEmailsBlock(Block):
|
||||
description="IP address to validate",
|
||||
default="",
|
||||
)
|
||||
credentials: ZeroBounceCredentialsInput = SchemaField(
|
||||
credentials: ZeroBounceCredentialsInput = CredentialsField(
|
||||
description="ZeroBounce credentials",
|
||||
)
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ from dotenv import load_dotenv
|
||||
|
||||
from backend.util.logging import configure_logging
|
||||
|
||||
os.environ["ENABLE_AUTH"] = "false"
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# NOTE: You can run tests like with the --log-cli-level=INFO to see the logs
|
||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel
|
||||
from redis.asyncio.client import PubSub as AsyncPubSub
|
||||
from redis.client import PubSub
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data import redis_client as redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -48,6 +48,7 @@ from .block import (
|
||||
get_webhook_block_ids,
|
||||
)
|
||||
from .db import BaseDbModel
|
||||
from .event_bus import AsyncRedisEventBus, RedisEventBus
|
||||
from .includes import (
|
||||
EXECUTION_RESULT_INCLUDE,
|
||||
EXECUTION_RESULT_ORDER,
|
||||
@@ -55,7 +56,6 @@ from .includes import (
|
||||
graph_execution_include,
|
||||
)
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
from .queue import AsyncRedisEventBus, RedisEventBus
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import type as type_utils
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
@@ -243,6 +244,8 @@ class Graph(BaseGraph):
|
||||
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
||||
if field.provider != other_field.provider:
|
||||
continue
|
||||
if ProviderName.HTTP in field.provider:
|
||||
continue
|
||||
|
||||
# If this happens, that means a block implementation probably needs
|
||||
# to be updated.
|
||||
@@ -264,6 +267,7 @@ class Graph(BaseGraph):
|
||||
required_scopes=set(field_info.required_scopes or []),
|
||||
discriminator=field_info.discriminator,
|
||||
discriminator_mapping=field_info.discriminator_mapping,
|
||||
discriminator_values=field_info.discriminator_values,
|
||||
),
|
||||
)
|
||||
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
||||
@@ -282,37 +286,40 @@ class Graph(BaseGraph):
|
||||
Returns:
|
||||
dict[aggregated_field_key, tuple(
|
||||
CredentialsFieldInfo: A spec for one aggregated credentials field
|
||||
(now includes discriminator_values from matching nodes)
|
||||
set[(node_id, field_name)]: Node credentials fields that are
|
||||
compatible with this aggregated field spec
|
||||
)]
|
||||
"""
|
||||
return {
|
||||
"_".join(sorted(agg_field_info.provider))
|
||||
+ "_"
|
||||
+ "_".join(sorted(agg_field_info.supported_types))
|
||||
+ "_credentials": (agg_field_info, node_fields)
|
||||
for agg_field_info, node_fields in CredentialsFieldInfo.combine(
|
||||
*(
|
||||
(
|
||||
# Apply discrimination before aggregating credentials inputs
|
||||
(
|
||||
field_info.discriminate(
|
||||
node.input_default[field_info.discriminator]
|
||||
)
|
||||
if (
|
||||
field_info.discriminator
|
||||
and node.input_default.get(field_info.discriminator)
|
||||
)
|
||||
else field_info
|
||||
),
|
||||
(node.id, field_name),
|
||||
# First collect all credential field data with input defaults
|
||||
node_credential_data = []
|
||||
|
||||
for graph in [self] + self.sub_graphs:
|
||||
for node in graph.nodes:
|
||||
for (
|
||||
field_name,
|
||||
field_info,
|
||||
) in node.block.input_schema.get_credentials_fields_info().items():
|
||||
|
||||
discriminator = field_info.discriminator
|
||||
if not discriminator:
|
||||
node_credential_data.append((field_info, (node.id, field_name)))
|
||||
continue
|
||||
|
||||
discriminator_value = node.input_default.get(discriminator)
|
||||
if discriminator_value is None:
|
||||
node_credential_data.append((field_info, (node.id, field_name)))
|
||||
continue
|
||||
|
||||
discriminated_info = field_info.discriminate(discriminator_value)
|
||||
discriminated_info.discriminator_values.add(discriminator_value)
|
||||
|
||||
node_credential_data.append(
|
||||
(discriminated_info, (node.id, field_name))
|
||||
)
|
||||
for graph in [self] + self.sub_graphs
|
||||
for node in graph.nodes
|
||||
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
# Combine credential field info (this will merge discriminator_values automatically)
|
||||
return CredentialsFieldInfo.combine(*node_credential_data)
|
||||
|
||||
|
||||
class GraphModel(Graph):
|
||||
@@ -391,7 +398,9 @@ class GraphModel(Graph):
|
||||
continue
|
||||
node.input_default["user_id"] = user_id
|
||||
node.input_default.setdefault("inputs", {})
|
||||
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
|
||||
if (
|
||||
graph_id := node.input_default.get("graph_id")
|
||||
) and graph_id in graph_id_map:
|
||||
node.input_default["graph_id"] = graph_id_map[graph_id]
|
||||
|
||||
def validate_graph(
|
||||
|
||||
@@ -11,8 +11,8 @@ from prisma.types import (
|
||||
)
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
from backend.data.queue import AsyncRedisEventBus
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
@@ -14,11 +14,12 @@ from typing import (
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
cast,
|
||||
get_args,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
from prisma.enums import CreditTransactionType
|
||||
@@ -240,13 +241,65 @@ class UserPasswordCredentials(_BaseCredentials):
|
||||
return f"Basic {base64.b64encode(f'{self.username.get_secret_value()}:{self.password.get_secret_value()}'.encode()).decode()}"
|
||||
|
||||
|
||||
class HostScopedCredentials(_BaseCredentials):
|
||||
type: Literal["host_scoped"] = "host_scoped"
|
||||
host: str = Field(description="The host/URI pattern to match against request URLs")
|
||||
headers: dict[str, SecretStr] = Field(
|
||||
description="Key-value header map to add to matching requests",
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
def _extract_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]:
|
||||
"""Helper to extract secret values from headers."""
|
||||
return {key: value.get_secret_value() for key, value in headers.items()}
|
||||
|
||||
@field_serializer("headers")
|
||||
def serialize_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]:
|
||||
"""Serialize headers by extracting secret values."""
|
||||
return self._extract_headers(headers)
|
||||
|
||||
def get_headers_dict(self) -> dict[str, str]:
|
||||
"""Get headers with secret values extracted."""
|
||||
return self._extract_headers(self.headers)
|
||||
|
||||
def auth_header(self) -> str:
|
||||
"""Get authorization header for backward compatibility."""
|
||||
auth_headers = self.get_headers_dict()
|
||||
if "Authorization" in auth_headers:
|
||||
return auth_headers["Authorization"]
|
||||
return ""
|
||||
|
||||
def matches_url(self, url: str) -> bool:
|
||||
"""Check if this credential should be applied to the given URL."""
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
# Extract hostname without port
|
||||
request_host = parsed_url.hostname
|
||||
if not request_host:
|
||||
return False
|
||||
|
||||
# Simple host matching - exact match or wildcard subdomain match
|
||||
if self.host == request_host:
|
||||
return True
|
||||
|
||||
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
||||
if self.host.startswith("*."):
|
||||
domain = self.host[2:] # Remove "*."
|
||||
return request_host.endswith(f".{domain}") or request_host == domain
|
||||
|
||||
return False
|
||||
|
||||
|
||||
Credentials = Annotated[
|
||||
OAuth2Credentials | APIKeyCredentials | UserPasswordCredentials,
|
||||
OAuth2Credentials
|
||||
| APIKeyCredentials
|
||||
| UserPasswordCredentials
|
||||
| HostScopedCredentials,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
CredentialsType = Literal["api_key", "oauth2", "user_password"]
|
||||
CredentialsType = Literal["api_key", "oauth2", "user_password", "host_scoped"]
|
||||
|
||||
|
||||
class OAuthState(BaseModel):
|
||||
@@ -320,15 +373,29 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
|
||||
schema["credentials_provider"] = cls.allowed_providers()
|
||||
schema["credentials_types"] = cls.allowed_cred_types()
|
||||
def _add_json_schema_extra(schema: dict, model_class: type):
|
||||
# Use model_class for allowed_providers/cred_types
|
||||
if hasattr(model_class, "allowed_providers") and hasattr(
|
||||
model_class, "allowed_cred_types"
|
||||
):
|
||||
schema["credentials_provider"] = model_class.allowed_providers()
|
||||
schema["credentials_types"] = model_class.allowed_cred_types()
|
||||
# Do not return anything, just mutate schema in place
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=_add_json_schema_extra, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def _extract_host_from_url(url: str) -> str:
|
||||
"""Extract host from URL for grouping host-scoped credentials."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return parsed.hostname or url
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
|
||||
provider: frozenset[CP] = Field(..., alias="credentials_provider")
|
||||
@@ -336,11 +403,12 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
|
||||
discriminator: Optional[str] = None
|
||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||
discriminator_values: set[Any] = Field(default_factory=set)
|
||||
|
||||
@classmethod
|
||||
def combine(
|
||||
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
|
||||
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
|
||||
) -> dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
|
||||
"""
|
||||
Combines multiple CredentialsFieldInfo objects into as few as possible.
|
||||
|
||||
@@ -358,22 +426,36 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
the set of keys of the respective original items that were grouped together.
|
||||
"""
|
||||
if not fields:
|
||||
return []
|
||||
return {}
|
||||
|
||||
# Group fields by their provider and supported_types
|
||||
# For HTTP host-scoped credentials, also group by host
|
||||
grouped_fields: defaultdict[
|
||||
tuple[frozenset[CP], frozenset[CT]],
|
||||
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
|
||||
] = defaultdict(list)
|
||||
|
||||
for field, key in fields:
|
||||
group_key = (frozenset(field.provider), frozenset(field.supported_types))
|
||||
if field.provider == frozenset([ProviderName.HTTP]):
|
||||
# HTTP host-scoped credentials can have different hosts that reqires different credential sets.
|
||||
# Group by host extracted from the URL
|
||||
providers = frozenset(
|
||||
[cast(CP, "http")]
|
||||
+ [
|
||||
cast(CP, _extract_host_from_url(str(value)))
|
||||
for value in field.discriminator_values
|
||||
]
|
||||
)
|
||||
else:
|
||||
providers = frozenset(field.provider)
|
||||
|
||||
group_key = (providers, frozenset(field.supported_types))
|
||||
grouped_fields[group_key].append((key, field))
|
||||
|
||||
# Combine fields within each group
|
||||
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
|
||||
result: dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]] = {}
|
||||
|
||||
for group in grouped_fields.values():
|
||||
for key, group in grouped_fields.items():
|
||||
# Start with the first field in the group
|
||||
_, combined = group[0]
|
||||
|
||||
@@ -386,18 +468,32 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
if field.required_scopes:
|
||||
all_scopes.update(field.required_scopes)
|
||||
|
||||
# Create a new combined field
|
||||
result.append(
|
||||
(
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
credentials_types=combined.supported_types,
|
||||
credentials_scopes=frozenset(all_scopes) or None,
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
),
|
||||
combined_keys,
|
||||
)
|
||||
# Combine discriminator_values from all fields in the group (removing duplicates)
|
||||
all_discriminator_values = []
|
||||
for _, field in group:
|
||||
for value in field.discriminator_values:
|
||||
if value not in all_discriminator_values:
|
||||
all_discriminator_values.append(value)
|
||||
|
||||
# Generate the key for the combined result
|
||||
providers_key, supported_types_key = key
|
||||
group_key = (
|
||||
"-".join(sorted(providers_key))
|
||||
+ "_"
|
||||
+ "-".join(sorted(supported_types_key))
|
||||
+ "_credentials"
|
||||
)
|
||||
|
||||
result[group_key] = (
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
credentials_types=combined.supported_types,
|
||||
credentials_scopes=frozenset(all_scopes) or None,
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
discriminator_values=set(all_discriminator_values),
|
||||
),
|
||||
combined_keys,
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -406,11 +502,15 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
if not (self.discriminator and self.discriminator_mapping):
|
||||
return self
|
||||
|
||||
discriminator_value = self.discriminator_mapping[discriminator_value]
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([discriminator_value]),
|
||||
credentials_provider=frozenset(
|
||||
[self.discriminator_mapping[discriminator_value]]
|
||||
),
|
||||
credentials_types=self.supported_types,
|
||||
credentials_scopes=self.required_scopes,
|
||||
discriminator=self.discriminator,
|
||||
discriminator_mapping=self.discriminator_mapping,
|
||||
discriminator_values=self.discriminator_values,
|
||||
)
|
||||
|
||||
|
||||
@@ -419,6 +519,7 @@ def CredentialsField(
|
||||
*,
|
||||
discriminator: Optional[str] = None,
|
||||
discriminator_mapping: Optional[dict[str, Any]] = None,
|
||||
discriminator_values: Optional[set[Any]] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
**kwargs,
|
||||
@@ -434,6 +535,7 @@ def CredentialsField(
|
||||
"credentials_scopes": list(required_scopes) or None,
|
||||
"discriminator": discriminator,
|
||||
"discriminator_mapping": discriminator_mapping,
|
||||
"discriminator_values": discriminator_values,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
143
autogpt_platform/backend/backend/data/model_test.py
Normal file
143
autogpt_platform/backend/backend/data/model_test.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import HostScopedCredentials
|
||||
|
||||
|
||||
class TestHostScopedCredentials:
|
||||
def test_host_scoped_credentials_creation(self):
|
||||
"""Test creating HostScopedCredentials with required fields."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer secret-token"),
|
||||
"X-API-Key": SecretStr("api-key-123"),
|
||||
},
|
||||
title="Example API Credentials",
|
||||
)
|
||||
|
||||
assert creds.type == "host_scoped"
|
||||
assert creds.provider == "custom"
|
||||
assert creds.host == "api.example.com"
|
||||
assert creds.title == "Example API Credentials"
|
||||
assert len(creds.headers) == 2
|
||||
assert "Authorization" in creds.headers
|
||||
assert "X-API-Key" in creds.headers
|
||||
|
||||
def test_get_headers_dict(self):
|
||||
"""Test getting headers with secret values extracted."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer secret-token"),
|
||||
"X-Custom-Header": SecretStr("custom-value"),
|
||||
},
|
||||
)
|
||||
|
||||
headers_dict = creds.get_headers_dict()
|
||||
|
||||
assert headers_dict == {
|
||||
"Authorization": "Bearer secret-token",
|
||||
"X-Custom-Header": "custom-value",
|
||||
}
|
||||
|
||||
def test_matches_url_exact_host(self):
|
||||
"""Test URL matching with exact host match."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("https://api.example.com/v1/data")
|
||||
assert creds.matches_url("http://api.example.com/endpoint")
|
||||
assert not creds.matches_url("https://other.example.com/v1/data")
|
||||
assert not creds.matches_url("https://subdomain.api.example.com/v1/data")
|
||||
|
||||
def test_matches_url_wildcard_subdomain(self):
|
||||
"""Test URL matching with wildcard subdomain pattern."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="*.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("https://api.example.com/v1/data")
|
||||
assert creds.matches_url("https://subdomain.example.com/endpoint")
|
||||
assert creds.matches_url("https://deep.nested.example.com/path")
|
||||
assert creds.matches_url("https://example.com/path") # Base domain should match
|
||||
assert not creds.matches_url("https://example.org/v1/data")
|
||||
assert not creds.matches_url("https://notexample.com/v1/data")
|
||||
|
||||
def test_matches_url_with_port_and_path(self):
|
||||
"""Test URL matching with ports and paths."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="localhost",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
||||
assert creds.matches_url("http://localhost/simple")
|
||||
|
||||
def test_empty_headers_dict(self):
|
||||
"""Test HostScopedCredentials with empty headers."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom", host="api.example.com", headers={}
|
||||
)
|
||||
|
||||
assert creds.get_headers_dict() == {}
|
||||
assert creds.matches_url("https://api.example.com/test")
|
||||
|
||||
def test_credential_serialization(self):
|
||||
"""Test that credentials can be serialized/deserialized properly."""
|
||||
original_creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer secret-token"),
|
||||
"X-API-Key": SecretStr("api-key-123"),
|
||||
},
|
||||
title="Test Credentials",
|
||||
)
|
||||
|
||||
# Serialize to dict (simulating storage)
|
||||
serialized = original_creds.model_dump()
|
||||
|
||||
# Deserialize back
|
||||
restored_creds = HostScopedCredentials.model_validate(serialized)
|
||||
|
||||
assert restored_creds.id == original_creds.id
|
||||
assert restored_creds.provider == original_creds.provider
|
||||
assert restored_creds.host == original_creds.host
|
||||
assert restored_creds.title == original_creds.title
|
||||
assert restored_creds.type == "host_scoped"
|
||||
|
||||
# Check that headers are properly restored
|
||||
assert restored_creds.get_headers_dict() == original_creds.get_headers_dict()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"host,test_url,expected",
|
||||
[
|
||||
("api.example.com", "https://api.example.com/test", True),
|
||||
("api.example.com", "https://different.example.com/test", False),
|
||||
("*.example.com", "https://api.example.com/test", True),
|
||||
("*.example.com", "https://sub.api.example.com/test", True),
|
||||
("*.example.com", "https://example.com/test", True),
|
||||
("*.example.com", "https://example.org/test", False),
|
||||
("localhost", "http://localhost:3000/test", True),
|
||||
("localhost", "http://127.0.0.1:3000/test", False),
|
||||
],
|
||||
)
|
||||
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
||||
"""Parametrized test for various URL matching scenarios."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="test",
|
||||
host=host,
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url(test_url) == expected
|
||||
@@ -35,7 +35,7 @@ from autogpt_libs.utils.cache import thread_cached
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import (
|
||||
BlockData,
|
||||
BlockInput,
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.redis import get_redis_async
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor.database import DatabaseManagerAsyncClient
|
||||
|
||||
@@ -7,7 +7,7 @@ from autogpt_libs.utils.synchronize import AsyncRedisKeyedMutex
|
||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
|
||||
from backend.data.model import Credentials, OAuth2Credentials
|
||||
from backend.data.redis import get_redis_async
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
@@ -17,6 +17,7 @@ class ProviderName(str, Enum):
|
||||
GOOGLE = "google"
|
||||
GOOGLE_MAPS = "google_maps"
|
||||
GROQ = "groq"
|
||||
HTTP = "http"
|
||||
HUBSPOT = "hubspot"
|
||||
IDEOGRAM = "ideogram"
|
||||
JINA = "jina"
|
||||
|
||||
@@ -22,7 +22,12 @@ from backend.data.integrations import (
|
||||
publish_webhook_event,
|
||||
wait_for_webhook_event,
|
||||
)
|
||||
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
|
||||
from backend.data.model import (
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
@@ -82,6 +87,9 @@ class CredentialsMetaResponse(BaseModel):
|
||||
title: str | None
|
||||
scopes: list[str] | None
|
||||
username: str | None
|
||||
host: str | None = Field(
|
||||
default=None, description="Host pattern for host-scoped credentials"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/callback")
|
||||
@@ -156,6 +164,9 @@ async def callback(
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=(
|
||||
credentials.host if isinstance(credentials, HostScopedCredentials) else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -172,6 +183,7 @@ async def list_credentials(
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
@@ -193,6 +205,7 @@ async def list_credentials_by_provider(
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
|
||||
@@ -357,11 +357,22 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
from backend.server.integrations.router import create_credentials
|
||||
|
||||
return await create_credentials(
|
||||
user_id=user_id, provider=provider, credentials=credentials
|
||||
from backend.server.integrations.router import (
|
||||
create_credentials,
|
||||
get_credential,
|
||||
)
|
||||
|
||||
try:
|
||||
return await create_credentials(
|
||||
user_id=user_id, provider=provider, credentials=credentials
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating credentials: {e}")
|
||||
return await get_credential(
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
cred_id=credentials.id,
|
||||
)
|
||||
|
||||
def set_test_dependency_overrides(self, overrides: dict):
|
||||
app.dependency_overrides.update(overrides)
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
from typing import Annotated
|
||||
|
||||
import fastapi
|
||||
import pydantic
|
||||
|
||||
import backend.data.analytics
|
||||
from backend.server.utils import get_user_id
|
||||
@@ -12,24 +13,28 @@ router = fastapi.APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LogRawMetricRequest(pydantic.BaseModel):
|
||||
metric_name: str = pydantic.Field(..., min_length=1)
|
||||
metric_value: float = pydantic.Field(..., allow_inf_nan=False)
|
||||
data_string: str = pydantic.Field(..., min_length=1)
|
||||
|
||||
|
||||
@router.post(path="/log_raw_metric")
|
||||
async def log_raw_metric(
|
||||
user_id: Annotated[str, fastapi.Depends(get_user_id)],
|
||||
metric_name: Annotated[str, fastapi.Body(..., embed=True)],
|
||||
metric_value: Annotated[float, fastapi.Body(..., embed=True)],
|
||||
data_string: Annotated[str, fastapi.Body(..., embed=True)],
|
||||
request: LogRawMetricRequest,
|
||||
):
|
||||
try:
|
||||
result = await backend.data.analytics.log_raw_metric(
|
||||
user_id=user_id,
|
||||
metric_name=metric_name,
|
||||
metric_value=metric_value,
|
||||
data_string=data_string,
|
||||
metric_name=request.metric_name,
|
||||
metric_value=request.metric_value,
|
||||
data_string=request.data_string,
|
||||
)
|
||||
return result.id
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to log metric %s for user %s: %s", metric_name, user_id, e
|
||||
"Failed to log metric %s for user %s: %s", request.metric_name, user_id, e
|
||||
)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
|
||||
@@ -97,8 +97,17 @@ def test_log_raw_metric_invalid_request_improved() -> None:
|
||||
assert "data_string" in error_fields, "Should report missing data_string"
|
||||
|
||||
|
||||
def test_log_raw_metric_type_validation_improved() -> None:
|
||||
def test_log_raw_metric_type_validation_improved(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test metric type validation with improved assertions."""
|
||||
# Mock the analytics function to avoid event loop issues
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=Mock(id="test-id"),
|
||||
)
|
||||
|
||||
invalid_requests = [
|
||||
{
|
||||
"data": {
|
||||
@@ -119,10 +128,10 @@ def test_log_raw_metric_type_validation_improved() -> None:
|
||||
{
|
||||
"data": {
|
||||
"metric_name": "test",
|
||||
"metric_value": float("inf"), # Infinity
|
||||
"data_string": "test",
|
||||
"metric_value": 123, # Valid number
|
||||
"data_string": "", # Empty data_string
|
||||
},
|
||||
"expected_error": "ensure this value is finite",
|
||||
"expected_error": "String should have at least 1 character",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@@ -93,10 +93,18 @@ def test_log_raw_metric_values_parametrized(
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_invalid_requests_parametrized(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test invalid metric requests with parametrize."""
|
||||
# Mock the analytics function to avoid event loop issues
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=Mock(id="test-id"),
|
||||
)
|
||||
|
||||
response = client.post("/log_raw_metric", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -108,11 +108,16 @@ class TestDatabaseIsolation:
|
||||
where={"email": {"contains": "@test.example"}}
|
||||
)
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def test_create_user(self, test_db_connection):
|
||||
"""Test that demonstrates proper isolation."""
|
||||
# This test has access to a clean database
|
||||
user = await test_db_connection.user.create(
|
||||
data={"email": "test@test.example", "name": "Test User"}
|
||||
data={
|
||||
"id": "test-user-id",
|
||||
"email": "test@test.example",
|
||||
"name": "Test User",
|
||||
}
|
||||
)
|
||||
assert user.email == "test@test.example"
|
||||
# User will be cleaned up automatically
|
||||
|
||||
@@ -143,7 +143,7 @@ async def test_add_agent_to_library(mocker):
|
||||
)
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.create = mocker.AsyncMock(
|
||||
return_value=mock_library_agent_data
|
||||
)
|
||||
@@ -159,21 +159,24 @@ async def test_add_agent_to_library(mocker):
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
mock_library_agent.return_value.find_first.assert_called_once_with(
|
||||
mock_library_agent.return_value.find_unique.assert_called_once_with(
|
||||
where={
|
||||
"userId": "test-user",
|
||||
"agentGraphId": "agent1",
|
||||
"agentGraphVersion": 1,
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": "test-user",
|
||||
"agentGraphId": "agent1",
|
||||
"agentGraphVersion": 1,
|
||||
}
|
||||
},
|
||||
include=library_agent_include("test-user"),
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
mock_library_agent.return_value.create.assert_called_once_with(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId="test-user",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
),
|
||||
data={
|
||||
"User": {"connect": {"id": "test-user"}},
|
||||
"AgentGraph": {
|
||||
"connect": {"graphVersionId": {"id": "agent1", "version": 1}}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
},
|
||||
include=library_agent_include("test-user"),
|
||||
)
|
||||
|
||||
|
||||
@@ -121,26 +121,57 @@ def test_get_library_agents_error(mocker: pytest_mock.MockFixture):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Mocker Not implemented")
|
||||
def test_add_agent_to_library_success(mocker: pytest_mock.MockFixture):
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
|
||||
mock_db_call.return_value = None
|
||||
mock_library_agent = library_model.LibraryAgent(
|
||||
id="test-library-agent-id",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
is_latest_version=True,
|
||||
updated_at=FIXED_NOW,
|
||||
)
|
||||
|
||||
response = client.post("/agents/test-version-id")
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.add_store_agent_to_library"
|
||||
)
|
||||
mock_db_call.return_value = mock_library_agent
|
||||
|
||||
response = client.post(
|
||||
"/agents", json={"store_listing_version_id": "test-version-id"}
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
# Verify the response contains the library agent data
|
||||
data = library_model.LibraryAgent.model_validate(response.json())
|
||||
assert data.id == "test-library-agent-id"
|
||||
assert data.graph_id == "test-agent-1"
|
||||
|
||||
mock_db_call.assert_called_once_with(
|
||||
store_listing_version_id="test-version-id", user_id="test-user-id"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Mocker Not implemented")
|
||||
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture):
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.add_store_agent_to_library"
|
||||
)
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.post("/agents/test-version-id")
|
||||
response = client.post(
|
||||
"/agents", json={"store_listing_version_id": "test-version-id"}
|
||||
)
|
||||
assert response.status_code == 500
|
||||
assert response.json()["detail"] == "Failed to add agent to library"
|
||||
assert "detail" in response.json() # Verify error response structure
|
||||
mock_db_call.assert_called_once_with(
|
||||
store_listing_version_id="test-version-id", user_id="test-user-id"
|
||||
)
|
||||
|
||||
@@ -259,8 +259,8 @@ def test_ask_otto_unauthenticated(mocker: pytest_mock.MockFixture) -> None:
|
||||
}
|
||||
|
||||
response = client.post("/ask", json=request_data)
|
||||
# When auth is disabled and Otto API URL is not configured, we get 503
|
||||
assert response.status_code == 503
|
||||
# When auth is disabled and Otto API URL is not configured, we get 502 (wrapped from 503)
|
||||
assert response.status_code == 502
|
||||
|
||||
# Restore the override
|
||||
app.dependency_overrides[autogpt_libs.auth.middleware.auth_middleware] = (
|
||||
|
||||
@@ -93,6 +93,14 @@ async def test_get_store_agent_details(mocker):
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Mock Profile prisma call
|
||||
mock_profile = mocker.MagicMock()
|
||||
mock_profile.userId = "user-id-123"
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
|
||||
# Mock StoreListing prisma call - this is what was missing
|
||||
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
|
||||
|
||||
@@ -19,7 +19,9 @@ from backend.server.ws_api import (
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket() -> AsyncMock:
|
||||
return AsyncMock(spec=WebSocket)
|
||||
mock = AsyncMock(spec=WebSocket)
|
||||
mock.query_params = {} # Add query_params attribute for authentication
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -29,8 +31,13 @@ def mock_manager() -> AsyncMock:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_router_subscribe(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker
|
||||
) -> None:
|
||||
# Mock the authenticate_websocket function to ensure it returns a valid user_id
|
||||
mocker.patch(
|
||||
"backend.server.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
mock_websocket.receive_text.side_effect = [
|
||||
WSMessage(
|
||||
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
|
||||
@@ -70,8 +77,13 @@ async def test_websocket_router_subscribe(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_router_unsubscribe(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker
|
||||
) -> None:
|
||||
# Mock the authenticate_websocket function to ensure it returns a valid user_id
|
||||
mocker.patch(
|
||||
"backend.server.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
mock_websocket.receive_text.side_effect = [
|
||||
WSMessage(
|
||||
method=WSMethod.UNSUBSCRIBE,
|
||||
@@ -108,8 +120,13 @@ async def test_websocket_router_unsubscribe(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_router_invalid_method(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, mocker
|
||||
) -> None:
|
||||
# Mock the authenticate_websocket function to ensure it returns a valid user_id
|
||||
mocker.patch(
|
||||
"backend.server.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
mock_websocket.receive_text.side_effect = [
|
||||
WSMessage(method=WSMethod.GRAPH_EXECUTION_EVENT).model_dump_json(),
|
||||
WebSocketDisconnect(),
|
||||
@@ -113,6 +113,11 @@ ignore_patterns = []
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "session"
|
||||
filterwarnings = [
|
||||
"ignore:'audioop' is deprecated:DeprecationWarning:discord.player",
|
||||
"ignore:invalid escape sequence:DeprecationWarning:tweepy.api",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"metric_id": "metric-123-uuid"
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"metric_id": "metric-float_precision-uuid",
|
||||
"test_case": "float_precision"
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"metric_id": "metric-integer_value-uuid",
|
||||
"test_case": "integer_value"
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"metric_id": "metric-large_number-uuid",
|
||||
"test_case": "large_number"
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"metric_id": "metric-negative_value-uuid",
|
||||
"test_case": "negative_value"
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"metric_id": "metric-tiny_number-uuid",
|
||||
"test_case": "tiny_number"
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"metric_id": "metric-zero_value-uuid",
|
||||
"test_case": "zero_value"
|
||||
}
|
||||
@@ -15,6 +15,12 @@
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
},
|
||||
"credentials_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
},
|
||||
"has_external_trigger": false,
|
||||
"trigger_setup_info": null,
|
||||
"new_output": false,
|
||||
"can_access_graph": true,
|
||||
"is_latest_version": true
|
||||
@@ -34,6 +40,12 @@
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
},
|
||||
"credentials_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
},
|
||||
"has_external_trigger": false,
|
||||
"trigger_setup_info": null,
|
||||
"new_output": false,
|
||||
"can_access_graph": false,
|
||||
"is_latest_version": true
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
import os
|
||||
|
||||
os.environ["ENABLE_AUTH"] = "false"
|
||||
@@ -5,6 +5,7 @@ import { useCallback, useContext, useEffect, useMemo, useState } from "react";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
import { IconKey, IconUser } from "@/components/ui/icons";
|
||||
import { Trash2Icon } from "lucide-react";
|
||||
import { KeyIcon } from "@phosphor-icons/react/dist/ssr";
|
||||
import { providerIcons } from "@/components/integrations/credentials-input";
|
||||
import { CredentialsProvidersContext } from "@/components/integrations/credentials-provider";
|
||||
import {
|
||||
@@ -140,11 +141,12 @@ export default function UserIntegrationsPage() {
|
||||
...credentials,
|
||||
provider: provider.provider,
|
||||
providerName: provider.providerName,
|
||||
ProviderIcon: providerIcons[provider.provider],
|
||||
ProviderIcon: providerIcons[provider.provider] || KeyIcon,
|
||||
TypeIcon: {
|
||||
oauth2: IconUser,
|
||||
api_key: IconKey,
|
||||
user_password: IconKey,
|
||||
host_scoped: IconKey,
|
||||
}[credentials.type],
|
||||
})),
|
||||
)
|
||||
@@ -181,6 +183,7 @@ export default function UserIntegrationsPage() {
|
||||
oauth2: "OAuth2 credentials",
|
||||
api_key: "API key",
|
||||
user_password: "Username & password",
|
||||
host_scoped: "Host-scoped credentials",
|
||||
}[cred.type]
|
||||
}{" "}
|
||||
- <code>{cred.id}</code>
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
import { FC } from "react";
|
||||
import { z } from "zod";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormDescription,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import useCredentials from "@/hooks/useCredentials";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
|
||||
export const APIKeyCredentialsModal: FC<{
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
siblingInputs?: Record<string, any>;
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
const formSchema = z.object({
|
||||
apiKey: z.string().min(1, "API Key is required"),
|
||||
title: z.string().min(1, "Name is required"),
|
||||
expiresAt: z.string().optional(),
|
||||
});
|
||||
|
||||
const form = useForm<z.infer<typeof formSchema>>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: {
|
||||
apiKey: "",
|
||||
title: "",
|
||||
expiresAt: "",
|
||||
},
|
||||
});
|
||||
|
||||
if (!credentials || credentials.isLoading || !credentials.supportsApiKey) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { provider, providerName, createAPIKeyCredentials } = credentials;
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
const expiresAt = values.expiresAt
|
||||
? new Date(values.expiresAt).getTime() / 1000
|
||||
: undefined;
|
||||
const newCredentials = await createAPIKeyCredentials({
|
||||
api_key: values.apiKey,
|
||||
title: values.title,
|
||||
expires_at: expiresAt,
|
||||
});
|
||||
onCredentialsCreate({
|
||||
provider,
|
||||
id: newCredentials.id,
|
||||
type: "api_key",
|
||||
title: newCredentials.title,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Add new API key for {providerName}</DialogTitle>
|
||||
{schema.description && (
|
||||
<DialogDescription>{schema.description}</DialogDescription>
|
||||
)}
|
||||
</DialogHeader>
|
||||
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="apiKey"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>API Key</FormLabel>
|
||||
{schema.credentials_scopes && (
|
||||
<FormDescription>
|
||||
Required scope(s) for this block:{" "}
|
||||
{schema.credentials_scopes?.map((s, i, a) => (
|
||||
<span key={i}>
|
||||
<code>{s}</code>
|
||||
{i < a.length - 1 && ", "}
|
||||
</span>
|
||||
))}
|
||||
</FormDescription>
|
||||
)}
|
||||
<FormControl>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="Enter API key..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="title"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Name</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Enter a name for this API key..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="expiresAt"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Expiration Date (Optional)</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="datetime-local"
|
||||
placeholder="Select expiration date..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<Button type="submit" className="w-full">
|
||||
Save & use this API key
|
||||
</Button>
|
||||
</form>
|
||||
</Form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
@@ -1,12 +1,8 @@
|
||||
import { FC, useEffect, useMemo, useState } from "react";
|
||||
import { z } from "zod";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import SchemaTooltip from "@/components/SchemaTooltip";
|
||||
import useCredentials from "@/hooks/useCredentials";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { NotionLogoIcon } from "@radix-ui/react-icons";
|
||||
import {
|
||||
FaDiscord,
|
||||
@@ -23,22 +19,6 @@ import {
|
||||
CredentialsProviderName,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import { IconKey, IconKeyPlus, IconUserPlus } from "@/components/ui/icons";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormDescription,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
@@ -48,6 +28,11 @@ import {
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { APIKeyCredentialsModal } from "./api-key-credentials-modal";
|
||||
import { UserPasswordCredentialsModal } from "./user-password-credentials-modal";
|
||||
import { HostScopedCredentialsModal } from "./host-scoped-credentials-modal";
|
||||
import { OAuth2FlowWaitingModal } from "./oauth2-flow-waiting-modal";
|
||||
import { getHostFromUrl } from "@/lib/utils/url";
|
||||
|
||||
const fallbackIcon = FaKey;
|
||||
|
||||
@@ -63,6 +48,7 @@ export const providerIcons: Record<
|
||||
github: FaGithub,
|
||||
google: FaGoogle,
|
||||
groq: fallbackIcon,
|
||||
http: fallbackIcon,
|
||||
notion: NotionLogoIcon,
|
||||
nvidia: fallbackIcon,
|
||||
discord: FaDiscord,
|
||||
@@ -129,6 +115,8 @@ export const CredentialsInput: FC<{
|
||||
isUserPasswordCredentialsModalOpen,
|
||||
setUserPasswordCredentialsModalOpen,
|
||||
] = useState(false);
|
||||
const [isHostScopedCredentialsModalOpen, setHostScopedCredentialsModalOpen] =
|
||||
useState(false);
|
||||
const [isOAuth2FlowInProgress, setOAuth2FlowInProgress] = useState(false);
|
||||
const [oAuthPopupController, setOAuthPopupController] =
|
||||
useState<AbortController | null>(null);
|
||||
@@ -148,13 +136,27 @@ export const CredentialsInput: FC<{
|
||||
}
|
||||
}, [credentials, selectedCredentials, onSelectCredentials]);
|
||||
|
||||
const singleCredential = useMemo(() => {
|
||||
if (!credentials || !("savedCredentials" in credentials)) return null;
|
||||
const { hasRelevantCredentials, singleCredential } = useMemo(() => {
|
||||
if (!credentials || !("savedCredentials" in credentials)) {
|
||||
return {
|
||||
hasRelevantCredentials: false,
|
||||
singleCredential: null,
|
||||
};
|
||||
}
|
||||
|
||||
if (credentials.savedCredentials.length === 1)
|
||||
return credentials.savedCredentials[0];
|
||||
// Simple logic: if we have any saved credentials, we have relevant credentials
|
||||
const hasRelevant = credentials.savedCredentials.length > 0;
|
||||
|
||||
return null;
|
||||
// Auto-select single credential if only one exists
|
||||
const single =
|
||||
credentials.savedCredentials.length === 1
|
||||
? credentials.savedCredentials[0]
|
||||
: null;
|
||||
|
||||
return {
|
||||
hasRelevantCredentials: hasRelevant,
|
||||
singleCredential: single,
|
||||
};
|
||||
}, [credentials]);
|
||||
|
||||
// If only 1 credential is available, auto-select it and hide this input
|
||||
@@ -178,6 +180,7 @@ export const CredentialsInput: FC<{
|
||||
supportsApiKey,
|
||||
supportsOAuth2,
|
||||
supportsUserPassword,
|
||||
supportsHostScoped,
|
||||
savedCredentials,
|
||||
oAuthCallback,
|
||||
} = credentials;
|
||||
@@ -271,7 +274,7 @@ export const CredentialsInput: FC<{
|
||||
);
|
||||
}
|
||||
|
||||
const ProviderIcon = providerIcons[provider];
|
||||
const ProviderIcon = providerIcons[provider] || fallbackIcon;
|
||||
const modals = (
|
||||
<>
|
||||
{supportsApiKey && (
|
||||
@@ -305,6 +308,18 @@ export const CredentialsInput: FC<{
|
||||
siblingInputs={siblingInputs}
|
||||
/>
|
||||
)}
|
||||
{supportsHostScoped && (
|
||||
<HostScopedCredentialsModal
|
||||
schema={schema}
|
||||
open={isHostScopedCredentialsModalOpen}
|
||||
onClose={() => setHostScopedCredentialsModalOpen(false)}
|
||||
onCredentialsCreate={(creds) => {
|
||||
onSelectCredentials(creds);
|
||||
setHostScopedCredentialsModalOpen(false);
|
||||
}}
|
||||
siblingInputs={siblingInputs}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -317,8 +332,8 @@ export const CredentialsInput: FC<{
|
||||
</div>
|
||||
);
|
||||
|
||||
// No saved credentials yet
|
||||
if (savedCredentials.length === 0) {
|
||||
// Show credentials creation UI when no relevant credentials exist
|
||||
if (!hasRelevantCredentials) {
|
||||
return (
|
||||
<div>
|
||||
{fieldHeader}
|
||||
@@ -342,6 +357,12 @@ export const CredentialsInput: FC<{
|
||||
Enter username and password
|
||||
</Button>
|
||||
)}
|
||||
{supportsHostScoped && credentials.discriminatorValue && (
|
||||
<Button onClick={() => setHostScopedCredentialsModalOpen(true)}>
|
||||
<ProviderIcon className="mr-2 h-4 w-4" />
|
||||
{`Enter sensitive headers for ${getHostFromUrl(credentials.discriminatorValue)}`}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
{modals}
|
||||
{oAuthError && (
|
||||
@@ -358,6 +379,12 @@ export const CredentialsInput: FC<{
|
||||
} else if (newValue === "add-api-key") {
|
||||
// Open API key dialog
|
||||
setAPICredentialsModalOpen(true);
|
||||
} else if (newValue === "add-user-password") {
|
||||
// Open user password dialog
|
||||
setUserPasswordCredentialsModalOpen(true);
|
||||
} else if (newValue === "add-host-scoped") {
|
||||
// Open host-scoped credentials dialog
|
||||
setHostScopedCredentialsModalOpen(true);
|
||||
} else {
|
||||
const selectedCreds = savedCredentials.find((c) => c.id == newValue)!;
|
||||
|
||||
@@ -406,6 +433,15 @@ export const CredentialsInput: FC<{
|
||||
{credentials.title}
|
||||
</SelectItem>
|
||||
))}
|
||||
{savedCredentials
|
||||
.filter((c) => c.type == "host_scoped")
|
||||
.map((credentials, index) => (
|
||||
<SelectItem key={index} value={credentials.id}>
|
||||
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||
<IconKey className="mr-1.5 inline" />
|
||||
{credentials.title}
|
||||
</SelectItem>
|
||||
))}
|
||||
<SelectSeparator />
|
||||
{supportsOAuth2 && (
|
||||
<SelectItem value="sign-in">
|
||||
@@ -425,6 +461,12 @@ export const CredentialsInput: FC<{
|
||||
Add new user password
|
||||
</SelectItem>
|
||||
)}
|
||||
{supportsHostScoped && (
|
||||
<SelectItem value="add-host-scoped">
|
||||
<IconKey className="mr-1.5 inline" />
|
||||
Add host-scoped headers
|
||||
</SelectItem>
|
||||
)}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
{modals}
|
||||
@@ -434,291 +476,3 @@ export const CredentialsInput: FC<{
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const APIKeyCredentialsModal: FC<{
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
siblingInputs?: Record<string, any>;
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
const formSchema = z.object({
|
||||
apiKey: z.string().min(1, "API Key is required"),
|
||||
title: z.string().min(1, "Name is required"),
|
||||
expiresAt: z.string().optional(),
|
||||
});
|
||||
|
||||
const form = useForm<z.infer<typeof formSchema>>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: {
|
||||
apiKey: "",
|
||||
title: "",
|
||||
expiresAt: "",
|
||||
},
|
||||
});
|
||||
|
||||
if (!credentials || credentials.isLoading || !credentials.supportsApiKey) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { provider, providerName, createAPIKeyCredentials } = credentials;
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
const expiresAt = values.expiresAt
|
||||
? new Date(values.expiresAt).getTime() / 1000
|
||||
: undefined;
|
||||
const newCredentials = await createAPIKeyCredentials({
|
||||
api_key: values.apiKey,
|
||||
title: values.title,
|
||||
expires_at: expiresAt,
|
||||
});
|
||||
onCredentialsCreate({
|
||||
provider,
|
||||
id: newCredentials.id,
|
||||
type: "api_key",
|
||||
title: newCredentials.title,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Add new API key for {providerName}</DialogTitle>
|
||||
{schema.description && (
|
||||
<DialogDescription>{schema.description}</DialogDescription>
|
||||
)}
|
||||
</DialogHeader>
|
||||
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="apiKey"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>API Key</FormLabel>
|
||||
{schema.credentials_scopes && (
|
||||
<FormDescription>
|
||||
Required scope(s) for this block:{" "}
|
||||
{schema.credentials_scopes?.map((s, i, a) => (
|
||||
<span key={i}>
|
||||
<code>{s}</code>
|
||||
{i < a.length - 1 && ", "}
|
||||
</span>
|
||||
))}
|
||||
</FormDescription>
|
||||
)}
|
||||
<FormControl>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="Enter API key..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="title"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Name</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Enter a name for this API key..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="expiresAt"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Expiration Date (Optional)</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="datetime-local"
|
||||
placeholder="Select expiration date..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<Button type="submit" className="w-full">
|
||||
Save & use this API key
|
||||
</Button>
|
||||
</form>
|
||||
</Form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
export const UserPasswordCredentialsModal: FC<{
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
siblingInputs?: Record<string, any>;
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
const formSchema = z.object({
|
||||
username: z.string().min(1, "Username is required"),
|
||||
password: z.string().min(1, "Password is required"),
|
||||
title: z.string().min(1, "Name is required"),
|
||||
});
|
||||
|
||||
const form = useForm<z.infer<typeof formSchema>>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: {
|
||||
username: "",
|
||||
password: "",
|
||||
title: "",
|
||||
},
|
||||
});
|
||||
|
||||
if (
|
||||
!credentials ||
|
||||
credentials.isLoading ||
|
||||
!credentials.supportsUserPassword
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { provider, providerName, createUserPasswordCredentials } = credentials;
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
const newCredentials = await createUserPasswordCredentials({
|
||||
username: values.username,
|
||||
password: values.password,
|
||||
title: values.title,
|
||||
});
|
||||
onCredentialsCreate({
|
||||
provider,
|
||||
id: newCredentials.id,
|
||||
type: "user_password",
|
||||
title: newCredentials.title,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
Add new username & password for {providerName}
|
||||
</DialogTitle>
|
||||
</DialogHeader>
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="username"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Username</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Enter username..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="password"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Password</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="Enter password..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="title"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Name</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Enter a name for this user login..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<Button type="submit" className="w-full">
|
||||
Save & use this user login
|
||||
</Button>
|
||||
</form>
|
||||
</Form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
export const OAuth2FlowWaitingModal: FC<{
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
providerName: string;
|
||||
}> = ({ open, onClose, providerName }) => {
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
Waiting on {providerName} sign-in process...
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
Complete the sign-in process in the pop-up window.
|
||||
<br />
|
||||
Closing this dialog will cancel the sign-in process.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -6,10 +6,12 @@ import {
|
||||
CredentialsDeleteResponse,
|
||||
CredentialsMetaResponse,
|
||||
CredentialsProviderName,
|
||||
HostScopedCredentials,
|
||||
PROVIDER_NAMES,
|
||||
UserPasswordCredentials,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { useToastOnFail } from "@/components/ui/use-toast";
|
||||
|
||||
// Get keys from CredentialsProviderName type
|
||||
const CREDENTIALS_PROVIDER_NAMES = Object.values(
|
||||
@@ -30,6 +32,7 @@ const providerDisplayNames: Record<CredentialsProviderName, string> = {
|
||||
google: "Google",
|
||||
google_maps: "Google Maps",
|
||||
groq: "Groq",
|
||||
http: "HTTP",
|
||||
hubspot: "Hubspot",
|
||||
ideogram: "Ideogram",
|
||||
jina: "Jina",
|
||||
@@ -68,6 +71,11 @@ type UserPasswordCredentialsCreatable = Omit<
|
||||
"id" | "provider" | "type"
|
||||
>;
|
||||
|
||||
type HostScopedCredentialsCreatable = Omit<
|
||||
HostScopedCredentials,
|
||||
"id" | "provider" | "type"
|
||||
>;
|
||||
|
||||
export type CredentialsProviderData = {
|
||||
provider: CredentialsProviderName;
|
||||
providerName: string;
|
||||
@@ -82,6 +90,9 @@ export type CredentialsProviderData = {
|
||||
createUserPasswordCredentials: (
|
||||
credentials: UserPasswordCredentialsCreatable,
|
||||
) => Promise<CredentialsMetaResponse>;
|
||||
createHostScopedCredentials: (
|
||||
credentials: HostScopedCredentialsCreatable,
|
||||
) => Promise<CredentialsMetaResponse>;
|
||||
deleteCredentials: (
|
||||
id: string,
|
||||
force?: boolean,
|
||||
@@ -106,6 +117,7 @@ export default function CredentialsProvider({
|
||||
useState<CredentialsProvidersContextType | null>(null);
|
||||
const { isLoggedIn } = useSupabase();
|
||||
const api = useBackendAPI();
|
||||
const onFailToast = useToastOnFail();
|
||||
|
||||
const addCredentials = useCallback(
|
||||
(
|
||||
@@ -134,11 +146,16 @@ export default function CredentialsProvider({
|
||||
code: string,
|
||||
state_token: string,
|
||||
): Promise<CredentialsMetaResponse> => {
|
||||
const credsMeta = await api.oAuthCallback(provider, code, state_token);
|
||||
addCredentials(provider, credsMeta);
|
||||
return credsMeta;
|
||||
try {
|
||||
const credsMeta = await api.oAuthCallback(provider, code, state_token);
|
||||
addCredentials(provider, credsMeta);
|
||||
return credsMeta;
|
||||
} catch (error) {
|
||||
onFailToast("complete OAuth authentication")(error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[api, addCredentials],
|
||||
[api, addCredentials, onFailToast],
|
||||
);
|
||||
|
||||
/** Wraps `BackendAPI.createAPIKeyCredentials`, and adds the result to the internal credentials store. */
|
||||
@@ -147,14 +164,19 @@ export default function CredentialsProvider({
|
||||
provider: CredentialsProviderName,
|
||||
credentials: APIKeyCredentialsCreatable,
|
||||
): Promise<CredentialsMetaResponse> => {
|
||||
const credsMeta = await api.createAPIKeyCredentials({
|
||||
provider,
|
||||
...credentials,
|
||||
});
|
||||
addCredentials(provider, credsMeta);
|
||||
return credsMeta;
|
||||
try {
|
||||
const credsMeta = await api.createAPIKeyCredentials({
|
||||
provider,
|
||||
...credentials,
|
||||
});
|
||||
addCredentials(provider, credsMeta);
|
||||
return credsMeta;
|
||||
} catch (error) {
|
||||
onFailToast("create API key credentials")(error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[api, addCredentials],
|
||||
[api, addCredentials, onFailToast],
|
||||
);
|
||||
|
||||
/** Wraps `BackendAPI.createUserPasswordCredentials`, and adds the result to the internal credentials store. */
|
||||
@@ -163,14 +185,40 @@ export default function CredentialsProvider({
|
||||
provider: CredentialsProviderName,
|
||||
credentials: UserPasswordCredentialsCreatable,
|
||||
): Promise<CredentialsMetaResponse> => {
|
||||
const credsMeta = await api.createUserPasswordCredentials({
|
||||
provider,
|
||||
...credentials,
|
||||
});
|
||||
addCredentials(provider, credsMeta);
|
||||
return credsMeta;
|
||||
try {
|
||||
const credsMeta = await api.createUserPasswordCredentials({
|
||||
provider,
|
||||
...credentials,
|
||||
});
|
||||
addCredentials(provider, credsMeta);
|
||||
return credsMeta;
|
||||
} catch (error) {
|
||||
onFailToast("create user/password credentials")(error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[api, addCredentials],
|
||||
[api, addCredentials, onFailToast],
|
||||
);
|
||||
|
||||
/** Wraps `BackendAPI.createHostScopedCredentials`, and adds the result to the internal credentials store. */
|
||||
const createHostScopedCredentials = useCallback(
|
||||
async (
|
||||
provider: CredentialsProviderName,
|
||||
credentials: HostScopedCredentialsCreatable,
|
||||
): Promise<CredentialsMetaResponse> => {
|
||||
try {
|
||||
const credsMeta = await api.createHostScopedCredentials({
|
||||
provider,
|
||||
...credentials,
|
||||
});
|
||||
addCredentials(provider, credsMeta);
|
||||
return credsMeta;
|
||||
} catch (error) {
|
||||
onFailToast("create host-scoped credentials")(error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[api, addCredentials, onFailToast],
|
||||
);
|
||||
|
||||
/** Wraps `BackendAPI.deleteCredentials`, and removes the credentials from the internal store. */
|
||||
@@ -182,26 +230,31 @@ export default function CredentialsProvider({
|
||||
): Promise<
|
||||
CredentialsDeleteResponse | CredentialsDeleteNeedConfirmationResponse
|
||||
> => {
|
||||
const result = await api.deleteCredentials(provider, id, force);
|
||||
if (!result.deleted) {
|
||||
return result;
|
||||
}
|
||||
setProviders((prev) => {
|
||||
if (!prev || !prev[provider]) return prev;
|
||||
try {
|
||||
const result = await api.deleteCredentials(provider, id, force);
|
||||
if (!result.deleted) {
|
||||
return result;
|
||||
}
|
||||
setProviders((prev) => {
|
||||
if (!prev || !prev[provider]) return prev;
|
||||
|
||||
return {
|
||||
...prev,
|
||||
[provider]: {
|
||||
...prev[provider],
|
||||
savedCredentials: prev[provider].savedCredentials.filter(
|
||||
(cred) => cred.id !== id,
|
||||
),
|
||||
},
|
||||
};
|
||||
});
|
||||
return result;
|
||||
return {
|
||||
...prev,
|
||||
[provider]: {
|
||||
...prev[provider],
|
||||
savedCredentials: prev[provider].savedCredentials.filter(
|
||||
(cred) => cred.id !== id,
|
||||
),
|
||||
},
|
||||
};
|
||||
});
|
||||
return result;
|
||||
} catch (error) {
|
||||
onFailToast("delete credentials")(error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[api],
|
||||
[api, onFailToast],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -210,47 +263,54 @@ export default function CredentialsProvider({
|
||||
return;
|
||||
}
|
||||
|
||||
api.listCredentials().then((response) => {
|
||||
const credentialsByProvider = response.reduce(
|
||||
(acc, cred) => {
|
||||
if (!acc[cred.provider]) {
|
||||
acc[cred.provider] = [];
|
||||
}
|
||||
acc[cred.provider].push(cred);
|
||||
return acc;
|
||||
},
|
||||
{} as Record<CredentialsProviderName, CredentialsMetaResponse[]>,
|
||||
);
|
||||
api
|
||||
.listCredentials()
|
||||
.then((response) => {
|
||||
const credentialsByProvider = response.reduce(
|
||||
(acc, cred) => {
|
||||
if (!acc[cred.provider]) {
|
||||
acc[cred.provider] = [];
|
||||
}
|
||||
acc[cred.provider].push(cred);
|
||||
return acc;
|
||||
},
|
||||
{} as Record<CredentialsProviderName, CredentialsMetaResponse[]>,
|
||||
);
|
||||
|
||||
setProviders((prev) => ({
|
||||
...prev,
|
||||
...Object.fromEntries(
|
||||
CREDENTIALS_PROVIDER_NAMES.map((provider) => [
|
||||
provider,
|
||||
{
|
||||
setProviders((prev) => ({
|
||||
...prev,
|
||||
...Object.fromEntries(
|
||||
CREDENTIALS_PROVIDER_NAMES.map((provider) => [
|
||||
provider,
|
||||
providerName: providerDisplayNames[provider],
|
||||
savedCredentials: credentialsByProvider[provider] ?? [],
|
||||
oAuthCallback: (code: string, state_token: string) =>
|
||||
oAuthCallback(provider, code, state_token),
|
||||
createAPIKeyCredentials: (
|
||||
credentials: APIKeyCredentialsCreatable,
|
||||
) => createAPIKeyCredentials(provider, credentials),
|
||||
createUserPasswordCredentials: (
|
||||
credentials: UserPasswordCredentialsCreatable,
|
||||
) => createUserPasswordCredentials(provider, credentials),
|
||||
deleteCredentials: (id: string, force: boolean = false) =>
|
||||
deleteCredentials(provider, id, force),
|
||||
} satisfies CredentialsProviderData,
|
||||
]),
|
||||
),
|
||||
}));
|
||||
});
|
||||
{
|
||||
provider,
|
||||
providerName: providerDisplayNames[provider],
|
||||
savedCredentials: credentialsByProvider[provider] ?? [],
|
||||
oAuthCallback: (code: string, state_token: string) =>
|
||||
oAuthCallback(provider, code, state_token),
|
||||
createAPIKeyCredentials: (
|
||||
credentials: APIKeyCredentialsCreatable,
|
||||
) => createAPIKeyCredentials(provider, credentials),
|
||||
createUserPasswordCredentials: (
|
||||
credentials: UserPasswordCredentialsCreatable,
|
||||
) => createUserPasswordCredentials(provider, credentials),
|
||||
createHostScopedCredentials: (
|
||||
credentials: HostScopedCredentialsCreatable,
|
||||
) => createHostScopedCredentials(provider, credentials),
|
||||
deleteCredentials: (id: string, force: boolean = false) =>
|
||||
deleteCredentials(provider, id, force),
|
||||
} satisfies CredentialsProviderData,
|
||||
]),
|
||||
),
|
||||
}));
|
||||
})
|
||||
.catch(onFailToast("load credentials"));
|
||||
}, [
|
||||
api,
|
||||
isLoggedIn,
|
||||
createAPIKeyCredentials,
|
||||
createUserPasswordCredentials,
|
||||
createHostScopedCredentials,
|
||||
deleteCredentials,
|
||||
oAuthCallback,
|
||||
]);
|
||||
|
||||
@@ -0,0 +1,235 @@
|
||||
import { FC, useEffect, useState } from "react";
|
||||
import { z } from "zod";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormDescription,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import useCredentials from "@/hooks/useCredentials";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import { getHostFromUrl } from "@/lib/utils/url";
|
||||
|
||||
export const HostScopedCredentialsModal: FC<{
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
siblingInputs?: Record<string, any>;
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
// Get current host from siblingInputs or discriminator_values
|
||||
const currentUrl = credentials?.discriminatorValue;
|
||||
const currentHost = currentUrl ? getHostFromUrl(currentUrl) : "";
|
||||
|
||||
const formSchema = z.object({
|
||||
host: z.string().min(1, "Host is required"),
|
||||
title: z.string().optional(),
|
||||
headers: z.record(z.string()).optional(),
|
||||
});
|
||||
|
||||
const form = useForm<z.infer<typeof formSchema>>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: {
|
||||
host: currentHost || "",
|
||||
title: currentHost || "Manual Entry",
|
||||
headers: {},
|
||||
},
|
||||
});
|
||||
|
||||
const [headerPairs, setHeaderPairs] = useState<
|
||||
Array<{ key: string; value: string }>
|
||||
>([{ key: "", value: "" }]);
|
||||
|
||||
// Update form values when siblingInputs change
|
||||
useEffect(() => {
|
||||
if (currentHost) {
|
||||
form.setValue("host", currentHost);
|
||||
form.setValue("title", currentHost);
|
||||
} else {
|
||||
// Reset to empty when no current host
|
||||
form.setValue("host", "");
|
||||
form.setValue("title", "Manual Entry");
|
||||
}
|
||||
}, [currentHost, form]);
|
||||
|
||||
if (
|
||||
!credentials ||
|
||||
credentials.isLoading ||
|
||||
!credentials.supportsHostScoped
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { provider, providerName, createHostScopedCredentials } = credentials;
|
||||
|
||||
const addHeaderPair = () => {
|
||||
setHeaderPairs([...headerPairs, { key: "", value: "" }]);
|
||||
};
|
||||
|
||||
const removeHeaderPair = (index: number) => {
|
||||
if (headerPairs.length > 1) {
|
||||
setHeaderPairs(headerPairs.filter((_, i) => i !== index));
|
||||
}
|
||||
};
|
||||
|
||||
const updateHeaderPair = (
|
||||
index: number,
|
||||
field: "key" | "value",
|
||||
value: string,
|
||||
) => {
|
||||
const newPairs = [...headerPairs];
|
||||
newPairs[index][field] = value;
|
||||
setHeaderPairs(newPairs);
|
||||
};
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
// Convert header pairs to object, filtering out empty pairs
|
||||
const headers = headerPairs.reduce(
|
||||
(acc, pair) => {
|
||||
if (pair.key.trim() && pair.value.trim()) {
|
||||
acc[pair.key.trim()] = pair.value.trim();
|
||||
}
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, string>,
|
||||
);
|
||||
|
||||
const newCredentials = await createHostScopedCredentials({
|
||||
host: values.host,
|
||||
title: currentHost || values.host,
|
||||
headers,
|
||||
});
|
||||
|
||||
onCredentialsCreate({
|
||||
provider,
|
||||
id: newCredentials.id,
|
||||
type: "host_scoped",
|
||||
title: newCredentials.title,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent className="max-h-[90vh] max-w-2xl overflow-y-auto">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Add sensitive headers for {providerName}</DialogTitle>
|
||||
{schema.description && (
|
||||
<DialogDescription>{schema.description}</DialogDescription>
|
||||
)}
|
||||
</DialogHeader>
|
||||
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="host"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Host Pattern</FormLabel>
|
||||
<FormDescription>
|
||||
{currentHost
|
||||
? "Auto-populated from the URL field. Headers will be applied to requests to this host."
|
||||
: "Enter the host/domain to match against request URLs (e.g., api.example.com)."}
|
||||
</FormDescription>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="text"
|
||||
readOnly={!!currentHost}
|
||||
placeholder={
|
||||
currentHost
|
||||
? undefined
|
||||
: "Enter host (e.g., api.example.com)"
|
||||
}
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<div className="space-y-2">
|
||||
<FormLabel>Headers</FormLabel>
|
||||
<FormDescription>
|
||||
Add sensitive headers (like Authorization, X-API-Key) that
|
||||
should be automatically included in requests to the specified
|
||||
host.
|
||||
</FormDescription>
|
||||
|
||||
{headerPairs.map((pair, index) => (
|
||||
<div key={index} className="flex items-end gap-2">
|
||||
<div className="flex-1">
|
||||
<Input
|
||||
placeholder="Header name (e.g., Authorization)"
|
||||
value={pair.key}
|
||||
onChange={(e) =>
|
||||
updateHeaderPair(index, "key", e.target.value)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="Header value (e.g., Bearer token123)"
|
||||
value={pair.value}
|
||||
onChange={(e) =>
|
||||
updateHeaderPair(index, "value", e.target.value)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => removeHeaderPair(index)}
|
||||
disabled={headerPairs.length === 1}
|
||||
>
|
||||
Remove
|
||||
</Button>
|
||||
</div>
|
||||
))}
|
||||
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={addHeaderPair}
|
||||
className="w-full"
|
||||
>
|
||||
Add Another Header
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<Button type="submit" className="w-full">
|
||||
Save & use these credentials
|
||||
</Button>
|
||||
</form>
|
||||
</Form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,36 @@
|
||||
import { FC } from "react";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
|
||||
export const OAuth2FlowWaitingModal: FC<{
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
providerName: string;
|
||||
}> = ({ open, onClose, providerName }) => {
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
Waiting on {providerName} sign-in process...
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
Complete the sign-in process in the pop-up window.
|
||||
<br />
|
||||
Closing this dialog will cancel the sign-in process.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,149 @@
|
||||
import { FC } from "react";
|
||||
import { z } from "zod";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import useCredentials from "@/hooks/useCredentials";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
|
||||
export const UserPasswordCredentialsModal: FC<{
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
siblingInputs?: Record<string, any>;
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
const formSchema = z.object({
|
||||
username: z.string().min(1, "Username is required"),
|
||||
password: z.string().min(1, "Password is required"),
|
||||
title: z.string().min(1, "Name is required"),
|
||||
});
|
||||
|
||||
const form = useForm<z.infer<typeof formSchema>>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: {
|
||||
username: "",
|
||||
password: "",
|
||||
title: "",
|
||||
},
|
||||
});
|
||||
|
||||
if (
|
||||
!credentials ||
|
||||
credentials.isLoading ||
|
||||
!credentials.supportsUserPassword
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { provider, providerName, createUserPasswordCredentials } = credentials;
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
const newCredentials = await createUserPasswordCredentials({
|
||||
username: values.username,
|
||||
password: values.password,
|
||||
title: values.title,
|
||||
});
|
||||
onCredentialsCreate({
|
||||
provider,
|
||||
id: newCredentials.id,
|
||||
type: "user_password",
|
||||
title: newCredentials.title,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
Add new username & password for {providerName}
|
||||
</DialogTitle>
|
||||
</DialogHeader>
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="username"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Username</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Enter username..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="password"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Password</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="Enter password..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="title"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Name</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Enter a name for this user login..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<Button type="submit" className="w-full">
|
||||
Save & use this user login
|
||||
</Button>
|
||||
</form>
|
||||
</Form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsProviderName,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import { getHostFromUrl } from "@/lib/utils/url";
|
||||
|
||||
export type CredentialsData =
|
||||
| {
|
||||
@@ -17,14 +18,18 @@ export type CredentialsData =
|
||||
supportsApiKey: boolean;
|
||||
supportsOAuth2: boolean;
|
||||
supportsUserPassword: boolean;
|
||||
supportsHostScoped: boolean;
|
||||
isLoading: true;
|
||||
discriminatorValue?: string;
|
||||
}
|
||||
| (CredentialsProviderData & {
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
supportsApiKey: boolean;
|
||||
supportsOAuth2: boolean;
|
||||
supportsUserPassword: boolean;
|
||||
supportsHostScoped: boolean;
|
||||
isLoading: false;
|
||||
discriminatorValue?: string;
|
||||
});
|
||||
|
||||
export default function useCredentials(
|
||||
@@ -33,12 +38,16 @@ export default function useCredentials(
|
||||
): CredentialsData | null {
|
||||
const allProviders = useContext(CredentialsProvidersContext);
|
||||
|
||||
const discriminatorValue: CredentialsProviderName | null =
|
||||
(credsInputSchema.discriminator &&
|
||||
credsInputSchema.discriminator_mapping![
|
||||
getValue(credsInputSchema.discriminator, nodeInputValues)
|
||||
]) ||
|
||||
null;
|
||||
const discriminatorValue = [
|
||||
credsInputSchema.discriminator
|
||||
? getValue(credsInputSchema.discriminator, nodeInputValues)
|
||||
: null,
|
||||
...(credsInputSchema.discriminator_values || []),
|
||||
].find(Boolean);
|
||||
|
||||
const discriminatedProvider = credsInputSchema.discriminator_mapping
|
||||
? credsInputSchema.discriminator_mapping[discriminatorValue]
|
||||
: null;
|
||||
|
||||
let providerName: CredentialsProviderName;
|
||||
if (credsInputSchema.credentials_provider.length > 1) {
|
||||
@@ -47,14 +56,14 @@ export default function useCredentials(
|
||||
"Multi-provider credential input requires discriminator!",
|
||||
);
|
||||
}
|
||||
if (!discriminatorValue) {
|
||||
console.log(
|
||||
if (!discriminatedProvider) {
|
||||
console.warn(
|
||||
`Missing discriminator value from '${credsInputSchema.discriminator}': ` +
|
||||
"hiding credentials input until it is set.",
|
||||
);
|
||||
return null;
|
||||
}
|
||||
providerName = discriminatorValue;
|
||||
providerName = discriminatedProvider;
|
||||
} else {
|
||||
providerName = credsInputSchema.credentials_provider[0];
|
||||
}
|
||||
@@ -69,6 +78,8 @@ export default function useCredentials(
|
||||
const supportsOAuth2 = credsInputSchema.credentials_types.includes("oauth2");
|
||||
const supportsUserPassword =
|
||||
credsInputSchema.credentials_types.includes("user_password");
|
||||
const supportsHostScoped =
|
||||
credsInputSchema.credentials_types.includes("host_scoped");
|
||||
|
||||
// No provider means maybe it's still loading
|
||||
if (!provider) {
|
||||
@@ -82,15 +93,24 @@ export default function useCredentials(
|
||||
return null;
|
||||
}
|
||||
|
||||
// Filter by OAuth credentials that have sufficient scopes for this block
|
||||
const requiredScopes = credsInputSchema.credentials_scopes;
|
||||
const savedCredentials = requiredScopes
|
||||
? provider.savedCredentials.filter(
|
||||
(c) =>
|
||||
c.type != "oauth2" ||
|
||||
new Set(c.scopes).isSupersetOf(new Set(requiredScopes)),
|
||||
)
|
||||
: provider.savedCredentials;
|
||||
const savedCredentials = provider.savedCredentials.filter((c) => {
|
||||
// Filter by OAuth credentials that have sufficient scopes for this block
|
||||
if (c.type === "oauth2") {
|
||||
const requiredScopes = credsInputSchema.credentials_scopes;
|
||||
return (
|
||||
!requiredScopes ||
|
||||
new Set(c.scopes).isSupersetOf(new Set(requiredScopes))
|
||||
);
|
||||
}
|
||||
|
||||
// Filter host_scoped credentials by host matching
|
||||
if (c.type === "host_scoped") {
|
||||
return discriminatorValue && getHostFromUrl(discriminatorValue) == c.host;
|
||||
}
|
||||
|
||||
// Include all other credential types
|
||||
return true;
|
||||
});
|
||||
|
||||
return {
|
||||
...provider,
|
||||
@@ -99,7 +119,9 @@ export default function useCredentials(
|
||||
supportsApiKey,
|
||||
supportsOAuth2,
|
||||
supportsUserPassword,
|
||||
supportsHostScoped,
|
||||
savedCredentials,
|
||||
discriminatorValue,
|
||||
isLoading: false,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -62,6 +62,7 @@ import type {
|
||||
User,
|
||||
UserOnboarding,
|
||||
UserPasswordCredentials,
|
||||
HostScopedCredentials,
|
||||
UsersBalanceHistoryResponse,
|
||||
} from "./types";
|
||||
|
||||
@@ -347,6 +348,16 @@ export default class BackendAPI {
|
||||
);
|
||||
}
|
||||
|
||||
createHostScopedCredentials(
|
||||
credentials: Omit<HostScopedCredentials, "id" | "type">,
|
||||
): Promise<HostScopedCredentials> {
|
||||
return this._request(
|
||||
"POST",
|
||||
`/integrations/${credentials.provider}/credentials`,
|
||||
{ ...credentials, type: "host_scoped" },
|
||||
);
|
||||
}
|
||||
|
||||
listCredentials(provider?: string): Promise<CredentialsMetaResponse[]> {
|
||||
return this._get(
|
||||
provider
|
||||
|
||||
@@ -140,12 +140,17 @@ export type BlockIOBooleanSubSchema = BlockIOSubSchemaMeta & {
|
||||
secret?: boolean;
|
||||
};
|
||||
|
||||
export type CredentialsType = "api_key" | "oauth2" | "user_password";
|
||||
export type CredentialsType =
|
||||
| "api_key"
|
||||
| "oauth2"
|
||||
| "user_password"
|
||||
| "host_scoped";
|
||||
|
||||
export type Credentials =
|
||||
| APIKeyCredentials
|
||||
| OAuth2Credentials
|
||||
| UserPasswordCredentials;
|
||||
| UserPasswordCredentials
|
||||
| HostScopedCredentials;
|
||||
|
||||
// --8<-- [start:BlockIOCredentialsSubSchema]
|
||||
export const PROVIDER_NAMES = {
|
||||
@@ -161,6 +166,7 @@ export const PROVIDER_NAMES = {
|
||||
GOOGLE: "google",
|
||||
GOOGLE_MAPS: "google_maps",
|
||||
GROQ: "groq",
|
||||
HTTP: "http",
|
||||
HUBSPOT: "hubspot",
|
||||
IDEOGRAM: "ideogram",
|
||||
JINA: "jina",
|
||||
@@ -199,6 +205,7 @@ export type BlockIOCredentialsSubSchema = BlockIOObjectSubSchema & {
|
||||
credentials_types: Array<CredentialsType>;
|
||||
discriminator?: string;
|
||||
discriminator_mapping?: { [key: string]: CredentialsProviderName };
|
||||
discriminator_values?: any[];
|
||||
secret?: boolean;
|
||||
};
|
||||
|
||||
@@ -501,6 +508,7 @@ export type CredentialsMetaResponse = {
|
||||
title?: string;
|
||||
scopes?: Array<string>;
|
||||
username?: string;
|
||||
host?: string;
|
||||
};
|
||||
|
||||
/* Mirror of backend/server/integrations/router.py:CredentialsDeletionResponse */
|
||||
@@ -559,6 +567,14 @@ export type UserPasswordCredentials = BaseCredentials & {
|
||||
password: string;
|
||||
};
|
||||
|
||||
/* Mirror of backend/backend/data/model.py:HostScopedCredentials */
|
||||
export type HostScopedCredentials = BaseCredentials & {
|
||||
type: "host_scoped";
|
||||
title: string;
|
||||
host: string;
|
||||
headers: Record<string, string>;
|
||||
};
|
||||
|
||||
// Mirror of backend/backend/data/notifications.py:NotificationType
|
||||
export type NotificationType =
|
||||
| "AGENT_RUN"
|
||||
|
||||
16
autogpt_platform/frontend/src/lib/utils/url.ts
Normal file
16
autogpt_platform/frontend/src/lib/utils/url.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
/**
|
||||
* Extracts the hostname from a URL string.
|
||||
* @param url - The URL string to extract the hostname from
|
||||
* @returns The hostname if valid, null if invalid
|
||||
*/
|
||||
export const getHostFromUrl = (url: string): string | null => {
|
||||
try {
|
||||
if (!url.startsWith("http://") && !url.startsWith("https://")) {
|
||||
url = "http://" + url; // Add a scheme if missing for URL parsing
|
||||
}
|
||||
const urlObj = new URL(url);
|
||||
return urlObj.hostname;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user