mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
11 Commits
aryshare-r
...
hosjdl-cod
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc981b52a3 | ||
|
|
61643e6a47 | ||
|
|
21b4d272ce | ||
|
|
b8ba572629 | ||
|
|
47deeb53c3 | ||
|
|
1b81a7c755 | ||
|
|
8f1b3eb8ba | ||
|
|
73ee6e272a | ||
|
|
f466b010e4 | ||
|
|
f8965e530f | ||
|
|
701d283f69 |
50
AGENTS.md
Normal file
50
AGENTS.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# AutoGPT Platform Contribution Guide
|
||||
|
||||
This guide provides context for Codex when updating the **autogpt_platform** folder.
|
||||
|
||||
## Directory overview
|
||||
- `autogpt_platform/backend` – FastAPI based backend service.
|
||||
- `autogpt_platform/autogpt_libs` – Shared Python libraries.
|
||||
- `autogpt_platform/frontend` – Next.js + Typescript frontend.
|
||||
- `autogpt_platform/docker-compose.yml` – development stack.
|
||||
|
||||
See `docs/content/platform/getting-started.md` for setup instructions.
|
||||
|
||||
## Code style
|
||||
- Format Python code with `poetry run format`.
|
||||
- Format frontend code using `yarn format`.
|
||||
|
||||
## Testing
|
||||
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).
|
||||
- Frontend: `yarn test` or `yarn test-ui` for Playwright tests. See `docs/content/platform/contributing/tests.md` for tips.
|
||||
|
||||
Always run the relevant linters and tests before committing.
|
||||
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
||||
Types:
|
||||
- feat
|
||||
- fix
|
||||
- refactor
|
||||
- ci
|
||||
- dx (developer experience)
|
||||
Scopes:
|
||||
- platform
|
||||
- platform/library
|
||||
- platform/marketplace
|
||||
- backend
|
||||
- backend/executor
|
||||
- frontend
|
||||
- frontend/library
|
||||
- frontend/marketplace
|
||||
- blocks
|
||||
|
||||
## Pull requests
|
||||
- Use the template in `.github/PULL_REQUEST_TEMPLATE.md`.
|
||||
- Rely on the pre-commit checks for linting and formatting
|
||||
- Fill out the **Changes** section and the checklist.
|
||||
- Use conventional commit titles with a scope (e.g. `feat(frontend): add feature`).
|
||||
- Keep out-of-scope changes under 20% of the PR.
|
||||
- Ensure PR descriptions are complete.
|
||||
- For changes touching `data/*.py`, validate user ID checks or explain why not needed.
|
||||
- If adding protected frontend routes, update `frontend/lib/supabase/middleware.ts`.
|
||||
- Use the linear ticket branch structure if given codex/open-1668-resume-dropped-runs
|
||||
|
||||
@@ -197,10 +197,6 @@ SMARTLEAD_API_KEY=
|
||||
# ZeroBounce
|
||||
ZEROBOUNCE_API_KEY=
|
||||
|
||||
# Ayrshare
|
||||
AYRSHARE_API_KEY=
|
||||
AYRSHARE_JWT_KEY=
|
||||
|
||||
## ===== OPTIONAL API KEYS END ===== ##
|
||||
|
||||
# Logging Configuration
|
||||
|
||||
@@ -52,7 +52,6 @@ class AudioTrack(str, Enum):
|
||||
REFRESHER = ("Refresher",)
|
||||
TOURIST = ("Tourist",)
|
||||
TWIN_TYCHES = ("Twin Tyches",)
|
||||
DONT_STOP_ME_ABSTRACT_FUTURE_BASS = ("Dont Stop Me Abstract Future Bass",)
|
||||
|
||||
@property
|
||||
def audio_url(self):
|
||||
@@ -78,7 +77,6 @@ class AudioTrack(str, Enum):
|
||||
AudioTrack.REFRESHER: "https://cdn.tfrv.xyz/audio/refresher.mp3",
|
||||
AudioTrack.TOURIST: "https://cdn.tfrv.xyz/audio/tourist.mp3",
|
||||
AudioTrack.TWIN_TYCHES: "https://cdn.tfrv.xyz/audio/twin-tynches.mp3",
|
||||
AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS: "https://cdn.revid.ai/audio/_dont-stop-me-abstract-future-bass.mp3",
|
||||
}
|
||||
return audio_urls[self]
|
||||
|
||||
@@ -106,7 +104,6 @@ class GenerationPreset(str, Enum):
|
||||
MOVIE = ("Movie",)
|
||||
STYLIZED_ILLUSTRATION = ("Stylized Illustration",)
|
||||
MANGA = ("Manga",)
|
||||
DEFAULT = ("DEFAULT",)
|
||||
|
||||
|
||||
class Voice(str, Enum):
|
||||
@@ -116,7 +113,6 @@ class Voice(str, Enum):
|
||||
JESSICA = "Jessica"
|
||||
CHARLOTTE = "Charlotte"
|
||||
CALLUM = "Callum"
|
||||
EVA = "Eva"
|
||||
|
||||
@property
|
||||
def voice_id(self):
|
||||
@@ -127,7 +123,6 @@ class Voice(str, Enum):
|
||||
Voice.JESSICA: "cgSgspJ2msm6clMCkdW9",
|
||||
Voice.CHARLOTTE: "XB0fDUnXU5powFXDhCwa",
|
||||
Voice.CALLUM: "N2lVS1w4EtoT3dr4eOWO",
|
||||
Voice.EVA: "FGY2WhTYpPnrIDTdsKH5",
|
||||
}
|
||||
return voice_id_map[self]
|
||||
|
||||
@@ -144,54 +139,7 @@ class VisualMediaType(str, Enum):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _RevidMixin:
|
||||
"""Utility mix‑in that bundles the shared webhook / polling helpers."""
|
||||
|
||||
def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status_code}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = requests.get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
max_wait_time: int = 3600,
|
||||
) -> str:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
|
||||
class AIShortformVideoCreatorBlock(Block, _RevidMixin):
|
||||
"""Creates a short‑form text‑to‑video clip using stock or AI imagery."""
|
||||
|
||||
class AIShortformVideoCreatorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
@@ -258,28 +206,86 @@ class AIShortformVideoCreatorBlock(Block, _RevidMixin):
|
||||
"https://example.com/video.mp4",
|
||||
),
|
||||
test_mock={
|
||||
"create_webhook": lambda: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda api_key, payload: {"pid": "test_pid"},
|
||||
"wait_for_video": lambda api_key, pid, max_wait_time=3600: "https://example.com/video.mp4",
|
||||
"wait_for_video": lambda api_key, pid, webhook_token, max_wait_time=1000: "https://example.com/video.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
def create_webhook(self):
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = requests.post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status_code}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = requests.get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
webhook_token: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Create a new Webhook.site URL
|
||||
webhook_token, webhook_url = self.create_webhook()
|
||||
logger.debug(f"Webhook URL: {webhook_url}")
|
||||
|
||||
audio_url = input_data.background_music.audio_url
|
||||
|
||||
payload = {
|
||||
"frameRate": input_data.frame_rate,
|
||||
"resolution": input_data.resolution,
|
||||
"frameDurationMultiplier": 18,
|
||||
"webhook": None,
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"mediaType": input_data.video_style,
|
||||
"captionPresetName": "Wrap 1",
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"hasEnhancedGeneration": True,
|
||||
"generationPreset": input_data.generation_preset.name,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedAudio": input_data.background_music,
|
||||
"origin": "/create",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
@@ -295,7 +301,7 @@ class AIShortformVideoCreatorBlock(Block, _RevidMixin):
|
||||
"selectedStoryStyle": {"value": "custom", "label": "Custom"},
|
||||
"hasToGenerateVideos": input_data.video_style
|
||||
!= VisualMediaType.STOCK_VIDEOS,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"audioUrl": audio_url,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -308,354 +314,10 @@ class AIShortformVideoCreatorBlock(Block, _RevidMixin):
|
||||
f"Failed to create video: No project ID returned. API Response: {response}"
|
||||
)
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
logger.debug(f"Video created with project ID: {pid}. Waiting for completion...")
|
||||
video_url = self.wait_for_video(credentials.api_key, pid)
|
||||
logger.debug(f"Video ready: {video_url}")
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIAdMakerVideoCreatorBlock(Block, _RevidMixin):
|
||||
"""Generates a 30‑second vertical AI advert using optional user‑supplied imagery."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Credentials for Revid.ai API access.",
|
||||
)
|
||||
script: str = SchemaField(
|
||||
description="Short advertising copy. Line breaks create new scenes.",
|
||||
placeholder="Introducing Foobar – [show product photo] the gadget that does it all.",
|
||||
)
|
||||
ratio: str = SchemaField(description="Aspect ratio", default="9 / 16")
|
||||
target_duration: int = SchemaField(
|
||||
description="Desired length of the ad in seconds.", default=30
|
||||
)
|
||||
voice: Voice = SchemaField(
|
||||
description="Narration voice", default=Voice.EVA, placeholder=Voice.EVA
|
||||
)
|
||||
background_music: AudioTrack = SchemaField(
|
||||
description="Background track",
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS,
|
||||
)
|
||||
input_media_urls: list[str] = SchemaField(
|
||||
description="List of image URLs to feature in the advert.", default=[]
|
||||
)
|
||||
use_only_provided_media: bool = SchemaField(
|
||||
description="Restrict visuals to supplied images only.", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="URL of the finished advert")
|
||||
error: str = SchemaField(description="Error message on failure")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3e3fd845-000e-457f-9f50-9f2f9e278bbd",
|
||||
description="Creates an AI‑generated 30‑second advert (text + images)",
|
||||
categories={BlockCategory.MARKETING, BlockCategory.AI},
|
||||
input_schema=AIAdMakerVideoCreatorBlock.Input,
|
||||
output_schema=AIAdMakerVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "Test product launch!",
|
||||
"input_media_urls": [
|
||||
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
||||
],
|
||||
},
|
||||
test_output=("video_url", "https://example.com/ad.mp4"),
|
||||
test_mock={
|
||||
"create_video": lambda api_key, payload: {"pid": "test_pid"},
|
||||
"wait_for_video": lambda api_key, pid, max_wait_time=3600: "https://example.com/ad.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
|
||||
payload = {
|
||||
"webhook": None,
|
||||
"creationParams": {
|
||||
"targetDuration": input_data.target_duration,
|
||||
"ratio": input_data.ratio,
|
||||
"mediaType": "aiVideo",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "ai-ad-generator",
|
||||
"slugNew": "",
|
||||
"isCopiedFrom": False,
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasAvatar": False,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
|
||||
"selectedAvatarType": "video/mp4",
|
||||
"websiteToRecord": "",
|
||||
"hasToGenerateCover": True,
|
||||
"nbGenerations": 1,
|
||||
"disableCaptions": False,
|
||||
"mediaMultiplier": "medium",
|
||||
"characters": [],
|
||||
"captionPresetName": "Revid",
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "General"},
|
||||
"generationPreset": "DEFAULT",
|
||||
"hasToGenerateMusic": False,
|
||||
"isOptimizedForChinese": False,
|
||||
"generationUserPrompt": "",
|
||||
"enableNsfwFilter": False,
|
||||
"addStickers": False,
|
||||
"typeMovingImageAnim": "dynamic",
|
||||
"hasToGenerateSoundEffects": False,
|
||||
"forceModelType": "gpt-image-1",
|
||||
"selectedCharacters": [],
|
||||
"lang": "",
|
||||
"voiceSpeed": 1,
|
||||
"disableAudio": False,
|
||||
"disableVoice": False,
|
||||
"useOnlyProvidedMedia": input_data.use_only_provided_media,
|
||||
"imageGenerationModel": "ultra",
|
||||
"videoGenerationModel": "base",
|
||||
"hasEnhancedGeneration": True,
|
||||
"hasEnhancedGenerationPro": True,
|
||||
"inputMedias": [
|
||||
{"url": url, "title": "", "type": "image"}
|
||||
for url in input_data.input_media_urls
|
||||
],
|
||||
"hasToGenerateVideos": True,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"watermark": None,
|
||||
},
|
||||
}
|
||||
|
||||
response = self.create_video(credentials.api_key, payload)
|
||||
pid = response.get("pid")
|
||||
if not pid:
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIPromptToVideoCreatorBlock(Block, _RevidMixin):
|
||||
"""Turns a single creative prompt into a fully AI‑generated video."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(description="Revid.ai API credentials")
|
||||
prompt: str = SchemaField(
|
||||
description="Imaginative prompt describing the desired video.",
|
||||
placeholder="A neon‑lit cyberpunk alley with rain‑soaked pavements.",
|
||||
)
|
||||
ratio: str = SchemaField(default="9 / 16")
|
||||
prompt_target_duration: int = SchemaField(default=30)
|
||||
voice: Voice = SchemaField(default=Voice.EVA)
|
||||
background_music: AudioTrack = SchemaField(
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="Rendered video URL")
|
||||
error: str = SchemaField(description="Error message if any")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="46f4099c-ad01-4d79-874c-37a24c937ba3",
|
||||
description="Creates an AI video from a single prompt (no line‑breaking script).",
|
||||
categories={BlockCategory.AI, BlockCategory.SOCIAL},
|
||||
input_schema=AIPromptToVideoCreatorBlock.Input,
|
||||
output_schema=AIPromptToVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"prompt": "Epic time‑lapse of a city skyline from day to night",
|
||||
},
|
||||
test_output=("video_url", "https://example.com/prompt.mp4"),
|
||||
test_mock={
|
||||
"create_video": lambda api_key, payload: {"pid": "test_pid"},
|
||||
"wait_for_video": lambda api_key, pid, max_wait_time=3600: "https://example.com/prompt.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
|
||||
payload = {
|
||||
"webhook": None,
|
||||
"creationParams": {
|
||||
"mediaType": "aiVideo",
|
||||
"flowType": "prompt-to-video",
|
||||
"slug": "prompt-to-video",
|
||||
"slugNew": "",
|
||||
"isCopiedFrom": False,
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasAvatar": False,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"ratio": input_data.ratio,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
|
||||
"selectedAvatarType": "video/mp4",
|
||||
"websiteToRecord": "",
|
||||
"hasToGenerateCover": True,
|
||||
"nbGenerations": 1,
|
||||
"disableCaptions": False,
|
||||
"characters": [],
|
||||
"captionPresetName": "Revid",
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "General"},
|
||||
"generationPreset": "DEFAULT",
|
||||
"hasToGenerateMusic": False,
|
||||
"isOptimizedForChinese": False,
|
||||
"generationUserPrompt": input_data.prompt,
|
||||
"enableNsfwFilter": False,
|
||||
"addStickers": False,
|
||||
"typeMovingImageAnim": "dynamic",
|
||||
"hasToGenerateSoundEffects": False,
|
||||
"promptTargetDuration": input_data.prompt_target_duration,
|
||||
"selectedCharacters": [],
|
||||
"lang": "",
|
||||
"voiceSpeed": 1,
|
||||
"disableAudio": False,
|
||||
"disableVoice": False,
|
||||
"imageGenerationModel": "good",
|
||||
"videoGenerationModel": "base",
|
||||
"hasEnhancedGeneration": True,
|
||||
"hasEnhancedGenerationPro": True,
|
||||
"inputMedias": [],
|
||||
"hasToGenerateVideos": True,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"watermark": None,
|
||||
},
|
||||
}
|
||||
|
||||
response = self.create_video(credentials.api_key, payload)
|
||||
pid = response.get("pid")
|
||||
if not pid:
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIScreenshotToVideoAdBlock(Block, _RevidMixin):
|
||||
"""Creates an advert where the supplied screenshot is narrated by an AI avatar."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(description="Revid.ai API key")
|
||||
script: str = SchemaField(
|
||||
description="Narration that will accompany the screenshot.",
|
||||
placeholder="Check out these amazing stats!",
|
||||
)
|
||||
screenshot_url: str = SchemaField(
|
||||
description="Screenshot or image URL to showcase."
|
||||
)
|
||||
ratio: str = SchemaField(default="9 / 16")
|
||||
target_duration: int = SchemaField(default=30)
|
||||
voice: Voice = SchemaField(default=Voice.EVA)
|
||||
background_music: AudioTrack = SchemaField(
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="Rendered video URL")
|
||||
error: str = SchemaField(description="Error, if encountered")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9f68982c-3af6-4923-9a97-b50a8c8d2234",
|
||||
description="Turns a screenshot into an engaging, avatar‑narrated video advert.",
|
||||
categories={BlockCategory.AI, BlockCategory.MARKETING},
|
||||
input_schema=AIScreenshotToVideoAdBlock.Input,
|
||||
output_schema=AIScreenshotToVideoAdBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "Amazing numbers!",
|
||||
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
||||
},
|
||||
test_output=("video_url", "https://example.com/screenshot.mp4"),
|
||||
test_mock={
|
||||
"create_video": lambda api_key, payload: {"pid": "test_pid"},
|
||||
"wait_for_video": lambda api_key, pid, max_wait_time=3600: "https://example.com/screenshot.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
|
||||
payload = {
|
||||
"webhook": None,
|
||||
"creationParams": {
|
||||
"targetDuration": input_data.target_duration,
|
||||
"ratio": input_data.ratio,
|
||||
"mediaType": "aiVideo",
|
||||
"hasAvatar": True,
|
||||
"removeAvatarBackground": True,
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "ai-ad-generator",
|
||||
"slugNew": "screenshot-to-video-ad",
|
||||
"isCopiedFrom": "ai-ad-generator",
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
|
||||
"selectedAvatarType": "video/mp4",
|
||||
"websiteToRecord": "",
|
||||
"hasToGenerateCover": True,
|
||||
"nbGenerations": 1,
|
||||
"disableCaptions": False,
|
||||
"mediaMultiplier": "medium",
|
||||
"characters": [],
|
||||
"captionPresetName": "Revid",
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "General"},
|
||||
"generationPreset": "DEFAULT",
|
||||
"hasToGenerateMusic": False,
|
||||
"isOptimizedForChinese": False,
|
||||
"generationUserPrompt": "",
|
||||
"enableNsfwFilter": False,
|
||||
"addStickers": False,
|
||||
"typeMovingImageAnim": "dynamic",
|
||||
"hasToGenerateSoundEffects": False,
|
||||
"forceModelType": "gpt-image-1",
|
||||
"selectedCharacters": [],
|
||||
"lang": "",
|
||||
"voiceSpeed": 1,
|
||||
"disableAudio": False,
|
||||
"disableVoice": False,
|
||||
"useOnlyProvidedMedia": True,
|
||||
"imageGenerationModel": "ultra",
|
||||
"videoGenerationModel": "base",
|
||||
"hasEnhancedGeneration": True,
|
||||
"hasEnhancedGenerationPro": True,
|
||||
"inputMedias": [
|
||||
{"url": input_data.screenshot_url, "title": "", "type": "image"}
|
||||
],
|
||||
"hasToGenerateVideos": True,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"watermark": None,
|
||||
},
|
||||
}
|
||||
|
||||
response = self.create_video(credentials.api_key, payload)
|
||||
pid = response.get("pid")
|
||||
if not pid:
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
else:
|
||||
logger.debug(
|
||||
f"Video created with project ID: {pid}. Waiting for completion..."
|
||||
)
|
||||
video_url = self.wait_for_video(credentials.api_key, pid, webhook_token)
|
||||
logger.debug(f"Video ready: {video_url}")
|
||||
yield "video_url", video_url
|
||||
|
||||
@@ -1,482 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from backend.util.request import Requests
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class AyrshareAPIException(Exception):
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class SocialPlatform(str, Enum):
|
||||
BLUESKY = "bluesky"
|
||||
FACEBOOK = "facebook"
|
||||
TWITTER = "twitter"
|
||||
LINKEDIN = "linkedin"
|
||||
INSTAGRAM = "instagram"
|
||||
YOUTUBE = "youtube"
|
||||
REDDIT = "reddit"
|
||||
TELEGRAM = "telegram"
|
||||
GMB = "gmb"
|
||||
PINTEREST = "pinterest"
|
||||
TIKTOK = "tiktok"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmailConfig:
|
||||
to: str
|
||||
subject: Optional[str] = None
|
||||
body: Optional[str] = None
|
||||
from_name: Optional[str] = None
|
||||
from_email: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JWTResponse:
|
||||
status: str
|
||||
title: str
|
||||
token: str
|
||||
url: str
|
||||
emailSent: Optional[bool] = None
|
||||
expiresIn: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileResponse:
|
||||
status: str
|
||||
title: str
|
||||
refId: str
|
||||
profileKey: str
|
||||
messagingActive: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PostResponse:
|
||||
status: str
|
||||
id: str
|
||||
refId: str
|
||||
profileTitle: str
|
||||
post: str
|
||||
postIds: Optional[List[Dict[str, Any]]] = None
|
||||
scheduleDate: Optional[str] = None
|
||||
errors: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoHashtag:
|
||||
max: Optional[int] = None
|
||||
position: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FirstComment:
|
||||
text: str
|
||||
platforms: Optional[List[SocialPlatform]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoSchedule:
|
||||
interval: str
|
||||
platforms: Optional[List[SocialPlatform]] = None
|
||||
startDate: Optional[str] = None
|
||||
endDate: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoRepost:
|
||||
interval: str
|
||||
platforms: Optional[List[SocialPlatform]] = None
|
||||
startDate: Optional[str] = None
|
||||
endDate: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PostError:
|
||||
code: int
|
||||
message: str
|
||||
details: str
|
||||
|
||||
|
||||
class AyrshareClient:
|
||||
"""Client for the Ayrshare Social Media Post API"""
|
||||
|
||||
API_URL = "https://api.ayrshare.com/api"
|
||||
POST_ENDPOINT = f"{API_URL}/post"
|
||||
PROFILES_ENDPOINT = f"{API_URL}/profiles"
|
||||
JWT_ENDPOINT = f"{PROFILES_ENDPOINT}/generateJWT"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
custom_requests: Optional[Requests] = None,
|
||||
):
|
||||
headers: Dict[str, str] = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {settings.secrets.ayrshare_api_key}",
|
||||
}
|
||||
self.headers = headers
|
||||
|
||||
if custom_requests:
|
||||
self._requests = custom_requests
|
||||
else:
|
||||
self._requests = Requests(
|
||||
extra_headers=headers,
|
||||
trusted_origins=["https://api.ayrshare.com"],
|
||||
raise_for_status=False,
|
||||
)
|
||||
|
||||
def generate_jwt(
|
||||
self,
|
||||
private_key: str,
|
||||
profile_key: str,
|
||||
logout: Optional[bool] = None,
|
||||
redirect: Optional[str] = None,
|
||||
allowed_social: Optional[List[SocialPlatform]] = None,
|
||||
verify: Optional[bool] = None,
|
||||
base64: Optional[bool] = None,
|
||||
expires_in: Optional[int] = None,
|
||||
email: Optional[EmailConfig] = None,
|
||||
) -> JWTResponse:
|
||||
"""
|
||||
Generate a JSON Web Token (JWT) for use with single sign on.
|
||||
|
||||
Args:
|
||||
domain: Domain of app. Must match the domain given during onboarding.
|
||||
private_key: Private Key used for encryption.
|
||||
profile_key: User Profile Key (not the API Key).
|
||||
logout: Automatically logout the current session.
|
||||
redirect: URL to redirect to when the "Done" button or logo is clicked.
|
||||
allowed_social: List of social networks to display in the linking page.
|
||||
verify: Verify that the generated token is valid (recommended for non-production).
|
||||
base64: Whether the private key is base64 encoded.
|
||||
expires_in: Token longevity in minutes (1-2880).
|
||||
email: Configuration for sending Connect Accounts email.
|
||||
|
||||
Returns:
|
||||
JWTResponse object containing the JWT token and URL.
|
||||
|
||||
Raises:
|
||||
AyrshareAPIException: If the API request fails or private key is invalid.
|
||||
"""
|
||||
payload: Dict[str, Any] = {
|
||||
"domain": "id-pojeg",
|
||||
"privateKey": private_key,
|
||||
"profileKey": profile_key,
|
||||
}
|
||||
|
||||
headers = self.headers
|
||||
headers["Profile-Key"] = profile_key
|
||||
if logout is not None:
|
||||
payload["logout"] = logout
|
||||
if redirect is not None:
|
||||
payload["redirect"] = redirect
|
||||
if allowed_social is not None:
|
||||
payload["allowedSocial"] = [p.value for p in allowed_social]
|
||||
if verify is not None:
|
||||
payload["verify"] = verify
|
||||
if base64 is not None:
|
||||
payload["base64"] = base64
|
||||
if expires_in is not None:
|
||||
payload["expiresIn"] = expires_in
|
||||
if email is not None:
|
||||
payload["email"] = email.__dict__
|
||||
|
||||
response = self._requests.post(self.JWT_ENDPOINT, json=payload, headers=headers)
|
||||
|
||||
if not response.ok:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_message = error_data.get("message", response.text)
|
||||
except json.JSONDecodeError:
|
||||
error_message = response.text
|
||||
|
||||
raise AyrshareAPIException(
|
||||
f"Ayrshare API request failed ({response.status_code}): {error_message}",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
if response_data.get("status") != "success":
|
||||
raise AyrshareAPIException(
|
||||
f"Ayrshare API returned error: {response_data.get('message', 'Unknown error')}",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
return JWTResponse(**response_data)
|
||||
|
||||
def create_profile(
|
||||
self,
|
||||
title: str,
|
||||
messaging_active: Optional[bool] = None,
|
||||
hide_top_header: Optional[bool] = None,
|
||||
top_header: Optional[str] = None,
|
||||
disable_social: Optional[List[SocialPlatform]] = None,
|
||||
team: Optional[bool] = None,
|
||||
email: Optional[str] = None,
|
||||
sub_header: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> ProfileResponse | PostError:
|
||||
"""
|
||||
Create a new User Profile under your Primary Profile.
|
||||
|
||||
Args:
|
||||
title: Title of the new profile. Must be unique.
|
||||
messaging_active: Set to true to activate messaging for this user profile.
|
||||
hide_top_header: Hide the top header on the social accounts linkage page.
|
||||
top_header: Change the header on the social accounts linkage page.
|
||||
disable_social: Array of social networks that are disabled for this user's profile.
|
||||
team: Create a new user profile as a team member.
|
||||
email: Email address for team member invite (required if team is true).
|
||||
sub_header: Change the sub header on the social accounts linkage page.
|
||||
tags: Array of strings to tag user profiles.
|
||||
|
||||
Returns:
|
||||
ProfileResponse object containing the profile details and profile key.
|
||||
|
||||
Raises:
|
||||
AyrshareAPIException: If the API request fails or profile title already exists.
|
||||
"""
|
||||
payload: Dict[str, Any] = {
|
||||
"title": title,
|
||||
}
|
||||
|
||||
if messaging_active is not None:
|
||||
payload["messagingActive"] = messaging_active
|
||||
if hide_top_header is not None:
|
||||
payload["hideTopHeader"] = hide_top_header
|
||||
if top_header is not None:
|
||||
payload["topHeader"] = top_header
|
||||
if disable_social is not None:
|
||||
payload["disableSocial"] = [p.value for p in disable_social]
|
||||
if team is not None:
|
||||
payload["team"] = team
|
||||
if email is not None:
|
||||
payload["email"] = email
|
||||
if sub_header is not None:
|
||||
payload["subHeader"] = sub_header
|
||||
if tags is not None:
|
||||
payload["tags"] = tags
|
||||
|
||||
response = self._requests.post(self.PROFILES_ENDPOINT, json=payload)
|
||||
|
||||
if not response.ok:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_message = error_data.get("message", response.text)
|
||||
except json.JSONDecodeError:
|
||||
error_message = response.text
|
||||
|
||||
raise AyrshareAPIException(
|
||||
f"Ayrshare API request failed ({response.status_code}): {error_message}",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
if response_data.get("status") != "success":
|
||||
raise AyrshareAPIException(
|
||||
f"Ayrshare API returned error: {response_data.get('message', 'Unknown error')}",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
return ProfileResponse(**response_data)
|
||||
|
||||
def create_post(
|
||||
self,
|
||||
post: str,
|
||||
platforms: List[SocialPlatform],
|
||||
media_urls: Optional[List[str]] = None,
|
||||
is_video: Optional[bool] = None,
|
||||
schedule_date: Optional[str] = None,
|
||||
first_comment: Optional[FirstComment] = None,
|
||||
disable_comments: Optional[bool] = None,
|
||||
shorten_links: Optional[bool] = None,
|
||||
auto_schedule: Optional[AutoSchedule] = None,
|
||||
auto_repost: Optional[AutoRepost] = None,
|
||||
auto_hashtag: Optional[Union[AutoHashtag, bool]] = None,
|
||||
unsplash: Optional[str] = None,
|
||||
bluesky_options: Optional[Dict[str, Any]] = None,
|
||||
facebook_options: Optional[Dict[str, Any]] = None,
|
||||
gmb_options: Optional[Dict[str, Any]] = None,
|
||||
instagram_options: Optional[Dict[str, Any]] = None,
|
||||
linkedin_options: Optional[Dict[str, Any]] = None,
|
||||
pinterest_options: Optional[Dict[str, Any]] = None,
|
||||
reddit_options: Optional[Dict[str, Any]] = None,
|
||||
telegram_options: Optional[Dict[str, Any]] = None,
|
||||
threads_options: Optional[Dict[str, Any]] = None,
|
||||
tiktok_options: Optional[Dict[str, Any]] = None,
|
||||
twitter_options: Optional[Dict[str, Any]] = None,
|
||||
youtube_options: Optional[Dict[str, Any]] = None,
|
||||
requires_approval: Optional[bool] = None,
|
||||
random_post: Optional[bool] = None,
|
||||
random_media_url: Optional[bool] = None,
|
||||
idempotency_key: Optional[str] = None,
|
||||
notes: Optional[str] = None,
|
||||
profile_key: Optional[str] = None,
|
||||
) -> PostResponse | PostError:
|
||||
"""
|
||||
Create a post across multiple social media platforms.
|
||||
|
||||
Args:
|
||||
post: The post text to be published
|
||||
platforms: List of platforms to post to (e.g. [SocialPlatform.TWITTER, SocialPlatform.FACEBOOK])
|
||||
media_urls: Optional list of media URLs to include
|
||||
is_video: Whether the media is a video
|
||||
schedule_date: UTC datetime for scheduling (YYYY-MM-DDThh:mm:ssZ)
|
||||
first_comment: Configuration for first comment
|
||||
disable_comments: Whether to disable comments
|
||||
shorten_links: Whether to shorten links
|
||||
auto_schedule: Configuration for automatic scheduling
|
||||
auto_repost: Configuration for automatic reposting
|
||||
auto_hashtag: Configuration for automatic hashtags
|
||||
unsplash: Unsplash image configuration
|
||||
bluesky_options: Bluesky-specific options
|
||||
facebook_options: Facebook-specific options
|
||||
gmb_options: Google Business Profile options
|
||||
instagram_options: Instagram-specific options
|
||||
linkedin_options: LinkedIn-specific options
|
||||
pinterest_options: Pinterest-specific options
|
||||
reddit_options: Reddit-specific options
|
||||
telegram_options: Telegram-specific options
|
||||
threads_options: Threads-specific options
|
||||
tiktok_options: TikTok-specific options
|
||||
twitter_options: Twitter-specific options
|
||||
youtube_options: YouTube-specific options
|
||||
requires_approval: Whether to enable approval workflow
|
||||
random_post: Whether to generate random post text
|
||||
random_media_url: Whether to generate random media
|
||||
idempotency_key: Unique ID for the post
|
||||
notes: Additional notes for the post
|
||||
|
||||
Returns:
|
||||
PostResponse object containing the post details and status
|
||||
|
||||
Raises:
|
||||
AyrshareAPIException: If the API request fails
|
||||
"""
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"post": post,
|
||||
"platforms": [p.value for p in platforms],
|
||||
}
|
||||
|
||||
# Add optional parameters if provided
|
||||
if media_urls:
|
||||
payload["mediaUrls"] = media_urls
|
||||
if is_video is not None:
|
||||
payload["isVideo"] = is_video
|
||||
if schedule_date:
|
||||
payload["scheduleDate"] = schedule_date
|
||||
if first_comment:
|
||||
first_comment_dict = first_comment.__dict__.copy()
|
||||
if first_comment.platforms:
|
||||
first_comment_dict["platforms"] = [
|
||||
p.value for p in first_comment.platforms
|
||||
]
|
||||
payload["firstComment"] = first_comment_dict
|
||||
if disable_comments is not None:
|
||||
payload["disableComments"] = disable_comments
|
||||
if shorten_links is not None:
|
||||
payload["shortenLinks"] = shorten_links
|
||||
if auto_schedule:
|
||||
auto_schedule_dict = auto_schedule.__dict__.copy()
|
||||
if auto_schedule.platforms:
|
||||
auto_schedule_dict["platforms"] = [
|
||||
p.value for p in auto_schedule.platforms
|
||||
]
|
||||
payload["autoSchedule"] = auto_schedule_dict
|
||||
if auto_repost:
|
||||
auto_repost_dict = auto_repost.__dict__.copy()
|
||||
if auto_repost.platforms:
|
||||
auto_repost_dict["platforms"] = [p.value for p in auto_repost.platforms]
|
||||
payload["autoRepost"] = auto_repost_dict
|
||||
if auto_hashtag:
|
||||
payload["autoHashtag"] = (
|
||||
auto_hashtag.__dict__
|
||||
if isinstance(auto_hashtag, AutoHashtag)
|
||||
else auto_hashtag
|
||||
)
|
||||
if unsplash:
|
||||
payload["unsplash"] = unsplash
|
||||
if bluesky_options:
|
||||
payload["blueskyOptions"] = bluesky_options
|
||||
if facebook_options:
|
||||
payload["faceBookOptions"] = facebook_options
|
||||
if gmb_options:
|
||||
payload["gmbOptions"] = gmb_options
|
||||
if instagram_options:
|
||||
payload["instagramOptions"] = instagram_options
|
||||
if linkedin_options:
|
||||
payload["linkedInOptions"] = linkedin_options
|
||||
if pinterest_options:
|
||||
payload["pinterestOptions"] = pinterest_options
|
||||
if reddit_options:
|
||||
payload["redditOptions"] = reddit_options
|
||||
if telegram_options:
|
||||
payload["telegramOptions"] = telegram_options
|
||||
if threads_options:
|
||||
payload["threadsOptions"] = threads_options
|
||||
if tiktok_options:
|
||||
payload["tikTokOptions"] = tiktok_options
|
||||
if twitter_options:
|
||||
payload["twitterOptions"] = twitter_options
|
||||
if youtube_options:
|
||||
payload["youTubeOptions"] = youtube_options
|
||||
if requires_approval is not None:
|
||||
payload["requiresApproval"] = requires_approval
|
||||
if random_post is not None:
|
||||
payload["randomPost"] = random_post
|
||||
if random_media_url is not None:
|
||||
payload["randomMediaUrl"] = random_media_url
|
||||
if idempotency_key:
|
||||
payload["idempotencyKey"] = idempotency_key
|
||||
if notes:
|
||||
payload["notes"] = notes
|
||||
|
||||
headers = self.headers
|
||||
if profile_key:
|
||||
headers["Profile-Key"] = profile_key
|
||||
|
||||
response = self._requests.post(
|
||||
self.POST_ENDPOINT, json=payload, headers=headers
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_message = error_data.get("message", response.text)
|
||||
error_code = error_data.get("code", response.status_code)
|
||||
error_details = error_data.get("details", {})
|
||||
logger.error(error_data)
|
||||
return PostError(
|
||||
code=error_code,
|
||||
message=error_message,
|
||||
details=error_details,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
error_message = response.text
|
||||
|
||||
raise AyrshareAPIException(
|
||||
f"Ayrshare API request failed ({response.status_code}): {error_message}",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
if response_data.get("status") != "success":
|
||||
raise AyrshareAPIException(
|
||||
f"Ayrshare API returned error: {response_data.get('message', 'Unknown error')}",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
# Return the first post from the response
|
||||
return PostResponse(**response_data["posts"][0])
|
||||
@@ -1,531 +0,0 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.blocks.ayrshare._api import (
|
||||
AyrshareClient,
|
||||
PostError,
|
||||
PostResponse,
|
||||
SocialPlatform,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
creads_store = IntegrationCredentialsStore()
|
||||
|
||||
|
||||
class RequestOutput(BaseModel):
|
||||
"""Base output model for Ayrshare social media posts."""
|
||||
|
||||
status: str = Field(..., description="Status of the post")
|
||||
id: str = Field(..., description="ID of the post")
|
||||
refId: str = Field(..., description="Reference ID of the post")
|
||||
profileTitle: str = Field(..., description="Title of the profile")
|
||||
post: str = Field(..., description="The post text")
|
||||
postIds: Optional[List[dict]] = Field(
|
||||
description="IDs of the posts on each platform"
|
||||
)
|
||||
scheduleDate: Optional[str] = Field(description="Scheduled date of the post")
|
||||
errors: Optional[List[str]] = Field(description="Any errors that occurred")
|
||||
|
||||
|
||||
class AyrsharePostBlockBase(Block):
|
||||
"""Base class for Ayrshare social media posting blocks."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
"""Base input model for Ayrshare social media posts."""
|
||||
|
||||
post: str = SchemaField(
|
||||
description="The post text to be published", default="", advanced=False
|
||||
)
|
||||
media_urls: List[str] = SchemaField(
|
||||
description="Optional list of media URLs to include. Set is_video in advanced settings to true if you want to upload videos.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video", default=False, advanced=True
|
||||
)
|
||||
schedule_date: Optional[datetime] = SchemaField(
|
||||
description="UTC datetime for scheduling (YYYY-MM-DDThh:mm:ssZ)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
disable_comments: bool = SchemaField(
|
||||
description="Whether to disable comments", default=False, advanced=True
|
||||
)
|
||||
shorten_links: bool = SchemaField(
|
||||
description="Whether to shorten links", default=False, advanced=True
|
||||
)
|
||||
|
||||
unsplash: Optional[str] = SchemaField(
|
||||
description="Unsplash image configuration", default=None, advanced=True
|
||||
)
|
||||
requires_approval: bool = SchemaField(
|
||||
description="Whether to enable approval workflow",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
random_post: bool = SchemaField(
|
||||
description="Whether to generate random post text",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
random_media_url: bool = SchemaField(
|
||||
description="Whether to generate random media", default=False, advanced=True
|
||||
)
|
||||
notes: Optional[str] = SchemaField(
|
||||
description="Additional notes for the post", default=None, advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: RequestOutput = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id="b3a7b3b9-5169-410a-9d5c-fd625460fb14",
|
||||
description="Ayrshare Post",
|
||||
):
|
||||
super().__init__(
|
||||
# The unique identifier for the block, this value will be persisted in the DB.
|
||||
# It should be unique and constant across the application run.
|
||||
# Use the UUID format for the ID.
|
||||
id=id,
|
||||
# The description of the block, explaining what the block does.
|
||||
description=description,
|
||||
# The set of categories that the block belongs to.
|
||||
# Each category is an instance of BlockCategory Enum.
|
||||
categories={BlockCategory.SOCIAL},
|
||||
# The type of block, this is used to determine the block type in the UI.
|
||||
block_type=BlockType.AYRSHARE,
|
||||
# The schema, defined as a Pydantic model, for the input data.
|
||||
input_schema=AyrsharePostBlockBase.Input,
|
||||
# The schema, defined as a Pydantic model, for the output data.
|
||||
output_schema=AyrsharePostBlockBase.Output,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_client():
|
||||
return AyrshareClient()
|
||||
|
||||
def _create_post(
|
||||
self,
|
||||
input_data: "AyrsharePostBlockBase.Input",
|
||||
platforms: List[SocialPlatform],
|
||||
profile_key: Optional[str] = None,
|
||||
) -> PostResponse | PostError:
|
||||
client = self.create_client()
|
||||
"""Create a post on the specified platforms."""
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
response = client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=platforms,
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
profile_key=profile_key,
|
||||
)
|
||||
return response
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: "AyrsharePostBlockBase.Input",
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Run the block."""
|
||||
platforms = [SocialPlatform.FACEBOOK]
|
||||
|
||||
if not profile_key:
|
||||
yield "error", "Please Link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
post_result = self._create_post(
|
||||
input_data, platforms=platforms, profile_key=profile_key.get_secret_value()
|
||||
)
|
||||
if isinstance(post_result, PostError):
|
||||
yield "error", post_result.message
|
||||
return
|
||||
yield "post_result", post_result
|
||||
|
||||
|
||||
class PostToFacebookBlock(AyrsharePostBlockBase):
|
||||
"""Block for posting to Facebook."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3352f512-3524-49ed-a08f-003042da2fc1",
|
||||
description="Post to Facebook using Ayrshare",
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: AyrsharePostBlockBase.Input,
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Facebook."""
|
||||
if not profile_key:
|
||||
yield "error", "Please Link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
post_result = self._create_post(
|
||||
input_data,
|
||||
[SocialPlatform.FACEBOOK],
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
if isinstance(post_result, PostError):
|
||||
yield "error", post_result.message
|
||||
return
|
||||
yield "post_result", post_result
|
||||
|
||||
|
||||
class PostToXBlock(AyrsharePostBlockBase):
|
||||
"""Block for posting to X / Twitter."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9e8f844e-b4a5-4b25-80f2-9e1dd7d67625",
|
||||
description="Post to X / Twitter using Ayrshare",
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: AyrsharePostBlockBase.Input,
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Twitter."""
|
||||
if not profile_key:
|
||||
yield "error", "Please Link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
post_result = self._create_post(
|
||||
input_data,
|
||||
[SocialPlatform.TWITTER],
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
if isinstance(post_result, PostError):
|
||||
yield "error", post_result.message
|
||||
return
|
||||
yield "post_result", post_result
|
||||
|
||||
|
||||
class PostToLinkedInBlock(AyrsharePostBlockBase):
|
||||
"""Block for posting to LinkedIn."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="589af4e4-507f-42fd-b9ac-a67ecef25811",
|
||||
description="Post to LinkedIn using Ayrshare",
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: AyrsharePostBlockBase.Input,
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to LinkedIn."""
|
||||
if not profile_key:
|
||||
yield "error", "Please Link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
post_result = self._create_post(
|
||||
input_data,
|
||||
[SocialPlatform.LINKEDIN],
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
if isinstance(post_result, PostError):
|
||||
yield "error", post_result.message
|
||||
return
|
||||
yield "post_result", post_result
|
||||
|
||||
|
||||
class PostToInstagramBlock(AyrsharePostBlockBase):
|
||||
"""Block for posting to Instagram."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="89b02b96-a7cb-46f4-9900-c48b32fe1552",
|
||||
description="Post to Instagram using Ayrshare",
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: AyrsharePostBlockBase.Input,
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Instagram."""
|
||||
if not profile_key:
|
||||
yield "error", "Please Link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
post_result = self._create_post(
|
||||
input_data,
|
||||
[SocialPlatform.INSTAGRAM],
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
if isinstance(post_result, PostError):
|
||||
yield "error", post_result.message
|
||||
return
|
||||
yield "post_result", post_result
|
||||
|
||||
|
||||
class PostToYouTubeBlock(AyrsharePostBlockBase):
|
||||
"""Block for posting to YouTube."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0082d712-ff1b-4c3d-8a8d-6c7721883b83",
|
||||
description="Post to YouTube using Ayrshare",
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: AyrsharePostBlockBase.Input,
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to YouTube."""
|
||||
if not profile_key:
|
||||
yield "error", "Please Link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
post_result = self._create_post(
|
||||
input_data,
|
||||
[SocialPlatform.YOUTUBE],
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
if isinstance(post_result, PostError):
|
||||
yield "error", post_result.message
|
||||
return
|
||||
yield "post_result", post_result
|
||||
|
||||
|
||||
class PostToRedditBlock(AyrsharePostBlockBase):
|
||||
"""Block for posting to Reddit."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c7733580-3c72-483e-8e47-a8d58754d853",
|
||||
description="Post to Reddit using Ayrshare",
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: AyrsharePostBlockBase.Input,
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Reddit."""
|
||||
if not profile_key:
|
||||
yield "error", "Please Link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
post_result = self._create_post(
|
||||
input_data,
|
||||
[SocialPlatform.REDDIT],
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
if isinstance(post_result, PostError):
|
||||
yield "error", post_result.message
|
||||
return
|
||||
yield "post_result", post_result
|
||||
|
||||
|
||||
class PostToTelegramBlock(AyrsharePostBlockBase):
|
||||
"""Block for posting to Telegram."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="47bc74eb-4af2-452c-b933-af377c7287df",
|
||||
description="Post to Telegram using Ayrshare",
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: AyrsharePostBlockBase.Input,
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Telegram."""
|
||||
if not profile_key:
|
||||
yield "error", "Please Link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
post_result = self._create_post(
|
||||
input_data,
|
||||
[SocialPlatform.TELEGRAM],
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
if isinstance(post_result, PostError):
|
||||
yield "error", post_result.message
|
||||
return
|
||||
yield "post_result", post_result
|
||||
|
||||
|
||||
class PostToGMBBlock(AyrsharePostBlockBase):
|
||||
"""Block for posting to Google My Business."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2c38c783-c484-4503-9280-ef5d1d345a7e",
|
||||
description="Post to Google My Business using Ayrshare",
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: AyrsharePostBlockBase.Input,
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Google My Business."""
|
||||
if not profile_key:
|
||||
yield "error", "Please Link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
post_result = self._create_post(
|
||||
input_data,
|
||||
[SocialPlatform.GMB],
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
if isinstance(post_result, PostError):
|
||||
yield "error", post_result.message
|
||||
return
|
||||
yield "post_result", post_result
|
||||
|
||||
|
||||
class PostToPinterestBlock(AyrsharePostBlockBase):
|
||||
"""Block for posting to Pinterest."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3ca46e05-dbaa-4afb-9e95-5a429c4177e6",
|
||||
description="Post to Pinterest using Ayrshare",
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: AyrsharePostBlockBase.Input,
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Pinterest."""
|
||||
if not profile_key:
|
||||
yield "error", "Please Link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
post_result = self._create_post(
|
||||
input_data,
|
||||
[SocialPlatform.PINTEREST],
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
if isinstance(post_result, PostError):
|
||||
yield "error", post_result.message
|
||||
return
|
||||
yield "post_result", post_result
|
||||
|
||||
|
||||
class PostToTikTokBlock(AyrsharePostBlockBase):
|
||||
"""Block for posting to TikTok."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7faf4b27-96b0-4f05-bf64-e0de54ae74e1",
|
||||
description="Post to TikTok using Ayrshare",
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: AyrsharePostBlockBase.Input,
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to TikTok."""
|
||||
if not profile_key:
|
||||
yield "error", "Please Link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
post_result = self._create_post(
|
||||
input_data,
|
||||
[SocialPlatform.TIKTOK],
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
if isinstance(post_result, PostError):
|
||||
yield "error", post_result.message
|
||||
return
|
||||
yield "post_result", post_result
|
||||
|
||||
|
||||
class PostToBlueskyBlock(AyrsharePostBlockBase):
|
||||
"""Block for posting to Bluesky."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="cbd52c2a-06d2-43ed-9560-6576cc163283",
|
||||
description="Post to Bluesky using Ayrshare",
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: AyrsharePostBlockBase.Input,
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Bluesky."""
|
||||
if not profile_key:
|
||||
yield "error", "Please Link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
post_result = self._create_post(
|
||||
input_data,
|
||||
[SocialPlatform.BLUESKY],
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
if isinstance(post_result, PostError):
|
||||
yield "error", post_result.message
|
||||
return
|
||||
yield "post_result", post_result
|
||||
|
||||
|
||||
AYRSHARE_NODE_IDS = [
|
||||
PostToBlueskyBlock().id,
|
||||
PostToFacebookBlock().id,
|
||||
PostToXBlock().id,
|
||||
PostToLinkedInBlock().id,
|
||||
PostToInstagramBlock().id,
|
||||
PostToYouTubeBlock().id,
|
||||
PostToRedditBlock().id,
|
||||
PostToTelegramBlock().id,
|
||||
PostToGMBBlock().id,
|
||||
PostToPinterestBlock().id,
|
||||
PostToTikTokBlock().id,
|
||||
]
|
||||
@@ -1,19 +1,30 @@
|
||||
from typing import overload
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from backend.blocks.github._auth import (
|
||||
GithubCredentials,
|
||||
GithubFineGrainedAPICredentials,
|
||||
)
|
||||
from backend.util.request import Requests
|
||||
from backend.util.request import URL, Requests
|
||||
|
||||
|
||||
def _convert_to_api_url(url: str) -> str:
|
||||
@overload
|
||||
def _convert_to_api_url(url: str) -> str: ...
|
||||
|
||||
|
||||
@overload
|
||||
def _convert_to_api_url(url: URL) -> URL: ...
|
||||
|
||||
|
||||
def _convert_to_api_url(url: str | URL) -> str | URL:
|
||||
"""
|
||||
Converts a standard GitHub URL to the corresponding GitHub API URL.
|
||||
Handles repository URLs, issue URLs, pull request URLs, and more.
|
||||
"""
|
||||
parsed_url = urlparse(url)
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
if url_as_str := isinstance(url, str):
|
||||
url = urlparse(url)
|
||||
|
||||
path_parts = url.path.strip("/").split("/")
|
||||
|
||||
if len(path_parts) >= 2:
|
||||
owner, repo = path_parts[0], path_parts[1]
|
||||
@@ -28,7 +39,7 @@ def _convert_to_api_url(url: str) -> str:
|
||||
else:
|
||||
raise ValueError("Invalid GitHub URL format.")
|
||||
|
||||
return api_url
|
||||
return api_url if url_as_str else urlparse(api_url)
|
||||
|
||||
|
||||
def _get_headers(credentials: GithubCredentials) -> dict[str, str]:
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from io import BufferedReader
|
||||
from typing import Any
|
||||
|
||||
from requests.exceptions import HTTPError, RequestException
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
from backend.util.request import requests
|
||||
|
||||
logger = logging.getLogger(name=__name__)
|
||||
@@ -47,10 +45,6 @@ class SendWebRequestBlock(Block):
|
||||
description="The body of the request",
|
||||
default=None,
|
||||
)
|
||||
files: dict[str, MediaFileType] = SchemaField(
|
||||
description="File fields mapping to MediaFileType for multipart upload",
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: object = SchemaField(description="The response from the server")
|
||||
@@ -67,7 +61,7 @@ class SendWebRequestBlock(Block):
|
||||
output_schema=SendWebRequestBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *, graph_exec_id: str, **kwargs) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
body = input_data.body
|
||||
|
||||
if input_data.json_format:
|
||||
@@ -80,31 +74,11 @@ class SendWebRequestBlock(Block):
|
||||
# we should send it as plain text instead
|
||||
input_data.json_format = False
|
||||
|
||||
# Prepare files for multipart upload using store_media_file
|
||||
files: dict[str, BufferedReader] = {}
|
||||
if input_data.files:
|
||||
for field_name, media in input_data.files.items():
|
||||
try:
|
||||
rel_path = store_media_file(
|
||||
graph_exec_id, media, return_content=False
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, rel_path)
|
||||
files[field_name] = open(abs_path, "rb")
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to prepare file '{field_name}': {e}"
|
||||
for f in files.values():
|
||||
try:
|
||||
f.close()
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
try:
|
||||
response = requests.request(
|
||||
input_data.method.value,
|
||||
input_data.url,
|
||||
headers=input_data.headers,
|
||||
files=files if files else None,
|
||||
json=body if input_data.json_format else None,
|
||||
data=body if not input_data.json_format else None,
|
||||
)
|
||||
@@ -145,11 +119,3 @@ class SendWebRequestBlock(Block):
|
||||
except Exception as e:
|
||||
# Catch any other unexpected exceptions
|
||||
yield "error", str(e)
|
||||
|
||||
finally:
|
||||
# ensure cleanup of file handles
|
||||
for f in files.values():
|
||||
try:
|
||||
f.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -101,6 +101,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
|
||||
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
|
||||
@@ -184,6 +186,12 @@ MODEL_METADATA = {
|
||||
), # gpt-4-turbo-2024-04-09
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-4-opus-20250514
|
||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-4-sonnet-20250514
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-3-7-sonnet-20250219
|
||||
|
||||
@@ -124,8 +124,10 @@ class AddMemoryBlock(Block, Mem0Base):
|
||||
|
||||
if isinstance(input_data.content, Conversation):
|
||||
messages = input_data.content.messages
|
||||
elif isinstance(input_data.content, Content):
|
||||
messages = [{"role": "user", "content": input_data.content.content}]
|
||||
else:
|
||||
messages = [{"role": "user", "content": input_data.content}]
|
||||
messages = [{"role": "user", "content": str(input_data.content)}]
|
||||
|
||||
params = {
|
||||
"user_id": user_id,
|
||||
@@ -152,7 +154,7 @@ class AddMemoryBlock(Block, Mem0Base):
|
||||
yield "action", "NO_CHANGE"
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(object=e)
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class SearchMemoryBlock(Block, Mem0Base):
|
||||
|
||||
@@ -53,7 +53,6 @@ class BlockType(Enum):
|
||||
WEBHOOK_MANUAL = "Webhook (manual)"
|
||||
AGENT = "Agent"
|
||||
AI = "AI"
|
||||
AYRSHARE = "Ayrshare"
|
||||
|
||||
|
||||
class BlockCategory(Enum):
|
||||
@@ -77,7 +76,6 @@ class BlockCategory(Enum):
|
||||
PRODUCTIVITY = "Block that helps with productivity"
|
||||
ISSUE_TRACKING = "Block that helps with issue tracking"
|
||||
MULTIMEDIA = "Block that interacts with multimedia content"
|
||||
MARKETING = "Block that helps with marketing"
|
||||
|
||||
def dict(self) -> dict[str, str]:
|
||||
return {"category": self.name, "description": self.value}
|
||||
|
||||
@@ -47,6 +47,8 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.GPT4O: 3,
|
||||
LlmModel.GPT4_TURBO: 10,
|
||||
LlmModel.GPT3_5_TURBO: 1,
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_5_SONNET: 4,
|
||||
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
|
||||
|
||||
@@ -15,6 +15,7 @@ from typing import (
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
get_args,
|
||||
)
|
||||
@@ -36,7 +37,6 @@ from pydantic_core import (
|
||||
ValidationError,
|
||||
core_schema,
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
@@ -260,32 +260,15 @@ class OAuthState(BaseModel):
|
||||
|
||||
class UserMetadata(BaseModel):
|
||||
integration_credentials: list[Credentials] = Field(default_factory=list)
|
||||
"""⚠️ Deprecated; use `UserIntegrations.credentials` instead"""
|
||||
integration_oauth_states: list[OAuthState] = Field(default_factory=list)
|
||||
"""⚠️ Deprecated; use `UserIntegrations.oauth_states` instead"""
|
||||
|
||||
|
||||
class UserMetadataRaw(TypedDict, total=False):
|
||||
integration_credentials: list[dict]
|
||||
"""⚠️ Deprecated; use `UserIntegrations.credentials` instead"""
|
||||
integration_oauth_states: list[dict]
|
||||
"""⚠️ Deprecated; use `UserIntegrations.oauth_states` instead"""
|
||||
|
||||
|
||||
class UserIntegrations(BaseModel):
|
||||
|
||||
class ManagedCredentials(BaseModel):
|
||||
"""Integration credentials managed by us, rather than by the user"""
|
||||
|
||||
ayrshare_profile_key: Optional[SecretStr] = None
|
||||
|
||||
@field_serializer("*")
|
||||
def dump_secret_strings(value: Any, _info):
|
||||
if isinstance(value, SecretStr):
|
||||
return value.get_secret_value()
|
||||
return value
|
||||
|
||||
managed_credentials: ManagedCredentials = Field(default_factory=ManagedCredentials)
|
||||
credentials: list[Credentials] = Field(default_factory=list)
|
||||
oauth_states: list[OAuthState] = Field(default_factory=list)
|
||||
|
||||
|
||||
@@ -38,7 +38,6 @@ from autogpt_libs.utils.cache import thread_cached
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.ayrshare.post import AYRSHARE_NODE_IDS
|
||||
from backend.data import redis
|
||||
from backend.data.block import BlockData, BlockInput, BlockSchema, get_block
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
@@ -68,7 +67,7 @@ from backend.util.decorator import error_logged, time_measured
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.retry import continuous_retry, func_retry
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -218,10 +217,6 @@ def execute_node(
|
||||
credentials, creds_lock = creds_manager.acquire(user_id, credentials_meta.id)
|
||||
extra_exec_kwargs[field_name] = credentials
|
||||
|
||||
if node_block.id in AYRSHARE_NODE_IDS:
|
||||
profile_key = creds_manager.store.get_ayrshare_profile_key(user_id)
|
||||
extra_exec_kwargs["profile_key"] = profile_key
|
||||
|
||||
output_size = 0
|
||||
try:
|
||||
outputs: dict[str, Any] = {}
|
||||
@@ -943,8 +938,6 @@ class ExecutionManager(AppProcess):
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.running = True
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
signal.signal(signal.SIGTERM, lambda sig, frame: self._on_sigterm())
|
||||
signal.signal(signal.SIGINT, lambda sig, frame: self._on_sigterm())
|
||||
|
||||
def run(self):
|
||||
pool_size_gauge.set(self.pool_size)
|
||||
@@ -970,22 +963,29 @@ class ExecutionManager(AppProcess):
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
|
||||
redis.connect()
|
||||
|
||||
threading.Thread(
|
||||
target=lambda: self._consume_execution_cancel(),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
self._consume_execution_run()
|
||||
|
||||
@continuous_retry()
|
||||
def _consume_execution_cancel(self):
|
||||
cancel_client = SyncRabbitMQ(create_execution_queue_config())
|
||||
cancel_client.connect()
|
||||
cancel_channel = cancel_client.get_channel()
|
||||
logger.info(f"[{self.service_name}] ⏳ Starting cancel message consumer...")
|
||||
threading.Thread(
|
||||
target=lambda: (
|
||||
cancel_channel.basic_consume(
|
||||
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
on_message_callback=self._handle_cancel_message,
|
||||
auto_ack=True,
|
||||
),
|
||||
cancel_channel.start_consuming(),
|
||||
),
|
||||
daemon=True,
|
||||
).start()
|
||||
cancel_channel.basic_consume(
|
||||
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
on_message_callback=self._handle_cancel_message,
|
||||
auto_ack=True,
|
||||
)
|
||||
cancel_channel.start_consuming()
|
||||
raise RuntimeError(f"❌ cancel message consumer is stopped: {cancel_channel}")
|
||||
|
||||
@continuous_retry()
|
||||
def _consume_execution_run(self):
|
||||
run_client = SyncRabbitMQ(create_execution_queue_config())
|
||||
run_client.connect()
|
||||
run_channel = run_client.get_channel()
|
||||
@@ -997,6 +997,7 @@ class ExecutionManager(AppProcess):
|
||||
)
|
||||
logger.info(f"[{self.service_name}] ⏳ Starting to consume run messages...")
|
||||
run_channel.start_consuming()
|
||||
raise RuntimeError(f"❌ run message consumer is stopped: {run_channel}")
|
||||
|
||||
def _handle_cancel_message(
|
||||
self,
|
||||
@@ -1095,10 +1096,6 @@ class ExecutionManager(AppProcess):
|
||||
super().cleanup()
|
||||
self._on_cleanup()
|
||||
|
||||
def _on_sigterm(self):
|
||||
llprint(f"[{self.service_name}] ⚠️ GraphExec SIGTERM received")
|
||||
self._on_cleanup(log=llprint)
|
||||
|
||||
def _on_cleanup(self, log=logger.info):
|
||||
prefix = f"[{self.service_name}][on_graph_executor_stop {os.getpid()}]"
|
||||
log(f"{prefix} ⏳ Shutting down service loop...")
|
||||
@@ -1115,7 +1112,7 @@ class ExecutionManager(AppProcess):
|
||||
redis.disconnect()
|
||||
|
||||
log(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
exit(0)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
@@ -178,7 +177,6 @@ zerobounce_credentials = APIKeyCredentials(
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
llama_api_credentials = APIKeyCredentials(
|
||||
id="d44045af-1c33-4833-9e19-752313214de2",
|
||||
provider="llama_api",
|
||||
@@ -226,8 +224,6 @@ class IntegrationCredentialsStore:
|
||||
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
# =============== USER-MANAGED CREDENTIALS =============== #
|
||||
|
||||
def add_creds(self, user_id: str, credentials: Credentials) -> None:
|
||||
with self.locked_user_integrations(user_id):
|
||||
if self.get_creds_by_id(user_id, credentials.id):
|
||||
@@ -286,8 +282,6 @@ class IntegrationCredentialsStore:
|
||||
all_credentials.append(zerobounce_credentials)
|
||||
if settings.secrets.google_maps_api_key:
|
||||
all_credentials.append(google_maps_credentials)
|
||||
if settings.secrets.llama_api_key:
|
||||
all_credentials.append(llama_api_credentials)
|
||||
return all_credentials
|
||||
|
||||
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
|
||||
@@ -343,19 +337,6 @@ class IntegrationCredentialsStore:
|
||||
]
|
||||
self._set_user_integration_creds(user_id, filtered_credentials)
|
||||
|
||||
# ============== SYSTEM-MANAGED CREDENTIALS ============== #
|
||||
|
||||
def get_ayrshare_profile_key(self, user_id: str) -> SecretStr | None:
|
||||
managed_user_creds = self._get_user_integrations(user_id).managed_credentials
|
||||
return managed_user_creds.ayrshare_profile_key
|
||||
|
||||
def set_ayrshare_profile_key(self, user_id: str, profile_key: str) -> None:
|
||||
_profile_key = SecretStr(profile_key)
|
||||
with self.edit_user_integrations(user_id) as user_integrations:
|
||||
user_integrations.managed_credentials.ayrshare_profile_key = _profile_key
|
||||
|
||||
# ===================== OAUTH STATES ===================== #
|
||||
|
||||
def store_state_token(
|
||||
self, user_id: str, provider: str, scopes: list[str], use_pkce: bool = False
|
||||
) -> tuple[str, str]:
|
||||
@@ -372,8 +353,16 @@ class IntegrationCredentialsStore:
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
with self.edit_user_integrations(user_id) as user_integrations:
|
||||
user_integrations.oauth_states.append(state)
|
||||
with self.locked_user_integrations(user_id):
|
||||
|
||||
user_integrations = self._get_user_integrations(user_id)
|
||||
oauth_states = user_integrations.oauth_states
|
||||
oauth_states.append(state)
|
||||
user_integrations.oauth_states = oauth_states
|
||||
|
||||
self.db_manager.update_user_integrations(
|
||||
user_id=user_id, data=user_integrations
|
||||
)
|
||||
|
||||
return token, code_challenge
|
||||
|
||||
@@ -415,17 +404,6 @@ class IntegrationCredentialsStore:
|
||||
|
||||
return None
|
||||
|
||||
# =================== GET/SET HELPERS =================== #
|
||||
|
||||
@contextmanager
|
||||
def edit_user_integrations(self, user_id: str):
|
||||
with self.locked_user_integrations(user_id):
|
||||
user_integrations = self._get_user_integrations(user_id)
|
||||
yield user_integrations # yield to allow edits
|
||||
self.db_manager.update_user_integrations(
|
||||
user_id=user_id, data=user_integrations
|
||||
)
|
||||
|
||||
def _set_user_integration_creds(
|
||||
self, user_id: str, credentials: list[Credentials]
|
||||
) -> None:
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.status import HTTP_404_NOT_FOUND
|
||||
|
||||
from backend.blocks.ayrshare._api import AyrshareClient, PostError, SocialPlatform
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.data.integrations import (
|
||||
WebhookEvent,
|
||||
@@ -418,72 +416,3 @@ def _get_provider_oauth_handler(
|
||||
client_secret=client_secret,
|
||||
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/ayrshare/sso_url")
|
||||
async def get_ayrshare_sso_url(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Generate an SSO URL for Ayrshare social media integration.
|
||||
|
||||
Returns:
|
||||
dict: Contains the SSO URL for Ayrshare integration
|
||||
"""
|
||||
# Generate JWT and get SSO URL
|
||||
client = AyrshareClient()
|
||||
|
||||
# Get or create profile key
|
||||
profile_key = creds_manager.store.get_ayrshare_profile_key(user_id)
|
||||
if not profile_key:
|
||||
logger.info(f"Creating new Ayrshare profile for user {user_id}")
|
||||
# Create new profile if none exists
|
||||
profile = client.create_profile(title=f"User {user_id}", messaging_active=True)
|
||||
if isinstance(profile, PostError):
|
||||
logger.error(
|
||||
f"Error creating Ayrshare profile for user {user_id}: {profile}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to create Ayrshare profile"
|
||||
)
|
||||
profile_key = profile.profileKey
|
||||
creds_manager.store.set_ayrshare_profile_key(user_id, profile_key)
|
||||
else:
|
||||
logger.info(f"Using existing Ayrshare profile for user {user_id}")
|
||||
|
||||
# Convert SecretStr to string if needed
|
||||
profile_key_str = (
|
||||
profile_key.get_secret_value()
|
||||
if isinstance(profile_key, SecretStr)
|
||||
else str(profile_key)
|
||||
)
|
||||
|
||||
private_key = settings.secrets.ayrshare_jwt_key
|
||||
|
||||
try:
|
||||
logger.info(f"Generating JWT for user {user_id}")
|
||||
jwt_response = client.generate_jwt(
|
||||
private_key=private_key,
|
||||
profile_key=profile_key_str,
|
||||
allowed_social=[
|
||||
SocialPlatform.FACEBOOK,
|
||||
SocialPlatform.TWITTER,
|
||||
SocialPlatform.LINKEDIN,
|
||||
SocialPlatform.INSTAGRAM,
|
||||
SocialPlatform.YOUTUBE,
|
||||
SocialPlatform.REDDIT,
|
||||
SocialPlatform.TELEGRAM,
|
||||
SocialPlatform.GMB,
|
||||
SocialPlatform.PINTEREST,
|
||||
SocialPlatform.TIKTOK,
|
||||
SocialPlatform.BLUESKY,
|
||||
],
|
||||
expires_in=2880,
|
||||
verify=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating JWT for user {user_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to generate JWT")
|
||||
|
||||
expire_at = datetime.now(timezone.utc) + timedelta(minutes=2880)
|
||||
return {"sso_url": jwt_response.url, "expire_at": expire_at.isoformat()}
|
||||
|
||||
@@ -84,7 +84,7 @@ if TYPE_CHECKING:
|
||||
|
||||
@thread_cached
|
||||
def execution_scheduler_client() -> scheduler.SchedulerClient:
|
||||
return get_service_client(scheduler.SchedulerClient)
|
||||
return get_service_client(scheduler.SchedulerClient, health_check=False)
|
||||
|
||||
|
||||
@thread_cached
|
||||
|
||||
@@ -82,7 +82,6 @@ async def test_get_library_agents(mocker):
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.skip(reason="Test not implemented yet")
|
||||
async def test_add_agent_to_library(mocker):
|
||||
await connect()
|
||||
# Mock data
|
||||
|
||||
@@ -2,8 +2,9 @@ import ipaddress
|
||||
import re
|
||||
import socket
|
||||
import ssl
|
||||
from typing import Callable
|
||||
from urllib.parse import quote, urljoin, urlparse, urlunparse
|
||||
from typing import Callable, Optional
|
||||
from urllib.parse import ParseResult as URL
|
||||
from urllib.parse import quote, urljoin, urlparse
|
||||
|
||||
import idna
|
||||
import requests as req
|
||||
@@ -44,17 +45,15 @@ def _is_ip_blocked(ip: str) -> bool:
|
||||
return any(ip_addr in network for network in BLOCKED_IP_NETWORKS)
|
||||
|
||||
|
||||
def _remove_insecure_headers(headers: dict, old_url: str, new_url: str) -> dict:
|
||||
def _remove_insecure_headers(headers: dict, old_url: URL, new_url: URL) -> dict:
|
||||
"""
|
||||
Removes sensitive headers (Authorization, Proxy-Authorization, Cookie)
|
||||
if the scheme/host/port of new_url differ from old_url.
|
||||
"""
|
||||
old_parsed = urlparse(old_url)
|
||||
new_parsed = urlparse(new_url)
|
||||
if (
|
||||
(old_parsed.scheme != new_parsed.scheme)
|
||||
or (old_parsed.hostname != new_parsed.hostname)
|
||||
or (old_parsed.port != new_parsed.port)
|
||||
(old_url.scheme != new_url.scheme)
|
||||
or (old_url.hostname != new_url.hostname)
|
||||
or (old_url.port != new_url.port)
|
||||
):
|
||||
headers.pop("Authorization", None)
|
||||
headers.pop("Proxy-Authorization", None)
|
||||
@@ -81,19 +80,16 @@ class HostSSLAdapter(HTTPAdapter):
|
||||
)
|
||||
|
||||
|
||||
def validate_url(
|
||||
url: str,
|
||||
trusted_origins: list[str],
|
||||
enable_dns_rebinding: bool = True,
|
||||
) -> tuple[str, str]:
|
||||
def validate_url(url: str, trusted_origins: list[str]) -> tuple[URL, bool, list[str]]:
|
||||
"""
|
||||
Validates the URL to prevent SSRF attacks by ensuring it does not point
|
||||
to a private, link-local, or otherwise blocked IP address — unless
|
||||
the hostname is explicitly trusted.
|
||||
|
||||
Returns a tuple of:
|
||||
- pinned_url: a URL that has the netloc replaced with the validated IP
|
||||
- ascii_hostname: the original ASCII hostname (IDNA-decoded) for use in the Host header
|
||||
Returns:
|
||||
str: The validated, canonicalized, parsed URL
|
||||
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
||||
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
||||
"""
|
||||
# Canonicalize URL
|
||||
url = url.strip("/ ").replace("\\", "/")
|
||||
@@ -122,45 +118,56 @@ def validate_url(
|
||||
if not HOSTNAME_REGEX.match(ascii_hostname):
|
||||
raise ValueError("Hostname contains invalid characters.")
|
||||
|
||||
# If hostname is trusted, skip IP-based checks but still return pinned URL
|
||||
if ascii_hostname in trusted_origins:
|
||||
pinned_netloc = ascii_hostname
|
||||
if parsed.port:
|
||||
pinned_netloc += f":{parsed.port}"
|
||||
# Check if hostname is trusted
|
||||
is_trusted = ascii_hostname in trusted_origins
|
||||
|
||||
pinned_url = urlunparse(
|
||||
(
|
||||
parsed.scheme,
|
||||
pinned_netloc,
|
||||
quote(parsed.path, safe="/%:@"),
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment,
|
||||
)
|
||||
)
|
||||
return pinned_url, ascii_hostname
|
||||
# If not trusted, validate IP addresses
|
||||
ip_addresses: list[str] = []
|
||||
if not is_trusted:
|
||||
# Resolve all IP addresses for the hostname
|
||||
ip_addresses = _resolve_host(ascii_hostname)
|
||||
|
||||
# Resolve all IP addresses for the hostname
|
||||
try:
|
||||
ip_list = [str(res[4][0]) for res in socket.getaddrinfo(ascii_hostname, None)]
|
||||
ipv4 = [ip for ip in ip_list if ":" not in ip]
|
||||
ipv6 = [ip for ip in ip_list if ":" in ip]
|
||||
ip_addresses = ipv4 + ipv6 # Prefer IPv4 over IPv6
|
||||
except socket.gaierror:
|
||||
raise ValueError(f"Unable to resolve IP address for hostname {ascii_hostname}")
|
||||
# Block any IP address that belongs to a blocked range
|
||||
for ip_str in ip_addresses:
|
||||
if _is_ip_blocked(ip_str):
|
||||
raise ValueError(
|
||||
f"Access to blocked or private IP address {ip_str} "
|
||||
f"for hostname {ascii_hostname} is not allowed."
|
||||
)
|
||||
|
||||
return (
|
||||
URL(
|
||||
parsed.scheme,
|
||||
ascii_hostname,
|
||||
quote(parsed.path, safe="/%:@"),
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment,
|
||||
),
|
||||
is_trusted,
|
||||
ip_addresses,
|
||||
)
|
||||
|
||||
|
||||
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
|
||||
"""
|
||||
Pins a URL to a specific IP address to prevent DNS rebinding attacks.
|
||||
|
||||
Args:
|
||||
url: The original URL
|
||||
ip_addresses: List of IP addresses corresponding to the URL's host
|
||||
|
||||
Returns:
|
||||
pinned_url: The URL with hostname replaced with IP address
|
||||
"""
|
||||
if not url.hostname:
|
||||
raise ValueError(f"URL has no hostname: {url}")
|
||||
|
||||
if not ip_addresses:
|
||||
raise ValueError(f"No IP addresses found for {ascii_hostname}")
|
||||
# Resolve all IP addresses for the hostname
|
||||
ip_addresses = _resolve_host(url.hostname)
|
||||
|
||||
# Block any IP address that belongs to a blocked range
|
||||
for ip_str in ip_addresses:
|
||||
if _is_ip_blocked(ip_str):
|
||||
raise ValueError(
|
||||
f"Access to blocked or private IP address {ip_str} "
|
||||
f"for hostname {ascii_hostname} is not allowed."
|
||||
)
|
||||
|
||||
# Pin to the first valid IP (for SSRF defense).
|
||||
# Pin to the first valid IP (for SSRF defense)
|
||||
pinned_ip = ip_addresses[0]
|
||||
|
||||
# If it's IPv6, bracket it
|
||||
@@ -169,24 +176,31 @@ def validate_url(
|
||||
else:
|
||||
pinned_netloc = pinned_ip
|
||||
|
||||
if parsed.port:
|
||||
pinned_netloc += f":{parsed.port}"
|
||||
if url.port:
|
||||
pinned_netloc += f":{url.port}"
|
||||
|
||||
if not enable_dns_rebinding:
|
||||
pinned_netloc = ascii_hostname
|
||||
|
||||
pinned_url = urlunparse(
|
||||
(
|
||||
parsed.scheme,
|
||||
pinned_netloc,
|
||||
quote(parsed.path, safe="/%:@"),
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment,
|
||||
)
|
||||
return URL(
|
||||
url.scheme,
|
||||
pinned_netloc,
|
||||
url.path,
|
||||
url.params,
|
||||
url.query,
|
||||
url.fragment,
|
||||
)
|
||||
|
||||
return pinned_url, ascii_hostname # (pinned_url, original_hostname)
|
||||
|
||||
def _resolve_host(hostname: str) -> list[str]:
|
||||
try:
|
||||
ip_list = [str(res[4][0]) for res in socket.getaddrinfo(hostname, None)]
|
||||
ipv4 = [ip for ip in ip_list if ":" not in ip]
|
||||
ipv6 = [ip for ip in ip_list if ":" in ip]
|
||||
ip_addresses = ipv4 + ipv6 # Prefer IPv4 over IPv6
|
||||
except socket.gaierror:
|
||||
raise ValueError(f"Unable to resolve IP address for hostname {hostname}")
|
||||
|
||||
if not ip_addresses:
|
||||
raise ValueError(f"No IP addresses found for {hostname}")
|
||||
return ip_addresses
|
||||
|
||||
|
||||
class Requests:
|
||||
@@ -200,7 +214,7 @@ class Requests:
|
||||
self,
|
||||
trusted_origins: list[str] | None = None,
|
||||
raise_for_status: bool = True,
|
||||
extra_url_validator: Callable[[str], str] | None = None,
|
||||
extra_url_validator: Callable[[URL], URL] | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
):
|
||||
self.trusted_origins = []
|
||||
@@ -224,12 +238,18 @@ class Requests:
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> req.Response:
|
||||
# Validate URL and get trust status
|
||||
url, is_trusted, ip_addresses = validate_url(url, self.trusted_origins)
|
||||
|
||||
# Apply any extra user-defined validation/transformation
|
||||
if self.extra_url_validator is not None:
|
||||
url = self.extra_url_validator(url)
|
||||
|
||||
# Validate URL and get pinned URL + hostname
|
||||
pinned_url, hostname = validate_url(url, self.trusted_origins)
|
||||
# Pin the URL if untrusted
|
||||
hostname = url.hostname
|
||||
original_url = url.geturl()
|
||||
if not is_trusted:
|
||||
url = pin_url(url, ip_addresses)
|
||||
|
||||
# Merge any extra headers
|
||||
headers = dict(headers) if headers else {}
|
||||
@@ -240,27 +260,30 @@ class Requests:
|
||||
|
||||
# If untrusted, the hostname in the URL is replaced with the corresponding
|
||||
# IP address, and we need to override the Host header with the actual hostname.
|
||||
if (pinned := urlparse(pinned_url)).hostname != hostname:
|
||||
if url.hostname != hostname:
|
||||
headers["Host"] = hostname
|
||||
|
||||
# If hostname was untrusted and we replaced it by (pinned it to) its IP,
|
||||
# we also need to attach a custom SNI adapter to make SSL work:
|
||||
mount_prefix = f"{pinned.scheme}://{pinned.hostname}"
|
||||
if pinned.port:
|
||||
mount_prefix += f":{pinned.port}"
|
||||
adapter = HostSSLAdapter(ssl_hostname=hostname)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
# Perform the request with redirects disabled for manual handling
|
||||
response = session.request(
|
||||
method,
|
||||
pinned_url,
|
||||
url.geturl(),
|
||||
headers=headers,
|
||||
allow_redirects=False,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Replace response URLs with the original host for clearer error messages
|
||||
if url.hostname != hostname:
|
||||
response.url = original_url
|
||||
if response.request is not None:
|
||||
response.request.url = original_url
|
||||
|
||||
if self.raise_for_status:
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -275,13 +298,13 @@ class Requests:
|
||||
|
||||
# The base URL is the pinned_url we just used
|
||||
# so that relative redirects resolve correctly.
|
||||
new_url = urljoin(pinned_url, location)
|
||||
redirect_url = urlparse(urljoin(url.geturl(), location))
|
||||
# Carry forward the same headers but update Host
|
||||
new_headers = _remove_insecure_headers(dict(headers), url, new_url)
|
||||
new_headers = _remove_insecure_headers(headers, url, redirect_url)
|
||||
|
||||
return self.request(
|
||||
method,
|
||||
new_url,
|
||||
redirect_url.geturl(),
|
||||
headers=new_headers,
|
||||
allow_redirects=allow_redirects,
|
||||
max_redirects=max_redirects - 1,
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -80,3 +81,24 @@ func_retry = retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=30),
|
||||
)
|
||||
|
||||
|
||||
def continuous_retry(*, retry_delay: float = 1.0):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
while True:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"%s failed with %s — retrying in %.2f s",
|
||||
func.__name__,
|
||||
exc,
|
||||
retry_delay,
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -438,8 +438,7 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
apollo_api_key: str = Field(default="", description="Apollo API Key")
|
||||
smartlead_api_key: str = Field(default="", description="SmartLead API Key")
|
||||
zerobounce_api_key: str = Field(default="", description="ZeroBounce API Key")
|
||||
ayrshare_api_key: str = Field(default="", description="Ayrshare API Key")
|
||||
ayrshare_jwt_key: str = Field(default="", description="Ayrshare private Key")
|
||||
|
||||
# Add more secret fields as needed
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import pytest
|
||||
|
||||
from backend.util.request import validate_url
|
||||
from backend.util.request import pin_url, validate_url
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url, trusted_origins, expected_value, should_raise",
|
||||
"raw_url, trusted_origins, expected_value, should_raise",
|
||||
[
|
||||
# Rejected IP ranges
|
||||
("localhost", [], None, True),
|
||||
@@ -55,14 +55,14 @@ from backend.util.request import validate_url
|
||||
],
|
||||
)
|
||||
def test_validate_url_no_dns_rebinding(
|
||||
url, trusted_origins, expected_value, should_raise
|
||||
raw_url: str, trusted_origins: list[str], expected_value: str, should_raise: bool
|
||||
):
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError):
|
||||
validate_url(url, trusted_origins, enable_dns_rebinding=False)
|
||||
validate_url(raw_url, trusted_origins)
|
||||
else:
|
||||
url, host = validate_url(url, trusted_origins, enable_dns_rebinding=False)
|
||||
assert url == expected_value
|
||||
validated_url, _, _ = validate_url(raw_url, trusted_origins)
|
||||
assert validated_url.geturl() == expected_value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -79,7 +79,11 @@ def test_validate_url_no_dns_rebinding(
|
||||
],
|
||||
)
|
||||
def test_dns_rebinding_fix(
|
||||
monkeypatch, hostname, resolved_ips, expect_error, expected_ip
|
||||
monkeypatch,
|
||||
hostname: str,
|
||||
resolved_ips: list[str],
|
||||
expect_error: bool,
|
||||
expected_ip: str,
|
||||
):
|
||||
"""
|
||||
Tests that validate_url pins the first valid public IP address, and rejects
|
||||
@@ -96,11 +100,13 @@ def test_dns_rebinding_fix(
|
||||
if expect_error:
|
||||
# If any IP is blocked, we expect a ValueError
|
||||
with pytest.raises(ValueError):
|
||||
validate_url(hostname, [])
|
||||
url, _, ip_addresses = validate_url(hostname, [])
|
||||
pin_url(url, ip_addresses)
|
||||
else:
|
||||
pinned_url, ascii_hostname = validate_url(hostname, [])
|
||||
url, _, ip_addresses = validate_url(hostname, [])
|
||||
pinned_url = pin_url(url, ip_addresses).geturl()
|
||||
# The pinned_url should contain the first valid IP
|
||||
assert pinned_url.startswith("http://") or pinned_url.startswith("https://")
|
||||
assert expected_ip in pinned_url
|
||||
# The ascii_hostname should match our original hostname after IDNA encoding
|
||||
assert ascii_hostname == hostname
|
||||
# The unpinned URL's hostname should match our original IDNA encoded hostname
|
||||
assert url.hostname == hostname
|
||||
|
||||
@@ -39,7 +39,7 @@ import {
|
||||
} from "@/components/ui/dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
import { LoadingSpinner } from "@/components/ui/loading";
|
||||
import LoadingBox, { LoadingSpinner } from "@/components/ui/loading";
|
||||
|
||||
export default function AgentRunsPage(): React.ReactElement {
|
||||
const { id: agentID }: { id: LibraryAgentID } = useParams();
|
||||
@@ -357,8 +357,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
);
|
||||
|
||||
if (!agent || !graph) {
|
||||
/* TODO: implement loading indicators / skeleton page */
|
||||
return <span>Loading...</span>;
|
||||
return <LoadingBox className="h-[90vh]" />;
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -416,7 +415,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
agentActions={agentActions}
|
||||
/>
|
||||
)
|
||||
) : null) || <p>Loading...</p>}
|
||||
) : null) || <LoadingBox className="h-[70vh]" />}
|
||||
|
||||
<DeleteConfirmDialog
|
||||
entityType="agent"
|
||||
|
||||
@@ -53,7 +53,7 @@ import {
|
||||
CopyIcon,
|
||||
ExitIcon,
|
||||
} from "@radix-ui/react-icons";
|
||||
import { FaKey } from "react-icons/fa";
|
||||
|
||||
import useCredits from "@/hooks/useCredits";
|
||||
|
||||
export type ConnectionData = Array<{
|
||||
@@ -116,8 +116,6 @@ export const CustomNode = React.memo(
|
||||
const flowContext = useContext(FlowContext);
|
||||
const api = useBackendAPI();
|
||||
const { formatCredits } = useCredits();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
let nodeFlowId = "";
|
||||
|
||||
if (data.uiType === BlockUIType.AGENT) {
|
||||
@@ -251,55 +249,6 @@ export const CustomNode = React.memo(
|
||||
return renderHandles(schema.properties);
|
||||
};
|
||||
|
||||
const generateAyrshareSSOHandles = (
|
||||
api: ReturnType<typeof useBackendAPI>,
|
||||
) => {
|
||||
const handleSSOLogin = async () => {
|
||||
setIsLoading(true);
|
||||
try {
|
||||
const { sso_url } = await api.getAyrshareSSOUrl();
|
||||
const popup = window.open(sso_url, "_blank", "popup=true");
|
||||
if (!popup) {
|
||||
throw new Error(
|
||||
"Failed to open popup window. Please allow popups for this site.",
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error getting SSO URL:", error);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
className="w-full"
|
||||
onClick={handleSSOLogin}
|
||||
disabled={isLoading}
|
||||
>
|
||||
{isLoading ? (
|
||||
"Loading..."
|
||||
) : (
|
||||
<>
|
||||
<FaKey className="mr-2 h-4 w-4" />
|
||||
Connect Social Media Accounts
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
<NodeHandle
|
||||
title="SSO Token"
|
||||
keyName="sso_token"
|
||||
isConnected={false}
|
||||
schema={{ type: "string" }}
|
||||
side="right"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const generateInputHandles = (
|
||||
schema: BlockIORootSchema,
|
||||
nodeType: BlockUIType,
|
||||
@@ -877,18 +826,8 @@ export const CustomNode = React.memo(
|
||||
(A Webhook URL will be generated when you save the agent)
|
||||
</p>
|
||||
))}
|
||||
{data.uiType === BlockUIType.AYRSHARE ? (
|
||||
<>
|
||||
{generateAyrshareSSOHandles(api)}
|
||||
{generateInputHandles(
|
||||
data.inputSchema,
|
||||
BlockUIType.STANDARD,
|
||||
)}
|
||||
</>
|
||||
) : (
|
||||
data.inputSchema &&
|
||||
generateInputHandles(data.inputSchema, data.uiType)
|
||||
)}
|
||||
{data.inputSchema &&
|
||||
generateInputHandles(data.inputSchema, data.uiType)}
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
|
||||
@@ -14,7 +14,7 @@ import { Label } from "@/components/ui/label";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { addDollars } from "@/app/admin/spending/actions";
|
||||
import { addDollars } from "@/app/(platform)/admin/spending/actions";
|
||||
import useCredits from "@/hooks/useCredits";
|
||||
|
||||
export function AdminAddMoneyButton({
|
||||
@@ -99,7 +99,6 @@ export function AdminAddMoneyButton({
|
||||
id="dollarAmount"
|
||||
type="number"
|
||||
step="0.01"
|
||||
min="0"
|
||||
className="rounded-l-none"
|
||||
value={dollarAmount}
|
||||
onChange={(e) => setDollarAmount(e.target.value)}
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
|
||||
import { PaginationControls } from "../../ui/pagination-controls";
|
||||
import { SearchAndFilterAdminSpending } from "./search-filter-form";
|
||||
import { getUsersTransactionHistory } from "@/app/admin/spending/actions";
|
||||
import { getUsersTransactionHistory } from "@/app/(platform)/admin/spending/actions";
|
||||
import { AdminAddMoneyButton } from "./add-money-button";
|
||||
import { CreditTransactionType } from "@/lib/autogpt-server-api";
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { IconRefresh, IconSquare } from "@/components/ui/icons";
|
||||
import { useToastOnFail } from "@/components/ui/use-toast";
|
||||
import ActionButtonGroup from "@/components/agptui/action-button-group";
|
||||
import LoadingBox from "@/components/ui/loading";
|
||||
import { Input } from "@/components/ui/input";
|
||||
|
||||
import {
|
||||
@@ -252,7 +253,7 @@ export default function AgentRunDetailsView({
|
||||
),
|
||||
)
|
||||
) : (
|
||||
<p>Loading...</p>
|
||||
<LoadingBox spinnerSize={12} className="h-24" />
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
@@ -271,7 +272,7 @@ export default function AgentRunDetailsView({
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<p>Loading...</p>
|
||||
<LoadingBox spinnerSize={12} className="h-24" />
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
@@ -13,6 +13,7 @@ import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { AgentRunStatus } from "@/components/agents/agent-run-status-chip";
|
||||
import { useToastOnFail } from "@/components/ui/use-toast";
|
||||
import ActionButtonGroup from "@/components/agptui/action-button-group";
|
||||
import LoadingBox from "@/components/ui/loading";
|
||||
import { Input } from "@/components/ui/input";
|
||||
|
||||
export default function AgentScheduleDetailsView({
|
||||
@@ -113,7 +114,7 @@ export default function AgentScheduleDetailsView({
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<p>Loading...</p>
|
||||
<LoadingBox spinnerSize={12} className="h-24" />
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
@@ -32,7 +32,7 @@ export const AgentsSection: React.FC<AgentsSectionProps> = ({
|
||||
sectionTitle,
|
||||
agents: allAgents,
|
||||
hideAvatars = false,
|
||||
margin = "37px",
|
||||
margin = "24px",
|
||||
}) => {
|
||||
const router = useRouter();
|
||||
|
||||
@@ -48,11 +48,12 @@ export const AgentsSection: React.FC<AgentsSectionProps> = ({
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center">
|
||||
<div className="w-full max-w-[1360px]">
|
||||
<div
|
||||
className={`mb-[${margin}] font-poppins text-lg font-semibold text-[#282828] dark:text-neutral-200`}
|
||||
<h2
|
||||
style={{ marginBottom: margin }}
|
||||
className="font-poppins text-lg font-semibold text-[#282828] dark:text-neutral-200"
|
||||
>
|
||||
{sectionTitle}
|
||||
</div>
|
||||
</h2>
|
||||
{!displayedAgents || displayedAgents.length === 0 ? (
|
||||
<div className="text-center text-gray-500 dark:text-gray-400">
|
||||
No agents found
|
||||
|
||||
@@ -100,9 +100,8 @@ export default function CredentialsProvider({
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
const [providers, setProviders] = useState<CredentialsProvidersContextType>(
|
||||
{},
|
||||
);
|
||||
const [providers, setProviders] =
|
||||
useState<CredentialsProvidersContextType | null>(null);
|
||||
const api = useBackendAPI();
|
||||
|
||||
const addCredentials = useCallback(
|
||||
|
||||
@@ -669,10 +669,6 @@ export default class BackendAPI {
|
||||
await this._request("DELETE", `/library/presets/${presetId}`);
|
||||
}
|
||||
|
||||
getAyrshareSSOUrl(): Promise<{ sso_url: string; expire_at: string }> {
|
||||
return this._get("/integrations/ayrshare/sso_url");
|
||||
}
|
||||
|
||||
executeLibraryAgentPreset(
|
||||
presetId: string,
|
||||
graphId: GraphID,
|
||||
|
||||
@@ -580,7 +580,6 @@ export enum BlockUIType {
|
||||
WEBHOOK_MANUAL = "Webhook (manual)",
|
||||
AGENT = "Agent",
|
||||
AI = "AI",
|
||||
AYRSHARE = "Ayrshare",
|
||||
}
|
||||
|
||||
export enum SpecialBlockID {
|
||||
|
||||
Reference in New Issue
Block a user