Merge 'dev' into 'chore/storybook-test-setup'

This commit is contained in:
Lluis Agusti
2025-06-27 15:11:31 +04:00
107 changed files with 5497 additions and 2171 deletions

View File

@@ -190,9 +190,9 @@ jobs:
- name: Run pytest with coverage
run: |
if [[ "${{ runner.debug }}" == "1" ]]; then
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
else
poetry run pytest -s -vv test
poetry run pytest -s -vv
fi
if: success() || (failure() && steps.lint.outcome == 'failure')
env:
@@ -205,6 +205,7 @@ jobs:
REDIS_HOST: "localhost"
REDIS_PORT: "6379"
REDIS_PASSWORD: "testpassword"
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
env:
CI: true

View File

@@ -20,7 +20,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
modules = [
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
for f in current_dir.rglob("*.py")
if f.is_file() and f.name != "__init__.py"
if f.is_file() and f.name != "__init__.py" and not f.name.startswith("test_")
]
for module in modules:
if not re.match("^[a-z0-9_.]+$", module):

View File

@@ -53,6 +53,7 @@ class AudioTrack(str, Enum):
REFRESHER = ("Refresher",)
TOURIST = ("Tourist",)
TWIN_TYCHES = ("Twin Tyches",)
DONT_STOP_ME_ABSTRACT_FUTURE_BASS = ("Dont Stop Me Abstract Future Bass",)
@property
def audio_url(self):
@@ -78,6 +79,7 @@ class AudioTrack(str, Enum):
AudioTrack.REFRESHER: "https://cdn.tfrv.xyz/audio/refresher.mp3",
AudioTrack.TOURIST: "https://cdn.tfrv.xyz/audio/tourist.mp3",
AudioTrack.TWIN_TYCHES: "https://cdn.tfrv.xyz/audio/twin-tynches.mp3",
AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS: "https://cdn.revid.ai/audio/_dont-stop-me-abstract-future-bass.mp3",
}
return audio_urls[self]
@@ -105,6 +107,7 @@ class GenerationPreset(str, Enum):
MOVIE = ("Movie",)
STYLIZED_ILLUSTRATION = ("Stylized Illustration",)
MANGA = ("Manga",)
DEFAULT = ("DEFAULT",)
class Voice(str, Enum):
@@ -114,6 +117,7 @@ class Voice(str, Enum):
JESSICA = "Jessica"
CHARLOTTE = "Charlotte"
CALLUM = "Callum"
EVA = "Eva"
@property
def voice_id(self):
@@ -124,6 +128,7 @@ class Voice(str, Enum):
Voice.JESSICA: "cgSgspJ2msm6clMCkdW9",
Voice.CHARLOTTE: "XB0fDUnXU5powFXDhCwa",
Voice.CALLUM: "N2lVS1w4EtoT3dr4eOWO",
Voice.EVA: "FGY2WhTYpPnrIDTdsKH5",
}
return voice_id_map[self]
@@ -141,6 +146,8 @@ logger = logging.getLogger(__name__)
class AIShortformVideoCreatorBlock(Block):
"""Creates a shortform texttovideo clip using stock or AI imagery."""
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.REVID], Literal["api_key"]
@@ -184,40 +191,8 @@ class AIShortformVideoCreatorBlock(Block):
video_url: str = SchemaField(description="The URL of the created video")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
description="Creates a shortform video using revid.ai",
categories={BlockCategory.SOCIAL, BlockCategory.AI},
input_schema=AIShortformVideoCreatorBlock.Input,
output_schema=AIShortformVideoCreatorBlock.Output,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"script": "[close-up of a cat] Meow!",
"ratio": "9 / 16",
"resolution": "720p",
"frame_rate": 60,
"generation_preset": GenerationPreset.LEONARDO,
"background_music": AudioTrack.HIGHWAY_NOCTURNE,
"voice": Voice.LILY,
"video_style": VisualMediaType.STOCK_VIDEOS,
},
test_output=(
"video_url",
"https://example.com/video.mp4",
),
test_mock={
"create_webhook": lambda: (
"test_uuid",
"https://webhook.site/test_uuid",
),
"create_video": lambda api_key, payload: {"pid": "test_pid"},
"wait_for_video": lambda api_key, pid, webhook_token, max_wait_time=1000: "https://example.com/video.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def create_webhook(self):
async def create_webhook(self) -> tuple[str, str]:
"""Create a new webhook URL for receiving notifications."""
url = "https://webhook.site/token"
headers = {"Accept": "application/json", "Content-Type": "application/json"}
response = await Requests().post(url, headers=headers)
@@ -225,6 +200,7 @@ class AIShortformVideoCreatorBlock(Block):
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
"""Create a video using the Revid API."""
url = "https://www.revid.ai/api/public/v2/render"
headers = {"key": api_key.get_secret_value()}
response = await Requests().post(url, json=payload, headers=headers)
@@ -234,6 +210,7 @@ class AIShortformVideoCreatorBlock(Block):
return response.json()
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
"""Check the status of a video creation job."""
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
headers = {"key": api_key.get_secret_value()}
response = await Requests().get(url, headers=headers)
@@ -243,9 +220,9 @@ class AIShortformVideoCreatorBlock(Block):
self,
api_key: SecretStr,
pid: str,
webhook_token: str,
max_wait_time: int = 1000,
) -> str:
"""Wait for video creation to complete and return the video URL."""
start_time = time.time()
while time.time() - start_time < max_wait_time:
status = await self.check_video_status(api_key, pid)
@@ -266,6 +243,40 @@ class AIShortformVideoCreatorBlock(Block):
logger.error("Video creation timed out")
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
description="Creates a shortform video using revid.ai",
categories={BlockCategory.SOCIAL, BlockCategory.AI},
input_schema=AIShortformVideoCreatorBlock.Input,
output_schema=AIShortformVideoCreatorBlock.Output,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"script": "[close-up of a cat] Meow!",
"ratio": "9 / 16",
"resolution": "720p",
"frame_rate": 60,
"generation_preset": GenerationPreset.LEONARDO,
"background_music": AudioTrack.HIGHWAY_NOCTURNE,
"voice": Voice.LILY,
"video_style": VisualMediaType.STOCK_VIDEOS,
},
test_output=("video_url", "https://example.com/video.mp4"),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
"https://webhook.site/test_uuid",
),
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/video.mp4",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
@@ -273,20 +284,18 @@ class AIShortformVideoCreatorBlock(Block):
webhook_token, webhook_url = await self.create_webhook()
logger.debug(f"Webhook URL: {webhook_url}")
audio_url = input_data.background_music.audio_url
payload = {
"frameRate": input_data.frame_rate,
"resolution": input_data.resolution,
"frameDurationMultiplier": 18,
"webhook": webhook_url,
"webhook": None,
"creationParams": {
"mediaType": input_data.video_style,
"captionPresetName": "Wrap 1",
"selectedVoice": input_data.voice.voice_id,
"hasEnhancedGeneration": True,
"generationPreset": input_data.generation_preset.name,
"selectedAudio": input_data.background_music,
"selectedAudio": input_data.background_music.value,
"origin": "/create",
"inputText": input_data.script,
"flowType": "text-to-video",
@@ -302,7 +311,7 @@ class AIShortformVideoCreatorBlock(Block):
"selectedStoryStyle": {"value": "custom", "label": "Custom"},
"hasToGenerateVideos": input_data.video_style
!= VisualMediaType.STOCK_VIDEOS,
"audioUrl": audio_url,
"audioUrl": input_data.background_music.audio_url,
},
}
@@ -319,8 +328,370 @@ class AIShortformVideoCreatorBlock(Block):
logger.debug(
f"Video created with project ID: {pid}. Waiting for completion..."
)
video_url = await self.wait_for_video(
credentials.api_key, pid, webhook_token
)
video_url = await self.wait_for_video(credentials.api_key, pid)
logger.debug(f"Video ready: {video_url}")
yield "video_url", video_url
class AIAdMakerVideoCreatorBlock(Block):
"""Generates a 30second vertical AI advert using optional usersupplied imagery."""
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.REVID], Literal["api_key"]
] = CredentialsField(
description="Credentials for Revid.ai API access.",
)
script: str = SchemaField(
description="Short advertising copy. Line breaks create new scenes.",
placeholder="Introducing Foobar [show product photo] the gadget that does it all.",
)
ratio: str = SchemaField(description="Aspect ratio", default="9 / 16")
target_duration: int = SchemaField(
description="Desired length of the ad in seconds.", default=30
)
voice: Voice = SchemaField(
description="Narration voice", default=Voice.EVA, placeholder=Voice.EVA
)
background_music: AudioTrack = SchemaField(
description="Background track",
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS,
)
input_media_urls: list[str] = SchemaField(
description="List of image URLs to feature in the advert.", default=[]
)
use_only_provided_media: bool = SchemaField(
description="Restrict visuals to supplied images only.", default=True
)
class Output(BlockSchema):
video_url: str = SchemaField(description="URL of the finished advert")
error: str = SchemaField(description="Error message on failure")
async def create_webhook(self) -> tuple[str, str]:
"""Create a new webhook URL for receiving notifications."""
url = "https://webhook.site/token"
headers = {"Accept": "application/json", "Content-Type": "application/json"}
response = await Requests().post(url, headers=headers)
webhook_data = response.json()
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
"""Create a video using the Revid API."""
url = "https://www.revid.ai/api/public/v2/render"
headers = {"key": api_key.get_secret_value()}
response = await Requests().post(url, json=payload, headers=headers)
logger.debug(
f"API Response Status Code: {response.status}, Content: {response.text}"
)
return response.json()
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
"""Check the status of a video creation job."""
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
headers = {"key": api_key.get_secret_value()}
response = await Requests().get(url, headers=headers)
return response.json()
async def wait_for_video(
self,
api_key: SecretStr,
pid: str,
max_wait_time: int = 1000,
) -> str:
"""Wait for video creation to complete and return the video URL."""
start_time = time.time()
while time.time() - start_time < max_wait_time:
status = await self.check_video_status(api_key, pid)
logger.debug(f"Video status: {status}")
if status.get("status") == "ready" and "videoUrl" in status:
return status["videoUrl"]
elif status.get("status") == "error":
error_message = status.get("error", "Unknown error occurred")
logger.error(f"Video creation failed: {error_message}")
raise ValueError(f"Video creation failed: {error_message}")
elif status.get("status") in ["FAILED", "CANCELED"]:
logger.error(f"Video creation failed: {status.get('message')}")
raise ValueError(f"Video creation failed: {status.get('message')}")
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
id="58bd2a19-115d-4fd1-8ca4-13b9e37fa6a0",
description="Creates an AIgenerated 30second advert (text + images)",
categories={BlockCategory.MARKETING, BlockCategory.AI},
input_schema=AIAdMakerVideoCreatorBlock.Input,
output_schema=AIAdMakerVideoCreatorBlock.Output,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"script": "Test product launch!",
"input_media_urls": [
"https://cdn.revid.ai/uploads/1747076315114-image.png",
],
},
test_output=("video_url", "https://example.com/ad.mp4"),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
"https://webhook.site/test_uuid",
),
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/ad.mp4",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
webhook_token, webhook_url = await self.create_webhook()
payload = {
"webhook": webhook_url,
"creationParams": {
"targetDuration": input_data.target_duration,
"ratio": input_data.ratio,
"mediaType": "aiVideo",
"inputText": input_data.script,
"flowType": "text-to-video",
"slug": "ai-ad-generator",
"slugNew": "",
"isCopiedFrom": False,
"hasToGenerateVoice": True,
"hasToTranscript": False,
"hasToSearchMedia": True,
"hasAvatar": False,
"hasWebsiteRecorder": False,
"hasTextSmallAtBottom": False,
"selectedAudio": input_data.background_music.value,
"selectedVoice": input_data.voice.voice_id,
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
"selectedAvatarType": "video/mp4",
"websiteToRecord": "",
"hasToGenerateCover": True,
"nbGenerations": 1,
"disableCaptions": False,
"mediaMultiplier": "medium",
"characters": [],
"captionPresetName": "Revid",
"sourceType": "contentScraping",
"selectedStoryStyle": {"value": "custom", "label": "General"},
"generationPreset": "DEFAULT",
"hasToGenerateMusic": False,
"isOptimizedForChinese": False,
"generationUserPrompt": "",
"enableNsfwFilter": False,
"addStickers": False,
"typeMovingImageAnim": "dynamic",
"hasToGenerateSoundEffects": False,
"forceModelType": "gpt-image-1",
"selectedCharacters": [],
"lang": "",
"voiceSpeed": 1,
"disableAudio": False,
"disableVoice": False,
"useOnlyProvidedMedia": input_data.use_only_provided_media,
"imageGenerationModel": "ultra",
"videoGenerationModel": "pro",
"hasEnhancedGeneration": True,
"hasEnhancedGenerationPro": True,
"inputMedias": [
{"url": url, "title": "", "type": "image"}
for url in input_data.input_media_urls
],
"hasToGenerateVideos": True,
"audioUrl": input_data.background_music.audio_url,
"watermark": None,
},
}
response = await self.create_video(credentials.api_key, payload)
pid = response.get("pid")
if not pid:
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
yield "video_url", video_url
class AIScreenshotToVideoAdBlock(Block):
"""Creates an advert where the supplied screenshot is narrated by an AI avatar."""
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.REVID], Literal["api_key"]
] = CredentialsField(description="Revid.ai API key")
script: str = SchemaField(
description="Narration that will accompany the screenshot.",
placeholder="Check out these amazing stats!",
)
screenshot_url: str = SchemaField(
description="Screenshot or image URL to showcase."
)
ratio: str = SchemaField(default="9 / 16")
target_duration: int = SchemaField(default=30)
voice: Voice = SchemaField(default=Voice.EVA)
background_music: AudioTrack = SchemaField(
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS
)
class Output(BlockSchema):
video_url: str = SchemaField(description="Rendered video URL")
error: str = SchemaField(description="Error, if encountered")
async def create_webhook(self) -> tuple[str, str]:
"""Create a new webhook URL for receiving notifications."""
url = "https://webhook.site/token"
headers = {"Accept": "application/json", "Content-Type": "application/json"}
response = await Requests().post(url, headers=headers)
webhook_data = response.json()
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
"""Create a video using the Revid API."""
url = "https://www.revid.ai/api/public/v2/render"
headers = {"key": api_key.get_secret_value()}
response = await Requests().post(url, json=payload, headers=headers)
logger.debug(
f"API Response Status Code: {response.status}, Content: {response.text}"
)
return response.json()
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
"""Check the status of a video creation job."""
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
headers = {"key": api_key.get_secret_value()}
response = await Requests().get(url, headers=headers)
return response.json()
async def wait_for_video(
self,
api_key: SecretStr,
pid: str,
max_wait_time: int = 1000,
) -> str:
"""Wait for video creation to complete and return the video URL."""
start_time = time.time()
while time.time() - start_time < max_wait_time:
status = await self.check_video_status(api_key, pid)
logger.debug(f"Video status: {status}")
if status.get("status") == "ready" and "videoUrl" in status:
return status["videoUrl"]
elif status.get("status") == "error":
error_message = status.get("error", "Unknown error occurred")
logger.error(f"Video creation failed: {error_message}")
raise ValueError(f"Video creation failed: {error_message}")
elif status.get("status") in ["FAILED", "CANCELED"]:
logger.error(f"Video creation failed: {status.get('message')}")
raise ValueError(f"Video creation failed: {status.get('message')}")
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
id="0f3e4635-e810-43d9-9e81-49e6f4e83b7c",
description="Turns a screenshot into an engaging, avatarnarrated video advert.",
categories={BlockCategory.AI, BlockCategory.MARKETING},
input_schema=AIScreenshotToVideoAdBlock.Input,
output_schema=AIScreenshotToVideoAdBlock.Output,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"script": "Amazing numbers!",
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
},
test_output=("video_url", "https://example.com/screenshot.mp4"),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
"https://webhook.site/test_uuid",
),
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/screenshot.mp4",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
webhook_token, webhook_url = await self.create_webhook()
payload = {
"webhook": webhook_url,
"creationParams": {
"targetDuration": input_data.target_duration,
"ratio": input_data.ratio,
"mediaType": "aiVideo",
"hasAvatar": True,
"removeAvatarBackground": True,
"inputText": input_data.script,
"flowType": "text-to-video",
"slug": "ai-ad-generator",
"slugNew": "screenshot-to-video-ad",
"isCopiedFrom": "ai-ad-generator",
"hasToGenerateVoice": True,
"hasToTranscript": False,
"hasToSearchMedia": True,
"hasWebsiteRecorder": False,
"hasTextSmallAtBottom": False,
"selectedAudio": input_data.background_music.value,
"selectedVoice": input_data.voice.voice_id,
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
"selectedAvatarType": "video/mp4",
"websiteToRecord": "",
"hasToGenerateCover": True,
"nbGenerations": 1,
"disableCaptions": False,
"mediaMultiplier": "medium",
"characters": [],
"captionPresetName": "Revid",
"sourceType": "contentScraping",
"selectedStoryStyle": {"value": "custom", "label": "General"},
"generationPreset": "DEFAULT",
"hasToGenerateMusic": False,
"isOptimizedForChinese": False,
"generationUserPrompt": "",
"enableNsfwFilter": False,
"addStickers": False,
"typeMovingImageAnim": "dynamic",
"hasToGenerateSoundEffects": False,
"forceModelType": "gpt-image-1",
"selectedCharacters": [],
"lang": "",
"voiceSpeed": 1,
"disableAudio": False,
"disableVoice": False,
"useOnlyProvidedMedia": True,
"imageGenerationModel": "ultra",
"videoGenerationModel": "ultra",
"hasEnhancedGeneration": True,
"hasEnhancedGenerationPro": True,
"inputMedias": [
{"url": input_data.screenshot_url, "title": "", "type": "image"}
],
"hasToGenerateVideos": True,
"audioUrl": input_data.background_music.audio_url,
"watermark": None,
},
}
response = await self.create_video(credentials.api_key, payload)
pid = response.get("pid")
if not pid:
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
yield "video_url", video_url

View File

