Merge branch 'dev' into codex/fix-400-error-with-non-default-voices-in-unreal-tts

This commit is contained in:
Toran Bruce Richards
2025-07-01 11:02:28 +01:00
committed by GitHub
564 changed files with 32322 additions and 7257 deletions

View File

@@ -50,6 +50,23 @@ jobs:
env:
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
clamav:
image: clamav/clamav-debian:latest
ports:
- 3310:3310
env:
CLAMAV_NO_FRESHCLAMD: false
CLAMD_CONF_StreamMaxLength: 50M
CLAMD_CONF_MaxFileSize: 100M
CLAMD_CONF_MaxScanSize: 100M
CLAMD_CONF_MaxThreads: 4
CLAMD_CONF_ReadTimeout: 300
options: >-
--health-cmd "clamdscan --version || exit 1"
--health-interval 30s
--health-timeout 10s
--health-retries 5
--health-start-period 180s
steps:
- name: Checkout repository
@@ -131,6 +148,35 @@ jobs:
# outputs:
# DB_URL, API_URL, GRAPHQL_URL, ANON_KEY, SERVICE_ROLE_KEY, JWT_SECRET
- name: Wait for ClamAV to be ready
run: |
echo "Waiting for ClamAV daemon to start..."
max_attempts=60
attempt=0
until nc -z localhost 3310 || [ $attempt -eq $max_attempts ]; do
echo "ClamAV is unavailable - sleeping (attempt $((attempt+1))/$max_attempts)"
sleep 5
attempt=$((attempt+1))
done
if [ $attempt -eq $max_attempts ]; then
echo "ClamAV failed to start after $((max_attempts*5)) seconds"
echo "Checking ClamAV service logs..."
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
exit 1
fi
echo "ClamAV is ready!"
# Verify ClamAV is responsive
echo "Testing ClamAV connection..."
timeout 10 bash -c 'echo "PING" | nc localhost 3310' || {
echo "ClamAV is not responding to PING"
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
exit 1
}
- name: Run Database Migrations
run: poetry run prisma migrate dev --name updates
env:
@@ -144,9 +190,9 @@ jobs:
- name: Run pytest with coverage
run: |
if [[ "${{ runner.debug }}" == "1" ]]; then
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
else
poetry run pytest -s -vv test
poetry run pytest -s -vv
fi
if: success() || (failure() && steps.lint.outcome == 'failure')
env:
@@ -159,6 +205,7 @@ jobs:
REDIS_HOST: "localhost"
REDIS_PORT: "6379"
REDIS_PASSWORD: "testpassword"
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
env:
CI: true

View File

@@ -55,12 +55,37 @@ jobs:
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Generate API client
run: pnpm generate:api-client
- name: Run tsc check
run: pnpm type-check
chromatic:
runs-on: ubuntu-latest
# Only run on dev branch pushes or PRs targeting dev
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Run Chromatic
uses: chromaui/action@latest
with:
projectToken: chpt_9e7c1a76478c9c8
onlyChanged: true
workingDir: autogpt_platform/frontend
token: ${{ secrets.GITHUB_TOKEN }}
test:
runs-on: ubuntu-latest
strategy:

3
.gitignore vendored
View File

@@ -177,6 +177,3 @@ autogpt_platform/backend/settings.py
*.ign.*
.test-contents
.claude/settings.local.json
# Auto generated client
autogpt_platform/frontend/src/api/__generated__

View File

@@ -19,7 +19,7 @@ cd backend && poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq)
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend server
@@ -92,6 +92,7 @@ npm run type-check
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
4. **Store**: Marketplace for sharing agent templates
5. **Virus Scanning**: ClamAV integration for file upload security
### Testing Approach
- Backend uses pytest with snapshot testing for API responses

View File

@@ -55,9 +55,9 @@ RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
## GCS bucket is required for marketplace and library functionality
MEDIA_GCS_BUCKET_NAME=
## For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow
## For local development, you may need to set NEXT_PUBLIC_FRONTEND_BASE_URL for the OAuth flow
## for integrations to work. Defaults to the value of PLATFORM_BASE_URL if not set.
# FRONTEND_BASE_URL=http://localhost:3000
# NEXT_PUBLIC_FRONTEND_BASE_URL=http://localhost:3000
## PLATFORM_BASE_URL must be set to a *publicly accessible* URL pointing to your backend
## to use the platform's webhook-related functionality.

View File

@@ -20,7 +20,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
modules = [
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
for f in current_dir.rglob("*.py")
if f.is_file() and f.name != "__init__.py"
if f.is_file() and f.name != "__init__.py" and not f.name.startswith("test_")
]
for module in modules:
if not re.match("^[a-z0-9_.]+$", module):

View File

@@ -2,6 +2,8 @@ import asyncio
import logging
from typing import Any, Optional
from pydantic import JsonValue
from backend.data.block import (
Block,
BlockCategory,
@@ -12,7 +14,7 @@ from backend.data.block import (
get_block,
)
from backend.data.execution import ExecutionStatus
from backend.data.model import CredentialsMetaInput, SchemaField
from backend.data.model import SchemaField
from backend.util import json
logger = logging.getLogger(__name__)
@@ -28,9 +30,9 @@ class AgentExecutorBlock(Block):
input_schema: dict = SchemaField(description="Input schema for the graph")
output_schema: dict = SchemaField(description="Output schema for the graph")
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = SchemaField(default=None, hidden=True)
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = SchemaField(
default=None, hidden=True
)
@classmethod
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
@@ -71,7 +73,7 @@ class AgentExecutorBlock(Block):
graph_version=input_data.graph_version,
user_id=input_data.user_id,
inputs=input_data.inputs,
node_credentials_input_map=input_data.node_credentials_input_map,
nodes_input_masks=input_data.nodes_input_masks,
use_db_query=False,
)

View File

@@ -53,6 +53,7 @@ class AudioTrack(str, Enum):
REFRESHER = ("Refresher",)
TOURIST = ("Tourist",)
TWIN_TYCHES = ("Twin Tyches",)
DONT_STOP_ME_ABSTRACT_FUTURE_BASS = ("Dont Stop Me Abstract Future Bass",)
@property
def audio_url(self):
@@ -78,6 +79,7 @@ class AudioTrack(str, Enum):
AudioTrack.REFRESHER: "https://cdn.tfrv.xyz/audio/refresher.mp3",
AudioTrack.TOURIST: "https://cdn.tfrv.xyz/audio/tourist.mp3",
AudioTrack.TWIN_TYCHES: "https://cdn.tfrv.xyz/audio/twin-tynches.mp3",
AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS: "https://cdn.revid.ai/audio/_dont-stop-me-abstract-future-bass.mp3",
}
return audio_urls[self]
@@ -105,6 +107,7 @@ class GenerationPreset(str, Enum):
MOVIE = ("Movie",)
STYLIZED_ILLUSTRATION = ("Stylized Illustration",)
MANGA = ("Manga",)
DEFAULT = ("DEFAULT",)
class Voice(str, Enum):
@@ -114,6 +117,7 @@ class Voice(str, Enum):
JESSICA = "Jessica"
CHARLOTTE = "Charlotte"
CALLUM = "Callum"
EVA = "Eva"
@property
def voice_id(self):
@@ -124,6 +128,7 @@ class Voice(str, Enum):
Voice.JESSICA: "cgSgspJ2msm6clMCkdW9",
Voice.CHARLOTTE: "XB0fDUnXU5powFXDhCwa",
Voice.CALLUM: "N2lVS1w4EtoT3dr4eOWO",
Voice.EVA: "FGY2WhTYpPnrIDTdsKH5",
}
return voice_id_map[self]
@@ -141,6 +146,8 @@ logger = logging.getLogger(__name__)
class AIShortformVideoCreatorBlock(Block):
"""Creates a shortform texttovideo clip using stock or AI imagery."""
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.REVID], Literal["api_key"]
@@ -184,40 +191,8 @@ class AIShortformVideoCreatorBlock(Block):
video_url: str = SchemaField(description="The URL of the created video")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
description="Creates a shortform video using revid.ai",
categories={BlockCategory.SOCIAL, BlockCategory.AI},
input_schema=AIShortformVideoCreatorBlock.Input,
output_schema=AIShortformVideoCreatorBlock.Output,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"script": "[close-up of a cat] Meow!",
"ratio": "9 / 16",
"resolution": "720p",
"frame_rate": 60,
"generation_preset": GenerationPreset.LEONARDO,
"background_music": AudioTrack.HIGHWAY_NOCTURNE,
"voice": Voice.LILY,
"video_style": VisualMediaType.STOCK_VIDEOS,
},
test_output=(
"video_url",
"https://example.com/video.mp4",
),
test_mock={
"create_webhook": lambda: (
"test_uuid",
"https://webhook.site/test_uuid",
),
"create_video": lambda api_key, payload: {"pid": "test_pid"},
"wait_for_video": lambda api_key, pid, webhook_token, max_wait_time=1000: "https://example.com/video.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def create_webhook(self):
async def create_webhook(self) -> tuple[str, str]:
"""Create a new webhook URL for receiving notifications."""
url = "https://webhook.site/token"
headers = {"Accept": "application/json", "Content-Type": "application/json"}
response = await Requests().post(url, headers=headers)
@@ -225,6 +200,7 @@ class AIShortformVideoCreatorBlock(Block):
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
"""Create a video using the Revid API."""
url = "https://www.revid.ai/api/public/v2/render"
headers = {"key": api_key.get_secret_value()}
response = await Requests().post(url, json=payload, headers=headers)
@@ -234,6 +210,7 @@ class AIShortformVideoCreatorBlock(Block):
return response.json()
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
"""Check the status of a video creation job."""
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
headers = {"key": api_key.get_secret_value()}
response = await Requests().get(url, headers=headers)
@@ -243,9 +220,9 @@ class AIShortformVideoCreatorBlock(Block):
self,
api_key: SecretStr,
pid: str,
webhook_token: str,
max_wait_time: int = 1000,
) -> str:
"""Wait for video creation to complete and return the video URL."""
start_time = time.time()
while time.time() - start_time < max_wait_time:
status = await self.check_video_status(api_key, pid)
@@ -266,6 +243,40 @@ class AIShortformVideoCreatorBlock(Block):
logger.error("Video creation timed out")
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
description="Creates a shortform video using revid.ai",
categories={BlockCategory.SOCIAL, BlockCategory.AI},
input_schema=AIShortformVideoCreatorBlock.Input,
output_schema=AIShortformVideoCreatorBlock.Output,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"script": "[close-up of a cat] Meow!",
"ratio": "9 / 16",
"resolution": "720p",
"frame_rate": 60,
"generation_preset": GenerationPreset.LEONARDO,
"background_music": AudioTrack.HIGHWAY_NOCTURNE,
"voice": Voice.LILY,
"video_style": VisualMediaType.STOCK_VIDEOS,
},
test_output=("video_url", "https://example.com/video.mp4"),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
"https://webhook.site/test_uuid",
),
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/video.mp4",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
@@ -273,20 +284,18 @@ class AIShortformVideoCreatorBlock(Block):
webhook_token, webhook_url = await self.create_webhook()
logger.debug(f"Webhook URL: {webhook_url}")
audio_url = input_data.background_music.audio_url
payload = {
"frameRate": input_data.frame_rate,
"resolution": input_data.resolution,
"frameDurationMultiplier": 18,
"webhook": webhook_url,
"webhook": None,
"creationParams": {
"mediaType": input_data.video_style,
"captionPresetName": "Wrap 1",
"selectedVoice": input_data.voice.voice_id,
"hasEnhancedGeneration": True,
"generationPreset": input_data.generation_preset.name,
"selectedAudio": input_data.background_music,
"selectedAudio": input_data.background_music.value,
"origin": "/create",
"inputText": input_data.script,
"flowType": "text-to-video",
@@ -302,7 +311,7 @@ class AIShortformVideoCreatorBlock(Block):
"selectedStoryStyle": {"value": "custom", "label": "Custom"},
"hasToGenerateVideos": input_data.video_style
!= VisualMediaType.STOCK_VIDEOS,
"audioUrl": audio_url,
"audioUrl": input_data.background_music.audio_url,
},
}
@@ -319,8 +328,370 @@ class AIShortformVideoCreatorBlock(Block):
logger.debug(
f"Video created with project ID: {pid}. Waiting for completion..."
)
video_url = await self.wait_for_video(
credentials.api_key, pid, webhook_token
)
video_url = await self.wait_for_video(credentials.api_key, pid)
logger.debug(f"Video ready: {video_url}")
yield "video_url", video_url
class AIAdMakerVideoCreatorBlock(Block):
"""Generates a 30second vertical AI advert using optional usersupplied imagery."""
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.REVID], Literal["api_key"]
] = CredentialsField(
description="Credentials for Revid.ai API access.",
)
script: str = SchemaField(
description="Short advertising copy. Line breaks create new scenes.",
placeholder="Introducing Foobar [show product photo] the gadget that does it all.",
)
ratio: str = SchemaField(description="Aspect ratio", default="9 / 16")
target_duration: int = SchemaField(
description="Desired length of the ad in seconds.", default=30
)
voice: Voice = SchemaField(
description="Narration voice", default=Voice.EVA, placeholder=Voice.EVA
)
background_music: AudioTrack = SchemaField(
description="Background track",
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS,
)
input_media_urls: list[str] = SchemaField(
description="List of image URLs to feature in the advert.", default=[]
)
use_only_provided_media: bool = SchemaField(
description="Restrict visuals to supplied images only.", default=True
)
class Output(BlockSchema):
video_url: str = SchemaField(description="URL of the finished advert")
error: str = SchemaField(description="Error message on failure")
async def create_webhook(self) -> tuple[str, str]:
"""Create a new webhook URL for receiving notifications."""
url = "https://webhook.site/token"
headers = {"Accept": "application/json", "Content-Type": "application/json"}
response = await Requests().post(url, headers=headers)
webhook_data = response.json()
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
"""Create a video using the Revid API."""
url = "https://www.revid.ai/api/public/v2/render"
headers = {"key": api_key.get_secret_value()}
response = await Requests().post(url, json=payload, headers=headers)
logger.debug(
f"API Response Status Code: {response.status}, Content: {response.text}"
)
return response.json()
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
"""Check the status of a video creation job."""
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
headers = {"key": api_key.get_secret_value()}
response = await Requests().get(url, headers=headers)
return response.json()
async def wait_for_video(
self,
api_key: SecretStr,
pid: str,
max_wait_time: int = 1000,
) -> str:
"""Wait for video creation to complete and return the video URL."""
start_time = time.time()
while time.time() - start_time < max_wait_time:
status = await self.check_video_status(api_key, pid)
logger.debug(f"Video status: {status}")
if status.get("status") == "ready" and "videoUrl" in status:
return status["videoUrl"]
elif status.get("status") == "error":
error_message = status.get("error", "Unknown error occurred")
logger.error(f"Video creation failed: {error_message}")
raise ValueError(f"Video creation failed: {error_message}")
elif status.get("status") in ["FAILED", "CANCELED"]:
logger.error(f"Video creation failed: {status.get('message')}")
raise ValueError(f"Video creation failed: {status.get('message')}")
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
id="58bd2a19-115d-4fd1-8ca4-13b9e37fa6a0",
description="Creates an AIgenerated 30second advert (text + images)",
categories={BlockCategory.MARKETING, BlockCategory.AI},
input_schema=AIAdMakerVideoCreatorBlock.Input,
output_schema=AIAdMakerVideoCreatorBlock.Output,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"script": "Test product launch!",
"input_media_urls": [
"https://cdn.revid.ai/uploads/1747076315114-image.png",
],
},
test_output=("video_url", "https://example.com/ad.mp4"),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
"https://webhook.site/test_uuid",
),
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/ad.mp4",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
webhook_token, webhook_url = await self.create_webhook()
payload = {
"webhook": webhook_url,
"creationParams": {
"targetDuration": input_data.target_duration,
"ratio": input_data.ratio,
"mediaType": "aiVideo",
"inputText": input_data.script,
"flowType": "text-to-video",
"slug": "ai-ad-generator",
"slugNew": "",
"isCopiedFrom": False,
"hasToGenerateVoice": True,
"hasToTranscript": False,
"hasToSearchMedia": True,
"hasAvatar": False,
"hasWebsiteRecorder": False,
"hasTextSmallAtBottom": False,
"selectedAudio": input_data.background_music.value,
"selectedVoice": input_data.voice.voice_id,
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
"selectedAvatarType": "video/mp4",
"websiteToRecord": "",
"hasToGenerateCover": True,
"nbGenerations": 1,
"disableCaptions": False,
"mediaMultiplier": "medium",
"characters": [],
"captionPresetName": "Revid",
"sourceType": "contentScraping",
"selectedStoryStyle": {"value": "custom", "label": "General"},
"generationPreset": "DEFAULT",
"hasToGenerateMusic": False,
"isOptimizedForChinese": False,
"generationUserPrompt": "",
"enableNsfwFilter": False,
"addStickers": False,
"typeMovingImageAnim": "dynamic",
"hasToGenerateSoundEffects": False,
"forceModelType": "gpt-image-1",
"selectedCharacters": [],
"lang": "",
"voiceSpeed": 1,
"disableAudio": False,
"disableVoice": False,
"useOnlyProvidedMedia": input_data.use_only_provided_media,
"imageGenerationModel": "ultra",
"videoGenerationModel": "pro",
"hasEnhancedGeneration": True,
"hasEnhancedGenerationPro": True,
"inputMedias": [
{"url": url, "title": "", "type": "image"}
for url in input_data.input_media_urls
],
"hasToGenerateVideos": True,
"audioUrl": input_data.background_music.audio_url,
"watermark": None,
},
}
response = await self.create_video(credentials.api_key, payload)
pid = response.get("pid")
if not pid:
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
yield "video_url", video_url
class AIScreenshotToVideoAdBlock(Block):
"""Creates an advert where the supplied screenshot is narrated by an AI avatar."""
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.REVID], Literal["api_key"]
] = CredentialsField(description="Revid.ai API key")
script: str = SchemaField(
description="Narration that will accompany the screenshot.",
placeholder="Check out these amazing stats!",
)
screenshot_url: str = SchemaField(
description="Screenshot or image URL to showcase."
)
ratio: str = SchemaField(default="9 / 16")
target_duration: int = SchemaField(default=30)
voice: Voice = SchemaField(default=Voice.EVA)
background_music: AudioTrack = SchemaField(
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS
)
class Output(BlockSchema):
video_url: str = SchemaField(description="Rendered video URL")
error: str = SchemaField(description="Error, if encountered")
async def create_webhook(self) -> tuple[str, str]:
"""Create a new webhook URL for receiving notifications."""
url = "https://webhook.site/token"
headers = {"Accept": "application/json", "Content-Type": "application/json"}
response = await Requests().post(url, headers=headers)
webhook_data = response.json()
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
"""Create a video using the Revid API."""
url = "https://www.revid.ai/api/public/v2/render"
headers = {"key": api_key.get_secret_value()}
response = await Requests().post(url, json=payload, headers=headers)
logger.debug(
f"API Response Status Code: {response.status}, Content: {response.text}"
)
return response.json()
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
"""Check the status of a video creation job."""
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
headers = {"key": api_key.get_secret_value()}
response = await Requests().get(url, headers=headers)
return response.json()
async def wait_for_video(
self,
api_key: SecretStr,
pid: str,
max_wait_time: int = 1000,
) -> str:
"""Wait for video creation to complete and return the video URL."""
start_time = time.time()
while time.time() - start_time < max_wait_time:
status = await self.check_video_status(api_key, pid)
logger.debug(f"Video status: {status}")
if status.get("status") == "ready" and "videoUrl" in status:
return status["videoUrl"]
elif status.get("status") == "error":
error_message = status.get("error", "Unknown error occurred")
logger.error(f"Video creation failed: {error_message}")
raise ValueError(f"Video creation failed: {error_message}")
elif status.get("status") in ["FAILED", "CANCELED"]:
logger.error(f"Video creation failed: {status.get('message')}")
raise ValueError(f"Video creation failed: {status.get('message')}")
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
id="0f3e4635-e810-43d9-9e81-49e6f4e83b7c",
description="Turns a screenshot into an engaging, avatarnarrated video advert.",
categories={BlockCategory.AI, BlockCategory.MARKETING},
input_schema=AIScreenshotToVideoAdBlock.Input,
output_schema=AIScreenshotToVideoAdBlock.Output,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"script": "Amazing numbers!",
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
},
test_output=("video_url", "https://example.com/screenshot.mp4"),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
"https://webhook.site/test_uuid",
),
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/screenshot.mp4",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
webhook_token, webhook_url = await self.create_webhook()
payload = {
"webhook": webhook_url,
"creationParams": {
"targetDuration": input_data.target_duration,
"ratio": input_data.ratio,
"mediaType": "aiVideo",
"hasAvatar": True,
"removeAvatarBackground": True,
"inputText": input_data.script,
"flowType": "text-to-video",
"slug": "ai-ad-generator",
"slugNew": "screenshot-to-video-ad",
"isCopiedFrom": "ai-ad-generator",
"hasToGenerateVoice": True,
"hasToTranscript": False,
"hasToSearchMedia": True,
"hasWebsiteRecorder": False,
"hasTextSmallAtBottom": False,
"selectedAudio": input_data.background_music.value,
"selectedVoice": input_data.voice.voice_id,
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
"selectedAvatarType": "video/mp4",
"websiteToRecord": "",
"hasToGenerateCover": True,
"nbGenerations": 1,
"disableCaptions": False,
"mediaMultiplier": "medium",
"characters": [],
"captionPresetName": "Revid",
"sourceType": "contentScraping",
"selectedStoryStyle": {"value": "custom", "label": "General"},
"generationPreset": "DEFAULT",
"hasToGenerateMusic": False,
"isOptimizedForChinese": False,
"generationUserPrompt": "",
"enableNsfwFilter": False,
"addStickers": False,
"typeMovingImageAnim": "dynamic",
"hasToGenerateSoundEffects": False,
"forceModelType": "gpt-image-1",
"selectedCharacters": [],
"lang": "",
"voiceSpeed": 1,
"disableAudio": False,
"disableVoice": False,
"useOnlyProvidedMedia": True,
"imageGenerationModel": "ultra",
"videoGenerationModel": "ultra",
"hasEnhancedGeneration": True,
"hasEnhancedGenerationPro": True,
"inputMedias": [
{"url": input_data.screenshot_url, "title": "", "type": "image"}
],
"hasToGenerateVideos": True,
"audioUrl": input_data.background_music.audio_url,
"watermark": None,
},
}
response = await self.create_video(credentials.api_key, payload)
pid = response.get("pid")
if not pid:
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
yield "video_url", video_url

View File

