mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-13 00:58:16 -05:00
Compare commits
38 Commits
fix/blocks
...
feat/agent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee11623735 | ||
|
|
0bb160e930 | ||
|
|
81a09738dc | ||
|
|
6feedafd7d | ||
|
|
547da633c4 | ||
|
|
fde3533943 | ||
|
|
a789f87734 | ||
|
|
0b6e46d363 | ||
|
|
6ffe57c3df | ||
|
|
3ca0d04ea0 | ||
|
|
c2eea593c0 | ||
|
|
6d13dfc688 | ||
|
|
36f5f24333 | ||
|
|
d0d498fa66 | ||
|
|
c843dee317 | ||
|
|
db969c1bf8 | ||
|
|
690fac91e4 | ||
|
|
309114a727 | ||
|
|
5368fdc998 | ||
|
|
b9d293f181 | ||
|
|
acbcef77b2 | ||
|
|
4ffb99bfb0 | ||
|
|
e902848e04 | ||
|
|
cd917ec919 | ||
|
|
5741331250 | ||
|
|
2fda8dfd32 | ||
|
|
22c76eab61 | ||
|
|
7688a9701e | ||
|
|
8ae37491e4 | ||
|
|
243400e128 | ||
|
|
c77cb1fcfb | ||
|
|
b3b5eefe2c | ||
|
|
f45e5e0d59 | ||
|
|
1231236d87 | ||
|
|
4db0792ade | ||
|
|
81cb6fb1e6 | ||
|
|
c16598eed6 | ||
|
|
7706740308 |
1
.github/workflows/platform-frontend-ci.yml
vendored
1
.github/workflows/platform-frontend-ci.yml
vendored
@@ -148,6 +148,7 @@ jobs:
|
||||
onlyChanged: true
|
||||
workingDir: autogpt_platform/frontend
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
exitOnceUploaded: true
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# AutoGPT: Build, Deploy, and Run AI Agents
|
||||
|
||||
[](https://discord.gg/autogpt)  
|
||||
[](https://discord.gg/autogpt)  
|
||||
[](https://twitter.com/Auto_GPT)  
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
**AutoGPT** is a powerful platform that allows you to create, deploy, and manage continuous AI agents that automate complex workflows.
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
@@ -144,4 +143,4 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
- Applied to both main API server and external API applications
|
||||
|
||||
@@ -199,9 +199,18 @@ ZEROBOUNCE_API_KEY=
|
||||
|
||||
## ===== OPTIONAL API KEYS END ===== ##
|
||||
|
||||
# Block Error Rate Monitoring
|
||||
BLOCK_ERROR_RATE_THRESHOLD=0.5
|
||||
BLOCK_ERROR_RATE_CHECK_INTERVAL_SECS=86400
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
ENABLE_CLOUD_LOGGING=false
|
||||
ENABLE_FILE_LOGGING=false
|
||||
# Use to manually set the log directory
|
||||
# LOG_DIR=./logs
|
||||
|
||||
# Example Blocks Configuration
|
||||
# Set to true to enable example blocks in development
|
||||
# These blocks are disabled by default in production
|
||||
ENABLE_EXAMPLE_BLOCKS=false
|
||||
|
||||
150
autogpt_platform/backend/backend/TEST_DATA_README.md
Normal file
150
autogpt_platform/backend/backend/TEST_DATA_README.md
Normal file
@@ -0,0 +1,150 @@
|
||||
# Test Data Scripts
|
||||
|
||||
This directory contains scripts for creating and updating test data in the AutoGPT Platform database, specifically designed to test the materialized views for the store functionality.
|
||||
|
||||
## Scripts
|
||||
|
||||
### test_data_creator.py
|
||||
Creates a comprehensive set of test data including:
|
||||
- Users with profiles
|
||||
- Agent graphs, nodes, and executions
|
||||
- Store listings with multiple versions
|
||||
- Reviews and ratings
|
||||
- Library agents
|
||||
- Integration webhooks
|
||||
- Onboarding data
|
||||
- Credit transactions
|
||||
|
||||
**Image/Video Domains Used:**
|
||||
- Images: `picsum.photos` (for all image URLs)
|
||||
- Videos: `youtube.com` (for store listing videos)
|
||||
|
||||
### test_data_updater.py
|
||||
Updates existing test data to simulate real-world changes:
|
||||
- Adds new agent graph executions
|
||||
- Creates new store listing reviews
|
||||
- Updates store listing versions
|
||||
- Adds credit transactions
|
||||
- Refreshes materialized views
|
||||
|
||||
### check_db.py
|
||||
Tests and verifies materialized views functionality:
|
||||
- Checks pg_cron job status (for automatic refresh)
|
||||
- Displays current materialized view counts
|
||||
- Adds test data (executions and reviews)
|
||||
- Creates store listings if none exist
|
||||
- Manually refreshes materialized views
|
||||
- Compares before/after counts to verify updates
|
||||
- Provides a summary of test results
|
||||
|
||||
## Materialized Views
|
||||
|
||||
The scripts test three key database views:
|
||||
|
||||
1. **mv_agent_run_counts**: Tracks execution counts by agent
|
||||
2. **mv_review_stats**: Tracks review statistics (count, average rating) by store listing
|
||||
3. **StoreAgent**: A view that combines store listing data with execution counts and ratings for display
|
||||
|
||||
The materialized views (mv_agent_run_counts and mv_review_stats) are automatically refreshed every 15 minutes via pg_cron, or can be manually refreshed using the `refresh_store_materialized_views()` function.
|
||||
|
||||
## Usage
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. Ensure the database is running:
|
||||
```bash
|
||||
docker compose up -d
|
||||
# or for test database:
|
||||
docker compose -f docker-compose.test.yaml --env-file ../.env up -d
|
||||
```
|
||||
|
||||
2. Run database migrations:
|
||||
```bash
|
||||
poetry run prisma migrate deploy
|
||||
```
|
||||
|
||||
### Running the Scripts
|
||||
|
||||
#### Option 1: Use the helper script (from backend directory)
|
||||
```bash
|
||||
poetry run python run_test_data.py
|
||||
```
|
||||
|
||||
#### Option 2: Run individually
|
||||
```bash
|
||||
# From backend/test directory:
|
||||
# Create initial test data
|
||||
poetry run python test_data_creator.py
|
||||
|
||||
# Update data to test materialized view changes
|
||||
poetry run python test_data_updater.py
|
||||
|
||||
# From backend directory:
|
||||
# Test materialized views functionality
|
||||
poetry run python check_db.py
|
||||
|
||||
# Check store data status
|
||||
poetry run python check_store_data.py
|
||||
```
|
||||
|
||||
#### Option 3: Use the shell script (from backend directory)
|
||||
```bash
|
||||
./run_test_data_scripts.sh
|
||||
```
|
||||
|
||||
### Manual Materialized View Refresh
|
||||
|
||||
To manually refresh the materialized views:
|
||||
```sql
|
||||
SELECT refresh_store_materialized_views();
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The scripts use the database configuration from your `.env` file:
|
||||
- `DATABASE_URL`: PostgreSQL connection string
|
||||
- Database should have the platform schema
|
||||
|
||||
## Data Generation Limits
|
||||
|
||||
Configured in `test_data_creator.py`:
|
||||
- 100 users
|
||||
- 100 agent blocks
|
||||
- 1-5 graphs per user
|
||||
- 2-5 nodes per graph
|
||||
- 1-5 presets per user
|
||||
- 1-10 library agents per user
|
||||
- 1-20 executions per graph
|
||||
- 1-5 reviews per store listing version
|
||||
|
||||
## Notes
|
||||
|
||||
- All image URLs use `picsum.photos` for consistency with Next.js image configuration
|
||||
- The scripts create realistic relationships between entities
|
||||
- Materialized views are refreshed at the end of each script
|
||||
- Data is designed to test both happy paths and edge cases
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Reviews and StoreAgent view showing 0
|
||||
|
||||
If `check_db.py` shows that reviews remain at 0 and StoreAgent view shows 0 store agents:
|
||||
|
||||
1. **No store listings exist**: The script will automatically create test store listings if none exist
|
||||
2. **No approved versions**: Store listings need approved versions to appear in the StoreAgent view
|
||||
3. **Check with `check_store_data.py`**: This script provides detailed information about:
|
||||
- Total store listings
|
||||
- Store listing versions by status
|
||||
- Existing reviews
|
||||
- StoreAgent view contents
|
||||
- Agent graph executions
|
||||
|
||||
### pg_cron not installed
|
||||
|
||||
The warning "pg_cron extension is not installed" is normal in local development environments. The materialized views can still be refreshed manually using the `refresh_store_materialized_views()` function, which all scripts do automatically.
|
||||
|
||||
### Common Issues
|
||||
|
||||
- **Type errors with None values**: Fixed in the latest version of check_db.py by using `or 0` for nullable numeric fields
|
||||
- **Missing relations**: Ensure you're using the correct field names (e.g., `StoreListing` not `storeListing` in includes)
|
||||
- **Column name mismatches**: The database uses camelCase for column names (e.g., `agentGraphId` not `agent_graph_id`)
|
||||
@@ -14,14 +14,27 @@ T = TypeVar("T")
|
||||
@functools.cache
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
from backend.util.settings import Config
|
||||
|
||||
# Check if example blocks should be loaded from settings
|
||||
config = Config()
|
||||
load_examples = config.enable_example_blocks
|
||||
|
||||
# Dynamically load all modules under backend.blocks
|
||||
current_dir = Path(__file__).parent
|
||||
modules = [
|
||||
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
|
||||
for f in current_dir.rglob("*.py")
|
||||
if f.is_file() and f.name != "__init__.py" and not f.name.startswith("test_")
|
||||
]
|
||||
modules = []
|
||||
for f in current_dir.rglob("*.py"):
|
||||
if not f.is_file() or f.name == "__init__.py" or f.name.startswith("test_"):
|
||||
continue
|
||||
|
||||
# Skip examples directory if not enabled
|
||||
relative_path = f.relative_to(current_dir)
|
||||
if not load_examples and relative_path.parts[0] == "examples":
|
||||
continue
|
||||
|
||||
module_path = str(relative_path)[:-3].replace(os.path.sep, ".")
|
||||
modules.append(module_path)
|
||||
|
||||
for module in modules:
|
||||
if not re.match("^[a-z0-9_.]+$", module):
|
||||
raise ValueError(
|
||||
|
||||
@@ -14,7 +14,7 @@ from backend.data.block import (
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json, retry
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
@@ -151,6 +151,12 @@ class AgentExecutorBlock(Block):
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
|
||||
# we can stop listening for further events.
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
extra_cost=event.stats.cost if event.stats else 0,
|
||||
extra_steps=event.stats.node_exec_count if event.stats else 0,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
ExaCredentials = APIKeyCredentials
|
||||
ExaCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.EXA],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="exa",
|
||||
api_key=SecretStr("mock-exa-api-key"),
|
||||
title="Mock Exa API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def ExaCredentialsField() -> ExaCredentialsInput:
|
||||
"""Creates an Exa credentials input on a block."""
|
||||
return CredentialsField(description="The Exa integration requires an API Key.")
|
||||
16
autogpt_platform/backend/backend/blocks/exa/_config.py
Normal file
16
autogpt_platform/backend/backend/blocks/exa/_config.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Shared configuration for all Exa blocks using the new SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
from ._webhook import ExaWebhookManager
|
||||
|
||||
# Configure the Exa provider once for all blocks
|
||||
exa = (
|
||||
ProviderBuilder("exa")
|
||||
.with_api_key("EXA_API_KEY", "Exa API Key")
|
||||
.with_webhook_manager(ExaWebhookManager)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
134
autogpt_platform/backend/backend/blocks/exa/_webhook.py
Normal file
134
autogpt_platform/backend/backend/blocks/exa/_webhook.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Exa Webhook Manager implementation.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
from enum import Enum
|
||||
|
||||
from backend.data.model import Credentials
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseWebhooksManager,
|
||||
ProviderName,
|
||||
Requests,
|
||||
Webhook,
|
||||
)
|
||||
|
||||
|
||||
class ExaWebhookType(str, Enum):
|
||||
"""Available webhook types for Exa."""
|
||||
|
||||
WEBSET = "webset"
|
||||
|
||||
|
||||
class ExaEventType(str, Enum):
|
||||
"""Available event types for Exa webhooks."""
|
||||
|
||||
WEBSET_CREATED = "webset.created"
|
||||
WEBSET_DELETED = "webset.deleted"
|
||||
WEBSET_PAUSED = "webset.paused"
|
||||
WEBSET_IDLE = "webset.idle"
|
||||
WEBSET_SEARCH_CREATED = "webset.search.created"
|
||||
WEBSET_SEARCH_CANCELED = "webset.search.canceled"
|
||||
WEBSET_SEARCH_COMPLETED = "webset.search.completed"
|
||||
WEBSET_SEARCH_UPDATED = "webset.search.updated"
|
||||
IMPORT_CREATED = "import.created"
|
||||
IMPORT_COMPLETED = "import.completed"
|
||||
IMPORT_PROCESSING = "import.processing"
|
||||
WEBSET_ITEM_CREATED = "webset.item.created"
|
||||
WEBSET_ITEM_ENRICHED = "webset.item.enriched"
|
||||
WEBSET_EXPORT_CREATED = "webset.export.created"
|
||||
WEBSET_EXPORT_COMPLETED = "webset.export.completed"
|
||||
|
||||
|
||||
class ExaWebhookManager(BaseWebhooksManager):
|
||||
"""Webhook manager for Exa API."""
|
||||
|
||||
PROVIDER_NAME = ProviderName("exa")
|
||||
|
||||
class WebhookType(str, Enum):
|
||||
WEBSET = "webset"
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(cls, webhook: Webhook, request) -> tuple[dict, str]:
|
||||
"""Validate incoming webhook payload and signature."""
|
||||
payload = await request.json()
|
||||
|
||||
# Get event type from payload
|
||||
event_type = payload.get("eventType", "unknown")
|
||||
|
||||
# Verify webhook signature if secret is available
|
||||
if webhook.secret:
|
||||
signature = request.headers.get("X-Exa-Signature")
|
||||
if signature:
|
||||
# Compute expected signature
|
||||
body = await request.body()
|
||||
expected_signature = hmac.new(
|
||||
webhook.secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Compare signatures
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
raise ValueError("Invalid webhook signature")
|
||||
|
||||
return payload, event_type
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> tuple[str, dict]:
|
||||
"""Register webhook with Exa API."""
|
||||
if not isinstance(credentials, APIKeyCredentials):
|
||||
raise ValueError("Exa webhooks require API key credentials")
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Create webhook via Exa API
|
||||
response = await Requests().post(
|
||||
"https://api.exa.ai/v0/webhooks",
|
||||
headers={"x-api-key": api_key},
|
||||
json={
|
||||
"url": ingress_url,
|
||||
"events": events,
|
||||
"metadata": {
|
||||
"resource": resource,
|
||||
"webhook_type": webhook_type,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
error_data = response.json()
|
||||
raise Exception(f"Failed to create Exa webhook: {error_data}")
|
||||
|
||||
webhook_data = response.json()
|
||||
|
||||
# Store the secret returned by Exa
|
||||
return webhook_data["id"], {
|
||||
"events": events,
|
||||
"resource": resource,
|
||||
"exa_secret": webhook_data.get("secret"),
|
||||
}
|
||||
|
||||
async def _deregister_webhook(
|
||||
self, webhook: Webhook, credentials: Credentials
|
||||
) -> None:
|
||||
"""Deregister webhook from Exa API."""
|
||||
if not isinstance(credentials, APIKeyCredentials):
|
||||
raise ValueError("Exa webhooks require API key credentials")
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Delete webhook via Exa API
|
||||
response = await Requests().delete(
|
||||
f"https://api.exa.ai/v0/webhooks/{webhook.provider_webhook_id}",
|
||||
headers={"x-api-key": api_key},
|
||||
)
|
||||
|
||||
if not response.ok and response.status != 404:
|
||||
error_data = response.json()
|
||||
raise Exception(f"Failed to delete Exa webhook: {error_data}")
|
||||
124
autogpt_platform/backend/backend/blocks/exa/answers.py
Normal file
124
autogpt_platform/backend/backend/blocks/exa/answers.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseModel,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
class CostBreakdown(BaseModel):
|
||||
keywordSearch: float
|
||||
neuralSearch: float
|
||||
contentText: float
|
||||
contentHighlight: float
|
||||
contentSummary: float
|
||||
|
||||
|
||||
class SearchBreakdown(BaseModel):
|
||||
search: float
|
||||
contents: float
|
||||
breakdown: CostBreakdown
|
||||
|
||||
|
||||
class PerRequestPrices(BaseModel):
|
||||
neuralSearch_1_25_results: float
|
||||
neuralSearch_26_100_results: float
|
||||
neuralSearch_100_plus_results: float
|
||||
keywordSearch_1_100_results: float
|
||||
keywordSearch_100_plus_results: float
|
||||
|
||||
|
||||
class PerPagePrices(BaseModel):
|
||||
contentText: float
|
||||
contentHighlight: float
|
||||
contentSummary: float
|
||||
|
||||
|
||||
class CostDollars(BaseModel):
|
||||
total: float
|
||||
breakDown: list[SearchBreakdown]
|
||||
perRequestPrices: PerRequestPrices
|
||||
perPagePrices: PerPagePrices
|
||||
|
||||
|
||||
class ExaAnswerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="The question or query to answer",
|
||||
placeholder="What is the latest valuation of SpaceX?",
|
||||
)
|
||||
text: bool = SchemaField(
|
||||
default=False,
|
||||
description="If true, the response includes full text content in the search results",
|
||||
advanced=True,
|
||||
)
|
||||
model: str = SchemaField(
|
||||
default="exa",
|
||||
description="The search model to use (exa or exa-pro)",
|
||||
placeholder="exa",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
answer: str = SchemaField(
|
||||
description="The generated answer based on search results"
|
||||
)
|
||||
citations: list[dict] = SchemaField(
|
||||
description="Search results used to generate the answer",
|
||||
default_factory=list,
|
||||
)
|
||||
cost_dollars: CostDollars = SchemaField(
|
||||
description="Cost breakdown of the request"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b79ca4cc-9d5e-47d1-9d4f-e3a2d7f28df5",
|
||||
description="Get an LLM answer to a question informed by Exa search results",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.AI},
|
||||
input_schema=ExaAnswerBlock.Input,
|
||||
output_schema=ExaAnswerBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/answer"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
# Build the payload
|
||||
payload = {
|
||||
"query": input_data.query,
|
||||
"text": input_data.text,
|
||||
"model": input_data.model,
|
||||
}
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
yield "answer", data.get("answer", "")
|
||||
yield "citations", data.get("citations", [])
|
||||
yield "cost_dollars", data.get("costDollars", {})
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "answer", ""
|
||||
yield "citations", []
|
||||
yield "cost_dollars", {}
|
||||
@@ -1,57 +1,39 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.blocks.exa._auth import (
|
||||
ExaCredentials,
|
||||
ExaCredentialsField,
|
||||
ExaCredentialsInput,
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
class ContentRetrievalSettings(BaseModel):
|
||||
text: dict = SchemaField(
|
||||
description="Text content settings",
|
||||
default={"maxCharacters": 1000, "includeHtmlTags": False},
|
||||
advanced=True,
|
||||
)
|
||||
highlights: dict = SchemaField(
|
||||
description="Highlight settings",
|
||||
default={
|
||||
"numSentences": 3,
|
||||
"highlightsPerUrl": 3,
|
||||
"query": "",
|
||||
},
|
||||
advanced=True,
|
||||
)
|
||||
summary: dict = SchemaField(
|
||||
description="Summary settings",
|
||||
default={"query": ""},
|
||||
advanced=True,
|
||||
)
|
||||
from ._config import exa
|
||||
from .helpers import ContentSettings
|
||||
|
||||
|
||||
class ExaContentsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
ids: List[str] = SchemaField(
|
||||
description="Array of document IDs obtained from searches",
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
contents: ContentRetrievalSettings = SchemaField(
|
||||
ids: list[str] = SchemaField(
|
||||
description="Array of document IDs obtained from searches"
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
description="Content retrieval settings",
|
||||
default=ContentRetrievalSettings(),
|
||||
default=ContentSettings(),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of document contents",
|
||||
default_factory=list,
|
||||
description="List of document contents", default_factory=list
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -63,7 +45,7 @@ class ExaContentsBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/contents"
|
||||
headers = {
|
||||
@@ -71,6 +53,7 @@ class ExaContentsBlock(Block):
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
# Convert ContentSettings to API format
|
||||
payload = {
|
||||
"ids": input_data.ids,
|
||||
"text": input_data.contents.text,
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import SchemaField
|
||||
from backend.sdk import BaseModel, SchemaField
|
||||
|
||||
|
||||
class TextSettings(BaseModel):
|
||||
@@ -42,13 +40,90 @@ class SummarySettings(BaseModel):
|
||||
class ContentSettings(BaseModel):
|
||||
text: TextSettings = SchemaField(
|
||||
default=TextSettings(),
|
||||
description="Text content settings",
|
||||
)
|
||||
highlights: HighlightSettings = SchemaField(
|
||||
default=HighlightSettings(),
|
||||
description="Highlight settings",
|
||||
)
|
||||
summary: SummarySettings = SchemaField(
|
||||
default=SummarySettings(),
|
||||
description="Summary settings",
|
||||
)
|
||||
|
||||
|
||||
# Websets Models
|
||||
class WebsetEntitySettings(BaseModel):
|
||||
type: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Entity type (e.g., 'company', 'person')",
|
||||
placeholder="company",
|
||||
)
|
||||
|
||||
|
||||
class WebsetCriterion(BaseModel):
|
||||
description: str = SchemaField(
|
||||
description="Description of the criterion",
|
||||
placeholder="Must be based in the US",
|
||||
)
|
||||
success_rate: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
description="Success rate percentage",
|
||||
ge=0,
|
||||
le=100,
|
||||
)
|
||||
|
||||
|
||||
class WebsetSearchConfig(BaseModel):
|
||||
query: str = SchemaField(
|
||||
description="Search query",
|
||||
placeholder="Marketing agencies based in the US",
|
||||
)
|
||||
count: int = SchemaField(
|
||||
default=10,
|
||||
description="Number of results to return",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
entity: Optional[WebsetEntitySettings] = SchemaField(
|
||||
default=None,
|
||||
description="Entity settings for the search",
|
||||
)
|
||||
criteria: Optional[list[WebsetCriterion]] = SchemaField(
|
||||
default=None,
|
||||
description="Search criteria",
|
||||
)
|
||||
behavior: Optional[str] = SchemaField(
|
||||
default="override",
|
||||
description="Behavior when updating results ('override' or 'append')",
|
||||
placeholder="override",
|
||||
)
|
||||
|
||||
|
||||
class EnrichmentOption(BaseModel):
|
||||
label: str = SchemaField(
|
||||
description="Label for the enrichment option",
|
||||
placeholder="Option 1",
|
||||
)
|
||||
|
||||
|
||||
class WebsetEnrichmentConfig(BaseModel):
|
||||
title: str = SchemaField(
|
||||
description="Title of the enrichment",
|
||||
placeholder="Company Details",
|
||||
)
|
||||
description: str = SchemaField(
|
||||
description="Description of what this enrichment does",
|
||||
placeholder="Extract company information",
|
||||
)
|
||||
format: str = SchemaField(
|
||||
default="text",
|
||||
description="Format of the enrichment result",
|
||||
placeholder="text",
|
||||
)
|
||||
instructions: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Instructions for the enrichment",
|
||||
placeholder="Extract key company metrics",
|
||||
)
|
||||
options: Optional[list[EnrichmentOption]] = SchemaField(
|
||||
default=None,
|
||||
description="Options for the enrichment",
|
||||
)
|
||||
|
||||
@@ -1,71 +1,61 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from backend.blocks.exa._auth import (
|
||||
ExaCredentials,
|
||||
ExaCredentialsField,
|
||||
ExaCredentialsInput,
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.blocks.exa.helpers import ContentSettings
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import ContentSettings
|
||||
|
||||
|
||||
class ExaSearchBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
query: str = SchemaField(description="The search query")
|
||||
use_auto_prompt: bool = SchemaField(
|
||||
description="Whether to use autoprompt",
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
type: str = SchemaField(
|
||||
description="Type of search",
|
||||
default="",
|
||||
advanced=True,
|
||||
description="Whether to use autoprompt", default=True, advanced=True
|
||||
)
|
||||
type: str = SchemaField(description="Type of search", default="", advanced=True)
|
||||
category: str = SchemaField(
|
||||
description="Category to search within",
|
||||
default="",
|
||||
advanced=True,
|
||||
description="Category to search within", default="", advanced=True
|
||||
)
|
||||
number_of_results: int = SchemaField(
|
||||
description="Number of results to return",
|
||||
default=10,
|
||||
advanced=True,
|
||||
description="Number of results to return", default=10, advanced=True
|
||||
)
|
||||
include_domains: List[str] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
default_factory=list,
|
||||
include_domains: list[str] = SchemaField(
|
||||
description="Domains to include in search", default_factory=list
|
||||
)
|
||||
exclude_domains: List[str] = SchemaField(
|
||||
exclude_domains: list[str] = SchemaField(
|
||||
description="Domains to exclude from search",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
description="Start date for crawled content",
|
||||
description="Start date for crawled content"
|
||||
)
|
||||
end_crawl_date: datetime = SchemaField(
|
||||
description="End date for crawled content",
|
||||
description="End date for crawled content"
|
||||
)
|
||||
start_published_date: datetime = SchemaField(
|
||||
description="Start date for published content",
|
||||
description="Start date for published content"
|
||||
)
|
||||
end_published_date: datetime = SchemaField(
|
||||
description="End date for published content",
|
||||
description="End date for published content"
|
||||
)
|
||||
include_text: List[str] = SchemaField(
|
||||
description="Text patterns to include",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
include_text: list[str] = SchemaField(
|
||||
description="Text patterns to include", default_factory=list, advanced=True
|
||||
)
|
||||
exclude_text: List[str] = SchemaField(
|
||||
description="Text patterns to exclude",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
exclude_text: list[str] = SchemaField(
|
||||
description="Text patterns to exclude", default_factory=list, advanced=True
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
description="Content retrieval settings",
|
||||
@@ -75,8 +65,7 @@ class ExaSearchBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of search results",
|
||||
default_factory=list,
|
||||
description="List of search results", default_factory=list
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
@@ -92,7 +81,7 @@ class ExaSearchBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/search"
|
||||
headers = {
|
||||
@@ -104,7 +93,7 @@ class ExaSearchBlock(Block):
|
||||
"query": input_data.query,
|
||||
"useAutoprompt": input_data.use_auto_prompt,
|
||||
"numResults": input_data.number_of_results,
|
||||
"contents": input_data.contents.dict(),
|
||||
"contents": input_data.contents.model_dump(),
|
||||
}
|
||||
|
||||
date_field_mapping = {
|
||||
|
||||
@@ -1,57 +1,60 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.exa._auth import (
|
||||
ExaCredentials,
|
||||
ExaCredentialsField,
|
||||
ExaCredentialsInput,
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import ContentSettings
|
||||
|
||||
|
||||
class ExaFindSimilarBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
url: str = SchemaField(
|
||||
description="The url for which you would like to find similar links"
|
||||
)
|
||||
number_of_results: int = SchemaField(
|
||||
description="Number of results to return",
|
||||
default=10,
|
||||
advanced=True,
|
||||
description="Number of results to return", default=10, advanced=True
|
||||
)
|
||||
include_domains: List[str] = SchemaField(
|
||||
include_domains: list[str] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
exclude_domains: List[str] = SchemaField(
|
||||
exclude_domains: list[str] = SchemaField(
|
||||
description="Domains to exclude from search",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
description="Start date for crawled content",
|
||||
description="Start date for crawled content"
|
||||
)
|
||||
end_crawl_date: datetime = SchemaField(
|
||||
description="End date for crawled content",
|
||||
description="End date for crawled content"
|
||||
)
|
||||
start_published_date: datetime = SchemaField(
|
||||
description="Start date for published content",
|
||||
description="Start date for published content"
|
||||
)
|
||||
end_published_date: datetime = SchemaField(
|
||||
description="End date for published content",
|
||||
description="End date for published content"
|
||||
)
|
||||
include_text: List[str] = SchemaField(
|
||||
include_text: list[str] = SchemaField(
|
||||
description="Text patterns to include (max 1 string, up to 5 words)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
exclude_text: List[str] = SchemaField(
|
||||
exclude_text: list[str] = SchemaField(
|
||||
description="Text patterns to exclude (max 1 string, up to 5 words)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
@@ -63,11 +66,13 @@ class ExaFindSimilarBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: List[Any] = SchemaField(
|
||||
results: list[Any] = SchemaField(
|
||||
description="List of similar documents with title, URL, published date, author, and score",
|
||||
default_factory=list,
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -79,7 +84,7 @@ class ExaFindSimilarBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/findSimilar"
|
||||
headers = {
|
||||
@@ -90,7 +95,7 @@ class ExaFindSimilarBlock(Block):
|
||||
payload = {
|
||||
"url": input_data.url,
|
||||
"numResults": input_data.number_of_results,
|
||||
"contents": input_data.contents.dict(),
|
||||
"contents": input_data.contents.model_dump(),
|
||||
}
|
||||
|
||||
optional_field_mapping = {
|
||||
|
||||
201
autogpt_platform/backend/backend/blocks/exa/webhook_blocks.py
Normal file
201
autogpt_platform/backend/backend/blocks/exa/webhook_blocks.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Exa Webhook Blocks
|
||||
|
||||
These blocks handle webhook events from Exa's API for websets and other events.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
BaseModel,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
CredentialsMetaInput,
|
||||
Field,
|
||||
ProviderName,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from ._webhook import ExaEventType
|
||||
|
||||
|
||||
class WebsetEventFilter(BaseModel):
|
||||
"""Filter configuration for Exa webset events."""
|
||||
|
||||
webset_created: bool = Field(
|
||||
default=True, description="Receive notifications when websets are created"
|
||||
)
|
||||
webset_deleted: bool = Field(
|
||||
default=False, description="Receive notifications when websets are deleted"
|
||||
)
|
||||
webset_paused: bool = Field(
|
||||
default=False, description="Receive notifications when websets are paused"
|
||||
)
|
||||
webset_idle: bool = Field(
|
||||
default=False, description="Receive notifications when websets become idle"
|
||||
)
|
||||
search_created: bool = Field(
|
||||
default=True,
|
||||
description="Receive notifications when webset searches are created",
|
||||
)
|
||||
search_completed: bool = Field(
|
||||
default=True, description="Receive notifications when webset searches complete"
|
||||
)
|
||||
search_canceled: bool = Field(
|
||||
default=False,
|
||||
description="Receive notifications when webset searches are canceled",
|
||||
)
|
||||
search_updated: bool = Field(
|
||||
default=False,
|
||||
description="Receive notifications when webset searches are updated",
|
||||
)
|
||||
item_created: bool = Field(
|
||||
default=True, description="Receive notifications when webset items are created"
|
||||
)
|
||||
item_enriched: bool = Field(
|
||||
default=True, description="Receive notifications when webset items are enriched"
|
||||
)
|
||||
export_created: bool = Field(
|
||||
default=False,
|
||||
description="Receive notifications when webset exports are created",
|
||||
)
|
||||
export_completed: bool = Field(
|
||||
default=True, description="Receive notifications when webset exports complete"
|
||||
)
|
||||
import_created: bool = Field(
|
||||
default=False, description="Receive notifications when imports are created"
|
||||
)
|
||||
import_completed: bool = Field(
|
||||
default=True, description="Receive notifications when imports complete"
|
||||
)
|
||||
import_processing: bool = Field(
|
||||
default=False, description="Receive notifications when imports are processing"
|
||||
)
|
||||
|
||||
|
||||
class ExaWebsetWebhookBlock(Block):
|
||||
"""
|
||||
Receives webhook notifications for Exa webset events.
|
||||
|
||||
This block allows you to monitor various events related to Exa websets,
|
||||
including creation, updates, searches, and exports.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="Exa API credentials for webhook management"
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive webhooks (auto-generated)",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The webset ID to monitor (optional, monitors all if empty)",
|
||||
default="",
|
||||
)
|
||||
event_filter: WebsetEventFilter = SchemaField(
|
||||
description="Configure which events to receive", default=WebsetEventFilter()
|
||||
)
|
||||
payload: dict = SchemaField(
|
||||
description="Webhook payload data", default={}, hidden=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
event_type: str = SchemaField(description="Type of event that occurred")
|
||||
event_id: str = SchemaField(description="Unique identifier for this event")
|
||||
webset_id: str = SchemaField(description="ID of the affected webset")
|
||||
data: dict = SchemaField(description="Event-specific data")
|
||||
timestamp: str = SchemaField(description="When the event occurred")
|
||||
metadata: dict = SchemaField(description="Additional event metadata")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d0204ed8-8b81-408d-8b8d-ed087a546228",
|
||||
description="Receive webhook notifications for Exa webset events",
|
||||
categories={BlockCategory.INPUT},
|
||||
input_schema=ExaWebsetWebhookBlock.Input,
|
||||
output_schema=ExaWebsetWebhookBlock.Output,
|
||||
block_type=BlockType.WEBHOOK,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName("exa"),
|
||||
webhook_type="webset",
|
||||
event_filter_input="event_filter",
|
||||
resource_format="{webset_id}",
|
||||
),
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
"""Process incoming Exa webhook payload."""
|
||||
try:
|
||||
payload = input_data.payload
|
||||
|
||||
# Extract event details
|
||||
event_type = payload.get("eventType", "unknown")
|
||||
event_id = payload.get("eventId", "")
|
||||
|
||||
# Get webset ID from payload or input
|
||||
webset_id = payload.get("websetId", input_data.webset_id)
|
||||
|
||||
# Check if we should process this event based on filter
|
||||
should_process = self._should_process_event(
|
||||
event_type, input_data.event_filter
|
||||
)
|
||||
|
||||
if not should_process:
|
||||
# Skip events that don't match our filter
|
||||
return
|
||||
|
||||
# Extract event data
|
||||
event_data = payload.get("data", {})
|
||||
timestamp = payload.get("occurredAt", payload.get("createdAt", ""))
|
||||
metadata = payload.get("metadata", {})
|
||||
|
||||
yield "event_type", event_type
|
||||
yield "event_id", event_id
|
||||
yield "webset_id", webset_id
|
||||
yield "data", event_data
|
||||
yield "timestamp", timestamp
|
||||
yield "metadata", metadata
|
||||
|
||||
except Exception as e:
|
||||
# Handle errors gracefully
|
||||
yield "event_type", "error"
|
||||
yield "event_id", ""
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "data", {"error": str(e)}
|
||||
yield "timestamp", ""
|
||||
yield "metadata", {}
|
||||
|
||||
def _should_process_event(
|
||||
self, event_type: str, event_filter: WebsetEventFilter
|
||||
) -> bool:
|
||||
"""Check if an event should be processed based on the filter."""
|
||||
filter_mapping = {
|
||||
ExaEventType.WEBSET_CREATED: event_filter.webset_created,
|
||||
ExaEventType.WEBSET_DELETED: event_filter.webset_deleted,
|
||||
ExaEventType.WEBSET_PAUSED: event_filter.webset_paused,
|
||||
ExaEventType.WEBSET_IDLE: event_filter.webset_idle,
|
||||
ExaEventType.WEBSET_SEARCH_CREATED: event_filter.search_created,
|
||||
ExaEventType.WEBSET_SEARCH_COMPLETED: event_filter.search_completed,
|
||||
ExaEventType.WEBSET_SEARCH_CANCELED: event_filter.search_canceled,
|
||||
ExaEventType.WEBSET_SEARCH_UPDATED: event_filter.search_updated,
|
||||
ExaEventType.WEBSET_ITEM_CREATED: event_filter.item_created,
|
||||
ExaEventType.WEBSET_ITEM_ENRICHED: event_filter.item_enriched,
|
||||
ExaEventType.WEBSET_EXPORT_CREATED: event_filter.export_created,
|
||||
ExaEventType.WEBSET_EXPORT_COMPLETED: event_filter.export_completed,
|
||||
ExaEventType.IMPORT_CREATED: event_filter.import_created,
|
||||
ExaEventType.IMPORT_COMPLETED: event_filter.import_completed,
|
||||
ExaEventType.IMPORT_PROCESSING: event_filter.import_processing,
|
||||
}
|
||||
|
||||
# Try to convert string to ExaEventType enum
|
||||
try:
|
||||
event_type_enum = ExaEventType(event_type)
|
||||
return filter_mapping.get(event_type_enum, True)
|
||||
except ValueError:
|
||||
# If event_type is not a valid enum value, process it by default
|
||||
return True
|
||||
456
autogpt_platform/backend/backend/blocks/exa/websets.py
Normal file
456
autogpt_platform/backend/backend/blocks/exa/websets.py
Normal file
@@ -0,0 +1,456 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import WebsetEnrichmentConfig, WebsetSearchConfig
|
||||
|
||||
|
||||
class ExaCreateWebsetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
search: WebsetSearchConfig = SchemaField(
|
||||
description="Initial search configuration for the Webset"
|
||||
)
|
||||
enrichments: Optional[list[WebsetEnrichmentConfig]] = SchemaField(
|
||||
default=None,
|
||||
description="Enrichments to apply to Webset items",
|
||||
advanced=True,
|
||||
)
|
||||
external_id: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="External identifier for the webset",
|
||||
placeholder="my-webset-123",
|
||||
advanced=True,
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="Key-value pairs to associate with this webset",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(
|
||||
description="The unique identifier for the created webset"
|
||||
)
|
||||
status: str = SchemaField(description="The status of the webset")
|
||||
external_id: Optional[str] = SchemaField(
|
||||
description="The external identifier for the webset", default=None
|
||||
)
|
||||
created_at: str = SchemaField(
|
||||
description="The date and time the webset was created"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0cda29ff-c549-4a19-8805-c982b7d4ec34",
|
||||
description="Create a new Exa Webset for persistent web search collections",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCreateWebsetBlock.Input,
|
||||
output_schema=ExaCreateWebsetBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/websets/v0/websets"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
# Build the payload
|
||||
payload: dict[str, Any] = {
|
||||
"search": input_data.search.model_dump(exclude_none=True),
|
||||
}
|
||||
|
||||
# Convert enrichments to API format
|
||||
if input_data.enrichments:
|
||||
enrichments_data = []
|
||||
for enrichment in input_data.enrichments:
|
||||
enrichments_data.append(enrichment.model_dump(exclude_none=True))
|
||||
payload["enrichments"] = enrichments_data
|
||||
|
||||
if input_data.external_id:
|
||||
payload["externalId"] = input_data.external_id
|
||||
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
yield "webset_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
yield "external_id", data.get("externalId")
|
||||
yield "created_at", data.get("createdAt", "")
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "webset_id", ""
|
||||
yield "status", ""
|
||||
yield "created_at", ""
|
||||
|
||||
|
||||
class ExaUpdateWebsetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to update",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="Key-value pairs to associate with this webset (set to null to clear)",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(description="The unique identifier for the webset")
|
||||
status: str = SchemaField(description="The status of the webset")
|
||||
external_id: Optional[str] = SchemaField(
|
||||
description="The external identifier for the webset", default=None
|
||||
)
|
||||
metadata: dict = SchemaField(
|
||||
description="Updated metadata for the webset", default_factory=dict
|
||||
)
|
||||
updated_at: str = SchemaField(
|
||||
description="The date and time the webset was updated"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="89ccd99a-3c2b-4fbf-9e25-0ffa398d0314",
|
||||
description="Update metadata for an existing Webset",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaUpdateWebsetBlock.Input,
|
||||
output_schema=ExaUpdateWebsetBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
# Build the payload
|
||||
payload = {}
|
||||
if input_data.metadata is not None:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
yield "webset_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
yield "external_id", data.get("externalId")
|
||||
yield "metadata", data.get("metadata", {})
|
||||
yield "updated_at", data.get("updatedAt", "")
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "webset_id", ""
|
||||
yield "status", ""
|
||||
yield "metadata", {}
|
||||
yield "updated_at", ""
|
||||
|
||||
|
||||
class ExaListWebsetsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination through results",
|
||||
advanced=True,
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
default=25,
|
||||
description="Number of websets to return (1-100)",
|
||||
ge=1,
|
||||
le=100,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
websets: list = SchemaField(description="List of websets", default_factory=list)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more results to paginate through",
|
||||
default=False,
|
||||
)
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Cursor for the next page of results", default=None
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1dcd8fd6-c13f-4e6f-bd4c-654428fa4757",
|
||||
description="List all Websets with pagination support",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaListWebsetsBlock.Input,
|
||||
output_schema=ExaListWebsetsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/websets/v0/websets"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"limit": input_data.limit,
|
||||
}
|
||||
if input_data.cursor:
|
||||
params["cursor"] = input_data.cursor
|
||||
|
||||
try:
|
||||
response = await Requests().get(url, headers=headers, params=params)
|
||||
data = response.json()
|
||||
|
||||
yield "websets", data.get("data", [])
|
||||
yield "has_more", data.get("hasMore", False)
|
||||
yield "next_cursor", data.get("nextCursor")
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "websets", []
|
||||
yield "has_more", False
|
||||
|
||||
|
||||
class ExaGetWebsetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to retrieve",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
expand_items: bool = SchemaField(
|
||||
default=False, description="Include items in the response", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(description="The unique identifier for the webset")
|
||||
status: str = SchemaField(description="The status of the webset")
|
||||
external_id: Optional[str] = SchemaField(
|
||||
description="The external identifier for the webset", default=None
|
||||
)
|
||||
searches: list[dict] = SchemaField(
|
||||
description="The searches performed on the webset", default_factory=list
|
||||
)
|
||||
enrichments: list[dict] = SchemaField(
|
||||
description="The enrichments applied to the webset", default_factory=list
|
||||
)
|
||||
monitors: list[dict] = SchemaField(
|
||||
description="The monitors for the webset", default_factory=list
|
||||
)
|
||||
items: Optional[list[dict]] = SchemaField(
|
||||
description="The items in the webset (if expand_items is true)",
|
||||
default=None,
|
||||
)
|
||||
metadata: dict = SchemaField(
|
||||
description="Key-value pairs associated with the webset",
|
||||
default_factory=dict,
|
||||
)
|
||||
created_at: str = SchemaField(
|
||||
description="The date and time the webset was created"
|
||||
)
|
||||
updated_at: str = SchemaField(
|
||||
description="The date and time the webset was last updated"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6ab8e12a-132c-41bf-b5f3-d662620fa832",
|
||||
description="Retrieve a Webset by ID or external ID",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetWebsetBlock.Input,
|
||||
output_schema=ExaGetWebsetBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
params = {}
|
||||
if input_data.expand_items:
|
||||
params["expand[]"] = "items"
|
||||
|
||||
try:
|
||||
response = await Requests().get(url, headers=headers, params=params)
|
||||
data = response.json()
|
||||
|
||||
yield "webset_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
yield "external_id", data.get("externalId")
|
||||
yield "searches", data.get("searches", [])
|
||||
yield "enrichments", data.get("enrichments", [])
|
||||
yield "monitors", data.get("monitors", [])
|
||||
yield "items", data.get("items")
|
||||
yield "metadata", data.get("metadata", {})
|
||||
yield "created_at", data.get("createdAt", "")
|
||||
yield "updated_at", data.get("updatedAt", "")
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "webset_id", ""
|
||||
yield "status", ""
|
||||
yield "searches", []
|
||||
yield "enrichments", []
|
||||
yield "monitors", []
|
||||
yield "metadata", {}
|
||||
yield "created_at", ""
|
||||
yield "updated_at", ""
|
||||
|
||||
|
||||
class ExaDeleteWebsetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to delete",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(
|
||||
description="The unique identifier for the deleted webset"
|
||||
)
|
||||
external_id: Optional[str] = SchemaField(
|
||||
description="The external identifier for the deleted webset", default=None
|
||||
)
|
||||
status: str = SchemaField(description="The status of the deleted webset")
|
||||
success: str = SchemaField(
|
||||
description="Whether the deletion was successful", default="true"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="aa6994a2-e986-421f-8d4c-7671d3be7b7e",
|
||||
description="Delete a Webset and all its items",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaDeleteWebsetBlock.Input,
|
||||
output_schema=ExaDeleteWebsetBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
try:
|
||||
response = await Requests().delete(url, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
yield "webset_id", data.get("id", "")
|
||||
yield "external_id", data.get("externalId")
|
||||
yield "status", data.get("status", "")
|
||||
yield "success", "true"
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "webset_id", ""
|
||||
yield "status", ""
|
||||
yield "success", "false"
|
||||
|
||||
|
||||
class ExaCancelWebsetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to cancel",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(description="The unique identifier for the webset")
|
||||
status: str = SchemaField(
|
||||
description="The status of the webset after cancellation"
|
||||
)
|
||||
external_id: Optional[str] = SchemaField(
|
||||
description="The external identifier for the webset", default=None
|
||||
)
|
||||
success: str = SchemaField(
|
||||
description="Whether the cancellation was successful", default="true"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e40a6420-1db8-47bb-b00a-0e6aecd74176",
|
||||
description="Cancel all operations being performed on a Webset",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCancelWebsetBlock.Input,
|
||||
output_schema=ExaCancelWebsetBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}/cancel"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
yield "webset_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
yield "external_id", data.get("externalId")
|
||||
yield "success", "true"
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "webset_id", ""
|
||||
yield "status", ""
|
||||
yield "success", "false"
|
||||
@@ -0,0 +1,9 @@
|
||||
# Import the provider builder to ensure it's registered
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
from .triggers import GenericWebhookTriggerBlock, generic_webhook
|
||||
|
||||
# Ensure the SDK registry is patched to include our webhook manager
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
__all__ = ["GenericWebhookTriggerBlock", "generic_webhook"]
|
||||
@@ -3,10 +3,7 @@ import logging
|
||||
from fastapi import Request
|
||||
from strenum import StrEnum
|
||||
|
||||
from backend.data import integrations
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from ._manual_base import ManualWebhookManagerBase
|
||||
from backend.sdk import ManualWebhookManagerBase, Webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,12 +13,11 @@ class GenericWebhookType(StrEnum):
|
||||
|
||||
|
||||
class GenericWebhooksManager(ManualWebhookManagerBase):
|
||||
PROVIDER_NAME = ProviderName.GENERIC_WEBHOOK
|
||||
WebhookType = GenericWebhookType
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(
|
||||
cls, webhook: integrations.Webhook, request: Request
|
||||
cls, webhook: Webhook, request: Request
|
||||
) -> tuple[dict, str]:
|
||||
payload = await request.json()
|
||||
event_type = GenericWebhookType.PLAIN
|
||||
@@ -1,13 +1,21 @@
|
||||
from backend.data.block import (
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockManualWebhookConfig,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
ProviderBuilder,
|
||||
ProviderName,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._webhook import GenericWebhooksManager, GenericWebhookType
|
||||
|
||||
generic_webhook = (
|
||||
ProviderBuilder("generic_webhook")
|
||||
.with_webhook_manager(GenericWebhooksManager)
|
||||
.build()
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.generic import GenericWebhookType
|
||||
|
||||
|
||||
class GenericWebhookTriggerBlock(Block):
|
||||
@@ -36,7 +44,7 @@ class GenericWebhookTriggerBlock(Block):
|
||||
input_schema=GenericWebhookTriggerBlock.Input,
|
||||
output_schema=GenericWebhookTriggerBlock.Output,
|
||||
webhook_config=BlockManualWebhookConfig(
|
||||
provider=ProviderName.GENERIC_WEBHOOK,
|
||||
provider=ProviderName(generic_webhook.name),
|
||||
webhook_type=GenericWebhookType.PLAIN,
|
||||
),
|
||||
test_input={"constants": {"key": "value"}, "payload": self.example_payload},
|
||||
|
||||
14
autogpt_platform/backend/backend/blocks/linear/__init__.py
Normal file
14
autogpt_platform/backend/backend/blocks/linear/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Linear integration blocks for AutoGPT Platform.
|
||||
"""
|
||||
|
||||
from .comment import LinearCreateCommentBlock
|
||||
from .issues import LinearCreateIssueBlock, LinearSearchIssuesBlock
|
||||
from .projects import LinearSearchProjectsBlock
|
||||
|
||||
__all__ = [
|
||||
"LinearCreateCommentBlock",
|
||||
"LinearCreateIssueBlock",
|
||||
"LinearSearchIssuesBlock",
|
||||
"LinearSearchProjectsBlock",
|
||||
]
|
||||
@@ -1,16 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from backend.blocks.linear._auth import LinearCredentials
|
||||
from backend.blocks.linear.models import (
|
||||
CreateCommentResponse,
|
||||
CreateIssueResponse,
|
||||
Issue,
|
||||
Project,
|
||||
)
|
||||
from backend.util.request import Requests
|
||||
from backend.sdk import APIKeyCredentials, OAuth2Credentials, Requests
|
||||
|
||||
from .models import CreateCommentResponse, CreateIssueResponse, Issue, Project
|
||||
|
||||
|
||||
class LinearAPIException(Exception):
|
||||
@@ -29,13 +24,12 @@ class LinearClient:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
credentials: LinearCredentials | None = None,
|
||||
credentials: Union[OAuth2Credentials, APIKeyCredentials, None] = None,
|
||||
custom_requests: Optional[Requests] = None,
|
||||
):
|
||||
if custom_requests:
|
||||
self._requests = custom_requests
|
||||
else:
|
||||
|
||||
headers: Dict[str, str] = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
@@ -1,31 +1,19 @@
|
||||
"""
|
||||
Shared configuration for all Linear blocks using the new SDK pattern.
|
||||
"""
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import (
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
BlockCostType,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
secrets = Secrets()
|
||||
LINEAR_OAUTH_IS_CONFIGURED = bool(
|
||||
secrets.linear_client_id and secrets.linear_client_secret
|
||||
ProviderBuilder,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
LinearCredentials = OAuth2Credentials | APIKeyCredentials
|
||||
# LinearCredentialsInput = CredentialsMetaInput[
|
||||
# Literal[ProviderName.LINEAR],
|
||||
# Literal["oauth2", "api_key"] if LINEAR_OAUTH_IS_CONFIGURED else Literal["oauth2"],
|
||||
# ]
|
||||
LinearCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.LINEAR], Literal["oauth2"]
|
||||
]
|
||||
|
||||
from ._oauth import LinearOAuthHandler
|
||||
|
||||
# (required) Comma separated list of scopes:
|
||||
|
||||
@@ -50,21 +38,35 @@ class LinearScope(str, Enum):
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
def LinearCredentialsField(scopes: list[LinearScope]) -> LinearCredentialsInput:
|
||||
"""
|
||||
Creates a Linear credentials input on a block.
|
||||
# Check if Linear OAuth is configured
|
||||
client_id = os.getenv("LINEAR_CLIENT_ID")
|
||||
client_secret = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
LINEAR_OAUTH_IS_CONFIGURED = bool(client_id and client_secret)
|
||||
|
||||
Params:
|
||||
scope: The authorization scope needed for the block to work. ([list of available scopes](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/scopes-for-oauth-apps#available-scopes))
|
||||
""" # noqa
|
||||
return CredentialsField(
|
||||
required_scopes=set([LinearScope.READ.value]).union(
|
||||
set([scope.value for scope in scopes])
|
||||
),
|
||||
description="The Linear integration can be used with OAuth, "
|
||||
"or any API key with sufficient permissions for the blocks it is used on.",
|
||||
# Build the Linear provider
|
||||
builder = (
|
||||
ProviderBuilder("linear")
|
||||
.with_api_key(env_var_name="LINEAR_API_KEY", title="Linear API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
)
|
||||
|
||||
# Linear only supports OAuth authentication
|
||||
if LINEAR_OAUTH_IS_CONFIGURED:
|
||||
builder = builder.with_oauth(
|
||||
LinearOAuthHandler,
|
||||
scopes=[
|
||||
LinearScope.READ,
|
||||
LinearScope.WRITE,
|
||||
LinearScope.ISSUES_CREATE,
|
||||
LinearScope.COMMENTS_CREATE,
|
||||
],
|
||||
client_id_env_var="LINEAR_CLIENT_ID",
|
||||
client_secret_env_var="LINEAR_CLIENT_SECRET",
|
||||
)
|
||||
|
||||
# Build the provider
|
||||
linear = builder.build()
|
||||
|
||||
|
||||
TEST_CREDENTIALS_OAUTH = OAuth2Credentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -1,15 +1,27 @@
|
||||
"""
|
||||
Linear OAuth handler implementation.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from pydantic import SecretStr
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseOAuthHandler,
|
||||
OAuth2Credentials,
|
||||
ProviderName,
|
||||
Requests,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from backend.blocks.linear._api import LinearAPIException
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import Requests
|
||||
|
||||
from .base import BaseOAuthHandler
|
||||
class LinearAPIException(Exception):
|
||||
"""Exception for Linear API errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class LinearOAuthHandler(BaseOAuthHandler):
|
||||
@@ -17,7 +29,9 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
OAuth2 handler for Linear.
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = ProviderName.LINEAR
|
||||
# Provider name will be set dynamically by the SDK when registered
|
||||
# We use a placeholder that will be replaced by AutoRegistry.register_provider()
|
||||
PROVIDER_NAME = ProviderName("linear")
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
@@ -30,7 +44,6 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
@@ -139,9 +152,10 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
|
||||
async def _request_username(self, access_token: str) -> Optional[str]:
|
||||
# Use the LinearClient to fetch user details using GraphQL
|
||||
from backend.blocks.linear._api import LinearClient
|
||||
from ._api import LinearClient
|
||||
|
||||
try:
|
||||
# Create a temporary OAuth2Credentials object for the LinearClient
|
||||
linear_client = LinearClient(
|
||||
APIKeyCredentials(
|
||||
api_key=SecretStr(access_token),
|
||||
@@ -1,24 +1,32 @@
|
||||
from backend.blocks.linear._api import LinearAPIException, LinearClient
|
||||
from backend.blocks.linear._auth import (
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import LinearAPIException, LinearClient
|
||||
from ._config import (
|
||||
LINEAR_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS_INPUT_OAUTH,
|
||||
TEST_CREDENTIALS_OAUTH,
|
||||
LinearCredentials,
|
||||
LinearCredentialsField,
|
||||
LinearCredentialsInput,
|
||||
LinearScope,
|
||||
linear,
|
||||
)
|
||||
from backend.blocks.linear.models import CreateCommentResponse
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from .models import CreateCommentResponse
|
||||
|
||||
|
||||
class LinearCreateCommentBlock(Block):
|
||||
"""Block for creating comments on Linear issues"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: LinearCredentialsInput = LinearCredentialsField(
|
||||
scopes=[LinearScope.COMMENTS_CREATE],
|
||||
credentials: CredentialsMetaInput = linear.credentials_field(
|
||||
description="Linear credentials with comment creation permissions",
|
||||
required_scopes={LinearScope.COMMENTS_CREATE},
|
||||
)
|
||||
issue_id: str = SchemaField(description="ID of the issue to comment on")
|
||||
comment: str = SchemaField(description="Comment text to add to the issue")
|
||||
@@ -55,7 +63,7 @@ class LinearCreateCommentBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
async def create_comment(
|
||||
credentials: LinearCredentials, issue_id: str, comment: str
|
||||
credentials: OAuth2Credentials | APIKeyCredentials, issue_id: str, comment: str
|
||||
) -> tuple[str, str]:
|
||||
client = LinearClient(credentials=credentials)
|
||||
response: CreateCommentResponse = await client.try_create_comment(
|
||||
@@ -64,7 +72,11 @@ class LinearCreateCommentBlock(Block):
|
||||
return response.comment.id, response.comment.body
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the comment creation"""
|
||||
try:
|
||||
|
||||
@@ -1,24 +1,32 @@
|
||||
from backend.blocks.linear._api import LinearAPIException, LinearClient
|
||||
from backend.blocks.linear._auth import (
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import LinearAPIException, LinearClient
|
||||
from ._config import (
|
||||
LINEAR_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS_INPUT_OAUTH,
|
||||
TEST_CREDENTIALS_OAUTH,
|
||||
LinearCredentials,
|
||||
LinearCredentialsField,
|
||||
LinearCredentialsInput,
|
||||
LinearScope,
|
||||
linear,
|
||||
)
|
||||
from backend.blocks.linear.models import CreateIssueResponse, Issue
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from .models import CreateIssueResponse, Issue
|
||||
|
||||
|
||||
class LinearCreateIssueBlock(Block):
|
||||
"""Block for creating issues on Linear"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: LinearCredentialsInput = LinearCredentialsField(
|
||||
scopes=[LinearScope.ISSUES_CREATE],
|
||||
credentials: CredentialsMetaInput = linear.credentials_field(
|
||||
description="Linear credentials with issue creation permissions",
|
||||
required_scopes={LinearScope.ISSUES_CREATE},
|
||||
)
|
||||
title: str = SchemaField(description="Title of the issue")
|
||||
description: str | None = SchemaField(description="Description of the issue")
|
||||
@@ -68,7 +76,7 @@ class LinearCreateIssueBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
async def create_issue(
|
||||
credentials: LinearCredentials,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
team_name: str,
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
@@ -94,7 +102,11 @@ class LinearCreateIssueBlock(Block):
|
||||
return response.issue.identifier, response.issue.title
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the issue creation"""
|
||||
try:
|
||||
@@ -121,8 +133,9 @@ class LinearSearchIssuesBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
term: str = SchemaField(description="Term to search for issues")
|
||||
credentials: LinearCredentialsInput = LinearCredentialsField(
|
||||
scopes=[LinearScope.READ],
|
||||
credentials: CredentialsMetaInput = linear.credentials_field(
|
||||
description="Linear credentials with read permissions",
|
||||
required_scopes={LinearScope.READ},
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -169,7 +182,7 @@ class LinearSearchIssuesBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
async def search_issues(
|
||||
credentials: LinearCredentials,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
term: str,
|
||||
) -> list[Issue]:
|
||||
client = LinearClient(credentials=credentials)
|
||||
@@ -177,7 +190,11 @@ class LinearSearchIssuesBlock(Block):
|
||||
return response
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the issue search"""
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel
|
||||
from backend.sdk import BaseModel
|
||||
|
||||
|
||||
class Comment(BaseModel):
|
||||
|
||||
@@ -1,24 +1,32 @@
|
||||
from backend.blocks.linear._api import LinearAPIException, LinearClient
|
||||
from backend.blocks.linear._auth import (
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import LinearAPIException, LinearClient
|
||||
from ._config import (
|
||||
LINEAR_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS_INPUT_OAUTH,
|
||||
TEST_CREDENTIALS_OAUTH,
|
||||
LinearCredentials,
|
||||
LinearCredentialsField,
|
||||
LinearCredentialsInput,
|
||||
LinearScope,
|
||||
linear,
|
||||
)
|
||||
from backend.blocks.linear.models import Project
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from .models import Project
|
||||
|
||||
|
||||
class LinearSearchProjectsBlock(Block):
|
||||
"""Block for searching projects on Linear"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: LinearCredentialsInput = LinearCredentialsField(
|
||||
scopes=[LinearScope.READ],
|
||||
credentials: CredentialsMetaInput = linear.credentials_field(
|
||||
description="Linear credentials with read permissions",
|
||||
required_scopes={LinearScope.READ},
|
||||
)
|
||||
term: str = SchemaField(description="Term to search for projects")
|
||||
|
||||
@@ -70,7 +78,7 @@ class LinearSearchProjectsBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
async def search_projects(
|
||||
credentials: LinearCredentials,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
term: str,
|
||||
) -> list[Project]:
|
||||
client = LinearClient(credentials=credentials)
|
||||
@@ -78,7 +86,11 @@ class LinearSearchProjectsBlock(Block):
|
||||
return response
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the project search"""
|
||||
try:
|
||||
|
||||
@@ -9,3 +9,117 @@ from backend.util.test import execute_block_test
|
||||
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b.name)
|
||||
async def test_available_blocks(block: Type[Block]):
|
||||
await execute_block_test(block())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b.name)
|
||||
async def test_block_ids_valid(block: Type[Block]):
|
||||
# add the tests here to check they are uuid4
|
||||
import uuid
|
||||
|
||||
# Skip list for blocks with known invalid UUIDs
|
||||
skip_blocks = {
|
||||
"GetWeatherInformationBlock",
|
||||
"CodeExecutionBlock",
|
||||
"CountdownTimerBlock",
|
||||
"TwitterGetListTweetsBlock",
|
||||
"TwitterRemoveListMemberBlock",
|
||||
"TwitterAddListMemberBlock",
|
||||
"TwitterGetListMembersBlock",
|
||||
"TwitterGetListMembershipsBlock",
|
||||
"TwitterUnfollowListBlock",
|
||||
"TwitterFollowListBlock",
|
||||
"TwitterUnpinListBlock",
|
||||
"TwitterPinListBlock",
|
||||
"TwitterGetPinnedListsBlock",
|
||||
"TwitterDeleteListBlock",
|
||||
"TwitterUpdateListBlock",
|
||||
"TwitterCreateListBlock",
|
||||
"TwitterGetListBlock",
|
||||
"TwitterGetOwnedListsBlock",
|
||||
"TwitterGetSpacesBlock",
|
||||
"TwitterGetSpaceByIdBlock",
|
||||
"TwitterGetSpaceBuyersBlock",
|
||||
"TwitterGetSpaceTweetsBlock",
|
||||
"TwitterSearchSpacesBlock",
|
||||
"TwitterGetUserMentionsBlock",
|
||||
"TwitterGetHomeTimelineBlock",
|
||||
"TwitterGetUserTweetsBlock",
|
||||
"TwitterGetTweetBlock",
|
||||
"TwitterGetTweetsBlock",
|
||||
"TwitterGetQuoteTweetsBlock",
|
||||
"TwitterLikeTweetBlock",
|
||||
"TwitterGetLikingUsersBlock",
|
||||
"TwitterGetLikedTweetsBlock",
|
||||
"TwitterUnlikeTweetBlock",
|
||||
"TwitterBookmarkTweetBlock",
|
||||
"TwitterGetBookmarkedTweetsBlock",
|
||||
"TwitterRemoveBookmarkTweetBlock",
|
||||
"TwitterRetweetBlock",
|
||||
"TwitterRemoveRetweetBlock",
|
||||
"TwitterGetRetweetersBlock",
|
||||
"TwitterHideReplyBlock",
|
||||
"TwitterUnhideReplyBlock",
|
||||
"TwitterPostTweetBlock",
|
||||
"TwitterDeleteTweetBlock",
|
||||
"TwitterSearchRecentTweetsBlock",
|
||||
"TwitterUnfollowUserBlock",
|
||||
"TwitterFollowUserBlock",
|
||||
"TwitterGetFollowersBlock",
|
||||
"TwitterGetFollowingBlock",
|
||||
"TwitterUnmuteUserBlock",
|
||||
"TwitterGetMutedUsersBlock",
|
||||
"TwitterMuteUserBlock",
|
||||
"TwitterGetBlockedUsersBlock",
|
||||
"TwitterGetUserBlock",
|
||||
"TwitterGetUsersBlock",
|
||||
"TodoistCreateLabelBlock",
|
||||
"TodoistListLabelsBlock",
|
||||
"TodoistGetLabelBlock",
|
||||
"TodoistUpdateLabelBlock",
|
||||
"TodoistDeleteLabelBlock",
|
||||
"TodoistGetSharedLabelsBlock",
|
||||
"TodoistRenameSharedLabelsBlock",
|
||||
"TodoistRemoveSharedLabelsBlock",
|
||||
"TodoistCreateTaskBlock",
|
||||
"TodoistGetTasksBlock",
|
||||
"TodoistGetTaskBlock",
|
||||
"TodoistUpdateTaskBlock",
|
||||
"TodoistCloseTaskBlock",
|
||||
"TodoistReopenTaskBlock",
|
||||
"TodoistDeleteTaskBlock",
|
||||
"TodoistListSectionsBlock",
|
||||
"TodoistGetSectionBlock",
|
||||
"TodoistDeleteSectionBlock",
|
||||
"TodoistCreateProjectBlock",
|
||||
"TodoistGetProjectBlock",
|
||||
"TodoistUpdateProjectBlock",
|
||||
"TodoistDeleteProjectBlock",
|
||||
"TodoistListCollaboratorsBlock",
|
||||
"TodoistGetCommentsBlock",
|
||||
"TodoistGetCommentBlock",
|
||||
"TodoistUpdateCommentBlock",
|
||||
"TodoistDeleteCommentBlock",
|
||||
"GithubListStargazersBlock",
|
||||
"Slant3DSlicerBlock",
|
||||
}
|
||||
|
||||
block_instance = block()
|
||||
|
||||
# Skip blocks with known invalid UUIDs
|
||||
if block_instance.__class__.__name__ in skip_blocks:
|
||||
pytest.skip(
|
||||
f"Skipping UUID check for {block_instance.__class__.__name__} - known invalid UUID"
|
||||
)
|
||||
|
||||
# Check that the ID is not empty
|
||||
assert block_instance.id, f"Block {block.name} has empty ID"
|
||||
|
||||
# Check that the ID is a valid UUID4
|
||||
try:
|
||||
parsed_uuid = uuid.UUID(block_instance.id)
|
||||
# Verify it's specifically UUID version 4
|
||||
assert (
|
||||
parsed_uuid.version == 4
|
||||
), f"Block {block.name} ID is UUID version {parsed_uuid.version}, expected version 4"
|
||||
except ValueError:
|
||||
pytest.fail(f"Block {block.name} has invalid UUID format: {block_instance.id}")
|
||||
|
||||
359
autogpt_platform/backend/backend/check_db.py
Normal file
359
autogpt_platform/backend/backend/check_db.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import asyncio
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
from faker import Faker
|
||||
from prisma import Prisma
|
||||
|
||||
faker = Faker()
|
||||
|
||||
|
||||
async def check_cron_job(db):
|
||||
"""Check if the pg_cron job for refreshing materialized views exists."""
|
||||
print("\n1. Checking pg_cron job...")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
# Check if pg_cron extension exists
|
||||
extension_check = await db.query_raw("CREATE EXTENSION pg_cron;")
|
||||
print(extension_check)
|
||||
extension_check = await db.query_raw(
|
||||
"SELECT COUNT(*) as count FROM pg_extension WHERE extname = 'pg_cron'"
|
||||
)
|
||||
if extension_check[0]["count"] == 0:
|
||||
print("⚠️ pg_cron extension is not installed")
|
||||
return False
|
||||
|
||||
# Check if the refresh job exists
|
||||
job_check = await db.query_raw(
|
||||
"""
|
||||
SELECT jobname, schedule, command
|
||||
FROM cron.job
|
||||
WHERE jobname = 'refresh-store-views'
|
||||
"""
|
||||
)
|
||||
|
||||
if job_check:
|
||||
job = job_check[0]
|
||||
print("✅ pg_cron job found:")
|
||||
print(f" Name: {job['jobname']}")
|
||||
print(f" Schedule: {job['schedule']} (every 15 minutes)")
|
||||
print(f" Command: {job['command']}")
|
||||
return True
|
||||
else:
|
||||
print("⚠️ pg_cron job 'refresh-store-views' not found")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error checking pg_cron: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_materialized_view_counts(db):
|
||||
"""Get current counts from materialized views."""
|
||||
print("\n2. Getting current materialized view data...")
|
||||
print("-" * 40)
|
||||
|
||||
# Get counts from mv_agent_run_counts
|
||||
agent_runs = await db.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as total_agents,
|
||||
SUM(run_count) as total_runs,
|
||||
MAX(run_count) as max_runs,
|
||||
MIN(run_count) as min_runs
|
||||
FROM mv_agent_run_counts
|
||||
"""
|
||||
)
|
||||
|
||||
# Get counts from mv_review_stats
|
||||
review_stats = await db.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as total_listings,
|
||||
SUM(review_count) as total_reviews,
|
||||
AVG(avg_rating) as overall_avg_rating
|
||||
FROM mv_review_stats
|
||||
"""
|
||||
)
|
||||
|
||||
# Get sample data from StoreAgent view
|
||||
store_agents = await db.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as total_store_agents,
|
||||
AVG(runs) as avg_runs,
|
||||
AVG(rating) as avg_rating
|
||||
FROM "StoreAgent"
|
||||
"""
|
||||
)
|
||||
|
||||
agent_run_data = agent_runs[0] if agent_runs else {}
|
||||
review_data = review_stats[0] if review_stats else {}
|
||||
store_data = store_agents[0] if store_agents else {}
|
||||
|
||||
print("📊 mv_agent_run_counts:")
|
||||
print(f" Total agents: {agent_run_data.get('total_agents', 0)}")
|
||||
print(f" Total runs: {agent_run_data.get('total_runs', 0)}")
|
||||
print(f" Max runs per agent: {agent_run_data.get('max_runs', 0)}")
|
||||
print(f" Min runs per agent: {agent_run_data.get('min_runs', 0)}")
|
||||
|
||||
print("\n📊 mv_review_stats:")
|
||||
print(f" Total listings: {review_data.get('total_listings', 0)}")
|
||||
print(f" Total reviews: {review_data.get('total_reviews', 0)}")
|
||||
print(f" Overall avg rating: {review_data.get('overall_avg_rating') or 0:.2f}")
|
||||
|
||||
print("\n📊 StoreAgent view:")
|
||||
print(f" Total store agents: {store_data.get('total_store_agents', 0)}")
|
||||
print(f" Average runs: {store_data.get('avg_runs') or 0:.2f}")
|
||||
print(f" Average rating: {store_data.get('avg_rating') or 0:.2f}")
|
||||
|
||||
return {
|
||||
"agent_runs": agent_run_data,
|
||||
"reviews": review_data,
|
||||
"store_agents": store_data,
|
||||
}
|
||||
|
||||
|
||||
async def add_test_data(db):
|
||||
"""Add some test data to verify materialized view updates."""
|
||||
print("\n3. Adding test data...")
|
||||
print("-" * 40)
|
||||
|
||||
# Get some existing data
|
||||
users = await db.user.find_many(take=5)
|
||||
graphs = await db.agentgraph.find_many(take=5)
|
||||
|
||||
if not users or not graphs:
|
||||
print("❌ No existing users or graphs found. Run test_data_creator.py first.")
|
||||
return False
|
||||
|
||||
# Add new executions
|
||||
print("Adding new agent graph executions...")
|
||||
new_executions = 0
|
||||
for graph in graphs:
|
||||
for _ in range(random.randint(2, 5)):
|
||||
await db.agentgraphexecution.create(
|
||||
data={
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"userId": random.choice(users).id,
|
||||
"executionStatus": "COMPLETED",
|
||||
"startedAt": datetime.now(),
|
||||
}
|
||||
)
|
||||
new_executions += 1
|
||||
|
||||
print(f"✅ Added {new_executions} new executions")
|
||||
|
||||
# Check if we need to create store listings first
|
||||
store_versions = await db.storelistingversion.find_many(
|
||||
where={"submissionStatus": "APPROVED"}, take=5
|
||||
)
|
||||
|
||||
if not store_versions:
|
||||
print("\nNo approved store listings found. Creating test store listings...")
|
||||
|
||||
# Create store listings for existing agent graphs
|
||||
for i, graph in enumerate(graphs[:3]): # Create up to 3 store listings
|
||||
# Create a store listing
|
||||
listing = await db.storelisting.create(
|
||||
data={
|
||||
"slug": f"test-agent-{graph.id[:8]}",
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"hasApprovedVersion": True,
|
||||
"owningUserId": graph.userId,
|
||||
}
|
||||
)
|
||||
|
||||
# Create an approved version
|
||||
version = await db.storelistingversion.create(
|
||||
data={
|
||||
"storeListingId": listing.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"name": f"Test Agent {i+1}",
|
||||
"subHeading": faker.catch_phrase(),
|
||||
"description": faker.paragraph(nb_sentences=5),
|
||||
"imageUrls": [faker.image_url()],
|
||||
"categories": ["productivity", "automation"],
|
||||
"submissionStatus": "APPROVED",
|
||||
"submittedAt": datetime.now(),
|
||||
}
|
||||
)
|
||||
|
||||
# Update listing with active version
|
||||
await db.storelisting.update(
|
||||
where={"id": listing.id}, data={"activeVersionId": version.id}
|
||||
)
|
||||
|
||||
print("✅ Created test store listings")
|
||||
|
||||
# Re-fetch approved versions
|
||||
store_versions = await db.storelistingversion.find_many(
|
||||
where={"submissionStatus": "APPROVED"}, take=5
|
||||
)
|
||||
|
||||
# Add new reviews
|
||||
print("\nAdding new store listing reviews...")
|
||||
new_reviews = 0
|
||||
for version in store_versions:
|
||||
# Find users who haven't reviewed this version
|
||||
existing_reviews = await db.storelistingreview.find_many(
|
||||
where={"storeListingVersionId": version.id}
|
||||
)
|
||||
reviewed_user_ids = {r.reviewByUserId for r in existing_reviews}
|
||||
available_users = [u for u in users if u.id not in reviewed_user_ids]
|
||||
|
||||
if available_users:
|
||||
user = random.choice(available_users)
|
||||
await db.storelistingreview.create(
|
||||
data={
|
||||
"storeListingVersionId": version.id,
|
||||
"reviewByUserId": user.id,
|
||||
"score": random.randint(3, 5),
|
||||
"comments": faker.text(max_nb_chars=100),
|
||||
}
|
||||
)
|
||||
new_reviews += 1
|
||||
|
||||
print(f"✅ Added {new_reviews} new reviews")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def refresh_materialized_views(db):
|
||||
"""Manually refresh the materialized views."""
|
||||
print("\n4. Manually refreshing materialized views...")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
await db.execute_raw("SELECT refresh_store_materialized_views();")
|
||||
print("✅ Materialized views refreshed successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ Error refreshing views: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def compare_counts(before, after):
|
||||
"""Compare counts before and after refresh."""
|
||||
print("\n5. Comparing counts before and after refresh...")
|
||||
print("-" * 40)
|
||||
|
||||
# Compare agent runs
|
||||
print("🔍 Agent run changes:")
|
||||
before_runs = before["agent_runs"].get("total_runs") or 0
|
||||
after_runs = after["agent_runs"].get("total_runs") or 0
|
||||
print(
|
||||
f" Total runs: {before_runs} → {after_runs} " f"(+{after_runs - before_runs})"
|
||||
)
|
||||
|
||||
# Compare reviews
|
||||
print("\n🔍 Review changes:")
|
||||
before_reviews = before["reviews"].get("total_reviews") or 0
|
||||
after_reviews = after["reviews"].get("total_reviews") or 0
|
||||
print(
|
||||
f" Total reviews: {before_reviews} → {after_reviews} "
|
||||
f"(+{after_reviews - before_reviews})"
|
||||
)
|
||||
|
||||
# Compare store agents
|
||||
print("\n🔍 StoreAgent view changes:")
|
||||
before_avg_runs = before["store_agents"].get("avg_runs", 0) or 0
|
||||
after_avg_runs = after["store_agents"].get("avg_runs", 0) or 0
|
||||
print(
|
||||
f" Average runs: {before_avg_runs:.2f} → {after_avg_runs:.2f} "
|
||||
f"(+{after_avg_runs - before_avg_runs:.2f})"
|
||||
)
|
||||
|
||||
# Verify changes occurred
|
||||
runs_changed = (after["agent_runs"].get("total_runs") or 0) > (
|
||||
before["agent_runs"].get("total_runs") or 0
|
||||
)
|
||||
reviews_changed = (after["reviews"].get("total_reviews") or 0) > (
|
||||
before["reviews"].get("total_reviews") or 0
|
||||
)
|
||||
|
||||
if runs_changed and reviews_changed:
|
||||
print("\n✅ Materialized views are updating correctly!")
|
||||
return True
|
||||
else:
|
||||
print("\n⚠️ Some materialized views may not have updated:")
|
||||
if not runs_changed:
|
||||
print(" - Agent run counts did not increase")
|
||||
if not reviews_changed:
|
||||
print(" - Review counts did not increase")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
|
||||
print("=" * 60)
|
||||
print("Materialized Views Test")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Check if data exists
|
||||
user_count = await db.user.count()
|
||||
if user_count == 0:
|
||||
print("❌ No data in database. Please run test_data_creator.py first.")
|
||||
await db.disconnect()
|
||||
return
|
||||
|
||||
# 1. Check cron job
|
||||
cron_exists = await check_cron_job(db)
|
||||
|
||||
# 2. Get initial counts
|
||||
counts_before = await get_materialized_view_counts(db)
|
||||
|
||||
# 3. Add test data
|
||||
data_added = await add_test_data(db)
|
||||
refresh_success = False
|
||||
|
||||
if data_added:
|
||||
# Wait a moment for data to be committed
|
||||
print("\nWaiting for data to be committed...")
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# 4. Manually refresh views
|
||||
refresh_success = await refresh_materialized_views(db)
|
||||
|
||||
if refresh_success:
|
||||
# 5. Get counts after refresh
|
||||
counts_after = await get_materialized_view_counts(db)
|
||||
|
||||
# 6. Compare results
|
||||
await compare_counts(counts_before, counts_after)
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Test Summary")
|
||||
print("=" * 60)
|
||||
print(f"✓ pg_cron job exists: {'Yes' if cron_exists else 'No'}")
|
||||
print(f"✓ Test data added: {'Yes' if data_added else 'No'}")
|
||||
print(f"✓ Manual refresh worked: {'Yes' if refresh_success else 'No'}")
|
||||
print(
|
||||
f"✓ Views updated correctly: {'Yes' if data_added and refresh_success else 'Cannot verify'}"
|
||||
)
|
||||
|
||||
if cron_exists:
|
||||
print(
|
||||
"\n💡 The materialized views will also refresh automatically every 15 minutes via pg_cron."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"\n⚠️ Automatic refresh is not configured. Views must be refreshed manually."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
159
autogpt_platform/backend/backend/check_store_data.py
Normal file
159
autogpt_platform/backend/backend/check_store_data.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Check store-related data in the database."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from prisma import Prisma
|
||||
|
||||
|
||||
async def check_store_data(db):
|
||||
"""Check what store data exists in the database."""
|
||||
|
||||
print("============================================================")
|
||||
print("Store Data Check")
|
||||
print("============================================================")
|
||||
|
||||
# Check store listings
|
||||
print("\n1. Store Listings:")
|
||||
print("-" * 40)
|
||||
listings = await db.storelisting.find_many()
|
||||
print(f"Total store listings: {len(listings)}")
|
||||
|
||||
if listings:
|
||||
for listing in listings[:5]:
|
||||
print(f"\nListing ID: {listing.id}")
|
||||
print(f" Name: {listing.name}")
|
||||
print(f" Status: {listing.status}")
|
||||
print(f" Slug: {listing.slug}")
|
||||
|
||||
# Check store listing versions
|
||||
print("\n\n2. Store Listing Versions:")
|
||||
print("-" * 40)
|
||||
versions = await db.storelistingversion.find_many(include={"StoreListing": True})
|
||||
print(f"Total store listing versions: {len(versions)}")
|
||||
|
||||
# Group by submission status
|
||||
status_counts = {}
|
||||
for version in versions:
|
||||
status = version.submissionStatus
|
||||
status_counts[status] = status_counts.get(status, 0) + 1
|
||||
|
||||
print("\nVersions by status:")
|
||||
for status, count in status_counts.items():
|
||||
print(f" {status}: {count}")
|
||||
|
||||
# Show approved versions
|
||||
approved_versions = [v for v in versions if v.submissionStatus == "APPROVED"]
|
||||
print(f"\nApproved versions: {len(approved_versions)}")
|
||||
if approved_versions:
|
||||
for version in approved_versions[:5]:
|
||||
print(f"\n Version ID: {version.id}")
|
||||
print(f" Listing: {version.StoreListing.name}")
|
||||
print(f" Version: {version.version}")
|
||||
|
||||
# Check store listing reviews
|
||||
print("\n\n3. Store Listing Reviews:")
|
||||
print("-" * 40)
|
||||
reviews = await db.storelistingreview.find_many(
|
||||
include={"StoreListingVersion": {"include": {"StoreListing": True}}}
|
||||
)
|
||||
print(f"Total reviews: {len(reviews)}")
|
||||
|
||||
if reviews:
|
||||
# Calculate average rating
|
||||
total_score = sum(r.score for r in reviews)
|
||||
avg_score = total_score / len(reviews) if reviews else 0
|
||||
print(f"Average rating: {avg_score:.2f}")
|
||||
|
||||
# Show sample reviews
|
||||
print("\nSample reviews:")
|
||||
for review in reviews[:3]:
|
||||
print(f"\n Review for: {review.StoreListingVersion.StoreListing.name}")
|
||||
print(f" Score: {review.score}")
|
||||
print(f" Comments: {review.comments[:100]}...")
|
||||
|
||||
# Check StoreAgent view data
|
||||
print("\n\n4. StoreAgent View Data:")
|
||||
print("-" * 40)
|
||||
|
||||
# Query the StoreAgent view
|
||||
query = """
|
||||
SELECT
|
||||
sa.listing_id,
|
||||
sa.slug,
|
||||
sa.agent_name,
|
||||
sa.description,
|
||||
sa.featured,
|
||||
sa.runs,
|
||||
sa.rating,
|
||||
sa.creator_username,
|
||||
sa.categories,
|
||||
sa.updated_at
|
||||
FROM "StoreAgent" sa
|
||||
LIMIT 10;
|
||||
"""
|
||||
|
||||
store_agents = await db.query_raw(query)
|
||||
print(f"Total store agents in view: {len(store_agents)}")
|
||||
|
||||
if store_agents:
|
||||
for agent in store_agents[:5]:
|
||||
print(f"\nStore Agent: {agent['agent_name']}")
|
||||
print(f" Slug: {agent['slug']}")
|
||||
print(f" Runs: {agent['runs']}")
|
||||
print(f" Rating: {agent['rating']}")
|
||||
print(f" Creator: {agent['creator_username']}")
|
||||
|
||||
# Check the underlying data that should populate StoreAgent
|
||||
print("\n\n5. Data that should populate StoreAgent view:")
|
||||
print("-" * 40)
|
||||
|
||||
# Check for any APPROVED store listing versions
|
||||
query = """
|
||||
SELECT COUNT(*) as count
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
"""
|
||||
|
||||
result = await db.query_raw(query)
|
||||
approved_count = result[0]["count"] if result else 0
|
||||
print(f"Approved store listing versions: {approved_count}")
|
||||
|
||||
# Check for store listings with hasApprovedVersion = true
|
||||
query = """
|
||||
SELECT COUNT(*) as count
|
||||
FROM "StoreListing"
|
||||
WHERE "hasApprovedVersion" = true AND "isDeleted" = false
|
||||
"""
|
||||
|
||||
result = await db.query_raw(query)
|
||||
has_approved_count = result[0]["count"] if result else 0
|
||||
print(f"Store listings with approved versions: {has_approved_count}")
|
||||
|
||||
# Check agent graph executions
|
||||
query = """
|
||||
SELECT COUNT(DISTINCT "agentGraphId") as unique_agents,
|
||||
COUNT(*) as total_executions
|
||||
FROM "AgentGraphExecution"
|
||||
"""
|
||||
|
||||
result = await db.query_raw(query)
|
||||
if result:
|
||||
print("\nAgent Graph Executions:")
|
||||
print(f" Unique agents with executions: {result[0]['unique_agents']}")
|
||||
print(f" Total executions: {result[0]['total_executions']}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function."""
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
|
||||
try:
|
||||
await check_store_data(db)
|
||||
finally:
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -425,28 +425,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
raise ValueError(f"{self.name} did not produce any output for {output}")
|
||||
|
||||
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
|
||||
stats_dict = stats.model_dump()
|
||||
current_stats = self.execution_stats.model_dump()
|
||||
|
||||
for key, value in stats_dict.items():
|
||||
if key not in current_stats:
|
||||
# Field doesn't exist yet, just set it, but this will probably
|
||||
# not happen, just in case though so we throw for invalid when
|
||||
# converting back in
|
||||
current_stats[key] = value
|
||||
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
|
||||
current_stats[key].update(value)
|
||||
elif isinstance(value, (int, float)) and isinstance(
|
||||
current_stats[key], (int, float)
|
||||
):
|
||||
current_stats[key] += value
|
||||
elif isinstance(value, list) and isinstance(current_stats[key], list):
|
||||
current_stats[key].extend(value)
|
||||
else:
|
||||
current_stats[key] = value
|
||||
|
||||
self.execution_stats = NodeExecutionStats(**current_stats)
|
||||
|
||||
self.execution_stats += stats
|
||||
return self.execution_stats
|
||||
|
||||
@property
|
||||
@@ -513,6 +492,12 @@ def get_blocks() -> dict[str, Type[Block]]:
|
||||
|
||||
|
||||
async def initialize_blocks() -> None:
|
||||
# First, sync all provider costs to blocks
|
||||
# Imported here to avoid circular import
|
||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||
|
||||
sync_all_provider_costs()
|
||||
|
||||
for cls in get_blocks().values():
|
||||
block = cls()
|
||||
existing_block = await AgentBlock.prisma().find_first(
|
||||
|
||||
@@ -93,6 +93,28 @@ async def locked_transaction(key: str):
|
||||
yield tx
|
||||
|
||||
|
||||
def get_database_schema() -> str:
|
||||
"""Extract database schema from DATABASE_URL."""
|
||||
parsed_url = urlparse(DATABASE_URL)
|
||||
query_params = dict(parse_qsl(parsed_url.query))
|
||||
return query_params.get("schema", "public")
|
||||
|
||||
|
||||
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
||||
"""Execute raw SQL query with proper schema handling."""
|
||||
schema = get_database_schema()
|
||||
schema_prefix = f"{schema}." if schema != "public" else ""
|
||||
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
||||
|
||||
import prisma as prisma_module
|
||||
|
||||
result = await prisma_module.get_client().query_raw(
|
||||
formatted_query, *args # type: ignore
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class BaseDbModel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ from .block import (
|
||||
get_io_block_ids,
|
||||
get_webhook_block_ids,
|
||||
)
|
||||
from .db import BaseDbModel
|
||||
from .db import BaseDbModel, query_raw_with_schema
|
||||
from .event_bus import AsyncRedisEventBus, RedisEventBus
|
||||
from .includes import (
|
||||
EXECUTION_RESULT_INCLUDE,
|
||||
@@ -68,6 +68,21 @@ config = Config()
|
||||
# -------------------------- Models -------------------------- #
|
||||
|
||||
|
||||
class BlockErrorStats(BaseModel):
|
||||
"""Typed data structure for block error statistics."""
|
||||
|
||||
block_id: str
|
||||
total_executions: int
|
||||
failed_executions: int
|
||||
|
||||
@property
|
||||
def error_rate(self) -> float:
|
||||
"""Calculate error rate as a percentage."""
|
||||
if self.total_executions == 0:
|
||||
return 0.0
|
||||
return (self.failed_executions / self.total_executions) * 100
|
||||
|
||||
|
||||
ExecutionStatus = AgentExecutionStatus
|
||||
|
||||
|
||||
@@ -357,6 +372,7 @@ async def get_graph_executions(
|
||||
created_time_lte: datetime | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[GraphExecutionMeta]:
|
||||
"""⚠️ **Optional `user_id` check**: MUST USE check in user-facing endpoints."""
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
"isDeleted": False,
|
||||
}
|
||||
@@ -722,6 +738,7 @@ async def delete_graph_execution(
|
||||
|
||||
|
||||
async def get_node_execution(node_exec_id: str) -> NodeExecutionResult | None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
execution = await AgentNodeExecution.prisma().find_first(
|
||||
where={"id": node_exec_id},
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
@@ -732,15 +749,19 @@ async def get_node_execution(node_exec_id: str) -> NodeExecutionResult | None:
|
||||
|
||||
|
||||
async def get_node_executions(
|
||||
graph_exec_id: str,
|
||||
graph_exec_id: str | None = None,
|
||||
node_id: str | None = None,
|
||||
block_ids: list[str] | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
limit: int | None = None,
|
||||
created_time_gte: datetime | None = None,
|
||||
created_time_lte: datetime | None = None,
|
||||
include_exec_data: bool = True,
|
||||
) -> list[NodeExecutionResult]:
|
||||
where_clause: AgentNodeExecutionWhereInput = {
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
}
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
where_clause: AgentNodeExecutionWhereInput = {}
|
||||
if graph_exec_id:
|
||||
where_clause["agentGraphExecutionId"] = graph_exec_id
|
||||
if node_id:
|
||||
where_clause["agentNodeId"] = node_id
|
||||
if block_ids:
|
||||
@@ -748,9 +769,19 @@ async def get_node_executions(
|
||||
if statuses:
|
||||
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
if created_time_gte or created_time_lte:
|
||||
where_clause["addedTime"] = {
|
||||
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
|
||||
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
|
||||
}
|
||||
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
include=(
|
||||
EXECUTION_RESULT_INCLUDE
|
||||
if include_exec_data
|
||||
else {"Node": True, "GraphExecution": True}
|
||||
),
|
||||
order=EXECUTION_RESULT_ORDER,
|
||||
take=limit,
|
||||
)
|
||||
@@ -761,6 +792,7 @@ async def get_node_executions(
|
||||
async def get_latest_node_execution(
|
||||
node_id: str, graph_eid: str
|
||||
) -> NodeExecutionResult | None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
execution = await AgentNodeExecution.prisma().find_first(
|
||||
where={
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
@@ -963,3 +995,33 @@ async def set_execution_kv_data(
|
||||
},
|
||||
)
|
||||
return type_utils.convert(resp.data, type[Any]) if resp and resp.data else None
|
||||
|
||||
|
||||
async def get_block_error_stats(
|
||||
start_time: datetime, end_time: datetime
|
||||
) -> list[BlockErrorStats]:
|
||||
"""Get block execution stats using efficient SQL aggregation."""
|
||||
|
||||
query_template = """
|
||||
SELECT
|
||||
n."agentBlockId" as block_id,
|
||||
COUNT(*) as total_executions,
|
||||
SUM(CASE WHEN ne."executionStatus" = 'FAILED' THEN 1 ELSE 0 END) as failed_executions
|
||||
FROM {schema_prefix}"AgentNodeExecution" ne
|
||||
JOIN {schema_prefix}"AgentNode" n ON ne."agentNodeId" = n.id
|
||||
WHERE ne."addedTime" >= $1::timestamp AND ne."addedTime" <= $2::timestamp
|
||||
GROUP BY n."agentBlockId"
|
||||
HAVING COUNT(*) >= 10
|
||||
"""
|
||||
|
||||
result = await query_raw_with_schema(query_template, start_time, end_time)
|
||||
|
||||
# Convert to typed data structures
|
||||
return [
|
||||
BlockErrorStats(
|
||||
block_id=row["block_id"],
|
||||
total_executions=int(row["total_executions"]),
|
||||
failed_executions=int(row["failed_executions"]),
|
||||
)
|
||||
for row in result
|
||||
]
|
||||
|
||||
@@ -3,7 +3,6 @@ import uuid
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
import prisma
|
||||
from prisma import Json
|
||||
from prisma.enums import SubmissionStatus
|
||||
from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVersion
|
||||
@@ -14,7 +13,7 @@ from prisma.types import (
|
||||
AgentNodeLinkCreateInput,
|
||||
StoreListingVersionWhereInput,
|
||||
)
|
||||
from pydantic import JsonValue, create_model
|
||||
from pydantic import Field, JsonValue, create_model
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
@@ -31,7 +30,7 @@ from backend.integrations.providers import ProviderName
|
||||
from backend.util import type as type_utils
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
from .db import BaseDbModel, transaction
|
||||
from .db import BaseDbModel, query_raw_with_schema, transaction
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -189,6 +188,23 @@ class BaseGraph(BaseDbModel):
|
||||
)
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def has_external_trigger(self) -> bool:
|
||||
return self.webhook_input_node is not None
|
||||
|
||||
@property
|
||||
def webhook_input_node(self) -> Node | None:
|
||||
return next(
|
||||
(
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_schema(
|
||||
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
|
||||
@@ -326,11 +342,6 @@ class GraphModel(Graph):
|
||||
user_id: str
|
||||
nodes: list[NodeModel] = [] # type: ignore
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def has_webhook_trigger(self) -> bool:
|
||||
return self.webhook_input_node is not None
|
||||
|
||||
@property
|
||||
def starting_nodes(self) -> list[NodeModel]:
|
||||
outbound_nodes = {link.sink_id for link in self.links}
|
||||
@@ -343,17 +354,12 @@ class GraphModel(Graph):
|
||||
if node.id not in outbound_nodes or node.id in input_nodes
|
||||
]
|
||||
|
||||
@property
|
||||
def webhook_input_node(self) -> NodeModel | None:
|
||||
return next(
|
||||
(
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
),
|
||||
None,
|
||||
)
|
||||
def meta(self) -> "GraphMeta":
|
||||
"""
|
||||
Returns a GraphMeta object with metadata about the graph.
|
||||
This is used to return metadata about the graph without exposing nodes and links.
|
||||
"""
|
||||
return GraphMeta.from_graph(self)
|
||||
|
||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||
"""
|
||||
@@ -612,6 +618,18 @@ class GraphModel(Graph):
|
||||
)
|
||||
|
||||
|
||||
class GraphMeta(Graph):
|
||||
user_id: str
|
||||
|
||||
# Easy work-around to prevent exposing nodes and links in the API response
|
||||
nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
|
||||
links: list[Link] = Field(default=[], exclude=True)
|
||||
|
||||
@staticmethod
|
||||
def from_graph(graph: GraphModel) -> "GraphMeta":
|
||||
return GraphMeta(**graph.model_dump())
|
||||
|
||||
|
||||
# --------------------- CRUD functions --------------------- #
|
||||
|
||||
|
||||
@@ -640,10 +658,10 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
|
||||
return NodeModel.from_db(node)
|
||||
|
||||
|
||||
async def get_graphs(
|
||||
async def list_graphs(
|
||||
user_id: str,
|
||||
filter_by: Literal["active"] | None = "active",
|
||||
) -> list[GraphModel]:
|
||||
) -> list[GraphMeta]:
|
||||
"""
|
||||
Retrieves graph metadata objects.
|
||||
Default behaviour is to get all currently active graphs.
|
||||
@@ -653,7 +671,7 @@ async def get_graphs(
|
||||
user_id: The ID of the user that owns the graph.
|
||||
|
||||
Returns:
|
||||
list[GraphModel]: A list of objects representing the retrieved graphs.
|
||||
list[GraphMeta]: A list of objects representing the retrieved graphs.
|
||||
"""
|
||||
where_clause: AgentGraphWhereInput = {"userId": user_id}
|
||||
|
||||
@@ -667,13 +685,13 @@ async def get_graphs(
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
|
||||
graph_models = []
|
||||
graph_models: list[GraphMeta] = []
|
||||
for graph in graphs:
|
||||
try:
|
||||
graph_model = GraphModel.from_db(graph)
|
||||
# Trigger serialization to validate that the graph is well formed.
|
||||
graph_model.model_dump()
|
||||
graph_models.append(graph_model)
|
||||
graph_meta = GraphModel.from_db(graph).meta()
|
||||
# Trigger serialization to validate that the graph is well formed
|
||||
graph_meta.model_dump()
|
||||
graph_models.append(graph_meta)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing graph {graph.id}: {e}")
|
||||
continue
|
||||
@@ -1040,13 +1058,13 @@ async def fix_llm_provider_credentials():
|
||||
|
||||
broken_nodes = []
|
||||
try:
|
||||
broken_nodes = await prisma.get_client().query_raw(
|
||||
broken_nodes = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT graph."userId" user_id,
|
||||
node.id node_id,
|
||||
node."constantInput" node_preset_input
|
||||
FROM platform."AgentNode" node
|
||||
LEFT JOIN platform."AgentGraph" graph
|
||||
FROM {schema_prefix}"AgentNode" node
|
||||
LEFT JOIN {schema_prefix}"AgentGraph" graph
|
||||
ON node."agentGraphId" = graph.id
|
||||
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
|
||||
ORDER BY graph."userId";
|
||||
|
||||
@@ -42,6 +42,9 @@ from pydantic_core import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
# Type alias for any provider name (including custom ones)
|
||||
AnyProviderName = str # Will be validated as ProviderName at runtime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
@@ -341,7 +344,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
type: CT
|
||||
|
||||
@classmethod
|
||||
def allowed_providers(cls) -> tuple[ProviderName, ...]:
|
||||
def allowed_providers(cls) -> tuple[ProviderName, ...] | None:
|
||||
return get_args(cls.model_fields["provider"].annotation)
|
||||
|
||||
@classmethod
|
||||
@@ -366,7 +369,12 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
f"{field_schema}"
|
||||
) from e
|
||||
|
||||
if len(cls.allowed_providers()) > 1 and not schema_extra.discriminator:
|
||||
providers = cls.allowed_providers()
|
||||
if (
|
||||
providers is not None
|
||||
and len(providers) > 1
|
||||
and not schema_extra.discriminator
|
||||
):
|
||||
raise TypeError(
|
||||
f"Multi-provider CredentialsField '{field_name}' "
|
||||
"requires discriminator!"
|
||||
@@ -378,7 +386,12 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
if hasattr(model_class, "allowed_providers") and hasattr(
|
||||
model_class, "allowed_cred_types"
|
||||
):
|
||||
schema["credentials_provider"] = model_class.allowed_providers()
|
||||
allowed_providers = model_class.allowed_providers()
|
||||
# If no specific providers (None), allow any string
|
||||
if allowed_providers is None:
|
||||
schema["credentials_provider"] = ["string"] # Allow any string provider
|
||||
else:
|
||||
schema["credentials_provider"] = allowed_providers
|
||||
schema["credentials_types"] = model_class.allowed_cred_types()
|
||||
# Do not return anything, just mutate schema in place
|
||||
|
||||
@@ -540,6 +553,11 @@ def CredentialsField(
|
||||
if v is not None
|
||||
}
|
||||
|
||||
# Merge any json_schema_extra passed in kwargs
|
||||
if "json_schema_extra" in kwargs:
|
||||
extra_schema = kwargs.pop("json_schema_extra")
|
||||
field_schema_extra.update(extra_schema)
|
||||
|
||||
return Field(
|
||||
title=title,
|
||||
description=description,
|
||||
@@ -618,6 +636,35 @@ class NodeExecutionStats(BaseModel):
|
||||
llm_retry_count: int = 0
|
||||
input_token_count: int = 0
|
||||
output_token_count: int = 0
|
||||
extra_cost: int = 0
|
||||
extra_steps: int = 0
|
||||
|
||||
def __iadd__(self, other: "NodeExecutionStats") -> "NodeExecutionStats":
|
||||
"""Mutate this instance by adding another NodeExecutionStats."""
|
||||
if not isinstance(other, NodeExecutionStats):
|
||||
return NotImplemented
|
||||
|
||||
stats_dict = other.model_dump()
|
||||
current_stats = self.model_dump()
|
||||
|
||||
for key, value in stats_dict.items():
|
||||
if key not in current_stats:
|
||||
# Field doesn't exist yet, just set it
|
||||
setattr(self, key, value)
|
||||
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
|
||||
current_stats[key].update(value)
|
||||
setattr(self, key, current_stats[key])
|
||||
elif isinstance(value, (int, float)) and isinstance(
|
||||
current_stats[key], (int, float)
|
||||
):
|
||||
setattr(self, key, current_stats[key] + value)
|
||||
elif isinstance(value, list) and isinstance(current_stats[key], list):
|
||||
current_stats[key].extend(value)
|
||||
setattr(self, key, current_stats[key])
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class GraphExecutionStats(BaseModel):
|
||||
|
||||
@@ -5,6 +5,7 @@ from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
create_graph_execution,
|
||||
get_block_error_stats,
|
||||
get_execution_kv_data,
|
||||
get_graph_execution,
|
||||
get_graph_execution_meta,
|
||||
@@ -105,6 +106,7 @@ class DatabaseManager(AppService):
|
||||
upsert_execution_output = _(upsert_execution_output)
|
||||
get_execution_kv_data = _(get_execution_kv_data)
|
||||
set_execution_kv_data = _(set_execution_kv_data)
|
||||
get_block_error_stats = _(get_block_error_stats)
|
||||
|
||||
# Graphs
|
||||
get_node = _(get_node)
|
||||
@@ -199,6 +201,9 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Block error monitoring
|
||||
get_block_error_stats = _(d.get_block_error_stats)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -226,3 +231,4 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
update_user_integrations = d.update_user_integrations
|
||||
get_execution_kv_data = d.get_execution_kv_data
|
||||
set_execution_kv_data = d.set_execution_kv_data
|
||||
get_block_error_stats = d.get_block_error_stats
|
||||
|
||||
@@ -207,9 +207,7 @@ async def execute_node(
|
||||
|
||||
# Update execution stats
|
||||
if execution_stats is not None:
|
||||
execution_stats = execution_stats.model_copy(
|
||||
update=node_block.execution_stats.model_dump()
|
||||
)
|
||||
execution_stats += node_block.execution_stats
|
||||
execution_stats.input_size = input_size
|
||||
execution_stats.output_size = output_size
|
||||
|
||||
@@ -648,9 +646,10 @@ class Executor:
|
||||
return
|
||||
|
||||
nonlocal execution_stats
|
||||
execution_stats.node_count += 1
|
||||
execution_stats.node_count += 1 + result.extra_steps
|
||||
execution_stats.nodes_cputime += result.cputime
|
||||
execution_stats.nodes_walltime += result.walltime
|
||||
execution_stats.cost += result.extra_cost
|
||||
if (err := result.error) and isinstance(err, Exception):
|
||||
execution_stats.node_error_count += 1
|
||||
update_node_execution_status(
|
||||
@@ -877,6 +876,7 @@ class Executor:
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
include_exec_data=False,
|
||||
)
|
||||
db_client.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in inflight_executions],
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
@@ -14,25 +13,23 @@ from apscheduler.schedulers.blocking import BlockingScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from dotenv import load_dotenv
|
||||
from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.monitoring import (
|
||||
NotificationJobArgs,
|
||||
process_existing_batches,
|
||||
process_weekly_summary,
|
||||
report_block_error_rates,
|
||||
report_late_executions,
|
||||
)
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.logging import PrefixFilter
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.service import AppService, AppServiceClient, endpoint_to_async, expose
|
||||
from backend.util.settings import Config
|
||||
|
||||
|
||||
@@ -71,11 +68,6 @@ def job_listener(event):
|
||||
logger.info(f"Job {event.job_id} completed successfully.")
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_client():
|
||||
return get_service_client(NotificationManagerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_event_loop():
|
||||
return asyncio.new_event_loop()
|
||||
@@ -89,7 +81,7 @@ async def _execute_graph(**kwargs):
|
||||
args = GraphExecutionJobArgs(**kwargs)
|
||||
try:
|
||||
logger.info(f"Executing recurring job for graph #{args.graph_id}")
|
||||
await execution_utils.add_graph_execution(
|
||||
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
|
||||
user_id=args.user_id,
|
||||
graph_id=args.graph_id,
|
||||
graph_version=args.graph_version,
|
||||
@@ -97,65 +89,14 @@ async def _execute_graph(**kwargs):
|
||||
graph_credentials_inputs=args.input_credentials,
|
||||
use_db_query=False,
|
||||
)
|
||||
logger.info(
|
||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing graph {args.graph_id}: {e}")
|
||||
|
||||
|
||||
class LateExecutionException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def report_late_executions() -> str:
|
||||
late_executions = execution_utils.get_db_client().get_graph_executions(
|
||||
statuses=[ExecutionStatus.QUEUED],
|
||||
created_time_gte=datetime.now(timezone.utc)
|
||||
- timedelta(seconds=config.execution_late_notification_checkrange_secs),
|
||||
created_time_lte=datetime.now(timezone.utc)
|
||||
- timedelta(seconds=config.execution_late_notification_threshold_secs),
|
||||
limit=1000,
|
||||
)
|
||||
|
||||
if not late_executions:
|
||||
return "No late executions detected."
|
||||
|
||||
num_late_executions = len(late_executions)
|
||||
num_users = len(set([r.user_id for r in late_executions]))
|
||||
|
||||
late_execution_details = [
|
||||
f"* `Execution ID: {exec.id}, Graph ID: {exec.graph_id}v{exec.graph_version}, User ID: {exec.user_id}, Created At: {exec.started_at.isoformat()}`"
|
||||
for exec in late_executions
|
||||
]
|
||||
|
||||
error = LateExecutionException(
|
||||
f"Late executions detected: {num_late_executions} late executions from {num_users} users "
|
||||
f"in the last {config.execution_late_notification_checkrange_secs} seconds. "
|
||||
f"Graph has been queued for more than {config.execution_late_notification_threshold_secs} seconds. "
|
||||
"Please check the executor status. Details:\n"
|
||||
+ "\n".join(late_execution_details)
|
||||
)
|
||||
msg = str(error)
|
||||
sentry_capture_error(error)
|
||||
get_notification_client().discord_system_alert(msg)
|
||||
return msg
|
||||
|
||||
|
||||
def process_existing_batches(**kwargs):
|
||||
args = NotificationJobArgs(**kwargs)
|
||||
try:
|
||||
logger.info(
|
||||
f"Processing existing batches for notification type {args.notification_types}"
|
||||
)
|
||||
get_notification_client().process_existing_batches(args.notification_types)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing existing batches: {e}")
|
||||
|
||||
|
||||
def process_weekly_summary(**kwargs):
|
||||
try:
|
||||
logger.info("Processing weekly summary")
|
||||
get_notification_client().queue_weekly_summary()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing weekly summary: {e}")
|
||||
# Monitoring functions are now imported from monitoring module
|
||||
|
||||
|
||||
class Jobstores(Enum):
|
||||
@@ -190,11 +131,6 @@ class GraphExecutionJobInfo(GraphExecutionJobArgs):
|
||||
)
|
||||
|
||||
|
||||
class NotificationJobArgs(BaseModel):
|
||||
notification_types: list[NotificationType]
|
||||
cron: str
|
||||
|
||||
|
||||
class NotificationJobInfo(NotificationJobArgs):
|
||||
id: str
|
||||
name: str
|
||||
@@ -287,6 +223,16 @@ class Scheduler(AppService):
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
# Block Error Rate Monitoring
|
||||
self.scheduler.add_job(
|
||||
report_block_error_rates,
|
||||
id="report_block_error_rates",
|
||||
trigger="interval",
|
||||
replace_existing=True,
|
||||
seconds=config.block_error_rate_check_interval_secs,
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||
self.scheduler.start()
|
||||
|
||||
@@ -379,6 +325,10 @@ class Scheduler(AppService):
|
||||
def execute_report_late_executions(self):
|
||||
return report_late_executions()
|
||||
|
||||
@expose
|
||||
def execute_report_block_error_rates(self):
|
||||
return report_block_error_rates()
|
||||
|
||||
|
||||
class SchedulerClient(AppServiceClient):
|
||||
@classmethod
|
||||
|
||||
@@ -731,6 +731,7 @@ async def stop_graph_execution(
|
||||
node_execs = await db.get_node_executions(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE],
|
||||
include_exec_data=False,
|
||||
)
|
||||
await db.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
|
||||
@@ -1,29 +1,226 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.integrations.oauth.todoist import TodoistOAuthHandler
|
||||
|
||||
from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .linear import LinearOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
from .twitter import TwitterOAuthHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
from .base import BaseOAuthHandler
|
||||
|
||||
# --8<-- [start:HANDLERS_BY_NAMEExample]
|
||||
HANDLERS_BY_NAME: dict["ProviderName", type["BaseOAuthHandler"]] = {
|
||||
handler.PROVIDER_NAME: handler
|
||||
for handler in [
|
||||
GitHubOAuthHandler,
|
||||
GoogleOAuthHandler,
|
||||
NotionOAuthHandler,
|
||||
TwitterOAuthHandler,
|
||||
LinearOAuthHandler,
|
||||
TodoistOAuthHandler,
|
||||
]
|
||||
# Build handlers dict with string keys for compatibility with SDK auto-registration
|
||||
_ORIGINAL_HANDLERS = [
|
||||
GitHubOAuthHandler,
|
||||
GoogleOAuthHandler,
|
||||
NotionOAuthHandler,
|
||||
TwitterOAuthHandler,
|
||||
TodoistOAuthHandler,
|
||||
]
|
||||
|
||||
# Start with original handlers
|
||||
_handlers_dict = {
|
||||
(
|
||||
handler.PROVIDER_NAME.value
|
||||
if hasattr(handler.PROVIDER_NAME, "value")
|
||||
else str(handler.PROVIDER_NAME)
|
||||
): handler
|
||||
for handler in _ORIGINAL_HANDLERS
|
||||
}
|
||||
|
||||
|
||||
class SDKAwareCredentials(BaseModel):
|
||||
"""OAuth credentials configuration."""
|
||||
|
||||
use_secrets: bool = True
|
||||
client_id_env_var: Optional[str] = None
|
||||
client_secret_env_var: Optional[str] = None
|
||||
|
||||
|
||||
_credentials_by_provider = {}
|
||||
# Add default credentials for original handlers
|
||||
for handler in _ORIGINAL_HANDLERS:
|
||||
provider_name = (
|
||||
handler.PROVIDER_NAME.value
|
||||
if hasattr(handler.PROVIDER_NAME, "value")
|
||||
else str(handler.PROVIDER_NAME)
|
||||
)
|
||||
_credentials_by_provider[provider_name] = SDKAwareCredentials(
|
||||
use_secrets=True, client_id_env_var=None, client_secret_env_var=None
|
||||
)
|
||||
|
||||
|
||||
# Create a custom dict class that includes SDK handlers
|
||||
class SDKAwareHandlersDict(dict):
|
||||
"""Dictionary that automatically includes SDK-registered OAuth handlers."""
|
||||
|
||||
def __getitem__(self, key):
|
||||
# First try the original handlers
|
||||
if key in _handlers_dict:
|
||||
return _handlers_dict[key]
|
||||
|
||||
# Then try SDK handlers
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_handlers = AutoRegistry.get_oauth_handlers()
|
||||
if key in sdk_handlers:
|
||||
return sdk_handlers[key]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# If not found, raise KeyError
|
||||
raise KeyError(key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
if key in _handlers_dict:
|
||||
return True
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_handlers = AutoRegistry.get_oauth_handlers()
|
||||
return key in sdk_handlers
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def keys(self):
|
||||
# Combine all keys into a single dict and return its keys view
|
||||
combined = dict(_handlers_dict)
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_handlers = AutoRegistry.get_oauth_handlers()
|
||||
combined.update(sdk_handlers)
|
||||
except ImportError:
|
||||
pass
|
||||
return combined.keys()
|
||||
|
||||
def values(self):
|
||||
combined = dict(_handlers_dict)
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_handlers = AutoRegistry.get_oauth_handlers()
|
||||
combined.update(sdk_handlers)
|
||||
except ImportError:
|
||||
pass
|
||||
return combined.values()
|
||||
|
||||
def items(self):
|
||||
combined = dict(_handlers_dict)
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_handlers = AutoRegistry.get_oauth_handlers()
|
||||
combined.update(sdk_handlers)
|
||||
except ImportError:
|
||||
pass
|
||||
return combined.items()
|
||||
|
||||
|
||||
class SDKAwareCredentialsDict(dict):
|
||||
"""Dictionary that automatically includes SDK-registered OAuth credentials."""
|
||||
|
||||
def __getitem__(self, key):
|
||||
# First try the original handlers
|
||||
if key in _credentials_by_provider:
|
||||
return _credentials_by_provider[key]
|
||||
|
||||
# Then try SDK credentials
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_credentials = AutoRegistry.get_oauth_credentials()
|
||||
if key in sdk_credentials:
|
||||
# Convert from SDKOAuthCredentials to SDKAwareCredentials
|
||||
sdk_cred = sdk_credentials[key]
|
||||
return SDKAwareCredentials(
|
||||
use_secrets=sdk_cred.use_secrets,
|
||||
client_id_env_var=sdk_cred.client_id_env_var,
|
||||
client_secret_env_var=sdk_cred.client_secret_env_var,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# If not found, raise KeyError
|
||||
raise KeyError(key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
if key in _credentials_by_provider:
|
||||
return True
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_credentials = AutoRegistry.get_oauth_credentials()
|
||||
return key in sdk_credentials
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def keys(self):
|
||||
# Combine all keys into a single dict and return its keys view
|
||||
combined = dict(_credentials_by_provider)
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_credentials = AutoRegistry.get_oauth_credentials()
|
||||
combined.update(sdk_credentials)
|
||||
except ImportError:
|
||||
pass
|
||||
return combined.keys()
|
||||
|
||||
def values(self):
|
||||
combined = dict(_credentials_by_provider)
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_credentials = AutoRegistry.get_oauth_credentials()
|
||||
# Convert SDK credentials to SDKAwareCredentials
|
||||
for key, sdk_cred in sdk_credentials.items():
|
||||
combined[key] = SDKAwareCredentials(
|
||||
use_secrets=sdk_cred.use_secrets,
|
||||
client_id_env_var=sdk_cred.client_id_env_var,
|
||||
client_secret_env_var=sdk_cred.client_secret_env_var,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
return combined.values()
|
||||
|
||||
def items(self):
|
||||
combined = dict(_credentials_by_provider)
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_credentials = AutoRegistry.get_oauth_credentials()
|
||||
# Convert SDK credentials to SDKAwareCredentials
|
||||
for key, sdk_cred in sdk_credentials.items():
|
||||
combined[key] = SDKAwareCredentials(
|
||||
use_secrets=sdk_cred.use_secrets,
|
||||
client_id_env_var=sdk_cred.client_id_env_var,
|
||||
client_secret_env_var=sdk_cred.client_secret_env_var,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
return combined.items()
|
||||
|
||||
|
||||
HANDLERS_BY_NAME: dict[str, type["BaseOAuthHandler"]] = SDKAwareHandlersDict()
|
||||
CREDENTIALS_BY_PROVIDER: dict[str, SDKAwareCredentials] = SDKAwareCredentialsDict()
|
||||
# --8<-- [end:HANDLERS_BY_NAMEExample]
|
||||
|
||||
__all__ = ["HANDLERS_BY_NAME"]
|
||||
|
||||
@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseOAuthHandler(ABC):
|
||||
# --8<-- [start:BaseOAuthHandler1]
|
||||
PROVIDER_NAME: ClassVar[ProviderName]
|
||||
PROVIDER_NAME: ClassVar[ProviderName | str]
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
||||
# --8<-- [end:BaseOAuthHandler1]
|
||||
|
||||
@@ -81,8 +81,6 @@ class BaseOAuthHandler(ABC):
|
||||
"""Handles the default scopes for the provider"""
|
||||
# If scopes are empty, use the default scopes for the provider
|
||||
if not scopes:
|
||||
logger.debug(
|
||||
f"Using default scopes for provider {self.PROVIDER_NAME.value}"
|
||||
)
|
||||
logger.debug(f"Using default scopes for provider {str(self.PROVIDER_NAME)}")
|
||||
scopes = self.DEFAULT_SCOPES
|
||||
return scopes
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
# --8<-- [start:ProviderName]
|
||||
class ProviderName(str, Enum):
|
||||
"""
|
||||
Provider names for integrations.
|
||||
|
||||
This enum extends str to accept any string value while maintaining
|
||||
backward compatibility with existing provider constants.
|
||||
"""
|
||||
|
||||
AIML_API = "aiml_api"
|
||||
ANTHROPIC = "anthropic"
|
||||
APOLLO = "apollo"
|
||||
@@ -10,9 +18,7 @@ class ProviderName(str, Enum):
|
||||
DISCORD = "discord"
|
||||
D_ID = "d_id"
|
||||
E2B = "e2b"
|
||||
EXA = "exa"
|
||||
FAL = "fal"
|
||||
GENERIC_WEBHOOK = "generic_webhook"
|
||||
GITHUB = "github"
|
||||
GOOGLE = "google"
|
||||
GOOGLE_MAPS = "google_maps"
|
||||
@@ -21,7 +27,6 @@ class ProviderName(str, Enum):
|
||||
HUBSPOT = "hubspot"
|
||||
IDEOGRAM = "ideogram"
|
||||
JINA = "jina"
|
||||
LINEAR = "linear"
|
||||
LLAMA_API = "llama_api"
|
||||
MEDIUM = "medium"
|
||||
MEM0 = "mem0"
|
||||
@@ -43,4 +48,57 @@ class ProviderName(str, Enum):
|
||||
TODOIST = "todoist"
|
||||
UNREAL_SPEECH = "unreal_speech"
|
||||
ZEROBOUNCE = "zerobounce"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: Any) -> "ProviderName":
|
||||
"""
|
||||
Allow any string value to be used as a ProviderName.
|
||||
This enables SDK users to define custom providers without
|
||||
modifying the enum.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
# Create a pseudo-member that behaves like an enum member
|
||||
pseudo_member = str.__new__(cls, value)
|
||||
pseudo_member._name_ = value.upper()
|
||||
pseudo_member._value_ = value
|
||||
return pseudo_member
|
||||
return None # type: ignore
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, schema, handler):
|
||||
"""
|
||||
Custom JSON schema generation that allows any string value,
|
||||
not just the predefined enum values.
|
||||
"""
|
||||
# Get the default schema
|
||||
json_schema = handler(schema)
|
||||
|
||||
# Remove the enum constraint to allow any string
|
||||
if "enum" in json_schema:
|
||||
del json_schema["enum"]
|
||||
|
||||
# Keep the type as string
|
||||
json_schema["type"] = "string"
|
||||
|
||||
# Update description to indicate custom providers are allowed
|
||||
json_schema["description"] = (
|
||||
"Provider name for integrations. "
|
||||
"Can be any string value, including custom provider names."
|
||||
)
|
||||
|
||||
return json_schema
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type, handler):
|
||||
"""
|
||||
Pydantic v2 core schema that allows any string value.
|
||||
"""
|
||||
from pydantic_core import core_schema
|
||||
|
||||
# Create a string schema that validates any string
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls,
|
||||
core_schema.str_schema(),
|
||||
)
|
||||
|
||||
# --8<-- [end:ProviderName]
|
||||
|
||||
@@ -12,7 +12,6 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
|
||||
webhook_managers = {}
|
||||
|
||||
from .compass import CompassWebhookManager
|
||||
from .generic import GenericWebhooksManager
|
||||
from .github import GithubWebhooksManager
|
||||
from .slant3d import Slant3DWebhooksManager
|
||||
|
||||
@@ -23,7 +22,6 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
|
||||
CompassWebhookManager,
|
||||
GithubWebhooksManager,
|
||||
Slant3DWebhooksManager,
|
||||
GenericWebhooksManager,
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
24
autogpt_platform/backend/backend/monitoring/__init__.py
Normal file
24
autogpt_platform/backend/backend/monitoring/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Monitoring module for platform health and alerting."""
|
||||
|
||||
from .block_error_monitor import BlockErrorMonitor, report_block_error_rates
|
||||
from .late_execution_monitor import (
|
||||
LateExecutionException,
|
||||
LateExecutionMonitor,
|
||||
report_late_executions,
|
||||
)
|
||||
from .notification_monitor import (
|
||||
NotificationJobArgs,
|
||||
process_existing_batches,
|
||||
process_weekly_summary,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BlockErrorMonitor",
|
||||
"LateExecutionMonitor",
|
||||
"LateExecutionException",
|
||||
"NotificationJobArgs",
|
||||
"report_block_error_rates",
|
||||
"report_late_executions",
|
||||
"process_existing_batches",
|
||||
"process_weekly_summary",
|
||||
]
|
||||
@@ -0,0 +1,291 @@
|
||||
"""Block error rate monitoring module."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
|
||||
|
||||
class BlockStatsWithSamples(BaseModel):
|
||||
"""Enhanced block stats with error samples."""
|
||||
|
||||
block_id: str
|
||||
block_name: str
|
||||
total_executions: int
|
||||
failed_executions: int
|
||||
error_samples: list[str] = []
|
||||
|
||||
@property
|
||||
def error_rate(self) -> float:
|
||||
"""Calculate error rate as a percentage."""
|
||||
if self.total_executions == 0:
|
||||
return 0.0
|
||||
return (self.failed_executions / self.total_executions) * 100
|
||||
|
||||
|
||||
class BlockErrorMonitor:
|
||||
"""Monitor block error rates and send alerts when thresholds are exceeded."""
|
||||
|
||||
def __init__(self, include_top_blocks: int | None = None):
|
||||
self.config = config
|
||||
self.notification_client = get_service_client(NotificationManagerClient)
|
||||
self.include_top_blocks = (
|
||||
include_top_blocks
|
||||
if include_top_blocks is not None
|
||||
else config.block_error_include_top_blocks
|
||||
)
|
||||
|
||||
def check_block_error_rates(self) -> str:
|
||||
"""Check block error rates and send Discord alerts if thresholds are exceeded."""
|
||||
try:
|
||||
logger.info("Checking block error rates")
|
||||
|
||||
# Get executions from the last 24 hours
|
||||
end_time = datetime.now(timezone.utc)
|
||||
start_time = end_time - timedelta(hours=24)
|
||||
|
||||
# Use SQL aggregation to efficiently count totals and failures by block
|
||||
block_stats = self._get_block_stats_from_db(start_time, end_time)
|
||||
|
||||
# For blocks with high error rates, fetch error samples
|
||||
threshold = self.config.block_error_rate_threshold
|
||||
for block_name, stats in block_stats.items():
|
||||
if stats.total_executions >= 10 and stats.error_rate >= threshold * 100:
|
||||
# Only fetch error samples for blocks that exceed threshold
|
||||
error_samples = self._get_error_samples_for_block(
|
||||
stats.block_id, start_time, end_time, limit=3
|
||||
)
|
||||
stats.error_samples = error_samples
|
||||
|
||||
# Check thresholds and send alerts
|
||||
critical_alerts = self._generate_critical_alerts(block_stats, threshold)
|
||||
|
||||
if critical_alerts:
|
||||
msg = "Block Error Rate Alert:\n\n" + "\n\n".join(critical_alerts)
|
||||
self.notification_client.discord_system_alert(msg)
|
||||
logger.info(
|
||||
f"Sent block error rate alert for {len(critical_alerts)} blocks"
|
||||
)
|
||||
return f"Alert sent for {len(critical_alerts)} blocks with high error rates"
|
||||
|
||||
# If no critical alerts, check if we should show top blocks
|
||||
if self.include_top_blocks > 0:
|
||||
top_blocks_msg = self._generate_top_blocks_alert(
|
||||
block_stats, start_time, end_time
|
||||
)
|
||||
if top_blocks_msg:
|
||||
self.notification_client.discord_system_alert(top_blocks_msg)
|
||||
logger.info("Sent top blocks summary")
|
||||
return "Sent top blocks summary"
|
||||
|
||||
logger.info("No blocks exceeded error rate threshold")
|
||||
return "No errors reported for today"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error checking block error rates: {e}")
|
||||
|
||||
error = Exception(f"Error checking block error rates: {e}")
|
||||
msg = str(error)
|
||||
sentry_capture_error(error)
|
||||
self.notification_client.discord_system_alert(msg)
|
||||
return msg
|
||||
|
||||
def _get_block_stats_from_db(
|
||||
self, start_time: datetime, end_time: datetime
|
||||
) -> dict[str, BlockStatsWithSamples]:
|
||||
"""Get block execution stats using efficient SQL aggregation."""
|
||||
|
||||
result = execution_utils.get_db_client().get_block_error_stats(
|
||||
start_time, end_time
|
||||
)
|
||||
|
||||
block_stats = {}
|
||||
for stats in result:
|
||||
block_name = b.name if (b := get_block(stats.block_id)) else "Unknown"
|
||||
|
||||
block_stats[block_name] = BlockStatsWithSamples(
|
||||
block_id=stats.block_id,
|
||||
block_name=block_name,
|
||||
total_executions=stats.total_executions,
|
||||
failed_executions=stats.failed_executions,
|
||||
error_samples=[],
|
||||
)
|
||||
|
||||
return block_stats
|
||||
|
||||
def _generate_critical_alerts(
|
||||
self, block_stats: dict[str, BlockStatsWithSamples], threshold: float
|
||||
) -> list[str]:
|
||||
"""Generate alerts for blocks that exceed the error rate threshold."""
|
||||
alerts = []
|
||||
|
||||
for block_name, stats in block_stats.items():
|
||||
if stats.total_executions >= 10 and stats.error_rate >= threshold * 100:
|
||||
error_groups = self._group_similar_errors(stats.error_samples)
|
||||
|
||||
alert_msg = (
|
||||
f"🚨 Block '{block_name}' has {stats.error_rate:.1f}% error rate "
|
||||
f"({stats.failed_executions}/{stats.total_executions}) in the last 24 hours"
|
||||
)
|
||||
|
||||
if error_groups:
|
||||
alert_msg += "\n\n📊 Error Types:"
|
||||
for error_pattern, count in error_groups.items():
|
||||
alert_msg += f"\n• {error_pattern} ({count}x)"
|
||||
|
||||
alerts.append(alert_msg)
|
||||
|
||||
return alerts
|
||||
|
||||
def _generate_top_blocks_alert(
|
||||
self,
|
||||
block_stats: dict[str, BlockStatsWithSamples],
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> str | None:
|
||||
"""Generate top blocks summary when no critical alerts exist."""
|
||||
top_error_blocks = sorted(
|
||||
[
|
||||
(name, stats)
|
||||
for name, stats in block_stats.items()
|
||||
if stats.total_executions >= 10 and stats.failed_executions > 0
|
||||
],
|
||||
key=lambda x: x[1].failed_executions,
|
||||
reverse=True,
|
||||
)[: self.include_top_blocks]
|
||||
|
||||
if not top_error_blocks:
|
||||
return "✅ No errors reported for today - all blocks are running smoothly!"
|
||||
|
||||
# Get error samples for top blocks
|
||||
for block_name, stats in top_error_blocks:
|
||||
if not stats.error_samples:
|
||||
stats.error_samples = self._get_error_samples_for_block(
|
||||
stats.block_id, start_time, end_time, limit=2
|
||||
)
|
||||
|
||||
count_text = (
|
||||
f"top {self.include_top_blocks}" if self.include_top_blocks > 1 else "top"
|
||||
)
|
||||
alert_msg = f"📊 Daily Error Summary - {count_text} blocks with most errors:"
|
||||
for block_name, stats in top_error_blocks:
|
||||
alert_msg += f"\n• {block_name}: {stats.failed_executions} errors ({stats.error_rate:.1f}% of {stats.total_executions})"
|
||||
|
||||
if stats.error_samples:
|
||||
error_groups = self._group_similar_errors(stats.error_samples)
|
||||
if error_groups:
|
||||
# Show most common error
|
||||
most_common_error = next(iter(error_groups.items()))
|
||||
alert_msg += f"\n └ Most common: {most_common_error[0]}"
|
||||
|
||||
return alert_msg
|
||||
|
||||
def _get_error_samples_for_block(
|
||||
self, block_id: str, start_time: datetime, end_time: datetime, limit: int = 3
|
||||
) -> list[str]:
|
||||
"""Get error samples for a specific block - just a few recent ones."""
|
||||
# Only fetch a small number of recent failed executions for this specific block
|
||||
executions = execution_utils.get_db_client().get_node_executions(
|
||||
block_ids=[block_id],
|
||||
statuses=[ExecutionStatus.FAILED],
|
||||
created_time_gte=start_time,
|
||||
created_time_lte=end_time,
|
||||
limit=limit, # Just get the limit we need
|
||||
)
|
||||
|
||||
error_samples = []
|
||||
for execution in executions:
|
||||
if error_message := self._extract_error_message(execution):
|
||||
masked_error = self._mask_sensitive_data(error_message)
|
||||
error_samples.append(masked_error)
|
||||
|
||||
if len(error_samples) >= limit: # Stop once we have enough samples
|
||||
break
|
||||
|
||||
return error_samples
|
||||
|
||||
def _extract_error_message(self, execution: NodeExecutionResult) -> str | None:
|
||||
"""Extract error message from execution output."""
|
||||
try:
|
||||
if execution.output_data and (
|
||||
error_msg := execution.output_data.get("error")
|
||||
):
|
||||
return str(error_msg[0])
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _mask_sensitive_data(self, error_message):
|
||||
"""Mask sensitive data in error messages to enable grouping."""
|
||||
if not error_message:
|
||||
return ""
|
||||
|
||||
# Convert to string if not already
|
||||
error_str = str(error_message)
|
||||
|
||||
# Mask numbers (replace with X)
|
||||
error_str = re.sub(r"\d+", "X", error_str)
|
||||
|
||||
# Mask all caps words (likely constants/IDs)
|
||||
error_str = re.sub(r"\b[A-Z_]{3,}\b", "MASKED", error_str)
|
||||
|
||||
# Mask words with underscores (likely internal variables)
|
||||
error_str = re.sub(r"\b\w*_\w*\b", "MASKED", error_str)
|
||||
|
||||
# Mask UUIDs and long alphanumeric strings
|
||||
error_str = re.sub(
|
||||
r"\b[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}\b",
|
||||
"UUID",
|
||||
error_str,
|
||||
)
|
||||
error_str = re.sub(r"\b[a-f0-9]{20,}\b", "HASH", error_str)
|
||||
|
||||
# Mask file paths
|
||||
error_str = re.sub(r"(/[^/\s]+)+", "/MASKED/path", error_str)
|
||||
|
||||
# Mask URLs
|
||||
error_str = re.sub(r"https?://[^\s]+", "URL", error_str)
|
||||
|
||||
# Mask email addresses
|
||||
error_str = re.sub(
|
||||
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", "EMAIL", error_str
|
||||
)
|
||||
|
||||
# Truncate if too long
|
||||
if len(error_str) > 100:
|
||||
error_str = error_str[:97] + "..."
|
||||
|
||||
return error_str.strip()
|
||||
|
||||
def _group_similar_errors(self, error_samples):
|
||||
"""Group similar error messages and return counts."""
|
||||
if not error_samples:
|
||||
return {}
|
||||
|
||||
error_groups = {}
|
||||
for error in error_samples:
|
||||
if error in error_groups:
|
||||
error_groups[error] += 1
|
||||
else:
|
||||
error_groups[error] = 1
|
||||
|
||||
# Sort by frequency, most common first
|
||||
return dict(sorted(error_groups.items(), key=lambda x: x[1], reverse=True))
|
||||
|
||||
|
||||
def report_block_error_rates(include_top_blocks: int | None = None):
|
||||
"""Check block error rates and send Discord alerts if thresholds are exceeded."""
|
||||
monitor = BlockErrorMonitor(include_top_blocks=include_top_blocks)
|
||||
return monitor.check_block_error_rates()
|
||||
@@ -0,0 +1,71 @@
|
||||
"""Late execution monitoring module."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
|
||||
|
||||
class LateExecutionException(Exception):
|
||||
"""Exception raised when late executions are detected."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LateExecutionMonitor:
|
||||
"""Monitor late executions and send alerts when thresholds are exceeded."""
|
||||
|
||||
def __init__(self):
|
||||
self.config = config
|
||||
self.notification_client = get_service_client(NotificationManagerClient)
|
||||
|
||||
def check_late_executions(self) -> str:
|
||||
"""Check for late executions and send alerts if found."""
|
||||
late_executions = execution_utils.get_db_client().get_graph_executions(
|
||||
statuses=[ExecutionStatus.QUEUED],
|
||||
created_time_gte=datetime.now(timezone.utc)
|
||||
- timedelta(
|
||||
seconds=self.config.execution_late_notification_checkrange_secs
|
||||
),
|
||||
created_time_lte=datetime.now(timezone.utc)
|
||||
- timedelta(seconds=self.config.execution_late_notification_threshold_secs),
|
||||
limit=1000,
|
||||
)
|
||||
|
||||
if not late_executions:
|
||||
return "No late executions detected."
|
||||
|
||||
num_late_executions = len(late_executions)
|
||||
num_users = len(set([r.user_id for r in late_executions]))
|
||||
|
||||
late_execution_details = [
|
||||
f"* `Execution ID: {exec.id}, Graph ID: {exec.graph_id}v{exec.graph_version}, User ID: {exec.user_id}, Created At: {exec.started_at.isoformat()}`"
|
||||
for exec in late_executions
|
||||
]
|
||||
|
||||
error = LateExecutionException(
|
||||
f"Late executions detected: {num_late_executions} late executions from {num_users} users "
|
||||
f"in the last {self.config.execution_late_notification_checkrange_secs} seconds. "
|
||||
f"Graph has been queued for more than {self.config.execution_late_notification_threshold_secs} seconds. "
|
||||
"Please check the executor status. Details:\n"
|
||||
+ "\n".join(late_execution_details)
|
||||
)
|
||||
msg = str(error)
|
||||
|
||||
sentry_capture_error(error)
|
||||
self.notification_client.discord_system_alert(msg)
|
||||
return msg
|
||||
|
||||
|
||||
def report_late_executions() -> str:
|
||||
"""Check for late executions and send Discord alerts if found."""
|
||||
monitor = LateExecutionMonitor()
|
||||
return monitor.check_late_executions()
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Notification processing monitoring module."""
|
||||
|
||||
import logging
|
||||
|
||||
from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationJobArgs(BaseModel):
|
||||
notification_types: list[NotificationType]
|
||||
cron: str
|
||||
|
||||
|
||||
def process_existing_batches(**kwargs):
|
||||
"""Process existing notification batches."""
|
||||
args = NotificationJobArgs(**kwargs)
|
||||
try:
|
||||
logging.info(
|
||||
f"Processing existing batches for notification type {args.notification_types}"
|
||||
)
|
||||
get_service_client(NotificationManagerClient).process_existing_batches(
|
||||
args.notification_types
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing existing batches: {e}")
|
||||
|
||||
|
||||
def process_weekly_summary(**kwargs):
|
||||
"""Process weekly summary notifications."""
|
||||
try:
|
||||
logging.info("Processing weekly summary")
|
||||
get_service_client(NotificationManagerClient).queue_weekly_summary()
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing weekly summary: {e}")
|
||||
169
autogpt_platform/backend/backend/sdk/__init__.py
Normal file
169
autogpt_platform/backend/backend/sdk/__init__.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
AutoGPT Platform Block Development SDK
|
||||
|
||||
Complete re-export of all dependencies needed for block development.
|
||||
Usage: from backend.sdk import *
|
||||
|
||||
This module provides:
|
||||
- All block base classes and types
|
||||
- All credential and authentication components
|
||||
- All cost tracking components
|
||||
- All webhook components
|
||||
- All utility functions
|
||||
- Auto-registration decorators
|
||||
"""
|
||||
|
||||
# Third-party imports
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
# === CORE BLOCK SYSTEM ===
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockManualWebhookConfig,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
)
|
||||
from backend.data.integrations import Webhook
|
||||
from backend.data.model import APIKeyCredentials, Credentials, CredentialsField
|
||||
from backend.data.model import CredentialsMetaInput as _CredentialsMetaInput
|
||||
from backend.data.model import (
|
||||
NodeExecutionStats,
|
||||
OAuth2Credentials,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
# === INTEGRATIONS ===
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk.builder import ProviderBuilder
|
||||
from backend.sdk.cost_integration import cost
|
||||
from backend.sdk.provider import Provider
|
||||
|
||||
# === NEW SDK COMPONENTS (imported early for patches) ===
|
||||
from backend.sdk.registry import AutoRegistry, BlockConfiguration
|
||||
|
||||
# === UTILITIES ===
|
||||
from backend.util import json
|
||||
from backend.util.request import Requests
|
||||
|
||||
# === OPTIONAL IMPORTS WITH TRY/EXCEPT ===
|
||||
# Webhooks
|
||||
try:
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
except ImportError:
|
||||
BaseWebhooksManager = None
|
||||
|
||||
try:
|
||||
from backend.integrations.webhooks._manual_base import ManualWebhookManagerBase
|
||||
except ImportError:
|
||||
ManualWebhookManagerBase = None
|
||||
|
||||
# Cost System
|
||||
try:
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
except ImportError:
|
||||
from backend.data.block_cost_config import BlockCost, BlockCostType
|
||||
|
||||
try:
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
except ImportError:
|
||||
UsageTransactionMetadata = None
|
||||
|
||||
try:
|
||||
from backend.executor.utils import block_usage_cost
|
||||
except ImportError:
|
||||
block_usage_cost = None
|
||||
|
||||
# Utilities
|
||||
try:
|
||||
from backend.util.file import store_media_file
|
||||
except ImportError:
|
||||
store_media_file = None
|
||||
|
||||
try:
|
||||
from backend.util.type import MediaFileType, convert
|
||||
except ImportError:
|
||||
MediaFileType = None
|
||||
convert = None
|
||||
|
||||
try:
|
||||
from backend.util.text import TextFormatter
|
||||
except ImportError:
|
||||
TextFormatter = None
|
||||
|
||||
try:
|
||||
from backend.util.logging import TruncatedLogger
|
||||
except ImportError:
|
||||
TruncatedLogger = None
|
||||
|
||||
|
||||
# OAuth handlers
|
||||
try:
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
except ImportError:
|
||||
BaseOAuthHandler = None
|
||||
|
||||
|
||||
# Credential type with proper provider name
|
||||
from typing import Literal as _Literal
|
||||
|
||||
CredentialsMetaInput = _CredentialsMetaInput[
|
||||
ProviderName, _Literal["api_key", "oauth2", "user_password"]
|
||||
]
|
||||
|
||||
|
||||
# === COMPREHENSIVE __all__ EXPORT ===
|
||||
__all__ = [
|
||||
# Core Block System
|
||||
"Block",
|
||||
"BlockCategory",
|
||||
"BlockOutput",
|
||||
"BlockSchema",
|
||||
"BlockType",
|
||||
"BlockWebhookConfig",
|
||||
"BlockManualWebhookConfig",
|
||||
# Schema and Model Components
|
||||
"SchemaField",
|
||||
"Credentials",
|
||||
"CredentialsField",
|
||||
"CredentialsMetaInput",
|
||||
"APIKeyCredentials",
|
||||
"OAuth2Credentials",
|
||||
"UserPasswordCredentials",
|
||||
"NodeExecutionStats",
|
||||
# Cost System
|
||||
"BlockCost",
|
||||
"BlockCostType",
|
||||
"UsageTransactionMetadata",
|
||||
"block_usage_cost",
|
||||
# Integrations
|
||||
"ProviderName",
|
||||
"BaseWebhooksManager",
|
||||
"ManualWebhookManagerBase",
|
||||
"Webhook",
|
||||
# Provider-Specific (when available)
|
||||
"BaseOAuthHandler",
|
||||
# Utilities
|
||||
"json",
|
||||
"store_media_file",
|
||||
"MediaFileType",
|
||||
"convert",
|
||||
"TextFormatter",
|
||||
"TruncatedLogger",
|
||||
"BaseModel",
|
||||
"Field",
|
||||
"SecretStr",
|
||||
"Requests",
|
||||
# SDK Components
|
||||
"AutoRegistry",
|
||||
"BlockConfiguration",
|
||||
"Provider",
|
||||
"ProviderBuilder",
|
||||
"cost",
|
||||
]
|
||||
|
||||
# Remove None values from __all__
|
||||
__all__ = [name for name in __all__ if globals().get(name) is not None]
|
||||
161
autogpt_platform/backend/backend/sdk/builder.py
Normal file
161
autogpt_platform/backend/backend/sdk/builder.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Builder class for creating provider configurations with a fluent API.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Callable, List, Optional, Type
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.data.model import APIKeyCredentials, Credentials, UserPasswordCredentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
from backend.sdk.provider import OAuthConfig, Provider
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
class ProviderBuilder:
|
||||
"""Builder for creating provider configurations."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self._oauth_config: Optional[OAuthConfig] = None
|
||||
self._webhook_manager: Optional[Type[BaseWebhooksManager]] = None
|
||||
self._default_credentials: List[Credentials] = []
|
||||
self._base_costs: List[BlockCost] = []
|
||||
self._supported_auth_types: set = set()
|
||||
self._api_client_factory: Optional[Callable] = None
|
||||
self._error_handler: Optional[Callable[[Exception], str]] = None
|
||||
self._default_scopes: Optional[List[str]] = None
|
||||
self._client_id_env_var: Optional[str] = None
|
||||
self._client_secret_env_var: Optional[str] = None
|
||||
self._extra_config: dict = {}
|
||||
|
||||
def with_oauth(
|
||||
self,
|
||||
handler_class: Type[BaseOAuthHandler],
|
||||
scopes: Optional[List[str]] = None,
|
||||
client_id_env_var: Optional[str] = None,
|
||||
client_secret_env_var: Optional[str] = None,
|
||||
) -> "ProviderBuilder":
|
||||
"""Add OAuth support."""
|
||||
self._oauth_config = OAuthConfig(
|
||||
oauth_handler=handler_class,
|
||||
scopes=scopes,
|
||||
client_id_env_var=client_id_env_var,
|
||||
client_secret_env_var=client_secret_env_var,
|
||||
)
|
||||
self._supported_auth_types.add("oauth2")
|
||||
return self
|
||||
|
||||
def with_api_key(self, env_var_name: str, title: str) -> "ProviderBuilder":
|
||||
"""Add API key support with environment variable name."""
|
||||
self._supported_auth_types.add("api_key")
|
||||
|
||||
# Register the API key mapping
|
||||
AutoRegistry.register_api_key(self.name, env_var_name)
|
||||
|
||||
# Check if API key exists in environment
|
||||
api_key = os.getenv(env_var_name)
|
||||
if api_key:
|
||||
self._default_credentials.append(
|
||||
APIKeyCredentials(
|
||||
id=f"{self.name}-default",
|
||||
provider=self.name,
|
||||
api_key=SecretStr(api_key),
|
||||
title=title,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def with_api_key_from_settings(
|
||||
self, settings_attr: str, title: str
|
||||
) -> "ProviderBuilder":
|
||||
"""Use existing API key from settings."""
|
||||
self._supported_auth_types.add("api_key")
|
||||
|
||||
# Try to get the API key from settings
|
||||
settings = Settings()
|
||||
api_key = getattr(settings.secrets, settings_attr, None)
|
||||
if api_key:
|
||||
self._default_credentials.append(
|
||||
APIKeyCredentials(
|
||||
id=f"{self.name}-default",
|
||||
provider=self.name,
|
||||
api_key=api_key,
|
||||
title=title,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def with_user_password(
|
||||
self, username_env_var: str, password_env_var: str, title: str
|
||||
) -> "ProviderBuilder":
|
||||
"""Add username/password support with environment variable names."""
|
||||
self._supported_auth_types.add("user_password")
|
||||
|
||||
# Check if credentials exist in environment
|
||||
username = os.getenv(username_env_var)
|
||||
password = os.getenv(password_env_var)
|
||||
if username and password:
|
||||
self._default_credentials.append(
|
||||
UserPasswordCredentials(
|
||||
id=f"{self.name}-default",
|
||||
provider=self.name,
|
||||
username=SecretStr(username),
|
||||
password=SecretStr(password),
|
||||
title=title,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def with_webhook_manager(
|
||||
self, manager_class: Type[BaseWebhooksManager]
|
||||
) -> "ProviderBuilder":
|
||||
"""Register webhook manager for this provider."""
|
||||
self._webhook_manager = manager_class
|
||||
return self
|
||||
|
||||
def with_base_cost(
|
||||
self, amount: int, cost_type: BlockCostType
|
||||
) -> "ProviderBuilder":
|
||||
"""Set base cost for all blocks using this provider."""
|
||||
self._base_costs.append(BlockCost(cost_amount=amount, cost_type=cost_type))
|
||||
return self
|
||||
|
||||
def with_api_client(self, factory: Callable) -> "ProviderBuilder":
|
||||
"""Register API client factory."""
|
||||
self._api_client_factory = factory
|
||||
return self
|
||||
|
||||
def with_error_handler(
|
||||
self, handler: Callable[[Exception], str]
|
||||
) -> "ProviderBuilder":
|
||||
"""Register error handler for provider-specific errors."""
|
||||
self._error_handler = handler
|
||||
return self
|
||||
|
||||
def with_config(self, **kwargs) -> "ProviderBuilder":
|
||||
"""Add additional configuration options."""
|
||||
self._extra_config.update(kwargs)
|
||||
return self
|
||||
|
||||
def build(self) -> Provider:
|
||||
"""Build and register the provider configuration."""
|
||||
provider = Provider(
|
||||
name=self.name,
|
||||
oauth_config=self._oauth_config,
|
||||
webhook_manager=self._webhook_manager,
|
||||
default_credentials=self._default_credentials,
|
||||
base_costs=self._base_costs,
|
||||
supported_auth_types=self._supported_auth_types,
|
||||
api_client_factory=self._api_client_factory,
|
||||
error_handler=self._error_handler,
|
||||
**self._extra_config,
|
||||
)
|
||||
|
||||
# Auto-registration happens here
|
||||
AutoRegistry.register_provider(provider)
|
||||
return provider
|
||||
163
autogpt_platform/backend/backend/sdk/cost_integration.py
Normal file
163
autogpt_platform/backend/backend/sdk/cost_integration.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Integration between SDK provider costs and the execution cost system.
|
||||
|
||||
This module provides the glue between provider-defined base costs and the
|
||||
BLOCK_COSTS configuration used by the execution system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Type
|
||||
|
||||
from backend.data.block import Block
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_provider_costs_for_block(block_class: Type[Block]) -> None:
|
||||
"""
|
||||
Register provider base costs for a specific block in BLOCK_COSTS.
|
||||
|
||||
This function checks if the block uses credentials from a provider that has
|
||||
base costs defined, and automatically registers those costs for the block.
|
||||
|
||||
Args:
|
||||
block_class: The block class to register costs for
|
||||
"""
|
||||
# Skip if block already has custom costs defined
|
||||
if block_class in BLOCK_COSTS:
|
||||
logger.debug(
|
||||
f"Block {block_class.__name__} already has costs defined, skipping provider costs"
|
||||
)
|
||||
return
|
||||
|
||||
# Get the block's input schema
|
||||
# We need to instantiate the block to get its input schema
|
||||
try:
|
||||
block_instance = block_class()
|
||||
input_schema = block_instance.input_schema
|
||||
except Exception as e:
|
||||
logger.debug(f"Block {block_class.__name__} cannot be instantiated: {e}")
|
||||
return
|
||||
|
||||
# Look for credentials fields
|
||||
# The cost system works of filtering on credentials fields,
|
||||
# without credentials fields, we can not apply costs
|
||||
# TODO: Improve cost system to allow for costs witout a provider
|
||||
credentials_fields = input_schema.get_credentials_fields()
|
||||
if not credentials_fields:
|
||||
logger.debug(f"Block {block_class.__name__} has no credentials fields")
|
||||
return
|
||||
|
||||
# Get provider information from credentials fields
|
||||
for field_name, field_info in credentials_fields.items():
|
||||
# Get the field schema to extract provider information
|
||||
field_schema = input_schema.get_field_schema(field_name)
|
||||
|
||||
# Extract provider names from json_schema_extra
|
||||
providers = field_schema.get("credentials_provider", [])
|
||||
if not providers:
|
||||
continue
|
||||
|
||||
# For each provider, check if it has base costs
|
||||
block_costs: List[BlockCost] = []
|
||||
for provider_name in providers:
|
||||
provider = AutoRegistry.get_provider(provider_name)
|
||||
if not provider:
|
||||
logger.debug(f"Provider {provider_name} not found in registry")
|
||||
continue
|
||||
|
||||
# Add provider's base costs to the block
|
||||
if provider.base_costs:
|
||||
logger.info(
|
||||
f"Registering {len(provider.base_costs)} base costs from provider {provider_name} for block {block_class.__name__}"
|
||||
)
|
||||
block_costs.extend(provider.base_costs)
|
||||
|
||||
# Register costs if any were found
|
||||
if block_costs:
|
||||
BLOCK_COSTS[block_class] = block_costs
|
||||
logger.info(
|
||||
f"Registered {len(block_costs)} total costs for block {block_class.__name__}"
|
||||
)
|
||||
|
||||
|
||||
def sync_all_provider_costs() -> None:
|
||||
"""
|
||||
Sync all provider base costs to blocks that use them.
|
||||
|
||||
This should be called after all providers and blocks are registered,
|
||||
typically during application startup.
|
||||
"""
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
logger.info("Syncing provider costs to blocks...")
|
||||
|
||||
blocks_with_costs = 0
|
||||
total_costs = 0
|
||||
|
||||
for block_id, block_class in load_all_blocks().items():
|
||||
initial_count = len(BLOCK_COSTS.get(block_class, []))
|
||||
register_provider_costs_for_block(block_class)
|
||||
final_count = len(BLOCK_COSTS.get(block_class, []))
|
||||
|
||||
if final_count > initial_count:
|
||||
blocks_with_costs += 1
|
||||
total_costs += final_count - initial_count
|
||||
|
||||
logger.info(f"Synced {total_costs} costs to {blocks_with_costs} blocks")
|
||||
|
||||
|
||||
def get_block_costs(block_class: Type[Block]) -> List[BlockCost]:
|
||||
"""
|
||||
Get all costs for a block, including both explicit and provider costs.
|
||||
|
||||
Args:
|
||||
block_class: The block class to get costs for
|
||||
|
||||
Returns:
|
||||
List of BlockCost objects for the block
|
||||
"""
|
||||
# First ensure provider costs are registered
|
||||
register_provider_costs_for_block(block_class)
|
||||
|
||||
# Return all costs for the block
|
||||
return BLOCK_COSTS.get(block_class, [])
|
||||
|
||||
|
||||
def cost(*costs: BlockCost):
|
||||
"""
|
||||
Decorator to set custom costs for a block.
|
||||
|
||||
This decorator allows blocks to define their own costs, which will override
|
||||
any provider base costs. Multiple costs can be specified with different
|
||||
filters for different pricing tiers (e.g., different models).
|
||||
|
||||
Example:
|
||||
@cost(
|
||||
BlockCost(cost_type=BlockCostType.RUN, cost_amount=10),
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_amount=20,
|
||||
cost_filter={"model": "premium"}
|
||||
)
|
||||
)
|
||||
class MyBlock(Block):
|
||||
...
|
||||
|
||||
Args:
|
||||
*costs: Variable number of BlockCost objects
|
||||
"""
|
||||
|
||||
def decorator(block_class: Type[Block]) -> Type[Block]:
|
||||
# Register the costs for this block
|
||||
if costs:
|
||||
BLOCK_COSTS[block_class] = list(costs)
|
||||
logger.info(
|
||||
f"Registered {len(costs)} custom costs for block {block_class.__name__}"
|
||||
)
|
||||
return block_class
|
||||
|
||||
return decorator
|
||||
114
autogpt_platform/backend/backend/sdk/provider.py
Normal file
114
autogpt_platform/backend/backend/sdk/provider.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
Provider configuration class that holds all provider-related settings.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, List, Optional, Set, Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.data.model import Credentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
|
||||
|
||||
class OAuthConfig(BaseModel):
|
||||
"""Configuration for OAuth authentication."""
|
||||
|
||||
oauth_handler: Type[BaseOAuthHandler]
|
||||
scopes: Optional[List[str]] = None
|
||||
client_id_env_var: Optional[str] = None
|
||||
client_secret_env_var: Optional[str] = None
|
||||
|
||||
|
||||
class Provider:
|
||||
"""A configured provider that blocks can use.
|
||||
|
||||
A Provider represents a service or platform that blocks can integrate with, like Linear, OpenAI, etc.
|
||||
It contains configuration for:
|
||||
- Authentication (OAuth, API keys)
|
||||
- Default credentials
|
||||
- Base costs for using the provider
|
||||
- Webhook handling
|
||||
- Error handling
|
||||
- API client factory
|
||||
|
||||
Blocks use Provider instances to handle authentication, make API calls, and manage service-specific logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
oauth_config: Optional[OAuthConfig] = None,
|
||||
webhook_manager: Optional[Type[BaseWebhooksManager]] = None,
|
||||
default_credentials: Optional[List[Credentials]] = None,
|
||||
base_costs: Optional[List[BlockCost]] = None,
|
||||
supported_auth_types: Optional[Set[str]] = None,
|
||||
api_client_factory: Optional[Callable] = None,
|
||||
error_handler: Optional[Callable[[Exception], str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.name = name
|
||||
self.oauth_config = oauth_config
|
||||
self.webhook_manager = webhook_manager
|
||||
self.default_credentials = default_credentials or []
|
||||
self.base_costs = base_costs or []
|
||||
self.supported_auth_types = supported_auth_types or set()
|
||||
self._api_client_factory = api_client_factory
|
||||
self._error_handler = error_handler
|
||||
|
||||
# Store any additional configuration
|
||||
self._extra_config = kwargs
|
||||
|
||||
def credentials_field(self, **kwargs) -> CredentialsMetaInput:
|
||||
"""Return a CredentialsField configured for this provider."""
|
||||
# Extract known CredentialsField parameters
|
||||
title = kwargs.pop("title", None)
|
||||
description = kwargs.pop("description", f"{self.name.title()} credentials")
|
||||
required_scopes = kwargs.pop("required_scopes", set())
|
||||
discriminator = kwargs.pop("discriminator", None)
|
||||
discriminator_mapping = kwargs.pop("discriminator_mapping", None)
|
||||
discriminator_values = kwargs.pop("discriminator_values", None)
|
||||
|
||||
# Create json_schema_extra with provider information
|
||||
json_schema_extra = {
|
||||
"credentials_provider": [self.name],
|
||||
"credentials_types": (
|
||||
list(self.supported_auth_types)
|
||||
if self.supported_auth_types
|
||||
else ["api_key"]
|
||||
),
|
||||
}
|
||||
|
||||
# Merge any existing json_schema_extra
|
||||
if "json_schema_extra" in kwargs:
|
||||
json_schema_extra.update(kwargs.pop("json_schema_extra"))
|
||||
|
||||
# Add json_schema_extra to kwargs
|
||||
kwargs["json_schema_extra"] = json_schema_extra
|
||||
|
||||
return CredentialsField(
|
||||
required_scopes=required_scopes,
|
||||
discriminator=discriminator,
|
||||
discriminator_mapping=discriminator_mapping,
|
||||
discriminator_values=discriminator_values,
|
||||
title=title,
|
||||
description=description,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_api(self, credentials: Credentials) -> Any:
|
||||
"""Get API client instance for the given credentials."""
|
||||
if self._api_client_factory:
|
||||
return self._api_client_factory(credentials)
|
||||
raise NotImplementedError(f"No API client factory registered for {self.name}")
|
||||
|
||||
def handle_error(self, error: Exception) -> str:
|
||||
"""Handle provider-specific errors."""
|
||||
if self._error_handler:
|
||||
return self._error_handler(error)
|
||||
return str(error)
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get additional configuration value."""
|
||||
return self._extra_config.get(key, default)
|
||||
220
autogpt_platform/backend/backend/sdk/registry.py
Normal file
220
autogpt_platform/backend/backend/sdk/registry.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Auto-registration system for blocks, providers, and their configurations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.blocks.basic import Block
|
||||
from backend.data.model import APIKeyCredentials, Credentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.sdk.provider import Provider
|
||||
|
||||
|
||||
class SDKOAuthCredentials(BaseModel):
|
||||
"""OAuth credentials configuration for SDK providers."""
|
||||
|
||||
use_secrets: bool = False
|
||||
client_id_env_var: Optional[str] = None
|
||||
client_secret_env_var: Optional[str] = None
|
||||
|
||||
|
||||
class BlockConfiguration:
|
||||
"""Configuration associated with a block."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str,
|
||||
costs: List[Any],
|
||||
default_credentials: List[Credentials],
|
||||
webhook_manager: Optional[Type[BaseWebhooksManager]] = None,
|
||||
oauth_handler: Optional[Type[BaseOAuthHandler]] = None,
|
||||
):
|
||||
self.provider = provider
|
||||
self.costs = costs
|
||||
self.default_credentials = default_credentials
|
||||
self.webhook_manager = webhook_manager
|
||||
self.oauth_handler = oauth_handler
|
||||
|
||||
|
||||
class AutoRegistry:
|
||||
"""Central registry for all block-related configurations."""
|
||||
|
||||
_lock = threading.Lock()
|
||||
_providers: Dict[str, "Provider"] = {}
|
||||
_default_credentials: List[Credentials] = []
|
||||
_oauth_handlers: Dict[str, Type[BaseOAuthHandler]] = {}
|
||||
_oauth_credentials: Dict[str, SDKOAuthCredentials] = {}
|
||||
_webhook_managers: Dict[str, Type[BaseWebhooksManager]] = {}
|
||||
_block_configurations: Dict[Type[Block], BlockConfiguration] = {}
|
||||
_api_key_mappings: Dict[str, str] = {} # provider -> env_var_name
|
||||
|
||||
@classmethod
|
||||
def register_provider(cls, provider: "Provider") -> None:
|
||||
"""Auto-register provider and all its configurations."""
|
||||
with cls._lock:
|
||||
cls._providers[provider.name] = provider
|
||||
|
||||
# Register OAuth handler if provided
|
||||
if provider.oauth_config:
|
||||
# Dynamically set PROVIDER_NAME if not already set
|
||||
if (
|
||||
not hasattr(provider.oauth_config.oauth_handler, "PROVIDER_NAME")
|
||||
or provider.oauth_config.oauth_handler.PROVIDER_NAME is None
|
||||
):
|
||||
# Import ProviderName to create dynamic enum value
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# This works because ProviderName has _missing_ method
|
||||
provider.oauth_config.oauth_handler.PROVIDER_NAME = ProviderName(
|
||||
provider.name
|
||||
)
|
||||
cls._oauth_handlers[provider.name] = provider.oauth_config.oauth_handler
|
||||
|
||||
# Register OAuth credentials configuration
|
||||
oauth_creds = SDKOAuthCredentials(
|
||||
use_secrets=False, # SDK providers use custom env vars
|
||||
client_id_env_var=provider.oauth_config.client_id_env_var,
|
||||
client_secret_env_var=provider.oauth_config.client_secret_env_var,
|
||||
)
|
||||
cls._oauth_credentials[provider.name] = oauth_creds
|
||||
|
||||
# Register webhook manager if provided
|
||||
if provider.webhook_manager:
|
||||
# Dynamically set PROVIDER_NAME if not already set
|
||||
if (
|
||||
not hasattr(provider.webhook_manager, "PROVIDER_NAME")
|
||||
or provider.webhook_manager.PROVIDER_NAME is None
|
||||
):
|
||||
# Import ProviderName to create dynamic enum value
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# This works because ProviderName has _missing_ method
|
||||
provider.webhook_manager.PROVIDER_NAME = ProviderName(provider.name)
|
||||
cls._webhook_managers[provider.name] = provider.webhook_manager
|
||||
|
||||
# Register default credentials
|
||||
cls._default_credentials.extend(provider.default_credentials)
|
||||
|
||||
@classmethod
|
||||
def register_api_key(cls, provider: str, env_var_name: str) -> None:
|
||||
"""Register an environment variable as an API key for a provider."""
|
||||
with cls._lock:
|
||||
cls._api_key_mappings[provider] = env_var_name
|
||||
|
||||
# Dynamically check if the env var exists and create credential
|
||||
import os
|
||||
|
||||
api_key = os.getenv(env_var_name)
|
||||
if api_key:
|
||||
credential = APIKeyCredentials(
|
||||
id=f"{provider}-default",
|
||||
provider=provider,
|
||||
api_key=SecretStr(api_key),
|
||||
title=f"Default {provider} credentials",
|
||||
)
|
||||
# Check if credential already exists to avoid duplicates
|
||||
if not any(c.id == credential.id for c in cls._default_credentials):
|
||||
cls._default_credentials.append(credential)
|
||||
|
||||
@classmethod
|
||||
def get_all_credentials(cls) -> List[Credentials]:
|
||||
"""Replace hardcoded get_all_creds() in credentials_store.py."""
|
||||
with cls._lock:
|
||||
return cls._default_credentials.copy()
|
||||
|
||||
@classmethod
|
||||
def get_oauth_handlers(cls) -> Dict[str, Type[BaseOAuthHandler]]:
|
||||
"""Replace HANDLERS_BY_NAME in oauth/__init__.py."""
|
||||
with cls._lock:
|
||||
return cls._oauth_handlers.copy()
|
||||
|
||||
@classmethod
|
||||
def get_oauth_credentials(cls) -> Dict[str, SDKOAuthCredentials]:
|
||||
"""Get OAuth credentials configuration for SDK providers."""
|
||||
with cls._lock:
|
||||
return cls._oauth_credentials.copy()
|
||||
|
||||
@classmethod
|
||||
def get_webhook_managers(cls) -> Dict[str, Type[BaseWebhooksManager]]:
|
||||
"""Replace load_webhook_managers() in webhooks/__init__.py."""
|
||||
with cls._lock:
|
||||
return cls._webhook_managers.copy()
|
||||
|
||||
@classmethod
|
||||
def register_block_configuration(
|
||||
cls, block_class: Type[Block], config: BlockConfiguration
|
||||
) -> None:
|
||||
"""Register configuration for a specific block class."""
|
||||
with cls._lock:
|
||||
cls._block_configurations[block_class] = config
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, name: str) -> Optional["Provider"]:
|
||||
"""Get a registered provider by name."""
|
||||
with cls._lock:
|
||||
return cls._providers.get(name)
|
||||
|
||||
@classmethod
|
||||
def get_all_provider_names(cls) -> List[str]:
|
||||
"""Get all registered provider names."""
|
||||
with cls._lock:
|
||||
return list(cls._providers.keys())
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""Clear all registrations (useful for testing)."""
|
||||
with cls._lock:
|
||||
cls._providers.clear()
|
||||
cls._default_credentials.clear()
|
||||
cls._oauth_handlers.clear()
|
||||
cls._webhook_managers.clear()
|
||||
cls._block_configurations.clear()
|
||||
cls._api_key_mappings.clear()
|
||||
|
||||
@classmethod
|
||||
def patch_integrations(cls) -> None:
|
||||
"""Patch existing integration points to use AutoRegistry."""
|
||||
# OAuth handlers are handled by SDKAwareHandlersDict in oauth/__init__.py
|
||||
# No patching needed for OAuth handlers
|
||||
|
||||
# Patch webhook managers
|
||||
try:
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
# Get the module from sys.modules to respect mocking
|
||||
if "backend.integrations.webhooks" in sys.modules:
|
||||
webhooks: Any = sys.modules["backend.integrations.webhooks"]
|
||||
else:
|
||||
import backend.integrations.webhooks
|
||||
|
||||
webhooks: Any = backend.integrations.webhooks
|
||||
|
||||
if hasattr(webhooks, "load_webhook_managers"):
|
||||
original_load = webhooks.load_webhook_managers
|
||||
|
||||
def patched_load():
|
||||
# Get original managers
|
||||
managers = original_load()
|
||||
# Add SDK-registered managers
|
||||
sdk_managers = cls.get_webhook_managers()
|
||||
if isinstance(sdk_managers, dict):
|
||||
# Import ProviderName for conversion
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# Convert string keys to ProviderName for consistency
|
||||
for provider_str, manager in sdk_managers.items():
|
||||
provider_name = ProviderName(provider_str)
|
||||
managers[provider_name] = manager
|
||||
return managers
|
||||
|
||||
webhooks.load_webhook_managers = patched_load
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to patch webhook managers: {e}")
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Dict, List, Optional, Sequence
|
||||
from typing import Annotated, Any, Optional, Sequence
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
@@ -11,7 +11,6 @@ from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKey
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.execution import NodeExecutionResult
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.util.settings import Settings
|
||||
@@ -30,30 +29,19 @@ class NodeOutput(TypedDict):
|
||||
class ExecutionNode(TypedDict):
|
||||
node_id: str
|
||||
input: Any
|
||||
output: Dict[str, Any]
|
||||
output: dict[str, Any]
|
||||
|
||||
|
||||
class ExecutionNodeOutput(TypedDict):
|
||||
node_id: str
|
||||
outputs: List[NodeOutput]
|
||||
outputs: list[NodeOutput]
|
||||
|
||||
|
||||
class GraphExecutionResult(TypedDict):
|
||||
execution_id: str
|
||||
status: str
|
||||
nodes: List[ExecutionNode]
|
||||
output: Optional[List[Dict[str, str]]]
|
||||
|
||||
|
||||
def get_outputs_with_names(results: list[NodeExecutionResult]) -> list[dict[str, str]]:
|
||||
outputs = []
|
||||
for result in results:
|
||||
if "output" in result.output_data:
|
||||
output_value = result.output_data["output"][0]
|
||||
name = result.output_data.get("name", [None])[0]
|
||||
if output_value and name:
|
||||
outputs.append({name: output_value})
|
||||
return outputs
|
||||
nodes: list[ExecutionNode]
|
||||
output: Optional[list[dict[str, str]]]
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -122,23 +110,34 @@ async def get_graph_execution_results(
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
results = await execution_db.get_node_executions(graph_exec_id)
|
||||
last_result = results[-1] if results else None
|
||||
execution_status = (
|
||||
last_result.status if last_result else AgentExecutionStatus.INCOMPLETE
|
||||
graph_exec = await execution_db.get_graph_execution(
|
||||
user_id=api_key.user_id,
|
||||
execution_id=graph_exec_id,
|
||||
include_node_executions=True,
|
||||
)
|
||||
outputs = get_outputs_with_names(results)
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
|
||||
)
|
||||
|
||||
return GraphExecutionResult(
|
||||
execution_id=graph_exec_id,
|
||||
status=execution_status,
|
||||
status=graph_exec.status.value,
|
||||
nodes=[
|
||||
ExecutionNode(
|
||||
node_id=result.node_id,
|
||||
input=result.input_data.get("value", result.input_data),
|
||||
output={k: v for k, v in result.output_data.items()},
|
||||
node_id=node_exec.node_id,
|
||||
input=node_exec.input_data.get("value", node_exec.input_data),
|
||||
output={k: v for k, v in node_exec.output_data.items()},
|
||||
)
|
||||
for result in results
|
||||
for node_exec in graph_exec.node_executions
|
||||
],
|
||||
output=outputs if execution_status == AgentExecutionStatus.COMPLETED else None,
|
||||
output=(
|
||||
[
|
||||
{name: value}
|
||||
for name, values in graph_exec.outputs.items()
|
||||
for value in values
|
||||
]
|
||||
if graph_exec.status == AgentExecutionStatus.COMPLETED
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
Models for integration-related data structures that need to be exposed in the OpenAPI schema.
|
||||
|
||||
This module provides models that will be included in the OpenAPI schema generation,
|
||||
allowing frontend code generators like Orval to create corresponding TypeScript types.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
|
||||
def get_all_provider_names() -> list[str]:
|
||||
"""
|
||||
Collect all provider names from both ProviderName enum and AutoRegistry.
|
||||
|
||||
This function should be called at runtime to ensure we get all
|
||||
dynamically registered providers.
|
||||
|
||||
Returns:
|
||||
A sorted list of unique provider names.
|
||||
"""
|
||||
# Get static providers from enum
|
||||
static_providers = [member.value for member in ProviderName]
|
||||
|
||||
# Get dynamic providers from registry
|
||||
dynamic_providers = AutoRegistry.get_all_provider_names()
|
||||
|
||||
# Combine and deduplicate
|
||||
all_providers = list(set(static_providers + dynamic_providers))
|
||||
all_providers.sort()
|
||||
|
||||
return all_providers
|
||||
|
||||
|
||||
# Note: We don't create a static enum here because providers are registered dynamically.
|
||||
# Instead, we expose provider names through API endpoints that can be fetched at runtime.
|
||||
|
||||
|
||||
class ProviderNamesResponse(BaseModel):
|
||||
"""Response containing list of all provider names."""
|
||||
|
||||
providers: list[str] = Field(
|
||||
description="List of all available provider names",
|
||||
default_factory=get_all_provider_names,
|
||||
)
|
||||
|
||||
|
||||
class ProviderConstants(BaseModel):
|
||||
"""
|
||||
Model that exposes all provider names as a constant in the OpenAPI schema.
|
||||
This is designed to be converted by Orval into a TypeScript constant.
|
||||
"""
|
||||
|
||||
PROVIDER_NAMES: dict[str, str] = Field(
|
||||
description="All available provider names as a constant mapping",
|
||||
default_factory=lambda: {
|
||||
name.upper().replace("-", "_"): name for name in get_all_provider_names()
|
||||
},
|
||||
)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"PROVIDER_NAMES": {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"EXA": "exa",
|
||||
"GEM": "gem",
|
||||
"EXAMPLE_SERVICE": "example-service",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, Awaitable, List, Literal
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
@@ -30,9 +30,14 @@ from backend.data.model import (
|
||||
)
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.server.integrations.models import (
|
||||
ProviderConstants,
|
||||
ProviderNamesResponse,
|
||||
get_all_provider_names,
|
||||
)
|
||||
from backend.server.v2.library.db import set_preset_webhook, update_preset
|
||||
from backend.util.exceptions import NeedConfirmation, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
@@ -472,14 +477,49 @@ async def remove_all_webhooks_for_credentials(
|
||||
def _get_provider_oauth_handler(
|
||||
req: Request, provider_name: ProviderName
|
||||
) -> "BaseOAuthHandler":
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
# Ensure blocks are loaded so SDK providers are available
|
||||
try:
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
load_all_blocks() # This is cached, so it only runs once
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load blocks: {e}")
|
||||
|
||||
# Convert provider_name to string for lookup
|
||||
provider_key = (
|
||||
provider_name.value if hasattr(provider_name, "value") else str(provider_name)
|
||||
)
|
||||
|
||||
if provider_key not in HANDLERS_BY_NAME:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Provider '{provider_name.value}' does not support OAuth",
|
||||
detail=f"Provider '{provider_key}' does not support OAuth",
|
||||
)
|
||||
|
||||
# Check if this provider has custom OAuth credentials
|
||||
oauth_credentials = CREDENTIALS_BY_PROVIDER.get(provider_key)
|
||||
|
||||
if oauth_credentials and not oauth_credentials.use_secrets:
|
||||
# SDK provider with custom env vars
|
||||
import os
|
||||
|
||||
client_id = (
|
||||
os.getenv(oauth_credentials.client_id_env_var)
|
||||
if oauth_credentials.client_id_env_var
|
||||
else None
|
||||
)
|
||||
client_secret = (
|
||||
os.getenv(oauth_credentials.client_secret_env_var)
|
||||
if oauth_credentials.client_secret_env_var
|
||||
else None
|
||||
)
|
||||
else:
|
||||
# Original provider using settings.secrets
|
||||
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id", None)
|
||||
client_secret = getattr(
|
||||
settings.secrets, f"{provider_name.value}_client_secret", None
|
||||
)
|
||||
|
||||
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id")
|
||||
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
|
||||
if not (client_id and client_secret):
|
||||
logger.error(
|
||||
f"Attempt to use unconfigured {provider_name.value} OAuth integration"
|
||||
@@ -492,14 +532,84 @@ def _get_provider_oauth_handler(
|
||||
},
|
||||
)
|
||||
|
||||
handler_class = HANDLERS_BY_NAME[provider_name]
|
||||
frontend_base_url = (
|
||||
settings.config.frontend_base_url
|
||||
or settings.config.platform_base_url
|
||||
or str(req.base_url)
|
||||
)
|
||||
handler_class = HANDLERS_BY_NAME[provider_key]
|
||||
frontend_base_url = settings.config.frontend_base_url
|
||||
|
||||
if not frontend_base_url:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Frontend base URL is not configured",
|
||||
)
|
||||
|
||||
return handler_class(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
||||
)
|
||||
|
||||
|
||||
# === PROVIDER DISCOVERY ENDPOINTS ===
|
||||
|
||||
|
||||
@router.get("/providers", response_model=List[str])
|
||||
async def list_providers() -> List[str]:
|
||||
"""
|
||||
Get a list of all available provider names.
|
||||
|
||||
Returns both statically defined providers (from ProviderName enum)
|
||||
and dynamically registered providers (from SDK decorators).
|
||||
|
||||
Note: The complete list of provider names is also available as a constant
|
||||
in the generated TypeScript client via PROVIDER_NAMES.
|
||||
"""
|
||||
# Get all providers at runtime
|
||||
all_providers = get_all_provider_names()
|
||||
return all_providers
|
||||
|
||||
|
||||
@router.get("/providers/names", response_model=ProviderNamesResponse)
|
||||
async def get_provider_names() -> ProviderNamesResponse:
|
||||
"""
|
||||
Get all provider names in a structured format.
|
||||
|
||||
This endpoint is specifically designed to expose the provider names
|
||||
in the OpenAPI schema so that code generators like Orval can create
|
||||
appropriate TypeScript constants.
|
||||
"""
|
||||
return ProviderNamesResponse()
|
||||
|
||||
|
||||
@router.get("/providers/constants", response_model=ProviderConstants)
|
||||
async def get_provider_constants() -> ProviderConstants:
|
||||
"""
|
||||
Get provider names as constants.
|
||||
|
||||
This endpoint returns a model with provider names as constants,
|
||||
specifically designed for OpenAPI code generation tools to create
|
||||
TypeScript constants.
|
||||
"""
|
||||
return ProviderConstants()
|
||||
|
||||
|
||||
class ProviderEnumResponse(BaseModel):
|
||||
"""Response containing a provider from the enum."""
|
||||
|
||||
provider: str = Field(
|
||||
description="A provider name from the complete list of providers"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/providers/enum-example", response_model=ProviderEnumResponse)
|
||||
async def get_provider_enum_example() -> ProviderEnumResponse:
|
||||
"""
|
||||
Example endpoint that uses the CompleteProviderNames enum.
|
||||
|
||||
This endpoint exists to ensure that the CompleteProviderNames enum is included
|
||||
in the OpenAPI schema, which will cause Orval to generate it as a
|
||||
TypeScript enum/constant.
|
||||
"""
|
||||
# Return the first provider as an example
|
||||
all_providers = get_all_provider_names()
|
||||
return ProviderEnumResponse(
|
||||
provider=all_providers[0] if all_providers else "openai"
|
||||
)
|
||||
|
||||
@@ -62,6 +62,10 @@ def launch_darkly_context():
|
||||
async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.db.connect()
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
# SDK auto-registration is now handled by AutoRegistry.patch_integrations()
|
||||
# which is called when the SDK module is imported
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
|
||||
@@ -448,10 +448,10 @@ class DeleteGraphResponse(TypedDict):
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_graphs(
|
||||
async def list_graphs(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
|
||||
) -> Sequence[graph_db.GraphMeta]:
|
||||
return await graph_db.list_graphs(filter_by="active", user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -680,22 +680,6 @@ async def stop_graph_run(
|
||||
return res[0]
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/executions",
|
||||
summary="Stop graph executions",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def stop_graph_runs(
|
||||
graph_id: str, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[execution_db.GraphExecutionMeta]:
|
||||
return await _stop_graph_run(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
|
||||
async def _stop_graph_run(
|
||||
user_id: str,
|
||||
graph_id: Optional[str] = None,
|
||||
|
||||
@@ -270,7 +270,7 @@ def test_get_graphs(
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.graph_db.get_graphs",
|
||||
"backend.server.routers.v1.graph_db.list_graphs",
|
||||
return_value=[mock_graph],
|
||||
)
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
async def get_library_agent_by_store_version_id(
|
||||
store_listing_version_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
) -> library_model.LibraryAgent | None:
|
||||
"""
|
||||
Get the library agent metadata for a given store listing version ID and user ID.
|
||||
"""
|
||||
@@ -202,7 +202,7 @@ async def get_library_agent_by_store_version_id(
|
||||
)
|
||||
if not store_listing_version:
|
||||
logger.warning(f"Store listing version not found: {store_listing_version_id}")
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
@@ -214,12 +214,9 @@ async def get_library_agent_by_store_version_id(
|
||||
"agentGraphVersion": store_listing_version.agentGraphVersion,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include={"AgentGraph": True},
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
if agent:
|
||||
return library_model.LibraryAgent.from_db(agent)
|
||||
else:
|
||||
return None
|
||||
return library_model.LibraryAgent.from_db(agent) if agent else None
|
||||
|
||||
|
||||
async def get_library_agent_by_graph_id(
|
||||
|
||||
@@ -127,9 +127,9 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
description=graph.description,
|
||||
input_schema=graph.input_schema,
|
||||
credentials_input_schema=(
|
||||
graph.credentials_input_schema if sub_graphs else None
|
||||
graph.credentials_input_schema if sub_graphs is not None else None
|
||||
),
|
||||
has_external_trigger=graph.has_webhook_trigger,
|
||||
has_external_trigger=graph.has_external_trigger,
|
||||
trigger_setup_info=(
|
||||
LibraryAgentTriggerInfo(
|
||||
provider=trigger_block.webhook_config.provider,
|
||||
@@ -262,6 +262,19 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class TriggeredPresetSetupRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
|
||||
trigger_config: dict[str, Any]
|
||||
agent_credentials: dict[str, CredentialsMetaInput] = pydantic.Field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
|
||||
class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
"""Represents a preset configuration for a library agent."""
|
||||
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, status
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor.utils import make_node_credentials_input_map
|
||||
from backend.integrations.webhooks.utils import setup_webhook_for_block
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -113,12 +108,11 @@ async def get_library_agent_by_graph_id(
|
||||
"/marketplace/{store_listing_version_id}",
|
||||
summary="Get Agent By Store ID",
|
||||
tags=["store, library"],
|
||||
response_model=library_model.LibraryAgent | None,
|
||||
)
|
||||
async def get_library_agent_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
):
|
||||
) -> library_model.LibraryAgent | None:
|
||||
"""
|
||||
Get Library Agent from Store Listing Version ID.
|
||||
"""
|
||||
@@ -295,81 +289,3 @@ async def fork_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
class TriggeredPresetSetupParams(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
trigger_config: dict[str, Any]
|
||||
agent_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@router.post("/{library_agent_id}/setup-trigger")
|
||||
async def setup_trigger(
|
||||
library_agent_id: str = Path(..., description="ID of the library agent"),
|
||||
params: TriggeredPresetSetupParams = Body(),
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
"""
|
||||
Sets up a webhook-triggered `LibraryAgentPreset` for a `LibraryAgent`.
|
||||
Returns the correspondingly created `LibraryAgentPreset` with `webhook_id` set.
|
||||
"""
|
||||
library_agent = await library_db.get_library_agent(
|
||||
id=library_agent_id, user_id=user_id
|
||||
)
|
||||
if not library_agent:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Library agent #{library_agent_id} not found",
|
||||
)
|
||||
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id, version=library_agent.graph_version, user_id=user_id
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status.HTTP_410_GONE,
|
||||
f"Graph #{library_agent.graph_id} not accessible (anymore)",
|
||||
)
|
||||
if not (trigger_node := graph.webhook_input_node):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Graph #{library_agent.graph_id} does not have a webhook node",
|
||||
)
|
||||
|
||||
trigger_config_with_credentials = {
|
||||
**params.trigger_config,
|
||||
**(
|
||||
make_node_credentials_input_map(graph, params.agent_credentials).get(
|
||||
trigger_node.id
|
||||
)
|
||||
or {}
|
||||
),
|
||||
}
|
||||
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=trigger_node.block,
|
||||
trigger_config=trigger_config_with_credentials,
|
||||
)
|
||||
if not new_webhook:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Could not set up webhook: {feedback}",
|
||||
)
|
||||
|
||||
new_preset = await library_db.create_preset(
|
||||
user_id=user_id,
|
||||
preset=library_model.LibraryAgentPresetCreatable(
|
||||
graph_id=library_agent.graph_id,
|
||||
graph_version=library_agent.graph_version,
|
||||
name=params.name,
|
||||
description=params.description,
|
||||
inputs=trigger_config_with_credentials,
|
||||
credentials=params.agent_credentials,
|
||||
webhook_id=new_webhook.id,
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
return new_preset
|
||||
|
||||
@@ -138,6 +138,66 @@ async def create_preset(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/presets/setup-trigger")
|
||||
async def setup_trigger(
|
||||
params: models.TriggeredPresetSetupRequest = Body(),
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> models.LibraryAgentPreset:
|
||||
"""
|
||||
Sets up a webhook-triggered `LibraryAgentPreset` for a `LibraryAgent`.
|
||||
Returns the correspondingly created `LibraryAgentPreset` with `webhook_id` set.
|
||||
"""
|
||||
graph = await get_graph(
|
||||
params.graph_id, version=params.graph_version, user_id=user_id
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status.HTTP_410_GONE,
|
||||
f"Graph #{params.graph_id} not accessible (anymore)",
|
||||
)
|
||||
if not (trigger_node := graph.webhook_input_node):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Graph #{params.graph_id} does not have a webhook node",
|
||||
)
|
||||
|
||||
trigger_config_with_credentials = {
|
||||
**params.trigger_config,
|
||||
**(
|
||||
make_node_credentials_input_map(graph, params.agent_credentials).get(
|
||||
trigger_node.id
|
||||
)
|
||||
or {}
|
||||
),
|
||||
}
|
||||
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=trigger_node.block,
|
||||
trigger_config=trigger_config_with_credentials,
|
||||
)
|
||||
if not new_webhook:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Could not set up webhook: {feedback}",
|
||||
)
|
||||
|
||||
new_preset = await db.create_preset(
|
||||
user_id=user_id,
|
||||
preset=models.LibraryAgentPresetCreatable(
|
||||
graph_id=params.graph_id,
|
||||
graph_version=params.graph_version,
|
||||
name=params.name,
|
||||
description=params.description,
|
||||
inputs=trigger_config_with_credentials,
|
||||
credentials=params.agent_credentials,
|
||||
webhook_id=new_webhook.id,
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
return new_preset
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/presets/{preset_id}",
|
||||
summary="Update an existing preset",
|
||||
|
||||
@@ -7,10 +7,15 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.data.graph
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
from backend.data.graph import GraphModel, get_sub_graphs
|
||||
from backend.data.graph import (
|
||||
GraphMeta,
|
||||
GraphModel,
|
||||
get_graph,
|
||||
get_graph_as_admin,
|
||||
get_sub_graphs,
|
||||
)
|
||||
from backend.data.includes import AGENT_GRAPH_INCLUDE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -193,9 +198,7 @@ async def get_store_agent_details(
|
||||
) from e
|
||||
|
||||
|
||||
async def get_available_graph(
|
||||
store_listing_version_id: str,
|
||||
):
|
||||
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
||||
try:
|
||||
# Get avaialble, non-deleted store listing version
|
||||
store_listing_version = (
|
||||
@@ -215,18 +218,7 @@ async def get_available_graph(
|
||||
detail=f"Store listing version {store_listing_version_id} not found",
|
||||
)
|
||||
|
||||
graph = GraphModel.from_db(store_listing_version.AgentGraph)
|
||||
# We return graph meta, without nodes, they cannot be just removed
|
||||
# because then input_schema would be empty
|
||||
return {
|
||||
"id": graph.id,
|
||||
"version": graph.version,
|
||||
"is_active": graph.is_active,
|
||||
"name": graph.name,
|
||||
"description": graph.description,
|
||||
"input_schema": graph.input_schema,
|
||||
"output_schema": graph.output_schema,
|
||||
}
|
||||
return GraphModel.from_db(store_listing_version.AgentGraph).meta()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent: {e}")
|
||||
@@ -1024,7 +1016,7 @@ async def get_agent(
|
||||
if not store_listing_version:
|
||||
raise ValueError(f"Store listing version {store_listing_version_id} not found")
|
||||
|
||||
graph = await backend.data.graph.get_graph(
|
||||
graph = await get_graph(
|
||||
user_id=user_id,
|
||||
graph_id=store_listing_version.agentGraphId,
|
||||
version=store_listing_version.agentGraphVersion,
|
||||
@@ -1383,7 +1375,7 @@ async def get_agent_as_admin(
|
||||
if not store_listing_version:
|
||||
raise ValueError(f"Store listing version {store_listing_version_id} not found")
|
||||
|
||||
graph = await backend.data.graph.get_graph_as_admin(
|
||||
graph = await get_graph_as_admin(
|
||||
user_id=user_id,
|
||||
graph_id=store_listing_version.agentGraphId,
|
||||
version=store_listing_version.agentGraphVersion,
|
||||
|
||||
@@ -124,6 +124,19 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Time in seconds for how far back to check for the late executions.",
|
||||
)
|
||||
|
||||
block_error_rate_threshold: float = Field(
|
||||
default=0.5,
|
||||
description="Error rate threshold (0.0-1.0) for triggering block error alerts.",
|
||||
)
|
||||
block_error_rate_check_interval_secs: int = Field(
|
||||
default=24 * 60 * 60, # 24 hours
|
||||
description="Interval in seconds between block error rate checks.",
|
||||
)
|
||||
block_error_include_top_blocks: int = Field(
|
||||
default=3,
|
||||
description="Number of top blocks with most errors to show when no blocks exceed threshold (0 to disable).",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
extra="allow",
|
||||
@@ -263,6 +276,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Whether to mark failed scans as clean or not",
|
||||
)
|
||||
|
||||
enable_example_blocks: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable example blocks in production",
|
||||
)
|
||||
|
||||
@field_validator("platform_base_url", "frontend_base_url")
|
||||
@classmethod
|
||||
def validate_platform_base_url(cls, v: str, info: ValidationInfo) -> str:
|
||||
|
||||
101
autogpt_platform/backend/clean_test_db.py
Normal file
101
autogpt_platform/backend/clean_test_db.py
Normal file
@@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Clean the test database by removing all data while preserving the schema.
|
||||
|
||||
Usage:
|
||||
poetry run python clean_test_db.py [--yes]
|
||||
|
||||
Options:
|
||||
--yes Skip confirmation prompt
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from prisma import Prisma
|
||||
|
||||
|
||||
async def main():
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
|
||||
print("=" * 60)
|
||||
print("Cleaning Test Database")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Get initial counts
|
||||
user_count = await db.user.count()
|
||||
agent_count = await db.agentgraph.count()
|
||||
|
||||
print(f"Current data: {user_count} users, {agent_count} agent graphs")
|
||||
|
||||
if user_count == 0 and agent_count == 0:
|
||||
print("Database is already clean!")
|
||||
await db.disconnect()
|
||||
return
|
||||
|
||||
# Check for --yes flag
|
||||
skip_confirm = "--yes" in sys.argv
|
||||
|
||||
if not skip_confirm:
|
||||
response = input("\nDo you want to clean all data? (yes/no): ")
|
||||
if response.lower() != "yes":
|
||||
print("Aborted.")
|
||||
await db.disconnect()
|
||||
return
|
||||
|
||||
print("\nCleaning database...")
|
||||
|
||||
# Delete in reverse order of dependencies
|
||||
tables = [
|
||||
("UserNotificationBatch", db.usernotificationbatch),
|
||||
("NotificationEvent", db.notificationevent),
|
||||
("CreditRefundRequest", db.creditrefundrequest),
|
||||
("StoreListingReview", db.storelistingreview),
|
||||
("StoreListingVersion", db.storelistingversion),
|
||||
("StoreListing", db.storelisting),
|
||||
("AgentNodeExecutionInputOutput", db.agentnodeexecutioninputoutput),
|
||||
("AgentNodeExecution", db.agentnodeexecution),
|
||||
("AgentGraphExecution", db.agentgraphexecution),
|
||||
("AgentNodeLink", db.agentnodelink),
|
||||
("LibraryAgent", db.libraryagent),
|
||||
("AgentPreset", db.agentpreset),
|
||||
("IntegrationWebhook", db.integrationwebhook),
|
||||
("AgentNode", db.agentnode),
|
||||
("AgentGraph", db.agentgraph),
|
||||
("AgentBlock", db.agentblock),
|
||||
("APIKey", db.apikey),
|
||||
("CreditTransaction", db.credittransaction),
|
||||
("AnalyticsMetrics", db.analyticsmetrics),
|
||||
("AnalyticsDetails", db.analyticsdetails),
|
||||
("Profile", db.profile),
|
||||
("UserOnboarding", db.useronboarding),
|
||||
("User", db.user),
|
||||
]
|
||||
|
||||
for table_name, table in tables:
|
||||
try:
|
||||
count = await table.count()
|
||||
if count > 0:
|
||||
await table.delete_many()
|
||||
print(f"✓ Deleted {count} records from {table_name}")
|
||||
except Exception as e:
|
||||
print(f"⚠ Error cleaning {table_name}: {e}")
|
||||
|
||||
# Refresh materialized views (they should be empty now)
|
||||
try:
|
||||
await db.execute_raw("SELECT refresh_store_materialized_views();")
|
||||
print("\n✓ Refreshed materialized views")
|
||||
except Exception as e:
|
||||
print(f"\n⚠ Could not refresh materialized views: {e}")
|
||||
|
||||
await db.disconnect()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Database cleaned successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,35 +1,60 @@
|
||||
networks:
|
||||
app-network:
|
||||
name: app-network
|
||||
shared-network:
|
||||
name: shared-network
|
||||
|
||||
volumes:
|
||||
supabase-config:
|
||||
|
||||
x-agpt-services:
|
||||
&agpt-services
|
||||
networks:
|
||||
- app-network
|
||||
- shared-network
|
||||
|
||||
x-supabase-services:
|
||||
&supabase-services
|
||||
networks:
|
||||
- app-network
|
||||
- shared-network
|
||||
|
||||
|
||||
volumes:
|
||||
clamav-data:
|
||||
|
||||
services:
|
||||
postgres-test:
|
||||
image: ankane/pgvector:latest
|
||||
environment:
|
||||
- POSTGRES_USER=${DB_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${DB_PASS:-postgres}
|
||||
- POSTGRES_DB=${DB_NAME:-postgres}
|
||||
- POSTGRES_PORT=${DB_PORT:-5432}
|
||||
healthcheck:
|
||||
test: pg_isready -U $$POSTGRES_USER -d $$POSTGRES_DB
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
db:
|
||||
<<: *supabase-services
|
||||
extends:
|
||||
file: ../db/docker/docker-compose.yml
|
||||
service: db
|
||||
ports:
|
||||
- "${DB_PORT:-5432}:5432"
|
||||
networks:
|
||||
- app-network-test
|
||||
redis-test:
|
||||
- ${POSTGRES_PORT}:5432 # We don't use Supavisor locally, so we expose the db directly.
|
||||
|
||||
vector:
|
||||
<<: *supabase-services
|
||||
extends:
|
||||
file: ../db/docker/docker-compose.yml
|
||||
service: vector
|
||||
|
||||
redis:
|
||||
<<: *agpt-services
|
||||
image: redis:latest
|
||||
command: redis-server --requirepass password
|
||||
ports:
|
||||
- "6379:6379"
|
||||
networks:
|
||||
- app-network-test
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
rabbitmq-test:
|
||||
|
||||
rabbitmq:
|
||||
<<: *agpt-services
|
||||
image: rabbitmq:management
|
||||
container_name: rabbitmq-test
|
||||
container_name: rabbitmq
|
||||
healthcheck:
|
||||
test: rabbitmq-diagnostics -q ping
|
||||
interval: 30s
|
||||
@@ -38,11 +63,28 @@ services:
|
||||
start_period: 10s
|
||||
environment:
|
||||
- RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
- RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7 # CHANGE THIS TO A RANDOM PASSWORD IN PRODUCTION -- everywhere lol
|
||||
- RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
ports:
|
||||
- "5672:5672"
|
||||
- "15672:15672"
|
||||
|
||||
clamav:
|
||||
image: clamav/clamav-debian:latest
|
||||
ports:
|
||||
- "3310:3310"
|
||||
volumes:
|
||||
- clamav-data:/var/lib/clamav
|
||||
environment:
|
||||
- CLAMAV_NO_FRESHCLAMD=false
|
||||
- CLAMD_CONF_StreamMaxLength=50M
|
||||
- CLAMD_CONF_MaxFileSize=100M
|
||||
- CLAMD_CONF_MaxScanSize=100M
|
||||
- CLAMD_CONF_MaxThreads=12
|
||||
- CLAMD_CONF_ReadTimeout=300
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "clamdscan --version || exit 1"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
networks:
|
||||
app-network-test:
|
||||
driver: bridge
|
||||
|
||||
@@ -0,0 +1,254 @@
|
||||
-- This migration creates materialized views for performance optimization
|
||||
--
|
||||
-- IMPORTANT: For production environments, pg_cron is REQUIRED for automatic refresh
|
||||
-- Prerequisites for production:
|
||||
-- 1. pg_cron extension must be installed: CREATE EXTENSION pg_cron;
|
||||
-- 2. pg_cron must be configured in postgresql.conf:
|
||||
-- shared_preload_libraries = 'pg_cron'
|
||||
-- cron.database_name = 'your_database_name'
|
||||
--
|
||||
-- For development environments without pg_cron:
|
||||
-- The migration will succeed but you must manually refresh views with:
|
||||
-- SELECT refresh_store_materialized_views();
|
||||
|
||||
-- Check if pg_cron extension is installed and set a flag
|
||||
DO $$
|
||||
DECLARE
|
||||
has_pg_cron BOOLEAN;
|
||||
BEGIN
|
||||
SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pg_cron') INTO has_pg_cron;
|
||||
|
||||
IF NOT has_pg_cron THEN
|
||||
RAISE WARNING 'pg_cron extension is not installed!';
|
||||
RAISE WARNING 'Materialized views will be created but WILL NOT refresh automatically.';
|
||||
RAISE WARNING 'For production use, install pg_cron with: CREATE EXTENSION pg_cron;';
|
||||
RAISE WARNING 'For development, manually refresh with: SELECT refresh_store_materialized_views();';
|
||||
|
||||
-- For production deployments, uncomment the following line to make pg_cron mandatory:
|
||||
-- RAISE EXCEPTION 'pg_cron is required for production deployments';
|
||||
END IF;
|
||||
|
||||
-- Store the flag for later use in the migration
|
||||
PERFORM set_config('migration.has_pg_cron', has_pg_cron::text, false);
|
||||
END
|
||||
$$;
|
||||
|
||||
-- CreateIndex
|
||||
-- Optimized: Only include owningUserId in index columns since isDeleted and hasApprovedVersion are in WHERE clause
|
||||
CREATE INDEX IF NOT EXISTS "idx_store_listing_approved" ON "StoreListing"("owningUserId") WHERE "isDeleted" = false AND "hasApprovedVersion" = true;
|
||||
|
||||
-- CreateIndex
|
||||
-- Optimized: Only include storeListingId since submissionStatus is in WHERE clause
|
||||
CREATE INDEX IF NOT EXISTS "idx_store_listing_version_status" ON "StoreListingVersion"("storeListingId") WHERE "submissionStatus" = 'APPROVED';
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX IF NOT EXISTS "idx_slv_categories_gin" ON "StoreListingVersion" USING GIN ("categories") WHERE "submissionStatus" = 'APPROVED';
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX IF NOT EXISTS "idx_slv_agent" ON "StoreListingVersion"("agentGraphId", "agentGraphVersion") WHERE "submissionStatus" = 'APPROVED';
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX IF NOT EXISTS "idx_store_listing_review_version" ON "StoreListingReview"("storeListingVersionId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX IF NOT EXISTS "idx_agent_graph_execution_agent" ON "AgentGraphExecution"("agentGraphId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX IF NOT EXISTS "idx_profile_user" ON "Profile"("userId");
|
||||
|
||||
-- Additional performance indexes
|
||||
CREATE INDEX IF NOT EXISTS "idx_store_listing_version_approved_listing" ON "StoreListingVersion"("storeListingId", "version") WHERE "submissionStatus" = 'APPROVED';
|
||||
|
||||
-- Create materialized view for agent run counts
|
||||
CREATE MATERIALIZED VIEW IF NOT EXISTS "mv_agent_run_counts" AS
|
||||
SELECT
|
||||
"agentGraphId",
|
||||
COUNT(*) AS run_count
|
||||
FROM "AgentGraphExecution"
|
||||
GROUP BY "agentGraphId";
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS "idx_mv_agent_run_counts" ON "mv_agent_run_counts"("agentGraphId");
|
||||
|
||||
-- Create materialized view for review statistics
|
||||
CREATE MATERIALIZED VIEW IF NOT EXISTS "mv_review_stats" AS
|
||||
SELECT
|
||||
sl.id AS "storeListingId",
|
||||
COUNT(sr.id) AS review_count,
|
||||
AVG(sr.score::numeric) AS avg_rating
|
||||
FROM "StoreListing" sl
|
||||
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
|
||||
LEFT JOIN "StoreListingReview" sr ON sr."storeListingVersionId" = slv.id
|
||||
WHERE sl."isDeleted" = false
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
GROUP BY sl.id;
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS "idx_mv_review_stats" ON "mv_review_stats"("storeListingId");
|
||||
|
||||
-- DropForeignKey (if any exist on the views)
|
||||
-- None needed as views don't have foreign keys
|
||||
|
||||
-- DropView
|
||||
DROP VIEW IF EXISTS "Creator";
|
||||
|
||||
-- DropView
|
||||
DROP VIEW IF EXISTS "StoreAgent";
|
||||
|
||||
-- CreateView
|
||||
CREATE OR REPLACE VIEW "StoreAgent" AS
|
||||
WITH agent_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
GROUP BY "storeListingId"
|
||||
)
|
||||
SELECT
|
||||
sl.id AS listing_id,
|
||||
slv.id AS "storeListingVersionId",
|
||||
slv."createdAt" AS updated_at,
|
||||
sl.slug,
|
||||
COALESCE(slv.name, '') AS agent_name,
|
||||
slv."videoUrl" AS agent_video,
|
||||
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
|
||||
slv."isFeatured" AS featured,
|
||||
p.username AS creator_username,
|
||||
p."avatarUrl" AS creator_avatar,
|
||||
slv."subHeading" AS sub_heading,
|
||||
slv.description,
|
||||
slv.categories,
|
||||
COALESCE(ar.run_count, 0::bigint) AS runs,
|
||||
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
|
||||
COALESCE(av.versions, ARRAY[slv.version::text]) AS versions
|
||||
FROM "StoreListing" sl
|
||||
INNER JOIN "StoreListingVersion" slv
|
||||
ON slv."storeListingId" = sl.id
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
JOIN "AgentGraph" a
|
||||
ON slv."agentGraphId" = a.id
|
||||
AND slv."agentGraphVersion" = a.version
|
||||
LEFT JOIN "Profile" p
|
||||
ON sl."owningUserId" = p."userId"
|
||||
LEFT JOIN "mv_review_stats" rs
|
||||
ON sl.id = rs."storeListingId"
|
||||
LEFT JOIN "mv_agent_run_counts" ar
|
||||
ON a.id = ar."agentGraphId"
|
||||
LEFT JOIN agent_versions av
|
||||
ON sl.id = av."storeListingId"
|
||||
WHERE sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true;
|
||||
|
||||
-- CreateView
|
||||
CREATE OR REPLACE VIEW "Creator" AS
|
||||
WITH creator_listings AS (
|
||||
SELECT
|
||||
sl."owningUserId",
|
||||
sl.id AS listing_id,
|
||||
slv."agentGraphId",
|
||||
slv.categories,
|
||||
sr.score,
|
||||
ar.run_count
|
||||
FROM "StoreListing" sl
|
||||
INNER JOIN "StoreListingVersion" slv
|
||||
ON slv."storeListingId" = sl.id
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
LEFT JOIN "StoreListingReview" sr
|
||||
ON sr."storeListingVersionId" = slv.id
|
||||
LEFT JOIN "mv_agent_run_counts" ar
|
||||
ON ar."agentGraphId" = slv."agentGraphId"
|
||||
WHERE sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true
|
||||
),
|
||||
creator_stats AS (
|
||||
SELECT
|
||||
cl."owningUserId",
|
||||
COUNT(DISTINCT cl.listing_id) AS num_agents,
|
||||
AVG(COALESCE(cl.score, 0)::numeric) AS agent_rating,
|
||||
SUM(DISTINCT COALESCE(cl.run_count, 0)) AS agent_runs,
|
||||
array_agg(DISTINCT cat ORDER BY cat) FILTER (WHERE cat IS NOT NULL) AS all_categories
|
||||
FROM creator_listings cl
|
||||
LEFT JOIN LATERAL unnest(COALESCE(cl.categories, ARRAY[]::text[])) AS cat ON true
|
||||
GROUP BY cl."owningUserId"
|
||||
)
|
||||
SELECT
|
||||
p.username,
|
||||
p.name,
|
||||
p."avatarUrl" AS avatar_url,
|
||||
p.description,
|
||||
cs.all_categories AS top_categories,
|
||||
p.links,
|
||||
p."isFeatured" AS is_featured,
|
||||
COALESCE(cs.num_agents, 0::bigint) AS num_agents,
|
||||
COALESCE(cs.agent_rating, 0.0) AS agent_rating,
|
||||
COALESCE(cs.agent_runs, 0::numeric) AS agent_runs
|
||||
FROM "Profile" p
|
||||
LEFT JOIN creator_stats cs ON cs."owningUserId" = p."userId";
|
||||
|
||||
-- Create refresh function that works with the current schema
|
||||
CREATE OR REPLACE FUNCTION refresh_store_materialized_views()
|
||||
RETURNS void
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
DECLARE
|
||||
current_schema_name text;
|
||||
BEGIN
|
||||
-- Get the current schema
|
||||
current_schema_name := current_schema();
|
||||
|
||||
-- Use CONCURRENTLY for better performance during refresh
|
||||
EXECUTE format('REFRESH MATERIALIZED VIEW CONCURRENTLY %I."mv_agent_run_counts"', current_schema_name);
|
||||
EXECUTE format('REFRESH MATERIALIZED VIEW CONCURRENTLY %I."mv_review_stats"', current_schema_name);
|
||||
RAISE NOTICE 'Materialized views refreshed in schema % at %', current_schema_name, NOW();
|
||||
EXCEPTION
|
||||
WHEN OTHERS THEN
|
||||
-- Fallback to non-concurrent refresh if concurrent fails
|
||||
EXECUTE format('REFRESH MATERIALIZED VIEW %I."mv_agent_run_counts"', current_schema_name);
|
||||
EXECUTE format('REFRESH MATERIALIZED VIEW %I."mv_review_stats"', current_schema_name);
|
||||
RAISE NOTICE 'Materialized views refreshed (non-concurrent) in schema % at % due to: %', current_schema_name, NOW(), SQLERRM;
|
||||
END;
|
||||
$$;
|
||||
|
||||
-- Initial refresh of materialized views
|
||||
SELECT refresh_store_materialized_views();
|
||||
|
||||
-- Schedule automatic refresh every 15 minutes (only if pg_cron is available)
|
||||
DO $$
|
||||
DECLARE
|
||||
has_pg_cron BOOLEAN;
|
||||
current_schema_name text;
|
||||
job_name text;
|
||||
BEGIN
|
||||
-- Get the flag we set earlier
|
||||
has_pg_cron := current_setting('migration.has_pg_cron', true)::boolean;
|
||||
|
||||
-- Get current schema name
|
||||
current_schema_name := current_schema();
|
||||
|
||||
-- Create a unique job name for this schema
|
||||
job_name := format('refresh-store-views-%s', current_schema_name);
|
||||
|
||||
IF has_pg_cron THEN
|
||||
-- Try to unschedule existing job (ignore errors if it doesn't exist)
|
||||
BEGIN
|
||||
PERFORM cron.unschedule(job_name);
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
-- Job doesn't exist, that's fine
|
||||
NULL;
|
||||
END;
|
||||
|
||||
-- Schedule the refresh job with schema-specific command
|
||||
PERFORM cron.schedule(
|
||||
job_name,
|
||||
'*/15 * * * *',
|
||||
format('SELECT %I.refresh_store_materialized_views();', current_schema_name)
|
||||
);
|
||||
RAISE NOTICE 'Scheduled automatic refresh of materialized views every 15 minutes for schema %', current_schema_name;
|
||||
ELSE
|
||||
RAISE WARNING '⚠️ Automatic refresh NOT configured - pg_cron is not available';
|
||||
RAISE WARNING '⚠️ You must manually refresh views with: SELECT refresh_store_materialized_views();';
|
||||
RAISE WARNING '⚠️ Or install pg_cron for automatic refresh in production';
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -0,0 +1,155 @@
|
||||
-- Unschedule cron job (if it exists)
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pg_cron') THEN
|
||||
PERFORM cron.unschedule('refresh-store-views');
|
||||
RAISE NOTICE 'Unscheduled automatic refresh of materialized views';
|
||||
END IF;
|
||||
EXCEPTION
|
||||
WHEN OTHERS THEN
|
||||
RAISE NOTICE 'Could not unschedule cron job (may not exist): %', SQLERRM;
|
||||
END;
|
||||
$$;
|
||||
|
||||
-- DropView
|
||||
DROP VIEW IF EXISTS "Creator";
|
||||
|
||||
-- DropView
|
||||
DROP VIEW IF EXISTS "StoreAgent";
|
||||
|
||||
-- CreateView (restore original StoreAgent)
|
||||
CREATE VIEW "StoreAgent" AS
|
||||
WITH reviewstats AS (
|
||||
SELECT sl_1.id AS "storeListingId",
|
||||
count(sr.id) AS review_count,
|
||||
avg(sr.score::numeric) AS avg_rating
|
||||
FROM "StoreListing" sl_1
|
||||
JOIN "StoreListingVersion" slv_1
|
||||
ON slv_1."storeListingId" = sl_1.id
|
||||
JOIN "StoreListingReview" sr
|
||||
ON sr."storeListingVersionId" = slv_1.id
|
||||
WHERE sl_1."isDeleted" = false
|
||||
GROUP BY sl_1.id
|
||||
), agentruns AS (
|
||||
SELECT "AgentGraphExecution"."agentGraphId",
|
||||
count(*) AS run_count
|
||||
FROM "AgentGraphExecution"
|
||||
GROUP BY "AgentGraphExecution"."agentGraphId"
|
||||
)
|
||||
SELECT sl.id AS listing_id,
|
||||
slv.id AS "storeListingVersionId",
|
||||
slv."createdAt" AS updated_at,
|
||||
sl.slug,
|
||||
COALESCE(slv.name, '') AS agent_name,
|
||||
slv."videoUrl" AS agent_video,
|
||||
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
|
||||
slv."isFeatured" AS featured,
|
||||
p.username AS creator_username,
|
||||
p."avatarUrl" AS creator_avatar,
|
||||
slv."subHeading" AS sub_heading,
|
||||
slv.description,
|
||||
slv.categories,
|
||||
COALESCE(ar.run_count, 0::bigint) AS runs,
|
||||
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
|
||||
array_agg(DISTINCT slv.version::text) AS versions
|
||||
FROM "StoreListing" sl
|
||||
JOIN "StoreListingVersion" slv
|
||||
ON slv."storeListingId" = sl.id
|
||||
JOIN "AgentGraph" a
|
||||
ON slv."agentGraphId" = a.id
|
||||
AND slv."agentGraphVersion" = a.version
|
||||
LEFT JOIN "Profile" p
|
||||
ON sl."owningUserId" = p."userId"
|
||||
LEFT JOIN reviewstats rs
|
||||
ON sl.id = rs."storeListingId"
|
||||
LEFT JOIN agentruns ar
|
||||
ON a.id = ar."agentGraphId"
|
||||
WHERE sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
GROUP BY sl.id, slv.id, sl.slug, slv."createdAt", slv.name, slv."videoUrl",
|
||||
slv."imageUrls", slv."isFeatured", p.username, p."avatarUrl",
|
||||
slv."subHeading", slv.description, slv.categories, ar.run_count,
|
||||
rs.avg_rating;
|
||||
|
||||
-- CreateView (restore original Creator)
|
||||
CREATE VIEW "Creator" AS
|
||||
WITH agentstats AS (
|
||||
SELECT p_1.username,
|
||||
count(DISTINCT sl.id) AS num_agents,
|
||||
avg(COALESCE(sr.score, 0)::numeric) AS agent_rating,
|
||||
sum(COALESCE(age.run_count, 0::bigint)) AS agent_runs
|
||||
FROM "Profile" p_1
|
||||
LEFT JOIN "StoreListing" sl
|
||||
ON sl."owningUserId" = p_1."userId"
|
||||
LEFT JOIN "StoreListingVersion" slv
|
||||
ON slv."storeListingId" = sl.id
|
||||
LEFT JOIN "StoreListingReview" sr
|
||||
ON sr."storeListingVersionId" = slv.id
|
||||
LEFT JOIN (
|
||||
SELECT "AgentGraphExecution"."agentGraphId",
|
||||
count(*) AS run_count
|
||||
FROM "AgentGraphExecution"
|
||||
GROUP BY "AgentGraphExecution"."agentGraphId"
|
||||
) age ON age."agentGraphId" = slv."agentGraphId"
|
||||
WHERE sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
GROUP BY p_1.username
|
||||
)
|
||||
SELECT p.username,
|
||||
p.name,
|
||||
p."avatarUrl" AS avatar_url,
|
||||
p.description,
|
||||
array_agg(DISTINCT cats.c) FILTER (WHERE cats.c IS NOT NULL) AS top_categories,
|
||||
p.links,
|
||||
p."isFeatured" AS is_featured,
|
||||
COALESCE(ast.num_agents, 0::bigint) AS num_agents,
|
||||
COALESCE(ast.agent_rating, 0.0) AS agent_rating,
|
||||
COALESCE(ast.agent_runs, 0::numeric) AS agent_runs
|
||||
FROM "Profile" p
|
||||
LEFT JOIN agentstats ast
|
||||
ON ast.username = p.username
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT unnest(slv.categories) AS c
|
||||
FROM "StoreListing" sl
|
||||
JOIN "StoreListingVersion" slv
|
||||
ON slv."storeListingId" = sl.id
|
||||
WHERE sl."owningUserId" = p."userId"
|
||||
AND sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
) cats ON true
|
||||
GROUP BY p.username, p.name, p."avatarUrl", p.description, p.links,
|
||||
p."isFeatured", ast.num_agents, ast.agent_rating, ast.agent_runs;
|
||||
|
||||
-- Drop function
|
||||
DROP FUNCTION IF EXISTS platform.refresh_store_materialized_views();
|
||||
|
||||
-- Drop materialized views
|
||||
DROP MATERIALIZED VIEW IF EXISTS "mv_review_stats";
|
||||
DROP MATERIALIZED VIEW IF EXISTS "mv_agent_run_counts";
|
||||
|
||||
-- DropIndex
|
||||
DROP INDEX IF EXISTS "idx_profile_user";
|
||||
|
||||
-- DropIndex
|
||||
DROP INDEX IF EXISTS "idx_agent_graph_execution_agent";
|
||||
|
||||
-- DropIndex
|
||||
DROP INDEX IF EXISTS "idx_store_listing_review_version";
|
||||
|
||||
-- DropIndex
|
||||
DROP INDEX IF EXISTS "idx_slv_agent";
|
||||
|
||||
-- DropIndex
|
||||
DROP INDEX IF EXISTS "idx_slv_categories_gin";
|
||||
|
||||
-- DropIndex
|
||||
DROP INDEX IF EXISTS "idx_store_listing_version_status";
|
||||
|
||||
-- DropIndex
|
||||
DROP INDEX IF EXISTS "idx_store_listing_approved";
|
||||
|
||||
-- DropIndex
|
||||
DROP INDEX IF EXISTS "idx_store_listing_version_approved_listing";
|
||||
@@ -123,3 +123,4 @@ filterwarnings = [
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
|
||||
|
||||
110
autogpt_platform/backend/run_test_data.py
Normal file
110
autogpt_platform/backend/run_test_data.py
Normal file
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Run test data creation and update scripts in sequence.
|
||||
|
||||
Usage:
|
||||
poetry run python run_test_data.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def run_command(cmd: list[str], cwd: Path | None = None) -> bool:
|
||||
"""Run a command and return True if successful."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, check=True, capture_output=True, text=True, cwd=cwd
|
||||
)
|
||||
if result.stdout:
|
||||
print(result.stdout)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error running command: {' '.join(cmd)}")
|
||||
print(f"Error: {e.stderr}")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function to run test data scripts."""
|
||||
print("=" * 60)
|
||||
print("Running Test Data Scripts for AutoGPT Platform")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Get the backend directory
|
||||
backend_dir = Path(__file__).parent
|
||||
test_dir = backend_dir / "test"
|
||||
|
||||
# Check if we're in the right directory
|
||||
if not (backend_dir / "pyproject.toml").exists():
|
||||
print("ERROR: This script must be run from the backend directory")
|
||||
sys.exit(1)
|
||||
|
||||
print("1. Checking database connection...")
|
||||
print("-" * 40)
|
||||
|
||||
# Import here to ensure proper environment setup
|
||||
try:
|
||||
from prisma import Prisma
|
||||
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
print("✓ Database connection successful")
|
||||
await db.disconnect()
|
||||
except Exception as e:
|
||||
print(f"✗ Database connection failed: {e}")
|
||||
print("\nPlease ensure:")
|
||||
print("1. The database services are running (docker compose up -d)")
|
||||
print("2. The DATABASE_URL in .env is correct")
|
||||
print("3. Migrations have been run (poetry run prisma migrate deploy)")
|
||||
sys.exit(1)
|
||||
|
||||
print()
|
||||
print("2. Running test data creator...")
|
||||
print("-" * 40)
|
||||
|
||||
# Run test_data_creator.py
|
||||
if run_command(["poetry", "run", "python", "test_data_creator.py"], cwd=test_dir):
|
||||
print()
|
||||
print("✅ Test data created successfully!")
|
||||
|
||||
print()
|
||||
print("3. Running test data updater...")
|
||||
print("-" * 40)
|
||||
|
||||
# Run test_data_updater.py
|
||||
if run_command(
|
||||
["poetry", "run", "python", "test_data_updater.py"], cwd=test_dir
|
||||
):
|
||||
print()
|
||||
print("✅ Test data updated successfully!")
|
||||
else:
|
||||
print()
|
||||
print("❌ Test data updater failed!")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print()
|
||||
print("❌ Test data creator failed!")
|
||||
sys.exit(1)
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("Test data setup completed successfully!")
|
||||
print("=" * 60)
|
||||
print()
|
||||
print("The materialized views have been populated with test data:")
|
||||
print("- mv_agent_run_counts: Agent execution statistics")
|
||||
print("- mv_review_stats: Store listing review statistics")
|
||||
print()
|
||||
print("You can now:")
|
||||
print("1. Run tests: poetry run test")
|
||||
print("2. Start the backend: poetry run serve")
|
||||
print("3. View data in the database")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -13,8 +13,10 @@ def wait_for_postgres(max_retries=5, delay=5):
|
||||
"compose",
|
||||
"-f",
|
||||
"docker-compose.test.yaml",
|
||||
"--env-file",
|
||||
"../.env",
|
||||
"exec",
|
||||
"postgres-test",
|
||||
"db",
|
||||
"pg_isready",
|
||||
"-U",
|
||||
"postgres",
|
||||
@@ -51,6 +53,8 @@ def test():
|
||||
"compose",
|
||||
"-f",
|
||||
"docker-compose.test.yaml",
|
||||
"--env-file",
|
||||
"../.env",
|
||||
"up",
|
||||
"-d",
|
||||
]
|
||||
@@ -74,11 +78,20 @@ def test():
|
||||
# to their development database, running tests would wipe their local data!
|
||||
test_env = os.environ.copy()
|
||||
|
||||
# Use environment variables if set, otherwise use defaults that match docker-compose.test.yaml
|
||||
db_user = os.getenv("DB_USER", "postgres")
|
||||
db_pass = os.getenv("DB_PASS", "postgres")
|
||||
db_name = os.getenv("DB_NAME", "postgres")
|
||||
db_port = os.getenv("DB_PORT", "5432")
|
||||
# Load database configuration from .env file
|
||||
dotenv_path = os.path.join(os.path.dirname(__file__), "../.env")
|
||||
if os.path.exists(dotenv_path):
|
||||
with open(dotenv_path) as f:
|
||||
for line in f:
|
||||
if line.strip() and not line.startswith("#"):
|
||||
key, value = line.strip().split("=", 1)
|
||||
os.environ[key] = value
|
||||
|
||||
# Get database config from environment (now populated from .env)
|
||||
db_user = os.getenv("POSTGRES_USER", "postgres")
|
||||
db_pass = os.getenv("POSTGRES_PASSWORD", "postgres")
|
||||
db_name = os.getenv("POSTGRES_DB", "postgres")
|
||||
db_port = os.getenv("POSTGRES_PORT", "5432")
|
||||
|
||||
# Construct the test database URL - this ensures we're always pointing to the test container
|
||||
test_env["DATABASE_URL"] = (
|
||||
|
||||
@@ -599,7 +599,23 @@ view Creator {
|
||||
agent_runs Int
|
||||
is_featured Boolean
|
||||
|
||||
// Index or unique are not applied to views
|
||||
// Note: Prisma doesn't support indexes on views, but the following indexes exist in the database:
|
||||
//
|
||||
// Optimized indexes (partial indexes to reduce size and improve performance):
|
||||
// - idx_profile_user on Profile(userId)
|
||||
// - idx_store_listing_approved on StoreListing(owningUserId) WHERE isDeleted = false AND hasApprovedVersion = true
|
||||
// - idx_store_listing_version_status on StoreListingVersion(storeListingId) WHERE submissionStatus = 'APPROVED'
|
||||
// - idx_slv_categories_gin - GIN index on StoreListingVersion(categories) WHERE submissionStatus = 'APPROVED'
|
||||
// - idx_slv_agent on StoreListingVersion(agentGraphId, agentGraphVersion) WHERE submissionStatus = 'APPROVED'
|
||||
// - idx_store_listing_review_version on StoreListingReview(storeListingVersionId)
|
||||
// - idx_store_listing_version_approved_listing on StoreListingVersion(storeListingId, version) WHERE submissionStatus = 'APPROVED'
|
||||
// - idx_agent_graph_execution_agent on AgentGraphExecution(agentGraphId)
|
||||
//
|
||||
// Materialized views used (refreshed every 15 minutes via pg_cron):
|
||||
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
|
||||
// - mv_review_stats - Pre-aggregated review statistics (count, avg rating) by storeListingId
|
||||
//
|
||||
// Query strategy: Uses CTEs to efficiently aggregate creator statistics leveraging materialized views
|
||||
}
|
||||
|
||||
view StoreAgent {
|
||||
@@ -622,7 +638,30 @@ view StoreAgent {
|
||||
rating Float
|
||||
versions String[]
|
||||
|
||||
// Index or unique are not applied to views
|
||||
// Note: Prisma doesn't support indexes on views, but the following indexes exist in the database:
|
||||
//
|
||||
// Optimized indexes (partial indexes to reduce size and improve performance):
|
||||
// - idx_store_listing_approved on StoreListing(owningUserId) WHERE isDeleted = false AND hasApprovedVersion = true
|
||||
// - idx_store_listing_version_status on StoreListingVersion(storeListingId) WHERE submissionStatus = 'APPROVED'
|
||||
// - idx_slv_categories_gin - GIN index on StoreListingVersion(categories) WHERE submissionStatus = 'APPROVED' for array searches
|
||||
// - idx_slv_agent on StoreListingVersion(agentGraphId, agentGraphVersion) WHERE submissionStatus = 'APPROVED'
|
||||
// - idx_store_listing_review_version on StoreListingReview(storeListingVersionId)
|
||||
// - idx_store_listing_version_approved_listing on StoreListingVersion(storeListingId, version) WHERE submissionStatus = 'APPROVED'
|
||||
// - idx_agent_graph_execution_agent on AgentGraphExecution(agentGraphId)
|
||||
// - idx_profile_user on Profile(userId)
|
||||
//
|
||||
// Additional indexes from earlier migrations:
|
||||
// - StoreListing_agentId_owningUserId_idx
|
||||
// - StoreListing_isDeleted_isApproved_idx (replaced by idx_store_listing_approved)
|
||||
// - StoreListing_isDeleted_idx
|
||||
// - StoreListing_agentId_key (unique on agentGraphId)
|
||||
// - StoreListingVersion_agentId_agentVersion_isDeleted_idx
|
||||
//
|
||||
// Materialized views used (refreshed every 15 minutes via pg_cron):
|
||||
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
|
||||
// - mv_review_stats - Pre-aggregated review statistics (count, avg rating) by storeListingId
|
||||
//
|
||||
// Query strategy: Uses CTE for version aggregation and joins with materialized views for performance
|
||||
}
|
||||
|
||||
view StoreSubmission {
|
||||
@@ -649,6 +688,33 @@ view StoreSubmission {
|
||||
// Index or unique are not applied to views
|
||||
}
|
||||
|
||||
// Note: This is actually a MATERIALIZED VIEW in the database
|
||||
// Refreshed automatically every 15 minutes via pg_cron (with fallback to manual refresh)
|
||||
view mv_agent_run_counts {
|
||||
agentGraphId String @unique
|
||||
run_count Int
|
||||
|
||||
// Pre-aggregated count of AgentGraphExecution records by agentGraphId
|
||||
// Used by StoreAgent and Creator views for performance optimization
|
||||
// Unique index created automatically on agentGraphId for fast lookups
|
||||
// Refresh uses CONCURRENTLY to avoid blocking reads
|
||||
}
|
||||
|
||||
// Note: This is actually a MATERIALIZED VIEW in the database
|
||||
// Refreshed automatically every 15 minutes via pg_cron (with fallback to manual refresh)
|
||||
view mv_review_stats {
|
||||
storeListingId String @unique
|
||||
review_count Int
|
||||
avg_rating Float
|
||||
|
||||
// Pre-aggregated review statistics from StoreListingReview
|
||||
// Includes count of reviews and average rating per StoreListing
|
||||
// Only includes approved versions (submissionStatus = 'APPROVED') and non-deleted listings
|
||||
// Used by StoreAgent view for performance optimization
|
||||
// Unique index created automatically on storeListingId for fast lookups
|
||||
// Refresh uses CONCURRENTLY to avoid blocking reads
|
||||
}
|
||||
|
||||
model StoreListing {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"description": "A test graph",
|
||||
"forked_from_id": null,
|
||||
"forked_from_version": null,
|
||||
"has_webhook_trigger": false,
|
||||
"has_external_trigger": false,
|
||||
"id": "graph-123",
|
||||
"input_schema": {
|
||||
"properties": {},
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
"description": "A test graph",
|
||||
"forked_from_id": null,
|
||||
"forked_from_version": null,
|
||||
"has_webhook_trigger": false,
|
||||
"has_external_trigger": false,
|
||||
"id": "graph-123",
|
||||
"input_schema": {
|
||||
"properties": {},
|
||||
@@ -16,9 +16,7 @@
|
||||
"type": "object"
|
||||
},
|
||||
"is_active": true,
|
||||
"links": [],
|
||||
"name": "Test Graph",
|
||||
"nodes": [],
|
||||
"output_schema": {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
|
||||
1
autogpt_platform/backend/test/sdk/__init__.py
Normal file
1
autogpt_platform/backend/test/sdk/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""SDK test module."""
|
||||
20
autogpt_platform/backend/test/sdk/_config.py
Normal file
20
autogpt_platform/backend/test/sdk/_config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Shared configuration for SDK test providers using the SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
# Configure test providers
|
||||
test_api = (
|
||||
ProviderBuilder("test_api")
|
||||
.with_api_key("TEST_API_KEY", "Test API Key")
|
||||
.with_base_cost(5, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
|
||||
test_service = (
|
||||
ProviderBuilder("test_service")
|
||||
.with_api_key("TEST_SERVICE_API_KEY", "Test Service API Key")
|
||||
.with_base_cost(10, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
29
autogpt_platform/backend/test/sdk/conftest.py
Normal file
29
autogpt_platform/backend/test/sdk/conftest.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Configuration for SDK tests.
|
||||
|
||||
This conftest.py file provides basic test setup for SDK unit tests
|
||||
without requiring the full server infrastructure.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server():
|
||||
"""Mock server fixture for SDK tests."""
|
||||
mock_server = MagicMock()
|
||||
mock_server.agent_server = MagicMock()
|
||||
mock_server.agent_server.test_create_graph = MagicMock()
|
||||
return mock_server
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_registry():
|
||||
"""Reset the AutoRegistry before each test."""
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
AutoRegistry.clear()
|
||||
yield
|
||||
AutoRegistry.clear()
|
||||
914
autogpt_platform/backend/test/sdk/test_sdk_block_creation.py
Normal file
914
autogpt_platform/backend/test/sdk/test_sdk_block_creation.py
Normal file
@@ -0,0 +1,914 @@
|
||||
"""
|
||||
Tests for creating blocks using the SDK.
|
||||
|
||||
This test suite verifies that blocks can be created using only SDK imports
|
||||
and that they work correctly without decorators.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
ProviderBuilder,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._config import test_api, test_service
|
||||
|
||||
|
||||
class TestBasicBlockCreation:
|
||||
"""Test creating basic blocks using the SDK."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_block(self):
|
||||
"""Test creating a simple block without any decorators."""
|
||||
|
||||
class SimpleBlock(Block):
|
||||
"""A simple test block."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(description="Input text")
|
||||
count: int = SchemaField(description="Repeat count", default=1)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: str = SchemaField(description="Output result")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="simple-test-block",
|
||||
description="A simple test block",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=SimpleBlock.Input,
|
||||
output_schema=SimpleBlock.Output,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
result = input_data.text * input_data.count
|
||||
yield "result", result
|
||||
|
||||
# Create and test the block
|
||||
block = SimpleBlock()
|
||||
assert block.id == "simple-test-block"
|
||||
assert BlockCategory.TEXT in block.categories
|
||||
|
||||
# Test execution
|
||||
outputs = []
|
||||
async for name, value in block.run(
|
||||
SimpleBlock.Input(text="Hello ", count=3),
|
||||
):
|
||||
outputs.append((name, value))
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0] == ("result", "Hello Hello Hello ")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_with_credentials(self):
|
||||
"""Test creating a block that requires credentials."""
|
||||
|
||||
class APIBlock(Block):
|
||||
"""A block that requires API credentials."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = test_api.credentials_field(
|
||||
description="API credentials for test service",
|
||||
)
|
||||
query: str = SchemaField(description="API query")
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(description="API response")
|
||||
authenticated: bool = SchemaField(description="Was authenticated")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="api-test-block",
|
||||
description="Test block with API credentials",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=APIBlock.Input,
|
||||
output_schema=APIBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Simulate API call
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
authenticated = bool(api_key)
|
||||
|
||||
yield "response", f"API response for: {input_data.query}"
|
||||
yield "authenticated", authenticated
|
||||
|
||||
# Create test credentials
|
||||
test_creds = APIKeyCredentials(
|
||||
id="test-creds",
|
||||
provider="test_api",
|
||||
api_key=SecretStr("test-api-key"),
|
||||
title="Test API Key",
|
||||
)
|
||||
|
||||
# Create and test the block
|
||||
block = APIBlock()
|
||||
outputs = []
|
||||
async for name, value in block.run(
|
||||
APIBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "test_api",
|
||||
"id": "test-creds",
|
||||
"type": "api_key",
|
||||
},
|
||||
query="test query",
|
||||
),
|
||||
credentials=test_creds,
|
||||
):
|
||||
outputs.append((name, value))
|
||||
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0] == ("response", "API response for: test query")
|
||||
assert outputs[1] == ("authenticated", True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_with_multiple_outputs(self):
|
||||
"""Test block that yields multiple outputs."""
|
||||
|
||||
class MultiOutputBlock(Block):
|
||||
"""Block with multiple outputs."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(description="Input text")
|
||||
|
||||
class Output(BlockSchema):
|
||||
uppercase: str = SchemaField(description="Uppercase version")
|
||||
lowercase: str = SchemaField(description="Lowercase version")
|
||||
length: int = SchemaField(description="Text length")
|
||||
is_empty: bool = SchemaField(description="Is text empty")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="multi-output-block",
|
||||
description="Block with multiple outputs",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=MultiOutputBlock.Input,
|
||||
output_schema=MultiOutputBlock.Output,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
text = input_data.text
|
||||
yield "uppercase", text.upper()
|
||||
yield "lowercase", text.lower()
|
||||
yield "length", len(text)
|
||||
yield "is_empty", len(text) == 0
|
||||
|
||||
# Test the block
|
||||
block = MultiOutputBlock()
|
||||
outputs = []
|
||||
async for name, value in block.run(MultiOutputBlock.Input(text="Hello World")):
|
||||
outputs.append((name, value))
|
||||
|
||||
assert len(outputs) == 4
|
||||
assert ("uppercase", "HELLO WORLD") in outputs
|
||||
assert ("lowercase", "hello world") in outputs
|
||||
assert ("length", 11) in outputs
|
||||
assert ("is_empty", False) in outputs
|
||||
|
||||
|
||||
class TestBlockWithProvider:
|
||||
"""Test creating blocks associated with providers."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_using_provider(self):
|
||||
"""Test block that uses a registered provider."""
|
||||
|
||||
class TestServiceBlock(Block):
|
||||
"""Block for test service."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = test_service.credentials_field(
|
||||
description="Test service credentials",
|
||||
)
|
||||
action: str = SchemaField(description="Action to perform")
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: str = SchemaField(description="Action result")
|
||||
provider_name: str = SchemaField(description="Provider used")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="test-service-block",
|
||||
description="Block using test service provider",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=TestServiceBlock.Input,
|
||||
output_schema=TestServiceBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# The provider name should match
|
||||
yield "result", f"Performed: {input_data.action}"
|
||||
yield "provider_name", credentials.provider
|
||||
|
||||
# Create credentials for our provider
|
||||
creds = APIKeyCredentials(
|
||||
id="test-service-creds",
|
||||
provider="test_service",
|
||||
api_key=SecretStr("test-key"),
|
||||
title="Test Service Key",
|
||||
)
|
||||
|
||||
# Test the block
|
||||
block = TestServiceBlock()
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
TestServiceBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "test_service",
|
||||
"id": "test-service-creds",
|
||||
"type": "api_key",
|
||||
},
|
||||
action="test action",
|
||||
),
|
||||
credentials=creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["result"] == "Performed: test action"
|
||||
assert outputs["provider_name"] == "test_service"
|
||||
|
||||
|
||||
class TestComplexBlockScenarios:
|
||||
"""Test more complex block scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_with_optional_fields(self):
|
||||
"""Test block with optional input fields."""
|
||||
# Optional is already imported at the module level
|
||||
|
||||
class OptionalFieldBlock(Block):
|
||||
"""Block with optional fields."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
required_field: str = SchemaField(description="Required field")
|
||||
optional_field: Optional[str] = SchemaField(
|
||||
description="Optional field",
|
||||
default=None,
|
||||
)
|
||||
optional_with_default: str = SchemaField(
|
||||
description="Optional with default",
|
||||
default="default value",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
has_optional: bool = SchemaField(description="Has optional value")
|
||||
optional_value: Optional[str] = SchemaField(
|
||||
description="Optional value"
|
||||
)
|
||||
default_value: str = SchemaField(description="Default value")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="optional-field-block",
|
||||
description="Block with optional fields",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=OptionalFieldBlock.Input,
|
||||
output_schema=OptionalFieldBlock.Output,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "has_optional", input_data.optional_field is not None
|
||||
yield "optional_value", input_data.optional_field
|
||||
yield "default_value", input_data.optional_with_default
|
||||
|
||||
# Test with optional field provided
|
||||
block = OptionalFieldBlock()
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
OptionalFieldBlock.Input(
|
||||
required_field="test",
|
||||
optional_field="provided",
|
||||
)
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["has_optional"] is True
|
||||
assert outputs["optional_value"] == "provided"
|
||||
assert outputs["default_value"] == "default value"
|
||||
|
||||
# Test without optional field
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
OptionalFieldBlock.Input(
|
||||
required_field="test",
|
||||
)
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["has_optional"] is False
|
||||
assert outputs["optional_value"] is None
|
||||
assert outputs["default_value"] == "default value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_with_complex_types(self):
|
||||
"""Test block with complex input/output types."""
|
||||
|
||||
class ComplexBlock(Block):
|
||||
"""Block with complex types."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
items: list[str] = SchemaField(description="List of items")
|
||||
mapping: dict[str, int] = SchemaField(
|
||||
description="String to int mapping"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
item_count: int = SchemaField(description="Number of items")
|
||||
total_value: int = SchemaField(description="Sum of mapping values")
|
||||
combined: list[str] = SchemaField(description="Combined results")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="complex-types-block",
|
||||
description="Block with complex types",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=ComplexBlock.Input,
|
||||
output_schema=ComplexBlock.Output,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "item_count", len(input_data.items)
|
||||
yield "total_value", sum(input_data.mapping.values())
|
||||
|
||||
# Combine items with their mapping values
|
||||
combined = []
|
||||
for item in input_data.items:
|
||||
value = input_data.mapping.get(item, 0)
|
||||
combined.append(f"{item}: {value}")
|
||||
|
||||
yield "combined", combined
|
||||
|
||||
# Test the block
|
||||
block = ComplexBlock()
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
ComplexBlock.Input(
|
||||
items=["apple", "banana", "orange"],
|
||||
mapping={"apple": 5, "banana": 3, "orange": 4},
|
||||
)
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["item_count"] == 3
|
||||
assert outputs["total_value"] == 12
|
||||
assert outputs["combined"] == ["apple: 5", "banana: 3", "orange: 4"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_error_handling(self):
|
||||
"""Test block error handling."""
|
||||
|
||||
class ErrorHandlingBlock(Block):
|
||||
"""Block that demonstrates error handling."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
value: int = SchemaField(description="Input value")
|
||||
should_error: bool = SchemaField(
|
||||
description="Whether to trigger an error",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: int = SchemaField(description="Result")
|
||||
error_message: Optional[str] = SchemaField(
|
||||
description="Error if any", default=None
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="error-handling-block",
|
||||
description="Block with error handling",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=ErrorHandlingBlock.Input,
|
||||
output_schema=ErrorHandlingBlock.Output,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if input_data.should_error:
|
||||
raise ValueError("Intentional error triggered")
|
||||
|
||||
if input_data.value < 0:
|
||||
yield "error_message", "Value must be non-negative"
|
||||
yield "result", 0
|
||||
else:
|
||||
yield "result", input_data.value * 2
|
||||
yield "error_message", None
|
||||
|
||||
# Test normal operation
|
||||
block = ErrorHandlingBlock()
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
ErrorHandlingBlock.Input(value=5, should_error=False)
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["result"] == 10
|
||||
assert outputs["error_message"] is None
|
||||
|
||||
# Test with negative value
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
ErrorHandlingBlock.Input(value=-5, should_error=False)
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["result"] == 0
|
||||
assert outputs["error_message"] == "Value must be non-negative"
|
||||
|
||||
# Test with error
|
||||
with pytest.raises(ValueError, match="Intentional error triggered"):
|
||||
async for _ in block.run(
|
||||
ErrorHandlingBlock.Input(value=5, should_error=True)
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class TestAuthenticationVariants:
|
||||
"""Test complex authentication scenarios including OAuth, API keys, and scopes."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_block_with_scopes(self):
|
||||
"""Test creating a block that uses OAuth2 with scopes."""
|
||||
from backend.sdk import OAuth2Credentials, ProviderBuilder
|
||||
|
||||
# Create a test OAuth provider with scopes
|
||||
# For testing, we don't need an actual OAuth handler
|
||||
# In real usage, you would provide a proper OAuth handler class
|
||||
oauth_provider = (
|
||||
ProviderBuilder("test_oauth_provider")
|
||||
.with_api_key("TEST_OAUTH_API", "Test OAuth API")
|
||||
.with_base_cost(5, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
|
||||
class OAuthScopedBlock(Block):
|
||||
"""Block requiring OAuth2 with specific scopes."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = oauth_provider.credentials_field(
|
||||
description="OAuth2 credentials with scopes",
|
||||
scopes=["read:user", "write:data"],
|
||||
)
|
||||
resource: str = SchemaField(description="Resource to access")
|
||||
|
||||
class Output(BlockSchema):
|
||||
data: str = SchemaField(description="Retrieved data")
|
||||
scopes_used: list[str] = SchemaField(
|
||||
description="Scopes that were used"
|
||||
)
|
||||
token_info: dict[str, Any] = SchemaField(
|
||||
description="Token information"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="oauth-scoped-block",
|
||||
description="Test OAuth2 with scopes",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=OAuthScopedBlock.Input,
|
||||
output_schema=OAuthScopedBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: OAuth2Credentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Simulate OAuth API call with scopes
|
||||
token = credentials.access_token.get_secret_value()
|
||||
|
||||
yield "data", f"OAuth data for {input_data.resource}"
|
||||
yield "scopes_used", credentials.scopes or []
|
||||
yield "token_info", {
|
||||
"has_token": bool(token),
|
||||
"has_refresh": credentials.refresh_token is not None,
|
||||
"provider": credentials.provider,
|
||||
"expires_at": credentials.access_token_expires_at,
|
||||
}
|
||||
|
||||
# Create test OAuth credentials
|
||||
test_oauth_creds = OAuth2Credentials(
|
||||
id="test-oauth-creds",
|
||||
provider="test_oauth_provider",
|
||||
access_token=SecretStr("test-access-token"),
|
||||
refresh_token=SecretStr("test-refresh-token"),
|
||||
scopes=["read:user", "write:data"],
|
||||
title="Test OAuth Credentials",
|
||||
)
|
||||
|
||||
# Test the block
|
||||
block = OAuthScopedBlock()
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
OAuthScopedBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "test_oauth_provider",
|
||||
"id": "test-oauth-creds",
|
||||
"type": "oauth2",
|
||||
},
|
||||
resource="user/profile",
|
||||
),
|
||||
credentials=test_oauth_creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["data"] == "OAuth data for user/profile"
|
||||
assert set(outputs["scopes_used"]) == {"read:user", "write:data"}
|
||||
assert outputs["token_info"]["has_token"] is True
|
||||
assert outputs["token_info"]["expires_at"] is None
|
||||
assert outputs["token_info"]["has_refresh"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_auth_block(self):
|
||||
"""Test block that supports both OAuth2 and API key authentication."""
|
||||
# No need to import these again, already imported at top
|
||||
|
||||
# Create provider supporting both auth types
|
||||
# Create provider supporting API key auth
|
||||
# In real usage, you would add OAuth support with .with_oauth()
|
||||
mixed_provider = (
|
||||
ProviderBuilder("mixed_auth_provider")
|
||||
.with_api_key("MIXED_API_KEY", "Mixed Provider API Key")
|
||||
.with_base_cost(8, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
|
||||
class MixedAuthBlock(Block):
|
||||
"""Block supporting multiple authentication methods."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = mixed_provider.credentials_field(
|
||||
description="API key or OAuth2 credentials",
|
||||
supported_credential_types=["api_key", "oauth2"],
|
||||
)
|
||||
operation: str = SchemaField(description="Operation to perform")
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: str = SchemaField(description="Operation result")
|
||||
auth_type: str = SchemaField(description="Authentication type used")
|
||||
auth_details: dict[str, Any] = SchemaField(description="Auth details")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="mixed-auth-block",
|
||||
description="Block supporting OAuth2 and API key",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=MixedAuthBlock.Input,
|
||||
output_schema=MixedAuthBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: Union[APIKeyCredentials, OAuth2Credentials],
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Handle different credential types
|
||||
if isinstance(credentials, APIKeyCredentials):
|
||||
auth_type = "api_key"
|
||||
auth_details = {
|
||||
"has_key": bool(credentials.api_key.get_secret_value()),
|
||||
"key_prefix": credentials.api_key.get_secret_value()[:5]
|
||||
+ "...",
|
||||
}
|
||||
elif isinstance(credentials, OAuth2Credentials):
|
||||
auth_type = "oauth2"
|
||||
auth_details = {
|
||||
"has_token": bool(credentials.access_token.get_secret_value()),
|
||||
"scopes": credentials.scopes or [],
|
||||
}
|
||||
else:
|
||||
auth_type = "unknown"
|
||||
auth_details = {}
|
||||
|
||||
yield "result", f"Performed {input_data.operation} with {auth_type}"
|
||||
yield "auth_type", auth_type
|
||||
yield "auth_details", auth_details
|
||||
|
||||
# Test with API key
|
||||
api_creds = APIKeyCredentials(
|
||||
id="mixed-api-creds",
|
||||
provider="mixed_auth_provider",
|
||||
api_key=SecretStr("sk-1234567890"),
|
||||
title="Mixed API Key",
|
||||
)
|
||||
|
||||
block = MixedAuthBlock()
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
MixedAuthBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "mixed_auth_provider",
|
||||
"id": "mixed-api-creds",
|
||||
"type": "api_key",
|
||||
},
|
||||
operation="fetch_data",
|
||||
),
|
||||
credentials=api_creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["auth_type"] == "api_key"
|
||||
assert outputs["result"] == "Performed fetch_data with api_key"
|
||||
assert outputs["auth_details"]["key_prefix"] == "sk-12..."
|
||||
|
||||
# Test with OAuth2
|
||||
oauth_creds = OAuth2Credentials(
|
||||
id="mixed-oauth-creds",
|
||||
provider="mixed_auth_provider",
|
||||
access_token=SecretStr("oauth-token-123"),
|
||||
scopes=["full_access"],
|
||||
title="Mixed OAuth",
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
MixedAuthBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "mixed_auth_provider",
|
||||
"id": "mixed-oauth-creds",
|
||||
"type": "oauth2",
|
||||
},
|
||||
operation="update_data",
|
||||
),
|
||||
credentials=oauth_creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["auth_type"] == "oauth2"
|
||||
assert outputs["result"] == "Performed update_data with oauth2"
|
||||
assert outputs["auth_details"]["scopes"] == ["full_access"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_credentials_block(self):
|
||||
"""Test block requiring multiple different credentials."""
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
# Create multiple providers
|
||||
primary_provider = (
|
||||
ProviderBuilder("primary_service")
|
||||
.with_api_key("PRIMARY_API_KEY", "Primary Service Key")
|
||||
.build()
|
||||
)
|
||||
|
||||
# For testing purposes, using API key instead of OAuth handler
|
||||
secondary_provider = (
|
||||
ProviderBuilder("secondary_service")
|
||||
.with_api_key("SECONDARY_API_KEY", "Secondary Service Key")
|
||||
.build()
|
||||
)
|
||||
|
||||
class MultiCredentialBlock(Block):
|
||||
"""Block requiring credentials from multiple services."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
primary_credentials: CredentialsMetaInput = (
|
||||
primary_provider.credentials_field(
|
||||
description="Primary service API key"
|
||||
)
|
||||
)
|
||||
secondary_credentials: CredentialsMetaInput = (
|
||||
secondary_provider.credentials_field(
|
||||
description="Secondary service OAuth"
|
||||
)
|
||||
)
|
||||
merge_data: bool = SchemaField(
|
||||
description="Whether to merge data from both services",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
primary_data: str = SchemaField(description="Data from primary service")
|
||||
secondary_data: str = SchemaField(
|
||||
description="Data from secondary service"
|
||||
)
|
||||
merged_result: Optional[str] = SchemaField(
|
||||
description="Merged data if requested"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="multi-credential-block",
|
||||
description="Block using multiple credentials",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=MultiCredentialBlock.Input,
|
||||
output_schema=MultiCredentialBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
primary_credentials: APIKeyCredentials,
|
||||
secondary_credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Simulate fetching data with primary API key
|
||||
primary_data = f"Primary data using {primary_credentials.provider}"
|
||||
yield "primary_data", primary_data
|
||||
|
||||
# Simulate fetching data with secondary OAuth
|
||||
secondary_data = f"Secondary data with {len(secondary_credentials.scopes or [])} scopes"
|
||||
yield "secondary_data", secondary_data
|
||||
|
||||
# Merge if requested
|
||||
if input_data.merge_data:
|
||||
merged = f"{primary_data} + {secondary_data}"
|
||||
yield "merged_result", merged
|
||||
else:
|
||||
yield "merged_result", None
|
||||
|
||||
# Create test credentials
|
||||
primary_creds = APIKeyCredentials(
|
||||
id="primary-creds",
|
||||
provider="primary_service",
|
||||
api_key=SecretStr("primary-key-123"),
|
||||
title="Primary Key",
|
||||
)
|
||||
|
||||
secondary_creds = OAuth2Credentials(
|
||||
id="secondary-creds",
|
||||
provider="secondary_service",
|
||||
access_token=SecretStr("secondary-token"),
|
||||
scopes=["read", "write"],
|
||||
title="Secondary OAuth",
|
||||
)
|
||||
|
||||
# Test the block
|
||||
block = MultiCredentialBlock()
|
||||
outputs = {}
|
||||
|
||||
# Note: In real usage, the framework would inject the correct credentials
|
||||
# based on the field names. Here we simulate that behavior.
|
||||
async for name, value in block.run(
|
||||
MultiCredentialBlock.Input(
|
||||
primary_credentials={ # type: ignore
|
||||
"provider": "primary_service",
|
||||
"id": "primary-creds",
|
||||
"type": "api_key",
|
||||
},
|
||||
secondary_credentials={ # type: ignore
|
||||
"provider": "secondary_service",
|
||||
"id": "secondary-creds",
|
||||
"type": "oauth2",
|
||||
},
|
||||
merge_data=True,
|
||||
),
|
||||
primary_credentials=primary_creds,
|
||||
secondary_credentials=secondary_creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["primary_data"] == "Primary data using primary_service"
|
||||
assert outputs["secondary_data"] == "Secondary data with 2 scopes"
|
||||
assert "Primary data" in outputs["merged_result"]
|
||||
assert "Secondary data" in outputs["merged_result"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_scope_validation(self):
|
||||
"""Test OAuth scope validation and handling."""
|
||||
from backend.sdk import OAuth2Credentials, ProviderBuilder
|
||||
|
||||
# Provider with specific required scopes
|
||||
# For testing OAuth scope validation
|
||||
scoped_provider = (
|
||||
ProviderBuilder("scoped_oauth_service")
|
||||
.with_api_key("SCOPED_OAUTH_KEY", "Scoped OAuth Service")
|
||||
.build()
|
||||
)
|
||||
|
||||
class ScopeValidationBlock(Block):
|
||||
"""Block that validates OAuth scopes."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = scoped_provider.credentials_field(
|
||||
description="OAuth credentials with specific scopes",
|
||||
scopes=["user:read", "user:write"], # Required scopes
|
||||
)
|
||||
require_admin: bool = SchemaField(
|
||||
description="Whether admin scopes are required",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
allowed_operations: list[str] = SchemaField(
|
||||
description="Operations allowed with current scopes"
|
||||
)
|
||||
missing_scopes: list[str] = SchemaField(
|
||||
description="Scopes that are missing for full access"
|
||||
)
|
||||
has_required_scopes: bool = SchemaField(
|
||||
description="Whether all required scopes are present"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="scope-validation-block",
|
||||
description="Block that validates OAuth scopes",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=ScopeValidationBlock.Input,
|
||||
output_schema=ScopeValidationBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: OAuth2Credentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
current_scopes = set(credentials.scopes or [])
|
||||
required_scopes = {"user:read", "user:write"}
|
||||
|
||||
if input_data.require_admin:
|
||||
required_scopes.update({"admin:read", "admin:write"})
|
||||
|
||||
# Determine allowed operations based on scopes
|
||||
allowed_ops = []
|
||||
if "user:read" in current_scopes:
|
||||
allowed_ops.append("read_user_data")
|
||||
if "user:write" in current_scopes:
|
||||
allowed_ops.append("update_user_data")
|
||||
if "admin:read" in current_scopes:
|
||||
allowed_ops.append("read_admin_data")
|
||||
if "admin:write" in current_scopes:
|
||||
allowed_ops.append("update_admin_data")
|
||||
|
||||
missing = list(required_scopes - current_scopes)
|
||||
has_required = len(missing) == 0
|
||||
|
||||
yield "allowed_operations", allowed_ops
|
||||
yield "missing_scopes", missing
|
||||
yield "has_required_scopes", has_required
|
||||
|
||||
# Test with partial scopes
|
||||
partial_creds = OAuth2Credentials(
|
||||
id="partial-oauth",
|
||||
provider="scoped_oauth_service",
|
||||
access_token=SecretStr("partial-token"),
|
||||
scopes=["user:read"], # Only one of the required scopes
|
||||
title="Partial OAuth",
|
||||
)
|
||||
|
||||
block = ScopeValidationBlock()
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
ScopeValidationBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "scoped_oauth_service",
|
||||
"id": "partial-oauth",
|
||||
"type": "oauth2",
|
||||
},
|
||||
require_admin=False,
|
||||
),
|
||||
credentials=partial_creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["allowed_operations"] == ["read_user_data"]
|
||||
assert "user:write" in outputs["missing_scopes"]
|
||||
assert outputs["has_required_scopes"] is False
|
||||
|
||||
# Test with all required scopes
|
||||
full_creds = OAuth2Credentials(
|
||||
id="full-oauth",
|
||||
provider="scoped_oauth_service",
|
||||
access_token=SecretStr("full-token"),
|
||||
scopes=["user:read", "user:write", "admin:read"],
|
||||
title="Full OAuth",
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
ScopeValidationBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "scoped_oauth_service",
|
||||
"id": "full-oauth",
|
||||
"type": "oauth2",
|
||||
},
|
||||
require_admin=False,
|
||||
),
|
||||
credentials=full_creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert set(outputs["allowed_operations"]) == {
|
||||
"read_user_data",
|
||||
"update_user_data",
|
||||
"read_admin_data",
|
||||
}
|
||||
assert outputs["missing_scopes"] == []
|
||||
assert outputs["has_required_scopes"] is True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
150
autogpt_platform/backend/test/sdk/test_sdk_patching.py
Normal file
150
autogpt_platform/backend/test/sdk/test_sdk_patching.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Tests for the SDK's integration patching mechanism.
|
||||
|
||||
This test suite verifies that the AutoRegistry correctly patches
|
||||
existing integration points to include SDK-registered components.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk import (
|
||||
AutoRegistry,
|
||||
BaseOAuthHandler,
|
||||
BaseWebhooksManager,
|
||||
ProviderBuilder,
|
||||
)
|
||||
|
||||
|
||||
class MockOAuthHandler(BaseOAuthHandler):
|
||||
"""Mock OAuth handler for testing."""
|
||||
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
@classmethod
|
||||
async def authorize(cls, *args, **kwargs):
|
||||
return "mock_auth"
|
||||
|
||||
|
||||
class MockWebhookManager(BaseWebhooksManager):
|
||||
"""Mock webhook manager for testing."""
|
||||
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(cls, webhook, request):
|
||||
return {}, "test_event"
|
||||
|
||||
async def _register_webhook(self, *args, **kwargs):
|
||||
return "mock_webhook_id", {}
|
||||
|
||||
async def _deregister_webhook(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class TestWebhookPatching:
|
||||
"""Test webhook manager patching functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
def test_webhook_manager_patching(self):
|
||||
"""Test that webhook managers are correctly patched."""
|
||||
|
||||
# Mock the original load_webhook_managers function
|
||||
def mock_load_webhook_managers():
|
||||
return {
|
||||
"existing_webhook": Mock(spec=BaseWebhooksManager),
|
||||
}
|
||||
|
||||
# Register a provider with webhooks
|
||||
(
|
||||
ProviderBuilder("webhook_provider")
|
||||
.with_webhook_manager(MockWebhookManager)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Mock the webhooks module
|
||||
mock_webhooks_module = MagicMock()
|
||||
mock_webhooks_module.load_webhook_managers = mock_load_webhook_managers
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules", {"backend.integrations.webhooks": mock_webhooks_module}
|
||||
):
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Call the patched function
|
||||
result = mock_webhooks_module.load_webhook_managers()
|
||||
|
||||
# Original webhook should still exist
|
||||
assert "existing_webhook" in result
|
||||
|
||||
# New webhook should be added
|
||||
assert "webhook_provider" in result
|
||||
assert result["webhook_provider"] == MockWebhookManager
|
||||
|
||||
def test_webhook_patching_no_original_function(self):
|
||||
"""Test webhook patching when load_webhook_managers doesn't exist."""
|
||||
# Mock webhooks module without load_webhook_managers
|
||||
mock_webhooks_module = MagicMock(spec=[])
|
||||
|
||||
# Register a provider
|
||||
(
|
||||
ProviderBuilder("test_provider")
|
||||
.with_webhook_manager(MockWebhookManager)
|
||||
.build()
|
||||
)
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules", {"backend.integrations.webhooks": mock_webhooks_module}
|
||||
):
|
||||
# Should not raise an error
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Function should not be added if it didn't exist
|
||||
assert not hasattr(mock_webhooks_module, "load_webhook_managers")
|
||||
|
||||
|
||||
class TestPatchingIntegration:
|
||||
"""Test the complete patching integration flow."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
def test_complete_provider_registration_and_patching(self):
|
||||
"""Test the complete flow from provider registration to patching."""
|
||||
# Mock webhooks module
|
||||
mock_webhooks = MagicMock()
|
||||
mock_webhooks.load_webhook_managers = lambda: {"original": Mock()}
|
||||
|
||||
# Create a fully featured provider
|
||||
(
|
||||
ProviderBuilder("complete_provider")
|
||||
.with_api_key("COMPLETE_KEY", "Complete API Key")
|
||||
.with_oauth(MockOAuthHandler, scopes=["read", "write"])
|
||||
.with_webhook_manager(MockWebhookManager)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Apply patches
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"backend.integrations.webhooks": mock_webhooks,
|
||||
},
|
||||
):
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Verify webhook patching
|
||||
webhook_result = mock_webhooks.load_webhook_managers()
|
||||
assert "complete_provider" in webhook_result
|
||||
assert webhook_result["complete_provider"] == MockWebhookManager
|
||||
assert "original" in webhook_result # Original preserved
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
482
autogpt_platform/backend/test/sdk/test_sdk_registry.py
Normal file
482
autogpt_platform/backend/test/sdk/test_sdk_registry.py
Normal file
@@ -0,0 +1,482 @@
|
||||
"""
|
||||
Tests for the SDK auto-registration system via AutoRegistry.
|
||||
|
||||
This test suite verifies:
|
||||
1. Provider registration and retrieval
|
||||
2. OAuth handler registration via patches
|
||||
3. Webhook manager registration via patches
|
||||
4. Credential registration and management
|
||||
5. Block configuration association
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
AutoRegistry,
|
||||
BaseOAuthHandler,
|
||||
BaseWebhooksManager,
|
||||
Block,
|
||||
BlockConfiguration,
|
||||
Provider,
|
||||
ProviderBuilder,
|
||||
)
|
||||
|
||||
|
||||
class TestAutoRegistry:
|
||||
"""Test the AutoRegistry functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry before each test."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
def test_provider_registration(self):
|
||||
"""Test that providers can be registered and retrieved."""
|
||||
# Create a test provider
|
||||
provider = Provider(
|
||||
name="test_provider",
|
||||
oauth_handler=None,
|
||||
webhook_manager=None,
|
||||
default_credentials=[],
|
||||
base_costs=[],
|
||||
supported_auth_types={"api_key"},
|
||||
)
|
||||
|
||||
# Register it
|
||||
AutoRegistry.register_provider(provider)
|
||||
|
||||
# Verify it's registered
|
||||
assert "test_provider" in AutoRegistry._providers
|
||||
assert AutoRegistry.get_provider("test_provider") == provider
|
||||
|
||||
def test_provider_with_oauth(self):
|
||||
"""Test provider registration with OAuth handler."""
|
||||
|
||||
# Create a mock OAuth handler
|
||||
class TestOAuthHandler(BaseOAuthHandler):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
from backend.sdk.provider import OAuthConfig
|
||||
|
||||
provider = Provider(
|
||||
name="oauth_provider",
|
||||
oauth_config=OAuthConfig(oauth_handler=TestOAuthHandler),
|
||||
webhook_manager=None,
|
||||
default_credentials=[],
|
||||
base_costs=[],
|
||||
supported_auth_types={"oauth2"},
|
||||
)
|
||||
|
||||
AutoRegistry.register_provider(provider)
|
||||
|
||||
# Verify OAuth handler is registered
|
||||
assert "oauth_provider" in AutoRegistry._oauth_handlers
|
||||
assert AutoRegistry._oauth_handlers["oauth_provider"] == TestOAuthHandler
|
||||
|
||||
def test_provider_with_webhook_manager(self):
|
||||
"""Test provider registration with webhook manager."""
|
||||
|
||||
# Create a mock webhook manager
|
||||
class TestWebhookManager(BaseWebhooksManager):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
provider = Provider(
|
||||
name="webhook_provider",
|
||||
oauth_handler=None,
|
||||
webhook_manager=TestWebhookManager,
|
||||
default_credentials=[],
|
||||
base_costs=[],
|
||||
supported_auth_types={"api_key"},
|
||||
)
|
||||
|
||||
AutoRegistry.register_provider(provider)
|
||||
|
||||
# Verify webhook manager is registered
|
||||
assert "webhook_provider" in AutoRegistry._webhook_managers
|
||||
assert AutoRegistry._webhook_managers["webhook_provider"] == TestWebhookManager
|
||||
|
||||
def test_default_credentials_registration(self):
|
||||
"""Test that default credentials are registered."""
|
||||
# Create test credentials
|
||||
from backend.sdk import SecretStr
|
||||
|
||||
cred1 = APIKeyCredentials(
|
||||
id="test-cred-1",
|
||||
provider="test_provider",
|
||||
api_key=SecretStr("test-key-1"),
|
||||
title="Test Credential 1",
|
||||
)
|
||||
cred2 = APIKeyCredentials(
|
||||
id="test-cred-2",
|
||||
provider="test_provider",
|
||||
api_key=SecretStr("test-key-2"),
|
||||
title="Test Credential 2",
|
||||
)
|
||||
|
||||
provider = Provider(
|
||||
name="test_provider",
|
||||
oauth_handler=None,
|
||||
webhook_manager=None,
|
||||
default_credentials=[cred1, cred2],
|
||||
base_costs=[],
|
||||
supported_auth_types={"api_key"},
|
||||
)
|
||||
|
||||
AutoRegistry.register_provider(provider)
|
||||
|
||||
# Verify credentials are registered
|
||||
all_creds = AutoRegistry.get_all_credentials()
|
||||
assert cred1 in all_creds
|
||||
assert cred2 in all_creds
|
||||
|
||||
def test_api_key_registration(self):
|
||||
"""Test API key environment variable registration."""
|
||||
import os
|
||||
|
||||
# Set up a test environment variable
|
||||
os.environ["TEST_API_KEY"] = "test-api-key-value"
|
||||
|
||||
try:
|
||||
AutoRegistry.register_api_key("test_provider", "TEST_API_KEY")
|
||||
|
||||
# Verify the mapping is stored
|
||||
assert AutoRegistry._api_key_mappings["test_provider"] == "TEST_API_KEY"
|
||||
|
||||
# Verify a credential was created
|
||||
all_creds = AutoRegistry.get_all_credentials()
|
||||
test_cred = next(
|
||||
(c for c in all_creds if c.id == "test_provider-default"), None
|
||||
)
|
||||
assert test_cred is not None
|
||||
assert test_cred.provider == "test_provider"
|
||||
assert test_cred.api_key.get_secret_value() == "test-api-key-value" # type: ignore
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
del os.environ["TEST_API_KEY"]
|
||||
|
||||
def test_get_oauth_handlers(self):
|
||||
"""Test retrieving all OAuth handlers."""
|
||||
|
||||
# Register multiple providers with OAuth
|
||||
class TestOAuth1(BaseOAuthHandler):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
class TestOAuth2(BaseOAuthHandler):
|
||||
PROVIDER_NAME = ProviderName.GOOGLE
|
||||
|
||||
from backend.sdk.provider import OAuthConfig
|
||||
|
||||
provider1 = Provider(
|
||||
name="provider1",
|
||||
oauth_config=OAuthConfig(oauth_handler=TestOAuth1),
|
||||
webhook_manager=None,
|
||||
default_credentials=[],
|
||||
base_costs=[],
|
||||
supported_auth_types={"oauth2"},
|
||||
)
|
||||
|
||||
provider2 = Provider(
|
||||
name="provider2",
|
||||
oauth_config=OAuthConfig(oauth_handler=TestOAuth2),
|
||||
webhook_manager=None,
|
||||
default_credentials=[],
|
||||
base_costs=[],
|
||||
supported_auth_types={"oauth2"},
|
||||
)
|
||||
|
||||
AutoRegistry.register_provider(provider1)
|
||||
AutoRegistry.register_provider(provider2)
|
||||
|
||||
handlers = AutoRegistry.get_oauth_handlers()
|
||||
assert "provider1" in handlers
|
||||
assert "provider2" in handlers
|
||||
assert handlers["provider1"] == TestOAuth1
|
||||
assert handlers["provider2"] == TestOAuth2
|
||||
|
||||
def test_block_configuration_registration(self):
|
||||
"""Test registering block configuration."""
|
||||
|
||||
# Create a test block class
|
||||
class TestBlock(Block):
|
||||
pass
|
||||
|
||||
config = BlockConfiguration(
|
||||
provider="test_provider",
|
||||
costs=[],
|
||||
default_credentials=[],
|
||||
webhook_manager=None,
|
||||
oauth_handler=None,
|
||||
)
|
||||
|
||||
AutoRegistry.register_block_configuration(TestBlock, config)
|
||||
|
||||
# Verify it's registered
|
||||
assert TestBlock in AutoRegistry._block_configurations
|
||||
assert AutoRegistry._block_configurations[TestBlock] == config
|
||||
|
||||
def test_clear_registry(self):
|
||||
"""Test clearing all registrations."""
|
||||
# Add some registrations
|
||||
provider = Provider(
|
||||
name="test_provider",
|
||||
oauth_handler=None,
|
||||
webhook_manager=None,
|
||||
default_credentials=[],
|
||||
base_costs=[],
|
||||
supported_auth_types={"api_key"},
|
||||
)
|
||||
AutoRegistry.register_provider(provider)
|
||||
AutoRegistry.register_api_key("test", "TEST_KEY")
|
||||
|
||||
# Clear everything
|
||||
AutoRegistry.clear()
|
||||
|
||||
# Verify everything is cleared
|
||||
assert len(AutoRegistry._providers) == 0
|
||||
assert len(AutoRegistry._default_credentials) == 0
|
||||
assert len(AutoRegistry._oauth_handlers) == 0
|
||||
assert len(AutoRegistry._webhook_managers) == 0
|
||||
assert len(AutoRegistry._block_configurations) == 0
|
||||
assert len(AutoRegistry._api_key_mappings) == 0
|
||||
|
||||
|
||||
class TestAutoRegistryPatching:
|
||||
"""Test the integration patching functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry before each test."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
@patch("backend.integrations.webhooks.load_webhook_managers")
|
||||
def test_webhook_manager_patching(self, mock_load_managers):
|
||||
"""Test that webhook managers are patched into the system."""
|
||||
# Set up the mock to return an empty dict
|
||||
mock_load_managers.return_value = {}
|
||||
|
||||
# Create a test webhook manager
|
||||
class TestWebhookManager(BaseWebhooksManager):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
# Register a provider with webhooks
|
||||
provider = Provider(
|
||||
name="webhook_provider",
|
||||
oauth_handler=None,
|
||||
webhook_manager=TestWebhookManager,
|
||||
default_credentials=[],
|
||||
base_costs=[],
|
||||
supported_auth_types={"api_key"},
|
||||
)
|
||||
|
||||
AutoRegistry.register_provider(provider)
|
||||
|
||||
# Mock the webhooks module
|
||||
mock_webhooks = MagicMock()
|
||||
mock_webhooks.load_webhook_managers = mock_load_managers
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules", {"backend.integrations.webhooks": mock_webhooks}
|
||||
):
|
||||
# Apply patches
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Call the patched function
|
||||
result = mock_webhooks.load_webhook_managers()
|
||||
|
||||
# Verify our webhook manager is included
|
||||
assert "webhook_provider" in result
|
||||
assert result["webhook_provider"] == TestWebhookManager
|
||||
|
||||
|
||||
class TestProviderBuilder:
|
||||
"""Test the ProviderBuilder fluent API."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry before each test."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
def test_basic_provider_builder(self):
|
||||
"""Test building a basic provider."""
|
||||
provider = (
|
||||
ProviderBuilder("test_provider")
|
||||
.with_api_key("TEST_API_KEY", "Test API Key")
|
||||
.build()
|
||||
)
|
||||
|
||||
assert provider.name == "test_provider"
|
||||
assert "api_key" in provider.supported_auth_types
|
||||
assert AutoRegistry.get_provider("test_provider") == provider
|
||||
|
||||
def test_provider_builder_with_oauth(self):
|
||||
"""Test building a provider with OAuth."""
|
||||
|
||||
class TestOAuth(BaseOAuthHandler):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
provider = (
|
||||
ProviderBuilder("oauth_test")
|
||||
.with_oauth(TestOAuth, scopes=["read", "write"])
|
||||
.build()
|
||||
)
|
||||
|
||||
assert provider.oauth_config is not None
|
||||
assert provider.oauth_config.oauth_handler == TestOAuth
|
||||
assert "oauth2" in provider.supported_auth_types
|
||||
|
||||
def test_provider_builder_with_webhook(self):
|
||||
"""Test building a provider with webhook manager."""
|
||||
|
||||
class TestWebhook(BaseWebhooksManager):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
provider = (
|
||||
ProviderBuilder("webhook_test").with_webhook_manager(TestWebhook).build()
|
||||
)
|
||||
|
||||
assert provider.webhook_manager == TestWebhook
|
||||
|
||||
def test_provider_builder_with_base_cost(self):
|
||||
"""Test building a provider with base costs."""
|
||||
from backend.data.cost import BlockCostType
|
||||
|
||||
provider = (
|
||||
ProviderBuilder("cost_test")
|
||||
.with_base_cost(10, BlockCostType.RUN)
|
||||
.with_base_cost(5, BlockCostType.BYTE)
|
||||
.build()
|
||||
)
|
||||
|
||||
assert len(provider.base_costs) == 2
|
||||
assert provider.base_costs[0].cost_amount == 10
|
||||
assert provider.base_costs[0].cost_type == BlockCostType.RUN
|
||||
assert provider.base_costs[1].cost_amount == 5
|
||||
assert provider.base_costs[1].cost_type == BlockCostType.BYTE
|
||||
|
||||
def test_provider_builder_with_api_client(self):
|
||||
"""Test building a provider with API client factory."""
|
||||
|
||||
def mock_client_factory():
|
||||
return Mock()
|
||||
|
||||
provider = (
|
||||
ProviderBuilder("client_test").with_api_client(mock_client_factory).build()
|
||||
)
|
||||
|
||||
assert provider._api_client_factory == mock_client_factory
|
||||
|
||||
def test_provider_builder_with_error_handler(self):
|
||||
"""Test building a provider with error handler."""
|
||||
|
||||
def mock_error_handler(exc: Exception) -> str:
|
||||
return f"Error: {str(exc)}"
|
||||
|
||||
provider = (
|
||||
ProviderBuilder("error_test").with_error_handler(mock_error_handler).build()
|
||||
)
|
||||
|
||||
assert provider._error_handler == mock_error_handler
|
||||
|
||||
def test_provider_builder_complete_example(self):
|
||||
"""Test building a complete provider with all features."""
|
||||
from backend.data.cost import BlockCostType
|
||||
|
||||
class TestOAuth(BaseOAuthHandler):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
class TestWebhook(BaseWebhooksManager):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
def client_factory():
|
||||
return Mock()
|
||||
|
||||
def error_handler(exc):
|
||||
return str(exc)
|
||||
|
||||
provider = (
|
||||
ProviderBuilder("complete_test")
|
||||
.with_api_key("COMPLETE_API_KEY", "Complete API Key")
|
||||
.with_oauth(TestOAuth, scopes=["read"])
|
||||
.with_webhook_manager(TestWebhook)
|
||||
.with_base_cost(100, BlockCostType.RUN)
|
||||
.with_api_client(client_factory)
|
||||
.with_error_handler(error_handler)
|
||||
.with_config(custom_setting="value")
|
||||
.build()
|
||||
)
|
||||
|
||||
# Verify all settings
|
||||
assert provider.name == "complete_test"
|
||||
assert "api_key" in provider.supported_auth_types
|
||||
assert "oauth2" in provider.supported_auth_types
|
||||
assert provider.oauth_config is not None
|
||||
assert provider.oauth_config.oauth_handler == TestOAuth
|
||||
assert provider.webhook_manager == TestWebhook
|
||||
assert len(provider.base_costs) == 1
|
||||
assert provider._api_client_factory == client_factory
|
||||
assert provider._error_handler == error_handler
|
||||
assert provider.get_config("custom_setting") == "value" # from with_config
|
||||
|
||||
# Verify it's registered
|
||||
assert AutoRegistry.get_provider("complete_test") == provider
|
||||
assert "complete_test" in AutoRegistry._oauth_handlers
|
||||
assert "complete_test" in AutoRegistry._webhook_managers
|
||||
|
||||
|
||||
class TestSDKImports:
|
||||
"""Test that all expected exports are available from the SDK."""
|
||||
|
||||
def test_core_block_imports(self):
|
||||
"""Test core block system imports."""
|
||||
from backend.sdk import Block, BlockCategory
|
||||
|
||||
# Just verify they're importable
|
||||
assert Block is not None
|
||||
assert BlockCategory is not None
|
||||
|
||||
def test_schema_imports(self):
|
||||
"""Test schema and model imports."""
|
||||
from backend.sdk import APIKeyCredentials, SchemaField
|
||||
|
||||
assert SchemaField is not None
|
||||
assert APIKeyCredentials is not None
|
||||
|
||||
def test_type_alias_imports(self):
|
||||
"""Test type alias imports are removed."""
|
||||
# Type aliases have been removed from SDK
|
||||
# Users should import from typing or use built-in types directly
|
||||
pass
|
||||
|
||||
def test_cost_system_imports(self):
|
||||
"""Test cost system imports."""
|
||||
from backend.sdk import BlockCost, BlockCostType
|
||||
|
||||
assert BlockCost is not None
|
||||
assert BlockCostType is not None
|
||||
|
||||
def test_utility_imports(self):
|
||||
"""Test utility imports."""
|
||||
from backend.sdk import BaseModel, Requests, json
|
||||
|
||||
assert json is not None
|
||||
assert BaseModel is not None
|
||||
assert Requests is not None
|
||||
|
||||
def test_integration_imports(self):
|
||||
"""Test integration imports."""
|
||||
from backend.sdk import ProviderName
|
||||
|
||||
assert ProviderName is not None
|
||||
|
||||
def test_sdk_component_imports(self):
|
||||
"""Test SDK-specific component imports."""
|
||||
from backend.sdk import AutoRegistry, ProviderBuilder
|
||||
|
||||
assert AutoRegistry is not None
|
||||
assert ProviderBuilder is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
506
autogpt_platform/backend/test/sdk/test_sdk_webhooks.py
Normal file
506
autogpt_platform/backend/test/sdk/test_sdk_webhooks.py
Normal file
@@ -0,0 +1,506 @@
|
||||
"""
|
||||
Tests for SDK webhook functionality.
|
||||
|
||||
This test suite verifies webhook blocks and webhook manager integration.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
AutoRegistry,
|
||||
BaseModel,
|
||||
BaseWebhooksManager,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockWebhookConfig,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
Field,
|
||||
ProviderBuilder,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
|
||||
class TestWebhookTypes(str, Enum):
|
||||
"""Test webhook event types."""
|
||||
|
||||
CREATED = "created"
|
||||
UPDATED = "updated"
|
||||
DELETED = "deleted"
|
||||
|
||||
|
||||
class TestWebhooksManager(BaseWebhooksManager):
|
||||
"""Test webhook manager implementation."""
|
||||
|
||||
PROVIDER_NAME = ProviderName.GITHUB # Reuse for testing
|
||||
|
||||
class WebhookType(str, Enum):
|
||||
TEST = "test"
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(cls, webhook, request):
|
||||
"""Validate incoming webhook payload."""
|
||||
# Mock implementation
|
||||
payload = {"test": "data"}
|
||||
event_type = "test_event"
|
||||
return payload, event_type
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> tuple[str, dict]:
|
||||
"""Register webhook with external service."""
|
||||
# Mock implementation
|
||||
webhook_id = f"test_webhook_{resource}"
|
||||
config = {
|
||||
"webhook_type": webhook_type,
|
||||
"resource": resource,
|
||||
"events": events,
|
||||
"url": ingress_url,
|
||||
}
|
||||
return webhook_id, config
|
||||
|
||||
async def _deregister_webhook(self, webhook, credentials) -> None:
|
||||
"""Deregister webhook from external service."""
|
||||
# Mock implementation
|
||||
pass
|
||||
|
||||
|
||||
class TestWebhookBlock(Block):
|
||||
"""Test webhook block implementation."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="test_webhooks",
|
||||
supported_credential_types={"api_key"},
|
||||
description="Webhook service credentials",
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive webhooks",
|
||||
)
|
||||
resource_id: str = SchemaField(
|
||||
description="Resource to monitor",
|
||||
)
|
||||
events: list[TestWebhookTypes] = SchemaField(
|
||||
description="Events to listen for",
|
||||
default=[TestWebhookTypes.CREATED],
|
||||
)
|
||||
payload: dict = SchemaField(
|
||||
description="Webhook payload",
|
||||
default={},
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webhook_id: str = SchemaField(description="Registered webhook ID")
|
||||
is_active: bool = SchemaField(description="Webhook is active")
|
||||
event_count: int = SchemaField(description="Number of events configured")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="test-webhook-block",
|
||||
description="Test webhook block",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=TestWebhookBlock.Input,
|
||||
output_schema=TestWebhookBlock.Output,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider="test_webhooks", # type: ignore
|
||||
webhook_type="test",
|
||||
resource_format="{resource_id}",
|
||||
),
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Simulate webhook registration
|
||||
webhook_id = f"webhook_{input_data.resource_id}"
|
||||
|
||||
yield "webhook_id", webhook_id
|
||||
yield "is_active", True
|
||||
yield "event_count", len(input_data.events)
|
||||
|
||||
|
||||
class TestWebhookBlockCreation:
|
||||
"""Test creating webhook blocks with the SDK."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test environment."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
# Register a provider with webhook support
|
||||
self.provider = (
|
||||
ProviderBuilder("test_webhooks")
|
||||
.with_api_key("TEST_WEBHOOK_KEY", "Test Webhook API Key")
|
||||
.with_webhook_manager(TestWebhooksManager)
|
||||
.build()
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_webhook_block(self):
|
||||
"""Test creating a basic webhook block."""
|
||||
block = TestWebhookBlock()
|
||||
|
||||
# Verify block configuration
|
||||
assert block.webhook_config is not None
|
||||
assert block.webhook_config.provider == "test_webhooks"
|
||||
assert block.webhook_config.webhook_type == "test"
|
||||
assert "{resource_id}" in block.webhook_config.resource_format # type: ignore
|
||||
|
||||
# Test block execution
|
||||
test_creds = APIKeyCredentials(
|
||||
id="test-webhook-creds",
|
||||
provider="test_webhooks",
|
||||
api_key=SecretStr("test-key"),
|
||||
title="Test Webhook Key",
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
TestWebhookBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "test_webhooks",
|
||||
"id": "test-webhook-creds",
|
||||
"type": "api_key",
|
||||
},
|
||||
webhook_url="https://example.com/webhook",
|
||||
resource_id="resource_123",
|
||||
events=[TestWebhookTypes.CREATED, TestWebhookTypes.UPDATED],
|
||||
),
|
||||
credentials=test_creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["webhook_id"] == "webhook_resource_123"
|
||||
assert outputs["is_active"] is True
|
||||
assert outputs["event_count"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_block_with_filters(self):
|
||||
"""Test webhook block with event filters."""
|
||||
|
||||
class EventFilterModel(BaseModel):
|
||||
include_system: bool = Field(default=False)
|
||||
severity_levels: list[str] = Field(
|
||||
default_factory=lambda: ["info", "warning"]
|
||||
)
|
||||
|
||||
class FilteredWebhookBlock(Block):
|
||||
"""Webhook block with filtering."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="test_webhooks",
|
||||
supported_credential_types={"api_key"},
|
||||
)
|
||||
resource: str = SchemaField(description="Resource to monitor")
|
||||
filters: EventFilterModel = SchemaField(
|
||||
description="Event filters",
|
||||
default_factory=EventFilterModel,
|
||||
)
|
||||
payload: dict = SchemaField(
|
||||
description="Webhook payload",
|
||||
default={},
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webhook_active: bool = SchemaField(description="Webhook active")
|
||||
filter_summary: str = SchemaField(description="Active filters")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="filtered-webhook-block",
|
||||
description="Webhook with filters",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=FilteredWebhookBlock.Input,
|
||||
output_schema=FilteredWebhookBlock.Output,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider="test_webhooks", # type: ignore
|
||||
webhook_type="filtered",
|
||||
resource_format="{resource}",
|
||||
),
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
filters = input_data.filters
|
||||
filter_parts = []
|
||||
|
||||
if filters.include_system:
|
||||
filter_parts.append("system events")
|
||||
|
||||
filter_parts.append(f"{len(filters.severity_levels)} severity levels")
|
||||
|
||||
yield "webhook_active", True
|
||||
yield "filter_summary", ", ".join(filter_parts)
|
||||
|
||||
# Test the block
|
||||
block = FilteredWebhookBlock()
|
||||
|
||||
test_creds = APIKeyCredentials(
|
||||
id="test-creds",
|
||||
provider="test_webhooks",
|
||||
api_key=SecretStr("key"),
|
||||
title="Test Key",
|
||||
)
|
||||
|
||||
# Test with default filters
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
FilteredWebhookBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "test_webhooks",
|
||||
"id": "test-creds",
|
||||
"type": "api_key",
|
||||
},
|
||||
resource="test_resource",
|
||||
),
|
||||
credentials=test_creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["webhook_active"] is True
|
||||
assert "2 severity levels" in outputs["filter_summary"]
|
||||
|
||||
# Test with custom filters
|
||||
custom_filters = EventFilterModel(
|
||||
include_system=True,
|
||||
severity_levels=["error", "critical"],
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
FilteredWebhookBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "test_webhooks",
|
||||
"id": "test-creds",
|
||||
"type": "api_key",
|
||||
},
|
||||
resource="test_resource",
|
||||
filters=custom_filters,
|
||||
),
|
||||
credentials=test_creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert "system events" in outputs["filter_summary"]
|
||||
assert "2 severity levels" in outputs["filter_summary"]
|
||||
|
||||
|
||||
class TestWebhookManagerIntegration:
|
||||
"""Test webhook manager integration with AutoRegistry."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
def test_webhook_manager_registration(self):
|
||||
"""Test that webhook managers are properly registered."""
|
||||
|
||||
# Create multiple webhook managers
|
||||
class WebhookManager1(BaseWebhooksManager):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
class WebhookManager2(BaseWebhooksManager):
|
||||
PROVIDER_NAME = ProviderName.GOOGLE
|
||||
|
||||
# Register providers with webhook managers
|
||||
(
|
||||
ProviderBuilder("webhook_service_1")
|
||||
.with_webhook_manager(WebhookManager1)
|
||||
.build()
|
||||
)
|
||||
|
||||
(
|
||||
ProviderBuilder("webhook_service_2")
|
||||
.with_webhook_manager(WebhookManager2)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Verify registration
|
||||
managers = AutoRegistry.get_webhook_managers()
|
||||
assert "webhook_service_1" in managers
|
||||
assert "webhook_service_2" in managers
|
||||
assert managers["webhook_service_1"] == WebhookManager1
|
||||
assert managers["webhook_service_2"] == WebhookManager2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_block_with_provider_manager(self):
|
||||
"""Test webhook block using a provider's webhook manager."""
|
||||
# Register provider with webhook manager
|
||||
(
|
||||
ProviderBuilder("integrated_webhooks")
|
||||
.with_api_key("INTEGRATED_KEY", "Integrated Webhook Key")
|
||||
.with_webhook_manager(TestWebhooksManager)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Create a block that uses this provider
|
||||
class IntegratedWebhookBlock(Block):
|
||||
"""Block using integrated webhook manager."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="integrated_webhooks",
|
||||
supported_credential_types={"api_key"},
|
||||
)
|
||||
target: str = SchemaField(description="Webhook target")
|
||||
payload: dict = SchemaField(
|
||||
description="Webhook payload",
|
||||
default={},
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Webhook status")
|
||||
manager_type: str = SchemaField(description="Manager type used")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="integrated-webhook-block",
|
||||
description="Uses integrated webhook manager",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=IntegratedWebhookBlock.Input,
|
||||
output_schema=IntegratedWebhookBlock.Output,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider="integrated_webhooks", # type: ignore
|
||||
webhook_type=TestWebhooksManager.WebhookType.TEST,
|
||||
resource_format="{target}",
|
||||
),
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Get the webhook manager for this provider
|
||||
managers = AutoRegistry.get_webhook_managers()
|
||||
manager_class = managers.get("integrated_webhooks")
|
||||
|
||||
yield "status", "configured"
|
||||
yield "manager_type", (
|
||||
manager_class.__name__ if manager_class else "none"
|
||||
)
|
||||
|
||||
# Test the block
|
||||
block = IntegratedWebhookBlock()
|
||||
|
||||
test_creds = APIKeyCredentials(
|
||||
id="integrated-creds",
|
||||
provider="integrated_webhooks",
|
||||
api_key=SecretStr("key"),
|
||||
title="Integrated Key",
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
IntegratedWebhookBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "integrated_webhooks",
|
||||
"id": "integrated-creds",
|
||||
"type": "api_key",
|
||||
},
|
||||
target="test_target",
|
||||
),
|
||||
credentials=test_creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["status"] == "configured"
|
||||
assert outputs["manager_type"] == "TestWebhooksManager"
|
||||
|
||||
|
||||
class TestWebhookEventHandling:
|
||||
"""Test webhook event handling in blocks."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_event_processing_block(self):
|
||||
"""Test a block that processes webhook events."""
|
||||
|
||||
class WebhookEventBlock(Block):
|
||||
"""Block that processes webhook events."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
event_type: str = SchemaField(description="Type of webhook event")
|
||||
payload: dict = SchemaField(description="Webhook payload")
|
||||
verify_signature: bool = SchemaField(
|
||||
description="Whether to verify webhook signature",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
processed: bool = SchemaField(description="Event was processed")
|
||||
event_summary: str = SchemaField(description="Summary of event")
|
||||
action_required: bool = SchemaField(description="Action required")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="webhook-event-processor",
|
||||
description="Processes incoming webhook events",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=WebhookEventBlock.Input,
|
||||
output_schema=WebhookEventBlock.Output,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Process based on event type
|
||||
event_type = input_data.event_type
|
||||
payload = input_data.payload
|
||||
|
||||
if event_type == "created":
|
||||
summary = f"New item created: {payload.get('id', 'unknown')}"
|
||||
action_required = True
|
||||
elif event_type == "updated":
|
||||
summary = f"Item updated: {payload.get('id', 'unknown')}"
|
||||
action_required = False
|
||||
elif event_type == "deleted":
|
||||
summary = f"Item deleted: {payload.get('id', 'unknown')}"
|
||||
action_required = True
|
||||
else:
|
||||
summary = f"Unknown event: {event_type}"
|
||||
action_required = False
|
||||
|
||||
yield "processed", True
|
||||
yield "event_summary", summary
|
||||
yield "action_required", action_required
|
||||
|
||||
# Test the block with different events
|
||||
block = WebhookEventBlock()
|
||||
|
||||
# Test created event
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
WebhookEventBlock.Input(
|
||||
event_type="created",
|
||||
payload={"id": "123", "name": "Test Item"},
|
||||
)
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["processed"] is True
|
||||
assert "New item created: 123" in outputs["event_summary"]
|
||||
assert outputs["action_required"] is True
|
||||
|
||||
# Test updated event
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
WebhookEventBlock.Input(
|
||||
event_type="updated",
|
||||
payload={"id": "456", "changes": ["name", "status"]},
|
||||
)
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["processed"] is True
|
||||
assert "Item updated: 456" in outputs["event_summary"]
|
||||
assert outputs["action_required"] is False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,3 +1,21 @@
|
||||
"""
|
||||
Test Data Creator for AutoGPT Platform
|
||||
|
||||
This script creates test data for the AutoGPT platform database.
|
||||
|
||||
Image/Video URL Domains Used:
|
||||
- Images: picsum.photos (for all image URLs - avatars, store listing images, etc.)
|
||||
- Videos: youtube.com (for store listing video URLs)
|
||||
|
||||
Add these domains to your Next.js config:
|
||||
```javascript
|
||||
// next.config.js
|
||||
images: {
|
||||
domains: ['picsum.photos'],
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from datetime import datetime
|
||||
@@ -14,6 +32,7 @@ from prisma.types import (
|
||||
AnalyticsMetricsCreateInput,
|
||||
APIKeyCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
IntegrationWebhookCreateInput,
|
||||
ProfileCreateInput,
|
||||
StoreListingReviewCreateInput,
|
||||
UserCreateInput,
|
||||
@@ -53,10 +72,26 @@ MAX_REVIEWS_PER_VERSION = 5 # Total reviews depends on number of versions creat
|
||||
|
||||
|
||||
def get_image():
|
||||
url = faker.image_url()
|
||||
while "placekitten.com" in url:
|
||||
url = faker.image_url()
|
||||
return url
|
||||
"""Generate a consistent image URL using picsum.photos service."""
|
||||
width = random.choice([200, 300, 400, 500, 600, 800])
|
||||
height = random.choice([200, 300, 400, 500, 600, 800])
|
||||
# Use a random seed to get different images
|
||||
seed = random.randint(1, 1000)
|
||||
return f"https://picsum.photos/seed/{seed}/{width}/{height}"
|
||||
|
||||
|
||||
def get_video_url():
|
||||
"""Generate a consistent video URL using a placeholder service."""
|
||||
# Using YouTube as a consistent source for video URLs
|
||||
video_ids = [
|
||||
"dQw4w9WgXcQ", # Example video IDs
|
||||
"9bZkp7q19f0",
|
||||
"kJQP7kiw5Fk",
|
||||
"RgKAFK5djSk",
|
||||
"L_jWHffIx5E",
|
||||
]
|
||||
video_id = random.choice(video_ids)
|
||||
return f"https://www.youtube.com/watch?v={video_id}"
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -147,12 +182,27 @@ async def main():
|
||||
)
|
||||
agent_presets.append(preset)
|
||||
|
||||
# Insert UserAgents
|
||||
user_agents = []
|
||||
print(f"Inserting {NUM_USERS * MAX_AGENTS_PER_USER} user agents")
|
||||
# Insert Profiles first (before LibraryAgents)
|
||||
profiles = []
|
||||
print(f"Inserting {NUM_USERS} profiles")
|
||||
for user in users:
|
||||
profile = await db.profile.create(
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
name=user.name or faker.name(),
|
||||
username=faker.unique.user_name(),
|
||||
description=faker.text(),
|
||||
links=[faker.url() for _ in range(3)],
|
||||
avatarUrl=get_image(),
|
||||
)
|
||||
)
|
||||
profiles.append(profile)
|
||||
|
||||
# Insert LibraryAgents
|
||||
library_agents = []
|
||||
print("Inserting library agents")
|
||||
for user in users:
|
||||
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
|
||||
|
||||
# Get a shuffled list of graphs to ensure uniqueness per user
|
||||
available_graphs = agent_graphs.copy()
|
||||
random.shuffle(available_graphs)
|
||||
@@ -162,18 +212,27 @@ async def main():
|
||||
|
||||
for i in range(num_agents):
|
||||
graph = available_graphs[i] # Use unique graph for each library agent
|
||||
user_agent = await db.libraryagent.create(
|
||||
|
||||
# Get creator profile for this graph's owner
|
||||
creator_profile = next(
|
||||
(p for p in profiles if p.userId == graph.userId), None
|
||||
)
|
||||
|
||||
library_agent = await db.libraryagent.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"creatorId": creator_profile.id if creator_profile else None,
|
||||
"imageUrl": get_image() if random.random() < 0.5 else None,
|
||||
"useGraphIsActiveVersion": random.choice([True, False]),
|
||||
"isFavorite": random.choice([True, False]),
|
||||
"isCreatedByUser": random.choice([True, False]),
|
||||
"isArchived": random.choice([True, False]),
|
||||
"isDeleted": random.choice([True, False]),
|
||||
}
|
||||
)
|
||||
user_agents.append(user_agent)
|
||||
library_agents.append(library_agent)
|
||||
|
||||
# Insert AgentGraphExecutions
|
||||
agent_graph_executions = []
|
||||
@@ -325,25 +384,9 @@ async def main():
|
||||
)
|
||||
)
|
||||
|
||||
# Insert Profiles
|
||||
profiles = []
|
||||
print(f"Inserting {NUM_USERS} profiles")
|
||||
for user in users:
|
||||
profile = await db.profile.create(
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
name=user.name or faker.name(),
|
||||
username=faker.unique.user_name(),
|
||||
description=faker.text(),
|
||||
links=[faker.url() for _ in range(3)],
|
||||
avatarUrl=get_image(),
|
||||
)
|
||||
)
|
||||
profiles.append(profile)
|
||||
|
||||
# Insert StoreListings
|
||||
store_listings = []
|
||||
print(f"Inserting {NUM_USERS} store listings")
|
||||
print("Inserting store listings")
|
||||
for graph in agent_graphs:
|
||||
user = random.choice(users)
|
||||
slug = faker.slug()
|
||||
@@ -360,7 +403,7 @@ async def main():
|
||||
|
||||
# Insert StoreListingVersions
|
||||
store_listing_versions = []
|
||||
print(f"Inserting {NUM_USERS} store listing versions")
|
||||
print("Inserting store listing versions")
|
||||
for listing in store_listings:
|
||||
graph = [g for g in agent_graphs if g.id == listing.agentGraphId][0]
|
||||
version = await db.storelistingversion.create(
|
||||
@@ -369,7 +412,7 @@ async def main():
|
||||
"agentGraphVersion": graph.version,
|
||||
"name": graph.name or faker.sentence(nb_words=3),
|
||||
"subHeading": faker.sentence(),
|
||||
"videoUrl": faker.url(),
|
||||
"videoUrl": get_video_url() if random.random() < 0.3 else None,
|
||||
"imageUrls": [get_image() for _ in range(3)],
|
||||
"description": faker.text(),
|
||||
"categories": [faker.word() for _ in range(3)],
|
||||
@@ -388,7 +431,7 @@ async def main():
|
||||
store_listing_versions.append(version)
|
||||
|
||||
# Insert StoreListingReviews
|
||||
print(f"Inserting {NUM_USERS * MAX_REVIEWS_PER_VERSION} store listing reviews")
|
||||
print("Inserting store listing reviews")
|
||||
for version in store_listing_versions:
|
||||
# Create a copy of users list and shuffle it to avoid duplicates
|
||||
available_reviewers = users.copy()
|
||||
@@ -411,26 +454,92 @@ async def main():
|
||||
)
|
||||
)
|
||||
|
||||
# Update StoreListingVersions with submission status (StoreListingSubmissions table no longer exists)
|
||||
print(f"Updating {NUM_USERS} store listing versions with submission status")
|
||||
for version in store_listing_versions:
|
||||
reviewer = random.choice(users)
|
||||
status: prisma.enums.SubmissionStatus = random.choice(
|
||||
[
|
||||
prisma.enums.SubmissionStatus.PENDING,
|
||||
prisma.enums.SubmissionStatus.APPROVED,
|
||||
prisma.enums.SubmissionStatus.REJECTED,
|
||||
]
|
||||
)
|
||||
await db.storelistingversion.update(
|
||||
where={"id": version.id},
|
||||
data={
|
||||
"submissionStatus": status,
|
||||
"Reviewer": {"connect": {"id": reviewer.id}},
|
||||
"reviewComments": faker.text(),
|
||||
"reviewedAt": datetime.now(),
|
||||
},
|
||||
)
|
||||
# Insert UserOnboarding for some users
|
||||
print("Inserting user onboarding data")
|
||||
for user in random.sample(
|
||||
users, k=int(NUM_USERS * 0.7)
|
||||
): # 70% of users have onboarding data
|
||||
completed_steps = []
|
||||
possible_steps = list(prisma.enums.OnboardingStep)
|
||||
# Randomly complete some steps
|
||||
if random.random() < 0.8:
|
||||
num_steps = random.randint(1, len(possible_steps))
|
||||
completed_steps = random.sample(possible_steps, k=num_steps)
|
||||
|
||||
try:
|
||||
await db.useronboarding.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"completedSteps": completed_steps,
|
||||
"notificationDot": random.choice([True, False]),
|
||||
"notified": (
|
||||
random.sample(completed_steps, k=min(3, len(completed_steps)))
|
||||
if completed_steps
|
||||
else []
|
||||
),
|
||||
"rewardedFor": (
|
||||
random.sample(completed_steps, k=min(2, len(completed_steps)))
|
||||
if completed_steps
|
||||
else []
|
||||
),
|
||||
"usageReason": (
|
||||
random.choice(["personal", "business", "research", "learning"])
|
||||
if random.random() < 0.7
|
||||
else None
|
||||
),
|
||||
"integrations": random.sample(
|
||||
["github", "google", "discord", "slack"], k=random.randint(0, 2)
|
||||
),
|
||||
"otherIntegrations": (
|
||||
faker.word() if random.random() < 0.2 else None
|
||||
),
|
||||
"selectedStoreListingVersionId": (
|
||||
random.choice(store_listing_versions).id
|
||||
if store_listing_versions and random.random() < 0.5
|
||||
else None
|
||||
),
|
||||
"agentInput": (
|
||||
Json({"test": "data"}) if random.random() < 0.3 else None
|
||||
),
|
||||
"onboardingAgentExecutionId": (
|
||||
random.choice(agent_graph_executions).id
|
||||
if agent_graph_executions and random.random() < 0.3
|
||||
else None
|
||||
),
|
||||
"agentRuns": random.randint(0, 10),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error creating onboarding for user {user.id}: {e}")
|
||||
# Try simpler version
|
||||
await db.useronboarding.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
}
|
||||
)
|
||||
|
||||
# Insert IntegrationWebhooks for some users
|
||||
print("Inserting integration webhooks")
|
||||
for user in random.sample(
|
||||
users, k=int(NUM_USERS * 0.3)
|
||||
): # 30% of users have webhooks
|
||||
for _ in range(random.randint(1, 3)):
|
||||
await db.integrationwebhook.create(
|
||||
data=IntegrationWebhookCreateInput(
|
||||
userId=user.id,
|
||||
provider=random.choice(["github", "slack", "discord"]),
|
||||
credentialsId=str(faker.uuid4()),
|
||||
webhookType=random.choice(["repo", "channel", "server"]),
|
||||
resource=faker.slug(),
|
||||
events=[
|
||||
random.choice(["created", "updated", "deleted"])
|
||||
for _ in range(random.randint(1, 3))
|
||||
],
|
||||
config=prisma.Json({"url": faker.url()}),
|
||||
secret=str(faker.sha256()),
|
||||
providerWebhookId=str(faker.uuid4()),
|
||||
)
|
||||
)
|
||||
|
||||
# Insert APIKeys
|
||||
print(f"Inserting {NUM_USERS} api keys")
|
||||
@@ -451,7 +560,12 @@ async def main():
|
||||
)
|
||||
)
|
||||
|
||||
# Refresh materialized views
|
||||
print("Refreshing materialized views...")
|
||||
await db.execute_raw("SELECT refresh_store_materialized_views();")
|
||||
|
||||
await db.disconnect()
|
||||
print("Test data creation completed successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
323
autogpt_platform/backend/test/test_data_updater.py
Executable file
323
autogpt_platform/backend/test/test_data_updater.py
Executable file
@@ -0,0 +1,323 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Data Updater for Store Materialized Views
|
||||
|
||||
This script updates existing test data to trigger changes in the materialized views:
|
||||
- mv_agent_run_counts: Updated by creating new AgentGraphExecution records
|
||||
- mv_review_stats: Updated by creating new StoreListingReview records
|
||||
|
||||
Run this after test_data_creator.py to test that materialized views update correctly.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import prisma.enums
|
||||
from faker import Faker
|
||||
from prisma import Json, Prisma
|
||||
|
||||
faker = Faker()
|
||||
|
||||
|
||||
async def main():
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
|
||||
print("Starting test data updates for materialized views...")
|
||||
print("=" * 60)
|
||||
|
||||
# Get existing data
|
||||
users = await db.user.find_many(take=50)
|
||||
agent_graphs = await db.agentgraph.find_many(where={"isActive": True}, take=50)
|
||||
store_listings = await db.storelisting.find_many(
|
||||
where={"hasApprovedVersion": True}, include={"Versions": True}, take=30
|
||||
)
|
||||
agent_nodes = await db.agentnode.find_many(take=100)
|
||||
|
||||
if not all([users, agent_graphs, store_listings]):
|
||||
print(
|
||||
"ERROR: Not enough test data found. Please run test_data_creator.py first."
|
||||
)
|
||||
await db.disconnect()
|
||||
return
|
||||
|
||||
print(
|
||||
f"Found {len(users)} users, {len(agent_graphs)} graphs, {len(store_listings)} store listings"
|
||||
)
|
||||
print()
|
||||
|
||||
# 1. Add new AgentGraphExecutions to update mv_agent_run_counts
|
||||
print("1. Adding new agent graph executions...")
|
||||
print("-" * 40)
|
||||
|
||||
new_executions_count = 0
|
||||
execution_data = []
|
||||
|
||||
for graph in random.sample(agent_graphs, min(20, len(agent_graphs))):
|
||||
# Add 5-15 new executions per selected graph
|
||||
num_new_executions = random.randint(5, 15)
|
||||
for _ in range(num_new_executions):
|
||||
user = random.choice(users)
|
||||
execution_data.append(
|
||||
{
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"userId": user.id,
|
||||
"executionStatus": random.choice(
|
||||
[
|
||||
prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||
prisma.enums.AgentExecutionStatus.FAILED,
|
||||
prisma.enums.AgentExecutionStatus.RUNNING,
|
||||
]
|
||||
),
|
||||
"startedAt": faker.date_time_between(
|
||||
start_date="-7d", end_date="now"
|
||||
),
|
||||
"stats": Json(
|
||||
{
|
||||
"duration": random.randint(100, 5000),
|
||||
"blocks_executed": random.randint(1, 10),
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
new_executions_count += 1
|
||||
|
||||
# Batch create executions
|
||||
await db.agentgraphexecution.create_many(data=execution_data)
|
||||
print(f"✓ Created {new_executions_count} new executions")
|
||||
|
||||
# Get the created executions for node executions
|
||||
recent_executions = await db.agentgraphexecution.find_many(
|
||||
take=new_executions_count, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
# 2. Add corresponding AgentNodeExecutions
|
||||
print("\n2. Adding agent node executions...")
|
||||
print("-" * 40)
|
||||
|
||||
node_execution_data = []
|
||||
for execution in recent_executions:
|
||||
# Get nodes for this graph
|
||||
graph_nodes = [
|
||||
n for n in agent_nodes if n.agentGraphId == execution.agentGraphId
|
||||
]
|
||||
if graph_nodes:
|
||||
for node in random.sample(graph_nodes, min(3, len(graph_nodes))):
|
||||
node_execution_data.append(
|
||||
{
|
||||
"agentGraphExecutionId": execution.id,
|
||||
"agentNodeId": node.id,
|
||||
"executionStatus": execution.executionStatus,
|
||||
"addedTime": datetime.now(),
|
||||
"startedTime": datetime.now()
|
||||
- timedelta(minutes=random.randint(1, 10)),
|
||||
"endedTime": (
|
||||
datetime.now()
|
||||
if execution.executionStatus
|
||||
== prisma.enums.AgentExecutionStatus.COMPLETED
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
await db.agentnodeexecution.create_many(data=node_execution_data)
|
||||
print(f"✓ Created {len(node_execution_data)} node executions")
|
||||
|
||||
# 3. Add new StoreListingReviews to update mv_review_stats
|
||||
print("\n3. Adding new store listing reviews...")
|
||||
print("-" * 40)
|
||||
|
||||
new_reviews_count = 0
|
||||
|
||||
for listing in store_listings:
|
||||
if not listing.Versions:
|
||||
continue
|
||||
|
||||
# Get approved versions
|
||||
approved_versions = [
|
||||
v
|
||||
for v in listing.Versions
|
||||
if v.submissionStatus == prisma.enums.SubmissionStatus.APPROVED
|
||||
]
|
||||
if not approved_versions:
|
||||
continue
|
||||
|
||||
# Pick a version to add reviews to
|
||||
version = random.choice(approved_versions)
|
||||
|
||||
# Get existing reviews for this version to avoid duplicates
|
||||
existing_reviews = await db.storelistingreview.find_many(
|
||||
where={"storeListingVersionId": version.id}
|
||||
)
|
||||
existing_reviewer_ids = {r.reviewByUserId for r in existing_reviews}
|
||||
|
||||
# Find users who haven't reviewed this version yet
|
||||
available_reviewers = [u for u in users if u.id not in existing_reviewer_ids]
|
||||
|
||||
if available_reviewers:
|
||||
# Add 2-5 new reviews
|
||||
num_new_reviews = min(random.randint(2, 5), len(available_reviewers))
|
||||
selected_reviewers = random.sample(available_reviewers, num_new_reviews)
|
||||
|
||||
for reviewer in selected_reviewers:
|
||||
# Bias towards positive reviews (4-5 stars)
|
||||
score = random.choices([1, 2, 3, 4, 5], weights=[5, 10, 20, 40, 25])[0]
|
||||
|
||||
await db.storelistingreview.create(
|
||||
data={
|
||||
"storeListingVersionId": version.id,
|
||||
"reviewByUserId": reviewer.id,
|
||||
"score": score,
|
||||
"comments": (
|
||||
faker.text(max_nb_chars=200)
|
||||
if random.random() < 0.7
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
new_reviews_count += 1
|
||||
|
||||
print(f"✓ Created {new_reviews_count} new reviews")
|
||||
|
||||
# 4. Update some store listing versions (change categories, featured status)
|
||||
print("\n4. Updating store listing versions...")
|
||||
print("-" * 40)
|
||||
|
||||
updates_count = 0
|
||||
for listing in random.sample(store_listings, min(10, len(store_listings))):
|
||||
if listing.Versions:
|
||||
version = random.choice(listing.Versions)
|
||||
if version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
|
||||
# Toggle featured status or update categories
|
||||
new_categories = random.sample(
|
||||
[
|
||||
"productivity",
|
||||
"ai",
|
||||
"automation",
|
||||
"data",
|
||||
"social",
|
||||
"marketing",
|
||||
"development",
|
||||
"analytics",
|
||||
],
|
||||
k=random.randint(2, 4),
|
||||
)
|
||||
|
||||
await db.storelistingversion.update(
|
||||
where={"id": version.id},
|
||||
data={
|
||||
"isFeatured": (
|
||||
not version.isFeatured
|
||||
if random.random() < 0.3
|
||||
else version.isFeatured
|
||||
),
|
||||
"categories": new_categories,
|
||||
"updatedAt": datetime.now(),
|
||||
},
|
||||
)
|
||||
updates_count += 1
|
||||
|
||||
print(f"✓ Updated {updates_count} store listing versions")
|
||||
|
||||
# 5. Create some new credit transactions
|
||||
print("\n5. Adding credit transactions...")
|
||||
print("-" * 40)
|
||||
|
||||
transaction_count = 0
|
||||
for user in random.sample(users, min(30, len(users))):
|
||||
# Add 1-3 transactions per user
|
||||
for _ in range(random.randint(1, 3)):
|
||||
transaction_type = random.choice(
|
||||
[
|
||||
prisma.enums.CreditTransactionType.USAGE,
|
||||
prisma.enums.CreditTransactionType.TOP_UP,
|
||||
prisma.enums.CreditTransactionType.GRANT,
|
||||
]
|
||||
)
|
||||
|
||||
amount = (
|
||||
random.randint(10, 500)
|
||||
if transaction_type == prisma.enums.CreditTransactionType.TOP_UP
|
||||
else -random.randint(1, 50)
|
||||
)
|
||||
|
||||
await db.credittransaction.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"amount": amount,
|
||||
"type": transaction_type,
|
||||
"metadata": Json(
|
||||
{
|
||||
"source": "test_updater",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
transaction_count += 1
|
||||
|
||||
print(f"✓ Created {transaction_count} credit transactions")
|
||||
|
||||
# 6. Refresh materialized views
|
||||
print("\n6. Refreshing materialized views...")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
await db.execute_raw("SELECT refresh_store_materialized_views();")
|
||||
print("✓ Materialized views refreshed successfully")
|
||||
except Exception as e:
|
||||
print(f"⚠ Warning: Could not refresh materialized views: {e}")
|
||||
print(
|
||||
" You may need to refresh them manually with: SELECT refresh_store_materialized_views();"
|
||||
)
|
||||
|
||||
# 7. Verify the updates
|
||||
print("\n7. Verifying updates...")
|
||||
print("-" * 40)
|
||||
|
||||
# Check agent run counts
|
||||
run_counts = await db.query_raw(
|
||||
"SELECT COUNT(*) as view_count FROM mv_agent_run_counts"
|
||||
)
|
||||
print(f"✓ mv_agent_run_counts has {run_counts[0]['view_count']} entries")
|
||||
|
||||
# Check review stats
|
||||
review_stats = await db.query_raw(
|
||||
"SELECT COUNT(*) as view_count FROM mv_review_stats"
|
||||
)
|
||||
print(f"✓ mv_review_stats has {review_stats[0]['view_count']} entries")
|
||||
|
||||
# Sample some data from the views
|
||||
print("\nSample data from materialized views:")
|
||||
|
||||
sample_runs = await db.query_raw(
|
||||
"SELECT * FROM mv_agent_run_counts ORDER BY run_count DESC LIMIT 5"
|
||||
)
|
||||
print("\nTop 5 agents by run count:")
|
||||
for row in sample_runs:
|
||||
print(f" - Agent {row['agentGraphId'][:8]}...: {row['run_count']} runs")
|
||||
|
||||
sample_reviews = await db.query_raw(
|
||||
"SELECT * FROM mv_review_stats ORDER BY avg_rating DESC NULLS LAST LIMIT 5"
|
||||
)
|
||||
print("\nTop 5 store listings by rating:")
|
||||
for row in sample_reviews:
|
||||
avg_rating = row["avg_rating"] if row["avg_rating"] is not None else 0.0
|
||||
print(
|
||||
f" - Listing {row['storeListingId'][:8]}...: {avg_rating:.2f} ⭐ ({row['review_count']} reviews)"
|
||||
)
|
||||
|
||||
await db.disconnect()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Test data update completed successfully!")
|
||||
print("The materialized views should now reflect the updated data.")
|
||||
print(
|
||||
"\nTo manually refresh views, run: SELECT refresh_store_materialized_views();"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -11,8 +11,6 @@ const nextConfig = {
|
||||
|
||||
"ideogram.ai", // for generated images
|
||||
"picsum.photos", // for placeholder images
|
||||
"dummyimage.com", // for placeholder images
|
||||
"placekitten.com", // for placeholder images
|
||||
],
|
||||
},
|
||||
output: "standalone",
|
||||
@@ -30,6 +28,11 @@ export default isDevelopmentBuild
|
||||
org: "significant-gravitas",
|
||||
project: "builder",
|
||||
|
||||
// Expose Vercel env to the client
|
||||
env: {
|
||||
NEXT_PUBLIC_VERCEL_ENV: process.env.VERCEL_ENV,
|
||||
},
|
||||
|
||||
// Only print logs for uploading source maps in CI
|
||||
silent: !process.env.CI,
|
||||
|
||||
|
||||
@@ -54,7 +54,6 @@
|
||||
"@supabase/supabase-js": "2.50.3",
|
||||
"@tanstack/react-query": "5.81.5",
|
||||
"@tanstack/react-table": "8.21.3",
|
||||
"@tanstack/react-virtual": "3.13.12",
|
||||
"@types/jaro-winkler": "0.2.4",
|
||||
"@xyflow/react": "12.8.1",
|
||||
"ajv": "8.17.1",
|
||||
@@ -76,6 +75,7 @@
|
||||
"moment": "2.30.1",
|
||||
"next": "15.3.5",
|
||||
"next-themes": "0.4.6",
|
||||
"nuqs": "2.4.3",
|
||||
"party-js": "2.2.0",
|
||||
"react": "18.3.1",
|
||||
"react-day-picker": "9.8.0",
|
||||
@@ -88,6 +88,7 @@
|
||||
"react-shepherd": "6.1.8",
|
||||
"recharts": "2.15.3",
|
||||
"shepherd.js": "14.5.0",
|
||||
"sonner": "2.0.6",
|
||||
"tailwind-merge": "2.6.0",
|
||||
"tailwindcss-animate": "1.0.7",
|
||||
"uuid": "11.1.0",
|
||||
|
||||
83
autogpt_platform/frontend/pnpm-lock.yaml
generated
83
autogpt_platform/frontend/pnpm-lock.yaml
generated
@@ -92,9 +92,6 @@ importers:
|
||||
'@tanstack/react-table':
|
||||
specifier: 8.21.3
|
||||
version: 8.21.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
'@tanstack/react-virtual':
|
||||
specifier: 3.13.12
|
||||
version: 3.13.12(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
'@types/jaro-winkler':
|
||||
specifier: 0.2.4
|
||||
version: 0.2.4
|
||||
@@ -158,6 +155,9 @@ importers:
|
||||
next-themes:
|
||||
specifier: 0.4.6
|
||||
version: 0.4.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
nuqs:
|
||||
specifier: 2.4.3
|
||||
version: 2.4.3(next@15.3.5(@babel/core@7.28.0)(@opentelemetry/api@1.9.0)(@playwright/test@1.53.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
||||
party-js:
|
||||
specifier: 2.2.0
|
||||
version: 2.2.0
|
||||
@@ -194,6 +194,9 @@ importers:
|
||||
shepherd.js:
|
||||
specifier: 14.5.0
|
||||
version: 14.5.0
|
||||
sonner:
|
||||
specifier: 2.0.6
|
||||
version: 2.0.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
tailwind-merge:
|
||||
specifier: 2.6.0
|
||||
version: 2.6.0
|
||||
@@ -2751,19 +2754,10 @@ packages:
|
||||
react: '>=16.8'
|
||||
react-dom: '>=16.8'
|
||||
|
||||
'@tanstack/react-virtual@3.13.12':
|
||||
resolution: {integrity: sha512-Gd13QdxPSukP8ZrkbgS2RwoZseTTbQPLnQEn7HY/rqtM+8Zt95f7xKC7N0EsKs7aoz0WzZ+fditZux+F8EzYxA==}
|
||||
peerDependencies:
|
||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
|
||||
react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
|
||||
|
||||
'@tanstack/table-core@8.21.3':
|
||||
resolution: {integrity: sha512-ldZXEhOBb8Is7xLs01fR3YEc3DERiz5silj8tnGkFZytt1abEvl/GhUmCE0PMLaMPTa3Jk4HbKmRlHmu+gCftg==}
|
||||
engines: {node: '>=12'}
|
||||
|
||||
'@tanstack/virtual-core@3.13.12':
|
||||
resolution: {integrity: sha512-1YBOJfRHV4sXUmWsFSf5rQor4Ss82G8dQWLRbnk3GA4jeP8hQt1hxXh0tmflpC0dz3VgEv/1+qwPyLeWkQuPFA==}
|
||||
|
||||
'@testing-library/dom@10.4.0':
|
||||
resolution: {integrity: sha512-pemlzrSESWbdAloYml3bAJMEfNh1Z7EduzqPKprCH5S341frlpYnUEW0H72dLxa6IsYr+mPno20GiSm+h9dEdQ==}
|
||||
engines: {node: '>=18'}
|
||||
@@ -5338,6 +5332,9 @@ packages:
|
||||
resolution: {integrity: sha512-qOOzS1cBTWYF4BH8fVePDBOO9iptMnGUEZwNc/cMWnTV2nVLZ7VoNWEPHkYczZA0pdoA7dl6e7FL659nX9S2aw==}
|
||||
engines: {node: '>=16 || 14 >=14.17'}
|
||||
|
||||
mitt@3.0.1:
|
||||
resolution: {integrity: sha512-vKivATfr97l2/QBCYAkXYDbrIWPM2IIKEl7YPhjCvKlG3kE2gm+uBo6nEXK3M5/Ffh/FLpKExzOQ3JJoJGFKBw==}
|
||||
|
||||
module-details-from-path@1.0.4:
|
||||
resolution: {integrity: sha512-EGWKgxALGMgzvxYF1UyGTy0HXX/2vHLkw6+NvDKW2jypWbHpjQuj4UMcqQWXHERJhVGKikolT06G3bcKe4fi7w==}
|
||||
|
||||
@@ -5468,6 +5465,24 @@ packages:
|
||||
nth-check@2.1.1:
|
||||
resolution: {integrity: sha512-lqjrjmaOoAnWfMmBPL+XNnynZh2+swxiX3WUE0s4yEHI6m+AwrK2UZOimIRl3X/4QctVqS8AiZjFqyOGrMXb/w==}
|
||||
|
||||
nuqs@2.4.3:
|
||||
resolution: {integrity: sha512-BgtlYpvRwLYiJuWzxt34q2bXu/AIS66sLU1QePIMr2LWkb+XH0vKXdbLSgn9t6p7QKzwI7f38rX3Wl9llTXQ8Q==}
|
||||
peerDependencies:
|
||||
'@remix-run/react': '>=2'
|
||||
next: '>=14.2.0'
|
||||
react: '>=18.2.0 || ^19.0.0-0'
|
||||
react-router: ^6 || ^7
|
||||
react-router-dom: ^6 || ^7
|
||||
peerDependenciesMeta:
|
||||
'@remix-run/react':
|
||||
optional: true
|
||||
next:
|
||||
optional: true
|
||||
react-router:
|
||||
optional: true
|
||||
react-router-dom:
|
||||
optional: true
|
||||
|
||||
oas-kit-common@1.0.8:
|
||||
resolution: {integrity: sha512-pJTS2+T0oGIwgjGpw7sIRU8RQMcUoKCDWFLdBqKB2BNmGpbBMH2sdqAaOXUg8OzonZHU0L7vfJu1mJFEiYDWOQ==}
|
||||
|
||||
@@ -6388,6 +6403,12 @@ packages:
|
||||
resolution: {integrity: sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==}
|
||||
engines: {node: '>=8'}
|
||||
|
||||
sonner@2.0.6:
|
||||
resolution: {integrity: sha512-yHFhk8T/DK3YxjFQXIrcHT1rGEeTLliVzWbO0xN8GberVun2RiBnxAjXAYpZrqwEVHBG9asI/Li8TAAhN9m59Q==}
|
||||
peerDependencies:
|
||||
react: ^18.0.0 || ^19.0.0 || ^19.0.0-rc
|
||||
react-dom: ^18.0.0 || ^19.0.0 || ^19.0.0-rc
|
||||
|
||||
source-map-js@1.2.1:
|
||||
resolution: {integrity: sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==}
|
||||
engines: {node: '>=0.10.0'}
|
||||
@@ -9994,16 +10015,8 @@ snapshots:
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
|
||||
'@tanstack/react-virtual@3.13.12(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
|
||||
dependencies:
|
||||
'@tanstack/virtual-core': 3.13.12
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
|
||||
'@tanstack/table-core@8.21.3': {}
|
||||
|
||||
'@tanstack/virtual-core@3.13.12': {}
|
||||
|
||||
'@testing-library/dom@10.4.0':
|
||||
dependencies:
|
||||
'@babel/code-frame': 7.27.1
|
||||
@@ -11653,8 +11666,8 @@ snapshots:
|
||||
'@typescript-eslint/parser': 8.36.0(eslint@8.57.1)(typescript@5.8.3)
|
||||
eslint: 8.57.1
|
||||
eslint-import-resolver-node: 0.3.9
|
||||
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint@8.57.1))(eslint@8.57.1)
|
||||
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
|
||||
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1)
|
||||
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
|
||||
eslint-plugin-jsx-a11y: 6.10.2(eslint@8.57.1)
|
||||
eslint-plugin-react: 7.37.5(eslint@8.57.1)
|
||||
eslint-plugin-react-hooks: 5.2.0(eslint@8.57.1)
|
||||
@@ -11673,7 +11686,7 @@ snapshots:
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint@8.57.1))(eslint@8.57.1):
|
||||
eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1):
|
||||
dependencies:
|
||||
'@nolyfill/is-core-module': 1.0.39
|
||||
debug: 4.4.1
|
||||
@@ -11684,22 +11697,22 @@ snapshots:
|
||||
tinyglobby: 0.2.14
|
||||
unrs-resolver: 1.11.0
|
||||
optionalDependencies:
|
||||
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
|
||||
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
eslint-module-utils@2.12.1(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1):
|
||||
eslint-module-utils@2.12.1(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1):
|
||||
dependencies:
|
||||
debug: 3.2.7
|
||||
optionalDependencies:
|
||||
'@typescript-eslint/parser': 8.36.0(eslint@8.57.1)(typescript@5.8.3)
|
||||
eslint: 8.57.1
|
||||
eslint-import-resolver-node: 0.3.9
|
||||
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint@8.57.1))(eslint@8.57.1)
|
||||
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1):
|
||||
eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1):
|
||||
dependencies:
|
||||
'@rtsao/scc': 1.1.0
|
||||
array-includes: 3.1.9
|
||||
@@ -11710,7 +11723,7 @@ snapshots:
|
||||
doctrine: 2.1.0
|
||||
eslint: 8.57.1
|
||||
eslint-import-resolver-node: 0.3.9
|
||||
eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
|
||||
eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
|
||||
hasown: 2.0.2
|
||||
is-core-module: 2.16.1
|
||||
is-glob: 4.0.3
|
||||
@@ -13009,6 +13022,8 @@ snapshots:
|
||||
|
||||
minipass@7.1.2: {}
|
||||
|
||||
mitt@3.0.1: {}
|
||||
|
||||
module-details-from-path@1.0.4: {}
|
||||
|
||||
moment@2.30.1: {}
|
||||
@@ -13171,6 +13186,13 @@ snapshots:
|
||||
dependencies:
|
||||
boolbase: 1.0.0
|
||||
|
||||
nuqs@2.4.3(next@15.3.5(@babel/core@7.28.0)(@opentelemetry/api@1.9.0)(@playwright/test@1.53.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1):
|
||||
dependencies:
|
||||
mitt: 3.0.1
|
||||
react: 18.3.1
|
||||
optionalDependencies:
|
||||
next: 15.3.5(@babel/core@7.28.0)(@opentelemetry/api@1.9.0)(@playwright/test@1.53.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
|
||||
oas-kit-common@1.0.8:
|
||||
dependencies:
|
||||
fast-safe-stringify: 2.1.1
|
||||
@@ -14214,6 +14236,11 @@ snapshots:
|
||||
|
||||
slash@3.0.0: {}
|
||||
|
||||
sonner@2.0.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
||||
dependencies:
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
|
||||
source-map-js@1.2.1: {}
|
||||
|
||||
source-map-support@0.5.21:
|
||||
|
||||
6
autogpt_platform/frontend/public/google-logo.svg
Normal file
6
autogpt_platform/frontend/public/google-logo.svg
Normal file
@@ -0,0 +1,6 @@
|
||||
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M19.99 10.1871C19.99 9.36767 19.9246 8.76973 19.7831 8.14966H10.1943V11.8493H15.8207C15.7062 12.7676 15.0943 14.1618 13.7567 15.0492L13.7398 15.1632L16.7444 17.4429L16.9637 17.4648C18.8825 15.7291 19.99 13.2042 19.99 10.1871Z" fill="#4285F4"/>
|
||||
<path d="M10.1943 19.9313C12.9592 19.9313 15.2429 19.0454 16.9637 17.4648L13.7567 15.0492C12.8697 15.6438 11.7348 16.0244 10.1943 16.0244C7.50242 16.0244 5.25023 14.2886 4.39644 11.9036L4.28823 11.9125L1.17021 14.2775L1.13477 14.3808C2.84508 17.8028 6.1992 19.9313 10.1943 19.9313Z" fill="#34A853"/>
|
||||
<path d="M4.39644 11.9036C4.1758 11.2746 4.04876 10.6013 4.04876 9.90569C4.04876 9.21011 4.1758 8.53684 4.38177 7.90781L4.37563 7.7883L1.20776 5.3801L1.13477 5.41253C0.436264 6.80439 0.0390625 8.35202 0.0390625 9.90569C0.0390625 11.4594 0.436264 13.007 1.13477 14.3808L4.39644 11.9036Z" fill="#FBBC05"/>
|
||||
<path d="M10.1943 3.78682C12.1168 3.78682 13.397 4.66154 14.1236 5.33481L17.0194 2.59768C15.2373 0.953818 12.9592 0 10.1943 0C6.1992 0 2.84508 2.12847 1.13477 5.41253L4.38177 7.90781C5.25023 5.52278 7.50242 3.78682 10.1943 3.78682Z" fill="#EB4335"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.2 KiB |
@@ -11,7 +11,7 @@ import StarRating from "@/components/onboarding/StarRating";
|
||||
import SchemaTooltip from "@/components/SchemaTooltip";
|
||||
import { TypeBasedInput } from "@/components/type-based-input";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { GraphMeta, StoreAgentDetails } from "@/lib/autogpt-server-api";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { cn } from "@/lib/utils";
|
||||
@@ -46,13 +46,13 @@ export default function Page() {
|
||||
setStoreAgent(storeAgent);
|
||||
});
|
||||
api
|
||||
.getAgentMetaByStoreListingVersionId(state?.selectedStoreListingVersionId)
|
||||
.getGraphMetaByStoreListingVersionID(state.selectedStoreListingVersionId)
|
||||
.then((agent) => {
|
||||
setAgent(agent);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const update: { [key: string]: any } = {};
|
||||
// Set default values from schema
|
||||
Object.entries(agent.input_schema?.properties || {}).forEach(
|
||||
Object.entries(agent.input_schema.properties).forEach(
|
||||
([key, value]) => {
|
||||
// Skip if already set
|
||||
if (state.agentInput && state.agentInput[key]) {
|
||||
@@ -224,7 +224,7 @@ export default function Page() {
|
||||
<CardTitle className="font-poppins text-lg">Input</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent className="flex flex-col gap-4">
|
||||
{Object.entries(agent?.input_schema?.properties || {}).map(
|
||||
{Object.entries(agent?.input_schema.properties || {}).map(
|
||||
([key, inputSubSchema]) => (
|
||||
<div key={key} className="flex flex-col space-y-2">
|
||||
<label className="flex items-center gap-1 text-sm font-medium">
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
"use client";
|
||||
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { GraphID } from "@/lib/autogpt-server-api/types";
|
||||
import FlowEditor from "@/components/Flow";
|
||||
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
|
||||
import { useEffect } from "react";
|
||||
import LoadingBox from "@/components/ui/loading";
|
||||
import { GraphID } from "@/lib/autogpt-server-api/types";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { Suspense, useEffect } from "react";
|
||||
|
||||
export default function BuilderPage() {
|
||||
function BuilderContent() {
|
||||
const query = useSearchParams();
|
||||
const { completeStep } = useOnboarding();
|
||||
|
||||
@@ -15,12 +16,20 @@ export default function BuilderPage() {
|
||||
}, [completeStep]);
|
||||
|
||||
const _graphVersion = query.get("flowVersion");
|
||||
const graphVersion = _graphVersion ? parseInt(_graphVersion) : undefined
|
||||
const graphVersion = _graphVersion ? parseInt(_graphVersion) : undefined;
|
||||
return (
|
||||
<FlowEditor
|
||||
className="flow-container"
|
||||
flowID={query.get("flowID") as GraphID | null ?? undefined}
|
||||
flowID={(query.get("flowID") as GraphID | null) ?? undefined}
|
||||
flowVersion={graphVersion}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export default function BuilderPage() {
|
||||
return (
|
||||
<Suspense fallback={<LoadingBox className="h-[80vh]" />}>
|
||||
<BuilderContent />
|
||||
</Suspense>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,67 +1,10 @@
|
||||
import { Navbar } from "@/components/layout/Navbar/Navbar";
|
||||
import { ReactNode } from "react";
|
||||
import { Navbar } from "@/components/agptui/Navbar";
|
||||
import { IconType } from "@/components/ui/icons";
|
||||
|
||||
export default function PlatformLayout({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<>
|
||||
<Navbar
|
||||
links={[
|
||||
{
|
||||
name: "Marketplace",
|
||||
href: "/marketplace",
|
||||
},
|
||||
{
|
||||
name: "Library",
|
||||
href: "/library",
|
||||
},
|
||||
{
|
||||
name: "Build",
|
||||
href: "/build",
|
||||
},
|
||||
]}
|
||||
menuItemGroups={[
|
||||
{
|
||||
items: [
|
||||
{
|
||||
icon: IconType.Edit,
|
||||
text: "Edit profile",
|
||||
href: "/profile",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
items: [
|
||||
{
|
||||
icon: IconType.LayoutDashboard,
|
||||
text: "Creator Dashboard",
|
||||
href: "/profile/dashboard",
|
||||
},
|
||||
{
|
||||
icon: IconType.UploadCloud,
|
||||
text: "Publish an agent",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
items: [
|
||||
{
|
||||
icon: IconType.Settings,
|
||||
text: "Settings",
|
||||
href: "/profile/settings",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
items: [
|
||||
{
|
||||
icon: IconType.LogOut,
|
||||
text: "Log out",
|
||||
},
|
||||
],
|
||||
},
|
||||
]}
|
||||
/>
|
||||
<Navbar />
|
||||
<main>{children}</main>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"use client";
|
||||
import { useParams, useRouter } from "next/navigation";
|
||||
import { useQueryState } from "nuqs";
|
||||
import React, {
|
||||
useCallback,
|
||||
useEffect,
|
||||
@@ -41,10 +42,11 @@ import {
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import LoadingBox, { LoadingSpinner } from "@/components/ui/loading";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
|
||||
export default function AgentRunsPage(): React.ReactElement {
|
||||
const { id: agentID }: { id: LibraryAgentID } = useParams();
|
||||
const [executionId, setExecutionId] = useQueryState("executionId");
|
||||
const { toast } = useToast();
|
||||
const router = useRouter();
|
||||
const api = useBackendAPI();
|
||||
@@ -202,6 +204,13 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
selectPreset,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
if (executionId) {
|
||||
selectRun(executionId as GraphExecutionID);
|
||||
setExecutionId(null);
|
||||
}
|
||||
}, [executionId, selectRun, setExecutionId]);
|
||||
|
||||
// Initial load
|
||||
useEffect(() => {
|
||||
refreshPageData();
|
||||
@@ -468,7 +477,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="container justify-stretch p-0 lg:flex">
|
||||
<div className="container justify-stretch p-0 pt-16 lg:flex">
|
||||
{/* Sidebar w/ list of runs */}
|
||||
{/* TODO: render this below header in sm and md layouts */}
|
||||
<AgentRunsSelectorList
|
||||
@@ -512,7 +521,8 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
) : selectedView.type == "run" ? (
|
||||
/* Draft new runs / Create new presets */
|
||||
<AgentRunDraftView
|
||||
agent={agent}
|
||||
graph={graph}
|
||||
triggerSetupInfo={agent.trigger_setup_info}
|
||||
onRun={selectRun}
|
||||
onCreateSchedule={onCreateSchedule}
|
||||
onCreatePreset={onCreatePreset}
|
||||
@@ -521,7 +531,8 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
) : selectedView.type == "preset" ? (
|
||||
/* Edit & update presets */
|
||||
<AgentRunDraftView
|
||||
agent={agent}
|
||||
graph={graph}
|
||||
triggerSetupInfo={agent.trigger_setup_info}
|
||||
agentPreset={
|
||||
agentPresets.find((preset) => preset.id == selectedView.id)!
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import { useLibraryAgentList } from "./useLibraryAgentList";
|
||||
export default function LibraryAgentList() {
|
||||
const {
|
||||
agentLoading,
|
||||
agentCount,
|
||||
allAgents: agents,
|
||||
isFetchingNextPage,
|
||||
isSearching,
|
||||
@@ -18,7 +19,7 @@ export default function LibraryAgentList() {
|
||||
return (
|
||||
<>
|
||||
{/* TODO: We need a new endpoint on backend that returns total number of agents */}
|
||||
<LibraryActionSubHeader agentCount={agents.length} />
|
||||
<LibraryActionSubHeader agentCount={agentCount} />
|
||||
<div className="px-2">
|
||||
{agentLoading ? (
|
||||
<div className="flex h-[200px] items-center justify-center">
|
||||
|
||||
@@ -56,11 +56,16 @@ export const useLibraryAgentList = () => {
|
||||
return data.agents;
|
||||
}) ?? [];
|
||||
|
||||
const agentCount = agents?.pages[0]
|
||||
? (agents.pages[0].data as LibraryAgentResponse).pagination.total_items
|
||||
: 0;
|
||||
|
||||
return {
|
||||
allAgents,
|
||||
agentLoading,
|
||||
isFetchingNextPage,
|
||||
hasNextPage,
|
||||
agentCount,
|
||||
isSearching: isFetching && !isFetchingNextPage,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -4,7 +4,7 @@ import { z } from "zod";
|
||||
import { uploadAgentFormSchema } from "./LibraryUploadAgentDialog";
|
||||
import { usePostV1CreateNewGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useState } from "react";
|
||||
import { Graph } from "@/app/api/__generated__/models/graph";
|
||||
import { sanitizeImportedGraph } from "@/lib/autogpt-server-api";
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
"use client";
|
||||
import Link from "next/link";
|
||||
|
||||
import { Alert, AlertDescription } from "@/components/ui/alert";
|
||||
import {
|
||||
ArrowBottomRightIcon,
|
||||
QuestionMarkCircledIcon,
|
||||
} from "@radix-ui/react-icons";
|
||||
import { Alert, AlertDescription } from "@/components/ui/alert";
|
||||
|
||||
import { LibraryPageStateProvider } from "./components/state-provider";
|
||||
import LibraryActionHeader from "./components/LibraryActionHeader/LibraryActionHeader";
|
||||
import LibraryAgentList from "./components/LibraryAgentList/LibraryAgentList";
|
||||
import { LibraryPageStateProvider } from "./components/state-provider";
|
||||
|
||||
/**
|
||||
* LibraryPage Component
|
||||
@@ -17,7 +17,7 @@ import LibraryAgentList from "./components/LibraryAgentList/LibraryAgentList";
|
||||
*/
|
||||
export default function LibraryPage() {
|
||||
return (
|
||||
<main className="container min-h-screen space-y-4 pb-20 sm:px-8 md:px-12">
|
||||
<main className="pt-160 container min-h-screen space-y-4 pb-20 pt-16 sm:px-8 md:px-12">
|
||||
<LibraryPageStateProvider>
|
||||
<LibraryActionHeader />
|
||||
<LibraryAgentList />
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"use server";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import { verifyTurnstileToken } from "@/lib/turnstile";
|
||||
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { revalidatePath } from "next/cache";
|
||||
import { redirect } from "next/navigation";
|
||||
import { z } from "zod";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
||||
import { verifyTurnstileToken } from "@/lib/turnstile";
|
||||
|
||||
async function shouldShowOnboarding() {
|
||||
const api = new BackendAPI();
|
||||
@@ -23,6 +23,7 @@ export async function login(
|
||||
return await Sentry.withServerActionInstrumentation("login", {}, async () => {
|
||||
const supabase = await getServerSupabase();
|
||||
const api = new BackendAPI();
|
||||
const isVercelPreview = process.env.VERCEL_ENV === "preview";
|
||||
|
||||
if (!supabase) {
|
||||
redirect("/error");
|
||||
@@ -30,7 +31,7 @@ export async function login(
|
||||
|
||||
// Verify Turnstile token if provided
|
||||
const success = await verifyTurnstileToken(turnstileToken, "login");
|
||||
if (!success) {
|
||||
if (!success && !isVercelPreview) {
|
||||
return "CAPTCHA verification failed. Please try again.";
|
||||
}
|
||||
|
||||
@@ -38,7 +39,6 @@ export async function login(
|
||||
const { error } = await supabase.auth.signInWithPassword(values);
|
||||
|
||||
if (error) {
|
||||
console.error("Error logging in:", error);
|
||||
return error.message;
|
||||
}
|
||||
|
||||
@@ -76,6 +76,11 @@ export async function providerLogin(provider: LoginProvider) {
|
||||
});
|
||||
|
||||
if (error) {
|
||||
// FIXME: supabase doesn't return the correct error message for this case
|
||||
if (error.message.includes("P0001")) {
|
||||
return "not_allowed";
|
||||
}
|
||||
|
||||
console.error("Error logging in", error);
|
||||
return error.message;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
import { AuthCard } from "@/components/auth/AuthCard";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
|
||||
export function LoadingLogin() {
|
||||
return (
|
||||
<div className="flex h-full min-h-[85vh] flex-col items-center justify-center">
|
||||
<AuthCard title="">
|
||||
<div className="w-full space-y-6">
|
||||
<Skeleton className="mx-auto h-8 w-48" />
|
||||
<Skeleton className="h-12 w-full rounded-md" />
|
||||
<div className="flex w-full items-center">
|
||||
<Skeleton className="h-px flex-1" />
|
||||
<Skeleton className="mx-3 h-4 w-6" />
|
||||
<Skeleton className="h-px flex-1" />
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-4 w-12" />
|
||||
<Skeleton className="h-12 w-full rounded-md" />
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-4 w-16" />
|
||||
<Skeleton className="h-12 w-full rounded-md" />
|
||||
</div>
|
||||
<Skeleton className="h-16 w-full rounded-md" />
|
||||
<Skeleton className="h-12 w-full rounded-md" />
|
||||
<div className="flex justify-center space-x-1">
|
||||
<Skeleton className="h-4 w-32" />
|
||||
<Skeleton className="h-4 w-12" />
|
||||
</div>
|
||||
</div>
|
||||
</AuthCard>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,26 +1,15 @@
|
||||
"use client";
|
||||
import {
|
||||
AuthBottomText,
|
||||
AuthButton,
|
||||
AuthCard,
|
||||
AuthFeedback,
|
||||
AuthHeader,
|
||||
GoogleOAuthButton,
|
||||
PasswordInput,
|
||||
Turnstile,
|
||||
} from "@/components/auth";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import LoadingBox from "@/components/ui/loading";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Link } from "@/components/atoms/Link/Link";
|
||||
import { AuthCard } from "@/components/auth/AuthCard";
|
||||
import AuthFeedback from "@/components/auth/AuthFeedback";
|
||||
import { EmailNotAllowedModal } from "@/components/auth/EmailNotAllowedModal";
|
||||
import { GoogleOAuthButton } from "@/components/auth/GoogleOAuthButton";
|
||||
import Turnstile from "@/components/auth/Turnstile";
|
||||
import { Form, FormField } from "@/components/ui/form";
|
||||
import { getBehaveAs } from "@/lib/utils";
|
||||
import Link from "next/link";
|
||||
import { LoadingLogin } from "./components/LoadingLogin";
|
||||
import { useLoginPage } from "./useLoginPage";
|
||||
|
||||
export default function LoginPage() {
|
||||
@@ -30,17 +19,20 @@ export default function LoginPage() {
|
||||
turnstile,
|
||||
captchaKey,
|
||||
isLoading,
|
||||
isCloudEnv,
|
||||
isLoggedIn,
|
||||
isCloudEnv,
|
||||
shouldNotRenderCaptcha,
|
||||
isUserLoading,
|
||||
isGoogleLoading,
|
||||
showNotAllowedModal,
|
||||
isSupabaseAvailable,
|
||||
handleSubmit,
|
||||
handleProviderLogin,
|
||||
handleCloseNotAllowedModal,
|
||||
} = useLoginPage();
|
||||
|
||||
if (isUserLoading || isLoggedIn) {
|
||||
return <LoadingBox className="h-[80vh]" />;
|
||||
return <LoadingLogin />;
|
||||
}
|
||||
|
||||
if (!isSupabaseAvailable) {
|
||||
@@ -52,99 +44,93 @@ export default function LoginPage() {
|
||||
}
|
||||
|
||||
return (
|
||||
<AuthCard className="mx-auto">
|
||||
<AuthHeader>Login to your account</AuthHeader>
|
||||
<div className="flex h-full min-h-[85vh] flex-col items-center justify-center py-10">
|
||||
<AuthCard title="Login to your account">
|
||||
<Form {...form}>
|
||||
<form onSubmit={handleSubmit} className="flex w-full flex-col gap-1">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="email"
|
||||
render={({ field }) => (
|
||||
<Input
|
||||
id={field.name}
|
||||
label="Email"
|
||||
placeholder="m@example.com"
|
||||
type="email"
|
||||
autoComplete="username"
|
||||
className="w-full"
|
||||
error={form.formState.errors.email?.message}
|
||||
{...field}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="password"
|
||||
render={({ field }) => (
|
||||
<Input
|
||||
id={field.name}
|
||||
label="Password"
|
||||
placeholder="•••••••••••••••••••••"
|
||||
type="password"
|
||||
autoComplete="current-password"
|
||||
error={form.formState.errors.password?.message}
|
||||
hint={
|
||||
<Link variant="secondary" href="/reset-password">
|
||||
Forgot password?
|
||||
</Link>
|
||||
}
|
||||
{...field}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
|
||||
{isCloudEnv ? (
|
||||
<>
|
||||
<div className="mb-6">
|
||||
{/* Turnstile CAPTCHA Component */}
|
||||
{shouldNotRenderCaptcha ? null : (
|
||||
<Turnstile
|
||||
key={captchaKey}
|
||||
siteKey={turnstile.siteKey}
|
||||
onVerify={turnstile.handleVerify}
|
||||
onExpire={turnstile.handleExpire}
|
||||
onError={turnstile.handleError}
|
||||
setWidgetId={turnstile.setWidgetId}
|
||||
action="login"
|
||||
shouldRender={turnstile.shouldRender}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Button
|
||||
variant="primary"
|
||||
loading={isLoading}
|
||||
type="submit"
|
||||
className="mt-6 w-full"
|
||||
>
|
||||
{isLoading ? "Logging in..." : "Login"}
|
||||
</Button>
|
||||
</form>
|
||||
{isCloudEnv ? (
|
||||
<GoogleOAuthButton
|
||||
onClick={() => handleProviderLogin("google")}
|
||||
isLoading={isGoogleLoading}
|
||||
disabled={isLoading}
|
||||
/>
|
||||
</div>
|
||||
<div className="mb-6 flex items-center">
|
||||
<div className="flex-1 border-t border-gray-300"></div>
|
||||
<span className="mx-3 text-sm text-gray-500">or</span>
|
||||
<div className="flex-1 border-t border-gray-300"></div>
|
||||
</div>
|
||||
</>
|
||||
) : null}
|
||||
|
||||
<Form {...form}>
|
||||
<form onSubmit={handleSubmit}>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="email"
|
||||
render={({ field }) => (
|
||||
<FormItem className="mb-6">
|
||||
<FormLabel>Email</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
placeholder="m@example.com"
|
||||
{...field}
|
||||
type="email" // Explicitly specify email type
|
||||
autoComplete="username" // Added for password managers
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
) : null}
|
||||
<AuthFeedback
|
||||
type="login"
|
||||
message={feedback}
|
||||
isError={!!feedback}
|
||||
behaveAs={getBehaveAs()}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="password"
|
||||
render={({ field }) => (
|
||||
<FormItem className="mb-6">
|
||||
<FormLabel className="flex w-full items-center justify-between">
|
||||
<span>Password</span>
|
||||
<Link
|
||||
href="/reset-password"
|
||||
className="text-sm font-normal leading-normal text-black underline"
|
||||
>
|
||||
Forgot your password?
|
||||
</Link>
|
||||
</FormLabel>
|
||||
<FormControl>
|
||||
<PasswordInput
|
||||
{...field}
|
||||
autoComplete="current-password" // Added for password managers
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
{/* Turnstile CAPTCHA Component */}
|
||||
<Turnstile
|
||||
key={captchaKey}
|
||||
siteKey={turnstile.siteKey}
|
||||
onVerify={turnstile.handleVerify}
|
||||
onExpire={turnstile.handleExpire}
|
||||
onError={turnstile.handleError}
|
||||
setWidgetId={turnstile.setWidgetId}
|
||||
action="login"
|
||||
shouldRender={turnstile.shouldRender}
|
||||
/>
|
||||
|
||||
<AuthButton isLoading={isLoading} type="submit">
|
||||
Login
|
||||
</AuthButton>
|
||||
</form>
|
||||
<AuthFeedback
|
||||
type="login"
|
||||
message={feedback}
|
||||
isError={!!feedback}
|
||||
behaveAs={getBehaveAs()}
|
||||
</Form>
|
||||
<AuthCard.BottomText
|
||||
text="Don't have an account?"
|
||||
link={{ text: "Sign up", href: "/signup" }}
|
||||
/>
|
||||
</Form>
|
||||
<AuthBottomText
|
||||
text="Don't have an account?"
|
||||
linkText="Sign up"
|
||||
href="/signup"
|
||||
</AuthCard>
|
||||
<EmailNotAllowedModal
|
||||
isOpen={showNotAllowedModal}
|
||||
onClose={handleCloseNotAllowedModal}
|
||||
/>
|
||||
</AuthCard>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,23 +1,26 @@
|
||||
import { useTurnstile } from "@/hooks/useTurnstile";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { BehaveAs, getBehaveAs } from "@/lib/utils";
|
||||
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { login, providerLogin } from "./actions";
|
||||
import z from "zod";
|
||||
import { BehaveAs } from "@/lib/utils";
|
||||
import { getBehaveAs } from "@/lib/utils";
|
||||
import { login, providerLogin } from "./actions";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
|
||||
export function useLoginPage() {
|
||||
const { supabase, user, isUserLoading } = useSupabase();
|
||||
const [feedback, setFeedback] = useState<string | null>(null);
|
||||
const [captchaKey, setCaptchaKey] = useState(0);
|
||||
const router = useRouter();
|
||||
const { toast } = useToast();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||
const isCloudEnv = getBehaveAs() === BehaveAs.CLOUD;
|
||||
const isVercelPreview = process.env.NEXT_PUBLIC_VERCEL_ENV === "preview";
|
||||
|
||||
const turnstile = useTurnstile({
|
||||
action: "login",
|
||||
@@ -25,6 +28,8 @@ export function useLoginPage() {
|
||||
resetOnError: true,
|
||||
});
|
||||
|
||||
const shouldNotRenderCaptcha = isVercelPreview || turnstile.verified;
|
||||
|
||||
const form = useForm<z.infer<typeof loginFormSchema>>({
|
||||
resolver: zodResolver(loginFormSchema),
|
||||
defaultValues: {
|
||||
@@ -44,29 +49,53 @@ export function useLoginPage() {
|
||||
|
||||
async function handleProviderLogin(provider: LoginProvider) {
|
||||
setIsGoogleLoading(true);
|
||||
|
||||
if (!turnstile.verified && !isVercelPreview) {
|
||||
toast({
|
||||
title: "Please complete the CAPTCHA challenge.",
|
||||
variant: "info",
|
||||
});
|
||||
|
||||
setIsGoogleLoading(false);
|
||||
resetCaptcha();
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const error = await providerLogin(provider);
|
||||
if (error) throw error;
|
||||
setFeedback(null);
|
||||
} catch (error) {
|
||||
resetCaptcha();
|
||||
setFeedback(JSON.stringify(error));
|
||||
} finally {
|
||||
setIsGoogleLoading(false);
|
||||
const errorString = JSON.stringify(error);
|
||||
if (errorString.includes("not_allowed")) {
|
||||
setShowNotAllowedModal(true);
|
||||
} else {
|
||||
setFeedback(errorString);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function handleLogin(data: z.infer<typeof loginFormSchema>) {
|
||||
setIsLoading(true);
|
||||
if (!turnstile.verified) {
|
||||
setFeedback("Please complete the CAPTCHA challenge.");
|
||||
if (!turnstile.verified && !isVercelPreview) {
|
||||
toast({
|
||||
title: "Please complete the CAPTCHA challenge.",
|
||||
variant: "info",
|
||||
});
|
||||
|
||||
setIsLoading(false);
|
||||
resetCaptcha();
|
||||
return;
|
||||
}
|
||||
|
||||
if (data.email.includes("@agpt.co")) {
|
||||
setFeedback("Please use Google SSO to login using an AutoGPT email.");
|
||||
toast({
|
||||
title: "Please use Google SSO to login using an AutoGPT email.",
|
||||
variant: "default",
|
||||
});
|
||||
|
||||
setIsLoading(false);
|
||||
resetCaptcha();
|
||||
return;
|
||||
@@ -76,7 +105,11 @@ export function useLoginPage() {
|
||||
await supabase?.auth.refreshSession();
|
||||
setIsLoading(false);
|
||||
if (error) {
|
||||
setFeedback(error);
|
||||
toast({
|
||||
title: error,
|
||||
variant: "destructive",
|
||||
});
|
||||
|
||||
resetCaptcha();
|
||||
// Always reset the turnstile on any error
|
||||
turnstile.reset();
|
||||
@@ -94,9 +127,12 @@ export function useLoginPage() {
|
||||
isLoading,
|
||||
isCloudEnv,
|
||||
isUserLoading,
|
||||
shouldNotRenderCaptcha,
|
||||
isGoogleLoading,
|
||||
showNotAllowedModal,
|
||||
isSupabaseAvailable: !!supabase,
|
||||
handleSubmit: form.handleSubmit(handleLogin),
|
||||
handleProviderLogin,
|
||||
handleCloseNotAllowedModal: () => setShowNotAllowedModal(false),
|
||||
};
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user