mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge 'dev' into 'chore/storybook-test-setup'
This commit is contained in:
5
.github/workflows/platform-backend-ci.yml
vendored
5
.github/workflows/platform-backend-ci.yml
vendored
@@ -190,9 +190,9 @@ jobs:
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
else
|
||||
poetry run pytest -s -vv test
|
||||
poetry run pytest -s -vv
|
||||
fi
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
@@ -205,6 +205,7 @@ jobs:
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
env:
|
||||
CI: true
|
||||
|
||||
@@ -20,7 +20,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
modules = [
|
||||
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
|
||||
for f in current_dir.rglob("*.py")
|
||||
if f.is_file() and f.name != "__init__.py"
|
||||
if f.is_file() and f.name != "__init__.py" and not f.name.startswith("test_")
|
||||
]
|
||||
for module in modules:
|
||||
if not re.match("^[a-z0-9_.]+$", module):
|
||||
|
||||
@@ -53,6 +53,7 @@ class AudioTrack(str, Enum):
|
||||
REFRESHER = ("Refresher",)
|
||||
TOURIST = ("Tourist",)
|
||||
TWIN_TYCHES = ("Twin Tyches",)
|
||||
DONT_STOP_ME_ABSTRACT_FUTURE_BASS = ("Dont Stop Me Abstract Future Bass",)
|
||||
|
||||
@property
|
||||
def audio_url(self):
|
||||
@@ -78,6 +79,7 @@ class AudioTrack(str, Enum):
|
||||
AudioTrack.REFRESHER: "https://cdn.tfrv.xyz/audio/refresher.mp3",
|
||||
AudioTrack.TOURIST: "https://cdn.tfrv.xyz/audio/tourist.mp3",
|
||||
AudioTrack.TWIN_TYCHES: "https://cdn.tfrv.xyz/audio/twin-tynches.mp3",
|
||||
AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS: "https://cdn.revid.ai/audio/_dont-stop-me-abstract-future-bass.mp3",
|
||||
}
|
||||
return audio_urls[self]
|
||||
|
||||
@@ -105,6 +107,7 @@ class GenerationPreset(str, Enum):
|
||||
MOVIE = ("Movie",)
|
||||
STYLIZED_ILLUSTRATION = ("Stylized Illustration",)
|
||||
MANGA = ("Manga",)
|
||||
DEFAULT = ("DEFAULT",)
|
||||
|
||||
|
||||
class Voice(str, Enum):
|
||||
@@ -114,6 +117,7 @@ class Voice(str, Enum):
|
||||
JESSICA = "Jessica"
|
||||
CHARLOTTE = "Charlotte"
|
||||
CALLUM = "Callum"
|
||||
EVA = "Eva"
|
||||
|
||||
@property
|
||||
def voice_id(self):
|
||||
@@ -124,6 +128,7 @@ class Voice(str, Enum):
|
||||
Voice.JESSICA: "cgSgspJ2msm6clMCkdW9",
|
||||
Voice.CHARLOTTE: "XB0fDUnXU5powFXDhCwa",
|
||||
Voice.CALLUM: "N2lVS1w4EtoT3dr4eOWO",
|
||||
Voice.EVA: "FGY2WhTYpPnrIDTdsKH5",
|
||||
}
|
||||
return voice_id_map[self]
|
||||
|
||||
@@ -141,6 +146,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIShortformVideoCreatorBlock(Block):
|
||||
"""Creates a short‑form text‑to‑video clip using stock or AI imagery."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
@@ -184,40 +191,8 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
video_url: str = SchemaField(description="The URL of the created video")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
|
||||
description="Creates a shortform video using revid.ai",
|
||||
categories={BlockCategory.SOCIAL, BlockCategory.AI},
|
||||
input_schema=AIShortformVideoCreatorBlock.Input,
|
||||
output_schema=AIShortformVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "[close-up of a cat] Meow!",
|
||||
"ratio": "9 / 16",
|
||||
"resolution": "720p",
|
||||
"frame_rate": 60,
|
||||
"generation_preset": GenerationPreset.LEONARDO,
|
||||
"background_music": AudioTrack.HIGHWAY_NOCTURNE,
|
||||
"voice": Voice.LILY,
|
||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||
},
|
||||
test_output=(
|
||||
"video_url",
|
||||
"https://example.com/video.mp4",
|
||||
),
|
||||
test_mock={
|
||||
"create_webhook": lambda: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda api_key, payload: {"pid": "test_pid"},
|
||||
"wait_for_video": lambda api_key, pid, webhook_token, max_wait_time=1000: "https://example.com/video.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def create_webhook(self):
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
@@ -225,6 +200,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
@@ -234,6 +210,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
@@ -243,9 +220,9 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
webhook_token: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
@@ -266,6 +243,40 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
|
||||
description="Creates a shortform video using revid.ai",
|
||||
categories={BlockCategory.SOCIAL, BlockCategory.AI},
|
||||
input_schema=AIShortformVideoCreatorBlock.Input,
|
||||
output_schema=AIShortformVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "[close-up of a cat] Meow!",
|
||||
"ratio": "9 / 16",
|
||||
"resolution": "720p",
|
||||
"frame_rate": 60,
|
||||
"generation_preset": GenerationPreset.LEONARDO,
|
||||
"background_music": AudioTrack.HIGHWAY_NOCTURNE,
|
||||
"voice": Voice.LILY,
|
||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||
},
|
||||
test_output=("video_url", "https://example.com/video.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/video.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
@@ -273,20 +284,18 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
logger.debug(f"Webhook URL: {webhook_url}")
|
||||
|
||||
audio_url = input_data.background_music.audio_url
|
||||
|
||||
payload = {
|
||||
"frameRate": input_data.frame_rate,
|
||||
"resolution": input_data.resolution,
|
||||
"frameDurationMultiplier": 18,
|
||||
"webhook": webhook_url,
|
||||
"webhook": None,
|
||||
"creationParams": {
|
||||
"mediaType": input_data.video_style,
|
||||
"captionPresetName": "Wrap 1",
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"hasEnhancedGeneration": True,
|
||||
"generationPreset": input_data.generation_preset.name,
|
||||
"selectedAudio": input_data.background_music,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"origin": "/create",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
@@ -302,7 +311,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
"selectedStoryStyle": {"value": "custom", "label": "Custom"},
|
||||
"hasToGenerateVideos": input_data.video_style
|
||||
!= VisualMediaType.STOCK_VIDEOS,
|
||||
"audioUrl": audio_url,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -319,8 +328,370 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
logger.debug(
|
||||
f"Video created with project ID: {pid}. Waiting for completion..."
|
||||
)
|
||||
video_url = await self.wait_for_video(
|
||||
credentials.api_key, pid, webhook_token
|
||||
)
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
logger.debug(f"Video ready: {video_url}")
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIAdMakerVideoCreatorBlock(Block):
|
||||
"""Generates a 30‑second vertical AI advert using optional user‑supplied imagery."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Credentials for Revid.ai API access.",
|
||||
)
|
||||
script: str = SchemaField(
|
||||
description="Short advertising copy. Line breaks create new scenes.",
|
||||
placeholder="Introducing Foobar – [show product photo] the gadget that does it all.",
|
||||
)
|
||||
ratio: str = SchemaField(description="Aspect ratio", default="9 / 16")
|
||||
target_duration: int = SchemaField(
|
||||
description="Desired length of the ad in seconds.", default=30
|
||||
)
|
||||
voice: Voice = SchemaField(
|
||||
description="Narration voice", default=Voice.EVA, placeholder=Voice.EVA
|
||||
)
|
||||
background_music: AudioTrack = SchemaField(
|
||||
description="Background track",
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS,
|
||||
)
|
||||
input_media_urls: list[str] = SchemaField(
|
||||
description="List of image URLs to feature in the advert.", default=[]
|
||||
)
|
||||
use_only_provided_media: bool = SchemaField(
|
||||
description="Restrict visuals to supplied images only.", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="URL of the finished advert")
|
||||
error: str = SchemaField(description="Error message on failure")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
async def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="58bd2a19-115d-4fd1-8ca4-13b9e37fa6a0",
|
||||
description="Creates an AI‑generated 30‑second advert (text + images)",
|
||||
categories={BlockCategory.MARKETING, BlockCategory.AI},
|
||||
input_schema=AIAdMakerVideoCreatorBlock.Input,
|
||||
output_schema=AIAdMakerVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "Test product launch!",
|
||||
"input_media_urls": [
|
||||
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
||||
],
|
||||
},
|
||||
test_output=("video_url", "https://example.com/ad.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/ad.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
|
||||
payload = {
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"targetDuration": input_data.target_duration,
|
||||
"ratio": input_data.ratio,
|
||||
"mediaType": "aiVideo",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "ai-ad-generator",
|
||||
"slugNew": "",
|
||||
"isCopiedFrom": False,
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasAvatar": False,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
|
||||
"selectedAvatarType": "video/mp4",
|
||||
"websiteToRecord": "",
|
||||
"hasToGenerateCover": True,
|
||||
"nbGenerations": 1,
|
||||
"disableCaptions": False,
|
||||
"mediaMultiplier": "medium",
|
||||
"characters": [],
|
||||
"captionPresetName": "Revid",
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "General"},
|
||||
"generationPreset": "DEFAULT",
|
||||
"hasToGenerateMusic": False,
|
||||
"isOptimizedForChinese": False,
|
||||
"generationUserPrompt": "",
|
||||
"enableNsfwFilter": False,
|
||||
"addStickers": False,
|
||||
"typeMovingImageAnim": "dynamic",
|
||||
"hasToGenerateSoundEffects": False,
|
||||
"forceModelType": "gpt-image-1",
|
||||
"selectedCharacters": [],
|
||||
"lang": "",
|
||||
"voiceSpeed": 1,
|
||||
"disableAudio": False,
|
||||
"disableVoice": False,
|
||||
"useOnlyProvidedMedia": input_data.use_only_provided_media,
|
||||
"imageGenerationModel": "ultra",
|
||||
"videoGenerationModel": "pro",
|
||||
"hasEnhancedGeneration": True,
|
||||
"hasEnhancedGenerationPro": True,
|
||||
"inputMedias": [
|
||||
{"url": url, "title": "", "type": "image"}
|
||||
for url in input_data.input_media_urls
|
||||
],
|
||||
"hasToGenerateVideos": True,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"watermark": None,
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.create_video(credentials.api_key, payload)
|
||||
pid = response.get("pid")
|
||||
if not pid:
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIScreenshotToVideoAdBlock(Block):
|
||||
"""Creates an advert where the supplied screenshot is narrated by an AI avatar."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(description="Revid.ai API key")
|
||||
script: str = SchemaField(
|
||||
description="Narration that will accompany the screenshot.",
|
||||
placeholder="Check out these amazing stats!",
|
||||
)
|
||||
screenshot_url: str = SchemaField(
|
||||
description="Screenshot or image URL to showcase."
|
||||
)
|
||||
ratio: str = SchemaField(default="9 / 16")
|
||||
target_duration: int = SchemaField(default=30)
|
||||
voice: Voice = SchemaField(default=Voice.EVA)
|
||||
background_music: AudioTrack = SchemaField(
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="Rendered video URL")
|
||||
error: str = SchemaField(description="Error, if encountered")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
async def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0f3e4635-e810-43d9-9e81-49e6f4e83b7c",
|
||||
description="Turns a screenshot into an engaging, avatar‑narrated video advert.",
|
||||
categories={BlockCategory.AI, BlockCategory.MARKETING},
|
||||
input_schema=AIScreenshotToVideoAdBlock.Input,
|
||||
output_schema=AIScreenshotToVideoAdBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "Amazing numbers!",
|
||||
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
||||
},
|
||||
test_output=("video_url", "https://example.com/screenshot.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/screenshot.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
|
||||
payload = {
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"targetDuration": input_data.target_duration,
|
||||
"ratio": input_data.ratio,
|
||||
"mediaType": "aiVideo",
|
||||
"hasAvatar": True,
|
||||
"removeAvatarBackground": True,
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "ai-ad-generator",
|
||||
"slugNew": "screenshot-to-video-ad",
|
||||
"isCopiedFrom": "ai-ad-generator",
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
|
||||
"selectedAvatarType": "video/mp4",
|
||||
"websiteToRecord": "",
|
||||
"hasToGenerateCover": True,
|
||||
"nbGenerations": 1,
|
||||
"disableCaptions": False,
|
||||
"mediaMultiplier": "medium",
|
||||
"characters": [],
|
||||
"captionPresetName": "Revid",
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "General"},
|
||||
"generationPreset": "DEFAULT",
|
||||
"hasToGenerateMusic": False,
|
||||
"isOptimizedForChinese": False,
|
||||
"generationUserPrompt": "",
|
||||
"enableNsfwFilter": False,
|
||||
"addStickers": False,
|
||||
"typeMovingImageAnim": "dynamic",
|
||||
"hasToGenerateSoundEffects": False,
|
||||
"forceModelType": "gpt-image-1",
|
||||
"selectedCharacters": [],
|
||||
"lang": "",
|
||||
"voiceSpeed": 1,
|
||||
"disableAudio": False,
|
||||
"disableVoice": False,
|
||||
"useOnlyProvidedMedia": True,
|
||||
"imageGenerationModel": "ultra",
|
||||
"videoGenerationModel": "ultra",
|
||||
"hasEnhancedGeneration": True,
|
||||
"hasEnhancedGenerationPro": True,
|
||||
"inputMedias": [
|
||||
{"url": input_data.screenshot_url, "title": "", "type": "image"}
|
||||
],
|
||||
"hasToGenerateVideos": True,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"watermark": None,
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.create_video(credentials.api_key, payload)
|
||||
pid = response.get("pid")
|
||||
if not pid:
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List
|
||||
from backend.blocks.apollo._auth import ApolloCredentials
|
||||
from backend.blocks.apollo.models import (
|
||||
Contact,
|
||||
EnrichPersonRequest,
|
||||
Organization,
|
||||
SearchOrganizationsRequest,
|
||||
SearchOrganizationsResponse,
|
||||
@@ -110,3 +111,21 @@ class ApolloClient:
|
||||
return (
|
||||
organizations[: query.max_results] if query.max_results else organizations
|
||||
)
|
||||
|
||||
async def enrich_person(self, query: EnrichPersonRequest) -> Contact:
|
||||
"""Enrich a person's data including email & phone reveal"""
|
||||
response = await self.requests.post(
|
||||
f"{self.API_URL}/people/match",
|
||||
headers=self._get_headers(),
|
||||
json=query.model_dump(),
|
||||
params={
|
||||
"reveal_personal_emails": "true",
|
||||
},
|
||||
)
|
||||
data = response.json()
|
||||
if "person" not in data:
|
||||
raise ValueError(f"Person not found or enrichment failed: {data}")
|
||||
|
||||
contact = Contact(**data["person"])
|
||||
contact.email = contact.email or "-"
|
||||
return contact
|
||||
|
||||
@@ -23,9 +23,9 @@ class BaseModel(OriginalBaseModel):
|
||||
class PrimaryPhone(BaseModel):
|
||||
"""A primary phone in Apollo"""
|
||||
|
||||
number: str = ""
|
||||
source: str = ""
|
||||
sanitized_number: str = ""
|
||||
number: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
sanitized_number: Optional[str] = ""
|
||||
|
||||
|
||||
class SenorityLevels(str, Enum):
|
||||
@@ -56,102 +56,102 @@ class ContactEmailStatuses(str, Enum):
|
||||
class RuleConfigStatus(BaseModel):
|
||||
"""A rule config status in Apollo"""
|
||||
|
||||
_id: str = ""
|
||||
created_at: str = ""
|
||||
rule_action_config_id: str = ""
|
||||
rule_config_id: str = ""
|
||||
status_cd: str = ""
|
||||
updated_at: str = ""
|
||||
id: str = ""
|
||||
key: str = ""
|
||||
_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
rule_action_config_id: Optional[str] = ""
|
||||
rule_config_id: Optional[str] = ""
|
||||
status_cd: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
id: Optional[str] = ""
|
||||
key: Optional[str] = ""
|
||||
|
||||
|
||||
class ContactCampaignStatus(BaseModel):
|
||||
"""A contact campaign status in Apollo"""
|
||||
|
||||
id: str = ""
|
||||
emailer_campaign_id: str = ""
|
||||
send_email_from_user_id: str = ""
|
||||
inactive_reason: str = ""
|
||||
status: str = ""
|
||||
added_at: str = ""
|
||||
added_by_user_id: str = ""
|
||||
finished_at: str = ""
|
||||
paused_at: str = ""
|
||||
auto_unpause_at: str = ""
|
||||
send_email_from_email_address: str = ""
|
||||
send_email_from_email_account_id: str = ""
|
||||
manually_set_unpause: str = ""
|
||||
failure_reason: str = ""
|
||||
current_step_id: str = ""
|
||||
in_response_to_emailer_message_id: str = ""
|
||||
cc_emails: str = ""
|
||||
bcc_emails: str = ""
|
||||
to_emails: str = ""
|
||||
id: Optional[str] = ""
|
||||
emailer_campaign_id: Optional[str] = ""
|
||||
send_email_from_user_id: Optional[str] = ""
|
||||
inactive_reason: Optional[str] = ""
|
||||
status: Optional[str] = ""
|
||||
added_at: Optional[str] = ""
|
||||
added_by_user_id: Optional[str] = ""
|
||||
finished_at: Optional[str] = ""
|
||||
paused_at: Optional[str] = ""
|
||||
auto_unpause_at: Optional[str] = ""
|
||||
send_email_from_email_address: Optional[str] = ""
|
||||
send_email_from_email_account_id: Optional[str] = ""
|
||||
manually_set_unpause: Optional[str] = ""
|
||||
failure_reason: Optional[str] = ""
|
||||
current_step_id: Optional[str] = ""
|
||||
in_response_to_emailer_message_id: Optional[str] = ""
|
||||
cc_emails: Optional[str] = ""
|
||||
bcc_emails: Optional[str] = ""
|
||||
to_emails: Optional[str] = ""
|
||||
|
||||
|
||||
class Account(BaseModel):
|
||||
"""An account in Apollo"""
|
||||
|
||||
id: str = ""
|
||||
name: str = ""
|
||||
website_url: str = ""
|
||||
blog_url: str = ""
|
||||
angellist_url: str = ""
|
||||
linkedin_url: str = ""
|
||||
twitter_url: str = ""
|
||||
facebook_url: str = ""
|
||||
primary_phone: PrimaryPhone = PrimaryPhone()
|
||||
languages: list[str]
|
||||
alexa_ranking: int = 0
|
||||
phone: str = ""
|
||||
linkedin_uid: str = ""
|
||||
founded_year: int = 0
|
||||
publicly_traded_symbol: str = ""
|
||||
publicly_traded_exchange: str = ""
|
||||
logo_url: str = ""
|
||||
chrunchbase_url: str = ""
|
||||
primary_domain: str = ""
|
||||
domain: str = ""
|
||||
team_id: str = ""
|
||||
organization_id: str = ""
|
||||
account_stage_id: str = ""
|
||||
source: str = ""
|
||||
original_source: str = ""
|
||||
creator_id: str = ""
|
||||
owner_id: str = ""
|
||||
created_at: str = ""
|
||||
phone_status: str = ""
|
||||
hubspot_id: str = ""
|
||||
salesforce_id: str = ""
|
||||
crm_owner_id: str = ""
|
||||
parent_account_id: str = ""
|
||||
sanitized_phone: str = ""
|
||||
id: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
website_url: Optional[str] = ""
|
||||
blog_url: Optional[str] = ""
|
||||
angellist_url: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
facebook_url: Optional[str] = ""
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
|
||||
languages: Optional[list[str]] = []
|
||||
alexa_ranking: Optional[int] = 0
|
||||
phone: Optional[str] = ""
|
||||
linkedin_uid: Optional[str] = ""
|
||||
founded_year: Optional[int] = 0
|
||||
publicly_traded_symbol: Optional[str] = ""
|
||||
publicly_traded_exchange: Optional[str] = ""
|
||||
logo_url: Optional[str] = ""
|
||||
chrunchbase_url: Optional[str] = ""
|
||||
primary_domain: Optional[str] = ""
|
||||
domain: Optional[str] = ""
|
||||
team_id: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
account_stage_id: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
original_source: Optional[str] = ""
|
||||
creator_id: Optional[str] = ""
|
||||
owner_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
phone_status: Optional[str] = ""
|
||||
hubspot_id: Optional[str] = ""
|
||||
salesforce_id: Optional[str] = ""
|
||||
crm_owner_id: Optional[str] = ""
|
||||
parent_account_id: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
# no listed type on the API docs
|
||||
account_playbook_statues: list[Any] = []
|
||||
account_rule_config_statuses: list[RuleConfigStatus] = []
|
||||
existence_level: str = ""
|
||||
label_ids: list[str] = []
|
||||
typed_custom_fields: Any
|
||||
custom_field_errors: Any
|
||||
modality: str = ""
|
||||
source_display_name: str = ""
|
||||
salesforce_record_id: str = ""
|
||||
crm_record_url: str = ""
|
||||
account_playbook_statues: Optional[list[Any]] = []
|
||||
account_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
|
||||
existence_level: Optional[str] = ""
|
||||
label_ids: Optional[list[str]] = []
|
||||
typed_custom_fields: Optional[Any] = {}
|
||||
custom_field_errors: Optional[Any] = {}
|
||||
modality: Optional[str] = ""
|
||||
source_display_name: Optional[str] = ""
|
||||
salesforce_record_id: Optional[str] = ""
|
||||
crm_record_url: Optional[str] = ""
|
||||
|
||||
|
||||
class ContactEmail(BaseModel):
|
||||
"""A contact email in Apollo"""
|
||||
|
||||
email: str = ""
|
||||
email_md5: str = ""
|
||||
email_sha256: str = ""
|
||||
email_status: str = ""
|
||||
email_source: str = ""
|
||||
extrapolated_email_confidence: str = ""
|
||||
position: int = 0
|
||||
email_from_customer: str = ""
|
||||
free_domain: bool = True
|
||||
email: Optional[str] = ""
|
||||
email_md5: Optional[str] = ""
|
||||
email_sha256: Optional[str] = ""
|
||||
email_status: Optional[str] = ""
|
||||
email_source: Optional[str] = ""
|
||||
extrapolated_email_confidence: Optional[str] = ""
|
||||
position: Optional[int] = 0
|
||||
email_from_customer: Optional[str] = ""
|
||||
free_domain: Optional[bool] = True
|
||||
|
||||
|
||||
class EmploymentHistory(BaseModel):
|
||||
@@ -164,40 +164,40 @@ class EmploymentHistory(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
current: Optional[bool] = None
|
||||
degree: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
emails: Optional[str] = None
|
||||
end_date: Optional[str] = None
|
||||
grade_level: Optional[str] = None
|
||||
kind: Optional[str] = None
|
||||
major: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
organization_name: Optional[str] = None
|
||||
raw_address: Optional[str] = None
|
||||
start_date: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
id: Optional[str] = None
|
||||
key: Optional[str] = None
|
||||
_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
current: Optional[bool] = False
|
||||
degree: Optional[str] = ""
|
||||
description: Optional[str] = ""
|
||||
emails: Optional[str] = ""
|
||||
end_date: Optional[str] = ""
|
||||
grade_level: Optional[str] = ""
|
||||
kind: Optional[str] = ""
|
||||
major: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
organization_name: Optional[str] = ""
|
||||
raw_address: Optional[str] = ""
|
||||
start_date: Optional[str] = ""
|
||||
title: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
id: Optional[str] = ""
|
||||
key: Optional[str] = ""
|
||||
|
||||
|
||||
class Breadcrumb(BaseModel):
|
||||
"""A breadcrumb in Apollo"""
|
||||
|
||||
label: Optional[str] = "N/A"
|
||||
signal_field_name: Optional[str] = "N/A"
|
||||
value: str | list | None = "N/A"
|
||||
display_name: Optional[str] = "N/A"
|
||||
label: Optional[str] = ""
|
||||
signal_field_name: Optional[str] = ""
|
||||
value: str | list | None = ""
|
||||
display_name: Optional[str] = ""
|
||||
|
||||
|
||||
class TypedCustomField(BaseModel):
|
||||
"""A typed custom field in Apollo"""
|
||||
|
||||
id: Optional[str] = "N/A"
|
||||
value: Optional[str] = "N/A"
|
||||
id: Optional[str] = ""
|
||||
value: Optional[str] = ""
|
||||
|
||||
|
||||
class Pagination(BaseModel):
|
||||
@@ -219,23 +219,23 @@ class Pagination(BaseModel):
|
||||
class DialerFlags(BaseModel):
|
||||
"""A dialer flags in Apollo"""
|
||||
|
||||
country_name: str = ""
|
||||
country_enabled: bool
|
||||
high_risk_calling_enabled: bool
|
||||
potential_high_risk_number: bool
|
||||
country_name: Optional[str] = ""
|
||||
country_enabled: Optional[bool] = True
|
||||
high_risk_calling_enabled: Optional[bool] = True
|
||||
potential_high_risk_number: Optional[bool] = True
|
||||
|
||||
|
||||
class PhoneNumber(BaseModel):
|
||||
"""A phone number in Apollo"""
|
||||
|
||||
raw_number: str = ""
|
||||
sanitized_number: str = ""
|
||||
type: str = ""
|
||||
position: int = 0
|
||||
status: str = ""
|
||||
dnc_status: str = ""
|
||||
dnc_other_info: str = ""
|
||||
dailer_flags: DialerFlags = DialerFlags(
|
||||
raw_number: Optional[str] = ""
|
||||
sanitized_number: Optional[str] = ""
|
||||
type: Optional[str] = ""
|
||||
position: Optional[int] = 0
|
||||
status: Optional[str] = ""
|
||||
dnc_status: Optional[str] = ""
|
||||
dnc_other_info: Optional[str] = ""
|
||||
dailer_flags: Optional[DialerFlags] = DialerFlags(
|
||||
country_name="",
|
||||
country_enabled=True,
|
||||
high_risk_calling_enabled=True,
|
||||
@@ -253,33 +253,31 @@ class Organization(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
id: Optional[str] = "N/A"
|
||||
name: Optional[str] = "N/A"
|
||||
website_url: Optional[str] = "N/A"
|
||||
blog_url: Optional[str] = "N/A"
|
||||
angellist_url: Optional[str] = "N/A"
|
||||
linkedin_url: Optional[str] = "N/A"
|
||||
twitter_url: Optional[str] = "N/A"
|
||||
facebook_url: Optional[str] = "N/A"
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone(
|
||||
number="N/A", source="N/A", sanitized_number="N/A"
|
||||
)
|
||||
languages: list[str] = []
|
||||
id: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
website_url: Optional[str] = ""
|
||||
blog_url: Optional[str] = ""
|
||||
angellist_url: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
facebook_url: Optional[str] = ""
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
|
||||
languages: Optional[list[str]] = []
|
||||
alexa_ranking: Optional[int] = 0
|
||||
phone: Optional[str] = "N/A"
|
||||
linkedin_uid: Optional[str] = "N/A"
|
||||
phone: Optional[str] = ""
|
||||
linkedin_uid: Optional[str] = ""
|
||||
founded_year: Optional[int] = 0
|
||||
publicly_traded_symbol: Optional[str] = "N/A"
|
||||
publicly_traded_exchange: Optional[str] = "N/A"
|
||||
logo_url: Optional[str] = "N/A"
|
||||
chrunchbase_url: Optional[str] = "N/A"
|
||||
primary_domain: Optional[str] = "N/A"
|
||||
sanitized_phone: Optional[str] = "N/A"
|
||||
owned_by_organization_id: Optional[str] = "N/A"
|
||||
intent_strength: Optional[str] = "N/A"
|
||||
show_intent: bool = True
|
||||
publicly_traded_symbol: Optional[str] = ""
|
||||
publicly_traded_exchange: Optional[str] = ""
|
||||
logo_url: Optional[str] = ""
|
||||
chrunchbase_url: Optional[str] = ""
|
||||
primary_domain: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
owned_by_organization_id: Optional[str] = ""
|
||||
intent_strength: Optional[str] = ""
|
||||
show_intent: Optional[bool] = True
|
||||
has_intent_signal_account: Optional[bool] = True
|
||||
intent_signal_account: Optional[str] = "N/A"
|
||||
intent_signal_account: Optional[str] = ""
|
||||
|
||||
|
||||
class Contact(BaseModel):
|
||||
@@ -292,95 +290,95 @@ class Contact(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
contact_roles: list[Any] = []
|
||||
id: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
linkedin_url: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
contact_stage_id: Optional[str] = None
|
||||
owner_id: Optional[str] = None
|
||||
creator_id: Optional[str] = None
|
||||
person_id: Optional[str] = None
|
||||
email_needs_tickling: bool = True
|
||||
organization_name: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
original_source: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
headline: Optional[str] = None
|
||||
photo_url: Optional[str] = None
|
||||
present_raw_address: Optional[str] = None
|
||||
linkededin_uid: Optional[str] = None
|
||||
extrapolated_email_confidence: Optional[float] = None
|
||||
salesforce_id: Optional[str] = None
|
||||
salesforce_lead_id: Optional[str] = None
|
||||
salesforce_contact_id: Optional[str] = None
|
||||
saleforce_account_id: Optional[str] = None
|
||||
crm_owner_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
emailer_campaign_ids: list[str] = []
|
||||
direct_dial_status: Optional[str] = None
|
||||
direct_dial_enrichment_failed_at: Optional[str] = None
|
||||
email_status: Optional[str] = None
|
||||
email_source: Optional[str] = None
|
||||
account_id: Optional[str] = None
|
||||
last_activity_date: Optional[str] = None
|
||||
hubspot_vid: Optional[str] = None
|
||||
hubspot_company_id: Optional[str] = None
|
||||
crm_id: Optional[str] = None
|
||||
sanitized_phone: Optional[str] = None
|
||||
merged_crm_ids: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
queued_for_crm_push: bool = True
|
||||
suggested_from_rule_engine_config_id: Optional[str] = None
|
||||
email_unsubscribed: Optional[str] = None
|
||||
label_ids: list[Any] = []
|
||||
has_pending_email_arcgate_request: bool = True
|
||||
has_email_arcgate_request: bool = True
|
||||
existence_level: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
email_from_customer: Optional[str] = None
|
||||
typed_custom_fields: list[TypedCustomField] = []
|
||||
custom_field_errors: Any = None
|
||||
salesforce_record_id: Optional[str] = None
|
||||
crm_record_url: Optional[str] = None
|
||||
email_status_unavailable_reason: Optional[str] = None
|
||||
email_true_status: Optional[str] = None
|
||||
updated_email_true_status: bool = True
|
||||
contact_rule_config_statuses: list[RuleConfigStatus] = []
|
||||
source_display_name: Optional[str] = None
|
||||
twitter_url: Optional[str] = None
|
||||
contact_campaign_statuses: list[ContactCampaignStatus] = []
|
||||
state: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
account: Optional[Account] = None
|
||||
contact_emails: list[ContactEmail] = []
|
||||
organization: Optional[Organization] = None
|
||||
employment_history: list[EmploymentHistory] = []
|
||||
time_zone: Optional[str] = None
|
||||
intent_strength: Optional[str] = None
|
||||
show_intent: bool = True
|
||||
phone_numbers: list[PhoneNumber] = []
|
||||
account_phone_note: Optional[str] = None
|
||||
free_domain: bool = True
|
||||
is_likely_to_engage: bool = True
|
||||
email_domain_catchall: bool = True
|
||||
contact_job_change_event: Optional[str] = None
|
||||
contact_roles: Optional[list[Any]] = []
|
||||
id: Optional[str] = ""
|
||||
first_name: Optional[str] = ""
|
||||
last_name: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
title: Optional[str] = ""
|
||||
contact_stage_id: Optional[str] = ""
|
||||
owner_id: Optional[str] = ""
|
||||
creator_id: Optional[str] = ""
|
||||
person_id: Optional[str] = ""
|
||||
email_needs_tickling: Optional[bool] = True
|
||||
organization_name: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
original_source: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
headline: Optional[str] = ""
|
||||
photo_url: Optional[str] = ""
|
||||
present_raw_address: Optional[str] = ""
|
||||
linkededin_uid: Optional[str] = ""
|
||||
extrapolated_email_confidence: Optional[float] = 0.0
|
||||
salesforce_id: Optional[str] = ""
|
||||
salesforce_lead_id: Optional[str] = ""
|
||||
salesforce_contact_id: Optional[str] = ""
|
||||
saleforce_account_id: Optional[str] = ""
|
||||
crm_owner_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
emailer_campaign_ids: Optional[list[str]] = []
|
||||
direct_dial_status: Optional[str] = ""
|
||||
direct_dial_enrichment_failed_at: Optional[str] = ""
|
||||
email_status: Optional[str] = ""
|
||||
email_source: Optional[str] = ""
|
||||
account_id: Optional[str] = ""
|
||||
last_activity_date: Optional[str] = ""
|
||||
hubspot_vid: Optional[str] = ""
|
||||
hubspot_company_id: Optional[str] = ""
|
||||
crm_id: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
merged_crm_ids: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
queued_for_crm_push: Optional[bool] = True
|
||||
suggested_from_rule_engine_config_id: Optional[str] = ""
|
||||
email_unsubscribed: Optional[str] = ""
|
||||
label_ids: Optional[list[Any]] = []
|
||||
has_pending_email_arcgate_request: Optional[bool] = True
|
||||
has_email_arcgate_request: Optional[bool] = True
|
||||
existence_level: Optional[str] = ""
|
||||
email: Optional[str] = ""
|
||||
email_from_customer: Optional[str] = ""
|
||||
typed_custom_fields: Optional[list[TypedCustomField]] = []
|
||||
custom_field_errors: Optional[Any] = {}
|
||||
salesforce_record_id: Optional[str] = ""
|
||||
crm_record_url: Optional[str] = ""
|
||||
email_status_unavailable_reason: Optional[str] = ""
|
||||
email_true_status: Optional[str] = ""
|
||||
updated_email_true_status: Optional[bool] = True
|
||||
contact_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
|
||||
source_display_name: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
contact_campaign_statuses: Optional[list[ContactCampaignStatus]] = []
|
||||
state: Optional[str] = ""
|
||||
city: Optional[str] = ""
|
||||
country: Optional[str] = ""
|
||||
account: Optional[Account] = Account()
|
||||
contact_emails: Optional[list[ContactEmail]] = []
|
||||
organization: Optional[Organization] = Organization()
|
||||
employment_history: Optional[list[EmploymentHistory]] = []
|
||||
time_zone: Optional[str] = ""
|
||||
intent_strength: Optional[str] = ""
|
||||
show_intent: Optional[bool] = True
|
||||
phone_numbers: Optional[list[PhoneNumber]] = []
|
||||
account_phone_note: Optional[str] = ""
|
||||
free_domain: Optional[bool] = True
|
||||
is_likely_to_engage: Optional[bool] = True
|
||||
email_domain_catchall: Optional[bool] = True
|
||||
contact_job_change_event: Optional[str] = ""
|
||||
|
||||
|
||||
class SearchOrganizationsRequest(BaseModel):
|
||||
"""Request for Apollo's search organizations API"""
|
||||
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
organization_num_employees_range: Optional[list[int]] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default=[0, 1000000],
|
||||
)
|
||||
|
||||
organization_locations: list[str] = SchemaField(
|
||||
organization_locations: Optional[list[str]] = SchemaField(
|
||||
description="""The location of the company headquarters. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, any Boston-based companies will not appearch in your search results, even if they match other parameters.
|
||||
@@ -389,28 +387,30 @@ To exclude companies based on location, use the organization_not_locations param
|
||||
""",
|
||||
default_factory=list,
|
||||
)
|
||||
organizations_not_locations: list[str] = SchemaField(
|
||||
organizations_not_locations: Optional[list[str]] = SchemaField(
|
||||
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
|
||||
|
||||
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
|
||||
""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_keyword_tags: list[str] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
|
||||
q_organization_keyword_tags: Optional[list[str]] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_name: str = SchemaField(
|
||||
q_organization_name: Optional[str] = SchemaField(
|
||||
description="""Filter search results to include a specific company name.
|
||||
|
||||
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible."""
|
||||
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible.""",
|
||||
default="",
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
organization_ids: Optional[list[str]] = SchemaField(
|
||||
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, identify the values for organization_id when you call this endpoint.""",
|
||||
default_factory=list,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
max_results: Optional[int] = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
@@ -435,11 +435,11 @@ Use the page parameter to search the different pages of data.""",
|
||||
class SearchOrganizationsResponse(BaseModel):
|
||||
"""Response from Apollo's search organizations API"""
|
||||
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
has_join: bool = True
|
||||
disable_eu_prospecting: bool = True
|
||||
partial_results_limit: int = 0
|
||||
breadcrumbs: Optional[list[Breadcrumb]] = []
|
||||
partial_results_only: Optional[bool] = True
|
||||
has_join: Optional[bool] = True
|
||||
disable_eu_prospecting: Optional[bool] = True
|
||||
partial_results_limit: Optional[int] = 0
|
||||
pagination: Pagination = Pagination(
|
||||
page=0, per_page=0, total_entries=0, total_pages=0
|
||||
)
|
||||
@@ -447,14 +447,14 @@ class SearchOrganizationsResponse(BaseModel):
|
||||
accounts: list[Any] = []
|
||||
organizations: list[Organization] = []
|
||||
models_ids: list[str] = []
|
||||
num_fetch_result: Optional[str] = "N/A"
|
||||
derived_params: Optional[str] = "N/A"
|
||||
num_fetch_result: Optional[str] = ""
|
||||
derived_params: Optional[str] = ""
|
||||
|
||||
|
||||
class SearchPeopleRequest(BaseModel):
|
||||
"""Request for Apollo's search people API"""
|
||||
|
||||
person_titles: list[str] = SchemaField(
|
||||
person_titles: Optional[list[str]] = SchemaField(
|
||||
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
|
||||
|
||||
Results also include job titles with the same terms, even if they are not exact matches. For example, searching for marketing manager might return people with the job title content marketing manager.
|
||||
@@ -464,13 +464,13 @@ Use this parameter in combination with the person_seniorities[] parameter to fin
|
||||
default_factory=list,
|
||||
placeholder="marketing manager",
|
||||
)
|
||||
person_locations: list[str] = SchemaField(
|
||||
person_locations: Optional[list[str]] = SchemaField(
|
||||
description="""The location where people live. You can search across cities, US states, and countries.
|
||||
|
||||
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
|
||||
default_factory=list,
|
||||
)
|
||||
person_seniorities: list[SenorityLevels] = SchemaField(
|
||||
person_seniorities: Optional[list[SenorityLevels]] = SchemaField(
|
||||
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
|
||||
|
||||
For a person to be included in search results, they only need to match 1 of the seniorities you add. Adding more seniorities expands your search results.
|
||||
@@ -480,7 +480,7 @@ Searches only return results based on their current job title, so searching for
|
||||
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
|
||||
default_factory=list,
|
||||
)
|
||||
organization_locations: list[str] = SchemaField(
|
||||
organization_locations: Optional[list[str]] = SchemaField(
|
||||
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
|
||||
@@ -488,7 +488,7 @@ If a company has several office locations, results are still based on the headqu
|
||||
To find people based on their personal location, use the person_locations parameter.""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_domains: list[str] = SchemaField(
|
||||
q_organization_domains: Optional[list[str]] = SchemaField(
|
||||
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
|
||||
|
||||
You can add multiple domains to search across companies.
|
||||
@@ -496,23 +496,23 @@ You can add multiple domains to search across companies.
|
||||
Examples: apollo.io and microsoft.com""",
|
||||
default_factory=list,
|
||||
)
|
||||
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
|
||||
contact_email_statuses: Optional[list[ContactEmailStatuses]] = SchemaField(
|
||||
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
|
||||
default_factory=list,
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
organization_ids: Optional[list[str]] = SchemaField(
|
||||
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
|
||||
default_factory=list,
|
||||
)
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
organization_num_employees_range: Optional[list[int]] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_keywords: str = SchemaField(
|
||||
q_keywords: Optional[str] = SchemaField(
|
||||
description="""A string of words over which we want to filter the results""",
|
||||
default="",
|
||||
)
|
||||
@@ -528,7 +528,7 @@ Use this parameter in combination with the per_page parameter to make search res
|
||||
Use the page parameter to search the different pages of data.""",
|
||||
default=100,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
max_results: Optional[int] = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
@@ -547,16 +547,61 @@ class SearchPeopleResponse(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
has_join: bool = True
|
||||
disable_eu_prospecting: bool = True
|
||||
partial_results_limit: int = 0
|
||||
breadcrumbs: Optional[list[Breadcrumb]] = []
|
||||
partial_results_only: Optional[bool] = True
|
||||
has_join: Optional[bool] = True
|
||||
disable_eu_prospecting: Optional[bool] = True
|
||||
partial_results_limit: Optional[int] = 0
|
||||
pagination: Pagination = Pagination(
|
||||
page=0, per_page=0, total_entries=0, total_pages=0
|
||||
)
|
||||
contacts: list[Contact] = []
|
||||
people: list[Contact] = []
|
||||
model_ids: list[str] = []
|
||||
num_fetch_result: Optional[str] = "N/A"
|
||||
derived_params: Optional[str] = "N/A"
|
||||
num_fetch_result: Optional[str] = ""
|
||||
derived_params: Optional[str] = ""
|
||||
|
||||
|
||||
class EnrichPersonRequest(BaseModel):
|
||||
"""Request for Apollo's person enrichment API"""
|
||||
|
||||
person_id: Optional[str] = SchemaField(
|
||||
description="Apollo person ID to enrich (most accurate method)",
|
||||
default="",
|
||||
)
|
||||
first_name: Optional[str] = SchemaField(
|
||||
description="First name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
last_name: Optional[str] = SchemaField(
|
||||
description="Last name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
name: Optional[str] = SchemaField(
|
||||
description="Full name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
email: Optional[str] = SchemaField(
|
||||
description="Email address of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
domain: Optional[str] = SchemaField(
|
||||
description="Company domain of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
company: Optional[str] = SchemaField(
|
||||
description="Company name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
linkedin_url: Optional[str] = SchemaField(
|
||||
description="LinkedIn URL of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
organization_id: Optional[str] = SchemaField(
|
||||
description="Apollo organization ID of the person's company",
|
||||
default="",
|
||||
)
|
||||
title: Optional[str] = SchemaField(
|
||||
description="Job title of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
|
||||
@@ -11,14 +11,14 @@ from backend.blocks.apollo.models import (
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
"""Search for organizations in Apollo"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
organization_num_employees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
@@ -65,7 +65,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
le=50000,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: ApolloCredentialsInput = SchemaField(
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
|
||||
from backend.blocks.apollo._api import ApolloClient
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
@@ -8,11 +10,12 @@ from backend.blocks.apollo._auth import (
|
||||
from backend.blocks.apollo.models import (
|
||||
Contact,
|
||||
ContactEmailStatuses,
|
||||
EnrichPersonRequest,
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
@@ -77,7 +80,7 @@ class SearchPeopleBlock(Block):
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
organization_num_employees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
@@ -90,14 +93,19 @@ class SearchPeopleBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 25. Limited to 500 to prevent overspending.""",
|
||||
default=25,
|
||||
ge=1,
|
||||
le=50000,
|
||||
le=500,
|
||||
advanced=True,
|
||||
)
|
||||
enrich_info: bool = SchemaField(
|
||||
description="""Whether to enrich contacts with detailed information including real email addresses. This will double the search cost.""",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
credentials: ApolloCredentialsInput = SchemaField(
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
@@ -106,10 +114,6 @@ class SearchPeopleBlock(Block):
|
||||
description="List of people found",
|
||||
default_factory=list,
|
||||
)
|
||||
person: Contact = SchemaField(
|
||||
title="Person",
|
||||
description="Each found person, one at a time",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the search failed",
|
||||
default="",
|
||||
@@ -125,87 +129,6 @@ class SearchPeopleBlock(Block):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
(
|
||||
"person",
|
||||
Contact(
|
||||
contact_roles=[],
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
organization_id="123456",
|
||||
contact_stage_id="1",
|
||||
owner_id="1",
|
||||
creator_id="1",
|
||||
person_id="1",
|
||||
email_needs_tickling=True,
|
||||
source="apollo",
|
||||
original_source="apollo",
|
||||
headline="Software Engineer",
|
||||
photo_url="https://www.linkedin.com/in/johndoe",
|
||||
present_raw_address="123 Main St, Anytown, USA",
|
||||
linkededin_uid="123456",
|
||||
extrapolated_email_confidence=0.8,
|
||||
salesforce_id="123456",
|
||||
salesforce_lead_id="123456",
|
||||
salesforce_contact_id="123456",
|
||||
saleforce_account_id="123456",
|
||||
crm_owner_id="123456",
|
||||
created_at="2021-01-01",
|
||||
emailer_campaign_ids=[],
|
||||
direct_dial_status="active",
|
||||
direct_dial_enrichment_failed_at="2021-01-01",
|
||||
email_status="active",
|
||||
email_source="apollo",
|
||||
account_id="123456",
|
||||
last_activity_date="2021-01-01",
|
||||
hubspot_vid="123456",
|
||||
hubspot_company_id="123456",
|
||||
crm_id="123456",
|
||||
sanitized_phone="123456",
|
||||
merged_crm_ids="123456",
|
||||
updated_at="2021-01-01",
|
||||
queued_for_crm_push=True,
|
||||
suggested_from_rule_engine_config_id="123456",
|
||||
email_unsubscribed=None,
|
||||
label_ids=[],
|
||||
has_pending_email_arcgate_request=True,
|
||||
has_email_arcgate_request=True,
|
||||
existence_level=None,
|
||||
email=None,
|
||||
email_from_customer=None,
|
||||
typed_custom_fields=[],
|
||||
custom_field_errors=None,
|
||||
salesforce_record_id=None,
|
||||
crm_record_url=None,
|
||||
email_status_unavailable_reason=None,
|
||||
email_true_status=None,
|
||||
updated_email_true_status=True,
|
||||
contact_rule_config_statuses=[],
|
||||
source_display_name=None,
|
||||
twitter_url=None,
|
||||
contact_campaign_statuses=[],
|
||||
state=None,
|
||||
city=None,
|
||||
country=None,
|
||||
account=None,
|
||||
contact_emails=[],
|
||||
organization=None,
|
||||
employment_history=[],
|
||||
time_zone=None,
|
||||
intent_strength=None,
|
||||
show_intent=True,
|
||||
phone_numbers=[],
|
||||
account_phone_note=None,
|
||||
free_domain=True,
|
||||
is_likely_to_engage=True,
|
||||
email_domain_catchall=True,
|
||||
contact_job_change_event=None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"people",
|
||||
[
|
||||
@@ -380,6 +303,34 @@ class SearchPeopleBlock(Block):
|
||||
client = ApolloClient(credentials)
|
||||
return await client.search_people(query)
|
||||
|
||||
@staticmethod
|
||||
async def enrich_person(
|
||||
query: EnrichPersonRequest, credentials: ApolloCredentials
|
||||
) -> Contact:
|
||||
client = ApolloClient(credentials)
|
||||
return await client.enrich_person(query)
|
||||
|
||||
@staticmethod
|
||||
def merge_contact_data(original: Contact, enriched: Contact) -> Contact:
|
||||
"""
|
||||
Merge contact data from original search with enriched data.
|
||||
Enriched data complements original data, only filling in missing values.
|
||||
"""
|
||||
merged_data = original.model_dump()
|
||||
enriched_data = enriched.model_dump()
|
||||
|
||||
# Only update fields that are None, empty string, empty list, or default values in original
|
||||
for key, enriched_value in enriched_data.items():
|
||||
# Skip if enriched value is None, empty string, or empty list
|
||||
if enriched_value is None or enriched_value == "" or enriched_value == []:
|
||||
continue
|
||||
|
||||
# Update if original value is None, empty string, empty list, or zero
|
||||
if enriched_value:
|
||||
merged_data[key] = enriched_value
|
||||
|
||||
return Contact(**merged_data)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
@@ -390,6 +341,23 @@ class SearchPeopleBlock(Block):
|
||||
|
||||
query = SearchPeopleRequest(**input_data.model_dump())
|
||||
people = await self.search_people(query, credentials)
|
||||
for person in people:
|
||||
yield "person", person
|
||||
|
||||
# Enrich with detailed info if requested
|
||||
if input_data.enrich_info:
|
||||
|
||||
async def enrich_or_fallback(person: Contact):
|
||||
try:
|
||||
enrich_query = EnrichPersonRequest(person_id=person.id)
|
||||
enriched_person = await self.enrich_person(
|
||||
enrich_query, credentials
|
||||
)
|
||||
# Merge enriched data with original data, complementing instead of replacing
|
||||
return self.merge_contact_data(person, enriched_person)
|
||||
except Exception:
|
||||
return person # If enrichment fails, use original person data
|
||||
|
||||
people = await asyncio.gather(
|
||||
*(enrich_or_fallback(person) for person in people)
|
||||
)
|
||||
|
||||
yield "people", people
|
||||
|
||||
138
autogpt_platform/backend/backend/blocks/apollo/person.py
Normal file
138
autogpt_platform/backend/backend/blocks/apollo/person.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from backend.blocks.apollo._api import ApolloClient
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ApolloCredentials,
|
||||
ApolloCredentialsInput,
|
||||
)
|
||||
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class GetPersonDetailBlock(Block):
|
||||
"""Get detailed person data with Apollo API, including email reveal"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
person_id: str = SchemaField(
|
||||
description="Apollo person ID to enrich (most accurate method)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
first_name: str = SchemaField(
|
||||
description="First name of the person to enrich",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
last_name: str = SchemaField(
|
||||
description="Last name of the person to enrich",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(
|
||||
description="Full name of the person to enrich (alternative to first_name + last_name)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
email: str = SchemaField(
|
||||
description="Known email address of the person (helps with matching)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Company domain of the person (e.g., 'google.com')",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
company: str = SchemaField(
|
||||
description="Company name of the person",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
linkedin_url: str = SchemaField(
|
||||
description="LinkedIn URL of the person",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
organization_id: str = SchemaField(
|
||||
description="Apollo organization ID of the person's company",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Job title of the person to enrich",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
contact: Contact = SchemaField(
|
||||
description="Enriched contact information",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if enrichment failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3b18d46c-3db6-42ae-a228-0ba441bdd176",
|
||||
description="Get detailed person data with Apollo API, including email reveal",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=GetPersonDetailBlock.Input,
|
||||
output_schema=GetPersonDetailBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"company": "Google",
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"contact",
|
||||
Contact(
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
email="john.doe@gmail.com",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
),
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"enrich_person": lambda query, credentials: Contact(
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
email="john.doe@gmail.com",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def enrich_person(
|
||||
query: EnrichPersonRequest, credentials: ApolloCredentials
|
||||
) -> Contact:
|
||||
client = ApolloClient(credentials)
|
||||
return await client.enrich_person(query)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: ApolloCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
query = EnrichPersonRequest(**input_data.model_dump())
|
||||
yield "contact", await self.enrich_person(query, credentials)
|
||||
@@ -3,6 +3,7 @@ from typing import Any
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.type import convert
|
||||
|
||||
|
||||
class ComparisonOperator(Enum):
|
||||
@@ -181,7 +182,23 @@ class IfInputMatchesBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if input_data.input == input_data.value or input_data.input is input_data.value:
|
||||
|
||||
# If input_data.value is not matching input_data.input, convert value to type of input
|
||||
if (
|
||||
input_data.input != input_data.value
|
||||
and input_data.input is not input_data.value
|
||||
):
|
||||
try:
|
||||
# Only attempt conversion if input is not None and value is not None
|
||||
if input_data.input is not None and input_data.value is not None:
|
||||
input_type = type(input_data.input)
|
||||
# Avoid converting if input_type is Any or object
|
||||
if input_type not in (Any, object):
|
||||
input_data.value = convert(input_data.value, input_type)
|
||||
except Exception:
|
||||
pass # If conversion fails, just leave value as is
|
||||
|
||||
if input_data.input == input_data.value:
|
||||
yield "result", True
|
||||
yield "yes_output", input_data.yes_value
|
||||
else:
|
||||
|
||||
@@ -3,11 +3,19 @@ import logging
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import aiofiles
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
HostScopedCredentials,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import (
|
||||
MediaFileType,
|
||||
get_exec_file_path,
|
||||
@@ -19,6 +27,30 @@ from backend.util.request import Requests
|
||||
logger = logging.getLogger(name=__name__)
|
||||
|
||||
|
||||
# Host-scoped credentials for HTTP requests
|
||||
HttpCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.HTTP], Literal["host_scoped"]
|
||||
]
|
||||
|
||||
|
||||
TEST_CREDENTIALS = HostScopedCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer test-token"),
|
||||
"X-API-Key": SecretStr("test-api-key"),
|
||||
},
|
||||
title="Mock HTTP Host-Scoped Credentials",
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class HttpMethod(Enum):
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
@@ -169,3 +201,62 @@ class SendWebRequestBlock(Block):
|
||||
yield "client_error", result
|
||||
else:
|
||||
yield "server_error", result
|
||||
|
||||
|
||||
class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
||||
class Input(SendWebRequestBlock.Input):
|
||||
credentials: HttpCredentials = CredentialsField(
|
||||
description="HTTP host-scoped credentials for automatic header injection",
|
||||
discriminator="url",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
Block.__init__(
|
||||
self,
|
||||
id="fff86bcd-e001-4bad-a7f6-2eae4720c8dc",
|
||||
description="Make an authenticated HTTP request with host-scoped credentials (JSON / form / multipart).",
|
||||
categories={BlockCategory.OUTPUT},
|
||||
input_schema=SendAuthenticatedWebRequestBlock.Input,
|
||||
output_schema=SendWebRequestBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run( # type: ignore[override]
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
credentials: HostScopedCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
||||
base_input = SendWebRequestBlock.Input(
|
||||
url=input_data.url,
|
||||
method=input_data.method,
|
||||
headers=input_data.headers,
|
||||
json_format=input_data.json_format,
|
||||
body=input_data.body,
|
||||
files_name=input_data.files_name,
|
||||
files=input_data.files,
|
||||
)
|
||||
|
||||
# Apply host-scoped credentials to headers
|
||||
extra_headers = {}
|
||||
if credentials.matches_url(input_data.url):
|
||||
logger.debug(
|
||||
f"Applying host-scoped credentials {credentials.id} for URL {input_data.url}"
|
||||
)
|
||||
extra_headers.update(credentials.get_headers_dict())
|
||||
else:
|
||||
logger.warning(
|
||||
f"Host-scoped credentials {credentials.id} do not match URL {input_data.url}"
|
||||
)
|
||||
|
||||
# Merge with user-provided headers (user headers take precedence)
|
||||
base_input.headers = {**extra_headers, **input_data.headers}
|
||||
|
||||
# Use parent class run method
|
||||
async for output_name, output_data in super().run(
|
||||
base_input, graph_exec_id=graph_exec_id, **kwargs
|
||||
):
|
||||
yield output_name, output_data
|
||||
|
||||
@@ -40,7 +40,7 @@ LLMProviderName = Literal[
|
||||
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
id="769f6af7-820b-4d5d-9b7a-ab82bbc165f",
|
||||
provider="openai",
|
||||
api_key=SecretStr("mock-openai-api-key"),
|
||||
title="Mock OpenAI API key",
|
||||
|
||||
@@ -13,7 +13,7 @@ from backend.data.model import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
id="8cc8b2c5-d3e4-4b1c-84ad-e1e9fe2a0122",
|
||||
provider="mem0",
|
||||
api_key=SecretStr("mock-mem0-api-key"),
|
||||
title="Mock Mem0 API key",
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.smartlead.models import (
|
||||
Sequence,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class CreateCampaignBlock(Block):
|
||||
@@ -27,7 +27,7 @@ class CreateCampaignBlock(Block):
|
||||
name: str = SchemaField(
|
||||
description="The name of the campaign",
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
credentials: SmartLeadCredentialsInput = CredentialsField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
@@ -119,7 +119,7 @@ class AddLeadToCampaignBlock(Block):
|
||||
description="Settings for lead upload",
|
||||
default=LeadUploadSettings(),
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
credentials: SmartLeadCredentialsInput = CredentialsField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
@@ -251,7 +251,7 @@ class SaveCampaignSequencesBlock(Block):
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
credentials: SmartLeadCredentialsInput = CredentialsField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
|
||||
485
autogpt_platform/backend/backend/blocks/test/test_http.py
Normal file
485
autogpt_platform/backend/backend/blocks/test/test_http.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""Comprehensive tests for HTTP block with HostScopedCredentials functionality."""
|
||||
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.http import (
|
||||
HttpCredentials,
|
||||
HttpMethod,
|
||||
SendAuthenticatedWebRequestBlock,
|
||||
)
|
||||
from backend.data.model import HostScopedCredentials
|
||||
from backend.util.request import Response
|
||||
|
||||
|
||||
class TestHttpBlockWithHostScopedCredentials:
|
||||
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
||||
|
||||
@pytest.fixture
|
||||
def http_block(self):
|
||||
"""Create an HTTP block instance."""
|
||||
return SendAuthenticatedWebRequestBlock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Mock a successful HTTP response."""
|
||||
response = MagicMock(spec=Response)
|
||||
response.status = 200
|
||||
response.headers = {"content-type": "application/json"}
|
||||
response.json.return_value = {"success": True, "data": "test"}
|
||||
return response
|
||||
|
||||
@pytest.fixture
|
||||
def exact_match_credentials(self):
|
||||
"""Create host-scoped credentials for exact domain matching."""
|
||||
return HostScopedCredentials(
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer exact-match-token"),
|
||||
"X-API-Key": SecretStr("api-key-123"),
|
||||
},
|
||||
title="Exact Match API Credentials",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def wildcard_credentials(self):
|
||||
"""Create host-scoped credentials with wildcard pattern."""
|
||||
return HostScopedCredentials(
|
||||
provider="http",
|
||||
host="*.github.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("token ghp_wildcard123"),
|
||||
},
|
||||
title="GitHub Wildcard Credentials",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def non_matching_credentials(self):
|
||||
"""Create credentials that don't match test URLs."""
|
||||
return HostScopedCredentials(
|
||||
provider="http",
|
||||
host="different.api.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer non-matching-token"),
|
||||
},
|
||||
title="Non-matching Credentials",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_http_block_with_exact_host_match(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
exact_match_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test HTTP block with exact host matching credentials."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Prepare input data
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": exact_match_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": exact_match_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with credentials provided by execution manager
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=exact_match_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify request headers include both credential and user headers
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {
|
||||
"Authorization": "Bearer exact-match-token",
|
||||
"X-API-Key": "api-key-123",
|
||||
"User-Agent": "test-agent",
|
||||
}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
# Verify response handling
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "response"
|
||||
assert result[0][1] == {"success": True, "data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_http_block_with_wildcard_host_match(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
wildcard_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test HTTP block with wildcard host pattern matching."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with subdomain that should match *.github.com
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.github.com/user",
|
||||
method=HttpMethod.GET,
|
||||
headers={},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": wildcard_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": wildcard_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with wildcard credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=wildcard_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify wildcard matching works
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {"Authorization": "token ghp_wildcard123"}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_http_block_with_non_matching_credentials(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
non_matching_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test HTTP block when credentials don't match the target URL."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with URL that doesn't match the credentials
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": non_matching_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": non_matching_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with non-matching credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=non_matching_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify only user headers are included (no credential headers)
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {"User-Agent": "test-agent"}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_user_headers_override_credential_headers(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
exact_match_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test that user-provided headers take precedence over credential headers."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with user header that conflicts with credential header
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.POST,
|
||||
headers={
|
||||
"Authorization": "Bearer user-override-token", # Should override
|
||||
"Content-Type": "application/json", # Additional user header
|
||||
},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": exact_match_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": exact_match_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with conflicting headers
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=exact_match_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify user headers take precedence
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {
|
||||
"X-API-Key": "api-key-123", # From credentials
|
||||
"Authorization": "Bearer user-override-token", # User override
|
||||
"Content-Type": "application/json", # User header
|
||||
}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_auto_discovered_credentials_flow(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
mock_response,
|
||||
):
|
||||
"""Test the auto-discovery flow where execution manager provides matching credentials."""
|
||||
# Create auto-discovered credentials
|
||||
auto_discovered_creds = HostScopedCredentials(
|
||||
provider="http",
|
||||
host="*.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer auto-discovered-token"),
|
||||
},
|
||||
title="Auto-discovered Credentials",
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with empty credentials field (triggers auto-discovery)
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": "", # Empty ID triggers auto-discovery in execution manager
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": "",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with auto-discovered credentials provided by execution manager
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=auto_discovered_creds, # Execution manager found these
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify auto-discovered credentials were applied
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {"Authorization": "Bearer auto-discovered-token"}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
# Verify response handling
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "response"
|
||||
assert result[0][1] == {"success": True, "data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_multiple_header_credentials(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
mock_response,
|
||||
):
|
||||
"""Test credentials with multiple headers are all applied."""
|
||||
# Create credentials with multiple headers
|
||||
multi_header_creds = HostScopedCredentials(
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer multi-token"),
|
||||
"X-API-Key": SecretStr("api-key-456"),
|
||||
"X-Client-ID": SecretStr("client-789"),
|
||||
"X-Custom-Header": SecretStr("custom-value"),
|
||||
},
|
||||
title="Multi-Header Credentials",
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with credentials containing multiple headers
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": multi_header_creds.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": multi_header_creds.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with multi-header credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=multi_header_creds,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify all headers are included
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {
|
||||
"Authorization": "Bearer multi-token",
|
||||
"X-API-Key": "api-key-456",
|
||||
"X-Client-ID": "client-789",
|
||||
"X-Custom-Header": "custom-value",
|
||||
"User-Agent": "test-agent",
|
||||
}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_credentials_with_complex_url_patterns(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
mock_response,
|
||||
):
|
||||
"""Test credentials matching various URL patterns."""
|
||||
# Test cases for different URL patterns
|
||||
test_cases = [
|
||||
{
|
||||
"host_pattern": "api.example.com",
|
||||
"test_url": "https://api.example.com/v1/users",
|
||||
"should_match": True,
|
||||
},
|
||||
{
|
||||
"host_pattern": "*.example.com",
|
||||
"test_url": "https://api.example.com/v1/users",
|
||||
"should_match": True,
|
||||
},
|
||||
{
|
||||
"host_pattern": "*.example.com",
|
||||
"test_url": "https://subdomain.example.com/data",
|
||||
"should_match": True,
|
||||
},
|
||||
{
|
||||
"host_pattern": "api.example.com",
|
||||
"test_url": "https://api.different.com/data",
|
||||
"should_match": False,
|
||||
},
|
||||
]
|
||||
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
for case in test_cases:
|
||||
# Reset mock for each test case
|
||||
mock_requests.reset_mock()
|
||||
|
||||
# Create credentials for this test case
|
||||
test_creds = HostScopedCredentials(
|
||||
provider="http",
|
||||
host=case["host_pattern"],
|
||||
headers={
|
||||
"Authorization": SecretStr(f"Bearer {case['host_pattern']}-token"),
|
||||
},
|
||||
title=f"Credentials for {case['host_pattern']}",
|
||||
)
|
||||
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url=case["test_url"],
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": test_creds.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": test_creds.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with test credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=test_creds,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify headers based on whether pattern should match
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
headers = call_args.kwargs["headers"]
|
||||
|
||||
if case["should_match"]:
|
||||
# Should include both user and credential headers
|
||||
expected_auth = f"Bearer {case['host_pattern']}-token"
|
||||
assert headers["Authorization"] == expected_auth
|
||||
assert headers["User-Agent"] == "test-agent"
|
||||
else:
|
||||
# Should only include user headers
|
||||
assert "Authorization" not in headers
|
||||
assert headers["User-Agent"] == "test-agent"
|
||||
@@ -25,12 +25,7 @@ async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Grap
|
||||
async def create_credentials(s: SpinTestServer, u: User):
|
||||
provider = ProviderName.OPENAI
|
||||
credentials = llm.TEST_CREDENTIALS
|
||||
try:
|
||||
await s.agent_server.test_create_credentials(u.id, provider, credentials)
|
||||
except Exception:
|
||||
# ValueErrors is raised trying to recreate the same credentials
|
||||
# so hidding the error
|
||||
pass
|
||||
return await s.agent_server.test_create_credentials(u.id, provider, credentials)
|
||||
|
||||
|
||||
async def execute_graph(
|
||||
@@ -60,19 +55,18 @@ async def execute_graph(
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
await create_credentials(server, test_user)
|
||||
creds = await create_credentials(server, test_user)
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
"credentials": creds,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
@@ -110,80 +104,18 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_validation_with_tool_nodes_raises_error(server: SpinTestServer):
|
||||
|
||||
test_user = await create_test_user()
|
||||
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
await create_credentials(server, test_user)
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=AgentExecutorBlock().id,
|
||||
input_default={
|
||||
"graph_id": test_tool_graph.id,
|
||||
"graph_version": test_tool_graph.version,
|
||||
"input_schema": test_tool_graph.input_schema,
|
||||
"output_schema": test_tool_graph.output_schema,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=StoreValueBlock().id,
|
||||
),
|
||||
]
|
||||
|
||||
links = [
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[1].id,
|
||||
source_name="tools_^_sample_tool_input_1",
|
||||
sink_name="input_1",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[1].id,
|
||||
source_name="tools_^_sample_tool_input_2",
|
||||
sink_name="input_2",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="tools_^_store_value_input",
|
||||
sink_name="input",
|
||||
),
|
||||
]
|
||||
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
await create_credentials(server, test_user)
|
||||
creds = await create_credentials(server, test_user)
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
"credentials": creds,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
@@ -15,7 +15,7 @@ from backend.blocks.zerobounce._auth import (
|
||||
ZeroBounceCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
@@ -90,7 +90,7 @@ class ValidateEmailsBlock(Block):
|
||||
description="IP address to validate",
|
||||
default="",
|
||||
)
|
||||
credentials: ZeroBounceCredentialsInput = SchemaField(
|
||||
credentials: ZeroBounceCredentialsInput = CredentialsField(
|
||||
description="ZeroBounce credentials",
|
||||
)
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ from dotenv import load_dotenv
|
||||
|
||||
from backend.util.logging import configure_logging
|
||||
|
||||
os.environ["ENABLE_AUTH"] = "false"
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# NOTE: You can run tests like with the --log-cli-level=INFO to see the logs
|
||||
@@ -78,6 +78,7 @@ class BlockCategory(Enum):
|
||||
PRODUCTIVITY = "Block that helps with productivity"
|
||||
ISSUE_TRACKING = "Block that helps with issue tracking"
|
||||
MULTIMEDIA = "Block that interacts with multimedia content"
|
||||
MARKETING = "Block that helps with marketing"
|
||||
|
||||
def dict(self) -> dict[str, str]:
|
||||
return {"category": self.name, "description": self.value}
|
||||
|
||||
@@ -4,6 +4,7 @@ from backend.blocks.ai_music_generator import AIMusicGeneratorBlock
|
||||
from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
from backend.blocks.apollo.person import GetPersonDetailBlock
|
||||
from backend.blocks.flux_kontext import AIImageEditorBlock, FluxKontextModelName
|
||||
from backend.blocks.ideogram import IdeogramModelBlock
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
@@ -362,7 +363,31 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
],
|
||||
SearchPeopleBlock: [
|
||||
BlockCost(
|
||||
cost_amount=2,
|
||||
cost_amount=10,
|
||||
cost_filter={
|
||||
"enrich_info": False,
|
||||
"credentials": {
|
||||
"id": apollo_credentials.id,
|
||||
"provider": apollo_credentials.provider,
|
||||
"type": apollo_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
BlockCost(
|
||||
cost_amount=20,
|
||||
cost_filter={
|
||||
"enrich_info": True,
|
||||
"credentials": {
|
||||
"id": apollo_credentials.id,
|
||||
"provider": apollo_credentials.provider,
|
||||
"type": apollo_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
GetPersonDetailBlock: [
|
||||
BlockCost(
|
||||
cost_amount=1,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": apollo_credentials.id,
|
||||
|
||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel
|
||||
from redis.asyncio.client import PubSub as AsyncPubSub
|
||||
from redis.client import PubSub
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data import redis_client as redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -48,6 +48,7 @@ from .block import (
|
||||
get_webhook_block_ids,
|
||||
)
|
||||
from .db import BaseDbModel
|
||||
from .event_bus import AsyncRedisEventBus, RedisEventBus
|
||||
from .includes import (
|
||||
EXECUTION_RESULT_INCLUDE,
|
||||
EXECUTION_RESULT_ORDER,
|
||||
@@ -55,7 +56,6 @@ from .includes import (
|
||||
graph_execution_include,
|
||||
)
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
from .queue import AsyncRedisEventBus, RedisEventBus
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import type as type_utils
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
@@ -243,6 +244,8 @@ class Graph(BaseGraph):
|
||||
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
||||
if field.provider != other_field.provider:
|
||||
continue
|
||||
if ProviderName.HTTP in field.provider:
|
||||
continue
|
||||
|
||||
# If this happens, that means a block implementation probably needs
|
||||
# to be updated.
|
||||
@@ -264,6 +267,7 @@ class Graph(BaseGraph):
|
||||
required_scopes=set(field_info.required_scopes or []),
|
||||
discriminator=field_info.discriminator,
|
||||
discriminator_mapping=field_info.discriminator_mapping,
|
||||
discriminator_values=field_info.discriminator_values,
|
||||
),
|
||||
)
|
||||
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
||||
@@ -282,37 +286,40 @@ class Graph(BaseGraph):
|
||||
Returns:
|
||||
dict[aggregated_field_key, tuple(
|
||||
CredentialsFieldInfo: A spec for one aggregated credentials field
|
||||
(now includes discriminator_values from matching nodes)
|
||||
set[(node_id, field_name)]: Node credentials fields that are
|
||||
compatible with this aggregated field spec
|
||||
)]
|
||||
"""
|
||||
return {
|
||||
"_".join(sorted(agg_field_info.provider))
|
||||
+ "_"
|
||||
+ "_".join(sorted(agg_field_info.supported_types))
|
||||
+ "_credentials": (agg_field_info, node_fields)
|
||||
for agg_field_info, node_fields in CredentialsFieldInfo.combine(
|
||||
*(
|
||||
(
|
||||
# Apply discrimination before aggregating credentials inputs
|
||||
(
|
||||
field_info.discriminate(
|
||||
node.input_default[field_info.discriminator]
|
||||
)
|
||||
if (
|
||||
field_info.discriminator
|
||||
and node.input_default.get(field_info.discriminator)
|
||||
)
|
||||
else field_info
|
||||
),
|
||||
(node.id, field_name),
|
||||
# First collect all credential field data with input defaults
|
||||
node_credential_data = []
|
||||
|
||||
for graph in [self] + self.sub_graphs:
|
||||
for node in graph.nodes:
|
||||
for (
|
||||
field_name,
|
||||
field_info,
|
||||
) in node.block.input_schema.get_credentials_fields_info().items():
|
||||
|
||||
discriminator = field_info.discriminator
|
||||
if not discriminator:
|
||||
node_credential_data.append((field_info, (node.id, field_name)))
|
||||
continue
|
||||
|
||||
discriminator_value = node.input_default.get(discriminator)
|
||||
if discriminator_value is None:
|
||||
node_credential_data.append((field_info, (node.id, field_name)))
|
||||
continue
|
||||
|
||||
discriminated_info = field_info.discriminate(discriminator_value)
|
||||
discriminated_info.discriminator_values.add(discriminator_value)
|
||||
|
||||
node_credential_data.append(
|
||||
(discriminated_info, (node.id, field_name))
|
||||
)
|
||||
for graph in [self] + self.sub_graphs
|
||||
for node in graph.nodes
|
||||
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
# Combine credential field info (this will merge discriminator_values automatically)
|
||||
return CredentialsFieldInfo.combine(*node_credential_data)
|
||||
|
||||
|
||||
class GraphModel(Graph):
|
||||
@@ -391,7 +398,9 @@ class GraphModel(Graph):
|
||||
continue
|
||||
node.input_default["user_id"] = user_id
|
||||
node.input_default.setdefault("inputs", {})
|
||||
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
|
||||
if (
|
||||
graph_id := node.input_default.get("graph_id")
|
||||
) and graph_id in graph_id_map:
|
||||
node.input_default["graph_id"] = graph_id_map[graph_id]
|
||||
|
||||
def validate_graph(
|
||||
|
||||
@@ -11,8 +11,8 @@ from prisma.types import (
|
||||
)
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
from backend.data.queue import AsyncRedisEventBus
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
|
||||
@@ -14,11 +14,12 @@ from typing import (
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
cast,
|
||||
get_args,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
from prisma.enums import CreditTransactionType
|
||||
@@ -240,13 +241,65 @@ class UserPasswordCredentials(_BaseCredentials):
|
||||
return f"Basic {base64.b64encode(f'{self.username.get_secret_value()}:{self.password.get_secret_value()}'.encode()).decode()}"
|
||||
|
||||
|
||||
class HostScopedCredentials(_BaseCredentials):
|
||||
type: Literal["host_scoped"] = "host_scoped"
|
||||
host: str = Field(description="The host/URI pattern to match against request URLs")
|
||||
headers: dict[str, SecretStr] = Field(
|
||||
description="Key-value header map to add to matching requests",
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
def _extract_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]:
|
||||
"""Helper to extract secret values from headers."""
|
||||
return {key: value.get_secret_value() for key, value in headers.items()}
|
||||
|
||||
@field_serializer("headers")
|
||||
def serialize_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]:
|
||||
"""Serialize headers by extracting secret values."""
|
||||
return self._extract_headers(headers)
|
||||
|
||||
def get_headers_dict(self) -> dict[str, str]:
|
||||
"""Get headers with secret values extracted."""
|
||||
return self._extract_headers(self.headers)
|
||||
|
||||
def auth_header(self) -> str:
|
||||
"""Get authorization header for backward compatibility."""
|
||||
auth_headers = self.get_headers_dict()
|
||||
if "Authorization" in auth_headers:
|
||||
return auth_headers["Authorization"]
|
||||
return ""
|
||||
|
||||
def matches_url(self, url: str) -> bool:
|
||||
"""Check if this credential should be applied to the given URL."""
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
# Extract hostname without port
|
||||
request_host = parsed_url.hostname
|
||||
if not request_host:
|
||||
return False
|
||||
|
||||
# Simple host matching - exact match or wildcard subdomain match
|
||||
if self.host == request_host:
|
||||
return True
|
||||
|
||||
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
||||
if self.host.startswith("*."):
|
||||
domain = self.host[2:] # Remove "*."
|
||||
return request_host.endswith(f".{domain}") or request_host == domain
|
||||
|
||||
return False
|
||||
|
||||
|
||||
Credentials = Annotated[
|
||||
OAuth2Credentials | APIKeyCredentials | UserPasswordCredentials,
|
||||
OAuth2Credentials
|
||||
| APIKeyCredentials
|
||||
| UserPasswordCredentials
|
||||
| HostScopedCredentials,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
CredentialsType = Literal["api_key", "oauth2", "user_password"]
|
||||
CredentialsType = Literal["api_key", "oauth2", "user_password", "host_scoped"]
|
||||
|
||||
|
||||
class OAuthState(BaseModel):
|
||||
@@ -320,15 +373,29 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
|
||||
schema["credentials_provider"] = cls.allowed_providers()
|
||||
schema["credentials_types"] = cls.allowed_cred_types()
|
||||
def _add_json_schema_extra(schema: dict, model_class: type):
|
||||
# Use model_class for allowed_providers/cred_types
|
||||
if hasattr(model_class, "allowed_providers") and hasattr(
|
||||
model_class, "allowed_cred_types"
|
||||
):
|
||||
schema["credentials_provider"] = model_class.allowed_providers()
|
||||
schema["credentials_types"] = model_class.allowed_cred_types()
|
||||
# Do not return anything, just mutate schema in place
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=_add_json_schema_extra, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def _extract_host_from_url(url: str) -> str:
|
||||
"""Extract host from URL for grouping host-scoped credentials."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return parsed.hostname or url
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
|
||||
provider: frozenset[CP] = Field(..., alias="credentials_provider")
|
||||
@@ -336,11 +403,12 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
|
||||
discriminator: Optional[str] = None
|
||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||
discriminator_values: set[Any] = Field(default_factory=set)
|
||||
|
||||
@classmethod
|
||||
def combine(
|
||||
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
|
||||
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
|
||||
) -> dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
|
||||
"""
|
||||
Combines multiple CredentialsFieldInfo objects into as few as possible.
|
||||
|
||||
@@ -358,22 +426,36 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
the set of keys of the respective original items that were grouped together.
|
||||
"""
|
||||
if not fields:
|
||||
return []
|
||||
return {}
|
||||
|
||||
# Group fields by their provider and supported_types
|
||||
# For HTTP host-scoped credentials, also group by host
|
||||
grouped_fields: defaultdict[
|
||||
tuple[frozenset[CP], frozenset[CT]],
|
||||
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
|
||||
] = defaultdict(list)
|
||||
|
||||
for field, key in fields:
|
||||
group_key = (frozenset(field.provider), frozenset(field.supported_types))
|
||||
if field.provider == frozenset([ProviderName.HTTP]):
|
||||
# HTTP host-scoped credentials can have different hosts that reqires different credential sets.
|
||||
# Group by host extracted from the URL
|
||||
providers = frozenset(
|
||||
[cast(CP, "http")]
|
||||
+ [
|
||||
cast(CP, _extract_host_from_url(str(value)))
|
||||
for value in field.discriminator_values
|
||||
]
|
||||
)
|
||||
else:
|
||||
providers = frozenset(field.provider)
|
||||
|
||||
group_key = (providers, frozenset(field.supported_types))
|
||||
grouped_fields[group_key].append((key, field))
|
||||
|
||||
# Combine fields within each group
|
||||
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
|
||||
result: dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]] = {}
|
||||
|
||||
for group in grouped_fields.values():
|
||||
for key, group in grouped_fields.items():
|
||||
# Start with the first field in the group
|
||||
_, combined = group[0]
|
||||
|
||||
@@ -386,18 +468,32 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
if field.required_scopes:
|
||||
all_scopes.update(field.required_scopes)
|
||||
|
||||
# Create a new combined field
|
||||
result.append(
|
||||
(
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
credentials_types=combined.supported_types,
|
||||
credentials_scopes=frozenset(all_scopes) or None,
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
),
|
||||
combined_keys,
|
||||
)
|
||||
# Combine discriminator_values from all fields in the group (removing duplicates)
|
||||
all_discriminator_values = []
|
||||
for _, field in group:
|
||||
for value in field.discriminator_values:
|
||||
if value not in all_discriminator_values:
|
||||
all_discriminator_values.append(value)
|
||||
|
||||
# Generate the key for the combined result
|
||||
providers_key, supported_types_key = key
|
||||
group_key = (
|
||||
"-".join(sorted(providers_key))
|
||||
+ "_"
|
||||
+ "-".join(sorted(supported_types_key))
|
||||
+ "_credentials"
|
||||
)
|
||||
|
||||
result[group_key] = (
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
credentials_types=combined.supported_types,
|
||||
credentials_scopes=frozenset(all_scopes) or None,
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
discriminator_values=set(all_discriminator_values),
|
||||
),
|
||||
combined_keys,
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -406,11 +502,15 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
if not (self.discriminator and self.discriminator_mapping):
|
||||
return self
|
||||
|
||||
discriminator_value = self.discriminator_mapping[discriminator_value]
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([discriminator_value]),
|
||||
credentials_provider=frozenset(
|
||||
[self.discriminator_mapping[discriminator_value]]
|
||||
),
|
||||
credentials_types=self.supported_types,
|
||||
credentials_scopes=self.required_scopes,
|
||||
discriminator=self.discriminator,
|
||||
discriminator_mapping=self.discriminator_mapping,
|
||||
discriminator_values=self.discriminator_values,
|
||||
)
|
||||
|
||||
|
||||
@@ -419,6 +519,7 @@ def CredentialsField(
|
||||
*,
|
||||
discriminator: Optional[str] = None,
|
||||
discriminator_mapping: Optional[dict[str, Any]] = None,
|
||||
discriminator_values: Optional[set[Any]] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
**kwargs,
|
||||
@@ -434,6 +535,7 @@ def CredentialsField(
|
||||
"credentials_scopes": list(required_scopes) or None,
|
||||
"discriminator": discriminator,
|
||||
"discriminator_mapping": discriminator_mapping,
|
||||
"discriminator_values": discriminator_values,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
143
autogpt_platform/backend/backend/data/model_test.py
Normal file
143
autogpt_platform/backend/backend/data/model_test.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import HostScopedCredentials
|
||||
|
||||
|
||||
class TestHostScopedCredentials:
|
||||
def test_host_scoped_credentials_creation(self):
|
||||
"""Test creating HostScopedCredentials with required fields."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer secret-token"),
|
||||
"X-API-Key": SecretStr("api-key-123"),
|
||||
},
|
||||
title="Example API Credentials",
|
||||
)
|
||||
|
||||
assert creds.type == "host_scoped"
|
||||
assert creds.provider == "custom"
|
||||
assert creds.host == "api.example.com"
|
||||
assert creds.title == "Example API Credentials"
|
||||
assert len(creds.headers) == 2
|
||||
assert "Authorization" in creds.headers
|
||||
assert "X-API-Key" in creds.headers
|
||||
|
||||
def test_get_headers_dict(self):
|
||||
"""Test getting headers with secret values extracted."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer secret-token"),
|
||||
"X-Custom-Header": SecretStr("custom-value"),
|
||||
},
|
||||
)
|
||||
|
||||
headers_dict = creds.get_headers_dict()
|
||||
|
||||
assert headers_dict == {
|
||||
"Authorization": "Bearer secret-token",
|
||||
"X-Custom-Header": "custom-value",
|
||||
}
|
||||
|
||||
def test_matches_url_exact_host(self):
|
||||
"""Test URL matching with exact host match."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("https://api.example.com/v1/data")
|
||||
assert creds.matches_url("http://api.example.com/endpoint")
|
||||
assert not creds.matches_url("https://other.example.com/v1/data")
|
||||
assert not creds.matches_url("https://subdomain.api.example.com/v1/data")
|
||||
|
||||
def test_matches_url_wildcard_subdomain(self):
|
||||
"""Test URL matching with wildcard subdomain pattern."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="*.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("https://api.example.com/v1/data")
|
||||
assert creds.matches_url("https://subdomain.example.com/endpoint")
|
||||
assert creds.matches_url("https://deep.nested.example.com/path")
|
||||
assert creds.matches_url("https://example.com/path") # Base domain should match
|
||||
assert not creds.matches_url("https://example.org/v1/data")
|
||||
assert not creds.matches_url("https://notexample.com/v1/data")
|
||||
|
||||
def test_matches_url_with_port_and_path(self):
|
||||
"""Test URL matching with ports and paths."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="localhost",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
||||
assert creds.matches_url("http://localhost/simple")
|
||||
|
||||
def test_empty_headers_dict(self):
|
||||
"""Test HostScopedCredentials with empty headers."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom", host="api.example.com", headers={}
|
||||
)
|
||||
|
||||
assert creds.get_headers_dict() == {}
|
||||
assert creds.matches_url("https://api.example.com/test")
|
||||
|
||||
def test_credential_serialization(self):
|
||||
"""Test that credentials can be serialized/deserialized properly."""
|
||||
original_creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer secret-token"),
|
||||
"X-API-Key": SecretStr("api-key-123"),
|
||||
},
|
||||
title="Test Credentials",
|
||||
)
|
||||
|
||||
# Serialize to dict (simulating storage)
|
||||
serialized = original_creds.model_dump()
|
||||
|
||||
# Deserialize back
|
||||
restored_creds = HostScopedCredentials.model_validate(serialized)
|
||||
|
||||
assert restored_creds.id == original_creds.id
|
||||
assert restored_creds.provider == original_creds.provider
|
||||
assert restored_creds.host == original_creds.host
|
||||
assert restored_creds.title == original_creds.title
|
||||
assert restored_creds.type == "host_scoped"
|
||||
|
||||
# Check that headers are properly restored
|
||||
assert restored_creds.get_headers_dict() == original_creds.get_headers_dict()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"host,test_url,expected",
|
||||
[
|
||||
("api.example.com", "https://api.example.com/test", True),
|
||||
("api.example.com", "https://different.example.com/test", False),
|
||||
("*.example.com", "https://api.example.com/test", True),
|
||||
("*.example.com", "https://sub.api.example.com/test", True),
|
||||
("*.example.com", "https://example.com/test", True),
|
||||
("*.example.com", "https://example.org/test", False),
|
||||
("localhost", "http://localhost:3000/test", True),
|
||||
("localhost", "http://127.0.0.1:3000/test", False),
|
||||
],
|
||||
)
|
||||
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
||||
"""Parametrized test for various URL matching scenarios."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="test",
|
||||
host=host,
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url(test_url) == expected
|
||||
@@ -35,7 +35,7 @@ from autogpt_libs.utils.cache import thread_cached
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import (
|
||||
BlockData,
|
||||
BlockInput,
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.redis import get_redis_async
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor.database import DatabaseManagerAsyncClient
|
||||
|
||||
@@ -7,7 +7,7 @@ from autogpt_libs.utils.synchronize import AsyncRedisKeyedMutex
|
||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
|
||||
from backend.data.model import Credentials, OAuth2Credentials
|
||||
from backend.data.redis import get_redis_async
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
@@ -17,6 +17,7 @@ class ProviderName(str, Enum):
|
||||
GOOGLE = "google"
|
||||
GOOGLE_MAPS = "google_maps"
|
||||
GROQ = "groq"
|
||||
HTTP = "http"
|
||||
HUBSPOT = "hubspot"
|
||||
IDEOGRAM = "ideogram"
|
||||
JINA = "jina"
|
||||
|
||||
@@ -22,7 +22,12 @@ from backend.data.integrations import (
|
||||
publish_webhook_event,
|
||||
wait_for_webhook_event,
|
||||
)
|
||||
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
|
||||
from backend.data.model import (
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
@@ -82,6 +87,9 @@ class CredentialsMetaResponse(BaseModel):
|
||||
title: str | None
|
||||
scopes: list[str] | None
|
||||
username: str | None
|
||||
host: str | None = Field(
|
||||
default=None, description="Host pattern for host-scoped credentials"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/callback")
|
||||
@@ -156,6 +164,9 @@ async def callback(
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=(
|
||||
credentials.host if isinstance(credentials, HostScopedCredentials) else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -172,6 +183,7 @@ async def list_credentials(
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
@@ -193,6 +205,7 @@ async def list_credentials_by_provider(
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
|
||||
@@ -357,11 +357,22 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
from backend.server.integrations.router import create_credentials
|
||||
|
||||
return await create_credentials(
|
||||
user_id=user_id, provider=provider, credentials=credentials
|
||||
from backend.server.integrations.router import (
|
||||
create_credentials,
|
||||
get_credential,
|
||||
)
|
||||
|
||||
try:
|
||||
return await create_credentials(
|
||||
user_id=user_id, provider=provider, credentials=credentials
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating credentials: {e}")
|
||||
return await get_credential(
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
cred_id=credentials.id,
|
||||
)
|
||||
|
||||
def set_test_dependency_overrides(self, overrides: dict):
|
||||
app.dependency_overrides.update(overrides)
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
from typing import Annotated
|
||||
|
||||
import fastapi
|
||||
import pydantic
|
||||
|
||||
import backend.data.analytics
|
||||
from backend.server.utils import get_user_id
|
||||
@@ -12,24 +13,28 @@ router = fastapi.APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LogRawMetricRequest(pydantic.BaseModel):
|
||||
metric_name: str = pydantic.Field(..., min_length=1)
|
||||
metric_value: float = pydantic.Field(..., allow_inf_nan=False)
|
||||
data_string: str = pydantic.Field(..., min_length=1)
|
||||
|
||||
|
||||
@router.post(path="/log_raw_metric")
|
||||
async def log_raw_metric(
|
||||
user_id: Annotated[str, fastapi.Depends(get_user_id)],
|
||||
metric_name: Annotated[str, fastapi.Body(..., embed=True)],
|
||||
metric_value: Annotated[float, fastapi.Body(..., embed=True)],
|
||||
data_string: Annotated[str, fastapi.Body(..., embed=True)],
|
||||
request: LogRawMetricRequest,
|
||||
):
|
||||
try:
|
||||
result = await backend.data.analytics.log_raw_metric(
|
||||
user_id=user_id,
|
||||
metric_name=metric_name,
|
||||
metric_value=metric_value,
|
||||
data_string=data_string,
|
||||
metric_name=request.metric_name,
|
||||
metric_value=request.metric_value,
|
||||
data_string=request.data_string,
|
||||
)
|
||||
return result.id
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to log metric %s for user %s: %s", metric_name, user_id, e
|
||||
"Failed to log metric %s for user %s: %s", request.metric_name, user_id, e
|
||||
)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
|
||||
@@ -97,8 +97,17 @@ def test_log_raw_metric_invalid_request_improved() -> None:
|
||||
assert "data_string" in error_fields, "Should report missing data_string"
|
||||
|
||||
|
||||
def test_log_raw_metric_type_validation_improved() -> None:
|
||||
def test_log_raw_metric_type_validation_improved(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test metric type validation with improved assertions."""
|
||||
# Mock the analytics function to avoid event loop issues
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=Mock(id="test-id"),
|
||||
)
|
||||
|
||||
invalid_requests = [
|
||||
{
|
||||
"data": {
|
||||
@@ -119,10 +128,10 @@ def test_log_raw_metric_type_validation_improved() -> None:
|
||||
{
|
||||
"data": {
|
||||
"metric_name": "test",
|
||||
"metric_value": float("inf"), # Infinity
|
||||
"data_string": "test",
|
||||
"metric_value": 123, # Valid number
|
||||
"data_string": "", # Empty data_string
|
||||
},
|
||||
"expected_error": "ensure this value is finite",
|
||||
"expected_error": "String should have at least 1 character",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@@ -93,10 +93,18 @@ def test_log_raw_metric_values_parametrized(
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_invalid_requests_parametrized(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test invalid metric requests with parametrize."""
|
||||
# Mock the analytics function to avoid event loop issues
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=Mock(id="test-id"),
|
||||
)
|
||||
|
||||
response = client.post("/log_raw_metric", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -108,11 +108,16 @@ class TestDatabaseIsolation:
|
||||
where={"email": {"contains": "@test.example"}}
|
||||
)
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def test_create_user(self, test_db_connection):
|
||||
"""Test that demonstrates proper isolation."""
|
||||
# This test has access to a clean database
|
||||
user = await test_db_connection.user.create(
|
||||
data={"email": "test@test.example", "name": "Test User"}
|
||||
data={
|
||||
"id": "test-user-id",
|
||||
"email": "test@test.example",
|
||||
"name": "Test User",
|
||||
}
|
||||
)
|
||||
assert user.email == "test@test.example"
|
||||
# User will be cleaned up automatically
|
||||
|
||||
@@ -143,7 +143,7 @@ async def test_add_agent_to_library(mocker):
|
||||
)
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.create = mocker.AsyncMock(
|
||||
return_value=mock_library_agent_data
|
||||
)
|
||||
@@ -159,21 +159,24 @@ async def test_add_agent_to_library(mocker):
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
mock_library_agent.return_value.find_first.assert_called_once_with(
|
||||
mock_library_agent.return_value.find_unique.assert_called_once_with(
|
||||
where={
|
||||
"userId": "test-user",
|
||||
"agentGraphId": "agent1",
|
||||
"agentGraphVersion": 1,
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": "test-user",
|
||||
"agentGraphId": "agent1",
|
||||
"agentGraphVersion": 1,
|
||||
}
|
||||
},
|
||||
include=library_agent_include("test-user"),
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
mock_library_agent.return_value.create.assert_called_once_with(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId="test-user",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
),
|
||||
data={
|
||||
"User": {"connect": {"id": "test-user"}},
|
||||
"AgentGraph": {
|
||||
"connect": {"graphVersionId": {"id": "agent1", "version": 1}}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
},
|
||||
include=library_agent_include("test-user"),
|
||||
)
|
||||
|
||||
|
||||
@@ -121,26 +121,57 @@ def test_get_library_agents_error(mocker: pytest_mock.MockFixture):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Mocker Not implemented")
|
||||
def test_add_agent_to_library_success(mocker: pytest_mock.MockFixture):
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
|
||||
mock_db_call.return_value = None
|
||||
mock_library_agent = library_model.LibraryAgent(
|
||||
id="test-library-agent-id",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
is_latest_version=True,
|
||||
updated_at=FIXED_NOW,
|
||||
)
|
||||
|
||||
response = client.post("/agents/test-version-id")
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.add_store_agent_to_library"
|
||||
)
|
||||
mock_db_call.return_value = mock_library_agent
|
||||
|
||||
response = client.post(
|
||||
"/agents", json={"store_listing_version_id": "test-version-id"}
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
# Verify the response contains the library agent data
|
||||
data = library_model.LibraryAgent.model_validate(response.json())
|
||||
assert data.id == "test-library-agent-id"
|
||||
assert data.graph_id == "test-agent-1"
|
||||
|
||||
mock_db_call.assert_called_once_with(
|
||||
store_listing_version_id="test-version-id", user_id="test-user-id"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Mocker Not implemented")
|
||||
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture):
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.add_store_agent_to_library"
|
||||
)
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.post("/agents/test-version-id")
|
||||
response = client.post(
|
||||
"/agents", json={"store_listing_version_id": "test-version-id"}
|
||||
)
|
||||
assert response.status_code == 500
|
||||
assert response.json()["detail"] == "Failed to add agent to library"
|
||||
assert "detail" in response.json() # Verify error response structure
|
||||
mock_db_call.assert_called_once_with(
|
||||
store_listing_version_id="test-version-id", user_id="test-user-id"
|
||||
)
|
||||
|
||||
@@ -259,8 +259,8 @@ def test_ask_otto_unauthenticated(mocker: pytest_mock.MockFixture) -> None:
|
||||
}
|
||||
|
||||
response = client.post("/ask", json=request_data)
|
||||
# When auth is disabled and Otto API URL is not configured, we get 503
|
||||
assert response.status_code == 503
|
||||
# When auth is disabled and Otto API URL is not configured, we get 502 (wrapped from 503)
|
||||
assert response.status_code == 502
|
||||
|
||||
# Restore the override
|
||||
app.dependency_overrides[autogpt_libs.auth.middleware.auth_middleware] = (
|
||||
|
||||
@@ -93,6 +93,14 @@ async def test_get_store_agent_details(mocker):
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Mock Profile prisma call
|
||||
mock_profile = mocker.MagicMock()
|
||||
mock_profile.userId = "user-id-123"
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
|
||||
# Mock StoreListing prisma call - this is what was missing
|
||||
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
|
||||
|
||||
@@ -19,7 +19,9 @@ from backend.server.ws_api import (
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket() -> AsyncMock:
|
||||
return AsyncMock(spec=WebSocket)
|
||||
mock = AsyncMock(spec=WebSocket)
|
||||
mock.query_params = {} # Add query_params attribute for authentication
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -29,8 +31,13 @@ def mock_manager() -> AsyncMock:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_router_subscribe(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker
|
||||
) -> None:
|
||||
# Mock the authenticate_websocket function to ensure it returns a valid user_id
|
||||
mocker.patch(
|
||||
"backend.server.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
mock_websocket.receive_text.side_effect = [
|
||||
WSMessage(
|
||||
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
|
||||
@@ -70,8 +77,13 @@ async def test_websocket_router_subscribe(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_router_unsubscribe(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker
|
||||
) -> None:
|
||||
# Mock the authenticate_websocket function to ensure it returns a valid user_id
|
||||
mocker.patch(
|
||||
"backend.server.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
mock_websocket.receive_text.side_effect = [
|
||||
WSMessage(
|
||||
method=WSMethod.UNSUBSCRIBE,
|
||||
@@ -108,8 +120,13 @@ async def test_websocket_router_unsubscribe(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_router_invalid_method(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, mocker
|
||||
) -> None:
|
||||
# Mock the authenticate_websocket function to ensure it returns a valid user_id
|
||||
mocker.patch(
|
||||
"backend.server.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
mock_websocket.receive_text.side_effect = [
|
||||
WSMessage(method=WSMethod.GRAPH_EXECUTION_EVENT).model_dump_json(),
|
||||
WebSocketDisconnect(),
|
||||
@@ -113,6 +113,11 @@ ignore_patterns = []
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "session"
|
||||
filterwarnings = [
|
||||
"ignore:'audioop' is deprecated:DeprecationWarning:discord.player",
|
||||
"ignore:invalid escape sequence:DeprecationWarning:tweepy.api",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"metric_id": "metric-123-uuid"
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"metric_id": "metric-float_precision-uuid",
|
||||
"test_case": "float_precision"
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"metric_id": "metric-integer_value-uuid",
|
||||
"test_case": "integer_value"
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"metric_id": "metric-large_number-uuid",
|
||||
"test_case": "large_number"
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"metric_id": "metric-negative_value-uuid",
|
||||
"test_case": "negative_value"
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"metric_id": "metric-tiny_number-uuid",
|
||||
"test_case": "tiny_number"
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"metric_id": "metric-zero_value-uuid",
|
||||
"test_case": "zero_value"
|
||||
}
|
||||
@@ -15,6 +15,12 @@
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
},
|
||||
"credentials_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
},
|
||||
"has_external_trigger": false,
|
||||
"trigger_setup_info": null,
|
||||
"new_output": false,
|
||||
"can_access_graph": true,
|
||||
"is_latest_version": true
|
||||
@@ -34,6 +40,12 @@
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
},
|
||||
"credentials_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
},
|
||||
"has_external_trigger": false,
|
||||
"trigger_setup_info": null,
|
||||
"new_output": false,
|
||||
"can_access_graph": false,
|
||||
"is_latest_version": true
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
import os
|
||||
|
||||
os.environ["ENABLE_AUTH"] = "false"
|
||||
@@ -4,7 +4,6 @@
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "next dev --turbo",
|
||||
"dev:test": "NODE_ENV=test && next dev --turbo",
|
||||
"build": "pnpm run generate:api-client && SKIP_STORYBOOK_TESTS=true next build",
|
||||
"start": "next start",
|
||||
"start:standalone": "cd .next/standalone && node server.js",
|
||||
@@ -28,7 +27,7 @@
|
||||
"dependencies": {
|
||||
"@faker-js/faker": "9.8.0",
|
||||
"@hookform/resolvers": "5.1.1",
|
||||
"@next/third-parties": "15.3.3",
|
||||
"@next/third-parties": "15.3.4",
|
||||
"@phosphor-icons/react": "2.1.10",
|
||||
"@radix-ui/react-alert-dialog": "1.1.14",
|
||||
"@radix-ui/react-avatar": "1.1.10",
|
||||
@@ -49,13 +48,13 @@
|
||||
"@radix-ui/react-tabs": "1.1.12",
|
||||
"@radix-ui/react-toast": "1.2.14",
|
||||
"@radix-ui/react-tooltip": "1.2.7",
|
||||
"@sentry/nextjs": "9.27.0",
|
||||
"@sentry/nextjs": "9.33.0",
|
||||
"@supabase/ssr": "0.6.1",
|
||||
"@supabase/supabase-js": "2.50.0",
|
||||
"@tanstack/react-query": "5.80.7",
|
||||
"@supabase/supabase-js": "2.50.2",
|
||||
"@tanstack/react-query": "5.81.2",
|
||||
"@tanstack/react-table": "8.21.3",
|
||||
"@types/jaro-winkler": "0.2.4",
|
||||
"@xyflow/react": "12.6.4",
|
||||
"@xyflow/react": "12.8.0",
|
||||
"ajv": "8.17.1",
|
||||
"boring-avatars": "1.11.2",
|
||||
"class-variance-authority": "0.7.1",
|
||||
@@ -63,24 +62,24 @@
|
||||
"cmdk": "1.1.1",
|
||||
"cookie": "1.0.2",
|
||||
"date-fns": "4.1.0",
|
||||
"dotenv": "16.5.0",
|
||||
"dotenv": "16.6.0",
|
||||
"elliptic": "6.6.1",
|
||||
"embla-carousel-react": "8.6.0",
|
||||
"framer-motion": "12.16.0",
|
||||
"framer-motion": "12.19.2",
|
||||
"geist": "1.4.2",
|
||||
"jaro-winkler": "0.2.8",
|
||||
"launchdarkly-react-client-sdk": "3.8.1",
|
||||
"lodash": "4.17.21",
|
||||
"lucide-react": "0.513.0",
|
||||
"lucide-react": "0.524.0",
|
||||
"moment": "2.30.1",
|
||||
"next": "15.3.3",
|
||||
"next": "15.3.4",
|
||||
"next-themes": "0.4.6",
|
||||
"party-js": "2.2.0",
|
||||
"react": "18.3.1",
|
||||
"react-day-picker": "9.7.0",
|
||||
"react-dom": "18.3.1",
|
||||
"react-drag-drop-files": "2.4.0",
|
||||
"react-hook-form": "7.57.0",
|
||||
"react-hook-form": "7.58.1",
|
||||
"react-icons": "5.5.0",
|
||||
"react-markdown": "9.0.3",
|
||||
"react-modal": "3.16.3",
|
||||
@@ -90,23 +89,23 @@
|
||||
"tailwind-merge": "2.6.0",
|
||||
"tailwindcss-animate": "1.0.7",
|
||||
"uuid": "11.1.0",
|
||||
"zod": "3.25.56"
|
||||
"zod": "3.25.67"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/storybook": "4.0.0",
|
||||
"@chromatic-com/storybook": "4.0.1",
|
||||
"@playwright/test": "1.53.1",
|
||||
"@storybook/addon-a11y": "9.0.12",
|
||||
"@storybook/addon-docs": "9.0.12",
|
||||
"@storybook/addon-links": "9.0.12",
|
||||
"@storybook/addon-onboarding": "9.0.12",
|
||||
"@storybook/nextjs": "9.0.12",
|
||||
"@tanstack/eslint-plugin-query": "5.78.0",
|
||||
"@tanstack/react-query-devtools": "5.80.10",
|
||||
"@storybook/addon-a11y": "9.0.13",
|
||||
"@storybook/addon-docs": "9.0.13",
|
||||
"@storybook/addon-links": "9.0.13",
|
||||
"@storybook/addon-onboarding": "9.0.13",
|
||||
"@storybook/nextjs": "9.0.13",
|
||||
"@tanstack/eslint-plugin-query": "5.81.2",
|
||||
"@tanstack/react-query-devtools": "5.81.2",
|
||||
"@testing-library/jest-dom": "6.6.3",
|
||||
"@testing-library/react": "16.3.0",
|
||||
"@testing-library/user-event": "14.6.1",
|
||||
"@types/canvas-confetti": "1.9.0",
|
||||
"@types/lodash": "4.17.18",
|
||||
"@types/lodash": "4.17.19",
|
||||
"@types/negotiator": "0.6.4",
|
||||
"@types/node": "22.15.30",
|
||||
"@types/react": "18.3.17",
|
||||
@@ -115,20 +114,20 @@
|
||||
"@vitest/browser": "3.2.4",
|
||||
"axe-playwright": "2.1.0",
|
||||
"chromatic": "11.25.2",
|
||||
"concurrently": "9.1.2",
|
||||
"concurrently": "9.2.0",
|
||||
"eslint": "8.57.1",
|
||||
"eslint-config-next": "15.3.4",
|
||||
"eslint-plugin-storybook": "9.0.12",
|
||||
"eslint-plugin-storybook": "9.0.13",
|
||||
"import-in-the-middle": "1.14.2",
|
||||
"jsdom": "26.1.0",
|
||||
"msw": "2.10.2",
|
||||
"msw-storybook-addon": "2.0.5",
|
||||
"orval": "7.10.0",
|
||||
"postcss": "8.5.6",
|
||||
"prettier": "3.5.3",
|
||||
"prettier-plugin-tailwindcss": "0.6.12",
|
||||
"prettier": "3.6.1",
|
||||
"prettier-plugin-tailwindcss": "0.6.13",
|
||||
"require-in-the-middle": "7.5.2",
|
||||
"storybook": "9.0.12",
|
||||
"storybook": "9.0.13",
|
||||
"tailwindcss": "3.4.17",
|
||||
"typescript": "5.8.3",
|
||||
"vite": "7.0.0",
|
||||
|
||||
@@ -34,7 +34,15 @@ export default defineConfig({
|
||||
bypassCSP: true,
|
||||
},
|
||||
/* Maximum time one test can run for */
|
||||
timeout: 60000,
|
||||
timeout: 30000,
|
||||
|
||||
/* Configure web server to start automatically */
|
||||
webServer: {
|
||||
command: "NEXT_PUBLIC_PW_TEST=true pnpm start",
|
||||
url: "http://localhost:3000",
|
||||
reuseExistingServer: !process.env.CI,
|
||||
timeout: 120 * 1000,
|
||||
},
|
||||
|
||||
/* Configure projects for major browsers */
|
||||
projects: [
|
||||
@@ -73,15 +81,4 @@ export default defineConfig({
|
||||
// use: { ...devices['Desktop Chrome'], channel: 'chrome' },
|
||||
// },
|
||||
],
|
||||
|
||||
/* Run your local server before starting the tests */
|
||||
webServer: {
|
||||
command: "pnpm start",
|
||||
url: "http://localhost:3000/",
|
||||
reuseExistingServer: !process.env.CI,
|
||||
timeout: 10 * 1000,
|
||||
env: {
|
||||
NODE_ENV: "test",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
818
autogpt_platform/frontend/pnpm-lock.yaml
generated
818
autogpt_platform/frontend/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -75,7 +75,7 @@ export const customMutator = async <T = any>(
|
||||
|
||||
return {
|
||||
status: response.status,
|
||||
response_data,
|
||||
data: response_data,
|
||||
headers: response.headers,
|
||||
} as T;
|
||||
};
|
||||
|
||||
@@ -2931,18 +2931,6 @@
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Preset Id" }
|
||||
},
|
||||
{
|
||||
"name": "graph_id",
|
||||
"in": "query",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Graph Id" }
|
||||
},
|
||||
{
|
||||
"name": "graph_version",
|
||||
"in": "query",
|
||||
"required": true,
|
||||
"schema": { "type": "integer", "title": "Graph Version" }
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
@@ -3128,11 +3116,11 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"put": {
|
||||
"patch": {
|
||||
"tags": ["v2", "library", "private"],
|
||||
"summary": "Update Library Agent",
|
||||
"description": "Update the library agent with the given fields.\n\nArgs:\n library_agent_id: ID of the library agent to update.\n payload: Fields to update (auto_update_version, is_favorite, etc.).\n user_id: ID of the authenticated user.\n\nReturns:\n 204 (No Content) on success.\n\nRaises:\n HTTPException(500): If a server/database error occurs.",
|
||||
"operationId": "putV2Update library agent",
|
||||
"description": "Update the library agent with the given fields.\n\nArgs:\n library_agent_id: ID of the library agent to update.\n payload: Fields to update (auto_update_version, is_favorite, etc.).\n user_id: ID of the authenticated user.\n\nRaises:\n HTTPException(500): If a server/database error occurs.",
|
||||
"operationId": "patchV2Update library agent",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "library_agent_id",
|
||||
@@ -3152,7 +3140,45 @@
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"204": { "description": "Agent updated successfully" },
|
||||
"200": {
|
||||
"description": "Agent updated successfully",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/LibraryAgent" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"500": { "description": "Server error" },
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"delete": {
|
||||
"tags": ["v2", "library", "private"],
|
||||
"summary": "Delete Library Agent",
|
||||
"description": "Soft-delete the specified library agent.\n\nArgs:\n library_agent_id: ID of the library agent to delete.\n user_id: ID of the authenticated user.\n\nReturns:\n 204 No Content if successful.\n\nRaises:\n HTTPException(404): If the agent does not exist.\n HTTPException(500): If a server/database error occurs.",
|
||||
"operationId": "deleteV2Delete library agent",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "library_agent_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Library Agent Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": { "application/json": { "schema": {} } }
|
||||
},
|
||||
"204": { "description": "Agent deleted successfully" },
|
||||
"404": { "description": "Agent not found" },
|
||||
"500": { "description": "Server error" },
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
@@ -3238,6 +3264,55 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/library/agents/{library_agent_id}/setup-trigger": {
|
||||
"post": {
|
||||
"tags": ["v2", "library", "private"],
|
||||
"summary": "Setup Trigger",
|
||||
"description": "Sets up a webhook-triggered `LibraryAgentPreset` for a `LibraryAgent`.\nReturns the correspondingly created `LibraryAgentPreset` with `webhook_id` set.",
|
||||
"operationId": "postV2SetupTrigger",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "library_agent_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"description": "ID of the library agent",
|
||||
"title": "Library Agent Id"
|
||||
},
|
||||
"description": "ID of the library agent"
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/TriggeredPresetSetupParams"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/LibraryAgentPreset" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/otto/ask": {
|
||||
"post": {
|
||||
"tags": ["v2", "otto"],
|
||||
@@ -3713,10 +3788,10 @@
|
||||
},
|
||||
"Body_postV2Execute_a_preset": {
|
||||
"properties": {
|
||||
"node_input": {
|
||||
"inputs": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Node Input"
|
||||
"title": "Inputs"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -4303,6 +4378,23 @@
|
||||
"type": "object",
|
||||
"title": "Input Schema"
|
||||
},
|
||||
"credentials_input_schema": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Credentials Input Schema",
|
||||
"description": "Input schema for credentials required by the agent"
|
||||
},
|
||||
"has_external_trigger": {
|
||||
"type": "boolean",
|
||||
"title": "Has External Trigger",
|
||||
"description": "Whether the agent has an external trigger (e.g. webhook) node"
|
||||
},
|
||||
"trigger_setup_info": {
|
||||
"anyOf": [
|
||||
{ "$ref": "#/components/schemas/LibraryAgentTriggerInfo" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
},
|
||||
"new_output": { "type": "boolean", "title": "New Output" },
|
||||
"can_access_graph": {
|
||||
"type": "boolean",
|
||||
@@ -4326,6 +4418,8 @@
|
||||
"name",
|
||||
"description",
|
||||
"input_schema",
|
||||
"credentials_input_schema",
|
||||
"has_external_trigger",
|
||||
"new_output",
|
||||
"can_access_graph",
|
||||
"is_latest_version"
|
||||
@@ -4342,6 +4436,13 @@
|
||||
"type": "object",
|
||||
"title": "Inputs"
|
||||
},
|
||||
"credentials": {
|
||||
"additionalProperties": {
|
||||
"$ref": "#/components/schemas/CredentialsMetaInput"
|
||||
},
|
||||
"type": "object",
|
||||
"title": "Credentials"
|
||||
},
|
||||
"name": { "type": "string", "title": "Name" },
|
||||
"description": { "type": "string", "title": "Description" },
|
||||
"is_active": {
|
||||
@@ -4349,7 +4450,12 @@
|
||||
"title": "Is Active",
|
||||
"default": true
|
||||
},
|
||||
"webhook_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Webhook Id"
|
||||
},
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"user_id": { "type": "string", "title": "User Id" },
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
@@ -4361,9 +4467,11 @@
|
||||
"graph_id",
|
||||
"graph_version",
|
||||
"inputs",
|
||||
"credentials",
|
||||
"name",
|
||||
"description",
|
||||
"id",
|
||||
"user_id",
|
||||
"updated_at"
|
||||
],
|
||||
"title": "LibraryAgentPreset",
|
||||
@@ -4378,12 +4486,23 @@
|
||||
"type": "object",
|
||||
"title": "Inputs"
|
||||
},
|
||||
"credentials": {
|
||||
"additionalProperties": {
|
||||
"$ref": "#/components/schemas/CredentialsMetaInput"
|
||||
},
|
||||
"type": "object",
|
||||
"title": "Credentials"
|
||||
},
|
||||
"name": { "type": "string", "title": "Name" },
|
||||
"description": { "type": "string", "title": "Description" },
|
||||
"is_active": {
|
||||
"type": "boolean",
|
||||
"title": "Is Active",
|
||||
"default": true
|
||||
},
|
||||
"webhook_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Webhook Id"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -4391,6 +4510,7 @@
|
||||
"graph_id",
|
||||
"graph_version",
|
||||
"inputs",
|
||||
"credentials",
|
||||
"name",
|
||||
"description"
|
||||
],
|
||||
@@ -4439,6 +4559,18 @@
|
||||
],
|
||||
"title": "Inputs"
|
||||
},
|
||||
"credentials": {
|
||||
"anyOf": [
|
||||
{
|
||||
"additionalProperties": {
|
||||
"$ref": "#/components/schemas/CredentialsMetaInput"
|
||||
},
|
||||
"type": "object"
|
||||
},
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Credentials"
|
||||
},
|
||||
"name": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Name"
|
||||
@@ -4481,6 +4613,24 @@
|
||||
"enum": ["COMPLETED", "HEALTHY", "WAITING", "ERROR"],
|
||||
"title": "LibraryAgentStatus"
|
||||
},
|
||||
"LibraryAgentTriggerInfo": {
|
||||
"properties": {
|
||||
"provider": { "$ref": "#/components/schemas/ProviderName" },
|
||||
"config_schema": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Config Schema",
|
||||
"description": "Input schema for the trigger block"
|
||||
},
|
||||
"credentials_input_name": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Credentials Input Name"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["provider", "config_schema", "credentials_input_name"],
|
||||
"title": "LibraryAgentTriggerInfo"
|
||||
},
|
||||
"LibraryAgentUpdateRequest": {
|
||||
"properties": {
|
||||
"auto_update_version": {
|
||||
@@ -4497,11 +4647,6 @@
|
||||
"anyOf": [{ "type": "boolean" }, { "type": "null" }],
|
||||
"title": "Is Archived",
|
||||
"description": "Archive the agent"
|
||||
},
|
||||
"is_deleted": {
|
||||
"anyOf": [{ "type": "boolean" }, { "type": "null" }],
|
||||
"title": "Is Deleted",
|
||||
"description": "Delete the agent"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -5773,6 +5918,31 @@
|
||||
"required": ["transactions", "next_transaction_time"],
|
||||
"title": "TransactionHistory"
|
||||
},
|
||||
"TriggeredPresetSetupParams": {
|
||||
"properties": {
|
||||
"name": { "type": "string", "title": "Name" },
|
||||
"description": {
|
||||
"type": "string",
|
||||
"title": "Description",
|
||||
"default": ""
|
||||
},
|
||||
"trigger_config": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Trigger Config"
|
||||
},
|
||||
"agent_credentials": {
|
||||
"additionalProperties": {
|
||||
"$ref": "#/components/schemas/CredentialsMetaInput"
|
||||
},
|
||||
"type": "object",
|
||||
"title": "Agent Credentials"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["name", "trigger_config"],
|
||||
"title": "TriggeredPresetSetupParams"
|
||||
},
|
||||
"TurnstileVerifyRequest": {
|
||||
"properties": {
|
||||
"token": {
|
||||
@@ -6055,16 +6225,6 @@
|
||||
"type": "string",
|
||||
"title": "Provider Webhook Id"
|
||||
},
|
||||
"attached_nodes": {
|
||||
"anyOf": [
|
||||
{
|
||||
"items": { "$ref": "#/components/schemas/NodeModel" },
|
||||
"type": "array"
|
||||
},
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Attached Nodes"
|
||||
},
|
||||
"url": { "type": "string", "title": "Url", "readOnly": true }
|
||||
},
|
||||
"type": "object",
|
||||
|
||||
@@ -1,4 +1,14 @@
|
||||
"use client";
|
||||
import {
|
||||
AuthBottomText,
|
||||
AuthButton,
|
||||
AuthCard,
|
||||
AuthFeedback,
|
||||
AuthHeader,
|
||||
GoogleOAuthButton,
|
||||
PasswordInput,
|
||||
Turnstile,
|
||||
} from "@/components/auth";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
@@ -8,19 +18,9 @@ import {
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import Link from "next/link";
|
||||
import LoadingBox from "@/components/ui/loading";
|
||||
import {
|
||||
AuthCard,
|
||||
AuthHeader,
|
||||
AuthButton,
|
||||
AuthFeedback,
|
||||
AuthBottomText,
|
||||
GoogleOAuthButton,
|
||||
PasswordInput,
|
||||
Turnstile,
|
||||
} from "@/components/auth";
|
||||
import { getBehaveAs } from "@/lib/utils";
|
||||
import Link from "next/link";
|
||||
import { useLoginPage } from "./useLoginPage";
|
||||
|
||||
export default function LoginPage() {
|
||||
@@ -100,7 +100,7 @@ export default function LoginPage() {
|
||||
<FormLabel className="flex w-full items-center justify-between">
|
||||
<span>Password</span>
|
||||
<Link
|
||||
href="/reset_password"
|
||||
href="/reset-password"
|
||||
className="text-sm font-normal leading-normal text-black underline"
|
||||
>
|
||||
Forgot your password?
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
"use client";
|
||||
|
||||
import { Loader2, MoreVertical } from "lucide-react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { useAPISection } from "./useAPISection";
|
||||
|
||||
export function APIKeysSection() {
|
||||
const { apiKeys, isLoading, isDeleting, handleRevokeKey } = useAPISection();
|
||||
|
||||
return (
|
||||
<>
|
||||
{isLoading ? (
|
||||
<div className="flex justify-center p-4">
|
||||
<Loader2 className="h-6 w-6 animate-spin" />
|
||||
</div>
|
||||
) : (
|
||||
apiKeys &&
|
||||
apiKeys.length > 0 && (
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>API Key</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead>Last Used</TableHead>
|
||||
<TableHead></TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{apiKeys.map((key) => (
|
||||
<TableRow key={key.id}>
|
||||
<TableCell>{key.name}</TableCell>
|
||||
<TableCell>
|
||||
<div className="rounded-md border p-1 px-2 text-xs">
|
||||
{`${key.prefix}******************${key.postfix}`}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Badge
|
||||
variant={
|
||||
key.status === "ACTIVE" ? "default" : "destructive"
|
||||
}
|
||||
className={
|
||||
key.status === "ACTIVE"
|
||||
? "border-green-600 bg-green-100 text-green-800"
|
||||
: "border-red-600 bg-red-100 text-red-800"
|
||||
}
|
||||
>
|
||||
{key.status}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{new Date(key.created_at).toLocaleDateString()}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{key.last_used_at
|
||||
? new Date(key.last_used_at).toLocaleDateString()
|
||||
: "Never"}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button variant="ghost" size="sm">
|
||||
<MoreVertical className="h-4 w-4" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
className="text-destructive"
|
||||
onClick={() => handleRevokeKey(key.id)}
|
||||
disabled={isDeleting}
|
||||
>
|
||||
Revoke
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
)
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
"use client";
|
||||
import {
|
||||
getGetV1ListUserApiKeysQueryKey,
|
||||
useDeleteV1RevokeApiKey,
|
||||
useGetV1ListUserApiKeys,
|
||||
} from "@/api/__generated__/endpoints/api-keys/api-keys";
|
||||
import { APIKeyWithoutHash } from "@/api/__generated__/models/aPIKeyWithoutHash";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
||||
|
||||
export const useAPISection = () => {
|
||||
const queryClient = getQueryClient();
|
||||
const { toast } = useToast();
|
||||
|
||||
const { data: apiKeys, isLoading } = useGetV1ListUserApiKeys({
|
||||
query: {
|
||||
select: (res) => {
|
||||
return (res.data as APIKeyWithoutHash[]).filter(
|
||||
(key) => key.status === "ACTIVE",
|
||||
);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const { mutateAsync: revokeAPIKey, isPending: isDeleting } =
|
||||
useDeleteV1RevokeApiKey({
|
||||
mutation: {
|
||||
onSettled: () => {
|
||||
return queryClient.invalidateQueries({
|
||||
queryKey: getGetV1ListUserApiKeysQueryKey(),
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const handleRevokeKey = async (keyId: string) => {
|
||||
try {
|
||||
await revokeAPIKey({
|
||||
keyId: keyId,
|
||||
});
|
||||
|
||||
toast({
|
||||
title: "Success",
|
||||
description: "AutoGPT Platform API key revoked successfully",
|
||||
});
|
||||
} catch {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to revoke AutoGPT Platform API key",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
apiKeys,
|
||||
isLoading,
|
||||
isDeleting,
|
||||
handleRevokeKey,
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,133 @@
|
||||
"use client";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from "@/components/ui/dialog";
|
||||
import { LuCopy } from "react-icons/lu";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
import { useAPIkeysModals } from "./useAPIkeysModals";
|
||||
import { APIKeyPermission } from "@/api/__generated__/models/aPIKeyPermission";
|
||||
|
||||
export const APIKeysModals = () => {
|
||||
const {
|
||||
isCreating,
|
||||
handleCreateKey,
|
||||
handleCopyKey,
|
||||
setIsCreateOpen,
|
||||
setIsKeyDialogOpen,
|
||||
isCreateOpen,
|
||||
isKeyDialogOpen,
|
||||
keyState,
|
||||
setKeyState,
|
||||
} = useAPIkeysModals();
|
||||
|
||||
return (
|
||||
<div className="mb-4 flex justify-end">
|
||||
<Dialog open={isCreateOpen} onOpenChange={setIsCreateOpen}>
|
||||
<DialogTrigger asChild>
|
||||
<Button>Create Key</Button>
|
||||
</DialogTrigger>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Create New API Key</DialogTitle>
|
||||
<DialogDescription>
|
||||
Create a new AutoGPT Platform API key
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="grid gap-4 py-4">
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="name">Name</Label>
|
||||
<Input
|
||||
id="name"
|
||||
value={keyState.newKeyName}
|
||||
onChange={(e) =>
|
||||
setKeyState((prev) => ({
|
||||
...prev,
|
||||
newKeyName: e.target.value,
|
||||
}))
|
||||
}
|
||||
placeholder="My AutoGPT Platform API Key"
|
||||
/>
|
||||
</div>
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="description">Description (Optional)</Label>
|
||||
<Input
|
||||
id="description"
|
||||
value={keyState.newKeyDescription}
|
||||
onChange={(e) =>
|
||||
setKeyState((prev) => ({
|
||||
...prev,
|
||||
newKeyDescription: e.target.value,
|
||||
}))
|
||||
}
|
||||
placeholder="Used for..."
|
||||
/>
|
||||
</div>
|
||||
<div className="grid gap-2">
|
||||
<Label>Permissions</Label>
|
||||
{Object.values(APIKeyPermission).map((permission) => (
|
||||
<div className="flex items-center space-x-2" key={permission}>
|
||||
<Checkbox
|
||||
id={permission}
|
||||
checked={keyState.selectedPermissions.includes(permission)}
|
||||
onCheckedChange={(checked: boolean) => {
|
||||
setKeyState((prev) => ({
|
||||
...prev,
|
||||
selectedPermissions: checked
|
||||
? [...prev.selectedPermissions, permission]
|
||||
: prev.selectedPermissions.filter(
|
||||
(p) => p !== permission,
|
||||
),
|
||||
}));
|
||||
}}
|
||||
/>
|
||||
<Label htmlFor={permission}>{permission}</Label>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button variant="outline" onClick={() => setIsCreateOpen(false)}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleCreateKey} disabled={isCreating}>
|
||||
Create
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
<Dialog open={isKeyDialogOpen} onOpenChange={setIsKeyDialogOpen}>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>AutoGPT Platform API Key Created</DialogTitle>
|
||||
<DialogDescription>
|
||||
Please copy your AutoGPT API key now. You won't be able to
|
||||
see it again!
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="flex items-center space-x-2">
|
||||
<code className="flex-1 rounded-md bg-secondary p-2 text-sm">
|
||||
{keyState.newApiKey}
|
||||
</code>
|
||||
<Button size="icon" variant="outline" onClick={handleCopyKey}>
|
||||
<LuCopy className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button onClick={() => setIsKeyDialogOpen(false)}>Close</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,78 @@
|
||||
"use client";
|
||||
import {
|
||||
getGetV1ListUserApiKeysQueryKey,
|
||||
usePostV1CreateNewApiKey,
|
||||
} from "@/api/__generated__/endpoints/api-keys/api-keys";
|
||||
import { APIKeyPermission } from "@/api/__generated__/models/aPIKeyPermission";
|
||||
import { CreateAPIKeyResponse } from "@/api/__generated__/models/createAPIKeyResponse";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
||||
import { useState } from "react";
|
||||
|
||||
export const useAPIkeysModals = () => {
|
||||
const [isCreateOpen, setIsCreateOpen] = useState(false);
|
||||
const [isKeyDialogOpen, setIsKeyDialogOpen] = useState(false);
|
||||
const [keyState, setKeyState] = useState({
|
||||
newKeyName: "",
|
||||
newKeyDescription: "",
|
||||
newApiKey: "",
|
||||
selectedPermissions: [] as APIKeyPermission[],
|
||||
});
|
||||
const queryClient = getQueryClient();
|
||||
const { toast } = useToast();
|
||||
|
||||
const { mutateAsync: createAPIKey, isPending: isCreating } =
|
||||
usePostV1CreateNewApiKey({
|
||||
mutation: {
|
||||
onSettled: () => {
|
||||
return queryClient.invalidateQueries({
|
||||
queryKey: getGetV1ListUserApiKeysQueryKey(),
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const handleCreateKey = async () => {
|
||||
try {
|
||||
const response = await createAPIKey({
|
||||
data: {
|
||||
name: keyState.newKeyName,
|
||||
permissions: keyState.selectedPermissions,
|
||||
description: keyState.newKeyDescription,
|
||||
},
|
||||
});
|
||||
setKeyState((prev) => ({
|
||||
...prev,
|
||||
newApiKey: (response.data as CreateAPIKeyResponse).plain_text_key,
|
||||
}));
|
||||
setIsCreateOpen(false);
|
||||
setIsKeyDialogOpen(true);
|
||||
} catch {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to create AutoGPT Platform API key",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const handleCopyKey = () => {
|
||||
navigator.clipboard.writeText(keyState.newApiKey);
|
||||
toast({
|
||||
title: "Copied",
|
||||
description: "AutoGPT Platform API key copied to clipboard",
|
||||
});
|
||||
};
|
||||
|
||||
return {
|
||||
isCreating,
|
||||
handleCreateKey,
|
||||
handleCopyKey,
|
||||
setIsCreateOpen,
|
||||
setIsKeyDialogOpen,
|
||||
isCreateOpen,
|
||||
isKeyDialogOpen,
|
||||
keyState,
|
||||
setKeyState,
|
||||
};
|
||||
};
|
||||
@@ -1,12 +1,31 @@
|
||||
import { Metadata } from "next/types";
|
||||
import { APIKeysSection } from "@/components/agptui/composite/APIKeySection";
|
||||
import { APIKeysSection } from "@/app/(platform)/profile/(user)/api_keys/components/APIKeySection/APIKeySection";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
import { APIKeysModals } from "./components/APIKeysModals/APIKeysModals";
|
||||
|
||||
export const metadata: Metadata = { title: "API Keys - AutoGPT Platform" };
|
||||
|
||||
const ApiKeysPage = () => {
|
||||
return (
|
||||
<div className="w-full pr-4 pt-24 md:pt-0">
|
||||
<APIKeysSection />
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>AutoGPT Platform API Keys</CardTitle>
|
||||
<CardDescription>
|
||||
Manage your AutoGPT Platform API keys for programmatic access
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<APIKeysModals />
|
||||
<APIKeysSection />
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -5,6 +5,7 @@ import { useCallback, useContext, useEffect, useMemo, useState } from "react";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
import { IconKey, IconUser } from "@/components/ui/icons";
|
||||
import { Trash2Icon } from "lucide-react";
|
||||
import { KeyIcon } from "@phosphor-icons/react/dist/ssr";
|
||||
import { providerIcons } from "@/components/integrations/credentials-input";
|
||||
import { CredentialsProvidersContext } from "@/components/integrations/credentials-provider";
|
||||
import {
|
||||
@@ -140,11 +141,12 @@ export default function UserIntegrationsPage() {
|
||||
...credentials,
|
||||
provider: provider.provider,
|
||||
providerName: provider.providerName,
|
||||
ProviderIcon: providerIcons[provider.provider],
|
||||
ProviderIcon: providerIcons[provider.provider] || KeyIcon,
|
||||
TypeIcon: {
|
||||
oauth2: IconUser,
|
||||
api_key: IconKey,
|
||||
user_password: IconKey,
|
||||
host_scoped: IconKey,
|
||||
}[credentials.type],
|
||||
})),
|
||||
)
|
||||
@@ -181,6 +183,7 @@ export default function UserIntegrationsPage() {
|
||||
oauth2: "OAuth2 credentials",
|
||||
api_key: "API key",
|
||||
user_password: "Username & password",
|
||||
host_scoped: "Host-scoped credentials",
|
||||
}[cred.type]
|
||||
}{" "}
|
||||
- <code>{cred.id}</code>
|
||||
|
||||
@@ -2,8 +2,11 @@
|
||||
|
||||
import { revalidatePath } from "next/cache";
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import BackendApi from "@/lib/autogpt-server-api";
|
||||
import { NotificationPreferenceDTO } from "@/lib/autogpt-server-api/types";
|
||||
import {
|
||||
postV1UpdateNotificationPreferences,
|
||||
postV1UpdateUserEmail,
|
||||
} from "@/api/__generated__/endpoints/auth/auth";
|
||||
|
||||
export async function updateSettings(formData: FormData) {
|
||||
const supabase = await getServerSupabase();
|
||||
@@ -29,8 +32,7 @@ export async function updateSettings(formData: FormData) {
|
||||
const { error: emailError } = await supabase.auth.updateUser({
|
||||
email,
|
||||
});
|
||||
const api = new BackendApi();
|
||||
await api.updateUserEmail(email);
|
||||
await postV1UpdateUserEmail(email);
|
||||
|
||||
if (emailError) {
|
||||
throw new Error(`${emailError.message}`);
|
||||
@@ -38,7 +40,6 @@ export async function updateSettings(formData: FormData) {
|
||||
}
|
||||
|
||||
try {
|
||||
const api = new BackendApi();
|
||||
const preferences: NotificationPreferenceDTO = {
|
||||
email: user?.email || "",
|
||||
preferences: {
|
||||
@@ -55,7 +56,7 @@ export async function updateSettings(formData: FormData) {
|
||||
},
|
||||
daily_limit: 0,
|
||||
};
|
||||
await api.updateUserPreferences(preferences);
|
||||
await postV1UpdateNotificationPreferences(preferences);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
throw new Error(`Failed to update preferences: ${error}`);
|
||||
@@ -64,9 +65,3 @@ export async function updateSettings(formData: FormData) {
|
||||
revalidatePath("/profile/settings");
|
||||
return { success: true };
|
||||
}
|
||||
|
||||
export async function getUserPreferences(): Promise<NotificationPreferenceDTO> {
|
||||
const api = new BackendApi();
|
||||
const preferences = await api.getUserPreferences();
|
||||
return preferences;
|
||||
}
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
"use client";
|
||||
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useForm } from "react-hook-form";
|
||||
import * as z from "zod";
|
||||
import { User } from "@supabase/supabase-js";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Form,
|
||||
@@ -18,100 +13,22 @@ import {
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { updateSettings } from "@/app/(platform)/profile/(user)/settings/actions";
|
||||
import { toast } from "@/components/ui/use-toast";
|
||||
import { NotificationPreferenceDTO } from "@/lib/autogpt-server-api";
|
||||
import { NotificationPreference } from "@/api/__generated__/models/notificationPreference";
|
||||
import { User } from "@supabase/supabase-js";
|
||||
import { useSettingsForm } from "./useSettingsForm";
|
||||
|
||||
const formSchema = z
|
||||
.object({
|
||||
email: z.string().email(),
|
||||
password: z
|
||||
.string()
|
||||
.optional()
|
||||
.refine((val) => {
|
||||
// If password is provided, it must be at least 8 characters
|
||||
if (val) return val.length >= 12;
|
||||
return true;
|
||||
}, "String must contain at least 12 character(s)"),
|
||||
confirmPassword: z.string().optional(),
|
||||
notifyOnAgentRun: z.boolean(),
|
||||
notifyOnZeroBalance: z.boolean(),
|
||||
notifyOnLowBalance: z.boolean(),
|
||||
notifyOnBlockExecutionFailed: z.boolean(),
|
||||
notifyOnContinuousAgentError: z.boolean(),
|
||||
notifyOnDailySummary: z.boolean(),
|
||||
notifyOnWeeklySummary: z.boolean(),
|
||||
notifyOnMonthlySummary: z.boolean(),
|
||||
})
|
||||
.refine(
|
||||
(data) => {
|
||||
if (data.password || data.confirmPassword) {
|
||||
return data.password === data.confirmPassword;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{
|
||||
message: "Passwords do not match",
|
||||
path: ["confirmPassword"],
|
||||
},
|
||||
);
|
||||
|
||||
interface SettingsFormProps {
|
||||
export const SettingsForm = ({
|
||||
preferences,
|
||||
user,
|
||||
}: {
|
||||
preferences: NotificationPreference;
|
||||
user: User;
|
||||
preferences: NotificationPreferenceDTO;
|
||||
}
|
||||
|
||||
export default function SettingsForm({ user, preferences }: SettingsFormProps) {
|
||||
const defaultValues = {
|
||||
email: user.email || "",
|
||||
password: "",
|
||||
confirmPassword: "",
|
||||
notifyOnAgentRun: preferences.preferences.AGENT_RUN,
|
||||
notifyOnZeroBalance: preferences.preferences.ZERO_BALANCE,
|
||||
notifyOnLowBalance: preferences.preferences.LOW_BALANCE,
|
||||
notifyOnBlockExecutionFailed:
|
||||
preferences.preferences.BLOCK_EXECUTION_FAILED,
|
||||
notifyOnContinuousAgentError:
|
||||
preferences.preferences.CONTINUOUS_AGENT_ERROR,
|
||||
notifyOnDailySummary: preferences.preferences.DAILY_SUMMARY,
|
||||
notifyOnWeeklySummary: preferences.preferences.WEEKLY_SUMMARY,
|
||||
notifyOnMonthlySummary: preferences.preferences.MONTHLY_SUMMARY,
|
||||
};
|
||||
|
||||
const form = useForm<z.infer<typeof formSchema>>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues,
|
||||
}) => {
|
||||
const { form, onSubmit, onCancel } = useSettingsForm({
|
||||
preferences,
|
||||
user,
|
||||
});
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
try {
|
||||
const formData = new FormData();
|
||||
|
||||
Object.entries(values).forEach(([key, value]) => {
|
||||
if (key !== "confirmPassword") {
|
||||
formData.append(key, value.toString());
|
||||
}
|
||||
});
|
||||
|
||||
await updateSettings(formData);
|
||||
|
||||
toast({
|
||||
title: "Successfully updated settings",
|
||||
});
|
||||
} catch (error) {
|
||||
toast({
|
||||
title: "Error",
|
||||
description:
|
||||
error instanceof Error ? error.message : "Something went wrong",
|
||||
variant: "destructive",
|
||||
});
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
function onCancel() {
|
||||
form.reset(defaultValues);
|
||||
}
|
||||
return (
|
||||
<Form {...form}>
|
||||
<form
|
||||
@@ -396,4 +313,4 @@ export default function SettingsForm({ user, preferences }: SettingsFormProps) {
|
||||
</form>
|
||||
</Form>
|
||||
);
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,51 @@
|
||||
import { z } from "zod";
|
||||
|
||||
export const formSchema = z
|
||||
.object({
|
||||
email: z.string().email(),
|
||||
password: z
|
||||
.string()
|
||||
.optional()
|
||||
.refine((val) => {
|
||||
if (val) return val.length >= 12;
|
||||
return true;
|
||||
}, "String must contain at least 12 character(s)"),
|
||||
confirmPassword: z.string().optional(),
|
||||
notifyOnAgentRun: z.boolean(),
|
||||
notifyOnZeroBalance: z.boolean(),
|
||||
notifyOnLowBalance: z.boolean(),
|
||||
notifyOnBlockExecutionFailed: z.boolean(),
|
||||
notifyOnContinuousAgentError: z.boolean(),
|
||||
notifyOnDailySummary: z.boolean(),
|
||||
notifyOnWeeklySummary: z.boolean(),
|
||||
notifyOnMonthlySummary: z.boolean(),
|
||||
})
|
||||
.refine((data) => {
|
||||
if (data.password || data.confirmPassword) {
|
||||
return data.password === data.confirmPassword;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
export const createDefaultValues = (
|
||||
user: { email?: string },
|
||||
preferences: { preferences?: Record<string, boolean> },
|
||||
) => {
|
||||
const defaultValues = {
|
||||
email: user.email || "",
|
||||
password: "",
|
||||
confirmPassword: "",
|
||||
notifyOnAgentRun: preferences.preferences?.AGENT_RUN,
|
||||
notifyOnZeroBalance: preferences.preferences?.ZERO_BALANCE,
|
||||
notifyOnLowBalance: preferences.preferences?.LOW_BALANCE,
|
||||
notifyOnBlockExecutionFailed:
|
||||
preferences.preferences?.BLOCK_EXECUTION_FAILED,
|
||||
notifyOnContinuousAgentError:
|
||||
preferences.preferences?.CONTINUOUS_AGENT_ERROR,
|
||||
notifyOnDailySummary: preferences.preferences?.DAILY_SUMMARY,
|
||||
notifyOnWeeklySummary: preferences.preferences?.WEEKLY_SUMMARY,
|
||||
notifyOnMonthlySummary: preferences.preferences?.MONTHLY_SUMMARY,
|
||||
};
|
||||
|
||||
return defaultValues;
|
||||
};
|
||||
@@ -0,0 +1,57 @@
|
||||
"use client";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { createDefaultValues, formSchema } from "./helper";
|
||||
import { z } from "zod";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { updateSettings } from "../../actions";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
import { NotificationPreference } from "@/api/__generated__/models/notificationPreference";
|
||||
import { User } from "@supabase/supabase-js";
|
||||
|
||||
export const useSettingsForm = ({
|
||||
preferences,
|
||||
user,
|
||||
}: {
|
||||
preferences: NotificationPreference;
|
||||
user: User;
|
||||
}) => {
|
||||
const { toast } = useToast();
|
||||
const defaultValues = createDefaultValues(user, preferences);
|
||||
|
||||
const form = useForm<z.infer<typeof formSchema>>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues,
|
||||
});
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
try {
|
||||
const formData = new FormData();
|
||||
|
||||
Object.entries(values).forEach(([key, value]) => {
|
||||
if (key !== "confirmPassword") {
|
||||
formData.append(key, value.toString());
|
||||
}
|
||||
});
|
||||
|
||||
await updateSettings(formData);
|
||||
|
||||
toast({
|
||||
title: "Successfully updated settings",
|
||||
});
|
||||
} catch (error) {
|
||||
toast({
|
||||
title: "Error",
|
||||
description:
|
||||
error instanceof Error ? error.message : "Something went wrong",
|
||||
variant: "destructive",
|
||||
});
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
function onCancel() {
|
||||
form.reset(defaultValues);
|
||||
}
|
||||
|
||||
return { form, onSubmit, onCancel };
|
||||
};
|
||||
@@ -1,23 +1,37 @@
|
||||
"use client";
|
||||
import { useGetV1GetNotificationPreferences } from "@/api/__generated__/endpoints/auth/auth";
|
||||
import { SettingsForm } from "@/app/(platform)/profile/(user)/settings/components/SettingsForm/SettingsForm";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import * as React from "react";
|
||||
import { Metadata } from "next";
|
||||
import SettingsForm from "@/components/profile/settings/SettingsForm";
|
||||
import { getServerUser } from "@/lib/supabase/server/getServerUser";
|
||||
import SettingsLoading from "./loading";
|
||||
import { redirect } from "next/navigation";
|
||||
import { getUserPreferences } from "./actions";
|
||||
|
||||
export const metadata: Metadata = {
|
||||
title: "Settings - AutoGPT Platform",
|
||||
description: "Manage your account settings and preferences.",
|
||||
};
|
||||
export default function SettingsPage() {
|
||||
const {
|
||||
data: preferences,
|
||||
isError,
|
||||
isLoading,
|
||||
} = useGetV1GetNotificationPreferences({
|
||||
query: {
|
||||
select: (res) => {
|
||||
return res.data;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export default async function SettingsPage() {
|
||||
const { user, error } = await getServerUser();
|
||||
const { user, isUserLoading } = useSupabase();
|
||||
|
||||
if (error || !user) {
|
||||
if (isLoading || isUserLoading) {
|
||||
return <SettingsLoading />;
|
||||
}
|
||||
|
||||
if (!user) {
|
||||
redirect("/login");
|
||||
}
|
||||
|
||||
const preferences = await getUserPreferences();
|
||||
if (isError || !preferences || !preferences.preferences) {
|
||||
return "Errror..."; // TODO: Will use a Error reusable components from Block Menu redesign
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="container max-w-2xl space-y-6 py-10">
|
||||
@@ -27,7 +41,7 @@ export default async function SettingsPage() {
|
||||
Manage your account settings and preferences.
|
||||
</p>
|
||||
</div>
|
||||
<SettingsForm user={user} preferences={preferences} />
|
||||
<SettingsForm preferences={preferences} user={user} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"use server";
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import { redirect } from "next/navigation";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { verifyTurnstileToken } from "@/lib/turnstile";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { redirect } from "next/navigation";
|
||||
|
||||
export async function sendResetEmail(email: string, turnstileToken: string) {
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
@@ -19,14 +19,14 @@ export async function sendResetEmail(email: string, turnstileToken: string) {
|
||||
// Verify Turnstile token if provided
|
||||
const success = await verifyTurnstileToken(
|
||||
turnstileToken,
|
||||
"reset_password",
|
||||
"reset-password",
|
||||
);
|
||||
if (!success) {
|
||||
return "CAPTCHA verification failed. Please try again.";
|
||||
}
|
||||
|
||||
const { error } = await supabase.auth.resetPasswordForEmail(email, {
|
||||
redirectTo: `${origin}/reset_password`,
|
||||
redirectTo: `${origin}/reset-password`,
|
||||
});
|
||||
|
||||
if (error) {
|
||||
@@ -1,9 +1,9 @@
|
||||
"use client";
|
||||
import {
|
||||
AuthCard,
|
||||
AuthHeader,
|
||||
AuthButton,
|
||||
AuthCard,
|
||||
AuthFeedback,
|
||||
AuthHeader,
|
||||
PasswordInput,
|
||||
Turnstile,
|
||||
} from "@/components/auth";
|
||||
@@ -17,16 +17,16 @@ import {
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import LoadingBox from "@/components/ui/loading";
|
||||
import { useTurnstile } from "@/hooks/useTurnstile";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { sendEmailFormSchema, changePasswordFormSchema } from "@/types/auth";
|
||||
import { getBehaveAs } from "@/lib/utils";
|
||||
import { changePasswordFormSchema, sendEmailFormSchema } from "@/types/auth";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useCallback, useState } from "react";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { z } from "zod";
|
||||
import { changePassword, sendResetEmail } from "./actions";
|
||||
import LoadingBox from "@/components/ui/loading";
|
||||
import { getBehaveAs } from "@/lib/utils";
|
||||
import { useTurnstile } from "@/hooks/useTurnstile";
|
||||
|
||||
export default function ResetPasswordPage() {
|
||||
const { supabase, user, isUserLoading } = useSupabase();
|
||||
@@ -1,24 +1,21 @@
|
||||
"use client";
|
||||
|
||||
import * as React from "react";
|
||||
import { useState } from "react";
|
||||
|
||||
import Image from "next/image";
|
||||
|
||||
import { Button } from "./Button";
|
||||
import { IconPersonFill } from "@/components/ui/icons";
|
||||
import { ProfileDetails } from "@/lib/autogpt-server-api/types";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { ProfileDetails } from "@/lib/autogpt-server-api/types";
|
||||
import { Button } from "./Button";
|
||||
|
||||
export const ProfileInfoForm = ({ profile }: { profile: ProfileDetails }) => {
|
||||
export function ProfileInfoForm({ profile }: { profile: ProfileDetails }) {
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [profileData, setProfileData] = useState<ProfileDetails>(profile);
|
||||
const { supabase } = useSupabase();
|
||||
const api = useBackendAPI();
|
||||
|
||||
const submitForm = async () => {
|
||||
async function submitForm() {
|
||||
try {
|
||||
setIsSubmitting(true);
|
||||
|
||||
@@ -39,48 +36,12 @@ export const ProfileInfoForm = ({ profile }: { profile: ProfileDetails }) => {
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
const handleImageUpload = async (file: File) => {
|
||||
async function handleImageUpload(file: File) {
|
||||
try {
|
||||
// Create FormData and append file
|
||||
const formData = new FormData();
|
||||
formData.append("file", file);
|
||||
const mediaUrl = await api.uploadStoreSubmissionMedia(file);
|
||||
|
||||
// Get auth token
|
||||
if (!supabase) {
|
||||
throw new Error("Supabase client not initialized");
|
||||
}
|
||||
|
||||
const {
|
||||
data: { session },
|
||||
} = await supabase.auth.getSession();
|
||||
const token = session?.access_token;
|
||||
|
||||
if (!token) {
|
||||
throw new Error("No authentication token found");
|
||||
}
|
||||
|
||||
// Make upload request
|
||||
const response = await fetch(
|
||||
`${process.env.NEXT_PUBLIC_AGPT_SERVER_URL}/store/submissions/media`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
body: formData,
|
||||
},
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Upload failed: ${response.statusText}`);
|
||||
}
|
||||
|
||||
// Get media URL from response
|
||||
const mediaUrl = await response.json();
|
||||
|
||||
// Update profile with new avatar URL
|
||||
const updatedProfile = {
|
||||
...profileData,
|
||||
avatar_url: mediaUrl,
|
||||
@@ -91,7 +52,7 @@ export const ProfileInfoForm = ({ profile }: { profile: ProfileDetails }) => {
|
||||
} catch (error) {
|
||||
console.error("Error uploading image:", error);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="w-full min-w-[800px] px-4 sm:px-8">
|
||||
@@ -261,4 +222,4 @@ export const ProfileInfoForm = ({ profile }: { profile: ProfileDetails }) => {
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,296 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
|
||||
import { APIKey, APIKeyPermission } from "@/lib/autogpt-server-api/types";
|
||||
|
||||
import { LuCopy } from "react-icons/lu";
|
||||
import { Loader2, MoreVertical } from "lucide-react";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
|
||||
export function APIKeysSection() {
|
||||
const [apiKeys, setApiKeys] = useState<APIKey[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [isCreateOpen, setIsCreateOpen] = useState(false);
|
||||
const [isKeyDialogOpen, setIsKeyDialogOpen] = useState(false);
|
||||
const [newKeyName, setNewKeyName] = useState("");
|
||||
const [newKeyDescription, setNewKeyDescription] = useState("");
|
||||
const [newApiKey, setNewApiKey] = useState("");
|
||||
const [selectedPermissions, setSelectedPermissions] = useState<
|
||||
APIKeyPermission[]
|
||||
>([]);
|
||||
const { toast } = useToast();
|
||||
const api = useBackendAPI();
|
||||
|
||||
useEffect(() => {
|
||||
loadAPIKeys();
|
||||
}, []);
|
||||
|
||||
const loadAPIKeys = async () => {
|
||||
setIsLoading(true);
|
||||
try {
|
||||
const keys = await api.listAPIKeys();
|
||||
setApiKeys(keys.filter((key) => key.status === "ACTIVE"));
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleCreateKey = async () => {
|
||||
try {
|
||||
const response = await api.createAPIKey(
|
||||
newKeyName,
|
||||
selectedPermissions,
|
||||
newKeyDescription,
|
||||
);
|
||||
|
||||
setNewApiKey(response.plain_text_key);
|
||||
setIsCreateOpen(false);
|
||||
setIsKeyDialogOpen(true);
|
||||
loadAPIKeys();
|
||||
} catch {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to create AutoGPT Platform API key",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const handleCopyKey = () => {
|
||||
navigator.clipboard.writeText(newApiKey);
|
||||
toast({
|
||||
title: "Copied",
|
||||
description: "AutoGPT Platform API key copied to clipboard",
|
||||
});
|
||||
};
|
||||
|
||||
const handleRevokeKey = async (keyId: string) => {
|
||||
try {
|
||||
await api.revokeAPIKey(keyId);
|
||||
toast({
|
||||
title: "Success",
|
||||
description: "AutoGPT Platform API key revoked successfully",
|
||||
});
|
||||
loadAPIKeys();
|
||||
} catch {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to revoke AutoGPT Platform API key",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>AutoGPT Platform API Keys</CardTitle>
|
||||
<CardDescription>
|
||||
Manage your AutoGPT Platform API keys for programmatic access
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="mb-4 flex justify-end">
|
||||
<Dialog open={isCreateOpen} onOpenChange={setIsCreateOpen}>
|
||||
<DialogTrigger asChild>
|
||||
<Button>Create Key</Button>
|
||||
</DialogTrigger>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Create New API Key</DialogTitle>
|
||||
<DialogDescription>
|
||||
Create a new AutoGPT Platform API key
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="grid gap-4 py-4">
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="name">Name</Label>
|
||||
<Input
|
||||
id="name"
|
||||
value={newKeyName}
|
||||
onChange={(e) => setNewKeyName(e.target.value)}
|
||||
placeholder="My AutoGPT Platform API Key"
|
||||
/>
|
||||
</div>
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="description">Description (Optional)</Label>
|
||||
<Input
|
||||
id="description"
|
||||
value={newKeyDescription}
|
||||
onChange={(e) => setNewKeyDescription(e.target.value)}
|
||||
placeholder="Used for..."
|
||||
/>
|
||||
</div>
|
||||
<div className="grid gap-2">
|
||||
<Label>Permissions</Label>
|
||||
{Object.values(APIKeyPermission).map((permission) => (
|
||||
<div
|
||||
className="flex items-center space-x-2"
|
||||
key={permission}
|
||||
>
|
||||
<Checkbox
|
||||
id={permission}
|
||||
checked={selectedPermissions.includes(permission)}
|
||||
onCheckedChange={(checked: boolean) => {
|
||||
setSelectedPermissions(
|
||||
checked
|
||||
? [...selectedPermissions, permission]
|
||||
: selectedPermissions.filter(
|
||||
(p) => p !== permission,
|
||||
),
|
||||
);
|
||||
}}
|
||||
/>
|
||||
<Label htmlFor={permission}>{permission}</Label>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => setIsCreateOpen(false)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleCreateKey}>Create</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
<Dialog open={isKeyDialogOpen} onOpenChange={setIsKeyDialogOpen}>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>AutoGPT Platform API Key Created</DialogTitle>
|
||||
<DialogDescription>
|
||||
Please copy your AutoGPT API key now. You won't be able
|
||||
to see it again!
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="flex items-center space-x-2">
|
||||
<code className="flex-1 rounded-md bg-secondary p-2 text-sm">
|
||||
{newApiKey}
|
||||
</code>
|
||||
<Button size="icon" variant="outline" onClick={handleCopyKey}>
|
||||
<LuCopy className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button onClick={() => setIsKeyDialogOpen(false)}>Close</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</div>
|
||||
|
||||
{isLoading ? (
|
||||
<div className="flex justify-center p-4">
|
||||
<Loader2 className="h-6 w-6 animate-spin" />
|
||||
</div>
|
||||
) : (
|
||||
apiKeys.length > 0 && (
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>API Key</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead>Last Used</TableHead>
|
||||
<TableHead></TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{apiKeys.map((key) => (
|
||||
<TableRow key={key.id}>
|
||||
<TableCell>{key.name}</TableCell>
|
||||
<TableCell>
|
||||
<div className="rounded-md border p-1 px-2 text-xs">
|
||||
{`${key.prefix}******************${key.postfix}`}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Badge
|
||||
variant={
|
||||
key.status === "ACTIVE" ? "default" : "destructive"
|
||||
}
|
||||
className={
|
||||
key.status === "ACTIVE"
|
||||
? "border-green-600 bg-green-100 text-green-800"
|
||||
: "border-red-600 bg-red-100 text-red-800"
|
||||
}
|
||||
>
|
||||
{key.status}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{new Date(key.created_at).toLocaleDateString()}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{key.last_used_at
|
||||
? new Date(key.last_used_at).toLocaleDateString()
|
||||
: "Never"}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button variant="ghost" size="sm">
|
||||
<MoreVertical className="h-4 w-4" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
className="text-destructive"
|
||||
onClick={() => handleRevokeKey(key.id)}
|
||||
>
|
||||
Revoke
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
)
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
import { FC } from "react";
|
||||
import { z } from "zod";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormDescription,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import useCredentials from "@/hooks/useCredentials";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
|
||||
export const APIKeyCredentialsModal: FC<{
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
siblingInputs?: Record<string, any>;
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
const formSchema = z.object({
|
||||
apiKey: z.string().min(1, "API Key is required"),
|
||||
title: z.string().min(1, "Name is required"),
|
||||
expiresAt: z.string().optional(),
|
||||
});
|
||||
|
||||
const form = useForm<z.infer<typeof formSchema>>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: {
|
||||
apiKey: "",
|
||||
title: "",
|
||||
expiresAt: "",
|
||||
},
|
||||
});
|
||||
|
||||
if (!credentials || credentials.isLoading || !credentials.supportsApiKey) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { provider, providerName, createAPIKeyCredentials } = credentials;
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
const expiresAt = values.expiresAt
|
||||
? new Date(values.expiresAt).getTime() / 1000
|
||||
: undefined;
|
||||
const newCredentials = await createAPIKeyCredentials({
|
||||
api_key: values.apiKey,
|
||||
title: values.title,
|
||||
expires_at: expiresAt,
|
||||
});
|
||||
onCredentialsCreate({
|
||||
provider,
|
||||
id: newCredentials.id,
|
||||
type: "api_key",
|
||||
title: newCredentials.title,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Add new API key for {providerName}</DialogTitle>
|
||||
{schema.description && (
|
||||
<DialogDescription>{schema.description}</DialogDescription>
|
||||
)}
|
||||
</DialogHeader>
|
||||
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="apiKey"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>API Key</FormLabel>
|
||||
{schema.credentials_scopes && (
|
||||
<FormDescription>
|
||||
Required scope(s) for this block:{" "}
|
||||
{schema.credentials_scopes?.map((s, i, a) => (
|
||||
<span key={i}>
|
||||
<code>{s}</code>
|
||||
{i < a.length - 1 && ", "}
|
||||
</span>
|
||||
))}
|
||||
</FormDescription>
|
||||
)}
|
||||
<FormControl>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="Enter API key..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="title"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Name</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Enter a name for this API key..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="expiresAt"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Expiration Date (Optional)</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="datetime-local"
|
||||
placeholder="Select expiration date..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<Button type="submit" className="w-full">
|
||||
Save & use this API key
|
||||
</Button>
|
||||
</form>
|
||||
</Form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
@@ -1,12 +1,8 @@
|
||||
import { FC, useEffect, useMemo, useState } from "react";
|
||||
import { z } from "zod";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import SchemaTooltip from "@/components/SchemaTooltip";
|
||||
import useCredentials from "@/hooks/useCredentials";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { NotionLogoIcon } from "@radix-ui/react-icons";
|
||||
import {
|
||||
FaDiscord,
|
||||
@@ -23,22 +19,6 @@ import {
|
||||
CredentialsProviderName,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import { IconKey, IconKeyPlus, IconUserPlus } from "@/components/ui/icons";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormDescription,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
@@ -48,6 +28,11 @@ import {
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { APIKeyCredentialsModal } from "./api-key-credentials-modal";
|
||||
import { UserPasswordCredentialsModal } from "./user-password-credentials-modal";
|
||||
import { HostScopedCredentialsModal } from "./host-scoped-credentials-modal";
|
||||
import { OAuth2FlowWaitingModal } from "./oauth2-flow-waiting-modal";
|
||||
import { getHostFromUrl } from "@/lib/utils/url";
|
||||
|
||||
const fallbackIcon = FaKey;
|
||||
|
||||
@@ -63,6 +48,7 @@ export const providerIcons: Record<
|
||||
github: FaGithub,
|
||||
google: FaGoogle,
|
||||
groq: fallbackIcon,
|
||||
http: fallbackIcon,
|
||||
notion: NotionLogoIcon,
|
||||
nvidia: fallbackIcon,
|
||||
discord: FaDiscord,
|
||||
@@ -129,6 +115,8 @@ export const CredentialsInput: FC<{
|
||||
isUserPasswordCredentialsModalOpen,
|
||||
setUserPasswordCredentialsModalOpen,
|
||||
] = useState(false);
|
||||
const [isHostScopedCredentialsModalOpen, setHostScopedCredentialsModalOpen] =
|
||||
useState(false);
|
||||
const [isOAuth2FlowInProgress, setOAuth2FlowInProgress] = useState(false);
|
||||
const [oAuthPopupController, setOAuthPopupController] =
|
||||
useState<AbortController | null>(null);
|
||||
@@ -148,13 +136,27 @@ export const CredentialsInput: FC<{
|
||||
}
|
||||
}, [credentials, selectedCredentials, onSelectCredentials]);
|
||||
|
||||
const singleCredential = useMemo(() => {
|
||||
if (!credentials || !("savedCredentials" in credentials)) return null;
|
||||
const { hasRelevantCredentials, singleCredential } = useMemo(() => {
|
||||
if (!credentials || !("savedCredentials" in credentials)) {
|
||||
return {
|
||||
hasRelevantCredentials: false,
|
||||
singleCredential: null,
|
||||
};
|
||||
}
|
||||
|
||||
if (credentials.savedCredentials.length === 1)
|
||||
return credentials.savedCredentials[0];
|
||||
// Simple logic: if we have any saved credentials, we have relevant credentials
|
||||
const hasRelevant = credentials.savedCredentials.length > 0;
|
||||
|
||||
return null;
|
||||
// Auto-select single credential if only one exists
|
||||
const single =
|
||||
credentials.savedCredentials.length === 1
|
||||
? credentials.savedCredentials[0]
|
||||
: null;
|
||||
|
||||
return {
|
||||
hasRelevantCredentials: hasRelevant,
|
||||
singleCredential: single,
|
||||
};
|
||||
}, [credentials]);
|
||||
|
||||
// If only 1 credential is available, auto-select it and hide this input
|
||||
@@ -178,6 +180,7 @@ export const CredentialsInput: FC<{
|
||||
supportsApiKey,
|
||||
supportsOAuth2,
|
||||
supportsUserPassword,
|
||||
supportsHostScoped,
|
||||
savedCredentials,
|
||||
oAuthCallback,
|
||||
} = credentials;
|
||||
@@ -271,7 +274,7 @@ export const CredentialsInput: FC<{
|
||||
);
|
||||
}
|
||||
|
||||
const ProviderIcon = providerIcons[provider];
|
||||
const ProviderIcon = providerIcons[provider] || fallbackIcon;
|
||||
const modals = (
|
||||
<>
|
||||
{supportsApiKey && (
|
||||
@@ -305,6 +308,18 @@ export const CredentialsInput: FC<{
|
||||
siblingInputs={siblingInputs}
|
||||
/>
|
||||
)}
|
||||
{supportsHostScoped && (
|
||||
<HostScopedCredentialsModal
|
||||
schema={schema}
|
||||
open={isHostScopedCredentialsModalOpen}
|
||||
onClose={() => setHostScopedCredentialsModalOpen(false)}
|
||||
onCredentialsCreate={(creds) => {
|
||||
onSelectCredentials(creds);
|
||||
setHostScopedCredentialsModalOpen(false);
|
||||
}}
|
||||
siblingInputs={siblingInputs}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -317,8 +332,8 @@ export const CredentialsInput: FC<{
|
||||
</div>
|
||||
);
|
||||
|
||||
// No saved credentials yet
|
||||
if (savedCredentials.length === 0) {
|
||||
// Show credentials creation UI when no relevant credentials exist
|
||||
if (!hasRelevantCredentials) {
|
||||
return (
|
||||
<div>
|
||||
{fieldHeader}
|
||||
@@ -342,6 +357,12 @@ export const CredentialsInput: FC<{
|
||||
Enter username and password
|
||||
</Button>
|
||||
)}
|
||||
{supportsHostScoped && credentials.discriminatorValue && (
|
||||
<Button onClick={() => setHostScopedCredentialsModalOpen(true)}>
|
||||
<ProviderIcon className="mr-2 h-4 w-4" />
|
||||
{`Enter sensitive headers for ${getHostFromUrl(credentials.discriminatorValue)}`}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
{modals}
|
||||
{oAuthError && (
|
||||
@@ -358,6 +379,12 @@ export const CredentialsInput: FC<{
|
||||
} else if (newValue === "add-api-key") {
|
||||
// Open API key dialog
|
||||
setAPICredentialsModalOpen(true);
|
||||
} else if (newValue === "add-user-password") {
|
||||
// Open user password dialog
|
||||
setUserPasswordCredentialsModalOpen(true);
|
||||
} else if (newValue === "add-host-scoped") {
|
||||
// Open host-scoped credentials dialog
|
||||
setHostScopedCredentialsModalOpen(true);
|
||||
} else {
|
||||
const selectedCreds = savedCredentials.find((c) => c.id == newValue)!;
|
||||
|
||||
@@ -406,6 +433,15 @@ export const CredentialsInput: FC<{
|
||||
{credentials.title}
|
||||
</SelectItem>
|
||||
))}
|
||||
{savedCredentials
|
||||
.filter((c) => c.type == "host_scoped")
|
||||
.map((credentials, index) => (
|
||||
<SelectItem key={index} value={credentials.id}>
|
||||
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||
<IconKey className="mr-1.5 inline" />
|
||||
{credentials.title}
|
||||
</SelectItem>
|
||||
))}
|
||||
<SelectSeparator />
|
||||
{supportsOAuth2 && (
|
||||
<SelectItem value="sign-in">
|
||||
@@ -425,6 +461,12 @@ export const CredentialsInput: FC<{
|
||||
Add new user password
|
||||
</SelectItem>
|
||||
)}
|
||||
{supportsHostScoped && (
|
||||
<SelectItem value="add-host-scoped">
|
||||
<IconKey className="mr-1.5 inline" />
|
||||
Add host-scoped headers
|
||||
</SelectItem>
|
||||
)}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
{modals}
|
||||
@@ -434,291 +476,3 @@ export const CredentialsInput: FC<{
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const APIKeyCredentialsModal: FC<{
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
siblingInputs?: Record<string, any>;
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
const formSchema = z.object({
|
||||
apiKey: z.string().min(1, "API Key is required"),
|
||||
title: z.string().min(1, "Name is required"),
|
||||
expiresAt: z.string().optional(),
|
||||
});
|
||||
|
||||
const form = useForm<z.infer<typeof formSchema>>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: {
|
||||
apiKey: "",
|
||||
title: "",
|
||||
expiresAt: "",
|
||||
},
|
||||
});
|
||||
|
||||
if (!credentials || credentials.isLoading || !credentials.supportsApiKey) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { provider, providerName, createAPIKeyCredentials } = credentials;
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
const expiresAt = values.expiresAt
|
||||
? new Date(values.expiresAt).getTime() / 1000
|
||||
: undefined;
|
||||
const newCredentials = await createAPIKeyCredentials({
|
||||
api_key: values.apiKey,
|
||||
title: values.title,
|
||||
expires_at: expiresAt,
|
||||
});
|
||||
onCredentialsCreate({
|
||||
provider,
|
||||
id: newCredentials.id,
|
||||
type: "api_key",
|
||||
title: newCredentials.title,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Add new API key for {providerName}</DialogTitle>
|
||||
{schema.description && (
|
||||
<DialogDescription>{schema.description}</DialogDescription>
|
||||
)}
|
||||
</DialogHeader>
|
||||
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="apiKey"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>API Key</FormLabel>
|
||||
{schema.credentials_scopes && (
|
||||
<FormDescription>
|
||||
Required scope(s) for this block:{" "}
|
||||
{schema.credentials_scopes?.map((s, i, a) => (
|
||||
<span key={i}>
|
||||
<code>{s}</code>
|
||||
{i < a.length - 1 && ", "}
|
||||
</span>
|
||||
))}
|
||||
</FormDescription>
|
||||
)}
|
||||
<FormControl>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="Enter API key..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="title"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Name</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Enter a name for this API key..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="expiresAt"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Expiration Date (Optional)</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="datetime-local"
|
||||
placeholder="Select expiration date..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<Button type="submit" className="w-full">
|
||||
Save & use this API key
|
||||
</Button>
|
||||
</form>
|
||||
</Form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
export const UserPasswordCredentialsModal: FC<{
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
siblingInputs?: Record<string, any>;
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
const formSchema = z.object({
|
||||
username: z.string().min(1, "Username is required"),
|
||||
password: z.string().min(1, "Password is required"),
|
||||
title: z.string().min(1, "Name is required"),
|
||||
});
|
||||
|
||||
const form = useForm<z.infer<typeof formSchema>>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: {
|
||||
username: "",
|
||||
password: "",
|
||||
title: "",
|
||||
},
|
||||
});
|
||||
|
||||
if (
|
||||
!credentials ||
|
||||
credentials.isLoading ||
|
||||
!credentials.supportsUserPassword
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { provider, providerName, createUserPasswordCredentials } = credentials;
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
const newCredentials = await createUserPasswordCredentials({
|
||||
username: values.username,
|
||||
password: values.password,
|
||||
title: values.title,
|
||||
});
|
||||
onCredentialsCreate({
|
||||
provider,
|
||||
id: newCredentials.id,
|
||||
type: "user_password",
|
||||
title: newCredentials.title,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
Add new username & password for {providerName}
|
||||
</DialogTitle>
|
||||
</DialogHeader>
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="username"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Username</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Enter username..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="password"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Password</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="Enter password..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="title"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Name</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Enter a name for this user login..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<Button type="submit" className="w-full">
|
||||
Save & use this user login
|
||||
</Button>
|
||||
</form>
|
||||
</Form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
export const OAuth2FlowWaitingModal: FC<{
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
providerName: string;
|
||||
}> = ({ open, onClose, providerName }) => {
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
Waiting on {providerName} sign-in process...
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
Complete the sign-in process in the pop-up window.
|
||||
<br />
|
||||
Closing this dialog will cancel the sign-in process.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -6,10 +6,12 @@ import {
|
||||
CredentialsDeleteResponse,
|
||||
CredentialsMetaResponse,
|
||||
CredentialsProviderName,
|
||||
HostScopedCredentials,
|
||||
PROVIDER_NAMES,
|
||||
UserPasswordCredentials,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { useToastOnFail } from "@/components/ui/use-toast";
|
||||
|
||||
// Get keys from CredentialsProviderName type
|
||||
const CREDENTIALS_PROVIDER_NAMES = Object.values(
|
||||
@@ -30,6 +32,7 @@ const providerDisplayNames: Record<CredentialsProviderName, string> = {
|
||||
google: "Google",
|
||||
google_maps: "Google Maps",
|
||||
groq: "Groq",
|
||||
http: "HTTP",
|
||||
hubspot: "Hubspot",
|
||||
ideogram: "Ideogram",
|
||||
jina: "Jina",
|
||||
@@ -68,6 +71,11 @@ type UserPasswordCredentialsCreatable = Omit<
|
||||
"id" | "provider" | "type"
|
||||
>;
|
||||
|
||||
type HostScopedCredentialsCreatable = Omit<
|
||||
HostScopedCredentials,
|
||||
"id" | "provider" | "type"
|
||||
>;
|
||||
|
||||
export type CredentialsProviderData = {
|
||||
provider: CredentialsProviderName;
|
||||
providerName: string;
|
||||
@@ -82,6 +90,9 @@ export type CredentialsProviderData = {
|
||||
createUserPasswordCredentials: (
|
||||
credentials: UserPasswordCredentialsCreatable,
|
||||
) => Promise<CredentialsMetaResponse>;
|
||||
createHostScopedCredentials: (
|
||||
credentials: HostScopedCredentialsCreatable,
|
||||
) => Promise<CredentialsMetaResponse>;
|
||||
deleteCredentials: (
|
||||
id: string,
|
||||
force?: boolean,
|
||||
@@ -106,6 +117,7 @@ export default function CredentialsProvider({
|
||||
useState<CredentialsProvidersContextType | null>(null);
|
||||
const { isLoggedIn } = useSupabase();
|
||||
const api = useBackendAPI();
|
||||
const onFailToast = useToastOnFail();
|
||||
|
||||
const addCredentials = useCallback(
|
||||
(
|
||||
@@ -134,11 +146,16 @@ export default function CredentialsProvider({
|
||||
code: string,
|
||||
state_token: string,
|
||||
): Promise<CredentialsMetaResponse> => {
|
||||
const credsMeta = await api.oAuthCallback(provider, code, state_token);
|
||||
addCredentials(provider, credsMeta);
|
||||
return credsMeta;
|
||||
try {
|
||||
const credsMeta = await api.oAuthCallback(provider, code, state_token);
|
||||
addCredentials(provider, credsMeta);
|
||||
return credsMeta;
|
||||
} catch (error) {
|
||||
onFailToast("complete OAuth authentication")(error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[api, addCredentials],
|
||||
[api, addCredentials, onFailToast],
|
||||
);
|
||||
|
||||
/** Wraps `BackendAPI.createAPIKeyCredentials`, and adds the result to the internal credentials store. */
|
||||
@@ -147,14 +164,19 @@ export default function CredentialsProvider({
|
||||
provider: CredentialsProviderName,
|
||||
credentials: APIKeyCredentialsCreatable,
|
||||
): Promise<CredentialsMetaResponse> => {
|
||||
const credsMeta = await api.createAPIKeyCredentials({
|
||||
provider,
|
||||
...credentials,
|
||||
});
|
||||
addCredentials(provider, credsMeta);
|
||||
return credsMeta;
|
||||
try {
|
||||
const credsMeta = await api.createAPIKeyCredentials({
|
||||
provider,
|
||||
...credentials,
|
||||
});
|
||||
addCredentials(provider, credsMeta);
|
||||
return credsMeta;
|
||||
} catch (error) {
|
||||
onFailToast("create API key credentials")(error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[api, addCredentials],
|
||||
[api, addCredentials, onFailToast],
|
||||
);
|
||||
|
||||
/** Wraps `BackendAPI.createUserPasswordCredentials`, and adds the result to the internal credentials store. */
|
||||
@@ -163,14 +185,40 @@ export default function CredentialsProvider({
|
||||
provider: CredentialsProviderName,
|
||||
credentials: UserPasswordCredentialsCreatable,
|
||||
): Promise<CredentialsMetaResponse> => {
|
||||
const credsMeta = await api.createUserPasswordCredentials({
|
||||
provider,
|
||||
...credentials,
|
||||
});
|
||||
addCredentials(provider, credsMeta);
|
||||
return credsMeta;
|
||||
try {
|
||||
const credsMeta = await api.createUserPasswordCredentials({
|
||||
provider,
|
||||
...credentials,
|
||||
});
|
||||
addCredentials(provider, credsMeta);
|
||||
return credsMeta;
|
||||
} catch (error) {
|
||||
onFailToast("create user/password credentials")(error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[api, addCredentials],
|
||||
[api, addCredentials, onFailToast],
|
||||
);
|
||||
|
||||
/** Wraps `BackendAPI.createHostScopedCredentials`, and adds the result to the internal credentials store. */
|
||||
const createHostScopedCredentials = useCallback(
|
||||
async (
|
||||
provider: CredentialsProviderName,
|
||||
credentials: HostScopedCredentialsCreatable,
|
||||
): Promise<CredentialsMetaResponse> => {
|
||||
try {
|
||||
const credsMeta = await api.createHostScopedCredentials({
|
||||
provider,
|
||||
...credentials,
|
||||
});
|
||||
addCredentials(provider, credsMeta);
|
||||
return credsMeta;
|
||||
} catch (error) {
|
||||
onFailToast("create host-scoped credentials")(error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[api, addCredentials, onFailToast],
|
||||
);
|
||||
|
||||
/** Wraps `BackendAPI.deleteCredentials`, and removes the credentials from the internal store. */
|
||||
@@ -182,26 +230,31 @@ export default function CredentialsProvider({
|
||||
): Promise<
|
||||
CredentialsDeleteResponse | CredentialsDeleteNeedConfirmationResponse
|
||||
> => {
|
||||
const result = await api.deleteCredentials(provider, id, force);
|
||||
if (!result.deleted) {
|
||||
return result;
|
||||
}
|
||||
setProviders((prev) => {
|
||||
if (!prev || !prev[provider]) return prev;
|
||||
try {
|
||||
const result = await api.deleteCredentials(provider, id, force);
|
||||
if (!result.deleted) {
|
||||
return result;
|
||||
}
|
||||
setProviders((prev) => {
|
||||
if (!prev || !prev[provider]) return prev;
|
||||
|
||||
return {
|
||||
...prev,
|
||||
[provider]: {
|
||||
...prev[provider],
|
||||
savedCredentials: prev[provider].savedCredentials.filter(
|
||||
(cred) => cred.id !== id,
|
||||
),
|
||||
},
|
||||
};
|
||||
});
|
||||
return result;
|
||||
return {
|
||||
...prev,
|
||||
[provider]: {
|
||||
...prev[provider],
|
||||
savedCredentials: prev[provider].savedCredentials.filter(
|
||||
(cred) => cred.id !== id,
|
||||
),
|
||||
},
|
||||
};
|
||||
});
|
||||
return result;
|
||||
} catch (error) {
|
||||
onFailToast("delete credentials")(error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[api],
|
||||
[api, onFailToast],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -210,47 +263,54 @@ export default function CredentialsProvider({
|
||||
return;
|
||||
}
|
||||
|
||||
api.listCredentials().then((response) => {
|
||||
const credentialsByProvider = response.reduce(
|
||||
(acc, cred) => {
|
||||
if (!acc[cred.provider]) {
|
||||
acc[cred.provider] = [];
|
||||
}
|
||||
acc[cred.provider].push(cred);
|
||||
return acc;
|
||||
},
|
||||
{} as Record<CredentialsProviderName, CredentialsMetaResponse[]>,
|
||||
);
|
||||
api
|
||||
.listCredentials()
|
||||
.then((response) => {
|
||||
const credentialsByProvider = response.reduce(
|
||||
(acc, cred) => {
|
||||
if (!acc[cred.provider]) {
|
||||
acc[cred.provider] = [];
|
||||
}
|
||||
acc[cred.provider].push(cred);
|
||||
return acc;
|
||||
},
|
||||
{} as Record<CredentialsProviderName, CredentialsMetaResponse[]>,
|
||||
);
|
||||
|
||||
setProviders((prev) => ({
|
||||
...prev,
|
||||
...Object.fromEntries(
|
||||
CREDENTIALS_PROVIDER_NAMES.map((provider) => [
|
||||
provider,
|
||||
{
|
||||
setProviders((prev) => ({
|
||||
...prev,
|
||||
...Object.fromEntries(
|
||||
CREDENTIALS_PROVIDER_NAMES.map((provider) => [
|
||||
provider,
|
||||
providerName: providerDisplayNames[provider],
|
||||
savedCredentials: credentialsByProvider[provider] ?? [],
|
||||
oAuthCallback: (code: string, state_token: string) =>
|
||||
oAuthCallback(provider, code, state_token),
|
||||
createAPIKeyCredentials: (
|
||||
credentials: APIKeyCredentialsCreatable,
|
||||
) => createAPIKeyCredentials(provider, credentials),
|
||||
createUserPasswordCredentials: (
|
||||
credentials: UserPasswordCredentialsCreatable,
|
||||
) => createUserPasswordCredentials(provider, credentials),
|
||||
deleteCredentials: (id: string, force: boolean = false) =>
|
||||
deleteCredentials(provider, id, force),
|
||||
} satisfies CredentialsProviderData,
|
||||
]),
|
||||
),
|
||||
}));
|
||||
});
|
||||
{
|
||||
provider,
|
||||
providerName: providerDisplayNames[provider],
|
||||
savedCredentials: credentialsByProvider[provider] ?? [],
|
||||
oAuthCallback: (code: string, state_token: string) =>
|
||||
oAuthCallback(provider, code, state_token),
|
||||
createAPIKeyCredentials: (
|
||||
credentials: APIKeyCredentialsCreatable,
|
||||
) => createAPIKeyCredentials(provider, credentials),
|
||||
createUserPasswordCredentials: (
|
||||
credentials: UserPasswordCredentialsCreatable,
|
||||
) => createUserPasswordCredentials(provider, credentials),
|
||||
createHostScopedCredentials: (
|
||||
credentials: HostScopedCredentialsCreatable,
|
||||
) => createHostScopedCredentials(provider, credentials),
|
||||
deleteCredentials: (id: string, force: boolean = false) =>
|
||||
deleteCredentials(provider, id, force),
|
||||
} satisfies CredentialsProviderData,
|
||||
]),
|
||||
),
|
||||
}));
|
||||
})
|
||||
.catch(onFailToast("load credentials"));
|
||||
}, [
|
||||
api,
|
||||
isLoggedIn,
|
||||
createAPIKeyCredentials,
|
||||
createUserPasswordCredentials,
|
||||
createHostScopedCredentials,
|
||||
deleteCredentials,
|
||||
oAuthCallback,
|
||||
]);
|
||||
|
||||
@@ -0,0 +1,235 @@
|
||||
import { FC, useEffect, useState } from "react";
|
||||
import { z } from "zod";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormDescription,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import useCredentials from "@/hooks/useCredentials";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import { getHostFromUrl } from "@/lib/utils/url";
|
||||
|
||||
export const HostScopedCredentialsModal: FC<{
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
siblingInputs?: Record<string, any>;
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
// Get current host from siblingInputs or discriminator_values
|
||||
const currentUrl = credentials?.discriminatorValue;
|
||||
const currentHost = currentUrl ? getHostFromUrl(currentUrl) : "";
|
||||
|
||||
const formSchema = z.object({
|
||||
host: z.string().min(1, "Host is required"),
|
||||
title: z.string().optional(),
|
||||
headers: z.record(z.string()).optional(),
|
||||
});
|
||||
|
||||
const form = useForm<z.infer<typeof formSchema>>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: {
|
||||
host: currentHost || "",
|
||||
title: currentHost || "Manual Entry",
|
||||
headers: {},
|
||||
},
|
||||
});
|
||||
|
||||
const [headerPairs, setHeaderPairs] = useState<
|
||||
Array<{ key: string; value: string }>
|
||||
>([{ key: "", value: "" }]);
|
||||
|
||||
// Update form values when siblingInputs change
|
||||
useEffect(() => {
|
||||
if (currentHost) {
|
||||
form.setValue("host", currentHost);
|
||||
form.setValue("title", currentHost);
|
||||
} else {
|
||||
// Reset to empty when no current host
|
||||
form.setValue("host", "");
|
||||
form.setValue("title", "Manual Entry");
|
||||
}
|
||||
}, [currentHost, form]);
|
||||
|
||||
if (
|
||||
!credentials ||
|
||||
credentials.isLoading ||
|
||||
!credentials.supportsHostScoped
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { provider, providerName, createHostScopedCredentials } = credentials;
|
||||
|
||||
const addHeaderPair = () => {
|
||||
setHeaderPairs([...headerPairs, { key: "", value: "" }]);
|
||||
};
|
||||
|
||||
const removeHeaderPair = (index: number) => {
|
||||
if (headerPairs.length > 1) {
|
||||
setHeaderPairs(headerPairs.filter((_, i) => i !== index));
|
||||
}
|
||||
};
|
||||
|
||||
const updateHeaderPair = (
|
||||
index: number,
|
||||
field: "key" | "value",
|
||||
value: string,
|
||||
) => {
|
||||
const newPairs = [...headerPairs];
|
||||
newPairs[index][field] = value;
|
||||
setHeaderPairs(newPairs);
|
||||
};
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
// Convert header pairs to object, filtering out empty pairs
|
||||
const headers = headerPairs.reduce(
|
||||
(acc, pair) => {
|
||||
if (pair.key.trim() && pair.value.trim()) {
|
||||
acc[pair.key.trim()] = pair.value.trim();
|
||||
}
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, string>,
|
||||
);
|
||||
|
||||
const newCredentials = await createHostScopedCredentials({
|
||||
host: values.host,
|
||||
title: currentHost || values.host,
|
||||
headers,
|
||||
});
|
||||
|
||||
onCredentialsCreate({
|
||||
provider,
|
||||
id: newCredentials.id,
|
||||
type: "host_scoped",
|
||||
title: newCredentials.title,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent className="max-h-[90vh] max-w-2xl overflow-y-auto">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Add sensitive headers for {providerName}</DialogTitle>
|
||||
{schema.description && (
|
||||
<DialogDescription>{schema.description}</DialogDescription>
|
||||
)}
|
||||
</DialogHeader>
|
||||
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="host"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Host Pattern</FormLabel>
|
||||
<FormDescription>
|
||||
{currentHost
|
||||
? "Auto-populated from the URL field. Headers will be applied to requests to this host."
|
||||
: "Enter the host/domain to match against request URLs (e.g., api.example.com)."}
|
||||
</FormDescription>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="text"
|
||||
readOnly={!!currentHost}
|
||||
placeholder={
|
||||
currentHost
|
||||
? undefined
|
||||
: "Enter host (e.g., api.example.com)"
|
||||
}
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<div className="space-y-2">
|
||||
<FormLabel>Headers</FormLabel>
|
||||
<FormDescription>
|
||||
Add sensitive headers (like Authorization, X-API-Key) that
|
||||
should be automatically included in requests to the specified
|
||||
host.
|
||||
</FormDescription>
|
||||
|
||||
{headerPairs.map((pair, index) => (
|
||||
<div key={index} className="flex items-end gap-2">
|
||||
<div className="flex-1">
|
||||
<Input
|
||||
placeholder="Header name (e.g., Authorization)"
|
||||
value={pair.key}
|
||||
onChange={(e) =>
|
||||
updateHeaderPair(index, "key", e.target.value)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="Header value (e.g., Bearer token123)"
|
||||
value={pair.value}
|
||||
onChange={(e) =>
|
||||
updateHeaderPair(index, "value", e.target.value)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => removeHeaderPair(index)}
|
||||
disabled={headerPairs.length === 1}
|
||||
>
|
||||
Remove
|
||||
</Button>
|
||||
</div>
|
||||
))}
|
||||
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={addHeaderPair}
|
||||
className="w-full"
|
||||
>
|
||||
Add Another Header
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<Button type="submit" className="w-full">
|
||||
Save & use these credentials
|
||||
</Button>
|
||||
</form>
|
||||
</Form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,36 @@
|
||||
import { FC } from "react";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
|
||||
export const OAuth2FlowWaitingModal: FC<{
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
providerName: string;
|
||||
}> = ({ open, onClose, providerName }) => {
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
Waiting on {providerName} sign-in process...
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
Complete the sign-in process in the pop-up window.
|
||||
<br />
|
||||
Closing this dialog will cancel the sign-in process.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,149 @@
|
||||
import { FC } from "react";
|
||||
import { z } from "zod";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import useCredentials from "@/hooks/useCredentials";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
|
||||
export const UserPasswordCredentialsModal: FC<{
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
siblingInputs?: Record<string, any>;
|
||||
}> = ({ schema, open, onClose, onCredentialsCreate, siblingInputs }) => {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
|
||||
const formSchema = z.object({
|
||||
username: z.string().min(1, "Username is required"),
|
||||
password: z.string().min(1, "Password is required"),
|
||||
title: z.string().min(1, "Name is required"),
|
||||
});
|
||||
|
||||
const form = useForm<z.infer<typeof formSchema>>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: {
|
||||
username: "",
|
||||
password: "",
|
||||
title: "",
|
||||
},
|
||||
});
|
||||
|
||||
if (
|
||||
!credentials ||
|
||||
credentials.isLoading ||
|
||||
!credentials.supportsUserPassword
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { provider, providerName, createUserPasswordCredentials } = credentials;
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
const newCredentials = await createUserPasswordCredentials({
|
||||
username: values.username,
|
||||
password: values.password,
|
||||
title: values.title,
|
||||
});
|
||||
onCredentialsCreate({
|
||||
provider,
|
||||
id: newCredentials.id,
|
||||
type: "user_password",
|
||||
title: newCredentials.title,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
Add new username & password for {providerName}
|
||||
</DialogTitle>
|
||||
</DialogHeader>
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="username"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Username</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Enter username..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="password"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Password</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="Enter password..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="title"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Name</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Enter a name for this user login..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<Button type="submit" className="w-full">
|
||||
Save & use this user login
|
||||
</Button>
|
||||
</form>
|
||||
</Form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
@@ -1,7 +1,17 @@
|
||||
"use client";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import { OnboardingStep, UserOnboarding } from "@/lib/autogpt-server-api";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import Link from "next/link";
|
||||
import { usePathname, useRouter } from "next/navigation";
|
||||
import {
|
||||
createContext,
|
||||
@@ -11,16 +21,6 @@ import {
|
||||
useEffect,
|
||||
useState,
|
||||
} from "react";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogFooter,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import Link from "next/link";
|
||||
|
||||
const OnboardingContext = createContext<
|
||||
| {
|
||||
@@ -106,8 +106,6 @@ export default function OnboardingProvider({
|
||||
const updateState = useCallback(
|
||||
(newState: Omit<Partial<UserOnboarding>, "rewardedFor">) => {
|
||||
setState((prev) => {
|
||||
api.updateUserOnboarding(newState);
|
||||
|
||||
if (!prev) {
|
||||
// Handle initial state
|
||||
return {
|
||||
@@ -127,8 +125,15 @@ export default function OnboardingProvider({
|
||||
}
|
||||
return { ...prev, ...newState };
|
||||
});
|
||||
|
||||
// Make the API call asynchronously to not block render
|
||||
setTimeout(() => {
|
||||
api.updateUserOnboarding(newState).catch((error) => {
|
||||
console.error("Failed to update user onboarding:", error);
|
||||
});
|
||||
}, 0);
|
||||
},
|
||||
[api, setState],
|
||||
[api],
|
||||
);
|
||||
|
||||
const completeStep = useCallback(
|
||||
@@ -153,7 +158,7 @@ export default function OnboardingProvider({
|
||||
completedSteps: [...state.completedSteps, "RUN_AGENTS"],
|
||||
}),
|
||||
});
|
||||
}, [api, state]);
|
||||
}, [state, updateState]);
|
||||
|
||||
return (
|
||||
<OnboardingContext.Provider
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsProviderName,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import { getHostFromUrl } from "@/lib/utils/url";
|
||||
|
||||
export type CredentialsData =
|
||||
| {
|
||||
@@ -17,14 +18,18 @@ export type CredentialsData =
|
||||
supportsApiKey: boolean;
|
||||
supportsOAuth2: boolean;
|
||||
supportsUserPassword: boolean;
|
||||
supportsHostScoped: boolean;
|
||||
isLoading: true;
|
||||
discriminatorValue?: string;
|
||||
}
|
||||
| (CredentialsProviderData & {
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
supportsApiKey: boolean;
|
||||
supportsOAuth2: boolean;
|
||||
supportsUserPassword: boolean;
|
||||
supportsHostScoped: boolean;
|
||||
isLoading: false;
|
||||
discriminatorValue?: string;
|
||||
});
|
||||
|
||||
export default function useCredentials(
|
||||
@@ -33,12 +38,16 @@ export default function useCredentials(
|
||||
): CredentialsData | null {
|
||||
const allProviders = useContext(CredentialsProvidersContext);
|
||||
|
||||
const discriminatorValue: CredentialsProviderName | null =
|
||||
(credsInputSchema.discriminator &&
|
||||
credsInputSchema.discriminator_mapping![
|
||||
getValue(credsInputSchema.discriminator, nodeInputValues)
|
||||
]) ||
|
||||
null;
|
||||
const discriminatorValue = [
|
||||
credsInputSchema.discriminator
|
||||
? getValue(credsInputSchema.discriminator, nodeInputValues)
|
||||
: null,
|
||||
...(credsInputSchema.discriminator_values || []),
|
||||
].find(Boolean);
|
||||
|
||||
const discriminatedProvider = credsInputSchema.discriminator_mapping
|
||||
? credsInputSchema.discriminator_mapping[discriminatorValue]
|
||||
: null;
|
||||
|
||||
let providerName: CredentialsProviderName;
|
||||
if (credsInputSchema.credentials_provider.length > 1) {
|
||||
@@ -47,14 +56,14 @@ export default function useCredentials(
|
||||
"Multi-provider credential input requires discriminator!",
|
||||
);
|
||||
}
|
||||
if (!discriminatorValue) {
|
||||
console.log(
|
||||
if (!discriminatedProvider) {
|
||||
console.warn(
|
||||
`Missing discriminator value from '${credsInputSchema.discriminator}': ` +
|
||||
"hiding credentials input until it is set.",
|
||||
);
|
||||
return null;
|
||||
}
|
||||
providerName = discriminatorValue;
|
||||
providerName = discriminatedProvider;
|
||||
} else {
|
||||
providerName = credsInputSchema.credentials_provider[0];
|
||||
}
|
||||
@@ -69,6 +78,8 @@ export default function useCredentials(
|
||||
const supportsOAuth2 = credsInputSchema.credentials_types.includes("oauth2");
|
||||
const supportsUserPassword =
|
||||
credsInputSchema.credentials_types.includes("user_password");
|
||||
const supportsHostScoped =
|
||||
credsInputSchema.credentials_types.includes("host_scoped");
|
||||
|
||||
// No provider means maybe it's still loading
|
||||
if (!provider) {
|
||||
@@ -82,15 +93,24 @@ export default function useCredentials(
|
||||
return null;
|
||||
}
|
||||
|
||||
// Filter by OAuth credentials that have sufficient scopes for this block
|
||||
const requiredScopes = credsInputSchema.credentials_scopes;
|
||||
const savedCredentials = requiredScopes
|
||||
? provider.savedCredentials.filter(
|
||||
(c) =>
|
||||
c.type != "oauth2" ||
|
||||
new Set(c.scopes).isSupersetOf(new Set(requiredScopes)),
|
||||
)
|
||||
: provider.savedCredentials;
|
||||
const savedCredentials = provider.savedCredentials.filter((c) => {
|
||||
// Filter by OAuth credentials that have sufficient scopes for this block
|
||||
if (c.type === "oauth2") {
|
||||
const requiredScopes = credsInputSchema.credentials_scopes;
|
||||
return (
|
||||
!requiredScopes ||
|
||||
new Set(c.scopes).isSupersetOf(new Set(requiredScopes))
|
||||
);
|
||||
}
|
||||
|
||||
// Filter host_scoped credentials by host matching
|
||||
if (c.type === "host_scoped") {
|
||||
return discriminatorValue && getHostFromUrl(discriminatorValue) == c.host;
|
||||
}
|
||||
|
||||
// Include all other credential types
|
||||
return true;
|
||||
});
|
||||
|
||||
return {
|
||||
...provider,
|
||||
@@ -99,7 +119,9 @@ export default function useCredentials(
|
||||
supportsApiKey,
|
||||
supportsOAuth2,
|
||||
supportsUserPassword,
|
||||
supportsHostScoped,
|
||||
savedCredentials,
|
||||
discriminatorValue,
|
||||
isLoading: false,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { getWebSocketToken } from "@/lib/supabase/actions";
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import { createBrowserClient } from "@supabase/ssr";
|
||||
import type { SupabaseClient } from "@supabase/supabase-js";
|
||||
import { proxyApiRequest, proxyFileUpload } from "./proxy-action";
|
||||
import type {
|
||||
AddUserCreditsResponse,
|
||||
AnalyticsDetails,
|
||||
@@ -60,6 +62,7 @@ import type {
|
||||
User,
|
||||
UserOnboarding,
|
||||
UserPasswordCredentials,
|
||||
HostScopedCredentials,
|
||||
UsersBalanceHistoryResponse,
|
||||
} from "./types";
|
||||
|
||||
@@ -345,6 +348,16 @@ export default class BackendAPI {
|
||||
);
|
||||
}
|
||||
|
||||
createHostScopedCredentials(
|
||||
credentials: Omit<HostScopedCredentials, "id" | "type">,
|
||||
): Promise<HostScopedCredentials> {
|
||||
return this._request(
|
||||
"POST",
|
||||
`/integrations/${credentials.provider}/credentials`,
|
||||
{ ...credentials, type: "host_scoped" },
|
||||
);
|
||||
}
|
||||
|
||||
listCredentials(provider?: string): Promise<CredentialsMetaResponse[]> {
|
||||
return this._get(
|
||||
provider
|
||||
@@ -761,50 +774,25 @@ export default class BackendAPI {
|
||||
return this._request("GET", path, query);
|
||||
}
|
||||
|
||||
private async getAuthToken(): Promise<string> {
|
||||
// Only try client-side session (for WebSocket connections)
|
||||
// This will return "no-token-found" with httpOnly cookies, which is expected
|
||||
const supabaseClient = await this.getSupabaseClient();
|
||||
const {
|
||||
data: { session },
|
||||
} = (await supabaseClient?.auth.getSession()) || {
|
||||
data: { session: null },
|
||||
};
|
||||
|
||||
return session?.access_token || "no-token-found";
|
||||
}
|
||||
|
||||
private async _uploadFile(path: string, file: File): Promise<string> {
|
||||
// Get session with retry logic
|
||||
let token = "no-token-found";
|
||||
let retryCount = 0;
|
||||
const maxRetries = 3;
|
||||
|
||||
while (retryCount < maxRetries) {
|
||||
const supabaseClient = await this.getSupabaseClient();
|
||||
const {
|
||||
data: { session },
|
||||
} = (await supabaseClient?.auth.getSession()) || {
|
||||
data: { session: null },
|
||||
};
|
||||
|
||||
if (session?.access_token) {
|
||||
token = session.access_token;
|
||||
break;
|
||||
}
|
||||
|
||||
retryCount++;
|
||||
if (retryCount < maxRetries) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 100 * retryCount));
|
||||
}
|
||||
}
|
||||
|
||||
// Create a FormData object and append the file
|
||||
const formData = new FormData();
|
||||
formData.append("file", file);
|
||||
|
||||
const response = await fetch(this.baseUrl + path, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
...(token && { Authorization: `Bearer ${token}` }),
|
||||
},
|
||||
body: formData,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Error uploading file: ${response.statusText}`);
|
||||
}
|
||||
|
||||
// Parse the response appropriately
|
||||
const media_url = await response.text();
|
||||
return media_url;
|
||||
// Use proxy server action for secure file upload
|
||||
return await proxyFileUpload(path, formData, this.baseUrl);
|
||||
}
|
||||
|
||||
private async _request(
|
||||
@@ -816,103 +804,13 @@ export default class BackendAPI {
|
||||
console.debug(`${method} ${path} payload:`, payload);
|
||||
}
|
||||
|
||||
// Get session with retry logic
|
||||
let token = "no-token-found";
|
||||
let retryCount = 0;
|
||||
const maxRetries = 3;
|
||||
|
||||
while (retryCount < maxRetries) {
|
||||
const supabaseClient = await this.getSupabaseClient();
|
||||
const {
|
||||
data: { session },
|
||||
} = (await supabaseClient?.auth.getSession()) || {
|
||||
data: { session: null },
|
||||
};
|
||||
|
||||
if (session?.access_token) {
|
||||
token = session.access_token;
|
||||
break;
|
||||
}
|
||||
|
||||
retryCount++;
|
||||
if (retryCount < maxRetries) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 100 * retryCount));
|
||||
}
|
||||
}
|
||||
|
||||
let url = this.baseUrl + path;
|
||||
const payloadAsQuery = ["GET", "DELETE"].includes(method);
|
||||
if (payloadAsQuery && payload) {
|
||||
// For GET requests, use payload as query
|
||||
const queryParams = new URLSearchParams(payload);
|
||||
url += `?${queryParams.toString()}`;
|
||||
}
|
||||
|
||||
const hasRequestBody = !payloadAsQuery && payload !== undefined;
|
||||
const response = await fetch(url, {
|
||||
// Always use proxy server action to not expose any auth tokens to the browser
|
||||
return await proxyApiRequest({
|
||||
method,
|
||||
headers: {
|
||||
...(hasRequestBody && { "Content-Type": "application/json" }),
|
||||
...(token && { Authorization: `Bearer ${token}` }),
|
||||
},
|
||||
body: hasRequestBody ? JSON.stringify(payload) : undefined,
|
||||
path,
|
||||
payload,
|
||||
baseUrl: this.baseUrl,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.warn(`${method} ${path} returned non-OK response:`, response);
|
||||
|
||||
// console.warn("baseClient is attempting to redirect by changing window location")
|
||||
// if (
|
||||
// response.status === 403 &&
|
||||
// response.statusText === "Not authenticated" &&
|
||||
// typeof window !== "undefined" // Check if in browser environment
|
||||
// ) {
|
||||
// window.location.href = "/login";
|
||||
// }
|
||||
|
||||
let errorDetail;
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
if (
|
||||
Array.isArray(errorData.detail) &&
|
||||
errorData.detail.length > 0 &&
|
||||
errorData.detail[0].loc
|
||||
) {
|
||||
// This appears to be a Pydantic validation error
|
||||
const errors = errorData.detail.map(
|
||||
(err: _PydanticValidationError) => {
|
||||
const location = err.loc.join(" -> ");
|
||||
return `${location}: ${err.msg}`;
|
||||
},
|
||||
);
|
||||
errorDetail = errors.join("\n");
|
||||
} else {
|
||||
errorDetail = errorData.detail || response.statusText;
|
||||
}
|
||||
} catch {
|
||||
errorDetail = response.statusText;
|
||||
}
|
||||
|
||||
throw new Error(errorDetail);
|
||||
}
|
||||
|
||||
// Handle responses with no content (like DELETE requests)
|
||||
if (
|
||||
response.status === 204 ||
|
||||
response.headers.get("Content-Length") === "0"
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
return await response.json();
|
||||
} catch (e) {
|
||||
if (e instanceof SyntaxError) {
|
||||
console.warn(`${method} ${path} returned invalid JSON:`, e);
|
||||
return null;
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////
|
||||
@@ -1000,10 +898,19 @@ export default class BackendAPI {
|
||||
async connectWebSocket(): Promise<void> {
|
||||
return (this.wsConnecting ??= new Promise(async (resolve, reject) => {
|
||||
try {
|
||||
const supabaseClient = await this.getSupabaseClient();
|
||||
const token =
|
||||
(await supabaseClient?.auth.getSession())?.data.session
|
||||
?.access_token || "";
|
||||
let token = "";
|
||||
try {
|
||||
const { token: serverToken, error } = await getWebSocketToken();
|
||||
if (serverToken && !error) {
|
||||
token = serverToken;
|
||||
} else if (error) {
|
||||
console.warn("Failed to get WebSocket token from server:", error);
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn("Failed to get token for WebSocket connection:", error);
|
||||
// Continue with empty token, connection might still work
|
||||
}
|
||||
|
||||
const wsUrlWithToken = `${this.wsUrl}?token=${token}`;
|
||||
this.webSocket = new WebSocket(wsUrlWithToken);
|
||||
this.webSocket.state = "connecting";
|
||||
|
||||
236
autogpt_platform/frontend/src/lib/autogpt-server-api/helpers.ts
Normal file
236
autogpt_platform/frontend/src/lib/autogpt-server-api/helpers.ts
Normal file
@@ -0,0 +1,236 @@
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
|
||||
export function buildRequestUrl(
|
||||
baseUrl: string,
|
||||
path: string,
|
||||
method: string,
|
||||
payload?: Record<string, any>,
|
||||
): string {
|
||||
let url = baseUrl + path;
|
||||
const payloadAsQuery = ["GET", "DELETE"].includes(method);
|
||||
|
||||
if (payloadAsQuery && payload) {
|
||||
const queryParams = new URLSearchParams(payload);
|
||||
url += `?${queryParams.toString()}`;
|
||||
}
|
||||
|
||||
return url;
|
||||
}
|
||||
|
||||
export async function getServerAuthToken(): Promise<string> {
|
||||
const supabase = await getServerSupabase();
|
||||
|
||||
if (!supabase) {
|
||||
throw new Error("Supabase client not available");
|
||||
}
|
||||
|
||||
try {
|
||||
const {
|
||||
data: { session },
|
||||
error,
|
||||
} = await supabase.auth.getSession();
|
||||
|
||||
if (error || !session?.access_token) {
|
||||
return "no-token-found";
|
||||
}
|
||||
|
||||
return session.access_token;
|
||||
} catch (error) {
|
||||
console.error("Failed to get auth token:", error);
|
||||
return "no-token-found";
|
||||
}
|
||||
}
|
||||
|
||||
export function createRequestHeaders(
|
||||
token: string,
|
||||
hasRequestBody: boolean,
|
||||
contentType: string = "application/json",
|
||||
): Record<string, string> {
|
||||
const headers: Record<string, string> = {};
|
||||
|
||||
if (hasRequestBody) {
|
||||
headers["Content-Type"] = contentType;
|
||||
}
|
||||
|
||||
if (token && token !== "no-token-found") {
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
export function serializeRequestBody(
|
||||
payload: any,
|
||||
contentType: string = "application/json",
|
||||
): string {
|
||||
switch (contentType) {
|
||||
case "application/json":
|
||||
return JSON.stringify(payload);
|
||||
case "application/x-www-form-urlencoded":
|
||||
return new URLSearchParams(payload).toString();
|
||||
default:
|
||||
// For custom content types, assume payload is already properly formatted
|
||||
return typeof payload === "string" ? payload : JSON.stringify(payload);
|
||||
}
|
||||
}
|
||||
|
||||
export async function parseApiError(response: Response): Promise<string> {
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
|
||||
if (
|
||||
Array.isArray(errorData.detail) &&
|
||||
errorData.detail.length > 0 &&
|
||||
errorData.detail[0].loc
|
||||
) {
|
||||
// Pydantic validation error
|
||||
const errors = errorData.detail.map((err: any) => {
|
||||
const location = err.loc.join(" -> ");
|
||||
return `${location}: ${err.msg}`;
|
||||
});
|
||||
return errors.join("\n");
|
||||
}
|
||||
|
||||
return errorData.detail || response.statusText;
|
||||
} catch {
|
||||
return response.statusText;
|
||||
}
|
||||
}
|
||||
|
||||
export async function parseApiResponse(response: Response): Promise<any> {
|
||||
// Handle responses with no content
|
||||
if (
|
||||
response.status === 204 ||
|
||||
response.headers.get("Content-Length") === "0"
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
return await response.json();
|
||||
} catch (e) {
|
||||
if (e instanceof SyntaxError) {
|
||||
return null;
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
function isAuthenticationError(
|
||||
response: Response,
|
||||
errorDetail: string,
|
||||
): boolean {
|
||||
return (
|
||||
response.status === 401 ||
|
||||
response.status === 403 ||
|
||||
errorDetail.toLowerCase().includes("not authenticated") ||
|
||||
errorDetail.toLowerCase().includes("unauthorized") ||
|
||||
errorDetail.toLowerCase().includes("authentication failed")
|
||||
);
|
||||
}
|
||||
|
||||
function isLogoutInProgress(): boolean {
|
||||
if (typeof window === "undefined") return false;
|
||||
|
||||
try {
|
||||
// Check if logout was recently triggered
|
||||
const logoutTimestamp = window.localStorage.getItem("supabase-logout");
|
||||
if (logoutTimestamp) {
|
||||
const timeDiff = Date.now() - parseInt(logoutTimestamp);
|
||||
// Consider logout in progress for 5 seconds after trigger
|
||||
return timeDiff < 5000;
|
||||
}
|
||||
|
||||
// Check if we're being redirected to login
|
||||
return (
|
||||
window.location.pathname.includes("/login") ||
|
||||
window.location.pathname.includes("/logout")
|
||||
);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export async function makeAuthenticatedRequest(
|
||||
method: string,
|
||||
url: string,
|
||||
payload?: Record<string, any>,
|
||||
contentType: string = "application/json",
|
||||
): Promise<any> {
|
||||
const token = await getServerAuthToken();
|
||||
const payloadAsQuery = ["GET", "DELETE"].includes(method);
|
||||
const hasRequestBody = !payloadAsQuery && payload !== undefined;
|
||||
|
||||
const response = await fetch(url, {
|
||||
method,
|
||||
headers: createRequestHeaders(token, hasRequestBody, contentType),
|
||||
body: hasRequestBody
|
||||
? serializeRequestBody(payload, contentType)
|
||||
: undefined,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorDetail = await parseApiError(response);
|
||||
|
||||
// Handle authentication errors gracefully during logout
|
||||
if (isAuthenticationError(response, errorDetail)) {
|
||||
if (isLogoutInProgress()) {
|
||||
// Silently return null during logout to prevent error noise
|
||||
console.debug(
|
||||
"Authentication request failed during logout, ignoring:",
|
||||
errorDetail,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
// For authentication errors outside logout, log but don't throw
|
||||
// This prevents crashes when session expires naturally
|
||||
console.warn("Authentication failed:", errorDetail);
|
||||
return null;
|
||||
}
|
||||
|
||||
// For other errors, throw as normal
|
||||
throw new Error(errorDetail);
|
||||
}
|
||||
|
||||
return parseApiResponse(response);
|
||||
}
|
||||
|
||||
export async function makeAuthenticatedFileUpload(
|
||||
url: string,
|
||||
formData: FormData,
|
||||
): Promise<string> {
|
||||
const token = await getServerAuthToken();
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
if (token && token !== "no-token-found") {
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
// Don't set Content-Type for FormData - let the browser set it with boundary
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: formData,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
// Handle authentication errors gracefully for file uploads too
|
||||
const errorMessage = `Error uploading file: ${response.statusText}`;
|
||||
|
||||
if (response.status === 401 || response.status === 403) {
|
||||
if (isLogoutInProgress()) {
|
||||
console.debug(
|
||||
"File upload authentication failed during logout, ignoring",
|
||||
);
|
||||
return "";
|
||||
}
|
||||
console.warn("File upload authentication failed:", errorMessage);
|
||||
return "";
|
||||
}
|
||||
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
|
||||
return await response.text();
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
"use server";
|
||||
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import {
|
||||
buildRequestUrl,
|
||||
makeAuthenticatedFileUpload,
|
||||
makeAuthenticatedRequest,
|
||||
} from "./helpers";
|
||||
|
||||
const DEFAULT_BASE_URL = "http://localhost:8006/api";
|
||||
|
||||
export interface ProxyRequestOptions {
|
||||
method: "GET" | "POST" | "PUT" | "PATCH" | "DELETE";
|
||||
path: string;
|
||||
payload?: Record<string, any>;
|
||||
baseUrl?: string;
|
||||
contentType?: string;
|
||||
}
|
||||
|
||||
export async function proxyApiRequest({
|
||||
method,
|
||||
path,
|
||||
payload,
|
||||
baseUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL || DEFAULT_BASE_URL,
|
||||
contentType = "application/json",
|
||||
}: ProxyRequestOptions) {
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"proxyApiRequest",
|
||||
{},
|
||||
async () => {
|
||||
const url = buildRequestUrl(baseUrl, path, method, payload);
|
||||
return makeAuthenticatedRequest(method, url, payload, contentType);
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
export async function proxyFileUpload(
|
||||
path: string,
|
||||
formData: FormData,
|
||||
baseUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL ||
|
||||
"http://localhost:8006/api",
|
||||
): Promise<string> {
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"proxyFileUpload",
|
||||
{},
|
||||
async () => {
|
||||
const url = baseUrl + path;
|
||||
return makeAuthenticatedFileUpload(url, formData);
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -140,12 +140,17 @@ export type BlockIOBooleanSubSchema = BlockIOSubSchemaMeta & {
|
||||
secret?: boolean;
|
||||
};
|
||||
|
||||
export type CredentialsType = "api_key" | "oauth2" | "user_password";
|
||||
export type CredentialsType =
|
||||
| "api_key"
|
||||
| "oauth2"
|
||||
| "user_password"
|
||||
| "host_scoped";
|
||||
|
||||
export type Credentials =
|
||||
| APIKeyCredentials
|
||||
| OAuth2Credentials
|
||||
| UserPasswordCredentials;
|
||||
| UserPasswordCredentials
|
||||
| HostScopedCredentials;
|
||||
|
||||
// --8<-- [start:BlockIOCredentialsSubSchema]
|
||||
export const PROVIDER_NAMES = {
|
||||
@@ -161,6 +166,7 @@ export const PROVIDER_NAMES = {
|
||||
GOOGLE: "google",
|
||||
GOOGLE_MAPS: "google_maps",
|
||||
GROQ: "groq",
|
||||
HTTP: "http",
|
||||
HUBSPOT: "hubspot",
|
||||
IDEOGRAM: "ideogram",
|
||||
JINA: "jina",
|
||||
@@ -199,6 +205,7 @@ export type BlockIOCredentialsSubSchema = BlockIOObjectSubSchema & {
|
||||
credentials_types: Array<CredentialsType>;
|
||||
discriminator?: string;
|
||||
discriminator_mapping?: { [key: string]: CredentialsProviderName };
|
||||
discriminator_values?: any[];
|
||||
secret?: boolean;
|
||||
};
|
||||
|
||||
@@ -501,6 +508,7 @@ export type CredentialsMetaResponse = {
|
||||
title?: string;
|
||||
scopes?: Array<string>;
|
||||
username?: string;
|
||||
host?: string;
|
||||
};
|
||||
|
||||
/* Mirror of backend/server/integrations/router.py:CredentialsDeletionResponse */
|
||||
@@ -559,6 +567,14 @@ export type UserPasswordCredentials = BaseCredentials & {
|
||||
password: string;
|
||||
};
|
||||
|
||||
/* Mirror of backend/backend/data/model.py:HostScopedCredentials */
|
||||
export type HostScopedCredentials = BaseCredentials & {
|
||||
type: "host_scoped";
|
||||
title: string;
|
||||
host: string;
|
||||
headers: Record<string, string>;
|
||||
};
|
||||
|
||||
// Mirror of backend/backend/data/notifications.py:NotificationType
|
||||
export type NotificationType =
|
||||
| "AGENT_RUN"
|
||||
|
||||
180
autogpt_platform/frontend/src/lib/supabase/SESSION_VALIDATION.md
Normal file
180
autogpt_platform/frontend/src/lib/supabase/SESSION_VALIDATION.md
Normal file
@@ -0,0 +1,180 @@
|
||||
# Server-Side Session Validation with httpOnly Cookies
|
||||
|
||||
This implementation ensures that Supabase session validation is always performed on the server side using httpOnly cookies for improved security.
|
||||
|
||||
## Key Features
|
||||
|
||||
- **httpOnly Cookies**: Session cookies are inaccessible to client-side JavaScript, preventing XSS attacks
|
||||
- **Server-Side Authentication**: All API requests are authenticated on the server using httpOnly cookies
|
||||
- **Automatic Request Proxying**: All BackendAPI requests are automatically proxied through server actions
|
||||
- **File Upload Support**: File uploads work seamlessly with httpOnly cookie authentication
|
||||
- **Zero Code Changes**: Existing BackendAPI usage continues to work without modifications
|
||||
- **Cross-Tab Logout**: Logout events are still synchronized across browser tabs
|
||||
|
||||
## How It Works
|
||||
|
||||
All API requests made through `BackendAPI` are automatically proxied through server actions that:
|
||||
|
||||
1. Retrieve the JWT token from server-side httpOnly cookies
|
||||
2. Make the authenticated request to the backend API
|
||||
3. Return the response to the client
|
||||
|
||||
This includes both regular API calls and file uploads, all handled transparently!
|
||||
|
||||
## Usage
|
||||
|
||||
### Client Components
|
||||
|
||||
No changes needed! The existing `useSupabase` hook and `useBackendAPI` continue to work:
|
||||
|
||||
```tsx
|
||||
"use client";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
|
||||
function MyComponent() {
|
||||
const { user, isLoggedIn, isUserLoading, logOut } = useSupabase();
|
||||
const api = useBackendAPI();
|
||||
|
||||
if (isUserLoading) return <div>Loading...</div>;
|
||||
if (!isLoggedIn) return <div>Please log in</div>;
|
||||
|
||||
// Regular API calls use secure server-side authentication
|
||||
const handleGetGraphs = async () => {
|
||||
const graphs = await api.listGraphs();
|
||||
console.log(graphs);
|
||||
};
|
||||
|
||||
// File uploads also work with secure authentication
|
||||
const handleFileUpload = async (file: File) => {
|
||||
try {
|
||||
const mediaUrl = await api.uploadStoreSubmissionMedia(file);
|
||||
console.log("Uploaded:", mediaUrl);
|
||||
} catch (error) {
|
||||
console.error("Upload failed:", error);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div>
|
||||
<p>Welcome, {user?.email}!</p>
|
||||
<button onClick={handleGetGraphs}>Get Graphs</button>
|
||||
<input
|
||||
type="file"
|
||||
onChange={(e) =>
|
||||
e.target.files?.[0] && handleFileUpload(e.target.files[0])
|
||||
}
|
||||
/>
|
||||
<button onClick={logOut}>Log Out</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
### Server Components
|
||||
|
||||
No changes needed! Server components continue to work as before:
|
||||
|
||||
```tsx
|
||||
import { validateSession, getCurrentUser } from "@/lib/supabase/actions";
|
||||
import { redirect } from "next/navigation";
|
||||
|
||||
async function MyServerComponent() {
|
||||
const { user, error } = await getCurrentUser();
|
||||
|
||||
if (error || !user) {
|
||||
redirect("/login");
|
||||
}
|
||||
|
||||
return <div>Welcome, {user.email}!</div>;
|
||||
}
|
||||
```
|
||||
|
||||
### Server Actions
|
||||
|
||||
No changes needed! Server actions continue to work as before:
|
||||
|
||||
```tsx
|
||||
"use server";
|
||||
import { validateSession } from "@/lib/supabase/actions";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { redirect } from "next/navigation";
|
||||
|
||||
export async function myServerAction() {
|
||||
const { user, isValid } = await validateSession("/current-path");
|
||||
|
||||
if (!isValid || !user) {
|
||||
redirect("/login");
|
||||
return;
|
||||
}
|
||||
|
||||
// This automatically uses secure server-side authentication
|
||||
const api = new BackendAPI();
|
||||
const graphs = await api.listGraphs();
|
||||
|
||||
return graphs;
|
||||
}
|
||||
```
|
||||
|
||||
### API Calls and File Uploads
|
||||
|
||||
All operations use the same simple code everywhere:
|
||||
|
||||
```tsx
|
||||
// Works the same in both client and server contexts
|
||||
const api = new BackendAPI();
|
||||
|
||||
// Regular API requests
|
||||
const graphs = await api.listGraphs();
|
||||
const user = await api.createUser();
|
||||
const onboarding = await api.getUserOnboarding();
|
||||
|
||||
// File uploads
|
||||
const file = new File(["content"], "example.txt", { type: "text/plain" });
|
||||
const mediaUrl = await api.uploadStoreSubmissionMedia(file);
|
||||
```
|
||||
|
||||
## Available Server Actions
|
||||
|
||||
- `validateSession(currentPath)` - Validates the current session and returns user data
|
||||
- `getCurrentUser()` - Gets the current user without path validation
|
||||
- `serverLogout()` - Logs out the user server-side
|
||||
- `refreshSession()` - Refreshes the current session
|
||||
|
||||
## Internal Architecture
|
||||
|
||||
### Request Flow
|
||||
|
||||
All API requests (including file uploads) follow this flow:
|
||||
|
||||
1. **Any API call**: `api.listGraphs()` or `api.uploadStoreSubmissionMedia(file)`
|
||||
2. **Proxy server action**: `proxyApiRequest()` or `proxyFileUpload()` handles the request
|
||||
3. **Server authentication**: Gets JWT from httpOnly cookies
|
||||
4. **Backend request**: Makes authenticated request to backend API
|
||||
5. **Response**: Returns data to the calling code
|
||||
|
||||
### File Upload Implementation
|
||||
|
||||
File uploads are handled through a dedicated `proxyFileUpload` server action that:
|
||||
|
||||
- Receives the file data as FormData on the server
|
||||
- Retrieves authentication tokens from httpOnly cookies
|
||||
- Forwards the authenticated upload request to the backend
|
||||
- Returns the upload result to the client
|
||||
|
||||
## Security Benefits
|
||||
|
||||
1. **XSS Protection**: httpOnly cookies can't be accessed by malicious scripts
|
||||
2. **CSRF Protection**: Combined with SameSite cookie settings
|
||||
3. **Server-Side Validation**: Session validation always happens on the trusted server
|
||||
4. **Zero Token Exposure**: JWT tokens never exposed to client-side JavaScript
|
||||
5. **Zero Attack Surface**: No client-side session manipulation possible
|
||||
6. **Secure File Uploads**: File uploads maintain the same security model
|
||||
|
||||
## Migration Notes
|
||||
|
||||
- **No code changes required** - all existing BackendAPI usage continues to work
|
||||
- File uploads now work seamlessly with httpOnly cookies
|
||||
- Cross-tab logout functionality is preserved
|
||||
- WebSocket connections may need reconnection after session changes
|
||||
- All requests now have consistent security behavior
|
||||
223
autogpt_platform/frontend/src/lib/supabase/actions.ts
Normal file
223
autogpt_platform/frontend/src/lib/supabase/actions.ts
Normal file
@@ -0,0 +1,223 @@
|
||||
"use server";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import type { User } from "@supabase/supabase-js";
|
||||
import { revalidatePath } from "next/cache";
|
||||
import { redirect } from "next/navigation";
|
||||
import { getRedirectPath } from "./helpers";
|
||||
import { getServerSupabase } from "./server/getServerSupabase";
|
||||
|
||||
export interface SessionValidationResult {
|
||||
user: User | null;
|
||||
isValid: boolean;
|
||||
redirectPath?: string;
|
||||
}
|
||||
|
||||
export async function validateSession(
|
||||
currentPath: string,
|
||||
): Promise<SessionValidationResult> {
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"validateSession",
|
||||
{},
|
||||
async () => {
|
||||
const supabase = await getServerSupabase();
|
||||
|
||||
if (!supabase) {
|
||||
return {
|
||||
user: null,
|
||||
isValid: false,
|
||||
redirectPath: getRedirectPath(currentPath) || undefined,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const {
|
||||
data: { user },
|
||||
error,
|
||||
} = await supabase.auth.getUser();
|
||||
|
||||
if (error || !user) {
|
||||
const redirectPath = getRedirectPath(currentPath);
|
||||
return {
|
||||
user: null,
|
||||
isValid: false,
|
||||
redirectPath: redirectPath || undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
user,
|
||||
isValid: true,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Session validation error:", error);
|
||||
const redirectPath = getRedirectPath(currentPath);
|
||||
return {
|
||||
user: null,
|
||||
isValid: false,
|
||||
redirectPath: redirectPath || undefined,
|
||||
};
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
export async function getCurrentUser(): Promise<{
|
||||
user: User | null;
|
||||
error?: string;
|
||||
}> {
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"getCurrentUser",
|
||||
{},
|
||||
async () => {
|
||||
const supabase = await getServerSupabase();
|
||||
|
||||
if (!supabase) {
|
||||
return {
|
||||
user: null,
|
||||
error: "Supabase client not available",
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const {
|
||||
data: { user },
|
||||
error,
|
||||
} = await supabase.auth.getUser();
|
||||
|
||||
if (error) {
|
||||
return {
|
||||
user: null,
|
||||
error: error.message,
|
||||
};
|
||||
}
|
||||
|
||||
return { user };
|
||||
} catch (error) {
|
||||
console.error("Get current user error:", error);
|
||||
return {
|
||||
user: null,
|
||||
error: error instanceof Error ? error.message : "Unknown error",
|
||||
};
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
export async function getWebSocketToken(): Promise<{
|
||||
token: string | null;
|
||||
error?: string;
|
||||
}> {
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"getWebSocketToken",
|
||||
{},
|
||||
async () => {
|
||||
const supabase = await getServerSupabase();
|
||||
|
||||
if (!supabase) {
|
||||
return {
|
||||
token: null,
|
||||
error: "Supabase client not available",
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const {
|
||||
data: { session },
|
||||
error,
|
||||
} = await supabase.auth.getSession();
|
||||
|
||||
if (error) {
|
||||
return {
|
||||
token: null,
|
||||
error: error.message,
|
||||
};
|
||||
}
|
||||
|
||||
return { token: session?.access_token || null };
|
||||
} catch (error) {
|
||||
console.error("Get WebSocket token error:", error);
|
||||
return {
|
||||
token: null,
|
||||
error: error instanceof Error ? error.message : "Unknown error",
|
||||
};
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
export type ServerLogoutOptions = {
|
||||
globalLogout?: boolean;
|
||||
};
|
||||
|
||||
export async function serverLogout(options: ServerLogoutOptions = {}) {
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"serverLogout",
|
||||
{},
|
||||
async () => {
|
||||
const supabase = await getServerSupabase();
|
||||
|
||||
if (!supabase) {
|
||||
redirect("/login");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const { error } = await supabase.auth.signOut({
|
||||
scope: options.globalLogout ? "global" : "local",
|
||||
});
|
||||
|
||||
if (error) {
|
||||
console.error("Error logging out:", error);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Logout error:", error);
|
||||
}
|
||||
|
||||
// Clear all cached data and redirect
|
||||
revalidatePath("/", "layout");
|
||||
redirect("/login");
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
export async function refreshSession() {
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"refreshSession",
|
||||
{},
|
||||
async () => {
|
||||
const supabase = await getServerSupabase();
|
||||
|
||||
if (!supabase) {
|
||||
return {
|
||||
user: null,
|
||||
error: "Supabase client not available",
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const {
|
||||
data: { user },
|
||||
error,
|
||||
} = await supabase.auth.refreshSession();
|
||||
|
||||
if (error) {
|
||||
return {
|
||||
user: null,
|
||||
error: error.message,
|
||||
};
|
||||
}
|
||||
|
||||
// Revalidate the layout to update server components
|
||||
revalidatePath("/", "layout");
|
||||
|
||||
return { user };
|
||||
} catch (error) {
|
||||
console.error("Refresh session error:", error);
|
||||
return {
|
||||
user: null,
|
||||
error: error instanceof Error ? error.message : "Unknown error",
|
||||
};
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user