@@ -4,6 +4,7 @@ from typing import List
from backend.blocks.apollo._auth import ApolloCredentials
from backend.blocks.apollo.models import (
Contact,
EnrichPersonRequest,
Organization,
SearchOrganizationsRequest,
SearchOrganizationsResponse,
@@ -110,3 +111,21 @@ class ApolloClient:
return (
organizations[: query.max_results] if query.max_results else organizations
)
async def enrich_person(self, query: EnrichPersonRequest) -> Contact:
"""Enrich a person's data including email & phone reveal"""
response = await self.requests.post(
f"{self.API_URL}/people/match",
headers=self._get_headers(),
json=query.model_dump(),
params={
"reveal_personal_emails": "true",
},
)
data = response.json()
if "person" not in data:
raise ValueError(f"Person not found or enrichment failed: {data}")
contact = Contact(**data["person"])
contact.email = contact.email or "-"
return contact

View File

@@ -23,9 +23,9 @@ class BaseModel(OriginalBaseModel):
class PrimaryPhone(BaseModel):
"""A primary phone in Apollo"""
number: str = ""
source: str = ""
sanitized_number: str = ""
number: Optional[str] = ""
source: Optional[str] = ""
sanitized_number: Optional[str] = ""
class SenorityLevels(str, Enum):
@@ -56,102 +56,102 @@ class ContactEmailStatuses(str, Enum):
class RuleConfigStatus(BaseModel):
"""A rule config status in Apollo"""
_id: str = ""
created_at: str = ""
rule_action_config_id: str = ""
rule_config_id: str = ""
status_cd: str = ""
updated_at: str = ""
id: str = ""
key: str = ""
_id: Optional[str] = ""
created_at: Optional[str] = ""
rule_action_config_id: Optional[str] = ""
rule_config_id: Optional[str] = ""
status_cd: Optional[str] = ""
updated_at: Optional[str] = ""
id: Optional[str] = ""
key: Optional[str] = ""
class ContactCampaignStatus(BaseModel):
"""A contact campaign status in Apollo"""
id: str = ""
emailer_campaign_id: str = ""
send_email_from_user_id: str = ""
inactive_reason: str = ""
status: str = ""
added_at: str = ""
added_by_user_id: str = ""
finished_at: str = ""
paused_at: str = ""
auto_unpause_at: str = ""
send_email_from_email_address: str = ""
send_email_from_email_account_id: str = ""
manually_set_unpause: str = ""
failure_reason: str = ""
current_step_id: str = ""
in_response_to_emailer_message_id: str = ""
cc_emails: str = ""
bcc_emails: str = ""
to_emails: str = ""
id: Optional[str] = ""
emailer_campaign_id: Optional[str] = ""
send_email_from_user_id: Optional[str] = ""
inactive_reason: Optional[str] = ""
status: Optional[str] = ""
added_at: Optional[str] = ""
added_by_user_id: Optional[str] = ""
finished_at: Optional[str] = ""
paused_at: Optional[str] = ""
auto_unpause_at: Optional[str] = ""
send_email_from_email_address: Optional[str] = ""
send_email_from_email_account_id: Optional[str] = ""
manually_set_unpause: Optional[str] = ""
failure_reason: Optional[str] = ""
current_step_id: Optional[str] = ""
in_response_to_emailer_message_id: Optional[str] = ""
cc_emails: Optional[str] = ""
bcc_emails: Optional[str] = ""
to_emails: Optional[str] = ""
class Account(BaseModel):
"""An account in Apollo"""
id: str = ""
name: str = ""
website_url: str = ""
blog_url: str = ""
angellist_url: str = ""
linkedin_url: str = ""
twitter_url: str = ""
facebook_url: str = ""
primary_phone: PrimaryPhone = PrimaryPhone()
languages: list[str]
alexa_ranking: int = 0
phone: str = ""
linkedin_uid: str = ""
founded_year: int = 0
publicly_traded_symbol: str = ""
publicly_traded_exchange: str = ""
logo_url: str = ""
chrunchbase_url: str = ""
primary_domain: str = ""
domain: str = ""
team_id: str = ""
organization_id: str = ""
account_stage_id: str = ""
source: str = ""
original_source: str = ""
creator_id: str = ""
owner_id: str = ""
created_at: str = ""
phone_status: str = ""
hubspot_id: str = ""
salesforce_id: str = ""
crm_owner_id: str = ""
parent_account_id: str = ""
sanitized_phone: str = ""
id: Optional[str] = ""
name: Optional[str] = ""
website_url: Optional[str] = ""
blog_url: Optional[str] = ""
angellist_url: Optional[str] = ""
linkedin_url: Optional[str] = ""
twitter_url: Optional[str] = ""
facebook_url: Optional[str] = ""
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
languages: Optional[list[str]] = []
alexa_ranking: Optional[int] = 0
phone: Optional[str] = ""
linkedin_uid: Optional[str] = ""
founded_year: Optional[int] = 0
publicly_traded_symbol: Optional[str] = ""
publicly_traded_exchange: Optional[str] = ""
logo_url: Optional[str] = ""
chrunchbase_url: Optional[str] = ""
primary_domain: Optional[str] = ""
domain: Optional[str] = ""
team_id: Optional[str] = ""
organization_id: Optional[str] = ""
account_stage_id: Optional[str] = ""
source: Optional[str] = ""
original_source: Optional[str] = ""
creator_id: Optional[str] = ""
owner_id: Optional[str] = ""
created_at: Optional[str] = ""
phone_status: Optional[str] = ""
hubspot_id: Optional[str] = ""
salesforce_id: Optional[str] = ""
crm_owner_id: Optional[str] = ""
parent_account_id: Optional[str] = ""
sanitized_phone: Optional[str] = ""
# no listed type on the API docs
account_playbook_statues: list[Any] = []
account_rule_config_statuses: list[RuleConfigStatus] = []
existence_level: str = ""
label_ids: list[str] = []
typed_custom_fields: Any
custom_field_errors: Any
modality: str = ""
source_display_name: str = ""
salesforce_record_id: str = ""
crm_record_url: str = ""
account_playbook_statues: Optional[list[Any]] = []
account_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
existence_level: Optional[str] = ""
label_ids: Optional[list[str]] = []
typed_custom_fields: Optional[Any] = {}
custom_field_errors: Optional[Any] = {}
modality: Optional[str] = ""
source_display_name: Optional[str] = ""
salesforce_record_id: Optional[str] = ""
crm_record_url: Optional[str] = ""
class ContactEmail(BaseModel):
"""A contact email in Apollo"""
email: str = ""
email_md5: str = ""
email_sha256: str = ""
email_status: str = ""
email_source: str = ""
extrapolated_email_confidence: str = ""
position: int = 0
email_from_customer: str = ""
free_domain: bool = True
email: Optional[str] = ""
email_md5: Optional[str] = ""
email_sha256: Optional[str] = ""
email_status: Optional[str] = ""
email_source: Optional[str] = ""
extrapolated_email_confidence: Optional[str] = ""
position: Optional[int] = 0
email_from_customer: Optional[str] = ""
free_domain: Optional[bool] = True
class EmploymentHistory(BaseModel):
@@ -164,40 +164,40 @@ class EmploymentHistory(BaseModel):
populate_by_name=True,
)
_id: Optional[str] = None
created_at: Optional[str] = None
current: Optional[bool] = None
degree: Optional[str] = None
description: Optional[str] = None
emails: Optional[str] = None
end_date: Optional[str] = None
grade_level: Optional[str] = None
kind: Optional[str] = None
major: Optional[str] = None
organization_id: Optional[str] = None
organization_name: Optional[str] = None
raw_address: Optional[str] = None
start_date: Optional[str] = None
title: Optional[str] = None
updated_at: Optional[str] = None
id: Optional[str] = None
key: Optional[str] = None
_id: Optional[str] = ""
created_at: Optional[str] = ""
current: Optional[bool] = False
degree: Optional[str] = ""
description: Optional[str] = ""
emails: Optional[str] = ""
end_date: Optional[str] = ""
grade_level: Optional[str] = ""
kind: Optional[str] = ""
major: Optional[str] = ""
organization_id: Optional[str] = ""
organization_name: Optional[str] = ""
raw_address: Optional[str] = ""
start_date: Optional[str] = ""
title: Optional[str] = ""
updated_at: Optional[str] = ""
id: Optional[str] = ""
key: Optional[str] = ""
class Breadcrumb(BaseModel):
"""A breadcrumb in Apollo"""
label: Optional[str] = "N/A"
signal_field_name: Optional[str] = "N/A"
value: str | list | None = "N/A"
display_name: Optional[str] = "N/A"
label: Optional[str] = ""
signal_field_name: Optional[str] = ""
value: str | list | None = ""
display_name: Optional[str] = ""
class TypedCustomField(BaseModel):
"""A typed custom field in Apollo"""
id: Optional[str] = "N/A"
value: Optional[str] = "N/A"
id: Optional[str] = ""
value: Optional[str] = ""
class Pagination(BaseModel):
@@ -219,23 +219,23 @@ class Pagination(BaseModel):
class DialerFlags(BaseModel):
"""A dialer flags in Apollo"""
country_name: str = ""
country_enabled: bool
high_risk_calling_enabled: bool
potential_high_risk_number: bool
country_name: Optional[str] = ""
country_enabled: Optional[bool] = True
high_risk_calling_enabled: Optional[bool] = True
potential_high_risk_number: Optional[bool] = True
class PhoneNumber(BaseModel):
"""A phone number in Apollo"""
raw_number: str = ""
sanitized_number: str = ""
type: str = ""
position: int = 0
status: str = ""
dnc_status: str = ""
dnc_other_info: str = ""
dailer_flags: DialerFlags = DialerFlags(
raw_number: Optional[str] = ""
sanitized_number: Optional[str] = ""
type: Optional[str] = ""
position: Optional[int] = 0
status: Optional[str] = ""
dnc_status: Optional[str] = ""
dnc_other_info: Optional[str] = ""
dailer_flags: Optional[DialerFlags] = DialerFlags(
country_name="",
country_enabled=True,
high_risk_calling_enabled=True,
@@ -253,33 +253,31 @@ class Organization(BaseModel):
populate_by_name=True,
)
id: Optional[str] = "N/A"
name: Optional[str] = "N/A"
website_url: Optional[str] = "N/A"
blog_url: Optional[str] = "N/A"
angellist_url: Optional[str] = "N/A"
linkedin_url: Optional[str] = "N/A"
twitter_url: Optional[str] = "N/A"
facebook_url: Optional[str] = "N/A"
primary_phone: Optional[PrimaryPhone] = PrimaryPhone(
number="N/A", source="N/A", sanitized_number="N/A"
)
languages: list[str] = []
id: Optional[str] = ""
name: Optional[str] = ""
website_url: Optional[str] = ""
blog_url: Optional[str] = ""
angellist_url: Optional[str] = ""
linkedin_url: Optional[str] = ""
twitter_url: Optional[str] = ""
facebook_url: Optional[str] = ""
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
languages: Optional[list[str]] = []
alexa_ranking: Optional[int] = 0
phone: Optional[str] = "N/A"
linkedin_uid: Optional[str] = "N/A"
phone: Optional[str] = ""
linkedin_uid: Optional[str] = ""
founded_year: Optional[int] = 0
publicly_traded_symbol: Optional[str] = "N/A"
publicly_traded_exchange: Optional[str] = "N/A"
logo_url: Optional[str] = "N/A"
chrunchbase_url: Optional[str] = "N/A"
primary_domain: Optional[str] = "N/A"
sanitized_phone: Optional[str] = "N/A"
owned_by_organization_id: Optional[str] = "N/A"
intent_strength: Optional[str] = "N/A"
show_intent: bool = True
publicly_traded_symbol: Optional[str] = ""
publicly_traded_exchange: Optional[str] = ""
logo_url: Optional[str] = ""
chrunchbase_url: Optional[str] = ""
primary_domain: Optional[str] = ""
sanitized_phone: Optional[str] = ""
owned_by_organization_id: Optional[str] = ""
intent_strength: Optional[str] = ""
show_intent: Optional[bool] = True
has_intent_signal_account: Optional[bool] = True
intent_signal_account: Optional[str] = "N/A"
intent_signal_account: Optional[str] = ""
class Contact(BaseModel):
@@ -292,95 +290,95 @@ class Contact(BaseModel):
populate_by_name=True,
)
contact_roles: list[Any] = []
id: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
name: Optional[str] = None
linkedin_url: Optional[str] = None
title: Optional[str] = None
contact_stage_id: Optional[str] = None
owner_id: Optional[str] = None
creator_id: Optional[str] = None
person_id: Optional[str] = None
email_needs_tickling: bool = True
organization_name: Optional[str] = None
source: Optional[str] = None
original_source: Optional[str] = None
organization_id: Optional[str] = None
headline: Optional[str] = None
photo_url: Optional[str] = None
present_raw_address: Optional[str] = None
linkededin_uid: Optional[str] = None
extrapolated_email_confidence: Optional[float] = None
salesforce_id: Optional[str] = None
salesforce_lead_id: Optional[str] = None
salesforce_contact_id: Optional[str] = None
saleforce_account_id: Optional[str] = None
crm_owner_id: Optional[str] = None
created_at: Optional[str] = None
emailer_campaign_ids: list[str] = []
direct_dial_status: Optional[str] = None
direct_dial_enrichment_failed_at: Optional[str] = None
email_status: Optional[str] = None
email_source: Optional[str] = None
account_id: Optional[str] = None
last_activity_date: Optional[str] = None
hubspot_vid: Optional[str] = None
hubspot_company_id: Optional[str] = None
crm_id: Optional[str] = None
sanitized_phone: Optional[str] = None
merged_crm_ids: Optional[str] = None
updated_at: Optional[str] = None
queued_for_crm_push: bool = True
suggested_from_rule_engine_config_id: Optional[str] = None
email_unsubscribed: Optional[str] = None
label_ids: list[Any] = []
has_pending_email_arcgate_request: bool = True
has_email_arcgate_request: bool = True
existence_level: Optional[str] = None
email: Optional[str] = None
email_from_customer: Optional[str] = None
typed_custom_fields: list[TypedCustomField] = []
custom_field_errors: Any = None
salesforce_record_id: Optional[str] = None
crm_record_url: Optional[str] = None
email_status_unavailable_reason: Optional[str] = None
email_true_status: Optional[str] = None
updated_email_true_status: bool = True
contact_rule_config_statuses: list[RuleConfigStatus] = []
source_display_name: Optional[str] = None
twitter_url: Optional[str] = None
contact_campaign_statuses: list[ContactCampaignStatus] = []
state: Optional[str] = None
city: Optional[str] = None
country: Optional[str] = None
account: Optional[Account] = None
contact_emails: list[ContactEmail] = []
organization: Optional[Organization] = None
employment_history: list[EmploymentHistory] = []
time_zone: Optional[str] = None
intent_strength: Optional[str] = None
show_intent: bool = True
phone_numbers: list[PhoneNumber] = []
account_phone_note: Optional[str] = None
free_domain: bool = True
is_likely_to_engage: bool = True
email_domain_catchall: bool = True
contact_job_change_event: Optional[str] = None
contact_roles: Optional[list[Any]] = []
id: Optional[str] = ""
first_name: Optional[str] = ""
last_name: Optional[str] = ""
name: Optional[str] = ""
linkedin_url: Optional[str] = ""
title: Optional[str] = ""
contact_stage_id: Optional[str] = ""
owner_id: Optional[str] = ""
creator_id: Optional[str] = ""
person_id: Optional[str] = ""
email_needs_tickling: Optional[bool] = True
organization_name: Optional[str] = ""
source: Optional[str] = ""
original_source: Optional[str] = ""
organization_id: Optional[str] = ""
headline: Optional[str] = ""
photo_url: Optional[str] = ""
present_raw_address: Optional[str] = ""
linkededin_uid: Optional[str] = ""
extrapolated_email_confidence: Optional[float] = 0.0
salesforce_id: Optional[str] = ""
salesforce_lead_id: Optional[str] = ""
salesforce_contact_id: Optional[str] = ""
saleforce_account_id: Optional[str] = ""
crm_owner_id: Optional[str] = ""
created_at: Optional[str] = ""
emailer_campaign_ids: Optional[list[str]] = []
direct_dial_status: Optional[str] = ""
direct_dial_enrichment_failed_at: Optional[str] = ""
email_status: Optional[str] = ""
email_source: Optional[str] = ""
account_id: Optional[str] = ""
last_activity_date: Optional[str] = ""
hubspot_vid: Optional[str] = ""
hubspot_company_id: Optional[str] = ""
crm_id: Optional[str] = ""
sanitized_phone: Optional[str] = ""
merged_crm_ids: Optional[str] = ""
updated_at: Optional[str] = ""
queued_for_crm_push: Optional[bool] = True
suggested_from_rule_engine_config_id: Optional[str] = ""
email_unsubscribed: Optional[str] = ""
label_ids: Optional[list[Any]] = []
has_pending_email_arcgate_request: Optional[bool] = True
has_email_arcgate_request: Optional[bool] = True
existence_level: Optional[str] = ""
email: Optional[str] = ""
email_from_customer: Optional[str] = ""
typed_custom_fields: Optional[list[TypedCustomField]] = []
custom_field_errors: Optional[Any] = {}
salesforce_record_id: Optional[str] = ""
crm_record_url: Optional[str] = ""
email_status_unavailable_reason: Optional[str] = ""
email_true_status: Optional[str] = ""
updated_email_true_status: Optional[bool] = True
contact_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
source_display_name: Optional[str] = ""
twitter_url: Optional[str] = ""
contact_campaign_statuses: Optional[list[ContactCampaignStatus]] = []
state: Optional[str] = ""
city: Optional[str] = ""
country: Optional[str] = ""
account: Optional[Account] = Account()
contact_emails: Optional[list[ContactEmail]] = []
organization: Optional[Organization] = Organization()
employment_history: Optional[list[EmploymentHistory]] = []
time_zone: Optional[str] = ""
intent_strength: Optional[str] = ""
show_intent: Optional[bool] = True
phone_numbers: Optional[list[PhoneNumber]] = []
account_phone_note: Optional[str] = ""
free_domain: Optional[bool] = True
is_likely_to_engage: Optional[bool] = True
email_domain_catchall: Optional[bool] = True
contact_job_change_event: Optional[str] = ""
class SearchOrganizationsRequest(BaseModel):
"""Request for Apollo's search organizations API"""
organization_num_empoloyees_range: list[int] = SchemaField(
organization_num_employees_range: Optional[list[int]] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
default=[0, 1000000],
)
organization_locations: list[str] = SchemaField(
organization_locations: Optional[list[str]] = SchemaField(
description="""The location of the company headquarters. You can search across cities, US states, and countries.
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, any Boston-based companies will not appearch in your search results, even if they match other parameters.
@@ -389,28 +387,30 @@ To exclude companies based on location, use the organization_not_locations param
""",
default_factory=list,
)
organizations_not_locations: list[str] = SchemaField(
organizations_not_locations: Optional[list[str]] = SchemaField(
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
""",
default_factory=list,
)
q_organization_keyword_tags: list[str] = SchemaField(
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
q_organization_keyword_tags: Optional[list[str]] = SchemaField(
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
default_factory=list,
)
q_organization_name: str = SchemaField(
q_organization_name: Optional[str] = SchemaField(
description="""Filter search results to include a specific company name.
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible."""
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible.""",
default="",
)
organization_ids: list[str] = SchemaField(
organization_ids: Optional[list[str]] = SchemaField(
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, identify the values for organization_id when you call this endpoint.""",
default_factory=list,
)
max_results: int = SchemaField(
max_results: Optional[int] = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
default=100,
ge=1,
@@ -435,11 +435,11 @@ Use the page parameter to search the different pages of data.""",
class SearchOrganizationsResponse(BaseModel):
"""Response from Apollo's search organizations API"""
breadcrumbs: list[Breadcrumb] = []
partial_results_only: bool = True
has_join: bool = True
disable_eu_prospecting: bool = True
partial_results_limit: int = 0
breadcrumbs: Optional[list[Breadcrumb]] = []
partial_results_only: Optional[bool] = True
has_join: Optional[bool] = True
disable_eu_prospecting: Optional[bool] = True
partial_results_limit: Optional[int] = 0
pagination: Pagination = Pagination(
page=0, per_page=0, total_entries=0, total_pages=0
)
@@ -447,14 +447,14 @@ class SearchOrganizationsResponse(BaseModel):
accounts: list[Any] = []
organizations: list[Organization] = []
models_ids: list[str] = []
num_fetch_result: Optional[str] = "N/A"
derived_params: Optional[str] = "N/A"
num_fetch_result: Optional[str] = ""
derived_params: Optional[str] = ""
class SearchPeopleRequest(BaseModel):
"""Request for Apollo's search people API"""
person_titles: list[str] = SchemaField(
person_titles: Optional[list[str]] = SchemaField(
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
Results also include job titles with the same terms, even if they are not exact matches. For example, searching for marketing manager might return people with the job title content marketing manager.
@@ -464,13 +464,13 @@ Use this parameter in combination with the person_seniorities[] parameter to fin
default_factory=list,
placeholder="marketing manager",
)
person_locations: list[str] = SchemaField(
person_locations: Optional[list[str]] = SchemaField(
description="""The location where people live. You can search across cities, US states, and countries.
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
default_factory=list,
)
person_seniorities: list[SenorityLevels] = SchemaField(
person_seniorities: Optional[list[SenorityLevels]] = SchemaField(
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
For a person to be included in search results, they only need to match 1 of the seniorities you add. Adding more seniorities expands your search results.
@@ -480,7 +480,7 @@ Searches only return results based on their current job title, so searching for
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
default_factory=list,
)
organization_locations: list[str] = SchemaField(
organization_locations: Optional[list[str]] = SchemaField(
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
@@ -488,7 +488,7 @@ If a company has several office locations, results are still based on the headqu
To find people based on their personal location, use the person_locations parameter.""",
default_factory=list,
)
q_organization_domains: list[str] = SchemaField(
q_organization_domains: Optional[list[str]] = SchemaField(
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
You can add multiple domains to search across companies.
@@ -496,23 +496,23 @@ You can add multiple domains to search across companies.
Examples: apollo.io and microsoft.com""",
default_factory=list,
)
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
contact_email_statuses: Optional[list[ContactEmailStatuses]] = SchemaField(
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
default_factory=list,
)
organization_ids: list[str] = SchemaField(
organization_ids: Optional[list[str]] = SchemaField(
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
default_factory=list,
)
organization_num_empoloyees_range: list[int] = SchemaField(
organization_num_employees_range: Optional[list[int]] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
default_factory=list,
)
q_keywords: str = SchemaField(
q_keywords: Optional[str] = SchemaField(
description="""A string of words over which we want to filter the results""",
default="",
)
@@ -528,7 +528,7 @@ Use this parameter in combination with the per_page parameter to make search res
Use the page parameter to search the different pages of data.""",
default=100,
)
max_results: int = SchemaField(
max_results: Optional[int] = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
default=100,
ge=1,
@@ -547,16 +547,61 @@ class SearchPeopleResponse(BaseModel):
populate_by_name=True,
)
breadcrumbs: list[Breadcrumb] = []
partial_results_only: bool = True
has_join: bool = True
disable_eu_prospecting: bool = True
partial_results_limit: int = 0
breadcrumbs: Optional[list[Breadcrumb]] = []
partial_results_only: Optional[bool] = True
has_join: Optional[bool] = True
disable_eu_prospecting: Optional[bool] = True
partial_results_limit: Optional[int] = 0
pagination: Pagination = Pagination(
page=0, per_page=0, total_entries=0, total_pages=0
)
contacts: list[Contact] = []
people: list[Contact] = []
model_ids: list[str] = []
num_fetch_result: Optional[str] = "N/A"
derived_params: Optional[str] = "N/A"
num_fetch_result: Optional[str] = ""
derived_params: Optional[str] = ""
class EnrichPersonRequest(BaseModel):
"""Request for Apollo's person enrichment API"""
person_id: Optional[str] = SchemaField(
description="Apollo person ID to enrich (most accurate method)",
default="",
)
first_name: Optional[str] = SchemaField(
description="First name of the person to enrich",
default="",
)
last_name: Optional[str] = SchemaField(
description="Last name of the person to enrich",
default="",
)
name: Optional[str] = SchemaField(
description="Full name of the person to enrich",
default="",
)
email: Optional[str] = SchemaField(
description="Email address of the person to enrich",
default="",
)
domain: Optional[str] = SchemaField(
description="Company domain of the person to enrich",
default="",
)
company: Optional[str] = SchemaField(
description="Company name of the person to enrich",
default="",
)
linkedin_url: Optional[str] = SchemaField(
description="LinkedIn URL of the person to enrich",
default="",
)
organization_id: Optional[str] = SchemaField(
description="Apollo organization ID of the person's company",
default="",
)
title: Optional[str] = SchemaField(
description="Job title of the person to enrich",
default="",
)

View File

@@ -11,14 +11,14 @@ from backend.blocks.apollo.models import (
SearchOrganizationsRequest,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.data.model import CredentialsField, SchemaField
class SearchOrganizationsBlock(Block):
"""Search for organizations in Apollo"""
class Input(BlockSchema):
organization_num_empoloyees_range: list[int] = SchemaField(
organization_num_employees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
@@ -65,7 +65,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
le=50000,
advanced=True,
)
credentials: ApolloCredentialsInput = SchemaField(
credentials: ApolloCredentialsInput = CredentialsField(
description="Apollo credentials",
)

View File

@@ -1,3 +1,5 @@
import asyncio
from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS,
@@ -8,11 +10,12 @@ from backend.blocks.apollo._auth import (
from backend.blocks.apollo.models import (
Contact,
ContactEmailStatuses,
EnrichPersonRequest,
SearchPeopleRequest,
SenorityLevels,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.data.model import CredentialsField, SchemaField
class SearchPeopleBlock(Block):
@@ -77,7 +80,7 @@ class SearchPeopleBlock(Block):
default_factory=list,
advanced=False,
)
organization_num_empoloyees_range: list[int] = SchemaField(
organization_num_employees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
@@ -90,14 +93,19 @@ class SearchPeopleBlock(Block):
advanced=False,
)
max_results: int = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
default=100,
description="""The maximum number of results to return. If you don't specify this parameter, the default is 25. Limited to 500 to prevent overspending.""",
default=25,
ge=1,
le=50000,
le=500,
advanced=True,
)
enrich_info: bool = SchemaField(
description="""Whether to enrich contacts with detailed information including real email addresses. This will double the search cost.""",
default=False,
advanced=True,
)
credentials: ApolloCredentialsInput = SchemaField(
credentials: ApolloCredentialsInput = CredentialsField(
description="Apollo credentials",
)
@@ -106,10 +114,6 @@ class SearchPeopleBlock(Block):
description="List of people found",
default_factory=list,
)
person: Contact = SchemaField(
title="Person",
description="Each found person, one at a time",
)
error: str = SchemaField(
description="Error message if the search failed",
default="",
@@ -125,87 +129,6 @@ class SearchPeopleBlock(Block):
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT},
test_output=[
(
"person",
Contact(
contact_roles=[],
id="1",
name="John Doe",
first_name="John",
last_name="Doe",
linkedin_url="https://www.linkedin.com/in/johndoe",
title="Software Engineer",
organization_name="Google",
organization_id="123456",
contact_stage_id="1",
owner_id="1",
creator_id="1",
person_id="1",
email_needs_tickling=True,
source="apollo",
original_source="apollo",
headline="Software Engineer",
photo_url="https://www.linkedin.com/in/johndoe",
present_raw_address="123 Main St, Anytown, USA",
linkededin_uid="123456",
extrapolated_email_confidence=0.8,
salesforce_id="123456",
salesforce_lead_id="123456",
salesforce_contact_id="123456",
saleforce_account_id="123456",
crm_owner_id="123456",
created_at="2021-01-01",
emailer_campaign_ids=[],
direct_dial_status="active",
direct_dial_enrichment_failed_at="2021-01-01",
email_status="active",
email_source="apollo",
account_id="123456",
last_activity_date="2021-01-01",
hubspot_vid="123456",
hubspot_company_id="123456",
crm_id="123456",
sanitized_phone="123456",
merged_crm_ids="123456",
updated_at="2021-01-01",
queued_for_crm_push=True,
suggested_from_rule_engine_config_id="123456",
email_unsubscribed=None,
label_ids=[],
has_pending_email_arcgate_request=True,
has_email_arcgate_request=True,
existence_level=None,
email=None,
email_from_customer=None,
typed_custom_fields=[],
custom_field_errors=None,
salesforce_record_id=None,
crm_record_url=None,
email_status_unavailable_reason=None,
email_true_status=None,
updated_email_true_status=True,
contact_rule_config_statuses=[],
source_display_name=None,
twitter_url=None,
contact_campaign_statuses=[],
state=None,
city=None,
country=None,
account=None,
contact_emails=[],
organization=None,
employment_history=[],
time_zone=None,
intent_strength=None,
show_intent=True,
phone_numbers=[],
account_phone_note=None,
free_domain=True,
is_likely_to_engage=True,
email_domain_catchall=True,
contact_job_change_event=None,
),
),
(
"people",
[
@@ -380,6 +303,34 @@ class SearchPeopleBlock(Block):
client = ApolloClient(credentials)
return await client.search_people(query)
@staticmethod
async def enrich_person(
query: EnrichPersonRequest, credentials: ApolloCredentials
) -> Contact:
client = ApolloClient(credentials)
return await client.enrich_person(query)
@staticmethod
def merge_contact_data(original: Contact, enriched: Contact) -> Contact:
"""
Merge contact data from original search with enriched data.
Enriched data complements original data, only filling in missing values.
"""
merged_data = original.model_dump()
enriched_data = enriched.model_dump()
# Only update fields that are None, empty string, empty list, or default values in original
for key, enriched_value in enriched_data.items():
# Skip if enriched value is None, empty string, or empty list
if enriched_value is None or enriched_value == "" or enriched_value == []:
continue
# Update if original value is None, empty string, empty list, or zero
if enriched_value:
merged_data[key] = enriched_value
return Contact(**merged_data)
async def run(
self,
input_data: Input,
@@ -390,6 +341,23 @@ class SearchPeopleBlock(Block):
query = SearchPeopleRequest(**input_data.model_dump())
people = await self.search_people(query, credentials)
for person in people:
yield "person", person
# Enrich with detailed info if requested
if input_data.enrich_info:
async def enrich_or_fallback(person: Contact):
try:
enrich_query = EnrichPersonRequest(person_id=person.id)
enriched_person = await self.enrich_person(
enrich_query, credentials
)
# Merge enriched data with original data, complementing instead of replacing
return self.merge_contact_data(person, enriched_person)
except Exception:
return person # If enrichment fails, use original person data
people = await asyncio.gather(
*(enrich_or_fallback(person) for person in people)
)
yield "people", people

View File

@@ -0,0 +1,138 @@
from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
ApolloCredentials,
ApolloCredentialsInput,
)
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, SchemaField
class GetPersonDetailBlock(Block):
"""Get detailed person data with Apollo API, including email reveal"""
class Input(BlockSchema):
person_id: str = SchemaField(
description="Apollo person ID to enrich (most accurate method)",
default="",
advanced=False,
)
first_name: str = SchemaField(
description="First name of the person to enrich",
default="",
advanced=False,
)
last_name: str = SchemaField(
description="Last name of the person to enrich",
default="",
advanced=False,
)
name: str = SchemaField(
description="Full name of the person to enrich (alternative to first_name + last_name)",
default="",
advanced=False,
)
email: str = SchemaField(
description="Known email address of the person (helps with matching)",
default="",
advanced=False,
)
domain: str = SchemaField(
description="Company domain of the person (e.g., 'google.com')",
default="",
advanced=False,
)
company: str = SchemaField(
description="Company name of the person",
default="",
advanced=False,
)
linkedin_url: str = SchemaField(
description="LinkedIn URL of the person",
default="",
advanced=False,
)
organization_id: str = SchemaField(
description="Apollo organization ID of the person's company",
default="",
advanced=True,
)
title: str = SchemaField(
description="Job title of the person to enrich",
default="",
advanced=True,
)
credentials: ApolloCredentialsInput = CredentialsField(
description="Apollo credentials",
)
class Output(BlockSchema):
contact: Contact = SchemaField(
description="Enriched contact information",
)
error: str = SchemaField(
description="Error message if enrichment failed",
default="",
)
def __init__(self):
super().__init__(
id="3b18d46c-3db6-42ae-a228-0ba441bdd176",
description="Get detailed person data with Apollo API, including email reveal",
categories={BlockCategory.SEARCH},
input_schema=GetPersonDetailBlock.Input,
output_schema=GetPersonDetailBlock.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"first_name": "John",
"last_name": "Doe",
"company": "Google",
},
test_output=[
(
"contact",
Contact(
id="1",
name="John Doe",
first_name="John",
last_name="Doe",
email="john.doe@gmail.com",
title="Software Engineer",
organization_name="Google",
linkedin_url="https://www.linkedin.com/in/johndoe",
),
),
],
test_mock={
"enrich_person": lambda query, credentials: Contact(
id="1",
name="John Doe",
first_name="John",
last_name="Doe",
email="john.doe@gmail.com",
title="Software Engineer",
organization_name="Google",
linkedin_url="https://www.linkedin.com/in/johndoe",
)
},
)
@staticmethod
async def enrich_person(
query: EnrichPersonRequest, credentials: ApolloCredentials
) -> Contact:
client = ApolloClient(credentials)
return await client.enrich_person(query)
async def run(
self,
input_data: Input,
*,
credentials: ApolloCredentials,
**kwargs,
) -> BlockOutput:
query = EnrichPersonRequest(**input_data.model_dump())
yield "contact", await self.enrich_person(query, credentials)

View File

@@ -6,6 +6,7 @@ from backend.data.model import SchemaField
from backend.util import json
from backend.util.file import store_media_file
from backend.util.mock import MockObject
from backend.util.prompt import estimate_token_count_str
from backend.util.type import MediaFileType, convert
@@ -14,6 +15,12 @@ class FileStoreBlock(Block):
file_in: MediaFileType = SchemaField(
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
)
base_64: bool = SchemaField(
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
default=False,
advanced=True,
title="Produce Base64 Output",
)
class Output(BlockSchema):
file_out: MediaFileType = SchemaField(
@@ -37,12 +44,11 @@ class FileStoreBlock(Block):
graph_exec_id: str,
**kwargs,
) -> BlockOutput:
file_path = await store_media_file(
yield "file_out", await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.file_in,
return_content=False,
return_content=input_data.base_64,
)
yield "file_out", file_path
class StoreValueBlock(Block):
@@ -461,6 +467,11 @@ class CreateListBlock(Block):
description="Maximum size of the list. If provided, the list will be yielded in chunks of this size.",
advanced=True,
)
max_tokens: int | None = SchemaField(
default=None,
description="Maximum tokens for the list. If provided, the list will be yielded in chunks that fit within this token limit.",
advanced=True,
)
class Output(BlockSchema):
list: List[Any] = SchemaField(
@@ -471,7 +482,7 @@ class CreateListBlock(Block):
def __init__(self):
super().__init__(
id="a912d5c7-6e00-4542-b2a9-8034136930e4",
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront.",
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront. This block can also yield the list in batches based on a maximum size or token limit.",
categories={BlockCategory.DATA},
input_schema=CreateListBlock.Input,
output_schema=CreateListBlock.Output,
@@ -496,12 +507,30 @@ class CreateListBlock(Block):
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
max_size = input_data.max_size or len(input_data.values)
for i in range(0, len(input_data.values), max_size):
yield "list", input_data.values[i : i + max_size]
except Exception as e:
yield "error", f"Failed to create list: {str(e)}"
chunk = []
cur_tokens, max_tokens = 0, input_data.max_tokens
cur_size, max_size = 0, input_data.max_size
for value in input_data.values:
if max_tokens:
tokens = estimate_token_count_str(value)
else:
tokens = 0
# Check if adding this value would exceed either limit
if (max_tokens and (cur_tokens + tokens > max_tokens)) or (
max_size and (cur_size + 1 > max_size)
):
yield "list", chunk
chunk = [value]
cur_size, cur_tokens = 1, tokens
else:
chunk.append(value)
cur_size, cur_tokens = cur_size + 1, cur_tokens + tokens
# Yield final chunk if any
if chunk:
yield "list", chunk
class TypeOptions(enum.Enum):

View File

@@ -3,6 +3,7 @@ from typing import Any
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.type import convert
class ComparisonOperator(Enum):
@@ -181,7 +182,23 @@ class IfInputMatchesBlock(Block):
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
if input_data.input == input_data.value or input_data.input is input_data.value:
# If input_data.value is not matching input_data.input, convert value to type of input
if (
input_data.input != input_data.value
and input_data.input is not input_data.value
):
try:
# Only attempt conversion if input is not None and value is not None
if input_data.input is not None and input_data.value is not None:
input_type = type(input_data.input)
# Avoid converting if input_type is Any or object
if input_type not in (Any, object):
input_data.value = convert(input_data.value, input_type)
except Exception:
pass # If conversion fails, just leave value as is
if input_data.input == input_data.value:
yield "result", True
yield "yes_output", input_data.yes_value
else:

View File

@@ -13,7 +13,7 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import MediaFileType
from backend.util.file import MediaFileType, store_media_file
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -108,7 +108,7 @@ class AIImageEditorBlock(Block):
output_schema=AIImageEditorBlock.Output,
test_input={
"prompt": "Add a hat to the cat",
"input_image": "https://example.com/cat.png",
"input_image": "",
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
"seed": None,
"model": FluxKontextModelName.PRO,
@@ -128,13 +128,22 @@ class AIImageEditorBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
**kwargs,
) -> BlockOutput:
result = await self.run_model(
api_key=credentials.api_key,
model_name=input_data.model.api_name,
prompt=input_data.prompt,
input_image=input_data.input_image,
input_image_b64=(
await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.input_image,
return_content=True,
)
if input_data.input_image
else None
),
aspect_ratio=input_data.aspect_ratio.value,
seed=input_data.seed,
)
@@ -145,14 +154,14 @@ class AIImageEditorBlock(Block):
api_key: SecretStr,
model_name: str,
prompt: str,
input_image: Optional[MediaFileType],
input_image_b64: Optional[str],
aspect_ratio: str,
seed: Optional[int],
) -> MediaFileType:
client = ReplicateClient(api_token=api_key.get_secret_value())
input_params = {
"prompt": prompt,
"input_image": input_image,
"input_image": input_image_b64,
"aspect_ratio": aspect_ratio,
**({"seed": seed} if seed is not None else {}),
}

View File

@@ -265,10 +265,26 @@ class GithubReadPullRequestBlock(Block):
files = response.json()
changes = []
for file in files:
filename = file.get("filename", "")
status = file.get("status", "")
changes.append(f"{filename}: {status}")
return "\n".join(changes)
status: str = file.get("status", "")
diff: str = file.get("patch", "")
if status != "removed":
is_filename: str = file.get("filename", "")
was_filename: str = (
file.get("previous_filename", is_filename)
if status != "added"
else ""
)
else:
is_filename = ""
was_filename: str = file.get("filename", "")
patch_header = ""
if was_filename:
patch_header += f"--- {was_filename}\n"
if is_filename:
patch_header += f"+++ {is_filename}\n"
changes.append(patch_header + diff)
return "\n\n".join(changes)
async def run(
self,

View File

@@ -3,11 +3,19 @@ import logging
from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import Literal
import aiofiles
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
HostScopedCredentials,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import (
MediaFileType,
get_exec_file_path,
@@ -19,6 +27,30 @@ from backend.util.request import Requests
logger = logging.getLogger(name=__name__)
# Host-scoped credentials for HTTP requests
HttpCredentials = CredentialsMetaInput[
Literal[ProviderName.HTTP], Literal["host_scoped"]
]
TEST_CREDENTIALS = HostScopedCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="http",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer test-token"),
"X-API-Key": SecretStr("test-api-key"),
},
title="Mock HTTP Host-Scoped Credentials",
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
class HttpMethod(Enum):
GET = "GET"
POST = "POST"
@@ -169,3 +201,62 @@ class SendWebRequestBlock(Block):
yield "client_error", result
else:
yield "server_error", result
class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
class Input(SendWebRequestBlock.Input):
credentials: HttpCredentials = CredentialsField(
description="HTTP host-scoped credentials for automatic header injection",
discriminator="url",
)
def __init__(self):
Block.__init__(
self,
id="fff86bcd-e001-4bad-a7f6-2eae4720c8dc",
description="Make an authenticated HTTP request with host-scoped credentials (JSON / form / multipart).",
categories={BlockCategory.OUTPUT},
input_schema=SendAuthenticatedWebRequestBlock.Input,
output_schema=SendWebRequestBlock.Output,
test_credentials=TEST_CREDENTIALS,
)
async def run( # type: ignore[override]
self,
input_data: Input,
*,
graph_exec_id: str,
credentials: HostScopedCredentials,
**kwargs,
) -> BlockOutput:
# Create SendWebRequestBlock.Input from our input (removing credentials field)
base_input = SendWebRequestBlock.Input(
url=input_data.url,
method=input_data.method,
headers=input_data.headers,
json_format=input_data.json_format,
body=input_data.body,
files_name=input_data.files_name,
files=input_data.files,
)
# Apply host-scoped credentials to headers
extra_headers = {}
if credentials.matches_url(input_data.url):
logger.debug(
f"Applying host-scoped credentials {credentials.id} for URL {input_data.url}"
)
extra_headers.update(credentials.get_headers_dict())
else:
logger.warning(
f"Host-scoped credentials {credentials.id} do not match URL {input_data.url}"
)
# Merge with user-provided headers (user headers take precedence)
base_input.headers = {**extra_headers, **input_data.headers}
# Use parent class run method
async for output_name, output_data in super().run(
base_input, graph_exec_id=graph_exec_id, **kwargs
):
yield output_name, output_data

View File

@@ -413,6 +413,12 @@ class AgentFileInputBlock(AgentInputBlock):
advanced=False,
title="Default Value",
)
base_64: bool = SchemaField(
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
default=False,
advanced=True,
title="Produce Base64 Output",
)
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="File reference/path result.")
@@ -446,12 +452,11 @@ class AgentFileInputBlock(AgentInputBlock):
if not input_data.value:
return
file_path = await store_media_file(
yield "result", await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.value,
return_content=False,
return_content=input_data.base_64,
)
yield "result", file_path
class AgentDropdownInputBlock(AgentInputBlock):

View File

@@ -23,6 +23,7 @@ from backend.data.model import (
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_prompt, estimate_token_count
from backend.util.text import TextFormatter
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
@@ -40,7 +41,7 @@ LLMProviderName = Literal[
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
TEST_CREDENTIALS = APIKeyCredentials(
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
id="769f6af7-820b-4d5d-9b7a-ab82bbc165f",
provider="openai",
api_key=SecretStr("mock-openai-api-key"),
title="Mock OpenAI API key",
@@ -306,13 +307,6 @@ def convert_openai_tool_fmt_to_anthropic(
return anthropic_tools
def estimate_token_count(prompt_messages: list[dict]) -> int:
char_count = sum(len(str(msg.get("content", ""))) for msg in prompt_messages)
message_overhead = len(prompt_messages) * 4
estimated_tokens = (char_count // 4) + message_overhead
return int(estimated_tokens * 1.2)
async def llm_call(
credentials: APIKeyCredentials,
llm_model: LlmModel,
@@ -321,7 +315,8 @@ async def llm_call(
max_tokens: int | None,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
parallel_tool_calls: bool | None = None,
parallel_tool_calls=None,
compress_prompt_to_fit: bool = True,
) -> LLMResponse:
"""
Make a call to a language model.
@@ -344,10 +339,17 @@ async def llm_call(
- completion_tokens: The number of tokens used in the completion.
"""
provider = llm_model.metadata.provider
context_window = llm_model.context_window
if compress_prompt_to_fit:
prompt = compress_prompt(
messages=prompt,
target_tokens=llm_model.context_window // 2,
lossy_ok=True,
)
# Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt)
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or int(2**15)
user_max = max_tokens or model_max_output
available_tokens = max(context_window - estimated_input_tokens, 0)
@@ -358,14 +360,10 @@ async def llm_call(
oai_client = openai.AsyncOpenAI(api_key=credentials.api_key.get_secret_value())
response_format = None
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
prompt = [
{"role": "user", "content": "\n".join(sys_messages)},
{"role": "user", "content": "\n".join(usr_messages)},
]
elif json_format:
if llm_model.startswith("o") or parallel_tool_calls is None:
parallel_tool_calls = openai.NOT_GIVEN
if json_format:
response_format = {"type": "json_object"}
response = await oai_client.chat.completions.create(
@@ -374,9 +372,7 @@ async def llm_call(
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=(
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
),
parallel_tool_calls=parallel_tool_calls,
)
if response.choices[0].message.tool_calls:
@@ -699,7 +695,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
compress_prompt_to_fit: bool = SchemaField(
advanced=True,
default=True,
description="Whether to compress the prompt to fit within the model's context window.",
)
ollama_host: str = SchemaField(
advanced=True,
default="localhost:11434",
@@ -757,6 +757,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
llm_model: LlmModel,
prompt: list[dict],
json_format: bool,
compress_prompt_to_fit: bool,
max_tokens: int | None,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
@@ -774,6 +775,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
max_tokens=max_tokens,
tools=tools,
ollama_host=ollama_host,
compress_prompt_to_fit=compress_prompt_to_fit,
)
async def run(
@@ -832,7 +834,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
except JSONDecodeError as e:
return f"JSON decode error: {e}"
logger.info(f"LLM request: {prompt}")
logger.debug(f"LLM request: {prompt}")
retry_prompt = ""
llm_model = input_data.model
@@ -842,6 +844,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
credentials=credentials,
llm_model=llm_model,
prompt=prompt,
compress_prompt_to_fit=input_data.compress_prompt_to_fit,
json_format=bool(input_data.expected_format),
ollama_host=input_data.ollama_host,
max_tokens=input_data.max_tokens,
@@ -853,7 +856,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
output_token_count=llm_response.completion_tokens,
)
)
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
if input_data.expected_format:

View File

@@ -13,7 +13,7 @@ from backend.data.model import (
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
id="8cc8b2c5-d3e4-4b1c-84ad-e1e9fe2a0122",
provider="mem0",
api_key=SecretStr("mock-mem0-api-key"),
title="Mock Mem0 API key",

View File

@@ -85,7 +85,7 @@ def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
return tool_call_ids
def _create_tool_response(call_id: str, output: dict[str, Any]) -> dict[str, Any]:
def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
"""
Create a tool response message for either OpenAI or Anthropics,
based on the tool_id format.
@@ -212,6 +212,15 @@ class SmartDecisionMakerBlock(Block):
"link like the output of `StoreValue` or `AgentInput` block"
)
# Check that both conversation_history and last_tool_output are connected together
if any(link.sink_name == "conversation_history" for link in links) != any(
link.sink_name == "last_tool_output" for link in links
):
raise ValueError(
"Last Tool Output is needed when Conversation History is used, "
"and vice versa. Please connect both inputs together."
)
return missing_links
@classmethod
@@ -222,8 +231,15 @@ class SmartDecisionMakerBlock(Block):
conversation_history = data.get("conversation_history", [])
pending_tool_calls = get_pending_tool_calls(conversation_history)
last_tool_output = data.get("last_tool_output")
if not last_tool_output and pending_tool_calls:
# Tool call is pending, wait for the tool output to be provided.
if last_tool_output is None and pending_tool_calls:
return {"last_tool_output"}
# No tool call is pending, wait for the conversation history to be updated.
if last_tool_output is not None and not pending_tool_calls:
return {"conversation_history"}
return set()
class Output(BlockSchema):
@@ -433,7 +449,7 @@ class SmartDecisionMakerBlock(Block):
prompt = [json.to_dict(p) for p in input_data.conversation_history if p]
pending_tool_calls = get_pending_tool_calls(input_data.conversation_history)
if pending_tool_calls and not input_data.last_tool_output:
if pending_tool_calls and input_data.last_tool_output is None:
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
# Prefill all missing tool calls with the last tool output/
@@ -497,7 +513,7 @@ class SmartDecisionMakerBlock(Block):
max_tokens=input_data.max_tokens,
tools=tool_functions,
ollama_host=input_data.ollama_host,
parallel_tool_calls=True if input_data.multiple_tool_calls else None,
parallel_tool_calls=input_data.multiple_tool_calls,
)
if not response.tool_calls:

View File

@@ -17,7 +17,7 @@ from backend.blocks.smartlead.models import (
Sequence,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.data.model import CredentialsField, SchemaField
class CreateCampaignBlock(Block):
@@ -27,7 +27,7 @@ class CreateCampaignBlock(Block):
name: str = SchemaField(
description="The name of the campaign",
)
credentials: SmartLeadCredentialsInput = SchemaField(
credentials: SmartLeadCredentialsInput = CredentialsField(
description="SmartLead credentials",
)
@@ -119,7 +119,7 @@ class AddLeadToCampaignBlock(Block):
description="Settings for lead upload",
default=LeadUploadSettings(),
)
credentials: SmartLeadCredentialsInput = SchemaField(
credentials: SmartLeadCredentialsInput = CredentialsField(
description="SmartLead credentials",
)
@@ -251,7 +251,7 @@ class SaveCampaignSequencesBlock(Block):
default_factory=list,
advanced=False,
)
credentials: SmartLeadCredentialsInput = SchemaField(
credentials: SmartLeadCredentialsInput = CredentialsField(
description="SmartLead credentials",
)

View File

@@ -0,0 +1,485 @@
"""Comprehensive tests for HTTP block with HostScopedCredentials functionality."""
from typing import cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.blocks.http import (
HttpCredentials,
HttpMethod,
SendAuthenticatedWebRequestBlock,
)
from backend.data.model import HostScopedCredentials
from backend.util.request import Response
class TestHttpBlockWithHostScopedCredentials:
"""Test suite for HTTP block integration with HostScopedCredentials."""
@pytest.fixture
def http_block(self):
"""Create an HTTP block instance."""
return SendAuthenticatedWebRequestBlock()
@pytest.fixture
def mock_response(self):
"""Mock a successful HTTP response."""
response = MagicMock(spec=Response)
response.status = 200
response.headers = {"content-type": "application/json"}
response.json.return_value = {"success": True, "data": "test"}
return response
@pytest.fixture
def exact_match_credentials(self):
"""Create host-scoped credentials for exact domain matching."""
return HostScopedCredentials(
provider="http",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer exact-match-token"),
"X-API-Key": SecretStr("api-key-123"),
},
title="Exact Match API Credentials",
)
@pytest.fixture
def wildcard_credentials(self):
"""Create host-scoped credentials with wildcard pattern."""
return HostScopedCredentials(
provider="http",
host="*.github.com",
headers={
"Authorization": SecretStr("token ghp_wildcard123"),
},
title="GitHub Wildcard Credentials",
)
@pytest.fixture
def non_matching_credentials(self):
"""Create credentials that don't match test URLs."""
return HostScopedCredentials(
provider="http",
host="different.api.com",
headers={
"Authorization": SecretStr("Bearer non-matching-token"),
},
title="Non-matching Credentials",
)
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_http_block_with_exact_host_match(
self,
mock_requests_class,
http_block,
exact_match_credentials,
mock_response,
):
"""Test HTTP block with exact host matching credentials."""
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Prepare input data
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.example.com/data",
method=HttpMethod.GET,
headers={"User-Agent": "test-agent"},
credentials=cast(
HttpCredentials,
{
"id": exact_match_credentials.id,
"provider": "http",
"type": "host_scoped",
"title": exact_match_credentials.title,
},
),
)
# Execute with credentials provided by execution manager
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=exact_match_credentials,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify request headers include both credential and user headers
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {
"Authorization": "Bearer exact-match-token",
"X-API-Key": "api-key-123",
"User-Agent": "test-agent",
}
assert call_args.kwargs["headers"] == expected_headers
# Verify response handling
assert len(result) == 1
assert result[0][0] == "response"
assert result[0][1] == {"success": True, "data": "test"}
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_http_block_with_wildcard_host_match(
self,
mock_requests_class,
http_block,
wildcard_credentials,
mock_response,
):
"""Test HTTP block with wildcard host pattern matching."""
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Test with subdomain that should match *.github.com
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.github.com/user",
method=HttpMethod.GET,
headers={},
credentials=cast(
HttpCredentials,
{
"id": wildcard_credentials.id,
"provider": "http",
"type": "host_scoped",
"title": wildcard_credentials.title,
},
),
)
# Execute with wildcard credentials
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=wildcard_credentials,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify wildcard matching works
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {"Authorization": "token ghp_wildcard123"}
assert call_args.kwargs["headers"] == expected_headers
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_http_block_with_non_matching_credentials(
self,
mock_requests_class,
http_block,
non_matching_credentials,
mock_response,
):
"""Test HTTP block when credentials don't match the target URL."""
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Test with URL that doesn't match the credentials
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.example.com/data",
method=HttpMethod.GET,
headers={"User-Agent": "test-agent"},
credentials=cast(
HttpCredentials,
{
"id": non_matching_credentials.id,
"provider": "http",
"type": "host_scoped",
"title": non_matching_credentials.title,
},
),
)
# Execute with non-matching credentials
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=non_matching_credentials,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify only user headers are included (no credential headers)
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {"User-Agent": "test-agent"}
assert call_args.kwargs["headers"] == expected_headers
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_user_headers_override_credential_headers(
self,
mock_requests_class,
http_block,
exact_match_credentials,
mock_response,
):
"""Test that user-provided headers take precedence over credential headers."""
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Test with user header that conflicts with credential header
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.example.com/data",
method=HttpMethod.POST,
headers={
"Authorization": "Bearer user-override-token", # Should override
"Content-Type": "application/json", # Additional user header
},
credentials=cast(
HttpCredentials,
{
"id": exact_match_credentials.id,
"provider": "http",
"type": "host_scoped",
"title": exact_match_credentials.title,
},
),
)
# Execute with conflicting headers
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=exact_match_credentials,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify user headers take precedence
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {
"X-API-Key": "api-key-123", # From credentials
"Authorization": "Bearer user-override-token", # User override
"Content-Type": "application/json", # User header
}
assert call_args.kwargs["headers"] == expected_headers
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_auto_discovered_credentials_flow(
self,
mock_requests_class,
http_block,
mock_response,
):
"""Test the auto-discovery flow where execution manager provides matching credentials."""
# Create auto-discovered credentials
auto_discovered_creds = HostScopedCredentials(
provider="http",
host="*.example.com",
headers={
"Authorization": SecretStr("Bearer auto-discovered-token"),
},
title="Auto-discovered Credentials",
)
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Test with empty credentials field (triggers auto-discovery)
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.example.com/data",
method=HttpMethod.GET,
headers={},
credentials=cast(
HttpCredentials,
{
"id": "", # Empty ID triggers auto-discovery in execution manager
"provider": "http",
"type": "host_scoped",
"title": "",
},
),
)
# Execute with auto-discovered credentials provided by execution manager
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=auto_discovered_creds, # Execution manager found these
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify auto-discovered credentials were applied
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {"Authorization": "Bearer auto-discovered-token"}
assert call_args.kwargs["headers"] == expected_headers
# Verify response handling
assert len(result) == 1
assert result[0][0] == "response"
assert result[0][1] == {"success": True, "data": "test"}
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_multiple_header_credentials(
self,
mock_requests_class,
http_block,
mock_response,
):
"""Test credentials with multiple headers are all applied."""
# Create credentials with multiple headers
multi_header_creds = HostScopedCredentials(
provider="http",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer multi-token"),
"X-API-Key": SecretStr("api-key-456"),
"X-Client-ID": SecretStr("client-789"),
"X-Custom-Header": SecretStr("custom-value"),
},
title="Multi-Header Credentials",
)
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Test with credentials containing multiple headers
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.example.com/data",
method=HttpMethod.GET,
headers={"User-Agent": "test-agent"},
credentials=cast(
HttpCredentials,
{
"id": multi_header_creds.id,
"provider": "http",
"type": "host_scoped",
"title": multi_header_creds.title,
},
),
)
# Execute with multi-header credentials
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=multi_header_creds,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify all headers are included
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {
"Authorization": "Bearer multi-token",
"X-API-Key": "api-key-456",
"X-Client-ID": "client-789",
"X-Custom-Header": "custom-value",
"User-Agent": "test-agent",
}
assert call_args.kwargs["headers"] == expected_headers
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_credentials_with_complex_url_patterns(
self,
mock_requests_class,
http_block,
mock_response,
):
"""Test credentials matching various URL patterns."""
# Test cases for different URL patterns
test_cases = [
{
"host_pattern": "api.example.com",
"test_url": "https://api.example.com/v1/users",
"should_match": True,
},
{
"host_pattern": "*.example.com",
"test_url": "https://api.example.com/v1/users",
"should_match": True,
},
{
"host_pattern": "*.example.com",
"test_url": "https://subdomain.example.com/data",
"should_match": True,
},
{
"host_pattern": "api.example.com",
"test_url": "https://api.different.com/data",
"should_match": False,
},
]
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
for case in test_cases:
# Reset mock for each test case
mock_requests.reset_mock()
# Create credentials for this test case
test_creds = HostScopedCredentials(
provider="http",
host=case["host_pattern"],
headers={
"Authorization": SecretStr(f"Bearer {case['host_pattern']}-token"),
},
title=f"Credentials for {case['host_pattern']}",
)
input_data = SendAuthenticatedWebRequestBlock.Input(
url=case["test_url"],
method=HttpMethod.GET,
headers={"User-Agent": "test-agent"},
credentials=cast(
HttpCredentials,
{
"id": test_creds.id,
"provider": "http",
"type": "host_scoped",
"title": test_creds.title,
},
),
)
# Execute with test credentials
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=test_creds,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify headers based on whether pattern should match
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
headers = call_args.kwargs["headers"]
if case["should_match"]:
# Should include both user and credential headers
expected_auth = f"Bearer {case['host_pattern']}-token"
assert headers["Authorization"] == expected_auth
assert headers["User-Agent"] == "test-agent"
else:
# Should only include user headers
assert "Authorization" not in headers
assert headers["User-Agent"] == "test-agent"

View File

@@ -25,12 +25,7 @@ async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Grap
async def create_credentials(s: SpinTestServer, u: User):
provider = ProviderName.OPENAI
credentials = llm.TEST_CREDENTIALS
try:
await s.agent_server.test_create_credentials(u.id, provider, credentials)
except Exception:
# ValueErrors is raised trying to recreate the same credentials
# so hidding the error
pass
return await s.agent_server.test_create_credentials(u.id, provider, credentials)
async def execute_graph(
@@ -60,19 +55,18 @@ async def execute_graph(
return graph_exec_id
@pytest.mark.skip()
@pytest.mark.asyncio(loop_scope="session")
async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
test_user = await create_test_user()
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
await create_credentials(server, test_user)
creds = await create_credentials(server, test_user)
nodes = [
graph.Node(
block_id=SmartDecisionMakerBlock().id,
input_default={
"prompt": "Hello, World!",
"credentials": llm.TEST_CREDENTIALS_INPUT,
"credentials": creds,
},
),
graph.Node(
@@ -110,80 +104,18 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
test_graph = await create_graph(server, test_graph, test_user)
@pytest.mark.skip()
@pytest.mark.asyncio(loop_scope="session")
async def test_graph_validation_with_tool_nodes_raises_error(server: SpinTestServer):
test_user = await create_test_user()
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
await create_credentials(server, test_user)
nodes = [
graph.Node(
block_id=SmartDecisionMakerBlock().id,
input_default={
"prompt": "Hello, World!",
"credentials": llm.TEST_CREDENTIALS_INPUT,
},
),
graph.Node(
block_id=AgentExecutorBlock().id,
input_default={
"graph_id": test_tool_graph.id,
"graph_version": test_tool_graph.version,
"input_schema": test_tool_graph.input_schema,
"output_schema": test_tool_graph.output_schema,
},
),
graph.Node(
block_id=StoreValueBlock().id,
),
]
links = [
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[1].id,
source_name="tools_^_sample_tool_input_1",
sink_name="input_1",
),
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[1].id,
source_name="tools_^_sample_tool_input_2",
sink_name="input_2",
),
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[2].id,
source_name="tools_^_store_value_input",
sink_name="input",
),
]
test_graph = graph.Graph(
name="TestGraph",
description="Test graph",
nodes=nodes,
links=links,
)
with pytest.raises(ValueError):
test_graph = await create_graph(server, test_graph, test_user)
@pytest.mark.skip()
@pytest.mark.asyncio(loop_scope="session")
async def test_smart_decision_maker_function_signature(server: SpinTestServer):
test_user = await create_test_user()
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
await create_credentials(server, test_user)
creds = await create_credentials(server, test_user)
nodes = [
graph.Node(
block_id=SmartDecisionMakerBlock().id,
input_default={
"prompt": "Hello, World!",
"credentials": llm.TEST_CREDENTIALS_INPUT,
"credentials": creds,
},
),
graph.Node(

View File

@@ -15,7 +15,7 @@ from backend.blocks.zerobounce._auth import (
ZeroBounceCredentialsInput,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.data.model import CredentialsField, SchemaField
class Response(BaseModel):
@@ -90,7 +90,7 @@ class ValidateEmailsBlock(Block):
description="IP address to validate",
default="",
)
credentials: ZeroBounceCredentialsInput = SchemaField(
credentials: ZeroBounceCredentialsInput = CredentialsField(
description="ZeroBounce credentials",
)

View File

@@ -6,6 +6,8 @@ from dotenv import load_dotenv
from backend.util.logging import configure_logging
os.environ["ENABLE_AUTH"] = "false"
load_dotenv()
# NOTE: You can run tests like with the --log-cli-level=INFO to see the logs

View File

@@ -0,0 +1,5 @@
from .graph import NodeModel
from .integrations import Webhook # noqa: F401
# Resolve Webhook <- NodeModel forward reference
NodeModel.model_rebuild()

View File

@@ -78,6 +78,7 @@ class BlockCategory(Enum):
PRODUCTIVITY = "Block that helps with productivity"
ISSUE_TRACKING = "Block that helps with issue tracking"
MULTIMEDIA = "Block that interacts with multimedia content"
MARKETING = "Block that helps with marketing"
def dict(self) -> dict[str, str]:
return {"category": self.name, "description": self.value}
@@ -485,6 +486,22 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
raise ValueError(f"Block produced an invalid output data: {error}")
yield output_name, output_data
def is_triggered_by_event_type(
self, trigger_config: dict[str, Any], event_type: str
) -> bool:
if not self.webhook_config:
raise TypeError("This method can't be used on non-trigger blocks")
if not self.webhook_config.event_filter_input:
return True
event_filter = trigger_config.get(self.webhook_config.event_filter_input)
if not event_filter:
raise ValueError("Event filter is not configured on trigger")
return event_type in [
self.webhook_config.event_format.format(event=k)
for k in event_filter
if event_filter[k] is True
]
# ======================= Block Helper Functions ======================= #

View File

@@ -4,6 +4,7 @@ from backend.blocks.ai_music_generator import AIMusicGeneratorBlock
from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
from backend.blocks.apollo.organization import SearchOrganizationsBlock
from backend.blocks.apollo.people import SearchPeopleBlock
from backend.blocks.apollo.person import GetPersonDetailBlock
from backend.blocks.flux_kontext import AIImageEditorBlock, FluxKontextModelName
from backend.blocks.ideogram import IdeogramModelBlock
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
@@ -362,7 +363,31 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
],
SearchPeopleBlock: [
BlockCost(
cost_amount=2,
cost_amount=10,
cost_filter={
"enrich_info": False,
"credentials": {
"id": apollo_credentials.id,
"provider": apollo_credentials.provider,
"type": apollo_credentials.type,
},
},
),
BlockCost(
cost_amount=20,
cost_filter={
"enrich_info": True,
"credentials": {
"id": apollo_credentials.id,
"provider": apollo_credentials.provider,
"type": apollo_credentials.type,
},
},
),
],
GetPersonDetailBlock: [
BlockCost(
cost_amount=1,
cost_filter={
"credentials": {
"id": apollo_credentials.id,

View File

@@ -7,7 +7,7 @@ from pydantic import BaseModel
from redis.asyncio.client import PubSub as AsyncPubSub
from redis.client import PubSub
from backend.data import redis
from backend.data import redis_client as redis
logger = logging.getLogger(__name__)

View File

@@ -32,7 +32,7 @@ from prisma.types import (
AgentNodeExecutionUpdateInput,
AgentNodeExecutionWhereInput,
)
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, JsonValue
from pydantic.fields import Field
from backend.server.v2.store.exceptions import DatabaseError
@@ -48,14 +48,14 @@ from .block import (
get_webhook_block_ids,
)
from .db import BaseDbModel
from .event_bus import AsyncRedisEventBus, RedisEventBus
from .includes import (
EXECUTION_RESULT_INCLUDE,
EXECUTION_RESULT_ORDER,
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
graph_execution_include,
)
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
from .queue import AsyncRedisEventBus, RedisEventBus
from .model import GraphExecutionStats, NodeExecutionStats
T = TypeVar("T")
@@ -271,7 +271,7 @@ class GraphExecutionWithNodes(GraphExecution):
graph_id=self.graph_id,
graph_version=self.graph_version or 0,
graph_exec_id=self.id,
node_credentials_input_map={}, # FIXME
nodes_input_masks={}, # FIXME: store credentials on AgentGraphExecution
)
@@ -588,12 +588,10 @@ async def update_graph_execution_start_time(
async def update_graph_execution_stats(
graph_exec_id: str,
status: ExecutionStatus,
status: ExecutionStatus | None = None,
stats: GraphExecutionStats | None = None,
) -> GraphExecution | None:
update_data: AgentGraphExecutionUpdateManyMutationInput = {
"executionStatus": status
}
update_data: AgentGraphExecutionUpdateManyMutationInput = {}
if stats:
stats_dict = stats.model_dump()
@@ -601,6 +599,9 @@ async def update_graph_execution_stats(
stats_dict["error"] = str(stats_dict["error"])
update_data["stats"] = Json(stats_dict)
if status:
update_data["executionStatus"] = status
updated_count = await AgentGraphExecution.prisma().update_many(
where={
"id": graph_exec_id,
@@ -783,7 +784,7 @@ class GraphExecutionEntry(BaseModel):
graph_exec_id: str
graph_id: str
graph_version: int
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]]
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None
class NodeExecutionEntry(BaseModel):

View File

@@ -1,7 +1,7 @@
import logging
import uuid
from collections import defaultdict
from typing import Any, Literal, Optional, cast
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import prisma
from prisma import Json
@@ -14,7 +14,7 @@ from prisma.types import (
AgentNodeLinkCreateInput,
StoreListingVersionWhereInput,
)
from pydantic import create_model
from pydantic import JsonValue, create_model
from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock
@@ -27,12 +27,15 @@ from backend.data.model import (
CredentialsMetaInput,
is_credentials_field_name,
)
from backend.integrations.providers import ProviderName
from backend.util import type as type_utils
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
from .db import BaseDbModel, transaction
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
from .integrations import Webhook
if TYPE_CHECKING:
from .integrations import Webhook
logger = logging.getLogger(__name__)
@@ -81,10 +84,12 @@ class NodeModel(Node):
graph_version: int
webhook_id: Optional[str] = None
webhook: Optional[Webhook] = None
webhook: Optional["Webhook"] = None
@staticmethod
def from_db(node: AgentNode, for_export: bool = False) -> "NodeModel":
from .integrations import Webhook
obj = NodeModel(
id=node.id,
block_id=node.agentBlockId,
@@ -102,19 +107,7 @@ class NodeModel(Node):
return obj
def is_triggered_by_event_type(self, event_type: str) -> bool:
block = self.block
if not block.webhook_config:
raise TypeError("This method can't be used on non-webhook blocks")
if not block.webhook_config.event_filter_input:
return True
event_filter = self.input_default.get(block.webhook_config.event_filter_input)
if not event_filter:
raise ValueError(f"Event filter is not configured on node #{self.id}")
return event_type in [
block.webhook_config.event_format.format(event=k)
for k in event_filter
if event_filter[k] is True
]
return self.block.is_triggered_by_event_type(self.input_default, event_type)
def stripped_for_export(self) -> "NodeModel":
"""
@@ -162,10 +155,6 @@ class NodeModel(Node):
return result
# Fix 2-way reference Node <-> Webhook
Webhook.model_rebuild()
class BaseGraph(BaseDbModel):
version: int = 1
is_active: bool = True
@@ -255,6 +244,8 @@ class Graph(BaseGraph):
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
if field.provider != other_field.provider:
continue
if ProviderName.HTTP in field.provider:
continue
# If this happens, that means a block implementation probably needs
# to be updated.
@@ -276,6 +267,7 @@ class Graph(BaseGraph):
required_scopes=set(field_info.required_scopes or []),
discriminator=field_info.discriminator,
discriminator_mapping=field_info.discriminator_mapping,
discriminator_values=field_info.discriminator_values,
),
)
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
@@ -294,37 +286,40 @@ class Graph(BaseGraph):
Returns:
dict[aggregated_field_key, tuple(
CredentialsFieldInfo: A spec for one aggregated credentials field
(now includes discriminator_values from matching nodes)
set[(node_id, field_name)]: Node credentials fields that are
compatible with this aggregated field spec
)]
"""
return {
"_".join(sorted(agg_field_info.provider))
+ "_"
+ "_".join(sorted(agg_field_info.supported_types))
+ "_credentials": (agg_field_info, node_fields)
for agg_field_info, node_fields in CredentialsFieldInfo.combine(
*(
(
# Apply discrimination before aggregating credentials inputs
(
field_info.discriminate(
node.input_default[field_info.discriminator]
)
if (
field_info.discriminator
and node.input_default.get(field_info.discriminator)
)
else field_info
),
(node.id, field_name),
# First collect all credential field data with input defaults
node_credential_data = []
for graph in [self] + self.sub_graphs:
for node in graph.nodes:
for (
field_name,
field_info,
) in node.block.input_schema.get_credentials_fields_info().items():
discriminator = field_info.discriminator
if not discriminator:
node_credential_data.append((field_info, (node.id, field_name)))
continue
discriminator_value = node.input_default.get(discriminator)
if discriminator_value is None:
node_credential_data.append((field_info, (node.id, field_name)))
continue
discriminated_info = field_info.discriminate(discriminator_value)
discriminated_info.discriminator_values.add(discriminator_value)
node_credential_data.append(
(discriminated_info, (node.id, field_name))
)
for graph in [self] + self.sub_graphs
for node in graph.nodes
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
)
)
}
# Combine credential field info (this will merge discriminator_values automatically)
return CredentialsFieldInfo.combine(*node_credential_data)
class GraphModel(Graph):
@@ -403,16 +398,26 @@ class GraphModel(Graph):
continue
node.input_default["user_id"] = user_id
node.input_default.setdefault("inputs", {})
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
if (
graph_id := node.input_default.get("graph_id")
) and graph_id in graph_id_map:
node.input_default["graph_id"] = graph_id_map[graph_id]
def validate_graph(self, for_run: bool = False):
self._validate_graph(self, for_run)
def validate_graph(
self,
for_run: bool = False,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
):
self._validate_graph(self, for_run, nodes_input_masks)
for sub_graph in self.sub_graphs:
self._validate_graph(sub_graph, for_run)
self._validate_graph(sub_graph, for_run, nodes_input_masks)
@staticmethod
def _validate_graph(graph: BaseGraph, for_run: bool = False):
def _validate_graph(
graph: BaseGraph,
for_run: bool = False,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
):
def is_tool_pin(name: str) -> bool:
return name.startswith("tools_^_")
@@ -439,20 +444,18 @@ class GraphModel(Graph):
if (block := nodes_block.get(node.id)) is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
node_input_mask = (
nodes_input_masks.get(node.id, {}) if nodes_input_masks else {}
)
provided_inputs = set(
[sanitize(name) for name in node.input_default]
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
+ ([name for name in node_input_mask] if node_input_mask else [])
)
InputSchema = block.input_schema
for name in (required_fields := InputSchema.get_required_fields()):
if (
name not in provided_inputs
# Webhook payload is passed in by ExecutionManager
and not (
name == "payload"
and block.block_type
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
)
# Checking availability of credentials is done by ExecutionManager
and name not in InputSchema.get_credentials_fields()
# Validate only I/O nodes, or validate everything when executing
@@ -485,10 +488,18 @@ class GraphModel(Graph):
def has_value(node: Node, name: str):
return (
name in node.input_default
and node.input_default[name] is not None
and str(node.input_default[name]).strip() != ""
) or (name in input_fields and input_fields[name].default is not None)
(
name in node.input_default
and node.input_default[name] is not None
and str(node.input_default[name]).strip() != ""
)
or (name in input_fields and input_fields[name].default is not None)
or (
name in node_input_mask
and node_input_mask[name] is not None
and str(node_input_mask[name]).strip() != ""
)
)
# Validate dependencies between fields
for field_name in input_fields.keys():
@@ -574,7 +585,7 @@ class GraphModel(Graph):
graph: AgentGraph,
for_export: bool = False,
sub_graphs: list[AgentGraph] | None = None,
):
) -> "GraphModel":
return GraphModel(
id=graph.id,
user_id=graph.userId if not for_export else "",
@@ -603,6 +614,7 @@ class GraphModel(Graph):
async def get_node(node_id: str) -> NodeModel:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
node = await AgentNode.prisma().find_unique_or_raise(
where={"id": node_id},
include=AGENT_NODE_INCLUDE,
@@ -611,6 +623,7 @@ async def get_node(node_id: str) -> NodeModel:
async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
node = await AgentNode.prisma().update(
where={"id": node_id},
data=(

View File

@@ -60,7 +60,8 @@ def graph_execution_include(
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
"AgentNodes": {"include": AGENT_NODE_INCLUDE}
"AgentNodes": {"include": AGENT_NODE_INCLUDE},
"AgentPresets": {"include": {"InputPresets": True}},
}

View File

@@ -1,21 +1,25 @@
import logging
from typing import TYPE_CHECKING, AsyncGenerator, Optional
from typing import AsyncGenerator, Literal, Optional, overload
from prisma import Json
from prisma.models import IntegrationWebhook
from prisma.types import IntegrationWebhookCreateInput
from prisma.types import (
IntegrationWebhookCreateInput,
IntegrationWebhookUpdateInput,
IntegrationWebhookWhereInput,
Serializable,
)
from pydantic import Field, computed_field
from backend.data.event_bus import AsyncRedisEventBus
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
from backend.data.queue import AsyncRedisEventBus
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.utils import webhook_ingress_url
from backend.server.v2.library.model import LibraryAgentPreset
from backend.util.exceptions import NotFoundError
from .db import BaseDbModel
if TYPE_CHECKING:
from .graph import NodeModel
from .graph import NodeModel
logger = logging.getLogger(__name__)
@@ -32,8 +36,6 @@ class Webhook(BaseDbModel):
provider_webhook_id: str
attached_nodes: Optional[list["NodeModel"]] = None
@computed_field
@property
def url(self) -> str:
@@ -41,8 +43,6 @@ class Webhook(BaseDbModel):
@staticmethod
def from_db(webhook: IntegrationWebhook):
from .graph import NodeModel
return Webhook(
id=webhook.id,
user_id=webhook.userId,
@@ -54,11 +54,26 @@ class Webhook(BaseDbModel):
config=dict(webhook.config),
secret=webhook.secret,
provider_webhook_id=webhook.providerWebhookId,
attached_nodes=(
[NodeModel.from_db(node) for node in webhook.AgentNodes]
if webhook.AgentNodes is not None
else None
),
)
class WebhookWithRelations(Webhook):
triggered_nodes: list[NodeModel]
triggered_presets: list[LibraryAgentPreset]
@staticmethod
def from_db(webhook: IntegrationWebhook):
if webhook.AgentNodes is None or webhook.AgentPresets is None:
raise ValueError(
"AgentNodes and AgentPresets must be included in "
"IntegrationWebhook query with relations"
)
return WebhookWithRelations(
**Webhook.from_db(webhook).model_dump(),
triggered_nodes=[NodeModel.from_db(node) for node in webhook.AgentNodes],
triggered_presets=[
LibraryAgentPreset.from_db(preset) for preset in webhook.AgentPresets
],
)
@@ -83,7 +98,19 @@ async def create_webhook(webhook: Webhook) -> Webhook:
return Webhook.from_db(created_webhook)
async def get_webhook(webhook_id: str) -> Webhook:
@overload
async def get_webhook(
webhook_id: str, *, include_relations: Literal[True]
) -> WebhookWithRelations: ...
@overload
async def get_webhook(
webhook_id: str, *, include_relations: Literal[False] = False
) -> Webhook: ...
async def get_webhook(
webhook_id: str, *, include_relations: bool = False
) -> Webhook | WebhookWithRelations:
"""
⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints.
@@ -92,73 +119,113 @@ async def get_webhook(webhook_id: str) -> Webhook:
"""
webhook = await IntegrationWebhook.prisma().find_unique(
where={"id": webhook_id},
include=INTEGRATION_WEBHOOK_INCLUDE,
include=INTEGRATION_WEBHOOK_INCLUDE if include_relations else None,
)
if not webhook:
raise NotFoundError(f"Webhook #{webhook_id} not found")
return Webhook.from_db(webhook)
return (WebhookWithRelations if include_relations else Webhook).from_db(webhook)
async def get_all_webhooks_by_creds(credentials_id: str) -> list[Webhook]:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
@overload
async def get_all_webhooks_by_creds(
user_id: str, credentials_id: str, *, include_relations: Literal[True]
) -> list[WebhookWithRelations]: ...
@overload
async def get_all_webhooks_by_creds(
user_id: str, credentials_id: str, *, include_relations: Literal[False] = False
) -> list[Webhook]: ...
async def get_all_webhooks_by_creds(
user_id: str, credentials_id: str, *, include_relations: bool = False
) -> list[Webhook] | list[WebhookWithRelations]:
if not credentials_id:
raise ValueError("credentials_id must not be empty")
webhooks = await IntegrationWebhook.prisma().find_many(
where={"credentialsId": credentials_id},
include=INTEGRATION_WEBHOOK_INCLUDE,
where={"userId": user_id, "credentialsId": credentials_id},
include=INTEGRATION_WEBHOOK_INCLUDE if include_relations else None,
)
return [Webhook.from_db(webhook) for webhook in webhooks]
return [
(WebhookWithRelations if include_relations else Webhook).from_db(webhook)
for webhook in webhooks
]
async def find_webhook_by_credentials_and_props(
credentials_id: str, webhook_type: str, resource: str, events: list[str]
user_id: str,
credentials_id: str,
webhook_type: str,
resource: str,
events: list[str],
) -> Webhook | None:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
webhook = await IntegrationWebhook.prisma().find_first(
where={
"userId": user_id,
"credentialsId": credentials_id,
"webhookType": webhook_type,
"resource": resource,
"events": {"has_every": events},
},
include=INTEGRATION_WEBHOOK_INCLUDE,
)
return Webhook.from_db(webhook) if webhook else None
async def find_webhook_by_graph_and_props(
graph_id: str, provider: str, webhook_type: str, events: list[str]
user_id: str,
provider: str,
webhook_type: str,
graph_id: Optional[str] = None,
preset_id: Optional[str] = None,
) -> Webhook | None:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
"""Either `graph_id` or `preset_id` must be provided."""
where_clause: IntegrationWebhookWhereInput = {
"userId": user_id,
"provider": provider,
"webhookType": webhook_type,
}
if preset_id:
where_clause["AgentPresets"] = {"some": {"id": preset_id}}
elif graph_id:
where_clause["AgentNodes"] = {"some": {"agentGraphId": graph_id}}
else:
raise ValueError("Either graph_id or preset_id must be provided")
webhook = await IntegrationWebhook.prisma().find_first(
where={
"provider": provider,
"webhookType": webhook_type,
"events": {"has_every": events},
"AgentNodes": {"some": {"agentGraphId": graph_id}},
},
include=INTEGRATION_WEBHOOK_INCLUDE,
where=where_clause,
)
return Webhook.from_db(webhook) if webhook else None
async def update_webhook_config(webhook_id: str, updated_config: dict) -> Webhook:
async def update_webhook(
webhook_id: str,
config: Optional[dict[str, Serializable]] = None,
events: Optional[list[str]] = None,
) -> Webhook:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
data: IntegrationWebhookUpdateInput = {}
if config is not None:
data["config"] = Json(config)
if events is not None:
data["events"] = events
if not data:
raise ValueError("Empty update query")
_updated_webhook = await IntegrationWebhook.prisma().update(
where={"id": webhook_id},
data={"config": Json(updated_config)},
include=INTEGRATION_WEBHOOK_INCLUDE,
data=data,
)
if _updated_webhook is None:
raise ValueError(f"Webhook #{webhook_id} not found")
raise NotFoundError(f"Webhook #{webhook_id} not found")
return Webhook.from_db(_updated_webhook)
async def delete_webhook(webhook_id: str) -> None:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
deleted = await IntegrationWebhook.prisma().delete(where={"id": webhook_id})
if not deleted:
raise ValueError(f"Webhook #{webhook_id} not found")
async def delete_webhook(user_id: str, webhook_id: str) -> None:
deleted = await IntegrationWebhook.prisma().delete_many(
where={"id": webhook_id, "userId": user_id}
)
if deleted < 1:
raise NotFoundError(f"Webhook #{webhook_id} not found")
# --------------------- WEBHOOK EVENTS --------------------- #

View File

@@ -14,11 +14,12 @@ from typing import (
Generic,
Literal,
Optional,
Sequence,
TypedDict,
TypeVar,
cast,
get_args,
)
from urllib.parse import urlparse
from uuid import uuid4
from prisma.enums import CreditTransactionType
@@ -240,13 +241,65 @@ class UserPasswordCredentials(_BaseCredentials):
return f"Basic {base64.b64encode(f'{self.username.get_secret_value()}:{self.password.get_secret_value()}'.encode()).decode()}"
class HostScopedCredentials(_BaseCredentials):
type: Literal["host_scoped"] = "host_scoped"
host: str = Field(description="The host/URI pattern to match against request URLs")
headers: dict[str, SecretStr] = Field(
description="Key-value header map to add to matching requests",
default_factory=dict,
)
def _extract_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]:
"""Helper to extract secret values from headers."""
return {key: value.get_secret_value() for key, value in headers.items()}
@field_serializer("headers")
def serialize_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]:
"""Serialize headers by extracting secret values."""
return self._extract_headers(headers)
def get_headers_dict(self) -> dict[str, str]:
"""Get headers with secret values extracted."""
return self._extract_headers(self.headers)
def auth_header(self) -> str:
"""Get authorization header for backward compatibility."""
auth_headers = self.get_headers_dict()
if "Authorization" in auth_headers:
return auth_headers["Authorization"]
return ""
def matches_url(self, url: str) -> bool:
"""Check if this credential should be applied to the given URL."""
parsed_url = urlparse(url)
# Extract hostname without port
request_host = parsed_url.hostname
if not request_host:
return False
# Simple host matching - exact match or wildcard subdomain match
if self.host == request_host:
return True
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
if self.host.startswith("*."):
domain = self.host[2:] # Remove "*."
return request_host.endswith(f".{domain}") or request_host == domain
return False
Credentials = Annotated[
OAuth2Credentials | APIKeyCredentials | UserPasswordCredentials,
OAuth2Credentials
| APIKeyCredentials
| UserPasswordCredentials
| HostScopedCredentials,
Field(discriminator="type"),
]
CredentialsType = Literal["api_key", "oauth2", "user_password"]
CredentialsType = Literal["api_key", "oauth2", "user_password", "host_scoped"]
class OAuthState(BaseModel):
@@ -320,15 +373,29 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
)
@staticmethod
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
schema["credentials_provider"] = cls.allowed_providers()
schema["credentials_types"] = cls.allowed_cred_types()
def _add_json_schema_extra(schema: dict, model_class: type):
# Use model_class for allowed_providers/cred_types
if hasattr(model_class, "allowed_providers") and hasattr(
model_class, "allowed_cred_types"
):
schema["credentials_provider"] = model_class.allowed_providers()
schema["credentials_types"] = model_class.allowed_cred_types()
# Do not return anything, just mutate schema in place
model_config = ConfigDict(
json_schema_extra=_add_json_schema_extra, # type: ignore
)
def _extract_host_from_url(url: str) -> str:
"""Extract host from URL for grouping host-scoped credentials."""
try:
parsed = urlparse(url)
return parsed.hostname or url
except Exception:
return ""
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
provider: frozenset[CP] = Field(..., alias="credentials_provider")
@@ -336,11 +403,12 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
discriminator: Optional[str] = None
discriminator_mapping: Optional[dict[str, CP]] = None
discriminator_values: set[Any] = Field(default_factory=set)
@classmethod
def combine(
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
) -> dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
"""
Combines multiple CredentialsFieldInfo objects into as few as possible.
@@ -358,22 +426,36 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
the set of keys of the respective original items that were grouped together.
"""
if not fields:
return []
return {}
# Group fields by their provider and supported_types
# For HTTP host-scoped credentials, also group by host
grouped_fields: defaultdict[
tuple[frozenset[CP], frozenset[CT]],
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
] = defaultdict(list)
for field, key in fields:
group_key = (frozenset(field.provider), frozenset(field.supported_types))
if field.provider == frozenset([ProviderName.HTTP]):
# HTTP host-scoped credentials can have different hosts that reqires different credential sets.
# Group by host extracted from the URL
providers = frozenset(
[cast(CP, "http")]
+ [
cast(CP, _extract_host_from_url(str(value)))
for value in field.discriminator_values
]
)
else:
providers = frozenset(field.provider)
group_key = (providers, frozenset(field.supported_types))
grouped_fields[group_key].append((key, field))
# Combine fields within each group
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
result: dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]] = {}
for group in grouped_fields.values():
for key, group in grouped_fields.items():
# Start with the first field in the group
_, combined = group[0]
@@ -386,18 +468,32 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
if field.required_scopes:
all_scopes.update(field.required_scopes)
# Create a new combined field
result.append(
(
CredentialsFieldInfo[CP, CT](
credentials_provider=combined.provider,
credentials_types=combined.supported_types,
credentials_scopes=frozenset(all_scopes) or None,
discriminator=combined.discriminator,
discriminator_mapping=combined.discriminator_mapping,
),
combined_keys,
)
# Combine discriminator_values from all fields in the group (removing duplicates)
all_discriminator_values = []
for _, field in group:
for value in field.discriminator_values:
if value not in all_discriminator_values:
all_discriminator_values.append(value)
# Generate the key for the combined result
providers_key, supported_types_key = key
group_key = (
"-".join(sorted(providers_key))
+ "_"
+ "-".join(sorted(supported_types_key))
+ "_credentials"
)
result[group_key] = (
CredentialsFieldInfo[CP, CT](
credentials_provider=combined.provider,
credentials_types=combined.supported_types,
credentials_scopes=frozenset(all_scopes) or None,
discriminator=combined.discriminator,
discriminator_mapping=combined.discriminator_mapping,
discriminator_values=set(all_discriminator_values),
),
combined_keys,
)
return result
@@ -406,11 +502,15 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
if not (self.discriminator and self.discriminator_mapping):
return self
discriminator_value = self.discriminator_mapping[discriminator_value]
return CredentialsFieldInfo(
credentials_provider=frozenset([discriminator_value]),
credentials_provider=frozenset(
[self.discriminator_mapping[discriminator_value]]
),
credentials_types=self.supported_types,
credentials_scopes=self.required_scopes,
discriminator=self.discriminator,
discriminator_mapping=self.discriminator_mapping,
discriminator_values=self.discriminator_values,
)
@@ -419,6 +519,7 @@ def CredentialsField(
*,
discriminator: Optional[str] = None,
discriminator_mapping: Optional[dict[str, Any]] = None,
discriminator_values: Optional[set[Any]] = None,
title: Optional[str] = None,
description: Optional[str] = None,
**kwargs,
@@ -434,6 +535,7 @@ def CredentialsField(
"credentials_scopes": list(required_scopes) or None,
"discriminator": discriminator,
"discriminator_mapping": discriminator_mapping,
"discriminator_values": discriminator_values,
}.items()
if v is not None
}

View File

@@ -0,0 +1,143 @@
import pytest
from pydantic import SecretStr
from backend.data.model import HostScopedCredentials
class TestHostScopedCredentials:
def test_host_scoped_credentials_creation(self):
"""Test creating HostScopedCredentials with required fields."""
creds = HostScopedCredentials(
provider="custom",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer secret-token"),
"X-API-Key": SecretStr("api-key-123"),
},
title="Example API Credentials",
)
assert creds.type == "host_scoped"
assert creds.provider == "custom"
assert creds.host == "api.example.com"
assert creds.title == "Example API Credentials"
assert len(creds.headers) == 2
assert "Authorization" in creds.headers
assert "X-API-Key" in creds.headers
def test_get_headers_dict(self):
"""Test getting headers with secret values extracted."""
creds = HostScopedCredentials(
provider="custom",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer secret-token"),
"X-Custom-Header": SecretStr("custom-value"),
},
)
headers_dict = creds.get_headers_dict()
assert headers_dict == {
"Authorization": "Bearer secret-token",
"X-Custom-Header": "custom-value",
}
def test_matches_url_exact_host(self):
"""Test URL matching with exact host match."""
creds = HostScopedCredentials(
provider="custom",
host="api.example.com",
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url("https://api.example.com/v1/data")
assert creds.matches_url("http://api.example.com/endpoint")
assert not creds.matches_url("https://other.example.com/v1/data")
assert not creds.matches_url("https://subdomain.api.example.com/v1/data")
def test_matches_url_wildcard_subdomain(self):
"""Test URL matching with wildcard subdomain pattern."""
creds = HostScopedCredentials(
provider="custom",
host="*.example.com",
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url("https://api.example.com/v1/data")
assert creds.matches_url("https://subdomain.example.com/endpoint")
assert creds.matches_url("https://deep.nested.example.com/path")
assert creds.matches_url("https://example.com/path") # Base domain should match
assert not creds.matches_url("https://example.org/v1/data")
assert not creds.matches_url("https://notexample.com/v1/data")
def test_matches_url_with_port_and_path(self):
"""Test URL matching with ports and paths."""
creds = HostScopedCredentials(
provider="custom",
host="localhost",
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url("http://localhost:8080/api/v1")
assert creds.matches_url("https://localhost:443/secure/endpoint")
assert creds.matches_url("http://localhost/simple")
def test_empty_headers_dict(self):
"""Test HostScopedCredentials with empty headers."""
creds = HostScopedCredentials(
provider="custom", host="api.example.com", headers={}
)
assert creds.get_headers_dict() == {}
assert creds.matches_url("https://api.example.com/test")
def test_credential_serialization(self):
"""Test that credentials can be serialized/deserialized properly."""
original_creds = HostScopedCredentials(
provider="custom",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer secret-token"),
"X-API-Key": SecretStr("api-key-123"),
},
title="Test Credentials",
)
# Serialize to dict (simulating storage)
serialized = original_creds.model_dump()
# Deserialize back
restored_creds = HostScopedCredentials.model_validate(serialized)
assert restored_creds.id == original_creds.id
assert restored_creds.provider == original_creds.provider
assert restored_creds.host == original_creds.host
assert restored_creds.title == original_creds.title
assert restored_creds.type == "host_scoped"
# Check that headers are properly restored
assert restored_creds.get_headers_dict() == original_creds.get_headers_dict()
@pytest.mark.parametrize(
"host,test_url,expected",
[
("api.example.com", "https://api.example.com/test", True),
("api.example.com", "https://different.example.com/test", False),
("*.example.com", "https://api.example.com/test", True),
("*.example.com", "https://sub.api.example.com/test", True),
("*.example.com", "https://example.com/test", True),
("*.example.com", "https://example.org/test", False),
("localhost", "http://localhost:3000/test", True),
("localhost", "http://127.0.0.1:3000/test", False),
],
)
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
"""Parametrized test for various URL matching scenarios."""
creds = HostScopedCredentials(
provider="test",
host=host,
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url(test_url) == expected

View File

@@ -12,14 +12,11 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
from pika.adapters.blocking_connection import BlockingChannel
from pika.spec import Basic, BasicProperties
from pydantic import JsonValue
from redis.asyncio.lock import Lock as RedisLock
from backend.blocks.io import AgentOutputBlock
from backend.data.model import (
CredentialsMetaInput,
GraphExecutionStats,
NodeExecutionStats,
)
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.notifications import (
AgentRunData,
LowBalanceData,
@@ -38,7 +35,7 @@ from autogpt_libs.utils.cache import thread_cached
from prometheus_client import Gauge, start_http_server
from backend.blocks.agent import AgentExecutorBlock
from backend.data import redis
from backend.data import redis_client as redis
from backend.data.block import (
BlockData,
BlockInput,
@@ -138,9 +135,7 @@ async def execute_node(
creds_manager: IntegrationCredentialsManager,
data: NodeExecutionEntry,
execution_stats: NodeExecutionStats | None = None,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> BlockOutput:
"""
Execute a node in the graph. This will trigger a block execution on a node,
@@ -183,8 +178,8 @@ async def execute_node(
if isinstance(node_block, AgentExecutorBlock):
_input_data = AgentExecutorBlock.Input(**node.input_default)
_input_data.inputs = input_data
if node_credentials_input_map:
_input_data.node_credentials_input_map = node_credentials_input_map
if nodes_input_masks:
_input_data.nodes_input_masks = nodes_input_masks
input_data = _input_data.model_dump()
data.inputs = input_data
@@ -255,7 +250,7 @@ async def _enqueue_next_nodes(
graph_exec_id: str,
graph_id: str,
log_metadata: LogMetadata,
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]],
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
) -> list[NodeExecutionEntry]:
async def add_enqueued_execution(
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
@@ -326,14 +321,12 @@ async def _enqueue_next_nodes(
for name in static_link_names:
next_node_input[name] = latest_execution.input_data.get(name)
# Apply node credentials overrides
node_credentials = None
if node_credentials_input_map and (
node_credentials := node_credentials_input_map.get(next_node.id)
# Apply node input overrides
node_input_mask = None
if nodes_input_masks and (
node_input_mask := nodes_input_masks.get(next_node.id)
):
next_node_input.update(
{k: v.model_dump() for k, v in node_credentials.items()}
)
next_node_input.update(node_input_mask)
# Validate the input data for the next node.
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
@@ -377,11 +370,9 @@ async def _enqueue_next_nodes(
for input_name in static_link_names:
idata[input_name] = next_node_input[input_name]
# Apply node credentials overrides
if node_credentials:
idata.update(
{k: v.model_dump() for k, v in node_credentials.items()}
)
# Apply node input overrides
if node_input_mask:
idata.update(node_input_mask)
idata, msg = validate_exec(next_node, idata)
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
@@ -430,14 +421,12 @@ class Executor:
"""
@classmethod
@async_error_logged
@async_error_logged(swallow=True)
async def on_node_execution(
cls,
node_exec: NodeExecutionEntry,
node_exec_progress: NodeExecutionProgress,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> NodeExecutionStats:
log_metadata = LogMetadata(
user_id=node_exec.user_id,
@@ -458,7 +447,7 @@ class Executor:
db_client=db_client,
log_metadata=log_metadata,
stats=execution_stats,
node_credentials_input_map=node_credentials_input_map,
nodes_input_masks=nodes_input_masks,
)
execution_stats.walltime = timing_info.wall_time
execution_stats.cputime = timing_info.cpu_time
@@ -481,9 +470,7 @@ class Executor:
db_client: "DatabaseManagerAsyncClient",
log_metadata: LogMetadata,
stats: NodeExecutionStats | None = None,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
):
try:
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
@@ -498,7 +485,7 @@ class Executor:
creds_manager=cls.creds_manager,
data=node_exec,
execution_stats=stats,
node_credentials_input_map=node_credentials_input_map,
nodes_input_masks=nodes_input_masks,
):
node_exec_progress.add_output(
ExecutionOutputEntry(
@@ -542,7 +529,7 @@ class Executor:
logger.info(f"[GraphExecutor] {cls.pid} started")
@classmethod
@error_logged
@error_logged(swallow=False)
def on_graph_execution(
cls, graph_exec: GraphExecutionEntry, cancel: threading.Event
):
@@ -594,6 +581,15 @@ class Executor:
exec_stats.cputime += timing_info.cpu_time
exec_stats.error = str(error) if error else exec_stats.error
if status not in {
ExecutionStatus.COMPLETED,
ExecutionStatus.TERMINATED,
ExecutionStatus.FAILED,
}:
raise RuntimeError(
f"Graph Execution #{graph_exec.graph_exec_id} ended with unexpected status {status}"
)
if graph_exec_result := db_client.update_graph_execution_stats(
graph_exec_id=graph_exec.graph_exec_id,
status=status,
@@ -697,7 +693,6 @@ class Executor:
if _graph_exec := db_client.update_graph_execution_stats(
graph_exec_id=graph_exec.graph_exec_id,
status=execution_status,
stats=execution_stats,
):
send_execution_update(_graph_exec)
@@ -779,24 +774,19 @@ class Executor:
)
raise
# Add credential overrides -----------------------------
# Add input overrides -----------------------------
node_id = queued_node_exec.node_id
if (node_creds_map := graph_exec.node_credentials_input_map) and (
node_field_creds_map := node_creds_map.get(node_id)
if (nodes_input_masks := graph_exec.nodes_input_masks) and (
node_input_mask := nodes_input_masks.get(node_id)
):
queued_node_exec.inputs.update(
{
field_name: creds_meta.model_dump()
for field_name, creds_meta in node_field_creds_map.items()
}
)
queued_node_exec.inputs.update(node_input_mask)
# Kick off async node execution -------------------------
node_execution_task = asyncio.run_coroutine_threadsafe(
cls.on_node_execution(
node_exec=queued_node_exec,
node_exec_progress=running_node_execution[node_id],
node_credentials_input_map=node_creds_map,
nodes_input_masks=nodes_input_masks,
),
cls.node_execution_loop,
)
@@ -840,7 +830,7 @@ class Executor:
node_id=node_id,
graph_exec=graph_exec,
log_metadata=log_metadata,
node_creds_map=node_creds_map,
nodes_input_masks=nodes_input_masks,
execution_queue=execution_queue,
),
cls.node_evaluation_loop,
@@ -910,7 +900,7 @@ class Executor:
node_id: str,
graph_exec: GraphExecutionEntry,
log_metadata: LogMetadata,
node_creds_map: Optional[dict[str, dict[str, CredentialsMetaInput]]],
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
execution_queue: ExecutionQueue[NodeExecutionEntry],
) -> None:
"""Process a node's output, update its status, and enqueue next nodes.
@@ -920,7 +910,7 @@ class Executor:
node_id: The ID of the node that produced the output
graph_exec: The graph execution entry
log_metadata: Logger metadata for consistent logging
node_creds_map: Optional map of node credentials
nodes_input_masks: Optional map of node input overrides
execution_queue: Queue to add next executions to
"""
db_client = get_db_async_client()
@@ -944,7 +934,7 @@ class Executor:
graph_exec_id=graph_exec.graph_exec_id,
graph_id=graph_exec.graph_id,
log_metadata=log_metadata,
node_credentials_input_map=node_creds_map,
nodes_input_masks=nodes_input_masks,
):
execution_queue.add(next_execution)
except Exception as e:

View File

@@ -366,14 +366,13 @@ async def test_execute_preset(server: SpinTestServer):
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
},
credentials={},
is_active=True,
)
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
# Execute preset with overriding values
result = await server.agent_server.test_execute_preset(
graph_id=test_graph.id,
graph_version=test_graph.version,
preset_id=created_preset.id,
user_id=test_user.id,
)
@@ -455,16 +454,15 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
},
credentials={},
is_active=True,
)
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
# Execute preset with overriding values
result = await server.agent_server.test_execute_preset(
graph_id=test_graph.id,
graph_version=test_graph.version,
preset_id=created_preset.id,
node_input={"selected_value": "key1"},
inputs={"selected_value": "key1"},
user_id=test_user.id,
)

View File

@@ -3,6 +3,7 @@ import logging
import os
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Optional
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
@@ -14,13 +15,16 @@ from apscheduler.triggers.cron import CronTrigger
from autogpt_libs.utils.cache import thread_cached
from dotenv import load_dotenv
from prisma.enums import NotificationType
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import MetaData, create_engine
from backend.data.block import BlockInput
from backend.data.execution import ExecutionStatus
from backend.data.model import CredentialsMetaInput
from backend.executor import utils as execution_utils
from backend.notifications.notifications import NotificationManagerClient
from backend.util.exceptions import NotAuthorizedError, NotFoundError
from backend.util.logging import PrefixFilter
from backend.util.metrics import sentry_capture_error
from backend.util.service import (
AppService,
@@ -52,19 +56,19 @@ def _extract_schema_from_url(database_url) -> tuple[str, str]:
logger = logging.getLogger(__name__)
logger.addFilter(PrefixFilter("[Scheduler]"))
apscheduler_logger = logger.getChild("apscheduler")
apscheduler_logger.addFilter(PrefixFilter("[Scheduler] [APScheduler]"))
config = Config()
def log(msg, **kwargs):
logger.info("[Scheduler] " + msg, **kwargs)
def job_listener(event):
"""Logs job execution outcomes for better monitoring."""
if event.exception:
log(f"Job {event.job_id} failed.")
logger.error(f"Job {event.job_id} failed.")
else:
log(f"Job {event.job_id} completed successfully.")
logger.info(f"Job {event.job_id} completed successfully.")
@thread_cached
@@ -84,16 +88,17 @@ def execute_graph(**kwargs):
async def _execute_graph(**kwargs):
args = GraphExecutionJobArgs(**kwargs)
try:
log(f"Executing recurring job for graph #{args.graph_id}")
logger.info(f"Executing recurring job for graph #{args.graph_id}")
await execution_utils.add_graph_execution(
graph_id=args.graph_id,
inputs=args.input_data,
user_id=args.user_id,
graph_id=args.graph_id,
graph_version=args.graph_version,
inputs=args.input_data,
graph_credentials_inputs=args.input_credentials,
use_db_query=False,
)
except Exception as e:
logger.exception(f"Error executing graph {args.graph_id}: {e}")
logger.error(f"Error executing graph {args.graph_id}: {e}")
class LateExecutionException(Exception):
@@ -137,20 +142,20 @@ def report_late_executions() -> str:
def process_existing_batches(**kwargs):
args = NotificationJobArgs(**kwargs)
try:
log(
logger.info(
f"Processing existing batches for notification type {args.notification_types}"
)
get_notification_client().process_existing_batches(args.notification_types)
except Exception as e:
logger.exception(f"Error processing existing batches: {e}")
logger.error(f"Error processing existing batches: {e}")
def process_weekly_summary(**kwargs):
try:
log("Processing weekly summary")
logger.info("Processing weekly summary")
get_notification_client().queue_weekly_summary()
except Exception as e:
logger.exception(f"Error processing weekly summary: {e}")
logger.error(f"Error processing weekly summary: {e}")
class Jobstores(Enum):
@@ -160,11 +165,12 @@ class Jobstores(Enum):
class GraphExecutionJobArgs(BaseModel):
graph_id: str
input_data: BlockInput
user_id: str
graph_id: str
graph_version: int
cron: str
input_data: BlockInput
input_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
class GraphExecutionJobInfo(GraphExecutionJobArgs):
@@ -247,7 +253,8 @@ class Scheduler(AppService):
),
# These don't really need persistence
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
}
},
logger=apscheduler_logger,
)
if self.register_system_tasks:
@@ -285,34 +292,40 @@ class Scheduler(AppService):
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down scheduler...")
logger.info("⏳ Shutting down scheduler...")
if self.scheduler:
self.scheduler.shutdown(wait=False)
@expose
def add_graph_execution_schedule(
self,
user_id: str,
graph_id: str,
graph_version: int,
cron: str,
input_data: BlockInput,
user_id: str,
input_credentials: dict[str, CredentialsMetaInput],
name: Optional[str] = None,
) -> GraphExecutionJobInfo:
job_args = GraphExecutionJobArgs(
graph_id=graph_id,
input_data=input_data,
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version,
cron=cron,
input_data=input_data,
input_credentials=input_credentials,
)
job = self.scheduler.add_job(
execute_graph,
CronTrigger.from_crontab(cron),
kwargs=job_args.model_dump(),
replace_existing=True,
name=name,
trigger=CronTrigger.from_crontab(cron),
jobstore=Jobstores.EXECUTION.value,
replace_existing=True,
)
logger.info(
f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}"
)
log(f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}")
return GraphExecutionJobInfo.from_db(job_args, job)
@expose
@@ -321,14 +334,13 @@ class Scheduler(AppService):
) -> GraphExecutionJobInfo:
job = self.scheduler.get_job(schedule_id, jobstore=Jobstores.EXECUTION.value)
if not job:
log(f"Job {schedule_id} not found.")
raise ValueError(f"Job #{schedule_id} not found.")
raise NotFoundError(f"Job #{schedule_id} not found.")
job_args = GraphExecutionJobArgs(**job.kwargs)
if job_args.user_id != user_id:
raise ValueError("User ID does not match the job's user ID.")
raise NotAuthorizedError("User ID does not match the job's user ID")
log(f"Deleting job {schedule_id}")
logger.info(f"Deleting job {schedule_id}")
job.remove()
return GraphExecutionJobInfo.from_db(job_args, job)

View File

@@ -27,6 +27,7 @@ async def test_agent_schedule(server: SpinTestServer):
graph_version=1,
cron="0 0 * * *",
input_data={"input": "data"},
input_credentials={},
)
assert schedule

View File

@@ -5,7 +5,7 @@ from concurrent.futures import Future
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
from autogpt_libs.utils.cache import thread_cached
from pydantic import BaseModel
from pydantic import BaseModel, JsonValue
from backend.data.block import (
Block,
@@ -435,9 +435,7 @@ def validate_exec(
async def _validate_node_input_credentials(
graph: GraphModel,
user_id: str,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
):
"""Checks all credentials for all nodes of the graph"""
@@ -453,11 +451,13 @@ async def _validate_node_input_credentials(
for field_name, credentials_meta_type in credentials_fields.items():
if (
node_credentials_input_map
and (node_credentials_inputs := node_credentials_input_map.get(node.id))
and field_name in node_credentials_inputs
nodes_input_masks
and (node_input_mask := nodes_input_masks.get(node.id))
and field_name in node_input_mask
):
credentials_meta = node_credentials_input_map[node.id][field_name]
credentials_meta = credentials_meta_type.model_validate(
node_input_mask[field_name]
)
elif field_name in node.input_default:
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
@@ -496,7 +496,7 @@ async def _validate_node_input_credentials(
def make_node_credentials_input_map(
graph: GraphModel,
graph_credentials_input: dict[str, CredentialsMetaInput],
) -> dict[str, dict[str, CredentialsMetaInput]]:
) -> dict[str, dict[str, JsonValue]]:
"""
Maps credentials for an execution to the correct nodes.
@@ -505,9 +505,9 @@ def make_node_credentials_input_map(
graph_credentials_input: A (graph_input_name, credentials_meta) map.
Returns:
dict[node_id, dict[field_name, CredentialsMetaInput]]: Node credentials input map.
dict[node_id, dict[field_name, CredentialsMetaRaw]]: Node credentials input map.
"""
result: dict[str, dict[str, CredentialsMetaInput]] = {}
result: dict[str, dict[str, JsonValue]] = {}
# Get aggregated credentials fields for the graph
graph_cred_inputs = graph.aggregate_credentials_inputs()
@@ -521,7 +521,9 @@ def make_node_credentials_input_map(
for node_id, node_field_name in compatible_node_fields:
if node_id not in result:
result[node_id] = {}
result[node_id][node_field_name] = graph_credentials_input[graph_input_name]
result[node_id][node_field_name] = graph_credentials_input[
graph_input_name
].model_dump(exclude_none=True)
return result
@@ -530,9 +532,7 @@ async def construct_node_execution_input(
graph: GraphModel,
user_id: str,
graph_inputs: BlockInput,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> list[tuple[str, BlockInput]]:
"""
Validates and prepares the input data for executing a graph.
@@ -550,8 +550,8 @@ async def construct_node_execution_input(
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
the corresponding input data for that node.
"""
graph.validate_graph(for_run=True)
await _validate_node_input_credentials(graph, user_id, node_credentials_input_map)
graph.validate_graph(for_run=True, nodes_input_masks=nodes_input_masks)
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
nodes_input = []
for node in graph.starting_nodes:
@@ -568,23 +568,9 @@ async def construct_node_execution_input(
if input_name and input_name in graph_inputs:
input_data = {"value": graph_inputs[input_name]}
# Extract webhook payload, and assign it to the input pin
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
if (
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
and node.webhook_id
):
if webhook_payload_key not in graph_inputs:
raise ValueError(
f"Node {block.name} #{node.id} webhook payload is missing"
)
input_data = {"payload": graph_inputs[webhook_payload_key]}
# Apply node credentials overrides
if node_credentials_input_map and (
node_credentials := node_credentials_input_map.get(node.id)
):
input_data.update({k: v.model_dump() for k, v in node_credentials.items()})
# Apply node input overrides
if nodes_input_masks and (node_input_mask := nodes_input_masks.get(node.id)):
input_data.update(node_input_mask)
input_data, error = validate_exec(node, input_data)
if input_data is None:
@@ -600,6 +586,20 @@ async def construct_node_execution_input(
return nodes_input
def _merge_nodes_input_masks(
overrides_map_1: dict[str, dict[str, JsonValue]],
overrides_map_2: dict[str, dict[str, JsonValue]],
) -> dict[str, dict[str, JsonValue]]:
"""Perform a per-node merge of input overrides"""
result = overrides_map_1.copy()
for node_id, overrides2 in overrides_map_2.items():
if node_id in result:
result[node_id] = {**result[node_id], **overrides2}
else:
result[node_id] = overrides2
return result
# ============ Execution Queue Helpers ============ #
@@ -730,13 +730,11 @@ async def stop_graph_execution(
async def add_graph_execution(
graph_id: str,
user_id: str,
inputs: BlockInput,
inputs: Optional[BlockInput] = None,
preset_id: Optional[str] = None,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
use_db_query: bool = True,
) -> GraphExecutionWithNodes:
"""
@@ -750,7 +748,7 @@ async def add_graph_execution(
graph_version: The version of the graph to execute.
graph_credentials_inputs: Credentials inputs to use in the execution.
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
node_credentials_input_map: Credentials inputs to use in the execution, mapped to specific nodes.
nodes_input_masks: Node inputs to use in the execution.
Returns:
GraphExecutionEntry: The entry for the graph execution.
Raises:
@@ -774,10 +772,19 @@ async def add_graph_execution(
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
node_credentials_input_map = node_credentials_input_map or (
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else None
nodes_input_masks = _merge_nodes_input_masks(
(
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else {}
),
nodes_input_masks or {},
)
starting_nodes_input = await construct_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=inputs or {},
nodes_input_masks=nodes_input_masks,
)
if use_db_query:
@@ -785,12 +792,7 @@ async def add_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
starting_nodes_input=await construct_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=inputs,
node_credentials_input_map=node_credentials_input_map,
),
starting_nodes_input=starting_nodes_input,
preset_id=preset_id,
)
else:
@@ -798,20 +800,15 @@ async def add_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
starting_nodes_input=await construct_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=inputs,
node_credentials_input_map=node_credentials_input_map,
),
starting_nodes_input=starting_nodes_input,
preset_id=preset_id,
)
try:
queue = await get_async_execution_queue()
graph_exec_entry = graph_exec.to_graph_execution_entry()
if node_credentials_input_map:
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
if nodes_input_masks:
graph_exec_entry.nodes_input_masks = nodes_input_masks
await queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec_entry.model_dump_json(),

View File

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
from pydantic import SecretStr
from backend.data.redis import get_redis_async
from backend.data.redis_client import get_redis_async
if TYPE_CHECKING:
from backend.executor.database import DatabaseManagerAsyncClient

View File

@@ -7,7 +7,7 @@ from autogpt_libs.utils.synchronize import AsyncRedisKeyedMutex
from redis.asyncio.lock import Lock as AsyncRedisLock
from backend.data.model import Credentials, OAuth2Credentials
from backend.data.redis import get_redis_async
from backend.data.redis_client import get_redis_async
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName

View File

@@ -17,6 +17,7 @@ class ProviderName(str, Enum):
GOOGLE = "google"
GOOGLE_MAPS = "google_maps"
GROQ = "groq"
HTTP = "http"
HUBSPOT = "hubspot"
IDEOGRAM = "ideogram"
JINA = "jina"

View File

@@ -1,23 +1,22 @@
import functools
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..providers import ProviderName
from ._base import BaseWebhooksManager
_WEBHOOK_MANAGERS: dict["ProviderName", type["BaseWebhooksManager"]] = {}
# --8<-- [start:load_webhook_managers]
@functools.cache
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
if _WEBHOOK_MANAGERS:
return _WEBHOOK_MANAGERS
webhook_managers = {}
from .compass import CompassWebhookManager
from .generic import GenericWebhooksManager
from .github import GithubWebhooksManager
from .slant3d import Slant3DWebhooksManager
_WEBHOOK_MANAGERS.update(
webhook_managers.update(
{
handler.PROVIDER_NAME: handler
for handler in [
@@ -28,7 +27,7 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
]
}
)
return _WEBHOOK_MANAGERS
return webhook_managers
# --8<-- [end:load_webhook_managers]

View File

@@ -7,13 +7,14 @@ from uuid import uuid4
from fastapi import Request
from strenum import StrEnum
from backend.data import integrations
import backend.data.integrations as integrations
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.utils import webhook_ingress_url
from backend.util.exceptions import MissingConfigError
from backend.util.settings import Config
from .utils import webhook_ingress_url
logger = logging.getLogger(__name__)
app_config = Config()
@@ -41,44 +42,74 @@ class BaseWebhooksManager(ABC, Generic[WT]):
)
if webhook := await integrations.find_webhook_by_credentials_and_props(
credentials.id, webhook_type, resource, events
user_id=user_id,
credentials_id=credentials.id,
webhook_type=webhook_type,
resource=resource,
events=events,
):
return webhook
return await self._create_webhook(
user_id, webhook_type, events, resource, credentials
user_id=user_id,
webhook_type=webhook_type,
events=events,
resource=resource,
credentials=credentials,
)
async def get_manual_webhook(
self,
user_id: str,
graph_id: str,
webhook_type: WT,
events: list[str],
):
if current_webhook := await integrations.find_webhook_by_graph_and_props(
graph_id, self.PROVIDER_NAME, webhook_type, events
graph_id: Optional[str] = None,
preset_id: Optional[str] = None,
) -> integrations.Webhook:
"""
Tries to find an existing webhook tied to `graph_id`/`preset_id`,
or creates a new webhook if none exists.
Existing webhooks are matched by `user_id`, `webhook_type`,
and `graph_id`/`preset_id`.
If an existing webhook is found, we check if the events match and update them
if necessary. We do this rather than creating a new webhook
to avoid changing the webhook URL for existing manual webhooks.
"""
if (graph_id or preset_id) and (
current_webhook := await integrations.find_webhook_by_graph_and_props(
user_id=user_id,
provider=self.PROVIDER_NAME.value,
webhook_type=webhook_type,
graph_id=graph_id,
preset_id=preset_id,
)
):
if set(current_webhook.events) != set(events):
current_webhook = await integrations.update_webhook(
current_webhook.id, events=events
)
return current_webhook
return await self._create_webhook(
user_id,
webhook_type,
events,
user_id=user_id,
webhook_type=webhook_type,
events=events,
register=False,
)
async def prune_webhook_if_dangling(
self, webhook_id: str, credentials: Optional[Credentials]
self, user_id: str, webhook_id: str, credentials: Optional[Credentials]
) -> bool:
webhook = await integrations.get_webhook(webhook_id)
if webhook.attached_nodes is None:
raise ValueError("Error retrieving webhook including attached nodes")
if webhook.attached_nodes:
webhook = await integrations.get_webhook(webhook_id, include_relations=True)
if webhook.triggered_nodes or webhook.triggered_presets:
# Don't prune webhook if in use
return False
if credentials:
await self._deregister_webhook(webhook, credentials)
await integrations.delete_webhook(webhook.id)
await integrations.delete_webhook(user_id, webhook.id)
return True
# --8<-- [start:BaseWebhooksManager3]

View File

@@ -1,10 +1,12 @@
import logging
from typing import TYPE_CHECKING, Optional, cast
from backend.data.block import BlockSchema, BlockWebhookConfig
from backend.data.block import BlockSchema
from backend.data.graph import set_node_webhook
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks import get_webhook_manager, supports_webhooks
from . import get_webhook_manager, supports_webhooks
from .utils import setup_webhook_for_block
if TYPE_CHECKING:
from backend.data.graph import GraphModel, NodeModel
@@ -81,7 +83,9 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
f"credentials #{creds_meta['id']}"
)
updated_node = await on_node_deactivate(node, credentials=node_credentials)
updated_node = await on_node_deactivate(
user_id, node, credentials=node_credentials
)
updated_nodes.append(updated_node)
graph.nodes = updated_nodes
@@ -96,105 +100,25 @@ async def on_node_activate(
) -> "NodeModel":
"""Hook to be called when the node is activated/created"""
block = node.block
if not block.webhook_config:
return node
provider = block.webhook_config.provider
if not supports_webhooks(provider):
raise ValueError(
f"Block #{block.id} has webhook_config for provider {provider} "
"which does not support webhooks"
if node.block.webhook_config:
new_webhook, feedback = await setup_webhook_for_block(
user_id=user_id,
trigger_block=node.block,
trigger_config=node.input_default,
for_graph_id=node.graph_id,
)
logger.debug(
f"Activating webhook node #{node.id} with config {block.webhook_config}"
)
webhooks_manager = get_webhook_manager(provider)
if auto_setup_webhook := isinstance(block.webhook_config, BlockWebhookConfig):
try:
resource = block.webhook_config.resource_format.format(**node.input_default)
except KeyError:
resource = None
logger.debug(
f"Constructed resource string {resource} from input {node.input_default}"
)
else:
resource = "" # not relevant for manual webhooks
block_input_schema = cast(BlockSchema, block.input_schema)
credentials_field_name = next(iter(block_input_schema.get_credentials_fields()), "")
credentials_meta = (
node.input_default.get(credentials_field_name)
if credentials_field_name
else None
)
event_filter_input_name = block.webhook_config.event_filter_input
has_everything_for_webhook = (
resource is not None
and (credentials_meta or not credentials_field_name)
and (
not event_filter_input_name
or (
event_filter_input_name in node.input_default
and any(
is_on
for is_on in node.input_default[event_filter_input_name].values()
)
)
)
)
if has_everything_for_webhook and resource is not None:
logger.debug(f"Node #{node} has everything for a webhook!")
if credentials_meta and not credentials:
raise ValueError(
f"Cannot set up webhook for node #{node.id}: "
f"credentials #{credentials_meta['id']} not available"
)
if event_filter_input_name:
# Shape of the event filter is enforced in Block.__init__
event_filter = cast(dict, node.input_default[event_filter_input_name])
events = [
block.webhook_config.event_format.format(event=event)
for event, enabled in event_filter.items()
if enabled is True
]
logger.debug(f"Webhook events to subscribe to: {', '.join(events)}")
if new_webhook:
node = await set_node_webhook(node.id, new_webhook.id)
else:
events = []
# Find/make and attach a suitable webhook to the node
if auto_setup_webhook:
assert credentials is not None
new_webhook = await webhooks_manager.get_suitable_auto_webhook(
user_id,
credentials,
block.webhook_config.webhook_type,
resource,
events,
logger.debug(
f"Node #{node.id} does not have everything for a webhook: {feedback}"
)
else:
# Manual webhook -> no credentials -> don't register but do create
new_webhook = await webhooks_manager.get_manual_webhook(
user_id,
node.graph_id,
block.webhook_config.webhook_type,
events,
)
logger.debug(f"Acquired webhook: {new_webhook}")
return await set_node_webhook(node.id, new_webhook.id)
else:
logger.debug(f"Node #{node.id} does not have everything for a webhook")
return node
async def on_node_deactivate(
user_id: str,
node: "NodeModel",
*,
credentials: Optional["Credentials"] = None,
@@ -233,7 +157,9 @@ async def on_node_deactivate(
f"Pruning{' and deregistering' if credentials else ''} "
f"webhook #{webhook.id}"
)
await webhooks_manager.prune_webhook_if_dangling(webhook.id, credentials)
await webhooks_manager.prune_webhook_if_dangling(
user_id, webhook.id, credentials
)
if (
cast(BlockSchema, block.input_schema).get_credentials_fields()
and not credentials

View File

@@ -1,7 +1,22 @@
import logging
from typing import TYPE_CHECKING, Optional, cast
from pydantic import JsonValue
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.settings import Config
from . import get_webhook_manager, supports_webhooks
if TYPE_CHECKING:
from backend.data.block import Block, BlockSchema
from backend.data.integrations import Webhook
from backend.data.model import Credentials
logger = logging.getLogger(__name__)
app_config = Config()
credentials_manager = IntegrationCredentialsManager()
# TODO: add test to assert this matches the actual API route
@@ -10,3 +25,122 @@ def webhook_ingress_url(provider_name: ProviderName, webhook_id: str) -> str:
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
f"/webhooks/{webhook_id}/ingress"
)
async def setup_webhook_for_block(
user_id: str,
trigger_block: "Block[BlockSchema, BlockSchema]",
trigger_config: dict[str, JsonValue], # = Trigger block inputs
for_graph_id: Optional[str] = None,
for_preset_id: Optional[str] = None,
credentials: Optional["Credentials"] = None,
) -> tuple["Webhook", None] | tuple[None, str]:
"""
Utility function to create (and auto-setup if possible) a webhook for a given provider.
Returns:
Webhook: The created or found webhook object, if successful.
str: A feedback message, if any required inputs are missing.
"""
from backend.data.block import BlockWebhookConfig
if not (trigger_base_config := trigger_block.webhook_config):
raise ValueError(f"Block #{trigger_block.id} does not have a webhook_config")
provider = trigger_base_config.provider
if not supports_webhooks(provider):
raise NotImplementedError(
f"Block #{trigger_block.id} has webhook_config for provider {provider} "
"for which we do not have a WebhooksManager"
)
logger.debug(
f"Setting up webhook for block #{trigger_block.id} with config {trigger_config}"
)
# Check & parse the event filter input, if any
events: list[str] = []
if event_filter_input_name := trigger_base_config.event_filter_input:
if not (event_filter := trigger_config.get(event_filter_input_name)):
return None, (
f"Cannot set up {provider.value} webhook without event filter input: "
f"missing input for '{event_filter_input_name}'"
)
elif not (
# Shape of the event filter is enforced in Block.__init__
any((event_filter := cast(dict[str, bool], event_filter)).values())
):
return None, (
f"Cannot set up {provider.value} webhook without any enabled events "
f"in event filter input '{event_filter_input_name}'"
)
events = [
trigger_base_config.event_format.format(event=event)
for event, enabled in event_filter.items()
if enabled is True
]
logger.debug(f"Webhook events to subscribe to: {', '.join(events)}")
# Check & process prerequisites for auto-setup webhooks
if auto_setup_webhook := isinstance(trigger_base_config, BlockWebhookConfig):
try:
resource = trigger_base_config.resource_format.format(**trigger_config)
except KeyError as missing_key:
return None, (
f"Cannot auto-setup {provider.value} webhook without resource: "
f"missing input for '{missing_key}'"
)
logger.debug(
f"Constructed resource string {resource} from input {trigger_config}"
)
creds_field_name = next(
# presence of this field is enforced in Block.__init__
iter(trigger_block.input_schema.get_credentials_fields())
)
if not (
credentials_meta := cast(dict, trigger_config.get(creds_field_name, None))
):
return None, f"Cannot set up {provider.value} webhook without credentials"
elif not (
credentials := credentials
or await credentials_manager.get(user_id, credentials_meta["id"])
):
raise ValueError(
f"Cannot set up {provider.value} webhook without credentials: "
f"credentials #{credentials_meta['id']} not found for user #{user_id}"
)
elif credentials.provider != provider:
raise ValueError(
f"Credentials #{credentials.id} do not match provider {provider.value}"
)
else:
# not relevant for manual webhooks:
resource = ""
credentials = None
webhooks_manager = get_webhook_manager(provider)
# Find/make and attach a suitable webhook to the node
if auto_setup_webhook:
assert credentials is not None
webhook = await webhooks_manager.get_suitable_auto_webhook(
user_id=user_id,
credentials=credentials,
webhook_type=trigger_base_config.webhook_type,
resource=resource,
events=events,
)
else:
# Manual webhook -> no credentials -> don't register but do create
webhook = await webhooks_manager.get_manual_webhook(
user_id=user_id,
webhook_type=trigger_base_config.webhook_type,
events=events,
graph_id=for_graph_id,
preset_id=for_preset_id,
)
logger.debug(f"Acquired webhook: {webhook}")
return webhook, None

View File

@@ -2,11 +2,19 @@ import asyncio
import logging
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from fastapi import (
APIRouter,
Body,
Depends,
HTTPException,
Path,
Query,
Request,
status,
)
from pydantic import BaseModel, Field
from starlette.status import HTTP_404_NOT_FOUND
from backend.data.graph import set_node_webhook
from backend.data.graph import get_graph, set_node_webhook
from backend.data.integrations import (
WebhookEvent,
get_all_webhooks_by_creds,
@@ -14,12 +22,18 @@ from backend.data.integrations import (
publish_webhook_event,
wait_for_webhook_event,
)
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
from backend.data.model import (
Credentials,
CredentialsType,
HostScopedCredentials,
OAuth2Credentials,
)
from backend.executor.utils import add_graph_execution
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
from backend.server.v2.library.db import set_preset_webhook, update_preset
from backend.util.exceptions import NeedConfirmation, NotFoundError
from backend.util.settings import Settings
@@ -73,6 +87,9 @@ class CredentialsMetaResponse(BaseModel):
title: str | None
scopes: list[str] | None
username: str | None
host: str | None = Field(
default=None, description="Host pattern for host-scoped credentials"
)
@router.post("/{provider}/callback")
@@ -95,7 +112,10 @@ async def callback(
if not valid_state:
logger.warning(f"Invalid or expired state token for user {user_id}")
raise HTTPException(status_code=400, detail="Invalid or expired state token")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired state token",
)
try:
scopes = valid_state.scopes
logger.debug(f"Retrieved scopes from state token: {scopes}")
@@ -122,17 +142,12 @@ async def callback(
)
except Exception as e:
logger.exception(
"OAuth callback for provider %s failed during code exchange: %s. Confirm provider credentials.",
provider.value,
e,
logger.error(
f"OAuth2 Code->Token exchange failed for provider {provider.value}: {e}"
)
raise HTTPException(
status_code=400,
detail={
"message": str(e),
"hint": "Verify OAuth configuration and try again.",
},
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"OAuth2 callback failed to exchange code for tokens: {str(e)}",
)
# TODO: Allow specifying `title` to set on `credentials`
@@ -149,6 +164,9 @@ async def callback(
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=(
credentials.host if isinstance(credentials, HostScopedCredentials) else None
),
)
@@ -165,6 +183,7 @@ async def list_credentials(
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@@ -186,6 +205,7 @@ async def list_credentials_by_provider(
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@@ -201,10 +221,13 @@ async def get_credential(
) -> Credentials:
credential = await creds_manager.get(user_id, cred_id)
if not credential:
raise HTTPException(status_code=404, detail="Credentials not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if credential.provider != provider:
raise HTTPException(
status_code=404, detail="Credentials do not match the specified provider"
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials do not match the specified provider",
)
return credential
@@ -222,7 +245,8 @@ async def create_credentials(
await creds_manager.create(user_id, credentials)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to store credentials: {str(e)}"
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to store credentials: {str(e)}",
)
return credentials
@@ -256,14 +280,17 @@ async def delete_credentials(
) -> CredentialsDeletionResponse | CredentialsDeletionNeedsConfirmationResponse:
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
if not creds:
raise HTTPException(status_code=404, detail="Credentials not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if creds.provider != provider:
raise HTTPException(
status_code=404, detail="Credentials do not match the specified provider"
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials do not match the specified provider",
)
try:
await remove_all_webhooks_for_credentials(creds, force)
await remove_all_webhooks_for_credentials(user_id, creds, force)
except NeedConfirmation as e:
return CredentialsDeletionNeedsConfirmationResponse(message=str(e))
@@ -294,16 +321,10 @@ async def webhook_ingress_generic(
logger.debug(f"Received {provider.value} webhook ingress for ID {webhook_id}")
webhook_manager = get_webhook_manager(provider)
try:
webhook = await get_webhook(webhook_id)
webhook = await get_webhook(webhook_id, include_relations=True)
except NotFoundError as e:
logger.warning(
"Webhook payload received for unknown webhook %s. Confirm the webhook ID.",
webhook_id,
)
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail={"message": str(e), "hint": "Check if the webhook ID is correct."},
) from e
logger.warning(f"Webhook payload received for unknown webhook #{webhook_id}")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
logger.debug(f"Webhook #{webhook_id}: {webhook}")
payload, event_type = await webhook_manager.validate_payload(webhook, request)
logger.debug(
@@ -320,11 +341,11 @@ async def webhook_ingress_generic(
await publish_webhook_event(webhook_event)
logger.debug(f"Webhook event published: {webhook_event}")
if not webhook.attached_nodes:
if not (webhook.triggered_nodes or webhook.triggered_presets):
return
executions: list[Awaitable] = []
for node in webhook.attached_nodes:
for node in webhook.triggered_nodes:
logger.debug(f"Webhook-attached node: {node}")
if not node.is_triggered_by_event_type(event_type):
logger.debug(f"Node #{node.id} doesn't trigger on event {event_type}")
@@ -335,7 +356,48 @@ async def webhook_ingress_generic(
user_id=webhook.user_id,
graph_id=node.graph_id,
graph_version=node.graph_version,
inputs={f"webhook_{webhook_id}_payload": payload},
nodes_input_masks={node.id: {"payload": payload}},
)
)
for preset in webhook.triggered_presets:
logger.debug(f"Webhook-attached preset: {preset}")
if not preset.is_active:
logger.debug(f"Preset #{preset.id} is inactive")
continue
graph = await get_graph(preset.graph_id, preset.graph_version, webhook.user_id)
if not graph:
logger.error(
f"User #{webhook.user_id} has preset #{preset.id} for graph "
f"#{preset.graph_id} v{preset.graph_version}, "
"but no access to the graph itself."
)
logger.info(f"Automatically deactivating broken preset #{preset.id}")
await update_preset(preset.user_id, preset.id, is_active=False)
continue
if not (trigger_node := graph.webhook_input_node):
# NOTE: this should NEVER happen, but we log and handle it gracefully
logger.error(
f"Preset #{preset.id} is triggered by webhook #{webhook.id}, but graph "
f"#{preset.graph_id} v{preset.graph_version} has no webhook input node"
)
await set_preset_webhook(preset.user_id, preset.id, None)
continue
if not trigger_node.block.is_triggered_by_event_type(preset.inputs, event_type):
logger.debug(f"Preset #{preset.id} doesn't trigger on event {event_type}")
continue
logger.debug(f"Executing preset #{preset.id} for webhook #{webhook.id}")
executions.append(
add_graph_execution(
user_id=webhook.user_id,
graph_id=preset.graph_id,
preset_id=preset.id,
graph_version=preset.graph_version,
graph_credentials_inputs=preset.credentials,
nodes_input_masks={
trigger_node.id: {**preset.inputs, "payload": payload}
},
)
)
asyncio.gather(*executions)
@@ -360,7 +422,9 @@ async def webhook_ping(
return False
if not await wait_for_webhook_event(webhook_id, event_type="ping", timeout=10):
raise HTTPException(status_code=504, detail="Webhook ping timed out")
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Webhook ping timed out"
)
return True
@@ -369,32 +433,37 @@ async def webhook_ping(
async def remove_all_webhooks_for_credentials(
credentials: Credentials, force: bool = False
user_id: str, credentials: Credentials, force: bool = False
) -> None:
"""
Remove and deregister all webhooks that were registered using the given credentials.
Params:
user_id: The ID of the user who owns the credentials and webhooks.
credentials: The credentials for which to remove the associated webhooks.
force: Whether to proceed if any of the webhooks are still in use.
Raises:
NeedConfirmation: If any of the webhooks are still in use and `force` is `False`
"""
webhooks = await get_all_webhooks_by_creds(credentials.id)
if any(w.attached_nodes for w in webhooks) and not force:
webhooks = await get_all_webhooks_by_creds(
user_id, credentials.id, include_relations=True
)
if any(w.triggered_nodes or w.triggered_presets for w in webhooks) and not force:
raise NeedConfirmation(
"Some webhooks linked to these credentials are still in use by an agent"
)
for webhook in webhooks:
# Unlink all nodes
for node in webhook.attached_nodes or []:
# Unlink all nodes & presets
for node in webhook.triggered_nodes:
await set_node_webhook(node.id, None)
for preset in webhook.triggered_presets:
await set_preset_webhook(user_id, preset.id, None)
# Prune the webhook
webhook_manager = get_webhook_manager(ProviderName(credentials.provider))
success = await webhook_manager.prune_webhook_if_dangling(
webhook.id, credentials
user_id, webhook.id, credentials
)
if not success:
logger.warning(f"Webhook #{webhook.id} failed to prune")
@@ -405,7 +474,7 @@ def _get_provider_oauth_handler(
) -> "BaseOAuthHandler":
if provider_name not in HANDLERS_BY_NAME:
raise HTTPException(
status_code=404,
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Provider '{provider_name.value}' does not support OAuth",
)
@@ -413,14 +482,13 @@ def _get_provider_oauth_handler(
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
if not (client_id and client_secret):
logger.error(
"OAuth credentials for provider %s are missing. Check environment configuration.",
provider_name.value,
f"Attempt to use unconfigured {provider_name.value} OAuth integration"
)
raise HTTPException(
status_code=501,
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail={
"message": f"Integration with provider '{provider_name.value}' is not configured",
"hint": "Set client ID and secret in the environment.",
"message": f"Integration with provider '{provider_name.value}' is not configured.",
"hint": "Set client ID and secret in the application's deployment environment",
},
)

View File

@@ -279,6 +279,7 @@ class AgentServer(backend.util.service.AppProcess):
@staticmethod
async def test_delete_graph(graph_id: str, user_id: str):
"""Used for clean-up after a test run"""
await backend.server.v2.library.db.delete_library_agent_by_graph_id(
graph_id=graph_id, user_id=user_id
)
@@ -323,18 +324,14 @@ class AgentServer(backend.util.service.AppProcess):
@staticmethod
async def test_execute_preset(
graph_id: str,
graph_version: int,
preset_id: str,
user_id: str,
node_input: Optional[dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
):
return await backend.server.v2.library.routes.presets.execute_preset(
graph_id=graph_id,
graph_version=graph_version,
preset_id=preset_id,
node_input=node_input or {},
user_id=user_id,
inputs=inputs or {},
)
@staticmethod
@@ -360,11 +357,22 @@ class AgentServer(backend.util.service.AppProcess):
provider: ProviderName,
credentials: Credentials,
) -> Credentials:
from backend.server.integrations.router import create_credentials
return await create_credentials(
user_id=user_id, provider=provider, credentials=credentials
from backend.server.integrations.router import (
create_credentials,
get_credential,
)
try:
return await create_credentials(
user_id=user_id, provider=provider, credentials=credentials
)
except Exception as e:
logger.error(f"Error creating credentials: {e}")
return await get_credential(
provider=provider,
user_id=user_id,
cred_id=credentials.id,
)
def set_test_dependency_overrides(self, overrides: dict):
app.dependency_overrides.update(overrides)

View File

@@ -4,6 +4,7 @@ import logging
from typing import Annotated
import fastapi
import pydantic
import backend.data.analytics
from backend.server.utils import get_user_id
@@ -12,24 +13,28 @@ router = fastapi.APIRouter()
logger = logging.getLogger(__name__)
class LogRawMetricRequest(pydantic.BaseModel):
metric_name: str = pydantic.Field(..., min_length=1)
metric_value: float = pydantic.Field(..., allow_inf_nan=False)
data_string: str = pydantic.Field(..., min_length=1)
@router.post(path="/log_raw_metric")
async def log_raw_metric(
user_id: Annotated[str, fastapi.Depends(get_user_id)],
metric_name: Annotated[str, fastapi.Body(..., embed=True)],
metric_value: Annotated[float, fastapi.Body(..., embed=True)],
data_string: Annotated[str, fastapi.Body(..., embed=True)],
request: LogRawMetricRequest,
):
try:
result = await backend.data.analytics.log_raw_metric(
user_id=user_id,
metric_name=metric_name,
metric_value=metric_value,
data_string=data_string,
metric_name=request.metric_name,
metric_value=request.metric_value,
data_string=request.data_string,
)
return result.id
except Exception as e:
logger.exception(
"Failed to log metric %s for user %s: %s", metric_name, user_id, e
"Failed to log metric %s for user %s: %s", request.metric_name, user_id, e
)
raise fastapi.HTTPException(
status_code=500,

View File

@@ -97,8 +97,17 @@ def test_log_raw_metric_invalid_request_improved() -> None:
assert "data_string" in error_fields, "Should report missing data_string"
def test_log_raw_metric_type_validation_improved() -> None:
def test_log_raw_metric_type_validation_improved(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test metric type validation with improved assertions."""
# Mock the analytics function to avoid event loop issues
mocker.patch(
"backend.data.analytics.log_raw_metric",
new_callable=AsyncMock,
return_value=Mock(id="test-id"),
)
invalid_requests = [
{
"data": {
@@ -119,10 +128,10 @@ def test_log_raw_metric_type_validation_improved() -> None:
{
"data": {
"metric_name": "test",
"metric_value": float("inf"), # Infinity
"data_string": "test",
"metric_value": 123, # Valid number
"data_string": "", # Empty data_string
},
"expected_error": "ensure this value is finite",
"expected_error": "String should have at least 1 character",
},
]

View File

@@ -93,10 +93,18 @@ def test_log_raw_metric_values_parametrized(
],
)
def test_log_raw_metric_invalid_requests_parametrized(
mocker: pytest_mock.MockFixture,
invalid_data: dict,
expected_error: str,
) -> None:
"""Test invalid metric requests with parametrize."""
# Mock the analytics function to avoid event loop issues
mocker.patch(
"backend.data.analytics.log_raw_metric",
new_callable=AsyncMock,
return_value=Mock(id="test-id"),
)
response = client.post("/log_raw_metric", json=invalid_data)
assert response.status_code == 422

View File

@@ -9,7 +9,7 @@ import stripe
from autogpt_libs.auth.middleware import auth_middleware
from autogpt_libs.feature_flag.client import feature_flag
from autogpt_libs.utils.cache import thread_cached
from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request, Response
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
@@ -72,6 +72,7 @@ from backend.server.model import (
UpdatePermissionsRequest,
)
from backend.server.utils import get_user_id
from backend.util.exceptions import NotFoundError
from backend.util.service import get_service_client
from backend.util.settings import Settings
@@ -765,70 +766,94 @@ async def delete_graph_execution(
class ScheduleCreationRequest(pydantic.BaseModel):
graph_version: Optional[int] = None
name: str
cron: str
input_data: dict[Any, Any]
graph_id: str
graph_version: int
inputs: dict[str, Any]
credentials: dict[str, CredentialsMetaInput] = pydantic.Field(default_factory=dict)
@v1_router.post(
path="/schedules",
path="/graphs/{graph_id}/schedules",
summary="Create execution schedule",
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
async def create_schedule(
async def create_graph_execution_schedule(
user_id: Annotated[str, Depends(get_user_id)],
schedule: ScheduleCreationRequest,
graph_id: str = Path(..., description="ID of the graph to schedule"),
schedule_params: ScheduleCreationRequest = Body(),
) -> scheduler.GraphExecutionJobInfo:
graph = await graph_db.get_graph(
schedule.graph_id, schedule.graph_version, user_id=user_id
graph_id=graph_id,
version=schedule_params.graph_version,
user_id=user_id,
)
if not graph:
raise HTTPException(
status_code=404,
detail=f"Graph #{schedule.graph_id} v.{schedule.graph_version} not found.",
detail=f"Graph #{graph_id} v{schedule_params.graph_version} not found.",
)
return await execution_scheduler_client().add_execution_schedule(
graph_id=schedule.graph_id,
graph_version=graph.version,
cron=schedule.cron,
input_data=schedule.input_data,
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
name=schedule_params.name,
cron=schedule_params.cron,
input_data=schedule_params.inputs,
input_credentials=schedule_params.credentials,
)
@v1_router.get(
path="/graphs/{graph_id}/schedules",
summary="List execution schedules for a graph",
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
async def list_graph_execution_schedules(
user_id: Annotated[str, Depends(get_user_id)],
graph_id: str = Path(),
) -> list[scheduler.GraphExecutionJobInfo]:
return await execution_scheduler_client().get_execution_schedules(
user_id=user_id,
graph_id=graph_id,
)
@v1_router.get(
path="/schedules",
summary="List execution schedules for a user",
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
async def list_all_graphs_execution_schedules(
user_id: Annotated[str, Depends(get_user_id)],
) -> list[scheduler.GraphExecutionJobInfo]:
return await execution_scheduler_client().get_execution_schedules(user_id=user_id)
@v1_router.delete(
path="/schedules/{schedule_id}",
summary="Delete execution schedule",
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
async def delete_schedule(
schedule_id: str,
async def delete_graph_execution_schedule(
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
schedule_id: str = Path(..., description="ID of the schedule to delete"),
) -> dict[str, Any]:
try:
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
except NotFoundError:
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail=f"Schedule #{schedule_id} not found",
)
return {"id": schedule_id}
@v1_router.get(
path="/schedules",
summary="List execution schedules",
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
async def get_execution_schedules(
user_id: Annotated[str, Depends(get_user_id)],
graph_id: str | None = None,
) -> list[scheduler.GraphExecutionJobInfo]:
return await execution_scheduler_client().get_execution_schedules(
user_id=user_id,
graph_id=graph_id,
)
########################################################
##################### API KEY ##############################
########################################################

View File

@@ -108,11 +108,16 @@ class TestDatabaseIsolation:
where={"email": {"contains": "@test.example"}}
)
@pytest.fixture(scope="session")
async def test_create_user(self, test_db_connection):
"""Test that demonstrates proper isolation."""
# This test has access to a clean database
user = await test_db_connection.user.create(
data={"email": "test@test.example", "name": "Test User"}
data={
"id": "test-user-id",
"email": "test@test.example",
"name": "Test User",
}
)
assert user.email == "test@test.example"
# User will be cleaned up automatically

View File

@@ -1,5 +1,5 @@
import logging
from typing import Optional
from typing import Literal, Optional
import fastapi
import prisma.errors
@@ -7,17 +7,17 @@ import prisma.fields
import prisma.models
import prisma.types
import backend.data.graph
import backend.data.graph as graph_db
import backend.server.model
import backend.server.v2.library.model as library_model
import backend.server.v2.store.exceptions as store_exceptions
import backend.server.v2.store.image_gen as store_image_gen
import backend.server.v2.store.media as store_media
from backend.data import db
from backend.data import graph as graph_db
from backend.data.db import locked_transaction
from backend.data.block import BlockInput
from backend.data.db import locked_transaction, transaction
from backend.data.execution import get_graph_execution
from backend.data.includes import library_agent_include
from backend.data.model import CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.exceptions import NotFoundError
@@ -122,7 +122,7 @@ async def list_library_agents(
except Exception as e:
# Skip this agent if there was an error
logger.error(
f"Error parsing LibraryAgent when getting library agents from db: {e}"
f"Error parsing LibraryAgent #{agent.id} from DB item: {e}"
)
continue
@@ -168,7 +168,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
)
if not library_agent:
raise store_exceptions.AgentNotFoundError(f"Library agent #{id} not found")
raise NotFoundError(f"Library agent #{id} not found")
return library_model.LibraryAgent.from_db(library_agent)
@@ -215,8 +215,34 @@ async def get_library_agent_by_store_version_id(
return None
async def get_library_agent_by_graph_id(
user_id: str,
graph_id: str,
graph_version: Optional[int] = None,
) -> library_model.LibraryAgent | None:
try:
filter: prisma.types.LibraryAgentWhereInput = {
"agentGraphId": graph_id,
"userId": user_id,
"isDeleted": False,
}
if graph_version is not None:
filter["agentGraphVersion"] = graph_version
agent = await prisma.models.LibraryAgent.prisma().find_first(
where=filter,
include=library_agent_include(user_id),
)
if not agent:
return None
return library_model.LibraryAgent.from_db(agent)
except prisma.errors.PrismaError as e:
logger.error(f"Database error fetching library agent by graph ID: {e}")
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
async def add_generated_agent_image(
graph: backend.data.graph.GraphModel,
graph: graph_db.GraphModel,
library_agent_id: str,
) -> Optional[prisma.models.LibraryAgent]:
"""
@@ -249,7 +275,7 @@ async def add_generated_agent_image(
async def create_library_agent(
graph: backend.data.graph.GraphModel,
graph: graph_db.GraphModel,
user_id: str,
) -> library_model.LibraryAgent:
"""
@@ -346,8 +372,8 @@ async def update_library_agent(
auto_update_version: Optional[bool] = None,
is_favorite: Optional[bool] = None,
is_archived: Optional[bool] = None,
is_deleted: Optional[bool] = None,
) -> None:
is_deleted: Optional[Literal[False]] = None,
) -> library_model.LibraryAgent:
"""
Updates the specified LibraryAgent record.
@@ -357,15 +383,18 @@ async def update_library_agent(
auto_update_version: Whether the agent should auto-update to active version.
is_favorite: Whether this agent is marked as a favorite.
is_archived: Whether this agent is archived.
is_deleted: Whether this agent is deleted.
Returns:
The updated LibraryAgent.
Raises:
NotFoundError: If the specified LibraryAgent does not exist.
DatabaseError: If there's an error in the update operation.
"""
logger.debug(
f"Updating library agent {library_agent_id} for user {user_id} with "
f"auto_update_version={auto_update_version}, is_favorite={is_favorite}, "
f"is_archived={is_archived}, is_deleted={is_deleted}"
f"is_archived={is_archived}"
)
update_fields: prisma.types.LibraryAgentUpdateManyMutationInput = {}
if auto_update_version is not None:
@@ -375,17 +404,46 @@ async def update_library_agent(
if is_archived is not None:
update_fields["isArchived"] = is_archived
if is_deleted is not None:
if is_deleted is True:
raise RuntimeError(
"Use delete_library_agent() to (soft-)delete library agents"
)
update_fields["isDeleted"] = is_deleted
if not update_fields:
raise ValueError("No values were passed to update")
try:
await prisma.models.LibraryAgent.prisma().update_many(
where={"id": library_agent_id, "userId": user_id}, data=update_fields
n_updated = await prisma.models.LibraryAgent.prisma().update_many(
where={"id": library_agent_id, "userId": user_id},
data=update_fields,
)
if n_updated < 1:
raise NotFoundError(f"Library agent {library_agent_id} not found")
return await get_library_agent(
id=library_agent_id,
user_id=user_id,
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating library agent: {str(e)}")
raise store_exceptions.DatabaseError("Failed to update library agent") from e
async def delete_library_agent(
library_agent_id: str, user_id: str, soft_delete: bool = True
) -> None:
if soft_delete:
deleted_count = await prisma.models.LibraryAgent.prisma().update_many(
where={"id": library_agent_id, "userId": user_id}, data={"isDeleted": True}
)
else:
deleted_count = await prisma.models.LibraryAgent.prisma().delete_many(
where={"id": library_agent_id, "userId": user_id}
)
if deleted_count < 1:
raise NotFoundError(f"Library agent #{library_agent_id} not found")
async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
"""
Deletes a library agent for the given user
@@ -525,7 +583,10 @@ async def list_presets(
)
raise store_exceptions.DatabaseError("Invalid pagination parameters")
query_filter: prisma.types.AgentPresetWhereInput = {"userId": user_id}
query_filter: prisma.types.AgentPresetWhereInput = {
"userId": user_id,
"isDeleted": False,
}
if graph_id:
query_filter["agentGraphId"] = graph_id
@@ -581,7 +642,7 @@ async def get_preset(
where={"id": preset_id},
include={"InputPresets": True},
)
if not preset or preset.userId != user_id:
if not preset or preset.userId != user_id or preset.isDeleted:
return None
return library_model.LibraryAgentPreset.from_db(preset)
except prisma.errors.PrismaError as e:
@@ -618,12 +679,19 @@ async def create_preset(
agentGraphId=preset.graph_id,
agentGraphVersion=preset.graph_version,
isActive=preset.is_active,
webhookId=preset.webhook_id,
InputPresets={
"create": [
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
name=name, data=prisma.fields.Json(data)
)
for name, data in preset.inputs.items()
for name, data in {
**preset.inputs,
**{
key: creds_meta.model_dump(exclude_none=True)
for key, creds_meta in preset.credentials.items()
},
}.items()
]
},
),
@@ -664,6 +732,7 @@ async def create_preset_from_graph_execution(
user_id=user_id,
preset=library_model.LibraryAgentPresetCreatable(
inputs=graph_execution.inputs,
credentials={}, # FIXME
graph_id=graph_execution.graph_id,
graph_version=graph_execution.graph_version,
name=create_request.name,
@@ -676,7 +745,11 @@ async def create_preset_from_graph_execution(
async def update_preset(
user_id: str,
preset_id: str,
preset: library_model.LibraryAgentPresetUpdatable,
inputs: Optional[BlockInput] = None,
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
name: Optional[str] = None,
description: Optional[str] = None,
is_active: Optional[bool] = None,
) -> library_model.LibraryAgentPreset:
"""
Updates an existing AgentPreset for a user.
@@ -684,49 +757,95 @@ async def update_preset(
Args:
user_id: The ID of the user updating the preset.
preset_id: The ID of the preset to update.
preset: The preset data used for the update.
inputs: New inputs object to set on the preset.
credentials: New credentials to set on the preset.
name: New name for the preset.
description: New description for the preset.
is_active: New active status for the preset.
Returns:
The updated LibraryAgentPreset.
Raises:
DatabaseError: If there's a database error in updating the preset.
ValueError: If attempting to update a non-existent preset.
NotFoundError: If attempting to update a non-existent preset.
"""
current = await get_preset(user_id, preset_id) # assert ownership
if not current:
raise NotFoundError(f"Preset #{preset_id} not found for user #{user_id}")
logger.debug(
f"Updating preset #{preset_id} ({repr(preset.name)}) for user #{user_id}",
f"Updating preset #{preset_id} ({repr(current.name)}) for user #{user_id}",
)
try:
update_data: prisma.types.AgentPresetUpdateInput = {}
if preset.name:
update_data["name"] = preset.name
if preset.description:
update_data["description"] = preset.description
if preset.inputs:
update_data["InputPresets"] = {
"create": [
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
name=name, data=prisma.fields.Json(data)
async with transaction() as tx:
update_data: prisma.types.AgentPresetUpdateInput = {}
if name:
update_data["name"] = name
if description:
update_data["description"] = description
if is_active is not None:
update_data["isActive"] = is_active
if inputs or credentials:
if not (inputs and credentials):
raise ValueError(
"Preset inputs and credentials must be provided together"
)
for name, data in preset.inputs.items()
]
}
if preset.is_active:
update_data["isActive"] = preset.is_active
update_data["InputPresets"] = {
"create": [
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
name=name, data=prisma.fields.Json(data)
)
for name, data in {
**inputs,
**{
key: creds_meta.model_dump(exclude_none=True)
for key, creds_meta in credentials.items()
},
}.items()
],
}
# Existing InputPresets must be deleted, in a separate query
await prisma.models.AgentNodeExecutionInputOutput.prisma(
tx
).delete_many(where={"agentPresetId": preset_id})
updated = await prisma.models.AgentPreset.prisma().update(
where={"id": preset_id},
data=update_data,
include={"InputPresets": True},
)
updated = await prisma.models.AgentPreset.prisma(tx).update(
where={"id": preset_id},
data=update_data,
include={"InputPresets": True},
)
if not updated:
raise ValueError(f"AgentPreset #{preset_id} not found")
raise RuntimeError(f"AgentPreset #{preset_id} vanished while updating")
return library_model.LibraryAgentPreset.from_db(updated)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating preset: {e}")
raise store_exceptions.DatabaseError("Failed to update preset") from e
async def set_preset_webhook(
user_id: str, preset_id: str, webhook_id: str | None
) -> library_model.LibraryAgentPreset:
current = await prisma.models.AgentPreset.prisma().find_unique(
where={"id": preset_id},
include={"InputPresets": True},
)
if not current or current.userId != user_id:
raise NotFoundError(f"Preset #{preset_id} not found")
updated = await prisma.models.AgentPreset.prisma().update(
where={"id": preset_id},
data=(
{"Webhook": {"connect": {"id": webhook_id}}}
if webhook_id
else {"Webhook": {"disconnect": True}}
),
include={"InputPresets": True},
)
if not updated:
raise RuntimeError(f"AgentPreset #{preset_id} vanished while updating")
return library_model.LibraryAgentPreset.from_db(updated)
async def delete_preset(user_id: str, preset_id: str) -> None:
"""
Soft-deletes a preset by marking it as isDeleted = True.
@@ -738,7 +857,7 @@ async def delete_preset(user_id: str, preset_id: str) -> None:
Raises:
DatabaseError: If there's a database error during deletion.
"""
logger.info(f"Deleting preset {preset_id} for user {user_id}")
logger.debug(f"Setting preset #{preset_id} for user #{user_id} to deleted")
try:
await prisma.models.AgentPreset.prisma().update_many(
where={"id": preset_id, "userId": user_id},
@@ -765,7 +884,7 @@ async def fork_library_agent(library_agent_id: str, user_id: str):
"""
logger.debug(f"Forking library agent {library_agent_id} for user {user_id}")
try:
async with db.locked_transaction(f"usr_trx_{user_id}-fork_agent"):
async with locked_transaction(f"usr_trx_{user_id}-fork_agent"):
# Fetch the original agent
original_agent = await get_library_agent(library_agent_id, user_id)

View File

@@ -143,7 +143,7 @@ async def test_add_agent_to_library(mocker):
)
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock(
return_value=mock_library_agent_data
)
@@ -159,21 +159,24 @@ async def test_add_agent_to_library(mocker):
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"AgentGraph": True}
)
mock_library_agent.return_value.find_first.assert_called_once_with(
mock_library_agent.return_value.find_unique.assert_called_once_with(
where={
"userId": "test-user",
"agentGraphId": "agent1",
"agentGraphVersion": 1,
"userId_agentGraphId_agentGraphVersion": {
"userId": "test-user",
"agentGraphId": "agent1",
"agentGraphVersion": 1,
}
},
include=library_agent_include("test-user"),
include={"AgentGraph": True},
)
mock_library_agent.return_value.create.assert_called_once_with(
data=prisma.types.LibraryAgentCreateInput(
userId="test-user",
agentGraphId="agent1",
agentGraphVersion=1,
isCreatedByUser=False,
),
data={
"User": {"connect": {"id": "test-user"}},
"AgentGraph": {
"connect": {"graphVersionId": {"id": "agent1", "version": 1}}
},
"isCreatedByUser": False,
},
include=library_agent_include("test-user"),
)

View File

@@ -9,6 +9,8 @@ import pydantic
import backend.data.block as block_model
import backend.data.graph as graph_model
import backend.server.model as server_model
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
from backend.integrations.providers import ProviderName
class LibraryAgentStatus(str, Enum):
@@ -18,6 +20,14 @@ class LibraryAgentStatus(str, Enum):
ERROR = "ERROR" # Agent is in an error state
class LibraryAgentTriggerInfo(pydantic.BaseModel):
provider: ProviderName
config_schema: dict[str, Any] = pydantic.Field(
description="Input schema for the trigger block"
)
credentials_input_name: Optional[str]
class LibraryAgent(pydantic.BaseModel):
"""
Represents an agent in the library, including metadata for display and
@@ -40,8 +50,15 @@ class LibraryAgent(pydantic.BaseModel):
name: str
description: str
# Made input_schema and output_schema match GraphMeta's type
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
credentials_input_schema: dict[str, Any] = pydantic.Field(
description="Input schema for credentials required by the agent",
)
has_external_trigger: bool = pydantic.Field(
description="Whether the agent has an external trigger (e.g. webhook) node"
)
trigger_setup_info: Optional[LibraryAgentTriggerInfo] = None
# Indicates whether there's a new output (based on recent runs)
new_output: bool
@@ -106,6 +123,32 @@ class LibraryAgent(pydantic.BaseModel):
name=graph.name,
description=graph.description,
input_schema=graph.input_schema,
credentials_input_schema=graph.credentials_input_schema,
has_external_trigger=graph.has_webhook_trigger,
trigger_setup_info=(
LibraryAgentTriggerInfo(
provider=trigger_block.webhook_config.provider,
config_schema={
**(json_schema := trigger_block.input_schema.jsonschema()),
"properties": {
pn: sub_schema
for pn, sub_schema in json_schema["properties"].items()
if not is_credentials_field_name(pn)
},
"required": [
pn
for pn in json_schema.get("required", [])
if not is_credentials_field_name(pn)
],
},
credentials_input_name=next(
iter(trigger_block.input_schema.get_credentials_fields()), None
),
)
if graph.webhook_input_node
and (trigger_block := graph.webhook_input_node.block).webhook_config
else None
),
new_output=new_output,
can_access_graph=can_access_graph,
is_latest_version=is_latest_version,
@@ -177,12 +220,15 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
graph_version: int
inputs: block_model.BlockInput
credentials: dict[str, CredentialsMetaInput]
name: str
description: str
is_active: bool = True
webhook_id: Optional[str] = None
class LibraryAgentPresetCreatableFromGraphExecution(pydantic.BaseModel):
"""
@@ -203,6 +249,7 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
"""
inputs: Optional[block_model.BlockInput] = None
credentials: Optional[dict[str, CredentialsMetaInput]] = None
name: Optional[str] = None
description: Optional[str] = None
@@ -214,20 +261,28 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
"""Represents a preset configuration for a library agent."""
id: str
user_id: str
updated_at: datetime.datetime
@classmethod
def from_db(cls, preset: prisma.models.AgentPreset) -> "LibraryAgentPreset":
if preset.InputPresets is None:
raise ValueError("Input values must be included in object")
raise ValueError("InputPresets must be included in AgentPreset query")
input_data: block_model.BlockInput = {}
input_credentials: dict[str, CredentialsMetaInput] = {}
for preset_input in preset.InputPresets:
input_data[preset_input.name] = preset_input.data
if not is_credentials_field_name(preset_input.name):
input_data[preset_input.name] = preset_input.data
else:
input_credentials[preset_input.name] = (
CredentialsMetaInput.model_validate(preset_input.data)
)
return cls(
id=preset.id,
user_id=preset.userId,
updated_at=preset.updatedAt,
graph_id=preset.agentGraphId,
graph_version=preset.agentGraphVersion,
@@ -235,6 +290,8 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
description=preset.description,
is_active=preset.isActive,
inputs=input_data,
credentials=input_credentials,
webhook_id=preset.webhookId,
)
@@ -276,6 +333,3 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
is_archived: Optional[bool] = pydantic.Field(
default=None, description="Archive the agent"
)
is_deleted: Optional[bool] = pydantic.Field(
default=None, description="Delete the agent"
)

View File

@@ -1,13 +1,19 @@
import logging
from typing import Optional
from typing import Any, Optional
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from fastapi.responses import JSONResponse
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, status
from fastapi.responses import Response
from pydantic import BaseModel, Field
import backend.server.v2.library.db as library_db
import backend.server.v2.library.model as library_model
import backend.server.v2.store.exceptions as store_exceptions
from backend.data.graph import get_graph
from backend.data.model import CredentialsMetaInput
from backend.executor.utils import make_node_credentials_input_map
from backend.integrations.webhooks.utils import setup_webhook_for_block
from backend.util.exceptions import NotFoundError
logger = logging.getLogger(__name__)
@@ -71,10 +77,10 @@ async def list_library_agents(
page_size=page_size,
)
except Exception as e:
logger.exception("Listing library agents failed for user %s: %s", user_id, e)
logger.error(f"Could not list library agents for user #{user_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Inspect database connectivity."},
detail=str(e),
) from e
@@ -86,6 +92,23 @@ async def get_library_agent(
return await library_db.get_library_agent(id=library_agent_id, user_id=user_id)
@router.get("/by-graph/{graph_id}")
async def get_library_agent_by_graph_id(
graph_id: str,
version: Optional[int] = Query(default=None),
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> library_model.LibraryAgent:
library_agent = await library_db.get_library_agent_by_graph_id(
user_id, graph_id, version
)
if not library_agent:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Library agent for graph #{graph_id} and user #{user_id} not found",
)
return library_agent
@router.get(
"/marketplace/{store_listing_version_id}",
summary="Get Agent By Store ID",
@@ -103,18 +126,16 @@ async def get_library_agent_by_store_listing_version_id(
return await library_db.get_library_agent_by_store_version_id(
store_listing_version_id, user_id
)
except Exception as e:
logger.exception(
"Retrieving library agent by store version failed for user %s: %s",
user_id,
e,
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except Exception as e:
logger.error(f"Could not fetch library agent from store version ID: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"message": str(e),
"hint": "Check if the store listing ID is valid.",
},
detail=str(e),
) from e
@@ -152,26 +173,20 @@ async def add_marketplace_agent_to_library(
user_id=user_id,
)
except store_exceptions.AgentNotFoundError:
except store_exceptions.AgentNotFoundError as e:
logger.warning(
"Store listing version %s not found when adding to library",
store_listing_version_id,
)
raise HTTPException(
status_code=404,
detail={
"message": f"Store listing version {store_listing_version_id} not found",
"hint": "Confirm the ID provided.",
},
f"Could not find store listing version {store_listing_version_id} "
"to add to library"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except store_exceptions.DatabaseError as e:
logger.exception("Database error whilst adding agent to library: %s", e)
logger.error(f"Database error while adding agent to library: {e}", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Inspect DB logs for details."},
) from e
except Exception as e:
logger.exception("Unexpected error while adding agent to library: %s", e)
logger.error(f"Unexpected error while adding agent to library: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
@@ -181,12 +196,11 @@ async def add_marketplace_agent_to_library(
) from e
@router.put(
@router.patch(
"/{library_agent_id}",
summary="Update Library Agent",
status_code=status.HTTP_204_NO_CONTENT,
responses={
204: {"description": "Agent updated successfully"},
200: {"description": "Agent updated successfully"},
500: {"description": "Server error"},
},
)
@@ -194,7 +208,7 @@ async def update_library_agent(
library_agent_id: str,
payload: library_model.LibraryAgentUpdateRequest,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> JSONResponse:
) -> library_model.LibraryAgent:
"""
Update the library agent with the given fields.
@@ -203,39 +217,75 @@ async def update_library_agent(
payload: Fields to update (auto_update_version, is_favorite, etc.).
user_id: ID of the authenticated user.
Returns:
204 (No Content) on success.
Raises:
HTTPException(500): If a server/database error occurs.
"""
try:
await library_db.update_library_agent(
return await library_db.update_library_agent(
library_agent_id=library_agent_id,
user_id=user_id,
auto_update_version=payload.auto_update_version,
is_favorite=payload.is_favorite,
is_archived=payload.is_archived,
is_deleted=payload.is_deleted,
)
return JSONResponse(
status_code=status.HTTP_204_NO_CONTENT,
content={"message": "Agent updated successfully"},
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
except store_exceptions.DatabaseError as e:
logger.exception("Database error while updating library agent: %s", e)
logger.error(f"Database error while updating library agent: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Verify DB connection."},
) from e
except Exception as e:
logger.exception("Unexpected error while updating library agent: %s", e)
logger.error(f"Unexpected error while updating library agent: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Check server logs."},
) from e
@router.delete(
"/{library_agent_id}",
summary="Delete Library Agent",
responses={
204: {"description": "Agent deleted successfully"},
404: {"description": "Agent not found"},
500: {"description": "Server error"},
},
)
async def delete_library_agent(
library_agent_id: str,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> Response:
"""
Soft-delete the specified library agent.
Args:
library_agent_id: ID of the library agent to delete.
user_id: ID of the authenticated user.
Returns:
204 No Content if successful.
Raises:
HTTPException(404): If the agent does not exist.
HTTPException(500): If a server/database error occurs.
"""
try:
await library_db.delete_library_agent(
library_agent_id=library_agent_id, user_id=user_id
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
@router.post("/{library_agent_id}/fork", summary="Fork Library Agent")
async def fork_library_agent(
library_agent_id: str,
@@ -245,3 +295,81 @@ async def fork_library_agent(
library_agent_id=library_agent_id,
user_id=user_id,
)
class TriggeredPresetSetupParams(BaseModel):
name: str
description: str = ""
trigger_config: dict[str, Any]
agent_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
@router.post("/{library_agent_id}/setup-trigger")
async def setup_trigger(
library_agent_id: str = Path(..., description="ID of the library agent"),
params: TriggeredPresetSetupParams = Body(),
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> library_model.LibraryAgentPreset:
"""
Sets up a webhook-triggered `LibraryAgentPreset` for a `LibraryAgent`.
Returns the correspondingly created `LibraryAgentPreset` with `webhook_id` set.
"""
library_agent = await library_db.get_library_agent(
id=library_agent_id, user_id=user_id
)
if not library_agent:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Library agent #{library_agent_id} not found",
)
graph = await get_graph(
library_agent.graph_id, version=library_agent.graph_version, user_id=user_id
)
if not graph:
raise HTTPException(
status.HTTP_410_GONE,
f"Graph #{library_agent.graph_id} not accessible (anymore)",
)
if not (trigger_node := graph.webhook_input_node):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Graph #{library_agent.graph_id} does not have a webhook node",
)
trigger_config_with_credentials = {
**params.trigger_config,
**(
make_node_credentials_input_map(graph, params.agent_credentials).get(
trigger_node.id
)
or {}
),
}
new_webhook, feedback = await setup_webhook_for_block(
user_id=user_id,
trigger_block=trigger_node.block,
trigger_config=trigger_config_with_credentials,
)
if not new_webhook:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Could not set up webhook: {feedback}",
)
new_preset = await library_db.create_preset(
user_id=user_id,
preset=library_model.LibraryAgentPresetCreatable(
graph_id=library_agent.graph_id,
graph_version=library_agent.graph_version,
name=params.name,
description=params.description,
inputs=trigger_config_with_credentials,
credentials=params.agent_credentials,
webhook_id=new_webhook.id,
is_active=True,
),
)
return new_preset

View File

@@ -1,19 +1,23 @@
import logging
from typing import Annotated, Any, Optional
from typing import Any, Optional
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
import backend.server.v2.library.db as db
import backend.server.v2.library.model as models
from backend.executor.utils import add_graph_execution
from backend.data.graph import get_graph
from backend.data.integrations import get_webhook
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks import get_webhook_manager
from backend.integrations.webhooks.utils import setup_webhook_for_block
from backend.util.exceptions import NotFoundError
logger = logging.getLogger(__name__)
router = APIRouter(
tags=["presets"],
)
credentials_manager = IntegrationCredentialsManager()
router = APIRouter(tags=["presets"])
@router.get(
@@ -51,11 +55,7 @@ async def list_presets(
except Exception as e:
logger.exception("Failed to list presets for user %s: %s", user_id, e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"message": str(e),
"hint": "Ensure the presets DB table is accessible.",
},
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
@@ -83,21 +83,21 @@ async def get_preset(
"""
try:
preset = await db.get_preset(user_id, preset_id)
if not preset:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Preset {preset_id} not found",
)
return preset
except Exception as e:
logger.exception(
"Error retrieving preset %s for user %s: %s", preset_id, user_id, e
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Validate preset ID and retry."},
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
if not preset:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Preset #{preset_id} not found",
)
return preset
@router.post(
"/presets",
@@ -134,8 +134,7 @@ async def create_preset(
except Exception as e:
logger.exception("Preset creation failed for user %s: %s", user_id, e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Check preset payload format."},
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
@@ -163,17 +162,85 @@ async def update_preset(
Raises:
HTTPException: If an error occurs while updating the preset.
"""
current = await get_preset(preset_id, user_id=user_id)
if not current:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Preset #{preset_id} not found")
graph = await get_graph(
current.graph_id,
current.graph_version,
user_id=user_id,
)
if not graph:
raise HTTPException(
status.HTTP_410_GONE,
f"Graph #{current.graph_id} not accessible (anymore)",
)
trigger_inputs_updated, new_webhook, feedback = False, None, None
if (trigger_node := graph.webhook_input_node) and (
preset.inputs is not None and preset.credentials is not None
):
trigger_config_with_credentials = {
**preset.inputs,
**(
make_node_credentials_input_map(graph, preset.credentials).get(
trigger_node.id
)
or {}
),
}
new_webhook, feedback = await setup_webhook_for_block(
user_id=user_id,
trigger_block=graph.webhook_input_node.block,
trigger_config=trigger_config_with_credentials,
for_preset_id=preset_id,
)
trigger_inputs_updated = True
if not new_webhook:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Could not update trigger configuration: {feedback}",
)
try:
return await db.update_preset(
user_id=user_id, preset_id=preset_id, preset=preset
updated = await db.update_preset(
user_id=user_id,
preset_id=preset_id,
inputs=preset.inputs,
credentials=preset.credentials,
name=preset.name,
description=preset.description,
is_active=preset.is_active,
)
except Exception as e:
logger.exception("Preset update failed for user %s: %s", user_id, e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Check preset data and try again."},
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
# Update the webhook as well, if necessary
if trigger_inputs_updated:
updated = await db.set_preset_webhook(
user_id, preset_id, new_webhook.id if new_webhook else None
)
# Clean up webhook if it is now unused
if current.webhook_id and (
current.webhook_id != (new_webhook.id if new_webhook else None)
):
current_webhook = await get_webhook(current.webhook_id)
credentials = (
await credentials_manager.get(user_id, current_webhook.credentials_id)
if current_webhook.credentials_id
else None
)
await get_webhook_manager(
current_webhook.provider
).prune_webhook_if_dangling(user_id, current_webhook.id, credentials)
return updated
@router.delete(
"/presets/{preset_id}",
@@ -195,6 +262,28 @@ async def delete_preset(
Raises:
HTTPException: If an error occurs while deleting the preset.
"""
preset = await db.get_preset(user_id, preset_id)
if not preset:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Preset #{preset_id} not found for user #{user_id}",
)
# Detach and clean up the attached webhook, if any
if preset.webhook_id:
webhook = await get_webhook(preset.webhook_id)
await db.set_preset_webhook(user_id, preset_id, None)
# Clean up webhook if it is now unused
credentials = (
await credentials_manager.get(user_id, webhook.credentials_id)
if webhook.credentials_id
else None
)
await get_webhook_manager(webhook.provider).prune_webhook_if_dangling(
user_id, webhook.id, credentials
)
try:
await db.delete_preset(user_id, preset_id)
except Exception as e:
@@ -203,7 +292,7 @@ async def delete_preset(
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Ensure preset exists before deleting."},
detail=str(e),
)
@@ -214,24 +303,20 @@ async def delete_preset(
description="Execute a preset with the given graph and node input for the current user.",
)
async def execute_preset(
graph_id: str,
graph_version: int,
preset_id: str,
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
inputs: dict[str, Any] = Body(..., embed=True, default_factory=dict),
) -> dict[str, Any]: # FIXME: add proper return type
"""
Execute a preset given graph parameters, returning the execution ID on success.
Args:
graph_id (str): ID of the graph to execute.
graph_version (int): Version of the graph to execute.
preset_id (str): ID of the preset to execute.
node_input (Dict[Any, Any]): Input data for the node.
user_id (str): ID of the authenticated user.
inputs (dict[str, Any]): Optionally, additional input data for the graph execution.
Returns:
Dict[str, Any]: A response containing the execution ID.
{id: graph_exec_id}: A response containing the execution ID.
Raises:
HTTPException: If the preset is not found or an error occurs while executing the preset.
@@ -241,18 +326,18 @@ async def execute_preset(
if not preset:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Preset not found",
detail=f"Preset #{preset_id} not found",
)
# Merge input overrides with preset inputs
merged_node_input = preset.inputs | node_input
merged_node_input = preset.inputs | inputs
execution = await add_graph_execution(
graph_id=graph_id,
user_id=user_id,
inputs=merged_node_input,
graph_id=preset.graph_id,
graph_version=preset.graph_version,
preset_id=preset_id,
graph_version=graph_version,
inputs=merged_node_input,
)
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
@@ -263,9 +348,6 @@ async def execute_preset(
except Exception as e:
logger.exception("Preset execution failed for user %s: %s", user_id, e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": str(e),
"hint": "Review preset configuration and graph ID.",
},
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)

View File

@@ -50,6 +50,8 @@ async def test_get_library_agents_success(
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
credentials_input_schema={"type": "object", "properties": {}},
has_external_trigger=False,
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=True,
@@ -66,6 +68,8 @@ async def test_get_library_agents_success(
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
credentials_input_schema={"type": "object", "properties": {}},
has_external_trigger=False,
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=False,
@@ -117,26 +121,57 @@ def test_get_library_agents_error(mocker: pytest_mock.MockFixture):
)
@pytest.mark.skip(reason="Mocker Not implemented")
def test_add_agent_to_library_success(mocker: pytest_mock.MockFixture):
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
mock_db_call.return_value = None
mock_library_agent = library_model.LibraryAgent(
id="test-library-agent-id",
graph_id="test-agent-1",
graph_version=1,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
credentials_input_schema={"type": "object", "properties": {}},
has_external_trigger=False,
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=True,
is_latest_version=True,
updated_at=FIXED_NOW,
)
response = client.post("/agents/test-version-id")
mock_db_call = mocker.patch(
"backend.server.v2.library.db.add_store_agent_to_library"
)
mock_db_call.return_value = mock_library_agent
response = client.post(
"/agents", json={"store_listing_version_id": "test-version-id"}
)
assert response.status_code == 201
# Verify the response contains the library agent data
data = library_model.LibraryAgent.model_validate(response.json())
assert data.id == "test-library-agent-id"
assert data.graph_id == "test-agent-1"
mock_db_call.assert_called_once_with(
store_listing_version_id="test-version-id", user_id="test-user-id"
)
@pytest.mark.skip(reason="Mocker Not implemented")
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture):
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
mock_db_call = mocker.patch(
"backend.server.v2.library.db.add_store_agent_to_library"
)
mock_db_call.side_effect = Exception("Test error")
response = client.post("/agents/test-version-id")
response = client.post(
"/agents", json={"store_listing_version_id": "test-version-id"}
)
assert response.status_code == 500
assert response.json()["detail"] == "Failed to add agent to library"
assert "detail" in response.json() # Verify error response structure
mock_db_call.assert_called_once_with(
store_listing_version_id="test-version-id", user_id="test-user-id"
)

View File

@@ -259,8 +259,8 @@ def test_ask_otto_unauthenticated(mocker: pytest_mock.MockFixture) -> None:
}
response = client.post("/ask", json=request_data)
# When auth is disabled and Otto API URL is not configured, we get 503
assert response.status_code == 503
# When auth is disabled and Otto API URL is not configured, we get 502 (wrapped from 503)
assert response.status_code == 502
# Restore the override
app.dependency_overrides[autogpt_libs.auth.middleware.auth_middleware] = (

View File

@@ -93,6 +93,14 @@ async def test_get_store_agent_details(mocker):
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
# Mock Profile prisma call
mock_profile = mocker.MagicMock()
mock_profile.userId = "user-id-123"
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
mock_profile_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_profile
)
# Mock StoreListing prisma call - this is what was missing
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(

View File

@@ -34,6 +34,20 @@ class StorageUploadError(MediaUploadError):
pass
class VirusDetectedError(MediaUploadError):
"""Raised when a virus is detected in uploaded file"""
def __init__(self, threat_name: str, message: str | None = None):
self.threat_name = threat_name
super().__init__(message or f"Virus detected: {threat_name}")
class VirusScanError(MediaUploadError):
"""Raised when virus scanning fails"""
pass
class StoreError(Exception):
"""Base exception for store-related errors"""

View File

@@ -8,6 +8,7 @@ from google.cloud import storage
import backend.server.v2.store.exceptions
from backend.util.exceptions import MissingConfigError
from backend.util.settings import Settings
from backend.util.virus_scanner import scan_content_safe
logger = logging.getLogger(__name__)
@@ -67,7 +68,7 @@ async def upload_media(
# Validate file signature/magic bytes
if file.content_type in ALLOWED_IMAGE_TYPES:
# Check image file signatures
if content.startswith(b"\xFF\xD8\xFF"): # JPEG
if content.startswith(b"\xff\xd8\xff"): # JPEG
if file.content_type != "image/jpeg":
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
"File signature does not match content type"
@@ -175,6 +176,7 @@ async def upload_media(
blob.content_type = content_type
file_bytes = await file.read()
await scan_content_safe(file_bytes, filename=unique_filename)
blob.upload_from_string(file_bytes, content_type=content_type)
public_url = blob.public_url

View File

@@ -12,6 +12,7 @@ from autogpt_libs.auth.depends import auth_middleware, get_user_id
import backend.data.block
import backend.data.graph
import backend.server.v2.store.db
import backend.server.v2.store.exceptions
import backend.server.v2.store.image_gen
import backend.server.v2.store.media
import backend.server.v2.store.model
@@ -589,6 +590,25 @@ async def upload_submission_media(
user_id=user_id, file=file
)
return media_url
except backend.server.v2.store.exceptions.VirusDetectedError as e:
logger.warning(f"Virus detected in uploaded file: {e.threat_name}")
return fastapi.responses.JSONResponse(
status_code=400,
content={
"detail": f"File rejected due to virus detection: {e.threat_name}",
"error_type": "virus_detected",
"threat_name": e.threat_name,
},
)
except backend.server.v2.store.exceptions.VirusScanError as e:
logger.error(f"Virus scanning failed: {str(e)}")
return fastapi.responses.JSONResponse(
status_code=503,
content={
"detail": "Virus scanning service unavailable. Please try again later.",
"error_type": "virus_scan_failed",
},
)
except Exception:
logger.exception("Exception occurred whilst uploading submission media")
return fastapi.responses.JSONResponse(

View File

@@ -19,7 +19,9 @@ from backend.server.ws_api import (
@pytest.fixture
def mock_websocket() -> AsyncMock:
return AsyncMock(spec=WebSocket)
mock = AsyncMock(spec=WebSocket)
mock.query_params = {} # Add query_params attribute for authentication
return mock
@pytest.fixture
@@ -29,8 +31,13 @@ def mock_manager() -> AsyncMock:
@pytest.mark.asyncio
async def test_websocket_router_subscribe(
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker
) -> None:
# Mock the authenticate_websocket function to ensure it returns a valid user_id
mocker.patch(
"backend.server.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
)
mock_websocket.receive_text.side_effect = [
WSMessage(
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
@@ -70,8 +77,13 @@ async def test_websocket_router_subscribe(
@pytest.mark.asyncio
async def test_websocket_router_unsubscribe(
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker
) -> None:
# Mock the authenticate_websocket function to ensure it returns a valid user_id
mocker.patch(
"backend.server.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
)
mock_websocket.receive_text.side_effect = [
WSMessage(
method=WSMethod.UNSUBSCRIBE,
@@ -108,8 +120,13 @@ async def test_websocket_router_unsubscribe(
@pytest.mark.asyncio
async def test_websocket_router_invalid_method(
mock_websocket: AsyncMock, mock_manager: AsyncMock
mock_websocket: AsyncMock, mock_manager: AsyncMock, mocker
) -> None:
# Mock the authenticate_websocket function to ensure it returns a valid user_id
mocker.patch(
"backend.server.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
)
mock_websocket.receive_text.side_effect = [
WSMessage(method=WSMethod.GRAPH_EXECUTION_EVENT).model_dump_json(),
WebSocketDisconnect(),

View File

@@ -2,7 +2,17 @@ import functools
import logging
import os
import time
from typing import Any, Awaitable, Callable, Coroutine, ParamSpec, Tuple, TypeVar
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Literal,
ParamSpec,
Tuple,
TypeVar,
overload,
)
from pydantic import BaseModel
@@ -72,37 +82,115 @@ def async_time_measured(
return async_wrapper
def error_logged(func: Callable[P, T]) -> Callable[P, T | None]:
@overload
def error_logged(
*, swallow: Literal[True]
) -> Callable[[Callable[P, T]], Callable[P, T | None]]: ...
@overload
def error_logged(
*, swallow: Literal[False]
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
@overload
def error_logged() -> Callable[[Callable[P, T]], Callable[P, T | None]]: ...
def error_logged(
*, swallow: bool = True
) -> (
Callable[[Callable[P, T]], Callable[P, T | None]]
| Callable[[Callable[P, T]], Callable[P, T]]
):
"""
Decorator to suppress and log any exceptions raised by a function.
Decorator to log any exceptions raised by a function, with optional suppression.
Args:
swallow: Whether to suppress the exception (True) or re-raise it (False). Default is True.
Usage:
@error_logged() # Default behavior (swallow errors)
@error_logged(swallow=False) # Log and re-raise errors
"""
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
try:
return func(*args, **kwargs)
except Exception as e:
logger.exception(
f"Error when calling function {func.__name__} with arguments {args} {kwargs}: {e}"
)
def decorator(f: Callable[P, T]) -> Callable[P, T | None]:
@functools.wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
try:
return f(*args, **kwargs)
except Exception as e:
logger.exception(
f"Error when calling function {f.__name__} with arguments {args} {kwargs}: {e}"
)
if not swallow:
raise
return None
return wrapper
return wrapper
return decorator
@overload
def async_error_logged(
func: Callable[P, Coroutine[Any, Any, T]],
) -> Callable[P, Coroutine[Any, Any, T | None]]:
*, swallow: Literal[True]
) -> Callable[
[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T | None]]
]: ...
@overload
def async_error_logged(
*, swallow: Literal[False]
) -> Callable[
[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T]]
]: ...
@overload
def async_error_logged() -> Callable[
[Callable[P, Coroutine[Any, Any, T]]],
Callable[P, Coroutine[Any, Any, T | None]],
]: ...
def async_error_logged(*, swallow: bool = True) -> (
Callable[
[Callable[P, Coroutine[Any, Any, T]]],
Callable[P, Coroutine[Any, Any, T | None]],
]
| Callable[
[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T]]
]
):
"""
Decorator to suppress and log any exceptions raised by an async function.
Decorator to log any exceptions raised by an async function, with optional suppression.
Args:
swallow: Whether to suppress the exception (True) or re-raise it (False). Default is True.
Usage:
@async_error_logged() # Default behavior (swallow errors)
@async_error_logged(swallow=False) # Log and re-raise errors
"""
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
try:
return await func(*args, **kwargs)
except Exception as e:
logger.exception(
f"Error when calling async function {func.__name__} with arguments {args} {kwargs}: {e}"
)
def decorator(
f: Callable[P, Coroutine[Any, Any, T]]
) -> Callable[P, Coroutine[Any, Any, T | None]]:
@functools.wraps(f)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
try:
return await f(*args, **kwargs)
except Exception as e:
logger.exception(
f"Error when calling async function {f.__name__} with arguments {args} {kwargs}: {e}"
)
if not swallow:
raise
return None
return wrapper
return wrapper
return decorator

View File

@@ -0,0 +1,74 @@
import time
import pytest
from backend.util.decorator import async_error_logged, error_logged, time_measured
@time_measured
def example_function(a: int, b: int, c: int) -> int:
time.sleep(0.5)
return a + b + c
@error_logged(swallow=True)
def example_function_with_error_swallowed(a: int, b: int, c: int) -> int:
raise ValueError("This error should be swallowed")
@error_logged(swallow=False)
def example_function_with_error_not_swallowed(a: int, b: int, c: int) -> int:
raise ValueError("This error should NOT be swallowed")
@async_error_logged(swallow=True)
async def async_function_with_error_swallowed() -> int:
raise ValueError("This async error should be swallowed")
@async_error_logged(swallow=False)
async def async_function_with_error_not_swallowed() -> int:
raise ValueError("This async error should NOT be swallowed")
def test_timer_decorator():
"""Test that the time_measured decorator correctly measures execution time."""
info, res = example_function(1, 2, 3)
assert info.cpu_time >= 0
assert info.wall_time >= 0.4
assert res == 6
def test_error_decorator_swallow_true():
"""Test that error_logged(swallow=True) logs and swallows errors."""
res = example_function_with_error_swallowed(1, 2, 3)
assert res is None
def test_error_decorator_swallow_false():
"""Test that error_logged(swallow=False) logs errors but re-raises them."""
with pytest.raises(ValueError, match="This error should NOT be swallowed"):
example_function_with_error_not_swallowed(1, 2, 3)
def test_async_error_decorator_swallow_true():
"""Test that async_error_logged(swallow=True) logs and swallows errors."""
import asyncio
async def run_test():
res = await async_function_with_error_swallowed()
return res
res = asyncio.run(run_test())
assert res is None
def test_async_error_decorator_swallow_false():
"""Test that async_error_logged(swallow=False) logs errors but re-raises them."""
import asyncio
async def run_test():
await async_function_with_error_not_swallowed()
with pytest.raises(ValueError, match="This async error should NOT be swallowed"):
asyncio.run(run_test())

View File

@@ -10,6 +10,10 @@ class NeedConfirmation(Exception):
"""The user must explicitly confirm that they want to proceed"""
class NotAuthorizedError(ValueError):
"""The user is not authorized to perform the requested operation"""
class InsufficientBalanceError(ValueError):
user_id: str
message: str

View File

@@ -9,6 +9,7 @@ from urllib.parse import urlparse
from backend.util.request import Requests
from backend.util.type import MediaFileType
from backend.util.virus_scanner import scan_content_safe
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
@@ -105,7 +106,11 @@ async def store_media_file(
extension = _extension_from_mime(mime_type)
filename = f"{uuid.uuid4()}{extension}"
target_path = _ensure_inside_base(base_path / filename, base_path)
target_path.write_bytes(base64.b64decode(b64_content))
content = base64.b64decode(b64_content)
# Virus scan the base64 content before writing
await scan_content_safe(content, filename=filename)
target_path.write_bytes(content)
elif file.startswith(("http://", "https://")):
# URL
@@ -115,6 +120,9 @@ async def store_media_file(
# Download and save
resp = await Requests().get(file)
# Virus scan the downloaded content before writing
await scan_content_safe(resp.content, filename=filename)
target_path.write_bytes(resp.content)
else:

View File

@@ -14,8 +14,37 @@ def to_dict(data) -> dict:
return jsonable_encoder(data)
def dumps(data) -> str:
return json.dumps(to_dict(data))
def dumps(data: Any, *args: Any, **kwargs: Any) -> str:
"""
Serialize data to JSON string with automatic conversion of Pydantic models and complex types.
This function converts the input data to a JSON-serializable format using FastAPI's
jsonable_encoder before dumping to JSON. It handles Pydantic models, complex types,
and ensures proper serialization.
Parameters
----------
data : Any
The data to serialize. Can be any type including Pydantic models, dicts, lists, etc.
*args : Any
Additional positional arguments passed to json.dumps()
**kwargs : Any
Additional keyword arguments passed to json.dumps() (e.g., indent, separators)
Returns
-------
str
JSON string representation of the data
Examples
--------
>>> dumps({"name": "Alice", "age": 30})
'{"name": "Alice", "age": 30}'
>>> dumps(pydantic_model_instance, indent=2)
'{\n "field1": "value1",\n "field2": "value2"\n}'
"""
return json.dumps(to_dict(data), *args, **kwargs)
T = TypeVar("T")

View File

@@ -1,4 +1,4 @@
from logging import Logger
import logging
from backend.util.settings import AppEnvironment, BehaveAs, Settings
@@ -6,8 +6,6 @@ settings = Settings()
def configure_logging():
import logging
import autogpt_libs.logging.config
if (
@@ -25,7 +23,7 @@ def configure_logging():
class TruncatedLogger:
def __init__(
self,
logger: Logger,
logger: logging.Logger,
prefix: str = "",
metadata: dict | None = None,
max_length: int = 1000,
@@ -65,3 +63,13 @@ class TruncatedLogger:
if len(text) > self.max_length:
text = text[: self.max_length] + "..."
return text
class PrefixFilter(logging.Filter):
def __init__(self, prefix: str):
super().__init__()
self.prefix = prefix
def filter(self, record):
record.msg = f"{self.prefix} {record.msg}"
return True

View File

@@ -0,0 +1,206 @@
from copy import deepcopy
from typing import Any
from tiktoken import encoding_for_model
from backend.util import json
# ---------------------------------------------------------------------------#
# INTERNAL UTILITIES #
# ---------------------------------------------------------------------------#
def _tok_len(text: str, enc) -> int:
"""True token length of *text* in tokenizer *enc* (no wrapper cost)."""
return len(enc.encode(text))
def _msg_tokens(msg: dict, enc) -> int:
"""
OpenAI counts ≈3 wrapper tokens per chat message, plus 1 if "name"
is present, plus the tokenised content length.
"""
WRAPPER = 3 + (1 if "name" in msg else 0)
return WRAPPER + _tok_len(msg.get("content") or "", enc)
def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
"""
Return *text* shortened to ≈max_tok tokens by keeping the head & tail
and inserting an ellipsis token in the middle.
"""
ids = enc.encode(text)
if len(ids) <= max_tok:
return text # nothing to do
# Split the allowance between the two ends:
head = max_tok // 2 - 1 # -1 for the ellipsis
tail = max_tok - head - 1
mid = enc.encode("")
return enc.decode(ids[:head] + mid + ids[-tail:])
# ---------------------------------------------------------------------------#
# PUBLIC API #
# ---------------------------------------------------------------------------#
def compress_prompt(
messages: list[dict],
target_tokens: int,
*,
model: str = "gpt-4o",
reserve: int = 2_048,
start_cap: int = 8_192,
floor_cap: int = 128,
lossy_ok: bool = True,
) -> list[dict]:
"""
Shrink *messages* so that::
token_count(prompt) + reserve ≤ target_tokens
Strategy
--------
1. **Token-aware truncation** progressively halve a per-message cap
(`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the
*content* of every message except the first and last. Tool shells
are included: we keep the envelope but shorten huge payloads.
2. **Middle-out deletion** if still over the limit, delete whole
messages working outward from the centre, **skipping** any message
that contains ``tool_calls`` or has ``role == "tool"``.
3. **Last-chance trim** if still too big, truncate the *first* and
*last* message bodies down to `floor_cap` tokens.
4. If the prompt is *still* too large:
• raise ``ValueError`` when ``lossy_ok == False`` (default)
• return the partially-trimmed prompt when ``lossy_ok == True``
Parameters
----------
messages Complete chat history (will be deep-copied).
model Model name; passed to tiktoken to pick the right
tokenizer (gpt-4o → 'o200k_base', others fallback).
target_tokens Hard ceiling for prompt size **excluding** the model's
forthcoming answer.
reserve How many tokens you want to leave available for that
answer (`max_tokens` in your subsequent completion call).
start_cap Initial per-message truncation ceiling (tokens).
floor_cap Lowest cap we'll accept before moving to deletions.
lossy_ok If *True* return best-effort prompt instead of raising
after all trim passes have been exhausted.
Returns
-------
list[dict] A *new* messages list that abides by the rules above.
"""
enc = encoding_for_model(model) # best-match tokenizer
msgs = deepcopy(messages) # never mutate caller
def total_tokens() -> int:
"""Current size of *msgs* in tokens."""
return sum(_msg_tokens(m, enc) for m in msgs)
original_token_count = total_tokens()
if original_token_count + reserve <= target_tokens:
return msgs
# ---- STEP 0 : normalise content --------------------------------------
# Convert non-string payloads to strings so token counting is coherent.
for m in msgs[1:-1]: # keep the first & last intact
if not isinstance(m.get("content"), str) and m.get("content") is not None:
# Reasonable 20k-char ceiling prevents pathological blobs
content_str = json.dumps(m["content"], separators=(",", ":"))
if len(content_str) > 20_000:
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
m["content"] = content_str
# ---- STEP 1 : token-aware truncation ---------------------------------
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for m in msgs[1:-1]: # keep first & last intact
if _tok_len(m.get("content") or "", enc) > cap:
m["content"] = _truncate_middle_tokens(m["content"], enc, cap)
cap //= 2 # tighten the screw
# ---- STEP 2 : middle-out deletion -----------------------------------
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
centre = len(msgs) // 2
# Build a symmetrical centre-out index walk: centre, centre+1, centre-1, ...
order = [centre] + [
i
for pair in zip(range(centre + 1, len(msgs) - 1), range(centre - 1, 0, -1))
for i in pair
]
removed = False
for i in order:
msg = msgs[i]
if "tool_calls" in msg or msg.get("role") == "tool":
continue # protect tool shells
del msgs[i]
removed = True
break
if not removed: # nothing more we can drop
break
# ---- STEP 3 : final safety-net trim on first & last ------------------
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for idx in (0, -1): # first and last
text = msgs[idx].get("content") or ""
if _tok_len(text, enc) > cap:
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
cap //= 2 # tighten the screw
# ---- STEP 4 : success or fail-gracefully -----------------------------
if total_tokens() + reserve > target_tokens and not lossy_ok:
raise ValueError(
"compress_prompt: prompt still exceeds budget "
f"({total_tokens() + reserve} > {target_tokens})."
)
return msgs
def estimate_token_count(
messages: list[dict],
*,
model: str = "gpt-4o",
) -> int:
"""
Return the true token count of *messages* when encoded for *model*.
Parameters
----------
messages Complete chat history.
model Model name; passed to tiktoken to pick the right
tokenizer (gpt-4o → 'o200k_base', others fallback).
Returns
-------
int Token count.
"""
enc = encoding_for_model(model) # best-match tokenizer
return sum(_msg_tokens(m, enc) for m in messages)
def estimate_token_count_str(
text: Any,
*,
model: str = "gpt-4o",
) -> int:
"""
Return the true token count of *text* when encoded for *model*.
Parameters
----------
text Input text.
model Model name; passed to tiktoken to pick the right
tokenizer (gpt-4o → 'o200k_base', others fallback).
Returns
-------
int Token count.
"""
enc = encoding_for_model(model) # best-match tokenizer
text = json.dumps(text) if not isinstance(text, str) else text
return _tok_len(text, enc)

View File

@@ -430,7 +430,13 @@ class Requests:
) as response:
if self.raise_for_status:
response.raise_for_status()
try:
response.raise_for_status()
except ClientResponseError as e:
body = await response.read()
raise Exception(
f"HTTP {response.status} Error: {response.reason}, Body: {body.decode(errors='replace')}"
) from e
# If allowed and a redirect is received, follow the redirect manually
if allow_redirects and response.status in (301, 302, 303, 307, 308):

View File

@@ -31,7 +31,7 @@ from tenacity import (
wait_exponential_jitter,
)
from backend.util.exceptions import InsufficientBalanceError
import backend.util.exceptions as exceptions
from backend.util.json import to_dict
from backend.util.metrics import sentry_init
from backend.util.process import AppProcess, get_service_name
@@ -106,7 +106,13 @@ EXCEPTION_MAPPING = {
ValueError,
TimeoutError,
ConnectionError,
InsufficientBalanceError,
*[
ErrorType
for _, ErrorType in inspect.getmembers(exceptions)
if inspect.isclass(ErrorType)
and issubclass(ErrorType, Exception)
and ErrorType.__module__ == exceptions.__name__
],
]
}

View File

@@ -51,7 +51,7 @@ class ServiceTestClient(AppServiceClient):
subtract_async = endpoint_to_async(ServiceTest.subtract)
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio
async def test_service_creation(server):
with ServiceTest():
client = get_service_client(ServiceTestClient)

View File

@@ -238,6 +238,31 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The Discord channel for the platform",
)
clamav_service_host: str = Field(
default="localhost",
description="The host for the ClamAV daemon",
)
clamav_service_port: int = Field(
default=3310,
description="The port for the ClamAV daemon",
)
clamav_service_timeout: int = Field(
default=60,
description="The timeout in seconds for the ClamAV daemon",
)
clamav_service_enabled: bool = Field(
default=True,
description="Whether virus scanning is enabled or not",
)
clamav_max_concurrency: int = Field(
default=10,
description="The maximum number of concurrent scans to perform",
)
clamav_mark_failed_scans_as_clean: bool = Field(
default=False,
description="Whether to mark failed scans as clean or not",
)
@field_validator("platform_base_url", "frontend_base_url")
@classmethod
def validate_platform_base_url(cls, v: str, info: ValidationInfo) -> str:

View File

@@ -152,9 +152,16 @@ async def main():
print(f"Inserting {NUM_USERS * MAX_AGENTS_PER_USER} user agents")
for user in users:
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
for _ in range(num_agents): # Create 1 LibraryAgent per user
graph = random.choice(agent_graphs)
preset = random.choice(agent_presets)
# Get a shuffled list of graphs to ensure uniqueness per user
available_graphs = agent_graphs.copy()
random.shuffle(available_graphs)
# Limit to available unique graphs
num_agents = min(num_agents, len(available_graphs))
for i in range(num_agents):
graph = available_graphs[i] # Use unique graph for each library agent
user_agent = await db.libraryagent.create(
data={
"userId": user.id,
@@ -180,7 +187,7 @@ async def main():
MIN_EXECUTIONS_PER_GRAPH, MAX_EXECUTIONS_PER_GRAPH
)
for _ in range(num_executions):
matching_presets = [p for p in agent_presets if p.agentId == graph.id]
matching_presets = [p for p in agent_presets if p.agentGraphId == graph.id]
preset = (
random.choice(matching_presets)
if matching_presets and random.random() < 0.5
@@ -355,7 +362,7 @@ async def main():
store_listing_versions = []
print(f"Inserting {NUM_USERS} store listing versions")
for listing in store_listings:
graph = [g for g in agent_graphs if g.id == listing.agentId][0]
graph = [g for g in agent_graphs if g.id == listing.agentGraphId][0]
version = await db.storelistingversion.create(
data={
"agentGraphId": graph.id,

View File

@@ -0,0 +1,209 @@
import asyncio
import io
import logging
import time
from typing import Optional, Tuple
import aioclamd
from pydantic import BaseModel
from pydantic_settings import BaseSettings
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
class VirusScanResult(BaseModel):
is_clean: bool
scan_time_ms: int
file_size: int
threat_name: Optional[str] = None
class VirusScannerSettings(BaseSettings):
# Tunables for the scanner layer (NOT the ClamAV daemon).
clamav_service_host: str = "localhost"
clamav_service_port: int = 3310
clamav_service_timeout: int = 60
clamav_service_enabled: bool = True
# If the service is disabled, all files are considered clean.
mark_failed_scans_as_clean: bool = False
# Client-side protective limits
max_scan_size: int = 2 * 1024 * 1024 * 1024 # 2 GB guard-rail in memory
min_chunk_size: int = 128 * 1024 # 128 KB hard floor
max_retries: int = 8 # halve ≤ max_retries times
# Concurrency throttle toward the ClamAV daemon. Do *NOT* simply turn this
# up to the number of CPU cores; keep it ≤ (MaxThreads / pods) 1.
max_concurrency: int = 5
class VirusScannerService:
"""Fully-async ClamAV wrapper using **aioclamd**.
• Reuses a single `ClamdAsyncClient` connection (aioclamd keeps the socket open).
• Throttles concurrent `INSTREAM` calls with an `asyncio.Semaphore` so we don't exhaust daemon worker threads or file descriptors.
• Falls back to progressively smaller chunk sizes when the daemon rejects a stream as too large.
"""
def __init__(self, settings: VirusScannerSettings) -> None:
self.settings = settings
self._client = aioclamd.ClamdAsyncClient(
host=settings.clamav_service_host,
port=settings.clamav_service_port,
timeout=settings.clamav_service_timeout,
)
self._sem = asyncio.Semaphore(settings.max_concurrency)
# ------------------------------------------------------------------ #
# Helpers
# ------------------------------------------------------------------ #
@staticmethod
def _parse_raw(raw: Optional[dict]) -> Tuple[bool, Optional[str]]:
"""
Convert aioclamd output to (infected?, threat_name).
Returns (False, None) for clean.
"""
if not raw:
return False, None
status, threat = next(iter(raw.values()))
return status == "FOUND", threat
async def _instream(self, chunk: bytes) -> Tuple[bool, Optional[str]]:
"""Scan **one** chunk with concurrency control."""
async with self._sem:
try:
raw = await self._client.instream(io.BytesIO(chunk))
return self._parse_raw(raw)
except (BrokenPipeError, ConnectionResetError) as exc:
raise RuntimeError("size-limit") from exc
except Exception as exc:
if "INSTREAM size limit exceeded" in str(exc):
raise RuntimeError("size-limit") from exc
raise
# ------------------------------------------------------------------ #
# Public API
# ------------------------------------------------------------------ #
async def scan_file(
self, content: bytes, *, filename: str = "unknown"
) -> VirusScanResult:
"""
Scan `content`. Returns a result object or raises on infrastructure
failure (unreachable daemon, etc.).
The algorithm always tries whole-file first. If the daemon refuses
on size grounds, it falls back to chunked parallel scanning.
"""
if not self.settings.clamav_service_enabled:
logger.warning(f"Virus scanning disabled accepting {filename}")
return VirusScanResult(
is_clean=True, scan_time_ms=0, file_size=len(content)
)
if len(content) > self.settings.max_scan_size:
logger.warning(
f"File {filename} ({len(content)} bytes) exceeds client max scan size ({self.settings.max_scan_size}); Stopping virus scan"
)
return VirusScanResult(
is_clean=self.settings.mark_failed_scans_as_clean,
file_size=len(content),
scan_time_ms=0,
)
# Ensure daemon is reachable (small RTT check)
if not await self._client.ping():
raise RuntimeError("ClamAV service is unreachable")
start = time.monotonic()
chunk_size = len(content) # Start with full content length
for retry in range(self.settings.max_retries):
# For small files, don't check min_chunk_size limit
if chunk_size < self.settings.min_chunk_size and chunk_size < len(content):
break
logger.debug(
f"Scanning {filename} with chunk size: {chunk_size // 1_048_576} MB (retry {retry + 1}/{self.settings.max_retries})"
)
try:
tasks = [
asyncio.create_task(self._instream(content[o : o + chunk_size]))
for o in range(0, len(content), chunk_size)
]
for coro in asyncio.as_completed(tasks):
infected, threat = await coro
if infected:
for t in tasks:
if not t.done():
t.cancel()
return VirusScanResult(
is_clean=False,
threat_name=threat,
file_size=len(content),
scan_time_ms=int((time.monotonic() - start) * 1000),
)
# All chunks clean
return VirusScanResult(
is_clean=True,
file_size=len(content),
scan_time_ms=int((time.monotonic() - start) * 1000),
)
except RuntimeError as exc:
if str(exc) == "size-limit":
chunk_size //= 2
continue
logger.error(f"Cannot scan {filename}: {exc}")
raise
# Phase 3 give up but warn
logger.warning(
f"Unable to virus scan {filename} ({len(content)} bytes) even with minimum chunk size ({self.settings.min_chunk_size} bytes). Recommend manual review."
)
return VirusScanResult(
is_clean=self.settings.mark_failed_scans_as_clean,
file_size=len(content),
scan_time_ms=int((time.monotonic() - start) * 1000),
)
_scanner: Optional[VirusScannerService] = None
def get_virus_scanner() -> VirusScannerService:
global _scanner
if _scanner is None:
_settings = VirusScannerSettings(
clamav_service_host=settings.config.clamav_service_host,
clamav_service_port=settings.config.clamav_service_port,
clamav_service_enabled=settings.config.clamav_service_enabled,
max_concurrency=settings.config.clamav_max_concurrency,
mark_failed_scans_as_clean=settings.config.clamav_mark_failed_scans_as_clean,
)
_scanner = VirusScannerService(_settings)
return _scanner
async def scan_content_safe(content: bytes, *, filename: str = "unknown") -> None:
"""
Helper function to scan content and raise appropriate exceptions.
Raises:
VirusDetectedError: If virus is found
VirusScanError: If scanning fails
"""
from backend.server.v2.store.exceptions import VirusDetectedError, VirusScanError
try:
result = await get_virus_scanner().scan_file(content, filename=filename)
if not result.is_clean:
threat_name = result.threat_name or "Unknown threat"
logger.warning(f"Virus detected in file {filename}: {threat_name}")
raise VirusDetectedError(
threat_name, f"File rejected due to virus detection: {threat_name}"
)
logger.info(f"File {filename} passed virus scan in {result.scan_time_ms}ms")
except VirusDetectedError:
raise
except Exception as e:
logger.error(f"Virus scanning failed for {filename}: {str(e)}")
raise VirusScanError(f"Virus scanning failed: {str(e)}") from e

View File

@@ -0,0 +1,253 @@
import asyncio
from unittest.mock import AsyncMock, Mock, patch
import pytest
from backend.server.v2.store.exceptions import VirusDetectedError, VirusScanError
from backend.util.virus_scanner import (
VirusScannerService,
VirusScannerSettings,
VirusScanResult,
get_virus_scanner,
scan_content_safe,
)
class TestVirusScannerService:
@pytest.fixture
def scanner_settings(self):
return VirusScannerSettings(
clamav_service_host="localhost",
clamav_service_port=3310,
clamav_service_enabled=True,
max_scan_size=10 * 1024 * 1024, # 10MB for testing
mark_failed_scans_as_clean=False, # For testing, failed scans should be clean
)
@pytest.fixture
def scanner(self, scanner_settings):
return VirusScannerService(scanner_settings)
@pytest.fixture
def disabled_scanner(self):
settings = VirusScannerSettings(clamav_service_enabled=False)
return VirusScannerService(settings)
def test_scanner_initialization(self, scanner_settings):
scanner = VirusScannerService(scanner_settings)
assert scanner.settings.clamav_service_host == "localhost"
assert scanner.settings.clamav_service_port == 3310
assert scanner.settings.clamav_service_enabled is True
@pytest.mark.asyncio
async def test_scan_disabled_returns_clean(self, disabled_scanner):
content = b"test file content"
result = await disabled_scanner.scan_file(content, filename="test.txt")
assert result.is_clean is True
assert result.threat_name is None
assert result.file_size == len(content)
assert result.scan_time_ms == 0
@pytest.mark.asyncio
async def test_scan_file_too_large(self, scanner):
# Create content larger than max_scan_size
large_content = b"x" * (scanner.settings.max_scan_size + 1)
# Large files behavior depends on mark_failed_scans_as_clean setting
result = await scanner.scan_file(large_content, filename="large_file.txt")
assert result.is_clean == scanner.settings.mark_failed_scans_as_clean
assert result.file_size == len(large_content)
assert result.scan_time_ms == 0
@pytest.mark.asyncio
async def test_scan_file_too_large_both_configurations(self):
"""Test large file handling with both mark_failed_scans_as_clean configurations"""
large_content = b"x" * (10 * 1024 * 1024 + 1) # Larger than 10MB
# Test with mark_failed_scans_as_clean=True
settings_clean = VirusScannerSettings(
max_scan_size=10 * 1024 * 1024, mark_failed_scans_as_clean=True
)
scanner_clean = VirusScannerService(settings_clean)
result_clean = await scanner_clean.scan_file(
large_content, filename="large_file.txt"
)
assert result_clean.is_clean is True
# Test with mark_failed_scans_as_clean=False
settings_dirty = VirusScannerSettings(
max_scan_size=10 * 1024 * 1024, mark_failed_scans_as_clean=False
)
scanner_dirty = VirusScannerService(settings_dirty)
result_dirty = await scanner_dirty.scan_file(
large_content, filename="large_file.txt"
)
assert result_dirty.is_clean is False
# Note: ping method was removed from current implementation
@pytest.mark.asyncio
async def test_scan_clean_file(self, scanner):
async def mock_instream(_):
await asyncio.sleep(0.001) # Small delay to ensure timing > 0
return None # No virus detected
mock_client = Mock()
mock_client.ping = AsyncMock(return_value=True)
mock_client.instream = AsyncMock(side_effect=mock_instream)
# Replace the client instance that was created in the constructor
scanner._client = mock_client
content = b"clean file content"
result = await scanner.scan_file(content, filename="clean.txt")
assert result.is_clean is True
assert result.threat_name is None
assert result.file_size == len(content)
assert result.scan_time_ms > 0
@pytest.mark.asyncio
async def test_scan_infected_file(self, scanner):
async def mock_instream(_):
await asyncio.sleep(0.001) # Small delay to ensure timing > 0
return {"stream": ("FOUND", "Win.Test.EICAR_HDB-1")}
mock_client = Mock()
mock_client.ping = AsyncMock(return_value=True)
mock_client.instream = AsyncMock(side_effect=mock_instream)
# Replace the client instance that was created in the constructor
scanner._client = mock_client
content = b"infected file content"
result = await scanner.scan_file(content, filename="infected.txt")
assert result.is_clean is False
assert result.threat_name == "Win.Test.EICAR_HDB-1"
assert result.file_size == len(content)
assert result.scan_time_ms > 0
@pytest.mark.asyncio
async def test_scan_clamav_unavailable_fail_safe(self, scanner):
mock_client = Mock()
mock_client.ping = AsyncMock(return_value=False)
# Replace the client instance that was created in the constructor
scanner._client = mock_client
content = b"test content"
with pytest.raises(RuntimeError, match="ClamAV service is unreachable"):
await scanner.scan_file(content, filename="test.txt")
@pytest.mark.asyncio
async def test_scan_error_fail_safe(self, scanner):
mock_client = Mock()
mock_client.ping = AsyncMock(return_value=True)
mock_client.instream = AsyncMock(side_effect=Exception("Scanning error"))
# Replace the client instance that was created in the constructor
scanner._client = mock_client
content = b"test content"
with pytest.raises(Exception, match="Scanning error"):
await scanner.scan_file(content, filename="test.txt")
# Note: scan_file_method and scan_upload_file tests removed as these APIs don't exist in current implementation
def test_get_virus_scanner_singleton(self):
scanner1 = get_virus_scanner()
scanner2 = get_virus_scanner()
# Should return the same instance
assert scanner1 is scanner2
# Note: client_reuse test removed as _get_client method doesn't exist in current implementation
def test_scan_result_model(self):
# Test VirusScanResult model
result = VirusScanResult(
is_clean=False, threat_name="Test.Virus", scan_time_ms=150, file_size=1024
)
assert result.is_clean is False
assert result.threat_name == "Test.Virus"
assert result.scan_time_ms == 150
assert result.file_size == 1024
@pytest.mark.asyncio
async def test_concurrent_scans(self, scanner):
async def mock_instream(_):
await asyncio.sleep(0.001) # Small delay to ensure timing > 0
return None
mock_client = Mock()
mock_client.ping = AsyncMock(return_value=True)
mock_client.instream = AsyncMock(side_effect=mock_instream)
# Replace the client instance that was created in the constructor
scanner._client = mock_client
content1 = b"file1 content"
content2 = b"file2 content"
# Run concurrent scans
results = await asyncio.gather(
scanner.scan_file(content1, filename="file1.txt"),
scanner.scan_file(content2, filename="file2.txt"),
)
assert len(results) == 2
assert all(result.is_clean for result in results)
assert results[0].file_size == len(content1)
assert results[1].file_size == len(content2)
assert all(result.scan_time_ms > 0 for result in results)
class TestHelperFunctions:
"""Test the helper functions scan_content_safe"""
@pytest.mark.asyncio
async def test_scan_content_safe_clean(self):
"""Test scan_content_safe with clean content"""
with patch("backend.util.virus_scanner.get_virus_scanner") as mock_get_scanner:
mock_scanner = Mock()
mock_scanner.scan_file = AsyncMock()
mock_scanner.scan_file.return_value = Mock(
is_clean=True, threat_name=None, scan_time_ms=50, file_size=100
)
mock_get_scanner.return_value = mock_scanner
# Should not raise any exception
await scan_content_safe(b"clean content", filename="test.txt")
@pytest.mark.asyncio
async def test_scan_content_safe_infected(self):
"""Test scan_content_safe with infected content"""
with patch("backend.util.virus_scanner.get_virus_scanner") as mock_get_scanner:
mock_scanner = Mock()
mock_scanner.scan_file = AsyncMock()
mock_scanner.scan_file.return_value = Mock(
is_clean=False, threat_name="Test.Virus", scan_time_ms=50, file_size=100
)
mock_get_scanner.return_value = mock_scanner
with pytest.raises(VirusDetectedError) as exc_info:
await scan_content_safe(b"infected content", filename="virus.txt")
assert exc_info.value.threat_name == "Test.Virus"
@pytest.mark.asyncio
async def test_scan_content_safe_scan_error(self):
"""Test scan_content_safe when scanning fails"""
with patch("backend.util.virus_scanner.get_virus_scanner") as mock_get_scanner:
mock_scanner = Mock()
mock_scanner.scan_file = AsyncMock()
mock_scanner.scan_file.side_effect = Exception("Scan failed")
mock_get_scanner.return_value = mock_scanner
with pytest.raises(VirusScanError, match="Virus scanning failed"):
await scan_content_safe(b"test content", filename="test.txt")

View File

@@ -0,0 +1,5 @@
-- Add webhookId column
ALTER TABLE "AgentPreset" ADD COLUMN "webhookId" TEXT;
-- Add AgentPreset<->IntegrationWebhook relation
ALTER TABLE "AgentPreset" ADD CONSTRAINT "AgentPreset_webhookId_fkey" FOREIGN KEY ("webhookId") REFERENCES "IntegrationWebhook"("id") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@@ -17,6 +17,18 @@ aiormq = ">=6.8,<6.9"
exceptiongroup = ">=1,<2"
yarl = "*"
[[package]]
name = "aioclamd"
version = "1.0.0"
description = "Asynchronous client for virus scanning with ClamAV"
optional = false
python-versions = ">=3.7,<4.0"
groups = ["main"]
files = [
{file = "aioclamd-1.0.0-py3-none-any.whl", hash = "sha256:4727da3953a4b38be4c2de1acb6b3bb3c94c1c171dcac780b80234ee6253f3d9"},
{file = "aioclamd-1.0.0.tar.gz", hash = "sha256:7b14e94e3a2285cc89e2f4d434e2a01f322d3cb95476ce2dda015a7980876047"},
]
[[package]]
name = "aiodns"
version = "3.4.0"
@@ -5006,6 +5018,27 @@ statsig = ["statsig (>=0.55.3)"]
tornado = ["tornado (>=6)"]
unleash = ["UnleashClient (>=6.0.1)"]
[[package]]
name = "setuptools"
version = "80.9.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"},
{file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"},
]
[package.extras]
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""]
core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"]
cover = ["pytest-cov"]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
enabler = ["pytest-enabler (>=2.2)"]
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.14.*)", "pytest-mypy"]
[[package]]
name = "sgmllib3k"
version = "1.0.0"
@@ -6369,4 +6402,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.13"
content-hash = "6c93e51cf22c06548aa6d0e23ca8ceb4450f5e02d4142715e941aabc1a2cbd6a"
content-hash = "b5c1201f27ee8d05d5d8c89702123df4293f124301d1aef7451591a351872260"

View File

@@ -68,6 +68,9 @@ zerobouncesdk = "^1.1.1"
# NOTE: please insert new dependencies in their alphabetical location
pytest-snapshot = "^0.9.0"
aiofiles = "^24.1.0"
tiktoken = "^0.9.0"
aioclamd = "^1.0.0"
setuptools = "^80.9.0"
[tool.poetry.group.dev.dependencies]
aiohappyeyeballs = "^2.6.1"
@@ -112,6 +115,11 @@ ignore_patterns = []
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
filterwarnings = [
"ignore:'audioop' is deprecated:DeprecationWarning:discord.player",
"ignore:invalid escape sequence:DeprecationWarning:tweepy.api",
]
[tool.ruff]
target-version = "py310"

View File

@@ -169,6 +169,10 @@ model AgentPreset {
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
Executions AgentGraphExecution[]
// For webhook-triggered agents: reference to the webhook that triggers the agent
webhookId String?
Webhook IntegrationWebhook? @relation(fields: [webhookId], references: [id])
isDeleted Boolean @default(false)
@@index([userId])
@@ -428,7 +432,8 @@ model IntegrationWebhook {
providerWebhookId String // Webhook ID assigned by the provider
AgentNodes AgentNode[]
AgentNodes AgentNode[]
AgentPresets AgentPreset[]
@@index([userId])
}

View File

@@ -0,0 +1,3 @@
{
"metric_id": "metric-123-uuid"
}

View File

@@ -0,0 +1,4 @@
{
"metric_id": "metric-float_precision-uuid",
"test_case": "float_precision"
}

View File

@@ -0,0 +1,4 @@
{
"metric_id": "metric-integer_value-uuid",
"test_case": "integer_value"
}

View File

@@ -0,0 +1,4 @@
{
"metric_id": "metric-large_number-uuid",
"test_case": "large_number"
}

View File

@@ -0,0 +1,4 @@
{
"metric_id": "metric-negative_value-uuid",
"test_case": "negative_value"
}

Some files were not shown because too many files have changed in this diff Show More