mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
87 Commits
feat/eleve
...
swiftyos/i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f850ba033e | ||
|
|
13d7b53991 | ||
|
|
ab4cf9d557 | ||
|
|
9af79f750e | ||
|
|
c3c1ac9845 | ||
|
|
a225b8ab72 | ||
|
|
c432d14db9 | ||
|
|
6435cd340c | ||
|
|
a2f3c322dc | ||
|
|
38c167ff87 | ||
|
|
31ae7e2838 | ||
|
|
1885f88a6f | ||
|
|
c5aa147fd1 | ||
|
|
7790672d9f | ||
|
|
a633c440a9 | ||
|
|
dc9a2f84e7 | ||
|
|
e3115dbe08 | ||
|
|
126498b8d0 | ||
|
|
c5dec20e0c | ||
|
|
922150c7fa | ||
|
|
3aa04d4b96 | ||
|
|
03ca3f9179 | ||
|
|
f9e0b08e19 | ||
|
|
8882768bbf | ||
|
|
249249bdcc | ||
|
|
163713df1a | ||
|
|
ee91540b1a | ||
|
|
a7503ac716 | ||
|
|
df2ef41213 | ||
|
|
a0da6dd09f | ||
|
|
ec73331c79 | ||
|
|
39758a7ee0 | ||
|
|
30cebab17e | ||
|
|
bc7ab15951 | ||
|
|
3fbd3d79af | ||
|
|
c5539c8699 | ||
|
|
dfbeb10342 | ||
|
|
9daf6fb765 | ||
|
|
b3ceceda17 | ||
|
|
002b951c88 | ||
|
|
7a5c5db56f | ||
|
|
5fd15c74bf | ||
|
|
467219323a | ||
|
|
e148063a33 | ||
|
|
3ccecb7f8e | ||
|
|
eecf8c2020 | ||
|
|
35c50e2d4c | ||
|
|
b478ae51c1 | ||
|
|
e564e15701 | ||
|
|
748600d069 | ||
|
|
31aaabc1eb | ||
|
|
4f057c5b72 | ||
|
|
75309047cf | ||
|
|
e58a4599c8 | ||
|
|
848990411d | ||
|
|
ae500cd9c6 | ||
|
|
7f062545ba | ||
|
|
b75967a9a1 | ||
|
|
7c4c9fda0c | ||
|
|
03289f7a84 | ||
|
|
088613c64b | ||
|
|
0aaaf55452 | ||
|
|
aa66188a9a | ||
|
|
31bcdb97a7 | ||
|
|
d1b8dcd298 | ||
|
|
5e27cb3147 | ||
|
|
a09ecab7f1 | ||
|
|
864f76f904 | ||
|
|
19b979ea7f | ||
|
|
213f9aaa90 | ||
|
|
7f10fe9d70 | ||
|
|
31b31e00d9 | ||
|
|
f054d2642b | ||
|
|
0d469bb094 | ||
|
|
bfdc387e02 | ||
|
|
31b99c9572 | ||
|
|
617533fa1d | ||
|
|
f99c974ea8 | ||
|
|
12d43fb2fe | ||
|
|
b49b627a14 | ||
|
|
8073f41804 | ||
|
|
fcf91a0721 | ||
|
|
bce9a6ff46 | ||
|
|
87c802898d | ||
|
|
e353e1e25f | ||
|
|
ea06aed1e1 | ||
|
|
ef9814457c |
10
.github/workflows/platform-frontend-ci.yml
vendored
10
.github/workflows/platform-frontend-ci.yml
vendored
@@ -155,6 +155,8 @@ jobs:
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
browser: [chromium, webkit]
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -206,11 +208,13 @@ jobs:
|
||||
run: pnpm build --turbo
|
||||
# uses Turbopack, much faster and safe enough for a test pipeline
|
||||
|
||||
- name: Install Browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
- name: Install Browser '${{ matrix.browser }}'
|
||||
run: pnpm playwright install --with-deps ${{ matrix.browser }}
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
run: pnpm test:no-build --project=${{ matrix.browser }}
|
||||
env:
|
||||
BROWSER_TYPE: ${{ matrix.browser }}
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -177,3 +177,6 @@ autogpt_platform/backend/settings.py
|
||||
*.ign.*
|
||||
.test-contents
|
||||
.claude/settings.local.json
|
||||
|
||||
api.md
|
||||
blocks.md
|
||||
6
.vscode/launch.json
vendored
6
.vscode/launch.json
vendored
@@ -6,7 +6,7 @@
|
||||
"type": "node-terminal",
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceFolder}/autogpt_platform/frontend",
|
||||
"command": "pnpm dev"
|
||||
"command": "yarn dev"
|
||||
},
|
||||
{
|
||||
"name": "Frontend: Client Side",
|
||||
@@ -19,12 +19,12 @@
|
||||
"type": "node-terminal",
|
||||
|
||||
"request": "launch",
|
||||
"command": "pnpm dev",
|
||||
"command": "yarn dev",
|
||||
"cwd": "${workspaceFolder}/autogpt_platform/frontend",
|
||||
"serverReadyAction": {
|
||||
"pattern": "- Local:.+(https?://.+)",
|
||||
"uriFormat": "%s",
|
||||
"action": "debugWithChrome"
|
||||
"action": "debugWithEdge"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
23
README.md
23
README.md
@@ -1,7 +1,8 @@
|
||||
# AutoGPT: Build, Deploy, and Run AI Agents
|
||||
|
||||
[](https://discord.gg/autogpt)  
|
||||
[](https://discord.gg/autogpt)  
|
||||
[](https://twitter.com/Auto_GPT)  
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
**AutoGPT** is a powerful platform that allows you to create, deploy, and manage continuous AI agents that automate complex workflows.
|
||||
|
||||
@@ -49,24 +50,6 @@ We've moved to a fully maintained and regularly updated documentation site.
|
||||
|
||||
This tutorial assumes you have Docker, VSCode, git and npm installed.
|
||||
|
||||
---
|
||||
|
||||
#### ⚡ Quick Setup with One-Line Script (Recommended for Local Hosting)
|
||||
|
||||
Skip the manual steps and get started in minutes using our automatic setup script.
|
||||
|
||||
For macOS/Linux:
|
||||
```
|
||||
curl -fsSL https://setup.agpt.co/install.sh -o install.sh && bash install.sh
|
||||
```
|
||||
|
||||
For Windows (PowerShell):
|
||||
```
|
||||
powershell -c "iwr https://setup.agpt.co/install.bat -o install.bat; ./install.bat"
|
||||
```
|
||||
|
||||
This will install dependencies, configure Docker, and launch your local instance — all in one go.
|
||||
|
||||
### 🧱 AutoGPT Frontend
|
||||
|
||||
The AutoGPT frontend is where users interact with our powerful AI automation platform. It offers multiple ways to engage with and leverage our AI agents. This is the interface where you'll bring your AI automation ideas to life:
|
||||
@@ -223,4 +206,4 @@ To maintain a uniform standard and ensure seamless compatibility with many curre
|
||||
|
||||
<a href="https://github.com/Significant-Gravitas/AutoGPT/graphs/contributors" alt="View Contributors">
|
||||
<img src="https://contrib.rocks/image?repo=Significant-Gravitas/AutoGPT&max=1000&columns=10" alt="Contributors" />
|
||||
</a>
|
||||
</a>
|
||||
1767
autogpt_platform/autogpt_libs/poetry.lock
generated
1767
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -11,19 +11,19 @@ python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
expiringdict = "^1.2.2"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
pydantic = "^2.11.7"
|
||||
pydantic-settings = "^2.10.1"
|
||||
pydantic = "^2.11.4"
|
||||
pydantic-settings = "^2.9.1"
|
||||
pyjwt = "^2.10.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
supabase = "^2.16.0"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
fastapi = "^0.116.1"
|
||||
uvicorn = "^0.35.0"
|
||||
pytest-asyncio = "^0.26.0"
|
||||
pytest-mock = "^3.14.0"
|
||||
supabase = "^2.15.1"
|
||||
launchdarkly-server-sdk = "^9.11.1"
|
||||
fastapi = "^0.115.12"
|
||||
uvicorn = "^0.34.3"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
redis = "^5.2.1"
|
||||
ruff = "^0.12.3"
|
||||
ruff = "^0.12.2"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -199,10 +199,6 @@ ZEROBOUNCE_API_KEY=
|
||||
|
||||
## ===== OPTIONAL API KEYS END ===== ##
|
||||
|
||||
# Block Error Rate Monitoring
|
||||
BLOCK_ERROR_RATE_THRESHOLD=0.5
|
||||
BLOCK_ERROR_RATE_CHECK_INTERVAL_SECS=86400
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
ENABLE_CLOUD_LOGGING=false
|
||||
@@ -214,7 +210,3 @@ ENABLE_FILE_LOGGING=false
|
||||
# Set to true to enable example blocks in development
|
||||
# These blocks are disabled by default in production
|
||||
ENABLE_EXAMPLE_BLOCKS=false
|
||||
|
||||
# Cloud Storage Configuration
|
||||
# Cleanup interval for expired files (hours between cleanup runs, 1-24 hours)
|
||||
CLOUD_STORAGE_CLEANUP_INTERVAL_HOURS=6
|
||||
|
||||
@@ -1,150 +0,0 @@
|
||||
# Test Data Scripts
|
||||
|
||||
This directory contains scripts for creating and updating test data in the AutoGPT Platform database, specifically designed to test the materialized views for the store functionality.
|
||||
|
||||
## Scripts
|
||||
|
||||
### test_data_creator.py
|
||||
Creates a comprehensive set of test data including:
|
||||
- Users with profiles
|
||||
- Agent graphs, nodes, and executions
|
||||
- Store listings with multiple versions
|
||||
- Reviews and ratings
|
||||
- Library agents
|
||||
- Integration webhooks
|
||||
- Onboarding data
|
||||
- Credit transactions
|
||||
|
||||
**Image/Video Domains Used:**
|
||||
- Images: `picsum.photos` (for all image URLs)
|
||||
- Videos: `youtube.com` (for store listing videos)
|
||||
|
||||
### test_data_updater.py
|
||||
Updates existing test data to simulate real-world changes:
|
||||
- Adds new agent graph executions
|
||||
- Creates new store listing reviews
|
||||
- Updates store listing versions
|
||||
- Adds credit transactions
|
||||
- Refreshes materialized views
|
||||
|
||||
### check_db.py
|
||||
Tests and verifies materialized views functionality:
|
||||
- Checks pg_cron job status (for automatic refresh)
|
||||
- Displays current materialized view counts
|
||||
- Adds test data (executions and reviews)
|
||||
- Creates store listings if none exist
|
||||
- Manually refreshes materialized views
|
||||
- Compares before/after counts to verify updates
|
||||
- Provides a summary of test results
|
||||
|
||||
## Materialized Views
|
||||
|
||||
The scripts test three key database views:
|
||||
|
||||
1. **mv_agent_run_counts**: Tracks execution counts by agent
|
||||
2. **mv_review_stats**: Tracks review statistics (count, average rating) by store listing
|
||||
3. **StoreAgent**: A view that combines store listing data with execution counts and ratings for display
|
||||
|
||||
The materialized views (mv_agent_run_counts and mv_review_stats) are automatically refreshed every 15 minutes via pg_cron, or can be manually refreshed using the `refresh_store_materialized_views()` function.
|
||||
|
||||
## Usage
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. Ensure the database is running:
|
||||
```bash
|
||||
docker compose up -d
|
||||
# or for test database:
|
||||
docker compose -f docker-compose.test.yaml --env-file ../.env up -d
|
||||
```
|
||||
|
||||
2. Run database migrations:
|
||||
```bash
|
||||
poetry run prisma migrate deploy
|
||||
```
|
||||
|
||||
### Running the Scripts
|
||||
|
||||
#### Option 1: Use the helper script (from backend directory)
|
||||
```bash
|
||||
poetry run python run_test_data.py
|
||||
```
|
||||
|
||||
#### Option 2: Run individually
|
||||
```bash
|
||||
# From backend/test directory:
|
||||
# Create initial test data
|
||||
poetry run python test_data_creator.py
|
||||
|
||||
# Update data to test materialized view changes
|
||||
poetry run python test_data_updater.py
|
||||
|
||||
# From backend directory:
|
||||
# Test materialized views functionality
|
||||
poetry run python check_db.py
|
||||
|
||||
# Check store data status
|
||||
poetry run python check_store_data.py
|
||||
```
|
||||
|
||||
#### Option 3: Use the shell script (from backend directory)
|
||||
```bash
|
||||
./run_test_data_scripts.sh
|
||||
```
|
||||
|
||||
### Manual Materialized View Refresh
|
||||
|
||||
To manually refresh the materialized views:
|
||||
```sql
|
||||
SELECT refresh_store_materialized_views();
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The scripts use the database configuration from your `.env` file:
|
||||
- `DATABASE_URL`: PostgreSQL connection string
|
||||
- Database should have the platform schema
|
||||
|
||||
## Data Generation Limits
|
||||
|
||||
Configured in `test_data_creator.py`:
|
||||
- 100 users
|
||||
- 100 agent blocks
|
||||
- 1-5 graphs per user
|
||||
- 2-5 nodes per graph
|
||||
- 1-5 presets per user
|
||||
- 1-10 library agents per user
|
||||
- 1-20 executions per graph
|
||||
- 1-5 reviews per store listing version
|
||||
|
||||
## Notes
|
||||
|
||||
- All image URLs use `picsum.photos` for consistency with Next.js image configuration
|
||||
- The scripts create realistic relationships between entities
|
||||
- Materialized views are refreshed at the end of each script
|
||||
- Data is designed to test both happy paths and edge cases
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Reviews and StoreAgent view showing 0
|
||||
|
||||
If `check_db.py` shows that reviews remain at 0 and StoreAgent view shows 0 store agents:
|
||||
|
||||
1. **No store listings exist**: The script will automatically create test store listings if none exist
|
||||
2. **No approved versions**: Store listings need approved versions to appear in the StoreAgent view
|
||||
3. **Check with `check_store_data.py`**: This script provides detailed information about:
|
||||
- Total store listings
|
||||
- Store listing versions by status
|
||||
- Existing reviews
|
||||
- StoreAgent view contents
|
||||
- Agent graph executions
|
||||
|
||||
### pg_cron not installed
|
||||
|
||||
The warning "pg_cron extension is not installed" is normal in local development environments. The materialized views can still be refreshed manually using the `refresh_store_materialized_views()` function, which all scripts do automatically.
|
||||
|
||||
### Common Issues
|
||||
|
||||
- **Type errors with None values**: Fixed in the latest version of check_db.py by using `or 0` for nullable numeric fields
|
||||
- **Missing relations**: Ensure you're using the correct field names (e.g., `StoreListing` not `storeListing` in includes)
|
||||
- **Column name mismatches**: The database uses camelCase for column names (e.g., `agentGraphId` not `agent_graph_id`)
|
||||
@@ -14,7 +14,7 @@ from backend.data.block import (
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json, retry
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
@@ -151,12 +151,6 @@ class AgentExecutorBlock(Block):
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
|
||||
# we can stop listening for further events.
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
extra_cost=event.stats.cost if event.stats else 0,
|
||||
extra_steps=event.stats.node_exec_count if event.stats else 0,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
|
||||
84
autogpt_platform/backend/backend/blocks/airtable/__init__.py
Normal file
84
autogpt_platform/backend/backend/blocks/airtable/__init__.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Airtable integration for AutoGPT Platform.
|
||||
|
||||
This integration provides comprehensive access to the Airtable Web API,
|
||||
including:
|
||||
- Webhook triggers and management
|
||||
- Record CRUD operations
|
||||
- Attachment uploads
|
||||
- Schema and table management
|
||||
- Metadata operations
|
||||
"""
|
||||
|
||||
# Attachments
|
||||
from .attachments import AirtableUploadAttachmentBlock
|
||||
|
||||
# Metadata
|
||||
from .metadata import (
|
||||
AirtableGetViewBlock,
|
||||
AirtableListBasesBlock,
|
||||
AirtableListViewsBlock,
|
||||
)
|
||||
|
||||
# Record Operations
|
||||
from .records import (
|
||||
AirtableCreateRecordsBlock,
|
||||
AirtableDeleteRecordsBlock,
|
||||
AirtableGetRecordBlock,
|
||||
AirtableListRecordsBlock,
|
||||
AirtableUpdateRecordsBlock,
|
||||
AirtableUpsertRecordsBlock,
|
||||
)
|
||||
|
||||
# Schema & Table Management
|
||||
from .schema import (
|
||||
AirtableAddFieldBlock,
|
||||
AirtableCreateTableBlock,
|
||||
AirtableDeleteFieldBlock,
|
||||
AirtableDeleteTableBlock,
|
||||
AirtableListSchemaBlock,
|
||||
AirtableUpdateFieldBlock,
|
||||
AirtableUpdateTableBlock,
|
||||
)
|
||||
|
||||
# Webhook Triggers
|
||||
from .triggers import AirtableWebhookTriggerBlock
|
||||
|
||||
# Webhook Management
|
||||
from .webhooks import (
|
||||
AirtableCreateWebhookBlock,
|
||||
AirtableDeleteWebhookBlock,
|
||||
AirtableFetchWebhookPayloadsBlock,
|
||||
AirtableRefreshWebhookBlock,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Webhook Triggers
|
||||
"AirtableWebhookTriggerBlock",
|
||||
# Webhook Management
|
||||
"AirtableCreateWebhookBlock",
|
||||
"AirtableDeleteWebhookBlock",
|
||||
"AirtableFetchWebhookPayloadsBlock",
|
||||
"AirtableRefreshWebhookBlock",
|
||||
# Record Operations
|
||||
"AirtableCreateRecordsBlock",
|
||||
"AirtableDeleteRecordsBlock",
|
||||
"AirtableGetRecordBlock",
|
||||
"AirtableListRecordsBlock",
|
||||
"AirtableUpdateRecordsBlock",
|
||||
"AirtableUpsertRecordsBlock",
|
||||
# Attachments
|
||||
"AirtableUploadAttachmentBlock",
|
||||
# Schema & Table Management
|
||||
"AirtableAddFieldBlock",
|
||||
"AirtableCreateTableBlock",
|
||||
"AirtableDeleteFieldBlock",
|
||||
"AirtableDeleteTableBlock",
|
||||
"AirtableListSchemaBlock",
|
||||
"AirtableUpdateFieldBlock",
|
||||
"AirtableUpdateTableBlock",
|
||||
# Metadata
|
||||
"AirtableGetViewBlock",
|
||||
"AirtableListBasesBlock",
|
||||
"AirtableListViewsBlock",
|
||||
]
|
||||
16
autogpt_platform/backend/backend/blocks/airtable/_config.py
Normal file
16
autogpt_platform/backend/backend/blocks/airtable/_config.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Shared configuration for all Airtable blocks using the SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
from ._webhook import AirtableWebhookManager
|
||||
|
||||
# Configure the Airtable provider with API key authentication
|
||||
airtable = (
|
||||
ProviderBuilder("airtable")
|
||||
.with_api_key("AIRTABLE_API_KEY", "Airtable Personal Access Token")
|
||||
.with_webhook_manager(AirtableWebhookManager)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
125
autogpt_platform/backend/backend/blocks/airtable/_webhook.py
Normal file
125
autogpt_platform/backend/backend/blocks/airtable/_webhook.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Webhook management for Airtable blocks.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
from enum import Enum
|
||||
from typing import Tuple
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseWebhooksManager,
|
||||
Credentials,
|
||||
ProviderName,
|
||||
Requests,
|
||||
Webhook,
|
||||
)
|
||||
|
||||
|
||||
class AirtableWebhookManager(BaseWebhooksManager):
|
||||
"""Webhook manager for Airtable API."""
|
||||
|
||||
PROVIDER_NAME = ProviderName("airtable")
|
||||
|
||||
class WebhookType(str, Enum):
|
||||
TABLE_CHANGE = "table_change"
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(cls, webhook: Webhook, request) -> Tuple[dict, str]:
|
||||
"""Validate incoming webhook payload and signature."""
|
||||
payload = await request.json()
|
||||
|
||||
# Verify webhook signature using HMAC-SHA256
|
||||
if webhook.secret:
|
||||
mac_secret = webhook.config.get("mac_secret")
|
||||
if mac_secret:
|
||||
# Get the raw body for signature verification
|
||||
body = await request.body()
|
||||
|
||||
# Calculate expected signature
|
||||
expected_mac = hmac.new(
|
||||
mac_secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Get signature from headers
|
||||
signature = request.headers.get("X-Airtable-Content-MAC")
|
||||
|
||||
if signature and not hmac.compare_digest(signature, expected_mac):
|
||||
raise ValueError("Invalid webhook signature")
|
||||
|
||||
# Airtable sends the cursor in the payload
|
||||
event_type = "notification"
|
||||
return payload, event_type
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> Tuple[str, dict]:
|
||||
"""Register webhook with Airtable API."""
|
||||
if not isinstance(credentials, APIKeyCredentials):
|
||||
raise ValueError("Airtable webhooks require API key credentials")
|
||||
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Parse resource to get base_id and table_id/name
|
||||
# Resource format: "{base_id}/{table_id_or_name}"
|
||||
parts = resource.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError("Resource must be in format: {base_id}/{table_id_or_name}")
|
||||
|
||||
base_id, table_id_or_name = parts
|
||||
|
||||
# Prepare webhook specification
|
||||
specification = {
|
||||
"filters": {
|
||||
"dataTypes": events or ["tableData", "tableFields", "tableMetadata"]
|
||||
}
|
||||
}
|
||||
|
||||
# If specific table is provided, add to specification
|
||||
if table_id_or_name and table_id_or_name != "*":
|
||||
specification["filters"]["recordChangeScope"] = [table_id_or_name]
|
||||
|
||||
# Create webhook
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/bases/{base_id}/webhooks",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json={"notificationUrl": ingress_url, "specification": specification},
|
||||
)
|
||||
|
||||
webhook_data = response.json()
|
||||
webhook_id = webhook_data["id"]
|
||||
mac_secret = webhook_data.get("macSecretBase64")
|
||||
|
||||
return webhook_id, {
|
||||
"base_id": base_id,
|
||||
"table_id_or_name": table_id_or_name,
|
||||
"events": events,
|
||||
"mac_secret": mac_secret,
|
||||
"cursor": 1, # Start from cursor 1
|
||||
"expiration_time": webhook_data.get("expirationTime"),
|
||||
}
|
||||
|
||||
async def _deregister_webhook(
|
||||
self, webhook: Webhook, credentials: Credentials
|
||||
) -> None:
|
||||
"""Deregister webhook from Airtable API."""
|
||||
if not isinstance(credentials, APIKeyCredentials):
|
||||
raise ValueError("Airtable webhooks require API key credentials")
|
||||
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
base_id = webhook.config.get("base_id")
|
||||
|
||||
if not base_id:
|
||||
raise ValueError("Missing base_id in webhook metadata")
|
||||
|
||||
await Requests().delete(
|
||||
f"https://api.airtable.com/v0/bases/{base_id}/webhooks/{webhook.provider_webhook_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Airtable attachment blocks.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Union
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableUploadAttachmentBlock(Block):
|
||||
"""
|
||||
Uploads a file to Airtable for use as an attachment.
|
||||
|
||||
Files can be uploaded directly (up to 5MB) or via URL.
|
||||
The returned attachment ID can be used when creating or updating records.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
filename: str = SchemaField(description="Name of the file")
|
||||
file: Union[bytes, str] = SchemaField(
|
||||
description="File content (binary data or base64 string)"
|
||||
)
|
||||
content_type: str = SchemaField(
|
||||
description="MIME type of the file", default="application/octet-stream"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
attachment: dict = SchemaField(
|
||||
description="Attachment object with id, url, size, and type"
|
||||
)
|
||||
attachment_id: str = SchemaField(description="ID of the uploaded attachment")
|
||||
url: str = SchemaField(description="URL of the uploaded attachment")
|
||||
size: int = SchemaField(description="Size of the file in bytes")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="962e801b-5a6f-4c56-a929-83e816343a41",
|
||||
description="Upload a file to Airtable for use as an attachment",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Convert file to base64 if it's bytes
|
||||
if isinstance(input_data.file, bytes):
|
||||
file_data = base64.b64encode(input_data.file).decode("utf-8")
|
||||
else:
|
||||
# Assume it's already base64 encoded
|
||||
file_data = input_data.file
|
||||
|
||||
# Check file size (5MB limit)
|
||||
file_bytes = base64.b64decode(file_data)
|
||||
if len(file_bytes) > 5 * 1024 * 1024:
|
||||
raise ValueError(
|
||||
"File size exceeds 5MB limit. Use URL upload for larger files."
|
||||
)
|
||||
|
||||
# Upload the attachment
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/bases/{input_data.base_id}/attachments/upload",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"content": file_data,
|
||||
"filename": input_data.filename,
|
||||
"type": input_data.content_type,
|
||||
},
|
||||
)
|
||||
|
||||
attachment_data = response.json()
|
||||
|
||||
yield "attachment", attachment_data
|
||||
yield "attachment_id", attachment_data.get("id", "")
|
||||
yield "url", attachment_data.get("url", "")
|
||||
yield "size", attachment_data.get("size", 0)
|
||||
145
autogpt_platform/backend/backend/blocks/airtable/metadata.py
Normal file
145
autogpt_platform/backend/backend/blocks/airtable/metadata.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Airtable metadata blocks for bases and views.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableListBasesBlock(Block):
|
||||
"""
|
||||
Lists all Airtable bases accessible by the API token.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
bases: list[dict] = SchemaField(
|
||||
description="Array of base objects with id and name"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="613f9907-bef8-468a-be6d-2dd7a53f96e7",
|
||||
description="List all accessible Airtable bases",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# List bases
|
||||
response = await Requests().get(
|
||||
"https://api.airtable.com/v0/meta/bases",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "bases", data.get("bases", [])
|
||||
|
||||
|
||||
class AirtableListViewsBlock(Block):
|
||||
"""
|
||||
Lists all views in an Airtable base with their associated tables.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
|
||||
class Output(BlockSchema):
|
||||
views: list[dict] = SchemaField(
|
||||
description="Array of view objects with tableId"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3878cf82-d384-40c2-aace-097042233f6a",
|
||||
description="List all views in an Airtable base",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Get base schema which includes views
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Extract all views from all tables
|
||||
all_views = []
|
||||
for table in data.get("tables", []):
|
||||
table_id = table.get("id")
|
||||
for view in table.get("views", []):
|
||||
view_with_table = {**view, "tableId": table_id}
|
||||
all_views.append(view_with_table)
|
||||
|
||||
yield "views", all_views
|
||||
|
||||
|
||||
class AirtableGetViewBlock(Block):
|
||||
"""
|
||||
Gets detailed information about a specific view in an Airtable base.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
view_id: str = SchemaField(description="The view ID to retrieve")
|
||||
|
||||
class Output(BlockSchema):
|
||||
view: dict = SchemaField(description="Full view object with configuration")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ad0dd9f3-b3f4-446b-8142-e81a566797c4",
|
||||
description="Get details of a specific Airtable view",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Get specific view
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/views/{input_data.view_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
view_data = response.json()
|
||||
|
||||
yield "view", view_data
|
||||
395
autogpt_platform/backend/backend/blocks/airtable/records.py
Normal file
395
autogpt_platform/backend/backend/blocks/airtable/records.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Airtable record operation blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableListRecordsBlock(Block):
|
||||
"""
|
||||
Lists records from an Airtable table with optional filtering, sorting, and pagination.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name", default="")
|
||||
filter_formula: str = SchemaField(
|
||||
description="Airtable formula to filter records", default=""
|
||||
)
|
||||
view: str = SchemaField(description="View ID or name to use", default="")
|
||||
sort: list[dict] = SchemaField(
|
||||
description="Sort configuration (array of {field, direction})", default=[]
|
||||
)
|
||||
max_records: int = SchemaField(
|
||||
description="Maximum number of records to return", default=100
|
||||
)
|
||||
page_size: int = SchemaField(
|
||||
description="Number of records per page (max 100)", default=100
|
||||
)
|
||||
offset: str = SchemaField(
|
||||
description="Pagination offset from previous request", default=""
|
||||
)
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return (comma-separated)", default=[]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of record objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more records)", default=None
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="588a9fde-5733-4da7-b03c-35f5671e960f",
|
||||
description="List records from an Airtable table",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {}
|
||||
if input_data.filter_formula:
|
||||
params["filterByFormula"] = input_data.filter_formula
|
||||
if input_data.view:
|
||||
params["view"] = input_data.view
|
||||
if input_data.sort:
|
||||
for i, sort_config in enumerate(input_data.sort):
|
||||
params[f"sort[{i}][field]"] = sort_config.get("field", "")
|
||||
params[f"sort[{i}][direction]"] = sort_config.get("direction", "asc")
|
||||
if input_data.max_records:
|
||||
params["maxRecords"] = input_data.max_records
|
||||
if input_data.page_size:
|
||||
params["pageSize"] = min(input_data.page_size, 100)
|
||||
if input_data.offset:
|
||||
params["offset"] = input_data.offset
|
||||
if input_data.return_fields:
|
||||
for i, field in enumerate(input_data.return_fields):
|
||||
params[f"fields[{i}]"] = field
|
||||
|
||||
# Make request
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/{input_data.base_id}/{input_data.table_id_or_name}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
|
||||
class AirtableGetRecordBlock(Block):
|
||||
"""
|
||||
Retrieves a single record from an Airtable table by its ID.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name", default="")
|
||||
record_id: str = SchemaField(description="The record ID to retrieve")
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return", default=[]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
record: dict = SchemaField(description="The record object")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c29c5cbf-0aff-40f9-bbb5-f26061792d2b",
|
||||
description="Get a single record from Airtable",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {}
|
||||
if input_data.return_fields:
|
||||
for i, field in enumerate(input_data.return_fields):
|
||||
params[f"fields[{i}]"] = field
|
||||
|
||||
# Make request
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/{input_data.base_id}/{input_data.table_id_or_name}/{input_data.record_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
params=params,
|
||||
)
|
||||
|
||||
record = response.json()
|
||||
|
||||
yield "record", record
|
||||
|
||||
|
||||
class AirtableCreateRecordsBlock(Block):
|
||||
"""
|
||||
Creates one or more records in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name", default="")
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to create (each with 'fields' object)"
|
||||
)
|
||||
typecast: bool = SchemaField(
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=False,
|
||||
)
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return in created records", default=[]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of created record objects")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="42527e98-47b6-44ce-ac0e-86b4883721d3",
|
||||
description="Create records in an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body
|
||||
body = {"records": input_data.records, "typecast": input_data.typecast}
|
||||
|
||||
# Build query parameters for return fields
|
||||
params = {}
|
||||
if input_data.return_fields:
|
||||
for i, field in enumerate(input_data.return_fields):
|
||||
params[f"fields[{i}]"] = field
|
||||
|
||||
# Make request
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/{input_data.base_id}/{input_data.table_id_or_name}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json=body,
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
|
||||
|
||||
class AirtableUpdateRecordsBlock(Block):
|
||||
"""
|
||||
Updates one or more existing records in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name", default="")
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to update (each with 'id' and 'fields')"
|
||||
)
|
||||
typecast: bool = SchemaField(
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=False,
|
||||
)
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return in updated records", default=[]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of updated record objects")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6e7d2590-ac2b-4b5d-b08c-fc039cd77e1f",
|
||||
description="Update records in an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body
|
||||
body = {"records": input_data.records, "typecast": input_data.typecast}
|
||||
|
||||
# Build query parameters for return fields
|
||||
params = {}
|
||||
if input_data.return_fields:
|
||||
for i, field in enumerate(input_data.return_fields):
|
||||
params[f"fields[{i}]"] = field
|
||||
|
||||
# Make request
|
||||
response = await Requests().patch(
|
||||
f"https://api.airtable.com/v0/{input_data.base_id}/{input_data.table_id_or_name}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json=body,
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
|
||||
|
||||
class AirtableUpsertRecordsBlock(Block):
|
||||
"""
|
||||
Creates or updates records in an Airtable table based on a merge field.
|
||||
|
||||
If a record with the same value in the merge field exists, it will be updated.
|
||||
Otherwise, a new record will be created.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name", default="")
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to upsert (each with 'fields' object)"
|
||||
)
|
||||
merge_field: str = SchemaField(
|
||||
description="Field to use for matching existing records"
|
||||
)
|
||||
typecast: bool = SchemaField(
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=False,
|
||||
)
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return in upserted records", default=[]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of created/updated record objects"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="99f78a9d-3418-429f-a6fb-9d2166638e99",
|
||||
description="Create or update records based on a merge field",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body
|
||||
body = {
|
||||
"performUpsert": {"fieldsToMergeOn": [input_data.merge_field]},
|
||||
"records": input_data.records,
|
||||
"typecast": input_data.typecast,
|
||||
}
|
||||
|
||||
# Build query parameters for return fields
|
||||
params = {}
|
||||
if input_data.return_fields:
|
||||
for i, field in enumerate(input_data.return_fields):
|
||||
params[f"fields[{i}]"] = field
|
||||
|
||||
# Make request
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/{input_data.base_id}/{input_data.table_id_or_name}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json=body,
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
|
||||
|
||||
class AirtableDeleteRecordsBlock(Block):
|
||||
"""
|
||||
Deletes one or more records from an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name", default="")
|
||||
record_ids: list[str] = SchemaField(description="Array of record IDs to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of deletion results")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="93e22b8b-3642-4477-aefb-1c0929a4a3a6",
|
||||
description="Delete records from an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {}
|
||||
for i, record_id in enumerate(input_data.record_ids):
|
||||
params[f"records[{i}]"] = record_id
|
||||
|
||||
# Make request
|
||||
response = await Requests().delete(
|
||||
f"https://api.airtable.com/v0/{input_data.base_id}/{input_data.table_id_or_name}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
328
autogpt_platform/backend/backend/blocks/airtable/schema.py
Normal file
328
autogpt_platform/backend/backend/blocks/airtable/schema.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
Airtable schema and table management blocks.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableListSchemaBlock(Block):
|
||||
"""
|
||||
Retrieves the complete schema of an Airtable base, including all tables,
|
||||
fields, and views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
|
||||
class Output(BlockSchema):
|
||||
base_schema: dict = SchemaField(
|
||||
description="Complete base schema with tables, fields, and views"
|
||||
)
|
||||
tables: list[dict] = SchemaField(description="Array of table objects")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="64291d3c-99b5-47b7-a976-6d94293cdb2d",
|
||||
description="Get the complete schema of an Airtable base",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Get base schema
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "base_schema", data
|
||||
yield "tables", data.get("tables", [])
|
||||
|
||||
|
||||
class AirtableCreateTableBlock(Block):
|
||||
"""
|
||||
Creates a new table in an Airtable base with specified fields and views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_definition: dict = SchemaField(
|
||||
description="Table definition with name, description, fields, and views",
|
||||
default={
|
||||
"name": "New Table",
|
||||
"fields": [{"name": "Name", "type": "singleLineText"}],
|
||||
},
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
table: dict = SchemaField(description="Created table object")
|
||||
table_id: str = SchemaField(description="ID of the created table")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fcc20ced-d817-42ea-9b40-c35e7bf34b4f",
|
||||
description="Create a new table in an Airtable base",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Create table
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json=input_data.table_definition,
|
||||
)
|
||||
|
||||
table_data = response.json()
|
||||
|
||||
yield "table", table_data
|
||||
yield "table_id", table_data.get("id", "")
|
||||
|
||||
|
||||
class AirtableUpdateTableBlock(Block):
|
||||
"""
|
||||
Updates an existing table's properties such as name or description.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id: str = SchemaField(description="The table ID to update")
|
||||
patch: dict = SchemaField(
|
||||
description="Properties to update (name, description)", default={}
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
table: dict = SchemaField(description="Updated table object")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="34077c5f-f962-49f2-9ec6-97c67077013a",
|
||||
description="Update table properties",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Update table
|
||||
response = await Requests().patch(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables/{input_data.table_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json=input_data.patch,
|
||||
)
|
||||
|
||||
table_data = response.json()
|
||||
|
||||
yield "table", table_data
|
||||
|
||||
|
||||
class AirtableDeleteTableBlock(Block):
|
||||
"""
|
||||
Deletes a table from an Airtable base.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id: str = SchemaField(description="The table ID to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
deleted: bool = SchemaField(
|
||||
description="Confirmation that the table was deleted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6b96c196-d0ad-4fb2-981f-7a330549bc22",
|
||||
description="Delete a table from an Airtable base",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Delete table
|
||||
response = await Requests().delete(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables/{input_data.table_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
deleted = response.status in [200, 204]
|
||||
|
||||
yield "deleted", deleted
|
||||
|
||||
|
||||
class AirtableAddFieldBlock(Block):
|
||||
"""
|
||||
Adds a new field (column) to an existing Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id: str = SchemaField(description="The table ID to add field to")
|
||||
field_definition: dict = SchemaField(
|
||||
description="Field definition with name, type, and options",
|
||||
default={"name": "New Field", "type": "singleLineText"},
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
field: dict = SchemaField(description="Created field object")
|
||||
field_id: str = SchemaField(description="ID of the created field")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6c98a32f-dbf9-45d8-a2a8-5e97e8326351",
|
||||
description="Add a new field to an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Add field
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables/{input_data.table_id}/fields",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json=input_data.field_definition,
|
||||
)
|
||||
|
||||
field_data = response.json()
|
||||
|
||||
yield "field", field_data
|
||||
yield "field_id", field_data.get("id", "")
|
||||
|
||||
|
||||
class AirtableUpdateFieldBlock(Block):
|
||||
"""
|
||||
Updates an existing field's properties in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id: str = SchemaField(description="The table ID containing the field")
|
||||
field_id: str = SchemaField(description="The field ID to update")
|
||||
patch: dict = SchemaField(description="Field properties to update", default={})
|
||||
|
||||
class Output(BlockSchema):
|
||||
field: dict = SchemaField(description="Updated field object")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f46ac716-3b18-4da1-92e4-34ca9a464d48",
|
||||
description="Update field properties in an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Update field
|
||||
response = await Requests().patch(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables/{input_data.table_id}/fields/{input_data.field_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json=input_data.patch,
|
||||
)
|
||||
|
||||
field_data = response.json()
|
||||
|
||||
yield "field", field_data
|
||||
|
||||
|
||||
class AirtableDeleteFieldBlock(Block):
|
||||
"""
|
||||
Deletes a field from an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id: str = SchemaField(description="The table ID containing the field")
|
||||
field_id: str = SchemaField(description="The field ID to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
deleted: bool = SchemaField(
|
||||
description="Confirmation that the field was deleted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ca6ebacb-be8b-4c54-80a3-1fb519ad51c6",
|
||||
description="Delete a field from an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Delete field
|
||||
response = await Requests().delete(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables/{input_data.table_id}/fields/{input_data.field_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
deleted = response.status in [200, 204]
|
||||
|
||||
yield "deleted", deleted
|
||||
149
autogpt_platform/backend/backend/blocks/airtable/triggers.py
Normal file
149
autogpt_platform/backend/backend/blocks/airtable/triggers.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Airtable webhook trigger blocks.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
CredentialsMetaInput,
|
||||
ProviderName,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableWebhookTriggerBlock(Block):
|
||||
"""
|
||||
Starts a flow whenever Airtable pings your webhook URL.
|
||||
|
||||
If auto-fetch is enabled, it automatically fetches the full payloads
|
||||
after receiving the notification.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive webhooks (auto-generated)",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
base_id: str = SchemaField(
|
||||
description="The Airtable base ID to monitor",
|
||||
default="",
|
||||
)
|
||||
table_id_or_name: str = SchemaField(
|
||||
description="Table ID or name to monitor (leave empty for all tables)",
|
||||
default="",
|
||||
)
|
||||
event_types: list[str] = SchemaField(
|
||||
description="Event types to listen for",
|
||||
default=["tableData", "tableFields", "tableMetadata"],
|
||||
)
|
||||
auto_fetch: bool = SchemaField(
|
||||
description="Automatically fetch full payloads after notification",
|
||||
default=True,
|
||||
)
|
||||
payload: dict = SchemaField(
|
||||
description="Webhook payload data",
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
ping: dict = SchemaField(description="Raw webhook notification body")
|
||||
headers: dict = SchemaField(description="Webhook request headers")
|
||||
verified: bool = SchemaField(
|
||||
description="Whether the webhook signature was verified"
|
||||
)
|
||||
# Fields populated when auto_fetch is True
|
||||
payloads: list[dict] = SchemaField(
|
||||
description="Array of change payloads (when auto-fetch is enabled)",
|
||||
default=[],
|
||||
)
|
||||
next_cursor: int = SchemaField(
|
||||
description="Next cursor for pagination (when auto-fetch is enabled)",
|
||||
default=0,
|
||||
)
|
||||
might_have_more: bool = SchemaField(
|
||||
description="Whether there might be more payloads (when auto-fetch is enabled)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d0180ce6-ccb9-48c7-8256-b39e93e62801",
|
||||
description="Starts a flow whenever Airtable pings your webhook URL",
|
||||
categories={BlockCategory.INPUT},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
block_type=BlockType.WEBHOOK,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName("airtable"),
|
||||
webhook_type="table_change",
|
||||
# event_filter_input="event_types",
|
||||
resource_format="{base_id}/{table_id_or_name}",
|
||||
),
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
payload = input_data.payload
|
||||
|
||||
# Extract headers from the webhook request (passed through kwargs)
|
||||
headers = kwargs.get("webhook_headers", {})
|
||||
|
||||
# Check if signature was verified (handled by webhook manager)
|
||||
verified = True # Webhook manager raises error if verification fails
|
||||
|
||||
# Output basic webhook data
|
||||
yield "ping", payload
|
||||
yield "headers", headers
|
||||
yield "verified", verified
|
||||
|
||||
# If auto-fetch is enabled and we have a cursor, fetch the full payloads
|
||||
if input_data.auto_fetch and payload.get("base", {}).get("id"):
|
||||
base_id = payload["base"]["id"]
|
||||
webhook_id = payload.get("webhook", {}).get("id", "")
|
||||
cursor = payload.get("cursor", 1)
|
||||
|
||||
if webhook_id and cursor:
|
||||
# Get credentials from kwargs
|
||||
credentials = kwargs.get("credentials")
|
||||
if credentials:
|
||||
# Fetch payloads using the Airtable API
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
from backend.sdk import Requests
|
||||
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/bases/{base_id}/webhooks/{webhook_id}/payloads",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
params={"cursor": cursor},
|
||||
)
|
||||
|
||||
if response.status == 200:
|
||||
data = response.json()
|
||||
yield "payloads", data.get("payloads", [])
|
||||
yield "next_cursor", data.get("cursor", cursor)
|
||||
yield "might_have_more", data.get("mightHaveMore", False)
|
||||
else:
|
||||
# On error, still output empty payloads
|
||||
yield "payloads", []
|
||||
yield "next_cursor", cursor
|
||||
yield "might_have_more", False
|
||||
else:
|
||||
# No credentials, can't fetch
|
||||
yield "payloads", []
|
||||
yield "next_cursor", cursor
|
||||
yield "might_have_more", False
|
||||
else:
|
||||
# Auto-fetch disabled or missing data
|
||||
yield "payloads", []
|
||||
yield "next_cursor", 0
|
||||
yield "might_have_more", False
|
||||
229
autogpt_platform/backend/backend/blocks/airtable/webhooks.py
Normal file
229
autogpt_platform/backend/backend/blocks/airtable/webhooks.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
Airtable webhook management blocks.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableFetchWebhookPayloadsBlock(Block):
|
||||
"""
|
||||
Fetches accumulated event payloads for a webhook.
|
||||
|
||||
Use this to pull the full change details after receiving a webhook notification,
|
||||
or run on a schedule to poll for changes.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
webhook_id: str = SchemaField(
|
||||
description="The webhook ID to fetch payloads for"
|
||||
)
|
||||
cursor: int = SchemaField(
|
||||
description="Cursor position (0 = all payloads)", default=0
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
payloads: list[dict] = SchemaField(description="Array of webhook payloads")
|
||||
next_cursor: int = SchemaField(description="Next cursor for pagination")
|
||||
might_have_more: bool = SchemaField(
|
||||
description="Whether there might be more payloads"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7172db38-e338-4561-836f-9fa282c99949",
|
||||
description="Fetch webhook payloads from Airtable",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Fetch payloads from Airtable
|
||||
params = {}
|
||||
if input_data.cursor > 0:
|
||||
params["cursor"] = input_data.cursor
|
||||
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/bases/{input_data.base_id}/webhooks/{input_data.webhook_id}/payloads",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "payloads", data.get("payloads", [])
|
||||
yield "next_cursor", data.get("cursor", input_data.cursor)
|
||||
yield "might_have_more", data.get("mightHaveMore", False)
|
||||
|
||||
|
||||
class AirtableRefreshWebhookBlock(Block):
|
||||
"""
|
||||
Refreshes a webhook to extend its expiration by another 7 days.
|
||||
|
||||
Webhooks expire after 7 days of inactivity. Use this block in a daily
|
||||
cron job to keep long-lived webhooks active.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
webhook_id: str = SchemaField(description="The webhook ID to refresh")
|
||||
|
||||
class Output(BlockSchema):
|
||||
expiration_time: str = SchemaField(
|
||||
description="New expiration time (ISO format)"
|
||||
)
|
||||
webhook: dict = SchemaField(description="Full webhook object")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5e82d957-02b8-47eb-8974-7bdaf8caff78",
|
||||
description="Refresh a webhook to extend its expiration",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Refresh the webhook
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/bases/{input_data.base_id}/webhooks/{input_data.webhook_id}/refresh",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
webhook_data = response.json()
|
||||
|
||||
yield "expiration_time", webhook_data.get("expirationTime", "")
|
||||
yield "webhook", webhook_data
|
||||
|
||||
|
||||
class AirtableCreateWebhookBlock(Block):
|
||||
"""
|
||||
Creates a new webhook for monitoring changes in an Airtable base.
|
||||
|
||||
The webhook will send notifications to the specified URL when changes occur.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID to monitor")
|
||||
notification_url: str = SchemaField(
|
||||
description="URL to receive webhook notifications"
|
||||
)
|
||||
specification: dict = SchemaField(
|
||||
description="Webhook specification (filters, options)",
|
||||
default={
|
||||
"filters": {"dataTypes": ["tableData", "tableFields", "tableMetadata"]}
|
||||
},
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webhook: dict = SchemaField(description="Created webhook object")
|
||||
webhook_id: str = SchemaField(description="ID of the created webhook")
|
||||
mac_secret: str = SchemaField(
|
||||
description="MAC secret for signature verification"
|
||||
)
|
||||
expiration_time: str = SchemaField(description="Webhook expiration time")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b9f1f4ec-f4d1-4fbd-ab0b-b219c0e4da9a",
|
||||
description="Create a new Airtable webhook",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Create the webhook
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/bases/{input_data.base_id}/webhooks",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json={
|
||||
"notificationUrl": input_data.notification_url,
|
||||
"specification": input_data.specification,
|
||||
},
|
||||
)
|
||||
|
||||
webhook_data = response.json()
|
||||
|
||||
yield "webhook", webhook_data
|
||||
yield "webhook_id", webhook_data.get("id", "")
|
||||
yield "mac_secret", webhook_data.get("macSecretBase64", "")
|
||||
yield "expiration_time", webhook_data.get("expirationTime", "")
|
||||
|
||||
|
||||
class AirtableDeleteWebhookBlock(Block):
|
||||
"""
|
||||
Deletes a webhook from an Airtable base.
|
||||
|
||||
This will stop all notifications from the webhook.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
webhook_id: str = SchemaField(description="The webhook ID to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
deleted: bool = SchemaField(
|
||||
description="Whether the webhook was successfully deleted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e4ded448-1515-4fe2-b93e-3e4db527df83",
|
||||
description="Delete an Airtable webhook",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Delete the webhook
|
||||
response = await Requests().delete(
|
||||
f"https://api.airtable.com/v0/bases/{input_data.base_id}/webhooks/{input_data.webhook_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
# Check if deletion was successful
|
||||
deleted = response.status in [200, 204]
|
||||
|
||||
yield "deleted", deleted
|
||||
66
autogpt_platform/backend/backend/blocks/baas/__init__.py
Normal file
66
autogpt_platform/backend/backend/blocks/baas/__init__.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Meeting BaaS integration for AutoGPT Platform.
|
||||
|
||||
This integration provides comprehensive access to the Meeting BaaS API,
|
||||
including:
|
||||
- Bot management for meeting recordings
|
||||
- Calendar integration (Google/Microsoft)
|
||||
- Event management and scheduling
|
||||
- Webhook triggers for real-time events
|
||||
"""
|
||||
|
||||
# Bot (Recording) Blocks
|
||||
from .bots import (
|
||||
BaasBotDeleteRecordingBlock,
|
||||
BaasBotFetchMeetingDataBlock,
|
||||
BaasBotFetchScreenshotsBlock,
|
||||
BaasBotJoinMeetingBlock,
|
||||
BaasBotLeaveMeetingBlock,
|
||||
BaasBotRetranscribeBlock,
|
||||
)
|
||||
|
||||
# Calendar Blocks
|
||||
from .calendars import (
|
||||
BaasCalendarConnectBlock,
|
||||
BaasCalendarDeleteBlock,
|
||||
BaasCalendarListAllBlock,
|
||||
BaasCalendarResyncAllBlock,
|
||||
BaasCalendarUpdateCredsBlock,
|
||||
)
|
||||
|
||||
# Event Blocks
|
||||
from .events import (
|
||||
BaasEventGetDetailsBlock,
|
||||
BaasEventListBlock,
|
||||
BaasEventPatchBotBlock,
|
||||
BaasEventScheduleBotBlock,
|
||||
BaasEventUnscheduleBotBlock,
|
||||
)
|
||||
|
||||
# Webhook Triggers
|
||||
from .triggers import BaasOnCalendarEventBlock, BaasOnMeetingEventBlock
|
||||
|
||||
__all__ = [
|
||||
# Bot (Recording) Blocks
|
||||
"BaasBotJoinMeetingBlock",
|
||||
"BaasBotLeaveMeetingBlock",
|
||||
"BaasBotFetchMeetingDataBlock",
|
||||
"BaasBotFetchScreenshotsBlock",
|
||||
"BaasBotDeleteRecordingBlock",
|
||||
"BaasBotRetranscribeBlock",
|
||||
# Calendar Blocks
|
||||
"BaasCalendarConnectBlock",
|
||||
"BaasCalendarListAllBlock",
|
||||
"BaasCalendarUpdateCredsBlock",
|
||||
"BaasCalendarDeleteBlock",
|
||||
"BaasCalendarResyncAllBlock",
|
||||
# Event Blocks
|
||||
"BaasEventListBlock",
|
||||
"BaasEventGetDetailsBlock",
|
||||
"BaasEventScheduleBotBlock",
|
||||
"BaasEventUnscheduleBotBlock",
|
||||
"BaasEventPatchBotBlock",
|
||||
# Webhook Triggers
|
||||
"BaasOnMeetingEventBlock",
|
||||
"BaasOnCalendarEventBlock",
|
||||
]
|
||||
16
autogpt_platform/backend/backend/blocks/baas/_config.py
Normal file
16
autogpt_platform/backend/backend/blocks/baas/_config.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Shared configuration for all Meeting BaaS blocks using the SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
from ._webhook import BaasWebhookManager
|
||||
|
||||
# Configure the Meeting BaaS provider with API key authentication
|
||||
baas = (
|
||||
ProviderBuilder("baas")
|
||||
.with_api_key("MEETING_BAAS_API_KEY", "Meeting BaaS API Key")
|
||||
.with_webhook_manager(BaasWebhookManager)
|
||||
.with_base_cost(5, BlockCostType.RUN) # Higher cost for meeting recording service
|
||||
.build()
|
||||
)
|
||||
83
autogpt_platform/backend/backend/blocks/baas/_webhook.py
Normal file
83
autogpt_platform/backend/backend/blocks/baas/_webhook.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Webhook management for Meeting BaaS blocks.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Tuple
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseWebhooksManager,
|
||||
Credentials,
|
||||
ProviderName,
|
||||
Webhook,
|
||||
)
|
||||
|
||||
|
||||
class BaasWebhookManager(BaseWebhooksManager):
|
||||
"""Webhook manager for Meeting BaaS API."""
|
||||
|
||||
PROVIDER_NAME = ProviderName("baas")
|
||||
|
||||
class WebhookType(str, Enum):
|
||||
MEETING_EVENT = "meeting_event"
|
||||
CALENDAR_EVENT = "calendar_event"
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(cls, webhook: Webhook, request) -> Tuple[dict, str]:
|
||||
"""Validate incoming webhook payload."""
|
||||
payload = await request.json()
|
||||
|
||||
# Verify API key in header
|
||||
api_key_header = request.headers.get("x-meeting-baas-api-key")
|
||||
if webhook.secret and api_key_header != webhook.secret:
|
||||
raise ValueError("Invalid webhook API key")
|
||||
|
||||
# Extract event type from payload
|
||||
event_type = payload.get("event", "unknown")
|
||||
|
||||
return payload, event_type
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> Tuple[str, dict]:
|
||||
"""
|
||||
Register webhook with Meeting BaaS.
|
||||
|
||||
Note: Meeting BaaS doesn't have a webhook registration API.
|
||||
Webhooks are configured per-bot or as account defaults.
|
||||
This returns a synthetic webhook ID.
|
||||
"""
|
||||
if not isinstance(credentials, APIKeyCredentials):
|
||||
raise ValueError("Meeting BaaS webhooks require API key credentials")
|
||||
|
||||
# Generate a synthetic webhook ID since BaaS doesn't provide one
|
||||
import uuid
|
||||
|
||||
webhook_id = str(uuid.uuid4())
|
||||
|
||||
return webhook_id, {
|
||||
"webhook_type": webhook_type,
|
||||
"resource": resource,
|
||||
"events": events,
|
||||
"ingress_url": ingress_url,
|
||||
"api_key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
async def _deregister_webhook(
|
||||
self, webhook: Webhook, credentials: Credentials
|
||||
) -> None:
|
||||
"""
|
||||
Deregister webhook from Meeting BaaS.
|
||||
|
||||
Note: Meeting BaaS doesn't have a webhook deregistration API.
|
||||
Webhooks are removed by updating bot/calendar configurations.
|
||||
"""
|
||||
# No-op since BaaS doesn't have webhook deregistration
|
||||
pass
|
||||
367
autogpt_platform/backend/backend/blocks/baas/bots.py
Normal file
367
autogpt_platform/backend/backend/blocks/baas/bots.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Meeting BaaS bot (recording) blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import baas
|
||||
|
||||
|
||||
class BaasBotJoinMeetingBlock(Block):
|
||||
"""
|
||||
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
meeting_url: str = SchemaField(
|
||||
description="The URL of the meeting the bot should join"
|
||||
)
|
||||
bot_name: str = SchemaField(
|
||||
description="Display name for the bot in the meeting"
|
||||
)
|
||||
bot_image: str = SchemaField(
|
||||
description="URL to an image for the bot's avatar (16:9 ratio recommended)",
|
||||
default="",
|
||||
)
|
||||
entry_message: str = SchemaField(
|
||||
description="Chat message the bot will post upon entry", default=""
|
||||
)
|
||||
reserved: bool = SchemaField(
|
||||
description="Use a reserved bot slot (joins 4 min before meeting)",
|
||||
default=False,
|
||||
)
|
||||
start_time: Optional[int] = SchemaField(
|
||||
description="Unix timestamp (ms) when bot should join", default=None
|
||||
)
|
||||
speech_to_text: dict = SchemaField(
|
||||
description="Speech-to-text configuration", default={"provider": "Gladia"}
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive webhook events for this bot", default=""
|
||||
)
|
||||
timeouts: dict = SchemaField(
|
||||
description="Automatic leave timeouts configuration", default={}
|
||||
)
|
||||
extra: dict = SchemaField(
|
||||
description="Custom metadata to attach to the bot", default={}
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
bot_id: str = SchemaField(description="UUID of the deployed bot")
|
||||
join_response: dict = SchemaField(
|
||||
description="Full response from join operation"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7f8e9d0c-1b2a-3c4d-5e6f-7a8b9c0d1e2f",
|
||||
description="Deploy a bot to join and record a meeting",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body
|
||||
body = {
|
||||
"meeting_url": input_data.meeting_url,
|
||||
"bot_name": input_data.bot_name,
|
||||
"reserved": input_data.reserved,
|
||||
"speech_to_text": input_data.speech_to_text,
|
||||
}
|
||||
|
||||
# Add optional fields
|
||||
if input_data.bot_image:
|
||||
body["bot_image"] = input_data.bot_image
|
||||
if input_data.entry_message:
|
||||
body["entry_message"] = input_data.entry_message
|
||||
if input_data.start_time is not None:
|
||||
body["start_time"] = input_data.start_time
|
||||
if input_data.webhook_url:
|
||||
body["webhook_url"] = input_data.webhook_url
|
||||
if input_data.timeouts:
|
||||
body["automatic_leave"] = input_data.timeouts
|
||||
if input_data.extra:
|
||||
body["extra"] = input_data.extra
|
||||
|
||||
# Join meeting
|
||||
response = await Requests().post(
|
||||
"https://api.meetingbaas.com/bots",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
json=body,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "bot_id", data.get("bot_id", "")
|
||||
yield "join_response", data
|
||||
|
||||
|
||||
class BaasBotLeaveMeetingBlock(Block):
|
||||
"""
|
||||
Force the bot to exit the call.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot to remove from meeting")
|
||||
|
||||
class Output(BlockSchema):
|
||||
left: bool = SchemaField(description="Whether the bot successfully left")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8a9b0c1d-2e3f-4a5b-6c7d-8e9f0a1b2c3d",
|
||||
description="Remove a bot from an ongoing meeting",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Leave meeting
|
||||
response = await Requests().delete(
|
||||
f"https://api.meetingbaas.com/bots/{input_data.bot_id}",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
)
|
||||
|
||||
# Check if successful
|
||||
left = response.status in [200, 204]
|
||||
|
||||
yield "left", left
|
||||
|
||||
|
||||
class BaasBotFetchMeetingDataBlock(Block):
|
||||
"""
|
||||
Pull MP4 URL, transcript & metadata for a completed meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to fetch")
|
||||
include_transcripts: bool = SchemaField(
|
||||
description="Include transcript data in response", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
mp4_url: str = SchemaField(
|
||||
description="URL to download the meeting recording (time-limited)"
|
||||
)
|
||||
transcript: list = SchemaField(description="Meeting transcript data")
|
||||
metadata: dict = SchemaField(description="Meeting metadata and bot information")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9b0c1d2e-3f4a-5b6c-7d8e-9f0a1b2c3d4e",
|
||||
description="Retrieve recorded meeting data",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {
|
||||
"bot_id": input_data.bot_id,
|
||||
"include_transcripts": str(input_data.include_transcripts).lower(),
|
||||
}
|
||||
|
||||
# Fetch meeting data
|
||||
response = await Requests().get(
|
||||
"https://api.meetingbaas.com/bots/meeting_data",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "mp4_url", data.get("mp4", "")
|
||||
yield "transcript", data.get("bot_data", {}).get("transcripts", [])
|
||||
yield "metadata", data.get("bot_data", {}).get("bot", {})
|
||||
|
||||
|
||||
class BaasBotFetchScreenshotsBlock(Block):
|
||||
"""
|
||||
List screenshots captured during the call.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(
|
||||
description="UUID of the bot whose screenshots to fetch"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
screenshots: list[dict] = SchemaField(
|
||||
description="Array of screenshot objects with date and url"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0c1d2e3f-4a5b-6c7d-8e9f-0a1b2c3d4e5f",
|
||||
description="Retrieve screenshots captured during a meeting",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Fetch screenshots
|
||||
response = await Requests().get(
|
||||
f"https://api.meetingbaas.com/bots/{input_data.bot_id}/screenshots",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
)
|
||||
|
||||
screenshots = response.json()
|
||||
|
||||
yield "screenshots", screenshots
|
||||
|
||||
|
||||
class BaasBotDeleteRecordingBlock(Block):
|
||||
"""
|
||||
Purge MP4 + transcript data for privacy or storage management.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
deleted: bool = SchemaField(
|
||||
description="Whether the data was successfully deleted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1d2e3f4a-5b6c-7d8e-9f0a-1b2c3d4e5f6a",
|
||||
description="Permanently delete a meeting's recorded data",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Delete recording data
|
||||
response = await Requests().post(
|
||||
f"https://api.meetingbaas.com/bots/{input_data.bot_id}/delete_data",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
)
|
||||
|
||||
# Check if successful
|
||||
deleted = response.status == 200
|
||||
|
||||
yield "deleted", deleted
|
||||
|
||||
|
||||
class BaasBotRetranscribeBlock(Block):
|
||||
"""
|
||||
Re-run STT on past audio with a different provider or settings.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(
|
||||
description="UUID of the bot whose audio to retranscribe"
|
||||
)
|
||||
provider: str = SchemaField(
|
||||
description="Speech-to-text provider to use (e.g., Gladia, Deepgram)"
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive transcription complete event", default=""
|
||||
)
|
||||
custom_options: dict = SchemaField(
|
||||
description="Provider-specific options", default={}
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
job_id: Optional[str] = SchemaField(
|
||||
description="Transcription job ID if available"
|
||||
)
|
||||
accepted: bool = SchemaField(
|
||||
description="Whether the retranscription request was accepted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2e3f4a5b-6c7d-8e9f-0a1b-2c3d4e5f6a7b",
|
||||
description="Re-run transcription on a meeting's audio",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body
|
||||
body = {"bot_uuid": input_data.bot_id, "provider": input_data.provider}
|
||||
|
||||
if input_data.webhook_url:
|
||||
body["webhook_url"] = input_data.webhook_url
|
||||
|
||||
if input_data.custom_options:
|
||||
body.update(input_data.custom_options)
|
||||
|
||||
# Start retranscription
|
||||
response = await Requests().post(
|
||||
"https://api.meetingbaas.com/bots/retranscribe",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
json=body,
|
||||
)
|
||||
|
||||
# Check if accepted
|
||||
accepted = response.status in [200, 202]
|
||||
job_id = None
|
||||
|
||||
if accepted and response.status == 200:
|
||||
data = response.json()
|
||||
job_id = data.get("job_id")
|
||||
|
||||
yield "job_id", job_id
|
||||
yield "accepted", accepted
|
||||
265
autogpt_platform/backend/backend/blocks/baas/calendars.py
Normal file
265
autogpt_platform/backend/backend/blocks/baas/calendars.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Meeting BaaS calendar blocks.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import baas
|
||||
|
||||
|
||||
class BaasCalendarConnectBlock(Block):
|
||||
"""
|
||||
One-time integration of a Google or Microsoft calendar.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
oauth_client_id: str = SchemaField(description="OAuth client ID from provider")
|
||||
oauth_client_secret: str = SchemaField(description="OAuth client secret")
|
||||
oauth_refresh_token: str = SchemaField(
|
||||
description="OAuth refresh token with calendar access"
|
||||
)
|
||||
platform: str = SchemaField(
|
||||
description="Calendar platform (Google or Microsoft)"
|
||||
)
|
||||
calendar_email_or_id: str = SchemaField(
|
||||
description="Specific calendar email/ID to connect", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
calendar_id: str = SchemaField(description="UUID of the connected calendar")
|
||||
calendar_obj: dict = SchemaField(description="Full calendar object")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3f4a5b6c-7d8e-9f0a-1b2c-3d4e5f6a7b8c",
|
||||
description="Connect a Google or Microsoft calendar for integration",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body
|
||||
body = {
|
||||
"oauth_client_id": input_data.oauth_client_id,
|
||||
"oauth_client_secret": input_data.oauth_client_secret,
|
||||
"oauth_refresh_token": input_data.oauth_refresh_token,
|
||||
"platform": input_data.platform,
|
||||
}
|
||||
|
||||
if input_data.calendar_email_or_id:
|
||||
body["calendar_email"] = input_data.calendar_email_or_id
|
||||
|
||||
# Connect calendar
|
||||
response = await Requests().post(
|
||||
"https://api.meetingbaas.com/calendars",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
json=body,
|
||||
)
|
||||
|
||||
calendar = response.json()
|
||||
|
||||
yield "calendar_id", calendar.get("uuid", "")
|
||||
yield "calendar_obj", calendar
|
||||
|
||||
|
||||
class BaasCalendarListAllBlock(Block):
|
||||
"""
|
||||
Enumerate connected calendars.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
calendars: list[dict] = SchemaField(
|
||||
description="Array of connected calendar objects"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4a5b6c7d-8e9f-0a1b-2c3d-4e5f6a7b8c9d",
|
||||
description="List all integrated calendars",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# List calendars
|
||||
response = await Requests().get(
|
||||
"https://api.meetingbaas.com/calendars",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
)
|
||||
|
||||
calendars = response.json()
|
||||
|
||||
yield "calendars", calendars
|
||||
|
||||
|
||||
class BaasCalendarUpdateCredsBlock(Block):
|
||||
"""
|
||||
Refresh OAuth or switch provider for an existing calendar.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
calendar_id: str = SchemaField(description="UUID of the calendar to update")
|
||||
oauth_client_id: str = SchemaField(
|
||||
description="New OAuth client ID", default=""
|
||||
)
|
||||
oauth_client_secret: str = SchemaField(
|
||||
description="New OAuth client secret", default=""
|
||||
)
|
||||
oauth_refresh_token: str = SchemaField(
|
||||
description="New OAuth refresh token", default=""
|
||||
)
|
||||
platform: str = SchemaField(description="New platform if switching", default="")
|
||||
|
||||
class Output(BlockSchema):
|
||||
calendar_obj: dict = SchemaField(description="Updated calendar object")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5b6c7d8e-9f0a-1b2c-3d4e-5f6a7b8c9d0e",
|
||||
description="Update calendar credentials or platform",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body with only provided fields
|
||||
body = {}
|
||||
if input_data.oauth_client_id:
|
||||
body["oauth_client_id"] = input_data.oauth_client_id
|
||||
if input_data.oauth_client_secret:
|
||||
body["oauth_client_secret"] = input_data.oauth_client_secret
|
||||
if input_data.oauth_refresh_token:
|
||||
body["oauth_refresh_token"] = input_data.oauth_refresh_token
|
||||
if input_data.platform:
|
||||
body["platform"] = input_data.platform
|
||||
|
||||
# Update calendar
|
||||
response = await Requests().patch(
|
||||
f"https://api.meetingbaas.com/calendars/{input_data.calendar_id}",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
json=body,
|
||||
)
|
||||
|
||||
calendar = response.json()
|
||||
|
||||
yield "calendar_obj", calendar
|
||||
|
||||
|
||||
class BaasCalendarDeleteBlock(Block):
|
||||
"""
|
||||
Disconnect calendar & unschedule future bots.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
calendar_id: str = SchemaField(description="UUID of the calendar to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
deleted: bool = SchemaField(
|
||||
description="Whether the calendar was successfully deleted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6c7d8e9f-0a1b-2c3d-4e5f-6a7b8c9d0e1f",
|
||||
description="Remove a calendar integration",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Delete calendar
|
||||
response = await Requests().delete(
|
||||
f"https://api.meetingbaas.com/calendars/{input_data.calendar_id}",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
)
|
||||
|
||||
deleted = response.status in [200, 204]
|
||||
|
||||
yield "deleted", deleted
|
||||
|
||||
|
||||
class BaasCalendarResyncAllBlock(Block):
|
||||
"""
|
||||
Force full sync now (maintenance).
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
synced_ids: list[str] = SchemaField(
|
||||
description="Calendar UUIDs that synced successfully"
|
||||
)
|
||||
errors: list[list] = SchemaField(
|
||||
description="Array of [calendar_id, error_message] tuples"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7d8e9f0a-1b2c-3d4e-5f6a-7b8c9d0e1f2a",
|
||||
description="Force immediate re-sync of all connected calendars",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Resync all calendars
|
||||
response = await Requests().post(
|
||||
"https://api.meetingbaas.com/internal/calendar/resync_all",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "synced_ids", data.get("synced_calendars", [])
|
||||
yield "errors", data.get("errors", [])
|
||||
276
autogpt_platform/backend/backend/blocks/baas/events.py
Normal file
276
autogpt_platform/backend/backend/blocks/baas/events.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
Meeting BaaS calendar event blocks.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import baas
|
||||
|
||||
|
||||
class BaasEventListBlock(Block):
|
||||
"""
|
||||
Get events for a calendar & date range.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
calendar_id: str = SchemaField(
|
||||
description="UUID of the calendar to list events from"
|
||||
)
|
||||
start_date_gte: str = SchemaField(
|
||||
description="ISO date string for start date (greater than or equal)",
|
||||
default="",
|
||||
)
|
||||
start_date_lte: str = SchemaField(
|
||||
description="ISO date string for start date (less than or equal)",
|
||||
default="",
|
||||
)
|
||||
cursor: str = SchemaField(
|
||||
description="Pagination cursor from previous request", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
events: list[dict] = SchemaField(description="Array of calendar events")
|
||||
next_cursor: str = SchemaField(description="Cursor for next page of results")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8e9f0a1b-2c3d-4e5f-6a7b-8c9d0e1f2a3b",
|
||||
description="List calendar events with optional date filtering",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {"calendar_id": input_data.calendar_id}
|
||||
|
||||
if input_data.start_date_gte:
|
||||
params["start_date_gte"] = input_data.start_date_gte
|
||||
if input_data.start_date_lte:
|
||||
params["start_date_lte"] = input_data.start_date_lte
|
||||
if input_data.cursor:
|
||||
params["cursor"] = input_data.cursor
|
||||
|
||||
# List events
|
||||
response = await Requests().get(
|
||||
"https://api.meetingbaas.com/calendar_events",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "events", data.get("events", [])
|
||||
yield "next_cursor", data.get("next", "")
|
||||
|
||||
|
||||
class BaasEventGetDetailsBlock(Block):
|
||||
"""
|
||||
Fetch full object for one event.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
event_id: str = SchemaField(description="UUID of the event to retrieve")
|
||||
|
||||
class Output(BlockSchema):
|
||||
event: dict = SchemaField(description="Full event object with all details")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9f0a1b2c-3d4e-5f6a-7b8c-9d0e1f2a3b4c",
|
||||
description="Get detailed information for a specific calendar event",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Get event details
|
||||
response = await Requests().get(
|
||||
f"https://api.meetingbaas.com/calendar_events/{input_data.event_id}",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
)
|
||||
|
||||
event = response.json()
|
||||
|
||||
yield "event", event
|
||||
|
||||
|
||||
class BaasEventScheduleBotBlock(Block):
|
||||
"""
|
||||
Attach bot config to the event for automatic recording.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
event_id: str = SchemaField(description="UUID of the event to schedule bot for")
|
||||
all_occurrences: bool = SchemaField(
|
||||
description="Apply to all occurrences of recurring event", default=False
|
||||
)
|
||||
bot_config: dict = SchemaField(
|
||||
description="Bot configuration (same as Bot → Join Meeting)"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
events: Union[dict, list[dict]] = SchemaField(
|
||||
description="Updated event(s) with bot scheduled"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0a1b2c3d-4e5f-6a7b-8c9d-0e1f2a3b4c5d",
|
||||
description="Schedule a recording bot for a calendar event",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {"all_occurrences": str(input_data.all_occurrences).lower()}
|
||||
|
||||
# Schedule bot
|
||||
response = await Requests().post(
|
||||
f"https://api.meetingbaas.com/calendar_events/{input_data.event_id}/bot",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
params=params,
|
||||
json=input_data.bot_config,
|
||||
)
|
||||
|
||||
events = response.json()
|
||||
|
||||
yield "events", events
|
||||
|
||||
|
||||
class BaasEventUnscheduleBotBlock(Block):
|
||||
"""
|
||||
Remove bot from event/series.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
event_id: str = SchemaField(
|
||||
description="UUID of the event to unschedule bot from"
|
||||
)
|
||||
all_occurrences: bool = SchemaField(
|
||||
description="Apply to all occurrences of recurring event", default=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
events: Union[dict, list[dict]] = SchemaField(
|
||||
description="Updated event(s) with bot removed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1b2c3d4e-5f6a-7b8c-9d0e-1f2a3b4c5d6e",
|
||||
description="Cancel a scheduled recording for an event",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {"all_occurrences": str(input_data.all_occurrences).lower()}
|
||||
|
||||
# Unschedule bot
|
||||
response = await Requests().delete(
|
||||
f"https://api.meetingbaas.com/calendar_events/{input_data.event_id}/bot",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
params=params,
|
||||
)
|
||||
|
||||
events = response.json()
|
||||
|
||||
yield "events", events
|
||||
|
||||
|
||||
class BaasEventPatchBotBlock(Block):
|
||||
"""
|
||||
Modify an already-scheduled bot configuration.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
event_id: str = SchemaField(description="UUID of the event with scheduled bot")
|
||||
all_occurrences: bool = SchemaField(
|
||||
description="Apply to all occurrences of recurring event", default=False
|
||||
)
|
||||
bot_patch: dict = SchemaField(description="Bot configuration fields to update")
|
||||
|
||||
class Output(BlockSchema):
|
||||
events: Union[dict, list[dict]] = SchemaField(
|
||||
description="Updated event(s) with modified bot config"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2c3d4e5f-6a7b-8c9d-0e1f-2a3b4c5d6e7f",
|
||||
description="Update configuration of a scheduled bot",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {}
|
||||
if input_data.all_occurrences is not None:
|
||||
params["all_occurrences"] = str(input_data.all_occurrences).lower()
|
||||
|
||||
# Patch bot
|
||||
response = await Requests().patch(
|
||||
f"https://api.meetingbaas.com/calendar_events/{input_data.event_id}/bot",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
params=params,
|
||||
json=input_data.bot_patch,
|
||||
)
|
||||
|
||||
events = response.json()
|
||||
|
||||
yield "events", events
|
||||
185
autogpt_platform/backend/backend/blocks/baas/triggers.py
Normal file
185
autogpt_platform/backend/backend/blocks/baas/triggers.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
Meeting BaaS webhook trigger blocks.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
CredentialsMetaInput,
|
||||
ProviderName,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import baas
|
||||
|
||||
|
||||
class BaasOnMeetingEventBlock(Block):
|
||||
"""
|
||||
Trigger when Meeting BaaS sends meeting-related events:
|
||||
bot.status_change, complete, failed, transcription_complete
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive webhooks (auto-generated)",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
class EventsFilter(BaseModel):
|
||||
"""Meeting event types to subscribe to"""
|
||||
|
||||
bot_status_change: bool = SchemaField(
|
||||
description="Bot status changes", default=True
|
||||
)
|
||||
complete: bool = SchemaField(description="Meeting completed", default=True)
|
||||
failed: bool = SchemaField(description="Meeting failed", default=True)
|
||||
transcription_complete: bool = SchemaField(
|
||||
description="Transcription completed", default=True
|
||||
)
|
||||
|
||||
events: EventsFilter = SchemaField(
|
||||
title="Events", description="The events to subscribe to"
|
||||
)
|
||||
|
||||
payload: dict = SchemaField(
|
||||
description="Webhook payload data",
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
event_type: str = SchemaField(description="Type of event received")
|
||||
data: dict = SchemaField(description="Event data payload")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3d4e5f6a-7b8c-9d0e-1f2a-3b4c5d6e7f8a",
|
||||
description="Receive meeting events from Meeting BaaS webhooks",
|
||||
categories={BlockCategory.INPUT},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
block_type=BlockType.WEBHOOK,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName("baas"),
|
||||
webhook_type="meeting_event",
|
||||
event_filter_input="events",
|
||||
resource_format="meeting",
|
||||
),
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
payload = input_data.payload
|
||||
|
||||
# Extract event type and data
|
||||
event_type = payload.get("event", "unknown")
|
||||
data = payload.get("data", {})
|
||||
|
||||
# Map event types to filter fields
|
||||
event_filter_map = {
|
||||
"bot.status_change": input_data.events.bot_status_change,
|
||||
"complete": input_data.events.complete,
|
||||
"failed": input_data.events.failed,
|
||||
"transcription_complete": input_data.events.transcription_complete,
|
||||
}
|
||||
|
||||
# Filter events if needed
|
||||
if not event_filter_map.get(event_type, False):
|
||||
return # Skip unwanted events
|
||||
|
||||
yield "event_type", event_type
|
||||
yield "data", data
|
||||
|
||||
|
||||
class BaasOnCalendarEventBlock(Block):
|
||||
"""
|
||||
Trigger when Meeting BaaS sends calendar-related events:
|
||||
event.added, event.updated, event.deleted, calendar.synced
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive webhooks (auto-generated)",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
class EventsFilter(BaseModel):
|
||||
"""Calendar event types to subscribe to"""
|
||||
|
||||
event_added: bool = SchemaField(
|
||||
description="Calendar event added", default=True
|
||||
)
|
||||
event_updated: bool = SchemaField(
|
||||
description="Calendar event updated", default=True
|
||||
)
|
||||
event_deleted: bool = SchemaField(
|
||||
description="Calendar event deleted", default=True
|
||||
)
|
||||
calendar_synced: bool = SchemaField(
|
||||
description="Calendar synced", default=True
|
||||
)
|
||||
|
||||
events: EventsFilter = SchemaField(
|
||||
title="Events", description="The events to subscribe to"
|
||||
)
|
||||
|
||||
payload: dict = SchemaField(
|
||||
description="Webhook payload data",
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
event_type: str = SchemaField(description="Type of event received")
|
||||
data: dict = SchemaField(description="Event data payload")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4e5f6a7b-8c9d-0e1f-2a3b-4c5d6e7f8a9b",
|
||||
description="Receive calendar events from Meeting BaaS webhooks",
|
||||
categories={BlockCategory.INPUT},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
block_type=BlockType.WEBHOOK,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName("baas"),
|
||||
webhook_type="calendar_event",
|
||||
event_filter_input="events",
|
||||
resource_format="calendar",
|
||||
),
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
payload = input_data.payload
|
||||
|
||||
# Extract event type and data
|
||||
event_type = payload.get("event", "unknown")
|
||||
data = payload.get("data", {})
|
||||
|
||||
# Map event types to filter fields
|
||||
event_filter_map = {
|
||||
"event.added": input_data.events.event_added,
|
||||
"event.updated": input_data.events.event_updated,
|
||||
"event.deleted": input_data.events.event_deleted,
|
||||
"calendar.synced": input_data.events.calendar_synced,
|
||||
}
|
||||
|
||||
# Filter events if needed
|
||||
if not event_filter_map.get(event_type, False):
|
||||
return # Skip unwanted events
|
||||
|
||||
yield "event_type", event_type
|
||||
yield "data", data
|
||||
@@ -39,13 +39,11 @@ class FileStoreBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
yield "file_out", await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.file_in,
|
||||
user_id=user_id,
|
||||
return_content=input_data.base_64,
|
||||
)
|
||||
|
||||
@@ -188,3 +186,31 @@ class UniversalTypeConverterBlock(Block):
|
||||
yield "value", converted_value
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to convert value: {str(e)}"
|
||||
|
||||
|
||||
class ReverseListOrderBlock(Block):
|
||||
"""
|
||||
A block which takes in a list and returns it in the opposite order.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
input_list: list[Any] = SchemaField(description="The list to reverse")
|
||||
|
||||
class Output(BlockSchema):
|
||||
reversed_list: list[Any] = SchemaField(description="The list in reversed order")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="422cb708-3109-4277-bfe3-bc2ae5812777",
|
||||
description="Reverses the order of elements in a list",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ReverseListOrderBlock.Input,
|
||||
output_schema=ReverseListOrderBlock.Output,
|
||||
test_input={"input_list": [1, 2, 3, 4, 5]},
|
||||
test_output=[("reversed_list", [5, 4, 3, 2, 1])],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
reversed_list = list(input_data.input_list)
|
||||
reversed_list.reverse()
|
||||
yield "reversed_list", reversed_list
|
||||
|
||||
109
autogpt_platform/backend/backend/blocks/csv.py
Normal file
109
autogpt_platform/backend/backend/blocks/csv.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import ContributorDetails, SchemaField
|
||||
|
||||
|
||||
class ReadCsvBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
contents: str = SchemaField(
|
||||
description="The contents of the CSV file to read",
|
||||
placeholder="a, b, c\n1,2,3\n4,5,6",
|
||||
)
|
||||
delimiter: str = SchemaField(
|
||||
description="The delimiter used in the CSV file",
|
||||
default=",",
|
||||
)
|
||||
quotechar: str = SchemaField(
|
||||
description="The character used to quote fields",
|
||||
default='"',
|
||||
)
|
||||
escapechar: str = SchemaField(
|
||||
description="The character used to escape the delimiter",
|
||||
default="\\",
|
||||
)
|
||||
has_header: bool = SchemaField(
|
||||
description="Whether the CSV file has a header row",
|
||||
default=True,
|
||||
)
|
||||
skip_rows: int = SchemaField(
|
||||
description="The number of rows to skip from the start of the file",
|
||||
default=0,
|
||||
)
|
||||
strip: bool = SchemaField(
|
||||
description="Whether to strip whitespace from the values",
|
||||
default=True,
|
||||
)
|
||||
skip_columns: list[str] = SchemaField(
|
||||
description="The columns to skip from the start of the row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
row: dict[str, str] = SchemaField(
|
||||
description="The data produced from each row in the CSV file"
|
||||
)
|
||||
all_data: list[dict[str, str]] = SchemaField(
|
||||
description="All the data in the CSV file as a list of rows"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="acf7625e-d2cb-4941-bfeb-2819fc6fc015",
|
||||
input_schema=ReadCsvBlock.Input,
|
||||
output_schema=ReadCsvBlock.Output,
|
||||
description="Reads a CSV file and outputs the data as a list of dictionaries and individual rows via rows.",
|
||||
contributors=[ContributorDetails(name="Nicholas Tindle")],
|
||||
categories={BlockCategory.TEXT, BlockCategory.DATA},
|
||||
test_input={
|
||||
"contents": "a, b, c\n1,2,3\n4,5,6",
|
||||
},
|
||||
test_output=[
|
||||
("row", {"a": "1", "b": "2", "c": "3"}),
|
||||
("row", {"a": "4", "b": "5", "c": "6"}),
|
||||
(
|
||||
"all_data",
|
||||
[
|
||||
{"a": "1", "b": "2", "c": "3"},
|
||||
{"a": "4", "b": "5", "c": "6"},
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
import csv
|
||||
from io import StringIO
|
||||
|
||||
csv_file = StringIO(input_data.contents)
|
||||
reader = csv.reader(
|
||||
csv_file,
|
||||
delimiter=input_data.delimiter,
|
||||
quotechar=input_data.quotechar,
|
||||
escapechar=input_data.escapechar,
|
||||
)
|
||||
|
||||
header = None
|
||||
if input_data.has_header:
|
||||
header = next(reader)
|
||||
if input_data.strip:
|
||||
header = [h.strip() for h in header]
|
||||
|
||||
for _ in range(input_data.skip_rows):
|
||||
next(reader)
|
||||
|
||||
def process_row(row):
|
||||
data = {}
|
||||
for i, value in enumerate(row):
|
||||
if i not in input_data.skip_columns:
|
||||
if input_data.has_header and header:
|
||||
data[header[i]] = value.strip() if input_data.strip else value
|
||||
else:
|
||||
data[str(i)] = value.strip() if input_data.strip else value
|
||||
return data
|
||||
|
||||
all_data = []
|
||||
for row in reader:
|
||||
processed_row = process_row(row)
|
||||
all_data.append(processed_row)
|
||||
yield "row", processed_row
|
||||
|
||||
yield "all_data", all_data
|
||||
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
ElevenLabs integration blocks for AutoGPT Platform.
|
||||
"""
|
||||
|
||||
# Speech generation blocks
|
||||
from .speech import (
|
||||
ElevenLabsGenerateSpeechBlock,
|
||||
ElevenLabsGenerateSpeechWithTimestampsBlock,
|
||||
)
|
||||
|
||||
# Speech-to-text blocks
|
||||
from .transcription import (
|
||||
ElevenLabsTranscribeAudioAsyncBlock,
|
||||
ElevenLabsTranscribeAudioSyncBlock,
|
||||
)
|
||||
|
||||
# Webhook trigger blocks
|
||||
from .triggers import ElevenLabsWebhookTriggerBlock
|
||||
|
||||
# Utility blocks
|
||||
from .utility import ElevenLabsGetUsageStatsBlock, ElevenLabsListModelsBlock
|
||||
|
||||
# Voice management blocks
|
||||
from .voices import (
|
||||
ElevenLabsCreateVoiceCloneBlock,
|
||||
ElevenLabsDeleteVoiceBlock,
|
||||
ElevenLabsGetVoiceDetailsBlock,
|
||||
ElevenLabsListVoicesBlock,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Voice management
|
||||
"ElevenLabsListVoicesBlock",
|
||||
"ElevenLabsGetVoiceDetailsBlock",
|
||||
"ElevenLabsCreateVoiceCloneBlock",
|
||||
"ElevenLabsDeleteVoiceBlock",
|
||||
# Speech generation
|
||||
"ElevenLabsGenerateSpeechBlock",
|
||||
"ElevenLabsGenerateSpeechWithTimestampsBlock",
|
||||
# Speech-to-text
|
||||
"ElevenLabsTranscribeAudioSyncBlock",
|
||||
"ElevenLabsTranscribeAudioAsyncBlock",
|
||||
# Utility
|
||||
"ElevenLabsListModelsBlock",
|
||||
"ElevenLabsGetUsageStatsBlock",
|
||||
# Webhook triggers
|
||||
"ElevenLabsWebhookTriggerBlock",
|
||||
]
|
||||
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Shared configuration for all ElevenLabs blocks using the SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
from ._webhook import ElevenLabsWebhookManager
|
||||
|
||||
# Configure the ElevenLabs provider with API key authentication
|
||||
elevenlabs = (
|
||||
ProviderBuilder("elevenlabs")
|
||||
.with_api_key("ELEVENLABS_API_KEY", "ElevenLabs API Key")
|
||||
.with_webhook_manager(ElevenLabsWebhookManager)
|
||||
.with_base_cost(2, BlockCostType.RUN) # Base cost for API calls
|
||||
.build()
|
||||
)
|
||||
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
ElevenLabs webhook manager for handling webhook events.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
from typing import Tuple
|
||||
|
||||
from backend.data.model import Credentials
|
||||
from backend.sdk import BaseWebhooksManager, ProviderName, Webhook
|
||||
|
||||
|
||||
class ElevenLabsWebhookManager(BaseWebhooksManager):
|
||||
"""Manages ElevenLabs webhook events."""
|
||||
|
||||
PROVIDER_NAME = ProviderName("elevenlabs")
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(cls, webhook: Webhook, request) -> Tuple[dict, str]:
|
||||
"""
|
||||
Validate incoming webhook payload and signature.
|
||||
|
||||
ElevenLabs supports HMAC authentication for webhooks.
|
||||
"""
|
||||
payload = await request.json()
|
||||
|
||||
# Verify webhook signature if configured
|
||||
if webhook.secret:
|
||||
webhook_secret = webhook.config.get("webhook_secret")
|
||||
if webhook_secret:
|
||||
# Get the raw body for signature verification
|
||||
body = await request.body()
|
||||
|
||||
# Calculate expected signature
|
||||
expected_signature = hmac.new(
|
||||
webhook_secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Get signature from headers
|
||||
signature = request.headers.get("x-elevenlabs-signature")
|
||||
|
||||
if signature and not hmac.compare_digest(signature, expected_signature):
|
||||
raise ValueError("Invalid webhook signature")
|
||||
|
||||
# Extract event type from payload
|
||||
event_type = payload.get("type", "unknown")
|
||||
return payload, event_type
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> tuple[str, dict]:
|
||||
"""
|
||||
Register a webhook with ElevenLabs.
|
||||
|
||||
Note: ElevenLabs webhook registration is done through their dashboard,
|
||||
not via API. This is a placeholder implementation.
|
||||
"""
|
||||
# ElevenLabs requires manual webhook setup through dashboard
|
||||
# Return empty webhook ID and config with instructions
|
||||
config = {
|
||||
"manual_setup_required": True,
|
||||
"webhook_secret": secret,
|
||||
"instructions": "Please configure webhook URL in ElevenLabs dashboard",
|
||||
}
|
||||
return "", config
|
||||
|
||||
async def _deregister_webhook(
|
||||
self, webhook: Webhook, credentials: Credentials
|
||||
) -> None:
|
||||
"""
|
||||
Deregister a webhook with ElevenLabs.
|
||||
|
||||
Note: ElevenLabs webhook removal is done through their dashboard.
|
||||
"""
|
||||
# ElevenLabs requires manual webhook removal through dashboard
|
||||
pass
|
||||
179
autogpt_platform/backend/backend/blocks/elevenlabs/speech.py
Normal file
179
autogpt_platform/backend/backend/blocks/elevenlabs/speech.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
ElevenLabs speech generation (text-to-speech) blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import elevenlabs
|
||||
|
||||
|
||||
class ElevenLabsGenerateSpeechBlock(Block):
|
||||
"""
|
||||
Turn text into audio (binary).
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
voice_id: str = SchemaField(description="ID of the voice to use")
|
||||
text: str = SchemaField(description="Text to convert to speech")
|
||||
model_id: str = SchemaField(
|
||||
description="Model ID to use for generation",
|
||||
default="eleven_multilingual_v2",
|
||||
)
|
||||
output_format: str = SchemaField(
|
||||
description="Audio format (e.g., mp3_44100_128)",
|
||||
default="mp3_44100_128",
|
||||
)
|
||||
voice_settings: Optional[dict] = SchemaField(
|
||||
description="Override voice settings (stability, similarity_boost, etc.)",
|
||||
default=None,
|
||||
)
|
||||
language_code: Optional[str] = SchemaField(
|
||||
description="Language code to enforce output language", default=None
|
||||
)
|
||||
seed: Optional[int] = SchemaField(
|
||||
description="Seed for reproducible output", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
audio: str = SchemaField(description="Base64-encoded audio data")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c5d6e7f8-a9b0-c1d2-e3f4-a5b6c7d8e9f0",
|
||||
description="Generate speech audio from text using a specified voice",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
import base64
|
||||
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body
|
||||
body: dict[str, str | int | dict] = {
|
||||
"text": input_data.text,
|
||||
"model_id": input_data.model_id,
|
||||
}
|
||||
|
||||
# Add optional fields
|
||||
if input_data.voice_settings:
|
||||
body["voice_settings"] = input_data.voice_settings
|
||||
if input_data.language_code:
|
||||
body["language_code"] = input_data.language_code
|
||||
if input_data.seed is not None:
|
||||
body["seed"] = input_data.seed
|
||||
|
||||
# Generate speech
|
||||
response = await Requests().post(
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{input_data.voice_id}",
|
||||
headers={
|
||||
"xi-api-key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=body,
|
||||
params={"output_format": input_data.output_format},
|
||||
)
|
||||
|
||||
# Get audio data and encode to base64
|
||||
audio_data = response.content
|
||||
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
|
||||
|
||||
yield "audio", audio_base64
|
||||
|
||||
|
||||
class ElevenLabsGenerateSpeechWithTimestampsBlock(Block):
|
||||
"""
|
||||
Text to audio AND per-character timing data.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
voice_id: str = SchemaField(description="ID of the voice to use")
|
||||
text: str = SchemaField(description="Text to convert to speech")
|
||||
model_id: str = SchemaField(
|
||||
description="Model ID to use for generation",
|
||||
default="eleven_multilingual_v2",
|
||||
)
|
||||
output_format: str = SchemaField(
|
||||
description="Audio format (e.g., mp3_44100_128)",
|
||||
default="mp3_44100_128",
|
||||
)
|
||||
voice_settings: Optional[dict] = SchemaField(
|
||||
description="Override voice settings (stability, similarity_boost, etc.)",
|
||||
default=None,
|
||||
)
|
||||
language_code: Optional[str] = SchemaField(
|
||||
description="Language code to enforce output language", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
audio_base64: str = SchemaField(description="Base64-encoded audio data")
|
||||
alignment: dict = SchemaField(
|
||||
description="Character-level timing alignment data"
|
||||
)
|
||||
normalized_alignment: dict = SchemaField(
|
||||
description="Normalized text alignment data"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d6e7f8a9-b0c1-d2e3-f4a5-b6c7d8e9f0a1",
|
||||
description="Generate speech with character-level timestamp information",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body
|
||||
body: dict[str, str | dict] = {
|
||||
"text": input_data.text,
|
||||
"model_id": input_data.model_id,
|
||||
}
|
||||
|
||||
# Add optional fields
|
||||
if input_data.voice_settings:
|
||||
body["voice_settings"] = input_data.voice_settings
|
||||
if input_data.language_code:
|
||||
body["language_code"] = input_data.language_code
|
||||
|
||||
# Generate speech with timestamps
|
||||
response = await Requests().post(
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{input_data.voice_id}/with-timestamps",
|
||||
headers={
|
||||
"xi-api-key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=body,
|
||||
params={"output_format": input_data.output_format},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "audio_base64", data.get("audio_base64", "")
|
||||
yield "alignment", data.get("alignment", {})
|
||||
yield "normalized_alignment", data.get("normalized_alignment", {})
|
||||
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
ElevenLabs speech-to-text (transcription) blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import elevenlabs
|
||||
|
||||
|
||||
class ElevenLabsTranscribeAudioSyncBlock(Block):
|
||||
"""
|
||||
Synchronously convert audio to text (+ word timestamps, diarization).
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
model_id: str = SchemaField(
|
||||
description="Model ID for transcription", default="scribe_v1"
|
||||
)
|
||||
file: Optional[str] = SchemaField(
|
||||
description="Base64-encoded audio file", default=None
|
||||
)
|
||||
cloud_storage_url: Optional[str] = SchemaField(
|
||||
description="URL to audio file in cloud storage", default=None
|
||||
)
|
||||
language_code: Optional[str] = SchemaField(
|
||||
description="Language code (ISO 639-1 or -3) to improve accuracy",
|
||||
default=None,
|
||||
)
|
||||
diarize: bool = SchemaField(
|
||||
description="Enable speaker diarization", default=False
|
||||
)
|
||||
num_speakers: Optional[int] = SchemaField(
|
||||
description="Expected number of speakers (max 32)", default=None
|
||||
)
|
||||
timestamps_granularity: str = SchemaField(
|
||||
description="Timestamp detail level: word, character, or none",
|
||||
default="word",
|
||||
)
|
||||
tag_audio_events: bool = SchemaField(
|
||||
description="Tag non-speech sounds (laughter, noise)", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
text: str = SchemaField(description="Full transcribed text")
|
||||
words: list[dict] = SchemaField(
|
||||
description="Array with word timing and speaker info"
|
||||
)
|
||||
language_code: str = SchemaField(description="Detected language code")
|
||||
language_probability: float = SchemaField(
|
||||
description="Confidence in language detection"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e7f8a9b0-c1d2-e3f4-a5b6-c7d8e9f0a1b2",
|
||||
description="Transcribe audio to text with timing and speaker information",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Validate input - must have either file or URL
|
||||
if not input_data.file and not input_data.cloud_storage_url:
|
||||
raise ValueError("Either 'file' or 'cloud_storage_url' must be provided")
|
||||
if input_data.file and input_data.cloud_storage_url:
|
||||
raise ValueError(
|
||||
"Only one of 'file' or 'cloud_storage_url' should be provided"
|
||||
)
|
||||
|
||||
# Build form data
|
||||
form_data = {
|
||||
"model_id": input_data.model_id,
|
||||
"diarize": str(input_data.diarize).lower(),
|
||||
"timestamps_granularity": input_data.timestamps_granularity,
|
||||
"tag_audio_events": str(input_data.tag_audio_events).lower(),
|
||||
}
|
||||
|
||||
if input_data.language_code:
|
||||
form_data["language_code"] = input_data.language_code
|
||||
if input_data.num_speakers is not None:
|
||||
form_data["num_speakers"] = str(input_data.num_speakers)
|
||||
|
||||
# Handle file or URL
|
||||
files = None
|
||||
if input_data.file:
|
||||
# Decode base64 file
|
||||
file_data = base64.b64decode(input_data.file)
|
||||
files = [("file", ("audio.wav", BytesIO(file_data), "audio/wav"))]
|
||||
elif input_data.cloud_storage_url:
|
||||
form_data["cloud_storage_url"] = input_data.cloud_storage_url
|
||||
|
||||
# Transcribe audio
|
||||
response = await Requests().post(
|
||||
"https://api.elevenlabs.io/v1/speech-to-text",
|
||||
headers={"xi-api-key": api_key},
|
||||
data=form_data,
|
||||
files=files,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "text", data.get("text", "")
|
||||
yield "words", data.get("words", [])
|
||||
yield "language_code", data.get("language_code", "")
|
||||
yield "language_probability", data.get("language_probability", 0.0)
|
||||
|
||||
|
||||
class ElevenLabsTranscribeAudioAsyncBlock(Block):
|
||||
"""
|
||||
Kick off transcription that returns quickly; result arrives via webhook.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
model_id: str = SchemaField(
|
||||
description="Model ID for transcription", default="scribe_v1"
|
||||
)
|
||||
file: Optional[str] = SchemaField(
|
||||
description="Base64-encoded audio file", default=None
|
||||
)
|
||||
cloud_storage_url: Optional[str] = SchemaField(
|
||||
description="URL to audio file in cloud storage", default=None
|
||||
)
|
||||
language_code: Optional[str] = SchemaField(
|
||||
description="Language code (ISO 639-1 or -3) to improve accuracy",
|
||||
default=None,
|
||||
)
|
||||
diarize: bool = SchemaField(
|
||||
description="Enable speaker diarization", default=False
|
||||
)
|
||||
num_speakers: Optional[int] = SchemaField(
|
||||
description="Expected number of speakers (max 32)", default=None
|
||||
)
|
||||
timestamps_granularity: str = SchemaField(
|
||||
description="Timestamp detail level: word, character, or none",
|
||||
default="word",
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive transcription result",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
tracking_id: str = SchemaField(description="ID to track the transcription job")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f8a9b0c1-d2e3-f4a5-b6c7-d8e9f0a1b2c3",
|
||||
description="Start async transcription with webhook callback",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
import base64
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Validate input
|
||||
if not input_data.file and not input_data.cloud_storage_url:
|
||||
raise ValueError("Either 'file' or 'cloud_storage_url' must be provided")
|
||||
if input_data.file and input_data.cloud_storage_url:
|
||||
raise ValueError(
|
||||
"Only one of 'file' or 'cloud_storage_url' should be provided"
|
||||
)
|
||||
|
||||
# Build form data
|
||||
form_data = {
|
||||
"model_id": input_data.model_id,
|
||||
"diarize": str(input_data.diarize).lower(),
|
||||
"timestamps_granularity": input_data.timestamps_granularity,
|
||||
"webhook": "true", # Enable async mode
|
||||
}
|
||||
|
||||
if input_data.language_code:
|
||||
form_data["language_code"] = input_data.language_code
|
||||
if input_data.num_speakers is not None:
|
||||
form_data["num_speakers"] = str(input_data.num_speakers)
|
||||
if input_data.webhook_url:
|
||||
form_data["webhook_url"] = input_data.webhook_url
|
||||
|
||||
# Handle file or URL
|
||||
files = None
|
||||
if input_data.file:
|
||||
# Decode base64 file
|
||||
file_data = base64.b64decode(input_data.file)
|
||||
files = [("file", ("audio.wav", BytesIO(file_data), "audio/wav"))]
|
||||
elif input_data.cloud_storage_url:
|
||||
form_data["cloud_storage_url"] = input_data.cloud_storage_url
|
||||
|
||||
# Start async transcription
|
||||
response = await Requests().post(
|
||||
"https://api.elevenlabs.io/v1/speech-to-text",
|
||||
headers={"xi-api-key": api_key},
|
||||
data=form_data,
|
||||
files=files,
|
||||
)
|
||||
|
||||
# Generate tracking ID (API might return one)
|
||||
data = response.json()
|
||||
tracking_id = data.get("tracking_id", str(uuid.uuid4()))
|
||||
|
||||
yield "tracking_id", tracking_id
|
||||
160
autogpt_platform/backend/backend/blocks/elevenlabs/triggers.py
Normal file
160
autogpt_platform/backend/backend/blocks/elevenlabs/triggers.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
ElevenLabs webhook trigger blocks.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
CredentialsMetaInput,
|
||||
ProviderName,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import elevenlabs
|
||||
|
||||
|
||||
class ElevenLabsWebhookTriggerBlock(Block):
|
||||
"""
|
||||
Starts a flow when ElevenLabs POSTs an event (STT finished, voice removal, etc.).
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive webhooks (auto-generated)",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
class EventsFilter(BaseModel):
|
||||
"""ElevenLabs event types to subscribe to"""
|
||||
|
||||
speech_to_text_completed: bool = SchemaField(
|
||||
description="Speech-to-text transcription completed", default=True
|
||||
)
|
||||
post_call_transcription: bool = SchemaField(
|
||||
description="Conversational AI call transcription completed",
|
||||
default=True,
|
||||
)
|
||||
voice_removal_notice: bool = SchemaField(
|
||||
description="Voice scheduled for removal", default=True
|
||||
)
|
||||
voice_removed: bool = SchemaField(
|
||||
description="Voice has been removed", default=True
|
||||
)
|
||||
voice_removal_notice_withdrawn: bool = SchemaField(
|
||||
description="Voice removal cancelled", default=True
|
||||
)
|
||||
|
||||
events: EventsFilter = SchemaField(
|
||||
title="Events", description="The events to subscribe to"
|
||||
)
|
||||
|
||||
# Webhook payload - populated by the system
|
||||
payload: dict = SchemaField(
|
||||
description="Webhook payload data",
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
type: str = SchemaField(description="Event type")
|
||||
event_timestamp: int = SchemaField(description="Unix timestamp of the event")
|
||||
data: dict = SchemaField(description="Event-specific data payload")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c1d2e3f4-a5b6-c7d8-e9f0-a1b2c3d4e5f6",
|
||||
description="Receive webhook events from ElevenLabs",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
block_type=BlockType.WEBHOOK,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName("elevenlabs"),
|
||||
webhook_type="notification",
|
||||
event_filter_input="events",
|
||||
resource_format="",
|
||||
),
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Extract webhook data
|
||||
payload = input_data.payload
|
||||
|
||||
# Extract event type
|
||||
event_type = payload.get("type", "unknown")
|
||||
|
||||
# Map event types to filter fields
|
||||
event_filter_map = {
|
||||
"speech_to_text_completed": input_data.events.speech_to_text_completed,
|
||||
"post_call_transcription": input_data.events.post_call_transcription,
|
||||
"voice_removal_notice": input_data.events.voice_removal_notice,
|
||||
"voice_removed": input_data.events.voice_removed,
|
||||
"voice_removal_notice_withdrawn": input_data.events.voice_removal_notice_withdrawn,
|
||||
}
|
||||
|
||||
# Check if this event type is enabled
|
||||
if not event_filter_map.get(event_type, False):
|
||||
# Skip this event
|
||||
return
|
||||
|
||||
# Extract common fields
|
||||
yield "type", event_type
|
||||
yield "event_timestamp", payload.get("event_timestamp", 0)
|
||||
|
||||
# Extract event-specific data
|
||||
data = payload.get("data", {})
|
||||
|
||||
# Process based on event type
|
||||
if event_type == "speech_to_text_completed":
|
||||
# STT transcription completed
|
||||
processed_data = {
|
||||
"transcription_id": data.get("transcription_id"),
|
||||
"text": data.get("text"),
|
||||
"words": data.get("words", []),
|
||||
"language_code": data.get("language_code"),
|
||||
"language_probability": data.get("language_probability"),
|
||||
}
|
||||
elif event_type == "post_call_transcription":
|
||||
# Conversational AI call transcription
|
||||
processed_data = {
|
||||
"agent_id": data.get("agent_id"),
|
||||
"conversation_id": data.get("conversation_id"),
|
||||
"transcript": data.get("transcript"),
|
||||
"metadata": data.get("metadata", {}),
|
||||
}
|
||||
elif event_type == "voice_removal_notice":
|
||||
# Voice scheduled for removal
|
||||
processed_data = {
|
||||
"voice_id": data.get("voice_id"),
|
||||
"voice_name": data.get("voice_name"),
|
||||
"removal_date": data.get("removal_date"),
|
||||
"reason": data.get("reason"),
|
||||
}
|
||||
elif event_type == "voice_removal_notice_withdrawn":
|
||||
# Voice removal cancelled
|
||||
processed_data = {
|
||||
"voice_id": data.get("voice_id"),
|
||||
"voice_name": data.get("voice_name"),
|
||||
}
|
||||
elif event_type == "voice_removed":
|
||||
# Voice has been removed
|
||||
processed_data = {
|
||||
"voice_id": data.get("voice_id"),
|
||||
"voice_name": data.get("voice_name"),
|
||||
"removed_at": data.get("removed_at"),
|
||||
}
|
||||
else:
|
||||
# Unknown event type, pass through raw data
|
||||
processed_data = data
|
||||
|
||||
yield "data", processed_data
|
||||
116
autogpt_platform/backend/backend/blocks/elevenlabs/utility.py
Normal file
116
autogpt_platform/backend/backend/blocks/elevenlabs/utility.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
ElevenLabs utility blocks for models and usage stats.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import elevenlabs
|
||||
|
||||
|
||||
class ElevenLabsListModelsBlock(Block):
|
||||
"""
|
||||
Get all available model IDs & capabilities.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
models: list[dict] = SchemaField(
|
||||
description="Array of model objects with capabilities"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a9b0c1d2-e3f4-a5b6-c7d8-e9f0a1b2c3d4",
|
||||
description="List all available voice models and their capabilities",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Fetch models
|
||||
response = await Requests().get(
|
||||
"https://api.elevenlabs.io/v1/models",
|
||||
headers={"xi-api-key": api_key},
|
||||
)
|
||||
|
||||
models = response.json()
|
||||
|
||||
yield "models", models
|
||||
|
||||
|
||||
class ElevenLabsGetUsageStatsBlock(Block):
|
||||
"""
|
||||
Character / credit usage for billing dashboards.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
start_unix: int = SchemaField(
|
||||
description="Start timestamp in Unix epoch seconds"
|
||||
)
|
||||
end_unix: int = SchemaField(description="End timestamp in Unix epoch seconds")
|
||||
aggregation_interval: str = SchemaField(
|
||||
description="Aggregation interval: daily or monthly",
|
||||
default="daily",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
usage: list[dict] = SchemaField(description="Array of usage data per interval")
|
||||
total_character_count: int = SchemaField(
|
||||
description="Total characters used in period"
|
||||
)
|
||||
total_requests: int = SchemaField(description="Total API requests in period")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b0c1d2e3-f4a5-b6c7-d8e9-f0a1b2c3d4e5",
|
||||
description="Get character and credit usage statistics",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {
|
||||
"start_unix": input_data.start_unix,
|
||||
"end_unix": input_data.end_unix,
|
||||
"aggregation_interval": input_data.aggregation_interval,
|
||||
}
|
||||
|
||||
# Fetch usage stats
|
||||
response = await Requests().get(
|
||||
"https://api.elevenlabs.io/v1/usage/character-stats",
|
||||
headers={"xi-api-key": api_key},
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "usage", data.get("usage", [])
|
||||
yield "total_character_count", data.get("total_character_count", 0)
|
||||
yield "total_requests", data.get("total_requests", 0)
|
||||
249
autogpt_platform/backend/backend/blocks/elevenlabs/voices.py
Normal file
249
autogpt_platform/backend/backend/blocks/elevenlabs/voices.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""
|
||||
ElevenLabs voice management blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import elevenlabs
|
||||
|
||||
|
||||
class ElevenLabsListVoicesBlock(Block):
|
||||
"""
|
||||
Fetch all voices the account can use (for pick-lists, UI menus, etc.).
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
search: str = SchemaField(
|
||||
description="Search term to filter voices", default=""
|
||||
)
|
||||
voice_type: Optional[str] = SchemaField(
|
||||
description="Filter by voice type: premade, cloned, or professional",
|
||||
default=None,
|
||||
)
|
||||
page_size: int = SchemaField(
|
||||
description="Number of voices per page (max 100)", default=10
|
||||
)
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for fetching next page", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
voices: list[dict] = SchemaField(
|
||||
description="Array of voice objects with id, name, category, etc."
|
||||
)
|
||||
next_page_token: Optional[str] = SchemaField(
|
||||
description="Token for fetching next page, null if no more pages"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e1a2b3c4-d5e6-f7a8-b9c0-d1e2f3a4b5c6",
|
||||
description="List all available voices with filtering and pagination",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params: dict[str, str | int] = {"page_size": input_data.page_size}
|
||||
|
||||
if input_data.search:
|
||||
params["search"] = input_data.search
|
||||
if input_data.voice_type:
|
||||
params["voice_type"] = input_data.voice_type
|
||||
if input_data.next_page_token:
|
||||
params["next_page_token"] = input_data.next_page_token
|
||||
|
||||
# Fetch voices
|
||||
response = await Requests().get(
|
||||
"https://api.elevenlabs.io/v2/voices",
|
||||
headers={"xi-api-key": api_key},
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "voices", data.get("voices", [])
|
||||
yield "next_page_token", data.get("next_page_token")
|
||||
|
||||
|
||||
class ElevenLabsGetVoiceDetailsBlock(Block):
|
||||
"""
|
||||
Retrieve metadata/settings for a single voice.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
voice_id: str = SchemaField(description="The ID of the voice to retrieve")
|
||||
|
||||
class Output(BlockSchema):
|
||||
voice: dict = SchemaField(
|
||||
description="Voice object with name, labels, settings, etc."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f2a3b4c5-d6e7-f8a9-b0c1-d2e3f4a5b6c7",
|
||||
description="Get detailed information about a specific voice",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Fetch voice details
|
||||
response = await Requests().get(
|
||||
f"https://api.elevenlabs.io/v1/voices/{input_data.voice_id}",
|
||||
headers={"xi-api-key": api_key},
|
||||
)
|
||||
|
||||
voice = response.json()
|
||||
|
||||
yield "voice", voice
|
||||
|
||||
|
||||
class ElevenLabsCreateVoiceCloneBlock(Block):
|
||||
"""
|
||||
Upload sample clips to create a custom (IVC) voice.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
name: str = SchemaField(description="Name for the new voice")
|
||||
files: list[str] = SchemaField(
|
||||
description="Base64-encoded audio files (1-10 files, max 25MB each)"
|
||||
)
|
||||
description: str = SchemaField(
|
||||
description="Description of the voice", default=""
|
||||
)
|
||||
labels: dict = SchemaField(
|
||||
description="Metadata labels (e.g., accent, age)", default={}
|
||||
)
|
||||
remove_background_noise: bool = SchemaField(
|
||||
description="Whether to remove background noise from samples", default=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
voice_id: str = SchemaField(description="ID of the newly created voice")
|
||||
requires_verification: bool = SchemaField(
|
||||
description="Whether the voice requires verification"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a3b4c5d6-e7f8-a9b0-c1d2-e3f4a5b6c7d8",
|
||||
description="Create a new voice clone from audio samples",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
import base64
|
||||
import json
|
||||
from io import BytesIO
|
||||
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Prepare multipart form data
|
||||
form_data = {
|
||||
"name": input_data.name,
|
||||
}
|
||||
|
||||
if input_data.description:
|
||||
form_data["description"] = input_data.description
|
||||
if input_data.labels:
|
||||
form_data["labels"] = json.dumps(input_data.labels)
|
||||
if input_data.remove_background_noise:
|
||||
form_data["remove_background_noise"] = "true"
|
||||
|
||||
# Prepare files
|
||||
files = []
|
||||
for i, file_b64 in enumerate(input_data.files):
|
||||
file_data = base64.b64decode(file_b64)
|
||||
files.append(
|
||||
("files", (f"sample_{i}.mp3", BytesIO(file_data), "audio/mpeg"))
|
||||
)
|
||||
|
||||
# Create voice
|
||||
response = await Requests().post(
|
||||
"https://api.elevenlabs.io/v1/voices/add",
|
||||
headers={"xi-api-key": api_key},
|
||||
data=form_data,
|
||||
files=files,
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
|
||||
yield "voice_id", result.get("voice_id", "")
|
||||
yield "requires_verification", result.get("requires_verification", False)
|
||||
|
||||
|
||||
class ElevenLabsDeleteVoiceBlock(Block):
|
||||
"""
|
||||
Permanently remove a custom voice.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
voice_id: str = SchemaField(description="The ID of the voice to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Deletion status (ok or error)")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b4c5d6e7-f8a9-b0c1-d2e3-f4a5b6c7d8e9",
|
||||
description="Delete a custom voice from your account",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Delete voice
|
||||
response = await Requests().delete(
|
||||
f"https://api.elevenlabs.io/v1/voices/{input_data.voice_id}",
|
||||
headers={"xi-api-key": api_key},
|
||||
)
|
||||
|
||||
# Check if successful
|
||||
if response.status in [200, 204]:
|
||||
yield "status", "ok"
|
||||
else:
|
||||
yield "status", "error"
|
||||
@@ -6,10 +6,10 @@ import hashlib
|
||||
import hmac
|
||||
from enum import Enum
|
||||
|
||||
from backend.data.model import Credentials
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseWebhooksManager,
|
||||
Credentials,
|
||||
ProviderName,
|
||||
Requests,
|
||||
Webhook,
|
||||
|
||||
190
autogpt_platform/backend/backend/blocks/exa/answers.md
Normal file
190
autogpt_platform/backend/backend/blocks/exa/answers.md
Normal file
@@ -0,0 +1,190 @@
|
||||
|
||||
|
||||
Exa home pagelight logo
|
||||
|
||||
Search or ask...
|
||||
⌘K
|
||||
Exa Search
|
||||
Log In
|
||||
API Dashboard
|
||||
Documentation
|
||||
Examples
|
||||
Integrations
|
||||
SDKs
|
||||
Websets
|
||||
Changelog
|
||||
Discord
|
||||
Blog
|
||||
Getting Started
|
||||
|
||||
Overview
|
||||
Quickstart
|
||||
API Reference
|
||||
|
||||
POST
|
||||
Search
|
||||
POST
|
||||
Get contents
|
||||
POST
|
||||
Find similar links
|
||||
POST
|
||||
Answer
|
||||
OpenAPI Specification
|
||||
RAG Quick Start Guide
|
||||
|
||||
RAG with Exa and OpenAI
|
||||
RAG with LangChain
|
||||
OpenAI Exa Wrapper
|
||||
CrewAI agents with Exa
|
||||
RAG with LlamaIndex
|
||||
Tool calling with GPT
|
||||
Tool calling with Claude
|
||||
OpenAI Chat Completions
|
||||
OpenAI Responses API
|
||||
Concepts
|
||||
|
||||
How Exa Search Works
|
||||
The Exa Index
|
||||
Contents retrieval with Exa API
|
||||
Exa's Capabilities Explained
|
||||
FAQs
|
||||
Crawling Subpages with Exa
|
||||
Exa LiveCrawl
|
||||
Admin
|
||||
|
||||
Setting Up and Managing Your Team
|
||||
Rate Limits
|
||||
Enterprise Documentation & Security
|
||||
API Reference
|
||||
Answer
|
||||
Get an LLM answer to a question informed by Exa search results. Fully compatible with OpenAI’s chat completions endpoint - docs here. /answer performs an Exa search and uses an LLM to generate either:
|
||||
|
||||
A direct answer for specific queries. (i.e. “What is the capital of France?” would return “Paris”)
|
||||
A detailed summary with citations for open-ended queries (i.e. “What is the state of ai in healthcare?” would return a summary with citations to relevant sources)
|
||||
The response includes both the generated answer and the sources used to create it. The endpoint also supports streaming (as stream=True), which will returns tokens as they are generated.
|
||||
POST
|
||||
/
|
||||
answer
|
||||
|
||||
Try it
|
||||
Get your Exa API key
|
||||
|
||||
Authorizations
|
||||
|
||||
x-api-key
|
||||
stringheaderrequired
|
||||
API key can be provided either via x-api-key header or Authorization header with Bearer scheme
|
||||
Body
|
||||
application/json
|
||||
|
||||
query
|
||||
stringrequired
|
||||
The question or query to answer.
|
||||
Minimum length: 1
|
||||
Example:
|
||||
"What is the latest valuation of SpaceX?"
|
||||
|
||||
stream
|
||||
booleandefault:false
|
||||
If true, the response is returned as a server-sent events (SSS) stream.
|
||||
|
||||
text
|
||||
booleandefault:false
|
||||
If true, the response includes full text content in the search results
|
||||
|
||||
model
|
||||
enum<string>default:exa
|
||||
The search model to use for the answer. Exa passes only one query to exa, while exa-pro also passes 2 expanded queries to our search model.
|
||||
Available options: exa, exa-pro
|
||||
Response
|
||||
200
|
||||
application/json
|
||||
|
||||
OK
|
||||
|
||||
answer
|
||||
string
|
||||
The generated answer based on search results.
|
||||
Example:
|
||||
"$350 billion."
|
||||
|
||||
citations
|
||||
object[]
|
||||
Search results used to generate the answer.
|
||||
|
||||
Show child attributes
|
||||
|
||||
costDollars
|
||||
object
|
||||
|
||||
Show child attributes
|
||||
Find similar links
|
||||
OpenAPI Specification
|
||||
x
|
||||
discord
|
||||
Powered by Mintlify
|
||||
|
||||
cURL
|
||||
|
||||
Python
|
||||
|
||||
JavaScript
|
||||
|
||||
Copy
|
||||
# pip install exa-py
|
||||
from exa_py import Exa
|
||||
exa = Exa('YOUR_EXA_API_KEY')
|
||||
|
||||
result = exa.answer(
|
||||
"What is the latest valuation of SpaceX?",
|
||||
text=True
|
||||
)
|
||||
|
||||
print(result)
|
||||
|
||||
200
|
||||
|
||||
Copy
|
||||
{
|
||||
"answer": "$350 billion.",
|
||||
"citations": [
|
||||
{
|
||||
"id": "https://www.theguardian.com/science/2024/dec/11/spacex-valued-at-350bn-as-company-agrees-to-buy-shares-from-employees",
|
||||
"url": "https://www.theguardian.com/science/2024/dec/11/spacex-valued-at-350bn-as-company-agrees-to-buy-shares-from-employees",
|
||||
"title": "SpaceX valued at $350bn as company agrees to buy shares from ...",
|
||||
"author": "Dan Milmon",
|
||||
"publishedDate": "2023-11-16T01:36:32.547Z",
|
||||
"text": "SpaceX valued at $350bn as company agrees to buy shares from ...",
|
||||
"image": "https://i.guim.co.uk/img/media/7cfee7e84b24b73c97a079c402642a333ad31e77/0_380_6176_3706/master/6176.jpg?width=1200&height=630&quality=85&auto=format&fit=crop&overlay-align=bottom%2Cleft&overlay-width=100p&overlay-base64=L2ltZy9zdGF0aWMvb3ZlcmxheXMvdGctZGVmYXVsdC5wbmc&enable=upscale&s=71ebb2fbf458c185229d02d380c01530",
|
||||
"favicon": "https://assets.guim.co.uk/static/frontend/icons/homescreen/apple-touch-icon.svg"
|
||||
}
|
||||
],
|
||||
"costDollars": {
|
||||
"total": 0.005,
|
||||
"breakDown": [
|
||||
{
|
||||
"search": 0.005,
|
||||
"contents": 0,
|
||||
"breakdown": {
|
||||
"keywordSearch": 0,
|
||||
"neuralSearch": 0.005,
|
||||
"contentText": 0,
|
||||
"contentHighlight": 0,
|
||||
"contentSummary": 0
|
||||
}
|
||||
}
|
||||
],
|
||||
"perRequestPrices": {
|
||||
"neuralSearch_1_25_results": 0.005,
|
||||
"neuralSearch_26_100_results": 0.025,
|
||||
"neuralSearch_100_plus_results": 1,
|
||||
"keywordSearch_1_100_results": 0.0025,
|
||||
"keywordSearch_100_plus_results": 3
|
||||
},
|
||||
"perPagePrices": {
|
||||
"contentText": 0.001,
|
||||
"contentHighlight": 0.001,
|
||||
"contentSummary": 0.001
|
||||
}
|
||||
}
|
||||
}
|
||||
1004
autogpt_platform/backend/backend/blocks/exa/webset_webhook.md
Normal file
1004
autogpt_platform/backend/backend/blocks/exa/webset_webhook.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,81 @@
|
||||
# Example Blocks Deployment Guide
|
||||
|
||||
## Overview
|
||||
|
||||
Example blocks are disabled by default in production environments to keep the production block list clean and focused on real functionality. This guide explains how to control the visibility of example blocks.
|
||||
|
||||
## Configuration
|
||||
|
||||
Example blocks are controlled by the `ENABLE_EXAMPLE_BLOCKS` setting:
|
||||
|
||||
- **Default**: `false` (example blocks are hidden)
|
||||
- **Development**: Set to `true` to show example blocks
|
||||
|
||||
## How to Enable/Disable
|
||||
|
||||
### Method 1: Environment Variable (Recommended)
|
||||
|
||||
Add to your `.env` file:
|
||||
|
||||
```bash
|
||||
# Enable example blocks in development
|
||||
ENABLE_EXAMPLE_BLOCKS=true
|
||||
|
||||
# Disable example blocks in production (default)
|
||||
ENABLE_EXAMPLE_BLOCKS=false
|
||||
```
|
||||
|
||||
### Method 2: Configuration File
|
||||
|
||||
If you're using a `config.json` file:
|
||||
|
||||
```json
|
||||
{
|
||||
"enable_example_blocks": true
|
||||
}
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
The setting is checked in `backend/blocks/__init__.py` during the block loading process:
|
||||
|
||||
1. The `load_all_blocks()` function reads the `enable_example_blocks` setting from `Config`
|
||||
2. If disabled (default), any Python files in the `examples/` directory are skipped
|
||||
3. If enabled, example blocks are loaded normally
|
||||
|
||||
## Production Deployment
|
||||
|
||||
For production deployments:
|
||||
|
||||
1. **Do not set** `ENABLE_EXAMPLE_BLOCKS` in your production `.env` file (it defaults to `false`)
|
||||
2. Or explicitly set `ENABLE_EXAMPLE_BLOCKS=false` for clarity
|
||||
3. Example blocks will not appear in the block list or be available for use
|
||||
|
||||
## Development Environment
|
||||
|
||||
For local development:
|
||||
|
||||
1. Set `ENABLE_EXAMPLE_BLOCKS=true` in your `.env` file
|
||||
2. Restart your backend server
|
||||
3. Example blocks will be available for testing and demonstration
|
||||
|
||||
## Verification
|
||||
|
||||
To verify the setting is working:
|
||||
|
||||
```python
|
||||
# Check current setting
|
||||
from backend.util.settings import Config
|
||||
config = Config()
|
||||
print(f"Example blocks enabled: {config.enable_example_blocks}")
|
||||
|
||||
# Check loaded blocks
|
||||
from backend.blocks import load_all_blocks
|
||||
blocks = load_all_blocks()
|
||||
example_blocks = [b for b in blocks.values() if 'examples' in b.__module__]
|
||||
print(f"Example blocks loaded: {len(example_blocks)}")
|
||||
```
|
||||
|
||||
## Security Note
|
||||
|
||||
Example blocks are for demonstration purposes only and may not follow production security standards. Always keep them disabled in production environments.
|
||||
@@ -129,7 +129,6 @@ class AIImageEditorBlock(Block):
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
result = await self.run_model(
|
||||
@@ -140,7 +139,6 @@ class AIImageEditorBlock(Block):
|
||||
await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.input_image,
|
||||
user_id=user_id,
|
||||
return_content=True,
|
||||
)
|
||||
if input_data.input_image
|
||||
|
||||
13
autogpt_platform/backend/backend/blocks/gem/_config.py
Normal file
13
autogpt_platform/backend/backend/blocks/gem/_config.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Shared configuration for all GEM blocks using the new SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
# Configure the GEM provider once for all blocks
|
||||
gem = (
|
||||
ProviderBuilder("gem")
|
||||
.with_api_key("GEM_API_KEY", "GEM API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
1617
autogpt_platform/backend/backend/blocks/gem/blocks.py
Normal file
1617
autogpt_platform/backend/backend/blocks/gem/blocks.py
Normal file
File diff suppressed because it is too large
Load Diff
2751
autogpt_platform/backend/backend/blocks/gem/gem.md
Normal file
2751
autogpt_platform/backend/backend/blocks/gem/gem.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
from email.utils import getaddresses, parseaddr
|
||||
from pathlib import Path
|
||||
from email.utils import parseaddr
|
||||
from typing import List
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
@@ -10,7 +9,6 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ._auth import (
|
||||
@@ -31,7 +29,6 @@ class Attachment(BaseModel):
|
||||
|
||||
|
||||
class Email(BaseModel):
|
||||
threadId: str
|
||||
id: str
|
||||
subject: str
|
||||
snippet: str
|
||||
@@ -43,12 +40,6 @@ class Email(BaseModel):
|
||||
attachments: List[Attachment]
|
||||
|
||||
|
||||
class Thread(BaseModel):
|
||||
id: str
|
||||
messages: list[Email]
|
||||
historyId: str
|
||||
|
||||
|
||||
class GmailReadBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
@@ -92,7 +83,6 @@ class GmailReadBlock(Block):
|
||||
(
|
||||
"email",
|
||||
{
|
||||
"threadId": "t1",
|
||||
"id": "1",
|
||||
"subject": "Test Email",
|
||||
"snippet": "This is a test email",
|
||||
@@ -108,7 +98,6 @@ class GmailReadBlock(Block):
|
||||
"emails",
|
||||
[
|
||||
{
|
||||
"threadId": "t1",
|
||||
"id": "1",
|
||||
"subject": "Test Email",
|
||||
"snippet": "This is a test email",
|
||||
@@ -125,7 +114,6 @@ class GmailReadBlock(Block):
|
||||
test_mock={
|
||||
"_read_emails": lambda *args, **kwargs: [
|
||||
{
|
||||
"threadId": "t1",
|
||||
"id": "1",
|
||||
"subject": "Test Email",
|
||||
"snippet": "This is a test email",
|
||||
@@ -146,11 +134,7 @@ class GmailReadBlock(Block):
|
||||
) -> BlockOutput:
|
||||
service = GmailReadBlock._build_service(credentials, **kwargs)
|
||||
messages = await asyncio.to_thread(
|
||||
self._read_emails,
|
||||
service,
|
||||
input_data.query,
|
||||
input_data.max_results,
|
||||
credentials.scopes,
|
||||
self._read_emails, service, input_data.query, input_data.max_results
|
||||
)
|
||||
for email in messages:
|
||||
yield "email", email
|
||||
@@ -177,31 +161,22 @@ class GmailReadBlock(Block):
|
||||
return build("gmail", "v1", credentials=creds)
|
||||
|
||||
def _read_emails(
|
||||
self,
|
||||
service,
|
||||
query: str | None,
|
||||
max_results: int | None,
|
||||
scopes: list[str] | None,
|
||||
self, service, query: str | None, max_results: int | None
|
||||
) -> list[Email]:
|
||||
scopes = [s.lower() for s in (scopes or [])]
|
||||
list_kwargs = {"userId": "me", "maxResults": max_results or 10}
|
||||
if query and "https://www.googleapis.com/auth/gmail.metadata" not in scopes:
|
||||
list_kwargs["q"] = query
|
||||
|
||||
results = service.users().messages().list(**list_kwargs).execute()
|
||||
results = (
|
||||
service.users()
|
||||
.messages()
|
||||
.list(userId="me", q=query or "", maxResults=max_results or 10)
|
||||
.execute()
|
||||
)
|
||||
messages = results.get("messages", [])
|
||||
|
||||
email_data = []
|
||||
for message in messages:
|
||||
format_type = (
|
||||
"metadata"
|
||||
if "https://www.googleapis.com/auth/gmail.metadata" in scopes
|
||||
else "full"
|
||||
)
|
||||
msg = (
|
||||
service.users()
|
||||
.messages()
|
||||
.get(userId="me", id=message["id"], format=format_type)
|
||||
.get(userId="me", id=message["id"], format="full")
|
||||
.execute()
|
||||
)
|
||||
|
||||
@@ -213,14 +188,13 @@ class GmailReadBlock(Block):
|
||||
attachments = self._get_attachments(service, msg)
|
||||
|
||||
email = Email(
|
||||
threadId=msg["threadId"],
|
||||
id=msg["id"],
|
||||
subject=headers.get("subject", "No Subject"),
|
||||
snippet=msg["snippet"],
|
||||
from_=parseaddr(headers.get("from", ""))[1],
|
||||
to=parseaddr(headers.get("to", ""))[1],
|
||||
date=headers.get("date", ""),
|
||||
body=self._get_email_body(msg, service),
|
||||
body=self._get_email_body(msg),
|
||||
sizeEstimate=msg["sizeEstimate"],
|
||||
attachments=attachments,
|
||||
)
|
||||
@@ -228,81 +202,19 @@ class GmailReadBlock(Block):
|
||||
|
||||
return email_data
|
||||
|
||||
def _get_email_body(self, msg, service):
|
||||
"""Extract email body content with support for multipart messages and HTML conversion."""
|
||||
text = self._walk_for_body(msg["payload"], msg["id"], service)
|
||||
return text or "This email does not contain a readable body."
|
||||
|
||||
def _walk_for_body(self, part, msg_id, service, depth=0):
|
||||
"""Recursively walk through email parts to find readable body content."""
|
||||
# Prevent infinite recursion by limiting depth
|
||||
if depth > 10:
|
||||
return None
|
||||
|
||||
mime_type = part.get("mimeType", "")
|
||||
body = part.get("body", {})
|
||||
|
||||
# Handle text/plain content
|
||||
if mime_type == "text/plain" and body.get("data"):
|
||||
return self._decode_base64(body["data"])
|
||||
|
||||
# Handle text/html content (convert to plain text)
|
||||
if mime_type == "text/html" and body.get("data"):
|
||||
html_content = self._decode_base64(body["data"])
|
||||
if html_content:
|
||||
try:
|
||||
import html2text
|
||||
|
||||
h = html2text.HTML2Text()
|
||||
h.ignore_links = False
|
||||
h.ignore_images = True
|
||||
return h.handle(html_content)
|
||||
except ImportError:
|
||||
# Fallback: return raw HTML if html2text is not available
|
||||
return html_content
|
||||
|
||||
# Handle content stored as attachment
|
||||
if body.get("attachmentId"):
|
||||
attachment_data = self._download_attachment_body(
|
||||
body["attachmentId"], msg_id, service
|
||||
def _get_email_body(self, msg):
|
||||
if "parts" in msg["payload"]:
|
||||
for part in msg["payload"]["parts"]:
|
||||
if part["mimeType"] == "text/plain":
|
||||
return base64.urlsafe_b64decode(part["body"]["data"]).decode(
|
||||
"utf-8"
|
||||
)
|
||||
elif msg["payload"]["mimeType"] == "text/plain":
|
||||
return base64.urlsafe_b64decode(msg["payload"]["body"]["data"]).decode(
|
||||
"utf-8"
|
||||
)
|
||||
if attachment_data:
|
||||
return self._decode_base64(attachment_data)
|
||||
|
||||
# Recursively search in parts
|
||||
for sub_part in part.get("parts", []):
|
||||
text = self._walk_for_body(sub_part, msg_id, service, depth + 1)
|
||||
if text:
|
||||
return text
|
||||
|
||||
return None
|
||||
|
||||
def _decode_base64(self, data):
|
||||
"""Safely decode base64 URL-safe data with proper padding."""
|
||||
if not data:
|
||||
return None
|
||||
try:
|
||||
# Add padding if necessary
|
||||
missing_padding = len(data) % 4
|
||||
if missing_padding:
|
||||
data += "=" * (4 - missing_padding)
|
||||
return base64.urlsafe_b64decode(data).decode("utf-8")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _download_attachment_body(self, attachment_id, msg_id, service):
|
||||
"""Download attachment content when email body is stored as attachment."""
|
||||
try:
|
||||
attachment = (
|
||||
service.users()
|
||||
.messages()
|
||||
.attachments()
|
||||
.get(userId="me", messageId=msg_id, id=attachment_id)
|
||||
.execute()
|
||||
)
|
||||
return attachment.get("data")
|
||||
except Exception:
|
||||
return None
|
||||
return "This email does not contain a text body."
|
||||
|
||||
def _get_attachments(self, service, message):
|
||||
attachments = []
|
||||
@@ -400,6 +312,7 @@ class GmailSendBlock(Block):
|
||||
return {"id": sent_message["id"], "status": "sent"}
|
||||
|
||||
def _create_message(self, to: str, subject: str, body: str) -> dict:
|
||||
import base64
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
message = MIMEText(body)
|
||||
@@ -624,336 +537,3 @@ class GmailRemoveLabelBlock(Block):
|
||||
if label["name"] == label_name:
|
||||
return label["id"]
|
||||
return None
|
||||
|
||||
|
||||
class GmailGetThreadBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
)
|
||||
threadId: str = SchemaField(description="Gmail thread ID")
|
||||
|
||||
class Output(BlockSchema):
|
||||
thread: Thread = SchemaField(
|
||||
description="Gmail thread with decoded message bodies"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if any")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="21a79166-9df7-4b5f-9f36-96f639d86112",
|
||||
description="Get a full Gmail thread by ID",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=GmailGetThreadBlock.Input,
|
||||
output_schema=GmailGetThreadBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
test_input={"threadId": "t1", "credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"thread",
|
||||
{
|
||||
"id": "188199feff9dc907",
|
||||
"messages": [
|
||||
{
|
||||
"id": "188199feff9dc907",
|
||||
"to": "nick@example.co",
|
||||
"body": "This email does not contain a text body.",
|
||||
"date": "Thu, 17 Jul 2025 19:22:36 +0100",
|
||||
"from_": "bent@example.co",
|
||||
"snippet": "have a funny looking car -- Bently, Community Administrator For AutoGPT",
|
||||
"subject": "car",
|
||||
"threadId": "188199feff9dc907",
|
||||
"attachments": [
|
||||
{
|
||||
"size": 5694,
|
||||
"filename": "frog.jpg",
|
||||
"content_type": "image/jpeg",
|
||||
"attachment_id": "ANGjdJ_f777CvJ37TdHYSPIPPqJ0HVNgze1uM8alw5iiqTqAVXjsmBWxOWXrY3Z4W4rEJHfAcHVx54_TbtcZIVJJEqJfAD5LoUOK9_zKCRwwcTJ5TGgjsXcZNSnOJNazM-m4E6buo2-p0WNcA_hqQvuA36nzS31Olx3m2x7BaG1ILOkBcjlKJl4KCcR0AvnfK0S02k8i-bZVqII7XXrNp21f1BDolxH7tiEhkz3d5p-5Lbro24olgOWQwQk0SCJsTWWBMCVgbxU7oLt1QmPcjANxfpvh69Qfap3htvQxFa9P08NDI2YqQkry9yPxVR7ZBJQWrqO35EWmhNySEiX5pfG8SDRmfP9O_BqxTH35nEXmSOvZH9zb214iM-zfSoPSU1F5Fo71",
|
||||
}
|
||||
],
|
||||
"sizeEstimate": 14099,
|
||||
}
|
||||
],
|
||||
"historyId": "645006",
|
||||
},
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"_get_thread": lambda *args, **kwargs: {
|
||||
"id": "188199feff9dc907",
|
||||
"messages": [
|
||||
{
|
||||
"id": "188199feff9dc907",
|
||||
"to": "nick@example.co",
|
||||
"body": "This email does not contain a text body.",
|
||||
"date": "Thu, 17 Jul 2025 19:22:36 +0100",
|
||||
"from_": "bent@example.co",
|
||||
"snippet": "have a funny looking car -- Bently, Community Administrator For AutoGPT",
|
||||
"subject": "car",
|
||||
"threadId": "188199feff9dc907",
|
||||
"attachments": [
|
||||
{
|
||||
"size": 5694,
|
||||
"filename": "frog.jpg",
|
||||
"content_type": "image/jpeg",
|
||||
"attachment_id": "ANGjdJ_f777CvJ37TdHYSPIPPqJ0HVNgze1uM8alw5iiqTqAVXjsmBWxOWXrY3Z4W4rEJHfAcHVx54_TbtcZIVJJEqJfAD5LoUOK9_zKCRwwcTJ5TGgjsXcZNSnOJNazM-m4E6buo2-p0WNcA_hqQvuA36nzS31Olx3m2x7BaG1ILOkBcjlKJl4KCcR0AvnfK0S02k8i-bZVqII7XXrNp21f1BDolxH7tiEhkz3d5p-5Lbro24olgOWQwQk0SCJsTWWBMCVgbxU7oLt1QmPcjANxfpvh69Qfap3htvQxFa9P08NDI2YqQkry9yPxVR7ZBJQWrqO35EWmhNySEiX5pfG8SDRmfP9O_BqxTH35nEXmSOvZH9zb214iM-zfSoPSU1F5Fo71",
|
||||
}
|
||||
],
|
||||
"sizeEstimate": 14099,
|
||||
}
|
||||
],
|
||||
"historyId": "645006",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
service = GmailReadBlock._build_service(credentials, **kwargs)
|
||||
thread = self._get_thread(service, input_data.threadId, credentials.scopes)
|
||||
yield "thread", thread
|
||||
|
||||
def _get_thread(self, service, thread_id: str, scopes: list[str] | None) -> Thread:
|
||||
scopes = [s.lower() for s in (scopes or [])]
|
||||
format_type = (
|
||||
"metadata"
|
||||
if "https://www.googleapis.com/auth/gmail.metadata" in scopes
|
||||
else "full"
|
||||
)
|
||||
thread = (
|
||||
service.users()
|
||||
.threads()
|
||||
.get(userId="me", id=thread_id, format=format_type)
|
||||
.execute()
|
||||
)
|
||||
|
||||
parsed_messages = []
|
||||
for msg in thread.get("messages", []):
|
||||
headers = {
|
||||
h["name"].lower(): h["value"]
|
||||
for h in msg.get("payload", {}).get("headers", [])
|
||||
}
|
||||
body = self._get_email_body(msg)
|
||||
attachments = self._get_attachments(service, msg)
|
||||
email = Email(
|
||||
threadId=msg.get("threadId", thread_id),
|
||||
id=msg["id"],
|
||||
subject=headers.get("subject", "No Subject"),
|
||||
snippet=msg.get("snippet", ""),
|
||||
from_=parseaddr(headers.get("from", ""))[1],
|
||||
to=parseaddr(headers.get("to", ""))[1],
|
||||
date=headers.get("date", ""),
|
||||
body=body,
|
||||
sizeEstimate=msg.get("sizeEstimate", 0),
|
||||
attachments=attachments,
|
||||
)
|
||||
parsed_messages.append(email.model_dump())
|
||||
|
||||
thread["messages"] = parsed_messages
|
||||
return thread
|
||||
|
||||
def _get_email_body(self, msg):
|
||||
payload = msg.get("payload")
|
||||
if not payload:
|
||||
return "This email does not contain a text body."
|
||||
|
||||
if "parts" in payload:
|
||||
for part in payload["parts"]:
|
||||
if part.get("mimeType") == "text/plain" and "data" in part.get(
|
||||
"body", {}
|
||||
):
|
||||
return base64.urlsafe_b64decode(part["body"]["data"]).decode(
|
||||
"utf-8"
|
||||
)
|
||||
elif payload.get("mimeType") == "text/plain" and "data" in payload.get(
|
||||
"body", {}
|
||||
):
|
||||
return base64.urlsafe_b64decode(payload["body"]["data"]).decode("utf-8")
|
||||
return "This email does not contain a text body."
|
||||
|
||||
def _get_attachments(self, service, message):
|
||||
attachments = []
|
||||
if "parts" in message["payload"]:
|
||||
for part in message["payload"]["parts"]:
|
||||
if part.get("filename"):
|
||||
attachment = Attachment(
|
||||
filename=part["filename"],
|
||||
content_type=part["mimeType"],
|
||||
size=int(part["body"].get("size", 0)),
|
||||
attachment_id=part["body"]["attachmentId"],
|
||||
)
|
||||
attachments.append(attachment)
|
||||
return attachments
|
||||
|
||||
|
||||
class GmailReplyBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/gmail.send"]
|
||||
)
|
||||
threadId: str = SchemaField(description="Thread ID to reply in")
|
||||
parentMessageId: str = SchemaField(
|
||||
description="ID of the message being replied to"
|
||||
)
|
||||
to: list[str] = SchemaField(description="To recipients", default_factory=list)
|
||||
cc: list[str] = SchemaField(description="CC recipients", default_factory=list)
|
||||
bcc: list[str] = SchemaField(description="BCC recipients", default_factory=list)
|
||||
replyAll: bool = SchemaField(
|
||||
description="Reply to all original recipients", default=False
|
||||
)
|
||||
subject: str = SchemaField(description="Email subject", default="")
|
||||
body: str = SchemaField(description="Email body")
|
||||
attachments: list[MediaFileType] = SchemaField(
|
||||
description="Files to attach", default_factory=list, advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
messageId: str = SchemaField(description="Sent message ID")
|
||||
threadId: str = SchemaField(description="Thread ID")
|
||||
message: dict = SchemaField(description="Raw Gmail message object")
|
||||
error: str = SchemaField(description="Error message if any")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="12bf5a24-9b90-4f40-9090-4e86e6995e60",
|
||||
description="Reply to a Gmail thread",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=GmailReplyBlock.Input,
|
||||
output_schema=GmailReplyBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"threadId": "t1",
|
||||
"parentMessageId": "m1",
|
||||
"body": "Thanks",
|
||||
"replyAll": False,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("messageId", "m2"),
|
||||
("threadId", "t1"),
|
||||
("message", {"id": "m2", "threadId": "t1"}),
|
||||
],
|
||||
test_mock={
|
||||
"_reply": lambda *args, **kwargs: {
|
||||
"id": "m2",
|
||||
"threadId": "t1",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GoogleCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
service = GmailReadBlock._build_service(credentials, **kwargs)
|
||||
message = await self._reply(
|
||||
service,
|
||||
input_data,
|
||||
graph_exec_id,
|
||||
user_id,
|
||||
)
|
||||
yield "messageId", message["id"]
|
||||
yield "threadId", message.get("threadId", input_data.threadId)
|
||||
yield "message", message
|
||||
|
||||
async def _reply(
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
) -> dict:
|
||||
parent = (
|
||||
service.users()
|
||||
.messages()
|
||||
.get(
|
||||
userId="me",
|
||||
id=input_data.parentMessageId,
|
||||
format="metadata",
|
||||
metadataHeaders=[
|
||||
"Subject",
|
||||
"References",
|
||||
"Message-ID",
|
||||
"From",
|
||||
"To",
|
||||
"Cc",
|
||||
"Reply-To",
|
||||
],
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
headers = {
|
||||
h["name"].lower(): h["value"]
|
||||
for h in parent.get("payload", {}).get("headers", [])
|
||||
}
|
||||
if not (input_data.to or input_data.cc or input_data.bcc):
|
||||
if input_data.replyAll:
|
||||
recipients = [parseaddr(headers.get("from", ""))[1]]
|
||||
recipients += [
|
||||
addr for _, addr in getaddresses([headers.get("to", "")])
|
||||
]
|
||||
recipients += [
|
||||
addr for _, addr in getaddresses([headers.get("cc", "")])
|
||||
]
|
||||
dedup: list[str] = []
|
||||
for r in recipients:
|
||||
if r and r not in dedup:
|
||||
dedup.append(r)
|
||||
input_data.to = dedup
|
||||
else:
|
||||
sender = parseaddr(headers.get("reply-to", headers.get("from", "")))[1]
|
||||
input_data.to = [sender] if sender else []
|
||||
subject = input_data.subject or (f"Re: {headers.get('subject', '')}".strip())
|
||||
references = headers.get("references", "").split()
|
||||
if headers.get("message-id"):
|
||||
references.append(headers["message-id"])
|
||||
|
||||
from email import encoders
|
||||
from email.mime.base import MIMEBase
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
if references:
|
||||
msg["References"] = " ".join(references)
|
||||
msg.attach(
|
||||
MIMEText(input_data.body, "html" if "<" in input_data.body else "plain")
|
||||
)
|
||||
|
||||
for attach in input_data.attachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
return_content=False,
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
encoders.encode_base64(part)
|
||||
part.add_header(
|
||||
"Content-Disposition", f"attachment; filename={Path(abs_path).name}"
|
||||
)
|
||||
msg.attach(part)
|
||||
|
||||
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
|
||||
return (
|
||||
service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"threadId": input_data.threadId, "raw": raw})
|
||||
.execute()
|
||||
)
|
||||
|
||||
@@ -113,7 +113,6 @@ class SendWebRequestBlock(Block):
|
||||
graph_exec_id: str,
|
||||
files_name: str,
|
||||
files: list[MediaFileType],
|
||||
user_id: str,
|
||||
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
|
||||
"""
|
||||
Prepare files for the request by storing them and reading their content.
|
||||
@@ -125,7 +124,7 @@ class SendWebRequestBlock(Block):
|
||||
for media in files:
|
||||
# Normalise to a list so we can repeat the same key
|
||||
rel_path = await store_media_file(
|
||||
graph_exec_id, media, user_id, return_content=False
|
||||
graph_exec_id, media, return_content=False
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, rel_path)
|
||||
async with aiofiles.open(abs_path, "rb") as f:
|
||||
@@ -137,7 +136,7 @@ class SendWebRequestBlock(Block):
|
||||
return files_payload
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs
|
||||
self, input_data: Input, *, graph_exec_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
# ─── Parse/normalise body ────────────────────────────────────
|
||||
body = input_data.body
|
||||
@@ -168,7 +167,7 @@ class SendWebRequestBlock(Block):
|
||||
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
||||
if use_files:
|
||||
files_payload = await self._prepare_files(
|
||||
graph_exec_id, input_data.files_name, input_data.files, user_id
|
||||
graph_exec_id, input_data.files_name, input_data.files
|
||||
)
|
||||
|
||||
# Enforce body format rules
|
||||
@@ -228,7 +227,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
credentials: HostScopedCredentials,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
||||
@@ -259,6 +257,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
||||
|
||||
# Use parent class run method
|
||||
async for output_name, output_data in super().run(
|
||||
base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs
|
||||
base_input, graph_exec_id=graph_exec_id, **kwargs
|
||||
):
|
||||
yield output_name, output_data
|
||||
|
||||
@@ -447,7 +447,6 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
if not input_data.value:
|
||||
@@ -456,7 +455,6 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
yield "result", await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.value,
|
||||
user_id=user_id,
|
||||
return_content=input_data.base_64,
|
||||
)
|
||||
|
||||
|
||||
@@ -44,14 +44,12 @@ class MediaDurationBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# 1) Store the input media locally
|
||||
local_media_path = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.media_in,
|
||||
user_id=user_id,
|
||||
return_content=False,
|
||||
)
|
||||
media_abspath = get_exec_file_path(graph_exec_id, local_media_path)
|
||||
@@ -113,14 +111,12 @@ class LoopVideoBlock(Block):
|
||||
*,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# 1) Store the input video locally
|
||||
local_video_path = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.video_in,
|
||||
user_id=user_id,
|
||||
return_content=False,
|
||||
)
|
||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||
@@ -153,7 +149,6 @@ class LoopVideoBlock(Block):
|
||||
video_out = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=output_filename,
|
||||
user_id=user_id,
|
||||
return_content=input_data.output_return_type == "data_uri",
|
||||
)
|
||||
|
||||
@@ -205,20 +200,17 @@ class AddAudioToVideoBlock(Block):
|
||||
*,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# 1) Store the inputs locally
|
||||
local_video_path = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.video_in,
|
||||
user_id=user_id,
|
||||
return_content=False,
|
||||
)
|
||||
local_audio_path = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.audio_in,
|
||||
user_id=user_id,
|
||||
return_content=False,
|
||||
)
|
||||
|
||||
@@ -247,7 +239,6 @@ class AddAudioToVideoBlock(Block):
|
||||
video_out = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=output_filename,
|
||||
user_id=user_id,
|
||||
return_content=input_data.output_return_type == "data_uri",
|
||||
)
|
||||
|
||||
|
||||
25
autogpt_platform/backend/backend/blocks/oxylabs/__init__.py
Normal file
25
autogpt_platform/backend/backend/blocks/oxylabs/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Oxylabs Web Scraper API integration blocks.
|
||||
"""
|
||||
|
||||
from .blocks import (
|
||||
OxylabsCallbackerIPListBlock,
|
||||
OxylabsCheckJobStatusBlock,
|
||||
OxylabsGetJobResultsBlock,
|
||||
OxylabsProcessWebhookBlock,
|
||||
OxylabsProxyFetchBlock,
|
||||
OxylabsSubmitBatchBlock,
|
||||
OxylabsSubmitJobAsyncBlock,
|
||||
OxylabsSubmitJobRealtimeBlock,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OxylabsSubmitJobAsyncBlock",
|
||||
"OxylabsSubmitJobRealtimeBlock",
|
||||
"OxylabsSubmitBatchBlock",
|
||||
"OxylabsCheckJobStatusBlock",
|
||||
"OxylabsGetJobResultsBlock",
|
||||
"OxylabsProxyFetchBlock",
|
||||
"OxylabsProcessWebhookBlock",
|
||||
"OxylabsCallbackerIPListBlock",
|
||||
]
|
||||
15
autogpt_platform/backend/backend/blocks/oxylabs/_config.py
Normal file
15
autogpt_platform/backend/backend/blocks/oxylabs/_config.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Shared configuration for all Oxylabs blocks using the SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
# Configure the Oxylabs provider with username/password authentication
|
||||
oxylabs = (
|
||||
ProviderBuilder("oxylabs")
|
||||
.with_user_password(
|
||||
"OXYLABS_USERNAME", "OXYLABS_PASSWORD", "Oxylabs API Credentials"
|
||||
)
|
||||
.with_base_cost(10, BlockCostType.RUN) # Higher cost for web scraping service
|
||||
.build()
|
||||
)
|
||||
811
autogpt_platform/backend/backend/blocks/oxylabs/blocks.py
Normal file
811
autogpt_platform/backend/backend/blocks/oxylabs/blocks.py
Normal file
@@ -0,0 +1,811 @@
|
||||
"""
|
||||
Oxylabs Web Scraper API Blocks
|
||||
|
||||
This module implements blocks for interacting with the Oxylabs Web Scraper API.
|
||||
Oxylabs provides powerful web scraping capabilities with anti-blocking measures,
|
||||
JavaScript rendering, and built-in parsers for various sources.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
from ._config import oxylabs
|
||||
|
||||
|
||||
# Enums for Oxylabs API
|
||||
class OxylabsSource(str, Enum):
|
||||
"""Available scraping sources"""
|
||||
|
||||
AMAZON_PRODUCT = "amazon_product"
|
||||
AMAZON_SEARCH = "amazon_search"
|
||||
GOOGLE_SEARCH = "google_search"
|
||||
GOOGLE_SHOPPING = "google_shopping"
|
||||
UNIVERSAL = "universal"
|
||||
# Add more sources as needed
|
||||
|
||||
|
||||
class UserAgentType(str, Enum):
|
||||
"""User agent types for scraping"""
|
||||
|
||||
DESKTOP_CHROME = "desktop_chrome"
|
||||
DESKTOP_FIREFOX = "desktop_firefox"
|
||||
DESKTOP_SAFARI = "desktop_safari"
|
||||
DESKTOP_EDGE = "desktop_edge"
|
||||
MOBILE_ANDROID = "mobile_android"
|
||||
MOBILE_IOS = "mobile_ios"
|
||||
|
||||
|
||||
class RenderType(str, Enum):
|
||||
"""Rendering options"""
|
||||
|
||||
NONE = "none"
|
||||
HTML = "html"
|
||||
PNG = "png"
|
||||
|
||||
|
||||
class ResultType(str, Enum):
|
||||
"""Result format types"""
|
||||
|
||||
DEFAULT = "default"
|
||||
RAW = "raw"
|
||||
PARSED = "parsed"
|
||||
PNG = "png"
|
||||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
"""Job status values"""
|
||||
|
||||
PENDING = "pending"
|
||||
DONE = "done"
|
||||
FAULTED = "faulted"
|
||||
|
||||
|
||||
# Base class for Oxylabs blocks
|
||||
class OxylabsBlockBase(Block):
|
||||
"""Base class for all Oxylabs blocks with common functionality."""
|
||||
|
||||
@staticmethod
|
||||
def get_auth_header(credentials: UserPasswordCredentials) -> str:
|
||||
"""Create Basic Auth header from username and password."""
|
||||
username = credentials.username
|
||||
password = credentials.password.get_secret_value()
|
||||
auth_string = f"{username}:{password}"
|
||||
encoded = base64.b64encode(auth_string.encode()).decode()
|
||||
return f"Basic {encoded}"
|
||||
|
||||
@staticmethod
|
||||
async def make_request(
|
||||
method: str,
|
||||
url: str,
|
||||
credentials: UserPasswordCredentials,
|
||||
json_data: Optional[dict] = None,
|
||||
params: Optional[dict] = None,
|
||||
timeout: int = 300, # 5 minutes default for scraping
|
||||
) -> dict:
|
||||
"""Make an authenticated request to the Oxylabs API."""
|
||||
headers = {
|
||||
"Authorization": OxylabsBlockBase.get_auth_header(credentials),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = await Requests().request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
json=json_data,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if response.status < 200 or response.status >= 300:
|
||||
try:
|
||||
error_data = response.json()
|
||||
except Exception:
|
||||
error_data = {"message": response.text()}
|
||||
raise Exception(f"Oxylabs API error ({response.status}): {error_data}")
|
||||
|
||||
# Handle empty responses (204 No Content)
|
||||
if response.status == 204:
|
||||
return {}
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
# 1. Submit Job (Async)
|
||||
class OxylabsSubmitJobAsyncBlock(OxylabsBlockBase):
|
||||
"""
|
||||
Submit a scraping job asynchronously to Oxylabs.
|
||||
|
||||
Returns a job ID for later polling or webhook delivery.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = oxylabs.credentials_field(
|
||||
description="Oxylabs username and password"
|
||||
)
|
||||
source: OxylabsSource = SchemaField(description="The source/site to scrape")
|
||||
url: Optional[str] = SchemaField(
|
||||
description="URL to scrape (for URL-based sources)", default=None
|
||||
)
|
||||
query: Optional[str] = SchemaField(
|
||||
description="Query/keyword/ID to search (for query-based sources)",
|
||||
default=None,
|
||||
)
|
||||
geo_location: Optional[str] = SchemaField(
|
||||
description="Geographical location (e.g., 'United States', '90210')",
|
||||
default=None,
|
||||
)
|
||||
parse: bool = SchemaField(
|
||||
description="Return structured JSON output", default=False
|
||||
)
|
||||
render: RenderType = SchemaField(
|
||||
description="Enable JS rendering or screenshots", default=RenderType.NONE
|
||||
)
|
||||
user_agent_type: Optional[UserAgentType] = SchemaField(
|
||||
description="User agent type for the request", default=None
|
||||
)
|
||||
callback_url: Optional[str] = SchemaField(
|
||||
description="Webhook URL for job completion notification", default=None
|
||||
)
|
||||
advanced_options: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Additional parameters (e.g., storage_type, context)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
job_id: str = SchemaField(description="The Oxylabs job ID")
|
||||
status: str = SchemaField(description="Job status (usually 'pending')")
|
||||
self_url: str = SchemaField(description="URL to check job status")
|
||||
results_url: str = SchemaField(description="URL to get results (when done)")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a7c3b5d9-8e2f-4a1b-9c6d-3f7e8b9a0d5c",
|
||||
description="Submit an asynchronous scraping job to Oxylabs",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: UserPasswordCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Build request payload
|
||||
payload: Dict[str, Any] = {"source": input_data.source}
|
||||
|
||||
# Add URL or query based on what's provided
|
||||
if input_data.url:
|
||||
payload["url"] = input_data.url
|
||||
elif input_data.query:
|
||||
payload["query"] = input_data.query
|
||||
else:
|
||||
raise ValueError("Either 'url' or 'query' must be provided")
|
||||
|
||||
# Add optional parameters
|
||||
if input_data.geo_location:
|
||||
payload["geo_location"] = input_data.geo_location
|
||||
if input_data.parse:
|
||||
payload["parse"] = True
|
||||
if input_data.render != RenderType.NONE:
|
||||
payload["render"] = input_data.render
|
||||
if input_data.user_agent_type:
|
||||
payload["user_agent_type"] = input_data.user_agent_type
|
||||
if input_data.callback_url:
|
||||
payload["callback_url"] = input_data.callback_url
|
||||
|
||||
# Merge advanced options
|
||||
if input_data.advanced_options:
|
||||
payload.update(input_data.advanced_options)
|
||||
|
||||
# Submit job
|
||||
result = await self.make_request(
|
||||
method="POST",
|
||||
url="https://data.oxylabs.io/v1/queries",
|
||||
credentials=credentials,
|
||||
json_data=payload,
|
||||
)
|
||||
|
||||
# Extract job info
|
||||
job_id = result.get("id", "")
|
||||
status = result.get("status", "pending")
|
||||
|
||||
# Build URLs
|
||||
self_url = f"https://data.oxylabs.io/v1/queries/{job_id}"
|
||||
results_url = f"https://data.oxylabs.io/v1/queries/{job_id}/results"
|
||||
|
||||
yield "job_id", job_id
|
||||
yield "status", status
|
||||
yield "self_url", self_url
|
||||
yield "results_url", results_url
|
||||
|
||||
|
||||
# 2. Submit Job (Realtime)
|
||||
class OxylabsSubmitJobRealtimeBlock(OxylabsBlockBase):
|
||||
"""
|
||||
Submit a scraping job and wait for the result synchronously.
|
||||
|
||||
The connection is held open until the scraping completes.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = oxylabs.credentials_field(
|
||||
description="Oxylabs username and password"
|
||||
)
|
||||
source: OxylabsSource = SchemaField(description="The source/site to scrape")
|
||||
url: Optional[str] = SchemaField(
|
||||
description="URL to scrape (for URL-based sources)", default=None
|
||||
)
|
||||
query: Optional[str] = SchemaField(
|
||||
description="Query/keyword/ID to search (for query-based sources)",
|
||||
default=None,
|
||||
)
|
||||
geo_location: Optional[str] = SchemaField(
|
||||
description="Geographical location (e.g., 'United States', '90210')",
|
||||
default=None,
|
||||
)
|
||||
parse: bool = SchemaField(
|
||||
description="Return structured JSON output", default=False
|
||||
)
|
||||
render: RenderType = SchemaField(
|
||||
description="Enable JS rendering or screenshots", default=RenderType.NONE
|
||||
)
|
||||
user_agent_type: Optional[UserAgentType] = SchemaField(
|
||||
description="User agent type for the request", default=None
|
||||
)
|
||||
advanced_options: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Additional parameters", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: Literal["done", "faulted"] = SchemaField(
|
||||
description="Job completion status"
|
||||
)
|
||||
result: Union[str, dict, bytes] = SchemaField(
|
||||
description="Scraped content (HTML, JSON, or image)"
|
||||
)
|
||||
meta: Dict[str, Any] = SchemaField(description="Job metadata")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b8d4c6e0-9f3a-5b2c-0d7e-4a8f9c0b1e6d",
|
||||
description="Submit a synchronous scraping job to Oxylabs",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: UserPasswordCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Build request payload (similar to async, but no callback)
|
||||
payload: Dict[str, Any] = {"source": input_data.source}
|
||||
|
||||
if input_data.url:
|
||||
payload["url"] = input_data.url
|
||||
elif input_data.query:
|
||||
payload["query"] = input_data.query
|
||||
else:
|
||||
raise ValueError("Either 'url' or 'query' must be provided")
|
||||
|
||||
# Add optional parameters
|
||||
if input_data.geo_location:
|
||||
payload["geo_location"] = input_data.geo_location
|
||||
if input_data.parse:
|
||||
payload["parse"] = True
|
||||
if input_data.render != RenderType.NONE:
|
||||
payload["render"] = input_data.render
|
||||
if input_data.user_agent_type:
|
||||
payload["user_agent_type"] = input_data.user_agent_type
|
||||
|
||||
# Merge advanced options
|
||||
if input_data.advanced_options:
|
||||
payload.update(input_data.advanced_options)
|
||||
|
||||
# Submit job synchronously (using realtime endpoint)
|
||||
result = await self.make_request(
|
||||
method="POST",
|
||||
url="https://realtime.oxylabs.io/v1/queries",
|
||||
credentials=credentials,
|
||||
json_data=payload,
|
||||
timeout=600, # 10 minutes for realtime
|
||||
)
|
||||
|
||||
# Extract results
|
||||
status = "done" if result else "faulted"
|
||||
|
||||
# Handle different result types
|
||||
content = result
|
||||
if input_data.parse and "results" in result:
|
||||
content = result["results"]
|
||||
elif "content" in result:
|
||||
content = result["content"]
|
||||
|
||||
meta = {
|
||||
"source": input_data.source,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
yield "status", status
|
||||
yield "result", content
|
||||
yield "meta", meta
|
||||
|
||||
|
||||
# 3. Submit Batch
|
||||
class OxylabsSubmitBatchBlock(OxylabsBlockBase):
|
||||
"""
|
||||
Submit multiple scraping jobs in one request (up to 5,000).
|
||||
|
||||
Returns an array of job IDs for batch processing.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = oxylabs.credentials_field(
|
||||
description="Oxylabs username and password"
|
||||
)
|
||||
source: OxylabsSource = SchemaField(
|
||||
description="The source/site to scrape (applies to all)"
|
||||
)
|
||||
url_list: Optional[List[str]] = SchemaField(
|
||||
description="List of URLs to scrape", default=None
|
||||
)
|
||||
query_list: Optional[List[str]] = SchemaField(
|
||||
description="List of queries/keywords to search", default=None
|
||||
)
|
||||
geo_location: Optional[str] = SchemaField(
|
||||
description="Geographical location (applies to all)", default=None
|
||||
)
|
||||
parse: bool = SchemaField(
|
||||
description="Return structured JSON output", default=False
|
||||
)
|
||||
render: RenderType = SchemaField(
|
||||
description="Enable JS rendering or screenshots", default=RenderType.NONE
|
||||
)
|
||||
user_agent_type: Optional[UserAgentType] = SchemaField(
|
||||
description="User agent type for the requests", default=None
|
||||
)
|
||||
callback_url: Optional[str] = SchemaField(
|
||||
description="Webhook URL for job completion notifications", default=None
|
||||
)
|
||||
advanced_options: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Additional parameters", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
job_ids: List[str] = SchemaField(description="List of job IDs")
|
||||
count: int = SchemaField(description="Number of jobs created")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c9e5d7f1-0a4b-6c3d-1e8f-5b9a0c2d3f7e",
|
||||
description="Submit batch scraping jobs to Oxylabs",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: UserPasswordCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Build batch request payload
|
||||
payload: Dict[str, Any] = {"source": input_data.source}
|
||||
|
||||
# Add URL list or query list
|
||||
if input_data.url_list:
|
||||
if len(input_data.url_list) > 5000:
|
||||
raise ValueError("Batch size cannot exceed 5,000 URLs")
|
||||
payload["url"] = input_data.url_list
|
||||
elif input_data.query_list:
|
||||
if len(input_data.query_list) > 5000:
|
||||
raise ValueError("Batch size cannot exceed 5,000 queries")
|
||||
payload["query"] = input_data.query_list
|
||||
else:
|
||||
raise ValueError("Either 'url_list' or 'query_list' must be provided")
|
||||
|
||||
# Add optional parameters (apply to all items)
|
||||
if input_data.geo_location:
|
||||
payload["geo_location"] = input_data.geo_location
|
||||
if input_data.parse:
|
||||
payload["parse"] = True
|
||||
if input_data.render != RenderType.NONE:
|
||||
payload["render"] = input_data.render
|
||||
if input_data.user_agent_type:
|
||||
payload["user_agent_type"] = input_data.user_agent_type
|
||||
if input_data.callback_url:
|
||||
payload["callback_url"] = input_data.callback_url
|
||||
|
||||
# Merge advanced options
|
||||
if input_data.advanced_options:
|
||||
payload.update(input_data.advanced_options)
|
||||
|
||||
# Submit batch
|
||||
result = await self.make_request(
|
||||
method="POST",
|
||||
url="https://data.oxylabs.io/v1/queries/batch",
|
||||
credentials=credentials,
|
||||
json_data=payload,
|
||||
)
|
||||
|
||||
# Extract job IDs
|
||||
queries = result.get("queries", [])
|
||||
job_ids = [q.get("id", "") for q in queries if q.get("id")]
|
||||
|
||||
yield "job_ids", job_ids
|
||||
yield "count", len(job_ids)
|
||||
|
||||
|
||||
# 4. Check Job Status
|
||||
class OxylabsCheckJobStatusBlock(OxylabsBlockBase):
|
||||
"""
|
||||
Check the status of a scraping job.
|
||||
|
||||
Can optionally wait for completion by polling.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = oxylabs.credentials_field(
|
||||
description="Oxylabs username and password"
|
||||
)
|
||||
job_id: str = SchemaField(description="Job ID to check")
|
||||
wait_for_completion: bool = SchemaField(
|
||||
description="Poll until job leaves 'pending' status", default=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: JobStatus = SchemaField(description="Current job status")
|
||||
updated_at: Optional[str] = SchemaField(
|
||||
description="Last update timestamp", default=None
|
||||
)
|
||||
results_url: Optional[str] = SchemaField(
|
||||
description="URL to get results (when done)", default=None
|
||||
)
|
||||
raw_status: Dict[str, Any] = SchemaField(description="Full status response")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d0f6e8a2-1b5c-7d4e-2f9a-6c0b1d3e4a8f",
|
||||
description="Check the status of an Oxylabs scraping job",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: UserPasswordCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
import asyncio
|
||||
|
||||
url = f"https://data.oxylabs.io/v1/queries/{input_data.job_id}"
|
||||
|
||||
# Check status (with optional polling)
|
||||
max_attempts = 60 if input_data.wait_for_completion else 1
|
||||
delay = 5 # seconds between polls
|
||||
|
||||
# Initialize variables that will be used outside the loop
|
||||
result = {}
|
||||
status = "pending"
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
result = await self.make_request(
|
||||
method="GET",
|
||||
url=url,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
status = result.get("status", "pending")
|
||||
|
||||
# If not waiting or job is complete, return
|
||||
if not input_data.wait_for_completion or status != "pending":
|
||||
break
|
||||
|
||||
# Wait before next poll
|
||||
if attempt < max_attempts - 1:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Extract results URL if job is done
|
||||
results_url = None
|
||||
if status == "done":
|
||||
links = result.get("_links", [])
|
||||
for link in links:
|
||||
if link.get("rel") == "results":
|
||||
results_url = link.get("href")
|
||||
break
|
||||
|
||||
yield "status", JobStatus(status)
|
||||
yield "updated_at", result.get("updated_at")
|
||||
yield "results_url", results_url
|
||||
yield "raw_status", result
|
||||
|
||||
|
||||
# 5. Get Job Results
|
||||
class OxylabsGetJobResultsBlock(OxylabsBlockBase):
|
||||
"""
|
||||
Download the scraped data for a completed job.
|
||||
|
||||
Supports different result formats (raw, parsed, screenshot).
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = oxylabs.credentials_field(
|
||||
description="Oxylabs username and password"
|
||||
)
|
||||
job_id: str = SchemaField(description="Job ID to get results for")
|
||||
result_type: ResultType = SchemaField(
|
||||
description="Type of result to retrieve", default=ResultType.DEFAULT
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
content: Union[str, dict, bytes] = SchemaField(description="The scraped data")
|
||||
content_type: str = SchemaField(description="MIME type of the content")
|
||||
meta: Dict[str, Any] = SchemaField(description="Result metadata")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e1a7f9b3-2c6d-8e5f-3a0b-7d1c2e4f5b9a",
|
||||
description="Get results from a completed Oxylabs job",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: UserPasswordCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = f"https://data.oxylabs.io/v1/queries/{input_data.job_id}/results"
|
||||
|
||||
# Add result type parameter if not default
|
||||
params = {}
|
||||
if input_data.result_type != ResultType.DEFAULT:
|
||||
params["type"] = input_data.result_type
|
||||
|
||||
# Get results
|
||||
headers = {
|
||||
"Authorization": self.get_auth_header(credentials),
|
||||
}
|
||||
|
||||
# For PNG results, we need to handle binary data
|
||||
if input_data.result_type == ResultType.PNG:
|
||||
response = await Requests().request(
|
||||
method="GET",
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if response.status < 200 or response.status >= 300:
|
||||
raise Exception(f"Failed to get results: {response.status}")
|
||||
|
||||
content = response.content # Binary content
|
||||
content_type = response.headers.get("Content-Type", "image/png")
|
||||
else:
|
||||
# JSON or text results
|
||||
result = await self.make_request(
|
||||
method="GET",
|
||||
url=url,
|
||||
credentials=credentials,
|
||||
params=params,
|
||||
)
|
||||
|
||||
content = result
|
||||
content_type = "application/json"
|
||||
|
||||
meta = {
|
||||
"job_id": input_data.job_id,
|
||||
"result_type": input_data.result_type,
|
||||
"retrieved_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
yield "content", content
|
||||
yield "content_type", content_type
|
||||
yield "meta", meta
|
||||
|
||||
|
||||
# 6. Proxy Fetch URL
|
||||
class OxylabsProxyFetchBlock(OxylabsBlockBase):
|
||||
"""
|
||||
Fetch a URL through Oxylabs' HTTPS proxy endpoint.
|
||||
|
||||
Ideal for one-off page downloads without job management.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = oxylabs.credentials_field(
|
||||
description="Oxylabs username and password"
|
||||
)
|
||||
target_url: str = SchemaField(
|
||||
description="URL to fetch (must include https://)"
|
||||
)
|
||||
geo_location: Optional[str] = SchemaField(
|
||||
description="Geographical location", default=None
|
||||
)
|
||||
user_agent_type: Optional[UserAgentType] = SchemaField(
|
||||
description="User agent type", default=None
|
||||
)
|
||||
render: Literal["none", "html"] = SchemaField(
|
||||
description="Enable JavaScript rendering", default="none"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
html: str = SchemaField(description="Page HTML content")
|
||||
status_code: int = SchemaField(description="HTTP status code")
|
||||
headers: Dict[str, str] = SchemaField(description="Response headers")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f2b8a0c4-3d7e-9f6a-4b1c-8e2d3f5a6c0b",
|
||||
description="Fetch a URL through Oxylabs proxy",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: UserPasswordCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Prepare proxy headers
|
||||
headers = {
|
||||
"Authorization": self.get_auth_header(credentials),
|
||||
}
|
||||
|
||||
if input_data.geo_location:
|
||||
headers["x-oxylabs-geo-location"] = input_data.geo_location
|
||||
if input_data.user_agent_type:
|
||||
headers["x-oxylabs-user-agent-type"] = input_data.user_agent_type
|
||||
if input_data.render != "none":
|
||||
headers["x-oxylabs-render"] = input_data.render
|
||||
|
||||
# Use the proxy endpoint
|
||||
# Note: In a real implementation, you'd configure the HTTP client
|
||||
# to use realtime.oxylabs.io:60000 as an HTTPS proxy
|
||||
# For this example, we'll use the regular API endpoint
|
||||
|
||||
payload = {
|
||||
"source": "universal",
|
||||
"url": input_data.target_url,
|
||||
}
|
||||
|
||||
if input_data.geo_location:
|
||||
payload["geo_location"] = input_data.geo_location
|
||||
if input_data.user_agent_type:
|
||||
payload["user_agent_type"] = input_data.user_agent_type
|
||||
if input_data.render != "none":
|
||||
payload["render"] = input_data.render
|
||||
|
||||
result = await self.make_request(
|
||||
method="POST",
|
||||
url="https://realtime.oxylabs.io/v1/queries",
|
||||
credentials=credentials,
|
||||
json_data=payload,
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
# Extract content
|
||||
html = result.get("content", "")
|
||||
status_code = result.get("status_code", 200)
|
||||
headers = result.get("headers", {})
|
||||
|
||||
yield "html", html
|
||||
yield "status_code", status_code
|
||||
yield "headers", headers
|
||||
|
||||
|
||||
# 7. Callback Trigger (Webhook) - This would be handled by the platform's webhook system
|
||||
# We'll create a block to process webhook data instead
|
||||
class OxylabsProcessWebhookBlock(OxylabsBlockBase):
|
||||
"""
|
||||
Process incoming Oxylabs webhook callback data.
|
||||
|
||||
Extracts job information from the webhook payload.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
webhook_payload: Dict[str, Any] = SchemaField(
|
||||
description="Raw webhook payload from Oxylabs"
|
||||
)
|
||||
verify_ip: bool = SchemaField(
|
||||
description="Verify the request came from Oxylabs IPs", default=True
|
||||
)
|
||||
source_ip: Optional[str] = SchemaField(
|
||||
description="IP address of the webhook sender", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
job_id: str = SchemaField(description="Job ID from callback")
|
||||
status: JobStatus = SchemaField(description="Job completion status")
|
||||
results_url: Optional[str] = SchemaField(
|
||||
description="URL to fetch the results", default=None
|
||||
)
|
||||
raw_callback: Dict[str, Any] = SchemaField(description="Full callback payload")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a3c9b1d5-4e8f-0b2d-5c6e-9f0a1d3f7b8c",
|
||||
description="Process Oxylabs webhook callback data",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: UserPasswordCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
payload = input_data.webhook_payload
|
||||
|
||||
# Extract job information
|
||||
job_id = payload.get("id", "")
|
||||
status = JobStatus(payload.get("status", "pending"))
|
||||
|
||||
# Find results URL
|
||||
results_url = None
|
||||
links = payload.get("_links", [])
|
||||
for link in links:
|
||||
if link.get("rel") == "results":
|
||||
results_url = link.get("href")
|
||||
break
|
||||
|
||||
# If IP verification is requested, we'd check against the callbacker IPs
|
||||
# This is simplified for the example
|
||||
if input_data.verify_ip and input_data.source_ip:
|
||||
# In a real implementation, we'd fetch and cache the IP list
|
||||
# and verify the source_ip is in that list
|
||||
pass
|
||||
|
||||
yield "job_id", job_id
|
||||
yield "status", status
|
||||
yield "results_url", results_url
|
||||
yield "raw_callback", payload
|
||||
|
||||
|
||||
# 8. Callbacker IP List
|
||||
class OxylabsCallbackerIPListBlock(OxylabsBlockBase):
|
||||
"""
|
||||
Get the list of IP addresses used by Oxylabs for callbacks.
|
||||
|
||||
Use this for firewall whitelisting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = oxylabs.credentials_field(
|
||||
description="Oxylabs username and password"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
ip_list: List[str] = SchemaField(description="List of Oxylabs callback IPs")
|
||||
updated_at: str = SchemaField(description="Timestamp of retrieval")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b4d0c2e6-5f9a-1c3e-6d7f-0a1b2d4e8c9d",
|
||||
description="Get Oxylabs callback IP addresses",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: UserPasswordCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
result = await self.make_request(
|
||||
method="GET",
|
||||
url="https://data.oxylabs.io/v1/info/callbacker_ips",
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
# Extract IP list
|
||||
ip_list = result.get("callbacker_ips", [])
|
||||
updated_at = datetime.utcnow().isoformat()
|
||||
|
||||
yield "ip_list", ip_list
|
||||
yield "updated_at", updated_at
|
||||
@@ -1,24 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput, ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="replicate",
|
||||
api_key=SecretStr("mock-replicate-api-key"),
|
||||
title="Mock Replicate API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
ReplicateCredentials = APIKeyCredentials
|
||||
ReplicateCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
]
|
||||
@@ -1,39 +0,0 @@
|
||||
import logging
|
||||
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ReplicateOutputs = FileOutput | list[FileOutput] | list[str] | str | list[dict]
|
||||
|
||||
|
||||
def extract_result(output: ReplicateOutputs) -> str:
|
||||
result = (
|
||||
"Unable to process result. Please contact us with the models and inputs used"
|
||||
)
|
||||
# Check if output is a list or a string and extract accordingly; otherwise, assign a default message
|
||||
if isinstance(output, list) and len(output) > 0:
|
||||
# we could use something like all(output, FileOutput) but it will be slower so we just type ignore
|
||||
if isinstance(output[0], FileOutput):
|
||||
result = output[0].url # If output is a list, get the first element
|
||||
elif isinstance(output[0], str):
|
||||
result = "".join(
|
||||
output # type: ignore we're already not a file output here
|
||||
) # type:ignore If output is a list and a str, join the elements the first element. Happens if its text
|
||||
elif isinstance(output[0], dict):
|
||||
result = str(output[0])
|
||||
else:
|
||||
logger.error(
|
||||
"Replicate generated a new output type that's not a file output or a str in a replicate block"
|
||||
)
|
||||
elif isinstance(output, FileOutput):
|
||||
result = output.url # If output is a FileOutput, use the url
|
||||
elif isinstance(output, str):
|
||||
result = output # If output is a string (for some reason due to their janky type hinting), use it directly
|
||||
else:
|
||||
result = "No output received" # Fallback message if output is not as expected
|
||||
logger.error(
|
||||
"We somehow didn't get an output from a replicate block. This is almost certainly an error"
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -1,133 +0,0 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
|
||||
from backend.blocks.replicate._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ReplicateCredentialsInput,
|
||||
)
|
||||
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReplicateModelBlock(Block):
|
||||
"""
|
||||
Block for running any Replicate model with custom inputs.
|
||||
|
||||
This block allows you to:
|
||||
- Use any public Replicate model
|
||||
- Pass custom inputs as a dictionary
|
||||
- Specify model versions
|
||||
- Get structured outputs with prediction metadata
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: ReplicateCredentialsInput = CredentialsField(
|
||||
description="Enter your Replicate API key to access the model API. You can obtain an API key from https://replicate.com/account/api-tokens.",
|
||||
)
|
||||
model_name: str = SchemaField(
|
||||
description="The Replicate model name (format: 'owner/model-name')",
|
||||
placeholder="stability-ai/stable-diffusion",
|
||||
advanced=False,
|
||||
)
|
||||
model_inputs: dict[str, str | int] = SchemaField(
|
||||
default={},
|
||||
description="Dictionary of inputs to pass to the model",
|
||||
placeholder='{"prompt": "a beautiful landscape", "num_outputs": 1}',
|
||||
advanced=False,
|
||||
)
|
||||
version: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Specific version hash of the model (optional)",
|
||||
placeholder="db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: str = SchemaField(description="The output from the Replicate model")
|
||||
status: str = SchemaField(description="Status of the prediction")
|
||||
model_name: str = SchemaField(description="Name of the model used")
|
||||
error: str = SchemaField(description="Error message if any", default="")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c40d75a2-d0ea-44c9-a4f6-634bb3bdab1a",
|
||||
description="Run Replicate models synchronously",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=ReplicateModelBlock.Input,
|
||||
output_schema=ReplicateModelBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"model_name": "meta/llama-2-7b-chat",
|
||||
"model_inputs": {"prompt": "Hello, world!", "max_new_tokens": 50},
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("result", str),
|
||||
("status", str),
|
||||
("model_name", str),
|
||||
],
|
||||
test_mock={
|
||||
"run_model": lambda model_ref, model_inputs, api_key: (
|
||||
"Mock response from Replicate model"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Execute the Replicate model with the provided inputs.
|
||||
|
||||
Args:
|
||||
input_data: The input data containing model name and inputs
|
||||
credentials: The API credentials
|
||||
|
||||
Yields:
|
||||
BlockOutput containing the model results and metadata
|
||||
"""
|
||||
try:
|
||||
if input_data.version:
|
||||
model_ref = f"{input_data.model_name}:{input_data.version}"
|
||||
else:
|
||||
model_ref = input_data.model_name
|
||||
logger.debug(f"Running Replicate model: {model_ref}")
|
||||
result = await self.run_model(
|
||||
model_ref, input_data.model_inputs, credentials.api_key
|
||||
)
|
||||
yield "result", result
|
||||
yield "status", "succeeded"
|
||||
yield "model_name", input_data.model_name
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error running Replicate model: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
async def run_model(self, model_ref: str, model_inputs: dict, api_key: SecretStr):
|
||||
"""
|
||||
Run the Replicate model. This method can be mocked for testing.
|
||||
|
||||
Args:
|
||||
model_ref: The model reference (e.g., "owner/model-name:version")
|
||||
model_inputs: The inputs to pass to the model
|
||||
api_key: The Replicate API key as SecretStr
|
||||
|
||||
Returns:
|
||||
Tuple of (result, prediction_id)
|
||||
"""
|
||||
api_key_str = api_key.get_secret_value()
|
||||
client = ReplicateClient(api_token=api_key_str)
|
||||
output: ReplicateOutputs = await client.async_run(
|
||||
model_ref, input=model_inputs, wait=False
|
||||
) # type: ignore they suck at typing
|
||||
|
||||
result = extract_result(output)
|
||||
|
||||
return result
|
||||
@@ -1,17 +1,33 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.blocks.replicate._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ReplicateCredentialsInput,
|
||||
)
|
||||
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="replicate",
|
||||
api_key=SecretStr("mock-replicate-api-key"),
|
||||
title="Mock Replicate API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
|
||||
# Model name enum
|
||||
@@ -39,7 +55,9 @@ class ImageType(str, Enum):
|
||||
|
||||
class ReplicateFluxAdvancedModelBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: ReplicateCredentialsInput = CredentialsField(
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="The Replicate integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
@@ -183,7 +201,7 @@ class ReplicateFluxAdvancedModelBlock(Block):
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
|
||||
# Run the model with additional parameters
|
||||
output: ReplicateOutputs = await client.async_run( # type: ignore This is because they changed the return type, and didn't update the type hint! It should be overloaded depending on the value of `use_file_output` to `FileOutput | list[FileOutput]` but it's `Any | Iterator[Any]`
|
||||
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore This is because they changed the return type, and didn't update the type hint! It should be overloaded depending on the value of `use_file_output` to `FileOutput | list[FileOutput]` but it's `Any | Iterator[Any]`
|
||||
f"{model_name}",
|
||||
input={
|
||||
"prompt": prompt,
|
||||
@@ -199,6 +217,21 @@ class ReplicateFluxAdvancedModelBlock(Block):
|
||||
wait=False, # don't arbitrarily return data:octect/stream or sometimes url depending on the model???? what is this api
|
||||
)
|
||||
|
||||
result = extract_result(output)
|
||||
# Check if output is a list or a string and extract accordingly; otherwise, assign a default message
|
||||
if isinstance(output, list) and len(output) > 0:
|
||||
if isinstance(output[0], FileOutput):
|
||||
result_url = output[0].url # If output is a list, get the first element
|
||||
else:
|
||||
result_url = output[
|
||||
0
|
||||
] # If output is a list and not a FileOutput, get the first element. Should never happen, but just in case.
|
||||
elif isinstance(output, FileOutput):
|
||||
result_url = output.url # If output is a FileOutput, use the url
|
||||
elif isinstance(output, str):
|
||||
result_url = output # If output is a string (for some reason due to their janky type hinting), use it directly
|
||||
else:
|
||||
result_url = (
|
||||
"No output received" # Fallback message if output is not as expected
|
||||
)
|
||||
|
||||
return result
|
||||
return result_url
|
||||
@@ -108,7 +108,6 @@ class ScreenshotWebPageBlock(Block):
|
||||
async def take_screenshot(
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
url: str,
|
||||
viewport_width: int,
|
||||
viewport_height: int,
|
||||
@@ -154,7 +153,6 @@ class ScreenshotWebPageBlock(Block):
|
||||
file=MediaFileType(
|
||||
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
|
||||
),
|
||||
user_id=user_id,
|
||||
return_content=True,
|
||||
)
|
||||
}
|
||||
@@ -165,14 +163,12 @@ class ScreenshotWebPageBlock(Block):
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
screenshot_data = await self.take_screenshot(
|
||||
credentials=credentials,
|
||||
graph_exec_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
url=input_data.url,
|
||||
viewport_width=input_data.viewport_width,
|
||||
viewport_height=input_data.viewport_height,
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import ContributorDetails, SchemaField
|
||||
from backend.util.file import get_exec_file_path, store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
class ReadSpreadsheetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
contents: str | None = SchemaField(
|
||||
description="The contents of the CSV/spreadsheet data to read",
|
||||
placeholder="a, b, c\n1,2,3\n4,5,6",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
file_input: MediaFileType | None = SchemaField(
|
||||
description="CSV or Excel file to read from (URL, data URI, or local path). Excel files are automatically converted to CSV",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
delimiter: str = SchemaField(
|
||||
description="The delimiter used in the CSV/spreadsheet data",
|
||||
default=",",
|
||||
)
|
||||
quotechar: str = SchemaField(
|
||||
description="The character used to quote fields",
|
||||
default='"',
|
||||
)
|
||||
escapechar: str = SchemaField(
|
||||
description="The character used to escape the delimiter",
|
||||
default="\\",
|
||||
)
|
||||
has_header: bool = SchemaField(
|
||||
description="Whether the CSV file has a header row",
|
||||
default=True,
|
||||
)
|
||||
skip_rows: int = SchemaField(
|
||||
description="The number of rows to skip from the start of the file",
|
||||
default=0,
|
||||
)
|
||||
strip: bool = SchemaField(
|
||||
description="Whether to strip whitespace from the values",
|
||||
default=True,
|
||||
)
|
||||
skip_columns: list[str] = SchemaField(
|
||||
description="The columns to skip from the start of the row",
|
||||
default_factory=list,
|
||||
)
|
||||
produce_singular_result: bool = SchemaField(
|
||||
description="If True, yield individual 'row' outputs only (can be slow). If False, yield both 'rows' (all data)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
row: dict[str, str] = SchemaField(
|
||||
description="The data produced from each row in the spreadsheet"
|
||||
)
|
||||
rows: list[dict[str, str]] = SchemaField(
|
||||
description="All the data in the spreadsheet as a list of rows"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="acf7625e-d2cb-4941-bfeb-2819fc6fc015",
|
||||
input_schema=ReadSpreadsheetBlock.Input,
|
||||
output_schema=ReadSpreadsheetBlock.Output,
|
||||
description="Reads CSV and Excel files and outputs the data as a list of dictionaries and individual rows. Excel files are automatically converted to CSV format.",
|
||||
contributors=[ContributorDetails(name="Nicholas Tindle")],
|
||||
categories={BlockCategory.TEXT, BlockCategory.DATA},
|
||||
test_input=[
|
||||
{
|
||||
"contents": "a, b, c\n1,2,3\n4,5,6",
|
||||
"produce_singular_result": False,
|
||||
},
|
||||
{
|
||||
"contents": "a, b, c\n1,2,3\n4,5,6",
|
||||
"produce_singular_result": True,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"rows",
|
||||
[
|
||||
{"a": "1", "b": "2", "c": "3"},
|
||||
{"a": "4", "b": "5", "c": "6"},
|
||||
],
|
||||
),
|
||||
("row", {"a": "1", "b": "2", "c": "3"}),
|
||||
("row", {"a": "4", "b": "5", "c": "6"}),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
|
||||
) -> BlockOutput:
|
||||
import csv
|
||||
from io import StringIO
|
||||
|
||||
# Determine data source - prefer file_input if provided, otherwise use contents
|
||||
if input_data.file_input:
|
||||
stored_file_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.file_input,
|
||||
return_content=False,
|
||||
)
|
||||
|
||||
# Get full file path
|
||||
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
|
||||
if not Path(file_path).exists():
|
||||
raise ValueError(f"File does not exist: {file_path}")
|
||||
|
||||
# Check if file is an Excel file and convert to CSV
|
||||
file_extension = Path(file_path).suffix.lower()
|
||||
|
||||
if file_extension in [".xlsx", ".xls"]:
|
||||
# Handle Excel files
|
||||
try:
|
||||
from io import StringIO
|
||||
|
||||
import pandas as pd
|
||||
|
||||
# Read Excel file
|
||||
df = pd.read_excel(file_path)
|
||||
|
||||
# Convert to CSV string
|
||||
csv_buffer = StringIO()
|
||||
df.to_csv(csv_buffer, index=False)
|
||||
csv_content = csv_buffer.getvalue()
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"pandas library is required to read Excel files. Please install it."
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to read Excel file: {e}")
|
||||
else:
|
||||
# Handle CSV/text files
|
||||
csv_content = Path(file_path).read_text(encoding="utf-8")
|
||||
elif input_data.contents:
|
||||
# Use direct string content
|
||||
csv_content = input_data.contents
|
||||
else:
|
||||
raise ValueError("Either 'contents' or 'file_input' must be provided")
|
||||
|
||||
csv_file = StringIO(csv_content)
|
||||
reader = csv.reader(
|
||||
csv_file,
|
||||
delimiter=input_data.delimiter,
|
||||
quotechar=input_data.quotechar,
|
||||
escapechar=input_data.escapechar,
|
||||
)
|
||||
|
||||
header = None
|
||||
if input_data.has_header:
|
||||
header = next(reader)
|
||||
if input_data.strip:
|
||||
header = [h.strip() for h in header]
|
||||
|
||||
for _ in range(input_data.skip_rows):
|
||||
next(reader)
|
||||
|
||||
def process_row(row):
|
||||
data = {}
|
||||
for i, value in enumerate(row):
|
||||
if i not in input_data.skip_columns:
|
||||
if input_data.has_header and header:
|
||||
data[header[i]] = value.strip() if input_data.strip else value
|
||||
else:
|
||||
data[str(i)] = value.strip() if input_data.strip else value
|
||||
return data
|
||||
|
||||
rows = [process_row(row) for row in reader]
|
||||
|
||||
if input_data.produce_singular_result:
|
||||
for processed_row in rows:
|
||||
yield "row", processed_row
|
||||
else:
|
||||
yield "rows", rows
|
||||
@@ -106,7 +106,6 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
input_data,
|
||||
credentials=exact_match_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -162,7 +161,6 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
input_data,
|
||||
credentials=wildcard_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -209,7 +207,6 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
input_data,
|
||||
credentials=non_matching_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -259,7 +256,6 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
input_data,
|
||||
credentials=exact_match_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -319,7 +315,6 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
input_data,
|
||||
credentials=auto_discovered_creds, # Execution manager found these
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -383,7 +378,6 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
input_data,
|
||||
credentials=multi_header_creds,
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -472,7 +466,6 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
input_data,
|
||||
credentials=test_creds,
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json, text
|
||||
from backend.util.file import get_exec_file_path, store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
formatter = text.TextFormatter()
|
||||
|
||||
@@ -306,129 +303,3 @@ class TextReplaceBlock(Block):
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "output", input_data.text.replace(input_data.old, input_data.new)
|
||||
|
||||
|
||||
class FileReadBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
file_input: MediaFileType = SchemaField(
|
||||
description="The file to read from (URL, data URI, or local path)"
|
||||
)
|
||||
delimiter: str = SchemaField(
|
||||
description="Delimiter to split the content into rows/chunks (e.g., '\\n' for lines)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
size_limit: int = SchemaField(
|
||||
description="Maximum size in bytes per chunk to yield (0 for no limit)",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
row_limit: int = SchemaField(
|
||||
description="Maximum number of rows to process (0 for no limit, requires delimiter)",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
skip_size: int = SchemaField(
|
||||
description="Number of characters to skip from the beginning of the file",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
skip_rows: int = SchemaField(
|
||||
description="Number of rows to skip from the beginning (requires delimiter)",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
content: str = SchemaField(
|
||||
description="File content, yielded as individual chunks when delimiter or size limits are applied"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3735a31f-7e18-4aca-9e90-08a7120674bc",
|
||||
input_schema=FileReadBlock.Input,
|
||||
output_schema=FileReadBlock.Output,
|
||||
description="Reads a file and returns its content as a string, with optional chunking by delimiter and size limits",
|
||||
categories={BlockCategory.TEXT, BlockCategory.DATA},
|
||||
test_input={
|
||||
"file_input": "data:text/plain;base64,SGVsbG8gV29ybGQ=",
|
||||
},
|
||||
test_output=[
|
||||
("content", "Hello World"),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
|
||||
) -> BlockOutput:
|
||||
# Store the media file properly (handles URLs, data URIs, etc.)
|
||||
stored_file_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.file_input,
|
||||
return_content=False,
|
||||
)
|
||||
|
||||
# Get full file path
|
||||
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
|
||||
|
||||
if not Path(file_path).exists():
|
||||
raise ValueError(f"File does not exist: {file_path}")
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
except UnicodeDecodeError:
|
||||
# Try with different encodings
|
||||
try:
|
||||
with open(file_path, "r", encoding="latin-1") as file:
|
||||
content = file.read()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to read file: {e}")
|
||||
|
||||
# Apply skip_size (character-level skip)
|
||||
if input_data.skip_size > 0:
|
||||
content = content[input_data.skip_size :]
|
||||
|
||||
# Split content into items (by delimiter or treat as single item)
|
||||
items = (
|
||||
content.split(input_data.delimiter) if input_data.delimiter else [content]
|
||||
)
|
||||
|
||||
# Apply skip_rows (item-level skip)
|
||||
if input_data.skip_rows > 0:
|
||||
items = items[input_data.skip_rows :]
|
||||
|
||||
# Apply row_limit (item-level limit)
|
||||
if input_data.row_limit > 0:
|
||||
items = items[: input_data.row_limit]
|
||||
|
||||
# Process each item and create chunks
|
||||
def create_chunks(text, size_limit):
|
||||
"""Create chunks from text based on size_limit"""
|
||||
if size_limit <= 0:
|
||||
return [text] if text else []
|
||||
|
||||
chunks = []
|
||||
for i in range(0, len(text), size_limit):
|
||||
chunk = text[i : i + size_limit]
|
||||
if chunk: # Only add non-empty chunks
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
|
||||
# Process items and yield as content chunks
|
||||
if items:
|
||||
full_content = (
|
||||
input_data.delimiter.join(items)
|
||||
if input_data.delimiter
|
||||
else "".join(items)
|
||||
)
|
||||
|
||||
# Create chunks of the full content based on size_limit
|
||||
content_chunks = create_chunks(full_content, input_data.size_limit)
|
||||
for chunk in content_chunks:
|
||||
yield "content", chunk
|
||||
else:
|
||||
yield "content", ""
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from youtube_transcript_api._api import YouTubeTranscriptApi
|
||||
from youtube_transcript_api._transcripts import FetchedTranscript
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
from youtube_transcript_api.formatters import TextFormatter
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
@@ -43,7 +42,6 @@ class TranscribeYoutubeVideoBlock(Block):
|
||||
{"text": "Never gonna give you up"},
|
||||
{"text": "Never gonna let you down"},
|
||||
],
|
||||
"format_transcript": lambda transcript: "Never gonna give you up\nNever gonna let you down",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -63,20 +61,30 @@ class TranscribeYoutubeVideoBlock(Block):
|
||||
raise ValueError(f"Invalid YouTube URL: {url}")
|
||||
|
||||
@staticmethod
|
||||
def get_transcript(video_id: str) -> FetchedTranscript:
|
||||
return YouTubeTranscriptApi().fetch(video_id=video_id)
|
||||
def get_transcript(video_id: str):
|
||||
try:
|
||||
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
|
||||
|
||||
@staticmethod
|
||||
def format_transcript(transcript: FetchedTranscript) -> str:
|
||||
formatter = TextFormatter()
|
||||
transcript_text = formatter.format_transcript(transcript)
|
||||
return transcript_text
|
||||
if not transcript_list:
|
||||
raise ValueError(f"No transcripts found for the video: {video_id}")
|
||||
|
||||
for transcript in transcript_list:
|
||||
first_transcript = transcript_list.find_transcript(
|
||||
[transcript.language_code]
|
||||
)
|
||||
return YouTubeTranscriptApi.get_transcript(
|
||||
video_id, languages=[first_transcript.language_code]
|
||||
)
|
||||
|
||||
except Exception:
|
||||
raise ValueError(f"No transcripts found for the video: {video_id}")
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
video_id = self.extract_video_id(input_data.youtube_url)
|
||||
yield "video_id", video_id
|
||||
|
||||
transcript = self.get_transcript(video_id)
|
||||
transcript_text = self.format_transcript(transcript=transcript)
|
||||
formatter = TextFormatter()
|
||||
transcript_text = formatter.format_transcript(transcript)
|
||||
|
||||
yield "transcript", transcript_text
|
||||
|
||||
@@ -1,359 +0,0 @@
|
||||
import asyncio
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
from faker import Faker
|
||||
from prisma import Prisma
|
||||
|
||||
faker = Faker()
|
||||
|
||||
|
||||
async def check_cron_job(db):
|
||||
"""Check if the pg_cron job for refreshing materialized views exists."""
|
||||
print("\n1. Checking pg_cron job...")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
# Check if pg_cron extension exists
|
||||
extension_check = await db.query_raw("CREATE EXTENSION pg_cron;")
|
||||
print(extension_check)
|
||||
extension_check = await db.query_raw(
|
||||
"SELECT COUNT(*) as count FROM pg_extension WHERE extname = 'pg_cron'"
|
||||
)
|
||||
if extension_check[0]["count"] == 0:
|
||||
print("⚠️ pg_cron extension is not installed")
|
||||
return False
|
||||
|
||||
# Check if the refresh job exists
|
||||
job_check = await db.query_raw(
|
||||
"""
|
||||
SELECT jobname, schedule, command
|
||||
FROM cron.job
|
||||
WHERE jobname = 'refresh-store-views'
|
||||
"""
|
||||
)
|
||||
|
||||
if job_check:
|
||||
job = job_check[0]
|
||||
print("✅ pg_cron job found:")
|
||||
print(f" Name: {job['jobname']}")
|
||||
print(f" Schedule: {job['schedule']} (every 15 minutes)")
|
||||
print(f" Command: {job['command']}")
|
||||
return True
|
||||
else:
|
||||
print("⚠️ pg_cron job 'refresh-store-views' not found")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error checking pg_cron: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_materialized_view_counts(db):
|
||||
"""Get current counts from materialized views."""
|
||||
print("\n2. Getting current materialized view data...")
|
||||
print("-" * 40)
|
||||
|
||||
# Get counts from mv_agent_run_counts
|
||||
agent_runs = await db.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as total_agents,
|
||||
SUM(run_count) as total_runs,
|
||||
MAX(run_count) as max_runs,
|
||||
MIN(run_count) as min_runs
|
||||
FROM mv_agent_run_counts
|
||||
"""
|
||||
)
|
||||
|
||||
# Get counts from mv_review_stats
|
||||
review_stats = await db.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as total_listings,
|
||||
SUM(review_count) as total_reviews,
|
||||
AVG(avg_rating) as overall_avg_rating
|
||||
FROM mv_review_stats
|
||||
"""
|
||||
)
|
||||
|
||||
# Get sample data from StoreAgent view
|
||||
store_agents = await db.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as total_store_agents,
|
||||
AVG(runs) as avg_runs,
|
||||
AVG(rating) as avg_rating
|
||||
FROM "StoreAgent"
|
||||
"""
|
||||
)
|
||||
|
||||
agent_run_data = agent_runs[0] if agent_runs else {}
|
||||
review_data = review_stats[0] if review_stats else {}
|
||||
store_data = store_agents[0] if store_agents else {}
|
||||
|
||||
print("📊 mv_agent_run_counts:")
|
||||
print(f" Total agents: {agent_run_data.get('total_agents', 0)}")
|
||||
print(f" Total runs: {agent_run_data.get('total_runs', 0)}")
|
||||
print(f" Max runs per agent: {agent_run_data.get('max_runs', 0)}")
|
||||
print(f" Min runs per agent: {agent_run_data.get('min_runs', 0)}")
|
||||
|
||||
print("\n📊 mv_review_stats:")
|
||||
print(f" Total listings: {review_data.get('total_listings', 0)}")
|
||||
print(f" Total reviews: {review_data.get('total_reviews', 0)}")
|
||||
print(f" Overall avg rating: {review_data.get('overall_avg_rating') or 0:.2f}")
|
||||
|
||||
print("\n📊 StoreAgent view:")
|
||||
print(f" Total store agents: {store_data.get('total_store_agents', 0)}")
|
||||
print(f" Average runs: {store_data.get('avg_runs') or 0:.2f}")
|
||||
print(f" Average rating: {store_data.get('avg_rating') or 0:.2f}")
|
||||
|
||||
return {
|
||||
"agent_runs": agent_run_data,
|
||||
"reviews": review_data,
|
||||
"store_agents": store_data,
|
||||
}
|
||||
|
||||
|
||||
async def add_test_data(db):
|
||||
"""Add some test data to verify materialized view updates."""
|
||||
print("\n3. Adding test data...")
|
||||
print("-" * 40)
|
||||
|
||||
# Get some existing data
|
||||
users = await db.user.find_many(take=5)
|
||||
graphs = await db.agentgraph.find_many(take=5)
|
||||
|
||||
if not users or not graphs:
|
||||
print("❌ No existing users or graphs found. Run test_data_creator.py first.")
|
||||
return False
|
||||
|
||||
# Add new executions
|
||||
print("Adding new agent graph executions...")
|
||||
new_executions = 0
|
||||
for graph in graphs:
|
||||
for _ in range(random.randint(2, 5)):
|
||||
await db.agentgraphexecution.create(
|
||||
data={
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"userId": random.choice(users).id,
|
||||
"executionStatus": "COMPLETED",
|
||||
"startedAt": datetime.now(),
|
||||
}
|
||||
)
|
||||
new_executions += 1
|
||||
|
||||
print(f"✅ Added {new_executions} new executions")
|
||||
|
||||
# Check if we need to create store listings first
|
||||
store_versions = await db.storelistingversion.find_many(
|
||||
where={"submissionStatus": "APPROVED"}, take=5
|
||||
)
|
||||
|
||||
if not store_versions:
|
||||
print("\nNo approved store listings found. Creating test store listings...")
|
||||
|
||||
# Create store listings for existing agent graphs
|
||||
for i, graph in enumerate(graphs[:3]): # Create up to 3 store listings
|
||||
# Create a store listing
|
||||
listing = await db.storelisting.create(
|
||||
data={
|
||||
"slug": f"test-agent-{graph.id[:8]}",
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"hasApprovedVersion": True,
|
||||
"owningUserId": graph.userId,
|
||||
}
|
||||
)
|
||||
|
||||
# Create an approved version
|
||||
version = await db.storelistingversion.create(
|
||||
data={
|
||||
"storeListingId": listing.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"name": f"Test Agent {i+1}",
|
||||
"subHeading": faker.catch_phrase(),
|
||||
"description": faker.paragraph(nb_sentences=5),
|
||||
"imageUrls": [faker.image_url()],
|
||||
"categories": ["productivity", "automation"],
|
||||
"submissionStatus": "APPROVED",
|
||||
"submittedAt": datetime.now(),
|
||||
}
|
||||
)
|
||||
|
||||
# Update listing with active version
|
||||
await db.storelisting.update(
|
||||
where={"id": listing.id}, data={"activeVersionId": version.id}
|
||||
)
|
||||
|
||||
print("✅ Created test store listings")
|
||||
|
||||
# Re-fetch approved versions
|
||||
store_versions = await db.storelistingversion.find_many(
|
||||
where={"submissionStatus": "APPROVED"}, take=5
|
||||
)
|
||||
|
||||
# Add new reviews
|
||||
print("\nAdding new store listing reviews...")
|
||||
new_reviews = 0
|
||||
for version in store_versions:
|
||||
# Find users who haven't reviewed this version
|
||||
existing_reviews = await db.storelistingreview.find_many(
|
||||
where={"storeListingVersionId": version.id}
|
||||
)
|
||||
reviewed_user_ids = {r.reviewByUserId for r in existing_reviews}
|
||||
available_users = [u for u in users if u.id not in reviewed_user_ids]
|
||||
|
||||
if available_users:
|
||||
user = random.choice(available_users)
|
||||
await db.storelistingreview.create(
|
||||
data={
|
||||
"storeListingVersionId": version.id,
|
||||
"reviewByUserId": user.id,
|
||||
"score": random.randint(3, 5),
|
||||
"comments": faker.text(max_nb_chars=100),
|
||||
}
|
||||
)
|
||||
new_reviews += 1
|
||||
|
||||
print(f"✅ Added {new_reviews} new reviews")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def refresh_materialized_views(db):
|
||||
"""Manually refresh the materialized views."""
|
||||
print("\n4. Manually refreshing materialized views...")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
await db.execute_raw("SELECT refresh_store_materialized_views();")
|
||||
print("✅ Materialized views refreshed successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ Error refreshing views: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def compare_counts(before, after):
|
||||
"""Compare counts before and after refresh."""
|
||||
print("\n5. Comparing counts before and after refresh...")
|
||||
print("-" * 40)
|
||||
|
||||
# Compare agent runs
|
||||
print("🔍 Agent run changes:")
|
||||
before_runs = before["agent_runs"].get("total_runs") or 0
|
||||
after_runs = after["agent_runs"].get("total_runs") or 0
|
||||
print(
|
||||
f" Total runs: {before_runs} → {after_runs} " f"(+{after_runs - before_runs})"
|
||||
)
|
||||
|
||||
# Compare reviews
|
||||
print("\n🔍 Review changes:")
|
||||
before_reviews = before["reviews"].get("total_reviews") or 0
|
||||
after_reviews = after["reviews"].get("total_reviews") or 0
|
||||
print(
|
||||
f" Total reviews: {before_reviews} → {after_reviews} "
|
||||
f"(+{after_reviews - before_reviews})"
|
||||
)
|
||||
|
||||
# Compare store agents
|
||||
print("\n🔍 StoreAgent view changes:")
|
||||
before_avg_runs = before["store_agents"].get("avg_runs", 0) or 0
|
||||
after_avg_runs = after["store_agents"].get("avg_runs", 0) or 0
|
||||
print(
|
||||
f" Average runs: {before_avg_runs:.2f} → {after_avg_runs:.2f} "
|
||||
f"(+{after_avg_runs - before_avg_runs:.2f})"
|
||||
)
|
||||
|
||||
# Verify changes occurred
|
||||
runs_changed = (after["agent_runs"].get("total_runs") or 0) > (
|
||||
before["agent_runs"].get("total_runs") or 0
|
||||
)
|
||||
reviews_changed = (after["reviews"].get("total_reviews") or 0) > (
|
||||
before["reviews"].get("total_reviews") or 0
|
||||
)
|
||||
|
||||
if runs_changed and reviews_changed:
|
||||
print("\n✅ Materialized views are updating correctly!")
|
||||
return True
|
||||
else:
|
||||
print("\n⚠️ Some materialized views may not have updated:")
|
||||
if not runs_changed:
|
||||
print(" - Agent run counts did not increase")
|
||||
if not reviews_changed:
|
||||
print(" - Review counts did not increase")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
|
||||
print("=" * 60)
|
||||
print("Materialized Views Test")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Check if data exists
|
||||
user_count = await db.user.count()
|
||||
if user_count == 0:
|
||||
print("❌ No data in database. Please run test_data_creator.py first.")
|
||||
await db.disconnect()
|
||||
return
|
||||
|
||||
# 1. Check cron job
|
||||
cron_exists = await check_cron_job(db)
|
||||
|
||||
# 2. Get initial counts
|
||||
counts_before = await get_materialized_view_counts(db)
|
||||
|
||||
# 3. Add test data
|
||||
data_added = await add_test_data(db)
|
||||
refresh_success = False
|
||||
|
||||
if data_added:
|
||||
# Wait a moment for data to be committed
|
||||
print("\nWaiting for data to be committed...")
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# 4. Manually refresh views
|
||||
refresh_success = await refresh_materialized_views(db)
|
||||
|
||||
if refresh_success:
|
||||
# 5. Get counts after refresh
|
||||
counts_after = await get_materialized_view_counts(db)
|
||||
|
||||
# 6. Compare results
|
||||
await compare_counts(counts_before, counts_after)
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Test Summary")
|
||||
print("=" * 60)
|
||||
print(f"✓ pg_cron job exists: {'Yes' if cron_exists else 'No'}")
|
||||
print(f"✓ Test data added: {'Yes' if data_added else 'No'}")
|
||||
print(f"✓ Manual refresh worked: {'Yes' if refresh_success else 'No'}")
|
||||
print(
|
||||
f"✓ Views updated correctly: {'Yes' if data_added and refresh_success else 'Cannot verify'}"
|
||||
)
|
||||
|
||||
if cron_exists:
|
||||
print(
|
||||
"\n💡 The materialized views will also refresh automatically every 15 minutes via pg_cron."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"\n⚠️ Automatic refresh is not configured. Views must be refreshed manually."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,159 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Check store-related data in the database."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from prisma import Prisma
|
||||
|
||||
|
||||
async def check_store_data(db):
|
||||
"""Check what store data exists in the database."""
|
||||
|
||||
print("============================================================")
|
||||
print("Store Data Check")
|
||||
print("============================================================")
|
||||
|
||||
# Check store listings
|
||||
print("\n1. Store Listings:")
|
||||
print("-" * 40)
|
||||
listings = await db.storelisting.find_many()
|
||||
print(f"Total store listings: {len(listings)}")
|
||||
|
||||
if listings:
|
||||
for listing in listings[:5]:
|
||||
print(f"\nListing ID: {listing.id}")
|
||||
print(f" Name: {listing.name}")
|
||||
print(f" Status: {listing.status}")
|
||||
print(f" Slug: {listing.slug}")
|
||||
|
||||
# Check store listing versions
|
||||
print("\n\n2. Store Listing Versions:")
|
||||
print("-" * 40)
|
||||
versions = await db.storelistingversion.find_many(include={"StoreListing": True})
|
||||
print(f"Total store listing versions: {len(versions)}")
|
||||
|
||||
# Group by submission status
|
||||
status_counts = {}
|
||||
for version in versions:
|
||||
status = version.submissionStatus
|
||||
status_counts[status] = status_counts.get(status, 0) + 1
|
||||
|
||||
print("\nVersions by status:")
|
||||
for status, count in status_counts.items():
|
||||
print(f" {status}: {count}")
|
||||
|
||||
# Show approved versions
|
||||
approved_versions = [v for v in versions if v.submissionStatus == "APPROVED"]
|
||||
print(f"\nApproved versions: {len(approved_versions)}")
|
||||
if approved_versions:
|
||||
for version in approved_versions[:5]:
|
||||
print(f"\n Version ID: {version.id}")
|
||||
print(f" Listing: {version.StoreListing.name}")
|
||||
print(f" Version: {version.version}")
|
||||
|
||||
# Check store listing reviews
|
||||
print("\n\n3. Store Listing Reviews:")
|
||||
print("-" * 40)
|
||||
reviews = await db.storelistingreview.find_many(
|
||||
include={"StoreListingVersion": {"include": {"StoreListing": True}}}
|
||||
)
|
||||
print(f"Total reviews: {len(reviews)}")
|
||||
|
||||
if reviews:
|
||||
# Calculate average rating
|
||||
total_score = sum(r.score for r in reviews)
|
||||
avg_score = total_score / len(reviews) if reviews else 0
|
||||
print(f"Average rating: {avg_score:.2f}")
|
||||
|
||||
# Show sample reviews
|
||||
print("\nSample reviews:")
|
||||
for review in reviews[:3]:
|
||||
print(f"\n Review for: {review.StoreListingVersion.StoreListing.name}")
|
||||
print(f" Score: {review.score}")
|
||||
print(f" Comments: {review.comments[:100]}...")
|
||||
|
||||
# Check StoreAgent view data
|
||||
print("\n\n4. StoreAgent View Data:")
|
||||
print("-" * 40)
|
||||
|
||||
# Query the StoreAgent view
|
||||
query = """
|
||||
SELECT
|
||||
sa.listing_id,
|
||||
sa.slug,
|
||||
sa.agent_name,
|
||||
sa.description,
|
||||
sa.featured,
|
||||
sa.runs,
|
||||
sa.rating,
|
||||
sa.creator_username,
|
||||
sa.categories,
|
||||
sa.updated_at
|
||||
FROM "StoreAgent" sa
|
||||
LIMIT 10;
|
||||
"""
|
||||
|
||||
store_agents = await db.query_raw(query)
|
||||
print(f"Total store agents in view: {len(store_agents)}")
|
||||
|
||||
if store_agents:
|
||||
for agent in store_agents[:5]:
|
||||
print(f"\nStore Agent: {agent['agent_name']}")
|
||||
print(f" Slug: {agent['slug']}")
|
||||
print(f" Runs: {agent['runs']}")
|
||||
print(f" Rating: {agent['rating']}")
|
||||
print(f" Creator: {agent['creator_username']}")
|
||||
|
||||
# Check the underlying data that should populate StoreAgent
|
||||
print("\n\n5. Data that should populate StoreAgent view:")
|
||||
print("-" * 40)
|
||||
|
||||
# Check for any APPROVED store listing versions
|
||||
query = """
|
||||
SELECT COUNT(*) as count
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
"""
|
||||
|
||||
result = await db.query_raw(query)
|
||||
approved_count = result[0]["count"] if result else 0
|
||||
print(f"Approved store listing versions: {approved_count}")
|
||||
|
||||
# Check for store listings with hasApprovedVersion = true
|
||||
query = """
|
||||
SELECT COUNT(*) as count
|
||||
FROM "StoreListing"
|
||||
WHERE "hasApprovedVersion" = true AND "isDeleted" = false
|
||||
"""
|
||||
|
||||
result = await db.query_raw(query)
|
||||
has_approved_count = result[0]["count"] if result else 0
|
||||
print(f"Store listings with approved versions: {has_approved_count}")
|
||||
|
||||
# Check agent graph executions
|
||||
query = """
|
||||
SELECT COUNT(DISTINCT "agentGraphId") as unique_agents,
|
||||
COUNT(*) as total_executions
|
||||
FROM "AgentGraphExecution"
|
||||
"""
|
||||
|
||||
result = await db.query_raw(query)
|
||||
if result:
|
||||
print("\nAgent Graph Executions:")
|
||||
print(f" Unique agents with executions: {result[0]['unique_agents']}")
|
||||
print(f" Total executions: {result[0]['total_executions']}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function."""
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
|
||||
try:
|
||||
await check_store_data(db)
|
||||
finally:
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -425,7 +425,28 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
raise ValueError(f"{self.name} did not produce any output for {output}")
|
||||
|
||||
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
|
||||
self.execution_stats += stats
|
||||
stats_dict = stats.model_dump()
|
||||
current_stats = self.execution_stats.model_dump()
|
||||
|
||||
for key, value in stats_dict.items():
|
||||
if key not in current_stats:
|
||||
# Field doesn't exist yet, just set it, but this will probably
|
||||
# not happen, just in case though so we throw for invalid when
|
||||
# converting back in
|
||||
current_stats[key] = value
|
||||
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
|
||||
current_stats[key].update(value)
|
||||
elif isinstance(value, (int, float)) and isinstance(
|
||||
current_stats[key], (int, float)
|
||||
):
|
||||
current_stats[key] += value
|
||||
elif isinstance(value, list) and isinstance(current_stats[key], list):
|
||||
current_stats[key].extend(value)
|
||||
else:
|
||||
current_stats[key] = value
|
||||
|
||||
self.execution_stats = NodeExecutionStats(**current_stats)
|
||||
|
||||
return self.execution_stats
|
||||
|
||||
@property
|
||||
|
||||
@@ -18,8 +18,7 @@ from backend.blocks.llm import (
|
||||
AITextSummarizerBlock,
|
||||
LlmModel,
|
||||
)
|
||||
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
|
||||
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
from backend.blocks.replicate_flux_advanced import ReplicateFluxAdvancedModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
@@ -292,18 +291,6 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
},
|
||||
)
|
||||
],
|
||||
ReplicateModelBlock: [
|
||||
BlockCost(
|
||||
cost_amount=10,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": replicate_credentials.id,
|
||||
"provider": replicate_credentials.provider,
|
||||
"type": replicate_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
AIImageEditorBlock: [
|
||||
BlockCost(
|
||||
cost_amount=10,
|
||||
|
||||
@@ -93,28 +93,6 @@ async def locked_transaction(key: str):
|
||||
yield tx
|
||||
|
||||
|
||||
def get_database_schema() -> str:
|
||||
"""Extract database schema from DATABASE_URL."""
|
||||
parsed_url = urlparse(DATABASE_URL)
|
||||
query_params = dict(parse_qsl(parsed_url.query))
|
||||
return query_params.get("schema", "public")
|
||||
|
||||
|
||||
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
||||
"""Execute raw SQL query with proper schema handling."""
|
||||
schema = get_database_schema()
|
||||
schema_prefix = f"{schema}." if schema != "public" else ""
|
||||
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
||||
|
||||
import prisma as prisma_module
|
||||
|
||||
result = await prisma_module.get_client().query_raw(
|
||||
formatted_query, *args # type: ignore
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class BaseDbModel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
|
||||
|
||||
@@ -8,11 +8,8 @@ from redis.asyncio.client import PubSub as AsyncPubSub
|
||||
from redis.client import PubSub
|
||||
|
||||
from backend.data import redis_client as redis
|
||||
from backend.util import json
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Settings().config
|
||||
|
||||
|
||||
M = TypeVar("M", bound=BaseModel)
|
||||
@@ -31,41 +28,7 @@ class BaseRedisEventBus(Generic[M], ABC):
|
||||
return _EventPayloadWrapper[self.Model]
|
||||
|
||||
def _serialize_message(self, item: M, channel_key: str) -> tuple[str, str]:
|
||||
MAX_MESSAGE_SIZE = config.max_message_size_limit
|
||||
|
||||
try:
|
||||
# Use backend.util.json.dumps which handles datetime and other complex types
|
||||
message = json.dumps(
|
||||
self.Message(payload=item), ensure_ascii=False, separators=(",", ":")
|
||||
)
|
||||
except UnicodeError:
|
||||
# Fallback to ASCII encoding if Unicode causes issues
|
||||
message = json.dumps(
|
||||
self.Message(payload=item), ensure_ascii=True, separators=(",", ":")
|
||||
)
|
||||
logger.warning(
|
||||
f"Unicode serialization failed, falling back to ASCII for channel {channel_key}"
|
||||
)
|
||||
|
||||
# Check message size and truncate if necessary
|
||||
message_size = len(message.encode("utf-8"))
|
||||
if message_size > MAX_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Message size {message_size} bytes exceeds limit {MAX_MESSAGE_SIZE} bytes for channel {channel_key}. "
|
||||
"Truncating payload to prevent Redis connection issues."
|
||||
)
|
||||
error_payload = {
|
||||
"payload": {
|
||||
"event_type": "error_comms_update",
|
||||
"error": "Payload too large for Redis transmission",
|
||||
"original_size_bytes": message_size,
|
||||
"max_size_bytes": MAX_MESSAGE_SIZE,
|
||||
}
|
||||
}
|
||||
message = json.dumps(
|
||||
error_payload, ensure_ascii=False, separators=(",", ":")
|
||||
)
|
||||
|
||||
message = self.Message(payload=item).model_dump_json()
|
||||
channel_name = f"{self.event_bus_name}/{channel_key}"
|
||||
logger.debug(f"[{channel_name}] Publishing an event to Redis {message}")
|
||||
return message, channel_name
|
||||
|
||||
@@ -40,7 +40,6 @@ from pydantic.fields import Field
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.settings import Config
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .block import (
|
||||
BlockInput,
|
||||
@@ -50,7 +49,7 @@ from .block import (
|
||||
get_io_block_ids,
|
||||
get_webhook_block_ids,
|
||||
)
|
||||
from .db import BaseDbModel, query_raw_with_schema
|
||||
from .db import BaseDbModel
|
||||
from .event_bus import AsyncRedisEventBus, RedisEventBus
|
||||
from .includes import (
|
||||
EXECUTION_RESULT_INCLUDE,
|
||||
@@ -69,21 +68,6 @@ config = Config()
|
||||
# -------------------------- Models -------------------------- #
|
||||
|
||||
|
||||
class BlockErrorStats(BaseModel):
|
||||
"""Typed data structure for block error statistics."""
|
||||
|
||||
block_id: str
|
||||
total_executions: int
|
||||
failed_executions: int
|
||||
|
||||
@property
|
||||
def error_rate(self) -> float:
|
||||
"""Calculate error rate as a percentage."""
|
||||
if self.total_executions == 0:
|
||||
return 0.0
|
||||
return (self.failed_executions / self.total_executions) * 100
|
||||
|
||||
|
||||
ExecutionStatus = AgentExecutionStatus
|
||||
|
||||
|
||||
@@ -373,7 +357,6 @@ async def get_graph_executions(
|
||||
created_time_lte: datetime | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[GraphExecutionMeta]:
|
||||
"""⚠️ **Optional `user_id` check**: MUST USE check in user-facing endpoints."""
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
"isDeleted": False,
|
||||
}
|
||||
@@ -739,7 +722,6 @@ async def delete_graph_execution(
|
||||
|
||||
|
||||
async def get_node_execution(node_exec_id: str) -> NodeExecutionResult | None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
execution = await AgentNodeExecution.prisma().find_first(
|
||||
where={"id": node_exec_id},
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
@@ -750,19 +732,15 @@ async def get_node_execution(node_exec_id: str) -> NodeExecutionResult | None:
|
||||
|
||||
|
||||
async def get_node_executions(
|
||||
graph_exec_id: str | None = None,
|
||||
graph_exec_id: str,
|
||||
node_id: str | None = None,
|
||||
block_ids: list[str] | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
limit: int | None = None,
|
||||
created_time_gte: datetime | None = None,
|
||||
created_time_lte: datetime | None = None,
|
||||
include_exec_data: bool = True,
|
||||
) -> list[NodeExecutionResult]:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
where_clause: AgentNodeExecutionWhereInput = {}
|
||||
if graph_exec_id:
|
||||
where_clause["agentGraphExecutionId"] = graph_exec_id
|
||||
where_clause: AgentNodeExecutionWhereInput = {
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
}
|
||||
if node_id:
|
||||
where_clause["agentNodeId"] = node_id
|
||||
if block_ids:
|
||||
@@ -770,19 +748,9 @@ async def get_node_executions(
|
||||
if statuses:
|
||||
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
if created_time_gte or created_time_lte:
|
||||
where_clause["addedTime"] = {
|
||||
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
|
||||
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
|
||||
}
|
||||
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=(
|
||||
EXECUTION_RESULT_INCLUDE
|
||||
if include_exec_data
|
||||
else {"Node": True, "GraphExecution": True}
|
||||
),
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
order=EXECUTION_RESULT_ORDER,
|
||||
take=limit,
|
||||
)
|
||||
@@ -793,7 +761,6 @@ async def get_node_executions(
|
||||
async def get_latest_node_execution(
|
||||
node_id: str, graph_eid: str
|
||||
) -> NodeExecutionResult | None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
execution = await AgentNodeExecution.prisma().find_first(
|
||||
where={
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
@@ -867,7 +834,6 @@ class ExecutionQueue(Generic[T]):
|
||||
class ExecutionEventType(str, Enum):
|
||||
GRAPH_EXEC_UPDATE = "graph_execution_update"
|
||||
NODE_EXEC_UPDATE = "node_execution_update"
|
||||
ERROR_COMMS_UPDATE = "error_comms_update"
|
||||
|
||||
|
||||
class GraphExecutionEvent(GraphExecution):
|
||||
@@ -902,25 +868,11 @@ class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
|
||||
|
||||
def publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
event = NodeExecutionEvent.model_validate(res.model_dump())
|
||||
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
|
||||
self.publish_event(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
|
||||
|
||||
def publish_graph_exec_update(self, res: GraphExecution):
|
||||
event = GraphExecutionEvent.model_validate(res.model_dump())
|
||||
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
||||
|
||||
def _publish(self, event: ExecutionEvent, channel: str):
|
||||
"""
|
||||
truncate inputs and outputs to avoid large payloads
|
||||
"""
|
||||
limit = config.max_message_size_limit // 2
|
||||
if isinstance(event, GraphExecutionEvent):
|
||||
event.inputs = truncate(event.inputs, limit)
|
||||
event.outputs = truncate(event.outputs, limit)
|
||||
elif isinstance(event, NodeExecutionEvent):
|
||||
event.input_data = truncate(event.input_data, limit)
|
||||
event.output_data = truncate(event.output_data, limit)
|
||||
|
||||
super().publish_event(event, channel)
|
||||
self.publish_event(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
||||
|
||||
def listen(
|
||||
self, user_id: str, graph_id: str = "*", graph_exec_id: str = "*"
|
||||
@@ -944,30 +896,13 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
|
||||
|
||||
async def publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
event = NodeExecutionEvent.model_validate(res.model_dump())
|
||||
await self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
|
||||
await self.publish_event(
|
||||
event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}"
|
||||
)
|
||||
|
||||
async def publish_graph_exec_update(self, res: GraphExecutionMeta):
|
||||
# GraphExecutionEvent requires inputs and outputs fields that GraphExecutionMeta doesn't have
|
||||
# Add default empty values for compatibility
|
||||
event_data = res.model_dump()
|
||||
event_data.setdefault("inputs", {})
|
||||
event_data.setdefault("outputs", {})
|
||||
event = GraphExecutionEvent.model_validate(event_data)
|
||||
await self._publish(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
||||
|
||||
async def _publish(self, event: ExecutionEvent, channel: str):
|
||||
"""
|
||||
truncate inputs and outputs to avoid large payloads
|
||||
"""
|
||||
limit = config.max_message_size_limit // 2
|
||||
if isinstance(event, GraphExecutionEvent):
|
||||
event.inputs = truncate(event.inputs, limit)
|
||||
event.outputs = truncate(event.outputs, limit)
|
||||
elif isinstance(event, NodeExecutionEvent):
|
||||
event.input_data = truncate(event.input_data, limit)
|
||||
event.output_data = truncate(event.output_data, limit)
|
||||
|
||||
await super().publish_event(event, channel)
|
||||
event = GraphExecutionEvent.model_validate(res.model_dump())
|
||||
await self.publish_event(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
||||
|
||||
async def listen(
|
||||
self, user_id: str, graph_id: str = "*", graph_exec_id: str = "*"
|
||||
@@ -1028,33 +963,3 @@ async def set_execution_kv_data(
|
||||
},
|
||||
)
|
||||
return type_utils.convert(resp.data, type[Any]) if resp and resp.data else None
|
||||
|
||||
|
||||
async def get_block_error_stats(
|
||||
start_time: datetime, end_time: datetime
|
||||
) -> list[BlockErrorStats]:
|
||||
"""Get block execution stats using efficient SQL aggregation."""
|
||||
|
||||
query_template = """
|
||||
SELECT
|
||||
n."agentBlockId" as block_id,
|
||||
COUNT(*) as total_executions,
|
||||
SUM(CASE WHEN ne."executionStatus" = 'FAILED' THEN 1 ELSE 0 END) as failed_executions
|
||||
FROM {schema_prefix}"AgentNodeExecution" ne
|
||||
JOIN {schema_prefix}"AgentNode" n ON ne."agentNodeId" = n.id
|
||||
WHERE ne."addedTime" >= $1::timestamp AND ne."addedTime" <= $2::timestamp
|
||||
GROUP BY n."agentBlockId"
|
||||
HAVING COUNT(*) >= 10
|
||||
"""
|
||||
|
||||
result = await query_raw_with_schema(query_template, start_time, end_time)
|
||||
|
||||
# Convert to typed data structures
|
||||
return [
|
||||
BlockErrorStats(
|
||||
block_id=row["block_id"],
|
||||
total_executions=int(row["total_executions"]),
|
||||
failed_executions=int(row["failed_executions"]),
|
||||
)
|
||||
for row in result
|
||||
]
|
||||
|
||||
@@ -3,6 +3,7 @@ import uuid
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
import prisma
|
||||
from prisma import Json
|
||||
from prisma.enums import SubmissionStatus
|
||||
from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVersion
|
||||
@@ -13,7 +14,7 @@ from prisma.types import (
|
||||
AgentNodeLinkCreateInput,
|
||||
StoreListingVersionWhereInput,
|
||||
)
|
||||
from pydantic import Field, JsonValue, create_model
|
||||
from pydantic import JsonValue, create_model
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
@@ -30,7 +31,7 @@ from backend.integrations.providers import ProviderName
|
||||
from backend.util import type as type_utils
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
from .db import BaseDbModel, query_raw_with_schema, transaction
|
||||
from .db import BaseDbModel, transaction
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -188,23 +189,6 @@ class BaseGraph(BaseDbModel):
|
||||
)
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def has_external_trigger(self) -> bool:
|
||||
return self.webhook_input_node is not None
|
||||
|
||||
@property
|
||||
def webhook_input_node(self) -> Node | None:
|
||||
return next(
|
||||
(
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_schema(
|
||||
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
|
||||
@@ -342,6 +326,11 @@ class GraphModel(Graph):
|
||||
user_id: str
|
||||
nodes: list[NodeModel] = [] # type: ignore
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def has_webhook_trigger(self) -> bool:
|
||||
return self.webhook_input_node is not None
|
||||
|
||||
@property
|
||||
def starting_nodes(self) -> list[NodeModel]:
|
||||
outbound_nodes = {link.sink_id for link in self.links}
|
||||
@@ -354,12 +343,17 @@ class GraphModel(Graph):
|
||||
if node.id not in outbound_nodes or node.id in input_nodes
|
||||
]
|
||||
|
||||
def meta(self) -> "GraphMeta":
|
||||
"""
|
||||
Returns a GraphMeta object with metadata about the graph.
|
||||
This is used to return metadata about the graph without exposing nodes and links.
|
||||
"""
|
||||
return GraphMeta.from_graph(self)
|
||||
@property
|
||||
def webhook_input_node(self) -> NodeModel | None:
|
||||
return next(
|
||||
(
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||
"""
|
||||
@@ -618,18 +612,6 @@ class GraphModel(Graph):
|
||||
)
|
||||
|
||||
|
||||
class GraphMeta(Graph):
|
||||
user_id: str
|
||||
|
||||
# Easy work-around to prevent exposing nodes and links in the API response
|
||||
nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
|
||||
links: list[Link] = Field(default=[], exclude=True)
|
||||
|
||||
@staticmethod
|
||||
def from_graph(graph: GraphModel) -> "GraphMeta":
|
||||
return GraphMeta(**graph.model_dump())
|
||||
|
||||
|
||||
# --------------------- CRUD functions --------------------- #
|
||||
|
||||
|
||||
@@ -658,10 +640,10 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
|
||||
return NodeModel.from_db(node)
|
||||
|
||||
|
||||
async def list_graphs(
|
||||
async def get_graphs(
|
||||
user_id: str,
|
||||
filter_by: Literal["active"] | None = "active",
|
||||
) -> list[GraphMeta]:
|
||||
) -> list[GraphModel]:
|
||||
"""
|
||||
Retrieves graph metadata objects.
|
||||
Default behaviour is to get all currently active graphs.
|
||||
@@ -671,7 +653,7 @@ async def list_graphs(
|
||||
user_id: The ID of the user that owns the graph.
|
||||
|
||||
Returns:
|
||||
list[GraphMeta]: A list of objects representing the retrieved graphs.
|
||||
list[GraphModel]: A list of objects representing the retrieved graphs.
|
||||
"""
|
||||
where_clause: AgentGraphWhereInput = {"userId": user_id}
|
||||
|
||||
@@ -685,13 +667,13 @@ async def list_graphs(
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
|
||||
graph_models: list[GraphMeta] = []
|
||||
graph_models = []
|
||||
for graph in graphs:
|
||||
try:
|
||||
graph_meta = GraphModel.from_db(graph).meta()
|
||||
# Trigger serialization to validate that the graph is well formed
|
||||
graph_meta.model_dump()
|
||||
graph_models.append(graph_meta)
|
||||
graph_model = GraphModel.from_db(graph)
|
||||
# Trigger serialization to validate that the graph is well formed.
|
||||
graph_model.model_dump()
|
||||
graph_models.append(graph_model)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing graph {graph.id}: {e}")
|
||||
continue
|
||||
@@ -1058,13 +1040,13 @@ async def fix_llm_provider_credentials():
|
||||
|
||||
broken_nodes = []
|
||||
try:
|
||||
broken_nodes = await query_raw_with_schema(
|
||||
broken_nodes = await prisma.get_client().query_raw(
|
||||
"""
|
||||
SELECT graph."userId" user_id,
|
||||
node.id node_id,
|
||||
node."constantInput" node_preset_input
|
||||
FROM {schema_prefix}"AgentNode" node
|
||||
LEFT JOIN {schema_prefix}"AgentGraph" graph
|
||||
FROM platform."AgentNode" node
|
||||
LEFT JOIN platform."AgentGraph" graph
|
||||
ON node."agentGraphId" = graph.id
|
||||
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
|
||||
ORDER BY graph."userId";
|
||||
|
||||
@@ -636,35 +636,6 @@ class NodeExecutionStats(BaseModel):
|
||||
llm_retry_count: int = 0
|
||||
input_token_count: int = 0
|
||||
output_token_count: int = 0
|
||||
extra_cost: int = 0
|
||||
extra_steps: int = 0
|
||||
|
||||
def __iadd__(self, other: "NodeExecutionStats") -> "NodeExecutionStats":
|
||||
"""Mutate this instance by adding another NodeExecutionStats."""
|
||||
if not isinstance(other, NodeExecutionStats):
|
||||
return NotImplemented
|
||||
|
||||
stats_dict = other.model_dump()
|
||||
current_stats = self.model_dump()
|
||||
|
||||
for key, value in stats_dict.items():
|
||||
if key not in current_stats:
|
||||
# Field doesn't exist yet, just set it
|
||||
setattr(self, key, value)
|
||||
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
|
||||
current_stats[key].update(value)
|
||||
setattr(self, key, current_stats[key])
|
||||
elif isinstance(value, (int, float)) and isinstance(
|
||||
current_stats[key], (int, float)
|
||||
):
|
||||
setattr(self, key, current_stats[key] + value)
|
||||
elif isinstance(value, list) and isinstance(current_stats[key], list):
|
||||
current_stats[key].extend(value)
|
||||
setattr(self, key, current_stats[key])
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class GraphExecutionStats(BaseModel):
|
||||
|
||||
@@ -5,7 +5,6 @@ from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
create_graph_execution,
|
||||
get_block_error_stats,
|
||||
get_execution_kv_data,
|
||||
get_graph_execution,
|
||||
get_graph_execution_meta,
|
||||
@@ -106,7 +105,6 @@ class DatabaseManager(AppService):
|
||||
upsert_execution_output = _(upsert_execution_output)
|
||||
get_execution_kv_data = _(get_execution_kv_data)
|
||||
set_execution_kv_data = _(set_execution_kv_data)
|
||||
get_block_error_stats = _(get_block_error_stats)
|
||||
|
||||
# Graphs
|
||||
get_node = _(get_node)
|
||||
@@ -201,9 +199,6 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Block error monitoring
|
||||
get_block_error_stats = _(d.get_block_error_stats)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -231,4 +226,3 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
update_user_integrations = d.update_user_integrations
|
||||
get_execution_kv_data = d.get_execution_kv_data
|
||||
set_execution_kv_data = d.set_execution_kv_data
|
||||
get_block_error_stats = d.get_block_error_stats
|
||||
|
||||
@@ -207,7 +207,9 @@ async def execute_node(
|
||||
|
||||
# Update execution stats
|
||||
if execution_stats is not None:
|
||||
execution_stats += node_block.execution_stats
|
||||
execution_stats = execution_stats.model_copy(
|
||||
update=node_block.execution_stats.model_dump()
|
||||
)
|
||||
execution_stats.input_size = input_size
|
||||
execution_stats.output_size = output_size
|
||||
|
||||
@@ -646,10 +648,9 @@ class Executor:
|
||||
return
|
||||
|
||||
nonlocal execution_stats
|
||||
execution_stats.node_count += 1 + result.extra_steps
|
||||
execution_stats.node_count += 1
|
||||
execution_stats.nodes_cputime += result.cputime
|
||||
execution_stats.nodes_walltime += result.walltime
|
||||
execution_stats.cost += result.extra_cost
|
||||
if (err := result.error) and isinstance(err, Exception):
|
||||
execution_stats.node_error_count += 1
|
||||
update_node_execution_status(
|
||||
@@ -876,7 +877,6 @@ class Executor:
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
include_exec_data=False,
|
||||
)
|
||||
db_client.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in inflight_executions],
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
@@ -13,24 +14,25 @@ from apscheduler.schedulers.blocking import BlockingScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from dotenv import load_dotenv
|
||||
from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.monitoring import (
|
||||
NotificationJobArgs,
|
||||
process_existing_batches,
|
||||
process_weekly_summary,
|
||||
report_block_error_rates,
|
||||
report_late_executions,
|
||||
)
|
||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.logging import PrefixFilter
|
||||
from backend.util.service import AppService, AppServiceClient, endpoint_to_async, expose
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
|
||||
|
||||
@@ -69,6 +71,11 @@ def job_listener(event):
|
||||
logger.info(f"Job {event.job_id} completed successfully.")
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_client():
|
||||
return get_service_client(NotificationManagerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_event_loop():
|
||||
return asyncio.new_event_loop()
|
||||
@@ -82,7 +89,7 @@ async def _execute_graph(**kwargs):
|
||||
args = GraphExecutionJobArgs(**kwargs)
|
||||
try:
|
||||
logger.info(f"Executing recurring job for graph #{args.graph_id}")
|
||||
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
|
||||
await execution_utils.add_graph_execution(
|
||||
user_id=args.user_id,
|
||||
graph_id=args.graph_id,
|
||||
graph_version=args.graph_version,
|
||||
@@ -90,19 +97,65 @@ async def _execute_graph(**kwargs):
|
||||
graph_credentials_inputs=args.input_credentials,
|
||||
use_db_query=False,
|
||||
)
|
||||
logger.info(
|
||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing graph {args.graph_id}: {e}")
|
||||
|
||||
|
||||
def cleanup_expired_files():
|
||||
"""Clean up expired files from cloud storage."""
|
||||
get_event_loop().run_until_complete(cleanup_expired_files_async())
|
||||
class LateExecutionException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Monitoring functions are now imported from monitoring module
|
||||
def report_late_executions() -> str:
|
||||
late_executions = execution_utils.get_db_client().get_graph_executions(
|
||||
statuses=[ExecutionStatus.QUEUED],
|
||||
created_time_gte=datetime.now(timezone.utc)
|
||||
- timedelta(seconds=config.execution_late_notification_checkrange_secs),
|
||||
created_time_lte=datetime.now(timezone.utc)
|
||||
- timedelta(seconds=config.execution_late_notification_threshold_secs),
|
||||
limit=1000,
|
||||
)
|
||||
|
||||
if not late_executions:
|
||||
return "No late executions detected."
|
||||
|
||||
num_late_executions = len(late_executions)
|
||||
num_users = len(set([r.user_id for r in late_executions]))
|
||||
|
||||
late_execution_details = [
|
||||
f"* `Execution ID: {exec.id}, Graph ID: {exec.graph_id}v{exec.graph_version}, User ID: {exec.user_id}, Created At: {exec.started_at.isoformat()}`"
|
||||
for exec in late_executions
|
||||
]
|
||||
|
||||
error = LateExecutionException(
|
||||
f"Late executions detected: {num_late_executions} late executions from {num_users} users "
|
||||
f"in the last {config.execution_late_notification_checkrange_secs} seconds. "
|
||||
f"Graph has been queued for more than {config.execution_late_notification_threshold_secs} seconds. "
|
||||
"Please check the executor status. Details:\n"
|
||||
+ "\n".join(late_execution_details)
|
||||
)
|
||||
msg = str(error)
|
||||
sentry_capture_error(error)
|
||||
get_notification_client().discord_system_alert(msg)
|
||||
return msg
|
||||
|
||||
|
||||
def process_existing_batches(**kwargs):
|
||||
args = NotificationJobArgs(**kwargs)
|
||||
try:
|
||||
logger.info(
|
||||
f"Processing existing batches for notification type {args.notification_types}"
|
||||
)
|
||||
get_notification_client().process_existing_batches(args.notification_types)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing existing batches: {e}")
|
||||
|
||||
|
||||
def process_weekly_summary(**kwargs):
|
||||
try:
|
||||
logger.info("Processing weekly summary")
|
||||
get_notification_client().queue_weekly_summary()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing weekly summary: {e}")
|
||||
|
||||
|
||||
class Jobstores(Enum):
|
||||
@@ -137,6 +190,11 @@ class GraphExecutionJobInfo(GraphExecutionJobArgs):
|
||||
)
|
||||
|
||||
|
||||
class NotificationJobArgs(BaseModel):
|
||||
notification_types: list[NotificationType]
|
||||
cron: str
|
||||
|
||||
|
||||
class NotificationJobInfo(NotificationJobArgs):
|
||||
id: str
|
||||
name: str
|
||||
@@ -229,27 +287,6 @@ class Scheduler(AppService):
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
# Block Error Rate Monitoring
|
||||
self.scheduler.add_job(
|
||||
report_block_error_rates,
|
||||
id="report_block_error_rates",
|
||||
trigger="interval",
|
||||
replace_existing=True,
|
||||
seconds=config.block_error_rate_check_interval_secs,
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
# Cloud Storage Cleanup - configurable interval
|
||||
self.scheduler.add_job(
|
||||
cleanup_expired_files,
|
||||
id="cleanup_expired_files",
|
||||
trigger="interval",
|
||||
replace_existing=True,
|
||||
seconds=config.cloud_storage_cleanup_interval_hours
|
||||
* 3600, # Convert hours to seconds
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||
self.scheduler.start()
|
||||
|
||||
@@ -342,15 +379,6 @@ class Scheduler(AppService):
|
||||
def execute_report_late_executions(self):
|
||||
return report_late_executions()
|
||||
|
||||
@expose
|
||||
def execute_report_block_error_rates(self):
|
||||
return report_block_error_rates()
|
||||
|
||||
@expose
|
||||
def execute_cleanup_expired_files(self):
|
||||
"""Manually trigger cleanup of expired cloud storage files."""
|
||||
return cleanup_expired_files()
|
||||
|
||||
|
||||
class SchedulerClient(AppServiceClient):
|
||||
@classmethod
|
||||
|
||||
@@ -685,7 +685,7 @@ async def stop_graph_execution(
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
use_db_query: bool = True,
|
||||
wait_timeout: float = 15.0,
|
||||
wait_timeout: float = 60.0,
|
||||
):
|
||||
"""
|
||||
Mechanism:
|
||||
@@ -720,58 +720,32 @@ async def stop_graph_execution(
|
||||
ExecutionStatus.FAILED,
|
||||
]:
|
||||
# If graph execution is terminated/completed/failed, cancellation is complete
|
||||
await get_async_execution_event_bus().publish(graph_exec)
|
||||
return
|
||||
|
||||
if graph_exec.status in [
|
||||
elif graph_exec.status in [
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
]:
|
||||
break
|
||||
|
||||
if graph_exec.status == ExecutionStatus.RUNNING:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Set the termination status if the graph is not stopped after the timeout.
|
||||
if graph_exec := await db.get_graph_execution_meta(
|
||||
execution_id=graph_exec_id, user_id=user_id
|
||||
):
|
||||
# If the graph is still on the queue, we can prevent them from being executed
|
||||
# by setting the status to TERMINATED.
|
||||
node_execs = await db.get_node_executions(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
include_exec_data=False,
|
||||
)
|
||||
|
||||
graph_exec.status = ExecutionStatus.TERMINATED
|
||||
for node_exec in node_execs:
|
||||
node_exec.status = ExecutionStatus.TERMINATED
|
||||
|
||||
await asyncio.gather(
|
||||
# Update node execution statuses
|
||||
db.update_node_execution_status_batch(
|
||||
# If the graph is still on the queue, we can prevent them from being executed
|
||||
# by setting the status to TERMINATED.
|
||||
node_execs = await db.get_node_executions(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE],
|
||||
)
|
||||
await db.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
ExecutionStatus.TERMINATED,
|
||||
),
|
||||
# Publish node execution events
|
||||
*[
|
||||
get_async_execution_event_bus().publish(node_exec)
|
||||
for node_exec in node_execs
|
||||
],
|
||||
)
|
||||
await asyncio.gather(
|
||||
# Update graph execution status
|
||||
db.update_graph_execution_stats(
|
||||
)
|
||||
await db.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec_id,
|
||||
status=ExecutionStatus.TERMINATED,
|
||||
),
|
||||
# Publish graph execution event
|
||||
get_async_execution_event_bus().publish(graph_exec),
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for graph execution #{graph_exec_id} to terminate."
|
||||
)
|
||||
|
||||
|
||||
async def add_graph_execution(
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
"""Monitoring module for platform health and alerting."""
|
||||
|
||||
from .block_error_monitor import BlockErrorMonitor, report_block_error_rates
|
||||
from .late_execution_monitor import (
|
||||
LateExecutionException,
|
||||
LateExecutionMonitor,
|
||||
report_late_executions,
|
||||
)
|
||||
from .notification_monitor import (
|
||||
NotificationJobArgs,
|
||||
process_existing_batches,
|
||||
process_weekly_summary,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BlockErrorMonitor",
|
||||
"LateExecutionMonitor",
|
||||
"LateExecutionException",
|
||||
"NotificationJobArgs",
|
||||
"report_block_error_rates",
|
||||
"report_late_executions",
|
||||
"process_existing_batches",
|
||||
"process_weekly_summary",
|
||||
]
|
||||
@@ -1,291 +0,0 @@
|
||||
"""Block error rate monitoring module."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
|
||||
|
||||
class BlockStatsWithSamples(BaseModel):
|
||||
"""Enhanced block stats with error samples."""
|
||||
|
||||
block_id: str
|
||||
block_name: str
|
||||
total_executions: int
|
||||
failed_executions: int
|
||||
error_samples: list[str] = []
|
||||
|
||||
@property
|
||||
def error_rate(self) -> float:
|
||||
"""Calculate error rate as a percentage."""
|
||||
if self.total_executions == 0:
|
||||
return 0.0
|
||||
return (self.failed_executions / self.total_executions) * 100
|
||||
|
||||
|
||||
class BlockErrorMonitor:
|
||||
"""Monitor block error rates and send alerts when thresholds are exceeded."""
|
||||
|
||||
def __init__(self, include_top_blocks: int | None = None):
|
||||
self.config = config
|
||||
self.notification_client = get_service_client(NotificationManagerClient)
|
||||
self.include_top_blocks = (
|
||||
include_top_blocks
|
||||
if include_top_blocks is not None
|
||||
else config.block_error_include_top_blocks
|
||||
)
|
||||
|
||||
def check_block_error_rates(self) -> str:
|
||||
"""Check block error rates and send Discord alerts if thresholds are exceeded."""
|
||||
try:
|
||||
logger.info("Checking block error rates")
|
||||
|
||||
# Get executions from the last 24 hours
|
||||
end_time = datetime.now(timezone.utc)
|
||||
start_time = end_time - timedelta(hours=24)
|
||||
|
||||
# Use SQL aggregation to efficiently count totals and failures by block
|
||||
block_stats = self._get_block_stats_from_db(start_time, end_time)
|
||||
|
||||
# For blocks with high error rates, fetch error samples
|
||||
threshold = self.config.block_error_rate_threshold
|
||||
for block_name, stats in block_stats.items():
|
||||
if stats.total_executions >= 10 and stats.error_rate >= threshold * 100:
|
||||
# Only fetch error samples for blocks that exceed threshold
|
||||
error_samples = self._get_error_samples_for_block(
|
||||
stats.block_id, start_time, end_time, limit=3
|
||||
)
|
||||
stats.error_samples = error_samples
|
||||
|
||||
# Check thresholds and send alerts
|
||||
critical_alerts = self._generate_critical_alerts(block_stats, threshold)
|
||||
|
||||
if critical_alerts:
|
||||
msg = "Block Error Rate Alert:\n\n" + "\n\n".join(critical_alerts)
|
||||
self.notification_client.discord_system_alert(msg)
|
||||
logger.info(
|
||||
f"Sent block error rate alert for {len(critical_alerts)} blocks"
|
||||
)
|
||||
return f"Alert sent for {len(critical_alerts)} blocks with high error rates"
|
||||
|
||||
# If no critical alerts, check if we should show top blocks
|
||||
if self.include_top_blocks > 0:
|
||||
top_blocks_msg = self._generate_top_blocks_alert(
|
||||
block_stats, start_time, end_time
|
||||
)
|
||||
if top_blocks_msg:
|
||||
self.notification_client.discord_system_alert(top_blocks_msg)
|
||||
logger.info("Sent top blocks summary")
|
||||
return "Sent top blocks summary"
|
||||
|
||||
logger.info("No blocks exceeded error rate threshold")
|
||||
return "No errors reported for today"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error checking block error rates: {e}")
|
||||
|
||||
error = Exception(f"Error checking block error rates: {e}")
|
||||
msg = str(error)
|
||||
sentry_capture_error(error)
|
||||
self.notification_client.discord_system_alert(msg)
|
||||
return msg
|
||||
|
||||
def _get_block_stats_from_db(
|
||||
self, start_time: datetime, end_time: datetime
|
||||
) -> dict[str, BlockStatsWithSamples]:
|
||||
"""Get block execution stats using efficient SQL aggregation."""
|
||||
|
||||
result = execution_utils.get_db_client().get_block_error_stats(
|
||||
start_time, end_time
|
||||
)
|
||||
|
||||
block_stats = {}
|
||||
for stats in result:
|
||||
block_name = b.name if (b := get_block(stats.block_id)) else "Unknown"
|
||||
|
||||
block_stats[block_name] = BlockStatsWithSamples(
|
||||
block_id=stats.block_id,
|
||||
block_name=block_name,
|
||||
total_executions=stats.total_executions,
|
||||
failed_executions=stats.failed_executions,
|
||||
error_samples=[],
|
||||
)
|
||||
|
||||
return block_stats
|
||||
|
||||
def _generate_critical_alerts(
|
||||
self, block_stats: dict[str, BlockStatsWithSamples], threshold: float
|
||||
) -> list[str]:
|
||||
"""Generate alerts for blocks that exceed the error rate threshold."""
|
||||
alerts = []
|
||||
|
||||
for block_name, stats in block_stats.items():
|
||||
if stats.total_executions >= 10 and stats.error_rate >= threshold * 100:
|
||||
error_groups = self._group_similar_errors(stats.error_samples)
|
||||
|
||||
alert_msg = (
|
||||
f"🚨 Block '{block_name}' has {stats.error_rate:.1f}% error rate "
|
||||
f"({stats.failed_executions}/{stats.total_executions}) in the last 24 hours"
|
||||
)
|
||||
|
||||
if error_groups:
|
||||
alert_msg += "\n\n📊 Error Types:"
|
||||
for error_pattern, count in error_groups.items():
|
||||
alert_msg += f"\n• {error_pattern} ({count}x)"
|
||||
|
||||
alerts.append(alert_msg)
|
||||
|
||||
return alerts
|
||||
|
||||
def _generate_top_blocks_alert(
|
||||
self,
|
||||
block_stats: dict[str, BlockStatsWithSamples],
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> str | None:
|
||||
"""Generate top blocks summary when no critical alerts exist."""
|
||||
top_error_blocks = sorted(
|
||||
[
|
||||
(name, stats)
|
||||
for name, stats in block_stats.items()
|
||||
if stats.total_executions >= 10 and stats.failed_executions > 0
|
||||
],
|
||||
key=lambda x: x[1].failed_executions,
|
||||
reverse=True,
|
||||
)[: self.include_top_blocks]
|
||||
|
||||
if not top_error_blocks:
|
||||
return "✅ No errors reported for today - all blocks are running smoothly!"
|
||||
|
||||
# Get error samples for top blocks
|
||||
for block_name, stats in top_error_blocks:
|
||||
if not stats.error_samples:
|
||||
stats.error_samples = self._get_error_samples_for_block(
|
||||
stats.block_id, start_time, end_time, limit=2
|
||||
)
|
||||
|
||||
count_text = (
|
||||
f"top {self.include_top_blocks}" if self.include_top_blocks > 1 else "top"
|
||||
)
|
||||
alert_msg = f"📊 Daily Error Summary - {count_text} blocks with most errors:"
|
||||
for block_name, stats in top_error_blocks:
|
||||
alert_msg += f"\n• {block_name}: {stats.failed_executions} errors ({stats.error_rate:.1f}% of {stats.total_executions})"
|
||||
|
||||
if stats.error_samples:
|
||||
error_groups = self._group_similar_errors(stats.error_samples)
|
||||
if error_groups:
|
||||
# Show most common error
|
||||
most_common_error = next(iter(error_groups.items()))
|
||||
alert_msg += f"\n └ Most common: {most_common_error[0]}"
|
||||
|
||||
return alert_msg
|
||||
|
||||
def _get_error_samples_for_block(
|
||||
self, block_id: str, start_time: datetime, end_time: datetime, limit: int = 3
|
||||
) -> list[str]:
|
||||
"""Get error samples for a specific block - just a few recent ones."""
|
||||
# Only fetch a small number of recent failed executions for this specific block
|
||||
executions = execution_utils.get_db_client().get_node_executions(
|
||||
block_ids=[block_id],
|
||||
statuses=[ExecutionStatus.FAILED],
|
||||
created_time_gte=start_time,
|
||||
created_time_lte=end_time,
|
||||
limit=limit, # Just get the limit we need
|
||||
)
|
||||
|
||||
error_samples = []
|
||||
for execution in executions:
|
||||
if error_message := self._extract_error_message(execution):
|
||||
masked_error = self._mask_sensitive_data(error_message)
|
||||
error_samples.append(masked_error)
|
||||
|
||||
if len(error_samples) >= limit: # Stop once we have enough samples
|
||||
break
|
||||
|
||||
return error_samples
|
||||
|
||||
def _extract_error_message(self, execution: NodeExecutionResult) -> str | None:
|
||||
"""Extract error message from execution output."""
|
||||
try:
|
||||
if execution.output_data and (
|
||||
error_msg := execution.output_data.get("error")
|
||||
):
|
||||
return str(error_msg[0])
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _mask_sensitive_data(self, error_message):
|
||||
"""Mask sensitive data in error messages to enable grouping."""
|
||||
if not error_message:
|
||||
return ""
|
||||
|
||||
# Convert to string if not already
|
||||
error_str = str(error_message)
|
||||
|
||||
# Mask numbers (replace with X)
|
||||
error_str = re.sub(r"\d+", "X", error_str)
|
||||
|
||||
# Mask all caps words (likely constants/IDs)
|
||||
error_str = re.sub(r"\b[A-Z_]{3,}\b", "MASKED", error_str)
|
||||
|
||||
# Mask words with underscores (likely internal variables)
|
||||
error_str = re.sub(r"\b\w*_\w*\b", "MASKED", error_str)
|
||||
|
||||
# Mask UUIDs and long alphanumeric strings
|
||||
error_str = re.sub(
|
||||
r"\b[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}\b",
|
||||
"UUID",
|
||||
error_str,
|
||||
)
|
||||
error_str = re.sub(r"\b[a-f0-9]{20,}\b", "HASH", error_str)
|
||||
|
||||
# Mask file paths
|
||||
error_str = re.sub(r"(/[^/\s]+)+", "/MASKED/path", error_str)
|
||||
|
||||
# Mask URLs
|
||||
error_str = re.sub(r"https?://[^\s]+", "URL", error_str)
|
||||
|
||||
# Mask email addresses
|
||||
error_str = re.sub(
|
||||
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", "EMAIL", error_str
|
||||
)
|
||||
|
||||
# Truncate if too long
|
||||
if len(error_str) > 100:
|
||||
error_str = error_str[:97] + "..."
|
||||
|
||||
return error_str.strip()
|
||||
|
||||
def _group_similar_errors(self, error_samples):
|
||||
"""Group similar error messages and return counts."""
|
||||
if not error_samples:
|
||||
return {}
|
||||
|
||||
error_groups = {}
|
||||
for error in error_samples:
|
||||
if error in error_groups:
|
||||
error_groups[error] += 1
|
||||
else:
|
||||
error_groups[error] = 1
|
||||
|
||||
# Sort by frequency, most common first
|
||||
return dict(sorted(error_groups.items(), key=lambda x: x[1], reverse=True))
|
||||
|
||||
|
||||
def report_block_error_rates(include_top_blocks: int | None = None):
|
||||
"""Check block error rates and send Discord alerts if thresholds are exceeded."""
|
||||
monitor = BlockErrorMonitor(include_top_blocks=include_top_blocks)
|
||||
return monitor.check_block_error_rates()
|
||||
@@ -1,71 +0,0 @@
|
||||
"""Late execution monitoring module."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
|
||||
|
||||
class LateExecutionException(Exception):
|
||||
"""Exception raised when late executions are detected."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LateExecutionMonitor:
|
||||
"""Monitor late executions and send alerts when thresholds are exceeded."""
|
||||
|
||||
def __init__(self):
|
||||
self.config = config
|
||||
self.notification_client = get_service_client(NotificationManagerClient)
|
||||
|
||||
def check_late_executions(self) -> str:
|
||||
"""Check for late executions and send alerts if found."""
|
||||
late_executions = execution_utils.get_db_client().get_graph_executions(
|
||||
statuses=[ExecutionStatus.QUEUED],
|
||||
created_time_gte=datetime.now(timezone.utc)
|
||||
- timedelta(
|
||||
seconds=self.config.execution_late_notification_checkrange_secs
|
||||
),
|
||||
created_time_lte=datetime.now(timezone.utc)
|
||||
- timedelta(seconds=self.config.execution_late_notification_threshold_secs),
|
||||
limit=1000,
|
||||
)
|
||||
|
||||
if not late_executions:
|
||||
return "No late executions detected."
|
||||
|
||||
num_late_executions = len(late_executions)
|
||||
num_users = len(set([r.user_id for r in late_executions]))
|
||||
|
||||
late_execution_details = [
|
||||
f"* `Execution ID: {exec.id}, Graph ID: {exec.graph_id}v{exec.graph_version}, User ID: {exec.user_id}, Created At: {exec.started_at.isoformat()}`"
|
||||
for exec in late_executions
|
||||
]
|
||||
|
||||
error = LateExecutionException(
|
||||
f"Late executions detected: {num_late_executions} late executions from {num_users} users "
|
||||
f"in the last {self.config.execution_late_notification_checkrange_secs} seconds. "
|
||||
f"Graph has been queued for more than {self.config.execution_late_notification_threshold_secs} seconds. "
|
||||
"Please check the executor status. Details:\n"
|
||||
+ "\n".join(late_execution_details)
|
||||
)
|
||||
msg = str(error)
|
||||
|
||||
sentry_capture_error(error)
|
||||
self.notification_client.discord_system_alert(msg)
|
||||
return msg
|
||||
|
||||
|
||||
def report_late_executions() -> str:
|
||||
"""Check for late executions and send Discord alerts if found."""
|
||||
monitor = LateExecutionMonitor()
|
||||
return monitor.check_late_executions()
|
||||
@@ -1,45 +0,0 @@
|
||||
"""Notification processing monitoring module."""
|
||||
|
||||
import logging
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationJobArgs(BaseModel):
|
||||
notification_types: list[NotificationType]
|
||||
cron: str
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_manager_client():
|
||||
return get_service_client(NotificationManagerClient)
|
||||
|
||||
|
||||
def process_existing_batches(**kwargs):
|
||||
"""Process existing notification batches."""
|
||||
args = NotificationJobArgs(**kwargs)
|
||||
try:
|
||||
logging.info(
|
||||
f"Processing existing batches for notification type {args.notification_types}"
|
||||
)
|
||||
get_notification_manager_client().process_existing_batches(
|
||||
args.notification_types
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing existing batches: {e}")
|
||||
|
||||
|
||||
def process_weekly_summary(**kwargs):
|
||||
"""Process weekly summary notifications."""
|
||||
try:
|
||||
logging.info("Processing weekly summary")
|
||||
get_notification_manager_client().queue_weekly_summary()
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing weekly summary: {e}")
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Optional, Sequence
|
||||
from typing import Annotated, Any, Dict, List, Optional, Sequence
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
@@ -11,6 +11,7 @@ from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKey
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.execution import NodeExecutionResult
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.util.settings import Settings
|
||||
@@ -29,19 +30,30 @@ class NodeOutput(TypedDict):
|
||||
class ExecutionNode(TypedDict):
|
||||
node_id: str
|
||||
input: Any
|
||||
output: dict[str, Any]
|
||||
output: Dict[str, Any]
|
||||
|
||||
|
||||
class ExecutionNodeOutput(TypedDict):
|
||||
node_id: str
|
||||
outputs: list[NodeOutput]
|
||||
outputs: List[NodeOutput]
|
||||
|
||||
|
||||
class GraphExecutionResult(TypedDict):
|
||||
execution_id: str
|
||||
status: str
|
||||
nodes: list[ExecutionNode]
|
||||
output: Optional[list[dict[str, str]]]
|
||||
nodes: List[ExecutionNode]
|
||||
output: Optional[List[Dict[str, str]]]
|
||||
|
||||
|
||||
def get_outputs_with_names(results: list[NodeExecutionResult]) -> list[dict[str, str]]:
|
||||
outputs = []
|
||||
for result in results:
|
||||
if "output" in result.output_data:
|
||||
output_value = result.output_data["output"][0]
|
||||
name = result.output_data.get("name", [None])[0]
|
||||
if output_value and name:
|
||||
outputs.append({name: output_value})
|
||||
return outputs
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -110,34 +122,23 @@ async def get_graph_execution_results(
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
graph_exec = await execution_db.get_graph_execution(
|
||||
user_id=api_key.user_id,
|
||||
execution_id=graph_exec_id,
|
||||
include_node_executions=True,
|
||||
results = await execution_db.get_node_executions(graph_exec_id)
|
||||
last_result = results[-1] if results else None
|
||||
execution_status = (
|
||||
last_result.status if last_result else AgentExecutionStatus.INCOMPLETE
|
||||
)
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
|
||||
)
|
||||
outputs = get_outputs_with_names(results)
|
||||
|
||||
return GraphExecutionResult(
|
||||
execution_id=graph_exec_id,
|
||||
status=graph_exec.status.value,
|
||||
status=execution_status,
|
||||
nodes=[
|
||||
ExecutionNode(
|
||||
node_id=node_exec.node_id,
|
||||
input=node_exec.input_data.get("value", node_exec.input_data),
|
||||
output={k: v for k, v in node_exec.output_data.items()},
|
||||
node_id=result.node_id,
|
||||
input=result.input_data.get("value", result.input_data),
|
||||
output={k: v for k, v in result.output_data.items()},
|
||||
)
|
||||
for node_exec in graph_exec.node_executions
|
||||
for result in results
|
||||
],
|
||||
output=(
|
||||
[
|
||||
{name: value}
|
||||
for name, values in graph_exec.outputs.items()
|
||||
for value in values
|
||||
]
|
||||
if graph_exec.status == AgentExecutionStatus.COMPLETED
|
||||
else None
|
||||
),
|
||||
output=outputs if execution_status == AgentExecutionStatus.COMPLETED else None,
|
||||
)
|
||||
|
||||
@@ -77,11 +77,3 @@ class Pagination(pydantic.BaseModel):
|
||||
|
||||
class RequestTopUp(pydantic.BaseModel):
|
||||
credit_amount: int
|
||||
|
||||
|
||||
class UploadFileResponse(pydantic.BaseModel):
|
||||
file_uri: str
|
||||
file_name: str
|
||||
size: int
|
||||
content_type: str
|
||||
expires_in_hours: int
|
||||
|
||||
@@ -14,7 +14,6 @@ from autogpt_libs.feature_flag.client import (
|
||||
shutdown_launchdarkly,
|
||||
)
|
||||
from autogpt_libs.logging.utils import generate_uvicorn_config
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
@@ -40,7 +39,6 @@ from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.external.api import external_app
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.util import json
|
||||
|
||||
settings = backend.util.settings.Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -153,7 +151,7 @@ def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
|
||||
async def validation_error_handler(
|
||||
request: fastapi.Request, exc: Exception
|
||||
) -> fastapi.responses.Response:
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
logger.error(
|
||||
"Validation failed for %s %s: %s. Fix the request payload and try again.",
|
||||
request.method,
|
||||
@@ -165,19 +163,13 @@ async def validation_error_handler(
|
||||
errors = exc.errors() # type: ignore[call-arg]
|
||||
else:
|
||||
errors = str(exc)
|
||||
|
||||
response_content = {
|
||||
"message": f"Invalid data for {request.method} {request.url.path}",
|
||||
"detail": errors,
|
||||
"hint": "Ensure the request matches the API schema.",
|
||||
}
|
||||
|
||||
content_json = json.dumps(response_content)
|
||||
|
||||
return fastapi.responses.Response(
|
||||
content=content_json,
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=422,
|
||||
media_type="application/json",
|
||||
content={
|
||||
"message": f"Invalid data for {request.method} {request.url.path}",
|
||||
"detail": errors,
|
||||
"hint": "Ensure the request matches the API schema.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -220,19 +212,8 @@ app.include_router(
|
||||
app.mount("/external-api", external_app)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_async_client():
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
|
||||
return backend.util.service.get_service_client(
|
||||
DatabaseManagerAsyncClient,
|
||||
health_check=False,
|
||||
)
|
||||
|
||||
|
||||
@app.get(path="/health", tags=["health"], dependencies=[])
|
||||
async def health():
|
||||
await get_db_async_client().health_check_async()
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
@@ -10,17 +9,7 @@ import stripe
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from autogpt_libs.feature_flag.client import feature_flag
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
Depends,
|
||||
File,
|
||||
HTTPException,
|
||||
Path,
|
||||
Request,
|
||||
Response,
|
||||
UploadFile,
|
||||
)
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request, Response
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
@@ -81,14 +70,11 @@ from backend.server.model import (
|
||||
RequestTopUp,
|
||||
SetGraphActiveVersion,
|
||||
UpdatePermissionsRequest,
|
||||
UploadFileResponse,
|
||||
)
|
||||
from backend.server.utils import get_user_id
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
|
||||
@thread_cached
|
||||
@@ -96,14 +82,6 @@ def execution_scheduler_client() -> scheduler.SchedulerClient:
|
||||
return get_service_client(scheduler.SchedulerClient, health_check=False)
|
||||
|
||||
|
||||
def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
|
||||
"""Create standardized file size error response."""
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File size ({size_bytes} bytes) exceeds the maximum allowed size of {max_size_mb}MB",
|
||||
)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_event_bus() -> AsyncRedisExecutionEventBus:
|
||||
return AsyncRedisExecutionEventBus()
|
||||
@@ -273,92 +251,6 @@ async def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlock
|
||||
return output
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/files/upload",
|
||||
summary="Upload file to cloud storage",
|
||||
tags=["files"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
file: UploadFile = File(...),
|
||||
provider: str = "gcs",
|
||||
expiration_hours: int = 24,
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
Upload a file to cloud storage and return a storage key that can be used
|
||||
with FileStoreBlock and AgentFileInputBlock.
|
||||
|
||||
Args:
|
||||
file: The file to upload
|
||||
user_id: The user ID
|
||||
provider: Cloud storage provider ("gcs", "s3", "azure")
|
||||
expiration_hours: Hours until file expires (1-48)
|
||||
|
||||
Returns:
|
||||
Dict containing the cloud storage path and signed URL
|
||||
"""
|
||||
if expiration_hours < 1 or expiration_hours > 48:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Expiration hours must be between 1 and 48"
|
||||
)
|
||||
|
||||
# Check file size limit before reading content to avoid memory issues
|
||||
max_size_mb = settings.config.upload_file_size_limit_mb
|
||||
max_size_bytes = max_size_mb * 1024 * 1024
|
||||
|
||||
# Try to get file size from headers first
|
||||
if hasattr(file, "size") and file.size is not None and file.size > max_size_bytes:
|
||||
raise _create_file_size_error(file.size, max_size_mb)
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
content_size = len(content)
|
||||
|
||||
# Double-check file size after reading (in case header was missing/incorrect)
|
||||
if content_size > max_size_bytes:
|
||||
raise _create_file_size_error(content_size, max_size_mb)
|
||||
|
||||
# Extract common variables
|
||||
file_name = file.filename or "uploaded_file"
|
||||
content_type = file.content_type or "application/octet-stream"
|
||||
|
||||
# Virus scan the content
|
||||
await scan_content_safe(content, filename=file_name)
|
||||
|
||||
# Check if cloud storage is configured
|
||||
cloud_storage = await get_cloud_storage_handler()
|
||||
if not cloud_storage.config.gcs_bucket_name:
|
||||
# Fallback to base64 data URI when GCS is not configured
|
||||
base64_content = base64.b64encode(content).decode("utf-8")
|
||||
data_uri = f"data:{content_type};base64,{base64_content}"
|
||||
|
||||
return UploadFileResponse(
|
||||
file_uri=data_uri,
|
||||
file_name=file_name,
|
||||
size=content_size,
|
||||
content_type=content_type,
|
||||
expires_in_hours=expiration_hours,
|
||||
)
|
||||
|
||||
# Store in cloud storage
|
||||
storage_path = await cloud_storage.store_file(
|
||||
content=content,
|
||||
filename=file_name,
|
||||
provider=provider,
|
||||
expiration_hours=expiration_hours,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return UploadFileResponse(
|
||||
file_uri=storage_path,
|
||||
file_name=file_name,
|
||||
size=content_size,
|
||||
content_type=content_type,
|
||||
expires_in_hours=expiration_hours,
|
||||
)
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Credits ##########################
|
||||
########################################################
|
||||
@@ -556,10 +448,10 @@ class DeleteGraphResponse(TypedDict):
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def list_graphs(
|
||||
async def get_graphs(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> Sequence[graph_db.GraphMeta]:
|
||||
return await graph_db.list_graphs(filter_by="active", user_id=user_id)
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -788,6 +680,22 @@ async def stop_graph_run(
|
||||
return res[0]
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/executions",
|
||||
summary="Stop graph executions",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def stop_graph_runs(
|
||||
graph_id: str, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[execution_db.GraphExecutionMeta]:
|
||||
return await _stop_graph_run(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
|
||||
async def _stop_graph_run(
|
||||
user_id: str,
|
||||
graph_id: Optional[str] = None,
|
||||
|
||||
@@ -1,21 +1,16 @@
|
||||
import json
|
||||
from io import BytesIO
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
import starlette.datastructures
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.routers.v1 as v1_routes
|
||||
from backend.data.credit import AutoTopUpConfig
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.server.conftest import TEST_USER_ID
|
||||
from backend.server.routers.v1 import upload_file
|
||||
from backend.server.utils import get_user_id
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
@@ -275,7 +270,7 @@ def test_get_graphs(
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.graph_db.list_graphs",
|
||||
"backend.server.routers.v1.graph_db.get_graphs",
|
||||
return_value=[mock_graph],
|
||||
)
|
||||
|
||||
@@ -396,226 +391,3 @@ def test_missing_required_field() -> None:
|
||||
"""Test endpoint with missing required field"""
|
||||
response = client.post("/credits", json={}) # Missing credit_amount
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_success():
|
||||
"""Test successful file upload."""
|
||||
# Create mock upload file
|
||||
file_content = b"test file content"
|
||||
file_obj = BytesIO(file_content)
|
||||
upload_file_mock = UploadFile(
|
||||
filename="test.txt",
|
||||
file=file_obj,
|
||||
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
||||
)
|
||||
|
||||
# Mock dependencies
|
||||
with patch("backend.server.routers.v1.scan_content_safe") as mock_scan, patch(
|
||||
"backend.server.routers.v1.get_cloud_storage_handler"
|
||||
) as mock_handler_getter:
|
||||
|
||||
mock_scan.return_value = None
|
||||
mock_handler = AsyncMock()
|
||||
mock_handler.store_file.return_value = "gcs://test-bucket/uploads/123/test.txt"
|
||||
mock_handler_getter.return_value = mock_handler
|
||||
|
||||
# Mock file.read()
|
||||
upload_file_mock.read = AsyncMock(return_value=file_content)
|
||||
|
||||
result = await upload_file(
|
||||
file=upload_file_mock,
|
||||
user_id="test-user-123",
|
||||
provider="gcs",
|
||||
expiration_hours=24,
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert result.file_uri == "gcs://test-bucket/uploads/123/test.txt"
|
||||
assert result.file_name == "test.txt"
|
||||
assert result.size == len(file_content)
|
||||
assert result.content_type == "text/plain"
|
||||
assert result.expires_in_hours == 24
|
||||
|
||||
# Verify virus scan was called
|
||||
mock_scan.assert_called_once_with(file_content, filename="test.txt")
|
||||
|
||||
# Verify cloud storage operations
|
||||
mock_handler.store_file.assert_called_once_with(
|
||||
content=file_content,
|
||||
filename="test.txt",
|
||||
provider="gcs",
|
||||
expiration_hours=24,
|
||||
user_id="test-user-123",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_no_filename():
|
||||
"""Test file upload without filename."""
|
||||
file_content = b"test content"
|
||||
file_obj = BytesIO(file_content)
|
||||
upload_file_mock = UploadFile(
|
||||
filename=None,
|
||||
file=file_obj,
|
||||
headers=starlette.datastructures.Headers(
|
||||
{"content-type": "application/octet-stream"}
|
||||
),
|
||||
)
|
||||
|
||||
with patch("backend.server.routers.v1.scan_content_safe") as mock_scan, patch(
|
||||
"backend.server.routers.v1.get_cloud_storage_handler"
|
||||
) as mock_handler_getter:
|
||||
|
||||
mock_scan.return_value = None
|
||||
mock_handler = AsyncMock()
|
||||
mock_handler.store_file.return_value = (
|
||||
"gcs://test-bucket/uploads/123/uploaded_file"
|
||||
)
|
||||
mock_handler_getter.return_value = mock_handler
|
||||
|
||||
upload_file_mock.read = AsyncMock(return_value=file_content)
|
||||
|
||||
result = await upload_file(file=upload_file_mock, user_id="test-user-123")
|
||||
|
||||
assert result.file_name == "uploaded_file"
|
||||
assert result.content_type == "application/octet-stream"
|
||||
|
||||
# Verify virus scan was called with default filename
|
||||
mock_scan.assert_called_once_with(file_content, filename="uploaded_file")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_invalid_expiration():
|
||||
"""Test file upload with invalid expiration hours."""
|
||||
file_obj = BytesIO(b"content")
|
||||
upload_file_mock = UploadFile(
|
||||
filename="test.txt",
|
||||
file=file_obj,
|
||||
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
||||
)
|
||||
|
||||
# Test expiration too short
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await upload_file(
|
||||
file=upload_file_mock, user_id="test-user-123", expiration_hours=0
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "between 1 and 48" in exc_info.value.detail
|
||||
|
||||
# Test expiration too long
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await upload_file(
|
||||
file=upload_file_mock, user_id="test-user-123", expiration_hours=49
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "between 1 and 48" in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_virus_scan_failure():
|
||||
"""Test file upload when virus scan fails."""
|
||||
file_content = b"malicious content"
|
||||
file_obj = BytesIO(file_content)
|
||||
upload_file_mock = UploadFile(
|
||||
filename="virus.txt",
|
||||
file=file_obj,
|
||||
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
||||
)
|
||||
|
||||
with patch("backend.server.routers.v1.scan_content_safe") as mock_scan:
|
||||
# Mock virus scan to raise exception
|
||||
mock_scan.side_effect = RuntimeError("Virus detected!")
|
||||
|
||||
upload_file_mock.read = AsyncMock(return_value=file_content)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Virus detected!"):
|
||||
await upload_file(file=upload_file_mock, user_id="test-user-123")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_cloud_storage_failure():
|
||||
"""Test file upload when cloud storage fails."""
|
||||
file_content = b"test content"
|
||||
file_obj = BytesIO(file_content)
|
||||
upload_file_mock = UploadFile(
|
||||
filename="test.txt",
|
||||
file=file_obj,
|
||||
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
||||
)
|
||||
|
||||
with patch("backend.server.routers.v1.scan_content_safe") as mock_scan, patch(
|
||||
"backend.server.routers.v1.get_cloud_storage_handler"
|
||||
) as mock_handler_getter:
|
||||
|
||||
mock_scan.return_value = None
|
||||
mock_handler = AsyncMock()
|
||||
mock_handler.store_file.side_effect = RuntimeError("Storage error!")
|
||||
mock_handler_getter.return_value = mock_handler
|
||||
|
||||
upload_file_mock.read = AsyncMock(return_value=file_content)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Storage error!"):
|
||||
await upload_file(file=upload_file_mock, user_id="test-user-123")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_size_limit_exceeded():
|
||||
"""Test file upload when file size exceeds the limit."""
|
||||
# Create a file that exceeds the default 256MB limit
|
||||
large_file_content = b"x" * (257 * 1024 * 1024) # 257MB
|
||||
file_obj = BytesIO(large_file_content)
|
||||
upload_file_mock = UploadFile(
|
||||
filename="large_file.txt",
|
||||
file=file_obj,
|
||||
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
||||
)
|
||||
|
||||
upload_file_mock.read = AsyncMock(return_value=large_file_content)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await upload_file(file=upload_file_mock, user_id="test-user-123")
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "exceeds the maximum allowed size of 256MB" in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_gcs_not_configured_fallback():
|
||||
"""Test file upload fallback to base64 when GCS is not configured."""
|
||||
file_content = b"test file content"
|
||||
file_obj = BytesIO(file_content)
|
||||
upload_file_mock = UploadFile(
|
||||
filename="test.txt",
|
||||
file=file_obj,
|
||||
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
||||
)
|
||||
|
||||
with patch("backend.server.routers.v1.scan_content_safe") as mock_scan, patch(
|
||||
"backend.server.routers.v1.get_cloud_storage_handler"
|
||||
) as mock_handler_getter:
|
||||
|
||||
mock_scan.return_value = None
|
||||
mock_handler = AsyncMock()
|
||||
mock_handler.config.gcs_bucket_name = "" # Simulate no GCS bucket configured
|
||||
mock_handler_getter.return_value = mock_handler
|
||||
|
||||
upload_file_mock.read = AsyncMock(return_value=file_content)
|
||||
|
||||
result = await upload_file(file=upload_file_mock, user_id="test-user-123")
|
||||
|
||||
# Verify fallback behavior
|
||||
assert result.file_name == "test.txt"
|
||||
assert result.size == len(file_content)
|
||||
assert result.content_type == "text/plain"
|
||||
assert result.expires_in_hours == 24
|
||||
|
||||
# Verify file_uri is base64 data URI
|
||||
expected_data_uri = "data:text/plain;base64,dGVzdCBmaWxlIGNvbnRlbnQ="
|
||||
assert result.file_uri == expected_data_uri
|
||||
|
||||
# Verify virus scan was called
|
||||
mock_scan.assert_called_once_with(file_content, filename="test.txt")
|
||||
|
||||
# Verify cloud storage methods were NOT called
|
||||
mock_handler.store_file.assert_not_called()
|
||||
|
||||
@@ -187,7 +187,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
async def get_library_agent_by_store_version_id(
|
||||
store_listing_version_id: str,
|
||||
user_id: str,
|
||||
) -> library_model.LibraryAgent | None:
|
||||
):
|
||||
"""
|
||||
Get the library agent metadata for a given store listing version ID and user ID.
|
||||
"""
|
||||
@@ -202,7 +202,7 @@ async def get_library_agent_by_store_version_id(
|
||||
)
|
||||
if not store_listing_version:
|
||||
logger.warning(f"Store listing version not found: {store_listing_version_id}")
|
||||
raise NotFoundError(
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
@@ -214,9 +214,12 @@ async def get_library_agent_by_store_version_id(
|
||||
"agentGraphVersion": store_listing_version.agentGraphVersion,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include=library_agent_include(user_id),
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(agent) if agent else None
|
||||
if agent:
|
||||
return library_model.LibraryAgent.from_db(agent)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def get_library_agent_by_graph_id(
|
||||
|
||||
@@ -129,7 +129,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
credentials_input_schema=(
|
||||
graph.credentials_input_schema if sub_graphs is not None else None
|
||||
),
|
||||
has_external_trigger=graph.has_external_trigger,
|
||||
has_external_trigger=graph.has_webhook_trigger,
|
||||
trigger_setup_info=(
|
||||
LibraryAgentTriggerInfo(
|
||||
provider=trigger_block.webhook_config.provider,
|
||||
@@ -262,19 +262,6 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class TriggeredPresetSetupRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
|
||||
trigger_config: dict[str, Any]
|
||||
agent_credentials: dict[str, CredentialsMetaInput] = pydantic.Field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
|
||||
class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
"""Represents a preset configuration for a library agent."""
|
||||
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, status
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor.utils import make_node_credentials_input_map
|
||||
from backend.integrations.webhooks.utils import setup_webhook_for_block
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -108,11 +113,12 @@ async def get_library_agent_by_graph_id(
|
||||
"/marketplace/{store_listing_version_id}",
|
||||
summary="Get Agent By Store ID",
|
||||
tags=["store, library"],
|
||||
response_model=library_model.LibraryAgent | None,
|
||||
)
|
||||
async def get_library_agent_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> library_model.LibraryAgent | None:
|
||||
):
|
||||
"""
|
||||
Get Library Agent from Store Listing Version ID.
|
||||
"""
|
||||
@@ -289,3 +295,81 @@ async def fork_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
class TriggeredPresetSetupParams(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
trigger_config: dict[str, Any]
|
||||
agent_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@router.post("/{library_agent_id}/setup-trigger")
|
||||
async def setup_trigger(
|
||||
library_agent_id: str = Path(..., description="ID of the library agent"),
|
||||
params: TriggeredPresetSetupParams = Body(),
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
"""
|
||||
Sets up a webhook-triggered `LibraryAgentPreset` for a `LibraryAgent`.
|
||||
Returns the correspondingly created `LibraryAgentPreset` with `webhook_id` set.
|
||||
"""
|
||||
library_agent = await library_db.get_library_agent(
|
||||
id=library_agent_id, user_id=user_id
|
||||
)
|
||||
if not library_agent:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Library agent #{library_agent_id} not found",
|
||||
)
|
||||
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id, version=library_agent.graph_version, user_id=user_id
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status.HTTP_410_GONE,
|
||||
f"Graph #{library_agent.graph_id} not accessible (anymore)",
|
||||
)
|
||||
if not (trigger_node := graph.webhook_input_node):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Graph #{library_agent.graph_id} does not have a webhook node",
|
||||
)
|
||||
|
||||
trigger_config_with_credentials = {
|
||||
**params.trigger_config,
|
||||
**(
|
||||
make_node_credentials_input_map(graph, params.agent_credentials).get(
|
||||
trigger_node.id
|
||||
)
|
||||
or {}
|
||||
),
|
||||
}
|
||||
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=trigger_node.block,
|
||||
trigger_config=trigger_config_with_credentials,
|
||||
)
|
||||
if not new_webhook:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Could not set up webhook: {feedback}",
|
||||
)
|
||||
|
||||
new_preset = await library_db.create_preset(
|
||||
user_id=user_id,
|
||||
preset=library_model.LibraryAgentPresetCreatable(
|
||||
graph_id=library_agent.graph_id,
|
||||
graph_version=library_agent.graph_version,
|
||||
name=params.name,
|
||||
description=params.description,
|
||||
inputs=trigger_config_with_credentials,
|
||||
credentials=params.agent_credentials,
|
||||
webhook_id=new_webhook.id,
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
return new_preset
|
||||
|
||||
@@ -138,66 +138,6 @@ async def create_preset(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/presets/setup-trigger")
|
||||
async def setup_trigger(
|
||||
params: models.TriggeredPresetSetupRequest = Body(),
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> models.LibraryAgentPreset:
|
||||
"""
|
||||
Sets up a webhook-triggered `LibraryAgentPreset` for a `LibraryAgent`.
|
||||
Returns the correspondingly created `LibraryAgentPreset` with `webhook_id` set.
|
||||
"""
|
||||
graph = await get_graph(
|
||||
params.graph_id, version=params.graph_version, user_id=user_id
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status.HTTP_410_GONE,
|
||||
f"Graph #{params.graph_id} not accessible (anymore)",
|
||||
)
|
||||
if not (trigger_node := graph.webhook_input_node):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Graph #{params.graph_id} does not have a webhook node",
|
||||
)
|
||||
|
||||
trigger_config_with_credentials = {
|
||||
**params.trigger_config,
|
||||
**(
|
||||
make_node_credentials_input_map(graph, params.agent_credentials).get(
|
||||
trigger_node.id
|
||||
)
|
||||
or {}
|
||||
),
|
||||
}
|
||||
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=trigger_node.block,
|
||||
trigger_config=trigger_config_with_credentials,
|
||||
)
|
||||
if not new_webhook:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Could not set up webhook: {feedback}",
|
||||
)
|
||||
|
||||
new_preset = await db.create_preset(
|
||||
user_id=user_id,
|
||||
preset=models.LibraryAgentPresetCreatable(
|
||||
graph_id=params.graph_id,
|
||||
graph_version=params.graph_version,
|
||||
name=params.name,
|
||||
description=params.description,
|
||||
inputs=trigger_config_with_credentials,
|
||||
credentials=params.agent_credentials,
|
||||
webhook_id=new_webhook.id,
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
return new_preset
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/presets/{preset_id}",
|
||||
summary="Update an existing preset",
|
||||
|
||||
@@ -7,15 +7,10 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.data.graph
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
from backend.data.graph import (
|
||||
GraphMeta,
|
||||
GraphModel,
|
||||
get_graph,
|
||||
get_graph_as_admin,
|
||||
get_sub_graphs,
|
||||
)
|
||||
from backend.data.graph import GraphModel, get_sub_graphs
|
||||
from backend.data.includes import AGENT_GRAPH_INCLUDE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -198,7 +193,9 @@ async def get_store_agent_details(
|
||||
) from e
|
||||
|
||||
|
||||
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
||||
async def get_available_graph(
|
||||
store_listing_version_id: str,
|
||||
):
|
||||
try:
|
||||
# Get avaialble, non-deleted store listing version
|
||||
store_listing_version = (
|
||||
@@ -218,7 +215,18 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
||||
detail=f"Store listing version {store_listing_version_id} not found",
|
||||
)
|
||||
|
||||
return GraphModel.from_db(store_listing_version.AgentGraph).meta()
|
||||
graph = GraphModel.from_db(store_listing_version.AgentGraph)
|
||||
# We return graph meta, without nodes, they cannot be just removed
|
||||
# because then input_schema would be empty
|
||||
return {
|
||||
"id": graph.id,
|
||||
"version": graph.version,
|
||||
"is_active": graph.is_active,
|
||||
"name": graph.name,
|
||||
"description": graph.description,
|
||||
"input_schema": graph.input_schema,
|
||||
"output_schema": graph.output_schema,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent: {e}")
|
||||
@@ -1016,7 +1024,7 @@ async def get_agent(
|
||||
if not store_listing_version:
|
||||
raise ValueError(f"Store listing version {store_listing_version_id} not found")
|
||||
|
||||
graph = await get_graph(
|
||||
graph = await backend.data.graph.get_graph(
|
||||
user_id=user_id,
|
||||
graph_id=store_listing_version.agentGraphId,
|
||||
version=store_listing_version.agentGraphVersion,
|
||||
@@ -1375,7 +1383,7 @@ async def get_agent_as_admin(
|
||||
if not store_listing_version:
|
||||
raise ValueError(f"Store listing version {store_listing_version_id} not found")
|
||||
|
||||
graph = await get_graph_as_admin(
|
||||
graph = await backend.data.graph.get_graph_as_admin(
|
||||
user_id=user_id,
|
||||
graph_id=store_listing_version.agentGraphId,
|
||||
version=store_listing_version.agentGraphVersion,
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import uuid
|
||||
|
||||
import fastapi
|
||||
from gcloud.aio import storage as async_storage
|
||||
from google.cloud import storage
|
||||
|
||||
import backend.server.v2.store.exceptions
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
@@ -33,28 +33,21 @@ async def check_media_exists(user_id: str, filename: str) -> str | None:
|
||||
if not settings.config.media_gcs_bucket_name:
|
||||
raise MissingConfigError("GCS media bucket is not configured")
|
||||
|
||||
async_client = async_storage.Storage()
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
storage_client = storage.Client()
|
||||
bucket = storage_client.bucket(settings.config.media_gcs_bucket_name)
|
||||
|
||||
# Check images
|
||||
image_path = f"users/{user_id}/images/{filename}"
|
||||
try:
|
||||
await async_client.download_metadata(bucket_name, image_path)
|
||||
# If we get here, the file exists - construct public URL
|
||||
return f"https://storage.googleapis.com/{bucket_name}/{image_path}"
|
||||
except Exception:
|
||||
# File doesn't exist, continue to check videos
|
||||
pass
|
||||
image_blob = bucket.blob(image_path)
|
||||
if image_blob.exists():
|
||||
return image_blob.public_url
|
||||
|
||||
# Check videos
|
||||
video_path = f"users/{user_id}/videos/{filename}"
|
||||
try:
|
||||
await async_client.download_metadata(bucket_name, video_path)
|
||||
# If we get here, the file exists - construct public URL
|
||||
return f"https://storage.googleapis.com/{bucket_name}/{video_path}"
|
||||
except Exception:
|
||||
# File doesn't exist
|
||||
pass
|
||||
|
||||
video_blob = bucket.blob(video_path)
|
||||
if video_blob.exists():
|
||||
return video_blob.public_url
|
||||
|
||||
return None
|
||||
|
||||
@@ -177,19 +170,16 @@ async def upload_media(
|
||||
storage_path = f"users/{user_id}/{media_type}/{unique_filename}"
|
||||
|
||||
try:
|
||||
async_client = async_storage.Storage()
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
storage_client = storage.Client()
|
||||
bucket = storage_client.bucket(settings.config.media_gcs_bucket_name)
|
||||
blob = bucket.blob(storage_path)
|
||||
blob.content_type = content_type
|
||||
|
||||
file_bytes = await file.read()
|
||||
await scan_content_safe(file_bytes, filename=unique_filename)
|
||||
blob.upload_from_string(file_bytes, content_type=content_type)
|
||||
|
||||
# Upload using pure async client
|
||||
await async_client.upload(
|
||||
bucket_name, storage_path, file_bytes, content_type=content_type
|
||||
)
|
||||
|
||||
# Construct public URL
|
||||
public_url = f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
|
||||
public_url = blob.public_url
|
||||
|
||||
logger.info(f"Successfully uploaded file to: {storage_path}")
|
||||
return public_url
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import io
|
||||
import unittest.mock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import pytest
|
||||
@@ -22,19 +21,15 @@ def mock_settings(monkeypatch):
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage_client(mocker):
|
||||
# Mock the async gcloud.aio.storage.Storage client
|
||||
mock_client = AsyncMock()
|
||||
mock_client.upload = AsyncMock()
|
||||
mock_client = unittest.mock.MagicMock()
|
||||
mock_bucket = unittest.mock.MagicMock()
|
||||
mock_blob = unittest.mock.MagicMock()
|
||||
|
||||
# Mock the constructor to return our mock client
|
||||
mocker.patch(
|
||||
"backend.server.v2.store.media.async_storage.Storage", return_value=mock_client
|
||||
)
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
mock_blob.public_url = "http://test-url/media/laptop.jpeg"
|
||||
|
||||
# Mock virus scanner to avoid actual scanning
|
||||
mocker.patch(
|
||||
"backend.server.v2.store.media.scan_content_safe", new_callable=AsyncMock
|
||||
)
|
||||
mocker.patch("google.cloud.storage.Client", return_value=mock_client)
|
||||
|
||||
return mock_client
|
||||
|
||||
@@ -51,11 +46,10 @@ async def test_upload_media_success(mock_settings, mock_storage_client):
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
assert result.startswith(
|
||||
"https://storage.googleapis.com/test-bucket/users/test-user/images/"
|
||||
)
|
||||
assert result.endswith(".jpeg")
|
||||
mock_storage_client.upload.assert_called_once()
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
mock_bucket = mock_storage_client.bucket.return_value
|
||||
mock_blob = mock_bucket.blob.return_value
|
||||
mock_blob.upload_from_string.assert_called_once()
|
||||
|
||||
|
||||
async def test_upload_media_invalid_type(mock_settings, mock_storage_client):
|
||||
@@ -68,7 +62,9 @@ async def test_upload_media_invalid_type(mock_settings, mock_storage_client):
|
||||
with pytest.raises(backend.server.v2.store.exceptions.InvalidFileTypeError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
mock_storage_client.upload.assert_not_called()
|
||||
mock_bucket = mock_storage_client.bucket.return_value
|
||||
mock_blob = mock_bucket.blob.return_value
|
||||
mock_blob.upload_from_string.assert_not_called()
|
||||
|
||||
|
||||
async def test_upload_media_missing_credentials(monkeypatch):
|
||||
@@ -96,11 +92,10 @@ async def test_upload_media_video_type(mock_settings, mock_storage_client):
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
assert result.startswith(
|
||||
"https://storage.googleapis.com/test-bucket/users/test-user/videos/"
|
||||
)
|
||||
assert result.endswith(".mp4")
|
||||
mock_storage_client.upload.assert_called_once()
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
mock_bucket = mock_storage_client.bucket.return_value
|
||||
mock_blob = mock_bucket.blob.return_value
|
||||
mock_blob.upload_from_string.assert_called_once()
|
||||
|
||||
|
||||
async def test_upload_media_file_too_large(mock_settings, mock_storage_client):
|
||||
@@ -137,10 +132,7 @@ async def test_upload_media_png_success(mock_settings, mock_storage_client):
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
assert result.startswith(
|
||||
"https://storage.googleapis.com/test-bucket/users/test-user/images/"
|
||||
)
|
||||
assert result.endswith(".png")
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
|
||||
|
||||
async def test_upload_media_gif_success(mock_settings, mock_storage_client):
|
||||
@@ -151,10 +143,7 @@ async def test_upload_media_gif_success(mock_settings, mock_storage_client):
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
assert result.startswith(
|
||||
"https://storage.googleapis.com/test-bucket/users/test-user/images/"
|
||||
)
|
||||
assert result.endswith(".gif")
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
|
||||
|
||||
async def test_upload_media_webp_success(mock_settings, mock_storage_client):
|
||||
@@ -165,10 +154,7 @@ async def test_upload_media_webp_success(mock_settings, mock_storage_client):
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
assert result.startswith(
|
||||
"https://storage.googleapis.com/test-bucket/users/test-user/images/"
|
||||
)
|
||||
assert result.endswith(".webp")
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
|
||||
|
||||
async def test_upload_media_webm_success(mock_settings, mock_storage_client):
|
||||
@@ -179,10 +165,7 @@ async def test_upload_media_webm_success(mock_settings, mock_storage_client):
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
assert result.startswith(
|
||||
"https://storage.googleapis.com/test-bucket/users/test-user/videos/"
|
||||
)
|
||||
assert result.endswith(".webm")
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
|
||||
|
||||
async def test_upload_media_mismatched_signature(mock_settings, mock_storage_client):
|
||||
|
||||
@@ -19,7 +19,6 @@ from backend.server.model import (
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
WSSubscribeGraphExecutionsRequest,
|
||||
)
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.service import AppProcess
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
@@ -47,11 +46,18 @@ def get_connection_manager():
|
||||
return _connection_manager
|
||||
|
||||
|
||||
@continuous_retry()
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
event_queue = AsyncRedisExecutionEventBus()
|
||||
async for event in event_queue.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
try:
|
||||
event_queue = AsyncRedisExecutionEventBus()
|
||||
async for event in event_queue.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Event broadcaster stopped due to error: %s. "
|
||||
"Verify the Redis connection and restart the service.",
|
||||
e,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
|
||||
@@ -1,529 +0,0 @@
|
||||
"""
|
||||
Cloud storage utilities for handling various cloud storage providers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os.path
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Tuple
|
||||
|
||||
from gcloud.aio import storage as async_gcs_storage
|
||||
from google.cloud import storage as gcs_storage
|
||||
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CloudStorageConfig:
|
||||
"""Configuration for cloud storage providers."""
|
||||
|
||||
def __init__(self):
|
||||
config = Config()
|
||||
|
||||
# GCS configuration from settings - uses Application Default Credentials
|
||||
self.gcs_bucket_name = config.media_gcs_bucket_name
|
||||
|
||||
# Future providers can be added here
|
||||
# self.aws_bucket_name = config.aws_bucket_name
|
||||
# self.azure_container_name = config.azure_container_name
|
||||
|
||||
|
||||
class CloudStorageHandler:
|
||||
"""Generic cloud storage handler that can work with multiple providers."""
|
||||
|
||||
def __init__(self, config: CloudStorageConfig):
|
||||
self.config = config
|
||||
self._async_gcs_client = None
|
||||
self._sync_gcs_client = None # Only for signed URLs
|
||||
|
||||
def _get_async_gcs_client(self):
|
||||
"""Lazy initialization of async GCS client."""
|
||||
if self._async_gcs_client is None:
|
||||
# Use Application Default Credentials (ADC)
|
||||
self._async_gcs_client = async_gcs_storage.Storage()
|
||||
return self._async_gcs_client
|
||||
|
||||
def _get_sync_gcs_client(self):
|
||||
"""Lazy initialization of sync GCS client (only for signed URLs)."""
|
||||
if self._sync_gcs_client is None:
|
||||
# Use Application Default Credentials (ADC) - same as media.py
|
||||
self._sync_gcs_client = gcs_storage.Client()
|
||||
return self._sync_gcs_client
|
||||
|
||||
def parse_cloud_path(self, path: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parse a cloud storage path and return provider and actual path.
|
||||
|
||||
Args:
|
||||
path: Cloud storage path (e.g., "gcs://bucket/path/to/file")
|
||||
|
||||
Returns:
|
||||
Tuple of (provider, actual_path)
|
||||
"""
|
||||
if path.startswith("gcs://"):
|
||||
return "gcs", path[6:] # Remove "gcs://" prefix
|
||||
# Future providers:
|
||||
# elif path.startswith("s3://"):
|
||||
# return "s3", path[5:]
|
||||
# elif path.startswith("azure://"):
|
||||
# return "azure", path[8:]
|
||||
else:
|
||||
raise ValueError(f"Unsupported cloud storage path: {path}")
|
||||
|
||||
def is_cloud_path(self, path: str) -> bool:
|
||||
"""Check if a path is a cloud storage path."""
|
||||
return path.startswith(("gcs://", "s3://", "azure://"))
|
||||
|
||||
async def store_file(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
provider: str = "gcs",
|
||||
expiration_hours: int = 48,
|
||||
user_id: str | None = None,
|
||||
graph_exec_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Store file content in cloud storage.
|
||||
|
||||
Args:
|
||||
content: File content as bytes
|
||||
filename: Desired filename
|
||||
provider: Cloud storage provider ("gcs", "s3", "azure")
|
||||
expiration_hours: Hours until expiration (1-48, default: 48)
|
||||
user_id: User ID for user-scoped files (optional)
|
||||
graph_exec_id: Graph execution ID for execution-scoped files (optional)
|
||||
|
||||
Note:
|
||||
Provide either user_id OR graph_exec_id, not both. If neither is provided,
|
||||
files will be stored as system uploads.
|
||||
|
||||
Returns:
|
||||
Cloud storage path (e.g., "gcs://bucket/path/to/file")
|
||||
"""
|
||||
if provider == "gcs":
|
||||
return await self._store_file_gcs(
|
||||
content, filename, expiration_hours, user_id, graph_exec_id
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported cloud storage provider: {provider}")
|
||||
|
||||
async def _store_file_gcs(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
expiration_hours: int,
|
||||
user_id: str | None = None,
|
||||
graph_exec_id: str | None = None,
|
||||
) -> str:
|
||||
"""Store file in Google Cloud Storage."""
|
||||
if not self.config.gcs_bucket_name:
|
||||
raise ValueError("GCS_BUCKET_NAME not configured")
|
||||
|
||||
# Validate that only one scope is provided
|
||||
if user_id and graph_exec_id:
|
||||
raise ValueError("Provide either user_id OR graph_exec_id, not both")
|
||||
|
||||
async_client = self._get_async_gcs_client()
|
||||
|
||||
# Generate unique path with appropriate scope
|
||||
unique_id = str(uuid.uuid4())
|
||||
if user_id:
|
||||
# User-scoped uploads
|
||||
blob_name = f"uploads/users/{user_id}/{unique_id}/{filename}"
|
||||
elif graph_exec_id:
|
||||
# Execution-scoped uploads
|
||||
blob_name = f"uploads/executions/{graph_exec_id}/{unique_id}/{filename}"
|
||||
else:
|
||||
# System uploads (for backwards compatibility)
|
||||
blob_name = f"uploads/system/{unique_id}/{filename}"
|
||||
|
||||
# Upload content with metadata using pure async client
|
||||
upload_time = datetime.now(timezone.utc)
|
||||
expiration_time = upload_time + timedelta(hours=expiration_hours)
|
||||
|
||||
await async_client.upload(
|
||||
self.config.gcs_bucket_name,
|
||||
blob_name,
|
||||
content,
|
||||
metadata={
|
||||
"uploaded_at": upload_time.isoformat(),
|
||||
"expires_at": expiration_time.isoformat(),
|
||||
"expiration_hours": str(expiration_hours),
|
||||
},
|
||||
)
|
||||
|
||||
return f"gcs://{self.config.gcs_bucket_name}/{blob_name}"
|
||||
|
||||
async def retrieve_file(
|
||||
self,
|
||||
cloud_path: str,
|
||||
user_id: str | None = None,
|
||||
graph_exec_id: str | None = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Retrieve file content from cloud storage.
|
||||
|
||||
Args:
|
||||
cloud_path: Cloud storage path (e.g., "gcs://bucket/path/to/file")
|
||||
user_id: User ID for authorization of user-scoped files (optional)
|
||||
graph_exec_id: Graph execution ID for authorization of execution-scoped files (optional)
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
|
||||
Raises:
|
||||
PermissionError: If user tries to access files they don't own
|
||||
"""
|
||||
provider, path = self.parse_cloud_path(cloud_path)
|
||||
|
||||
if provider == "gcs":
|
||||
return await self._retrieve_file_gcs(path, user_id, graph_exec_id)
|
||||
else:
|
||||
raise ValueError(f"Unsupported cloud storage provider: {provider}")
|
||||
|
||||
async def _retrieve_file_gcs(
|
||||
self, path: str, user_id: str | None = None, graph_exec_id: str | None = None
|
||||
) -> bytes:
|
||||
"""Retrieve file from Google Cloud Storage with authorization."""
|
||||
# Parse bucket and blob name from path
|
||||
parts = path.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Invalid GCS path: {path}")
|
||||
|
||||
bucket_name, blob_name = parts
|
||||
|
||||
# Authorization check
|
||||
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
||||
|
||||
async_client = self._get_async_gcs_client()
|
||||
|
||||
try:
|
||||
# Download content using pure async client
|
||||
content = await async_client.download(bucket_name, blob_name)
|
||||
return content
|
||||
except Exception as e:
|
||||
# Convert gcloud-aio exceptions to standard ones
|
||||
if "404" in str(e) or "Not Found" in str(e):
|
||||
raise FileNotFoundError(f"File not found: gcs://{path}")
|
||||
raise
|
||||
|
||||
def _validate_file_access(
|
||||
self,
|
||||
blob_name: str,
|
||||
user_id: str | None = None,
|
||||
graph_exec_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Validate that a user can access a specific file path.
|
||||
|
||||
Args:
|
||||
blob_name: The blob path in GCS
|
||||
user_id: The requesting user ID (optional)
|
||||
graph_exec_id: The requesting graph execution ID (optional)
|
||||
|
||||
Raises:
|
||||
PermissionError: If access is denied
|
||||
"""
|
||||
|
||||
# Normalize the path to prevent path traversal attacks
|
||||
normalized_path = os.path.normpath(blob_name)
|
||||
|
||||
# Ensure the normalized path doesn't contain any path traversal attempts
|
||||
if ".." in normalized_path or normalized_path.startswith("/"):
|
||||
raise PermissionError("Invalid file path: path traversal detected")
|
||||
|
||||
# Split into components and validate each part
|
||||
path_parts = normalized_path.split("/")
|
||||
|
||||
# Validate path structure: must start with "uploads/"
|
||||
if not path_parts or path_parts[0] != "uploads":
|
||||
raise PermissionError("Invalid file path: must be under uploads/")
|
||||
|
||||
# System uploads (uploads/system/*) can be accessed by anyone for backwards compatibility
|
||||
if len(path_parts) >= 2 and path_parts[1] == "system":
|
||||
return
|
||||
|
||||
# User-specific uploads (uploads/users/{user_id}/*) require matching user_id
|
||||
if len(path_parts) >= 2 and path_parts[1] == "users":
|
||||
if not user_id or len(path_parts) < 3:
|
||||
raise PermissionError(
|
||||
"User ID required to access user files"
|
||||
if not user_id
|
||||
else "Invalid user file path format"
|
||||
)
|
||||
|
||||
file_owner_id = path_parts[2]
|
||||
# Validate user_id format (basic validation) - no need to check ".." again since we already did
|
||||
if not file_owner_id or "/" in file_owner_id:
|
||||
raise PermissionError("Invalid user ID in path")
|
||||
|
||||
if file_owner_id != user_id:
|
||||
raise PermissionError(
|
||||
f"Access denied: file belongs to user {file_owner_id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Execution-specific uploads (uploads/executions/{graph_exec_id}/*) require matching graph_exec_id
|
||||
if len(path_parts) >= 2 and path_parts[1] == "executions":
|
||||
if not graph_exec_id or len(path_parts) < 3:
|
||||
raise PermissionError(
|
||||
"Graph execution ID required to access execution files"
|
||||
if not graph_exec_id
|
||||
else "Invalid execution file path format"
|
||||
)
|
||||
|
||||
file_exec_id = path_parts[2]
|
||||
# Validate execution_id format (basic validation) - no need to check ".." again since we already did
|
||||
if not file_exec_id or "/" in file_exec_id:
|
||||
raise PermissionError("Invalid execution ID in path")
|
||||
|
||||
if file_exec_id != graph_exec_id:
|
||||
raise PermissionError(
|
||||
f"Access denied: file belongs to execution {file_exec_id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Legacy uploads directory (uploads/*) - allow for backwards compatibility with warning
|
||||
# Note: We already validated it starts with "uploads/" above, so this is guaranteed to match
|
||||
logger.warning(f"Accessing legacy upload path: {blob_name}")
|
||||
return
|
||||
|
||||
async def generate_signed_url(
|
||||
self,
|
||||
cloud_path: str,
|
||||
expiration_hours: int = 1,
|
||||
user_id: str | None = None,
|
||||
graph_exec_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a signed URL for temporary access to a cloud storage file.
|
||||
|
||||
Args:
|
||||
cloud_path: Cloud storage path
|
||||
expiration_hours: URL expiration in hours
|
||||
user_id: User ID for authorization (required for user files)
|
||||
graph_exec_id: Graph execution ID for authorization (required for execution files)
|
||||
|
||||
Returns:
|
||||
Signed URL string
|
||||
|
||||
Raises:
|
||||
PermissionError: If user tries to access files they don't own
|
||||
"""
|
||||
provider, path = self.parse_cloud_path(cloud_path)
|
||||
|
||||
if provider == "gcs":
|
||||
return await self._generate_signed_url_gcs(
|
||||
path, expiration_hours, user_id, graph_exec_id
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported cloud storage provider: {provider}")
|
||||
|
||||
async def _generate_signed_url_gcs(
|
||||
self,
|
||||
path: str,
|
||||
expiration_hours: int,
|
||||
user_id: str | None = None,
|
||||
graph_exec_id: str | None = None,
|
||||
) -> str:
|
||||
"""Generate signed URL for GCS with authorization."""
|
||||
|
||||
# Parse bucket and blob name from path
|
||||
parts = path.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Invalid GCS path: {path}")
|
||||
|
||||
bucket_name, blob_name = parts
|
||||
|
||||
# Authorization check
|
||||
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
||||
|
||||
# Use sync client for signed URLs since gcloud-aio doesn't support them
|
||||
sync_client = self._get_sync_gcs_client()
|
||||
bucket = sync_client.bucket(bucket_name)
|
||||
blob = bucket.blob(blob_name)
|
||||
|
||||
# Generate signed URL asynchronously using sync client
|
||||
url = await asyncio.to_thread(
|
||||
blob.generate_signed_url,
|
||||
version="v4",
|
||||
expiration=datetime.now(timezone.utc) + timedelta(hours=expiration_hours),
|
||||
method="GET",
|
||||
)
|
||||
|
||||
return url
|
||||
|
||||
async def delete_expired_files(self, provider: str = "gcs") -> int:
|
||||
"""
|
||||
Delete files that have passed their expiration time.
|
||||
|
||||
Args:
|
||||
provider: Cloud storage provider
|
||||
|
||||
Returns:
|
||||
Number of files deleted
|
||||
"""
|
||||
if provider == "gcs":
|
||||
return await self._delete_expired_files_gcs()
|
||||
else:
|
||||
raise ValueError(f"Unsupported cloud storage provider: {provider}")
|
||||
|
||||
async def _delete_expired_files_gcs(self) -> int:
|
||||
"""Delete expired files from GCS based on metadata."""
|
||||
if not self.config.gcs_bucket_name:
|
||||
raise ValueError("GCS_BUCKET_NAME not configured")
|
||||
|
||||
async_client = self._get_async_gcs_client()
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
# List all blobs in the uploads directory using pure async client
|
||||
list_response = await async_client.list_objects(
|
||||
self.config.gcs_bucket_name, params={"prefix": "uploads/"}
|
||||
)
|
||||
|
||||
items = list_response.get("items", [])
|
||||
deleted_count = 0
|
||||
|
||||
# Process deletions in parallel with limited concurrency
|
||||
semaphore = asyncio.Semaphore(10) # Limit to 10 concurrent deletions
|
||||
|
||||
async def delete_if_expired(blob_info):
|
||||
async with semaphore:
|
||||
blob_name = blob_info.get("name", "")
|
||||
try:
|
||||
# Get blob metadata - need to fetch it separately
|
||||
if not blob_name:
|
||||
return 0
|
||||
|
||||
# Get metadata for this specific blob using pure async client
|
||||
metadata_response = await async_client.download_metadata(
|
||||
self.config.gcs_bucket_name, blob_name
|
||||
)
|
||||
metadata = metadata_response.get("metadata", {})
|
||||
|
||||
if metadata and "expires_at" in metadata:
|
||||
expires_at = datetime.fromisoformat(metadata["expires_at"])
|
||||
if current_time > expires_at:
|
||||
# Delete using pure async client
|
||||
await async_client.delete(
|
||||
self.config.gcs_bucket_name, blob_name
|
||||
)
|
||||
return 1
|
||||
except Exception as e:
|
||||
# Log specific errors for debugging
|
||||
logger.warning(
|
||||
f"Failed to process file {blob_name} during cleanup: {e}"
|
||||
)
|
||||
# Skip files with invalid metadata or delete errors
|
||||
pass
|
||||
return 0
|
||||
|
||||
if items:
|
||||
results = await asyncio.gather(
|
||||
*[delete_if_expired(blob) for blob in items]
|
||||
)
|
||||
deleted_count = sum(results)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
# Log the error for debugging but continue operation
|
||||
logger.error(f"Cleanup operation failed: {e}")
|
||||
# Return 0 - we'll try again next cleanup cycle
|
||||
return 0
|
||||
|
||||
async def check_file_expired(self, cloud_path: str) -> bool:
|
||||
"""
|
||||
Check if a file has expired based on its metadata.
|
||||
|
||||
Args:
|
||||
cloud_path: Cloud storage path
|
||||
|
||||
Returns:
|
||||
True if file has expired, False otherwise
|
||||
"""
|
||||
provider, path = self.parse_cloud_path(cloud_path)
|
||||
|
||||
if provider == "gcs":
|
||||
return await self._check_file_expired_gcs(path)
|
||||
else:
|
||||
raise ValueError(f"Unsupported cloud storage provider: {provider}")
|
||||
|
||||
async def _check_file_expired_gcs(self, path: str) -> bool:
|
||||
"""Check if a GCS file has expired."""
|
||||
parts = path.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Invalid GCS path: {path}")
|
||||
|
||||
bucket_name, blob_name = parts
|
||||
|
||||
async_client = self._get_async_gcs_client()
|
||||
|
||||
try:
|
||||
# Get object metadata using pure async client
|
||||
metadata_info = await async_client.download_metadata(bucket_name, blob_name)
|
||||
metadata = metadata_info.get("metadata", {})
|
||||
|
||||
if metadata and "expires_at" in metadata:
|
||||
expires_at = datetime.fromisoformat(metadata["expires_at"])
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
except Exception as e:
|
||||
# If file doesn't exist or we can't read metadata
|
||||
if "404" in str(e) or "Not Found" in str(e):
|
||||
logger.debug(f"File not found during expiration check: {blob_name}")
|
||||
return True # File doesn't exist, consider it expired
|
||||
|
||||
# Log other types of errors for debugging
|
||||
logger.warning(f"Failed to check expiration for {blob_name}: {e}")
|
||||
# If we can't read metadata for other reasons, assume not expired
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Global instance with thread safety
|
||||
_cloud_storage_handler = None
|
||||
_handler_lock = asyncio.Lock()
|
||||
_cleanup_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_cloud_storage_handler() -> CloudStorageHandler:
|
||||
"""Get the global cloud storage handler instance with proper locking."""
|
||||
global _cloud_storage_handler
|
||||
|
||||
if _cloud_storage_handler is None:
|
||||
async with _handler_lock:
|
||||
# Double-check pattern to avoid race conditions
|
||||
if _cloud_storage_handler is None:
|
||||
config = CloudStorageConfig()
|
||||
_cloud_storage_handler = CloudStorageHandler(config)
|
||||
|
||||
return _cloud_storage_handler
|
||||
|
||||
|
||||
async def cleanup_expired_files_async() -> int:
|
||||
"""
|
||||
Clean up expired files from cloud storage.
|
||||
|
||||
This function uses a lock to prevent concurrent cleanup operations.
|
||||
|
||||
Returns:
|
||||
Number of files deleted
|
||||
"""
|
||||
# Use cleanup lock to prevent concurrent cleanup operations
|
||||
async with _cleanup_lock:
|
||||
try:
|
||||
logger.info("Starting cleanup of expired cloud storage files")
|
||||
handler = await get_cloud_storage_handler()
|
||||
deleted_count = await handler.delete_expired_files()
|
||||
logger.info(f"Cleaned up {deleted_count} expired files from cloud storage")
|
||||
return deleted_count
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cloud storage cleanup: {e}")
|
||||
return 0
|
||||
@@ -1,472 +0,0 @@
|
||||
"""
|
||||
Tests for cloud storage utilities.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.cloud_storage import CloudStorageConfig, CloudStorageHandler
|
||||
|
||||
|
||||
class TestCloudStorageHandler:
|
||||
"""Test cases for CloudStorageHandler."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
"""Create a test configuration."""
|
||||
config = CloudStorageConfig()
|
||||
config.gcs_bucket_name = "test-bucket"
|
||||
return config
|
||||
|
||||
@pytest.fixture
|
||||
def handler(self, config):
|
||||
"""Create a test handler."""
|
||||
return CloudStorageHandler(config)
|
||||
|
||||
def test_parse_cloud_path_gcs(self, handler):
|
||||
"""Test parsing GCS paths."""
|
||||
provider, path = handler.parse_cloud_path("gcs://bucket/path/to/file.txt")
|
||||
assert provider == "gcs"
|
||||
assert path == "bucket/path/to/file.txt"
|
||||
|
||||
def test_parse_cloud_path_invalid(self, handler):
|
||||
"""Test parsing invalid cloud paths."""
|
||||
with pytest.raises(ValueError, match="Unsupported cloud storage path"):
|
||||
handler.parse_cloud_path("invalid://path")
|
||||
|
||||
def test_is_cloud_path(self, handler):
|
||||
"""Test cloud path detection."""
|
||||
assert handler.is_cloud_path("gcs://bucket/file.txt")
|
||||
assert handler.is_cloud_path("s3://bucket/file.txt")
|
||||
assert handler.is_cloud_path("azure://container/file.txt")
|
||||
assert not handler.is_cloud_path("http://example.com/file.txt")
|
||||
assert not handler.is_cloud_path("/local/path/file.txt")
|
||||
assert not handler.is_cloud_path("data:text/plain;base64,SGVsbG8=")
|
||||
|
||||
@patch.object(CloudStorageHandler, "_get_async_gcs_client")
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_file_gcs(self, mock_get_async_client, handler):
|
||||
"""Test storing file in GCS."""
|
||||
# Mock async GCS client
|
||||
mock_async_client = AsyncMock()
|
||||
mock_get_async_client.return_value = mock_async_client
|
||||
|
||||
# Mock the upload method
|
||||
mock_async_client.upload = AsyncMock()
|
||||
|
||||
content = b"test file content"
|
||||
filename = "test.txt"
|
||||
|
||||
result = await handler.store_file(content, filename, "gcs", expiration_hours=24)
|
||||
|
||||
# Verify the result format
|
||||
assert result.startswith("gcs://test-bucket/uploads/")
|
||||
assert result.endswith("/test.txt")
|
||||
|
||||
# Verify upload was called with correct parameters
|
||||
mock_async_client.upload.assert_called_once()
|
||||
call_args = mock_async_client.upload.call_args
|
||||
assert call_args[0][0] == "test-bucket" # bucket name
|
||||
assert call_args[0][1].startswith("uploads/system/") # blob name
|
||||
assert call_args[0][2] == content # file content
|
||||
assert "metadata" in call_args[1] # metadata argument
|
||||
|
||||
@patch.object(CloudStorageHandler, "_get_async_gcs_client")
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_file_gcs(self, mock_get_async_client, handler):
|
||||
"""Test retrieving file from GCS."""
|
||||
# Mock async GCS client
|
||||
mock_async_client = AsyncMock()
|
||||
mock_get_async_client.return_value = mock_async_client
|
||||
|
||||
# Mock the download method
|
||||
mock_async_client.download = AsyncMock(return_value=b"test content")
|
||||
|
||||
result = await handler.retrieve_file(
|
||||
"gcs://test-bucket/uploads/system/uuid123/file.txt"
|
||||
)
|
||||
|
||||
assert result == b"test content"
|
||||
mock_async_client.download.assert_called_once_with(
|
||||
"test-bucket", "uploads/system/uuid123/file.txt"
|
||||
)
|
||||
|
||||
@patch.object(CloudStorageHandler, "_get_async_gcs_client")
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_file_not_found(self, mock_get_async_client, handler):
|
||||
"""Test retrieving non-existent file from GCS."""
|
||||
# Mock async GCS client
|
||||
mock_async_client = AsyncMock()
|
||||
mock_get_async_client.return_value = mock_async_client
|
||||
|
||||
# Mock the download method to raise a 404 exception
|
||||
mock_async_client.download = AsyncMock(side_effect=Exception("404 Not Found"))
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await handler.retrieve_file(
|
||||
"gcs://test-bucket/uploads/system/uuid123/nonexistent.txt"
|
||||
)
|
||||
|
||||
@patch.object(CloudStorageHandler, "_get_sync_gcs_client")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_signed_url_gcs(self, mock_get_sync_client, handler):
|
||||
"""Test generating signed URL for GCS."""
|
||||
# Mock sync GCS client for signed URLs
|
||||
mock_sync_client = MagicMock()
|
||||
mock_bucket = MagicMock()
|
||||
mock_blob = MagicMock()
|
||||
|
||||
mock_get_sync_client.return_value = mock_sync_client
|
||||
mock_sync_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
mock_blob.generate_signed_url.return_value = "https://signed-url.example.com"
|
||||
|
||||
result = await handler.generate_signed_url(
|
||||
"gcs://test-bucket/uploads/system/uuid123/file.txt", 1
|
||||
)
|
||||
|
||||
assert result == "https://signed-url.example.com"
|
||||
mock_blob.generate_signed_url.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_provider(self, handler):
|
||||
"""Test unsupported provider error."""
|
||||
with pytest.raises(ValueError, match="Unsupported cloud storage provider"):
|
||||
await handler.store_file(b"content", "file.txt", "unsupported")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported cloud storage path"):
|
||||
await handler.retrieve_file("unsupported://bucket/file.txt")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported cloud storage path"):
|
||||
await handler.generate_signed_url("unsupported://bucket/file.txt")
|
||||
|
||||
@patch.object(CloudStorageHandler, "_get_async_gcs_client")
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_expired_files_gcs(self, mock_get_async_client, handler):
|
||||
"""Test deleting expired files from GCS."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
# Mock async GCS client
|
||||
mock_async_client = AsyncMock()
|
||||
mock_get_async_client.return_value = mock_async_client
|
||||
|
||||
# Mock list_objects response with expired and valid files
|
||||
expired_time = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat()
|
||||
valid_time = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat()
|
||||
|
||||
mock_list_response = {
|
||||
"items": [
|
||||
{"name": "uploads/expired-file.txt"},
|
||||
{"name": "uploads/valid-file.txt"},
|
||||
]
|
||||
}
|
||||
mock_async_client.list_objects = AsyncMock(return_value=mock_list_response)
|
||||
|
||||
# Mock download_metadata responses
|
||||
async def mock_download_metadata(bucket, blob_name):
|
||||
if "expired-file" in blob_name:
|
||||
return {"metadata": {"expires_at": expired_time}}
|
||||
else:
|
||||
return {"metadata": {"expires_at": valid_time}}
|
||||
|
||||
mock_async_client.download_metadata = AsyncMock(
|
||||
side_effect=mock_download_metadata
|
||||
)
|
||||
mock_async_client.delete = AsyncMock()
|
||||
|
||||
result = await handler.delete_expired_files("gcs")
|
||||
|
||||
assert result == 1 # Only one file should be deleted
|
||||
# Verify delete was called once (for expired file)
|
||||
assert mock_async_client.delete.call_count == 1
|
||||
|
||||
@patch.object(CloudStorageHandler, "_get_async_gcs_client")
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_file_expired_gcs(self, mock_get_async_client, handler):
|
||||
"""Test checking if a file has expired."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
# Mock async GCS client
|
||||
mock_async_client = AsyncMock()
|
||||
mock_get_async_client.return_value = mock_async_client
|
||||
|
||||
# Test with expired file
|
||||
expired_time = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat()
|
||||
mock_async_client.download_metadata = AsyncMock(
|
||||
return_value={"metadata": {"expires_at": expired_time}}
|
||||
)
|
||||
|
||||
result = await handler.check_file_expired("gcs://test-bucket/expired-file.txt")
|
||||
assert result is True
|
||||
|
||||
# Test with valid file
|
||||
valid_time = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat()
|
||||
mock_async_client.download_metadata = AsyncMock(
|
||||
return_value={"metadata": {"expires_at": valid_time}}
|
||||
)
|
||||
|
||||
result = await handler.check_file_expired("gcs://test-bucket/valid-file.txt")
|
||||
assert result is False
|
||||
|
||||
@patch("backend.util.cloud_storage.get_cloud_storage_handler")
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_files_async(self, mock_get_handler):
|
||||
"""Test the async cleanup function."""
|
||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||
|
||||
# Mock the handler
|
||||
mock_handler = mock_get_handler.return_value
|
||||
mock_handler.delete_expired_files = AsyncMock(return_value=3)
|
||||
|
||||
result = await cleanup_expired_files_async()
|
||||
|
||||
assert result == 3
|
||||
mock_get_handler.assert_called_once()
|
||||
mock_handler.delete_expired_files.assert_called_once()
|
||||
|
||||
@patch("backend.util.cloud_storage.get_cloud_storage_handler")
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_files_async_error(self, mock_get_handler):
|
||||
"""Test the async cleanup function with error."""
|
||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||
|
||||
# Mock the handler to raise an exception
|
||||
mock_handler = mock_get_handler.return_value
|
||||
mock_handler.delete_expired_files = AsyncMock(
|
||||
side_effect=Exception("GCS error")
|
||||
)
|
||||
|
||||
result = await cleanup_expired_files_async()
|
||||
|
||||
assert result == 0 # Should return 0 on error
|
||||
mock_get_handler.assert_called_once()
|
||||
mock_handler.delete_expired_files.assert_called_once()
|
||||
|
||||
def test_validate_file_access_system_files(self, handler):
|
||||
"""Test access validation for system files."""
|
||||
# System files should be accessible by anyone
|
||||
handler._validate_file_access("uploads/system/uuid123/file.txt", None)
|
||||
handler._validate_file_access("uploads/system/uuid123/file.txt", "user123")
|
||||
|
||||
def test_validate_file_access_user_files_success(self, handler):
|
||||
"""Test successful access validation for user files."""
|
||||
# User should be able to access their own files
|
||||
handler._validate_file_access(
|
||||
"uploads/users/user123/uuid456/file.txt", "user123"
|
||||
)
|
||||
|
||||
def test_validate_file_access_user_files_no_user_id(self, handler):
|
||||
"""Test access validation failure when no user_id provided for user files."""
|
||||
with pytest.raises(
|
||||
PermissionError, match="User ID required to access user files"
|
||||
):
|
||||
handler._validate_file_access(
|
||||
"uploads/users/user123/uuid456/file.txt", None
|
||||
)
|
||||
|
||||
def test_validate_file_access_user_files_wrong_user(self, handler):
|
||||
"""Test access validation failure when accessing another user's files."""
|
||||
with pytest.raises(
|
||||
PermissionError, match="Access denied: file belongs to user user123"
|
||||
):
|
||||
handler._validate_file_access(
|
||||
"uploads/users/user123/uuid456/file.txt", "user456"
|
||||
)
|
||||
|
||||
def test_validate_file_access_legacy_files(self, handler):
|
||||
"""Test access validation for legacy files."""
|
||||
# Legacy files should be accessible with a warning
|
||||
handler._validate_file_access("uploads/uuid789/file.txt", None)
|
||||
handler._validate_file_access("uploads/uuid789/file.txt", "user123")
|
||||
|
||||
def test_validate_file_access_invalid_path(self, handler):
|
||||
"""Test access validation failure for invalid paths."""
|
||||
with pytest.raises(
|
||||
PermissionError, match="Invalid file path: must be under uploads/"
|
||||
):
|
||||
handler._validate_file_access("invalid/path/file.txt", "user123")
|
||||
|
||||
@patch.object(CloudStorageHandler, "_get_async_gcs_client")
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_file_with_authorization(self, mock_get_client, handler):
|
||||
"""Test file retrieval with authorization."""
|
||||
# Mock async GCS client
|
||||
mock_client = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.download = AsyncMock(return_value=b"test content")
|
||||
|
||||
# Test successful retrieval of user's own file
|
||||
result = await handler.retrieve_file(
|
||||
"gcs://test-bucket/uploads/users/user123/uuid456/file.txt",
|
||||
user_id="user123",
|
||||
)
|
||||
assert result == b"test content"
|
||||
mock_client.download.assert_called_once_with(
|
||||
"test-bucket", "uploads/users/user123/uuid456/file.txt"
|
||||
)
|
||||
|
||||
# Test authorization failure
|
||||
with pytest.raises(PermissionError):
|
||||
await handler.retrieve_file(
|
||||
"gcs://test-bucket/uploads/users/user123/uuid456/file.txt",
|
||||
user_id="user456",
|
||||
)
|
||||
|
||||
@patch.object(CloudStorageHandler, "_get_async_gcs_client")
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_file_with_user_id(self, mock_get_client, handler):
|
||||
"""Test file storage with user ID."""
|
||||
# Mock async GCS client
|
||||
mock_client = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_client.upload = AsyncMock()
|
||||
|
||||
content = b"test file content"
|
||||
filename = "test.txt"
|
||||
|
||||
# Test with user_id
|
||||
result = await handler.store_file(
|
||||
content, filename, "gcs", expiration_hours=24, user_id="user123"
|
||||
)
|
||||
|
||||
# Verify the result format includes user path
|
||||
assert result.startswith("gcs://test-bucket/uploads/users/user123/")
|
||||
assert result.endswith("/test.txt")
|
||||
mock_client.upload.assert_called()
|
||||
|
||||
# Test without user_id (system upload)
|
||||
result = await handler.store_file(
|
||||
content, filename, "gcs", expiration_hours=24, user_id=None
|
||||
)
|
||||
|
||||
# Verify the result format includes system path
|
||||
assert result.startswith("gcs://test-bucket/uploads/system/")
|
||||
assert result.endswith("/test.txt")
|
||||
assert mock_client.upload.call_count == 2
|
||||
|
||||
@patch.object(CloudStorageHandler, "_get_async_gcs_client")
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_file_with_graph_exec_id(self, mock_get_async_client, handler):
|
||||
"""Test file storage with graph execution ID."""
|
||||
# Mock async GCS client
|
||||
mock_async_client = AsyncMock()
|
||||
mock_get_async_client.return_value = mock_async_client
|
||||
|
||||
# Mock the upload method
|
||||
mock_async_client.upload = AsyncMock()
|
||||
|
||||
content = b"test file content"
|
||||
filename = "test.txt"
|
||||
|
||||
# Test with graph_exec_id
|
||||
result = await handler.store_file(
|
||||
content, filename, "gcs", expiration_hours=24, graph_exec_id="exec123"
|
||||
)
|
||||
|
||||
# Verify the result format includes execution path
|
||||
assert result.startswith("gcs://test-bucket/uploads/executions/exec123/")
|
||||
assert result.endswith("/test.txt")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_file_with_both_user_and_exec_id(self, handler):
|
||||
"""Test file storage fails when both user_id and graph_exec_id are provided."""
|
||||
content = b"test file content"
|
||||
filename = "test.txt"
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Provide either user_id OR graph_exec_id, not both"
|
||||
):
|
||||
await handler.store_file(
|
||||
content,
|
||||
filename,
|
||||
"gcs",
|
||||
expiration_hours=24,
|
||||
user_id="user123",
|
||||
graph_exec_id="exec123",
|
||||
)
|
||||
|
||||
def test_validate_file_access_execution_files_success(self, handler):
|
||||
"""Test successful access validation for execution files."""
|
||||
# Graph execution should be able to access their own files
|
||||
handler._validate_file_access(
|
||||
"uploads/executions/exec123/uuid456/file.txt", graph_exec_id="exec123"
|
||||
)
|
||||
|
||||
def test_validate_file_access_execution_files_no_exec_id(self, handler):
|
||||
"""Test access validation failure when no graph_exec_id provided for execution files."""
|
||||
with pytest.raises(
|
||||
PermissionError,
|
||||
match="Graph execution ID required to access execution files",
|
||||
):
|
||||
handler._validate_file_access(
|
||||
"uploads/executions/exec123/uuid456/file.txt", user_id="user123"
|
||||
)
|
||||
|
||||
def test_validate_file_access_execution_files_wrong_exec_id(self, handler):
|
||||
"""Test access validation failure when accessing another execution's files."""
|
||||
with pytest.raises(
|
||||
PermissionError, match="Access denied: file belongs to execution exec123"
|
||||
):
|
||||
handler._validate_file_access(
|
||||
"uploads/executions/exec123/uuid456/file.txt", graph_exec_id="exec456"
|
||||
)
|
||||
|
||||
@patch.object(CloudStorageHandler, "_get_async_gcs_client")
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_file_with_exec_authorization(
|
||||
self, mock_get_async_client, handler
|
||||
):
|
||||
"""Test file retrieval with execution authorization."""
|
||||
# Mock async GCS client
|
||||
mock_async_client = AsyncMock()
|
||||
mock_get_async_client.return_value = mock_async_client
|
||||
|
||||
# Mock the download method
|
||||
mock_async_client.download = AsyncMock(return_value=b"test content")
|
||||
|
||||
# Test successful retrieval of execution's own file
|
||||
result = await handler.retrieve_file(
|
||||
"gcs://test-bucket/uploads/executions/exec123/uuid456/file.txt",
|
||||
graph_exec_id="exec123",
|
||||
)
|
||||
assert result == b"test content"
|
||||
|
||||
# Test authorization failure
|
||||
with pytest.raises(PermissionError):
|
||||
await handler.retrieve_file(
|
||||
"gcs://test-bucket/uploads/executions/exec123/uuid456/file.txt",
|
||||
graph_exec_id="exec456",
|
||||
)
|
||||
|
||||
@patch.object(CloudStorageHandler, "_get_sync_gcs_client")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_signed_url_with_exec_authorization(
|
||||
self, mock_get_sync_client, handler
|
||||
):
|
||||
"""Test signed URL generation with execution authorization."""
|
||||
# Mock sync GCS client for signed URLs
|
||||
mock_sync_client = MagicMock()
|
||||
mock_bucket = MagicMock()
|
||||
mock_blob = MagicMock()
|
||||
|
||||
mock_get_sync_client.return_value = mock_sync_client
|
||||
mock_sync_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
mock_blob.generate_signed_url.return_value = "https://signed-url.example.com"
|
||||
|
||||
# Test successful signed URL generation for execution's own file
|
||||
result = await handler.generate_signed_url(
|
||||
"gcs://test-bucket/uploads/executions/exec123/uuid456/file.txt",
|
||||
1,
|
||||
graph_exec_id="exec123",
|
||||
)
|
||||
assert result == "https://signed-url.example.com"
|
||||
|
||||
# Test authorization failure
|
||||
with pytest.raises(PermissionError):
|
||||
await handler.generate_signed_url(
|
||||
"gcs://test-bucket/uploads/executions/exec123/uuid456/file.txt",
|
||||
1,
|
||||
graph_exec_id="exec456",
|
||||
)
|
||||
@@ -7,7 +7,6 @@ import uuid
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.request import Requests
|
||||
from backend.util.type import MediaFileType
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
@@ -32,10 +31,7 @@ def clean_exec_files(graph_exec_id: str, file: str = "") -> None:
|
||||
|
||||
|
||||
async def store_media_file(
|
||||
graph_exec_id: str,
|
||||
file: MediaFileType,
|
||||
user_id: str,
|
||||
return_content: bool = False,
|
||||
graph_exec_id: str, file: MediaFileType, return_content: bool = False
|
||||
) -> MediaFileType:
|
||||
"""
|
||||
Safely handle 'file' (a data URI, a URL, or a local path relative to {temp}/exec_file/{exec_id}),
|
||||
@@ -95,25 +91,8 @@ async def store_media_file(
|
||||
"""
|
||||
return str(absolute_path.relative_to(base))
|
||||
|
||||
# Check if this is a cloud storage path
|
||||
cloud_storage = await get_cloud_storage_handler()
|
||||
if cloud_storage.is_cloud_path(file):
|
||||
# Download from cloud storage and store locally
|
||||
cloud_content = await cloud_storage.retrieve_file(
|
||||
file, user_id=user_id, graph_exec_id=graph_exec_id
|
||||
)
|
||||
|
||||
# Generate filename from cloud path
|
||||
_, path_part = cloud_storage.parse_cloud_path(file)
|
||||
filename = Path(path_part).name or f"{uuid.uuid4()}.bin"
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
|
||||
# Virus scan the cloud content before writing locally
|
||||
await scan_content_safe(cloud_content, filename=filename)
|
||||
target_path.write_bytes(cloud_content)
|
||||
|
||||
# Process file
|
||||
elif file.startswith("data:"):
|
||||
if file.startswith("data:"):
|
||||
# Data URI
|
||||
match = re.match(r"^data:([^;]+);base64,(.*)$", file, re.DOTALL)
|
||||
if not match:
|
||||
|
||||
@@ -1,238 +0,0 @@
|
||||
"""
|
||||
Tests for cloud storage integration in file utilities.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
class TestFileCloudIntegration:
|
||||
"""Test cases for cloud storage integration in file utilities."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_media_file_cloud_path(self):
|
||||
"""Test storing a file from cloud storage path."""
|
||||
graph_exec_id = "test-exec-123"
|
||||
cloud_path = "gcs://test-bucket/uploads/456/source.txt"
|
||||
cloud_content = b"cloud file content"
|
||||
|
||||
with patch(
|
||||
"backend.util.file.get_cloud_storage_handler"
|
||||
) as mock_handler_getter, patch(
|
||||
"backend.util.file.scan_content_safe"
|
||||
) as mock_scan, patch(
|
||||
"backend.util.file.Path"
|
||||
) as mock_path_class:
|
||||
|
||||
# Mock cloud storage handler
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.is_cloud_path.return_value = True
|
||||
mock_handler.parse_cloud_path.return_value = (
|
||||
"gcs",
|
||||
"test-bucket/uploads/456/source.txt",
|
||||
)
|
||||
mock_handler.retrieve_file = AsyncMock(return_value=cloud_content)
|
||||
mock_handler_getter.return_value = mock_handler
|
||||
|
||||
# Mock virus scanner
|
||||
mock_scan.return_value = None
|
||||
|
||||
# Mock file system operations
|
||||
mock_base_path = MagicMock()
|
||||
mock_target_path = MagicMock()
|
||||
mock_resolved_path = MagicMock()
|
||||
|
||||
mock_path_class.return_value = mock_base_path
|
||||
mock_base_path.mkdir = MagicMock()
|
||||
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
||||
mock_target_path.resolve.return_value = mock_resolved_path
|
||||
mock_resolved_path.is_relative_to.return_value = True
|
||||
mock_resolved_path.write_bytes = MagicMock()
|
||||
mock_resolved_path.relative_to.return_value = Path("source.txt")
|
||||
|
||||
# Configure the main Path mock to handle filename extraction
|
||||
# When Path(path_part) is called, it should return a mock with .name = "source.txt"
|
||||
mock_path_for_filename = MagicMock()
|
||||
mock_path_for_filename.name = "source.txt"
|
||||
|
||||
# The Path constructor should return different mocks for different calls
|
||||
def path_constructor(*args, **kwargs):
|
||||
if len(args) == 1 and "source.txt" in str(args[0]):
|
||||
return mock_path_for_filename
|
||||
else:
|
||||
return mock_base_path
|
||||
|
||||
mock_path_class.side_effect = path_constructor
|
||||
|
||||
result = await store_media_file(
|
||||
graph_exec_id,
|
||||
MediaFileType(cloud_path),
|
||||
"test-user-123",
|
||||
return_content=False,
|
||||
)
|
||||
|
||||
# Verify cloud storage operations
|
||||
mock_handler.is_cloud_path.assert_called_once_with(cloud_path)
|
||||
mock_handler.parse_cloud_path.assert_called_once_with(cloud_path)
|
||||
mock_handler.retrieve_file.assert_called_once_with(
|
||||
cloud_path, user_id="test-user-123", graph_exec_id=graph_exec_id
|
||||
)
|
||||
|
||||
# Verify virus scan
|
||||
mock_scan.assert_called_once_with(cloud_content, filename="source.txt")
|
||||
|
||||
# Verify file operations
|
||||
mock_resolved_path.write_bytes.assert_called_once_with(cloud_content)
|
||||
|
||||
# Result should be the relative path
|
||||
assert str(result) == "source.txt"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_media_file_cloud_path_return_content(self):
|
||||
"""Test storing a file from cloud storage and returning content."""
|
||||
graph_exec_id = "test-exec-123"
|
||||
cloud_path = "gcs://test-bucket/uploads/456/image.png"
|
||||
cloud_content = b"\\x89PNG\\r\\n\\x1a\\n\\x00\\x00\\x00\\rIHDR" # PNG header
|
||||
|
||||
with patch(
|
||||
"backend.util.file.get_cloud_storage_handler"
|
||||
) as mock_handler_getter, patch(
|
||||
"backend.util.file.scan_content_safe"
|
||||
) as mock_scan, patch(
|
||||
"backend.util.file.get_mime_type"
|
||||
) as mock_mime, patch(
|
||||
"backend.util.file.base64.b64encode"
|
||||
) as mock_b64, patch(
|
||||
"backend.util.file.Path"
|
||||
) as mock_path_class:
|
||||
|
||||
# Mock cloud storage handler
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.is_cloud_path.return_value = True
|
||||
mock_handler.parse_cloud_path.return_value = (
|
||||
"gcs",
|
||||
"test-bucket/uploads/456/image.png",
|
||||
)
|
||||
mock_handler.retrieve_file = AsyncMock(return_value=cloud_content)
|
||||
mock_handler_getter.return_value = mock_handler
|
||||
|
||||
# Mock other operations
|
||||
mock_scan.return_value = None
|
||||
mock_mime.return_value = "image/png"
|
||||
mock_b64.return_value.decode.return_value = "iVBORw0KGgoAAAANSUhEUgA="
|
||||
|
||||
# Mock file system operations
|
||||
mock_base_path = MagicMock()
|
||||
mock_target_path = MagicMock()
|
||||
mock_resolved_path = MagicMock()
|
||||
|
||||
mock_path_class.return_value = mock_base_path
|
||||
mock_base_path.mkdir = MagicMock()
|
||||
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
||||
mock_target_path.resolve.return_value = mock_resolved_path
|
||||
mock_resolved_path.is_relative_to.return_value = True
|
||||
mock_resolved_path.write_bytes = MagicMock()
|
||||
mock_resolved_path.read_bytes.return_value = cloud_content
|
||||
|
||||
# Mock Path constructor for filename extraction
|
||||
mock_path_obj = MagicMock()
|
||||
mock_path_obj.name = "image.png"
|
||||
with patch("backend.util.file.Path", return_value=mock_path_obj):
|
||||
result = await store_media_file(
|
||||
graph_exec_id,
|
||||
MediaFileType(cloud_path),
|
||||
"test-user-123",
|
||||
return_content=True,
|
||||
)
|
||||
|
||||
# Verify result is a data URI
|
||||
assert str(result).startswith("data:image/png;base64,")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_media_file_non_cloud_path(self):
|
||||
"""Test that non-cloud paths are handled normally."""
|
||||
graph_exec_id = "test-exec-123"
|
||||
data_uri = "data:text/plain;base64,SGVsbG8gd29ybGQ="
|
||||
|
||||
with patch(
|
||||
"backend.util.file.get_cloud_storage_handler"
|
||||
) as mock_handler_getter, patch(
|
||||
"backend.util.file.scan_content_safe"
|
||||
) as mock_scan, patch(
|
||||
"backend.util.file.base64.b64decode"
|
||||
) as mock_b64decode, patch(
|
||||
"backend.util.file.uuid.uuid4"
|
||||
) as mock_uuid, patch(
|
||||
"backend.util.file.Path"
|
||||
) as mock_path_class:
|
||||
|
||||
# Mock cloud storage handler
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.is_cloud_path.return_value = False
|
||||
mock_handler.retrieve_file = (
|
||||
AsyncMock()
|
||||
) # Add this even though it won't be called
|
||||
mock_handler_getter.return_value = mock_handler
|
||||
|
||||
# Mock other operations
|
||||
mock_scan.return_value = None
|
||||
mock_b64decode.return_value = b"Hello world"
|
||||
mock_uuid.return_value = "test-uuid-789"
|
||||
|
||||
# Mock file system operations
|
||||
mock_base_path = MagicMock()
|
||||
mock_target_path = MagicMock()
|
||||
mock_resolved_path = MagicMock()
|
||||
|
||||
mock_path_class.return_value = mock_base_path
|
||||
mock_base_path.mkdir = MagicMock()
|
||||
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
||||
mock_target_path.resolve.return_value = mock_resolved_path
|
||||
mock_resolved_path.is_relative_to.return_value = True
|
||||
mock_resolved_path.write_bytes = MagicMock()
|
||||
mock_resolved_path.relative_to.return_value = Path("test-uuid-789.txt")
|
||||
|
||||
await store_media_file(
|
||||
graph_exec_id,
|
||||
MediaFileType(data_uri),
|
||||
"test-user-123",
|
||||
return_content=False,
|
||||
)
|
||||
|
||||
# Verify cloud handler was checked but not used for retrieval
|
||||
mock_handler.is_cloud_path.assert_called_once_with(data_uri)
|
||||
mock_handler.retrieve_file.assert_not_called()
|
||||
|
||||
# Verify normal data URI processing occurred
|
||||
mock_b64decode.assert_called_once()
|
||||
mock_resolved_path.write_bytes.assert_called_once_with(b"Hello world")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_media_file_cloud_retrieval_error(self):
|
||||
"""Test error handling when cloud retrieval fails."""
|
||||
graph_exec_id = "test-exec-123"
|
||||
cloud_path = "gcs://test-bucket/nonexistent.txt"
|
||||
|
||||
with patch(
|
||||
"backend.util.file.get_cloud_storage_handler"
|
||||
) as mock_handler_getter:
|
||||
|
||||
# Mock cloud storage handler to raise error
|
||||
mock_handler = AsyncMock()
|
||||
mock_handler.is_cloud_path.return_value = True
|
||||
mock_handler.retrieve_file.side_effect = FileNotFoundError(
|
||||
"File not found in cloud storage"
|
||||
)
|
||||
mock_handler_getter.return_value = mock_handler
|
||||
|
||||
with pytest.raises(
|
||||
FileNotFoundError, match="File not found in cloud storage"
|
||||
):
|
||||
await store_media_file(
|
||||
graph_exec_id, MediaFileType(cloud_path), "test-user-123"
|
||||
)
|
||||
@@ -61,8 +61,7 @@ class TruncatedLogger:
|
||||
extra_msg = str(extra or "")
|
||||
text = f"{self.prefix} {msg} {extra_msg}"
|
||||
if len(text) > self.max_length:
|
||||
half = (self.max_length - 3) // 2
|
||||
text = text[:half] + "..." + text[-half:]
|
||||
text = text[: self.max_length] + "..."
|
||||
return text
|
||||
|
||||
|
||||
|
||||
@@ -85,10 +85,8 @@ func_retry = retry(
|
||||
|
||||
def continuous_retry(*, retry_delay: float = 1.0):
|
||||
def decorator(func):
|
||||
is_coroutine = asyncio.iscoroutinefunction(func)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
def wrapper(*args, **kwargs):
|
||||
while True:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
@@ -101,20 +99,6 @@ def continuous_retry(*, retry_delay: float = 1.0):
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
while True:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"%s failed with %s — retrying in %.2f s",
|
||||
func.__name__,
|
||||
exc,
|
||||
retry_delay,
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
return async_wrapper if is_coroutine else sync_wrapper
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -216,10 +216,7 @@ class AppService(BaseAppService, ABC):
|
||||
methods=["POST"],
|
||||
)
|
||||
self.fastapi_app.add_api_route(
|
||||
"/health_check", self.health_check, methods=["POST", "GET"]
|
||||
)
|
||||
self.fastapi_app.add_api_route(
|
||||
"/health_check_async", self.health_check, methods=["POST", "GET"]
|
||||
"/health_check", self.health_check, methods=["POST"]
|
||||
)
|
||||
self.fastapi_app.add_exception_handler(
|
||||
ValueError, self._handle_internal_http_error(400)
|
||||
@@ -251,9 +248,6 @@ class AppServiceClient(ABC):
|
||||
def health_check(self):
|
||||
pass
|
||||
|
||||
async def health_check_async(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
@@ -124,19 +124,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Time in seconds for how far back to check for the late executions.",
|
||||
)
|
||||
|
||||
block_error_rate_threshold: float = Field(
|
||||
default=0.5,
|
||||
description="Error rate threshold (0.0-1.0) for triggering block error alerts.",
|
||||
)
|
||||
block_error_rate_check_interval_secs: int = Field(
|
||||
default=24 * 60 * 60, # 24 hours
|
||||
description="Interval in seconds between block error rate checks.",
|
||||
)
|
||||
block_error_include_top_blocks: int = Field(
|
||||
default=3,
|
||||
description="Number of top blocks with most errors to show when no blocks exceed threshold (0 to disable).",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
extra="allow",
|
||||
@@ -281,20 +268,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Whether to enable example blocks in production",
|
||||
)
|
||||
|
||||
cloud_storage_cleanup_interval_hours: int = Field(
|
||||
default=6,
|
||||
ge=1,
|
||||
le=24,
|
||||
description="Hours between cloud storage cleanup runs (1-24 hours)",
|
||||
)
|
||||
|
||||
upload_file_size_limit_mb: int = Field(
|
||||
default=256,
|
||||
ge=1,
|
||||
le=1024,
|
||||
description="Maximum file size in MB for file uploads (1-1024 MB)",
|
||||
)
|
||||
|
||||
@field_validator("platform_base_url", "frontend_base_url")
|
||||
@classmethod
|
||||
def validate_platform_base_url(cls, v: str, info: ValidationInfo) -> str:
|
||||
@@ -329,11 +302,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="A whitelist of trusted internal endpoints for the backend to make requests to.",
|
||||
)
|
||||
|
||||
max_message_size_limit: int = Field(
|
||||
default=16 * 1024 * 1024, # 16 MB
|
||||
description="Maximum message size limit for communication with the message bus",
|
||||
)
|
||||
|
||||
backend_cors_allow_origins: List[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("backend_cors_allow_origins")
|
||||
|
||||
@@ -1,21 +1,3 @@
|
||||
"""
|
||||
Test Data Creator for AutoGPT Platform
|
||||
|
||||
This script creates test data for the AutoGPT platform database.
|
||||
|
||||
Image/Video URL Domains Used:
|
||||
- Images: picsum.photos (for all image URLs - avatars, store listing images, etc.)
|
||||
- Videos: youtube.com (for store listing video URLs)
|
||||
|
||||
Add these domains to your Next.js config:
|
||||
```javascript
|
||||
// next.config.js
|
||||
images: {
|
||||
domains: ['picsum.photos'],
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from datetime import datetime
|
||||
@@ -32,7 +14,6 @@ from prisma.types import (
|
||||
AnalyticsMetricsCreateInput,
|
||||
APIKeyCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
IntegrationWebhookCreateInput,
|
||||
ProfileCreateInput,
|
||||
StoreListingReviewCreateInput,
|
||||
UserCreateInput,
|
||||
@@ -72,26 +53,10 @@ MAX_REVIEWS_PER_VERSION = 5 # Total reviews depends on number of versions creat
|
||||
|
||||
|
||||
def get_image():
|
||||
"""Generate a consistent image URL using picsum.photos service."""
|
||||
width = random.choice([200, 300, 400, 500, 600, 800])
|
||||
height = random.choice([200, 300, 400, 500, 600, 800])
|
||||
# Use a random seed to get different images
|
||||
seed = random.randint(1, 1000)
|
||||
return f"https://picsum.photos/seed/{seed}/{width}/{height}"
|
||||
|
||||
|
||||
def get_video_url():
|
||||
"""Generate a consistent video URL using a placeholder service."""
|
||||
# Using YouTube as a consistent source for video URLs
|
||||
video_ids = [
|
||||
"dQw4w9WgXcQ", # Example video IDs
|
||||
"9bZkp7q19f0",
|
||||
"kJQP7kiw5Fk",
|
||||
"RgKAFK5djSk",
|
||||
"L_jWHffIx5E",
|
||||
]
|
||||
video_id = random.choice(video_ids)
|
||||
return f"https://www.youtube.com/watch?v={video_id}"
|
||||
url = faker.image_url()
|
||||
while "placekitten.com" in url:
|
||||
url = faker.image_url()
|
||||
return url
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -182,27 +147,12 @@ async def main():
|
||||
)
|
||||
agent_presets.append(preset)
|
||||
|
||||
# Insert Profiles first (before LibraryAgents)
|
||||
profiles = []
|
||||
print(f"Inserting {NUM_USERS} profiles")
|
||||
for user in users:
|
||||
profile = await db.profile.create(
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
name=user.name or faker.name(),
|
||||
username=faker.unique.user_name(),
|
||||
description=faker.text(),
|
||||
links=[faker.url() for _ in range(3)],
|
||||
avatarUrl=get_image(),
|
||||
)
|
||||
)
|
||||
profiles.append(profile)
|
||||
|
||||
# Insert LibraryAgents
|
||||
library_agents = []
|
||||
print("Inserting library agents")
|
||||
# Insert UserAgents
|
||||
user_agents = []
|
||||
print(f"Inserting {NUM_USERS * MAX_AGENTS_PER_USER} user agents")
|
||||
for user in users:
|
||||
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
|
||||
|
||||
# Get a shuffled list of graphs to ensure uniqueness per user
|
||||
available_graphs = agent_graphs.copy()
|
||||
random.shuffle(available_graphs)
|
||||
@@ -212,27 +162,18 @@ async def main():
|
||||
|
||||
for i in range(num_agents):
|
||||
graph = available_graphs[i] # Use unique graph for each library agent
|
||||
|
||||
# Get creator profile for this graph's owner
|
||||
creator_profile = next(
|
||||
(p for p in profiles if p.userId == graph.userId), None
|
||||
)
|
||||
|
||||
library_agent = await db.libraryagent.create(
|
||||
user_agent = await db.libraryagent.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"creatorId": creator_profile.id if creator_profile else None,
|
||||
"imageUrl": get_image() if random.random() < 0.5 else None,
|
||||
"useGraphIsActiveVersion": random.choice([True, False]),
|
||||
"isFavorite": random.choice([True, False]),
|
||||
"isCreatedByUser": random.choice([True, False]),
|
||||
"isArchived": random.choice([True, False]),
|
||||
"isDeleted": random.choice([True, False]),
|
||||
}
|
||||
)
|
||||
library_agents.append(library_agent)
|
||||
user_agents.append(user_agent)
|
||||
|
||||
# Insert AgentGraphExecutions
|
||||
agent_graph_executions = []
|
||||
@@ -384,9 +325,25 @@ async def main():
|
||||
)
|
||||
)
|
||||
|
||||
# Insert Profiles
|
||||
profiles = []
|
||||
print(f"Inserting {NUM_USERS} profiles")
|
||||
for user in users:
|
||||
profile = await db.profile.create(
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
name=user.name or faker.name(),
|
||||
username=faker.unique.user_name(),
|
||||
description=faker.text(),
|
||||
links=[faker.url() for _ in range(3)],
|
||||
avatarUrl=get_image(),
|
||||
)
|
||||
)
|
||||
profiles.append(profile)
|
||||
|
||||
# Insert StoreListings
|
||||
store_listings = []
|
||||
print("Inserting store listings")
|
||||
print(f"Inserting {NUM_USERS} store listings")
|
||||
for graph in agent_graphs:
|
||||
user = random.choice(users)
|
||||
slug = faker.slug()
|
||||
@@ -403,7 +360,7 @@ async def main():
|
||||
|
||||
# Insert StoreListingVersions
|
||||
store_listing_versions = []
|
||||
print("Inserting store listing versions")
|
||||
print(f"Inserting {NUM_USERS} store listing versions")
|
||||
for listing in store_listings:
|
||||
graph = [g for g in agent_graphs if g.id == listing.agentGraphId][0]
|
||||
version = await db.storelistingversion.create(
|
||||
@@ -412,7 +369,7 @@ async def main():
|
||||
"agentGraphVersion": graph.version,
|
||||
"name": graph.name or faker.sentence(nb_words=3),
|
||||
"subHeading": faker.sentence(),
|
||||
"videoUrl": get_video_url() if random.random() < 0.3 else None,
|
||||
"videoUrl": faker.url(),
|
||||
"imageUrls": [get_image() for _ in range(3)],
|
||||
"description": faker.text(),
|
||||
"categories": [faker.word() for _ in range(3)],
|
||||
@@ -431,7 +388,7 @@ async def main():
|
||||
store_listing_versions.append(version)
|
||||
|
||||
# Insert StoreListingReviews
|
||||
print("Inserting store listing reviews")
|
||||
print(f"Inserting {NUM_USERS * MAX_REVIEWS_PER_VERSION} store listing reviews")
|
||||
for version in store_listing_versions:
|
||||
# Create a copy of users list and shuffle it to avoid duplicates
|
||||
available_reviewers = users.copy()
|
||||
@@ -454,92 +411,26 @@ async def main():
|
||||
)
|
||||
)
|
||||
|
||||
# Insert UserOnboarding for some users
|
||||
print("Inserting user onboarding data")
|
||||
for user in random.sample(
|
||||
users, k=int(NUM_USERS * 0.7)
|
||||
): # 70% of users have onboarding data
|
||||
completed_steps = []
|
||||
possible_steps = list(prisma.enums.OnboardingStep)
|
||||
# Randomly complete some steps
|
||||
if random.random() < 0.8:
|
||||
num_steps = random.randint(1, len(possible_steps))
|
||||
completed_steps = random.sample(possible_steps, k=num_steps)
|
||||
|
||||
try:
|
||||
await db.useronboarding.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"completedSteps": completed_steps,
|
||||
"notificationDot": random.choice([True, False]),
|
||||
"notified": (
|
||||
random.sample(completed_steps, k=min(3, len(completed_steps)))
|
||||
if completed_steps
|
||||
else []
|
||||
),
|
||||
"rewardedFor": (
|
||||
random.sample(completed_steps, k=min(2, len(completed_steps)))
|
||||
if completed_steps
|
||||
else []
|
||||
),
|
||||
"usageReason": (
|
||||
random.choice(["personal", "business", "research", "learning"])
|
||||
if random.random() < 0.7
|
||||
else None
|
||||
),
|
||||
"integrations": random.sample(
|
||||
["github", "google", "discord", "slack"], k=random.randint(0, 2)
|
||||
),
|
||||
"otherIntegrations": (
|
||||
faker.word() if random.random() < 0.2 else None
|
||||
),
|
||||
"selectedStoreListingVersionId": (
|
||||
random.choice(store_listing_versions).id
|
||||
if store_listing_versions and random.random() < 0.5
|
||||
else None
|
||||
),
|
||||
"agentInput": (
|
||||
Json({"test": "data"}) if random.random() < 0.3 else None
|
||||
),
|
||||
"onboardingAgentExecutionId": (
|
||||
random.choice(agent_graph_executions).id
|
||||
if agent_graph_executions and random.random() < 0.3
|
||||
else None
|
||||
),
|
||||
"agentRuns": random.randint(0, 10),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error creating onboarding for user {user.id}: {e}")
|
||||
# Try simpler version
|
||||
await db.useronboarding.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
}
|
||||
)
|
||||
|
||||
# Insert IntegrationWebhooks for some users
|
||||
print("Inserting integration webhooks")
|
||||
for user in random.sample(
|
||||
users, k=int(NUM_USERS * 0.3)
|
||||
): # 30% of users have webhooks
|
||||
for _ in range(random.randint(1, 3)):
|
||||
await db.integrationwebhook.create(
|
||||
data=IntegrationWebhookCreateInput(
|
||||
userId=user.id,
|
||||
provider=random.choice(["github", "slack", "discord"]),
|
||||
credentialsId=str(faker.uuid4()),
|
||||
webhookType=random.choice(["repo", "channel", "server"]),
|
||||
resource=faker.slug(),
|
||||
events=[
|
||||
random.choice(["created", "updated", "deleted"])
|
||||
for _ in range(random.randint(1, 3))
|
||||
],
|
||||
config=prisma.Json({"url": faker.url()}),
|
||||
secret=str(faker.sha256()),
|
||||
providerWebhookId=str(faker.uuid4()),
|
||||
)
|
||||
)
|
||||
# Update StoreListingVersions with submission status (StoreListingSubmissions table no longer exists)
|
||||
print(f"Updating {NUM_USERS} store listing versions with submission status")
|
||||
for version in store_listing_versions:
|
||||
reviewer = random.choice(users)
|
||||
status: prisma.enums.SubmissionStatus = random.choice(
|
||||
[
|
||||
prisma.enums.SubmissionStatus.PENDING,
|
||||
prisma.enums.SubmissionStatus.APPROVED,
|
||||
prisma.enums.SubmissionStatus.REJECTED,
|
||||
]
|
||||
)
|
||||
await db.storelistingversion.update(
|
||||
where={"id": version.id},
|
||||
data={
|
||||
"submissionStatus": status,
|
||||
"Reviewer": {"connect": {"id": reviewer.id}},
|
||||
"reviewComments": faker.text(),
|
||||
"reviewedAt": datetime.now(),
|
||||
},
|
||||
)
|
||||
|
||||
# Insert APIKeys
|
||||
print(f"Inserting {NUM_USERS} api keys")
|
||||
@@ -560,12 +451,7 @@ async def main():
|
||||
)
|
||||
)
|
||||
|
||||
# Refresh materialized views
|
||||
print("Refreshing materialized views...")
|
||||
await db.execute_raw("SELECT refresh_store_materialized_views();")
|
||||
|
||||
await db.disconnect()
|
||||
print("Test data creation completed successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -1,128 +0,0 @@
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# String helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _truncate_string_middle(value: str, limit: int) -> str:
|
||||
"""Shorten *value* to *limit* chars by removing the **middle** portion."""
|
||||
|
||||
if len(value) <= limit:
|
||||
return value
|
||||
|
||||
head_len = max(1, limit // 2)
|
||||
tail_len = limit - head_len # ensures total == limit
|
||||
omitted = len(value) - (head_len + tail_len)
|
||||
return f"{value[:head_len]}… (omitted {omitted} chars)…{value[-tail_len:]}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# List helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _truncate_list_middle(lst: list[Any], str_lim: int, list_lim: int) -> list[Any]:
|
||||
"""Return *lst* truncated to *list_lim* items, removing from the middle.
|
||||
|
||||
Each retained element is itself recursively truncated via
|
||||
:func:`_truncate_value` so we don’t blow the budget with long strings nested
|
||||
inside.
|
||||
"""
|
||||
|
||||
if len(lst) <= list_lim:
|
||||
return [_truncate_value(v, str_lim, list_lim) for v in lst]
|
||||
|
||||
# If the limit is very small (<3) fall back to head‑only + sentinel to avoid
|
||||
# degenerate splits.
|
||||
if list_lim < 3:
|
||||
kept = [_truncate_value(v, str_lim, list_lim) for v in lst[:list_lim]]
|
||||
kept.append(f"… (omitted {len(lst) - list_lim} items)…")
|
||||
return kept
|
||||
|
||||
head_len = list_lim // 2
|
||||
tail_len = list_lim - head_len
|
||||
|
||||
head = [_truncate_value(v, str_lim, list_lim) for v in lst[:head_len]]
|
||||
tail = [_truncate_value(v, str_lim, list_lim) for v in lst[-tail_len:]]
|
||||
|
||||
omitted = len(lst) - (head_len + tail_len)
|
||||
sentinel = f"… (omitted {omitted} items)…"
|
||||
return head + [sentinel] + tail
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recursive truncation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _truncate_value(value: Any, str_limit: int, list_limit: int) -> Any:
|
||||
"""Recursively truncate *value* using the current per‑type limits."""
|
||||
|
||||
if isinstance(value, str):
|
||||
return _truncate_string_middle(value, str_limit)
|
||||
|
||||
if isinstance(value, list):
|
||||
return _truncate_list_middle(value, str_limit, list_limit)
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {k: _truncate_value(v, str_limit, list_limit) for k, v in value.items()}
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def truncate(value: Any, size_limit: int) -> Any:
|
||||
"""
|
||||
Truncate the given value (recursively) so that its string representation
|
||||
does not exceed size_limit characters. Uses binary search to find the
|
||||
largest str_limit and list_limit that fit.
|
||||
"""
|
||||
|
||||
def measure(val):
|
||||
try:
|
||||
return len(str(val))
|
||||
except Exception:
|
||||
return sys.getsizeof(val)
|
||||
|
||||
# Reasonable bounds for string and list limits
|
||||
STR_MIN, STR_MAX = 8, 2**16
|
||||
LIST_MIN, LIST_MAX = 1, 2**12
|
||||
|
||||
# Binary search for the largest str_limit and list_limit that fit
|
||||
best = None
|
||||
|
||||
# We'll search str_limit first, then list_limit, but can do both together
|
||||
# For practical purposes, do a grid search with binary search on str_limit for each list_limit
|
||||
# (since lists are usually the main source of bloat)
|
||||
# We'll do binary search on list_limit, and for each, binary search on str_limit
|
||||
|
||||
# Outer binary search on list_limit
|
||||
l_lo, l_hi = LIST_MIN, LIST_MAX
|
||||
while l_lo <= l_hi:
|
||||
l_mid = (l_lo + l_hi) // 2
|
||||
|
||||
# Inner binary search on str_limit
|
||||
s_lo, s_hi = STR_MIN, STR_MAX
|
||||
local_best = None
|
||||
while s_lo <= s_hi:
|
||||
s_mid = (s_lo + s_hi) // 2
|
||||
truncated = _truncate_value(value, s_mid, l_mid)
|
||||
size = measure(truncated)
|
||||
if size <= size_limit:
|
||||
local_best = truncated
|
||||
s_lo = s_mid + 1 # try to increase str_limit
|
||||
else:
|
||||
s_hi = s_mid - 1 # decrease str_limit
|
||||
|
||||
if local_best is not None:
|
||||
best = local_best
|
||||
l_lo = l_mid + 1 # try to increase list_limit
|
||||
else:
|
||||
l_hi = l_mid - 1 # decrease list_limit
|
||||
|
||||
# If nothing fits, fall back to the most aggressive truncation
|
||||
if best is None:
|
||||
best = _truncate_value(value, STR_MIN, LIST_MIN)
|
||||
|
||||
return best
|
||||
@@ -1,101 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Clean the test database by removing all data while preserving the schema.
|
||||
|
||||
Usage:
|
||||
poetry run python clean_test_db.py [--yes]
|
||||
|
||||
Options:
|
||||
--yes Skip confirmation prompt
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from prisma import Prisma
|
||||
|
||||
|
||||
async def main():
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
|
||||
print("=" * 60)
|
||||
print("Cleaning Test Database")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Get initial counts
|
||||
user_count = await db.user.count()
|
||||
agent_count = await db.agentgraph.count()
|
||||
|
||||
print(f"Current data: {user_count} users, {agent_count} agent graphs")
|
||||
|
||||
if user_count == 0 and agent_count == 0:
|
||||
print("Database is already clean!")
|
||||
await db.disconnect()
|
||||
return
|
||||
|
||||
# Check for --yes flag
|
||||
skip_confirm = "--yes" in sys.argv
|
||||
|
||||
if not skip_confirm:
|
||||
response = input("\nDo you want to clean all data? (yes/no): ")
|
||||
if response.lower() != "yes":
|
||||
print("Aborted.")
|
||||
await db.disconnect()
|
||||
return
|
||||
|
||||
print("\nCleaning database...")
|
||||
|
||||
# Delete in reverse order of dependencies
|
||||
tables = [
|
||||
("UserNotificationBatch", db.usernotificationbatch),
|
||||
("NotificationEvent", db.notificationevent),
|
||||
("CreditRefundRequest", db.creditrefundrequest),
|
||||
("StoreListingReview", db.storelistingreview),
|
||||
("StoreListingVersion", db.storelistingversion),
|
||||
("StoreListing", db.storelisting),
|
||||
("AgentNodeExecutionInputOutput", db.agentnodeexecutioninputoutput),
|
||||
("AgentNodeExecution", db.agentnodeexecution),
|
||||
("AgentGraphExecution", db.agentgraphexecution),
|
||||
("AgentNodeLink", db.agentnodelink),
|
||||
("LibraryAgent", db.libraryagent),
|
||||
("AgentPreset", db.agentpreset),
|
||||
("IntegrationWebhook", db.integrationwebhook),
|
||||
("AgentNode", db.agentnode),
|
||||
("AgentGraph", db.agentgraph),
|
||||
("AgentBlock", db.agentblock),
|
||||
("APIKey", db.apikey),
|
||||
("CreditTransaction", db.credittransaction),
|
||||
("AnalyticsMetrics", db.analyticsmetrics),
|
||||
("AnalyticsDetails", db.analyticsdetails),
|
||||
("Profile", db.profile),
|
||||
("UserOnboarding", db.useronboarding),
|
||||
("User", db.user),
|
||||
]
|
||||
|
||||
for table_name, table in tables:
|
||||
try:
|
||||
count = await table.count()
|
||||
if count > 0:
|
||||
await table.delete_many()
|
||||
print(f"✓ Deleted {count} records from {table_name}")
|
||||
except Exception as e:
|
||||
print(f"⚠ Error cleaning {table_name}: {e}")
|
||||
|
||||
# Refresh materialized views (they should be empty now)
|
||||
try:
|
||||
await db.execute_raw("SELECT refresh_store_materialized_views();")
|
||||
print("\n✓ Refreshed materialized views")
|
||||
except Exception as e:
|
||||
print(f"\n⚠ Could not refresh materialized views: {e}")
|
||||
|
||||
await db.disconnect()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Database cleaned successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user