mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-09 06:15:41 -05:00
Merge branch 'dev' into codex/fix-400-error-with-non-default-voices-in-unreal-tts
This commit is contained in:
51
.github/workflows/platform-backend-ci.yml
vendored
51
.github/workflows/platform-backend-ci.yml
vendored
@@ -50,6 +50,23 @@ jobs:
|
||||
env:
|
||||
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
||||
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
||||
clamav:
|
||||
image: clamav/clamav-debian:latest
|
||||
ports:
|
||||
- 3310:3310
|
||||
env:
|
||||
CLAMAV_NO_FRESHCLAMD: false
|
||||
CLAMD_CONF_StreamMaxLength: 50M
|
||||
CLAMD_CONF_MaxFileSize: 100M
|
||||
CLAMD_CONF_MaxScanSize: 100M
|
||||
CLAMD_CONF_MaxThreads: 4
|
||||
CLAMD_CONF_ReadTimeout: 300
|
||||
options: >-
|
||||
--health-cmd "clamdscan --version || exit 1"
|
||||
--health-interval 30s
|
||||
--health-timeout 10s
|
||||
--health-retries 5
|
||||
--health-start-period 180s
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -131,6 +148,35 @@ jobs:
|
||||
# outputs:
|
||||
# DB_URL, API_URL, GRAPHQL_URL, ANON_KEY, SERVICE_ROLE_KEY, JWT_SECRET
|
||||
|
||||
- name: Wait for ClamAV to be ready
|
||||
run: |
|
||||
echo "Waiting for ClamAV daemon to start..."
|
||||
max_attempts=60
|
||||
attempt=0
|
||||
|
||||
until nc -z localhost 3310 || [ $attempt -eq $max_attempts ]; do
|
||||
echo "ClamAV is unavailable - sleeping (attempt $((attempt+1))/$max_attempts)"
|
||||
sleep 5
|
||||
attempt=$((attempt+1))
|
||||
done
|
||||
|
||||
if [ $attempt -eq $max_attempts ]; then
|
||||
echo "ClamAV failed to start after $((max_attempts*5)) seconds"
|
||||
echo "Checking ClamAV service logs..."
|
||||
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "ClamAV is ready!"
|
||||
|
||||
# Verify ClamAV is responsive
|
||||
echo "Testing ClamAV connection..."
|
||||
timeout 10 bash -c 'echo "PING" | nc localhost 3310' || {
|
||||
echo "ClamAV is not responding to PING"
|
||||
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
|
||||
exit 1
|
||||
}
|
||||
|
||||
- name: Run Database Migrations
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
env:
|
||||
@@ -144,9 +190,9 @@ jobs:
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
else
|
||||
poetry run pytest -s -vv test
|
||||
poetry run pytest -s -vv
|
||||
fi
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
@@ -159,6 +205,7 @@ jobs:
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
env:
|
||||
CI: true
|
||||
|
||||
31
.github/workflows/platform-frontend-ci.yml
vendored
31
.github/workflows/platform-frontend-ci.yml
vendored
@@ -55,12 +55,37 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Generate API client
|
||||
run: pnpm generate:api-client
|
||||
|
||||
- name: Run tsc check
|
||||
run: pnpm type-check
|
||||
|
||||
chromatic:
|
||||
runs-on: ubuntu-latest
|
||||
# Only run on dev branch pushes or PRs targeting dev
|
||||
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run Chromatic
|
||||
uses: chromaui/action@latest
|
||||
with:
|
||||
projectToken: chpt_9e7c1a76478c9c8
|
||||
onlyChanged: true
|
||||
workingDir: autogpt_platform/frontend
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -177,6 +177,3 @@ autogpt_platform/backend/settings.py
|
||||
*.ign.*
|
||||
.test-contents
|
||||
.claude/settings.local.json
|
||||
|
||||
# Auto generated client
|
||||
autogpt_platform/frontend/src/api/__generated__
|
||||
|
||||
@@ -19,7 +19,7 @@ cd backend && poetry install
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq)
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend server
|
||||
@@ -92,6 +92,7 @@ npm run type-check
|
||||
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Testing Approach
|
||||
- Backend uses pytest with snapshot testing for API responses
|
||||
|
||||
@@ -55,9 +55,9 @@ RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
## GCS bucket is required for marketplace and library functionality
|
||||
MEDIA_GCS_BUCKET_NAME=
|
||||
|
||||
## For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow
|
||||
## For local development, you may need to set NEXT_PUBLIC_FRONTEND_BASE_URL for the OAuth flow
|
||||
## for integrations to work. Defaults to the value of PLATFORM_BASE_URL if not set.
|
||||
# FRONTEND_BASE_URL=http://localhost:3000
|
||||
# NEXT_PUBLIC_FRONTEND_BASE_URL=http://localhost:3000
|
||||
|
||||
## PLATFORM_BASE_URL must be set to a *publicly accessible* URL pointing to your backend
|
||||
## to use the platform's webhook-related functionality.
|
||||
|
||||
@@ -20,7 +20,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
modules = [
|
||||
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
|
||||
for f in current_dir.rglob("*.py")
|
||||
if f.is_file() and f.name != "__init__.py"
|
||||
if f.is_file() and f.name != "__init__.py" and not f.name.startswith("test_")
|
||||
]
|
||||
for module in modules:
|
||||
if not re.match("^[a-z0-9_.]+$", module):
|
||||
|
||||
@@ -2,6 +2,8 @@ import asyncio
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -12,7 +14,7 @@ from backend.data.block import (
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import CredentialsMetaInput, SchemaField
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -28,9 +30,9 @@ class AgentExecutorBlock(Block):
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
output_schema: dict = SchemaField(description="Output schema for the graph")
|
||||
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = SchemaField(default=None, hidden=True)
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = SchemaField(
|
||||
default=None, hidden=True
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
||||
@@ -71,7 +73,7 @@ class AgentExecutorBlock(Block):
|
||||
graph_version=input_data.graph_version,
|
||||
user_id=input_data.user_id,
|
||||
inputs=input_data.inputs,
|
||||
node_credentials_input_map=input_data.node_credentials_input_map,
|
||||
nodes_input_masks=input_data.nodes_input_masks,
|
||||
use_db_query=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ class AudioTrack(str, Enum):
|
||||
REFRESHER = ("Refresher",)
|
||||
TOURIST = ("Tourist",)
|
||||
TWIN_TYCHES = ("Twin Tyches",)
|
||||
DONT_STOP_ME_ABSTRACT_FUTURE_BASS = ("Dont Stop Me Abstract Future Bass",)
|
||||
|
||||
@property
|
||||
def audio_url(self):
|
||||
@@ -78,6 +79,7 @@ class AudioTrack(str, Enum):
|
||||
AudioTrack.REFRESHER: "https://cdn.tfrv.xyz/audio/refresher.mp3",
|
||||
AudioTrack.TOURIST: "https://cdn.tfrv.xyz/audio/tourist.mp3",
|
||||
AudioTrack.TWIN_TYCHES: "https://cdn.tfrv.xyz/audio/twin-tynches.mp3",
|
||||
AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS: "https://cdn.revid.ai/audio/_dont-stop-me-abstract-future-bass.mp3",
|
||||
}
|
||||
return audio_urls[self]
|
||||
|
||||
@@ -105,6 +107,7 @@ class GenerationPreset(str, Enum):
|
||||
MOVIE = ("Movie",)
|
||||
STYLIZED_ILLUSTRATION = ("Stylized Illustration",)
|
||||
MANGA = ("Manga",)
|
||||
DEFAULT = ("DEFAULT",)
|
||||
|
||||
|
||||
class Voice(str, Enum):
|
||||
@@ -114,6 +117,7 @@ class Voice(str, Enum):
|
||||
JESSICA = "Jessica"
|
||||
CHARLOTTE = "Charlotte"
|
||||
CALLUM = "Callum"
|
||||
EVA = "Eva"
|
||||
|
||||
@property
|
||||
def voice_id(self):
|
||||
@@ -124,6 +128,7 @@ class Voice(str, Enum):
|
||||
Voice.JESSICA: "cgSgspJ2msm6clMCkdW9",
|
||||
Voice.CHARLOTTE: "XB0fDUnXU5powFXDhCwa",
|
||||
Voice.CALLUM: "N2lVS1w4EtoT3dr4eOWO",
|
||||
Voice.EVA: "FGY2WhTYpPnrIDTdsKH5",
|
||||
}
|
||||
return voice_id_map[self]
|
||||
|
||||
@@ -141,6 +146,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIShortformVideoCreatorBlock(Block):
|
||||
"""Creates a short‑form text‑to‑video clip using stock or AI imagery."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
@@ -184,40 +191,8 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
video_url: str = SchemaField(description="The URL of the created video")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
|
||||
description="Creates a shortform video using revid.ai",
|
||||
categories={BlockCategory.SOCIAL, BlockCategory.AI},
|
||||
input_schema=AIShortformVideoCreatorBlock.Input,
|
||||
output_schema=AIShortformVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "[close-up of a cat] Meow!",
|
||||
"ratio": "9 / 16",
|
||||
"resolution": "720p",
|
||||
"frame_rate": 60,
|
||||
"generation_preset": GenerationPreset.LEONARDO,
|
||||
"background_music": AudioTrack.HIGHWAY_NOCTURNE,
|
||||
"voice": Voice.LILY,
|
||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||
},
|
||||
test_output=(
|
||||
"video_url",
|
||||
"https://example.com/video.mp4",
|
||||
),
|
||||
test_mock={
|
||||
"create_webhook": lambda: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda api_key, payload: {"pid": "test_pid"},
|
||||
"wait_for_video": lambda api_key, pid, webhook_token, max_wait_time=1000: "https://example.com/video.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def create_webhook(self):
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
@@ -225,6 +200,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
@@ -234,6 +210,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
@@ -243,9 +220,9 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
webhook_token: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
@@ -266,6 +243,40 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
|
||||
description="Creates a shortform video using revid.ai",
|
||||
categories={BlockCategory.SOCIAL, BlockCategory.AI},
|
||||
input_schema=AIShortformVideoCreatorBlock.Input,
|
||||
output_schema=AIShortformVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "[close-up of a cat] Meow!",
|
||||
"ratio": "9 / 16",
|
||||
"resolution": "720p",
|
||||
"frame_rate": 60,
|
||||
"generation_preset": GenerationPreset.LEONARDO,
|
||||
"background_music": AudioTrack.HIGHWAY_NOCTURNE,
|
||||
"voice": Voice.LILY,
|
||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||
},
|
||||
test_output=("video_url", "https://example.com/video.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/video.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
@@ -273,20 +284,18 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
logger.debug(f"Webhook URL: {webhook_url}")
|
||||
|
||||
audio_url = input_data.background_music.audio_url
|
||||
|
||||
payload = {
|
||||
"frameRate": input_data.frame_rate,
|
||||
"resolution": input_data.resolution,
|
||||
"frameDurationMultiplier": 18,
|
||||
"webhook": webhook_url,
|
||||
"webhook": None,
|
||||
"creationParams": {
|
||||
"mediaType": input_data.video_style,
|
||||
"captionPresetName": "Wrap 1",
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"hasEnhancedGeneration": True,
|
||||
"generationPreset": input_data.generation_preset.name,
|
||||
"selectedAudio": input_data.background_music,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"origin": "/create",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
@@ -302,7 +311,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
"selectedStoryStyle": {"value": "custom", "label": "Custom"},
|
||||
"hasToGenerateVideos": input_data.video_style
|
||||
!= VisualMediaType.STOCK_VIDEOS,
|
||||
"audioUrl": audio_url,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -319,8 +328,370 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
logger.debug(
|
||||
f"Video created with project ID: {pid}. Waiting for completion..."
|
||||
)
|
||||
video_url = await self.wait_for_video(
|
||||
credentials.api_key, pid, webhook_token
|
||||
)
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
logger.debug(f"Video ready: {video_url}")
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIAdMakerVideoCreatorBlock(Block):
|
||||
"""Generates a 30‑second vertical AI advert using optional user‑supplied imagery."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Credentials for Revid.ai API access.",
|
||||
)
|
||||
script: str = SchemaField(
|
||||
description="Short advertising copy. Line breaks create new scenes.",
|
||||
placeholder="Introducing Foobar – [show product photo] the gadget that does it all.",
|
||||
)
|
||||
ratio: str = SchemaField(description="Aspect ratio", default="9 / 16")
|
||||
target_duration: int = SchemaField(
|
||||
description="Desired length of the ad in seconds.", default=30
|
||||
)
|
||||
voice: Voice = SchemaField(
|
||||
description="Narration voice", default=Voice.EVA, placeholder=Voice.EVA
|
||||
)
|
||||
background_music: AudioTrack = SchemaField(
|
||||
description="Background track",
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS,
|
||||
)
|
||||
input_media_urls: list[str] = SchemaField(
|
||||
description="List of image URLs to feature in the advert.", default=[]
|
||||
)
|
||||
use_only_provided_media: bool = SchemaField(
|
||||
description="Restrict visuals to supplied images only.", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="URL of the finished advert")
|
||||
error: str = SchemaField(description="Error message on failure")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
async def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="58bd2a19-115d-4fd1-8ca4-13b9e37fa6a0",
|
||||
description="Creates an AI‑generated 30‑second advert (text + images)",
|
||||
categories={BlockCategory.MARKETING, BlockCategory.AI},
|
||||
input_schema=AIAdMakerVideoCreatorBlock.Input,
|
||||
output_schema=AIAdMakerVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "Test product launch!",
|
||||
"input_media_urls": [
|
||||
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
||||
],
|
||||
},
|
||||
test_output=("video_url", "https://example.com/ad.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/ad.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
|
||||
payload = {
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"targetDuration": input_data.target_duration,
|
||||
"ratio": input_data.ratio,
|
||||
"mediaType": "aiVideo",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "ai-ad-generator",
|
||||
"slugNew": "",
|
||||
"isCopiedFrom": False,
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasAvatar": False,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
|
||||
"selectedAvatarType": "video/mp4",
|
||||
"websiteToRecord": "",
|
||||
"hasToGenerateCover": True,
|
||||
"nbGenerations": 1,
|
||||
"disableCaptions": False,
|
||||
"mediaMultiplier": "medium",
|
||||
"characters": [],
|
||||
"captionPresetName": "Revid",
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "General"},
|
||||
"generationPreset": "DEFAULT",
|
||||
"hasToGenerateMusic": False,
|
||||
"isOptimizedForChinese": False,
|
||||
"generationUserPrompt": "",
|
||||
"enableNsfwFilter": False,
|
||||
"addStickers": False,
|
||||
"typeMovingImageAnim": "dynamic",
|
||||
"hasToGenerateSoundEffects": False,
|
||||
"forceModelType": "gpt-image-1",
|
||||
"selectedCharacters": [],
|
||||
"lang": "",
|
||||
"voiceSpeed": 1,
|
||||
"disableAudio": False,
|
||||
"disableVoice": False,
|
||||
"useOnlyProvidedMedia": input_data.use_only_provided_media,
|
||||
"imageGenerationModel": "ultra",
|
||||
"videoGenerationModel": "pro",
|
||||
"hasEnhancedGeneration": True,
|
||||
"hasEnhancedGenerationPro": True,
|
||||
"inputMedias": [
|
||||
{"url": url, "title": "", "type": "image"}
|
||||
for url in input_data.input_media_urls
|
||||
],
|
||||
"hasToGenerateVideos": True,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"watermark": None,
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.create_video(credentials.api_key, payload)
|
||||
pid = response.get("pid")
|
||||
if not pid:
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIScreenshotToVideoAdBlock(Block):
|
||||
"""Creates an advert where the supplied screenshot is narrated by an AI avatar."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(description="Revid.ai API key")
|
||||
script: str = SchemaField(
|
||||
description="Narration that will accompany the screenshot.",
|
||||
placeholder="Check out these amazing stats!",
|
||||
)
|
||||
screenshot_url: str = SchemaField(
|
||||
description="Screenshot or image URL to showcase."
|
||||
)
|
||||
ratio: str = SchemaField(default="9 / 16")
|
||||
target_duration: int = SchemaField(default=30)
|
||||
voice: Voice = SchemaField(default=Voice.EVA)
|
||||
background_music: AudioTrack = SchemaField(
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="Rendered video URL")
|
||||
error: str = SchemaField(description="Error, if encountered")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
async def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0f3e4635-e810-43d9-9e81-49e6f4e83b7c",
|
||||
description="Turns a screenshot into an engaging, avatar‑narrated video advert.",
|
||||
categories={BlockCategory.AI, BlockCategory.MARKETING},
|
||||
input_schema=AIScreenshotToVideoAdBlock.Input,
|
||||
output_schema=AIScreenshotToVideoAdBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "Amazing numbers!",
|
||||
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
||||
},
|
||||
test_output=("video_url", "https://example.com/screenshot.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/screenshot.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
|
||||
payload = {
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"targetDuration": input_data.target_duration,
|
||||
"ratio": input_data.ratio,
|
||||
"mediaType": "aiVideo",
|
||||
"hasAvatar": True,
|
||||
"removeAvatarBackground": True,
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "ai-ad-generator",
|
||||
"slugNew": "screenshot-to-video-ad",
|
||||
"isCopiedFrom": "ai-ad-generator",
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
|
||||
"selectedAvatarType": "video/mp4",
|
||||
"websiteToRecord": "",
|
||||
"hasToGenerateCover": True,
|
||||
"nbGenerations": 1,
|
||||
"disableCaptions": False,
|
||||
"mediaMultiplier": "medium",
|
||||
"characters": [],
|
||||
"captionPresetName": "Revid",
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "General"},
|
||||
"generationPreset": "DEFAULT",
|
||||
"hasToGenerateMusic": False,
|
||||
"isOptimizedForChinese": False,
|
||||
"generationUserPrompt": "",
|
||||
"enableNsfwFilter": False,
|
||||
"addStickers": False,
|
||||
"typeMovingImageAnim": "dynamic",
|
||||
"hasToGenerateSoundEffects": False,
|
||||
"forceModelType": "gpt-image-1",
|
||||
"selectedCharacters": [],
|
||||
"lang": "",
|
||||
"voiceSpeed": 1,
|
||||
"disableAudio": False,
|
||||
"disableVoice": False,
|
||||
"useOnlyProvidedMedia": True,
|
||||
"imageGenerationModel": "ultra",
|
||||
"videoGenerationModel": "ultra",
|
||||
"hasEnhancedGeneration": True,
|
||||
"hasEnhancedGenerationPro": True,
|
||||
"inputMedias": [
|
||||
{"url": input_data.screenshot_url, "title": "", "type": "image"}
|
||||
],
|
||||
"hasToGenerateVideos": True,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"watermark": None,
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.create_video(credentials.api_key, payload)
|
||||
pid = response.get("pid")
|
||||
if not pid:
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List
|
||||
from backend.blocks.apollo._auth import ApolloCredentials
|
||||
from backend.blocks.apollo.models import (
|
||||
Contact,
|
||||
EnrichPersonRequest,
|
||||
Organization,
|
||||
SearchOrganizationsRequest,
|
||||
SearchOrganizationsResponse,
|
||||
@@ -110,3 +111,21 @@ class ApolloClient:
|
||||
return (
|
||||
organizations[: query.max_results] if query.max_results else organizations
|
||||
)
|
||||
|
||||
async def enrich_person(self, query: EnrichPersonRequest) -> Contact:
|
||||
"""Enrich a person's data including email & phone reveal"""
|
||||
response = await self.requests.post(
|
||||
f"{self.API_URL}/people/match",
|
||||
headers=self._get_headers(),
|
||||
json=query.model_dump(),
|
||||
params={
|
||||
"reveal_personal_emails": "true",
|
||||
},
|
||||
)
|
||||
data = response.json()
|
||||
if "person" not in data:
|
||||
raise ValueError(f"Person not found or enrichment failed: {data}")
|
||||
|
||||
contact = Contact(**data["person"])
|
||||
contact.email = contact.email or "-"
|
||||
return contact
|
||||
|
||||
@@ -23,9 +23,9 @@ class BaseModel(OriginalBaseModel):
|
||||
class PrimaryPhone(BaseModel):
|
||||
"""A primary phone in Apollo"""
|
||||
|
||||
number: str = ""
|
||||
source: str = ""
|
||||
sanitized_number: str = ""
|
||||
number: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
sanitized_number: Optional[str] = ""
|
||||
|
||||
|
||||
class SenorityLevels(str, Enum):
|
||||
@@ -56,102 +56,102 @@ class ContactEmailStatuses(str, Enum):
|
||||
class RuleConfigStatus(BaseModel):
|
||||
"""A rule config status in Apollo"""
|
||||
|
||||
_id: str = ""
|
||||
created_at: str = ""
|
||||
rule_action_config_id: str = ""
|
||||
rule_config_id: str = ""
|
||||
status_cd: str = ""
|
||||
updated_at: str = ""
|
||||
id: str = ""
|
||||
key: str = ""
|
||||
_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
rule_action_config_id: Optional[str] = ""
|
||||
rule_config_id: Optional[str] = ""
|
||||
status_cd: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
id: Optional[str] = ""
|
||||
key: Optional[str] = ""
|
||||
|
||||
|
||||
class ContactCampaignStatus(BaseModel):
|
||||
"""A contact campaign status in Apollo"""
|
||||
|
||||
id: str = ""
|
||||
emailer_campaign_id: str = ""
|
||||
send_email_from_user_id: str = ""
|
||||
inactive_reason: str = ""
|
||||
status: str = ""
|
||||
added_at: str = ""
|
||||
added_by_user_id: str = ""
|
||||
finished_at: str = ""
|
||||
paused_at: str = ""
|
||||
auto_unpause_at: str = ""
|
||||
send_email_from_email_address: str = ""
|
||||
send_email_from_email_account_id: str = ""
|
||||
manually_set_unpause: str = ""
|
||||
failure_reason: str = ""
|
||||
current_step_id: str = ""
|
||||
in_response_to_emailer_message_id: str = ""
|
||||
cc_emails: str = ""
|
||||
bcc_emails: str = ""
|
||||
to_emails: str = ""
|
||||
id: Optional[str] = ""
|
||||
emailer_campaign_id: Optional[str] = ""
|
||||
send_email_from_user_id: Optional[str] = ""
|
||||
inactive_reason: Optional[str] = ""
|
||||
status: Optional[str] = ""
|
||||
added_at: Optional[str] = ""
|
||||
added_by_user_id: Optional[str] = ""
|
||||
finished_at: Optional[str] = ""
|
||||
paused_at: Optional[str] = ""
|
||||
auto_unpause_at: Optional[str] = ""
|
||||
send_email_from_email_address: Optional[str] = ""
|
||||
send_email_from_email_account_id: Optional[str] = ""
|
||||
manually_set_unpause: Optional[str] = ""
|
||||
failure_reason: Optional[str] = ""
|
||||
current_step_id: Optional[str] = ""
|
||||
in_response_to_emailer_message_id: Optional[str] = ""
|
||||
cc_emails: Optional[str] = ""
|
||||
bcc_emails: Optional[str] = ""
|
||||
to_emails: Optional[str] = ""
|
||||
|
||||
|
||||
class Account(BaseModel):
|
||||
"""An account in Apollo"""
|
||||
|
||||
id: str = ""
|
||||
name: str = ""
|
||||
website_url: str = ""
|
||||
blog_url: str = ""
|
||||
angellist_url: str = ""
|
||||
linkedin_url: str = ""
|
||||
twitter_url: str = ""
|
||||
facebook_url: str = ""
|
||||
primary_phone: PrimaryPhone = PrimaryPhone()
|
||||
languages: list[str]
|
||||
alexa_ranking: int = 0
|
||||
phone: str = ""
|
||||
linkedin_uid: str = ""
|
||||
founded_year: int = 0
|
||||
publicly_traded_symbol: str = ""
|
||||
publicly_traded_exchange: str = ""
|
||||
logo_url: str = ""
|
||||
chrunchbase_url: str = ""
|
||||
primary_domain: str = ""
|
||||
domain: str = ""
|
||||
team_id: str = ""
|
||||
organization_id: str = ""
|
||||
account_stage_id: str = ""
|
||||
source: str = ""
|
||||
original_source: str = ""
|
||||
creator_id: str = ""
|
||||
owner_id: str = ""
|
||||
created_at: str = ""
|
||||
phone_status: str = ""
|
||||
hubspot_id: str = ""
|
||||
salesforce_id: str = ""
|
||||
crm_owner_id: str = ""
|
||||
parent_account_id: str = ""
|
||||
sanitized_phone: str = ""
|
||||
id: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
website_url: Optional[str] = ""
|
||||
blog_url: Optional[str] = ""
|
||||
angellist_url: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
facebook_url: Optional[str] = ""
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
|
||||
languages: Optional[list[str]] = []
|
||||
alexa_ranking: Optional[int] = 0
|
||||
phone: Optional[str] = ""
|
||||
linkedin_uid: Optional[str] = ""
|
||||
founded_year: Optional[int] = 0
|
||||
publicly_traded_symbol: Optional[str] = ""
|
||||
publicly_traded_exchange: Optional[str] = ""
|
||||
logo_url: Optional[str] = ""
|
||||
chrunchbase_url: Optional[str] = ""
|
||||
primary_domain: Optional[str] = ""
|
||||
domain: Optional[str] = ""
|
||||
team_id: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
account_stage_id: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
original_source: Optional[str] = ""
|
||||
creator_id: Optional[str] = ""
|
||||
owner_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
phone_status: Optional[str] = ""
|
||||
hubspot_id: Optional[str] = ""
|
||||
salesforce_id: Optional[str] = ""
|
||||
crm_owner_id: Optional[str] = ""
|
||||
parent_account_id: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
# no listed type on the API docs
|
||||
account_playbook_statues: list[Any] = []
|
||||
account_rule_config_statuses: list[RuleConfigStatus] = []
|
||||
existence_level: str = ""
|
||||
label_ids: list[str] = []
|
||||
typed_custom_fields: Any
|
||||
custom_field_errors: Any
|
||||
modality: str = ""
|
||||
source_display_name: str = ""
|
||||
salesforce_record_id: str = ""
|
||||
crm_record_url: str = ""
|
||||
account_playbook_statues: Optional[list[Any]] = []
|
||||
account_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
|
||||
existence_level: Optional[str] = ""
|
||||
label_ids: Optional[list[str]] = []
|
||||
typed_custom_fields: Optional[Any] = {}
|
||||
custom_field_errors: Optional[Any] = {}
|
||||
modality: Optional[str] = ""
|
||||
source_display_name: Optional[str] = ""
|
||||
salesforce_record_id: Optional[str] = ""
|
||||
crm_record_url: Optional[str] = ""
|
||||
|
||||
|
||||
class ContactEmail(BaseModel):
|
||||
"""A contact email in Apollo"""
|
||||
|
||||
email: str = ""
|
||||
email_md5: str = ""
|
||||
email_sha256: str = ""
|
||||
email_status: str = ""
|
||||
email_source: str = ""
|
||||
extrapolated_email_confidence: str = ""
|
||||
position: int = 0
|
||||
email_from_customer: str = ""
|
||||
free_domain: bool = True
|
||||
email: Optional[str] = ""
|
||||
email_md5: Optional[str] = ""
|
||||
email_sha256: Optional[str] = ""
|
||||
email_status: Optional[str] = ""
|
||||
email_source: Optional[str] = ""
|
||||
extrapolated_email_confidence: Optional[str] = ""
|
||||
position: Optional[int] = 0
|
||||
email_from_customer: Optional[str] = ""
|
||||
free_domain: Optional[bool] = True
|
||||
|
||||
|
||||
class EmploymentHistory(BaseModel):
|
||||
@@ -164,40 +164,40 @@ class EmploymentHistory(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
current: Optional[bool] = None
|
||||
degree: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
emails: Optional[str] = None
|
||||
end_date: Optional[str] = None
|
||||
grade_level: Optional[str] = None
|
||||
kind: Optional[str] = None
|
||||
major: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
organization_name: Optional[str] = None
|
||||
raw_address: Optional[str] = None
|
||||
start_date: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
id: Optional[str] = None
|
||||
key: Optional[str] = None
|
||||
_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
current: Optional[bool] = False
|
||||
degree: Optional[str] = ""
|
||||
description: Optional[str] = ""
|
||||
emails: Optional[str] = ""
|
||||
end_date: Optional[str] = ""
|
||||
grade_level: Optional[str] = ""
|
||||
kind: Optional[str] = ""
|
||||
major: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
organization_name: Optional[str] = ""
|
||||
raw_address: Optional[str] = ""
|
||||
start_date: Optional[str] = ""
|
||||
title: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
id: Optional[str] = ""
|
||||
key: Optional[str] = ""
|
||||
|
||||
|
||||
class Breadcrumb(BaseModel):
|
||||
"""A breadcrumb in Apollo"""
|
||||
|
||||
label: Optional[str] = "N/A"
|
||||
signal_field_name: Optional[str] = "N/A"
|
||||
value: str | list | None = "N/A"
|
||||
display_name: Optional[str] = "N/A"
|
||||
label: Optional[str] = ""
|
||||
signal_field_name: Optional[str] = ""
|
||||
value: str | list | None = ""
|
||||
display_name: Optional[str] = ""
|
||||
|
||||
|
||||
class TypedCustomField(BaseModel):
|
||||
"""A typed custom field in Apollo"""
|
||||
|
||||
id: Optional[str] = "N/A"
|
||||
value: Optional[str] = "N/A"
|
||||
id: Optional[str] = ""
|
||||
value: Optional[str] = ""
|
||||
|
||||
|
||||
class Pagination(BaseModel):
|
||||
@@ -219,23 +219,23 @@ class Pagination(BaseModel):
|
||||
class DialerFlags(BaseModel):
|
||||
"""A dialer flags in Apollo"""
|
||||
|
||||
country_name: str = ""
|
||||
country_enabled: bool
|
||||
high_risk_calling_enabled: bool
|
||||
potential_high_risk_number: bool
|
||||
country_name: Optional[str] = ""
|
||||
country_enabled: Optional[bool] = True
|
||||
high_risk_calling_enabled: Optional[bool] = True
|
||||
potential_high_risk_number: Optional[bool] = True
|
||||
|
||||
|
||||
class PhoneNumber(BaseModel):
|
||||
"""A phone number in Apollo"""
|
||||
|
||||
raw_number: str = ""
|
||||
sanitized_number: str = ""
|
||||
type: str = ""
|
||||
position: int = 0
|
||||
status: str = ""
|
||||
dnc_status: str = ""
|
||||
dnc_other_info: str = ""
|
||||
dailer_flags: DialerFlags = DialerFlags(
|
||||
raw_number: Optional[str] = ""
|
||||
sanitized_number: Optional[str] = ""
|
||||
type: Optional[str] = ""
|
||||
position: Optional[int] = 0
|
||||
status: Optional[str] = ""
|
||||
dnc_status: Optional[str] = ""
|
||||
dnc_other_info: Optional[str] = ""
|
||||
dailer_flags: Optional[DialerFlags] = DialerFlags(
|
||||
country_name="",
|
||||
country_enabled=True,
|
||||
high_risk_calling_enabled=True,
|
||||
@@ -253,33 +253,31 @@ class Organization(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
id: Optional[str] = "N/A"
|
||||
name: Optional[str] = "N/A"
|
||||
website_url: Optional[str] = "N/A"
|
||||
blog_url: Optional[str] = "N/A"
|
||||
angellist_url: Optional[str] = "N/A"
|
||||
linkedin_url: Optional[str] = "N/A"
|
||||
twitter_url: Optional[str] = "N/A"
|
||||
facebook_url: Optional[str] = "N/A"
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone(
|
||||
number="N/A", source="N/A", sanitized_number="N/A"
|
||||
)
|
||||
languages: list[str] = []
|
||||
id: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
website_url: Optional[str] = ""
|
||||
blog_url: Optional[str] = ""
|
||||
angellist_url: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
facebook_url: Optional[str] = ""
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
|
||||
languages: Optional[list[str]] = []
|
||||
alexa_ranking: Optional[int] = 0
|
||||
phone: Optional[str] = "N/A"
|
||||
linkedin_uid: Optional[str] = "N/A"
|
||||
phone: Optional[str] = ""
|
||||
linkedin_uid: Optional[str] = ""
|
||||
founded_year: Optional[int] = 0
|
||||
publicly_traded_symbol: Optional[str] = "N/A"
|
||||
publicly_traded_exchange: Optional[str] = "N/A"
|
||||
logo_url: Optional[str] = "N/A"
|
||||
chrunchbase_url: Optional[str] = "N/A"
|
||||
primary_domain: Optional[str] = "N/A"
|
||||
sanitized_phone: Optional[str] = "N/A"
|
||||
owned_by_organization_id: Optional[str] = "N/A"
|
||||
intent_strength: Optional[str] = "N/A"
|
||||
show_intent: bool = True
|
||||
publicly_traded_symbol: Optional[str] = ""
|
||||
publicly_traded_exchange: Optional[str] = ""
|
||||
logo_url: Optional[str] = ""
|
||||
chrunchbase_url: Optional[str] = ""
|
||||
primary_domain: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
owned_by_organization_id: Optional[str] = ""
|
||||
intent_strength: Optional[str] = ""
|
||||
show_intent: Optional[bool] = True
|
||||
has_intent_signal_account: Optional[bool] = True
|
||||
intent_signal_account: Optional[str] = "N/A"
|
||||
intent_signal_account: Optional[str] = ""
|
||||
|
||||
|
||||
class Contact(BaseModel):
|
||||
@@ -292,95 +290,95 @@ class Contact(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
contact_roles: list[Any] = []
|
||||
id: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
linkedin_url: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
contact_stage_id: Optional[str] = None
|
||||
owner_id: Optional[str] = None
|
||||
creator_id: Optional[str] = None
|
||||
person_id: Optional[str] = None
|
||||
email_needs_tickling: bool = True
|
||||
organization_name: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
original_source: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
headline: Optional[str] = None
|
||||
photo_url: Optional[str] = None
|
||||
present_raw_address: Optional[str] = None
|
||||
linkededin_uid: Optional[str] = None
|
||||
extrapolated_email_confidence: Optional[float] = None
|
||||
salesforce_id: Optional[str] = None
|
||||
salesforce_lead_id: Optional[str] = None
|
||||
salesforce_contact_id: Optional[str] = None
|
||||
saleforce_account_id: Optional[str] = None
|
||||
crm_owner_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
emailer_campaign_ids: list[str] = []
|
||||
direct_dial_status: Optional[str] = None
|
||||
direct_dial_enrichment_failed_at: Optional[str] = None
|
||||
email_status: Optional[str] = None
|
||||
email_source: Optional[str] = None
|
||||
account_id: Optional[str] = None
|
||||
last_activity_date: Optional[str] = None
|
||||
hubspot_vid: Optional[str] = None
|
||||
hubspot_company_id: Optional[str] = None
|
||||
crm_id: Optional[str] = None
|
||||
sanitized_phone: Optional[str] = None
|
||||
merged_crm_ids: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
queued_for_crm_push: bool = True
|
||||
suggested_from_rule_engine_config_id: Optional[str] = None
|
||||
email_unsubscribed: Optional[str] = None
|
||||
label_ids: list[Any] = []
|
||||
has_pending_email_arcgate_request: bool = True
|
||||
has_email_arcgate_request: bool = True
|
||||
existence_level: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
email_from_customer: Optional[str] = None
|
||||
typed_custom_fields: list[TypedCustomField] = []
|
||||
custom_field_errors: Any = None
|
||||
salesforce_record_id: Optional[str] = None
|
||||
crm_record_url: Optional[str] = None
|
||||
email_status_unavailable_reason: Optional[str] = None
|
||||
email_true_status: Optional[str] = None
|
||||
updated_email_true_status: bool = True
|
||||
contact_rule_config_statuses: list[RuleConfigStatus] = []
|
||||
source_display_name: Optional[str] = None
|
||||
twitter_url: Optional[str] = None
|
||||
contact_campaign_statuses: list[ContactCampaignStatus] = []
|
||||
state: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
account: Optional[Account] = None
|
||||
contact_emails: list[ContactEmail] = []
|
||||
organization: Optional[Organization] = None
|
||||
employment_history: list[EmploymentHistory] = []
|
||||
time_zone: Optional[str] = None
|
||||
intent_strength: Optional[str] = None
|
||||
show_intent: bool = True
|
||||
phone_numbers: list[PhoneNumber] = []
|
||||
account_phone_note: Optional[str] = None
|
||||
free_domain: bool = True
|
||||
is_likely_to_engage: bool = True
|
||||
email_domain_catchall: bool = True
|
||||
contact_job_change_event: Optional[str] = None
|
||||
contact_roles: Optional[list[Any]] = []
|
||||
id: Optional[str] = ""
|
||||
first_name: Optional[str] = ""
|
||||
last_name: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
title: Optional[str] = ""
|
||||
contact_stage_id: Optional[str] = ""
|
||||
owner_id: Optional[str] = ""
|
||||
creator_id: Optional[str] = ""
|
||||
person_id: Optional[str] = ""
|
||||
email_needs_tickling: Optional[bool] = True
|
||||
organization_name: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
original_source: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
headline: Optional[str] = ""
|
||||
photo_url: Optional[str] = ""
|
||||
present_raw_address: Optional[str] = ""
|
||||
linkededin_uid: Optional[str] = ""
|
||||
extrapolated_email_confidence: Optional[float] = 0.0
|
||||
salesforce_id: Optional[str] = ""
|
||||
salesforce_lead_id: Optional[str] = ""
|
||||
salesforce_contact_id: Optional[str] = ""
|
||||
saleforce_account_id: Optional[str] = ""
|
||||
crm_owner_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
emailer_campaign_ids: Optional[list[str]] = []
|
||||
direct_dial_status: Optional[str] = ""
|
||||
direct_dial_enrichment_failed_at: Optional[str] = ""
|
||||
email_status: Optional[str] = ""
|
||||
email_source: Optional[str] = ""
|
||||
account_id: Optional[str] = ""
|
||||
last_activity_date: Optional[str] = ""
|
||||
hubspot_vid: Optional[str] = ""
|
||||
hubspot_company_id: Optional[str] = ""
|
||||
crm_id: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
merged_crm_ids: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
queued_for_crm_push: Optional[bool] = True
|
||||
suggested_from_rule_engine_config_id: Optional[str] = ""
|
||||
email_unsubscribed: Optional[str] = ""
|
||||
label_ids: Optional[list[Any]] = []
|
||||
has_pending_email_arcgate_request: Optional[bool] = True
|
||||
has_email_arcgate_request: Optional[bool] = True
|
||||
existence_level: Optional[str] = ""
|
||||
email: Optional[str] = ""
|
||||
email_from_customer: Optional[str] = ""
|
||||
typed_custom_fields: Optional[list[TypedCustomField]] = []
|
||||
custom_field_errors: Optional[Any] = {}
|
||||
salesforce_record_id: Optional[str] = ""
|
||||
crm_record_url: Optional[str] = ""
|
||||
email_status_unavailable_reason: Optional[str] = ""
|
||||
email_true_status: Optional[str] = ""
|
||||
updated_email_true_status: Optional[bool] = True
|
||||
contact_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
|
||||
source_display_name: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
contact_campaign_statuses: Optional[list[ContactCampaignStatus]] = []
|
||||
state: Optional[str] = ""
|
||||
city: Optional[str] = ""
|
||||
country: Optional[str] = ""
|
||||
account: Optional[Account] = Account()
|
||||
contact_emails: Optional[list[ContactEmail]] = []
|
||||
organization: Optional[Organization] = Organization()
|
||||
employment_history: Optional[list[EmploymentHistory]] = []
|
||||
time_zone: Optional[str] = ""
|
||||
intent_strength: Optional[str] = ""
|
||||
show_intent: Optional[bool] = True
|
||||
phone_numbers: Optional[list[PhoneNumber]] = []
|
||||
account_phone_note: Optional[str] = ""
|
||||
free_domain: Optional[bool] = True
|
||||
is_likely_to_engage: Optional[bool] = True
|
||||
email_domain_catchall: Optional[bool] = True
|
||||
contact_job_change_event: Optional[str] = ""
|
||||
|
||||
|
||||
class SearchOrganizationsRequest(BaseModel):
|
||||
"""Request for Apollo's search organizations API"""
|
||||
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
organization_num_employees_range: Optional[list[int]] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default=[0, 1000000],
|
||||
)
|
||||
|
||||
organization_locations: list[str] = SchemaField(
|
||||
organization_locations: Optional[list[str]] = SchemaField(
|
||||
description="""The location of the company headquarters. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, any Boston-based companies will not appearch in your search results, even if they match other parameters.
|
||||
@@ -389,28 +387,30 @@ To exclude companies based on location, use the organization_not_locations param
|
||||
""",
|
||||
default_factory=list,
|
||||
)
|
||||
organizations_not_locations: list[str] = SchemaField(
|
||||
organizations_not_locations: Optional[list[str]] = SchemaField(
|
||||
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
|
||||
|
||||
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
|
||||
""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_keyword_tags: list[str] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
|
||||
q_organization_keyword_tags: Optional[list[str]] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_name: str = SchemaField(
|
||||
q_organization_name: Optional[str] = SchemaField(
|
||||
description="""Filter search results to include a specific company name.
|
||||
|
||||
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible."""
|
||||
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible.""",
|
||||
default="",
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
organization_ids: Optional[list[str]] = SchemaField(
|
||||
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, identify the values for organization_id when you call this endpoint.""",
|
||||
default_factory=list,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
max_results: Optional[int] = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
@@ -435,11 +435,11 @@ Use the page parameter to search the different pages of data.""",
|
||||
class SearchOrganizationsResponse(BaseModel):
|
||||
"""Response from Apollo's search organizations API"""
|
||||
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
has_join: bool = True
|
||||
disable_eu_prospecting: bool = True
|
||||
partial_results_limit: int = 0
|
||||
breadcrumbs: Optional[list[Breadcrumb]] = []
|
||||
partial_results_only: Optional[bool] = True
|
||||
has_join: Optional[bool] = True
|
||||
disable_eu_prospecting: Optional[bool] = True
|
||||
partial_results_limit: Optional[int] = 0
|
||||
pagination: Pagination = Pagination(
|
||||
page=0, per_page=0, total_entries=0, total_pages=0
|
||||
)
|
||||
@@ -447,14 +447,14 @@ class SearchOrganizationsResponse(BaseModel):
|
||||
accounts: list[Any] = []
|
||||
organizations: list[Organization] = []
|
||||
models_ids: list[str] = []
|
||||
num_fetch_result: Optional[str] = "N/A"
|
||||
derived_params: Optional[str] = "N/A"
|
||||
num_fetch_result: Optional[str] = ""
|
||||
derived_params: Optional[str] = ""
|
||||
|
||||
|
||||
class SearchPeopleRequest(BaseModel):
|
||||
"""Request for Apollo's search people API"""
|
||||
|
||||
person_titles: list[str] = SchemaField(
|
||||
person_titles: Optional[list[str]] = SchemaField(
|
||||
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
|
||||
|
||||
Results also include job titles with the same terms, even if they are not exact matches. For example, searching for marketing manager might return people with the job title content marketing manager.
|
||||
@@ -464,13 +464,13 @@ Use this parameter in combination with the person_seniorities[] parameter to fin
|
||||
default_factory=list,
|
||||
placeholder="marketing manager",
|
||||
)
|
||||
person_locations: list[str] = SchemaField(
|
||||
person_locations: Optional[list[str]] = SchemaField(
|
||||
description="""The location where people live. You can search across cities, US states, and countries.
|
||||
|
||||
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
|
||||
default_factory=list,
|
||||
)
|
||||
person_seniorities: list[SenorityLevels] = SchemaField(
|
||||
person_seniorities: Optional[list[SenorityLevels]] = SchemaField(
|
||||
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
|
||||
|
||||
For a person to be included in search results, they only need to match 1 of the seniorities you add. Adding more seniorities expands your search results.
|
||||
@@ -480,7 +480,7 @@ Searches only return results based on their current job title, so searching for
|
||||
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
|
||||
default_factory=list,
|
||||
)
|
||||
organization_locations: list[str] = SchemaField(
|
||||
organization_locations: Optional[list[str]] = SchemaField(
|
||||
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
|
||||
@@ -488,7 +488,7 @@ If a company has several office locations, results are still based on the headqu
|
||||
To find people based on their personal location, use the person_locations parameter.""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_domains: list[str] = SchemaField(
|
||||
q_organization_domains: Optional[list[str]] = SchemaField(
|
||||
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
|
||||
|
||||
You can add multiple domains to search across companies.
|
||||
@@ -496,23 +496,23 @@ You can add multiple domains to search across companies.
|
||||
Examples: apollo.io and microsoft.com""",
|
||||
default_factory=list,
|
||||
)
|
||||
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
|
||||
contact_email_statuses: Optional[list[ContactEmailStatuses]] = SchemaField(
|
||||
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
|
||||
default_factory=list,
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
organization_ids: Optional[list[str]] = SchemaField(
|
||||
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
|
||||
default_factory=list,
|
||||
)
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
organization_num_employees_range: Optional[list[int]] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_keywords: str = SchemaField(
|
||||
q_keywords: Optional[str] = SchemaField(
|
||||
description="""A string of words over which we want to filter the results""",
|
||||
default="",
|
||||
)
|
||||
@@ -528,7 +528,7 @@ Use this parameter in combination with the per_page parameter to make search res
|
||||
Use the page parameter to search the different pages of data.""",
|
||||
default=100,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
max_results: Optional[int] = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
@@ -547,16 +547,61 @@ class SearchPeopleResponse(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
has_join: bool = True
|
||||
disable_eu_prospecting: bool = True
|
||||
partial_results_limit: int = 0
|
||||
breadcrumbs: Optional[list[Breadcrumb]] = []
|
||||
partial_results_only: Optional[bool] = True
|
||||
has_join: Optional[bool] = True
|
||||
disable_eu_prospecting: Optional[bool] = True
|
||||
partial_results_limit: Optional[int] = 0
|
||||
pagination: Pagination = Pagination(
|
||||
page=0, per_page=0, total_entries=0, total_pages=0
|
||||
)
|
||||
contacts: list[Contact] = []
|
||||
people: list[Contact] = []
|
||||
model_ids: list[str] = []
|
||||
num_fetch_result: Optional[str] = "N/A"
|
||||
derived_params: Optional[str] = "N/A"
|
||||
num_fetch_result: Optional[str] = ""
|
||||
derived_params: Optional[str] = ""
|
||||
|
||||
|
||||
class EnrichPersonRequest(BaseModel):
|
||||
"""Request for Apollo's person enrichment API"""
|
||||
|
||||
person_id: Optional[str] = SchemaField(
|
||||
description="Apollo person ID to enrich (most accurate method)",
|
||||
default="",
|
||||
)
|
||||
first_name: Optional[str] = SchemaField(
|
||||
description="First name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
last_name: Optional[str] = SchemaField(
|
||||
description="Last name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
name: Optional[str] = SchemaField(
|
||||
description="Full name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
email: Optional[str] = SchemaField(
|
||||
description="Email address of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
domain: Optional[str] = SchemaField(
|
||||
description="Company domain of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
company: Optional[str] = SchemaField(
|
||||
description="Company name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
linkedin_url: Optional[str] = SchemaField(
|
||||
description="LinkedIn URL of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
organization_id: Optional[str] = SchemaField(
|
||||
description="Apollo organization ID of the person's company",
|
||||
default="",
|
||||
)
|
||||
title: Optional[str] = SchemaField(
|
||||
description="Job title of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
|
||||
@@ -11,14 +11,14 @@ from backend.blocks.apollo.models import (
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
"""Search for organizations in Apollo"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
organization_num_employees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
@@ -65,7 +65,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
le=50000,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: ApolloCredentialsInput = SchemaField(
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
|
||||
from backend.blocks.apollo._api import ApolloClient
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
@@ -8,11 +10,12 @@ from backend.blocks.apollo._auth import (
|
||||
from backend.blocks.apollo.models import (
|
||||
Contact,
|
||||
ContactEmailStatuses,
|
||||
EnrichPersonRequest,
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
@@ -77,7 +80,7 @@ class SearchPeopleBlock(Block):
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
organization_num_employees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
@@ -90,14 +93,19 @@ class SearchPeopleBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 25. Limited to 500 to prevent overspending.""",
|
||||
default=25,
|
||||
ge=1,
|
||||
le=50000,
|
||||
le=500,
|
||||
advanced=True,
|
||||
)
|
||||
enrich_info: bool = SchemaField(
|
||||
description="""Whether to enrich contacts with detailed information including real email addresses. This will double the search cost.""",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
credentials: ApolloCredentialsInput = SchemaField(
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
@@ -106,10 +114,6 @@ class SearchPeopleBlock(Block):
|
||||
description="List of people found",
|
||||
default_factory=list,
|
||||
)
|
||||
person: Contact = SchemaField(
|
||||
title="Person",
|
||||
description="Each found person, one at a time",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the search failed",
|
||||
default="",
|
||||
@@ -125,87 +129,6 @@ class SearchPeopleBlock(Block):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
(
|
||||
"person",
|
||||
Contact(
|
||||
contact_roles=[],
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
organization_id="123456",
|
||||
contact_stage_id="1",
|
||||
owner_id="1",
|
||||
creator_id="1",
|
||||
person_id="1",
|
||||
email_needs_tickling=True,
|
||||
source="apollo",
|
||||
original_source="apollo",
|
||||
headline="Software Engineer",
|
||||
photo_url="https://www.linkedin.com/in/johndoe",
|
||||
present_raw_address="123 Main St, Anytown, USA",
|
||||
linkededin_uid="123456",
|
||||
extrapolated_email_confidence=0.8,
|
||||
salesforce_id="123456",
|
||||
salesforce_lead_id="123456",
|
||||
salesforce_contact_id="123456",
|
||||
saleforce_account_id="123456",
|
||||
crm_owner_id="123456",
|
||||
created_at="2021-01-01",
|
||||
emailer_campaign_ids=[],
|
||||
direct_dial_status="active",
|
||||
direct_dial_enrichment_failed_at="2021-01-01",
|
||||
email_status="active",
|
||||
email_source="apollo",
|
||||
account_id="123456",
|
||||
last_activity_date="2021-01-01",
|
||||
hubspot_vid="123456",
|
||||
hubspot_company_id="123456",
|
||||
crm_id="123456",
|
||||
sanitized_phone="123456",
|
||||
merged_crm_ids="123456",
|
||||
updated_at="2021-01-01",
|
||||
queued_for_crm_push=True,
|
||||
suggested_from_rule_engine_config_id="123456",
|
||||
email_unsubscribed=None,
|
||||
label_ids=[],
|
||||
has_pending_email_arcgate_request=True,
|
||||
has_email_arcgate_request=True,
|
||||
existence_level=None,
|
||||
email=None,
|
||||
email_from_customer=None,
|
||||
typed_custom_fields=[],
|
||||
custom_field_errors=None,
|
||||
salesforce_record_id=None,
|
||||
crm_record_url=None,
|
||||
email_status_unavailable_reason=None,
|
||||
email_true_status=None,
|
||||
updated_email_true_status=True,
|
||||
contact_rule_config_statuses=[],
|
||||
source_display_name=None,
|
||||
twitter_url=None,
|
||||
contact_campaign_statuses=[],
|
||||
state=None,
|
||||
city=None,
|
||||
country=None,
|
||||
account=None,
|
||||
contact_emails=[],
|
||||
organization=None,
|
||||
employment_history=[],
|
||||
time_zone=None,
|
||||
intent_strength=None,
|
||||
show_intent=True,
|
||||
phone_numbers=[],
|
||||
account_phone_note=None,
|
||||
free_domain=True,
|
||||
is_likely_to_engage=True,
|
||||
email_domain_catchall=True,
|
||||
contact_job_change_event=None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"people",
|
||||
[
|
||||
@@ -380,6 +303,34 @@ class SearchPeopleBlock(Block):
|
||||
client = ApolloClient(credentials)
|
||||
return await client.search_people(query)
|
||||
|
||||
@staticmethod
|
||||
async def enrich_person(
|
||||
query: EnrichPersonRequest, credentials: ApolloCredentials
|
||||
) -> Contact:
|
||||
client = ApolloClient(credentials)
|
||||
return await client.enrich_person(query)
|
||||
|
||||
@staticmethod
|
||||
def merge_contact_data(original: Contact, enriched: Contact) -> Contact:
|
||||
"""
|
||||
Merge contact data from original search with enriched data.
|
||||
Enriched data complements original data, only filling in missing values.
|
||||
"""
|
||||
merged_data = original.model_dump()
|
||||
enriched_data = enriched.model_dump()
|
||||
|
||||
# Only update fields that are None, empty string, empty list, or default values in original
|
||||
for key, enriched_value in enriched_data.items():
|
||||
# Skip if enriched value is None, empty string, or empty list
|
||||
if enriched_value is None or enriched_value == "" or enriched_value == []:
|
||||
continue
|
||||
|
||||
# Update if original value is None, empty string, empty list, or zero
|
||||
if enriched_value:
|
||||
merged_data[key] = enriched_value
|
||||
|
||||
return Contact(**merged_data)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
@@ -390,6 +341,23 @@ class SearchPeopleBlock(Block):
|
||||
|
||||
query = SearchPeopleRequest(**input_data.model_dump())
|
||||
people = await self.search_people(query, credentials)
|
||||
for person in people:
|
||||
yield "person", person
|
||||
|
||||
# Enrich with detailed info if requested
|
||||
if input_data.enrich_info:
|
||||
|
||||
async def enrich_or_fallback(person: Contact):
|
||||
try:
|
||||
enrich_query = EnrichPersonRequest(person_id=person.id)
|
||||
enriched_person = await self.enrich_person(
|
||||
enrich_query, credentials
|
||||
)
|
||||
# Merge enriched data with original data, complementing instead of replacing
|
||||
return self.merge_contact_data(person, enriched_person)
|
||||
except Exception:
|
||||
return person # If enrichment fails, use original person data
|
||||
|
||||
people = await asyncio.gather(
|
||||
*(enrich_or_fallback(person) for person in people)
|
||||
)
|
||||
|
||||
yield "people", people
|
||||
|
||||
138
autogpt_platform/backend/backend/blocks/apollo/person.py
Normal file
138
autogpt_platform/backend/backend/blocks/apollo/person.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from backend.blocks.apollo._api import ApolloClient
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ApolloCredentials,
|
||||
ApolloCredentialsInput,
|
||||
)
|
||||
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class GetPersonDetailBlock(Block):
|
||||
"""Get detailed person data with Apollo API, including email reveal"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
person_id: str = SchemaField(
|
||||
description="Apollo person ID to enrich (most accurate method)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
first_name: str = SchemaField(
|
||||
description="First name of the person to enrich",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
last_name: str = SchemaField(
|
||||
description="Last name of the person to enrich",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(
|
||||
description="Full name of the person to enrich (alternative to first_name + last_name)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
email: str = SchemaField(
|
||||
description="Known email address of the person (helps with matching)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Company domain of the person (e.g., 'google.com')",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
company: str = SchemaField(
|
||||
description="Company name of the person",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
linkedin_url: str = SchemaField(
|
||||
description="LinkedIn URL of the person",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
organization_id: str = SchemaField(
|
||||
description="Apollo organization ID of the person's company",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Job title of the person to enrich",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
contact: Contact = SchemaField(
|
||||
description="Enriched contact information",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if enrichment failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3b18d46c-3db6-42ae-a228-0ba441bdd176",
|
||||
description="Get detailed person data with Apollo API, including email reveal",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=GetPersonDetailBlock.Input,
|
||||
output_schema=GetPersonDetailBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"company": "Google",
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"contact",
|
||||
Contact(
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
email="john.doe@gmail.com",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
),
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"enrich_person": lambda query, credentials: Contact(
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
email="john.doe@gmail.com",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def enrich_person(
|
||||
query: EnrichPersonRequest, credentials: ApolloCredentials
|
||||
) -> Contact:
|
||||
client = ApolloClient(credentials)
|
||||
return await client.enrich_person(query)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: ApolloCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
query = EnrichPersonRequest(**input_data.model_dump())
|
||||
yield "contact", await self.enrich_person(query, credentials)
|
||||
@@ -6,6 +6,7 @@ from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.prompt import estimate_token_count_str
|
||||
from backend.util.type import MediaFileType, convert
|
||||
|
||||
|
||||
@@ -14,6 +15,12 @@ class FileStoreBlock(Block):
|
||||
file_in: MediaFileType = SchemaField(
|
||||
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
||||
)
|
||||
base_64: bool = SchemaField(
|
||||
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
|
||||
default=False,
|
||||
advanced=True,
|
||||
title="Produce Base64 Output",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
file_out: MediaFileType = SchemaField(
|
||||
@@ -37,12 +44,11 @@ class FileStoreBlock(Block):
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
file_path = await store_media_file(
|
||||
yield "file_out", await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.file_in,
|
||||
return_content=False,
|
||||
return_content=input_data.base_64,
|
||||
)
|
||||
yield "file_out", file_path
|
||||
|
||||
|
||||
class StoreValueBlock(Block):
|
||||
@@ -461,6 +467,11 @@ class CreateListBlock(Block):
|
||||
description="Maximum size of the list. If provided, the list will be yielded in chunks of this size.",
|
||||
advanced=True,
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
default=None,
|
||||
description="Maximum tokens for the list. If provided, the list will be yielded in chunks that fit within this token limit.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
@@ -471,7 +482,7 @@ class CreateListBlock(Block):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a912d5c7-6e00-4542-b2a9-8034136930e4",
|
||||
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront.",
|
||||
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront. This block can also yield the list in batches based on a maximum size or token limit.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateListBlock.Input,
|
||||
output_schema=CreateListBlock.Output,
|
||||
@@ -496,12 +507,30 @@ class CreateListBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
max_size = input_data.max_size or len(input_data.values)
|
||||
for i in range(0, len(input_data.values), max_size):
|
||||
yield "list", input_data.values[i : i + max_size]
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create list: {str(e)}"
|
||||
chunk = []
|
||||
cur_tokens, max_tokens = 0, input_data.max_tokens
|
||||
cur_size, max_size = 0, input_data.max_size
|
||||
|
||||
for value in input_data.values:
|
||||
if max_tokens:
|
||||
tokens = estimate_token_count_str(value)
|
||||
else:
|
||||
tokens = 0
|
||||
|
||||
# Check if adding this value would exceed either limit
|
||||
if (max_tokens and (cur_tokens + tokens > max_tokens)) or (
|
||||
max_size and (cur_size + 1 > max_size)
|
||||
):
|
||||
yield "list", chunk
|
||||
chunk = [value]
|
||||
cur_size, cur_tokens = 1, tokens
|
||||
else:
|
||||
chunk.append(value)
|
||||
cur_size, cur_tokens = cur_size + 1, cur_tokens + tokens
|
||||
|
||||
# Yield final chunk if any
|
||||
if chunk:
|
||||
yield "list", chunk
|
||||
|
||||
|
||||
class TypeOptions(enum.Enum):
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.type import convert
|
||||
|
||||
|
||||
class ComparisonOperator(Enum):
|
||||
@@ -181,7 +182,23 @@ class IfInputMatchesBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if input_data.input == input_data.value or input_data.input is input_data.value:
|
||||
|
||||
# If input_data.value is not matching input_data.input, convert value to type of input
|
||||
if (
|
||||
input_data.input != input_data.value
|
||||
and input_data.input is not input_data.value
|
||||
):
|
||||
try:
|
||||
# Only attempt conversion if input is not None and value is not None
|
||||
if input_data.input is not None and input_data.value is not None:
|
||||
input_type = type(input_data.input)
|
||||
# Avoid converting if input_type is Any or object
|
||||
if input_type not in (Any, object):
|
||||
input_data.value = convert(input_data.value, input_type)
|
||||
except Exception:
|
||||
pass # If conversion fails, just leave value as is
|
||||
|
||||
if input_data.input == input_data.value:
|
||||
yield "result", True
|
||||
yield "yes_output", input_data.yes_value
|
||||
else:
|
||||
|
||||
@@ -13,7 +13,7 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import MediaFileType
|
||||
from backend.util.file import MediaFileType, store_media_file
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -108,7 +108,7 @@ class AIImageEditorBlock(Block):
|
||||
output_schema=AIImageEditorBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Add a hat to the cat",
|
||||
"input_image": "https://example.com/cat.png",
|
||||
"input_image": "",
|
||||
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
|
||||
"seed": None,
|
||||
"model": FluxKontextModelName.PRO,
|
||||
@@ -128,13 +128,22 @@ class AIImageEditorBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
model_name=input_data.model.api_name,
|
||||
prompt=input_data.prompt,
|
||||
input_image=input_data.input_image,
|
||||
input_image_b64=(
|
||||
await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.input_image,
|
||||
return_content=True,
|
||||
)
|
||||
if input_data.input_image
|
||||
else None
|
||||
),
|
||||
aspect_ratio=input_data.aspect_ratio.value,
|
||||
seed=input_data.seed,
|
||||
)
|
||||
@@ -145,14 +154,14 @@ class AIImageEditorBlock(Block):
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
input_image: Optional[MediaFileType],
|
||||
input_image_b64: Optional[str],
|
||||
aspect_ratio: str,
|
||||
seed: Optional[int],
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"input_image": input_image,
|
||||
"input_image": input_image_b64,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
|
||||
@@ -265,10 +265,26 @@ class GithubReadPullRequestBlock(Block):
|
||||
files = response.json()
|
||||
changes = []
|
||||
for file in files:
|
||||
filename = file.get("filename", "")
|
||||
status = file.get("status", "")
|
||||
changes.append(f"{filename}: {status}")
|
||||
return "\n".join(changes)
|
||||
status: str = file.get("status", "")
|
||||
diff: str = file.get("patch", "")
|
||||
if status != "removed":
|
||||
is_filename: str = file.get("filename", "")
|
||||
was_filename: str = (
|
||||
file.get("previous_filename", is_filename)
|
||||
if status != "added"
|
||||
else ""
|
||||
)
|
||||
else:
|
||||
is_filename = ""
|
||||
was_filename: str = file.get("filename", "")
|
||||
|
||||
patch_header = ""
|
||||
if was_filename:
|
||||
patch_header += f"--- {was_filename}\n"
|
||||
if is_filename:
|
||||
patch_header += f"+++ {is_filename}\n"
|
||||
changes.append(patch_header + diff)
|
||||
return "\n\n".join(changes)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
||||
@@ -3,11 +3,19 @@ import logging
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import aiofiles
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
HostScopedCredentials,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import (
|
||||
MediaFileType,
|
||||
get_exec_file_path,
|
||||
@@ -19,6 +27,30 @@ from backend.util.request import Requests
|
||||
logger = logging.getLogger(name=__name__)
|
||||
|
||||
|
||||
# Host-scoped credentials for HTTP requests
|
||||
HttpCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.HTTP], Literal["host_scoped"]
|
||||
]
|
||||
|
||||
|
||||
TEST_CREDENTIALS = HostScopedCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer test-token"),
|
||||
"X-API-Key": SecretStr("test-api-key"),
|
||||
},
|
||||
title="Mock HTTP Host-Scoped Credentials",
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class HttpMethod(Enum):
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
@@ -169,3 +201,62 @@ class SendWebRequestBlock(Block):
|
||||
yield "client_error", result
|
||||
else:
|
||||
yield "server_error", result
|
||||
|
||||
|
||||
class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
||||
class Input(SendWebRequestBlock.Input):
|
||||
credentials: HttpCredentials = CredentialsField(
|
||||
description="HTTP host-scoped credentials for automatic header injection",
|
||||
discriminator="url",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
Block.__init__(
|
||||
self,
|
||||
id="fff86bcd-e001-4bad-a7f6-2eae4720c8dc",
|
||||
description="Make an authenticated HTTP request with host-scoped credentials (JSON / form / multipart).",
|
||||
categories={BlockCategory.OUTPUT},
|
||||
input_schema=SendAuthenticatedWebRequestBlock.Input,
|
||||
output_schema=SendWebRequestBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run( # type: ignore[override]
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
credentials: HostScopedCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
||||
base_input = SendWebRequestBlock.Input(
|
||||
url=input_data.url,
|
||||
method=input_data.method,
|
||||
headers=input_data.headers,
|
||||
json_format=input_data.json_format,
|
||||
body=input_data.body,
|
||||
files_name=input_data.files_name,
|
||||
files=input_data.files,
|
||||
)
|
||||
|
||||
# Apply host-scoped credentials to headers
|
||||
extra_headers = {}
|
||||
if credentials.matches_url(input_data.url):
|
||||
logger.debug(
|
||||
f"Applying host-scoped credentials {credentials.id} for URL {input_data.url}"
|
||||
)
|
||||
extra_headers.update(credentials.get_headers_dict())
|
||||
else:
|
||||
logger.warning(
|
||||
f"Host-scoped credentials {credentials.id} do not match URL {input_data.url}"
|
||||
)
|
||||
|
||||
# Merge with user-provided headers (user headers take precedence)
|
||||
base_input.headers = {**extra_headers, **input_data.headers}
|
||||
|
||||
# Use parent class run method
|
||||
async for output_name, output_data in super().run(
|
||||
base_input, graph_exec_id=graph_exec_id, **kwargs
|
||||
):
|
||||
yield output_name, output_data
|
||||
|
||||
@@ -413,6 +413,12 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
base_64: bool = SchemaField(
|
||||
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
|
||||
default=False,
|
||||
advanced=True,
|
||||
title="Produce Base64 Output",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="File reference/path result.")
|
||||
@@ -446,12 +452,11 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
if not input_data.value:
|
||||
return
|
||||
|
||||
file_path = await store_media_file(
|
||||
yield "result", await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.value,
|
||||
return_content=False,
|
||||
return_content=input_data.base_64,
|
||||
)
|
||||
yield "result", file_path
|
||||
|
||||
|
||||
class AgentDropdownInputBlock(AgentInputBlock):
|
||||
|
||||
@@ -23,6 +23,7 @@ from backend.data.model import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.prompt import compress_prompt, estimate_token_count
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
@@ -40,7 +41,7 @@ LLMProviderName = Literal[
|
||||
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
id="769f6af7-820b-4d5d-9b7a-ab82bbc165f",
|
||||
provider="openai",
|
||||
api_key=SecretStr("mock-openai-api-key"),
|
||||
title="Mock OpenAI API key",
|
||||
@@ -306,13 +307,6 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def estimate_token_count(prompt_messages: list[dict]) -> int:
|
||||
char_count = sum(len(str(msg.get("content", ""))) for msg in prompt_messages)
|
||||
message_overhead = len(prompt_messages) * 4
|
||||
estimated_tokens = (char_count // 4) + message_overhead
|
||||
return int(estimated_tokens * 1.2)
|
||||
|
||||
|
||||
async def llm_call(
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
@@ -321,7 +315,8 @@ async def llm_call(
|
||||
max_tokens: int | None,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
parallel_tool_calls: bool | None = None,
|
||||
parallel_tool_calls=None,
|
||||
compress_prompt_to_fit: bool = True,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Make a call to a language model.
|
||||
@@ -344,10 +339,17 @@ async def llm_call(
|
||||
- completion_tokens: The number of tokens used in the completion.
|
||||
"""
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
|
||||
if compress_prompt_to_fit:
|
||||
prompt = compress_prompt(
|
||||
messages=prompt,
|
||||
target_tokens=llm_model.context_window // 2,
|
||||
lossy_ok=True,
|
||||
)
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
user_max = max_tokens or model_max_output
|
||||
available_tokens = max(context_window - estimated_input_tokens, 0)
|
||||
@@ -358,14 +360,10 @@ async def llm_call(
|
||||
oai_client = openai.AsyncOpenAI(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = None
|
||||
|
||||
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
prompt = [
|
||||
{"role": "user", "content": "\n".join(sys_messages)},
|
||||
{"role": "user", "content": "\n".join(usr_messages)},
|
||||
]
|
||||
elif json_format:
|
||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||
parallel_tool_calls = openai.NOT_GIVEN
|
||||
|
||||
if json_format:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = await oai_client.chat.completions.create(
|
||||
@@ -374,9 +372,7 @@ async def llm_call(
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=(
|
||||
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
|
||||
),
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
)
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
@@ -699,7 +695,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
)
|
||||
|
||||
compress_prompt_to_fit: bool = SchemaField(
|
||||
advanced=True,
|
||||
default=True,
|
||||
description="Whether to compress the prompt to fit within the model's context window.",
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
default="localhost:11434",
|
||||
@@ -757,6 +757,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
llm_model: LlmModel,
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
compress_prompt_to_fit: bool,
|
||||
max_tokens: int | None,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
@@ -774,6 +775,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
ollama_host=ollama_host,
|
||||
compress_prompt_to_fit=compress_prompt_to_fit,
|
||||
)
|
||||
|
||||
async def run(
|
||||
@@ -832,7 +834,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
except JSONDecodeError as e:
|
||||
return f"JSON decode error: {e}"
|
||||
|
||||
logger.info(f"LLM request: {prompt}")
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
retry_prompt = ""
|
||||
llm_model = input_data.model
|
||||
|
||||
@@ -842,6 +844,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
compress_prompt_to_fit=input_data.compress_prompt_to_fit,
|
||||
json_format=bool(input_data.expected_format),
|
||||
ollama_host=input_data.ollama_host,
|
||||
max_tokens=input_data.max_tokens,
|
||||
@@ -853,7 +856,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
)
|
||||
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from backend.data.model import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
id="8cc8b2c5-d3e4-4b1c-84ad-e1e9fe2a0122",
|
||||
provider="mem0",
|
||||
api_key=SecretStr("mock-mem0-api-key"),
|
||||
title="Mock Mem0 API key",
|
||||
|
||||
@@ -85,7 +85,7 @@ def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
|
||||
return tool_call_ids
|
||||
|
||||
|
||||
def _create_tool_response(call_id: str, output: dict[str, Any]) -> dict[str, Any]:
|
||||
def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Create a tool response message for either OpenAI or Anthropics,
|
||||
based on the tool_id format.
|
||||
@@ -212,6 +212,15 @@ class SmartDecisionMakerBlock(Block):
|
||||
"link like the output of `StoreValue` or `AgentInput` block"
|
||||
)
|
||||
|
||||
# Check that both conversation_history and last_tool_output are connected together
|
||||
if any(link.sink_name == "conversation_history" for link in links) != any(
|
||||
link.sink_name == "last_tool_output" for link in links
|
||||
):
|
||||
raise ValueError(
|
||||
"Last Tool Output is needed when Conversation History is used, "
|
||||
"and vice versa. Please connect both inputs together."
|
||||
)
|
||||
|
||||
return missing_links
|
||||
|
||||
@classmethod
|
||||
@@ -222,8 +231,15 @@ class SmartDecisionMakerBlock(Block):
|
||||
conversation_history = data.get("conversation_history", [])
|
||||
pending_tool_calls = get_pending_tool_calls(conversation_history)
|
||||
last_tool_output = data.get("last_tool_output")
|
||||
if not last_tool_output and pending_tool_calls:
|
||||
|
||||
# Tool call is pending, wait for the tool output to be provided.
|
||||
if last_tool_output is None and pending_tool_calls:
|
||||
return {"last_tool_output"}
|
||||
|
||||
# No tool call is pending, wait for the conversation history to be updated.
|
||||
if last_tool_output is not None and not pending_tool_calls:
|
||||
return {"conversation_history"}
|
||||
|
||||
return set()
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -433,7 +449,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
prompt = [json.to_dict(p) for p in input_data.conversation_history if p]
|
||||
|
||||
pending_tool_calls = get_pending_tool_calls(input_data.conversation_history)
|
||||
if pending_tool_calls and not input_data.last_tool_output:
|
||||
if pending_tool_calls and input_data.last_tool_output is None:
|
||||
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
|
||||
|
||||
# Prefill all missing tool calls with the last tool output/
|
||||
@@ -497,7 +513,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=True if input_data.multiple_tool_calls else None,
|
||||
parallel_tool_calls=input_data.multiple_tool_calls,
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.smartlead.models import (
|
||||
Sequence,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class CreateCampaignBlock(Block):
|
||||
@@ -27,7 +27,7 @@ class CreateCampaignBlock(Block):
|
||||
name: str = SchemaField(
|
||||
description="The name of the campaign",
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
credentials: SmartLeadCredentialsInput = CredentialsField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
@@ -119,7 +119,7 @@ class AddLeadToCampaignBlock(Block):
|
||||
description="Settings for lead upload",
|
||||
default=LeadUploadSettings(),
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
credentials: SmartLeadCredentialsInput = CredentialsField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
@@ -251,7 +251,7 @@ class SaveCampaignSequencesBlock(Block):
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
credentials: SmartLeadCredentialsInput = CredentialsField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
|
||||
485
autogpt_platform/backend/backend/blocks/test/test_http.py
Normal file
485
autogpt_platform/backend/backend/blocks/test/test_http.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""Comprehensive tests for HTTP block with HostScopedCredentials functionality."""
|
||||
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.http import (
|
||||
HttpCredentials,
|
||||
HttpMethod,
|
||||
SendAuthenticatedWebRequestBlock,
|
||||
)
|
||||
from backend.data.model import HostScopedCredentials
|
||||
from backend.util.request import Response
|
||||
|
||||
|
||||
class TestHttpBlockWithHostScopedCredentials:
|
||||
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
||||
|
||||
@pytest.fixture
|
||||
def http_block(self):
|
||||
"""Create an HTTP block instance."""
|
||||
return SendAuthenticatedWebRequestBlock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Mock a successful HTTP response."""
|
||||
response = MagicMock(spec=Response)
|
||||
response.status = 200
|
||||
response.headers = {"content-type": "application/json"}
|
||||
response.json.return_value = {"success": True, "data": "test"}
|
||||
return response
|
||||
|
||||
@pytest.fixture
|
||||
def exact_match_credentials(self):
|
||||
"""Create host-scoped credentials for exact domain matching."""
|
||||
return HostScopedCredentials(
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer exact-match-token"),
|
||||
"X-API-Key": SecretStr("api-key-123"),
|
||||
},
|
||||
title="Exact Match API Credentials",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def wildcard_credentials(self):
|
||||
"""Create host-scoped credentials with wildcard pattern."""
|
||||
return HostScopedCredentials(
|
||||
provider="http",
|
||||
host="*.github.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("token ghp_wildcard123"),
|
||||
},
|
||||
title="GitHub Wildcard Credentials",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def non_matching_credentials(self):
|
||||
"""Create credentials that don't match test URLs."""
|
||||
return HostScopedCredentials(
|
||||
provider="http",
|
||||
host="different.api.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer non-matching-token"),
|
||||
},
|
||||
title="Non-matching Credentials",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_http_block_with_exact_host_match(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
exact_match_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test HTTP block with exact host matching credentials."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Prepare input data
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": exact_match_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": exact_match_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with credentials provided by execution manager
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=exact_match_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify request headers include both credential and user headers
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {
|
||||
"Authorization": "Bearer exact-match-token",
|
||||
"X-API-Key": "api-key-123",
|
||||
"User-Agent": "test-agent",
|
||||
}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
# Verify response handling
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "response"
|
||||
assert result[0][1] == {"success": True, "data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_http_block_with_wildcard_host_match(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
wildcard_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test HTTP block with wildcard host pattern matching."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with subdomain that should match *.github.com
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.github.com/user",
|
||||
method=HttpMethod.GET,
|
||||
headers={},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": wildcard_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": wildcard_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with wildcard credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=wildcard_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify wildcard matching works
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {"Authorization": "token ghp_wildcard123"}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_http_block_with_non_matching_credentials(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
non_matching_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test HTTP block when credentials don't match the target URL."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with URL that doesn't match the credentials
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": non_matching_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": non_matching_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with non-matching credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=non_matching_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify only user headers are included (no credential headers)
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {"User-Agent": "test-agent"}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_user_headers_override_credential_headers(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
exact_match_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test that user-provided headers take precedence over credential headers."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with user header that conflicts with credential header
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.POST,
|
||||
headers={
|
||||
"Authorization": "Bearer user-override-token", # Should override
|
||||
"Content-Type": "application/json", # Additional user header
|
||||
},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": exact_match_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": exact_match_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with conflicting headers
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=exact_match_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify user headers take precedence
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {
|
||||
"X-API-Key": "api-key-123", # From credentials
|
||||
"Authorization": "Bearer user-override-token", # User override
|
||||
"Content-Type": "application/json", # User header
|
||||
}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_auto_discovered_credentials_flow(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
mock_response,
|
||||
):
|
||||
"""Test the auto-discovery flow where execution manager provides matching credentials."""
|
||||
# Create auto-discovered credentials
|
||||
auto_discovered_creds = HostScopedCredentials(
|
||||
provider="http",
|
||||
host="*.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer auto-discovered-token"),
|
||||
},
|
||||
title="Auto-discovered Credentials",
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with empty credentials field (triggers auto-discovery)
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": "", # Empty ID triggers auto-discovery in execution manager
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": "",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with auto-discovered credentials provided by execution manager
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=auto_discovered_creds, # Execution manager found these
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify auto-discovered credentials were applied
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {"Authorization": "Bearer auto-discovered-token"}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
# Verify response handling
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "response"
|
||||
assert result[0][1] == {"success": True, "data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_multiple_header_credentials(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
mock_response,
|
||||
):
|
||||
"""Test credentials with multiple headers are all applied."""
|
||||
# Create credentials with multiple headers
|
||||
multi_header_creds = HostScopedCredentials(
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer multi-token"),
|
||||
"X-API-Key": SecretStr("api-key-456"),
|
||||
"X-Client-ID": SecretStr("client-789"),
|
||||
"X-Custom-Header": SecretStr("custom-value"),
|
||||
},
|
||||
title="Multi-Header Credentials",
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with credentials containing multiple headers
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": multi_header_creds.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": multi_header_creds.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with multi-header credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=multi_header_creds,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify all headers are included
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {
|
||||
"Authorization": "Bearer multi-token",
|
||||
"X-API-Key": "api-key-456",
|
||||
"X-Client-ID": "client-789",
|
||||
"X-Custom-Header": "custom-value",
|
||||
"User-Agent": "test-agent",
|
||||
}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_credentials_with_complex_url_patterns(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
mock_response,
|
||||
):
|
||||
"""Test credentials matching various URL patterns."""
|
||||
# Test cases for different URL patterns
|
||||
test_cases = [
|
||||
{
|
||||
"host_pattern": "api.example.com",
|
||||
"test_url": "https://api.example.com/v1/users",
|
||||
"should_match": True,
|
||||
},
|
||||
{
|
||||
"host_pattern": "*.example.com",
|
||||
"test_url": "https://api.example.com/v1/users",
|
||||
"should_match": True,
|
||||
},
|
||||
{
|
||||
"host_pattern": "*.example.com",
|
||||
"test_url": "https://subdomain.example.com/data",
|
||||
"should_match": True,
|
||||
},
|
||||
{
|
||||
"host_pattern": "api.example.com",
|
||||
"test_url": "https://api.different.com/data",
|
||||
"should_match": False,
|
||||
},
|
||||
]
|
||||
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
for case in test_cases:
|
||||
# Reset mock for each test case
|
||||
mock_requests.reset_mock()
|
||||
|
||||
# Create credentials for this test case
|
||||
test_creds = HostScopedCredentials(
|
||||
provider="http",
|
||||
host=case["host_pattern"],
|
||||
headers={
|
||||
"Authorization": SecretStr(f"Bearer {case['host_pattern']}-token"),
|
||||
},
|
||||
title=f"Credentials for {case['host_pattern']}",
|
||||
)
|
||||
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url=case["test_url"],
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": test_creds.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": test_creds.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with test credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=test_creds,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify headers based on whether pattern should match
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
headers = call_args.kwargs["headers"]
|
||||
|
||||
if case["should_match"]:
|
||||
# Should include both user and credential headers
|
||||
expected_auth = f"Bearer {case['host_pattern']}-token"
|
||||
assert headers["Authorization"] == expected_auth
|
||||
assert headers["User-Agent"] == "test-agent"
|
||||
else:
|
||||
# Should only include user headers
|
||||
assert "Authorization" not in headers
|
||||
assert headers["User-Agent"] == "test-agent"
|
||||
@@ -25,12 +25,7 @@ async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Grap
|
||||
async def create_credentials(s: SpinTestServer, u: User):
|
||||
provider = ProviderName.OPENAI
|
||||
credentials = llm.TEST_CREDENTIALS
|
||||
try:
|
||||
await s.agent_server.test_create_credentials(u.id, provider, credentials)
|
||||
except Exception:
|
||||
# ValueErrors is raised trying to recreate the same credentials
|
||||
# so hidding the error
|
||||
pass
|
||||
return await s.agent_server.test_create_credentials(u.id, provider, credentials)
|
||||
|
||||
|
||||
async def execute_graph(
|
||||
@@ -60,19 +55,18 @@ async def execute_graph(
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
await create_credentials(server, test_user)
|
||||
creds = await create_credentials(server, test_user)
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
"credentials": creds,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
@@ -110,80 +104,18 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_validation_with_tool_nodes_raises_error(server: SpinTestServer):
|
||||
|
||||
test_user = await create_test_user()
|
||||
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
await create_credentials(server, test_user)
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=AgentExecutorBlock().id,
|
||||
input_default={
|
||||
"graph_id": test_tool_graph.id,
|
||||
"graph_version": test_tool_graph.version,
|
||||
"input_schema": test_tool_graph.input_schema,
|
||||
"output_schema": test_tool_graph.output_schema,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=StoreValueBlock().id,
|
||||
),
|
||||
]
|
||||
|
||||
links = [
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[1].id,
|
||||
source_name="tools_^_sample_tool_input_1",
|
||||
sink_name="input_1",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[1].id,
|
||||
source_name="tools_^_sample_tool_input_2",
|
||||
sink_name="input_2",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="tools_^_store_value_input",
|
||||
sink_name="input",
|
||||
),
|
||||
]
|
||||
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
await create_credentials(server, test_user)
|
||||
creds = await create_credentials(server, test_user)
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
"credentials": creds,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
@@ -15,7 +15,7 @@ from backend.blocks.zerobounce._auth import (
|
||||
ZeroBounceCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
@@ -90,7 +90,7 @@ class ValidateEmailsBlock(Block):
|
||||
description="IP address to validate",
|
||||
default="",
|
||||
)
|
||||
credentials: ZeroBounceCredentialsInput = SchemaField(
|
||||
credentials: ZeroBounceCredentialsInput = CredentialsField(
|
||||
description="ZeroBounce credentials",
|
||||
)
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ from dotenv import load_dotenv
|
||||
|
||||
from backend.util.logging import configure_logging
|
||||
|
||||
os.environ["ENABLE_AUTH"] = "false"
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# NOTE: You can run tests like with the --log-cli-level=INFO to see the logs
|
||||
5
autogpt_platform/backend/backend/data/__init__.py
Normal file
5
autogpt_platform/backend/backend/data/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .graph import NodeModel
|
||||
from .integrations import Webhook # noqa: F401
|
||||
|
||||
# Resolve Webhook <- NodeModel forward reference
|
||||
NodeModel.model_rebuild()
|
||||
@@ -78,6 +78,7 @@ class BlockCategory(Enum):
|
||||
PRODUCTIVITY = "Block that helps with productivity"
|
||||
ISSUE_TRACKING = "Block that helps with issue tracking"
|
||||
MULTIMEDIA = "Block that interacts with multimedia content"
|
||||
MARKETING = "Block that helps with marketing"
|
||||
|
||||
def dict(self) -> dict[str, str]:
|
||||
return {"category": self.name, "description": self.value}
|
||||
@@ -485,6 +486,22 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
raise ValueError(f"Block produced an invalid output data: {error}")
|
||||
yield output_name, output_data
|
||||
|
||||
def is_triggered_by_event_type(
|
||||
self, trigger_config: dict[str, Any], event_type: str
|
||||
) -> bool:
|
||||
if not self.webhook_config:
|
||||
raise TypeError("This method can't be used on non-trigger blocks")
|
||||
if not self.webhook_config.event_filter_input:
|
||||
return True
|
||||
event_filter = trigger_config.get(self.webhook_config.event_filter_input)
|
||||
if not event_filter:
|
||||
raise ValueError("Event filter is not configured on trigger")
|
||||
return event_type in [
|
||||
self.webhook_config.event_format.format(event=k)
|
||||
for k in event_filter
|
||||
if event_filter[k] is True
|
||||
]
|
||||
|
||||
|
||||
# ======================= Block Helper Functions ======================= #
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from backend.blocks.ai_music_generator import AIMusicGeneratorBlock
|
||||
from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
from backend.blocks.apollo.person import GetPersonDetailBlock
|
||||
from backend.blocks.flux_kontext import AIImageEditorBlock, FluxKontextModelName
|
||||
from backend.blocks.ideogram import IdeogramModelBlock
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
@@ -362,7 +363,31 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
],
|
||||
SearchPeopleBlock: [
|
||||
BlockCost(
|
||||
cost_amount=2,
|
||||
cost_amount=10,
|
||||
cost_filter={
|
||||
"enrich_info": False,
|
||||
"credentials": {
|
||||
"id": apollo_credentials.id,
|
||||
"provider": apollo_credentials.provider,
|
||||
"type": apollo_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
BlockCost(
|
||||
cost_amount=20,
|
||||
cost_filter={
|
||||
"enrich_info": True,
|
||||
"credentials": {
|
||||
"id": apollo_credentials.id,
|
||||
"provider": apollo_credentials.provider,
|
||||
"type": apollo_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
GetPersonDetailBlock: [
|
||||
BlockCost(
|
||||
cost_amount=1,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": apollo_credentials.id,
|
||||
|
||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel
|
||||
from redis.asyncio.client import PubSub as AsyncPubSub
|
||||
from redis.client import PubSub
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data import redis_client as redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,7 +32,7 @@ from prisma.types import (
|
||||
AgentNodeExecutionUpdateInput,
|
||||
AgentNodeExecutionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, JsonValue
|
||||
from pydantic.fields import Field
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
@@ -48,14 +48,14 @@ from .block import (
|
||||
get_webhook_block_ids,
|
||||
)
|
||||
from .db import BaseDbModel
|
||||
from .event_bus import AsyncRedisEventBus, RedisEventBus
|
||||
from .includes import (
|
||||
EXECUTION_RESULT_INCLUDE,
|
||||
EXECUTION_RESULT_ORDER,
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
graph_execution_include,
|
||||
)
|
||||
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
||||
from .queue import AsyncRedisEventBus, RedisEventBus
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -271,7 +271,7 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
graph_id=self.graph_id,
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
node_credentials_input_map={}, # FIXME
|
||||
nodes_input_masks={}, # FIXME: store credentials on AgentGraphExecution
|
||||
)
|
||||
|
||||
|
||||
@@ -588,12 +588,10 @@ async def update_graph_execution_start_time(
|
||||
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
update_data: AgentGraphExecutionUpdateManyMutationInput = {
|
||||
"executionStatus": status
|
||||
}
|
||||
update_data: AgentGraphExecutionUpdateManyMutationInput = {}
|
||||
|
||||
if stats:
|
||||
stats_dict = stats.model_dump()
|
||||
@@ -601,6 +599,9 @@ async def update_graph_execution_stats(
|
||||
stats_dict["error"] = str(stats_dict["error"])
|
||||
update_data["stats"] = Json(stats_dict)
|
||||
|
||||
if status:
|
||||
update_data["executionStatus"] = status
|
||||
|
||||
updated_count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
@@ -783,7 +784,7 @@ class GraphExecutionEntry(BaseModel):
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]]
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None
|
||||
|
||||
|
||||
class NodeExecutionEntry(BaseModel):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any, Literal, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
import prisma
|
||||
from prisma import Json
|
||||
@@ -14,7 +14,7 @@ from prisma.types import (
|
||||
AgentNodeLinkCreateInput,
|
||||
StoreListingVersionWhereInput,
|
||||
)
|
||||
from pydantic import create_model
|
||||
from pydantic import JsonValue, create_model
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
@@ -27,12 +27,15 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import type as type_utils
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
from .db import BaseDbModel, transaction
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
from .integrations import Webhook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .integrations import Webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -81,10 +84,12 @@ class NodeModel(Node):
|
||||
graph_version: int
|
||||
|
||||
webhook_id: Optional[str] = None
|
||||
webhook: Optional[Webhook] = None
|
||||
webhook: Optional["Webhook"] = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(node: AgentNode, for_export: bool = False) -> "NodeModel":
|
||||
from .integrations import Webhook
|
||||
|
||||
obj = NodeModel(
|
||||
id=node.id,
|
||||
block_id=node.agentBlockId,
|
||||
@@ -102,19 +107,7 @@ class NodeModel(Node):
|
||||
return obj
|
||||
|
||||
def is_triggered_by_event_type(self, event_type: str) -> bool:
|
||||
block = self.block
|
||||
if not block.webhook_config:
|
||||
raise TypeError("This method can't be used on non-webhook blocks")
|
||||
if not block.webhook_config.event_filter_input:
|
||||
return True
|
||||
event_filter = self.input_default.get(block.webhook_config.event_filter_input)
|
||||
if not event_filter:
|
||||
raise ValueError(f"Event filter is not configured on node #{self.id}")
|
||||
return event_type in [
|
||||
block.webhook_config.event_format.format(event=k)
|
||||
for k in event_filter
|
||||
if event_filter[k] is True
|
||||
]
|
||||
return self.block.is_triggered_by_event_type(self.input_default, event_type)
|
||||
|
||||
def stripped_for_export(self) -> "NodeModel":
|
||||
"""
|
||||
@@ -162,10 +155,6 @@ class NodeModel(Node):
|
||||
return result
|
||||
|
||||
|
||||
# Fix 2-way reference Node <-> Webhook
|
||||
Webhook.model_rebuild()
|
||||
|
||||
|
||||
class BaseGraph(BaseDbModel):
|
||||
version: int = 1
|
||||
is_active: bool = True
|
||||
@@ -255,6 +244,8 @@ class Graph(BaseGraph):
|
||||
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
||||
if field.provider != other_field.provider:
|
||||
continue
|
||||
if ProviderName.HTTP in field.provider:
|
||||
continue
|
||||
|
||||
# If this happens, that means a block implementation probably needs
|
||||
# to be updated.
|
||||
@@ -276,6 +267,7 @@ class Graph(BaseGraph):
|
||||
required_scopes=set(field_info.required_scopes or []),
|
||||
discriminator=field_info.discriminator,
|
||||
discriminator_mapping=field_info.discriminator_mapping,
|
||||
discriminator_values=field_info.discriminator_values,
|
||||
),
|
||||
)
|
||||
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
||||
@@ -294,37 +286,40 @@ class Graph(BaseGraph):
|
||||
Returns:
|
||||
dict[aggregated_field_key, tuple(
|
||||
CredentialsFieldInfo: A spec for one aggregated credentials field
|
||||
(now includes discriminator_values from matching nodes)
|
||||
set[(node_id, field_name)]: Node credentials fields that are
|
||||
compatible with this aggregated field spec
|
||||
)]
|
||||
"""
|
||||
return {
|
||||
"_".join(sorted(agg_field_info.provider))
|
||||
+ "_"
|
||||
+ "_".join(sorted(agg_field_info.supported_types))
|
||||
+ "_credentials": (agg_field_info, node_fields)
|
||||
for agg_field_info, node_fields in CredentialsFieldInfo.combine(
|
||||
*(
|
||||
(
|
||||
# Apply discrimination before aggregating credentials inputs
|
||||
(
|
||||
field_info.discriminate(
|
||||
node.input_default[field_info.discriminator]
|
||||
)
|
||||
if (
|
||||
field_info.discriminator
|
||||
and node.input_default.get(field_info.discriminator)
|
||||
)
|
||||
else field_info
|
||||
),
|
||||
(node.id, field_name),
|
||||
# First collect all credential field data with input defaults
|
||||
node_credential_data = []
|
||||
|
||||
for graph in [self] + self.sub_graphs:
|
||||
for node in graph.nodes:
|
||||
for (
|
||||
field_name,
|
||||
field_info,
|
||||
) in node.block.input_schema.get_credentials_fields_info().items():
|
||||
|
||||
discriminator = field_info.discriminator
|
||||
if not discriminator:
|
||||
node_credential_data.append((field_info, (node.id, field_name)))
|
||||
continue
|
||||
|
||||
discriminator_value = node.input_default.get(discriminator)
|
||||
if discriminator_value is None:
|
||||
node_credential_data.append((field_info, (node.id, field_name)))
|
||||
continue
|
||||
|
||||
discriminated_info = field_info.discriminate(discriminator_value)
|
||||
discriminated_info.discriminator_values.add(discriminator_value)
|
||||
|
||||
node_credential_data.append(
|
||||
(discriminated_info, (node.id, field_name))
|
||||
)
|
||||
for graph in [self] + self.sub_graphs
|
||||
for node in graph.nodes
|
||||
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
# Combine credential field info (this will merge discriminator_values automatically)
|
||||
return CredentialsFieldInfo.combine(*node_credential_data)
|
||||
|
||||
|
||||
class GraphModel(Graph):
|
||||
@@ -403,16 +398,26 @@ class GraphModel(Graph):
|
||||
continue
|
||||
node.input_default["user_id"] = user_id
|
||||
node.input_default.setdefault("inputs", {})
|
||||
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
|
||||
if (
|
||||
graph_id := node.input_default.get("graph_id")
|
||||
) and graph_id in graph_id_map:
|
||||
node.input_default["graph_id"] = graph_id_map[graph_id]
|
||||
|
||||
def validate_graph(self, for_run: bool = False):
|
||||
self._validate_graph(self, for_run)
|
||||
def validate_graph(
|
||||
self,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
):
|
||||
self._validate_graph(self, for_run, nodes_input_masks)
|
||||
for sub_graph in self.sub_graphs:
|
||||
self._validate_graph(sub_graph, for_run)
|
||||
self._validate_graph(sub_graph, for_run, nodes_input_masks)
|
||||
|
||||
@staticmethod
|
||||
def _validate_graph(graph: BaseGraph, for_run: bool = False):
|
||||
def _validate_graph(
|
||||
graph: BaseGraph,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
):
|
||||
def is_tool_pin(name: str) -> bool:
|
||||
return name.startswith("tools_^_")
|
||||
|
||||
@@ -439,20 +444,18 @@ class GraphModel(Graph):
|
||||
if (block := nodes_block.get(node.id)) is None:
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
node_input_mask = (
|
||||
nodes_input_masks.get(node.id, {}) if nodes_input_masks else {}
|
||||
)
|
||||
provided_inputs = set(
|
||||
[sanitize(name) for name in node.input_default]
|
||||
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
|
||||
+ ([name for name in node_input_mask] if node_input_mask else [])
|
||||
)
|
||||
InputSchema = block.input_schema
|
||||
for name in (required_fields := InputSchema.get_required_fields()):
|
||||
if (
|
||||
name not in provided_inputs
|
||||
# Webhook payload is passed in by ExecutionManager
|
||||
and not (
|
||||
name == "payload"
|
||||
and block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
)
|
||||
# Checking availability of credentials is done by ExecutionManager
|
||||
and name not in InputSchema.get_credentials_fields()
|
||||
# Validate only I/O nodes, or validate everything when executing
|
||||
@@ -485,10 +488,18 @@ class GraphModel(Graph):
|
||||
|
||||
def has_value(node: Node, name: str):
|
||||
return (
|
||||
name in node.input_default
|
||||
and node.input_default[name] is not None
|
||||
and str(node.input_default[name]).strip() != ""
|
||||
) or (name in input_fields and input_fields[name].default is not None)
|
||||
(
|
||||
name in node.input_default
|
||||
and node.input_default[name] is not None
|
||||
and str(node.input_default[name]).strip() != ""
|
||||
)
|
||||
or (name in input_fields and input_fields[name].default is not None)
|
||||
or (
|
||||
name in node_input_mask
|
||||
and node_input_mask[name] is not None
|
||||
and str(node_input_mask[name]).strip() != ""
|
||||
)
|
||||
)
|
||||
|
||||
# Validate dependencies between fields
|
||||
for field_name in input_fields.keys():
|
||||
@@ -574,7 +585,7 @@ class GraphModel(Graph):
|
||||
graph: AgentGraph,
|
||||
for_export: bool = False,
|
||||
sub_graphs: list[AgentGraph] | None = None,
|
||||
):
|
||||
) -> "GraphModel":
|
||||
return GraphModel(
|
||||
id=graph.id,
|
||||
user_id=graph.userId if not for_export else "",
|
||||
@@ -603,6 +614,7 @@ class GraphModel(Graph):
|
||||
|
||||
|
||||
async def get_node(node_id: str) -> NodeModel:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
node = await AgentNode.prisma().find_unique_or_raise(
|
||||
where={"id": node_id},
|
||||
include=AGENT_NODE_INCLUDE,
|
||||
@@ -611,6 +623,7 @@ async def get_node(node_id: str) -> NodeModel:
|
||||
|
||||
|
||||
async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
node = await AgentNode.prisma().update(
|
||||
where={"id": node_id},
|
||||
data=(
|
||||
|
||||
@@ -60,7 +60,8 @@ def graph_execution_include(
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE}
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE},
|
||||
"AgentPresets": {"include": {"InputPresets": True}},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,21 +1,25 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Optional
|
||||
from typing import AsyncGenerator, Literal, Optional, overload
|
||||
|
||||
from prisma import Json
|
||||
from prisma.models import IntegrationWebhook
|
||||
from prisma.types import IntegrationWebhookCreateInput
|
||||
from prisma.types import (
|
||||
IntegrationWebhookCreateInput,
|
||||
IntegrationWebhookUpdateInput,
|
||||
IntegrationWebhookWhereInput,
|
||||
Serializable,
|
||||
)
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
from backend.data.queue import AsyncRedisEventBus
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .db import BaseDbModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import NodeModel
|
||||
from .graph import NodeModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,8 +36,6 @@ class Webhook(BaseDbModel):
|
||||
|
||||
provider_webhook_id: str
|
||||
|
||||
attached_nodes: Optional[list["NodeModel"]] = None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
@@ -41,8 +43,6 @@ class Webhook(BaseDbModel):
|
||||
|
||||
@staticmethod
|
||||
def from_db(webhook: IntegrationWebhook):
|
||||
from .graph import NodeModel
|
||||
|
||||
return Webhook(
|
||||
id=webhook.id,
|
||||
user_id=webhook.userId,
|
||||
@@ -54,11 +54,26 @@ class Webhook(BaseDbModel):
|
||||
config=dict(webhook.config),
|
||||
secret=webhook.secret,
|
||||
provider_webhook_id=webhook.providerWebhookId,
|
||||
attached_nodes=(
|
||||
[NodeModel.from_db(node) for node in webhook.AgentNodes]
|
||||
if webhook.AgentNodes is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class WebhookWithRelations(Webhook):
|
||||
triggered_nodes: list[NodeModel]
|
||||
triggered_presets: list[LibraryAgentPreset]
|
||||
|
||||
@staticmethod
|
||||
def from_db(webhook: IntegrationWebhook):
|
||||
if webhook.AgentNodes is None or webhook.AgentPresets is None:
|
||||
raise ValueError(
|
||||
"AgentNodes and AgentPresets must be included in "
|
||||
"IntegrationWebhook query with relations"
|
||||
)
|
||||
return WebhookWithRelations(
|
||||
**Webhook.from_db(webhook).model_dump(),
|
||||
triggered_nodes=[NodeModel.from_db(node) for node in webhook.AgentNodes],
|
||||
triggered_presets=[
|
||||
LibraryAgentPreset.from_db(preset) for preset in webhook.AgentPresets
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -83,7 +98,19 @@ async def create_webhook(webhook: Webhook) -> Webhook:
|
||||
return Webhook.from_db(created_webhook)
|
||||
|
||||
|
||||
async def get_webhook(webhook_id: str) -> Webhook:
|
||||
@overload
|
||||
async def get_webhook(
|
||||
webhook_id: str, *, include_relations: Literal[True]
|
||||
) -> WebhookWithRelations: ...
|
||||
@overload
|
||||
async def get_webhook(
|
||||
webhook_id: str, *, include_relations: Literal[False] = False
|
||||
) -> Webhook: ...
|
||||
|
||||
|
||||
async def get_webhook(
|
||||
webhook_id: str, *, include_relations: bool = False
|
||||
) -> Webhook | WebhookWithRelations:
|
||||
"""
|
||||
⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints.
|
||||
|
||||
@@ -92,73 +119,113 @@ async def get_webhook(webhook_id: str) -> Webhook:
|
||||
"""
|
||||
webhook = await IntegrationWebhook.prisma().find_unique(
|
||||
where={"id": webhook_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE if include_relations else None,
|
||||
)
|
||||
if not webhook:
|
||||
raise NotFoundError(f"Webhook #{webhook_id} not found")
|
||||
return Webhook.from_db(webhook)
|
||||
return (WebhookWithRelations if include_relations else Webhook).from_db(webhook)
|
||||
|
||||
|
||||
async def get_all_webhooks_by_creds(credentials_id: str) -> list[Webhook]:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
@overload
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: Literal[True]
|
||||
) -> list[WebhookWithRelations]: ...
|
||||
@overload
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: Literal[False] = False
|
||||
) -> list[Webhook]: ...
|
||||
|
||||
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: bool = False
|
||||
) -> list[Webhook] | list[WebhookWithRelations]:
|
||||
if not credentials_id:
|
||||
raise ValueError("credentials_id must not be empty")
|
||||
webhooks = await IntegrationWebhook.prisma().find_many(
|
||||
where={"credentialsId": credentials_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
where={"userId": user_id, "credentialsId": credentials_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE if include_relations else None,
|
||||
)
|
||||
return [Webhook.from_db(webhook) for webhook in webhooks]
|
||||
return [
|
||||
(WebhookWithRelations if include_relations else Webhook).from_db(webhook)
|
||||
for webhook in webhooks
|
||||
]
|
||||
|
||||
|
||||
async def find_webhook_by_credentials_and_props(
|
||||
credentials_id: str, webhook_type: str, resource: str, events: list[str]
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
) -> Webhook | None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
webhook = await IntegrationWebhook.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"credentialsId": credentials_id,
|
||||
"webhookType": webhook_type,
|
||||
"resource": resource,
|
||||
"events": {"has_every": events},
|
||||
},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
)
|
||||
return Webhook.from_db(webhook) if webhook else None
|
||||
|
||||
|
||||
async def find_webhook_by_graph_and_props(
|
||||
graph_id: str, provider: str, webhook_type: str, events: list[str]
|
||||
user_id: str,
|
||||
provider: str,
|
||||
webhook_type: str,
|
||||
graph_id: Optional[str] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> Webhook | None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
"""Either `graph_id` or `preset_id` must be provided."""
|
||||
where_clause: IntegrationWebhookWhereInput = {
|
||||
"userId": user_id,
|
||||
"provider": provider,
|
||||
"webhookType": webhook_type,
|
||||
}
|
||||
|
||||
if preset_id:
|
||||
where_clause["AgentPresets"] = {"some": {"id": preset_id}}
|
||||
elif graph_id:
|
||||
where_clause["AgentNodes"] = {"some": {"agentGraphId": graph_id}}
|
||||
else:
|
||||
raise ValueError("Either graph_id or preset_id must be provided")
|
||||
|
||||
webhook = await IntegrationWebhook.prisma().find_first(
|
||||
where={
|
||||
"provider": provider,
|
||||
"webhookType": webhook_type,
|
||||
"events": {"has_every": events},
|
||||
"AgentNodes": {"some": {"agentGraphId": graph_id}},
|
||||
},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
where=where_clause,
|
||||
)
|
||||
return Webhook.from_db(webhook) if webhook else None
|
||||
|
||||
|
||||
async def update_webhook_config(webhook_id: str, updated_config: dict) -> Webhook:
|
||||
async def update_webhook(
|
||||
webhook_id: str,
|
||||
config: Optional[dict[str, Serializable]] = None,
|
||||
events: Optional[list[str]] = None,
|
||||
) -> Webhook:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
data: IntegrationWebhookUpdateInput = {}
|
||||
if config is not None:
|
||||
data["config"] = Json(config)
|
||||
if events is not None:
|
||||
data["events"] = events
|
||||
if not data:
|
||||
raise ValueError("Empty update query")
|
||||
|
||||
_updated_webhook = await IntegrationWebhook.prisma().update(
|
||||
where={"id": webhook_id},
|
||||
data={"config": Json(updated_config)},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
data=data,
|
||||
)
|
||||
if _updated_webhook is None:
|
||||
raise ValueError(f"Webhook #{webhook_id} not found")
|
||||
raise NotFoundError(f"Webhook #{webhook_id} not found")
|
||||
return Webhook.from_db(_updated_webhook)
|
||||
|
||||
|
||||
async def delete_webhook(webhook_id: str) -> None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
deleted = await IntegrationWebhook.prisma().delete(where={"id": webhook_id})
|
||||
if not deleted:
|
||||
raise ValueError(f"Webhook #{webhook_id} not found")
|
||||
async def delete_webhook(user_id: str, webhook_id: str) -> None:
|
||||
deleted = await IntegrationWebhook.prisma().delete_many(
|
||||
where={"id": webhook_id, "userId": user_id}
|
||||
)
|
||||
if deleted < 1:
|
||||
raise NotFoundError(f"Webhook #{webhook_id} not found")
|
||||
|
||||
|
||||
# --------------------- WEBHOOK EVENTS --------------------- #
|
||||
|
||||
@@ -14,11 +14,12 @@ from typing import (
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
cast,
|
||||
get_args,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
from prisma.enums import CreditTransactionType
|
||||
@@ -240,13 +241,65 @@ class UserPasswordCredentials(_BaseCredentials):
|
||||
return f"Basic {base64.b64encode(f'{self.username.get_secret_value()}:{self.password.get_secret_value()}'.encode()).decode()}"
|
||||
|
||||
|
||||
class HostScopedCredentials(_BaseCredentials):
|
||||
type: Literal["host_scoped"] = "host_scoped"
|
||||
host: str = Field(description="The host/URI pattern to match against request URLs")
|
||||
headers: dict[str, SecretStr] = Field(
|
||||
description="Key-value header map to add to matching requests",
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
def _extract_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]:
|
||||
"""Helper to extract secret values from headers."""
|
||||
return {key: value.get_secret_value() for key, value in headers.items()}
|
||||
|
||||
@field_serializer("headers")
|
||||
def serialize_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]:
|
||||
"""Serialize headers by extracting secret values."""
|
||||
return self._extract_headers(headers)
|
||||
|
||||
def get_headers_dict(self) -> dict[str, str]:
|
||||
"""Get headers with secret values extracted."""
|
||||
return self._extract_headers(self.headers)
|
||||
|
||||
def auth_header(self) -> str:
|
||||
"""Get authorization header for backward compatibility."""
|
||||
auth_headers = self.get_headers_dict()
|
||||
if "Authorization" in auth_headers:
|
||||
return auth_headers["Authorization"]
|
||||
return ""
|
||||
|
||||
def matches_url(self, url: str) -> bool:
|
||||
"""Check if this credential should be applied to the given URL."""
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
# Extract hostname without port
|
||||
request_host = parsed_url.hostname
|
||||
if not request_host:
|
||||
return False
|
||||
|
||||
# Simple host matching - exact match or wildcard subdomain match
|
||||
if self.host == request_host:
|
||||
return True
|
||||
|
||||
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
||||
if self.host.startswith("*."):
|
||||
domain = self.host[2:] # Remove "*."
|
||||
return request_host.endswith(f".{domain}") or request_host == domain
|
||||
|
||||
return False
|
||||
|
||||
|
||||
Credentials = Annotated[
|
||||
OAuth2Credentials | APIKeyCredentials | UserPasswordCredentials,
|
||||
OAuth2Credentials
|
||||
| APIKeyCredentials
|
||||
| UserPasswordCredentials
|
||||
| HostScopedCredentials,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
CredentialsType = Literal["api_key", "oauth2", "user_password"]
|
||||
CredentialsType = Literal["api_key", "oauth2", "user_password", "host_scoped"]
|
||||
|
||||
|
||||
class OAuthState(BaseModel):
|
||||
@@ -320,15 +373,29 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
|
||||
schema["credentials_provider"] = cls.allowed_providers()
|
||||
schema["credentials_types"] = cls.allowed_cred_types()
|
||||
def _add_json_schema_extra(schema: dict, model_class: type):
|
||||
# Use model_class for allowed_providers/cred_types
|
||||
if hasattr(model_class, "allowed_providers") and hasattr(
|
||||
model_class, "allowed_cred_types"
|
||||
):
|
||||
schema["credentials_provider"] = model_class.allowed_providers()
|
||||
schema["credentials_types"] = model_class.allowed_cred_types()
|
||||
# Do not return anything, just mutate schema in place
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=_add_json_schema_extra, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def _extract_host_from_url(url: str) -> str:
|
||||
"""Extract host from URL for grouping host-scoped credentials."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return parsed.hostname or url
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
|
||||
provider: frozenset[CP] = Field(..., alias="credentials_provider")
|
||||
@@ -336,11 +403,12 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
|
||||
discriminator: Optional[str] = None
|
||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||
discriminator_values: set[Any] = Field(default_factory=set)
|
||||
|
||||
@classmethod
|
||||
def combine(
|
||||
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
|
||||
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
|
||||
) -> dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
|
||||
"""
|
||||
Combines multiple CredentialsFieldInfo objects into as few as possible.
|
||||
|
||||
@@ -358,22 +426,36 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
the set of keys of the respective original items that were grouped together.
|
||||
"""
|
||||
if not fields:
|
||||
return []
|
||||
return {}
|
||||
|
||||
# Group fields by their provider and supported_types
|
||||
# For HTTP host-scoped credentials, also group by host
|
||||
grouped_fields: defaultdict[
|
||||
tuple[frozenset[CP], frozenset[CT]],
|
||||
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
|
||||
] = defaultdict(list)
|
||||
|
||||
for field, key in fields:
|
||||
group_key = (frozenset(field.provider), frozenset(field.supported_types))
|
||||
if field.provider == frozenset([ProviderName.HTTP]):
|
||||
# HTTP host-scoped credentials can have different hosts that reqires different credential sets.
|
||||
# Group by host extracted from the URL
|
||||
providers = frozenset(
|
||||
[cast(CP, "http")]
|
||||
+ [
|
||||
cast(CP, _extract_host_from_url(str(value)))
|
||||
for value in field.discriminator_values
|
||||
]
|
||||
)
|
||||
else:
|
||||
providers = frozenset(field.provider)
|
||||
|
||||
group_key = (providers, frozenset(field.supported_types))
|
||||
grouped_fields[group_key].append((key, field))
|
||||
|
||||
# Combine fields within each group
|
||||
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
|
||||
result: dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]] = {}
|
||||
|
||||
for group in grouped_fields.values():
|
||||
for key, group in grouped_fields.items():
|
||||
# Start with the first field in the group
|
||||
_, combined = group[0]
|
||||
|
||||
@@ -386,18 +468,32 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
if field.required_scopes:
|
||||
all_scopes.update(field.required_scopes)
|
||||
|
||||
# Create a new combined field
|
||||
result.append(
|
||||
(
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
credentials_types=combined.supported_types,
|
||||
credentials_scopes=frozenset(all_scopes) or None,
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
),
|
||||
combined_keys,
|
||||
)
|
||||
# Combine discriminator_values from all fields in the group (removing duplicates)
|
||||
all_discriminator_values = []
|
||||
for _, field in group:
|
||||
for value in field.discriminator_values:
|
||||
if value not in all_discriminator_values:
|
||||
all_discriminator_values.append(value)
|
||||
|
||||
# Generate the key for the combined result
|
||||
providers_key, supported_types_key = key
|
||||
group_key = (
|
||||
"-".join(sorted(providers_key))
|
||||
+ "_"
|
||||
+ "-".join(sorted(supported_types_key))
|
||||
+ "_credentials"
|
||||
)
|
||||
|
||||
result[group_key] = (
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
credentials_types=combined.supported_types,
|
||||
credentials_scopes=frozenset(all_scopes) or None,
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
discriminator_values=set(all_discriminator_values),
|
||||
),
|
||||
combined_keys,
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -406,11 +502,15 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
if not (self.discriminator and self.discriminator_mapping):
|
||||
return self
|
||||
|
||||
discriminator_value = self.discriminator_mapping[discriminator_value]
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([discriminator_value]),
|
||||
credentials_provider=frozenset(
|
||||
[self.discriminator_mapping[discriminator_value]]
|
||||
),
|
||||
credentials_types=self.supported_types,
|
||||
credentials_scopes=self.required_scopes,
|
||||
discriminator=self.discriminator,
|
||||
discriminator_mapping=self.discriminator_mapping,
|
||||
discriminator_values=self.discriminator_values,
|
||||
)
|
||||
|
||||
|
||||
@@ -419,6 +519,7 @@ def CredentialsField(
|
||||
*,
|
||||
discriminator: Optional[str] = None,
|
||||
discriminator_mapping: Optional[dict[str, Any]] = None,
|
||||
discriminator_values: Optional[set[Any]] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
**kwargs,
|
||||
@@ -434,6 +535,7 @@ def CredentialsField(
|
||||
"credentials_scopes": list(required_scopes) or None,
|
||||
"discriminator": discriminator,
|
||||
"discriminator_mapping": discriminator_mapping,
|
||||
"discriminator_values": discriminator_values,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
143
autogpt_platform/backend/backend/data/model_test.py
Normal file
143
autogpt_platform/backend/backend/data/model_test.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import HostScopedCredentials
|
||||
|
||||
|
||||
class TestHostScopedCredentials:
|
||||
def test_host_scoped_credentials_creation(self):
|
||||
"""Test creating HostScopedCredentials with required fields."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer secret-token"),
|
||||
"X-API-Key": SecretStr("api-key-123"),
|
||||
},
|
||||
title="Example API Credentials",
|
||||
)
|
||||
|
||||
assert creds.type == "host_scoped"
|
||||
assert creds.provider == "custom"
|
||||
assert creds.host == "api.example.com"
|
||||
assert creds.title == "Example API Credentials"
|
||||
assert len(creds.headers) == 2
|
||||
assert "Authorization" in creds.headers
|
||||
assert "X-API-Key" in creds.headers
|
||||
|
||||
def test_get_headers_dict(self):
|
||||
"""Test getting headers with secret values extracted."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer secret-token"),
|
||||
"X-Custom-Header": SecretStr("custom-value"),
|
||||
},
|
||||
)
|
||||
|
||||
headers_dict = creds.get_headers_dict()
|
||||
|
||||
assert headers_dict == {
|
||||
"Authorization": "Bearer secret-token",
|
||||
"X-Custom-Header": "custom-value",
|
||||
}
|
||||
|
||||
def test_matches_url_exact_host(self):
|
||||
"""Test URL matching with exact host match."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("https://api.example.com/v1/data")
|
||||
assert creds.matches_url("http://api.example.com/endpoint")
|
||||
assert not creds.matches_url("https://other.example.com/v1/data")
|
||||
assert not creds.matches_url("https://subdomain.api.example.com/v1/data")
|
||||
|
||||
def test_matches_url_wildcard_subdomain(self):
|
||||
"""Test URL matching with wildcard subdomain pattern."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="*.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("https://api.example.com/v1/data")
|
||||
assert creds.matches_url("https://subdomain.example.com/endpoint")
|
||||
assert creds.matches_url("https://deep.nested.example.com/path")
|
||||
assert creds.matches_url("https://example.com/path") # Base domain should match
|
||||
assert not creds.matches_url("https://example.org/v1/data")
|
||||
assert not creds.matches_url("https://notexample.com/v1/data")
|
||||
|
||||
def test_matches_url_with_port_and_path(self):
|
||||
"""Test URL matching with ports and paths."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="localhost",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
||||
assert creds.matches_url("http://localhost/simple")
|
||||
|
||||
def test_empty_headers_dict(self):
|
||||
"""Test HostScopedCredentials with empty headers."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom", host="api.example.com", headers={}
|
||||
)
|
||||
|
||||
assert creds.get_headers_dict() == {}
|
||||
assert creds.matches_url("https://api.example.com/test")
|
||||
|
||||
def test_credential_serialization(self):
|
||||
"""Test that credentials can be serialized/deserialized properly."""
|
||||
original_creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer secret-token"),
|
||||
"X-API-Key": SecretStr("api-key-123"),
|
||||
},
|
||||
title="Test Credentials",
|
||||
)
|
||||
|
||||
# Serialize to dict (simulating storage)
|
||||
serialized = original_creds.model_dump()
|
||||
|
||||
# Deserialize back
|
||||
restored_creds = HostScopedCredentials.model_validate(serialized)
|
||||
|
||||
assert restored_creds.id == original_creds.id
|
||||
assert restored_creds.provider == original_creds.provider
|
||||
assert restored_creds.host == original_creds.host
|
||||
assert restored_creds.title == original_creds.title
|
||||
assert restored_creds.type == "host_scoped"
|
||||
|
||||
# Check that headers are properly restored
|
||||
assert restored_creds.get_headers_dict() == original_creds.get_headers_dict()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"host,test_url,expected",
|
||||
[
|
||||
("api.example.com", "https://api.example.com/test", True),
|
||||
("api.example.com", "https://different.example.com/test", False),
|
||||
("*.example.com", "https://api.example.com/test", True),
|
||||
("*.example.com", "https://sub.api.example.com/test", True),
|
||||
("*.example.com", "https://example.com/test", True),
|
||||
("*.example.com", "https://example.org/test", False),
|
||||
("localhost", "http://localhost:3000/test", True),
|
||||
("localhost", "http://127.0.0.1:3000/test", False),
|
||||
],
|
||||
)
|
||||
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
||||
"""Parametrized test for various URL matching scenarios."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="test",
|
||||
host=host,
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url(test_url) == expected
|
||||
@@ -12,14 +12,11 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from pydantic import JsonValue
|
||||
from redis.asyncio.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
GraphExecutionStats,
|
||||
NodeExecutionStats,
|
||||
)
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
@@ -38,7 +35,7 @@ from autogpt_libs.utils.cache import thread_cached
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import (
|
||||
BlockData,
|
||||
BlockInput,
|
||||
@@ -138,9 +135,7 @@ async def execute_node(
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecutionEntry,
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
@@ -183,8 +178,8 @@ async def execute_node(
|
||||
if isinstance(node_block, AgentExecutorBlock):
|
||||
_input_data = AgentExecutorBlock.Input(**node.input_default)
|
||||
_input_data.inputs = input_data
|
||||
if node_credentials_input_map:
|
||||
_input_data.node_credentials_input_map = node_credentials_input_map
|
||||
if nodes_input_masks:
|
||||
_input_data.nodes_input_masks = nodes_input_masks
|
||||
input_data = _input_data.model_dump()
|
||||
data.inputs = input_data
|
||||
|
||||
@@ -255,7 +250,7 @@ async def _enqueue_next_nodes(
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]],
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
) -> list[NodeExecutionEntry]:
|
||||
async def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
||||
@@ -326,14 +321,12 @@ async def _enqueue_next_nodes(
|
||||
for name in static_link_names:
|
||||
next_node_input[name] = latest_execution.input_data.get(name)
|
||||
|
||||
# Apply node credentials overrides
|
||||
node_credentials = None
|
||||
if node_credentials_input_map and (
|
||||
node_credentials := node_credentials_input_map.get(next_node.id)
|
||||
# Apply node input overrides
|
||||
node_input_mask = None
|
||||
if nodes_input_masks and (
|
||||
node_input_mask := nodes_input_masks.get(next_node.id)
|
||||
):
|
||||
next_node_input.update(
|
||||
{k: v.model_dump() for k, v in node_credentials.items()}
|
||||
)
|
||||
next_node_input.update(node_input_mask)
|
||||
|
||||
# Validate the input data for the next node.
|
||||
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
|
||||
@@ -377,11 +370,9 @@ async def _enqueue_next_nodes(
|
||||
for input_name in static_link_names:
|
||||
idata[input_name] = next_node_input[input_name]
|
||||
|
||||
# Apply node credentials overrides
|
||||
if node_credentials:
|
||||
idata.update(
|
||||
{k: v.model_dump() for k, v in node_credentials.items()}
|
||||
)
|
||||
# Apply node input overrides
|
||||
if node_input_mask:
|
||||
idata.update(node_input_mask)
|
||||
|
||||
idata, msg = validate_exec(next_node, idata)
|
||||
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
|
||||
@@ -430,14 +421,12 @@ class Executor:
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@async_error_logged
|
||||
@async_error_logged(swallow=True)
|
||||
async def on_node_execution(
|
||||
cls,
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -458,7 +447,7 @@ class Executor:
|
||||
db_client=db_client,
|
||||
log_metadata=log_metadata,
|
||||
stats=execution_stats,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
execution_stats.walltime = timing_info.wall_time
|
||||
execution_stats.cputime = timing_info.cpu_time
|
||||
@@ -481,9 +470,7 @@ class Executor:
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
stats: NodeExecutionStats | None = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
):
|
||||
try:
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
@@ -498,7 +485,7 @@ class Executor:
|
||||
creds_manager=cls.creds_manager,
|
||||
data=node_exec,
|
||||
execution_stats=stats,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
):
|
||||
node_exec_progress.add_output(
|
||||
ExecutionOutputEntry(
|
||||
@@ -542,7 +529,7 @@ class Executor:
|
||||
logger.info(f"[GraphExecutor] {cls.pid} started")
|
||||
|
||||
@classmethod
|
||||
@error_logged
|
||||
@error_logged(swallow=False)
|
||||
def on_graph_execution(
|
||||
cls, graph_exec: GraphExecutionEntry, cancel: threading.Event
|
||||
):
|
||||
@@ -594,6 +581,15 @@ class Executor:
|
||||
exec_stats.cputime += timing_info.cpu_time
|
||||
exec_stats.error = str(error) if error else exec_stats.error
|
||||
|
||||
if status not in {
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
ExecutionStatus.FAILED,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Graph Execution #{graph_exec.graph_exec_id} ended with unexpected status {status}"
|
||||
)
|
||||
|
||||
if graph_exec_result := db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
status=status,
|
||||
@@ -697,7 +693,6 @@ class Executor:
|
||||
|
||||
if _graph_exec := db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
status=execution_status,
|
||||
stats=execution_stats,
|
||||
):
|
||||
send_execution_update(_graph_exec)
|
||||
@@ -779,24 +774,19 @@ class Executor:
|
||||
)
|
||||
raise
|
||||
|
||||
# Add credential overrides -----------------------------
|
||||
# Add input overrides -----------------------------
|
||||
node_id = queued_node_exec.node_id
|
||||
if (node_creds_map := graph_exec.node_credentials_input_map) and (
|
||||
node_field_creds_map := node_creds_map.get(node_id)
|
||||
if (nodes_input_masks := graph_exec.nodes_input_masks) and (
|
||||
node_input_mask := nodes_input_masks.get(node_id)
|
||||
):
|
||||
queued_node_exec.inputs.update(
|
||||
{
|
||||
field_name: creds_meta.model_dump()
|
||||
for field_name, creds_meta in node_field_creds_map.items()
|
||||
}
|
||||
)
|
||||
queued_node_exec.inputs.update(node_input_mask)
|
||||
|
||||
# Kick off async node execution -------------------------
|
||||
node_execution_task = asyncio.run_coroutine_threadsafe(
|
||||
cls.on_node_execution(
|
||||
node_exec=queued_node_exec,
|
||||
node_exec_progress=running_node_execution[node_id],
|
||||
node_credentials_input_map=node_creds_map,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
),
|
||||
cls.node_execution_loop,
|
||||
)
|
||||
@@ -840,7 +830,7 @@ class Executor:
|
||||
node_id=node_id,
|
||||
graph_exec=graph_exec,
|
||||
log_metadata=log_metadata,
|
||||
node_creds_map=node_creds_map,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
execution_queue=execution_queue,
|
||||
),
|
||||
cls.node_evaluation_loop,
|
||||
@@ -910,7 +900,7 @@ class Executor:
|
||||
node_id: str,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
log_metadata: LogMetadata,
|
||||
node_creds_map: Optional[dict[str, dict[str, CredentialsMetaInput]]],
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
execution_queue: ExecutionQueue[NodeExecutionEntry],
|
||||
) -> None:
|
||||
"""Process a node's output, update its status, and enqueue next nodes.
|
||||
@@ -920,7 +910,7 @@ class Executor:
|
||||
node_id: The ID of the node that produced the output
|
||||
graph_exec: The graph execution entry
|
||||
log_metadata: Logger metadata for consistent logging
|
||||
node_creds_map: Optional map of node credentials
|
||||
nodes_input_masks: Optional map of node input overrides
|
||||
execution_queue: Queue to add next executions to
|
||||
"""
|
||||
db_client = get_db_async_client()
|
||||
@@ -944,7 +934,7 @@ class Executor:
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
log_metadata=log_metadata,
|
||||
node_credentials_input_map=node_creds_map,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
):
|
||||
execution_queue.add(next_execution)
|
||||
except Exception as e:
|
||||
|
||||
@@ -366,14 +366,13 @@ async def test_execute_preset(server: SpinTestServer):
|
||||
"dictionary": {"key1": "Hello", "key2": "World"},
|
||||
"selected_value": "key2",
|
||||
},
|
||||
credentials={},
|
||||
is_active=True,
|
||||
)
|
||||
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
|
||||
|
||||
# Execute preset with overriding values
|
||||
result = await server.agent_server.test_execute_preset(
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
preset_id=created_preset.id,
|
||||
user_id=test_user.id,
|
||||
)
|
||||
@@ -455,16 +454,15 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||
"dictionary": {"key1": "Hello", "key2": "World"},
|
||||
"selected_value": "key2",
|
||||
},
|
||||
credentials={},
|
||||
is_active=True,
|
||||
)
|
||||
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
|
||||
|
||||
# Execute preset with overriding values
|
||||
result = await server.agent_server.test_execute_preset(
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
preset_id=created_preset.id,
|
||||
node_input={"selected_value": "key1"},
|
||||
inputs={"selected_value": "key1"},
|
||||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
|
||||
@@ -14,13 +15,16 @@ from apscheduler.triggers.cron import CronTrigger
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from dotenv import load_dotenv
|
||||
from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.logging import PrefixFilter
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
@@ -52,19 +56,19 @@ def _extract_schema_from_url(database_url) -> tuple[str, str]:
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.addFilter(PrefixFilter("[Scheduler]"))
|
||||
apscheduler_logger = logger.getChild("apscheduler")
|
||||
apscheduler_logger.addFilter(PrefixFilter("[Scheduler] [APScheduler]"))
|
||||
|
||||
config = Config()
|
||||
|
||||
|
||||
def log(msg, **kwargs):
|
||||
logger.info("[Scheduler] " + msg, **kwargs)
|
||||
|
||||
|
||||
def job_listener(event):
|
||||
"""Logs job execution outcomes for better monitoring."""
|
||||
if event.exception:
|
||||
log(f"Job {event.job_id} failed.")
|
||||
logger.error(f"Job {event.job_id} failed.")
|
||||
else:
|
||||
log(f"Job {event.job_id} completed successfully.")
|
||||
logger.info(f"Job {event.job_id} completed successfully.")
|
||||
|
||||
|
||||
@thread_cached
|
||||
@@ -84,16 +88,17 @@ def execute_graph(**kwargs):
|
||||
async def _execute_graph(**kwargs):
|
||||
args = GraphExecutionJobArgs(**kwargs)
|
||||
try:
|
||||
log(f"Executing recurring job for graph #{args.graph_id}")
|
||||
logger.info(f"Executing recurring job for graph #{args.graph_id}")
|
||||
await execution_utils.add_graph_execution(
|
||||
graph_id=args.graph_id,
|
||||
inputs=args.input_data,
|
||||
user_id=args.user_id,
|
||||
graph_id=args.graph_id,
|
||||
graph_version=args.graph_version,
|
||||
inputs=args.input_data,
|
||||
graph_credentials_inputs=args.input_credentials,
|
||||
use_db_query=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing graph {args.graph_id}: {e}")
|
||||
logger.error(f"Error executing graph {args.graph_id}: {e}")
|
||||
|
||||
|
||||
class LateExecutionException(Exception):
|
||||
@@ -137,20 +142,20 @@ def report_late_executions() -> str:
|
||||
def process_existing_batches(**kwargs):
|
||||
args = NotificationJobArgs(**kwargs)
|
||||
try:
|
||||
log(
|
||||
logger.info(
|
||||
f"Processing existing batches for notification type {args.notification_types}"
|
||||
)
|
||||
get_notification_client().process_existing_batches(args.notification_types)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing existing batches: {e}")
|
||||
logger.error(f"Error processing existing batches: {e}")
|
||||
|
||||
|
||||
def process_weekly_summary(**kwargs):
|
||||
try:
|
||||
log("Processing weekly summary")
|
||||
logger.info("Processing weekly summary")
|
||||
get_notification_client().queue_weekly_summary()
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing weekly summary: {e}")
|
||||
logger.error(f"Error processing weekly summary: {e}")
|
||||
|
||||
|
||||
class Jobstores(Enum):
|
||||
@@ -160,11 +165,12 @@ class Jobstores(Enum):
|
||||
|
||||
|
||||
class GraphExecutionJobArgs(BaseModel):
|
||||
graph_id: str
|
||||
input_data: BlockInput
|
||||
user_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
cron: str
|
||||
input_data: BlockInput
|
||||
input_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class GraphExecutionJobInfo(GraphExecutionJobArgs):
|
||||
@@ -247,7 +253,8 @@ class Scheduler(AppService):
|
||||
),
|
||||
# These don't really need persistence
|
||||
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
|
||||
}
|
||||
},
|
||||
logger=apscheduler_logger,
|
||||
)
|
||||
|
||||
if self.register_system_tasks:
|
||||
@@ -285,34 +292,40 @@ class Scheduler(AppService):
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down scheduler...")
|
||||
logger.info("⏳ Shutting down scheduler...")
|
||||
if self.scheduler:
|
||||
self.scheduler.shutdown(wait=False)
|
||||
|
||||
@expose
|
||||
def add_graph_execution_schedule(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
cron: str,
|
||||
input_data: BlockInput,
|
||||
user_id: str,
|
||||
input_credentials: dict[str, CredentialsMetaInput],
|
||||
name: Optional[str] = None,
|
||||
) -> GraphExecutionJobInfo:
|
||||
job_args = GraphExecutionJobArgs(
|
||||
graph_id=graph_id,
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
cron=cron,
|
||||
input_data=input_data,
|
||||
input_credentials=input_credentials,
|
||||
)
|
||||
job = self.scheduler.add_job(
|
||||
execute_graph,
|
||||
CronTrigger.from_crontab(cron),
|
||||
kwargs=job_args.model_dump(),
|
||||
replace_existing=True,
|
||||
name=name,
|
||||
trigger=CronTrigger.from_crontab(cron),
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}"
|
||||
)
|
||||
log(f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}")
|
||||
return GraphExecutionJobInfo.from_db(job_args, job)
|
||||
|
||||
@expose
|
||||
@@ -321,14 +334,13 @@ class Scheduler(AppService):
|
||||
) -> GraphExecutionJobInfo:
|
||||
job = self.scheduler.get_job(schedule_id, jobstore=Jobstores.EXECUTION.value)
|
||||
if not job:
|
||||
log(f"Job {schedule_id} not found.")
|
||||
raise ValueError(f"Job #{schedule_id} not found.")
|
||||
raise NotFoundError(f"Job #{schedule_id} not found.")
|
||||
|
||||
job_args = GraphExecutionJobArgs(**job.kwargs)
|
||||
if job_args.user_id != user_id:
|
||||
raise ValueError("User ID does not match the job's user ID.")
|
||||
raise NotAuthorizedError("User ID does not match the job's user ID")
|
||||
|
||||
log(f"Deleting job {schedule_id}")
|
||||
logger.info(f"Deleting job {schedule_id}")
|
||||
job.remove()
|
||||
|
||||
return GraphExecutionJobInfo.from_db(job_args, job)
|
||||
|
||||
@@ -27,6 +27,7 @@ async def test_agent_schedule(server: SpinTestServer):
|
||||
graph_version=1,
|
||||
cron="0 0 * * *",
|
||||
input_data={"input": "data"},
|
||||
input_credentials={},
|
||||
)
|
||||
assert schedule
|
||||
|
||||
@@ -5,7 +5,7 @@ from concurrent.futures import Future
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, JsonValue
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
@@ -435,9 +435,7 @@ def validate_exec(
|
||||
async def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
@@ -453,11 +451,13 @@ async def _validate_node_input_credentials(
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
if (
|
||||
node_credentials_input_map
|
||||
and (node_credentials_inputs := node_credentials_input_map.get(node.id))
|
||||
and field_name in node_credentials_inputs
|
||||
nodes_input_masks
|
||||
and (node_input_mask := nodes_input_masks.get(node.id))
|
||||
and field_name in node_input_mask
|
||||
):
|
||||
credentials_meta = node_credentials_input_map[node.id][field_name]
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node_input_mask[field_name]
|
||||
)
|
||||
elif field_name in node.input_default:
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
@@ -496,7 +496,7 @@ async def _validate_node_input_credentials(
|
||||
def make_node_credentials_input_map(
|
||||
graph: GraphModel,
|
||||
graph_credentials_input: dict[str, CredentialsMetaInput],
|
||||
) -> dict[str, dict[str, CredentialsMetaInput]]:
|
||||
) -> dict[str, dict[str, JsonValue]]:
|
||||
"""
|
||||
Maps credentials for an execution to the correct nodes.
|
||||
|
||||
@@ -505,9 +505,9 @@ def make_node_credentials_input_map(
|
||||
graph_credentials_input: A (graph_input_name, credentials_meta) map.
|
||||
|
||||
Returns:
|
||||
dict[node_id, dict[field_name, CredentialsMetaInput]]: Node credentials input map.
|
||||
dict[node_id, dict[field_name, CredentialsMetaRaw]]: Node credentials input map.
|
||||
"""
|
||||
result: dict[str, dict[str, CredentialsMetaInput]] = {}
|
||||
result: dict[str, dict[str, JsonValue]] = {}
|
||||
|
||||
# Get aggregated credentials fields for the graph
|
||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||
@@ -521,7 +521,9 @@ def make_node_credentials_input_map(
|
||||
for node_id, node_field_name in compatible_node_fields:
|
||||
if node_id not in result:
|
||||
result[node_id] = {}
|
||||
result[node_id][node_field_name] = graph_credentials_input[graph_input_name]
|
||||
result[node_id][node_field_name] = graph_credentials_input[
|
||||
graph_input_name
|
||||
].model_dump(exclude_none=True)
|
||||
|
||||
return result
|
||||
|
||||
@@ -530,9 +532,7 @@ async def construct_node_execution_input(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
@@ -550,8 +550,8 @@ async def construct_node_execution_input(
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||
the corresponding input data for that node.
|
||||
"""
|
||||
graph.validate_graph(for_run=True)
|
||||
await _validate_node_input_credentials(graph, user_id, node_credentials_input_map)
|
||||
graph.validate_graph(for_run=True, nodes_input_masks=nodes_input_masks)
|
||||
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
@@ -568,23 +568,9 @@ async def construct_node_execution_input(
|
||||
if input_name and input_name in graph_inputs:
|
||||
input_data = {"value": graph_inputs[input_name]}
|
||||
|
||||
# Extract webhook payload, and assign it to the input pin
|
||||
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
||||
if (
|
||||
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
and node.webhook_id
|
||||
):
|
||||
if webhook_payload_key not in graph_inputs:
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} webhook payload is missing"
|
||||
)
|
||||
input_data = {"payload": graph_inputs[webhook_payload_key]}
|
||||
|
||||
# Apply node credentials overrides
|
||||
if node_credentials_input_map and (
|
||||
node_credentials := node_credentials_input_map.get(node.id)
|
||||
):
|
||||
input_data.update({k: v.model_dump() for k, v in node_credentials.items()})
|
||||
# Apply node input overrides
|
||||
if nodes_input_masks and (node_input_mask := nodes_input_masks.get(node.id)):
|
||||
input_data.update(node_input_mask)
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
@@ -600,6 +586,20 @@ async def construct_node_execution_input(
|
||||
return nodes_input
|
||||
|
||||
|
||||
def _merge_nodes_input_masks(
|
||||
overrides_map_1: dict[str, dict[str, JsonValue]],
|
||||
overrides_map_2: dict[str, dict[str, JsonValue]],
|
||||
) -> dict[str, dict[str, JsonValue]]:
|
||||
"""Perform a per-node merge of input overrides"""
|
||||
result = overrides_map_1.copy()
|
||||
for node_id, overrides2 in overrides_map_2.items():
|
||||
if node_id in result:
|
||||
result[node_id] = {**result[node_id], **overrides2}
|
||||
else:
|
||||
result[node_id] = overrides2
|
||||
return result
|
||||
|
||||
|
||||
# ============ Execution Queue Helpers ============ #
|
||||
|
||||
|
||||
@@ -730,13 +730,11 @@ async def stop_graph_execution(
|
||||
async def add_graph_execution(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
inputs: BlockInput,
|
||||
inputs: Optional[BlockInput] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
use_db_query: bool = True,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
@@ -750,7 +748,7 @@ async def add_graph_execution(
|
||||
graph_version: The version of the graph to execute.
|
||||
graph_credentials_inputs: Credentials inputs to use in the execution.
|
||||
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
|
||||
node_credentials_input_map: Credentials inputs to use in the execution, mapped to specific nodes.
|
||||
nodes_input_masks: Node inputs to use in the execution.
|
||||
Returns:
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
@@ -774,10 +772,19 @@ async def add_graph_execution(
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
node_credentials_input_map = node_credentials_input_map or (
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else None
|
||||
nodes_input_masks = _merge_nodes_input_masks(
|
||||
(
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else {}
|
||||
),
|
||||
nodes_input_masks or {},
|
||||
)
|
||||
starting_nodes_input = await construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs or {},
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
|
||||
if use_db_query:
|
||||
@@ -785,12 +792,7 @@ async def add_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=await construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
),
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
else:
|
||||
@@ -798,20 +800,15 @@ async def add_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=await construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
),
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
try:
|
||||
queue = await get_async_execution_queue()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
if node_credentials_input_map:
|
||||
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
|
||||
if nodes_input_masks:
|
||||
graph_exec_entry.nodes_input_masks = nodes_input_masks
|
||||
await queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.redis import get_redis_async
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor.database import DatabaseManagerAsyncClient
|
||||
|
||||
@@ -7,7 +7,7 @@ from autogpt_libs.utils.synchronize import AsyncRedisKeyedMutex
|
||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
|
||||
from backend.data.model import Credentials, OAuth2Credentials
|
||||
from backend.data.redis import get_redis_async
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
@@ -17,6 +17,7 @@ class ProviderName(str, Enum):
|
||||
GOOGLE = "google"
|
||||
GOOGLE_MAPS = "google_maps"
|
||||
GROQ = "groq"
|
||||
HTTP = "http"
|
||||
HUBSPOT = "hubspot"
|
||||
IDEOGRAM = "ideogram"
|
||||
JINA = "jina"
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
import functools
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
from ._base import BaseWebhooksManager
|
||||
|
||||
_WEBHOOK_MANAGERS: dict["ProviderName", type["BaseWebhooksManager"]] = {}
|
||||
|
||||
|
||||
# --8<-- [start:load_webhook_managers]
|
||||
@functools.cache
|
||||
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
|
||||
if _WEBHOOK_MANAGERS:
|
||||
return _WEBHOOK_MANAGERS
|
||||
webhook_managers = {}
|
||||
|
||||
from .compass import CompassWebhookManager
|
||||
from .generic import GenericWebhooksManager
|
||||
from .github import GithubWebhooksManager
|
||||
from .slant3d import Slant3DWebhooksManager
|
||||
|
||||
_WEBHOOK_MANAGERS.update(
|
||||
webhook_managers.update(
|
||||
{
|
||||
handler.PROVIDER_NAME: handler
|
||||
for handler in [
|
||||
@@ -28,7 +27,7 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
|
||||
]
|
||||
}
|
||||
)
|
||||
return _WEBHOOK_MANAGERS
|
||||
return webhook_managers
|
||||
|
||||
|
||||
# --8<-- [end:load_webhook_managers]
|
||||
|
||||
@@ -7,13 +7,14 @@ from uuid import uuid4
|
||||
from fastapi import Request
|
||||
from strenum import StrEnum
|
||||
|
||||
from backend.data import integrations
|
||||
import backend.data.integrations as integrations
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .utils import webhook_ingress_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
app_config = Config()
|
||||
|
||||
@@ -41,44 +42,74 @@ class BaseWebhooksManager(ABC, Generic[WT]):
|
||||
)
|
||||
|
||||
if webhook := await integrations.find_webhook_by_credentials_and_props(
|
||||
credentials.id, webhook_type, resource, events
|
||||
user_id=user_id,
|
||||
credentials_id=credentials.id,
|
||||
webhook_type=webhook_type,
|
||||
resource=resource,
|
||||
events=events,
|
||||
):
|
||||
return webhook
|
||||
|
||||
return await self._create_webhook(
|
||||
user_id, webhook_type, events, resource, credentials
|
||||
user_id=user_id,
|
||||
webhook_type=webhook_type,
|
||||
events=events,
|
||||
resource=resource,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
async def get_manual_webhook(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
webhook_type: WT,
|
||||
events: list[str],
|
||||
):
|
||||
if current_webhook := await integrations.find_webhook_by_graph_and_props(
|
||||
graph_id, self.PROVIDER_NAME, webhook_type, events
|
||||
graph_id: Optional[str] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> integrations.Webhook:
|
||||
"""
|
||||
Tries to find an existing webhook tied to `graph_id`/`preset_id`,
|
||||
or creates a new webhook if none exists.
|
||||
|
||||
Existing webhooks are matched by `user_id`, `webhook_type`,
|
||||
and `graph_id`/`preset_id`.
|
||||
|
||||
If an existing webhook is found, we check if the events match and update them
|
||||
if necessary. We do this rather than creating a new webhook
|
||||
to avoid changing the webhook URL for existing manual webhooks.
|
||||
"""
|
||||
if (graph_id or preset_id) and (
|
||||
current_webhook := await integrations.find_webhook_by_graph_and_props(
|
||||
user_id=user_id,
|
||||
provider=self.PROVIDER_NAME.value,
|
||||
webhook_type=webhook_type,
|
||||
graph_id=graph_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
):
|
||||
if set(current_webhook.events) != set(events):
|
||||
current_webhook = await integrations.update_webhook(
|
||||
current_webhook.id, events=events
|
||||
)
|
||||
return current_webhook
|
||||
|
||||
return await self._create_webhook(
|
||||
user_id,
|
||||
webhook_type,
|
||||
events,
|
||||
user_id=user_id,
|
||||
webhook_type=webhook_type,
|
||||
events=events,
|
||||
register=False,
|
||||
)
|
||||
|
||||
async def prune_webhook_if_dangling(
|
||||
self, webhook_id: str, credentials: Optional[Credentials]
|
||||
self, user_id: str, webhook_id: str, credentials: Optional[Credentials]
|
||||
) -> bool:
|
||||
webhook = await integrations.get_webhook(webhook_id)
|
||||
if webhook.attached_nodes is None:
|
||||
raise ValueError("Error retrieving webhook including attached nodes")
|
||||
if webhook.attached_nodes:
|
||||
webhook = await integrations.get_webhook(webhook_id, include_relations=True)
|
||||
if webhook.triggered_nodes or webhook.triggered_presets:
|
||||
# Don't prune webhook if in use
|
||||
return False
|
||||
|
||||
if credentials:
|
||||
await self._deregister_webhook(webhook, credentials)
|
||||
await integrations.delete_webhook(webhook.id)
|
||||
await integrations.delete_webhook(user_id, webhook.id)
|
||||
return True
|
||||
|
||||
# --8<-- [start:BaseWebhooksManager3]
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
|
||||
from backend.data.block import BlockSchema, BlockWebhookConfig
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks import get_webhook_manager, supports_webhooks
|
||||
|
||||
from . import get_webhook_manager, supports_webhooks
|
||||
from .utils import setup_webhook_for_block
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import GraphModel, NodeModel
|
||||
@@ -81,7 +83,9 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
|
||||
f"credentials #{creds_meta['id']}"
|
||||
)
|
||||
|
||||
updated_node = await on_node_deactivate(node, credentials=node_credentials)
|
||||
updated_node = await on_node_deactivate(
|
||||
user_id, node, credentials=node_credentials
|
||||
)
|
||||
updated_nodes.append(updated_node)
|
||||
|
||||
graph.nodes = updated_nodes
|
||||
@@ -96,105 +100,25 @@ async def on_node_activate(
|
||||
) -> "NodeModel":
|
||||
"""Hook to be called when the node is activated/created"""
|
||||
|
||||
block = node.block
|
||||
|
||||
if not block.webhook_config:
|
||||
return node
|
||||
|
||||
provider = block.webhook_config.provider
|
||||
if not supports_webhooks(provider):
|
||||
raise ValueError(
|
||||
f"Block #{block.id} has webhook_config for provider {provider} "
|
||||
"which does not support webhooks"
|
||||
if node.block.webhook_config:
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=node.block,
|
||||
trigger_config=node.input_default,
|
||||
for_graph_id=node.graph_id,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Activating webhook node #{node.id} with config {block.webhook_config}"
|
||||
)
|
||||
|
||||
webhooks_manager = get_webhook_manager(provider)
|
||||
|
||||
if auto_setup_webhook := isinstance(block.webhook_config, BlockWebhookConfig):
|
||||
try:
|
||||
resource = block.webhook_config.resource_format.format(**node.input_default)
|
||||
except KeyError:
|
||||
resource = None
|
||||
logger.debug(
|
||||
f"Constructed resource string {resource} from input {node.input_default}"
|
||||
)
|
||||
else:
|
||||
resource = "" # not relevant for manual webhooks
|
||||
|
||||
block_input_schema = cast(BlockSchema, block.input_schema)
|
||||
credentials_field_name = next(iter(block_input_schema.get_credentials_fields()), "")
|
||||
credentials_meta = (
|
||||
node.input_default.get(credentials_field_name)
|
||||
if credentials_field_name
|
||||
else None
|
||||
)
|
||||
event_filter_input_name = block.webhook_config.event_filter_input
|
||||
has_everything_for_webhook = (
|
||||
resource is not None
|
||||
and (credentials_meta or not credentials_field_name)
|
||||
and (
|
||||
not event_filter_input_name
|
||||
or (
|
||||
event_filter_input_name in node.input_default
|
||||
and any(
|
||||
is_on
|
||||
for is_on in node.input_default[event_filter_input_name].values()
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if has_everything_for_webhook and resource is not None:
|
||||
logger.debug(f"Node #{node} has everything for a webhook!")
|
||||
if credentials_meta and not credentials:
|
||||
raise ValueError(
|
||||
f"Cannot set up webhook for node #{node.id}: "
|
||||
f"credentials #{credentials_meta['id']} not available"
|
||||
)
|
||||
|
||||
if event_filter_input_name:
|
||||
# Shape of the event filter is enforced in Block.__init__
|
||||
event_filter = cast(dict, node.input_default[event_filter_input_name])
|
||||
events = [
|
||||
block.webhook_config.event_format.format(event=event)
|
||||
for event, enabled in event_filter.items()
|
||||
if enabled is True
|
||||
]
|
||||
logger.debug(f"Webhook events to subscribe to: {', '.join(events)}")
|
||||
if new_webhook:
|
||||
node = await set_node_webhook(node.id, new_webhook.id)
|
||||
else:
|
||||
events = []
|
||||
|
||||
# Find/make and attach a suitable webhook to the node
|
||||
if auto_setup_webhook:
|
||||
assert credentials is not None
|
||||
new_webhook = await webhooks_manager.get_suitable_auto_webhook(
|
||||
user_id,
|
||||
credentials,
|
||||
block.webhook_config.webhook_type,
|
||||
resource,
|
||||
events,
|
||||
logger.debug(
|
||||
f"Node #{node.id} does not have everything for a webhook: {feedback}"
|
||||
)
|
||||
else:
|
||||
# Manual webhook -> no credentials -> don't register but do create
|
||||
new_webhook = await webhooks_manager.get_manual_webhook(
|
||||
user_id,
|
||||
node.graph_id,
|
||||
block.webhook_config.webhook_type,
|
||||
events,
|
||||
)
|
||||
logger.debug(f"Acquired webhook: {new_webhook}")
|
||||
return await set_node_webhook(node.id, new_webhook.id)
|
||||
else:
|
||||
logger.debug(f"Node #{node.id} does not have everything for a webhook")
|
||||
|
||||
return node
|
||||
|
||||
|
||||
async def on_node_deactivate(
|
||||
user_id: str,
|
||||
node: "NodeModel",
|
||||
*,
|
||||
credentials: Optional["Credentials"] = None,
|
||||
@@ -233,7 +157,9 @@ async def on_node_deactivate(
|
||||
f"Pruning{' and deregistering' if credentials else ''} "
|
||||
f"webhook #{webhook.id}"
|
||||
)
|
||||
await webhooks_manager.prune_webhook_if_dangling(webhook.id, credentials)
|
||||
await webhooks_manager.prune_webhook_if_dangling(
|
||||
user_id, webhook.id, credentials
|
||||
)
|
||||
if (
|
||||
cast(BlockSchema, block.input_schema).get_credentials_fields()
|
||||
and not credentials
|
||||
|
||||
@@ -1,7 +1,22 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Config
|
||||
|
||||
from . import get_webhook_manager, supports_webhooks
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import Block, BlockSchema
|
||||
from backend.data.integrations import Webhook
|
||||
from backend.data.model import Credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
app_config = Config()
|
||||
credentials_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
# TODO: add test to assert this matches the actual API route
|
||||
@@ -10,3 +25,122 @@ def webhook_ingress_url(provider_name: ProviderName, webhook_id: str) -> str:
|
||||
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
|
||||
f"/webhooks/{webhook_id}/ingress"
|
||||
)
|
||||
|
||||
|
||||
async def setup_webhook_for_block(
|
||||
user_id: str,
|
||||
trigger_block: "Block[BlockSchema, BlockSchema]",
|
||||
trigger_config: dict[str, JsonValue], # = Trigger block inputs
|
||||
for_graph_id: Optional[str] = None,
|
||||
for_preset_id: Optional[str] = None,
|
||||
credentials: Optional["Credentials"] = None,
|
||||
) -> tuple["Webhook", None] | tuple[None, str]:
|
||||
"""
|
||||
Utility function to create (and auto-setup if possible) a webhook for a given provider.
|
||||
|
||||
Returns:
|
||||
Webhook: The created or found webhook object, if successful.
|
||||
str: A feedback message, if any required inputs are missing.
|
||||
"""
|
||||
from backend.data.block import BlockWebhookConfig
|
||||
|
||||
if not (trigger_base_config := trigger_block.webhook_config):
|
||||
raise ValueError(f"Block #{trigger_block.id} does not have a webhook_config")
|
||||
|
||||
provider = trigger_base_config.provider
|
||||
if not supports_webhooks(provider):
|
||||
raise NotImplementedError(
|
||||
f"Block #{trigger_block.id} has webhook_config for provider {provider} "
|
||||
"for which we do not have a WebhooksManager"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Setting up webhook for block #{trigger_block.id} with config {trigger_config}"
|
||||
)
|
||||
|
||||
# Check & parse the event filter input, if any
|
||||
events: list[str] = []
|
||||
if event_filter_input_name := trigger_base_config.event_filter_input:
|
||||
if not (event_filter := trigger_config.get(event_filter_input_name)):
|
||||
return None, (
|
||||
f"Cannot set up {provider.value} webhook without event filter input: "
|
||||
f"missing input for '{event_filter_input_name}'"
|
||||
)
|
||||
elif not (
|
||||
# Shape of the event filter is enforced in Block.__init__
|
||||
any((event_filter := cast(dict[str, bool], event_filter)).values())
|
||||
):
|
||||
return None, (
|
||||
f"Cannot set up {provider.value} webhook without any enabled events "
|
||||
f"in event filter input '{event_filter_input_name}'"
|
||||
)
|
||||
|
||||
events = [
|
||||
trigger_base_config.event_format.format(event=event)
|
||||
for event, enabled in event_filter.items()
|
||||
if enabled is True
|
||||
]
|
||||
logger.debug(f"Webhook events to subscribe to: {', '.join(events)}")
|
||||
|
||||
# Check & process prerequisites for auto-setup webhooks
|
||||
if auto_setup_webhook := isinstance(trigger_base_config, BlockWebhookConfig):
|
||||
try:
|
||||
resource = trigger_base_config.resource_format.format(**trigger_config)
|
||||
except KeyError as missing_key:
|
||||
return None, (
|
||||
f"Cannot auto-setup {provider.value} webhook without resource: "
|
||||
f"missing input for '{missing_key}'"
|
||||
)
|
||||
logger.debug(
|
||||
f"Constructed resource string {resource} from input {trigger_config}"
|
||||
)
|
||||
|
||||
creds_field_name = next(
|
||||
# presence of this field is enforced in Block.__init__
|
||||
iter(trigger_block.input_schema.get_credentials_fields())
|
||||
)
|
||||
|
||||
if not (
|
||||
credentials_meta := cast(dict, trigger_config.get(creds_field_name, None))
|
||||
):
|
||||
return None, f"Cannot set up {provider.value} webhook without credentials"
|
||||
elif not (
|
||||
credentials := credentials
|
||||
or await credentials_manager.get(user_id, credentials_meta["id"])
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot set up {provider.value} webhook without credentials: "
|
||||
f"credentials #{credentials_meta['id']} not found for user #{user_id}"
|
||||
)
|
||||
elif credentials.provider != provider:
|
||||
raise ValueError(
|
||||
f"Credentials #{credentials.id} do not match provider {provider.value}"
|
||||
)
|
||||
else:
|
||||
# not relevant for manual webhooks:
|
||||
resource = ""
|
||||
credentials = None
|
||||
|
||||
webhooks_manager = get_webhook_manager(provider)
|
||||
|
||||
# Find/make and attach a suitable webhook to the node
|
||||
if auto_setup_webhook:
|
||||
assert credentials is not None
|
||||
webhook = await webhooks_manager.get_suitable_auto_webhook(
|
||||
user_id=user_id,
|
||||
credentials=credentials,
|
||||
webhook_type=trigger_base_config.webhook_type,
|
||||
resource=resource,
|
||||
events=events,
|
||||
)
|
||||
else:
|
||||
# Manual webhook -> no credentials -> don't register but do create
|
||||
webhook = await webhooks_manager.get_manual_webhook(
|
||||
user_id=user_id,
|
||||
webhook_type=trigger_base_config.webhook_type,
|
||||
events=events,
|
||||
graph_id=for_graph_id,
|
||||
preset_id=for_preset_id,
|
||||
)
|
||||
logger.debug(f"Acquired webhook: {webhook}")
|
||||
return webhook, None
|
||||
|
||||
@@ -2,11 +2,19 @@ import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Path,
|
||||
Query,
|
||||
Request,
|
||||
status,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.status import HTTP_404_NOT_FOUND
|
||||
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.data.graph import get_graph, set_node_webhook
|
||||
from backend.data.integrations import (
|
||||
WebhookEvent,
|
||||
get_all_webhooks_by_creds,
|
||||
@@ -14,12 +22,18 @@ from backend.data.integrations import (
|
||||
publish_webhook_event,
|
||||
wait_for_webhook_event,
|
||||
)
|
||||
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
|
||||
from backend.data.model import (
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.server.v2.library.db import set_preset_webhook, update_preset
|
||||
from backend.util.exceptions import NeedConfirmation, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -73,6 +87,9 @@ class CredentialsMetaResponse(BaseModel):
|
||||
title: str | None
|
||||
scopes: list[str] | None
|
||||
username: str | None
|
||||
host: str | None = Field(
|
||||
default=None, description="Host pattern for host-scoped credentials"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/callback")
|
||||
@@ -95,7 +112,10 @@ async def callback(
|
||||
|
||||
if not valid_state:
|
||||
logger.warning(f"Invalid or expired state token for user {user_id}")
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired state token")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid or expired state token",
|
||||
)
|
||||
try:
|
||||
scopes = valid_state.scopes
|
||||
logger.debug(f"Retrieved scopes from state token: {scopes}")
|
||||
@@ -122,17 +142,12 @@ async def callback(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"OAuth callback for provider %s failed during code exchange: %s. Confirm provider credentials.",
|
||||
provider.value,
|
||||
e,
|
||||
logger.error(
|
||||
f"OAuth2 Code->Token exchange failed for provider {provider.value}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"message": str(e),
|
||||
"hint": "Verify OAuth configuration and try again.",
|
||||
},
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"OAuth2 callback failed to exchange code for tokens: {str(e)}",
|
||||
)
|
||||
|
||||
# TODO: Allow specifying `title` to set on `credentials`
|
||||
@@ -149,6 +164,9 @@ async def callback(
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=(
|
||||
credentials.host if isinstance(credentials, HostScopedCredentials) else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -165,6 +183,7 @@ async def list_credentials(
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
@@ -186,6 +205,7 @@ async def list_credentials_by_provider(
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
@@ -201,10 +221,13 @@ async def get_credential(
|
||||
) -> Credentials:
|
||||
credential = await creds_manager.get(user_id, cred_id)
|
||||
if not credential:
|
||||
raise HTTPException(status_code=404, detail="Credentials not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if credential.provider != provider:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Credentials do not match the specified provider"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
)
|
||||
return credential
|
||||
|
||||
@@ -222,7 +245,8 @@ async def create_credentials(
|
||||
await creds_manager.create(user_id, credentials)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to store credentials: {str(e)}"
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to store credentials: {str(e)}",
|
||||
)
|
||||
return credentials
|
||||
|
||||
@@ -256,14 +280,17 @@ async def delete_credentials(
|
||||
) -> CredentialsDeletionResponse | CredentialsDeletionNeedsConfirmationResponse:
|
||||
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(status_code=404, detail="Credentials not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if creds.provider != provider:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Credentials do not match the specified provider"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
)
|
||||
|
||||
try:
|
||||
await remove_all_webhooks_for_credentials(creds, force)
|
||||
await remove_all_webhooks_for_credentials(user_id, creds, force)
|
||||
except NeedConfirmation as e:
|
||||
return CredentialsDeletionNeedsConfirmationResponse(message=str(e))
|
||||
|
||||
@@ -294,16 +321,10 @@ async def webhook_ingress_generic(
|
||||
logger.debug(f"Received {provider.value} webhook ingress for ID {webhook_id}")
|
||||
webhook_manager = get_webhook_manager(provider)
|
||||
try:
|
||||
webhook = await get_webhook(webhook_id)
|
||||
webhook = await get_webhook(webhook_id, include_relations=True)
|
||||
except NotFoundError as e:
|
||||
logger.warning(
|
||||
"Webhook payload received for unknown webhook %s. Confirm the webhook ID.",
|
||||
webhook_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail={"message": str(e), "hint": "Check if the webhook ID is correct."},
|
||||
) from e
|
||||
logger.warning(f"Webhook payload received for unknown webhook #{webhook_id}")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
logger.debug(f"Webhook #{webhook_id}: {webhook}")
|
||||
payload, event_type = await webhook_manager.validate_payload(webhook, request)
|
||||
logger.debug(
|
||||
@@ -320,11 +341,11 @@ async def webhook_ingress_generic(
|
||||
await publish_webhook_event(webhook_event)
|
||||
logger.debug(f"Webhook event published: {webhook_event}")
|
||||
|
||||
if not webhook.attached_nodes:
|
||||
if not (webhook.triggered_nodes or webhook.triggered_presets):
|
||||
return
|
||||
|
||||
executions: list[Awaitable] = []
|
||||
for node in webhook.attached_nodes:
|
||||
for node in webhook.triggered_nodes:
|
||||
logger.debug(f"Webhook-attached node: {node}")
|
||||
if not node.is_triggered_by_event_type(event_type):
|
||||
logger.debug(f"Node #{node.id} doesn't trigger on event {event_type}")
|
||||
@@ -335,7 +356,48 @@ async def webhook_ingress_generic(
|
||||
user_id=webhook.user_id,
|
||||
graph_id=node.graph_id,
|
||||
graph_version=node.graph_version,
|
||||
inputs={f"webhook_{webhook_id}_payload": payload},
|
||||
nodes_input_masks={node.id: {"payload": payload}},
|
||||
)
|
||||
)
|
||||
for preset in webhook.triggered_presets:
|
||||
logger.debug(f"Webhook-attached preset: {preset}")
|
||||
if not preset.is_active:
|
||||
logger.debug(f"Preset #{preset.id} is inactive")
|
||||
continue
|
||||
|
||||
graph = await get_graph(preset.graph_id, preset.graph_version, webhook.user_id)
|
||||
if not graph:
|
||||
logger.error(
|
||||
f"User #{webhook.user_id} has preset #{preset.id} for graph "
|
||||
f"#{preset.graph_id} v{preset.graph_version}, "
|
||||
"but no access to the graph itself."
|
||||
)
|
||||
logger.info(f"Automatically deactivating broken preset #{preset.id}")
|
||||
await update_preset(preset.user_id, preset.id, is_active=False)
|
||||
continue
|
||||
if not (trigger_node := graph.webhook_input_node):
|
||||
# NOTE: this should NEVER happen, but we log and handle it gracefully
|
||||
logger.error(
|
||||
f"Preset #{preset.id} is triggered by webhook #{webhook.id}, but graph "
|
||||
f"#{preset.graph_id} v{preset.graph_version} has no webhook input node"
|
||||
)
|
||||
await set_preset_webhook(preset.user_id, preset.id, None)
|
||||
continue
|
||||
if not trigger_node.block.is_triggered_by_event_type(preset.inputs, event_type):
|
||||
logger.debug(f"Preset #{preset.id} doesn't trigger on event {event_type}")
|
||||
continue
|
||||
logger.debug(f"Executing preset #{preset.id} for webhook #{webhook.id}")
|
||||
|
||||
executions.append(
|
||||
add_graph_execution(
|
||||
user_id=webhook.user_id,
|
||||
graph_id=preset.graph_id,
|
||||
preset_id=preset.id,
|
||||
graph_version=preset.graph_version,
|
||||
graph_credentials_inputs=preset.credentials,
|
||||
nodes_input_masks={
|
||||
trigger_node.id: {**preset.inputs, "payload": payload}
|
||||
},
|
||||
)
|
||||
)
|
||||
asyncio.gather(*executions)
|
||||
@@ -360,7 +422,9 @@ async def webhook_ping(
|
||||
return False
|
||||
|
||||
if not await wait_for_webhook_event(webhook_id, event_type="ping", timeout=10):
|
||||
raise HTTPException(status_code=504, detail="Webhook ping timed out")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Webhook ping timed out"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -369,32 +433,37 @@ async def webhook_ping(
|
||||
|
||||
|
||||
async def remove_all_webhooks_for_credentials(
|
||||
credentials: Credentials, force: bool = False
|
||||
user_id: str, credentials: Credentials, force: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Remove and deregister all webhooks that were registered using the given credentials.
|
||||
|
||||
Params:
|
||||
user_id: The ID of the user who owns the credentials and webhooks.
|
||||
credentials: The credentials for which to remove the associated webhooks.
|
||||
force: Whether to proceed if any of the webhooks are still in use.
|
||||
|
||||
Raises:
|
||||
NeedConfirmation: If any of the webhooks are still in use and `force` is `False`
|
||||
"""
|
||||
webhooks = await get_all_webhooks_by_creds(credentials.id)
|
||||
if any(w.attached_nodes for w in webhooks) and not force:
|
||||
webhooks = await get_all_webhooks_by_creds(
|
||||
user_id, credentials.id, include_relations=True
|
||||
)
|
||||
if any(w.triggered_nodes or w.triggered_presets for w in webhooks) and not force:
|
||||
raise NeedConfirmation(
|
||||
"Some webhooks linked to these credentials are still in use by an agent"
|
||||
)
|
||||
for webhook in webhooks:
|
||||
# Unlink all nodes
|
||||
for node in webhook.attached_nodes or []:
|
||||
# Unlink all nodes & presets
|
||||
for node in webhook.triggered_nodes:
|
||||
await set_node_webhook(node.id, None)
|
||||
for preset in webhook.triggered_presets:
|
||||
await set_preset_webhook(user_id, preset.id, None)
|
||||
|
||||
# Prune the webhook
|
||||
webhook_manager = get_webhook_manager(ProviderName(credentials.provider))
|
||||
success = await webhook_manager.prune_webhook_if_dangling(
|
||||
webhook.id, credentials
|
||||
user_id, webhook.id, credentials
|
||||
)
|
||||
if not success:
|
||||
logger.warning(f"Webhook #{webhook.id} failed to prune")
|
||||
@@ -405,7 +474,7 @@ def _get_provider_oauth_handler(
|
||||
) -> "BaseOAuthHandler":
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Provider '{provider_name.value}' does not support OAuth",
|
||||
)
|
||||
|
||||
@@ -413,14 +482,13 @@ def _get_provider_oauth_handler(
|
||||
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
|
||||
if not (client_id and client_secret):
|
||||
logger.error(
|
||||
"OAuth credentials for provider %s are missing. Check environment configuration.",
|
||||
provider_name.value,
|
||||
f"Attempt to use unconfigured {provider_name.value} OAuth integration"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail={
|
||||
"message": f"Integration with provider '{provider_name.value}' is not configured",
|
||||
"hint": "Set client ID and secret in the environment.",
|
||||
"message": f"Integration with provider '{provider_name.value}' is not configured.",
|
||||
"hint": "Set client ID and secret in the application's deployment environment",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -279,6 +279,7 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
|
||||
@staticmethod
|
||||
async def test_delete_graph(graph_id: str, user_id: str):
|
||||
"""Used for clean-up after a test run"""
|
||||
await backend.server.v2.library.db.delete_library_agent_by_graph_id(
|
||||
graph_id=graph_id, user_id=user_id
|
||||
)
|
||||
@@ -323,18 +324,14 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
|
||||
@staticmethod
|
||||
async def test_execute_preset(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
preset_id: str,
|
||||
user_id: str,
|
||||
node_input: Optional[dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
return await backend.server.v2.library.routes.presets.execute_preset(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
preset_id=preset_id,
|
||||
node_input=node_input or {},
|
||||
user_id=user_id,
|
||||
inputs=inputs or {},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -360,11 +357,22 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
from backend.server.integrations.router import create_credentials
|
||||
|
||||
return await create_credentials(
|
||||
user_id=user_id, provider=provider, credentials=credentials
|
||||
from backend.server.integrations.router import (
|
||||
create_credentials,
|
||||
get_credential,
|
||||
)
|
||||
|
||||
try:
|
||||
return await create_credentials(
|
||||
user_id=user_id, provider=provider, credentials=credentials
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating credentials: {e}")
|
||||
return await get_credential(
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
cred_id=credentials.id,
|
||||
)
|
||||
|
||||
def set_test_dependency_overrides(self, overrides: dict):
|
||||
app.dependency_overrides.update(overrides)
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
from typing import Annotated
|
||||
|
||||
import fastapi
|
||||
import pydantic
|
||||
|
||||
import backend.data.analytics
|
||||
from backend.server.utils import get_user_id
|
||||
@@ -12,24 +13,28 @@ router = fastapi.APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LogRawMetricRequest(pydantic.BaseModel):
|
||||
metric_name: str = pydantic.Field(..., min_length=1)
|
||||
metric_value: float = pydantic.Field(..., allow_inf_nan=False)
|
||||
data_string: str = pydantic.Field(..., min_length=1)
|
||||
|
||||
|
||||
@router.post(path="/log_raw_metric")
|
||||
async def log_raw_metric(
|
||||
user_id: Annotated[str, fastapi.Depends(get_user_id)],
|
||||
metric_name: Annotated[str, fastapi.Body(..., embed=True)],
|
||||
metric_value: Annotated[float, fastapi.Body(..., embed=True)],
|
||||
data_string: Annotated[str, fastapi.Body(..., embed=True)],
|
||||
request: LogRawMetricRequest,
|
||||
):
|
||||
try:
|
||||
result = await backend.data.analytics.log_raw_metric(
|
||||
user_id=user_id,
|
||||
metric_name=metric_name,
|
||||
metric_value=metric_value,
|
||||
data_string=data_string,
|
||||
metric_name=request.metric_name,
|
||||
metric_value=request.metric_value,
|
||||
data_string=request.data_string,
|
||||
)
|
||||
return result.id
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to log metric %s for user %s: %s", metric_name, user_id, e
|
||||
"Failed to log metric %s for user %s: %s", request.metric_name, user_id, e
|
||||
)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
|
||||
@@ -97,8 +97,17 @@ def test_log_raw_metric_invalid_request_improved() -> None:
|
||||
assert "data_string" in error_fields, "Should report missing data_string"
|
||||
|
||||
|
||||
def test_log_raw_metric_type_validation_improved() -> None:
|
||||
def test_log_raw_metric_type_validation_improved(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test metric type validation with improved assertions."""
|
||||
# Mock the analytics function to avoid event loop issues
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=Mock(id="test-id"),
|
||||
)
|
||||
|
||||
invalid_requests = [
|
||||
{
|
||||
"data": {
|
||||
@@ -119,10 +128,10 @@ def test_log_raw_metric_type_validation_improved() -> None:
|
||||
{
|
||||
"data": {
|
||||
"metric_name": "test",
|
||||
"metric_value": float("inf"), # Infinity
|
||||
"data_string": "test",
|
||||
"metric_value": 123, # Valid number
|
||||
"data_string": "", # Empty data_string
|
||||
},
|
||||
"expected_error": "ensure this value is finite",
|
||||
"expected_error": "String should have at least 1 character",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@@ -93,10 +93,18 @@ def test_log_raw_metric_values_parametrized(
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_invalid_requests_parametrized(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test invalid metric requests with parametrize."""
|
||||
# Mock the analytics function to avoid event loop issues
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=Mock(id="test-id"),
|
||||
)
|
||||
|
||||
response = client.post("/log_raw_metric", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -9,7 +9,7 @@ import stripe
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from autogpt_libs.feature_flag.client import feature_flag
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request, Response
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
@@ -72,6 +72,7 @@ from backend.server.model import (
|
||||
UpdatePermissionsRequest,
|
||||
)
|
||||
from backend.server.utils import get_user_id
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -765,70 +766,94 @@ async def delete_graph_execution(
|
||||
|
||||
|
||||
class ScheduleCreationRequest(pydantic.BaseModel):
|
||||
graph_version: Optional[int] = None
|
||||
name: str
|
||||
cron: str
|
||||
input_data: dict[Any, Any]
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
inputs: dict[str, Any]
|
||||
credentials: dict[str, CredentialsMetaInput] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/schedules",
|
||||
path="/graphs/{graph_id}/schedules",
|
||||
summary="Create execution schedule",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def create_schedule(
|
||||
async def create_graph_execution_schedule(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
schedule: ScheduleCreationRequest,
|
||||
graph_id: str = Path(..., description="ID of the graph to schedule"),
|
||||
schedule_params: ScheduleCreationRequest = Body(),
|
||||
) -> scheduler.GraphExecutionJobInfo:
|
||||
graph = await graph_db.get_graph(
|
||||
schedule.graph_id, schedule.graph_version, user_id=user_id
|
||||
graph_id=graph_id,
|
||||
version=schedule_params.graph_version,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Graph #{schedule.graph_id} v.{schedule.graph_version} not found.",
|
||||
detail=f"Graph #{graph_id} v{schedule_params.graph_version} not found.",
|
||||
)
|
||||
|
||||
return await execution_scheduler_client().add_execution_schedule(
|
||||
graph_id=schedule.graph_id,
|
||||
graph_version=graph.version,
|
||||
cron=schedule.cron,
|
||||
input_data=schedule.input_data,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
name=schedule_params.name,
|
||||
cron=schedule_params.cron,
|
||||
input_data=schedule_params.inputs,
|
||||
input_credentials=schedule_params.credentials,
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/schedules",
|
||||
summary="List execution schedules for a graph",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def list_graph_execution_schedules(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_id: str = Path(),
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
return await execution_scheduler_client().get_execution_schedules(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/schedules",
|
||||
summary="List execution schedules for a user",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def list_all_graphs_execution_schedules(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
return await execution_scheduler_client().get_execution_schedules(user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
path="/schedules/{schedule_id}",
|
||||
summary="Delete execution schedule",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def delete_schedule(
|
||||
schedule_id: str,
|
||||
async def delete_graph_execution_schedule(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
|
||||
schedule_id: str = Path(..., description="ID of the schedule to delete"),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail=f"Schedule #{schedule_id} not found",
|
||||
)
|
||||
return {"id": schedule_id}
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/schedules",
|
||||
summary="List execution schedules",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_execution_schedules(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_id: str | None = None,
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
return await execution_scheduler_client().get_execution_schedules(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
)
|
||||
|
||||
|
||||
########################################################
|
||||
##################### API KEY ##############################
|
||||
########################################################
|
||||
|
||||
@@ -108,11 +108,16 @@ class TestDatabaseIsolation:
|
||||
where={"email": {"contains": "@test.example"}}
|
||||
)
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def test_create_user(self, test_db_connection):
|
||||
"""Test that demonstrates proper isolation."""
|
||||
# This test has access to a clean database
|
||||
user = await test_db_connection.user.create(
|
||||
data={"email": "test@test.example", "name": "Test User"}
|
||||
data={
|
||||
"id": "test-user-id",
|
||||
"email": "test@test.example",
|
||||
"name": "Test User",
|
||||
}
|
||||
)
|
||||
assert user.email == "test@test.example"
|
||||
# User will be cleaned up automatically
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
import fastapi
|
||||
import prisma.errors
|
||||
@@ -7,17 +7,17 @@ import prisma.fields
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.data.graph
|
||||
import backend.data.graph as graph_db
|
||||
import backend.server.model
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
import backend.server.v2.store.image_gen as store_image_gen
|
||||
import backend.server.v2.store.media as store_media
|
||||
from backend.data import db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.db import locked_transaction
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.db import locked_transaction, transaction
|
||||
from backend.data.execution import get_graph_execution
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -122,7 +122,7 @@ async def list_library_agents(
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
logger.error(
|
||||
f"Error parsing LibraryAgent when getting library agents from db: {e}"
|
||||
f"Error parsing LibraryAgent #{agent.id} from DB item: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -168,7 +168,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
)
|
||||
|
||||
if not library_agent:
|
||||
raise store_exceptions.AgentNotFoundError(f"Library agent #{id} not found")
|
||||
raise NotFoundError(f"Library agent #{id} not found")
|
||||
|
||||
return library_model.LibraryAgent.from_db(library_agent)
|
||||
|
||||
@@ -215,8 +215,34 @@ async def get_library_agent_by_store_version_id(
|
||||
return None
|
||||
|
||||
|
||||
async def get_library_agent_by_graph_id(
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_version: Optional[int] = None,
|
||||
) -> library_model.LibraryAgent | None:
|
||||
try:
|
||||
filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"agentGraphId": graph_id,
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
if graph_version is not None:
|
||||
filter["agentGraphVersion"] = graph_version
|
||||
|
||||
agent = await prisma.models.LibraryAgent.prisma().find_first(
|
||||
where=filter,
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
if not agent:
|
||||
return None
|
||||
return library_model.LibraryAgent.from_db(agent)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error fetching library agent by graph ID: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
|
||||
|
||||
|
||||
async def add_generated_agent_image(
|
||||
graph: backend.data.graph.GraphModel,
|
||||
graph: graph_db.GraphModel,
|
||||
library_agent_id: str,
|
||||
) -> Optional[prisma.models.LibraryAgent]:
|
||||
"""
|
||||
@@ -249,7 +275,7 @@ async def add_generated_agent_image(
|
||||
|
||||
|
||||
async def create_library_agent(
|
||||
graph: backend.data.graph.GraphModel,
|
||||
graph: graph_db.GraphModel,
|
||||
user_id: str,
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
@@ -346,8 +372,8 @@ async def update_library_agent(
|
||||
auto_update_version: Optional[bool] = None,
|
||||
is_favorite: Optional[bool] = None,
|
||||
is_archived: Optional[bool] = None,
|
||||
is_deleted: Optional[bool] = None,
|
||||
) -> None:
|
||||
is_deleted: Optional[Literal[False]] = None,
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Updates the specified LibraryAgent record.
|
||||
|
||||
@@ -357,15 +383,18 @@ async def update_library_agent(
|
||||
auto_update_version: Whether the agent should auto-update to active version.
|
||||
is_favorite: Whether this agent is marked as a favorite.
|
||||
is_archived: Whether this agent is archived.
|
||||
is_deleted: Whether this agent is deleted.
|
||||
|
||||
Returns:
|
||||
The updated LibraryAgent.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the specified LibraryAgent does not exist.
|
||||
DatabaseError: If there's an error in the update operation.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Updating library agent {library_agent_id} for user {user_id} with "
|
||||
f"auto_update_version={auto_update_version}, is_favorite={is_favorite}, "
|
||||
f"is_archived={is_archived}, is_deleted={is_deleted}"
|
||||
f"is_archived={is_archived}"
|
||||
)
|
||||
update_fields: prisma.types.LibraryAgentUpdateManyMutationInput = {}
|
||||
if auto_update_version is not None:
|
||||
@@ -375,17 +404,46 @@ async def update_library_agent(
|
||||
if is_archived is not None:
|
||||
update_fields["isArchived"] = is_archived
|
||||
if is_deleted is not None:
|
||||
if is_deleted is True:
|
||||
raise RuntimeError(
|
||||
"Use delete_library_agent() to (soft-)delete library agents"
|
||||
)
|
||||
update_fields["isDeleted"] = is_deleted
|
||||
if not update_fields:
|
||||
raise ValueError("No values were passed to update")
|
||||
|
||||
try:
|
||||
await prisma.models.LibraryAgent.prisma().update_many(
|
||||
where={"id": library_agent_id, "userId": user_id}, data=update_fields
|
||||
n_updated = await prisma.models.LibraryAgent.prisma().update_many(
|
||||
where={"id": library_agent_id, "userId": user_id},
|
||||
data=update_fields,
|
||||
)
|
||||
if n_updated < 1:
|
||||
raise NotFoundError(f"Library agent {library_agent_id} not found")
|
||||
|
||||
return await get_library_agent(
|
||||
id=library_agent_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating library agent: {str(e)}")
|
||||
raise store_exceptions.DatabaseError("Failed to update library agent") from e
|
||||
|
||||
|
||||
async def delete_library_agent(
|
||||
library_agent_id: str, user_id: str, soft_delete: bool = True
|
||||
) -> None:
|
||||
if soft_delete:
|
||||
deleted_count = await prisma.models.LibraryAgent.prisma().update_many(
|
||||
where={"id": library_agent_id, "userId": user_id}, data={"isDeleted": True}
|
||||
)
|
||||
else:
|
||||
deleted_count = await prisma.models.LibraryAgent.prisma().delete_many(
|
||||
where={"id": library_agent_id, "userId": user_id}
|
||||
)
|
||||
if deleted_count < 1:
|
||||
raise NotFoundError(f"Library agent #{library_agent_id} not found")
|
||||
|
||||
|
||||
async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
|
||||
"""
|
||||
Deletes a library agent for the given user
|
||||
@@ -525,7 +583,10 @@ async def list_presets(
|
||||
)
|
||||
raise store_exceptions.DatabaseError("Invalid pagination parameters")
|
||||
|
||||
query_filter: prisma.types.AgentPresetWhereInput = {"userId": user_id}
|
||||
query_filter: prisma.types.AgentPresetWhereInput = {
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
if graph_id:
|
||||
query_filter["agentGraphId"] = graph_id
|
||||
|
||||
@@ -581,7 +642,7 @@ async def get_preset(
|
||||
where={"id": preset_id},
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not preset or preset.userId != user_id:
|
||||
if not preset or preset.userId != user_id or preset.isDeleted:
|
||||
return None
|
||||
return library_model.LibraryAgentPreset.from_db(preset)
|
||||
except prisma.errors.PrismaError as e:
|
||||
@@ -618,12 +679,19 @@ async def create_preset(
|
||||
agentGraphId=preset.graph_id,
|
||||
agentGraphVersion=preset.graph_version,
|
||||
isActive=preset.is_active,
|
||||
webhookId=preset.webhook_id,
|
||||
InputPresets={
|
||||
"create": [
|
||||
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
|
||||
name=name, data=prisma.fields.Json(data)
|
||||
)
|
||||
for name, data in preset.inputs.items()
|
||||
for name, data in {
|
||||
**preset.inputs,
|
||||
**{
|
||||
key: creds_meta.model_dump(exclude_none=True)
|
||||
for key, creds_meta in preset.credentials.items()
|
||||
},
|
||||
}.items()
|
||||
]
|
||||
},
|
||||
),
|
||||
@@ -664,6 +732,7 @@ async def create_preset_from_graph_execution(
|
||||
user_id=user_id,
|
||||
preset=library_model.LibraryAgentPresetCreatable(
|
||||
inputs=graph_execution.inputs,
|
||||
credentials={}, # FIXME
|
||||
graph_id=graph_execution.graph_id,
|
||||
graph_version=graph_execution.graph_version,
|
||||
name=create_request.name,
|
||||
@@ -676,7 +745,11 @@ async def create_preset_from_graph_execution(
|
||||
async def update_preset(
|
||||
user_id: str,
|
||||
preset_id: str,
|
||||
preset: library_model.LibraryAgentPresetUpdatable,
|
||||
inputs: Optional[BlockInput] = None,
|
||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
"""
|
||||
Updates an existing AgentPreset for a user.
|
||||
@@ -684,49 +757,95 @@ async def update_preset(
|
||||
Args:
|
||||
user_id: The ID of the user updating the preset.
|
||||
preset_id: The ID of the preset to update.
|
||||
preset: The preset data used for the update.
|
||||
inputs: New inputs object to set on the preset.
|
||||
credentials: New credentials to set on the preset.
|
||||
name: New name for the preset.
|
||||
description: New description for the preset.
|
||||
is_active: New active status for the preset.
|
||||
|
||||
Returns:
|
||||
The updated LibraryAgentPreset.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's a database error in updating the preset.
|
||||
ValueError: If attempting to update a non-existent preset.
|
||||
NotFoundError: If attempting to update a non-existent preset.
|
||||
"""
|
||||
current = await get_preset(user_id, preset_id) # assert ownership
|
||||
if not current:
|
||||
raise NotFoundError(f"Preset #{preset_id} not found for user #{user_id}")
|
||||
logger.debug(
|
||||
f"Updating preset #{preset_id} ({repr(preset.name)}) for user #{user_id}",
|
||||
f"Updating preset #{preset_id} ({repr(current.name)}) for user #{user_id}",
|
||||
)
|
||||
try:
|
||||
update_data: prisma.types.AgentPresetUpdateInput = {}
|
||||
if preset.name:
|
||||
update_data["name"] = preset.name
|
||||
if preset.description:
|
||||
update_data["description"] = preset.description
|
||||
if preset.inputs:
|
||||
update_data["InputPresets"] = {
|
||||
"create": [
|
||||
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
|
||||
name=name, data=prisma.fields.Json(data)
|
||||
async with transaction() as tx:
|
||||
update_data: prisma.types.AgentPresetUpdateInput = {}
|
||||
if name:
|
||||
update_data["name"] = name
|
||||
if description:
|
||||
update_data["description"] = description
|
||||
if is_active is not None:
|
||||
update_data["isActive"] = is_active
|
||||
if inputs or credentials:
|
||||
if not (inputs and credentials):
|
||||
raise ValueError(
|
||||
"Preset inputs and credentials must be provided together"
|
||||
)
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
}
|
||||
if preset.is_active:
|
||||
update_data["isActive"] = preset.is_active
|
||||
update_data["InputPresets"] = {
|
||||
"create": [
|
||||
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
|
||||
name=name, data=prisma.fields.Json(data)
|
||||
)
|
||||
for name, data in {
|
||||
**inputs,
|
||||
**{
|
||||
key: creds_meta.model_dump(exclude_none=True)
|
||||
for key, creds_meta in credentials.items()
|
||||
},
|
||||
}.items()
|
||||
],
|
||||
}
|
||||
# Existing InputPresets must be deleted, in a separate query
|
||||
await prisma.models.AgentNodeExecutionInputOutput.prisma(
|
||||
tx
|
||||
).delete_many(where={"agentPresetId": preset_id})
|
||||
|
||||
updated = await prisma.models.AgentPreset.prisma().update(
|
||||
where={"id": preset_id},
|
||||
data=update_data,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
updated = await prisma.models.AgentPreset.prisma(tx).update(
|
||||
where={"id": preset_id},
|
||||
data=update_data,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not updated:
|
||||
raise ValueError(f"AgentPreset #{preset_id} not found")
|
||||
raise RuntimeError(f"AgentPreset #{preset_id} vanished while updating")
|
||||
return library_model.LibraryAgentPreset.from_db(updated)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating preset: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to update preset") from e
|
||||
|
||||
|
||||
async def set_preset_webhook(
|
||||
user_id: str, preset_id: str, webhook_id: str | None
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
current = await prisma.models.AgentPreset.prisma().find_unique(
|
||||
where={"id": preset_id},
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not current or current.userId != user_id:
|
||||
raise NotFoundError(f"Preset #{preset_id} not found")
|
||||
|
||||
updated = await prisma.models.AgentPreset.prisma().update(
|
||||
where={"id": preset_id},
|
||||
data=(
|
||||
{"Webhook": {"connect": {"id": webhook_id}}}
|
||||
if webhook_id
|
||||
else {"Webhook": {"disconnect": True}}
|
||||
),
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not updated:
|
||||
raise RuntimeError(f"AgentPreset #{preset_id} vanished while updating")
|
||||
return library_model.LibraryAgentPreset.from_db(updated)
|
||||
|
||||
|
||||
async def delete_preset(user_id: str, preset_id: str) -> None:
|
||||
"""
|
||||
Soft-deletes a preset by marking it as isDeleted = True.
|
||||
@@ -738,7 +857,7 @@ async def delete_preset(user_id: str, preset_id: str) -> None:
|
||||
Raises:
|
||||
DatabaseError: If there's a database error during deletion.
|
||||
"""
|
||||
logger.info(f"Deleting preset {preset_id} for user {user_id}")
|
||||
logger.debug(f"Setting preset #{preset_id} for user #{user_id} to deleted")
|
||||
try:
|
||||
await prisma.models.AgentPreset.prisma().update_many(
|
||||
where={"id": preset_id, "userId": user_id},
|
||||
@@ -765,7 +884,7 @@ async def fork_library_agent(library_agent_id: str, user_id: str):
|
||||
"""
|
||||
logger.debug(f"Forking library agent {library_agent_id} for user {user_id}")
|
||||
try:
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}-fork_agent"):
|
||||
async with locked_transaction(f"usr_trx_{user_id}-fork_agent"):
|
||||
# Fetch the original agent
|
||||
original_agent = await get_library_agent(library_agent_id, user_id)
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ async def test_add_agent_to_library(mocker):
|
||||
)
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.create = mocker.AsyncMock(
|
||||
return_value=mock_library_agent_data
|
||||
)
|
||||
@@ -159,21 +159,24 @@ async def test_add_agent_to_library(mocker):
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
mock_library_agent.return_value.find_first.assert_called_once_with(
|
||||
mock_library_agent.return_value.find_unique.assert_called_once_with(
|
||||
where={
|
||||
"userId": "test-user",
|
||||
"agentGraphId": "agent1",
|
||||
"agentGraphVersion": 1,
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": "test-user",
|
||||
"agentGraphId": "agent1",
|
||||
"agentGraphVersion": 1,
|
||||
}
|
||||
},
|
||||
include=library_agent_include("test-user"),
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
mock_library_agent.return_value.create.assert_called_once_with(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId="test-user",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
),
|
||||
data={
|
||||
"User": {"connect": {"id": "test-user"}},
|
||||
"AgentGraph": {
|
||||
"connect": {"graphVersionId": {"id": "agent1", "version": 1}}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
},
|
||||
include=library_agent_include("test-user"),
|
||||
)
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ import pydantic
|
||||
import backend.data.block as block_model
|
||||
import backend.data.graph as graph_model
|
||||
import backend.server.model as server_model
|
||||
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
|
||||
class LibraryAgentStatus(str, Enum):
|
||||
@@ -18,6 +20,14 @@ class LibraryAgentStatus(str, Enum):
|
||||
ERROR = "ERROR" # Agent is in an error state
|
||||
|
||||
|
||||
class LibraryAgentTriggerInfo(pydantic.BaseModel):
|
||||
provider: ProviderName
|
||||
config_schema: dict[str, Any] = pydantic.Field(
|
||||
description="Input schema for the trigger block"
|
||||
)
|
||||
credentials_input_name: Optional[str]
|
||||
|
||||
|
||||
class LibraryAgent(pydantic.BaseModel):
|
||||
"""
|
||||
Represents an agent in the library, including metadata for display and
|
||||
@@ -40,8 +50,15 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
# Made input_schema and output_schema match GraphMeta's type
|
||||
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
credentials_input_schema: dict[str, Any] = pydantic.Field(
|
||||
description="Input schema for credentials required by the agent",
|
||||
)
|
||||
|
||||
has_external_trigger: bool = pydantic.Field(
|
||||
description="Whether the agent has an external trigger (e.g. webhook) node"
|
||||
)
|
||||
trigger_setup_info: Optional[LibraryAgentTriggerInfo] = None
|
||||
|
||||
# Indicates whether there's a new output (based on recent runs)
|
||||
new_output: bool
|
||||
@@ -106,6 +123,32 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
input_schema=graph.input_schema,
|
||||
credentials_input_schema=graph.credentials_input_schema,
|
||||
has_external_trigger=graph.has_webhook_trigger,
|
||||
trigger_setup_info=(
|
||||
LibraryAgentTriggerInfo(
|
||||
provider=trigger_block.webhook_config.provider,
|
||||
config_schema={
|
||||
**(json_schema := trigger_block.input_schema.jsonschema()),
|
||||
"properties": {
|
||||
pn: sub_schema
|
||||
for pn, sub_schema in json_schema["properties"].items()
|
||||
if not is_credentials_field_name(pn)
|
||||
},
|
||||
"required": [
|
||||
pn
|
||||
for pn in json_schema.get("required", [])
|
||||
if not is_credentials_field_name(pn)
|
||||
],
|
||||
},
|
||||
credentials_input_name=next(
|
||||
iter(trigger_block.input_schema.get_credentials_fields()), None
|
||||
),
|
||||
)
|
||||
if graph.webhook_input_node
|
||||
and (trigger_block := graph.webhook_input_node.block).webhook_config
|
||||
else None
|
||||
),
|
||||
new_output=new_output,
|
||||
can_access_graph=can_access_graph,
|
||||
is_latest_version=is_latest_version,
|
||||
@@ -177,12 +220,15 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
|
||||
graph_version: int
|
||||
|
||||
inputs: block_model.BlockInput
|
||||
credentials: dict[str, CredentialsMetaInput]
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
is_active: bool = True
|
||||
|
||||
webhook_id: Optional[str] = None
|
||||
|
||||
|
||||
class LibraryAgentPresetCreatableFromGraphExecution(pydantic.BaseModel):
|
||||
"""
|
||||
@@ -203,6 +249,7 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
|
||||
"""
|
||||
|
||||
inputs: Optional[block_model.BlockInput] = None
|
||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
@@ -214,20 +261,28 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
"""Represents a preset configuration for a library agent."""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, preset: prisma.models.AgentPreset) -> "LibraryAgentPreset":
|
||||
if preset.InputPresets is None:
|
||||
raise ValueError("Input values must be included in object")
|
||||
raise ValueError("InputPresets must be included in AgentPreset query")
|
||||
|
||||
input_data: block_model.BlockInput = {}
|
||||
input_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
|
||||
for preset_input in preset.InputPresets:
|
||||
input_data[preset_input.name] = preset_input.data
|
||||
if not is_credentials_field_name(preset_input.name):
|
||||
input_data[preset_input.name] = preset_input.data
|
||||
else:
|
||||
input_credentials[preset_input.name] = (
|
||||
CredentialsMetaInput.model_validate(preset_input.data)
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=preset.id,
|
||||
user_id=preset.userId,
|
||||
updated_at=preset.updatedAt,
|
||||
graph_id=preset.agentGraphId,
|
||||
graph_version=preset.agentGraphVersion,
|
||||
@@ -235,6 +290,8 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
description=preset.description,
|
||||
is_active=preset.isActive,
|
||||
inputs=input_data,
|
||||
credentials=input_credentials,
|
||||
webhook_id=preset.webhookId,
|
||||
)
|
||||
|
||||
|
||||
@@ -276,6 +333,3 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
|
||||
is_archived: Optional[bool] = pydantic.Field(
|
||||
default=None, description="Archive the agent"
|
||||
)
|
||||
is_deleted: Optional[bool] = pydantic.Field(
|
||||
default=None, description="Delete the agent"
|
||||
)
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, status
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor.utils import make_node_credentials_input_map
|
||||
from backend.integrations.webhooks.utils import setup_webhook_for_block
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -71,10 +77,10 @@ async def list_library_agents(
|
||||
page_size=page_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Listing library agents failed for user %s: %s", user_id, e)
|
||||
logger.error(f"Could not list library agents for user #{user_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Inspect database connectivity."},
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
|
||||
@@ -86,6 +92,23 @@ async def get_library_agent(
|
||||
return await library_db.get_library_agent(id=library_agent_id, user_id=user_id)
|
||||
|
||||
|
||||
@router.get("/by-graph/{graph_id}")
|
||||
async def get_library_agent_by_graph_id(
|
||||
graph_id: str,
|
||||
version: Optional[int] = Query(default=None),
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
library_agent = await library_db.get_library_agent_by_graph_id(
|
||||
user_id, graph_id, version
|
||||
)
|
||||
if not library_agent:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Library agent for graph #{graph_id} and user #{user_id} not found",
|
||||
)
|
||||
return library_agent
|
||||
|
||||
|
||||
@router.get(
|
||||
"/marketplace/{store_listing_version_id}",
|
||||
summary="Get Agent By Store ID",
|
||||
@@ -103,18 +126,16 @@ async def get_library_agent_by_store_listing_version_id(
|
||||
return await library_db.get_library_agent_by_store_version_id(
|
||||
store_listing_version_id, user_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Retrieving library agent by store version failed for user %s: %s",
|
||||
user_id,
|
||||
e,
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not fetch library agent from store version ID: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"message": str(e),
|
||||
"hint": "Check if the store listing ID is valid.",
|
||||
},
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
|
||||
@@ -152,26 +173,20 @@ async def add_marketplace_agent_to_library(
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
except store_exceptions.AgentNotFoundError:
|
||||
except store_exceptions.AgentNotFoundError as e:
|
||||
logger.warning(
|
||||
"Store listing version %s not found when adding to library",
|
||||
store_listing_version_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"message": f"Store listing version {store_listing_version_id} not found",
|
||||
"hint": "Confirm the ID provided.",
|
||||
},
|
||||
f"Could not find store listing version {store_listing_version_id} "
|
||||
"to add to library"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
except store_exceptions.DatabaseError as e:
|
||||
logger.exception("Database error whilst adding agent to library: %s", e)
|
||||
logger.error(f"Database error while adding agent to library: {e}", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Inspect DB logs for details."},
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error while adding agent to library: %s", e)
|
||||
logger.error(f"Unexpected error while adding agent to library: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
@@ -181,12 +196,11 @@ async def add_marketplace_agent_to_library(
|
||||
) from e
|
||||
|
||||
|
||||
@router.put(
|
||||
@router.patch(
|
||||
"/{library_agent_id}",
|
||||
summary="Update Library Agent",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
responses={
|
||||
204: {"description": "Agent updated successfully"},
|
||||
200: {"description": "Agent updated successfully"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
@@ -194,7 +208,7 @@ async def update_library_agent(
|
||||
library_agent_id: str,
|
||||
payload: library_model.LibraryAgentUpdateRequest,
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> JSONResponse:
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Update the library agent with the given fields.
|
||||
|
||||
@@ -203,39 +217,75 @@ async def update_library_agent(
|
||||
payload: Fields to update (auto_update_version, is_favorite, etc.).
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
204 (No Content) on success.
|
||||
|
||||
Raises:
|
||||
HTTPException(500): If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
await library_db.update_library_agent(
|
||||
return await library_db.update_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
auto_update_version=payload.auto_update_version,
|
||||
is_favorite=payload.is_favorite,
|
||||
is_archived=payload.is_archived,
|
||||
is_deleted=payload.is_deleted,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
content={"message": "Agent updated successfully"},
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except store_exceptions.DatabaseError as e:
|
||||
logger.exception("Database error while updating library agent: %s", e)
|
||||
logger.error(f"Database error while updating library agent: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Verify DB connection."},
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error while updating library agent: %s", e)
|
||||
logger.error(f"Unexpected error while updating library agent: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Check server logs."},
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{library_agent_id}",
|
||||
summary="Delete Library Agent",
|
||||
responses={
|
||||
204: {"description": "Agent deleted successfully"},
|
||||
404: {"description": "Agent not found"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def delete_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> Response:
|
||||
"""
|
||||
Soft-delete the specified library agent.
|
||||
|
||||
Args:
|
||||
library_agent_id: ID of the library agent to delete.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
204 No Content if successful.
|
||||
|
||||
Raises:
|
||||
HTTPException(404): If the agent does not exist.
|
||||
HTTPException(500): If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
await library_db.delete_library_agent(
|
||||
library_agent_id=library_agent_id, user_id=user_id
|
||||
)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/{library_agent_id}/fork", summary="Fork Library Agent")
|
||||
async def fork_library_agent(
|
||||
library_agent_id: str,
|
||||
@@ -245,3 +295,81 @@ async def fork_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
class TriggeredPresetSetupParams(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
trigger_config: dict[str, Any]
|
||||
agent_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@router.post("/{library_agent_id}/setup-trigger")
|
||||
async def setup_trigger(
|
||||
library_agent_id: str = Path(..., description="ID of the library agent"),
|
||||
params: TriggeredPresetSetupParams = Body(),
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
"""
|
||||
Sets up a webhook-triggered `LibraryAgentPreset` for a `LibraryAgent`.
|
||||
Returns the correspondingly created `LibraryAgentPreset` with `webhook_id` set.
|
||||
"""
|
||||
library_agent = await library_db.get_library_agent(
|
||||
id=library_agent_id, user_id=user_id
|
||||
)
|
||||
if not library_agent:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Library agent #{library_agent_id} not found",
|
||||
)
|
||||
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id, version=library_agent.graph_version, user_id=user_id
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status.HTTP_410_GONE,
|
||||
f"Graph #{library_agent.graph_id} not accessible (anymore)",
|
||||
)
|
||||
if not (trigger_node := graph.webhook_input_node):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Graph #{library_agent.graph_id} does not have a webhook node",
|
||||
)
|
||||
|
||||
trigger_config_with_credentials = {
|
||||
**params.trigger_config,
|
||||
**(
|
||||
make_node_credentials_input_map(graph, params.agent_credentials).get(
|
||||
trigger_node.id
|
||||
)
|
||||
or {}
|
||||
),
|
||||
}
|
||||
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=trigger_node.block,
|
||||
trigger_config=trigger_config_with_credentials,
|
||||
)
|
||||
if not new_webhook:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Could not set up webhook: {feedback}",
|
||||
)
|
||||
|
||||
new_preset = await library_db.create_preset(
|
||||
user_id=user_id,
|
||||
preset=library_model.LibraryAgentPresetCreatable(
|
||||
graph_id=library_agent.graph_id,
|
||||
graph_version=library_agent.graph_version,
|
||||
name=params.name,
|
||||
description=params.description,
|
||||
inputs=trigger_config_with_credentials,
|
||||
credentials=params.agent_credentials,
|
||||
webhook_id=new_webhook.id,
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
return new_preset
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
import logging
|
||||
from typing import Annotated, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.integrations import get_webhook
|
||||
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.integrations.webhooks.utils import setup_webhook_for_block
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
tags=["presets"],
|
||||
)
|
||||
credentials_manager = IntegrationCredentialsManager()
|
||||
router = APIRouter(tags=["presets"])
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -51,11 +55,7 @@ async def list_presets(
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list presets for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"message": str(e),
|
||||
"hint": "Ensure the presets DB table is accessible.",
|
||||
},
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
)
|
||||
|
||||
|
||||
@@ -83,21 +83,21 @@ async def get_preset(
|
||||
"""
|
||||
try:
|
||||
preset = await db.get_preset(user_id, preset_id)
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Preset {preset_id} not found",
|
||||
)
|
||||
return preset
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Error retrieving preset %s for user %s: %s", preset_id, user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Validate preset ID and retry."},
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
)
|
||||
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Preset #{preset_id} not found",
|
||||
)
|
||||
return preset
|
||||
|
||||
|
||||
@router.post(
|
||||
"/presets",
|
||||
@@ -134,8 +134,7 @@ async def create_preset(
|
||||
except Exception as e:
|
||||
logger.exception("Preset creation failed for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Check preset payload format."},
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
)
|
||||
|
||||
|
||||
@@ -163,17 +162,85 @@ async def update_preset(
|
||||
Raises:
|
||||
HTTPException: If an error occurs while updating the preset.
|
||||
"""
|
||||
current = await get_preset(preset_id, user_id=user_id)
|
||||
if not current:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Preset #{preset_id} not found")
|
||||
|
||||
graph = await get_graph(
|
||||
current.graph_id,
|
||||
current.graph_version,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status.HTTP_410_GONE,
|
||||
f"Graph #{current.graph_id} not accessible (anymore)",
|
||||
)
|
||||
|
||||
trigger_inputs_updated, new_webhook, feedback = False, None, None
|
||||
if (trigger_node := graph.webhook_input_node) and (
|
||||
preset.inputs is not None and preset.credentials is not None
|
||||
):
|
||||
trigger_config_with_credentials = {
|
||||
**preset.inputs,
|
||||
**(
|
||||
make_node_credentials_input_map(graph, preset.credentials).get(
|
||||
trigger_node.id
|
||||
)
|
||||
or {}
|
||||
),
|
||||
}
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=graph.webhook_input_node.block,
|
||||
trigger_config=trigger_config_with_credentials,
|
||||
for_preset_id=preset_id,
|
||||
)
|
||||
trigger_inputs_updated = True
|
||||
if not new_webhook:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Could not update trigger configuration: {feedback}",
|
||||
)
|
||||
|
||||
try:
|
||||
return await db.update_preset(
|
||||
user_id=user_id, preset_id=preset_id, preset=preset
|
||||
updated = await db.update_preset(
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
inputs=preset.inputs,
|
||||
credentials=preset.credentials,
|
||||
name=preset.name,
|
||||
description=preset.description,
|
||||
is_active=preset.is_active,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Preset update failed for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Check preset data and try again."},
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
)
|
||||
|
||||
# Update the webhook as well, if necessary
|
||||
if trigger_inputs_updated:
|
||||
updated = await db.set_preset_webhook(
|
||||
user_id, preset_id, new_webhook.id if new_webhook else None
|
||||
)
|
||||
|
||||
# Clean up webhook if it is now unused
|
||||
if current.webhook_id and (
|
||||
current.webhook_id != (new_webhook.id if new_webhook else None)
|
||||
):
|
||||
current_webhook = await get_webhook(current.webhook_id)
|
||||
credentials = (
|
||||
await credentials_manager.get(user_id, current_webhook.credentials_id)
|
||||
if current_webhook.credentials_id
|
||||
else None
|
||||
)
|
||||
await get_webhook_manager(
|
||||
current_webhook.provider
|
||||
).prune_webhook_if_dangling(user_id, current_webhook.id, credentials)
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/presets/{preset_id}",
|
||||
@@ -195,6 +262,28 @@ async def delete_preset(
|
||||
Raises:
|
||||
HTTPException: If an error occurs while deleting the preset.
|
||||
"""
|
||||
preset = await db.get_preset(user_id, preset_id)
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Preset #{preset_id} not found for user #{user_id}",
|
||||
)
|
||||
|
||||
# Detach and clean up the attached webhook, if any
|
||||
if preset.webhook_id:
|
||||
webhook = await get_webhook(preset.webhook_id)
|
||||
await db.set_preset_webhook(user_id, preset_id, None)
|
||||
|
||||
# Clean up webhook if it is now unused
|
||||
credentials = (
|
||||
await credentials_manager.get(user_id, webhook.credentials_id)
|
||||
if webhook.credentials_id
|
||||
else None
|
||||
)
|
||||
await get_webhook_manager(webhook.provider).prune_webhook_if_dangling(
|
||||
user_id, webhook.id, credentials
|
||||
)
|
||||
|
||||
try:
|
||||
await db.delete_preset(user_id, preset_id)
|
||||
except Exception as e:
|
||||
@@ -203,7 +292,7 @@ async def delete_preset(
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Ensure preset exists before deleting."},
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@@ -214,24 +303,20 @@ async def delete_preset(
|
||||
description="Execute a preset with the given graph and node input for the current user.",
|
||||
)
|
||||
async def execute_preset(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
preset_id: str,
|
||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
inputs: dict[str, Any] = Body(..., embed=True, default_factory=dict),
|
||||
) -> dict[str, Any]: # FIXME: add proper return type
|
||||
"""
|
||||
Execute a preset given graph parameters, returning the execution ID on success.
|
||||
|
||||
Args:
|
||||
graph_id (str): ID of the graph to execute.
|
||||
graph_version (int): Version of the graph to execute.
|
||||
preset_id (str): ID of the preset to execute.
|
||||
node_input (Dict[Any, Any]): Input data for the node.
|
||||
user_id (str): ID of the authenticated user.
|
||||
inputs (dict[str, Any]): Optionally, additional input data for the graph execution.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A response containing the execution ID.
|
||||
{id: graph_exec_id}: A response containing the execution ID.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the preset is not found or an error occurs while executing the preset.
|
||||
@@ -241,18 +326,18 @@ async def execute_preset(
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Preset not found",
|
||||
detail=f"Preset #{preset_id} not found",
|
||||
)
|
||||
|
||||
# Merge input overrides with preset inputs
|
||||
merged_node_input = preset.inputs | node_input
|
||||
merged_node_input = preset.inputs | inputs
|
||||
|
||||
execution = await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=merged_node_input,
|
||||
graph_id=preset.graph_id,
|
||||
graph_version=preset.graph_version,
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
inputs=merged_node_input,
|
||||
)
|
||||
|
||||
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
|
||||
@@ -263,9 +348,6 @@ async def execute_preset(
|
||||
except Exception as e:
|
||||
logger.exception("Preset execution failed for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"message": str(e),
|
||||
"hint": "Review preset configuration and graph ID.",
|
||||
},
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
@@ -50,6 +50,8 @@ async def test_get_library_agents_success(
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
@@ -66,6 +68,8 @@ async def test_get_library_agents_success(
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
new_output=False,
|
||||
can_access_graph=False,
|
||||
@@ -117,26 +121,57 @@ def test_get_library_agents_error(mocker: pytest_mock.MockFixture):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Mocker Not implemented")
|
||||
def test_add_agent_to_library_success(mocker: pytest_mock.MockFixture):
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
|
||||
mock_db_call.return_value = None
|
||||
mock_library_agent = library_model.LibraryAgent(
|
||||
id="test-library-agent-id",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
is_latest_version=True,
|
||||
updated_at=FIXED_NOW,
|
||||
)
|
||||
|
||||
response = client.post("/agents/test-version-id")
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.add_store_agent_to_library"
|
||||
)
|
||||
mock_db_call.return_value = mock_library_agent
|
||||
|
||||
response = client.post(
|
||||
"/agents", json={"store_listing_version_id": "test-version-id"}
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
# Verify the response contains the library agent data
|
||||
data = library_model.LibraryAgent.model_validate(response.json())
|
||||
assert data.id == "test-library-agent-id"
|
||||
assert data.graph_id == "test-agent-1"
|
||||
|
||||
mock_db_call.assert_called_once_with(
|
||||
store_listing_version_id="test-version-id", user_id="test-user-id"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Mocker Not implemented")
|
||||
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture):
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.add_store_agent_to_library"
|
||||
)
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.post("/agents/test-version-id")
|
||||
response = client.post(
|
||||
"/agents", json={"store_listing_version_id": "test-version-id"}
|
||||
)
|
||||
assert response.status_code == 500
|
||||
assert response.json()["detail"] == "Failed to add agent to library"
|
||||
assert "detail" in response.json() # Verify error response structure
|
||||
mock_db_call.assert_called_once_with(
|
||||
store_listing_version_id="test-version-id", user_id="test-user-id"
|
||||
)
|
||||
|
||||
@@ -259,8 +259,8 @@ def test_ask_otto_unauthenticated(mocker: pytest_mock.MockFixture) -> None:
|
||||
}
|
||||
|
||||
response = client.post("/ask", json=request_data)
|
||||
# When auth is disabled and Otto API URL is not configured, we get 503
|
||||
assert response.status_code == 503
|
||||
# When auth is disabled and Otto API URL is not configured, we get 502 (wrapped from 503)
|
||||
assert response.status_code == 502
|
||||
|
||||
# Restore the override
|
||||
app.dependency_overrides[autogpt_libs.auth.middleware.auth_middleware] = (
|
||||
|
||||
@@ -93,6 +93,14 @@ async def test_get_store_agent_details(mocker):
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Mock Profile prisma call
|
||||
mock_profile = mocker.MagicMock()
|
||||
mock_profile.userId = "user-id-123"
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
|
||||
# Mock StoreListing prisma call - this is what was missing
|
||||
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
|
||||
|
||||
@@ -34,6 +34,20 @@ class StorageUploadError(MediaUploadError):
|
||||
pass
|
||||
|
||||
|
||||
class VirusDetectedError(MediaUploadError):
|
||||
"""Raised when a virus is detected in uploaded file"""
|
||||
|
||||
def __init__(self, threat_name: str, message: str | None = None):
|
||||
self.threat_name = threat_name
|
||||
super().__init__(message or f"Virus detected: {threat_name}")
|
||||
|
||||
|
||||
class VirusScanError(MediaUploadError):
|
||||
"""Raised when virus scanning fails"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StoreError(Exception):
|
||||
"""Base exception for store-related errors"""
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from google.cloud import storage
|
||||
import backend.server.v2.store.exceptions
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -67,7 +68,7 @@ async def upload_media(
|
||||
# Validate file signature/magic bytes
|
||||
if file.content_type in ALLOWED_IMAGE_TYPES:
|
||||
# Check image file signatures
|
||||
if content.startswith(b"\xFF\xD8\xFF"): # JPEG
|
||||
if content.startswith(b"\xff\xd8\xff"): # JPEG
|
||||
if file.content_type != "image/jpeg":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
@@ -175,6 +176,7 @@ async def upload_media(
|
||||
blob.content_type = content_type
|
||||
|
||||
file_bytes = await file.read()
|
||||
await scan_content_safe(file_bytes, filename=unique_filename)
|
||||
blob.upload_from_string(file_bytes, content_type=content_type)
|
||||
|
||||
public_url = blob.public_url
|
||||
|
||||
@@ -12,6 +12,7 @@ from autogpt_libs.auth.depends import auth_middleware, get_user_id
|
||||
import backend.data.block
|
||||
import backend.data.graph
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.image_gen
|
||||
import backend.server.v2.store.media
|
||||
import backend.server.v2.store.model
|
||||
@@ -589,6 +590,25 @@ async def upload_submission_media(
|
||||
user_id=user_id, file=file
|
||||
)
|
||||
return media_url
|
||||
except backend.server.v2.store.exceptions.VirusDetectedError as e:
|
||||
logger.warning(f"Virus detected in uploaded file: {e.threat_name}")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"detail": f"File rejected due to virus detection: {e.threat_name}",
|
||||
"error_type": "virus_detected",
|
||||
"threat_name": e.threat_name,
|
||||
},
|
||||
)
|
||||
except backend.server.v2.store.exceptions.VirusScanError as e:
|
||||
logger.error(f"Virus scanning failed: {str(e)}")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"detail": "Virus scanning service unavailable. Please try again later.",
|
||||
"error_type": "virus_scan_failed",
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst uploading submission media")
|
||||
return fastapi.responses.JSONResponse(
|
||||
|
||||
@@ -19,7 +19,9 @@ from backend.server.ws_api import (
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket() -> AsyncMock:
|
||||
return AsyncMock(spec=WebSocket)
|
||||
mock = AsyncMock(spec=WebSocket)
|
||||
mock.query_params = {} # Add query_params attribute for authentication
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -29,8 +31,13 @@ def mock_manager() -> AsyncMock:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_router_subscribe(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker
|
||||
) -> None:
|
||||
# Mock the authenticate_websocket function to ensure it returns a valid user_id
|
||||
mocker.patch(
|
||||
"backend.server.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
mock_websocket.receive_text.side_effect = [
|
||||
WSMessage(
|
||||
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
|
||||
@@ -70,8 +77,13 @@ async def test_websocket_router_subscribe(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_router_unsubscribe(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker
|
||||
) -> None:
|
||||
# Mock the authenticate_websocket function to ensure it returns a valid user_id
|
||||
mocker.patch(
|
||||
"backend.server.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
mock_websocket.receive_text.side_effect = [
|
||||
WSMessage(
|
||||
method=WSMethod.UNSUBSCRIBE,
|
||||
@@ -108,8 +120,13 @@ async def test_websocket_router_unsubscribe(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_router_invalid_method(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock, mocker
|
||||
) -> None:
|
||||
# Mock the authenticate_websocket function to ensure it returns a valid user_id
|
||||
mocker.patch(
|
||||
"backend.server.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
mock_websocket.receive_text.side_effect = [
|
||||
WSMessage(method=WSMethod.GRAPH_EXECUTION_EVENT).model_dump_json(),
|
||||
WebSocketDisconnect(),
|
||||
@@ -2,7 +2,17 @@ import functools
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Awaitable, Callable, Coroutine, ParamSpec, Tuple, TypeVar
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Literal,
|
||||
ParamSpec,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -72,37 +82,115 @@ def async_time_measured(
|
||||
return async_wrapper
|
||||
|
||||
|
||||
def error_logged(func: Callable[P, T]) -> Callable[P, T | None]:
|
||||
@overload
|
||||
def error_logged(
|
||||
*, swallow: Literal[True]
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T | None]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def error_logged(
|
||||
*, swallow: Literal[False]
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def error_logged() -> Callable[[Callable[P, T]], Callable[P, T | None]]: ...
|
||||
|
||||
|
||||
def error_logged(
|
||||
*, swallow: bool = True
|
||||
) -> (
|
||||
Callable[[Callable[P, T]], Callable[P, T | None]]
|
||||
| Callable[[Callable[P, T]], Callable[P, T]]
|
||||
):
|
||||
"""
|
||||
Decorator to suppress and log any exceptions raised by a function.
|
||||
Decorator to log any exceptions raised by a function, with optional suppression.
|
||||
|
||||
Args:
|
||||
swallow: Whether to suppress the exception (True) or re-raise it (False). Default is True.
|
||||
|
||||
Usage:
|
||||
@error_logged() # Default behavior (swallow errors)
|
||||
@error_logged(swallow=False) # Log and re-raise errors
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error when calling function {func.__name__} with arguments {args} {kwargs}: {e}"
|
||||
)
|
||||
def decorator(f: Callable[P, T]) -> Callable[P, T | None]:
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
|
||||
try:
|
||||
return f(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error when calling function {f.__name__} with arguments {args} {kwargs}: {e}"
|
||||
)
|
||||
if not swallow:
|
||||
raise
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def async_error_logged(
|
||||
func: Callable[P, Coroutine[Any, Any, T]],
|
||||
) -> Callable[P, Coroutine[Any, Any, T | None]]:
|
||||
*, swallow: Literal[True]
|
||||
) -> Callable[
|
||||
[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T | None]]
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def async_error_logged(
|
||||
*, swallow: Literal[False]
|
||||
) -> Callable[
|
||||
[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T]]
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def async_error_logged() -> Callable[
|
||||
[Callable[P, Coroutine[Any, Any, T]]],
|
||||
Callable[P, Coroutine[Any, Any, T | None]],
|
||||
]: ...
|
||||
|
||||
|
||||
def async_error_logged(*, swallow: bool = True) -> (
|
||||
Callable[
|
||||
[Callable[P, Coroutine[Any, Any, T]]],
|
||||
Callable[P, Coroutine[Any, Any, T | None]],
|
||||
]
|
||||
| Callable[
|
||||
[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T]]
|
||||
]
|
||||
):
|
||||
"""
|
||||
Decorator to suppress and log any exceptions raised by an async function.
|
||||
Decorator to log any exceptions raised by an async function, with optional suppression.
|
||||
|
||||
Args:
|
||||
swallow: Whether to suppress the exception (True) or re-raise it (False). Default is True.
|
||||
|
||||
Usage:
|
||||
@async_error_logged() # Default behavior (swallow errors)
|
||||
@async_error_logged(swallow=False) # Log and re-raise errors
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error when calling async function {func.__name__} with arguments {args} {kwargs}: {e}"
|
||||
)
|
||||
def decorator(
|
||||
f: Callable[P, Coroutine[Any, Any, T]]
|
||||
) -> Callable[P, Coroutine[Any, Any, T | None]]:
|
||||
@functools.wraps(f)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
|
||||
try:
|
||||
return await f(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error when calling async function {f.__name__} with arguments {args} {kwargs}: {e}"
|
||||
)
|
||||
if not swallow:
|
||||
raise
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
74
autogpt_platform/backend/backend/util/decorator_test.py
Normal file
74
autogpt_platform/backend/backend/util/decorator_test.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.decorator import async_error_logged, error_logged, time_measured
|
||||
|
||||
|
||||
@time_measured
|
||||
def example_function(a: int, b: int, c: int) -> int:
|
||||
time.sleep(0.5)
|
||||
return a + b + c
|
||||
|
||||
|
||||
@error_logged(swallow=True)
|
||||
def example_function_with_error_swallowed(a: int, b: int, c: int) -> int:
|
||||
raise ValueError("This error should be swallowed")
|
||||
|
||||
|
||||
@error_logged(swallow=False)
|
||||
def example_function_with_error_not_swallowed(a: int, b: int, c: int) -> int:
|
||||
raise ValueError("This error should NOT be swallowed")
|
||||
|
||||
|
||||
@async_error_logged(swallow=True)
|
||||
async def async_function_with_error_swallowed() -> int:
|
||||
raise ValueError("This async error should be swallowed")
|
||||
|
||||
|
||||
@async_error_logged(swallow=False)
|
||||
async def async_function_with_error_not_swallowed() -> int:
|
||||
raise ValueError("This async error should NOT be swallowed")
|
||||
|
||||
|
||||
def test_timer_decorator():
|
||||
"""Test that the time_measured decorator correctly measures execution time."""
|
||||
info, res = example_function(1, 2, 3)
|
||||
assert info.cpu_time >= 0
|
||||
assert info.wall_time >= 0.4
|
||||
assert res == 6
|
||||
|
||||
|
||||
def test_error_decorator_swallow_true():
|
||||
"""Test that error_logged(swallow=True) logs and swallows errors."""
|
||||
res = example_function_with_error_swallowed(1, 2, 3)
|
||||
assert res is None
|
||||
|
||||
|
||||
def test_error_decorator_swallow_false():
|
||||
"""Test that error_logged(swallow=False) logs errors but re-raises them."""
|
||||
with pytest.raises(ValueError, match="This error should NOT be swallowed"):
|
||||
example_function_with_error_not_swallowed(1, 2, 3)
|
||||
|
||||
|
||||
def test_async_error_decorator_swallow_true():
|
||||
"""Test that async_error_logged(swallow=True) logs and swallows errors."""
|
||||
import asyncio
|
||||
|
||||
async def run_test():
|
||||
res = await async_function_with_error_swallowed()
|
||||
return res
|
||||
|
||||
res = asyncio.run(run_test())
|
||||
assert res is None
|
||||
|
||||
|
||||
def test_async_error_decorator_swallow_false():
|
||||
"""Test that async_error_logged(swallow=False) logs errors but re-raises them."""
|
||||
import asyncio
|
||||
|
||||
async def run_test():
|
||||
await async_function_with_error_not_swallowed()
|
||||
|
||||
with pytest.raises(ValueError, match="This async error should NOT be swallowed"):
|
||||
asyncio.run(run_test())
|
||||
@@ -10,6 +10,10 @@ class NeedConfirmation(Exception):
|
||||
"""The user must explicitly confirm that they want to proceed"""
|
||||
|
||||
|
||||
class NotAuthorizedError(ValueError):
|
||||
"""The user is not authorized to perform the requested operation"""
|
||||
|
||||
|
||||
class InsufficientBalanceError(ValueError):
|
||||
user_id: str
|
||||
message: str
|
||||
|
||||
@@ -9,6 +9,7 @@ from urllib.parse import urlparse
|
||||
|
||||
from backend.util.request import Requests
|
||||
from backend.util.type import MediaFileType
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
|
||||
|
||||
@@ -105,7 +106,11 @@ async def store_media_file(
|
||||
extension = _extension_from_mime(mime_type)
|
||||
filename = f"{uuid.uuid4()}{extension}"
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
target_path.write_bytes(base64.b64decode(b64_content))
|
||||
content = base64.b64decode(b64_content)
|
||||
|
||||
# Virus scan the base64 content before writing
|
||||
await scan_content_safe(content, filename=filename)
|
||||
target_path.write_bytes(content)
|
||||
|
||||
elif file.startswith(("http://", "https://")):
|
||||
# URL
|
||||
@@ -115,6 +120,9 @@ async def store_media_file(
|
||||
|
||||
# Download and save
|
||||
resp = await Requests().get(file)
|
||||
|
||||
# Virus scan the downloaded content before writing
|
||||
await scan_content_safe(resp.content, filename=filename)
|
||||
target_path.write_bytes(resp.content)
|
||||
|
||||
else:
|
||||
|
||||
@@ -14,8 +14,37 @@ def to_dict(data) -> dict:
|
||||
return jsonable_encoder(data)
|
||||
|
||||
|
||||
def dumps(data) -> str:
|
||||
return json.dumps(to_dict(data))
|
||||
def dumps(data: Any, *args: Any, **kwargs: Any) -> str:
|
||||
"""
|
||||
Serialize data to JSON string with automatic conversion of Pydantic models and complex types.
|
||||
|
||||
This function converts the input data to a JSON-serializable format using FastAPI's
|
||||
jsonable_encoder before dumping to JSON. It handles Pydantic models, complex types,
|
||||
and ensures proper serialization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : Any
|
||||
The data to serialize. Can be any type including Pydantic models, dicts, lists, etc.
|
||||
*args : Any
|
||||
Additional positional arguments passed to json.dumps()
|
||||
**kwargs : Any
|
||||
Additional keyword arguments passed to json.dumps() (e.g., indent, separators)
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
JSON string representation of the data
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> dumps({"name": "Alice", "age": 30})
|
||||
'{"name": "Alice", "age": 30}'
|
||||
|
||||
>>> dumps(pydantic_model_instance, indent=2)
|
||||
'{\n "field1": "value1",\n "field2": "value2"\n}'
|
||||
"""
|
||||
return json.dumps(to_dict(data), *args, **kwargs)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from logging import Logger
|
||||
import logging
|
||||
|
||||
from backend.util.settings import AppEnvironment, BehaveAs, Settings
|
||||
|
||||
@@ -6,8 +6,6 @@ settings = Settings()
|
||||
|
||||
|
||||
def configure_logging():
|
||||
import logging
|
||||
|
||||
import autogpt_libs.logging.config
|
||||
|
||||
if (
|
||||
@@ -25,7 +23,7 @@ def configure_logging():
|
||||
class TruncatedLogger:
|
||||
def __init__(
|
||||
self,
|
||||
logger: Logger,
|
||||
logger: logging.Logger,
|
||||
prefix: str = "",
|
||||
metadata: dict | None = None,
|
||||
max_length: int = 1000,
|
||||
@@ -65,3 +63,13 @@ class TruncatedLogger:
|
||||
if len(text) > self.max_length:
|
||||
text = text[: self.max_length] + "..."
|
||||
return text
|
||||
|
||||
|
||||
class PrefixFilter(logging.Filter):
|
||||
def __init__(self, prefix: str):
|
||||
super().__init__()
|
||||
self.prefix = prefix
|
||||
|
||||
def filter(self, record):
|
||||
record.msg = f"{self.prefix} {record.msg}"
|
||||
return True
|
||||
|
||||
206
autogpt_platform/backend/backend/util/prompt.py
Normal file
206
autogpt_platform/backend/backend/util/prompt.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.util import json
|
||||
|
||||
# ---------------------------------------------------------------------------#
|
||||
# INTERNAL UTILITIES #
|
||||
# ---------------------------------------------------------------------------#
|
||||
|
||||
|
||||
def _tok_len(text: str, enc) -> int:
|
||||
"""True token length of *text* in tokenizer *enc* (no wrapper cost)."""
|
||||
return len(enc.encode(text))
|
||||
|
||||
|
||||
def _msg_tokens(msg: dict, enc) -> int:
|
||||
"""
|
||||
OpenAI counts ≈3 wrapper tokens per chat message, plus 1 if "name"
|
||||
is present, plus the tokenised content length.
|
||||
"""
|
||||
WRAPPER = 3 + (1 if "name" in msg else 0)
|
||||
return WRAPPER + _tok_len(msg.get("content") or "", enc)
|
||||
|
||||
|
||||
def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||
"""
|
||||
Return *text* shortened to ≈max_tok tokens by keeping the head & tail
|
||||
and inserting an ellipsis token in the middle.
|
||||
"""
|
||||
ids = enc.encode(text)
|
||||
if len(ids) <= max_tok:
|
||||
return text # nothing to do
|
||||
|
||||
# Split the allowance between the two ends:
|
||||
head = max_tok // 2 - 1 # -1 for the ellipsis
|
||||
tail = max_tok - head - 1
|
||||
mid = enc.encode(" … ")
|
||||
return enc.decode(ids[:head] + mid + ids[-tail:])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------#
|
||||
# PUBLIC API #
|
||||
# ---------------------------------------------------------------------------#
|
||||
|
||||
|
||||
def compress_prompt(
|
||||
messages: list[dict],
|
||||
target_tokens: int,
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
reserve: int = 2_048,
|
||||
start_cap: int = 8_192,
|
||||
floor_cap: int = 128,
|
||||
lossy_ok: bool = True,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Shrink *messages* so that::
|
||||
|
||||
token_count(prompt) + reserve ≤ target_tokens
|
||||
|
||||
Strategy
|
||||
--------
|
||||
1. **Token-aware truncation** – progressively halve a per-message cap
|
||||
(`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the
|
||||
*content* of every message except the first and last. Tool shells
|
||||
are included: we keep the envelope but shorten huge payloads.
|
||||
2. **Middle-out deletion** – if still over the limit, delete whole
|
||||
messages working outward from the centre, **skipping** any message
|
||||
that contains ``tool_calls`` or has ``role == "tool"``.
|
||||
3. **Last-chance trim** – if still too big, truncate the *first* and
|
||||
*last* message bodies down to `floor_cap` tokens.
|
||||
4. If the prompt is *still* too large:
|
||||
• raise ``ValueError`` when ``lossy_ok == False`` (default)
|
||||
• return the partially-trimmed prompt when ``lossy_ok == True``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
messages Complete chat history (will be deep-copied).
|
||||
model Model name; passed to tiktoken to pick the right
|
||||
tokenizer (gpt-4o → 'o200k_base', others fallback).
|
||||
target_tokens Hard ceiling for prompt size **excluding** the model's
|
||||
forthcoming answer.
|
||||
reserve How many tokens you want to leave available for that
|
||||
answer (`max_tokens` in your subsequent completion call).
|
||||
start_cap Initial per-message truncation ceiling (tokens).
|
||||
floor_cap Lowest cap we'll accept before moving to deletions.
|
||||
lossy_ok If *True* return best-effort prompt instead of raising
|
||||
after all trim passes have been exhausted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[dict] – A *new* messages list that abides by the rules above.
|
||||
"""
|
||||
enc = encoding_for_model(model) # best-match tokenizer
|
||||
msgs = deepcopy(messages) # never mutate caller
|
||||
|
||||
def total_tokens() -> int:
|
||||
"""Current size of *msgs* in tokens."""
|
||||
return sum(_msg_tokens(m, enc) for m in msgs)
|
||||
|
||||
original_token_count = total_tokens()
|
||||
if original_token_count + reserve <= target_tokens:
|
||||
return msgs
|
||||
|
||||
# ---- STEP 0 : normalise content --------------------------------------
|
||||
# Convert non-string payloads to strings so token counting is coherent.
|
||||
for m in msgs[1:-1]: # keep the first & last intact
|
||||
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
||||
# Reasonable 20k-char ceiling prevents pathological blobs
|
||||
content_str = json.dumps(m["content"], separators=(",", ":"))
|
||||
if len(content_str) > 20_000:
|
||||
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
|
||||
m["content"] = content_str
|
||||
|
||||
# ---- STEP 1 : token-aware truncation ---------------------------------
|
||||
cap = start_cap
|
||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||
for m in msgs[1:-1]: # keep first & last intact
|
||||
if _tok_len(m.get("content") or "", enc) > cap:
|
||||
m["content"] = _truncate_middle_tokens(m["content"], enc, cap)
|
||||
cap //= 2 # tighten the screw
|
||||
|
||||
# ---- STEP 2 : middle-out deletion -----------------------------------
|
||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||
centre = len(msgs) // 2
|
||||
# Build a symmetrical centre-out index walk: centre, centre+1, centre-1, ...
|
||||
order = [centre] + [
|
||||
i
|
||||
for pair in zip(range(centre + 1, len(msgs) - 1), range(centre - 1, 0, -1))
|
||||
for i in pair
|
||||
]
|
||||
removed = False
|
||||
for i in order:
|
||||
msg = msgs[i]
|
||||
if "tool_calls" in msg or msg.get("role") == "tool":
|
||||
continue # protect tool shells
|
||||
del msgs[i]
|
||||
removed = True
|
||||
break
|
||||
if not removed: # nothing more we can drop
|
||||
break
|
||||
|
||||
# ---- STEP 3 : final safety-net trim on first & last ------------------
|
||||
cap = start_cap
|
||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||
for idx in (0, -1): # first and last
|
||||
text = msgs[idx].get("content") or ""
|
||||
if _tok_len(text, enc) > cap:
|
||||
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
|
||||
cap //= 2 # tighten the screw
|
||||
|
||||
# ---- STEP 4 : success or fail-gracefully -----------------------------
|
||||
if total_tokens() + reserve > target_tokens and not lossy_ok:
|
||||
raise ValueError(
|
||||
"compress_prompt: prompt still exceeds budget "
|
||||
f"({total_tokens() + reserve} > {target_tokens})."
|
||||
)
|
||||
|
||||
return msgs
|
||||
|
||||
|
||||
def estimate_token_count(
|
||||
messages: list[dict],
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
) -> int:
|
||||
"""
|
||||
Return the true token count of *messages* when encoded for *model*.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
messages Complete chat history.
|
||||
model Model name; passed to tiktoken to pick the right
|
||||
tokenizer (gpt-4o → 'o200k_base', others fallback).
|
||||
|
||||
Returns
|
||||
-------
|
||||
int – Token count.
|
||||
"""
|
||||
enc = encoding_for_model(model) # best-match tokenizer
|
||||
return sum(_msg_tokens(m, enc) for m in messages)
|
||||
|
||||
|
||||
def estimate_token_count_str(
|
||||
text: Any,
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
) -> int:
|
||||
"""
|
||||
Return the true token count of *text* when encoded for *model*.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text Input text.
|
||||
model Model name; passed to tiktoken to pick the right
|
||||
tokenizer (gpt-4o → 'o200k_base', others fallback).
|
||||
|
||||
Returns
|
||||
-------
|
||||
int – Token count.
|
||||
"""
|
||||
enc = encoding_for_model(model) # best-match tokenizer
|
||||
text = json.dumps(text) if not isinstance(text, str) else text
|
||||
return _tok_len(text, enc)
|
||||
@@ -430,7 +430,13 @@ class Requests:
|
||||
) as response:
|
||||
|
||||
if self.raise_for_status:
|
||||
response.raise_for_status()
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except ClientResponseError as e:
|
||||
body = await response.read()
|
||||
raise Exception(
|
||||
f"HTTP {response.status} Error: {response.reason}, Body: {body.decode(errors='replace')}"
|
||||
) from e
|
||||
|
||||
# If allowed and a redirect is received, follow the redirect manually
|
||||
if allow_redirects and response.status in (301, 302, 303, 307, 308):
|
||||
|
||||
@@ -31,7 +31,7 @@ from tenacity import (
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
import backend.util.exceptions as exceptions
|
||||
from backend.util.json import to_dict
|
||||
from backend.util.metrics import sentry_init
|
||||
from backend.util.process import AppProcess, get_service_name
|
||||
@@ -106,7 +106,13 @@ EXCEPTION_MAPPING = {
|
||||
ValueError,
|
||||
TimeoutError,
|
||||
ConnectionError,
|
||||
InsufficientBalanceError,
|
||||
*[
|
||||
ErrorType
|
||||
for _, ErrorType in inspect.getmembers(exceptions)
|
||||
if inspect.isclass(ErrorType)
|
||||
and issubclass(ErrorType, Exception)
|
||||
and ErrorType.__module__ == exceptions.__name__
|
||||
],
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class ServiceTestClient(AppServiceClient):
|
||||
subtract_async = endpoint_to_async(ServiceTest.subtract)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_creation(server):
|
||||
with ServiceTest():
|
||||
client = get_service_client(ServiceTestClient)
|
||||
@@ -238,6 +238,31 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The Discord channel for the platform",
|
||||
)
|
||||
|
||||
clamav_service_host: str = Field(
|
||||
default="localhost",
|
||||
description="The host for the ClamAV daemon",
|
||||
)
|
||||
clamav_service_port: int = Field(
|
||||
default=3310,
|
||||
description="The port for the ClamAV daemon",
|
||||
)
|
||||
clamav_service_timeout: int = Field(
|
||||
default=60,
|
||||
description="The timeout in seconds for the ClamAV daemon",
|
||||
)
|
||||
clamav_service_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether virus scanning is enabled or not",
|
||||
)
|
||||
clamav_max_concurrency: int = Field(
|
||||
default=10,
|
||||
description="The maximum number of concurrent scans to perform",
|
||||
)
|
||||
clamav_mark_failed_scans_as_clean: bool = Field(
|
||||
default=False,
|
||||
description="Whether to mark failed scans as clean or not",
|
||||
)
|
||||
|
||||
@field_validator("platform_base_url", "frontend_base_url")
|
||||
@classmethod
|
||||
def validate_platform_base_url(cls, v: str, info: ValidationInfo) -> str:
|
||||
|
||||
@@ -152,9 +152,16 @@ async def main():
|
||||
print(f"Inserting {NUM_USERS * MAX_AGENTS_PER_USER} user agents")
|
||||
for user in users:
|
||||
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
|
||||
for _ in range(num_agents): # Create 1 LibraryAgent per user
|
||||
graph = random.choice(agent_graphs)
|
||||
preset = random.choice(agent_presets)
|
||||
|
||||
# Get a shuffled list of graphs to ensure uniqueness per user
|
||||
available_graphs = agent_graphs.copy()
|
||||
random.shuffle(available_graphs)
|
||||
|
||||
# Limit to available unique graphs
|
||||
num_agents = min(num_agents, len(available_graphs))
|
||||
|
||||
for i in range(num_agents):
|
||||
graph = available_graphs[i] # Use unique graph for each library agent
|
||||
user_agent = await db.libraryagent.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
@@ -180,7 +187,7 @@ async def main():
|
||||
MIN_EXECUTIONS_PER_GRAPH, MAX_EXECUTIONS_PER_GRAPH
|
||||
)
|
||||
for _ in range(num_executions):
|
||||
matching_presets = [p for p in agent_presets if p.agentId == graph.id]
|
||||
matching_presets = [p for p in agent_presets if p.agentGraphId == graph.id]
|
||||
preset = (
|
||||
random.choice(matching_presets)
|
||||
if matching_presets and random.random() < 0.5
|
||||
@@ -355,7 +362,7 @@ async def main():
|
||||
store_listing_versions = []
|
||||
print(f"Inserting {NUM_USERS} store listing versions")
|
||||
for listing in store_listings:
|
||||
graph = [g for g in agent_graphs if g.id == listing.agentId][0]
|
||||
graph = [g for g in agent_graphs if g.id == listing.agentGraphId][0]
|
||||
version = await db.storelistingversion.create(
|
||||
data={
|
||||
"agentGraphId": graph.id,
|
||||
209
autogpt_platform/backend/backend/util/virus_scanner.py
Normal file
209
autogpt_platform/backend/backend/util/virus_scanner.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import aioclamd
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class VirusScanResult(BaseModel):
|
||||
is_clean: bool
|
||||
scan_time_ms: int
|
||||
file_size: int
|
||||
threat_name: Optional[str] = None
|
||||
|
||||
|
||||
class VirusScannerSettings(BaseSettings):
|
||||
# Tunables for the scanner layer (NOT the ClamAV daemon).
|
||||
clamav_service_host: str = "localhost"
|
||||
clamav_service_port: int = 3310
|
||||
clamav_service_timeout: int = 60
|
||||
clamav_service_enabled: bool = True
|
||||
# If the service is disabled, all files are considered clean.
|
||||
mark_failed_scans_as_clean: bool = False
|
||||
# Client-side protective limits
|
||||
max_scan_size: int = 2 * 1024 * 1024 * 1024 # 2 GB guard-rail in memory
|
||||
min_chunk_size: int = 128 * 1024 # 128 KB hard floor
|
||||
max_retries: int = 8 # halve ≤ max_retries times
|
||||
# Concurrency throttle toward the ClamAV daemon. Do *NOT* simply turn this
|
||||
# up to the number of CPU cores; keep it ≤ (MaxThreads / pods) – 1.
|
||||
max_concurrency: int = 5
|
||||
|
||||
|
||||
class VirusScannerService:
|
||||
"""Fully-async ClamAV wrapper using **aioclamd**.
|
||||
|
||||
• Reuses a single `ClamdAsyncClient` connection (aioclamd keeps the socket open).
|
||||
• Throttles concurrent `INSTREAM` calls with an `asyncio.Semaphore` so we don't exhaust daemon worker threads or file descriptors.
|
||||
• Falls back to progressively smaller chunk sizes when the daemon rejects a stream as too large.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: VirusScannerSettings) -> None:
|
||||
self.settings = settings
|
||||
self._client = aioclamd.ClamdAsyncClient(
|
||||
host=settings.clamav_service_host,
|
||||
port=settings.clamav_service_port,
|
||||
timeout=settings.clamav_service_timeout,
|
||||
)
|
||||
self._sem = asyncio.Semaphore(settings.max_concurrency)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
@staticmethod
|
||||
def _parse_raw(raw: Optional[dict]) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Convert aioclamd output to (infected?, threat_name).
|
||||
Returns (False, None) for clean.
|
||||
"""
|
||||
if not raw:
|
||||
return False, None
|
||||
status, threat = next(iter(raw.values()))
|
||||
return status == "FOUND", threat
|
||||
|
||||
async def _instream(self, chunk: bytes) -> Tuple[bool, Optional[str]]:
|
||||
"""Scan **one** chunk with concurrency control."""
|
||||
async with self._sem:
|
||||
try:
|
||||
raw = await self._client.instream(io.BytesIO(chunk))
|
||||
return self._parse_raw(raw)
|
||||
except (BrokenPipeError, ConnectionResetError) as exc:
|
||||
raise RuntimeError("size-limit") from exc
|
||||
except Exception as exc:
|
||||
if "INSTREAM size limit exceeded" in str(exc):
|
||||
raise RuntimeError("size-limit") from exc
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Public API
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
async def scan_file(
|
||||
self, content: bytes, *, filename: str = "unknown"
|
||||
) -> VirusScanResult:
|
||||
"""
|
||||
Scan `content`. Returns a result object or raises on infrastructure
|
||||
failure (unreachable daemon, etc.).
|
||||
The algorithm always tries whole-file first. If the daemon refuses
|
||||
on size grounds, it falls back to chunked parallel scanning.
|
||||
"""
|
||||
if not self.settings.clamav_service_enabled:
|
||||
logger.warning(f"Virus scanning disabled – accepting {filename}")
|
||||
return VirusScanResult(
|
||||
is_clean=True, scan_time_ms=0, file_size=len(content)
|
||||
)
|
||||
if len(content) > self.settings.max_scan_size:
|
||||
logger.warning(
|
||||
f"File {filename} ({len(content)} bytes) exceeds client max scan size ({self.settings.max_scan_size}); Stopping virus scan"
|
||||
)
|
||||
return VirusScanResult(
|
||||
is_clean=self.settings.mark_failed_scans_as_clean,
|
||||
file_size=len(content),
|
||||
scan_time_ms=0,
|
||||
)
|
||||
|
||||
# Ensure daemon is reachable (small RTT check)
|
||||
if not await self._client.ping():
|
||||
raise RuntimeError("ClamAV service is unreachable")
|
||||
|
||||
start = time.monotonic()
|
||||
chunk_size = len(content) # Start with full content length
|
||||
for retry in range(self.settings.max_retries):
|
||||
# For small files, don't check min_chunk_size limit
|
||||
if chunk_size < self.settings.min_chunk_size and chunk_size < len(content):
|
||||
break
|
||||
logger.debug(
|
||||
f"Scanning {filename} with chunk size: {chunk_size // 1_048_576} MB (retry {retry + 1}/{self.settings.max_retries})"
|
||||
)
|
||||
try:
|
||||
tasks = [
|
||||
asyncio.create_task(self._instream(content[o : o + chunk_size]))
|
||||
for o in range(0, len(content), chunk_size)
|
||||
]
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
infected, threat = await coro
|
||||
if infected:
|
||||
for t in tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
return VirusScanResult(
|
||||
is_clean=False,
|
||||
threat_name=threat,
|
||||
file_size=len(content),
|
||||
scan_time_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
# All chunks clean
|
||||
return VirusScanResult(
|
||||
is_clean=True,
|
||||
file_size=len(content),
|
||||
scan_time_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
if str(exc) == "size-limit":
|
||||
chunk_size //= 2
|
||||
continue
|
||||
logger.error(f"Cannot scan {filename}: {exc}")
|
||||
raise
|
||||
# Phase 3 – give up but warn
|
||||
logger.warning(
|
||||
f"Unable to virus scan {filename} ({len(content)} bytes) even with minimum chunk size ({self.settings.min_chunk_size} bytes). Recommend manual review."
|
||||
)
|
||||
return VirusScanResult(
|
||||
is_clean=self.settings.mark_failed_scans_as_clean,
|
||||
file_size=len(content),
|
||||
scan_time_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
|
||||
|
||||
_scanner: Optional[VirusScannerService] = None
|
||||
|
||||
|
||||
def get_virus_scanner() -> VirusScannerService:
|
||||
global _scanner
|
||||
if _scanner is None:
|
||||
_settings = VirusScannerSettings(
|
||||
clamav_service_host=settings.config.clamav_service_host,
|
||||
clamav_service_port=settings.config.clamav_service_port,
|
||||
clamav_service_enabled=settings.config.clamav_service_enabled,
|
||||
max_concurrency=settings.config.clamav_max_concurrency,
|
||||
mark_failed_scans_as_clean=settings.config.clamav_mark_failed_scans_as_clean,
|
||||
)
|
||||
_scanner = VirusScannerService(_settings)
|
||||
return _scanner
|
||||
|
||||
|
||||
async def scan_content_safe(content: bytes, *, filename: str = "unknown") -> None:
|
||||
"""
|
||||
Helper function to scan content and raise appropriate exceptions.
|
||||
|
||||
Raises:
|
||||
VirusDetectedError: If virus is found
|
||||
VirusScanError: If scanning fails
|
||||
"""
|
||||
from backend.server.v2.store.exceptions import VirusDetectedError, VirusScanError
|
||||
|
||||
try:
|
||||
result = await get_virus_scanner().scan_file(content, filename=filename)
|
||||
if not result.is_clean:
|
||||
threat_name = result.threat_name or "Unknown threat"
|
||||
logger.warning(f"Virus detected in file {filename}: {threat_name}")
|
||||
raise VirusDetectedError(
|
||||
threat_name, f"File rejected due to virus detection: {threat_name}"
|
||||
)
|
||||
|
||||
logger.info(f"File {filename} passed virus scan in {result.scan_time_ms}ms")
|
||||
|
||||
except VirusDetectedError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Virus scanning failed for {filename}: {str(e)}")
|
||||
raise VirusScanError(f"Virus scanning failed: {str(e)}") from e
|
||||
253
autogpt_platform/backend/backend/util/virus_scanner_test.py
Normal file
253
autogpt_platform/backend/backend/util/virus_scanner_test.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.store.exceptions import VirusDetectedError, VirusScanError
|
||||
from backend.util.virus_scanner import (
|
||||
VirusScannerService,
|
||||
VirusScannerSettings,
|
||||
VirusScanResult,
|
||||
get_virus_scanner,
|
||||
scan_content_safe,
|
||||
)
|
||||
|
||||
|
||||
class TestVirusScannerService:
|
||||
@pytest.fixture
|
||||
def scanner_settings(self):
|
||||
return VirusScannerSettings(
|
||||
clamav_service_host="localhost",
|
||||
clamav_service_port=3310,
|
||||
clamav_service_enabled=True,
|
||||
max_scan_size=10 * 1024 * 1024, # 10MB for testing
|
||||
mark_failed_scans_as_clean=False, # For testing, failed scans should be clean
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def scanner(self, scanner_settings):
|
||||
return VirusScannerService(scanner_settings)
|
||||
|
||||
@pytest.fixture
|
||||
def disabled_scanner(self):
|
||||
settings = VirusScannerSettings(clamav_service_enabled=False)
|
||||
return VirusScannerService(settings)
|
||||
|
||||
def test_scanner_initialization(self, scanner_settings):
|
||||
scanner = VirusScannerService(scanner_settings)
|
||||
assert scanner.settings.clamav_service_host == "localhost"
|
||||
assert scanner.settings.clamav_service_port == 3310
|
||||
assert scanner.settings.clamav_service_enabled is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_disabled_returns_clean(self, disabled_scanner):
|
||||
content = b"test file content"
|
||||
result = await disabled_scanner.scan_file(content, filename="test.txt")
|
||||
|
||||
assert result.is_clean is True
|
||||
assert result.threat_name is None
|
||||
assert result.file_size == len(content)
|
||||
assert result.scan_time_ms == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_file_too_large(self, scanner):
|
||||
# Create content larger than max_scan_size
|
||||
large_content = b"x" * (scanner.settings.max_scan_size + 1)
|
||||
|
||||
# Large files behavior depends on mark_failed_scans_as_clean setting
|
||||
result = await scanner.scan_file(large_content, filename="large_file.txt")
|
||||
assert result.is_clean == scanner.settings.mark_failed_scans_as_clean
|
||||
assert result.file_size == len(large_content)
|
||||
assert result.scan_time_ms == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_file_too_large_both_configurations(self):
|
||||
"""Test large file handling with both mark_failed_scans_as_clean configurations"""
|
||||
large_content = b"x" * (10 * 1024 * 1024 + 1) # Larger than 10MB
|
||||
|
||||
# Test with mark_failed_scans_as_clean=True
|
||||
settings_clean = VirusScannerSettings(
|
||||
max_scan_size=10 * 1024 * 1024, mark_failed_scans_as_clean=True
|
||||
)
|
||||
scanner_clean = VirusScannerService(settings_clean)
|
||||
result_clean = await scanner_clean.scan_file(
|
||||
large_content, filename="large_file.txt"
|
||||
)
|
||||
assert result_clean.is_clean is True
|
||||
|
||||
# Test with mark_failed_scans_as_clean=False
|
||||
settings_dirty = VirusScannerSettings(
|
||||
max_scan_size=10 * 1024 * 1024, mark_failed_scans_as_clean=False
|
||||
)
|
||||
scanner_dirty = VirusScannerService(settings_dirty)
|
||||
result_dirty = await scanner_dirty.scan_file(
|
||||
large_content, filename="large_file.txt"
|
||||
)
|
||||
assert result_dirty.is_clean is False
|
||||
|
||||
# Note: ping method was removed from current implementation
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_clean_file(self, scanner):
|
||||
async def mock_instream(_):
|
||||
await asyncio.sleep(0.001) # Small delay to ensure timing > 0
|
||||
return None # No virus detected
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.ping = AsyncMock(return_value=True)
|
||||
mock_client.instream = AsyncMock(side_effect=mock_instream)
|
||||
|
||||
# Replace the client instance that was created in the constructor
|
||||
scanner._client = mock_client
|
||||
|
||||
content = b"clean file content"
|
||||
result = await scanner.scan_file(content, filename="clean.txt")
|
||||
|
||||
assert result.is_clean is True
|
||||
assert result.threat_name is None
|
||||
assert result.file_size == len(content)
|
||||
assert result.scan_time_ms > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_infected_file(self, scanner):
|
||||
async def mock_instream(_):
|
||||
await asyncio.sleep(0.001) # Small delay to ensure timing > 0
|
||||
return {"stream": ("FOUND", "Win.Test.EICAR_HDB-1")}
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.ping = AsyncMock(return_value=True)
|
||||
mock_client.instream = AsyncMock(side_effect=mock_instream)
|
||||
|
||||
# Replace the client instance that was created in the constructor
|
||||
scanner._client = mock_client
|
||||
|
||||
content = b"infected file content"
|
||||
result = await scanner.scan_file(content, filename="infected.txt")
|
||||
|
||||
assert result.is_clean is False
|
||||
assert result.threat_name == "Win.Test.EICAR_HDB-1"
|
||||
assert result.file_size == len(content)
|
||||
assert result.scan_time_ms > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_clamav_unavailable_fail_safe(self, scanner):
|
||||
mock_client = Mock()
|
||||
mock_client.ping = AsyncMock(return_value=False)
|
||||
|
||||
# Replace the client instance that was created in the constructor
|
||||
scanner._client = mock_client
|
||||
|
||||
content = b"test content"
|
||||
|
||||
with pytest.raises(RuntimeError, match="ClamAV service is unreachable"):
|
||||
await scanner.scan_file(content, filename="test.txt")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_error_fail_safe(self, scanner):
|
||||
mock_client = Mock()
|
||||
mock_client.ping = AsyncMock(return_value=True)
|
||||
mock_client.instream = AsyncMock(side_effect=Exception("Scanning error"))
|
||||
|
||||
# Replace the client instance that was created in the constructor
|
||||
scanner._client = mock_client
|
||||
|
||||
content = b"test content"
|
||||
|
||||
with pytest.raises(Exception, match="Scanning error"):
|
||||
await scanner.scan_file(content, filename="test.txt")
|
||||
|
||||
# Note: scan_file_method and scan_upload_file tests removed as these APIs don't exist in current implementation
|
||||
|
||||
def test_get_virus_scanner_singleton(self):
|
||||
scanner1 = get_virus_scanner()
|
||||
scanner2 = get_virus_scanner()
|
||||
|
||||
# Should return the same instance
|
||||
assert scanner1 is scanner2
|
||||
|
||||
# Note: client_reuse test removed as _get_client method doesn't exist in current implementation
|
||||
|
||||
def test_scan_result_model(self):
|
||||
# Test VirusScanResult model
|
||||
result = VirusScanResult(
|
||||
is_clean=False, threat_name="Test.Virus", scan_time_ms=150, file_size=1024
|
||||
)
|
||||
|
||||
assert result.is_clean is False
|
||||
assert result.threat_name == "Test.Virus"
|
||||
assert result.scan_time_ms == 150
|
||||
assert result.file_size == 1024
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_scans(self, scanner):
|
||||
async def mock_instream(_):
|
||||
await asyncio.sleep(0.001) # Small delay to ensure timing > 0
|
||||
return None
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.ping = AsyncMock(return_value=True)
|
||||
mock_client.instream = AsyncMock(side_effect=mock_instream)
|
||||
|
||||
# Replace the client instance that was created in the constructor
|
||||
scanner._client = mock_client
|
||||
|
||||
content1 = b"file1 content"
|
||||
content2 = b"file2 content"
|
||||
|
||||
# Run concurrent scans
|
||||
results = await asyncio.gather(
|
||||
scanner.scan_file(content1, filename="file1.txt"),
|
||||
scanner.scan_file(content2, filename="file2.txt"),
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert all(result.is_clean for result in results)
|
||||
assert results[0].file_size == len(content1)
|
||||
assert results[1].file_size == len(content2)
|
||||
assert all(result.scan_time_ms > 0 for result in results)
|
||||
|
||||
|
||||
class TestHelperFunctions:
|
||||
"""Test the helper functions scan_content_safe"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_content_safe_clean(self):
|
||||
"""Test scan_content_safe with clean content"""
|
||||
with patch("backend.util.virus_scanner.get_virus_scanner") as mock_get_scanner:
|
||||
mock_scanner = Mock()
|
||||
mock_scanner.scan_file = AsyncMock()
|
||||
mock_scanner.scan_file.return_value = Mock(
|
||||
is_clean=True, threat_name=None, scan_time_ms=50, file_size=100
|
||||
)
|
||||
mock_get_scanner.return_value = mock_scanner
|
||||
|
||||
# Should not raise any exception
|
||||
await scan_content_safe(b"clean content", filename="test.txt")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_content_safe_infected(self):
|
||||
"""Test scan_content_safe with infected content"""
|
||||
with patch("backend.util.virus_scanner.get_virus_scanner") as mock_get_scanner:
|
||||
mock_scanner = Mock()
|
||||
mock_scanner.scan_file = AsyncMock()
|
||||
mock_scanner.scan_file.return_value = Mock(
|
||||
is_clean=False, threat_name="Test.Virus", scan_time_ms=50, file_size=100
|
||||
)
|
||||
mock_get_scanner.return_value = mock_scanner
|
||||
|
||||
with pytest.raises(VirusDetectedError) as exc_info:
|
||||
await scan_content_safe(b"infected content", filename="virus.txt")
|
||||
|
||||
assert exc_info.value.threat_name == "Test.Virus"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_content_safe_scan_error(self):
|
||||
"""Test scan_content_safe when scanning fails"""
|
||||
with patch("backend.util.virus_scanner.get_virus_scanner") as mock_get_scanner:
|
||||
mock_scanner = Mock()
|
||||
mock_scanner.scan_file = AsyncMock()
|
||||
mock_scanner.scan_file.side_effect = Exception("Scan failed")
|
||||
mock_get_scanner.return_value = mock_scanner
|
||||
|
||||
with pytest.raises(VirusScanError, match="Virus scanning failed"):
|
||||
await scan_content_safe(b"test content", filename="test.txt")
|
||||
@@ -0,0 +1,5 @@
|
||||
-- Add webhookId column
|
||||
ALTER TABLE "AgentPreset" ADD COLUMN "webhookId" TEXT;
|
||||
|
||||
-- Add AgentPreset<->IntegrationWebhook relation
|
||||
ALTER TABLE "AgentPreset" ADD CONSTRAINT "AgentPreset_webhookId_fkey" FOREIGN KEY ("webhookId") REFERENCES "IntegrationWebhook"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
35
autogpt_platform/backend/poetry.lock
generated
35
autogpt_platform/backend/poetry.lock
generated
@@ -17,6 +17,18 @@ aiormq = ">=6.8,<6.9"
|
||||
exceptiongroup = ">=1,<2"
|
||||
yarl = "*"
|
||||
|
||||
[[package]]
|
||||
name = "aioclamd"
|
||||
version = "1.0.0"
|
||||
description = "Asynchronous client for virus scanning with ClamAV"
|
||||
optional = false
|
||||
python-versions = ">=3.7,<4.0"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "aioclamd-1.0.0-py3-none-any.whl", hash = "sha256:4727da3953a4b38be4c2de1acb6b3bb3c94c1c171dcac780b80234ee6253f3d9"},
|
||||
{file = "aioclamd-1.0.0.tar.gz", hash = "sha256:7b14e94e3a2285cc89e2f4d434e2a01f322d3cb95476ce2dda015a7980876047"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiodns"
|
||||
version = "3.4.0"
|
||||
@@ -5006,6 +5018,27 @@ statsig = ["statsig (>=0.55.3)"]
|
||||
tornado = ["tornado (>=6)"]
|
||||
unleash = ["UnleashClient (>=6.0.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
version = "80.9.0"
|
||||
description = "Easily download, build, install, upgrade, and uninstall Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"},
|
||||
{file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""]
|
||||
core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"]
|
||||
cover = ["pytest-cov"]
|
||||
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
|
||||
enabler = ["pytest-enabler (>=2.2)"]
|
||||
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
|
||||
type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.14.*)", "pytest-mypy"]
|
||||
|
||||
[[package]]
|
||||
name = "sgmllib3k"
|
||||
version = "1.0.0"
|
||||
@@ -6369,4 +6402,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "6c93e51cf22c06548aa6d0e23ca8ceb4450f5e02d4142715e941aabc1a2cbd6a"
|
||||
content-hash = "b5c1201f27ee8d05d5d8c89702123df4293f124301d1aef7451591a351872260"
|
||||
|
||||
@@ -68,6 +68,9 @@ zerobouncesdk = "^1.1.1"
|
||||
# NOTE: please insert new dependencies in their alphabetical location
|
||||
pytest-snapshot = "^0.9.0"
|
||||
aiofiles = "^24.1.0"
|
||||
tiktoken = "^0.9.0"
|
||||
aioclamd = "^1.0.0"
|
||||
setuptools = "^80.9.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
aiohappyeyeballs = "^2.6.1"
|
||||
@@ -112,6 +115,11 @@ ignore_patterns = []
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "session"
|
||||
filterwarnings = [
|
||||
"ignore:'audioop' is deprecated:DeprecationWarning:discord.player",
|
||||
"ignore:invalid escape sequence:DeprecationWarning:tweepy.api",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
|
||||
@@ -169,6 +169,10 @@ model AgentPreset {
|
||||
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
|
||||
Executions AgentGraphExecution[]
|
||||
|
||||
// For webhook-triggered agents: reference to the webhook that triggers the agent
|
||||
webhookId String?
|
||||
Webhook IntegrationWebhook? @relation(fields: [webhookId], references: [id])
|
||||
|
||||
isDeleted Boolean @default(false)
|
||||
|
||||
@@index([userId])
|
||||
@@ -428,7 +432,8 @@ model IntegrationWebhook {
|
||||
|
||||
providerWebhookId String // Webhook ID assigned by the provider
|
||||
|
||||
AgentNodes AgentNode[]
|
||||
AgentNodes AgentNode[]
|
||||
AgentPresets AgentPreset[]
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"metric_id": "metric-123-uuid"
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"metric_id": "metric-float_precision-uuid",
|
||||
"test_case": "float_precision"
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"metric_id": "metric-integer_value-uuid",
|
||||
"test_case": "integer_value"
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"metric_id": "metric-large_number-uuid",
|
||||
"test_case": "large_number"
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"metric_id": "metric-negative_value-uuid",
|
||||
"test_case": "negative_value"
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user