@@ -4,6 +4,7 @@ from typing import List
from backend.blocks.apollo._auth import ApolloCredentials
from backend.blocks.apollo.models import (
Contact,
EnrichPersonRequest,
Organization,
SearchOrganizationsRequest,
SearchOrganizationsResponse,
@@ -110,3 +111,21 @@ class ApolloClient:
return (
organizations[: query.max_results] if query.max_results else organizations
)
async def enrich_person(self, query: EnrichPersonRequest) -> Contact:
"""Enrich a person's data including email & phone reveal"""
response = await self.requests.post(
f"{self.API_URL}/people/match",
headers=self._get_headers(),
json=query.model_dump(),
params={
"reveal_personal_emails": "true",
},
)
data = response.json()
if "person" not in data:
raise ValueError(f"Person not found or enrichment failed: {data}")
contact = Contact(**data["person"])
contact.email = contact.email or "-"
return contact

View File

@@ -23,9 +23,9 @@ class BaseModel(OriginalBaseModel):
class PrimaryPhone(BaseModel):
"""A primary phone in Apollo"""
number: str = ""
source: str = ""
sanitized_number: str = ""
number: Optional[str] = ""
source: Optional[str] = ""
sanitized_number: Optional[str] = ""
class SenorityLevels(str, Enum):
@@ -56,102 +56,102 @@ class ContactEmailStatuses(str, Enum):
class RuleConfigStatus(BaseModel):
"""A rule config status in Apollo"""
_id: str = ""
created_at: str = ""
rule_action_config_id: str = ""
rule_config_id: str = ""
status_cd: str = ""
updated_at: str = ""
id: str = ""
key: str = ""
_id: Optional[str] = ""
created_at: Optional[str] = ""
rule_action_config_id: Optional[str] = ""
rule_config_id: Optional[str] = ""
status_cd: Optional[str] = ""
updated_at: Optional[str] = ""
id: Optional[str] = ""
key: Optional[str] = ""
class ContactCampaignStatus(BaseModel):
"""A contact campaign status in Apollo"""
id: str = ""
emailer_campaign_id: str = ""
send_email_from_user_id: str = ""
inactive_reason: str = ""
status: str = ""
added_at: str = ""
added_by_user_id: str = ""
finished_at: str = ""
paused_at: str = ""
auto_unpause_at: str = ""
send_email_from_email_address: str = ""
send_email_from_email_account_id: str = ""
manually_set_unpause: str = ""
failure_reason: str = ""
current_step_id: str = ""
in_response_to_emailer_message_id: str = ""
cc_emails: str = ""
bcc_emails: str = ""
to_emails: str = ""
id: Optional[str] = ""
emailer_campaign_id: Optional[str] = ""
send_email_from_user_id: Optional[str] = ""
inactive_reason: Optional[str] = ""
status: Optional[str] = ""
added_at: Optional[str] = ""
added_by_user_id: Optional[str] = ""
finished_at: Optional[str] = ""
paused_at: Optional[str] = ""
auto_unpause_at: Optional[str] = ""
send_email_from_email_address: Optional[str] = ""
send_email_from_email_account_id: Optional[str] = ""
manually_set_unpause: Optional[str] = ""
failure_reason: Optional[str] = ""
current_step_id: Optional[str] = ""
in_response_to_emailer_message_id: Optional[str] = ""
cc_emails: Optional[str] = ""
bcc_emails: Optional[str] = ""
to_emails: Optional[str] = ""
class Account(BaseModel):
"""An account in Apollo"""
id: str = ""
name: str = ""
website_url: str = ""
blog_url: str = ""
angellist_url: str = ""
linkedin_url: str = ""
twitter_url: str = ""
facebook_url: str = ""
primary_phone: PrimaryPhone = PrimaryPhone()
languages: list[str]
alexa_ranking: int = 0
phone: str = ""
linkedin_uid: str = ""
founded_year: int = 0
publicly_traded_symbol: str = ""
publicly_traded_exchange: str = ""
logo_url: str = ""
chrunchbase_url: str = ""
primary_domain: str = ""
domain: str = ""
team_id: str = ""
organization_id: str = ""
account_stage_id: str = ""
source: str = ""
original_source: str = ""
creator_id: str = ""
owner_id: str = ""
created_at: str = ""
phone_status: str = ""
hubspot_id: str = ""
salesforce_id: str = ""
crm_owner_id: str = ""
parent_account_id: str = ""
sanitized_phone: str = ""
id: Optional[str] = ""
name: Optional[str] = ""
website_url: Optional[str] = ""
blog_url: Optional[str] = ""
angellist_url: Optional[str] = ""
linkedin_url: Optional[str] = ""
twitter_url: Optional[str] = ""
facebook_url: Optional[str] = ""
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
languages: Optional[list[str]] = []
alexa_ranking: Optional[int] = 0
phone: Optional[str] = ""
linkedin_uid: Optional[str] = ""
founded_year: Optional[int] = 0
publicly_traded_symbol: Optional[str] = ""
publicly_traded_exchange: Optional[str] = ""
logo_url: Optional[str] = ""
chrunchbase_url: Optional[str] = ""
primary_domain: Optional[str] = ""
domain: Optional[str] = ""
team_id: Optional[str] = ""
organization_id: Optional[str] = ""
account_stage_id: Optional[str] = ""
source: Optional[str] = ""
original_source: Optional[str] = ""
creator_id: Optional[str] = ""
owner_id: Optional[str] = ""
created_at: Optional[str] = ""
phone_status: Optional[str] = ""
hubspot_id: Optional[str] = ""
salesforce_id: Optional[str] = ""
crm_owner_id: Optional[str] = ""
parent_account_id: Optional[str] = ""
sanitized_phone: Optional[str] = ""
# no listed type on the API docs
account_playbook_statues: list[Any] = []
account_rule_config_statuses: list[RuleConfigStatus] = []
existence_level: str = ""
label_ids: list[str] = []
typed_custom_fields: Any
custom_field_errors: Any
modality: str = ""
source_display_name: str = ""
salesforce_record_id: str = ""
crm_record_url: str = ""
account_playbook_statues: Optional[list[Any]] = []
account_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
existence_level: Optional[str] = ""
label_ids: Optional[list[str]] = []
typed_custom_fields: Optional[Any] = {}
custom_field_errors: Optional[Any] = {}
modality: Optional[str] = ""
source_display_name: Optional[str] = ""
salesforce_record_id: Optional[str] = ""
crm_record_url: Optional[str] = ""
class ContactEmail(BaseModel):
"""A contact email in Apollo"""
email: str = ""
email_md5: str = ""
email_sha256: str = ""
email_status: str = ""
email_source: str = ""
extrapolated_email_confidence: str = ""
position: int = 0
email_from_customer: str = ""
free_domain: bool = True
email: Optional[str] = ""
email_md5: Optional[str] = ""
email_sha256: Optional[str] = ""
email_status: Optional[str] = ""
email_source: Optional[str] = ""
extrapolated_email_confidence: Optional[str] = ""
position: Optional[int] = 0
email_from_customer: Optional[str] = ""
free_domain: Optional[bool] = True
class EmploymentHistory(BaseModel):
@@ -164,40 +164,40 @@ class EmploymentHistory(BaseModel):
populate_by_name=True,
)
_id: Optional[str] = None
created_at: Optional[str] = None
current: Optional[bool] = None
degree: Optional[str] = None
description: Optional[str] = None
emails: Optional[str] = None
end_date: Optional[str] = None
grade_level: Optional[str] = None
kind: Optional[str] = None
major: Optional[str] = None
organization_id: Optional[str] = None
organization_name: Optional[str] = None
raw_address: Optional[str] = None
start_date: Optional[str] = None
title: Optional[str] = None
updated_at: Optional[str] = None
id: Optional[str] = None
key: Optional[str] = None
_id: Optional[str] = ""
created_at: Optional[str] = ""
current: Optional[bool] = False
degree: Optional[str] = ""
description: Optional[str] = ""
emails: Optional[str] = ""
end_date: Optional[str] = ""
grade_level: Optional[str] = ""
kind: Optional[str] = ""
major: Optional[str] = ""
organization_id: Optional[str] = ""
organization_name: Optional[str] = ""
raw_address: Optional[str] = ""
start_date: Optional[str] = ""
title: Optional[str] = ""
updated_at: Optional[str] = ""
id: Optional[str] = ""
key: Optional[str] = ""
class Breadcrumb(BaseModel):
"""A breadcrumb in Apollo"""
label: Optional[str] = "N/A"
signal_field_name: Optional[str] = "N/A"
value: str | list | None = "N/A"
display_name: Optional[str] = "N/A"
label: Optional[str] = ""
signal_field_name: Optional[str] = ""
value: str | list | None = ""
display_name: Optional[str] = ""
class TypedCustomField(BaseModel):
"""A typed custom field in Apollo"""
id: Optional[str] = "N/A"
value: Optional[str] = "N/A"
id: Optional[str] = ""
value: Optional[str] = ""
class Pagination(BaseModel):
@@ -219,23 +219,23 @@ class Pagination(BaseModel):
class DialerFlags(BaseModel):
"""A dialer flags in Apollo"""
country_name: str = ""
country_enabled: bool
high_risk_calling_enabled: bool
potential_high_risk_number: bool
country_name: Optional[str] = ""
country_enabled: Optional[bool] = True
high_risk_calling_enabled: Optional[bool] = True
potential_high_risk_number: Optional[bool] = True
class PhoneNumber(BaseModel):
"""A phone number in Apollo"""
raw_number: str = ""
sanitized_number: str = ""
type: str = ""
position: int = 0
status: str = ""
dnc_status: str = ""
dnc_other_info: str = ""
dailer_flags: DialerFlags = DialerFlags(
raw_number: Optional[str] = ""
sanitized_number: Optional[str] = ""
type: Optional[str] = ""
position: Optional[int] = 0
status: Optional[str] = ""
dnc_status: Optional[str] = ""
dnc_other_info: Optional[str] = ""
dailer_flags: Optional[DialerFlags] = DialerFlags(
country_name="",
country_enabled=True,
high_risk_calling_enabled=True,
@@ -253,33 +253,31 @@ class Organization(BaseModel):
populate_by_name=True,
)
id: Optional[str] = "N/A"
name: Optional[str] = "N/A"
website_url: Optional[str] = "N/A"
blog_url: Optional[str] = "N/A"
angellist_url: Optional[str] = "N/A"
linkedin_url: Optional[str] = "N/A"
twitter_url: Optional[str] = "N/A"
facebook_url: Optional[str] = "N/A"
primary_phone: Optional[PrimaryPhone] = PrimaryPhone(
number="N/A", source="N/A", sanitized_number="N/A"
)
languages: list[str] = []
id: Optional[str] = ""
name: Optional[str] = ""
website_url: Optional[str] = ""
blog_url: Optional[str] = ""
angellist_url: Optional[str] = ""
linkedin_url: Optional[str] = ""
twitter_url: Optional[str] = ""
facebook_url: Optional[str] = ""
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
languages: Optional[list[str]] = []
alexa_ranking: Optional[int] = 0
phone: Optional[str] = "N/A"
linkedin_uid: Optional[str] = "N/A"
phone: Optional[str] = ""
linkedin_uid: Optional[str] = ""
founded_year: Optional[int] = 0
publicly_traded_symbol: Optional[str] = "N/A"
publicly_traded_exchange: Optional[str] = "N/A"
logo_url: Optional[str] = "N/A"
chrunchbase_url: Optional[str] = "N/A"
primary_domain: Optional[str] = "N/A"
sanitized_phone: Optional[str] = "N/A"
owned_by_organization_id: Optional[str] = "N/A"
intent_strength: Optional[str] = "N/A"
show_intent: bool = True
publicly_traded_symbol: Optional[str] = ""
publicly_traded_exchange: Optional[str] = ""
logo_url: Optional[str] = ""
chrunchbase_url: Optional[str] = ""
primary_domain: Optional[str] = ""
sanitized_phone: Optional[str] = ""
owned_by_organization_id: Optional[str] = ""
intent_strength: Optional[str] = ""
show_intent: Optional[bool] = True
has_intent_signal_account: Optional[bool] = True
intent_signal_account: Optional[str] = "N/A"
intent_signal_account: Optional[str] = ""
class Contact(BaseModel):
@@ -292,95 +290,95 @@ class Contact(BaseModel):
populate_by_name=True,
)
contact_roles: list[Any] = []
id: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
name: Optional[str] = None
linkedin_url: Optional[str] = None
title: Optional[str] = None
contact_stage_id: Optional[str] = None
owner_id: Optional[str] = None
creator_id: Optional[str] = None
person_id: Optional[str] = None
email_needs_tickling: bool = True
organization_name: Optional[str] = None
source: Optional[str] = None
original_source: Optional[str] = None
organization_id: Optional[str] = None
headline: Optional[str] = None
photo_url: Optional[str] = None
present_raw_address: Optional[str] = None
linkededin_uid: Optional[str] = None
extrapolated_email_confidence: Optional[float] = None
salesforce_id: Optional[str] = None
salesforce_lead_id: Optional[str] = None
salesforce_contact_id: Optional[str] = None
saleforce_account_id: Optional[str] = None
crm_owner_id: Optional[str] = None
created_at: Optional[str] = None
emailer_campaign_ids: list[str] = []
direct_dial_status: Optional[str] = None
direct_dial_enrichment_failed_at: Optional[str] = None
email_status: Optional[str] = None
email_source: Optional[str] = None
account_id: Optional[str] = None
last_activity_date: Optional[str] = None
hubspot_vid: Optional[str] = None
hubspot_company_id: Optional[str] = None
crm_id: Optional[str] = None
sanitized_phone: Optional[str] = None
merged_crm_ids: Optional[str] = None
updated_at: Optional[str] = None
queued_for_crm_push: bool = True
suggested_from_rule_engine_config_id: Optional[str] = None
email_unsubscribed: Optional[str] = None
label_ids: list[Any] = []
has_pending_email_arcgate_request: bool = True
has_email_arcgate_request: bool = True
existence_level: Optional[str] = None
email: Optional[str] = None
email_from_customer: Optional[str] = None
typed_custom_fields: list[TypedCustomField] = []
custom_field_errors: Any = None
salesforce_record_id: Optional[str] = None
crm_record_url: Optional[str] = None
email_status_unavailable_reason: Optional[str] = None
email_true_status: Optional[str] = None
updated_email_true_status: bool = True
contact_rule_config_statuses: list[RuleConfigStatus] = []
source_display_name: Optional[str] = None
twitter_url: Optional[str] = None
contact_campaign_statuses: list[ContactCampaignStatus] = []
state: Optional[str] = None
city: Optional[str] = None
country: Optional[str] = None
account: Optional[Account] = None
contact_emails: list[ContactEmail] = []
organization: Optional[Organization] = None
employment_history: list[EmploymentHistory] = []
time_zone: Optional[str] = None
intent_strength: Optional[str] = None
show_intent: bool = True
phone_numbers: list[PhoneNumber] = []
account_phone_note: Optional[str] = None
free_domain: bool = True
is_likely_to_engage: bool = True
email_domain_catchall: bool = True
contact_job_change_event: Optional[str] = None
contact_roles: Optional[list[Any]] = []
id: Optional[str] = ""
first_name: Optional[str] = ""
last_name: Optional[str] = ""
name: Optional[str] = ""
linkedin_url: Optional[str] = ""
title: Optional[str] = ""
contact_stage_id: Optional[str] = ""
owner_id: Optional[str] = ""
creator_id: Optional[str] = ""
person_id: Optional[str] = ""
email_needs_tickling: Optional[bool] = True
organization_name: Optional[str] = ""
source: Optional[str] = ""
original_source: Optional[str] = ""
organization_id: Optional[str] = ""
headline: Optional[str] = ""
photo_url: Optional[str] = ""
present_raw_address: Optional[str] = ""
linkededin_uid: Optional[str] = ""
extrapolated_email_confidence: Optional[float] = 0.0
salesforce_id: Optional[str] = ""
salesforce_lead_id: Optional[str] = ""
salesforce_contact_id: Optional[str] = ""
saleforce_account_id: Optional[str] = ""
crm_owner_id: Optional[str] = ""
created_at: Optional[str] = ""
emailer_campaign_ids: Optional[list[str]] = []
direct_dial_status: Optional[str] = ""
direct_dial_enrichment_failed_at: Optional[str] = ""
email_status: Optional[str] = ""
email_source: Optional[str] = ""
account_id: Optional[str] = ""
last_activity_date: Optional[str] = ""
hubspot_vid: Optional[str] = ""
hubspot_company_id: Optional[str] = ""
crm_id: Optional[str] = ""
sanitized_phone: Optional[str] = ""
merged_crm_ids: Optional[str] = ""
updated_at: Optional[str] = ""
queued_for_crm_push: Optional[bool] = True
suggested_from_rule_engine_config_id: Optional[str] = ""
email_unsubscribed: Optional[str] = ""
label_ids: Optional[list[Any]] = []
has_pending_email_arcgate_request: Optional[bool] = True
has_email_arcgate_request: Optional[bool] = True
existence_level: Optional[str] = ""
email: Optional[str] = ""
email_from_customer: Optional[str] = ""
typed_custom_fields: Optional[list[TypedCustomField]] = []
custom_field_errors: Optional[Any] = {}
salesforce_record_id: Optional[str] = ""
crm_record_url: Optional[str] = ""
email_status_unavailable_reason: Optional[str] = ""
email_true_status: Optional[str] = ""
updated_email_true_status: Optional[bool] = True
contact_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
source_display_name: Optional[str] = ""
twitter_url: Optional[str] = ""
contact_campaign_statuses: Optional[list[ContactCampaignStatus]] = []
state: Optional[str] = ""
city: Optional[str] = ""
country: Optional[str] = ""
account: Optional[Account] = Account()
contact_emails: Optional[list[ContactEmail]] = []
organization: Optional[Organization] = Organization()
employment_history: Optional[list[EmploymentHistory]] = []
time_zone: Optional[str] = ""
intent_strength: Optional[str] = ""
show_intent: Optional[bool] = True
phone_numbers: Optional[list[PhoneNumber]] = []
account_phone_note: Optional[str] = ""
free_domain: Optional[bool] = True
is_likely_to_engage: Optional[bool] = True
email_domain_catchall: Optional[bool] = True
contact_job_change_event: Optional[str] = ""
class SearchOrganizationsRequest(BaseModel):
"""Request for Apollo's search organizations API"""
organization_num_empoloyees_range: list[int] = SchemaField(
organization_num_employees_range: Optional[list[int]] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
default=[0, 1000000],
)
organization_locations: list[str] = SchemaField(
organization_locations: Optional[list[str]] = SchemaField(
description="""The location of the company headquarters. You can search across cities, US states, and countries.
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, any Boston-based companies will not appearch in your search results, even if they match other parameters.
@@ -389,28 +387,30 @@ To exclude companies based on location, use the organization_not_locations param
""",
default_factory=list,
)
organizations_not_locations: list[str] = SchemaField(
organizations_not_locations: Optional[list[str]] = SchemaField(
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
""",
default_factory=list,
)
q_organization_keyword_tags: list[str] = SchemaField(
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
q_organization_keyword_tags: Optional[list[str]] = SchemaField(
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
default_factory=list,
)
q_organization_name: str = SchemaField(
q_organization_name: Optional[str] = SchemaField(
description="""Filter search results to include a specific company name.
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible."""
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible.""",
default="",
)
organization_ids: list[str] = SchemaField(
organization_ids: Optional[list[str]] = SchemaField(
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, identify the values for organization_id when you call this endpoint.""",
default_factory=list,
)
max_results: int = SchemaField(
max_results: Optional[int] = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
default=100,
ge=1,
@@ -435,11 +435,11 @@ Use the page parameter to search the different pages of data.""",
class SearchOrganizationsResponse(BaseModel):
"""Response from Apollo's search organizations API"""
breadcrumbs: list[Breadcrumb] = []
partial_results_only: bool = True
has_join: bool = True
disable_eu_prospecting: bool = True
partial_results_limit: int = 0
breadcrumbs: Optional[list[Breadcrumb]] = []
partial_results_only: Optional[bool] = True
has_join: Optional[bool] = True
disable_eu_prospecting: Optional[bool] = True
partial_results_limit: Optional[int] = 0
pagination: Pagination = Pagination(
page=0, per_page=0, total_entries=0, total_pages=0
)
@@ -447,14 +447,14 @@ class SearchOrganizationsResponse(BaseModel):
accounts: list[Any] = []
organizations: list[Organization] = []
models_ids: list[str] = []
num_fetch_result: Optional[str] = "N/A"
derived_params: Optional[str] = "N/A"
num_fetch_result: Optional[str] = ""
derived_params: Optional[str] = ""
class SearchPeopleRequest(BaseModel):
"""Request for Apollo's search people API"""
person_titles: list[str] = SchemaField(
person_titles: Optional[list[str]] = SchemaField(
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
Results also include job titles with the same terms, even if they are not exact matches. For example, searching for marketing manager might return people with the job title content marketing manager.
@@ -464,13 +464,13 @@ Use this parameter in combination with the person_seniorities[] parameter to fin
default_factory=list,
placeholder="marketing manager",
)
person_locations: list[str] = SchemaField(
person_locations: Optional[list[str]] = SchemaField(
description="""The location where people live. You can search across cities, US states, and countries.
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
default_factory=list,
)
person_seniorities: list[SenorityLevels] = SchemaField(
person_seniorities: Optional[list[SenorityLevels]] = SchemaField(
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
For a person to be included in search results, they only need to match 1 of the seniorities you add. Adding more seniorities expands your search results.
@@ -480,7 +480,7 @@ Searches only return results based on their current job title, so searching for
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
default_factory=list,
)
organization_locations: list[str] = SchemaField(
organization_locations: Optional[list[str]] = SchemaField(
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
@@ -488,7 +488,7 @@ If a company has several office locations, results are still based on the headqu
To find people based on their personal location, use the person_locations parameter.""",
default_factory=list,
)
q_organization_domains: list[str] = SchemaField(
q_organization_domains: Optional[list[str]] = SchemaField(
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
You can add multiple domains to search across companies.
@@ -496,23 +496,23 @@ You can add multiple domains to search across companies.
Examples: apollo.io and microsoft.com""",
default_factory=list,
)
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
contact_email_statuses: Optional[list[ContactEmailStatuses]] = SchemaField(
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
default_factory=list,
)
organization_ids: list[str] = SchemaField(
organization_ids: Optional[list[str]] = SchemaField(
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
default_factory=list,
)
organization_num_empoloyees_range: list[int] = SchemaField(
organization_num_employees_range: Optional[list[int]] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
default_factory=list,
)
q_keywords: str = SchemaField(
q_keywords: Optional[str] = SchemaField(
description="""A string of words over which we want to filter the results""",
default="",
)
@@ -528,7 +528,7 @@ Use this parameter in combination with the per_page parameter to make search res
Use the page parameter to search the different pages of data.""",
default=100,
)
max_results: int = SchemaField(
max_results: Optional[int] = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
default=100,
ge=1,
@@ -547,16 +547,61 @@ class SearchPeopleResponse(BaseModel):
populate_by_name=True,
)
breadcrumbs: list[Breadcrumb] = []
partial_results_only: bool = True
has_join: bool = True
disable_eu_prospecting: bool = True
partial_results_limit: int = 0
breadcrumbs: Optional[list[Breadcrumb]] = []
partial_results_only: Optional[bool] = True
has_join: Optional[bool] = True
disable_eu_prospecting: Optional[bool] = True
partial_results_limit: Optional[int] = 0
pagination: Pagination = Pagination(
page=0, per_page=0, total_entries=0, total_pages=0
)
contacts: list[Contact] = []
people: list[Contact] = []
model_ids: list[str] = []
num_fetch_result: Optional[str] = "N/A"
derived_params: Optional[str] = "N/A"
num_fetch_result: Optional[str] = ""
derived_params: Optional[str] = ""
class EnrichPersonRequest(BaseModel):
"""Request for Apollo's person enrichment API"""
person_id: Optional[str] = SchemaField(
description="Apollo person ID to enrich (most accurate method)",
default="",
)
first_name: Optional[str] = SchemaField(
description="First name of the person to enrich",
default="",
)
last_name: Optional[str] = SchemaField(
description="Last name of the person to enrich",
default="",
)
name: Optional[str] = SchemaField(
description="Full name of the person to enrich",
default="",
)
email: Optional[str] = SchemaField(
description="Email address of the person to enrich",
default="",
)
domain: Optional[str] = SchemaField(
description="Company domain of the person to enrich",
default="",
)
company: Optional[str] = SchemaField(
description="Company name of the person to enrich",
default="",
)
linkedin_url: Optional[str] = SchemaField(
description="LinkedIn URL of the person to enrich",
default="",
)
organization_id: Optional[str] = SchemaField(
description="Apollo organization ID of the person's company",
default="",
)
title: Optional[str] = SchemaField(
description="Job title of the person to enrich",
default="",
)

View File

@@ -11,14 +11,14 @@ from backend.blocks.apollo.models import (
SearchOrganizationsRequest,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.data.model import CredentialsField, SchemaField
class SearchOrganizationsBlock(Block):
"""Search for organizations in Apollo"""
class Input(BlockSchema):
organization_num_empoloyees_range: list[int] = SchemaField(
organization_num_employees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
@@ -65,7 +65,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
le=50000,
advanced=True,
)
credentials: ApolloCredentialsInput = SchemaField(
credentials: ApolloCredentialsInput = CredentialsField(
description="Apollo credentials",
)

View File

@@ -1,3 +1,5 @@
import asyncio
from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS,
@@ -8,11 +10,12 @@ from backend.blocks.apollo._auth import (
from backend.blocks.apollo.models import (
Contact,
ContactEmailStatuses,
EnrichPersonRequest,
SearchPeopleRequest,
SenorityLevels,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.data.model import CredentialsField, SchemaField
class SearchPeopleBlock(Block):
@@ -77,7 +80,7 @@ class SearchPeopleBlock(Block):
default_factory=list,
advanced=False,
)
organization_num_empoloyees_range: list[int] = SchemaField(
organization_num_employees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
@@ -90,14 +93,19 @@ class SearchPeopleBlock(Block):
advanced=False,
)
max_results: int = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
default=100,
description="""The maximum number of results to return. If you don't specify this parameter, the default is 25. Limited to 500 to prevent overspending.""",
default=25,
ge=1,
le=50000,
le=500,
advanced=True,
)
enrich_info: bool = SchemaField(
description="""Whether to enrich contacts with detailed information including real email addresses. This will double the search cost.""",
default=False,
advanced=True,
)
credentials: ApolloCredentialsInput = SchemaField(
credentials: ApolloCredentialsInput = CredentialsField(
description="Apollo credentials",
)
@@ -106,10 +114,6 @@ class SearchPeopleBlock(Block):
description="List of people found",
default_factory=list,
)
person: Contact = SchemaField(
title="Person",
description="Each found person, one at a time",
)
error: str = SchemaField(
description="Error message if the search failed",
default="",
@@ -125,87 +129,6 @@ class SearchPeopleBlock(Block):
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT},
test_output=[
(
"person",
Contact(
contact_roles=[],
id="1",
name="John Doe",
first_name="John",
last_name="Doe",
linkedin_url="https://www.linkedin.com/in/johndoe",
title="Software Engineer",
organization_name="Google",
organization_id="123456",
contact_stage_id="1",
owner_id="1",
creator_id="1",
person_id="1",
email_needs_tickling=True,
source="apollo",
original_source="apollo",
headline="Software Engineer",
photo_url="https://www.linkedin.com/in/johndoe",
present_raw_address="123 Main St, Anytown, USA",
linkededin_uid="123456",
extrapolated_email_confidence=0.8,
salesforce_id="123456",
salesforce_lead_id="123456",
salesforce_contact_id="123456",
saleforce_account_id="123456",
crm_owner_id="123456",
created_at="2021-01-01",
emailer_campaign_ids=[],
direct_dial_status="active",
direct_dial_enrichment_failed_at="2021-01-01",
email_status="active",
email_source="apollo",
account_id="123456",
last_activity_date="2021-01-01",
hubspot_vid="123456",
hubspot_company_id="123456",
crm_id="123456",
sanitized_phone="123456",
merged_crm_ids="123456",
updated_at="2021-01-01",
queued_for_crm_push=True,
suggested_from_rule_engine_config_id="123456",
email_unsubscribed=None,
label_ids=[],
has_pending_email_arcgate_request=True,
has_email_arcgate_request=True,
existence_level=None,
email=None,
email_from_customer=None,
typed_custom_fields=[],
custom_field_errors=None,
salesforce_record_id=None,
crm_record_url=None,
email_status_unavailable_reason=None,
email_true_status=None,
updated_email_true_status=True,
contact_rule_config_statuses=[],
source_display_name=None,
twitter_url=None,
contact_campaign_statuses=[],
state=None,
city=None,
country=None,
account=None,
contact_emails=[],
organization=None,
employment_history=[],
time_zone=None,
intent_strength=None,
show_intent=True,
phone_numbers=[],
account_phone_note=None,
free_domain=True,
is_likely_to_engage=True,
email_domain_catchall=True,
contact_job_change_event=None,
),
),
(
"people",
[
@@ -380,6 +303,34 @@ class SearchPeopleBlock(Block):
client = ApolloClient(credentials)
return await client.search_people(query)
@staticmethod
async def enrich_person(
query: EnrichPersonRequest, credentials: ApolloCredentials
) -> Contact:
client = ApolloClient(credentials)
return await client.enrich_person(query)
@staticmethod
def merge_contact_data(original: Contact, enriched: Contact) -> Contact:
"""
Merge contact data from original search with enriched data.
Enriched data complements original data, only filling in missing values.
"""
merged_data = original.model_dump()
enriched_data = enriched.model_dump()
# Only update fields that are None, empty string, empty list, or default values in original
for key, enriched_value in enriched_data.items():
# Skip if enriched value is None, empty string, or empty list
if enriched_value is None or enriched_value == "" or enriched_value == []:
continue
# Update if original value is None, empty string, empty list, or zero
if enriched_value:
merged_data[key] = enriched_value
return Contact(**merged_data)
async def run(
self,
input_data: Input,
@@ -390,6 +341,23 @@ class SearchPeopleBlock(Block):
query = SearchPeopleRequest(**input_data.model_dump())
people = await self.search_people(query, credentials)
for person in people:
yield "person", person
# Enrich with detailed info if requested
if input_data.enrich_info:
async def enrich_or_fallback(person: Contact):
try:
enrich_query = EnrichPersonRequest(person_id=person.id)
enriched_person = await self.enrich_person(
enrich_query, credentials
)
# Merge enriched data with original data, complementing instead of replacing
return self.merge_contact_data(person, enriched_person)
except Exception:
return person # If enrichment fails, use original person data
people = await asyncio.gather(
*(enrich_or_fallback(person) for person in people)
)
yield "people", people

View File

@@ -0,0 +1,138 @@
from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
ApolloCredentials,
ApolloCredentialsInput,
)
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, SchemaField
class GetPersonDetailBlock(Block):
"""Get detailed person data with Apollo API, including email reveal"""
class Input(BlockSchema):
person_id: str = SchemaField(
description="Apollo person ID to enrich (most accurate method)",
default="",
advanced=False,
)
first_name: str = SchemaField(
description="First name of the person to enrich",
default="",
advanced=False,
)
last_name: str = SchemaField(
description="Last name of the person to enrich",
default="",
advanced=False,
)
name: str = SchemaField(
description="Full name of the person to enrich (alternative to first_name + last_name)",
default="",
advanced=False,
)
email: str = SchemaField(
description="Known email address of the person (helps with matching)",
default="",
advanced=False,
)
domain: str = SchemaField(
description="Company domain of the person (e.g., 'google.com')",
default="",
advanced=False,
)
company: str = SchemaField(
description="Company name of the person",
default="",
advanced=False,
)
linkedin_url: str = SchemaField(
description="LinkedIn URL of the person",
default="",
advanced=False,
)
organization_id: str = SchemaField(
description="Apollo organization ID of the person's company",
default="",
advanced=True,
)
title: str = SchemaField(
description="Job title of the person to enrich",
default="",
advanced=True,
)
credentials: ApolloCredentialsInput = CredentialsField(
description="Apollo credentials",
)
class Output(BlockSchema):
contact: Contact = SchemaField(
description="Enriched contact information",
)
error: str = SchemaField(
description="Error message if enrichment failed",
default="",
)
def __init__(self):
super().__init__(
id="3b18d46c-3db6-42ae-a228-0ba441bdd176",
description="Get detailed person data with Apollo API, including email reveal",
categories={BlockCategory.SEARCH},
input_schema=GetPersonDetailBlock.Input,
output_schema=GetPersonDetailBlock.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"first_name": "John",
"last_name": "Doe",
"company": "Google",
},
test_output=[
(
"contact",
Contact(
id="1",
name="John Doe",
first_name="John",
last_name="Doe",
email="john.doe@gmail.com",
title="Software Engineer",
organization_name="Google",
linkedin_url="https://www.linkedin.com/in/johndoe",
),
),
],
test_mock={
"enrich_person": lambda query, credentials: Contact(
id="1",
name="John Doe",
first_name="John",
last_name="Doe",
email="john.doe@gmail.com",
title="Software Engineer",
organization_name="Google",
linkedin_url="https://www.linkedin.com/in/johndoe",
)
},
)
@staticmethod
async def enrich_person(
query: EnrichPersonRequest, credentials: ApolloCredentials
) -> Contact:
client = ApolloClient(credentials)
return await client.enrich_person(query)
async def run(
self,
input_data: Input,
*,
credentials: ApolloCredentials,
**kwargs,
) -> BlockOutput:
query = EnrichPersonRequest(**input_data.model_dump())
yield "contact", await self.enrich_person(query, credentials)

View File

@@ -3,6 +3,7 @@ from typing import Any
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.type import convert
class ComparisonOperator(Enum):
@@ -181,7 +182,23 @@ class IfInputMatchesBlock(Block):
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
if input_data.input == input_data.value or input_data.input is input_data.value:
# If input_data.value is not matching input_data.input, convert value to type of input
if (
input_data.input != input_data.value
and input_data.input is not input_data.value
):
try:
# Only attempt conversion if input is not None and value is not None
if input_data.input is not None and input_data.value is not None:
input_type = type(input_data.input)
# Avoid converting if input_type is Any or object
if input_type not in (Any, object):
input_data.value = convert(input_data.value, input_type)
except Exception:
pass # If conversion fails, just leave value as is
if input_data.input == input_data.value:
yield "result", True
yield "yes_output", input_data.yes_value
else:

View File

@@ -3,11 +3,19 @@ import logging
from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import Literal
import aiofiles
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
HostScopedCredentials,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import (
MediaFileType,
get_exec_file_path,
@@ -19,6 +27,30 @@ from backend.util.request import Requests
logger = logging.getLogger(name=__name__)
# Host-scoped credentials for HTTP requests
HttpCredentials = CredentialsMetaInput[
Literal[ProviderName.HTTP], Literal["host_scoped"]
]
TEST_CREDENTIALS = HostScopedCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="http",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer test-token"),
"X-API-Key": SecretStr("test-api-key"),
},
title="Mock HTTP Host-Scoped Credentials",
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
class HttpMethod(Enum):
GET = "GET"
POST = "POST"
@@ -169,3 +201,62 @@ class SendWebRequestBlock(Block):
yield "client_error", result
else:
yield "server_error", result
class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
class Input(SendWebRequestBlock.Input):
credentials: HttpCredentials = CredentialsField(
description="HTTP host-scoped credentials for automatic header injection",
discriminator="url",
)
def __init__(self):
Block.__init__(
self,
id="fff86bcd-e001-4bad-a7f6-2eae4720c8dc",
description="Make an authenticated HTTP request with host-scoped credentials (JSON / form / multipart).",
categories={BlockCategory.OUTPUT},
input_schema=SendAuthenticatedWebRequestBlock.Input,
output_schema=SendWebRequestBlock.Output,
test_credentials=TEST_CREDENTIALS,
)
async def run( # type: ignore[override]
self,
input_data: Input,
*,
graph_exec_id: str,
credentials: HostScopedCredentials,
**kwargs,
) -> BlockOutput:
# Create SendWebRequestBlock.Input from our input (removing credentials field)
base_input = SendWebRequestBlock.Input(
url=input_data.url,
method=input_data.method,
headers=input_data.headers,
json_format=input_data.json_format,
body=input_data.body,
files_name=input_data.files_name,
files=input_data.files,
)
# Apply host-scoped credentials to headers
extra_headers = {}
if credentials.matches_url(input_data.url):
logger.debug(
f"Applying host-scoped credentials {credentials.id} for URL {input_data.url}"
)
extra_headers.update(credentials.get_headers_dict())
else:
logger.warning(
f"Host-scoped credentials {credentials.id} do not match URL {input_data.url}"
)
# Merge with user-provided headers (user headers take precedence)
base_input.headers = {**extra_headers, **input_data.headers}
# Use parent class run method
async for output_name, output_data in super().run(
base_input, graph_exec_id=graph_exec_id, **kwargs
):
yield output_name, output_data

View File

@@ -40,7 +40,7 @@ LLMProviderName = Literal[
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
TEST_CREDENTIALS = APIKeyCredentials(
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
id="769f6af7-820b-4d5d-9b7a-ab82bbc165f",
provider="openai",
api_key=SecretStr("mock-openai-api-key"),
title="Mock OpenAI API key",

View File

@@ -13,7 +13,7 @@ from backend.data.model import (
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
id="8cc8b2c5-d3e4-4b1c-84ad-e1e9fe2a0122",
provider="mem0",
api_key=SecretStr("mock-mem0-api-key"),
title="Mock Mem0 API key",

View File

@@ -17,7 +17,7 @@ from backend.blocks.smartlead.models import (
Sequence,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.data.model import CredentialsField, SchemaField
class CreateCampaignBlock(Block):
@@ -27,7 +27,7 @@ class CreateCampaignBlock(Block):
name: str = SchemaField(
description="The name of the campaign",
)
credentials: SmartLeadCredentialsInput = SchemaField(
credentials: SmartLeadCredentialsInput = CredentialsField(
description="SmartLead credentials",
)
@@ -119,7 +119,7 @@ class AddLeadToCampaignBlock(Block):
description="Settings for lead upload",
default=LeadUploadSettings(),
)
credentials: SmartLeadCredentialsInput = SchemaField(
credentials: SmartLeadCredentialsInput = CredentialsField(
description="SmartLead credentials",
)
@@ -251,7 +251,7 @@ class SaveCampaignSequencesBlock(Block):
default_factory=list,
advanced=False,
)
credentials: SmartLeadCredentialsInput = SchemaField(
credentials: SmartLeadCredentialsInput = CredentialsField(
description="SmartLead credentials",
)

View File

@@ -0,0 +1,485 @@
"""Comprehensive tests for HTTP block with HostScopedCredentials functionality."""
from typing import cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.blocks.http import (
HttpCredentials,
HttpMethod,
SendAuthenticatedWebRequestBlock,
)
from backend.data.model import HostScopedCredentials
from backend.util.request import Response
class TestHttpBlockWithHostScopedCredentials:
"""Test suite for HTTP block integration with HostScopedCredentials."""
@pytest.fixture
def http_block(self):
"""Create an HTTP block instance."""
return SendAuthenticatedWebRequestBlock()
@pytest.fixture
def mock_response(self):
"""Mock a successful HTTP response."""
response = MagicMock(spec=Response)
response.status = 200
response.headers = {"content-type": "application/json"}
response.json.return_value = {"success": True, "data": "test"}
return response
@pytest.fixture
def exact_match_credentials(self):
"""Create host-scoped credentials for exact domain matching."""
return HostScopedCredentials(
provider="http",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer exact-match-token"),
"X-API-Key": SecretStr("api-key-123"),
},
title="Exact Match API Credentials",
)
@pytest.fixture
def wildcard_credentials(self):
"""Create host-scoped credentials with wildcard pattern."""
return HostScopedCredentials(
provider="http",
host="*.github.com",
headers={
"Authorization": SecretStr("token ghp_wildcard123"),
},
title="GitHub Wildcard Credentials",
)
@pytest.fixture
def non_matching_credentials(self):
"""Create credentials that don't match test URLs."""
return HostScopedCredentials(
provider="http",
host="different.api.com",
headers={
"Authorization": SecretStr("Bearer non-matching-token"),
},
title="Non-matching Credentials",
)
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_http_block_with_exact_host_match(
self,
mock_requests_class,
http_block,
exact_match_credentials,
mock_response,
):
"""Test HTTP block with exact host matching credentials."""
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Prepare input data
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.example.com/data",
method=HttpMethod.GET,
headers={"User-Agent": "test-agent"},
credentials=cast(
HttpCredentials,
{
"id": exact_match_credentials.id,
"provider": "http",
"type": "host_scoped",
"title": exact_match_credentials.title,
},
),
)
# Execute with credentials provided by execution manager
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=exact_match_credentials,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify request headers include both credential and user headers
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {
"Authorization": "Bearer exact-match-token",
"X-API-Key": "api-key-123",
"User-Agent": "test-agent",
}
assert call_args.kwargs["headers"] == expected_headers
# Verify response handling
assert len(result) == 1
assert result[0][0] == "response"
assert result[0][1] == {"success": True, "data": "test"}
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_http_block_with_wildcard_host_match(
self,
mock_requests_class,
http_block,
wildcard_credentials,
mock_response,
):
"""Test HTTP block with wildcard host pattern matching."""
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Test with subdomain that should match *.github.com
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.github.com/user",
method=HttpMethod.GET,
headers={},
credentials=cast(
HttpCredentials,
{
"id": wildcard_credentials.id,
"provider": "http",
"type": "host_scoped",
"title": wildcard_credentials.title,
},
),
)
# Execute with wildcard credentials
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=wildcard_credentials,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify wildcard matching works
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {"Authorization": "token ghp_wildcard123"}
assert call_args.kwargs["headers"] == expected_headers
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_http_block_with_non_matching_credentials(
self,
mock_requests_class,
http_block,
non_matching_credentials,
mock_response,
):
"""Test HTTP block when credentials don't match the target URL."""
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Test with URL that doesn't match the credentials
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.example.com/data",
method=HttpMethod.GET,
headers={"User-Agent": "test-agent"},
credentials=cast(
HttpCredentials,
{
"id": non_matching_credentials.id,
"provider": "http",
"type": "host_scoped",
"title": non_matching_credentials.title,
},
),
)
# Execute with non-matching credentials
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=non_matching_credentials,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify only user headers are included (no credential headers)
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {"User-Agent": "test-agent"}
assert call_args.kwargs["headers"] == expected_headers
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_user_headers_override_credential_headers(
self,
mock_requests_class,
http_block,
exact_match_credentials,
mock_response,
):
"""Test that user-provided headers take precedence over credential headers."""
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Test with user header that conflicts with credential header
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.example.com/data",
method=HttpMethod.POST,
headers={
"Authorization": "Bearer user-override-token", # Should override
"Content-Type": "application/json", # Additional user header
},
credentials=cast(
HttpCredentials,
{
"id": exact_match_credentials.id,
"provider": "http",
"type": "host_scoped",
"title": exact_match_credentials.title,
},
),
)
# Execute with conflicting headers
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=exact_match_credentials,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify user headers take precedence
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {
"X-API-Key": "api-key-123", # From credentials
"Authorization": "Bearer user-override-token", # User override
"Content-Type": "application/json", # User header
}
assert call_args.kwargs["headers"] == expected_headers
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_auto_discovered_credentials_flow(
self,
mock_requests_class,
http_block,
mock_response,
):
"""Test the auto-discovery flow where execution manager provides matching credentials."""
# Create auto-discovered credentials
auto_discovered_creds = HostScopedCredentials(
provider="http",
host="*.example.com",
headers={
"Authorization": SecretStr("Bearer auto-discovered-token"),
},
title="Auto-discovered Credentials",
)
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Test with empty credentials field (triggers auto-discovery)
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.example.com/data",
method=HttpMethod.GET,
headers={},
credentials=cast(
HttpCredentials,
{
"id": "", # Empty ID triggers auto-discovery in execution manager
"provider": "http",
"type": "host_scoped",
"title": "",
},
),
)
# Execute with auto-discovered credentials provided by execution manager
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=auto_discovered_creds, # Execution manager found these
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify auto-discovered credentials were applied
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {"Authorization": "Bearer auto-discovered-token"}
assert call_args.kwargs["headers"] == expected_headers
# Verify response handling
assert len(result) == 1
assert result[0][0] == "response"
assert result[0][1] == {"success": True, "data": "test"}
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_multiple_header_credentials(
self,
mock_requests_class,
http_block,
mock_response,
):
"""Test credentials with multiple headers are all applied."""
# Create credentials with multiple headers
multi_header_creds = HostScopedCredentials(
provider="http",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer multi-token"),
"X-API-Key": SecretStr("api-key-456"),
"X-Client-ID": SecretStr("client-789"),
"X-Custom-Header": SecretStr("custom-value"),
},
title="Multi-Header Credentials",
)
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Test with credentials containing multiple headers
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.example.com/data",
method=HttpMethod.GET,
headers={"User-Agent": "test-agent"},
credentials=cast(
HttpCredentials,
{
"id": multi_header_creds.id,
"provider": "http",
"type": "host_scoped",
"title": multi_header_creds.title,
},
),
)
# Execute with multi-header credentials
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=multi_header_creds,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify all headers are included
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {
"Authorization": "Bearer multi-token",
"X-API-Key": "api-key-456",
"X-Client-ID": "client-789",
"X-Custom-Header": "custom-value",
"User-Agent": "test-agent",
}
assert call_args.kwargs["headers"] == expected_headers
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_credentials_with_complex_url_patterns(
self,
mock_requests_class,
http_block,
mock_response,
):
"""Test credentials matching various URL patterns."""
# Test cases for different URL patterns
test_cases = [
{
"host_pattern": "api.example.com",
"test_url": "https://api.example.com/v1/users",
"should_match": True,
},
{
"host_pattern": "*.example.com",
"test_url": "https://api.example.com/v1/users",
"should_match": True,
},
{
"host_pattern": "*.example.com",
"test_url": "https://subdomain.example.com/data",
"should_match": True,
},
{
"host_pattern": "api.example.com",
"test_url": "https://api.different.com/data",
"should_match": False,
},
]
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
for case in test_cases:
# Reset mock for each test case
mock_requests.reset_mock()
# Create credentials for this test case
test_creds = HostScopedCredentials(
provider="http",
host=case["host_pattern"],
headers={
"Authorization": SecretStr(f"Bearer {case['host_pattern']}-token"),
},
title=f"Credentials for {case['host_pattern']}",
)
input_data = SendAuthenticatedWebRequestBlock.Input(
url=case["test_url"],
method=HttpMethod.GET,
headers={"User-Agent": "test-agent"},
credentials=cast(
HttpCredentials,
{
"id": test_creds.id,
"provider": "http",
"type": "host_scoped",
"title": test_creds.title,
},
),
)
# Execute with test credentials
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=test_creds,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify headers based on whether pattern should match
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
headers = call_args.kwargs["headers"]
if case["should_match"]:
# Should include both user and credential headers
expected_auth = f"Bearer {case['host_pattern']}-token"
assert headers["Authorization"] == expected_auth
assert headers["User-Agent"] == "test-agent"
else:
# Should only include user headers
assert "Authorization" not in headers
assert headers["User-Agent"] == "test-agent"

View File

@@ -25,12 +25,7 @@ async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Grap
async def create_credentials(s: SpinTestServer, u: User):
provider = ProviderName.OPENAI
credentials = llm.TEST_CREDENTIALS
try:
await s.agent_server.test_create_credentials(u.id, provider, credentials)
except Exception:
# ValueErrors is raised trying to recreate the same credentials
# so hidding the error
pass
return await s.agent_server.test_create_credentials(u.id, provider, credentials)
async def execute_graph(
@@ -60,19 +55,18 @@ async def execute_graph(
return graph_exec_id
@pytest.mark.skip()
@pytest.mark.asyncio(loop_scope="session")
async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
test_user = await create_test_user()
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
await create_credentials(server, test_user)
creds = await create_credentials(server, test_user)
nodes = [
graph.Node(
block_id=SmartDecisionMakerBlock().id,
input_default={
"prompt": "Hello, World!",
"credentials": llm.TEST_CREDENTIALS_INPUT,
"credentials": creds,
},
),
graph.Node(
@@ -110,80 +104,18 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
test_graph = await create_graph(server, test_graph, test_user)
@pytest.mark.skip()
@pytest.mark.asyncio(loop_scope="session")
async def test_graph_validation_with_tool_nodes_raises_error(server: SpinTestServer):
test_user = await create_test_user()
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
await create_credentials(server, test_user)
nodes = [
graph.Node(
block_id=SmartDecisionMakerBlock().id,
input_default={
"prompt": "Hello, World!",
"credentials": llm.TEST_CREDENTIALS_INPUT,
},
),
graph.Node(
block_id=AgentExecutorBlock().id,
input_default={
"graph_id": test_tool_graph.id,
"graph_version": test_tool_graph.version,
"input_schema": test_tool_graph.input_schema,
"output_schema": test_tool_graph.output_schema,
},
),
graph.Node(
block_id=StoreValueBlock().id,
),
]
links = [
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[1].id,
source_name="tools_^_sample_tool_input_1",
sink_name="input_1",
),
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[1].id,
source_name="tools_^_sample_tool_input_2",
sink_name="input_2",
),
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[2].id,
source_name="tools_^_store_value_input",
sink_name="input",
),
]
test_graph = graph.Graph(
name="TestGraph",
description="Test graph",
nodes=nodes,
links=links,
)
with pytest.raises(ValueError):
test_graph = await create_graph(server, test_graph, test_user)
@pytest.mark.skip()
@pytest.mark.asyncio(loop_scope="session")
async def test_smart_decision_maker_function_signature(server: SpinTestServer):
test_user = await create_test_user()
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
await create_credentials(server, test_user)
creds = await create_credentials(server, test_user)
nodes = [
graph.Node(
block_id=SmartDecisionMakerBlock().id,
input_default={
"prompt": "Hello, World!",
"credentials": llm.TEST_CREDENTIALS_INPUT,
"credentials": creds,
},
),
graph.Node(

View File

@@ -15,7 +15,7 @@ from backend.blocks.zerobounce._auth import (
ZeroBounceCredentialsInput,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.data.model import CredentialsField, SchemaField
class Response(BaseModel):
@@ -90,7 +90,7 @@ class ValidateEmailsBlock(Block):
description="IP address to validate",
default="",
)
credentials: ZeroBounceCredentialsInput = SchemaField(
credentials: ZeroBounceCredentialsInput = CredentialsField(
description="ZeroBounce credentials",
)

View File

@@ -6,6 +6,8 @@ from dotenv import load_dotenv
from backend.util.logging import configure_logging
os.environ["ENABLE_AUTH"] = "false"
load_dotenv()
# NOTE: You can run tests like with the --log-cli-level=INFO to see the logs

View File

@@ -78,6 +78,7 @@ class BlockCategory(Enum):
PRODUCTIVITY = "Block that helps with productivity"
ISSUE_TRACKING = "Block that helps with issue tracking"
MULTIMEDIA = "Block that interacts with multimedia content"
MARKETING = "Block that helps with marketing"
def dict(self) -> dict[str, str]:
return {"category": self.name, "description": self.value}

View File

@@ -4,6 +4,7 @@ from backend.blocks.ai_music_generator import AIMusicGeneratorBlock
from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
from backend.blocks.apollo.organization import SearchOrganizationsBlock
from backend.blocks.apollo.people import SearchPeopleBlock
from backend.blocks.apollo.person import GetPersonDetailBlock
from backend.blocks.flux_kontext import AIImageEditorBlock, FluxKontextModelName
from backend.blocks.ideogram import IdeogramModelBlock
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
@@ -362,7 +363,31 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
],
SearchPeopleBlock: [
BlockCost(
cost_amount=2,
cost_amount=10,
cost_filter={
"enrich_info": False,
"credentials": {
"id": apollo_credentials.id,
"provider": apollo_credentials.provider,
"type": apollo_credentials.type,
},
},
),
BlockCost(
cost_amount=20,
cost_filter={
"enrich_info": True,
"credentials": {
"id": apollo_credentials.id,
"provider": apollo_credentials.provider,
"type": apollo_credentials.type,
},
},
),
],
GetPersonDetailBlock: [
BlockCost(
cost_amount=1,
cost_filter={
"credentials": {
"id": apollo_credentials.id,

View File

@@ -7,7 +7,7 @@ from pydantic import BaseModel
from redis.asyncio.client import PubSub as AsyncPubSub
from redis.client import PubSub
from backend.data import redis
from backend.data import redis_client as redis
logger = logging.getLogger(__name__)

View File

@@ -48,6 +48,7 @@ from .block import (
get_webhook_block_ids,
)
from .db import BaseDbModel
from .event_bus import AsyncRedisEventBus, RedisEventBus
from .includes import (
EXECUTION_RESULT_INCLUDE,
EXECUTION_RESULT_ORDER,
@@ -55,7 +56,6 @@ from .includes import (
graph_execution_include,
)
from .model import GraphExecutionStats, NodeExecutionStats
from .queue import AsyncRedisEventBus, RedisEventBus
T = TypeVar("T")

View File

@@ -27,6 +27,7 @@ from backend.data.model import (
CredentialsMetaInput,
is_credentials_field_name,
)
from backend.integrations.providers import ProviderName
from backend.util import type as type_utils
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
@@ -243,6 +244,8 @@ class Graph(BaseGraph):
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
if field.provider != other_field.provider:
continue
if ProviderName.HTTP in field.provider:
continue
# If this happens, that means a block implementation probably needs
# to be updated.
@@ -264,6 +267,7 @@ class Graph(BaseGraph):
required_scopes=set(field_info.required_scopes or []),
discriminator=field_info.discriminator,
discriminator_mapping=field_info.discriminator_mapping,
discriminator_values=field_info.discriminator_values,
),
)
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
@@ -282,37 +286,40 @@ class Graph(BaseGraph):
Returns:
dict[aggregated_field_key, tuple(
CredentialsFieldInfo: A spec for one aggregated credentials field
(now includes discriminator_values from matching nodes)
set[(node_id, field_name)]: Node credentials fields that are
compatible with this aggregated field spec
)]
"""
return {
"_".join(sorted(agg_field_info.provider))
+ "_"
+ "_".join(sorted(agg_field_info.supported_types))
+ "_credentials": (agg_field_info, node_fields)
for agg_field_info, node_fields in CredentialsFieldInfo.combine(
*(
(
# Apply discrimination before aggregating credentials inputs
(
field_info.discriminate(
node.input_default[field_info.discriminator]
)
if (
field_info.discriminator
and node.input_default.get(field_info.discriminator)
)
else field_info
),
(node.id, field_name),
# First collect all credential field data with input defaults
node_credential_data = []
for graph in [self] + self.sub_graphs:
for node in graph.nodes:
for (
field_name,
field_info,
) in node.block.input_schema.get_credentials_fields_info().items():
discriminator = field_info.discriminator
if not discriminator:
node_credential_data.append((field_info, (node.id, field_name)))
continue
discriminator_value = node.input_default.get(discriminator)
if discriminator_value is None:
node_credential_data.append((field_info, (node.id, field_name)))
continue
discriminated_info = field_info.discriminate(discriminator_value)
discriminated_info.discriminator_values.add(discriminator_value)
node_credential_data.append(
(discriminated_info, (node.id, field_name))
)
for graph in [self] + self.sub_graphs
for node in graph.nodes
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
)
)
}
# Combine credential field info (this will merge discriminator_values automatically)
return CredentialsFieldInfo.combine(*node_credential_data)
class GraphModel(Graph):
@@ -391,7 +398,9 @@ class GraphModel(Graph):
continue
node.input_default["user_id"] = user_id
node.input_default.setdefault("inputs", {})
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
if (
graph_id := node.input_default.get("graph_id")
) and graph_id in graph_id_map:
node.input_default["graph_id"] = graph_id_map[graph_id]
def validate_graph(

View File

@@ -11,8 +11,8 @@ from prisma.types import (
)
from pydantic import Field, computed_field
from backend.data.event_bus import AsyncRedisEventBus
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
from backend.data.queue import AsyncRedisEventBus
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.utils import webhook_ingress_url
from backend.server.v2.library.model import LibraryAgentPreset

View File

@@ -14,11 +14,12 @@ from typing import (
Generic,
Literal,
Optional,
Sequence,
TypedDict,
TypeVar,
cast,
get_args,
)
from urllib.parse import urlparse
from uuid import uuid4
from prisma.enums import CreditTransactionType
@@ -240,13 +241,65 @@ class UserPasswordCredentials(_BaseCredentials):
return f"Basic {base64.b64encode(f'{self.username.get_secret_value()}:{self.password.get_secret_value()}'.encode()).decode()}"
class HostScopedCredentials(_BaseCredentials):
type: Literal["host_scoped"] = "host_scoped"
host: str = Field(description="The host/URI pattern to match against request URLs")
headers: dict[str, SecretStr] = Field(
description="Key-value header map to add to matching requests",
default_factory=dict,
)
def _extract_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]:
"""Helper to extract secret values from headers."""
return {key: value.get_secret_value() for key, value in headers.items()}
@field_serializer("headers")
def serialize_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]:
"""Serialize headers by extracting secret values."""
return self._extract_headers(headers)
def get_headers_dict(self) -> dict[str, str]:
"""Get headers with secret values extracted."""
return self._extract_headers(self.headers)
def auth_header(self) -> str:
"""Get authorization header for backward compatibility."""
auth_headers = self.get_headers_dict()
if "Authorization" in auth_headers:
return auth_headers["Authorization"]
return ""
def matches_url(self, url: str) -> bool:
"""Check if this credential should be applied to the given URL."""
parsed_url = urlparse(url)
# Extract hostname without port
request_host = parsed_url.hostname
if not request_host:
return False
# Simple host matching - exact match or wildcard subdomain match
if self.host == request_host:
return True
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
if self.host.startswith("*."):
domain = self.host[2:] # Remove "*."
return request_host.endswith(f".{domain}") or request_host == domain
return False
Credentials = Annotated[
OAuth2Credentials | APIKeyCredentials | UserPasswordCredentials,
OAuth2Credentials
| APIKeyCredentials
| UserPasswordCredentials
| HostScopedCredentials,
Field(discriminator="type"),
]
CredentialsType = Literal["api_key", "oauth2", "user_password"]
CredentialsType = Literal["api_key", "oauth2", "user_password", "host_scoped"]
class OAuthState(BaseModel):
@@ -320,15 +373,29 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
)
@staticmethod
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
schema["credentials_provider"] = cls.allowed_providers()
schema["credentials_types"] = cls.allowed_cred_types()
def _add_json_schema_extra(schema: dict, model_class: type):
# Use model_class for allowed_providers/cred_types
if hasattr(model_class, "allowed_providers") and hasattr(
model_class, "allowed_cred_types"
):
schema["credentials_provider"] = model_class.allowed_providers()
schema["credentials_types"] = model_class.allowed_cred_types()
# Do not return anything, just mutate schema in place
model_config = ConfigDict(
json_schema_extra=_add_json_schema_extra, # type: ignore
)
def _extract_host_from_url(url: str) -> str:
"""Extract host from URL for grouping host-scoped credentials."""
try:
parsed = urlparse(url)
return parsed.hostname or url
except Exception:
return ""
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
provider: frozenset[CP] = Field(..., alias="credentials_provider")
@@ -336,11 +403,12 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
discriminator: Optional[str] = None
discriminator_mapping: Optional[dict[str, CP]] = None
discriminator_values: set[Any] = Field(default_factory=set)
@classmethod
def combine(
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
) -> dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
"""
Combines multiple CredentialsFieldInfo objects into as few as possible.
@@ -358,22 +426,36 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
the set of keys of the respective original items that were grouped together.
"""
if not fields:
return []
return {}
# Group fields by their provider and supported_types
# For HTTP host-scoped credentials, also group by host
grouped_fields: defaultdict[
tuple[frozenset[CP], frozenset[CT]],
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
] = defaultdict(list)
for field, key in fields:
group_key = (frozenset(field.provider), frozenset(field.supported_types))
if field.provider == frozenset([ProviderName.HTTP]):
# HTTP host-scoped credentials can have different hosts that reqires different credential sets.
# Group by host extracted from the URL
providers = frozenset(
[cast(CP, "http")]
+ [
cast(CP, _extract_host_from_url(str(value)))
for value in field.discriminator_values
]
)
else:
providers = frozenset(field.provider)
group_key = (providers, frozenset(field.supported_types))
grouped_fields[group_key].append((key, field))
# Combine fields within each group
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
result: dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]] = {}
for group in grouped_fields.values():
for key, group in grouped_fields.items():
# Start with the first field in the group
_, combined = group[0]
@@ -386,18 +468,32 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
if field.required_scopes:
all_scopes.update(field.required_scopes)
# Create a new combined field
result.append(
(
CredentialsFieldInfo[CP, CT](
credentials_provider=combined.provider,
credentials_types=combined.supported_types,
credentials_scopes=frozenset(all_scopes) or None,
discriminator=combined.discriminator,
discriminator_mapping=combined.discriminator_mapping,
),
combined_keys,
)
# Combine discriminator_values from all fields in the group (removing duplicates)
all_discriminator_values = []
for _, field in group:
for value in field.discriminator_values:
if value not in all_discriminator_values:
all_discriminator_values.append(value)
# Generate the key for the combined result
providers_key, supported_types_key = key
group_key = (
"-".join(sorted(providers_key))
+ "_"
+ "-".join(sorted(supported_types_key))
+ "_credentials"
)
result[group_key] = (
CredentialsFieldInfo[CP, CT](
credentials_provider=combined.provider,
credentials_types=combined.supported_types,
credentials_scopes=frozenset(all_scopes) or None,
discriminator=combined.discriminator,
discriminator_mapping=combined.discriminator_mapping,
discriminator_values=set(all_discriminator_values),
),
combined_keys,
)
return result
@@ -406,11 +502,15 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
if not (self.discriminator and self.discriminator_mapping):
return self
discriminator_value = self.discriminator_mapping[discriminator_value]
return CredentialsFieldInfo(
credentials_provider=frozenset([discriminator_value]),
credentials_provider=frozenset(
[self.discriminator_mapping[discriminator_value]]
),
credentials_types=self.supported_types,
credentials_scopes=self.required_scopes,
discriminator=self.discriminator,
discriminator_mapping=self.discriminator_mapping,
discriminator_values=self.discriminator_values,
)
@@ -419,6 +519,7 @@ def CredentialsField(
*,
discriminator: Optional[str] = None,
discriminator_mapping: Optional[dict[str, Any]] = None,
discriminator_values: Optional[set[Any]] = None,
title: Optional[str] = None,
description: Optional[str] = None,
**kwargs,
@@ -434,6 +535,7 @@ def CredentialsField(
"credentials_scopes": list(required_scopes) or None,
"discriminator": discriminator,
"discriminator_mapping": discriminator_mapping,
"discriminator_values": discriminator_values,
}.items()
if v is not None
}

View File

@@ -0,0 +1,143 @@
import pytest
from pydantic import SecretStr
from backend.data.model import HostScopedCredentials
class TestHostScopedCredentials:
def test_host_scoped_credentials_creation(self):
"""Test creating HostScopedCredentials with required fields."""
creds = HostScopedCredentials(
provider="custom",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer secret-token"),
"X-API-Key": SecretStr("api-key-123"),
},
title="Example API Credentials",
)
assert creds.type == "host_scoped"
assert creds.provider == "custom"
assert creds.host == "api.example.com"
assert creds.title == "Example API Credentials"
assert len(creds.headers) == 2
assert "Authorization" in creds.headers
assert "X-API-Key" in creds.headers
def test_get_headers_dict(self):
"""Test getting headers with secret values extracted."""
creds = HostScopedCredentials(
provider="custom",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer secret-token"),
"X-Custom-Header": SecretStr("custom-value"),
},
)
headers_dict = creds.get_headers_dict()
assert headers_dict == {
"Authorization": "Bearer secret-token",
"X-Custom-Header": "custom-value",
}
def test_matches_url_exact_host(self):
"""Test URL matching with exact host match."""
creds = HostScopedCredentials(
provider="custom",
host="api.example.com",
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url("https://api.example.com/v1/data")
assert creds.matches_url("http://api.example.com/endpoint")
assert not creds.matches_url("https://other.example.com/v1/data")
assert not creds.matches_url("https://subdomain.api.example.com/v1/data")
def test_matches_url_wildcard_subdomain(self):
"""Test URL matching with wildcard subdomain pattern."""
creds = HostScopedCredentials(
provider="custom",
host="*.example.com",
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url("https://api.example.com/v1/data")
assert creds.matches_url("https://subdomain.example.com/endpoint")
assert creds.matches_url("https://deep.nested.example.com/path")
assert creds.matches_url("https://example.com/path") # Base domain should match
assert not creds.matches_url("https://example.org/v1/data")
assert not creds.matches_url("https://notexample.com/v1/data")
def test_matches_url_with_port_and_path(self):
"""Test URL matching with ports and paths."""
creds = HostScopedCredentials(
provider="custom",
host="localhost",
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url("http://localhost:8080/api/v1")
assert creds.matches_url("https://localhost:443/secure/endpoint")
assert creds.matches_url("http://localhost/simple")
def test_empty_headers_dict(self):
"""Test HostScopedCredentials with empty headers."""
creds = HostScopedCredentials(
provider="custom", host="api.example.com", headers={}
)
assert creds.get_headers_dict() == {}
assert creds.matches_url("https://api.example.com/test")
def test_credential_serialization(self):
"""Test that credentials can be serialized/deserialized properly."""
original_creds = HostScopedCredentials(
provider="custom",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer secret-token"),
"X-API-Key": SecretStr("api-key-123"),
},
title="Test Credentials",
)
# Serialize to dict (simulating storage)
serialized = original_creds.model_dump()
# Deserialize back
restored_creds = HostScopedCredentials.model_validate(serialized)
assert restored_creds.id == original_creds.id
assert restored_creds.provider == original_creds.provider
assert restored_creds.host == original_creds.host
assert restored_creds.title == original_creds.title
assert restored_creds.type == "host_scoped"
# Check that headers are properly restored
assert restored_creds.get_headers_dict() == original_creds.get_headers_dict()
@pytest.mark.parametrize(
"host,test_url,expected",
[
("api.example.com", "https://api.example.com/test", True),
("api.example.com", "https://different.example.com/test", False),
("*.example.com", "https://api.example.com/test", True),
("*.example.com", "https://sub.api.example.com/test", True),
("*.example.com", "https://example.com/test", True),
("*.example.com", "https://example.org/test", False),
("localhost", "http://localhost:3000/test", True),
("localhost", "http://127.0.0.1:3000/test", False),
],
)
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
"""Parametrized test for various URL matching scenarios."""
creds = HostScopedCredentials(
provider="test",
host=host,
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url(test_url) == expected

View File

@@ -35,7 +35,7 @@ from autogpt_libs.utils.cache import thread_cached
from prometheus_client import Gauge, start_http_server
from backend.blocks.agent import AgentExecutorBlock
from backend.data import redis
from backend.data import redis_client as redis
from backend.data.block import (
BlockData,
BlockInput,

View File

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
from pydantic import SecretStr
from backend.data.redis import get_redis_async
from backend.data.redis_client import get_redis_async
if TYPE_CHECKING:
from backend.executor.database import DatabaseManagerAsyncClient

View File

@@ -7,7 +7,7 @@ from autogpt_libs.utils.synchronize import AsyncRedisKeyedMutex
from redis.asyncio.lock import Lock as AsyncRedisLock
from backend.data.model import Credentials, OAuth2Credentials
from backend.data.redis import get_redis_async
from backend.data.redis_client import get_redis_async
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName

View File

@@ -17,6 +17,7 @@ class ProviderName(str, Enum):
GOOGLE = "google"
GOOGLE_MAPS = "google_maps"
GROQ = "groq"
HTTP = "http"
HUBSPOT = "hubspot"
IDEOGRAM = "ideogram"
JINA = "jina"

View File

@@ -22,7 +22,12 @@ from backend.data.integrations import (
publish_webhook_event,
wait_for_webhook_event,
)
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
from backend.data.model import (
Credentials,
CredentialsType,
HostScopedCredentials,
OAuth2Credentials,
)
from backend.executor.utils import add_graph_execution
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME
@@ -82,6 +87,9 @@ class CredentialsMetaResponse(BaseModel):
title: str | None
scopes: list[str] | None
username: str | None
host: str | None = Field(
default=None, description="Host pattern for host-scoped credentials"
)
@router.post("/{provider}/callback")
@@ -156,6 +164,9 @@ async def callback(
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=(
credentials.host if isinstance(credentials, HostScopedCredentials) else None
),
)
@@ -172,6 +183,7 @@ async def list_credentials(
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@@ -193,6 +205,7 @@ async def list_credentials_by_provider(
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]

View File

@@ -357,11 +357,22 @@ class AgentServer(backend.util.service.AppProcess):
provider: ProviderName,
credentials: Credentials,
) -> Credentials:
from backend.server.integrations.router import create_credentials
return await create_credentials(
user_id=user_id, provider=provider, credentials=credentials
from backend.server.integrations.router import (
create_credentials,
get_credential,
)
try:
return await create_credentials(
user_id=user_id, provider=provider, credentials=credentials
)
except Exception as e:
logger.error(f"Error creating credentials: {e}")
return await get_credential(
provider=provider,
user_id=user_id,
cred_id=credentials.id,
)
def set_test_dependency_overrides(self, overrides: dict):
app.dependency_overrides.update(overrides)

View File

@@ -4,6 +4,7 @@ import logging
from typing import Annotated
import fastapi
import pydantic
import backend.data.analytics
from backend.server.utils import get_user_id
@@ -12,24 +13,28 @@ router = fastapi.APIRouter()
logger = logging.getLogger(__name__)
class LogRawMetricRequest(pydantic.BaseModel):
metric_name: str = pydantic.Field(..., min_length=1)
metric_value: float = pydantic.Field(..., allow_inf_nan=False)
data_string: str = pydantic.Field(..., min_length=1)
@router.post(path="/log_raw_metric")
async def log_raw_metric(
user_id: Annotated[str, fastapi.Depends(get_user_id)],
metric_name: Annotated[str, fastapi.Body(..., embed=True)],
metric_value: Annotated[float, fastapi.Body(..., embed=True)],
data_string: Annotated[str, fastapi.Body(..., embed=True)],
request: LogRawMetricRequest,
):
try:
result = await backend.data.analytics.log_raw_metric(
user_id=user_id,
metric_name=metric_name,
metric_value=metric_value,
data_string=data_string,
metric_name=request.metric_name,
metric_value=request.metric_value,
data_string=request.data_string,
)
return result.id
except Exception as e:
logger.exception(
"Failed to log metric %s for user %s: %s", metric_name, user_id, e
"Failed to log metric %s for user %s: %s", request.metric_name, user_id, e
)
raise fastapi.HTTPException(
status_code=500,

View File

@@ -97,8 +97,17 @@ def test_log_raw_metric_invalid_request_improved() -> None:
assert "data_string" in error_fields, "Should report missing data_string"
def test_log_raw_metric_type_validation_improved() -> None:
def test_log_raw_metric_type_validation_improved(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test metric type validation with improved assertions."""
# Mock the analytics function to avoid event loop issues
mocker.patch(
"backend.data.analytics.log_raw_metric",
new_callable=AsyncMock,
return_value=Mock(id="test-id"),
)
invalid_requests = [
{
"data": {
@@ -119,10 +128,10 @@ def test_log_raw_metric_type_validation_improved() -> None:
{
"data": {
"metric_name": "test",
"metric_value": float("inf"), # Infinity
"data_string": "test",
"metric_value": 123, # Valid number
"data_string": "", # Empty data_string
},
"expected_error": "ensure this value is finite",
"expected_error": "String should have at least 1 character",
},
]

View File

@@ -93,10 +93,18 @@ def test_log_raw_metric_values_parametrized(
],
)
def test_log_raw_metric_invalid_requests_parametrized(
mocker: pytest_mock.MockFixture,
invalid_data: dict,
expected_error: str,
) -> None:
"""Test invalid metric requests with parametrize."""
# Mock the analytics function to avoid event loop issues
mocker.patch(
"backend.data.analytics.log_raw_metric",
new_callable=AsyncMock,
return_value=Mock(id="test-id"),
)
response = client.post("/log_raw_metric", json=invalid_data)
assert response.status_code == 422

View File

@@ -108,11 +108,16 @@ class TestDatabaseIsolation:
where={"email": {"contains": "@test.example"}}
)
@pytest.fixture(scope="session")
async def test_create_user(self, test_db_connection):
"""Test that demonstrates proper isolation."""
# This test has access to a clean database
user = await test_db_connection.user.create(
data={"email": "test@test.example", "name": "Test User"}
data={
"id": "test-user-id",
"email": "test@test.example",
"name": "Test User",
}
)
assert user.email == "test@test.example"
# User will be cleaned up automatically

View File

@@ -143,7 +143,7 @@ async def test_add_agent_to_library(mocker):
)
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock(
return_value=mock_library_agent_data
)
@@ -159,21 +159,24 @@ async def test_add_agent_to_library(mocker):
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"AgentGraph": True}
)
mock_library_agent.return_value.find_first.assert_called_once_with(
mock_library_agent.return_value.find_unique.assert_called_once_with(
where={
"userId": "test-user",
"agentGraphId": "agent1",
"agentGraphVersion": 1,
"userId_agentGraphId_agentGraphVersion": {
"userId": "test-user",
"agentGraphId": "agent1",
"agentGraphVersion": 1,
}
},
include=library_agent_include("test-user"),
include={"AgentGraph": True},
)
mock_library_agent.return_value.create.assert_called_once_with(
data=prisma.types.LibraryAgentCreateInput(
userId="test-user",
agentGraphId="agent1",
agentGraphVersion=1,
isCreatedByUser=False,
),
data={
"User": {"connect": {"id": "test-user"}},
"AgentGraph": {
"connect": {"graphVersionId": {"id": "agent1", "version": 1}}
},
"isCreatedByUser": False,
},
include=library_agent_include("test-user"),
)

View File

@@ -121,26 +121,57 @@ def test_get_library_agents_error(mocker: pytest_mock.MockFixture):
)
@pytest.mark.skip(reason="Mocker Not implemented")
def test_add_agent_to_library_success(mocker: pytest_mock.MockFixture):
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
mock_db_call.return_value = None
mock_library_agent = library_model.LibraryAgent(
id="test-library-agent-id",
graph_id="test-agent-1",
graph_version=1,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
credentials_input_schema={"type": "object", "properties": {}},
has_external_trigger=False,
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=True,
is_latest_version=True,
updated_at=FIXED_NOW,
)
response = client.post("/agents/test-version-id")
mock_db_call = mocker.patch(
"backend.server.v2.library.db.add_store_agent_to_library"
)
mock_db_call.return_value = mock_library_agent
response = client.post(
"/agents", json={"store_listing_version_id": "test-version-id"}
)
assert response.status_code == 201
# Verify the response contains the library agent data
data = library_model.LibraryAgent.model_validate(response.json())
assert data.id == "test-library-agent-id"
assert data.graph_id == "test-agent-1"
mock_db_call.assert_called_once_with(
store_listing_version_id="test-version-id", user_id="test-user-id"
)
@pytest.mark.skip(reason="Mocker Not implemented")
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture):
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
mock_db_call = mocker.patch(
"backend.server.v2.library.db.add_store_agent_to_library"
)
mock_db_call.side_effect = Exception("Test error")
response = client.post("/agents/test-version-id")
response = client.post(
"/agents", json={"store_listing_version_id": "test-version-id"}
)
assert response.status_code == 500
assert response.json()["detail"] == "Failed to add agent to library"
assert "detail" in response.json() # Verify error response structure
mock_db_call.assert_called_once_with(
store_listing_version_id="test-version-id", user_id="test-user-id"
)

View File

@@ -259,8 +259,8 @@ def test_ask_otto_unauthenticated(mocker: pytest_mock.MockFixture) -> None:
}
response = client.post("/ask", json=request_data)
# When auth is disabled and Otto API URL is not configured, we get 503
assert response.status_code == 503
# When auth is disabled and Otto API URL is not configured, we get 502 (wrapped from 503)
assert response.status_code == 502
# Restore the override
app.dependency_overrides[autogpt_libs.auth.middleware.auth_middleware] = (

View File

@@ -93,6 +93,14 @@ async def test_get_store_agent_details(mocker):
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
# Mock Profile prisma call
mock_profile = mocker.MagicMock()
mock_profile.userId = "user-id-123"
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
mock_profile_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_profile
)
# Mock StoreListing prisma call - this is what was missing
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(

View File

@@ -19,7 +19,9 @@ from backend.server.ws_api import (
@pytest.fixture
def mock_websocket() -> AsyncMock:
return AsyncMock(spec=WebSocket)
mock = AsyncMock(spec=WebSocket)
mock.query_params = {} # Add query_params attribute for authentication
return mock
@pytest.fixture
@@ -29,8 +31,13 @@ def mock_manager() -> AsyncMock:
@pytest.mark.asyncio
async def test_websocket_router_subscribe(
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker
) -> None:
# Mock the authenticate_websocket function to ensure it returns a valid user_id
mocker.patch(
"backend.server.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
)
mock_websocket.receive_text.side_effect = [
WSMessage(
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
@@ -70,8 +77,13 @@ async def test_websocket_router_subscribe(
@pytest.mark.asyncio
async def test_websocket_router_unsubscribe(
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker
) -> None:
# Mock the authenticate_websocket function to ensure it returns a valid user_id
mocker.patch(
"backend.server.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
)
mock_websocket.receive_text.side_effect = [
WSMessage(
method=WSMethod.UNSUBSCRIBE,
@@ -108,8 +120,13 @@ async def test_websocket_router_unsubscribe(
@pytest.mark.asyncio
async def test_websocket_router_invalid_method(
mock_websocket: AsyncMock, mock_manager: AsyncMock
mock_websocket: AsyncMock, mock_manager: AsyncMock, mocker
) -> None:
# Mock the authenticate_websocket function to ensure it returns a valid user_id
mocker.patch(
"backend.server.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
)
mock_websocket.receive_text.side_effect = [
WSMessage(method=WSMethod.GRAPH_EXECUTION_EVENT).model_dump_json(),
WebSocketDisconnect(),

View File

@@ -113,6 +113,11 @@ ignore_patterns = []
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
filterwarnings = [
"ignore:'audioop' is deprecated:DeprecationWarning:discord.player",
"ignore:invalid escape sequence:DeprecationWarning:tweepy.api",
]
[tool.ruff]
target-version = "py310"

View File

@@ -0,0 +1,3 @@
{
"metric_id": "metric-123-uuid"
}

View File

@@ -0,0 +1,4 @@
{
"metric_id": "metric-float_precision-uuid",
"test_case": "float_precision"
}

View File

@@ -0,0 +1,4 @@
{
"metric_id": "metric-integer_value-uuid",
"test_case": "integer_value"
}

View File

@@ -0,0 +1,4 @@
{
"metric_id": "metric-large_number-uuid",
"test_case": "large_number"
}

View File

@@ -0,0 +1,4 @@
{
"metric_id": "metric-negative_value-uuid",
"test_case": "negative_value"
}

View File

@@ -0,0 +1,4 @@
{
"metric_id": "metric-tiny_number-uuid",
"test_case": "tiny_number"
}

View File

@@ -0,0 +1,4 @@
{
"metric_id": "metric-zero_value-uuid",
"test_case": "zero_value"
}

View File

@@ -15,6 +15,12 @@
"type": "object",
"properties": {}
},
"credentials_input_schema": {
"type": "object",
"properties": {}
},
"has_external_trigger": false,
"trigger_setup_info": null,
"new_output": false,
"can_access_graph": true,
"is_latest_version": true
@@ -34,6 +40,12 @@
"type": "object",
"properties": {}
},
"credentials_input_schema": {
"type": "object",
"properties": {}
},
"has_external_trigger": false,
"trigger_setup_info": null,
"new_output": false,
"can_access_graph": false,
"is_latest_version": true

View File

@@ -1,3 +0,0 @@
import os
os.environ["ENABLE_AUTH"] = "false"

View File

@@ -4,7 +4,6 @@
"private": true,
"scripts": {
"dev": "next dev --turbo",
"dev:test": "NODE_ENV=test && next dev --turbo",
"build": "pnpm run generate:api-client && SKIP_STORYBOOK_TESTS=true next build",
"start": "next start",
"start:standalone": "cd .next/standalone && node server.js",
@@ -28,7 +27,7 @@
"dependencies": {
"@faker-js/faker": "9.8.0",
"@hookform/resolvers": "5.1.1",
"@next/third-parties": "15.3.3",
"@next/third-parties": "15.3.4",
"@phosphor-icons/react": "2.1.10",
"@radix-ui/react-alert-dialog": "1.1.14",
"@radix-ui/react-avatar": "1.1.10",
@@ -49,13 +48,13 @@
"@radix-ui/react-tabs": "1.1.12",
"@radix-ui/react-toast": "1.2.14",
"@radix-ui/react-tooltip": "1.2.7",
"@sentry/nextjs": "9.27.0",
"@sentry/nextjs": "9.33.0",
"@supabase/ssr": "0.6.1",
"@supabase/supabase-js": "2.50.0",
"@tanstack/react-query": "5.80.7",
"@supabase/supabase-js": "2.50.2",
"@tanstack/react-query": "5.81.2",
"@tanstack/react-table": "8.21.3",
"@types/jaro-winkler": "0.2.4",
"@xyflow/react": "12.6.4",
"@xyflow/react": "12.8.0",
"ajv": "8.17.1",
"boring-avatars": "1.11.2",
"class-variance-authority": "0.7.1",
@@ -63,24 +62,24 @@
"cmdk": "1.1.1",
"cookie": "1.0.2",
"date-fns": "4.1.0",
"dotenv": "16.5.0",
"dotenv": "16.6.0",
"elliptic": "6.6.1",
"embla-carousel-react": "8.6.0",
"framer-motion": "12.16.0",
"framer-motion": "12.19.2",
"geist": "1.4.2",
"jaro-winkler": "0.2.8",
"launchdarkly-react-client-sdk": "3.8.1",
"lodash": "4.17.21",
"lucide-react": "0.513.0",
"lucide-react": "0.524.0",
"moment": "2.30.1",
"next": "15.3.3",
"next": "15.3.4",
"next-themes": "0.4.6",
"party-js": "2.2.0",
"react": "18.3.1",
"react-day-picker": "9.7.0",
"react-dom": "18.3.1",
"react-drag-drop-files": "2.4.0",
"react-hook-form": "7.57.0",
"react-hook-form": "7.58.1",
"react-icons": "5.5.0",
"react-markdown": "9.0.3",
"react-modal": "3.16.3",
@@ -90,23 +89,23 @@
"tailwind-merge": "2.6.0",
"tailwindcss-animate": "1.0.7",
"uuid": "11.1.0",
"zod": "3.25.56"
"zod": "3.25.67"
},
"devDependencies": {
"@chromatic-com/storybook": "4.0.0",
"@chromatic-com/storybook": "4.0.1",
"@playwright/test": "1.53.1",
"@storybook/addon-a11y": "9.0.12",
"@storybook/addon-docs": "9.0.12",
"@storybook/addon-links": "9.0.12",
"@storybook/addon-onboarding": "9.0.12",
"@storybook/nextjs": "9.0.12",
"@tanstack/eslint-plugin-query": "5.78.0",
"@tanstack/react-query-devtools": "5.80.10",
"@storybook/addon-a11y": "9.0.13",
"@storybook/addon-docs": "9.0.13",
"@storybook/addon-links": "9.0.13",
"@storybook/addon-onboarding": "9.0.13",
"@storybook/nextjs": "9.0.13",
"@tanstack/eslint-plugin-query": "5.81.2",
"@tanstack/react-query-devtools": "5.81.2",
"@testing-library/jest-dom": "6.6.3",
"@testing-library/react": "16.3.0",
"@testing-library/user-event": "14.6.1",
"@types/canvas-confetti": "1.9.0",
"@types/lodash": "4.17.18",
"@types/lodash": "4.17.19",
"@types/negotiator": "0.6.4",
"@types/node": "22.15.30",
"@types/react": "18.3.17",
@@ -115,20 +114,20 @@
"@vitest/browser": "3.2.4",
"axe-playwright": "2.1.0",
"chromatic": "11.25.2",
"concurrently": "9.1.2",
"concurrently": "9.2.0",
"eslint": "8.57.1",
"eslint-config-next": "15.3.4",
"eslint-plugin-storybook": "9.0.12",
"eslint-plugin-storybook": "9.0.13",
"import-in-the-middle": "1.14.2",
"jsdom": "26.1.0",
"msw": "2.10.2",
"msw-storybook-addon": "2.0.5",
"orval": "7.10.0",
"postcss": "8.5.6",
"prettier": "3.5.3",
"prettier-plugin-tailwindcss": "0.6.12",
"prettier": "3.6.1",
"prettier-plugin-tailwindcss": "0.6.13",
"require-in-the-middle": "7.5.2",
"storybook": "9.0.12",
"storybook": "9.0.13",
"tailwindcss": "3.4.17",
"typescript": "5.8.3",
"vite": "7.0.0",

View File

@@ -34,7 +34,15 @@ export default defineConfig({
bypassCSP: true,
},
/* Maximum time one test can run for */
timeout: 60000,
timeout: 30000,
/* Configure web server to start automatically */
webServer: {
command: "NEXT_PUBLIC_PW_TEST=true pnpm start",
url: "http://localhost:3000",
reuseExistingServer: !process.env.CI,
timeout: 120 * 1000,
},
/* Configure projects for major browsers */
projects: [
@@ -73,15 +81,4 @@ export default defineConfig({
// use: { ...devices['Desktop Chrome'], channel: 'chrome' },
// },
],
/* Run your local server before starting the tests */
webServer: {
command: "pnpm start",
url: "http://localhost:3000/",
reuseExistingServer: !process.env.CI,
timeout: 10 * 1000,
env: {
NODE_ENV: "test",
},
},
});

File diff suppressed because it is too large Load Diff

View File

@@ -75,7 +75,7 @@ export const customMutator = async <T = any>(
return {
status: response.status,
response_data,
data: response_data,
headers: response.headers,
} as T;
};

View File

@@ -2931,18 +2931,6 @@
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Preset Id" }
},
{
"name": "graph_id",
"in": "query",
"required": true,
"schema": { "type": "string", "title": "Graph Id" }
},
{
"name": "graph_version",
"in": "query",
"required": true,
"schema": { "type": "integer", "title": "Graph Version" }
}
],
"requestBody": {
@@ -3128,11 +3116,11 @@
}
}
},
"put": {
"patch": {
"tags": ["v2", "library", "private"],
"summary": "Update Library Agent",
"description": "Update the library agent with the given fields.\n\nArgs:\n library_agent_id: ID of the library agent to update.\n payload: Fields to update (auto_update_version, is_favorite, etc.).\n user_id: ID of the authenticated user.\n\nReturns:\n 204 (No Content) on success.\n\nRaises:\n HTTPException(500): If a server/database error occurs.",
"operationId": "putV2Update library agent",
"description": "Update the library agent with the given fields.\n\nArgs:\n library_agent_id: ID of the library agent to update.\n payload: Fields to update (auto_update_version, is_favorite, etc.).\n user_id: ID of the authenticated user.\n\nRaises:\n HTTPException(500): If a server/database error occurs.",
"operationId": "patchV2Update library agent",
"parameters": [
{
"name": "library_agent_id",
@@ -3152,7 +3140,45 @@
}
},
"responses": {
"204": { "description": "Agent updated successfully" },
"200": {
"description": "Agent updated successfully",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/LibraryAgent" }
}
}
},
"500": { "description": "Server error" },
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
},
"delete": {
"tags": ["v2", "library", "private"],
"summary": "Delete Library Agent",
"description": "Soft-delete the specified library agent.\n\nArgs:\n library_agent_id: ID of the library agent to delete.\n user_id: ID of the authenticated user.\n\nReturns:\n 204 No Content if successful.\n\nRaises:\n HTTPException(404): If the agent does not exist.\n HTTPException(500): If a server/database error occurs.",
"operationId": "deleteV2Delete library agent",
"parameters": [
{
"name": "library_agent_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Library Agent Id" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": { "application/json": { "schema": {} } }
},
"204": { "description": "Agent deleted successfully" },
"404": { "description": "Agent not found" },
"500": { "description": "Server error" },
"422": {
"description": "Validation Error",
@@ -3238,6 +3264,55 @@
}
}
},
"/api/library/agents/{library_agent_id}/setup-trigger": {
"post": {
"tags": ["v2", "library", "private"],
"summary": "Setup Trigger",
"description": "Sets up a webhook-triggered `LibraryAgentPreset` for a `LibraryAgent`.\nReturns the correspondingly created `LibraryAgentPreset` with `webhook_id` set.",
"operationId": "postV2SetupTrigger",
"parameters": [
{
"name": "library_agent_id",
"in": "path",
"required": true,
"schema": {
"type": "string",
"description": "ID of the library agent",
"title": "Library Agent Id"
},
"description": "ID of the library agent"
}
],
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/TriggeredPresetSetupParams"
}
}
}
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/LibraryAgentPreset" }
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/otto/ask": {
"post": {
"tags": ["v2", "otto"],
@@ -3713,10 +3788,10 @@
},
"Body_postV2Execute_a_preset": {
"properties": {
"node_input": {
"inputs": {
"additionalProperties": true,
"type": "object",
"title": "Node Input"
"title": "Inputs"
}
},
"type": "object",
@@ -4303,6 +4378,23 @@
"type": "object",
"title": "Input Schema"
},
"credentials_input_schema": {
"additionalProperties": true,
"type": "object",
"title": "Credentials Input Schema",
"description": "Input schema for credentials required by the agent"
},
"has_external_trigger": {
"type": "boolean",
"title": "Has External Trigger",
"description": "Whether the agent has an external trigger (e.g. webhook) node"
},
"trigger_setup_info": {
"anyOf": [
{ "$ref": "#/components/schemas/LibraryAgentTriggerInfo" },
{ "type": "null" }
]
},
"new_output": { "type": "boolean", "title": "New Output" },
"can_access_graph": {
"type": "boolean",
@@ -4326,6 +4418,8 @@
"name",
"description",
"input_schema",
"credentials_input_schema",
"has_external_trigger",
"new_output",
"can_access_graph",
"is_latest_version"
@@ -4342,6 +4436,13 @@
"type": "object",
"title": "Inputs"
},
"credentials": {
"additionalProperties": {
"$ref": "#/components/schemas/CredentialsMetaInput"
},
"type": "object",
"title": "Credentials"
},
"name": { "type": "string", "title": "Name" },
"description": { "type": "string", "title": "Description" },
"is_active": {
@@ -4349,7 +4450,12 @@
"title": "Is Active",
"default": true
},
"webhook_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Webhook Id"
},
"id": { "type": "string", "title": "Id" },
"user_id": { "type": "string", "title": "User Id" },
"updated_at": {
"type": "string",
"format": "date-time",
@@ -4361,9 +4467,11 @@
"graph_id",
"graph_version",
"inputs",
"credentials",
"name",
"description",
"id",
"user_id",
"updated_at"
],
"title": "LibraryAgentPreset",
@@ -4378,12 +4486,23 @@
"type": "object",
"title": "Inputs"
},
"credentials": {
"additionalProperties": {
"$ref": "#/components/schemas/CredentialsMetaInput"
},
"type": "object",
"title": "Credentials"
},
"name": { "type": "string", "title": "Name" },
"description": { "type": "string", "title": "Description" },
"is_active": {
"type": "boolean",
"title": "Is Active",
"default": true
},
"webhook_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Webhook Id"
}
},
"type": "object",
@@ -4391,6 +4510,7 @@
"graph_id",
"graph_version",
"inputs",
"credentials",
"name",
"description"
],
@@ -4439,6 +4559,18 @@
],
"title": "Inputs"
},
"credentials": {
"anyOf": [
{
"additionalProperties": {
"$ref": "#/components/schemas/CredentialsMetaInput"
},
"type": "object"
},
{ "type": "null" }
],
"title": "Credentials"
},
"name": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Name"
@@ -4481,6 +4613,24 @@
"enum": ["COMPLETED", "HEALTHY", "WAITING", "ERROR"],
"title": "LibraryAgentStatus"
},
"LibraryAgentTriggerInfo": {
"properties": {
"provider": { "$ref": "#/components/schemas/ProviderName" },
"config_schema": {
"additionalProperties": true,
"type": "object",
"title": "Config Schema",
"description": "Input schema for the trigger block"
},
"credentials_input_name": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Credentials Input Name"
}
},
"type": "object",
"required": ["provider", "config_schema", "credentials_input_name"],
"title": "LibraryAgentTriggerInfo"
},
"LibraryAgentUpdateRequest": {
"properties": {
"auto_update_version": {
@@ -4497,11 +4647,6 @@
"anyOf": [{ "type": "boolean" }, { "type": "null" }],
"title": "Is Archived",
"description": "Archive the agent"
},
"is_deleted": {
"anyOf": [{ "type": "boolean" }, { "type": "null" }],
"title": "Is Deleted",
"description": "Delete the agent"
}
},
"type": "object",
@@ -5773,6 +5918,31 @@
"required": ["transactions", "next_transaction_time"],
"title": "TransactionHistory"
},
"TriggeredPresetSetupParams": {
"properties": {
"name": { "type": "string", "title": "Name" },
"description": {
"type": "string",
"title": "Description",
"default": ""
},
"trigger_config": {
"additionalProperties": true,
"type": "object",
"title": "Trigger Config"
},
"agent_credentials": {
"additionalProperties": {
"$ref": "#/components/schemas/CredentialsMetaInput"
},
"type": "object",
"title": "Agent Credentials"
}
},
"type": "object",
"required": ["name", "trigger_config"],
"title": "TriggeredPresetSetupParams"
},
"TurnstileVerifyRequest": {
"properties": {
"token": {
@@ -6055,16 +6225,6 @@
"type": "string",
"title": "Provider Webhook Id"
},
"attached_nodes": {
"anyOf": [
{
"items": { "$ref": "#/components/schemas/NodeModel" },
"type": "array"
},
{ "type": "null" }
],
"title": "Attached Nodes"
},
"url": { "type": "string", "title": "Url", "readOnly": true }
},
"type": "object",

View File

@@ -1,4 +1,14 @@
"use client";
import {
AuthBottomText,
AuthButton,
AuthCard,
AuthFeedback,
AuthHeader,
GoogleOAuthButton,
PasswordInput,
Turnstile,
} from "@/components/auth";
import {
Form,
FormControl,
@@ -8,19 +18,9 @@ import {
FormMessage,
} from "@/components/ui/form";
import { Input } from "@/components/ui/input";
import Link from "next/link";
import LoadingBox from "@/components/ui/loading";
import {
AuthCard,
AuthHeader,
AuthButton,
AuthFeedback,
AuthBottomText,
GoogleOAuthButton,
PasswordInput,
Turnstile,
} from "@/components/auth";
import { getBehaveAs } from "@/lib/utils";
import Link from "next/link";
import { useLoginPage } from "./useLoginPage";
export default function LoginPage() {
@@ -100,7 +100,7 @@ export default function LoginPage() {
<FormLabel className="flex w-full items-center justify-between">
<span>Password</span>
<Link
href="/reset_password"
href="/reset-password"
className="text-sm font-normal leading-normal text-black underline"
>
Forgot your password?

View File

@@ -0,0 +1,102 @@
"use client";
import { Loader2, MoreVertical } from "lucide-react";
import { Button } from "@/components/ui/button";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { Badge } from "@/components/ui/badge";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { useAPISection } from "./useAPISection";
export function APIKeysSection() {
const { apiKeys, isLoading, isDeleting, handleRevokeKey } = useAPISection();
return (
<>
{isLoading ? (
<div className="flex justify-center p-4">
<Loader2 className="h-6 w-6 animate-spin" />
</div>
) : (
apiKeys &&
apiKeys.length > 0 && (
<Table>
<TableHeader>
<TableRow>
<TableHead>Name</TableHead>
<TableHead>API Key</TableHead>
<TableHead>Status</TableHead>
<TableHead>Created</TableHead>
<TableHead>Last Used</TableHead>
<TableHead></TableHead>
</TableRow>
</TableHeader>
<TableBody>
{apiKeys.map((key) => (
<TableRow key={key.id}>
<TableCell>{key.name}</TableCell>
<TableCell>
<div className="rounded-md border p-1 px-2 text-xs">
{`${key.prefix}******************${key.postfix}`}
</div>
</TableCell>
<TableCell>
<Badge
variant={
key.status === "ACTIVE" ? "default" : "destructive"
}
className={
key.status === "ACTIVE"
? "border-green-600 bg-green-100 text-green-800"
: "border-red-600 bg-red-100 text-red-800"
}
>
{key.status}
</Badge>
</TableCell>
<TableCell>
{new Date(key.created_at).toLocaleDateString()}
</TableCell>
<TableCell>
{key.last_used_at
? new Date(key.last_used_at).toLocaleDateString()
: "Never"}
</TableCell>
<TableCell>
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button variant="ghost" size="sm">
<MoreVertical className="h-4 w-4" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
className="text-destructive"
onClick={() => handleRevokeKey(key.id)}
disabled={isDeleting}
>
Revoke
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
)
)}
</>
);
}

View File

@@ -0,0 +1,61 @@
"use client";
import {
getGetV1ListUserApiKeysQueryKey,
useDeleteV1RevokeApiKey,
useGetV1ListUserApiKeys,
} from "@/api/__generated__/endpoints/api-keys/api-keys";
import { APIKeyWithoutHash } from "@/api/__generated__/models/aPIKeyWithoutHash";
import { useToast } from "@/components/ui/use-toast";
import { getQueryClient } from "@/lib/react-query/queryClient";
export const useAPISection = () => {
const queryClient = getQueryClient();
const { toast } = useToast();
const { data: apiKeys, isLoading } = useGetV1ListUserApiKeys({
query: {
select: (res) => {
return (res.data as APIKeyWithoutHash[]).filter(
(key) => key.status === "ACTIVE",
);
},
},
});
const { mutateAsync: revokeAPIKey, isPending: isDeleting } =
useDeleteV1RevokeApiKey({
mutation: {
onSettled: () => {
return queryClient.invalidateQueries({
queryKey: getGetV1ListUserApiKeysQueryKey(),
});
},
},
});
const handleRevokeKey = async (keyId: string) => {
try {
await revokeAPIKey({
keyId: keyId,
});
toast({
title: "Success",
description: "AutoGPT Platform API key revoked successfully",
});
} catch {
toast({
title: "Error",
description: "Failed to revoke AutoGPT Platform API key",
variant: "destructive",
});
}
};
return {
apiKeys,
isLoading,
isDeleting,
handleRevokeKey,
};
};

View File

@@ -0,0 +1,133 @@
"use client";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
DialogTrigger,
} from "@/components/ui/dialog";
import { LuCopy } from "react-icons/lu";
import { Label } from "@/components/ui/label";
import { Input } from "@/components/ui/input";
import { Checkbox } from "@/components/ui/checkbox";
import { Button } from "@/components/ui/button";
import { useAPIkeysModals } from "./useAPIkeysModals";
import { APIKeyPermission } from "@/api/__generated__/models/aPIKeyPermission";
export const APIKeysModals = () => {
const {
isCreating,
handleCreateKey,
handleCopyKey,
setIsCreateOpen,
setIsKeyDialogOpen,
isCreateOpen,
isKeyDialogOpen,
keyState,
setKeyState,
} = useAPIkeysModals();
return (
<div className="mb-4 flex justify-end">
<Dialog open={isCreateOpen} onOpenChange={setIsCreateOpen}>
<DialogTrigger asChild>
<Button>Create Key</Button>
</DialogTrigger>
<DialogContent>
<DialogHeader>
<DialogTitle>Create New API Key</DialogTitle>
<DialogDescription>
Create a new AutoGPT Platform API key
</DialogDescription>
</DialogHeader>
<div className="grid gap-4 py-4">
<div className="grid gap-2">
<Label htmlFor="name">Name</Label>
<Input
id="name"
value={keyState.newKeyName}
onChange={(e) =>
setKeyState((prev) => ({
...prev,
newKeyName: e.target.value,
}))
}
placeholder="My AutoGPT Platform API Key"
/>
</div>
<div className="grid gap-2">
<Label htmlFor="description">Description (Optional)</Label>
<Input
id="description"
value={keyState.newKeyDescription}
onChange={(e) =>
setKeyState((prev) => ({
...prev,
newKeyDescription: e.target.value,
}))
}
placeholder="Used for..."
/>
</div>
<div className="grid gap-2">
<Label>Permissions</Label>
{Object.values(APIKeyPermission).map((permission) => (
<div className="flex items-center space-x-2" key={permission}>
<Checkbox
id={permission}
checked={keyState.selectedPermissions.includes(permission)}
onCheckedChange={(checked: boolean) => {
setKeyState((prev) => ({
...prev,
selectedPermissions: checked
? [...prev.selectedPermissions, permission]
: prev.selectedPermissions.filter(
(p) => p !== permission,
),
}));
}}
/>
<Label htmlFor={permission}>{permission}</Label>
</div>
))}
</div>
</div>
<DialogFooter>
<Button variant="outline" onClick={() => setIsCreateOpen(false)}>
Cancel
</Button>
<Button onClick={handleCreateKey} disabled={isCreating}>
Create
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
<Dialog open={isKeyDialogOpen} onOpenChange={setIsKeyDialogOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle>AutoGPT Platform API Key Created</DialogTitle>
<DialogDescription>
Please copy your AutoGPT API key now. You won&apos;t be able to
see it again!
</DialogDescription>
</DialogHeader>
<div className="flex items-center space-x-2">
<code className="flex-1 rounded-md bg-secondary p-2 text-sm">
{keyState.newApiKey}
</code>
<Button size="icon" variant="outline" onClick={handleCopyKey}>
<LuCopy className="h-4 w-4" />
</Button>
</div>
<DialogFooter>
<Button onClick={() => setIsKeyDialogOpen(false)}>Close</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</div>
);
};

View File

@@ -0,0 +1,78 @@
"use client";
import {
getGetV1ListUserApiKeysQueryKey,
usePostV1CreateNewApiKey,
} from "@/api/__generated__/endpoints/api-keys/api-keys";
import { APIKeyPermission } from "@/api/__generated__/models/aPIKeyPermission";
import { CreateAPIKeyResponse } from "@/api/__generated__/models/createAPIKeyResponse";
import { useToast } from "@/components/ui/use-toast";
import { getQueryClient } from "@/lib/react-query/queryClient";
import { useState } from "react";
export const useAPIkeysModals = () => {
const [isCreateOpen, setIsCreateOpen] = useState(false);
const [isKeyDialogOpen, setIsKeyDialogOpen] = useState(false);
const [keyState, setKeyState] = useState({
newKeyName: "",
newKeyDescription: "",
newApiKey: "",
selectedPermissions: [] as APIKeyPermission[],
});
const queryClient = getQueryClient();
const { toast } = useToast();
const { mutateAsync: createAPIKey, isPending: isCreating } =
usePostV1CreateNewApiKey({
mutation: {
onSettled: () => {
return queryClient.invalidateQueries({
queryKey: getGetV1ListUserApiKeysQueryKey(),
});
},
},
});
const handleCreateKey = async () => {
try {
const response = await createAPIKey({
data: {
name: keyState.newKeyName,
permissions: keyState.selectedPermissions,
description: keyState.newKeyDescription,
},
});
setKeyState((prev) => ({
...prev,
newApiKey: (response.data as CreateAPIKeyResponse).plain_text_key,
}));
setIsCreateOpen(false);
setIsKeyDialogOpen(true);
} catch {
toast({
title: "Error",
description: "Failed to create AutoGPT Platform API key",
variant: "destructive",
});
}
};
const handleCopyKey = () => {
navigator.clipboard.writeText(keyState.newApiKey);
toast({
title: "Copied",
description: "AutoGPT Platform API key copied to clipboard",
});
};
return {
isCreating,
handleCreateKey,
handleCopyKey,
setIsCreateOpen,
setIsKeyDialogOpen,
isCreateOpen,
isKeyDialogOpen,
keyState,
setKeyState,
};
};

View File

@@ -1,12 +1,31 @@
import { Metadata } from "next/types";
import { APIKeysSection } from "@/components/agptui/composite/APIKeySection";
import { APIKeysSection } from "@/app/(platform)/profile/(user)/api_keys/components/APIKeySection/APIKeySection";
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle,
} from "@/components/ui/card";
import { APIKeysModals } from "./components/APIKeysModals/APIKeysModals";
export const metadata: Metadata = { title: "API Keys - AutoGPT Platform" };
const ApiKeysPage = () => {
return (
<div className="w-full pr-4 pt-24 md:pt-0">
<APIKeysSection />
<Card>
<CardHeader>
<CardTitle>AutoGPT Platform API Keys</CardTitle>
<CardDescription>
Manage your AutoGPT Platform API keys for programmatic access
</CardDescription>
</CardHeader>
<CardContent>
<APIKeysModals />
<APIKeysSection />
</CardContent>
</Card>
</div>
);
};

View File

@@ -5,6 +5,7 @@ import { useCallback, useContext, useEffect, useMemo, useState } from "react";
import { useToast } from "@/components/ui/use-toast";
import { IconKey, IconUser } from "@/components/ui/icons";
import { Trash2Icon } from "lucide-react";
import { KeyIcon } from "@phosphor-icons/react/dist/ssr";
import { providerIcons } from "@/components/integrations/credentials-input";
import { CredentialsProvidersContext } from "@/components/integrations/credentials-provider";
import {
@@ -140,11 +141,12 @@ export default function UserIntegrationsPage() {
...credentials,
provider: provider.provider,
providerName: provider.providerName,
ProviderIcon: providerIcons[provider.provider],
ProviderIcon: providerIcons[provider.provider] || KeyIcon,
TypeIcon: {
oauth2: IconUser,
api_key: IconKey,
user_password: IconKey,
host_scoped: IconKey,
}[credentials.type],
})),
)
@@ -181,6 +183,7 @@ export default function UserIntegrationsPage() {
oauth2: "OAuth2 credentials",
api_key: "API key",
user_password: "Username & password",
host_scoped: "Host-scoped credentials",
}[cred.type]
}{" "}
- <code>{cred.id}</code>

View File

@@ -2,8 +2,11 @@
import { revalidatePath } from "next/cache";
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
import BackendApi from "@/lib/autogpt-server-api";
import { NotificationPreferenceDTO } from "@/lib/autogpt-server-api/types";
import {
postV1UpdateNotificationPreferences,
postV1UpdateUserEmail,
} from "@/api/__generated__/endpoints/auth/auth";
export async function updateSettings(formData: FormData) {
const supabase = await getServerSupabase();
@@ -29,8 +32,7 @@ export async function updateSettings(formData: FormData) {
const { error: emailError } = await supabase.auth.updateUser({
email,
});
const api = new BackendApi();
await api.updateUserEmail(email);
await postV1UpdateUserEmail(email);
if (emailError) {
throw new Error(`${emailError.message}`);
@@ -38,7 +40,6 @@ export async function updateSettings(formData: FormData) {
}
try {
const api = new BackendApi();
const preferences: NotificationPreferenceDTO = {
email: user?.email || "",
preferences: {
@@ -55,7 +56,7 @@ export async function updateSettings(formData: FormData) {
},
daily_limit: 0,
};
await api.updateUserPreferences(preferences);
await postV1UpdateNotificationPreferences(preferences);
} catch (error) {
console.error(error);
throw new Error(`Failed to update preferences: ${error}`);
@@ -64,9 +65,3 @@ export async function updateSettings(formData: FormData) {
revalidatePath("/profile/settings");
return { success: true };
}
export async function getUserPreferences(): Promise<NotificationPreferenceDTO> {
const api = new BackendApi();
const preferences = await api.getUserPreferences();
return preferences;
}

View File

@@ -1,10 +1,5 @@
"use client";
import { zodResolver } from "@hookform/resolvers/zod";
import { useForm } from "react-hook-form";
import * as z from "zod";
import { User } from "@supabase/supabase-js";
import { Button } from "@/components/ui/button";
import {
Form,
@@ -18,100 +13,22 @@ import {
import { Input } from "@/components/ui/input";
import { Switch } from "@/components/ui/switch";
import { Separator } from "@/components/ui/separator";
import { updateSettings } from "@/app/(platform)/profile/(user)/settings/actions";
import { toast } from "@/components/ui/use-toast";
import { NotificationPreferenceDTO } from "@/lib/autogpt-server-api";
import { NotificationPreference } from "@/api/__generated__/models/notificationPreference";
import { User } from "@supabase/supabase-js";
import { useSettingsForm } from "./useSettingsForm";
const formSchema = z
.object({
email: z.string().email(),
password: z
.string()
.optional()
.refine((val) => {
// If password is provided, it must be at least 8 characters
if (val) return val.length >= 12;
return true;
}, "String must contain at least 12 character(s)"),
confirmPassword: z.string().optional(),
notifyOnAgentRun: z.boolean(),
notifyOnZeroBalance: z.boolean(),
notifyOnLowBalance: z.boolean(),
notifyOnBlockExecutionFailed: z.boolean(),
notifyOnContinuousAgentError: z.boolean(),
notifyOnDailySummary: z.boolean(),
notifyOnWeeklySummary: z.boolean(),
notifyOnMonthlySummary: z.boolean(),
})
.refine(
(data) => {
if (data.password || data.confirmPassword) {
return data.password === data.confirmPassword;
}
return true;
},
{
message: "Passwords do not match",
path: ["confirmPassword"],
},
);
interface SettingsFormProps {
export const SettingsForm = ({
preferences,
user,
}: {
preferences: NotificationPreference;
user: User;
preferences: NotificationPreferenceDTO;
}
export default function SettingsForm({ user, preferences }: SettingsFormProps) {
const defaultValues = {
email: user.email || "",
password: "",
confirmPassword: "",
notifyOnAgentRun: preferences.preferences.AGENT_RUN,
notifyOnZeroBalance: preferences.preferences.ZERO_BALANCE,
notifyOnLowBalance: preferences.preferences.LOW_BALANCE,
notifyOnBlockExecutionFailed:
preferences.preferences.BLOCK_EXECUTION_FAILED,
notifyOnContinuousAgentError:
preferences.preferences.CONTINUOUS_AGENT_ERROR,
notifyOnDailySummary: preferences.preferences.DAILY_SUMMARY,
notifyOnWeeklySummary: preferences.preferences.WEEKLY_SUMMARY,
notifyOnMonthlySummary: preferences.preferences.MONTHLY_SUMMARY,
};
const form = useForm<z.infer<typeof formSchema>>({
resolver: zodResolver(formSchema),
defaultValues,
}) => {
const { form, onSubmit, onCancel } = useSettingsForm({
preferences,
user,
});
async function onSubmit(values: z.infer<typeof formSchema>) {
try {
const formData = new FormData();
Object.entries(values).forEach(([key, value]) => {
if (key !== "confirmPassword") {
formData.append(key, value.toString());
}
});
await updateSettings(formData);
toast({
title: "Successfully updated settings",
});
} catch (error) {
toast({
title: "Error",
description:
error instanceof Error ? error.message : "Something went wrong",
variant: "destructive",
});
throw error;
}
}
function onCancel() {
form.reset(defaultValues);
}
return (
<Form {...form}>
<form
@@ -396,4 +313,4 @@ export default function SettingsForm({ user, preferences }: SettingsFormProps) {
</form>
</Form>
);
}
};

View File

@@ -0,0 +1,51 @@
import { z } from "zod";
export const formSchema = z
.object({
email: z.string().email(),
password: z
.string()
.optional()
.refine((val) => {
if (val) return val.length >= 12;
return true;
}, "String must contain at least 12 character(s)"),
confirmPassword: z.string().optional(),
notifyOnAgentRun: z.boolean(),
notifyOnZeroBalance: z.boolean(),
notifyOnLowBalance: z.boolean(),
notifyOnBlockExecutionFailed: z.boolean(),
notifyOnContinuousAgentError: z.boolean(),
notifyOnDailySummary: z.boolean(),
notifyOnWeeklySummary: z.boolean(),
notifyOnMonthlySummary: z.boolean(),
})
.refine((data) => {
if (data.password || data.confirmPassword) {
return data.password === data.confirmPassword;
}
return true;
});
export const createDefaultValues = (
user: { email?: string },
preferences: { preferences?: Record<string, boolean> },
) => {
const defaultValues = {
email: user.email || "",
password: "",
confirmPassword: "",
notifyOnAgentRun: preferences.preferences?.AGENT_RUN,
notifyOnZeroBalance: preferences.preferences?.ZERO_BALANCE,
notifyOnLowBalance: preferences.preferences?.LOW_BALANCE,
notifyOnBlockExecutionFailed:
preferences.preferences?.BLOCK_EXECUTION_FAILED,
notifyOnContinuousAgentError:
preferences.preferences?.CONTINUOUS_AGENT_ERROR,
notifyOnDailySummary: preferences.preferences?.DAILY_SUMMARY,
notifyOnWeeklySummary: preferences.preferences?.WEEKLY_SUMMARY,
notifyOnMonthlySummary: preferences.preferences?.MONTHLY_SUMMARY,
};
return defaultValues;
};

View File

@@ -0,0 +1,57 @@
"use client";
import { useForm } from "react-hook-form";
import { createDefaultValues, formSchema } from "./helper";
import { z } from "zod";
import { zodResolver } from "@hookform/resolvers/zod";
import { updateSettings } from "../../actions";
import { useToast } from "@/components/ui/use-toast";
import { NotificationPreference } from "@/api/__generated__/models/notificationPreference";
import { User } from "@supabase/supabase-js";
export const useSettingsForm = ({
preferences,
user,
}: {
preferences: NotificationPreference;
user: User;
}) => {
const { toast } = useToast();
const defaultValues = createDefaultValues(user, preferences);
const form = useForm<z.infer<typeof formSchema>>({
resolver: zodResolver(formSchema),
defaultValues,
});
async function onSubmit(values: z.infer<typeof formSchema>) {
try {
const formData = new FormData();
Object.entries(values).forEach(([key, value]) => {
if (key !== "confirmPassword") {
formData.append(key, value.toString());
}
});
await updateSettings(formData);
toast({
title: "Successfully updated settings",
});
} catch (error) {
toast({
title: "Error",
description:
error instanceof Error ? error.message : "Something went wrong",
variant: "destructive",
});
throw error;
}
}
function onCancel() {
form.reset(defaultValues);
}
return { form, onSubmit, onCancel };
};

View File

@@ -1,23 +1,37 @@
"use client";
import { useGetV1GetNotificationPreferences } from "@/api/__generated__/endpoints/auth/auth";
import { SettingsForm } from "@/app/(platform)/profile/(user)/settings/components/SettingsForm/SettingsForm";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import * as React from "react";
import { Metadata } from "next";
import SettingsForm from "@/components/profile/settings/SettingsForm";
import { getServerUser } from "@/lib/supabase/server/getServerUser";
import SettingsLoading from "./loading";
import { redirect } from "next/navigation";
import { getUserPreferences } from "./actions";
export const metadata: Metadata = {
title: "Settings - AutoGPT Platform",
description: "Manage your account settings and preferences.",
};
export default function SettingsPage() {
const {
data: preferences,
isError,
isLoading,
} = useGetV1GetNotificationPreferences({
query: {
select: (res) => {
return res.data;
},
},
});
export default async function SettingsPage() {
const { user, error } = await getServerUser();
const { user, isUserLoading } = useSupabase();
if (error || !user) {
if (isLoading || isUserLoading) {
return <SettingsLoading />;
}
if (!user) {
redirect("/login");
}
const preferences = await getUserPreferences();
if (isError || !preferences || !preferences.preferences) {
return "Errror..."; // TODO: Will use a Error reusable components from Block Menu redesign
}
return (
<div className="container max-w-2xl space-y-6 py-10">
@@ -27,7 +41,7 @@ export default async function SettingsPage() {
Manage your account settings and preferences.
</p>
</div>
<SettingsForm user={user} preferences={preferences} />
<SettingsForm preferences={preferences} user={user} />
</div>
);
}

View File

@@ -1,8 +1,8 @@
"use server";
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
import { redirect } from "next/navigation";
import * as Sentry from "@sentry/nextjs";
import { verifyTurnstileToken } from "@/lib/turnstile";
import * as Sentry from "@sentry/nextjs";
import { redirect } from "next/navigation";
export async function sendResetEmail(email: string, turnstileToken: string) {
return await Sentry.withServerActionInstrumentation(
@@ -19,14 +19,14 @@ export async function sendResetEmail(email: string, turnstileToken: string) {
// Verify Turnstile token if provided
const success = await verifyTurnstileToken(
turnstileToken,
"reset_password",
"reset-password",
);
if (!success) {
return "CAPTCHA verification failed. Please try again.";
}
const { error } = await supabase.auth.resetPasswordForEmail(email, {
redirectTo: `${origin}/reset_password`,
redirectTo: `${origin}/reset-password`,
});
if (error) {

View File

@@ -1,9 +1,9 @@
"use client";
import {
AuthCard,
AuthHeader,
AuthButton,
AuthCard,
AuthFeedback,
AuthHeader,
PasswordInput,
Turnstile,
} from "@/components/auth";
@@ -17,16 +17,16 @@ import {
FormMessage,
} from "@/components/ui/form";
import { Input } from "@/components/ui/input";
import LoadingBox from "@/components/ui/loading";
import { useTurnstile } from "@/hooks/useTurnstile";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { sendEmailFormSchema, changePasswordFormSchema } from "@/types/auth";
import { getBehaveAs } from "@/lib/utils";
import { changePasswordFormSchema, sendEmailFormSchema } from "@/types/auth";
import { zodResolver } from "@hookform/resolvers/zod";
import { useCallback, useState } from "react";
import { useForm } from "react-hook-form";
import { z } from "zod";
import { changePassword, sendResetEmail } from "./actions";
import LoadingBox from "@/components/ui/loading";
import { getBehaveAs } from "@/lib/utils";
import { useTurnstile } from "@/hooks/useTurnstile";
export default function ResetPasswordPage() {
const { supabase, user, isUserLoading } = useSupabase();

View File

@@ -1,24 +1,21 @@
"use client";
import * as React from "react";
import { useState } from "react";
import Image from "next/image";
import { Button } from "./Button";
import { IconPersonFill } from "@/components/ui/icons";
import { ProfileDetails } from "@/lib/autogpt-server-api/types";
import { Separator } from "@/components/ui/separator";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { ProfileDetails } from "@/lib/autogpt-server-api/types";
import { Button } from "./Button";
export const ProfileInfoForm = ({ profile }: { profile: ProfileDetails }) => {
export function ProfileInfoForm({ profile }: { profile: ProfileDetails }) {
const [isSubmitting, setIsSubmitting] = useState(false);
const [profileData, setProfileData] = useState<ProfileDetails>(profile);
const { supabase } = useSupabase();
const api = useBackendAPI();
const submitForm = async () => {
async function submitForm() {
try {
setIsSubmitting(true);
@@ -39,48 +36,12 @@ export const ProfileInfoForm = ({ profile }: { profile: ProfileDetails }) => {
} finally {
setIsSubmitting(false);
}
};
}
const handleImageUpload = async (file: File) => {
async function handleImageUpload(file: File) {
try {
// Create FormData and append file
const formData = new FormData();
formData.append("file", file);
const mediaUrl = await api.uploadStoreSubmissionMedia(file);
// Get auth token
if (!supabase) {
throw new Error("Supabase client not initialized");
}
const {
data: { session },
} = await supabase.auth.getSession();
const token = session?.access_token;
if (!token) {
throw new Error("No authentication token found");
}
// Make upload request
const response = await fetch(
`${process.env.NEXT_PUBLIC_AGPT_SERVER_URL}/store/submissions/media`,
{
method: "POST",
headers: {
Authorization: `Bearer ${token}`,
},
body: formData,
},
);
if (!response.ok) {
throw new Error(`Upload failed: ${response.statusText}`);
}
// Get media URL from response
const mediaUrl = await response.json();
// Update profile with new avatar URL
const updatedProfile = {
...profileData,
avatar_url: mediaUrl,
@@ -91,7 +52,7 @@ export const ProfileInfoForm = ({ profile }: { profile: ProfileDetails }) => {
} catch (error) {
console.error("Error uploading image:", error);
}
};
}
return (
<div className="w-full min-w-[800px] px-4 sm:px-8">
@@ -261,4 +222,4 @@ export const ProfileInfoForm = ({ profile }: { profile: ProfileDetails }) => {
</div>
</div>
);
};
}

View File

@@ -1,296 +0,0 @@
"use client";
import { useState, useEffect } from "react";
import { APIKey, APIKeyPermission } from "@/lib/autogpt-server-api/types";
import { LuCopy } from "react-icons/lu";
import { Loader2, MoreVertical } from "lucide-react";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { useToast } from "@/components/ui/use-toast";
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle,
} from "@/components/ui/card";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
DialogTrigger,
} from "@/components/ui/dialog";
import { Button } from "@/components/ui/button";
import { Label } from "@/components/ui/label";
import { Input } from "@/components/ui/input";
import { Checkbox } from "@/components/ui/checkbox";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { Badge } from "@/components/ui/badge";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
export function APIKeysSection() {
const [apiKeys, setApiKeys] = useState<APIKey[]>([]);
const [isLoading, setIsLoading] = useState(true);
const [isCreateOpen, setIsCreateOpen] = useState(false);
const [isKeyDialogOpen, setIsKeyDialogOpen] = useState(false);
const [newKeyName, setNewKeyName] = useState("");
const [newKeyDescription, setNewKeyDescription] = useState("");
const [newApiKey, setNewApiKey] = useState("");
const [selectedPermissions, setSelectedPermissions] = useState<
APIKeyPermission[]
>([]);
const { toast } = useToast();
const api = useBackendAPI();
useEffect(() => {
loadAPIKeys();
}, []);
const loadAPIKeys = async () => {
setIsLoading(true);
try {
const keys = await api.listAPIKeys();
setApiKeys(keys.filter((key) => key.status === "ACTIVE"));
} finally {
setIsLoading(false);
}
};
const handleCreateKey = async () => {
try {
const response = await api.createAPIKey(
newKeyName,
selectedPermissions,
newKeyDescription,
);
setNewApiKey(response.plain_text_key);
setIsCreateOpen(false);
setIsKeyDialogOpen(true);
loadAPIKeys();
} catch {
toast({
title: "Error",
description: "Failed to create AutoGPT Platform API key",
variant: "destructive",
});
}
};
const handleCopyKey = () => {
navigator.clipboard.writeText(newApiKey);
toast({
title: "Copied",
description: "AutoGPT Platform API key copied to clipboard",
});
};
const handleRevokeKey = async (keyId: string) => {
try {
await api.revokeAPIKey(keyId);
toast({
title: "Success",
description: "AutoGPT Platform API key revoked successfully",
});
loadAPIKeys();
} catch {
toast({
title: "Error",
description: "Failed to revoke AutoGPT Platform API key",
variant: "destructive",
});
}
};
return (
<Card>
<CardHeader>
<CardTitle>AutoGPT Platform API Keys</CardTitle>
<CardDescription>
Manage your AutoGPT Platform API keys for programmatic access
</CardDescription>
</CardHeader>
<CardContent>
<div className="mb-4 flex justify-end">
<Dialog open={isCreateOpen} onOpenChange={setIsCreateOpen}>
<DialogTrigger asChild>
<Button>Create Key</Button>
</DialogTrigger>
<DialogContent>
<DialogHeader>
<DialogTitle>Create New API Key</DialogTitle>
<DialogDescription>
Create a new AutoGPT Platform API key
</DialogDescription>
</DialogHeader>
<div className="grid gap-4 py-4">
<div className="grid gap-2">
<Label htmlFor="name">Name</Label>
<Input
id="name"
value={newKeyName}
onChange={(e) => setNewKeyName(e.target.value)}
placeholder="My AutoGPT Platform API Key"
/>
</div>
<div className="grid gap-2">
<Label htmlFor="description">Description (Optional)</Label>
<Input
id="description"
value={newKeyDescription}
onChange={(e) => setNewKeyDescription(e.target.value)}
placeholder="Used for..."
/>
</div>
<div className="grid gap-2">
<Label>Permissions</Label>
{Object.values(APIKeyPermission).map((permission) => (
<div
className="flex items-center space-x-2"
key={permission}
>
<Checkbox
id={permission}
checked={selectedPermissions.includes(permission)}
onCheckedChange={(checked: boolean) => {
setSelectedPermissions(
checked
? [...selectedPermissions, permission]
: selectedPermissions.filter(
(p) => p !== permission,
),
);
}}
/>
<Label htmlFor={permission}>{permission}</Label>
</div>
))}
</div>
</div>
<DialogFooter>
<Button
variant="outline"
onClick={() => setIsCreateOpen(false)}
>
Cancel
</Button>
<Button onClick={handleCreateKey}>Create</Button>
</DialogFooter>
</DialogContent>
</Dialog>
<Dialog open={isKeyDialogOpen} onOpenChange={setIsKeyDialogOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle>AutoGPT Platform API Key Created</DialogTitle>
<DialogDescription>
Please copy your AutoGPT API key now. You won&apos;t be able
to see it again!
</DialogDescription>
</DialogHeader>
<div className="flex items-center space-x-2">
<code className="flex-1 rounded-md bg-secondary p-2 text-sm">
{newApiKey}
</code>
<Button size="icon" variant="outline" onClick={handleCopyKey}>
<LuCopy className="h-4 w-4" />
</Button>
</div>
<DialogFooter>
<Button onClick={() => setIsKeyDialogOpen(false)}>Close</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</div>
{isLoading ? (
<div className="flex justify-center p-4">
<Loader2 className="h-6 w-6 animate-spin" />
</div>
) : (
apiKeys.length > 0 && (
<Table>
<TableHeader>
<TableRow>
<TableHead>Name</TableHead>
<TableHead>API Key</TableHead>
<TableHead>Status</TableHead>
<TableHead>Created</TableHead>
<TableHead>Last Used</TableHead>
<TableHead></TableHead>
</TableRow>
</TableHeader>
<TableBody>
{apiKeys.map((key) => (
<TableRow key={key.id}>
<TableCell>{key.name}</TableCell>
<TableCell>
<div className="rounded-md border p-1 px-2 text-xs">
{`${key.prefix}******************${key.postfix}`}
</div>
</TableCell>
<TableCell>
<Badge
variant={
key.status === "ACTIVE" ? "default" : "destructive"
}
className={
key.status === "ACTIVE"
? "border-green-600 bg-green-100 text-green-800"
: "border-red-600 bg-red-100 text-red-800"
}
>
{key.status}
</Badge>
</TableCell>
<TableCell>
{new Date(key.created_at).toLocaleDateString()}
</TableCell>
<TableCell>
{key.last_used_at
? new Date(key.last_used_at).toLocaleDateString()
: "Never"}
</TableCell>
<TableCell>
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button variant="ghost" size="sm">
<MoreVertical className="h-4 w-4" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
className="text-destructive"
onClick={() => handleRevokeKey(key.id)}
>
Revoke
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
)
)}
</CardContent>
</Card>
);
}

View File

@@ -0,0 +1,163 @@
import { FC } from "react";
import { z } from "zod";
import { useForm } from "react-hook-form";
import { zodResolver } from "@hookform/resolvers/zod";
import { Input } from "@/components/ui/input";
import { Button } from "@/components/ui/button";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import {
Form,
FormControl,
FormDescription,
FormField,
FormItem,
FormLabel,
FormMessage,
} from "@/components/ui/form";
import useCredentials from "@/hooks/useCredentials";
import {
BlockIOCredentialsSubSchema,
CredentialsMetaInput,
} from "@/lib/autogpt-server-api/types";
export const APIKeyCredentialsModal: FC<{
schema: BlockIOCredentialsSubSchema;
open: boolean;
onClose: () => void;
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
siblingInputs?: Record<string, any>;
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
const credentials = useCredentials(schema, siblingInputs);
const formSchema = z.object({
apiKey: z.string().min(1, "API Key is required"),
title: z.string().min(1, "Name is required"),
expiresAt: z.string().optional(),
});
const form = useForm<z.infer<typeof formSchema>>({
resolver: zodResolver(formSchema),
defaultValues: {
apiKey: "",
title: "",
expiresAt: "",
},
});
if (!credentials || credentials.isLoading || !credentials.supportsApiKey) {
return null;
}
const { provider, providerName, createAPIKeyCredentials } = credentials;
async function onSubmit(values: z.infer<typeof formSchema>) {
const expiresAt = values.expiresAt
? new Date(values.expiresAt).getTime() / 1000
: undefined;
const newCredentials = await createAPIKeyCredentials({
api_key: values.apiKey,
title: values.title,
expires_at: expiresAt,
});
onCredentialsCreate({
provider,
id: newCredentials.id,
type: "api_key",
title: newCredentials.title,
});
}
return (
<Dialog
open={open}
onOpenChange={(open) => {
if (!open) onClose();
}}
>
<DialogContent>
<DialogHeader>
<DialogTitle>Add new API key for {providerName}</DialogTitle>
{schema.description && (
<DialogDescription>{schema.description}</DialogDescription>
)}
</DialogHeader>
<Form {...form}>
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
<FormField
control={form.control}
name="apiKey"
render={({ field }) => (
<FormItem>
<FormLabel>API Key</FormLabel>
{schema.credentials_scopes && (
<FormDescription>
Required scope(s) for this block:{" "}
{schema.credentials_scopes?.map((s, i, a) => (
<span key={i}>
<code>{s}</code>
{i < a.length - 1 && ", "}
</span>
))}
</FormDescription>
)}
<FormControl>
<Input
type="password"
placeholder="Enter API key..."
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="title"
render={({ field }) => (
<FormItem>
<FormLabel>Name</FormLabel>
<FormControl>
<Input
type="text"
placeholder="Enter a name for this API key..."
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="expiresAt"
render={({ field }) => (
<FormItem>
<FormLabel>Expiration Date (Optional)</FormLabel>
<FormControl>
<Input
type="datetime-local"
placeholder="Select expiration date..."
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<Button type="submit" className="w-full">
Save & use this API key
</Button>
</form>
</Form>
</DialogContent>
</Dialog>
);
};

View File

@@ -1,12 +1,8 @@
import { FC, useEffect, useMemo, useState } from "react";
import { z } from "zod";
import { cn } from "@/lib/utils";
import { useForm } from "react-hook-form";
import { Input } from "@/components/ui/input";
import { Button } from "@/components/ui/button";
import SchemaTooltip from "@/components/SchemaTooltip";
import useCredentials from "@/hooks/useCredentials";
import { zodResolver } from "@hookform/resolvers/zod";
import { NotionLogoIcon } from "@radix-ui/react-icons";
import {
FaDiscord,
@@ -23,22 +19,6 @@ import {
CredentialsProviderName,
} from "@/lib/autogpt-server-api/types";
import { IconKey, IconKeyPlus, IconUserPlus } from "@/components/ui/icons";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import {
Form,
FormControl,
FormDescription,
FormField,
FormItem,
FormLabel,
FormMessage,
} from "@/components/ui/form";
import {
Select,
SelectContent,
@@ -48,6 +28,11 @@ import {
SelectValue,
} from "@/components/ui/select";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { APIKeyCredentialsModal } from "./api-key-credentials-modal";
import { UserPasswordCredentialsModal } from "./user-password-credentials-modal";
import { HostScopedCredentialsModal } from "./host-scoped-credentials-modal";
import { OAuth2FlowWaitingModal } from "./oauth2-flow-waiting-modal";
import { getHostFromUrl } from "@/lib/utils/url";
const fallbackIcon = FaKey;
@@ -63,6 +48,7 @@ export const providerIcons: Record<
github: FaGithub,
google: FaGoogle,
groq: fallbackIcon,
http: fallbackIcon,
notion: NotionLogoIcon,
nvidia: fallbackIcon,
discord: FaDiscord,
@@ -129,6 +115,8 @@ export const CredentialsInput: FC<{
isUserPasswordCredentialsModalOpen,
setUserPasswordCredentialsModalOpen,
] = useState(false);
const [isHostScopedCredentialsModalOpen, setHostScopedCredentialsModalOpen] =
useState(false);
const [isOAuth2FlowInProgress, setOAuth2FlowInProgress] = useState(false);
const [oAuthPopupController, setOAuthPopupController] =
useState<AbortController | null>(null);
@@ -148,13 +136,27 @@ export const CredentialsInput: FC<{
}
}, [credentials, selectedCredentials, onSelectCredentials]);
const singleCredential = useMemo(() => {
if (!credentials || !("savedCredentials" in credentials)) return null;
const { hasRelevantCredentials, singleCredential } = useMemo(() => {
if (!credentials || !("savedCredentials" in credentials)) {
return {
hasRelevantCredentials: false,
singleCredential: null,
};
}
if (credentials.savedCredentials.length === 1)
return credentials.savedCredentials[0];
// Simple logic: if we have any saved credentials, we have relevant credentials
const hasRelevant = credentials.savedCredentials.length > 0;
return null;
// Auto-select single credential if only one exists
const single =
credentials.savedCredentials.length === 1
? credentials.savedCredentials[0]
: null;
return {
hasRelevantCredentials: hasRelevant,
singleCredential: single,
};
}, [credentials]);
// If only 1 credential is available, auto-select it and hide this input
@@ -178,6 +180,7 @@ export const CredentialsInput: FC<{
supportsApiKey,
supportsOAuth2,
supportsUserPassword,
supportsHostScoped,
savedCredentials,
oAuthCallback,
} = credentials;
@@ -271,7 +274,7 @@ export const CredentialsInput: FC<{
);
}
const ProviderIcon = providerIcons[provider];
const ProviderIcon = providerIcons[provider] || fallbackIcon;
const modals = (
<>
{supportsApiKey && (
@@ -305,6 +308,18 @@ export const CredentialsInput: FC<{
siblingInputs={siblingInputs}
/>
)}
{supportsHostScoped && (
<HostScopedCredentialsModal
schema={schema}
open={isHostScopedCredentialsModalOpen}
onClose={() => setHostScopedCredentialsModalOpen(false)}
onCredentialsCreate={(creds) => {
onSelectCredentials(creds);
setHostScopedCredentialsModalOpen(false);
}}
siblingInputs={siblingInputs}
/>
)}
</>
);
@@ -317,8 +332,8 @@ export const CredentialsInput: FC<{
</div>
);
// No saved credentials yet
if (savedCredentials.length === 0) {
// Show credentials creation UI when no relevant credentials exist
if (!hasRelevantCredentials) {
return (
<div>
{fieldHeader}
@@ -342,6 +357,12 @@ export const CredentialsInput: FC<{
Enter username and password
</Button>
)}
{supportsHostScoped && credentials.discriminatorValue && (
<Button onClick={() => setHostScopedCredentialsModalOpen(true)}>
<ProviderIcon className="mr-2 h-4 w-4" />
{`Enter sensitive headers for ${getHostFromUrl(credentials.discriminatorValue)}`}
</Button>
)}
</div>
{modals}
{oAuthError && (
@@ -358,6 +379,12 @@ export const CredentialsInput: FC<{
} else if (newValue === "add-api-key") {
// Open API key dialog
setAPICredentialsModalOpen(true);
} else if (newValue === "add-user-password") {
// Open user password dialog
setUserPasswordCredentialsModalOpen(true);
} else if (newValue === "add-host-scoped") {
// Open host-scoped credentials dialog
setHostScopedCredentialsModalOpen(true);
} else {
const selectedCreds = savedCredentials.find((c) => c.id == newValue)!;
@@ -406,6 +433,15 @@ export const CredentialsInput: FC<{
{credentials.title}
</SelectItem>
))}
{savedCredentials
.filter((c) => c.type == "host_scoped")
.map((credentials, index) => (
<SelectItem key={index} value={credentials.id}>
<ProviderIcon className="mr-2 inline h-4 w-4" />
<IconKey className="mr-1.5 inline" />
{credentials.title}
</SelectItem>
))}
<SelectSeparator />
{supportsOAuth2 && (
<SelectItem value="sign-in">
@@ -425,6 +461,12 @@ export const CredentialsInput: FC<{
Add new user password
</SelectItem>
)}
{supportsHostScoped && (
<SelectItem value="add-host-scoped">
<IconKey className="mr-1.5 inline" />
Add host-scoped headers
</SelectItem>
)}
</SelectContent>
</Select>
{modals}
@@ -434,291 +476,3 @@ export const CredentialsInput: FC<{
</div>
);
};
export const APIKeyCredentialsModal: FC<{
schema: BlockIOCredentialsSubSchema;
open: boolean;
onClose: () => void;
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
siblingInputs?: Record<string, any>;
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
const credentials = useCredentials(schema, siblingInputs);
const formSchema = z.object({
apiKey: z.string().min(1, "API Key is required"),
title: z.string().min(1, "Name is required"),
expiresAt: z.string().optional(),
});
const form = useForm<z.infer<typeof formSchema>>({
resolver: zodResolver(formSchema),
defaultValues: {
apiKey: "",
title: "",
expiresAt: "",
},
});
if (!credentials || credentials.isLoading || !credentials.supportsApiKey) {
return null;
}
const { provider, providerName, createAPIKeyCredentials } = credentials;
async function onSubmit(values: z.infer<typeof formSchema>) {
const expiresAt = values.expiresAt
? new Date(values.expiresAt).getTime() / 1000
: undefined;
const newCredentials = await createAPIKeyCredentials({
api_key: values.apiKey,
title: values.title,
expires_at: expiresAt,
});
onCredentialsCreate({
provider,
id: newCredentials.id,
type: "api_key",
title: newCredentials.title,
});
}
return (
<Dialog
open={open}
onOpenChange={(open) => {
if (!open) onClose();
}}
>
<DialogContent>
<DialogHeader>
<DialogTitle>Add new API key for {providerName}</DialogTitle>
{schema.description && (
<DialogDescription>{schema.description}</DialogDescription>
)}
</DialogHeader>
<Form {...form}>
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
<FormField
control={form.control}
name="apiKey"
render={({ field }) => (
<FormItem>
<FormLabel>API Key</FormLabel>
{schema.credentials_scopes && (
<FormDescription>
Required scope(s) for this block:{" "}
{schema.credentials_scopes?.map((s, i, a) => (
<span key={i}>
<code>{s}</code>
{i < a.length - 1 && ", "}
</span>
))}
</FormDescription>
)}
<FormControl>
<Input
type="password"
placeholder="Enter API key..."
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="title"
render={({ field }) => (
<FormItem>
<FormLabel>Name</FormLabel>
<FormControl>
<Input
type="text"
placeholder="Enter a name for this API key..."
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="expiresAt"
render={({ field }) => (
<FormItem>
<FormLabel>Expiration Date (Optional)</FormLabel>
<FormControl>
<Input
type="datetime-local"
placeholder="Select expiration date..."
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<Button type="submit" className="w-full">
Save & use this API key
</Button>
</form>
</Form>
</DialogContent>
</Dialog>
);
};
export const UserPasswordCredentialsModal: FC<{
schema: BlockIOCredentialsSubSchema;
open: boolean;
onClose: () => void;
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
siblingInputs?: Record<string, any>;
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
const credentials = useCredentials(schema, siblingInputs);
const formSchema = z.object({
username: z.string().min(1, "Username is required"),
password: z.string().min(1, "Password is required"),
title: z.string().min(1, "Name is required"),
});
const form = useForm<z.infer<typeof formSchema>>({
resolver: zodResolver(formSchema),
defaultValues: {
username: "",
password: "",
title: "",
},
});
if (
!credentials ||
credentials.isLoading ||
!credentials.supportsUserPassword
) {
return null;
}
const { provider, providerName, createUserPasswordCredentials } = credentials;
async function onSubmit(values: z.infer<typeof formSchema>) {
const newCredentials = await createUserPasswordCredentials({
username: values.username,
password: values.password,
title: values.title,
});
onCredentialsCreate({
provider,
id: newCredentials.id,
type: "user_password",
title: newCredentials.title,
});
}
return (
<Dialog
open={open}
onOpenChange={(open) => {
if (!open) onClose();
}}
>
<DialogContent>
<DialogHeader>
<DialogTitle>
Add new username & password for {providerName}
</DialogTitle>
</DialogHeader>
<Form {...form}>
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
<FormField
control={form.control}
name="username"
render={({ field }) => (
<FormItem>
<FormLabel>Username</FormLabel>
<FormControl>
<Input
type="text"
placeholder="Enter username..."
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="password"
render={({ field }) => (
<FormItem>
<FormLabel>Password</FormLabel>
<FormControl>
<Input
type="password"
placeholder="Enter password..."
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="title"
render={({ field }) => (
<FormItem>
<FormLabel>Name</FormLabel>
<FormControl>
<Input
type="text"
placeholder="Enter a name for this user login..."
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<Button type="submit" className="w-full">
Save & use this user login
</Button>
</form>
</Form>
</DialogContent>
</Dialog>
);
};
export const OAuth2FlowWaitingModal: FC<{
open: boolean;
onClose: () => void;
providerName: string;
}> = ({ open, onClose, providerName }) => {
return (
<Dialog
open={open}
onOpenChange={(open) => {
if (!open) onClose();
}}
>
<DialogContent>
<DialogHeader>
<DialogTitle>
Waiting on {providerName} sign-in process...
</DialogTitle>
<DialogDescription>
Complete the sign-in process in the pop-up window.
<br />
Closing this dialog will cancel the sign-in process.
</DialogDescription>
</DialogHeader>
</DialogContent>
</Dialog>
);
};

View File

@@ -6,10 +6,12 @@ import {
CredentialsDeleteResponse,
CredentialsMetaResponse,
CredentialsProviderName,
HostScopedCredentials,
PROVIDER_NAMES,
UserPasswordCredentials,
} from "@/lib/autogpt-server-api";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { useToastOnFail } from "@/components/ui/use-toast";
// Get keys from CredentialsProviderName type
const CREDENTIALS_PROVIDER_NAMES = Object.values(
@@ -30,6 +32,7 @@ const providerDisplayNames: Record<CredentialsProviderName, string> = {
google: "Google",
google_maps: "Google Maps",
groq: "Groq",
http: "HTTP",
hubspot: "Hubspot",
ideogram: "Ideogram",
jina: "Jina",
@@ -68,6 +71,11 @@ type UserPasswordCredentialsCreatable = Omit<
"id" | "provider" | "type"
>;
type HostScopedCredentialsCreatable = Omit<
HostScopedCredentials,
"id" | "provider" | "type"
>;
export type CredentialsProviderData = {
provider: CredentialsProviderName;
providerName: string;
@@ -82,6 +90,9 @@ export type CredentialsProviderData = {
createUserPasswordCredentials: (
credentials: UserPasswordCredentialsCreatable,
) => Promise<CredentialsMetaResponse>;
createHostScopedCredentials: (
credentials: HostScopedCredentialsCreatable,
) => Promise<CredentialsMetaResponse>;
deleteCredentials: (
id: string,
force?: boolean,
@@ -106,6 +117,7 @@ export default function CredentialsProvider({
useState<CredentialsProvidersContextType | null>(null);
const { isLoggedIn } = useSupabase();
const api = useBackendAPI();
const onFailToast = useToastOnFail();
const addCredentials = useCallback(
(
@@ -134,11 +146,16 @@ export default function CredentialsProvider({
code: string,
state_token: string,
): Promise<CredentialsMetaResponse> => {
const credsMeta = await api.oAuthCallback(provider, code, state_token);
addCredentials(provider, credsMeta);
return credsMeta;
try {
const credsMeta = await api.oAuthCallback(provider, code, state_token);
addCredentials(provider, credsMeta);
return credsMeta;
} catch (error) {
onFailToast("complete OAuth authentication")(error);
throw error;
}
},
[api, addCredentials],
[api, addCredentials, onFailToast],
);
/** Wraps `BackendAPI.createAPIKeyCredentials`, and adds the result to the internal credentials store. */
@@ -147,14 +164,19 @@ export default function CredentialsProvider({
provider: CredentialsProviderName,
credentials: APIKeyCredentialsCreatable,
): Promise<CredentialsMetaResponse> => {
const credsMeta = await api.createAPIKeyCredentials({
provider,
...credentials,
});
addCredentials(provider, credsMeta);
return credsMeta;
try {
const credsMeta = await api.createAPIKeyCredentials({
provider,
...credentials,
});
addCredentials(provider, credsMeta);
return credsMeta;
} catch (error) {
onFailToast("create API key credentials")(error);
throw error;
}
},
[api, addCredentials],
[api, addCredentials, onFailToast],
);
/** Wraps `BackendAPI.createUserPasswordCredentials`, and adds the result to the internal credentials store. */
@@ -163,14 +185,40 @@ export default function CredentialsProvider({
provider: CredentialsProviderName,
credentials: UserPasswordCredentialsCreatable,
): Promise<CredentialsMetaResponse> => {
const credsMeta = await api.createUserPasswordCredentials({
provider,
...credentials,
});
addCredentials(provider, credsMeta);
return credsMeta;
try {
const credsMeta = await api.createUserPasswordCredentials({
provider,
...credentials,
});
addCredentials(provider, credsMeta);
return credsMeta;
} catch (error) {
onFailToast("create user/password credentials")(error);
throw error;
}
},
[api, addCredentials],
[api, addCredentials, onFailToast],
);
/** Wraps `BackendAPI.createHostScopedCredentials`, and adds the result to the internal credentials store. */
const createHostScopedCredentials = useCallback(
async (
provider: CredentialsProviderName,
credentials: HostScopedCredentialsCreatable,
): Promise<CredentialsMetaResponse> => {
try {
const credsMeta = await api.createHostScopedCredentials({
provider,
...credentials,
});
addCredentials(provider, credsMeta);
return credsMeta;
} catch (error) {
onFailToast("create host-scoped credentials")(error);
throw error;
}
},
[api, addCredentials, onFailToast],
);
/** Wraps `BackendAPI.deleteCredentials`, and removes the credentials from the internal store. */
@@ -182,26 +230,31 @@ export default function CredentialsProvider({
): Promise<
CredentialsDeleteResponse | CredentialsDeleteNeedConfirmationResponse
> => {
const result = await api.deleteCredentials(provider, id, force);
if (!result.deleted) {
return result;
}
setProviders((prev) => {
if (!prev || !prev[provider]) return prev;
try {
const result = await api.deleteCredentials(provider, id, force);
if (!result.deleted) {
return result;
}
setProviders((prev) => {
if (!prev || !prev[provider]) return prev;
return {
...prev,
[provider]: {
...prev[provider],
savedCredentials: prev[provider].savedCredentials.filter(
(cred) => cred.id !== id,
),
},
};
});
return result;
return {
...prev,
[provider]: {
...prev[provider],
savedCredentials: prev[provider].savedCredentials.filter(
(cred) => cred.id !== id,
),
},
};
});
return result;
} catch (error) {
onFailToast("delete credentials")(error);
throw error;
}
},
[api],
[api, onFailToast],
);
useEffect(() => {
@@ -210,47 +263,54 @@ export default function CredentialsProvider({
return;
}
api.listCredentials().then((response) => {
const credentialsByProvider = response.reduce(
(acc, cred) => {
if (!acc[cred.provider]) {
acc[cred.provider] = [];
}
acc[cred.provider].push(cred);
return acc;
},
{} as Record<CredentialsProviderName, CredentialsMetaResponse[]>,
);
api
.listCredentials()
.then((response) => {
const credentialsByProvider = response.reduce(
(acc, cred) => {
if (!acc[cred.provider]) {
acc[cred.provider] = [];
}
acc[cred.provider].push(cred);
return acc;
},
{} as Record<CredentialsProviderName, CredentialsMetaResponse[]>,
);
setProviders((prev) => ({
...prev,
...Object.fromEntries(
CREDENTIALS_PROVIDER_NAMES.map((provider) => [
provider,
{
setProviders((prev) => ({
...prev,
...Object.fromEntries(
CREDENTIALS_PROVIDER_NAMES.map((provider) => [
provider,
providerName: providerDisplayNames[provider],
savedCredentials: credentialsByProvider[provider] ?? [],
oAuthCallback: (code: string, state_token: string) =>
oAuthCallback(provider, code, state_token),
createAPIKeyCredentials: (
credentials: APIKeyCredentialsCreatable,
) => createAPIKeyCredentials(provider, credentials),
createUserPasswordCredentials: (
credentials: UserPasswordCredentialsCreatable,
) => createUserPasswordCredentials(provider, credentials),
deleteCredentials: (id: string, force: boolean = false) =>
deleteCredentials(provider, id, force),
} satisfies CredentialsProviderData,
]),
),
}));
});
{
provider,
providerName: providerDisplayNames[provider],
savedCredentials: credentialsByProvider[provider] ?? [],
oAuthCallback: (code: string, state_token: string) =>
oAuthCallback(provider, code, state_token),
createAPIKeyCredentials: (
credentials: APIKeyCredentialsCreatable,
) => createAPIKeyCredentials(provider, credentials),
createUserPasswordCredentials: (
credentials: UserPasswordCredentialsCreatable,
) => createUserPasswordCredentials(provider, credentials),
createHostScopedCredentials: (
credentials: HostScopedCredentialsCreatable,
) => createHostScopedCredentials(provider, credentials),
deleteCredentials: (id: string, force: boolean = false) =>
deleteCredentials(provider, id, force),
} satisfies CredentialsProviderData,
]),
),
}));
})
.catch(onFailToast("load credentials"));
}, [
api,
isLoggedIn,
createAPIKeyCredentials,
createUserPasswordCredentials,
createHostScopedCredentials,
deleteCredentials,
oAuthCallback,
]);

View File

@@ -0,0 +1,235 @@
import { FC, useEffect, useState } from "react";
import { z } from "zod";
import { useForm } from "react-hook-form";
import { zodResolver } from "@hookform/resolvers/zod";
import { Input } from "@/components/ui/input";
import { Button } from "@/components/ui/button";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import {
Form,
FormControl,
FormDescription,
FormField,
FormItem,
FormLabel,
FormMessage,
} from "@/components/ui/form";
import useCredentials from "@/hooks/useCredentials";
import {
BlockIOCredentialsSubSchema,
CredentialsMetaInput,
} from "@/lib/autogpt-server-api/types";
import { getHostFromUrl } from "@/lib/utils/url";
export const HostScopedCredentialsModal: FC<{
schema: BlockIOCredentialsSubSchema;
open: boolean;
onClose: () => void;
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
siblingInputs?: Record<string, any>;
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
const credentials = useCredentials(schema, siblingInputs);
// Get current host from siblingInputs or discriminator_values
const currentUrl = credentials?.discriminatorValue;
const currentHost = currentUrl ? getHostFromUrl(currentUrl) : "";
const formSchema = z.object({
host: z.string().min(1, "Host is required"),
title: z.string().optional(),
headers: z.record(z.string()).optional(),
});
const form = useForm<z.infer<typeof formSchema>>({
resolver: zodResolver(formSchema),
defaultValues: {
host: currentHost || "",
title: currentHost || "Manual Entry",
headers: {},
},
});
const [headerPairs, setHeaderPairs] = useState<
Array<{ key: string; value: string }>
>([{ key: "", value: "" }]);
// Update form values when siblingInputs change
useEffect(() => {
if (currentHost) {
form.setValue("host", currentHost);
form.setValue("title", currentHost);
} else {
// Reset to empty when no current host
form.setValue("host", "");
form.setValue("title", "Manual Entry");
}
}, [currentHost, form]);
if (
!credentials ||
credentials.isLoading ||
!credentials.supportsHostScoped
) {
return null;
}
const { provider, providerName, createHostScopedCredentials } = credentials;
const addHeaderPair = () => {
setHeaderPairs([...headerPairs, { key: "", value: "" }]);
};
const removeHeaderPair = (index: number) => {
if (headerPairs.length > 1) {
setHeaderPairs(headerPairs.filter((_, i) => i !== index));
}
};
const updateHeaderPair = (
index: number,
field: "key" | "value",
value: string,
) => {
const newPairs = [...headerPairs];
newPairs[index][field] = value;
setHeaderPairs(newPairs);
};
async function onSubmit(values: z.infer<typeof formSchema>) {
// Convert header pairs to object, filtering out empty pairs
const headers = headerPairs.reduce(
(acc, pair) => {
if (pair.key.trim() && pair.value.trim()) {
acc[pair.key.trim()] = pair.value.trim();
}
return acc;
},
{} as Record<string, string>,
);
const newCredentials = await createHostScopedCredentials({
host: values.host,
title: currentHost || values.host,
headers,
});
onCredentialsCreate({
provider,
id: newCredentials.id,
type: "host_scoped",
title: newCredentials.title,
});
}
return (
<Dialog
open={open}
onOpenChange={(open) => {
if (!open) onClose();
}}
>
<DialogContent className="max-h-[90vh] max-w-2xl overflow-y-auto">
<DialogHeader>
<DialogTitle>Add sensitive headers for {providerName}</DialogTitle>
{schema.description && (
<DialogDescription>{schema.description}</DialogDescription>
)}
</DialogHeader>
<Form {...form}>
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
<FormField
control={form.control}
name="host"
render={({ field }) => (
<FormItem>
<FormLabel>Host Pattern</FormLabel>
<FormDescription>
{currentHost
? "Auto-populated from the URL field. Headers will be applied to requests to this host."
: "Enter the host/domain to match against request URLs (e.g., api.example.com)."}
</FormDescription>
<FormControl>
<Input
type="text"
readOnly={!!currentHost}
placeholder={
currentHost
? undefined
: "Enter host (e.g., api.example.com)"
}
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<div className="space-y-2">
<FormLabel>Headers</FormLabel>
<FormDescription>
Add sensitive headers (like Authorization, X-API-Key) that
should be automatically included in requests to the specified
host.
</FormDescription>
{headerPairs.map((pair, index) => (
<div key={index} className="flex items-end gap-2">
<div className="flex-1">
<Input
placeholder="Header name (e.g., Authorization)"
value={pair.key}
onChange={(e) =>
updateHeaderPair(index, "key", e.target.value)
}
/>
</div>
<div className="flex-1">
<Input
type="password"
placeholder="Header value (e.g., Bearer token123)"
value={pair.value}
onChange={(e) =>
updateHeaderPair(index, "value", e.target.value)
}
/>
</div>
<Button
type="button"
variant="outline"
size="sm"
onClick={() => removeHeaderPair(index)}
disabled={headerPairs.length === 1}
>
Remove
</Button>
</div>
))}
<Button
type="button"
variant="outline"
size="sm"
onClick={addHeaderPair}
className="w-full"
>
Add Another Header
</Button>
</div>
<Button type="submit" className="w-full">
Save & use these credentials
</Button>
</form>
</Form>
</DialogContent>
</Dialog>
);
};

View File

@@ -0,0 +1,36 @@
import { FC } from "react";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
export const OAuth2FlowWaitingModal: FC<{
open: boolean;
onClose: () => void;
providerName: string;
}> = ({ open, onClose, providerName }) => {
return (
<Dialog
open={open}
onOpenChange={(open) => {
if (!open) onClose();
}}
>
<DialogContent>
<DialogHeader>
<DialogTitle>
Waiting on {providerName} sign-in process...
</DialogTitle>
<DialogDescription>
Complete the sign-in process in the pop-up window.
<br />
Closing this dialog will cancel the sign-in process.
</DialogDescription>
</DialogHeader>
</DialogContent>
</Dialog>
);
};

View File

@@ -0,0 +1,149 @@
import { FC } from "react";
import { z } from "zod";
import { useForm } from "react-hook-form";
import { zodResolver } from "@hookform/resolvers/zod";
import { Input } from "@/components/ui/input";
import { Button } from "@/components/ui/button";
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import {
Form,
FormControl,
FormField,
FormItem,
FormLabel,
FormMessage,
} from "@/components/ui/form";
import useCredentials from "@/hooks/useCredentials";
import {
BlockIOCredentialsSubSchema,
CredentialsMetaInput,
} from "@/lib/autogpt-server-api/types";
export const UserPasswordCredentialsModal: FC<{
schema: BlockIOCredentialsSubSchema;
open: boolean;
onClose: () => void;
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
siblingInputs?: Record<string, any>;
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
const credentials = useCredentials(schema, siblingInputs);
const formSchema = z.object({
username: z.string().min(1, "Username is required"),
password: z.string().min(1, "Password is required"),
title: z.string().min(1, "Name is required"),
});
const form = useForm<z.infer<typeof formSchema>>({
resolver: zodResolver(formSchema),
defaultValues: {
username: "",
password: "",
title: "",
},
});
if (
!credentials ||
credentials.isLoading ||
!credentials.supportsUserPassword
) {
return null;
}
const { provider, providerName, createUserPasswordCredentials } = credentials;
async function onSubmit(values: z.infer<typeof formSchema>) {
const newCredentials = await createUserPasswordCredentials({
username: values.username,
password: values.password,
title: values.title,
});
onCredentialsCreate({
provider,
id: newCredentials.id,
type: "user_password",
title: newCredentials.title,
});
}
return (
<Dialog
open={open}
onOpenChange={(open) => {
if (!open) onClose();
}}
>
<DialogContent>
<DialogHeader>
<DialogTitle>
Add new username & password for {providerName}
</DialogTitle>
</DialogHeader>
<Form {...form}>
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
<FormField
control={form.control}
name="username"
render={({ field }) => (
<FormItem>
<FormLabel>Username</FormLabel>
<FormControl>
<Input
type="text"
placeholder="Enter username..."
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="password"
render={({ field }) => (
<FormItem>
<FormLabel>Password</FormLabel>
<FormControl>
<Input
type="password"
placeholder="Enter password..."
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="title"
render={({ field }) => (
<FormItem>
<FormLabel>Name</FormLabel>
<FormControl>
<Input
type="text"
placeholder="Enter a name for this user login..."
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<Button type="submit" className="w-full">
Save & use this user login
</Button>
</form>
</Form>
</DialogContent>
</Dialog>
);
};

View File

@@ -1,7 +1,17 @@
"use client";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { Button } from "@/components/ui/button";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import { OnboardingStep, UserOnboarding } from "@/lib/autogpt-server-api";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import Link from "next/link";
import { usePathname, useRouter } from "next/navigation";
import {
createContext,
@@ -11,16 +21,6 @@ import {
useEffect,
useState,
} from "react";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
DialogFooter,
} from "@/components/ui/dialog";
import { Button } from "@/components/ui/button";
import Link from "next/link";
const OnboardingContext = createContext<
| {
@@ -106,8 +106,6 @@ export default function OnboardingProvider({
const updateState = useCallback(
(newState: Omit<Partial<UserOnboarding>, "rewardedFor">) => {
setState((prev) => {
api.updateUserOnboarding(newState);
if (!prev) {
// Handle initial state
return {
@@ -127,8 +125,15 @@ export default function OnboardingProvider({
}
return { ...prev, ...newState };
});
// Make the API call asynchronously to not block render
setTimeout(() => {
api.updateUserOnboarding(newState).catch((error) => {
console.error("Failed to update user onboarding:", error);
});
}, 0);
},
[api, setState],
[api],
);
const completeStep = useCallback(
@@ -153,7 +158,7 @@ export default function OnboardingProvider({
completedSteps: [...state.completedSteps, "RUN_AGENTS"],
}),
});
}, [api, state]);
}, [state, updateState]);
return (
<OnboardingContext.Provider

View File

@@ -9,6 +9,7 @@ import {
BlockIOCredentialsSubSchema,
CredentialsProviderName,
} from "@/lib/autogpt-server-api";
import { getHostFromUrl } from "@/lib/utils/url";
export type CredentialsData =
| {
@@ -17,14 +18,18 @@ export type CredentialsData =
supportsApiKey: boolean;
supportsOAuth2: boolean;
supportsUserPassword: boolean;
supportsHostScoped: boolean;
isLoading: true;
discriminatorValue?: string;
}
| (CredentialsProviderData & {
schema: BlockIOCredentialsSubSchema;
supportsApiKey: boolean;
supportsOAuth2: boolean;
supportsUserPassword: boolean;
supportsHostScoped: boolean;
isLoading: false;
discriminatorValue?: string;
});
export default function useCredentials(
@@ -33,12 +38,16 @@ export default function useCredentials(
): CredentialsData | null {
const allProviders = useContext(CredentialsProvidersContext);
const discriminatorValue: CredentialsProviderName | null =
(credsInputSchema.discriminator &&
credsInputSchema.discriminator_mapping![
getValue(credsInputSchema.discriminator, nodeInputValues)
]) ||
null;
const discriminatorValue = [
credsInputSchema.discriminator
? getValue(credsInputSchema.discriminator, nodeInputValues)
: null,
...(credsInputSchema.discriminator_values || []),
].find(Boolean);
const discriminatedProvider = credsInputSchema.discriminator_mapping
? credsInputSchema.discriminator_mapping[discriminatorValue]
: null;
let providerName: CredentialsProviderName;
if (credsInputSchema.credentials_provider.length > 1) {
@@ -47,14 +56,14 @@ export default function useCredentials(
"Multi-provider credential input requires discriminator!",
);
}
if (!discriminatorValue) {
console.log(
if (!discriminatedProvider) {
console.warn(
`Missing discriminator value from '${credsInputSchema.discriminator}': ` +
"hiding credentials input until it is set.",
);
return null;
}
providerName = discriminatorValue;
providerName = discriminatedProvider;
} else {
providerName = credsInputSchema.credentials_provider[0];
}
@@ -69,6 +78,8 @@ export default function useCredentials(
const supportsOAuth2 = credsInputSchema.credentials_types.includes("oauth2");
const supportsUserPassword =
credsInputSchema.credentials_types.includes("user_password");
const supportsHostScoped =
credsInputSchema.credentials_types.includes("host_scoped");
// No provider means maybe it's still loading
if (!provider) {
@@ -82,15 +93,24 @@ export default function useCredentials(
return null;
}
// Filter by OAuth credentials that have sufficient scopes for this block
const requiredScopes = credsInputSchema.credentials_scopes;
const savedCredentials = requiredScopes
? provider.savedCredentials.filter(
(c) =>
c.type != "oauth2" ||
new Set(c.scopes).isSupersetOf(new Set(requiredScopes)),
)
: provider.savedCredentials;
const savedCredentials = provider.savedCredentials.filter((c) => {
// Filter by OAuth credentials that have sufficient scopes for this block
if (c.type === "oauth2") {
const requiredScopes = credsInputSchema.credentials_scopes;
return (
!requiredScopes ||
new Set(c.scopes).isSupersetOf(new Set(requiredScopes))
);
}
// Filter host_scoped credentials by host matching
if (c.type === "host_scoped") {
return discriminatorValue && getHostFromUrl(discriminatorValue) == c.host;
}
// Include all other credential types
return true;
});
return {
...provider,
@@ -99,7 +119,9 @@ export default function useCredentials(
supportsApiKey,
supportsOAuth2,
supportsUserPassword,
supportsHostScoped,
savedCredentials,
discriminatorValue,
isLoading: false,
};
}

View File

@@ -1,6 +1,8 @@
import { getWebSocketToken } from "@/lib/supabase/actions";
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
import { createBrowserClient } from "@supabase/ssr";
import type { SupabaseClient } from "@supabase/supabase-js";
import { proxyApiRequest, proxyFileUpload } from "./proxy-action";
import type {
AddUserCreditsResponse,
AnalyticsDetails,
@@ -60,6 +62,7 @@ import type {
User,
UserOnboarding,
UserPasswordCredentials,
HostScopedCredentials,
UsersBalanceHistoryResponse,
} from "./types";
@@ -345,6 +348,16 @@ export default class BackendAPI {
);
}
createHostScopedCredentials(
credentials: Omit<HostScopedCredentials, "id" | "type">,
): Promise<HostScopedCredentials> {
return this._request(
"POST",
`/integrations/${credentials.provider}/credentials`,
{ ...credentials, type: "host_scoped" },
);
}
listCredentials(provider?: string): Promise<CredentialsMetaResponse[]> {
return this._get(
provider
@@ -761,50 +774,25 @@ export default class BackendAPI {
return this._request("GET", path, query);
}
private async getAuthToken(): Promise<string> {
// Only try client-side session (for WebSocket connections)
// This will return "no-token-found" with httpOnly cookies, which is expected
const supabaseClient = await this.getSupabaseClient();
const {
data: { session },
} = (await supabaseClient?.auth.getSession()) || {
data: { session: null },
};
return session?.access_token || "no-token-found";
}
private async _uploadFile(path: string, file: File): Promise<string> {
// Get session with retry logic
let token = "no-token-found";
let retryCount = 0;
const maxRetries = 3;
while (retryCount < maxRetries) {
const supabaseClient = await this.getSupabaseClient();
const {
data: { session },
} = (await supabaseClient?.auth.getSession()) || {
data: { session: null },
};
if (session?.access_token) {
token = session.access_token;
break;
}
retryCount++;
if (retryCount < maxRetries) {
await new Promise((resolve) => setTimeout(resolve, 100 * retryCount));
}
}
// Create a FormData object and append the file
const formData = new FormData();
formData.append("file", file);
const response = await fetch(this.baseUrl + path, {
method: "POST",
headers: {
...(token && { Authorization: `Bearer ${token}` }),
},
body: formData,
});
if (!response.ok) {
throw new Error(`Error uploading file: ${response.statusText}`);
}
// Parse the response appropriately
const media_url = await response.text();
return media_url;
// Use proxy server action for secure file upload
return await proxyFileUpload(path, formData, this.baseUrl);
}
private async _request(
@@ -816,103 +804,13 @@ export default class BackendAPI {
console.debug(`${method} ${path} payload:`, payload);
}
// Get session with retry logic
let token = "no-token-found";
let retryCount = 0;
const maxRetries = 3;
while (retryCount < maxRetries) {
const supabaseClient = await this.getSupabaseClient();
const {
data: { session },
} = (await supabaseClient?.auth.getSession()) || {
data: { session: null },
};
if (session?.access_token) {
token = session.access_token;
break;
}
retryCount++;
if (retryCount < maxRetries) {
await new Promise((resolve) => setTimeout(resolve, 100 * retryCount));
}
}
let url = this.baseUrl + path;
const payloadAsQuery = ["GET", "DELETE"].includes(method);
if (payloadAsQuery && payload) {
// For GET requests, use payload as query
const queryParams = new URLSearchParams(payload);
url += `?${queryParams.toString()}`;
}
const hasRequestBody = !payloadAsQuery && payload !== undefined;
const response = await fetch(url, {
// Always use proxy server action to not expose any auth tokens to the browser
return await proxyApiRequest({
method,
headers: {
...(hasRequestBody && { "Content-Type": "application/json" }),
...(token && { Authorization: `Bearer ${token}` }),
},
body: hasRequestBody ? JSON.stringify(payload) : undefined,
path,
payload,
baseUrl: this.baseUrl,
});
if (!response.ok) {
console.warn(`${method} ${path} returned non-OK response:`, response);
// console.warn("baseClient is attempting to redirect by changing window location")
// if (
// response.status === 403 &&
// response.statusText === "Not authenticated" &&
// typeof window !== "undefined" // Check if in browser environment
// ) {
// window.location.href = "/login";
// }
let errorDetail;
try {
const errorData = await response.json();
if (
Array.isArray(errorData.detail) &&
errorData.detail.length > 0 &&
errorData.detail[0].loc
) {
// This appears to be a Pydantic validation error
const errors = errorData.detail.map(
(err: _PydanticValidationError) => {
const location = err.loc.join(" -> ");
return `${location}: ${err.msg}`;
},
);
errorDetail = errors.join("\n");
} else {
errorDetail = errorData.detail || response.statusText;
}
} catch {
errorDetail = response.statusText;
}
throw new Error(errorDetail);
}
// Handle responses with no content (like DELETE requests)
if (
response.status === 204 ||
response.headers.get("Content-Length") === "0"
) {
return null;
}
try {
return await response.json();
} catch (e) {
if (e instanceof SyntaxError) {
console.warn(`${method} ${path} returned invalid JSON:`, e);
return null;
}
throw e;
}
}
////////////////////////////////////////
@@ -1000,10 +898,19 @@ export default class BackendAPI {
async connectWebSocket(): Promise<void> {
return (this.wsConnecting ??= new Promise(async (resolve, reject) => {
try {
const supabaseClient = await this.getSupabaseClient();
const token =
(await supabaseClient?.auth.getSession())?.data.session
?.access_token || "";
let token = "";
try {
const { token: serverToken, error } = await getWebSocketToken();
if (serverToken && !error) {
token = serverToken;
} else if (error) {
console.warn("Failed to get WebSocket token from server:", error);
}
} catch (error) {
console.warn("Failed to get token for WebSocket connection:", error);
// Continue with empty token, connection might still work
}
const wsUrlWithToken = `${this.wsUrl}?token=${token}`;
this.webSocket = new WebSocket(wsUrlWithToken);
this.webSocket.state = "connecting";

View File

@@ -0,0 +1,236 @@
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
export function buildRequestUrl(
baseUrl: string,
path: string,
method: string,
payload?: Record<string, any>,
): string {
let url = baseUrl + path;
const payloadAsQuery = ["GET", "DELETE"].includes(method);
if (payloadAsQuery && payload) {
const queryParams = new URLSearchParams(payload);
url += `?${queryParams.toString()}`;
}
return url;
}
export async function getServerAuthToken(): Promise<string> {
const supabase = await getServerSupabase();
if (!supabase) {
throw new Error("Supabase client not available");
}
try {
const {
data: { session },
error,
} = await supabase.auth.getSession();
if (error || !session?.access_token) {
return "no-token-found";
}
return session.access_token;
} catch (error) {
console.error("Failed to get auth token:", error);
return "no-token-found";
}
}
export function createRequestHeaders(
token: string,
hasRequestBody: boolean,
contentType: string = "application/json",
): Record<string, string> {
const headers: Record<string, string> = {};
if (hasRequestBody) {
headers["Content-Type"] = contentType;
}
if (token && token !== "no-token-found") {
headers["Authorization"] = `Bearer ${token}`;
}
return headers;
}
export function serializeRequestBody(
payload: any,
contentType: string = "application/json",
): string {
switch (contentType) {
case "application/json":
return JSON.stringify(payload);
case "application/x-www-form-urlencoded":
return new URLSearchParams(payload).toString();
default:
// For custom content types, assume payload is already properly formatted
return typeof payload === "string" ? payload : JSON.stringify(payload);
}
}
export async function parseApiError(response: Response): Promise<string> {
try {
const errorData = await response.json();
if (
Array.isArray(errorData.detail) &&
errorData.detail.length > 0 &&
errorData.detail[0].loc
) {
// Pydantic validation error
const errors = errorData.detail.map((err: any) => {
const location = err.loc.join(" -> ");
return `${location}: ${err.msg}`;
});
return errors.join("\n");
}
return errorData.detail || response.statusText;
} catch {
return response.statusText;
}
}
export async function parseApiResponse(response: Response): Promise<any> {
// Handle responses with no content
if (
response.status === 204 ||
response.headers.get("Content-Length") === "0"
) {
return null;
}
try {
return await response.json();
} catch (e) {
if (e instanceof SyntaxError) {
return null;
}
throw e;
}
}
function isAuthenticationError(
response: Response,
errorDetail: string,
): boolean {
return (
response.status === 401 ||
response.status === 403 ||
errorDetail.toLowerCase().includes("not authenticated") ||
errorDetail.toLowerCase().includes("unauthorized") ||
errorDetail.toLowerCase().includes("authentication failed")
);
}
function isLogoutInProgress(): boolean {
if (typeof window === "undefined") return false;
try {
// Check if logout was recently triggered
const logoutTimestamp = window.localStorage.getItem("supabase-logout");
if (logoutTimestamp) {
const timeDiff = Date.now() - parseInt(logoutTimestamp);
// Consider logout in progress for 5 seconds after trigger
return timeDiff < 5000;
}
// Check if we're being redirected to login
return (
window.location.pathname.includes("/login") ||
window.location.pathname.includes("/logout")
);
} catch {
return false;
}
}
export async function makeAuthenticatedRequest(
method: string,
url: string,
payload?: Record<string, any>,
contentType: string = "application/json",
): Promise<any> {
const token = await getServerAuthToken();
const payloadAsQuery = ["GET", "DELETE"].includes(method);
const hasRequestBody = !payloadAsQuery && payload !== undefined;
const response = await fetch(url, {
method,
headers: createRequestHeaders(token, hasRequestBody, contentType),
body: hasRequestBody
? serializeRequestBody(payload, contentType)
: undefined,
});
if (!response.ok) {
const errorDetail = await parseApiError(response);
// Handle authentication errors gracefully during logout
if (isAuthenticationError(response, errorDetail)) {
if (isLogoutInProgress()) {
// Silently return null during logout to prevent error noise
console.debug(
"Authentication request failed during logout, ignoring:",
errorDetail,
);
return null;
}
// For authentication errors outside logout, log but don't throw
// This prevents crashes when session expires naturally
console.warn("Authentication failed:", errorDetail);
return null;
}
// For other errors, throw as normal
throw new Error(errorDetail);
}
return parseApiResponse(response);
}
export async function makeAuthenticatedFileUpload(
url: string,
formData: FormData,
): Promise<string> {
const token = await getServerAuthToken();
const headers: Record<string, string> = {};
if (token && token !== "no-token-found") {
headers["Authorization"] = `Bearer ${token}`;
}
// Don't set Content-Type for FormData - let the browser set it with boundary
const response = await fetch(url, {
method: "POST",
headers,
body: formData,
});
if (!response.ok) {
// Handle authentication errors gracefully for file uploads too
const errorMessage = `Error uploading file: ${response.statusText}`;
if (response.status === 401 || response.status === 403) {
if (isLogoutInProgress()) {
console.debug(
"File upload authentication failed during logout, ignoring",
);
return "";
}
console.warn("File upload authentication failed:", errorMessage);
return "";
}
throw new Error(errorMessage);
}
return await response.text();
}

View File

@@ -0,0 +1,51 @@
"use server";
import * as Sentry from "@sentry/nextjs";
import {
buildRequestUrl,
makeAuthenticatedFileUpload,
makeAuthenticatedRequest,
} from "./helpers";
const DEFAULT_BASE_URL = "http://localhost:8006/api";
export interface ProxyRequestOptions {
method: "GET" | "POST" | "PUT" | "PATCH" | "DELETE";
path: string;
payload?: Record<string, any>;
baseUrl?: string;
contentType?: string;
}
export async function proxyApiRequest({
method,
path,
payload,
baseUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL || DEFAULT_BASE_URL,
contentType = "application/json",
}: ProxyRequestOptions) {
return await Sentry.withServerActionInstrumentation(
"proxyApiRequest",
{},
async () => {
const url = buildRequestUrl(baseUrl, path, method, payload);
return makeAuthenticatedRequest(method, url, payload, contentType);
},
);
}
export async function proxyFileUpload(
path: string,
formData: FormData,
baseUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL ||
"http://localhost:8006/api",
): Promise<string> {
return await Sentry.withServerActionInstrumentation(
"proxyFileUpload",
{},
async () => {
const url = baseUrl + path;
return makeAuthenticatedFileUpload(url, formData);
},
);
}

View File

@@ -140,12 +140,17 @@ export type BlockIOBooleanSubSchema = BlockIOSubSchemaMeta & {
secret?: boolean;
};
export type CredentialsType = "api_key" | "oauth2" | "user_password";
export type CredentialsType =
| "api_key"
| "oauth2"
| "user_password"
| "host_scoped";
export type Credentials =
| APIKeyCredentials
| OAuth2Credentials
| UserPasswordCredentials;
| UserPasswordCredentials
| HostScopedCredentials;
// --8<-- [start:BlockIOCredentialsSubSchema]
export const PROVIDER_NAMES = {
@@ -161,6 +166,7 @@ export const PROVIDER_NAMES = {
GOOGLE: "google",
GOOGLE_MAPS: "google_maps",
GROQ: "groq",
HTTP: "http",
HUBSPOT: "hubspot",
IDEOGRAM: "ideogram",
JINA: "jina",
@@ -199,6 +205,7 @@ export type BlockIOCredentialsSubSchema = BlockIOObjectSubSchema & {
credentials_types: Array<CredentialsType>;
discriminator?: string;
discriminator_mapping?: { [key: string]: CredentialsProviderName };
discriminator_values?: any[];
secret?: boolean;
};
@@ -501,6 +508,7 @@ export type CredentialsMetaResponse = {
title?: string;
scopes?: Array<string>;
username?: string;
host?: string;
};
/* Mirror of backend/server/integrations/router.py:CredentialsDeletionResponse */
@@ -559,6 +567,14 @@ export type UserPasswordCredentials = BaseCredentials & {
password: string;
};
/* Mirror of backend/backend/data/model.py:HostScopedCredentials */
export type HostScopedCredentials = BaseCredentials & {
type: "host_scoped";
title: string;
host: string;
headers: Record<string, string>;
};
// Mirror of backend/backend/data/notifications.py:NotificationType
export type NotificationType =
| "AGENT_RUN"

View File

@@ -0,0 +1,180 @@
# Server-Side Session Validation with httpOnly Cookies
This implementation ensures that Supabase session validation is always performed on the server side using httpOnly cookies for improved security.
## Key Features
- **httpOnly Cookies**: Session cookies are inaccessible to client-side JavaScript, preventing XSS attacks
- **Server-Side Authentication**: All API requests are authenticated on the server using httpOnly cookies
- **Automatic Request Proxying**: All BackendAPI requests are automatically proxied through server actions
- **File Upload Support**: File uploads work seamlessly with httpOnly cookie authentication
- **Zero Code Changes**: Existing BackendAPI usage continues to work without modifications
- **Cross-Tab Logout**: Logout events are still synchronized across browser tabs
## How It Works
All API requests made through `BackendAPI` are automatically proxied through server actions that:
1. Retrieve the JWT token from server-side httpOnly cookies
2. Make the authenticated request to the backend API
3. Return the response to the client
This includes both regular API calls and file uploads, all handled transparently!
## Usage
### Client Components
No changes needed! The existing `useSupabase` hook and `useBackendAPI` continue to work:
```tsx
"use client";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
function MyComponent() {
const { user, isLoggedIn, isUserLoading, logOut } = useSupabase();
const api = useBackendAPI();
if (isUserLoading) return <div>Loading...</div>;
if (!isLoggedIn) return <div>Please log in</div>;
// Regular API calls use secure server-side authentication
const handleGetGraphs = async () => {
const graphs = await api.listGraphs();
console.log(graphs);
};
// File uploads also work with secure authentication
const handleFileUpload = async (file: File) => {
try {
const mediaUrl = await api.uploadStoreSubmissionMedia(file);
console.log("Uploaded:", mediaUrl);
} catch (error) {
console.error("Upload failed:", error);
}
};
return (
<div>
<p>Welcome, {user?.email}!</p>
<button onClick={handleGetGraphs}>Get Graphs</button>
<input
type="file"
onChange={(e) =>
e.target.files?.[0] && handleFileUpload(e.target.files[0])
}
/>
<button onClick={logOut}>Log Out</button>
</div>
);
}
```
### Server Components
No changes needed! Server components continue to work as before:
```tsx
import { validateSession, getCurrentUser } from "@/lib/supabase/actions";
import { redirect } from "next/navigation";
async function MyServerComponent() {
const { user, error } = await getCurrentUser();
if (error || !user) {
redirect("/login");
}
return <div>Welcome, {user.email}!</div>;
}
```
### Server Actions
No changes needed! Server actions continue to work as before:
```tsx
"use server";
import { validateSession } from "@/lib/supabase/actions";
import BackendAPI from "@/lib/autogpt-server-api";
import { redirect } from "next/navigation";
export async function myServerAction() {
const { user, isValid } = await validateSession("/current-path");
if (!isValid || !user) {
redirect("/login");
return;
}
// This automatically uses secure server-side authentication
const api = new BackendAPI();
const graphs = await api.listGraphs();
return graphs;
}
```
### API Calls and File Uploads
All operations use the same simple code everywhere:
```tsx
// Works the same in both client and server contexts
const api = new BackendAPI();
// Regular API requests
const graphs = await api.listGraphs();
const user = await api.createUser();
const onboarding = await api.getUserOnboarding();
// File uploads
const file = new File(["content"], "example.txt", { type: "text/plain" });
const mediaUrl = await api.uploadStoreSubmissionMedia(file);
```
## Available Server Actions
- `validateSession(currentPath)` - Validates the current session and returns user data
- `getCurrentUser()` - Gets the current user without path validation
- `serverLogout()` - Logs out the user server-side
- `refreshSession()` - Refreshes the current session
## Internal Architecture
### Request Flow
All API requests (including file uploads) follow this flow:
1. **Any API call**: `api.listGraphs()` or `api.uploadStoreSubmissionMedia(file)`
2. **Proxy server action**: `proxyApiRequest()` or `proxyFileUpload()` handles the request
3. **Server authentication**: Gets JWT from httpOnly cookies
4. **Backend request**: Makes authenticated request to backend API
5. **Response**: Returns data to the calling code
### File Upload Implementation
File uploads are handled through a dedicated `proxyFileUpload` server action that:
- Receives the file data as FormData on the server
- Retrieves authentication tokens from httpOnly cookies
- Forwards the authenticated upload request to the backend
- Returns the upload result to the client
## Security Benefits
1. **XSS Protection**: httpOnly cookies can't be accessed by malicious scripts
2. **CSRF Protection**: Combined with SameSite cookie settings
3. **Server-Side Validation**: Session validation always happens on the trusted server
4. **Zero Token Exposure**: JWT tokens never exposed to client-side JavaScript
5. **Zero Attack Surface**: No client-side session manipulation possible
6. **Secure File Uploads**: File uploads maintain the same security model
## Migration Notes
- **No code changes required** - all existing BackendAPI usage continues to work
- File uploads now work seamlessly with httpOnly cookies
- Cross-tab logout functionality is preserved
- WebSocket connections may need reconnection after session changes
- All requests now have consistent security behavior

View File

@@ -0,0 +1,223 @@
"use server";
import * as Sentry from "@sentry/nextjs";
import type { User } from "@supabase/supabase-js";
import { revalidatePath } from "next/cache";
import { redirect } from "next/navigation";
import { getRedirectPath } from "./helpers";
import { getServerSupabase } from "./server/getServerSupabase";
export interface SessionValidationResult {
user: User | null;
isValid: boolean;
redirectPath?: string;
}
export async function validateSession(
currentPath: string,
): Promise<SessionValidationResult> {
return await Sentry.withServerActionInstrumentation(
"validateSession",
{},
async () => {
const supabase = await getServerSupabase();
if (!supabase) {
return {
user: null,
isValid: false,
redirectPath: getRedirectPath(currentPath) || undefined,
};
}
try {
const {
data: { user },
error,
} = await supabase.auth.getUser();
if (error || !user) {
const redirectPath = getRedirectPath(currentPath);
return {
user: null,
isValid: false,
redirectPath: redirectPath || undefined,
};
}
return {
user,
isValid: true,
};
} catch (error) {
console.error("Session validation error:", error);
const redirectPath = getRedirectPath(currentPath);
return {
user: null,
isValid: false,
redirectPath: redirectPath || undefined,
};
}
},
);
}
export async function getCurrentUser(): Promise<{
user: User | null;
error?: string;
}> {
return await Sentry.withServerActionInstrumentation(
"getCurrentUser",
{},
async () => {
const supabase = await getServerSupabase();
if (!supabase) {
return {
user: null,
error: "Supabase client not available",
};
}
try {
const {
data: { user },
error,
} = await supabase.auth.getUser();
if (error) {
return {
user: null,
error: error.message,
};
}
return { user };
} catch (error) {
console.error("Get current user error:", error);
return {
user: null,
error: error instanceof Error ? error.message : "Unknown error",
};
}
},
);
}
export async function getWebSocketToken(): Promise<{
token: string | null;
error?: string;
}> {
return await Sentry.withServerActionInstrumentation(
"getWebSocketToken",
{},
async () => {
const supabase = await getServerSupabase();
if (!supabase) {
return {
token: null,
error: "Supabase client not available",
};
}
try {
const {
data: { session },
error,
} = await supabase.auth.getSession();
if (error) {
return {
token: null,
error: error.message,
};
}
return { token: session?.access_token || null };
} catch (error) {
console.error("Get WebSocket token error:", error);
return {
token: null,
error: error instanceof Error ? error.message : "Unknown error",
};
}
},
);
}
export type ServerLogoutOptions = {
globalLogout?: boolean;
};
export async function serverLogout(options: ServerLogoutOptions = {}) {
return await Sentry.withServerActionInstrumentation(
"serverLogout",
{},
async () => {
const supabase = await getServerSupabase();
if (!supabase) {
redirect("/login");
return;
}
try {
const { error } = await supabase.auth.signOut({
scope: options.globalLogout ? "global" : "local",
});
if (error) {
console.error("Error logging out:", error);
}
} catch (error) {
console.error("Logout error:", error);
}
// Clear all cached data and redirect
revalidatePath("/", "layout");
redirect("/login");
},
);
}
export async function refreshSession() {
return await Sentry.withServerActionInstrumentation(
"refreshSession",
{},
async () => {
const supabase = await getServerSupabase();
if (!supabase) {
return {
user: null,
error: "Supabase client not available",
};
}
try {
const {
data: { user },
error,
} = await supabase.auth.refreshSession();
if (error) {
return {
user: null,
error: error.message,
};
}
// Revalidate the layout to update server components
revalidatePath("/", "layout");
return { user };
} catch (error) {
console.error("Refresh session error:", error);
return {
user: null,
error: error instanceof Error ? error.message : "Unknown error",
};
}
},
);
}

Some files were not shown because too many files have changed in this diff Show More